In [None]:
%%capture --no-stderr
%pip install --quiet -U langgraph langchain-community langchain-openai pymongo[srv]==3.12

In [None]:
from dotenv import load_dotenv

# Load environment variables from .env file
load_dotenv()

## Define tools

In [None]:
from langchain_core.messages import AIMessage
from langgraph.prebuilt import ToolNode
from langchain_core.tools import tool
from pymongo import MongoClient
from datetime import datetime

import os
import json
from bson import json_util

# Define tools
@tool
def get_sales_data(query: dict = None):
    """Get sales data from MongoDB using flexible query parameters.
    
    Args:
        query: Custom MongoDB query dict to execute directly
    """    
    
    mongo_query = query
    
    uri = os.getenv("MONGODB_URI")
    if not uri:
        return "Error: MongoDB URI not found in environment variables"
        
    try:
        client = MongoClient(uri)
        database = client.get_database("mongodbVSCodePlaygroundDB")
        collection = database.get_collection("sales")
        
        # Execute the query with options
        cursor = collection.find(mongo_query)
        sales_data = list(cursor)
        
        # Convert results to a readable format (handle MongoDB ObjectId and Date objects)
        results = json.loads(json_util.dumps(sales_data))
        
        client.close()
        
        if not sales_data:
            return f"No sales found for query: {mongo_query}"  
        
        return {
            "count": len(sales_data),
            "query": mongo_query,
            "sales": results
        }
    
    except Exception as e:
        return f"Error retrieving sales data: {str(e)}"

In [None]:
from langchain_core.messages import AIMessage

# Create a tool node with the get_sales_data tool
tool_node = ToolNode([get_sales_data])

# Test with a query for 2014 sales
query = {'item': 'xyz'}

print("\nMethod using AIMessage with tool calls")
tool_call_message = AIMessage(
    content="",
    tool_calls=[
        {
            "id": "1",
            "type": "function",
            "name": "get_sales_data",
            "args": {
                "query": query
            }
        }
    ]
)

result = tool_node.invoke({"messages": [tool_call_message]})
print("Tool node result: " + str(result))

## Define MongoDB Query Agent

In [None]:
from typing import Literal
from langchain_openai import AzureChatOpenAI
from langgraph.types import Command
import re

def mongo_query_agent(state) -> Command[Literal["__end__"]]:
    """Agent that processes user queries about sales data and uses LLM to generate a mongo query based on schema."""
    
    # Extract the user message
    user_message = state["messages"][-1].content
    
    # Define the MongoDB schema for the LLM
    schema_info = """
    MongoDB Collection Schema:
    {
        'item': string,      // Product identifier like 'abc', 'xyz', etc.
        'price': number,     // Price of the item (e.g., 10, 7.5)
        'quantity': number,  // Number of items sold (e.g., 5, 10)
        'date': Date     // Date of the sale in ISO format 
    }
    """

    # Create system prompt for the LLM
    system_prompt = f"""
        You are a MongoDB query assistant. Based on user requests, generate appropriate MongoDB queries.
        {schema_info}

        Analyze the user query, determine what information they need from the sales data, and formulate a MongoDB query dictionary.
        Your response should be a valid Python dictionary that can be used as a MongoDB query.
        
        IMPORTANT: Return ONLY the Python dictionary without any additional text, markdown formatting, or code blocks.
        
        For example, if the user asks about sales in 2014, you might return: 
        {{'date': {{'$gte': datetime(2014, 1, 1), '$lt': datetime(2015, 1, 1)}}}}
        
        If they ask about an item 'xyz', you might return:
        {{'item': 'xyz'}}
        
        Be specific and precise with your queries.
        """

    # Call the LLM to analyze the user query
    llm = AzureChatOpenAI(
        azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT"],
        azure_deployment=os.environ["AZURE_OPENAI_DEPLOYMENT_NAME"],
        openai_api_version="2024-12-01-preview",
    )
    
    chat_prompt = [{"role": "system", "content": system_prompt}, 
                  {"role": "user", "content": f"Generate a MongoDB query for: {user_message}"}]
    
    llm_response = llm.invoke(chat_prompt)
    query_str = llm_response.content
    
    # Clean up the response - remove markdown code blocks if present
    query_str = re.sub(r'```(?:python|json)?\n?', '', query_str)
    query_str = re.sub(r'```\n?', '', query_str)
    query_str = query_str.strip()
    
    print("Generated MongoDB query:", query_str)
    
    # Return the result with a tool call to get_sales_data
    return {"messages": state["messages"] + [AIMessage(
        content="",
        tool_calls=[
            {
                "id": "1",
                "type": "function",
                "name": "get_sales_data",
                "args": {
                    "query": eval(query_str)  # Convert string representation to dict
                }
            }
        ]
    )]}

