In [1]:
from __future__ import annotations
from typing import Annotated, List, TypedDict, Optional, Any, Dict
import json, re, sqlite3, textwrap

from langchain_core.messages import AnyMessage, HumanMessage, AIMessage, SystemMessage
from langchain_community.chat_models import ChatLlamaCpp
from langgraph.graph import StateGraph, START, END

In [2]:
# Global variables
DB_PATH = "Data/dummy_data.db"

# Set up local LLM usimg Llama Cpp
chat = ChatLlamaCpp(
    model_path="LlamaCppModels/Llama-3.2-1B-Instruct/Llama-3.2-1B-Instruct-Q6_K_L.gguf",
    n_ctx=4096,
    n_threads=8,
    temperature=0.4,
    model_kwargs={"chat_format": "llama-3"},
    verbose=False,
)

llama_context: n_batch is less than GGML_KQ_MASK_PAD - increasing to 64
llama_context: n_ctx_per_seq (4096) < n_ctx_train (131072) -- the full capacity of the model will not be utilized
ggml_metal_init: skipping kernel_get_rows_bf16                     (not supported)
ggml_metal_init: skipping kernel_set_rows_bf16                     (not supported)
ggml_metal_init: skipping kernel_mul_mv_bf16_f32                   (not supported)
ggml_metal_init: skipping kernel_mul_mv_bf16_f32_c4                (not supported)
ggml_metal_init: skipping kernel_mul_mv_bf16_f32_1row              (not supported)
ggml_metal_init: skipping kernel_mul_mv_bf16_f32_l4                (not supported)
ggml_metal_init: skipping kernel_mul_mv_bf16_bf16                  (not supported)
ggml_metal_init: skipping kernel_mul_mv_id_bf16_f32                (not supported)
ggml_metal_init: skipping kernel_mul_mm_bf16_f32                   (not supported)
ggml_metal_init: skipping kernel_mul_mm_id_bf16_f16                

In [None]:
# State object for LangGraph
class ChatState(TypedDict):
    messages: Annotated[List[AnyMessage], "Running chat transcript"]
    sql: Optional[str]
    sql_explanation: Optional[str]
    rows: Optional[List[Dict[str, Any]]]
    error: Optional[str]


# Helper functions
def inspect_schema(db_path: str) -> str:
    """Return a compact, LLM-friendly schema description."""
    con = sqlite3.connect(db_path)
    con.row_factory = sqlite3.Row
    cur = con.cursor()

    # list tables
    cur.execute("SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%' ORDER BY name;")
    tables = [r["name"] for r in cur.fetchall()]

    parts = []
    for t in tables:
        cur.execute(f"PRAGMA table_info({t});")
        cols = cur.fetchall()
        cols_desc = ", ".join(f"{c['name']} {c['type']}" for c in cols)
        parts.append(f"- {t}({cols_desc})")
        # brief fk info
        cur.execute(f"PRAGMA foreign_key_list({t});")
        fks = cur.fetchall()
        for fk in fks:
            parts.append(f"    ↳ FK {fk['from']} → {fk['table']}({fk['to']})")
    con.close()

    if not parts:
        return "(No user tables found.)"
    return "Tables:\n" + "\n".join(parts)

def ensure_readonly_select(sql: str, default_limit: int = 50) -> str:
    """Allow only SELECT queries. Add LIMIT if missing."""
    s = sql.strip().rstrip(";")
    # simple gate: only allow SELECT, WITH … SELECT
    lowered = s.lower()
    if not (lowered.startswith("select") or lowered.startswith("with ")):
        raise ValueError("Only read-only SELECT queries are allowed.")
    # ensure a LIMIT exists (rough check)
    if " limit " not in lowered:
        s += f" LIMIT {default_limit}"
    return s + ";"

def run_select(db_path: str, sql: str) -> List[Dict[str, Any]]:
    con = sqlite3.connect(db_path)
    con.row_factory = sqlite3.Row
    cur = con.cursor()
    cur.execute(sql)
    rows = [dict(r) for r in cur.fetchall()]
    con.close()
    return rows

def extract_json(text: str) -> Dict[str, Any]:
    """Extract the first JSON object from a model string."""
    # prefer fenced code blocks with json
    fence = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", text, flags=re.S)
    if fence:
        return json.loads(fence.group(1))
    # fallback: first { ... }
    curly = re.search(r"\{.*\}", text, flags=re.S)
    if curly:
        return json.loads(curly.group(0))
    raise ValueError("No JSON object found in model output.")

