In [None]:
"""
NL → SQL Agent using LangGraph + LangChain + Ollama (SQLite version with Hospital/Doctor sample schema)
--------------------------------------------------

Uses a local SQLite database to answer natural-language questions by generating SQL.
If the database does not exist, it will be created and seeded with hospital/doctor/patient sample data.
"""
from __future__ import annotations
import os
import json
from typing import List, Dict, Any, Literal

from langchain_core.messages import HumanMessage, SystemMessage, AIMessage
from langchain_core.tools import tool
from langchain_ollama import ChatOllama
from langgraph.graph import StateGraph, END
from langgraph.prebuilt import ToolNode
from pydantic import BaseModel, Field

from sqlalchemy import create_engine, inspect, text

# -----------------------------
# DB Setup & Introspection
# -----------------------------

def _get_engine():
    db_path = os.getenv("SQLITE_PATH", "hospital.db")
    first_time = not os.path.exists(db_path)
    engine = create_engine(f"sqlite:///{db_path}")
    if first_time:
        _seed_database(engine)
    return engine

def _seed_database(engine):
    with engine.begin() as conn:
        conn.execute(text("""
        CREATE TABLE hospitals (
            id INTEGER PRIMARY KEY,
            name TEXT NOT NULL,
            location TEXT NOT NULL
        );
        """))
        conn.execute(text("""
        CREATE TABLE doctors (
            id INTEGER PRIMARY KEY,
            name TEXT NOT NULL,
            specialty TEXT NOT NULL,
            hospital_id INTEGER,
            FOREIGN KEY (hospital_id) REFERENCES hospitals (id)
        );
        """))
        conn.execute(text("""
        CREATE TABLE patients (
            id INTEGER PRIMARY KEY,
            name TEXT NOT NULL,
            age INTEGER,
            ailment TEXT,
            doctor_id INTEGER,
            FOREIGN KEY (doctor_id) REFERENCES doctors (id)
        );
        """))
        conn.execute(text("INSERT INTO hospitals (name, location) VALUES ('City Hospital', 'Downtown'), ('Green Valley Hospital', 'Suburbs')"))
        conn.execute(text("INSERT INTO doctors (name, specialty, hospital_id) VALUES ('Dr. Smith', 'Cardiology', 1), ('Dr. Brown', 'Orthopedics', 1), ('Dr. Green', 'Pediatrics', 2)"))
        conn.execute(text("INSERT INTO patients (name, age, ailment, doctor_id) VALUES ('Alice', 45, 'Heart Disease', 1), ('Bob', 30, 'Fracture', 2), ('Charlie', 5, 'Fever', 3)"))

def _safe_inspector():
    eng = _get_engine()
    return inspect(eng)

# -----------------------------
# Tools
# -----------------------------
@tool("read_table_names_with_table_description" , description="Get names and descriptions of all tables in the database.")
def read_table_names_with_table_description() -> str:
    insp = _safe_inspector()
    tables: List[Dict[str, Any]] = []
    for tname in insp.get_table_names():
        tables.append({"table": tname, "schema": "main", "comment": None})
    return json.dumps({"tables": tables}, indent=2)

class ColumnsWithTypesInput(BaseModel):
    table_name: str

@tool("read_table_columns_with_data_types", args_schema=ColumnsWithTypesInput , description="Get column names and data types for a specific table.")
def read_table_columns_with_data_types(table_name: str) -> str:
    insp = _safe_inspector()
    cols_meta = []
    for col in insp.get_columns(table_name):
        cols_meta.append({
            "name": col.get("name"),
            "type": str(col.get("type")),
            "nullable": bool(col.get("nullable")),
            "default": col.get("default")
        })
    return json.dumps({"table": table_name, "columns": cols_meta}, indent=2)

class ColumnsDescriptionInput(BaseModel):
    table_name: str


@tool("read_table_columns_description", args_schema=ColumnsDescriptionInput,description="Get column names and descriptions for a specific table.")
def read_table_columns_description(table_name: str) -> str:
    insp = _safe_inspector()
    results = []
    for col in insp.get_columns(table_name):
        results.append({"name": col.get("name"), "comment": None})
    return json.dumps({"table": table_name, "columns": results}, indent=2)

TOOLS = [
    read_table_names_with_table_description,
    read_table_columns_with_data_types,
    read_table_columns_description,
]

# -----------------------------
# Prompt & Agent
# -----------------------------
SYSTEM_PROMPT = """You are an expert SQL query generator that uses database schema metadata and sample data values 
to answer natural language questions.  

You have access to these tools:
1. read_table_names_with_table_description — returns available tables with descriptions.
2. read_table_columns_description — returns a description of each column.
3. read_column_sample_values — returns a list of distinct example values for a given column to understand variations and synonyms.

Your workflow:
- Always start by identifying which table(s) are relevant.
- Use read_table_columns_with_data_types to find correct column names and types.
- Use read_table_columns_description to confirm the meaning of each column.
- For textual or categorical columns, use read_column_sample_values to see real example values before writing WHERE conditions.
- Choose the most accurate filtering logic (e.g., exact match, LIKE, IN, case-insensitive match) based on the sample values.
- Only use columns and tables that exist.
- Avoid guessing — if the question is ambiguous, ask for clarification.

Output rules:
- Output ONLY the final SQL query inside:
```sql
SELECT ...
FROM ...
WHERE ...
"""

