In [None]:
from typing_extensions import TypedDict, Annotated
from typing import List
import os
import re
import operator

from langgraph.graph import StateGraph, START, END
from langgraph.prebuilt import ToolNode
from langchain_ollama import ChatOllama
from langchain_community.utilities import SQLDatabase
from langchain_core.tools import tool
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage

from dotenv import load_dotenv
load_dotenv()

In [None]:
# =============================================================================
# Database Setup
# =============================================================================

print("Setting up employees database...")

# Connect to employees database
db = SQLDatabase.from_uri("sqlite:///db/employees_db-full-1.0.6.db")

# Check connection and get basic info
try:
    tables = db.get_usable_table_names()
    print(f"âœ“ Database connected successfully")
    print(f"âœ“ Found {len(tables)} tables: {', '.join(tables)}")
except Exception as e:
    print(f"âœ— Database connection failed: {e}")

# Get schema information
SCHEMA = db.get_table_info()
print("âœ“ Connected to employees database")

In [None]:
# =============================================================================
# Configuration
# =============================================================================

LLM_MODEL = "qwen3"
BASE_URL = "http://localhost:11434"

llm = ChatOllama(model=LLM_MODEL, base_url=BASE_URL, reasoning=True)

response = llm.invoke("Hello, how are you?")
response.pretty_print()
print("âœ“ Initialized Ollama chat model")

In [None]:
# =============================================================================
# SQL Tools
# =============================================================================

@tool
def get_database_schema(table_name: str = None) -> str:
    """Get database schema information for SQL query generation.
    Use this first to understand table structure before creating queries."""
    print(f"[TOOL] Getting schema for: {table_name if table_name else 'all tables'}")
    
    if table_name:
        try:
            tables = db.get_usable_table_names()
            if table_name.lower() in [t.lower() for t in tables]:
                result = db.get_table_info([table_name])
                print(f"[TOOL] Retrieved schema for table: {table_name}")
                return result
            else:
                return f"Error: Table '{table_name}' not found. Available tables: {', '.join(tables)}"
        except Exception as e:
            return f"Error getting table info: {e}"
    else:
        print("[TOOL] Retrieved full database schema")
        return SCHEMA

In [None]:
@tool
def generate_sql_query(question: str, schema_info: str = None) -> str:
    """Generate a SQL SELECT query from a natural language question using database schema.
    Always use this after getting schema information."""
    print(f"[TOOL] Generating SQL for: {question[:100]}...")
    
    schema_to_use = schema_info if schema_info else SCHEMA
    
    prompt = f"""Based on this database schema:
                {schema_to_use}

                Generate a SQL query to answer this question: {question}

                Rules:
                - Use only SELECT statements
                - Include only existing columns and tables
                - Add appropriate WHERE, GROUP BY, ORDER BY clauses as needed
                - Limit results to 10 rows unless specified otherwise
                - Use proper SQL syntax for SQLite

                Return only the SQL query, nothing else."""
    
    try:
        response = llm.invoke(prompt)
        query = response.content.strip()
        print(f"[TOOL] Generated SQL query")
        return query
    except Exception as e:
        return f"Error generating query: {e}"

In [None]:
@tool
def validate_sql_query(query: str) -> str:
    """Validate SQL query for safety and syntax before execution.
    Returns 'Valid: <query>' if safe or 'Error: <message>' if unsafe."""
    print(f"[TOOL] Validating SQL: {query[:100]}...")
    
    query = query.strip()
    
    # Remove common SQL formatting
    clean_query = re.sub(r'```sql\s*', '', query, flags=re.IGNORECASE)
    clean_query = re.sub(r'```\s*', '', clean_query)
    clean_query = clean_query.strip()
    
    # Block multiple statements
    if clean_query.count(";") > 1 or (clean_query.endswith(";") and ";" in clean_query[:-1]):
        return "Error: Multiple statements not allowed"
    
    clean_query = clean_query.rstrip(";").strip()
    
    # Must be SELECT only
    if not clean_query.lower().startswith("select"):
        return "Error: Only SELECT statements allowed"
    
    # Block dangerous operations
    dangerous_patterns = [
        r'\b(INSERT|UPDATE|DELETE|ALTER|DROP|CREATE|REPLACE|TRUNCATE)\b',
        r'\b(EXEC|EXECUTE)\b',
        r'--',
        r'/\*',
    ]
    
    for pattern in dangerous_patterns:
        if re.search(pattern, clean_query, re.IGNORECASE):
            return f"Error: Unsafe SQL pattern detected"
    
    # Basic syntax validation
    try:
        if clean_query.lower().count('select') != 1:
            return "Error: Multiple SELECT statements not allowed"
        
        if clean_query.count('(') != clean_query.count(')'):
            return "Error: Unbalanced parentheses"
        
        print("[TOOL] Query validation passed")
        return f"Valid: {clean_query}"
        
    except Exception as e:
        return f"Error: Syntax validation failed: {e}"

