<h1>Excel Worker with LangGraph</h1>

In [None]:
%pip install langchain langgraph pandas numpy python-dotenv duckdb

In [None]:
%pip install --upgrade duckdb

In [None]:
from typing import TypedDict, Optional, Dict, List, Union
import logging
import json
import os
import pandas as pd
import numpy as np
import duckdb
from langchain.tools import tool
from langgraph.graph import StateGraph, END
from langchain_core.messages import SystemMessage, HumanMessage
from langgraph.checkpoint.memory import MemorySaver

from dotenv import load_dotenv

_ = load_dotenv()

In [None]:
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

In [None]:
QUERY_GENERATION_PROMPT = """You are an Excel analysis expert that generates SQL or Pandas queries to analyze data.

Review the preview_data to understand the available columns and data types.
**Evaluate the user_query. If it is simple and can be executed efficiently using Pandas, prioritize simple_dataframe_query.**
**Use complex_duckdb_query for more complex queries that require SQL operations.**
****Explain in the output `reasoning` value step-by-step how you designed the query itself and NOT how you decided for Pandas or DuckBD query..**

Tools Available:
1. simple_dataframe_query: For ANY operation that can be easily done with Pandas
   Input: {"file_name": "example.xlsx", "query": "df.query('column > 0').count()"}
   
2. complex_duckdb_query: For complex SQL operations (GROUP BY, aggregations). ONLY for operations that cannot be done easily in Pandas
   Input: {"file_name": "example.xlsx", "query": "SELECT * FROM data"}

Data Analysis Rules:
1. NULL/Empty Value Handling:
   - When calculating averages across columns:
     * Sum only non-null values
     * Divide by count of non-null values
   - Use CAST(column IS NOT NULL AS INTEGER) to count valid values
   - Always use NULLIF for safe division
   - Handle NULLs before aggregating

2. Query Structure Requirements:
   - Break complex calculations into CTEs
   - Calculate row-level metrics first
   - Then group and aggregate
   - Include validation counts
   - Example:
     WITH base_metrics AS (
       SELECT *,
         COALESCE(value, 0) as clean_value,
         CAST(value IS NOT NULL AS INTEGER) as valid_count
     ),
     row_metrics AS (
       SELECT *,
         SUM(clean_value) / NULLIF(SUM(valid_count), 0) as row_avg
     FROM base_metrics
     GROUP BY row_id
     )

3. DuckDB SQL Rules:
   - Quote columns with spaces: "Column Name"
   - Include all non-aggregated columns in GROUP BY
   - Reference table as 'data'
   - Use proper type casting
   - Add range checks for numeric operations

4. Pandas Rules:
   - Reference DataFrame as 'df'
   - NEVER use NumPy operations or `np.any()`, `np.all()`. Use only pure Pandas functions.
   - Handle NULLs with fillna()
   - Use proper data types
   - Always store newly created columns explicitly before using them in calculations.
   - Ensure the query executes in one go without needing multiple steps.
   - Avoid referencing a new column inside the same assign() call if it hasn’t been created yet.
   - The output should always be a valid Pandas statement that can execute without modification.

Pandas Query Patterns - USE THESE:
✓ Filtering: df.query('column == value')
✓ Counting: df.query('condition').count()
✓ Aggregation: df.query('condition').agg({'col': 'mean'})
✓ Multiple conditions: df.query('col1 > 0 and col2 < 10')

NEVER USE THESE PATTERNS:
✗ df[df['column'] == value]  # No bracket filtering
✗ np.any(), np.all()         # No numpy operations
✗ df.loc[], df.iloc[]        # Avoid direct indexing

**EXTREMELY IMPORTANT OUTPUT FORMAT RULES **:
You must return ONLY a raw JSON object. 
DO NOT wrap the JSON in ```json or ``` or any other markers.
DO NOT add any text before or after the JSON. 
DO NOT include any explanations or additional text outside the JSON.

REQUIRED OUTPUT FORMAT:
{
    "query": "<your query>",
    "llm_prompt": "<A clear description of what the query does, reflecting the original user_input.>",
    "tool": "<tool_name>",
    "reasoning": "<Step-by-step explanation of how you designed the query the query>"
}
"""