## Define the executor agent

In [None]:
from langchain import hub
from langchain_openai import AzureChatOpenAI
from langgraph.prebuilt import create_react_agent
from langchain_core.messages import HumanMessage

from langchain_community.tools.tavily_search import TavilySearchResults
tools = [TavilySearchResults(max_results=3), get_sales_data]

# Custom prompt for MongoDB query processing
prompt = """
You are a MongoDB data analyst assistant. Your job is to:
1. Process MongoDB queries provided to you
2. Use the get_sales_data tool to retrieve information from the MongoDB database
3. Analyze the results to provide insights about the sales data
4. If needed, use search tools to provide additional context for the data

When using the get_sales_data tool:
- The tool accepts a MongoDB query dictionary as input
- Format your response to highlight key insights from the data
- Include counts, summaries, and analysis of the data when possible

Be thorough in your analysis but concise in your explanations.
"""

# Choose the LLM that will drive the agent
llm = AzureChatOpenAI(
    azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT"],
    azure_deployment=os.environ["AZURE_OPENAI_DEPLOYMENT_NAME"],
    openai_api_version="2024-12-01-preview",
)

agent_executor = create_react_agent(model=llm, tools=tools, prompt=prompt)

In [None]:
def execute_step(state):
    # Extract the user message from the messages list
    messages = state["messages"]
    
    # The last message should be a tool output from the mongo_query_agent
    # We want to extract the MongoDB query that was generated
    last_message = messages[-1]
    mongo_query = None
    
    # Get the MongoDB query from the tool calls in the last message
    if hasattr(last_message, "tool_calls") and last_message.tool_calls:
        for tool_call in last_message.tool_calls:
            if tool_call["name"] == "get_sales_data" and "query" in tool_call["args"]:
                mongo_query = tool_call["args"]["query"]
                break
    
    if not mongo_query:
        # Fallback to using the user's original query if we couldn't extract the MongoDB query
        user_message = messages[0].content
        return {
            "messages": messages + [AIMessage(content=f"Could not extract MongoDB query from previous step. Original query: {user_message}")]
        }
    
    # Now we have the MongoDB query, use LLM (agent_executor) to process it
    # Create a description of what we're looking for based on the MongoDB query
    query_description = f"Get sales data matching this MongoDB query: {mongo_query}"
    
    # Use the agent_executor to run the query through LLM reasoning
    agent_response = agent_executor.invoke({"messages": [HumanMessage(content=query_description)]})
    
    # Extract the agent's final response
    final_response = agent_response["messages"][-1]
    
    # Return the original messages plus the agent's response
    # This makes sure the approval agent can access the results
    return {
        "messages": messages + [final_response]
    }

## Define an approval agent

In [None]:
from typing import Literal, TypedDict, List, Any
import json
import re
from langchain_core.tools import tool
from langchain_core.messages import AIMessage
from langgraph.types import Command

# Define a state class to track approval status
class ApprovalState(TypedDict):
    messages: List
    approval_status: str
    reason: str
    count: int

# Mock approval/denial API tools
@tool
def approve_request(data: dict):
    """Approve a request based on the given data.
    
    Args:
        data: The data supporting the approval decision
    """
    return f"Request APPROVED with {data['count']} matching records found"

