In [1]:
import os

if os.getcwd().endswith("notebooks"):
    os.chdir("..")
print(os.getcwd())

/Users/cmcoutosilva/Projects/github/nl2sql-agent


In [None]:
from datetime import datetime
from typing import Annotated, Any, Literal

import pandas as pd
from langchain.chat_models import init_chat_model
from langchain_core.messages import AIMessage, HumanMessage
from langgraph.checkpoint.memory import InMemorySaver
from langgraph.graph import END, START, StateGraph
from langgraph.graph.message import BaseMessage, add_messages
from langgraph.types import Command, interrupt
from loguru import logger
from pydantic import BaseModel

from nl2sql.agents.utils import (
    execute_sql_query,
    format_answer,
    format_query_results_for_llm,
    get_chat_history,
    validate_sql_syntax,
)
from nl2sql.config import UNSAFE_SQL_KEYWORDS, load_chat_prompt_template
from nl2sql.database.postgresql import PostgreSQLConnector
from nl2sql.knowledge_base.data_dictionary import DataDictionary
from nl2sql.knowledge_base.sql_examples import SQLExample
from nl2sql.knowledge_base.vector_store import VectorStore

In [3]:
# ===============================
# Vector Store
# ===============================

db_connector = PostgreSQLConnector(config_path="configs/database.yml")
vector_store = VectorStore(db_connector)

In [None]:
# ===============================
# Session-based Memory
# ===============================

# In-memory session saver
memory = InMemorySaver()

# Thread ID used to identify the session
session_id = "test_" + datetime.now().isoformat()
thread_config = {"configurable": {"thread_id": session_id}}

# ===============================
# State
# ===============================


class State(BaseModel):
    """State for the NL2SQL agent."""

    messages: Annotated[list[BaseMessage], add_messages]
    user_query: str | None = None
    user_intent: Literal["sql", "chat"] | None = None
    sql_query: str | None = None
    sql_explanation: str | None = None
    sql_execution_result: dict[str, Any] | None = None
    sql_execution_analysis: str | None = None
    sql_safety_status: Literal["safe", "unsafe"] | None = None
    sql_syntax_status: Literal["valid", "invalid"] | None = None
    user_feedback_status: Literal["approved", "rejected"] | None = None
    sql_execution_status: Literal["success", "failure"] | None = None


# ===============================
# Prompts & Knowledge Base
# ===============================

intent_classifier_prompt = load_chat_prompt_template(target_prompt="intent_classifier")
sql_generator_prompt = load_chat_prompt_template(target_prompt="sql_generator")

# Load knowledge base components
data_dictionary = DataDictionary.load()
sql_examples = SQLExample.from_yaml("knowledge/sql_examples.yml")

# ===============================
# Knowledge Base
# ===============================

# Load data dictionary and SQL examples
data_dictionary = DataDictionary.load()
sql_examples = SQLExample.from_yaml("knowledge/sql_examples.yml")

# ===============================
# Agent Nodes
# ===============================


def intent_classifier(state: State) -> dict:
    """Determine if user wants chat or SQL functionality."""
    logger.info("🔄 [Node] Intent Classifier")

    # Retrieve chat history
    chat_history = get_chat_history(state.messages[:-1])
    user_query = state.messages[-1].content
    logger.debug(f"Chat history:\n{chat_history}")
    logger.debug(f"Last message: {user_query}")

    # Classify user intent using LLM
    llm = init_chat_model(model="gpt-4.1-mini", model_provider="openai", temperature=0)
    router_llm_chain = intent_classifier_prompt | llm

    response = router_llm_chain.invoke(
        {"user_message": user_query, "chat_history": chat_history}
    )

    detected_user_intent = response.content.strip().lower()
    logger.debug(f"Detected user intent: {detected_user_intent}")

    # Default to chat if the intent is not chat or sql
    if detected_user_intent not in ["chat", "sql"]:
        logger.warning(
            "⚠️ LLM Router returned unexpected classification: "
            f"'{detected_user_intent}'. Defaulting to chat."
        )
        return {"user_intent": "chat"}

    return {"user_intent": detected_user_intent, "user_query": user_query}