# Nodes
SYS_PLANNER = SystemMessage(content=textwrap.dedent("""\
You are a careful SQL planning assistant for a SQLite database.
You MUST ONLY produce a single query

Hard rules:
- Only SELECT / WITH … SELECT. No INSERT/UPDATE/DELETE/PRAGMA/DROP.
- Prefer explicit column lists over SELECT * when reasonable.
- Use existing table/column names exactly as shown in the schema.
- Add a reasonable LIMIT (e.g., 50) if the question does not require full detail.
- If the question is ambiguous, choose a reasonable interpretation and note it in "explanation".

Output strictly a query. Do not include extra text.
"""))

SCHEMA = """\
Tables:
- dummy_data(igef TEXT, test TEXT, test_result TEXT)

Available values in columns:
- test_result: OK, NOK
"""

def plan_sql(state: ChatState) -> ChatState:
    schema = SCHEMA
    # Last user message
    user_msg = next((m for m in reversed(state["messages"]) if isinstance(m, HumanMessage)), None)
    question = user_msg.content if user_msg else "Show something useful from the database."

    plan_prompt = (
        f"Database schema:\n{schema}\n\n"
        f"User question:\n{question}\n\n"
        f"Return JSON now."
    )

    resp = chat.invoke([SYS_PLANNER, HumanMessage(content=plan_prompt)])
    try:
        # obj = extract_json(resp.content if isinstance(resp, AIMessage) else str(resp))
        # sql = obj.get("sql", "").strip()
        # expl = obj.get("explanation", "").strip()
        sql = resp.content.strip()
        expl = ''
        return {**state, "sql": sql, "sql_explanation": expl, "error": None}
    except Exception as e:
        return {**state, "sql": None, "sql_explanation": None, "error": f"Planner parsing error: {e}"}

def execute_sql(state: ChatState) -> ChatState:
    if state.get("error"):
        return state
    try:
        safe_sql = ensure_readonly_select(state["sql"] or "")
        rows = run_select(DB_PATH, safe_sql)
        return {**state, "sql": safe_sql, "rows": rows, "error": None}
    except Exception as e:
        return {**state, "rows": None, "error": f"SQL execution error: {e}"}

def respond(state: ChatState) -> ChatState:
    if state.get("error"):
        msg = f"⚠️ {state['error']}"
        return {"messages": state["messages"] + [AIMessage(content=msg)], "sql": None, "sql_explanation": None, "rows": None, "error": state["error"]}

    rows = state.get("rows") or []
    expl = state.get("sql_explanation") or "Query executed."
    sql = state.get("sql") or ""

    # Pretty-print a compact table (first 10 rows)
    preview = rows[:10]
    if preview:
        headers = list(preview[0].keys())
        lines = [" | ".join(headers), " | ".join(["---"] * len(headers))]
        for r in preview:
            lines.append(" | ".join(str(r.get(h, "")) for h in headers))
        table_md = "\n".join(lines)
    else:
        table_md = "_No rows returned._"

    answer = textwrap.dedent(f"""\
    {expl}

    **SQL used**
    ```sql
    {sql}
    ```

    **Preview ({len(preview)} of {len(rows)} rows)**
    {table_md}
    """)
    return {"messages": state["messages"] + [AIMessage(content=answer)], "sql": sql, "sql_explanation": expl, "rows": rows, "error": None}


# Build graph
builder = StateGraph(ChatState)
builder.add_node("plan_sql", plan_sql)
builder.add_node("execute_sql", execute_sql)
builder.add_node("respond", respond)

builder.add_edge(START, "plan_sql")
builder.add_edge("plan_sql", "execute_sql")
builder.add_edge("execute_sql", "respond")
builder.add_edge("respond", END)

app = builder.compile()

    Query executed.

    **SQL used**
    ```sql
    SELECT igef FROM dummy_data GROUP BY igef ORDER BY COUNT(test_result) DESC LIMIT 5;
    ```

    **Preview (5 of 5 rows)**
    igef
---
igef_50
igef_454
igef_142
igef_899
igef_801



In [None]:
initial = ChatState(messages=[HumanMessage(content="Show the 5 IGEFs where the most tests failed.")],
                    sql=None, sql_explanation=None, rows=None, error=None)
result = app.invoke(initial)
print(result["messages"][-1].content)