In [None]:
from langgraph.graph import StateGraph
from langgraph.graph.message import add_messages
from typing import TypedDict
import pymongo
from openai import OpenAI
import json
import os
from dotenv import load_dotenv

# Load environment variables from .env file
load_dotenv()
api_key = os.getenv("API_KEY")


# Setup OpenAI/Grok model
client = OpenAI(api_key=api_key, 
                base_url="https://api.x.ai/v1")

def generate_grok_response(prompt, context=None):
    print("\n" + "="*50)
    print("MODEL THINKING:")
    print(f"PROMPT:\n{prompt}")
    if context:
        print(f"\nCONTEXT PROVIDED:\n{context}")
    
    messages = [{"role": "system", "content": "you are a helpful assistant for analyzing procurement data."}]
    if context:
        messages.append({"role": "assistant", "content": str(context)})
    messages.append({"role": "user", "content": prompt})
    
    completion = client.chat.completions.create(
        model="grok-2-latest",
        temperature=0,
        messages=messages
    )
    
    response = completion.choices[0].message.content
    print(f"\nMODEL RESPONSE:\n{response}")
    print("="*50 + "\n")
    return response

# MongoDB setup with your specific schema
mongo_client = pymongo.MongoClient("mongodb://localhost:27017/")
db = mongo_client["procurement_db"]
collection = db["procurement_data"]

# LangGraph state
class State(TypedDict):
    query: str
    query_type: str  # 'procurement_related' or 'out_of_context'
    mongo_script: dict
    mongo_result: dict
    final_response: str

def classify_query(state: State) -> State:
    print("\n" + "="*30)
    print("CLASSIFYING QUERY")
    print(f"Original query: {state['query']}")
    
    prompt = f"""Analyze this query and determine if it's about procurement data or not. 
Classify as:
1. 'procurement_related' - if it's about purchases, suppliers, items, spending, contracts
2. 'out_of_context' - if it's unrelated to procurement

Our database contains procurement records with these exact fields:
- Creation Date, Purchase Date, Fiscal Year
- Purchase Order Number, Acquisition Type, Acquisition Method
- Department Name, Supplier Code, Supplier Name
- CalCard, Item Name, Item Description
- Quantity, Unit Price, Total Price
- Various date/time breakdowns (Year, Month, Quarter, etc.)

Query: {state['query']}

Return ONLY the classification ('procurement_related' or 'out_of_context')."""
    
    classification = generate_grok_response(prompt).strip().lower()
    print(f"Classification result: {classification}")
    
    if "procurement" in classification or "related" in classification:
        state["query_type"] = "procurement_related"
    else:
        state["query_type"] = "out_of_context"
    
    print(f"Final classification: {state['query_type']}")
    print("="*30 + "\n")
    return state

def generate_mongo_script(state: State) -> State:
    print("\n" + "="*30)
    print("GENERATING MONGO QUERY")
    print(f"Working with query: {state['query']}")
    first_stage = """
    {{"$match": {{...filters...}}},
    {{"$group": {{"_id": null, "maxValue": {{"$max": "$FieldName"}}}}}
"""
    second_stage = """
    {{"$match": {{"Year": 2013}}},
    {{"$group": {{"_id": null, "maxPrice": {{"$max": "$Unit Price"}}}}}
"""

    prompt = f"""Generate a MongoDB AGGREGATION PIPELINE (as a valid JSON string) for procurement data based on this question:
"{state['query']}"

For questions about maximum/minimum/average values, use this format:
{{
    "pipeline": [
      {first_stage}
    ]
}}

The database collection has these EXACT fields:
- Dates: Creation Date (string), Purchase Date (string), Fiscal Year (string)
- Order Info: Purchase Order Number (string), Acquisition Type (string), Acquisition Method (string)
- Department: Department Name (string)
- Supplier: Supplier Code (float), Supplier Name (string), CalCard (string)
- Items: Item Name (string), Item Description (string)
- Financials: Quantity (float), Unit Price (float), Total Price (float)
- Time Analysis: Year (int), Month (int), Quarter (int), various date parts

Rules:
1. Return ONLY the JSON object with "pipeline" key
2. Use null (not None) for MongoDB
3. No Python syntax or markdown formatting
4. For year filters, use either "Year" (integer) or "Fiscal Year" (string)

Example for "max Unit Price in 2013":
{{
    "pipeline": [
       {second_stage}
    ]
}}"""
    
    response = generate_grok_response(prompt)
    
    print(f"Raw model response: {response}")
    
    try:
        # Parse as JSON instead of evaluating as Python code
        import json
        query_obj = json.loads(response)
        if not isinstance(query_obj.get("pipeline"), list):
            raise ValueError("Response should contain a pipeline array")
            
        print(f"Valid MongoDB aggregation generated: {query_obj['pipeline']}")
        state["mongo_script"] = query_obj["pipeline"]
    except Exception as e:
        print(f"Error in query generation: {str(e)}")
        state["mongo_script"] = {"error": f"Failed to generate MongoDB query: {str(e)}"}
    
    print("="*30 + "\n")
    return state