In [None]:
@tool
def execute_sql_query(query: str) -> str:
    """Execute a validated SQL query and return results.
    Only use this after validating the query for safety."""
    print(f"[TOOL] Executing SQL: {query[:100]}...")
    
    try:
        clean_query = query.strip()
        if clean_query.startswith("Valid: "):
            clean_query = clean_query[7:]
        
        clean_query = re.sub(r'```sql\s*', '', clean_query, flags=re.IGNORECASE)
        clean_query = re.sub(r'```\s*', '', clean_query)
        clean_query = clean_query.strip().rstrip(";")
        
        result = db.run(clean_query)
        print("[TOOL] Query executed successfully")
        
        if result:
            return f"Query Results:\n{result}"
        else:
            return "Query executed successfully but returned no results."
            
    except Exception as e:
        error_msg = f"Execution Error: {str(e)}"
        print(f"[TOOL] {error_msg}")
        return error_msg

In [None]:
@tool
def fix_sql_error(original_query: str, error_message: str, question: str) -> str:
    """Fix a failed SQL query by analyzing the error and generating a corrected version.
    Use this when validation or execution fails."""
    print(f"[TOOL] Fixing SQL error: {error_message[:100]}...")
    
    fix_prompt = f"""The following SQL query failed:
Query: {original_query}
Error: {error_message}
Original Question: {question}

Database Schema:
{SCHEMA}

Analyze the error and provide a corrected SQL query that:
1. Fixes the specific error mentioned
2. Still answers the original question
3. Uses only valid table and column names from the schema
4. Follows SQLite syntax rules

Return only the corrected SQL query, nothing else."""
    
    try:
        response = llm.invoke(fix_prompt)
        fixed_query = response.content.strip()
        print("[TOOL] Generated fixed SQL query")
        return fixed_query
    except Exception as e:
        return f"Error generating fix: {e}"

In [None]:
# =============================================================================
# State
# =============================================================================

class AgentState(TypedDict):
    messages: Annotated[List, operator.add]


# All tools
tools = [
    get_database_schema,
    generate_sql_query,
    validate_sql_query, 
    execute_sql_query,
    fix_sql_error
]

# LLM with tools
llm_with_tools = llm.bind_tools(tools)

In [None]:
# =============================================================================
# LangGraph Nodes
# =============================================================================

# Agent node - LLM with tools
def agent_node(state:AgentState):
    print("[AGENT] Processing request")
    
    system_prompt = f"""You are an expert SQL analyst working with an employees database.

                    Database Schema:
                    {SCHEMA}

                    Your workflow for answering questions:
                    1. Use `get_database_schema` first to understand available tables and columns (if needed)
                    2. Use `generate_sql_query` to create SQL based on the question
                    3. Use `validate_sql_query` to check the query for safety and syntax
                    4. Use `execute_sql_query` to run the validated query
                    5. If there's an error, use `fix_sql_error` to correct it and try again (up to 3 times)

                    Rules:
                    - Always follow the workflow step by step
                    - If a query fails, use the fix tool and try again
                    - Provide clear, informative answers
                    - Be precise with table and column names
                    - Handle errors gracefully and try to fix them
                    - If you fail after 3 attempts, explain what went wrong

                    Available tools:
                    - get_database_schema: Get table structure info
                    - generate_sql_query: Create SQL from question
                    - validate_sql_query: Check query safety/syntax  
                    - execute_sql_query: Run the query
                    - fix_sql_error: Fix failed queries

                    Remember: Always validate queries before executing them for safety."""

    messages = [SystemMessage(system_prompt)] + state['messages']
    response = llm_with_tools.invoke(messages)
    
    return {'messages': [response]}

In [None]:
# =============================================================================
# Router Logic
# =============================================================================

# Should we continue or end?
def should_continue(state:AgentState):
    last_message = state['messages'][-1]
    
    if hasattr(last_message, 'tool_calls') and last_message.tool_calls:
        print("[ROUTER] Continuing to tools")
        return "tools"
    else:
        print("[ROUTER] Ending - answer complete")
        return END

In [None]:
# =============================================================================
# Build Graph
# =============================================================================

def create_sql_agent():
    
    builder = StateGraph(AgentState)
    
    # add nodes
    builder.add_node('agent', agent_node)
    builder.add_node('tools', ToolNode(tools))
    
    # define edges
    builder.add_edge(START, 'agent')
    builder.add_edge('tools', 'agent')
    
    # conditional edges
    builder.add_conditional_edges('agent', should_continue, ['tools', END])
    
    return builder.compile()

agent = create_sql_agent()
agent

In [None]:
# =============================================================================
# Query Functions
# =============================================================================

def ask_sql(question: str):
    """Ask the SQL agent a question."""
    print(f"\n{'='*60}")
    print(f"Question: {question}")
    print('='*60)
    
    result = agent.invoke({'messages': [HumanMessage(question)]})
    
    print(f"\n{'='*60}")
    print("Answer:")
    print('='*60)
    result['messages'][-1].pretty_print()
    
    return result

In [None]:
def chat_sql():
    """Interactive SQL chat."""
    print("\nðŸ¤– SQL Agent Chat - Type 'quit' to exit")
    print("Ask questions about the employees database")
    
    while True:
        question = input("\nYour question: ").strip()
        if question.lower() in ['quit', 'exit', 'q']:
            break
        if question:
            ask_sql(question)

# chat_sql()

In [None]:
# =============================================================================
# Demo
# =============================================================================

ask_sql("How many employees are there in total?")

In [None]:
ask_sql("Show me the top 5 highest paid employees with their titles and salaries")

In [None]:
ask_sql("What is the average salary by department?")