@tool
def deny_request(reason: str):
    """Deny a request with a provided reason.
    
    Args:
        reason: The reason for denial
    """
    return f"Request DENIED. Reason: {reason}"

def approval_agent(state) -> Command[Literal["__end__"]]:
    """Agent that evaluates the MongoDB query results and makes approval decisions."""
    
    # Debug - print the state structure to understand what we're working with
    print("\nApproval agent received state:")
    print(f"State type: {type(state)}")
    print(f"State keys: {state.keys() if isinstance(state, dict) else 'Not a dict'}")
    print(f"Messages count: {len(state['messages'])}")
    
    # Get the latest message which should contain the tool execution result or agent_executor's response
    last_message = state["messages"][-1]
    print(f"Last message type: {type(last_message)}")
    
    # Initialize result tracking variables
    has_results = False
    result_count = 0
    error_message = None
    result_data = None
    
    # New approach: first check if this is a response from agent_executor with content
    if hasattr(last_message, "content") and last_message.content:
        content = last_message.content
        print(f"Content preview: {content[:100]}...")
        
        # Look for indicators of successful results in the content
        has_data_indicators = ['found', 'result', 'data', 'sales', 'record', 'item']
        has_results = any(indicator in content.lower() for indicator in has_data_indicators)
        
        # Look for numbers that might indicate result count
        count_matches = re.findall(r'(\d+)\s+(?:record|item|result|sale|match)', content.lower())
        if count_matches:
            result_count = int(count_matches[0])
            has_results = True
        elif 'no' in content.lower() and any(neg in content.lower() for neg in ['result', 'match', 'found', 'record']):
            has_results = False
            error_message = "No matching records found according to analysis"
    
    # If we couldn't determine from content, check tool outputs as before
    if not has_results and hasattr(last_message, "tool_outputs") and last_message.tool_outputs:
        print(f"Tool outputs: {last_message.tool_outputs}")
        tool_output = last_message.tool_outputs[0]
        
        # Extract the content from the tool output
        if "content" in tool_output:
            content = tool_output["content"]
            print(f"Content type: {type(content)}")
            print(f"Content preview: {str(content)[:100]}...")
            
            # Handle different content types
            if isinstance(content, str):
                # Check for error or no results messages
                if "No sales found" in content:
                    has_results = False
                    error_message = "No matching records found"
                elif "Error" in content:
                    has_results = False
                    error_message = f"Error in query: {content}"
                else:
                    # Try to parse as JSON
                    try:
                        result_data = json.loads(content)
                        if isinstance(result_data, dict) and "count" in result_data:
                            result_count = result_data["count"]
                            has_results = result_count > 0
                    except json.JSONDecodeError:
                        # Not JSON, check if it looks like it has data
                        has_results = "sales" in content.lower() and "No sales" not in content
                        result_count = 1  # Default count if we can't determine
            elif isinstance(content, dict):
                # It's already a dict
                result_data = content
                if "count" in result_data:
                    result_count = result_data["count"]
                    has_results = result_count > 0
                elif "sales" in result_data:
                    # Try to infer count from sales data
                    if isinstance(result_data["sales"], list):
                        result_count = len(result_data["sales"])
                        has_results = result_count > 0
    
    # Make approval decision based on results
    print(f"Results analysis: has_results={has_results}, count={result_count}")
    
    if has_results:
        print("DECISION: Approving request")
        # Send approval with the data
        return {"messages": state["messages"] + [AIMessage(
            content="",
            tool_calls=[{
                "id": "1",
                "type": "function",
                "name": "approve_request",
                "args": {"data": {"count": result_count} if result_data is None else result_data}
            }]
        )]}
    else:
        print("DECISION: Denying request")
        # Send denial with reason
        reason = error_message if error_message else "No matching records found or unable to determine results"
        return {"messages": state["messages"] + [AIMessage(
            content="",
            tool_calls=[{
                "id": "1",
                "type": "function",
                "name": "deny_request",
                "args": {"reason": reason}
            }]
        )]}