def chat_agent(state: State, vector_store: VectorStore) -> dict:
    """Chat agent node."""
    logger.info("🔄 [Node] Chat Agent")
    response = "Hello, how can I help you today?"
    logger.debug(f"✅ Chat Agent response: {response[:50]}...")
    return {"messages": AIMessage(content=response)}


def sql_generator(state: State, vector_store: VectorStore) -> dict:
    """Generate SQL query from natural language using LLM and context."""
    logger.info("🔄 [Node] SQL Generator")
    logger.debug(f"User question: {state.user_query}")

    # Get chat history for context
    chat_history = get_chat_history(state.messages[:-1])
    logger.debug(f"Chat history length: {len(chat_history)}")

    # Format schema context from data dictionary
    schema_context = data_dictionary.format_context()

    # Format SQL examples for few-shot learning
    retrieved_docs = vector_store.vectorstore.similarity_search(
        state.user_query, k=4, filter={"type": "example"}
    )
    logger.debug(f"Retrieved {len(retrieved_docs)} docs")
    sql_examples_context = "\n\n".join([doc.page_content for doc in retrieved_docs])

    # Initialize LLM
    llm = init_chat_model(
        model="gpt-4.1-mini", model_provider="openai", temperature=0
    ).with_structured_output(method="json_mode")  # returns dict directly

    # Create SQL generation chain
    llm_chain = sql_generator_prompt | llm

    # Generate SQL query
    response = llm_chain.invoke(
        {
            "user_query": state.user_query,
            "chat_history": chat_history,
            "schema_context": schema_context,
            "sql_examples": sql_examples_context,
        }
    )

    # TODO: Add key validation for the response

    return response


def sql_safety_validator(state: State) -> dict:
    """Validate if the SQL query is safe."""
    import re

    logger.info("🔄 [Node] SQL Safety Validator")

    found_unsafe_keywords = []
    for keyword in UNSAFE_SQL_KEYWORDS:
        if re.search(rf"\b{keyword}\b", state.sql_query):
            found_unsafe_keywords.append(keyword.upper())

    if found_unsafe_keywords:
        logger.error(f"❌ SQL contains unsafe keywords: {found_unsafe_keywords}")
        ai_message = AIMessage(
            content=f"❌ SQL contains unsafe keywords: {found_unsafe_keywords}"
        )
        return {"messages": [ai_message], "sql_safety_status": "unsafe"}
    else:
        logger.debug("✅ SQL is safe")
        return {"sql_safety_status": "safe"}


def sql_syntax_validator(state: State, db_connector: PostgreSQLConnector) -> dict:
    """Validate SQL syntax and attempt to fix errors using LLM."""
    logger.info("🔄 [Node] SQL Syntax Validator")

    max_retries = 3
    current_query = state.sql_query

    for attempt in range(max_retries):
        # Validate current query
        validation_result = validate_sql_syntax(current_query, db_connector)

        # If the query is valid, return success
        if validation_result["valid_syntax"]:
            logger.debug("✅ Syntax Validator: SQL syntax is valid.")
            return {
                "sql_syntax_status": "valid",
                "sql_query": current_query,
            }

        # Log the error and the current query for debugging
        logger.debug(f"Invalid syntax → Entering Fix Attempt #{attempt + 1}...")
        logger.debug(f"Current query: {current_query}")
        logger.debug(f"Error: {validation_result['error']}")

        # Use LLM to fix the error
        sql_syntax_fixer_prompt = load_chat_prompt_template(
            target_prompt="sql_syntax_fixer"
        )
        llm = init_chat_model(model="gpt-4.1", model_provider="openai", temperature=0)
        llm_chain = sql_syntax_fixer_prompt | llm
        fixed_query = llm_chain.invoke(
            {
                "query": current_query,
                "error": validation_result["error"],
            }
        ).content
        current_query = fixed_query

    logger.warning(
        f"⚠️ Syntax Validator: Failed to fix SQL syntax after {max_retries} attempts."
    )
    return {
        "sql_syntax_status": False,
        "messages": [
            AIMessage(
                content=(
                    f"❌ I couldn't generate valid SQL syntax. Error: "
                    f"{validation_result['error']}"
                )
            )
        ],
    }


