In [56]:
!pip install langchain pymongo langchain-openai



In [79]:
import os
import pymongo
import json
from langchain_openai import ChatOpenAI
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnableLambda, RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser
from langchain.callbacks.tracers import ConsoleCallbackHandler
from pymongo.mongo_client import MongoClient
from pymongo.server_api import ServerApi

In [58]:
from google.colab import userdata
os.environ["OPENAI_API_KEY"] = userdata.get('OPEN_AI_KEY')
os.environ.pop('OPENAI_ORGANIZATION', None)
os.environ.pop('OPENAI_PROJECT_ID', None)

In [87]:
# --- MongoDB Atlas Configuration ---
MONGO_URI = userdata.get('MONGO_ATLAS_URI')
DB_NAME = "organization_db"
COLLECTION_NAME = "employees"

def get_mongo_collection():
    """Returns the MongoDB collection object."""
    try:
        # print(MONGO_URI)
        client = MongoClient(MONGO_URI, server_api=ServerApi('1'))
        db = client[DB_NAME]
        collection = db[COLLECTION_NAME]

        # Check if the collection is empty and populate it
        if collection.count_documents({}) == 0:
            employees = [
                {'id': 1, 'name': 'Alice', 'department': 'Engineering', 'salary': 120000},
                {'id': 2, 'name': 'Bob', 'department': 'Sales', 'salary': 90000},
                {'id': 3, 'name': 'Charlie', 'department': 'Engineering', 'salary': 150000},
                {'id': 4, 'name': 'David', 'department': 'Marketing', 'salary': 85000},
                {'id': 5, 'name': 'Eve', 'department': 'Sales', 'salary': 95000},
                {'id': 6, 'name': 'Frank', 'department': 'Engineering', 'salary': 130000},
            ]
            collection.insert_many(employees)
            print("Employee collection populated with sample data.")

        return collection

    except Exception as e:
        print(f"Error connecting to MongoDB: {e}")
        return None

In [None]:
#Run it to populate the data in db and see if its working
get_mongo_collection()
collection = get_mongo_collection()
collection.count_documents({})

In [88]:
# Custom Tool for MongoDB Interaction ---
def run_mongo_query(query_json: dict) -> str:
  # print("in run_mongo_query")
  """
  This is a custom tool that takes a PyMongo query as a string, executes it,
  and returns the result. The LLM will generate a query in a specific format
  that we can parse and execute.

  Args:
      query_string: A string representing a Python expression to query the
                    MongoDB collection.
                    Example: 'collection.count_documents({"department": "Engineering"})'
                    Another example: 'list(collection.find({"department": "Sales"}))'

  Returns:
      A JSON string of the query result or an error message.
  """
  # print(f"query string {query_json}")
  # return query_string


  collection = get_mongo_collection()
  if collection is None:
      return "Database connection failed."

  try:
      # print("Before execute query")
      # print(query_json)
      # print(query_json.get('type'))
      # print(query_json.get('collection'))
      # print(query_json.get('query'))

      if query_json.get('type') == "count_query":
        # print("in count query")
        result = collection.count_documents(query_json.get('query'))
      elif query_json.get('type') == "find_query":
        # print("in find query")
        result = list(collection.find(query_json.get('query')))
      elif query_json.get('type') == "aggregator_query":
        # print("in aggregator query")
        result = list(collection.aggregate(query_json.get('query')))

      # print("After execute query")

      # Handle different types of query results
      if isinstance(result, (list, dict)):
          # Convert results to a JSON string
          return json.dumps(result, default=str)
      else:
          return str(result)

  except Exception as e:
      return f"Error executing MongoDB query: {e}"

In [62]:
#Define the prompt ---
query_gen_template = """
You are a skilled MongoDB expert and an expert at converting natural language questions into mongodb query json.

The MongoDB collection 'employees' has documents with fields: 'name', 'department', 'salary'.

Use the collection name instead of collection in the actual query.

You need to provide the collection name as a key in output json on which the query needs to be perform, the actual query in the query key and the type of query like count_query, find_query, aggregator_query.
Your final output should be the json string, nothing else.
You should only generate the query for read operations, for other operations you can simply deny and give output "Out of my capacity as of now"

Question: {question}
"""

query_gen_prompt = PromptTemplate.from_template(query_gen_template)

In [63]:
# Custom chain to execute the generated query
# The RunnableLambda turns our function into a chain component.
query_execution_runnable = RunnableLambda(run_mongo_query)

In [64]:
# Final chain to summarize the result
final_summary_template = """
You are a friendly chatbot that answers questions based on a query result.
The user asked: {question}
The query result was: {query_result}

Provide a clear, concise answer.
"""
final_summary_prompt = PromptTemplate.from_template(final_summary_template)

In [65]:
def print_and_pass(input_data):
    """Prints the input data and returns it unchanged."""
    print("--- Intermediate Output ---")
    print(input_data)
    print("-------------------------")
    return input_data

In [83]:
def get_actual_query(input_dict):
  # print(f"input dict {input_dict}")
  cleaned_string = input_dict.strip().removeprefix('```json').removesuffix('```').strip()
  # print(f"cleaned string {cleaned_string}")
  # print(f"json output {json.loads(cleaned_string)}")
  return json.loads(cleaned_string)

In [89]:
llm = ChatOpenAI(model="gpt-4o", temperature=0)

full_chain = (
    # Step 1: Start with input
    RunnablePassthrough.assign(question=lambda x: x["question"])

    # Step 2: Generate parsed query from question
    .assign(
        parsed_query=(
            query_gen_prompt
            | llm
            | StrOutputParser()
            | get_actual_query
        )
    )

    # Step 3: Run MongoDB query
    .assign(
        query_result=RunnableLambda(lambda x: run_mongo_query(x["parsed_query"]))
    )

    # Step 4: Summarize results using question & query_result
    | final_summary_prompt

    # Step 5: Pass to LLM
    | llm

    # Step 6: Parse final output string
    | StrOutputParser()

    # | RunnableLambda(print_and_pass)
)

In [90]:
while True:
  user_query = input("You: ")
  if user_query.lower() == 'exit':
      break

  try:
      # We now use the `.invoke()` method with a dictionary.
      response = full_chain.invoke({"question": user_query}, config={
          # "verbose": True,
          # "callbacks": [ConsoleCallbackHandler()]
          })
      print(f"AI: {response.strip()}")
  except Exception as e:
      print(f"An error occurred: {e}")

You: how many employee are there
AI: There are 6 employees.
You: what is the total salary of all employee
AI: The total salary of all employees is $670,000.
You: how many employee are there in Engineering department
AI: There are 3 employees in the Engineering department.
You: who is the highest paid employee in company
AI: The highest paid employee in the company is Charlie from the Engineering department, with a salary of $150,000.
You: can you give me count of employees in each department
AI: Certainly! Here is the count of employees in each department:

- Engineering: 3 employees
- Sales: 2 employees
- Marketing: 1 employee
You: exit