In [None]:
VALIDATE_QUERY_PROMPT = """You are an AI assistant validating and improving queries for Excel data analysis.

** VERY IMPORTANT RULES**:
1. WHEN execution_result contains **error** related to Pandas query execution, then you MUST generate SQL query, given the user_query!**
2. DO NOT, I REPEAT, DO NOT validate or change Pandas queries! Just forward them in the "query" value in the output JSON.

Tools Available:
1. simple_dataframe_query: (PREFERRED) For ANY operation that can be easily done with Pandas
   Input: {"file_name": "example.xlsx", "query": "df.query('column > 0').count()"}
   
2. complex_duckdb_query: For complex SQL operations (GROUP BY, aggregations). ONLY for operations that cannot be done easily in Pandas
   Input: {"file_name": "example.xlsx", "query": "SELECT * FROM data"}

**SQL Query Validation:**
1. NULL Handling:
   - Verify NULLs are properly excluded from calculations
   - Check for NULLIF in divisions
   - Confirm COALESCE usage for NULL defaults
   - Validate empty value handling in aggregations

2. Aggregation Validation:
   - All non-aggregated columns must be in GROUP BY
   - Verify aggregation functions match business logic
   - Check percentage calculations sum correctly
   - Validate count metrics against row counts

3. Query Structure:
   - Use CTEs for complex calculations
   - Break operations into logical steps:
     * Clean and validate data
     * Calculate row-level metrics
     * Perform grouping and aggregation
   - Include data validation steps
   - Add range checks for numeric operations

4. Error Prevention:
   - Add NULLIF for division operations
   - Include proper type casting
   - Validate numeric operations
   - Handle edge cases explicitly
   
5. Complex Calculation Patterns:
   WITH data_validation AS (
     -- Clean and validate input data
   ),
   row_metrics AS (
     -- Calculate row-level statistics
   ),
   aggregated_results AS (
     -- Perform final aggregations
   )


**EXTREMELY IMPORTANT OUTPUT FORMAT RULES **:
You must return ONLY a raw JSON object. 
DO NOT wrap the JSON in ```json or ``` or any other markers.
DO NOT add any text before or after the JSON. 
DO NOT include any explanations or additional text outside the JSON.

REQUIRED OUTPUT FORMAT:
{
    "tool": "tool_name",
    "query": "<provide here either the original generated query or an improved version>"
}
"""

In [None]:

class AgentState(TypedDict):
    user_input: str
    query: Optional[str]
    file_name: str
    preview_data: Optional[Dict[str, List]]
    query_result: Optional[Dict]
    llm_prompt: Optional[str]
    tool: Optional[str]  # Track selected tool
    iterations_count: int
    error: Optional[str]  # Track errors

@tool
def load_preview_data(file_name: str) -> dict:
    """Examine Excel file structure and data types."""
    try:
        if not file_name:
            raise ValueError("File name must be provided")
            
        df = pd.read_excel(os.path.join(os.getcwd(), file_name), nrows=1)
        df = df.replace(r'^\s*$', None, regex=True)  # Convert empty strings to None
        df = df.replace(['nan', 'NaN', 'null'], None)  # Convert NaN strings to None
        df = df.where(pd.notnull(df), None)  # Convert pandas NaN to None
        
        return {
            "columns": df.columns.tolist(),
            "dtypes": df.dtypes.astype(str).to_dict(),
            "sample_rows": df.to_dict(orient="records")
        }
    except Exception as e:
        raise ValueError(f"Failed to examine Excel structure: {str(e)}")

@tool
def complex_duckdb_query(input: dict) -> dict:
    """Execute complex SQL operations with DuckDB."""
    try:
        if not isinstance(input, dict):
            return {"error": "Input must be a dictionary"}
            
        file_name = input.get("file_name")
        query = input.get("query", "").strip()
        
        if not file_name or not query:
            return {"error": "Both file_name and query required"}
            
        # Read and preprocess DataFrame
        df = pd.read_excel(os.path.join(os.getcwd(), file_name))
        df = df.replace(r'^\s*$', None, regex=True)
        df = df.replace(['nan', 'NaN', 'null'], None)
        df = df.where(pd.notnull(df), None)
        
        # Connect to DuckDB and register DataFrame
        con = duckdb.connect()
        # Register as 'data' to match the prompt
        con.register("data", df)
        
        try:
            # Execute query
            result = con.execute(query).fetchdf()
            
            # Process results
            if result is None:
                return {"result": None}
                
            # Handle different result types
            if isinstance(result, pd.DataFrame):
                # Clean the results
                result = result.replace([float('inf'), -float('inf')], None)
                result = result.where(pd.notna(result), None)
                
                # Convert object columns to strings where needed
                for col in result.select_dtypes(include=['object']).columns:
                    result[col] = result[col].apply(
                        lambda x: str(x) if x is not None else None
                    )
                
                return {
                    "result": {
                        "columns": result.columns.tolist(),
                        "rows": result.to_dict(orient="records")
                    }
                }
            else:
                return {"result": str(result)}
                
        except Exception as e:
            return {"error": f"DuckDB query error: {str(e)}"}
            
    except Exception as e:
        return {"error": f"Error executing query: {str(e)}"}
    finally:
        if 'con' in locals():
            con.close()