def human_feedback(state: State) -> dict:
    """Ask for human confirmation and pause for input."""
    logger.info("🔄 [Node] Human Feedback")

    # Format the answer for the user
    formatted_answer = format_answer(state)
    formatted_answer += "Should I execute this query? Answer with 'yes' or 'no'."

    # Create the AI message
    ai_message = AIMessage(content=formatted_answer)

    # Mark that we're waiting for confirmation (Slack integration will detect this)
    human_reply = (
        interrupt(
            {
                "messages": ai_message,
                "waiting_for_confirmation": True,
            }
        )
        .content.strip()
        .lower()
    )

    user_feedback_status = "approved" if "y" in human_reply else "rejected"

    if user_feedback_status == "approved":
        logger.debug(f"✅ Human feedback: {user_feedback_status}")
    else:
        logger.debug(f"❌ Human feedback: {user_feedback_status}")

    return {
        "messages": [ai_message, human_reply],
        "user_feedback_status": user_feedback_status,
    }


def sql_executor(state: State, db_connector: PostgreSQLConnector) -> dict:
    """Execute the SQL query."""
    logger.info("🔄 [Node] SQL Executor")

    # Execute the SQL query
    sql_execution_result = execute_sql_query(state.sql_query, db_connector)
    sql_execution_status = "success" if sql_execution_result["success"] else "failure"

    if sql_execution_status:
        logger.debug(f"✅ SQL execution status: {sql_execution_status}")
        logger.debug(f"✅ SQL result: {sql_execution_result['data'][:50]}...")
    else:
        logger.debug(f"❌ SQL execution status: {sql_execution_status}")
        logger.debug(f"❌ SQL result: {sql_execution_result['error']}")

    return {
        "sql_execution_status": sql_execution_status,
        "sql_execution_result": sql_execution_result,
    }


def sql_result_analyzer(state: State) -> dict:
    """Analyse the SQL result using LLM."""
    logger.info("🔄 [Node] SQL Result Analyser")

    # Reconstruct execution result with DataFrame for formatting
    execution_result_for_formatting = state.sql_execution_result.copy()
    if execution_result_for_formatting["data"] is not None:
        # Convert back to DataFrame for the formatting function
        execution_result_for_formatting["data"] = pd.DataFrame(
            execution_result_for_formatting["data"]
        )

    # Format results for LLM interpretation
    formatted_results = format_query_results_for_llm(execution_result_for_formatting)

    # Load the result interpretation prompt
    result_analyzer_prompt = load_chat_prompt_template(target_prompt="result_analyzer")

    # Initialize LLM
    llm = init_chat_model(model="gpt-4.1", model_provider="openai", temperature=0.1)

    # Create the interpretation chain
    llm_chain = result_analyzer_prompt | llm

    # Generate interpretation
    response = llm_chain.invoke(
        {
            "user_query": state.user_query,
            "sql_query": state.sql_query,
            "query_results": formatted_results,
        }
    )

    analyzed_result = response.content
    logger.info("✅ Results analyzed successfully")

    # Create final response message
    final_message = AIMessage(content=analyzed_result)

    return {
        "messages": [final_message],
        "sql_execution_analysis": analyzed_result,
    }


# ===============================
# Node Rounters
# ===============================


def route_intent(state: State) -> Literal["sql", "chat"]:
    """Route intent to either SQL or Chat agent."""
    logger.debug(f"→ Routing to {state.user_intent}")
    return state.user_intent