## build the graph

In [None]:
from langgraph.graph import StateGraph, MessagesState, START

# Create a tool node for the approval/denial tools
approval_tool_node = ToolNode([approve_request, deny_request])

# Build the graph
builder = StateGraph(MessagesState)
builder.add_node("mongo_query_agent", mongo_query_agent)
builder.add_node("execute_step", execute_step)
builder.add_node("approval_agent", approval_agent)
builder.add_node("approval_tool_executor", approval_tool_node)

# Connect the nodes
builder.add_edge(START, "mongo_query_agent")
builder.add_edge("mongo_query_agent", "execute_step")
builder.add_edge("execute_step", "approval_agent")
builder.add_edge("approval_agent", "approval_tool_executor")
builder.add_edge("approval_tool_executor", "__end__")

# Compile the network
network = builder.compile()

In [None]:
from IPython.display import Image, display

display(Image(network.get_graph(xray=True).draw_mermaid_png()))

In [None]:
# Test the agent with some example queries
from langchain_core.messages import HumanMessage

# Example: Query for a specific item in 2014
response = network.invoke({"messages": [HumanMessage(content="How many xyz items were sold in 2014?")]})
print("\nQuery Response:")
print(response["messages"][-1].content)

# Show all messages in the chain to see the approval decision
print("\nComplete message chain:")
for i, msg in enumerate(response["messages"]):
    print(f"\nMessage {i+1} ({msg.__class__.__name__}):")
    if hasattr(msg, "content") and msg.content:
        print(f"Content: {msg.content}")
    if hasattr(msg, "tool_calls") and msg.tool_calls:
        print(f"Tool calls: {msg.tool_calls}")
    if hasattr(msg, "tool_outputs") and msg.tool_outputs:
        print(f"Tool outputs: {msg.tool_outputs}")

## Test with specific MongoDB query types

In [None]:
# Test with direct MongoDB query examples to test agent_executor's reasoning

# Example 1: Test with a query for specific item type
test_query = {'item': 'xyz'}
print("\nTesting direct MongoDB query processing with agent_executor:")
print(f"Query: {test_query}")

# Manually testing the execute_step function with a mock state
from langchain_core.messages import HumanMessage, AIMessage

# Create a mock state with a tool call containing our test query
mock_state = {
    "messages": [HumanMessage(content="Find sales data for xyz items"), 
                AIMessage(
                    content="",
                    tool_calls=[{
                        "id": "1",
                        "type": "function",
                        "name": "get_sales_data",
                        "args": {"query": test_query}
                    }]
                )]
}

# Execute the step with our mock state
result = execute_step(mock_state)
print("\nAgent executor result:")
print(result["messages"][-1].content)

In [None]:
# Test the complete pipeline with the updated execute_step function
print("\nTesting full pipeline with agent_executor integration:")

# Example with a more complex query that requires reasoning
complex_query = network.invoke({"messages": [HumanMessage(content="What was the average price of items sold in quantities greater than 5?")]})
print("\nComplex Query Response:")
print(complex_query["messages"][-1].content)

# Show the progression through the pipeline
print("\nComplete message chain for complex query:")
for i, msg in enumerate(complex_query["messages"]):
    print(f"\nMessage {i+1} ({msg.__class__.__name__}):")
    if hasattr(msg, "content") and msg.content:
        content_preview = msg.content[:150] + '...' if len(msg.content) > 150 else msg.content
        print(f"Content: {content_preview}")
    if hasattr(msg, "tool_calls") and msg.tool_calls:
        print(f"Tool calls: {msg.tool_calls}")
    if hasattr(msg, "tool_outputs") and msg.tool_outputs:
        print(f"Tool outputs: {msg.tool_outputs[:2]}... (truncated)" if len(msg.tool_outputs) > 2 else f"Tool outputs: {msg.tool_outputs}")