@tool
def simple_dataframe_query(input: dict) -> dict:
    """Execute simple Pandas operations."""
    try:
        if not isinstance(input, dict):
            return {"error": "Input must be a dictionary"}
            
        file_name = input.get("file_name")
        query = input.get("query", "").strip()
        
        if not file_name or not query:
            return {"error": "Both file_name and query required"}
            
        # Read and preprocess DataFrame
        df = pd.read_excel(os.path.join(os.getcwd(), file_name))
        df = df.replace(r'^\s*$', None, regex=True)
        df = df.replace(['nan', 'NaN', 'null'], None)
        df = df.where(pd.notnull(df), None)
        
        # Create safe execution environment
        safe_globals = {
            "df": df,
            "pd": pd,
            "np": np,  # Add numpy for calculations
            "__builtins__": {}  # Restrict built-ins for safety
        }
        
        # Execute query
        result = eval(query, safe_globals, {})
        
        # Process different result types
        if isinstance(result, pd.DataFrame):
            result = result.replace([float('inf'), -float('inf')], None)
            result = result.where(pd.notna(result), None)
            
            return {
                "result": {
                    "type": "DataFrame",
                    "columns": result.columns.tolist(),
                    "rows": result.to_dict(orient="records")
                }
            }
        elif isinstance(result, pd.Series):
            result = result.replace([float('inf'), -float('inf')], None)
            result = result.where(pd.notna(result), None)
            
            return {
                "result": {
                    "type": "Series",
                    "name": result.name,
                    "values": result.tolist()
                }
            }
        else:
            # Handle scalar results
            if pd.isna(result):
                result = None
            return {
                "result": {
                    "type": "scalar",
                    "value": result
                }
            }
            
    except Exception as e:
        return {"error": f"Error executing query: {str(e)}"}


def generate_query_node(state: AgentState) -> AgentState:
    """Generates initial query based on user input."""
    try:
        # Ensure required fields exist
        if not state.get("file_name") or not state.get("user_input"):
            raise ValueError("Missing required fields: file_name or user_input")
            
        # Load preview data if not present
        if state.get("preview_data") is None:
            state["preview_data"] = load_preview_data(state["file_name"])
            
        # Generate query using LLM
        messages = [
            SystemMessage(content=QUERY_GENERATION_PROMPT),
            HumanMessage(content=json.dumps({
                "preview_data": state["preview_data"],
                "user_query": state["user_input"][0] if isinstance(state["user_input"], tuple) else state["user_input"]
            }, indent=2))
        ]
        
        response = model_with_tools.invoke(messages)
        print(f"===================== Generate Query Node - Raw LLM response\n{response.content}\n=====================")
        
        if not response or not response.content:
            raise ValueError("Empty response from LLM")
            
        try:
            query_data = json.loads(response.content)
            state["query"] = query_data.get("query", "").strip()
            state["tool"] = query_data.get("tool")  # Store selected tool
            state["llm_prompt"] = query_data.get("llm_prompt", 
                state["user_input"][0] if isinstance(state["user_input"], tuple) else state["user_input"]
            ).strip()
            
            # Validate required fields
            if not state["query"] or not state["tool"]:
                raise ValueError("Missing required fields in LLM response")
                
        except json.JSONDecodeError:
            raise ValueError("Invalid JSON response from LLM")
            
        return state
        
    except Exception as e:
        logger.error(f"Error in generate_query_node: {str(e)}")
        state["error"] = str(e)
        state["query"] = ""
        state["tool"] = None
        state["llm_prompt"] = state["user_input"][0] if isinstance(state["user_input"], tuple) else state["user_input"]
        return state