def check_sql_generation(state: State) -> Literal["success", "failure"]:
    """Check if the SQL query is valid."""
    if state.sql_query and state.sql_query.strip():
        logger.debug("→ Routing to success")
        return "success"
    else:
        logger.debug("→ Routing to failure")
        return "failure"


def check_sql_safety(state: State) -> Literal["safe", "unsafe"]:
    """Check if the SQL query is safe."""
    logger.debug(f"→ Routing to {state.sql_safety_status}")
    return "safe" if state.sql_safety_status == "safe" else "unsafe"


def check_sql_syntax(state: State) -> Literal["valid", "invalid"]:
    """Check if the SQL query is valid."""
    logger.debug(f"→ Routing to {state.sql_syntax_status}")
    return "valid" if state.sql_syntax_status == "valid" else "invalid"


def check_human_feedback(state: State) -> Literal["approved", "rejected"]:
    """Check if user approved the SQL query."""
    logger.debug(f"→ Routing to {state.user_feedback_status}")
    return "approved" if state.user_feedback_status == "approved" else "rejected"


def check_sql_execution(state: State) -> Literal["success", "failure"]:
    """Check if the SQL query was executed successfully."""
    logger.debug(f"→ Routing to {state.sql_execution_status}")
    return "success" if state.sql_execution_status == "success" else "failure"


# ===============================
# Graph
# ===============================

# State Workflow
workflow = StateGraph(State)

# Add nodes
workflow.add_node("intent_classifier", intent_classifier)
workflow.add_node("chat_agent", chat_agent)
workflow.add_node("sql_generator", sql_generator)
workflow.add_node("sql_safety_validator", sql_safety_validator)
workflow.add_node("sql_syntax_validator", sql_syntax_validator)
workflow.add_node("human_feedback", human_feedback)
workflow.add_node("sql_executor", sql_executor)
workflow.add_node("sql_result_analyzer", sql_result_analyzer)

# Add Edges
workflow.add_edge(START, "intent_classifier")

workflow.add_conditional_edges(
    "intent_classifier",
    route_intent,
    {
        "sql": "sql_generator",
        "chat": "chat_agent",
    },
)

workflow.add_conditional_edges(
    "sql_generator",
    check_sql_generation,
    {
        "success": "sql_safety_validator",
        "failure": END,
    },
)

workflow.add_conditional_edges(
    "sql_safety_validator",
    check_sql_safety,
    {
        "safe": "sql_syntax_validator",
        "unsafe": END,
    },
)

workflow.add_conditional_edges(
    "sql_syntax_validator",
    check_sql_syntax,
    {
        "valid": "human_feedback",
        "invalid": END,
    },
)

workflow.add_conditional_edges(
    "human_feedback",
    check_human_feedback,
    {
        "approved": "sql_executor",
        "rejected": END,
    },
)

workflow.add_conditional_edges(
    "sql_executor",
    check_sql_execution,
    {
        "success": "sql_result_analyzer",
        "failure": END,
    },
)

workflow.add_edge("sql_result_analyzer", END)
workflow.add_edge("chat_agent", END)

# Compile Graph
graph = workflow.compile(checkpointer=memory)

# # Display Graph
# display(graph)

In [5]:
# Simulate SQL agent interaction

state = State(
    messages=[
        HumanMessage(content="What are the most popular product categories?"),
        # HumanMessage(content="Hello!"),
    ],
)

graph.invoke(state, config=thread_config)