def execute_mongo_query(state: State) -> State:
    print("\n" + "="*30)
    print("EXECUTING MONGO QUERY")
    print(f"Query to execute: {state['mongo_script']}")
    
    mongo_query = state["mongo_script"]
    
    if isinstance(mongo_query, list):  # Now handling aggregation pipelines
        try:
            result = list(collection.aggregate(mongo_query))
            print(f"Aggregation executed successfully. Result: {result}")
            state["mongo_result"] = result[0] if result else {"error": "No matching procurement found"}
        except Exception as e:
            print(f"MongoDB execution error: {str(e)}")
            state["mongo_result"] = {"error": f"MongoDB error: {str(e)}"}
    elif isinstance(mongo_query, dict) and "error" not in mongo_query:
        try:
            result = collection.find_one(mongo_query)
            print(f"Query executed successfully. Result: {result}")
            state["mongo_result"] = result or {"error": "No matching procurement found"}
        except Exception as e:
            print(f"MongoDB execution error: {str(e)}")
            state["mongo_result"] = {"error": f"MongoDB error: {str(e)}"}
    else:
        print("Invalid query - skipping execution")
        state["mongo_result"] = {"error": "Invalid Mongo query - not executed"}
    
    print("="*30 + "\n")
    return state

def generate_response(state: State) -> State:
    print("\n" + "="*30)
    print("GENERATING FINAL RESPONSE")
    print(f"Original query: {state['query']}")
    print(f"Available context: {state.get('mongo_result', {})}")
    
    context = state.get("mongo_result", {})
    
    if state["query_type"] == "out_of_context":
        prompt = f"""The user asked this non-procurement question:
{state['query']}

Provide a helpful but brief response explaining this system handles procurement data."""
    else:
        prompt = f"""The user asked this procurement-related question:
{state['query']}

Database results: {context}

Generate a detailed response that:
1. Directly answers the question
2. Formats numbers properly (e.g., $46.56)
3. Includes all relevant fields from the result
4. Explains if no matching procurement was found
5. Is professional and clear

Example good response:
"Found a matching procurement record:
- Item: tape
- Supplier: NATIONAL OFFICE SOLUTIONS
- Purchase Date: June 26, 2012
- Total Price: $46.56
- Department: Developmental Services" """
    
    state["final_response"] = generate_grok_response(prompt, context).strip()
    
    print(f"Final response prepared: {state['final_response']}")
    print("="*30 + "\n")
    return state

# Build LangGraph
builder = StateGraph(State)

builder.add_node("classify", classify_query)
builder.add_node("generate_mongo_script", generate_mongo_script)
builder.add_node("execute_mongo", execute_mongo_query)
builder.add_node("respond", generate_response)

builder.set_entry_point("classify")

def route_classification(state: State) -> str:
    print("\n" + "="*30)
    print("ROUTING DECISION")
    print(f"Routing based on query type: {state['query_type']}")
    print("="*30 + "\n")
    return state["query_type"]

builder.add_conditional_edges(
    "classify",
    route_classification,
    {
        "out_of_context": "respond",
        "procurement_related": "generate_mongo_script"
    }
)

builder.add_edge("generate_mongo_script", "execute_mongo")
builder.add_edge("execute_mongo", "respond")

builder.set_finish_point("respond")
graph = builder.compile()



In [16]:
print(graph.get_graph().draw_mermaid())


---
config:
  flowchart:
    curve: linear
---
graph TD;
	__start__([<p>__start__</p>]):::first
	classify(classify)
	generate_mongo_script(generate_mongo_script)
	execute_mongo(execute_mongo)
	respond(respond)
	__end__([<p>__end__</p>]):::last
	__start__ --> classify;
	classify -. &nbsp;procurement_related&nbsp; .-> generate_mongo_script;
	classify -. &nbsp;out_of_context&nbsp; .-> respond;
	execute_mongo --> respond;
	generate_mongo_script --> execute_mongo;
	respond --> __end__;
	classDef default fill:#f2f0ff,line-height:1.2
	classDef first fill-opacity:0
	classDef last fill:#bfb6fc



In [29]:
print(graph.get_graph().draw_ascii())



                     +-----------+          
                     | __start__ |          
                     +-----------+          
                            *               
                            *               
                            *               
                      +----------+          
                      | classify |          
                      +----------+          
                   ...            ...       
                 ..                  ..     
               ..                      ..   
+-----------------------+                .. 
| generate_mongo_script |                 . 
+-----------------------+                 . 
            *                             . 
            *                             . 
            *                             . 
    +---------------+                    .. 
    | execute_mongo |                  ..   
    +---------------+                ..     
                   ***            ...       
          

In [37]:
print("TESTING THE FLOW")
test_state = {"query": "what is max Unit Price in 2013?"}
result = graph.invoke(test_state)
print("\nFINAL RESULT:")
print(result["final_response"])

TESTING THE FLOW

CLASSIFYING QUERY
Original query: what is max Unit Price in 2013?

MODEL THINKING:
PROMPT:
Analyze this query and determine if it's about procurement data or not. 
Classify as:
1. 'procurement_related' - if it's about purchases, suppliers, items, spending, contracts
2. 'out_of_context' - if it's unrelated to procurement

Our database contains procurement records with these exact fields:
- Creation Date, Purchase Date, Fiscal Year
- Purchase Order Number, Acquisition Type, Acquisition Method
- Department Name, Supplier Code, Supplier Name
- CalCard, Item Name, Item Description
- Quantity, Unit Price, Total Price
- Various date/time breakdowns (Year, Month, Quarter, etc.)

Query: what is max Unit Price in 2013?

Return ONLY the classification ('procurement_related' or 'out_of_context').

MODEL RESPONSE:
procurement_related

Classification result: procurement_related
Final classification: procurement_related


ROUTING DECISION
Routing based on query type: procurement_rel