class AgentState(BaseModel):
    messages: List[Any] = Field(default_factory=list)

def _make_llm() -> ChatOllama:
    base_url = os.getenv("OLLAMA_BASE_URL", "http://localhost:11434")
    model = os.getenv("OLLAMA_MODEL", "qwen3:1.7b")
    return ChatOllama(base_url=base_url, model=model, temperature=0.0)

def should_continue(state: AgentState) -> Literal["tools", "final"]:
    last = state.messages[-1]
    if isinstance(last, AIMessage) and last.tool_calls:
        return "tools"
    return "final"

import re

def clean_content(con: AIMessage) -> str:
    return AIMessage(content = re.sub(r"<think>.*?</think>", "", con.content, flags=re.DOTALL | re.IGNORECASE).strip())

def agent_node(state: AgentState) -> AgentState:
    llm = _make_llm().bind_tools(TOOLS)
    print('input:', state.messages[-1].content if state.messages else "No messages")
    result = llm.invoke([SystemMessage(content=SYSTEM_PROMPT)] + state.messages)
    if isinstance(result, AIMessage) and result.tool_calls:
        pass
    else:
        result = clean_content(result) if isinstance(result, AIMessage) else result
    print('result:', result.content)
    return AgentState(messages=state.messages + [result])

def build_graph():
    graph = StateGraph(AgentState)
    graph.add_node("agent", agent_node)
    graph.add_node("tools", ToolNode(tools=TOOLS))
    graph.set_entry_point("agent")
    graph.add_conditional_edges("agent", should_continue, {"tools": "tools", "final": END})
    graph.add_edge("tools", "agent")
    return graph.compile()

def ask_sql(question: str) -> str:
    app = build_graph()
    inputs = AgentState(messages=[HumanMessage(content=question)])
    out_state: AgentState = app.invoke(inputs)
    messages = out_state.get("messages", []) if isinstance(out_state, dict) else out_state.messages
    ai_msgs = [m for m in messages if isinstance(m, AIMessage)]
    return ai_msgs[-1].content if ai_msgs else ""
    # ai_msgs = [m for m in out_state.messages if isinstance(m, AIMessage)]
    # return ai_msgs[-1].content if ai_msgs else ""

if __name__ == "__main__":
    # import sys
    # if len(sys.argv) < 2:
    #     print("Usage: python nl_to_sql_langgraph_ollama_agent.py 'Your question here'")
    #     sys.exit(1)
    question = "how many doctors work at City Hospital?"
    #ys.argv[1]
    print(ask_sql(question))
    # Placeholder for additional functionality or testing
    # Example: Adding a test query to verify the setup
    # test_question = "list all patients treated by Dr. Smith"
    # print(ask_sql(test_question))

input: how many doctors work at City Hospital?
result: <think>
Okay, the user is asking how many doctors work at City Hospital. Let me figure out how to get that information.

First, I need to check the database schema to see which tables are involved. The relevant tables would probably be something like "doctors" and "hospitals" or similar. But since I don't have the actual database structure, I need to use the provided tools to find out the tables and their columns.

I should start by calling the read_table_names_with_table_description function to get a list of tables and their descriptions. That will help me identify which tables are related to hospitals and doctors. Let's see what the function returns. Suppose it lists a table called "hospitals" with a description like "Stores information about hospitals." And another table called "doctors" with a description about their details.

Once I have the table names, I need to check the columns. For example, if there's a "hospitals" table,

In [13]:
con =  _make_llm().invoke('hi')

In [18]:
def clean_content(con: AIMessage) -> str:
    return AIMessage(content = re.sub(r"<think>.*?</think>", "", con.content, flags=re.DOTALL | re.IGNORECASE).strip())

clean_content(con)

AIMessage(content='Hello! How can I assist you today? 😊', additional_kwargs={}, response_metadata={})

In [14]:
con.content

'<think>\nOkay, the user said "hi". I need to respond politely. Let me make sure to acknowledge their greeting and offer assistance. Maybe start with a friendly greeting like "Hello!" and ask how I can help them. Keep it simple and welcoming. Let me check for any possible misunderstandings. They might just want a standard response, so I\'ll stick to that. Alright, ready to respond.\n</think>\n\nHello! How can I assist you today? 😊'

In [17]:
import re
content = re.sub(r"<think>.*?</think>", "", con.content, flags=re.DOTALL | re.IGNORECASE).strip()
content

'Hello! How can I assist you today? 😊'