[32m2025-07-08 05:10:45.376[0m | [1mINFO    [0m | [36m__main__[0m:[36mintent_classifier[0m:[36m59[0m - [1m🔄 [Node] Intent Classifier[0m
[32m2025-07-08 05:10:45.376[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mintent_classifier[0m:[36m64[0m - [34m[1mChat history:
[0m
[32m2025-07-08 05:10:45.377[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mintent_classifier[0m:[36m65[0m - [34m[1mLast message: What are the most popular product categories?[0m
[32m2025-07-08 05:10:45.733[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mintent_classifier[0m:[36m76[0m - [34m[1mDetected user intent: sql[0m
[32m2025-07-08 05:10:45.734[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mroute_intent[0m:[36m325[0m - [34m[1m→ Routing to sql[0m
[32m2025-07-08 05:10:45.735[0m | [1mINFO    [0m | [36m__main__[0m:[36msql_generator[0m:[36m99[0m - [1m🔄 [Node] SQL Generator[0m
[32m2025-07-08 05:10:45.736[0m | [34m[1mDEBUG   [0m | [36m__main__[

{'messages': [HumanMessage(content='What are the most popular product categories?', additional_kwargs={}, response_metadata={}, id='03413633-c13c-4418-9a01-82fc7b5e370f')],
 'user_query': 'What are the most popular product categories?',
 'user_intent': 'sql',
 'sql_query': 'SELECT p."product_category_name", COUNT(*) AS order_count FROM "ecommerce"."order_items" oi JOIN "ecommerce"."products" p ON oi."product_id" = p."product_id" GROUP BY p."product_category_name" ORDER BY order_count DESC;',
 'sql_explanation': 'This query counts the number of order items for each product category to determine the most popular product categories, ordering the results by the count in descending order.',
 'sql_safety_status': 'safe',
 'sql_syntax_status': 'valid',
 '__interrupt__': [Interrupt(value={'messages': AIMessage(content='**SQL:**\n```sql\nSELECT p."product_category_name", COUNT(*) AS order_count FROM "ecommerce"."order_items" oi JOIN "ecommerce"."products" p ON oi."product_id" = p."product_id" G

In [6]:
graph.invoke(Command(resume=HumanMessage(content="yes")), thread_config)

[32m2025-07-08 05:10:47.943[0m | [1mINFO    [0m | [36m__main__[0m:[36mhuman_feedback[0m:[36m217[0m - [1m🔄 [Node] Human Feedback[0m
[32m2025-07-08 05:10:47.944[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mhuman_feedback[0m:[36m241[0m - [34m[1m✅ Human feedback: approved[0m
[32m2025-07-08 05:10:47.944[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcheck_human_feedback[0m:[36m353[0m - [34m[1m→ Routing to approved[0m
[32m2025-07-08 05:10:47.945[0m | [1mINFO    [0m | [36m__main__[0m:[36msql_executor[0m:[36m253[0m - [1m🔄 [Node] SQL Executor[0m
[32m2025-07-08 05:10:47.945[0m | [34m[1mDEBUG   [0m | [36mnl2sql.agents.utils[0m:[36mexecute_sql_query[0m:[36m81[0m - [34m[1m🔄 Executing SQL query: SELECT p."product_category_name", COUNT(*) AS order_count FROM "ecommerce"."order_items" oi JOIN "ec...[0m
[32m2025-07-08 05:10:47.990[0m | [34m[1mDEBUG   [0m | [36mnl2sql.agents.utils[0m:[36mexecute_sql_query[0m:[36m88[0m - [34m[

{'messages': [HumanMessage(content='What are the most popular product categories?', additional_kwargs={}, response_metadata={}, id='03413633-c13c-4418-9a01-82fc7b5e370f'),
  AIMessage(content='**SQL:**\n```sql\nSELECT p."product_category_name", COUNT(*) AS order_count FROM "ecommerce"."order_items" oi JOIN "ecommerce"."products" p ON oi."product_id" = p."product_id" GROUP BY p."product_category_name" ORDER BY order_count DESC;\n```\n\n**Explanation:**\nThis query counts the number of order items for each product category to determine the most popular product categories, ordering the results by the count in descending order.\nShould I execute this query? Answer with \'yes\' or \'no\'.', additional_kwargs={}, response_metadata={}, id='7b5b843d-42de-401d-ab41-9fd1f2092d92'),
  HumanMessage(content='yes', additional_kwargs={}, response_metadata={}, id='cee58282-c931-45b4-a547-9115d39eb005'),
  AIMessage(content='### Interpretation of Results: Most Popular Product Categories\n\n**Top 5 Prod