def validate_and_execute_query_node(state: AgentState) -> AgentState:
    """Validates and executes query with error handling."""
    try:
        # Check for errors from previous step
        if state.get("error"):
            logger.error(f"Previous error detected: {state['error']}")
            return state
            
        # Ensure required fields exist
        if not state.get("query") or not state.get("tool"):
            state["error"] = "Missing query or tool selection"
            return state
            
        # Ensure preview data exists
        if state.get("preview_data") is None:
            state["preview_data"] = load_preview_data(state["file_name"])
            
        # Prepare validation message
        messages = [
            SystemMessage(content=VALIDATE_QUERY_PROMPT),
            HumanMessage(content=json.dumps({
                "preview_data": state["preview_data"],
                "llm_query": state["query"],
                "user_query": state["llm_prompt"],
                "execution_result": state.get("query_result")
            }))
        ]
        print(f"===================== Validate and Execute Query Node - Messages\n{messages}\n=====================")
        
        # Get and parse LLM response
        response = model_with_tools.invoke(messages)
        print(f"===================== Validate and Execute Query Node - Raw LLM response\n{response.content}\n=====================")
        
        try:
            response_dict = json.loads(response.content)
            
            # Update state with validated query and tool
            state["query"] = response_dict.get("query", "").strip()
            state["tool"] = response_dict.get("tool")
            
            if not state["query"] or not state["tool"]:
                raise ValueError("Invalid validation response")
                
        except (json.JSONDecodeError, ValueError) as e:
            state["error"] = f"Validation error: {str(e)}"
            return state
            
        # Execute query with appropriate tool
        # Fixed: Properly structure the input dictionary
        input_data = {
            "input": {  # Add this nesting level
                "file_name": state["file_name"],
                "query": state["query"]
            }
        }
        
        if state["tool"] == "simple_dataframe_query":
            result = simple_dataframe_query.invoke(input_data)  # Use invoke instead of direct call
        elif state["tool"] == "complex_duckdb_query":
            result = complex_duckdb_query.invoke(input_data)  # Use invoke instead of direct call
        else:
            state["error"] = f"Unknown tool: {state['tool']}"
            return state
            
        # Update state
        state["query_result"] = result
        state["iterations_count"] = state.get("iterations_count", 0) + 1
        
        return state
        
    except Exception as e:
        logger.error(f"Error in validate_and_execute_query_node: {str(e)}")
        state["error"] = str(e)
        return state


def next_step(state: AgentState) -> str:
    """Determines next step in workflow with improved error handling."""
    iterations_count = state.get("iterations_count", 0)
    query_result = state.get("query_result", {})
    
    # Check for errors in query_result
    has_error = (
        isinstance(query_result, dict) and 
        "error" in query_result and 
        query_result["error"] is not None
    )
    
    if has_error and iterations_count <= 3:
        logger.info("Retrying validation (iteration %d)", iterations_count)
        return "validate_and_execute_query_node"
    
    logger.info("Workflow complete")
    return END

# Function to create and initialize the graph
def create_agent_graph(memory: Optional[MemorySaver] = None) -> StateGraph:
    builder = StateGraph(AgentState)
    
    # Add nodes
    builder.add_node("generate_query_node", generate_query_node)
    builder.add_node("validate_and_execute_query_node", validate_and_execute_query_node)
    
    # Add edges
    builder.add_edge("generate_query_node", "validate_and_execute_query_node")
    builder.add_conditional_edges(
        "validate_and_execute_query_node",
        next_step,
        {END: END, "validate_and_execute_query_node": "validate_and_execute_query_node"}
    )
    
    # Set entry/exit points
    builder.set_entry_point("generate_query_node")
    builder.set_finish_point("validate_and_execute_query_node")
    
    return builder.compile(checkpointer=memory)

# Helper function to pretty print events
def pretty_print_event(event: dict):
    """Pretty-print a single event (dictionary)."""
    print(json.dumps(event, indent=2))

def run_agent_query(user_input: str, file_name: str):
    """Run a query through the agent."""
    memory = MemorySaver()
    graph = create_agent_graph(memory)
    
    thread = {"configurable": {"thread_id": "2"}}
    initial_state = {
        "user_input": user_input,
        "file_name": file_name,
        "iterations_count": 0,
        "query": None,
        "preview_data": None,
        "query_result": None,
        "llm_prompt": None,
        "tool": None,
        "error": None
    }
    
    events = []
    for event in graph.stream(initial_state, thread):
        events.append(event)
        pretty_print_event(event)
        
    # Return final state
    return events[-1] if events else None

In [None]:
from IPython.display import Image
graph = create_agent_graph()
Image(graph.get_graph().draw_png())

In [None]:
# Initialize exactly as in original
from langchain_openai import ChatOpenAI
model = ChatOpenAI(model="gpt-4o", temperature=0)
model_with_tools = model.bind_tools(
    [simple_dataframe_query, complex_duckdb_query],
    parallel_tool_calls=False
)

# Run query
user_query = "Which ticker symbol have avg performance above 50% given the performance of a ticker symbol is spread in the months 1 through 13. Ignore empty months in the calculation!",
file_name = "ipo_data.xlsx"
run_agent_query(user_query, file_name)