# LLM Data Assistant Orchestrator (Notebook UI)
Ask questions like: “Why did conversion drop in the last 10 minutes?”
This notebook retrieves context (Vector Search), prompts an LLM (Llama via Model Serving), guards SQL, executes against UC (Gold/Silver), renders charts, and logs interactions.

## 1) Setup: packages and env
Installs/loads required packages, reads PATs and endpoints from env/.env.

In [None]:
# %pip install databricks-sql-connector requests python-dotenv pandas matplotlib plotly
import os, sys, subprocess, json, time, re

# Ensure required packages
def ensure(pkgs):
    try:
        subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", *pkgs])
    except Exception as e:
        print("pip install failed:", e)
        raise

try:
    import pandas as pd
except Exception:
    ensure(["pandas"]) ; import pandas as pd

try:
    import matplotlib.pyplot as plt
except Exception:
    ensure(["matplotlib"]) ; import matplotlib.pyplot as plt

try:
    import plotly.express as px
except Exception:
    ensure(["plotly"]) ; import plotly.express as px

try:
    from databricks import sql as dbsql  # used later
except Exception:
    ensure(["databricks-sql-connector"]) ; from databricks import sql as dbsql

try:
    import requests
except Exception:
    ensure(["requests"]) ; import requests

from datetime import datetime, timedelta

try:
    from dotenv import load_dotenv
except Exception:
    ensure(["python-dotenv"]) ; from dotenv import load_dotenv

load_dotenv(os.path.expanduser("/Users/kritan/data-monorepo/.env"))

# Databricks SQL creds (PAT)
DATABRICKS_HOST = os.getenv("DATABRICKS_HOST")
DATABRICKS_HTTP_PATH = os.getenv("DATABRICKS_HTTP_PATH")
DATABRICKS_TOKEN = os.getenv("DATABRICKS_TOKEN")
assert DATABRICKS_HOST and DATABRICKS_HTTP_PATH and DATABRICKS_TOKEN, "Missing DBSQL env vars"

# Model Serving (Llama) endpoint
MODEL_ENDPOINT_URL = os.getenv("MODEL_ENDPOINT_URL")  # e.g., https://.../serving-endpoints/llama/invocations
MODEL_TOKEN = os.getenv("MODEL_TOKEN", DATABRICKS_TOKEN)
assert MODEL_ENDPOINT_URL and MODEL_TOKEN, "Missing model serving endpoint/token"

# Vector Search endpoint/index
VS_ENDPOINT = os.getenv("VS_ENDPOINT")
VS_INDEX = os.getenv("VS_INDEX")  # e.g., demo.ecommerce_rt.kb_docs_index
assert VS_ENDPOINT and VS_INDEX, "Missing Vector Search endpoint/index"

CATALOG = os.getenv("CATALOG", "demo")
SCHEMA = os.getenv("SCHEMA", "ecommerce_rt")
DB = f"{CATALOG}.{SCHEMA}"

GOLD = f"{DB}.gold_kpis"
SILVER = f"{DB}.silver_events"
V_TIMESERIES = f"{DB}.v_kpi_timeseries"

## 2) Databricks SQL connection and query helpers
Creates a connection and helpers to run SQL with guardrails.

In [None]:
from databricks import sql as dbsql

def open_conn():
    return dbsql.connect(
        server_hostname=DATABRICKS_HOST,
        http_path=DATABRICKS_HTTP_PATH,
        access_token=DATABRICKS_TOKEN,
    )

ALLOWED_SCHEMAS = {DB}
RAW_LIMIT = 200
DEFAULT_TIME_WINDOW_MIN = 60*24  # 24h fallback

DDL_DML_PATTERN = re.compile(r"\b(INSERT|UPDATE|DELETE|DROP|ALTER|CREATE|TRUNCATE|MERGE|GRANT|REVOKE)\b", re.I)
MULTI_STMT_PATTERN = re.compile(r";\s*\S")

TIME_FILTER_PATTERN = re.compile(r"\b(window_start|ts)\b", re.I)

def enforce_whitelist(sql_text: str) -> str:
    if DDL_DML_PATTERN.search(sql_text):
        raise ValueError("Only SELECT allowed")
    if MULTI_STMT_PATTERN.search(sql_text):
        raise ValueError("Multiple statements not allowed")
    # Force catalog/schema prefix
    for token in re.findall(r"\bfrom\s+([\w\.]+)|\bjoin\s+([\w\.]+)", sql_text, re.I):
        tbl = next((t for t in token if t), None)
        if not tbl:
            continue
        parts = tbl.split('.')
        if len(parts) == 1:
            # inject default DB
            sql_text = re.sub(fr"\b{tbl}\b", f"{DB}.{tbl}", sql_text)
        elif len(parts) == 2:
            # ensure catalog present
            if parts[0] != CATALOG:
                sql_text = sql_text.replace(tbl, f"{CATALOG}.{tbl}")
    return sql_text

def ensure_limits(sql_text: str) -> str:
    if re.search(r"\bgroup\s+by\b", sql_text, re.I):
        return sql_text
    if re.search(r"\blimit\b", sql_text, re.I):
        return sql_text
    return sql_text.rstrip() + f"\nLIMIT {RAW_LIMIT}"

def ensure_time_filter(sql_text: str) -> str:
    if TIME_FILTER_PATTERN.search(sql_text) and re.search(r"now\s*\(\)\s*-\s*INTERVAL", sql_text, re.I):
        return sql_text
    # Prefer window_start when present
    col = "window_start" if "window_start" in sql_text else "ts"
    if re.search(r"\bwhere\b", sql_text, re.I):
        return re.sub(r"\bwhere\b", f"WHERE {col} >= now() - INTERVAL {DEFAULT_TIME_WINDOW_MIN} minutes AND ", sql_text, flags=re.I)
    return sql_text + f"\nWHERE {col} >= now() - INTERVAL {DEFAULT_TIME_WINDOW_MIN} minutes"

def run_sql(sql_text: str) -> pd.DataFrame:
    sql_text = enforce_whitelist(sql_text)
    sql_text = ensure_time_filter(sql_text)
    sql_text = ensure_limits(sql_text)
    with open_conn() as conn:
        cur = conn.cursor()
        cur.execute("USE CATALOG " + CATALOG)
        cur.execute("USE SCHEMA " + SCHEMA)
        cur.execute("EXPLAIN " + sql_text)
        cur.execute(sql_text)
        cols = [d[0] for d in cur.description]
        rows = cur.fetchall()
    return pd.DataFrame(rows, columns=cols)

## 3) Vector Search retrieval
Fetch top-k knowledge snippets to ground the LLM.

In [None]:
import requests

def vs_query(query: str, k: int = 6, kinds=("table","kpi","rule","example")):
    url = f"{VS_ENDPOINT}/indexes/{VS_INDEX}/query"
    headers = {"Authorization": f"Bearer {DATABRICKS_TOKEN}", "Content-Type": "application/json"}
    body = {
        "query": query,
        "k": k,
        "filters": {"kind": list(kinds)}
    }
    r = requests.post(url, headers=headers, json=body, timeout=15)
    r.raise_for_status()
    return r.json().get("results", [])

def summarize_context(results):
    chunks = []
    for r in results:
        meta = r.get("metadata", {})
        title = meta.get("title") or meta.get("table_name") or "doc"
        body = r.get("text") or r.get("body") or ""
        chunks.append(f"[{title}]\n{body}")
    return "\n\n".join(chunks)

## 4) Prompting and LLM call
Build system/user prompts, call Model Serving, and parse JSON plan.

In [None]:
def call_llm(messages):
    headers = {"Authorization": f"Bearer {MODEL_TOKEN}", "Content-Type": "application/json"}
    body = {"messages": messages, "max_tokens": 512, "temperature": 0.1}
    r = requests.post(MODEL_ENDPOINT_URL, headers=headers, json=body, timeout=30)
    r.raise_for_status()
    data = r.json()
    # Adapt to endpoint response shape
    content = data.get("choices", [{}])[0].get("message", {}).get("content") or data
    return content

SYSTEM = (
"""You are a Databricks SQL expert assistant.
- Only generate ANSI SQL SELECT statements against Unity Catalog objects in the provided schemas.
- Prefer gold KPIs table: {gold} and views like {vts}.
- Add LIMIT for row outputs; avoid CROSS JOINs unless necessary.
- Output strictly a compact JSON: {"sql":"...","explanation":"...","chart":{"type":"line","x":"window_start","y":["gmv"]}}
""".format(gold=GOLD, vts=V_TIMESERIES)
)

def build_prompt(user_question: str, context_text: str):
    return f"Context:\n{context_text}\n\nQuestion: {user_question}\nReturn JSON only."

import json

def plan_query(user_question: str):
    ctx = summarize_context(vs_query(user_question))
    prompt = build_prompt(user_question, ctx)
    content = call_llm([
        {"role":"system","content": SYSTEM},
        {"role":"user","content": prompt}
    ])
    try:
        plan = json.loads(content)
    except Exception:
        # try to extract JSON block
        m = re.search(r"\{[\s\S]*\}", content)
        if not m:
            raise ValueError("LLM did not return JSON")
        plan = json.loads(m.group(0))
    # sanity defaults
    plan.setdefault("chart", {"type":"line","x":"window_start","y":["gmv"]})
    plan.setdefault("explanation", "")
    return plan

## 5) Guardrails and execution
Apply whitelist/limits/time filters, EXPLAIN, repair once on error.

In [None]:
def execute_plan(plan: dict) -> pd.DataFrame:
    sql_text = plan.get("sql", "").strip()
    if not sql_text.lower().startswith("select"):
        raise ValueError("Generated SQL must be SELECT")
    try:
        df = run_sql(sql_text)
        return df
    except Exception as e:
        err = str(e)
        # one-shot repair with error info
        ctx = summarize_context(vs_query(plan.get("explanation", "") or ""))
        repair_prompt = f"Context:\n{ctx}\n\nSQL had error: {err}\nOriginal SQL: {sql_text}\nFix and return JSON only with updated sql."
        content = call_llm([
            {"role":"system","content": SYSTEM},
            {"role":"user","content": repair_prompt}
        ])
        try:
            fixed = json.loads(content)
            return run_sql(fixed.get("sql", sql_text))
        except Exception:
            m = re.search(r"\{[\s\S]*\}", content)
            if not m:
                raise
            fixed = json.loads(m.group(0))
            return run_sql(fixed.get("sql", sql_text))

## 6) Interaction logging (Delta)
Log question, context, SQL, timings; scrub email-like strings.

In [None]:
LOG_TABLE = f"{DB}.assistant_logs"

CREATE_LOG_SQL = f"""
CREATE TABLE IF NOT EXISTS {LOG_TABLE} (
  event_time TIMESTAMP,
  user_question STRING,
  plan_json STRING,
  sql_text STRING,
  row_count INT,
  latency_ms BIGINT,
  error STRING
) USING DELTA
"""

EMAIL_RE = re.compile(r"[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}")

def scrub_pii(text: str) -> str:
    if not text:
        return text
    return EMAIL_RE.sub("<email>", text)

def log_interaction(question: str, plan: dict, sql_text: str, row_count: int, latency_ms: int, error: str = None):
    with open_conn() as conn:
        cur = conn.cursor()
        cur.execute("USE CATALOG " + CATALOG)
        cur.execute("USE SCHEMA " + SCHEMA)
        cur.execute(CREATE_LOG_SQL)
        cur.execute(
            f"INSERT INTO {LOG_TABLE} VALUES (current_timestamp(), ?, ?, ?, ?, ?, ?)",
            (
                scrub_pii(question),
                json.dumps(plan)[:900000],
                scrub_pii(sql_text)[:900000],
                int(row_count),
                int(latency_ms),
                scrub_pii(error) if error else None,
            ),
        )

## 7) Notebook Chat UI
Enter a question, run the assistant, render results and chart.

In [None]:
# Ask a question, then plan → execute → plot → log
question = "Why did conversion drop in the last 10 minutes?"
minutes_override = 10

print("Question:", question)

try:
    t0 = time.time()
    plan = plan_query(question)
    df = execute_plan(plan)
    latency_ms = int((time.time() - t0) * 1000)
    print("Latency:", latency_ms, "ms")

    # Basic rendering
    if not df.empty:
        print("Rows:", len(df))
        # Choose plotting backend
        if set(["window_start","gmv"]).issubset(df.columns):
            fig = px.line(df, x="window_start", y=[c for c in ["gmv","orders","active_users","conversion_rate"] if c in df.columns])
            fig.update_layout(height=400, width=900)
            fig.show()
        else:
            display_cols = df.columns[:6]
            print(df[display_cols].head(10))
    else:
        print("No results.")

    # Log
    try:
        sql_text = plan.get("sql", "")
        log_interaction(question, plan, sql_text, int(len(df)), latency_ms)
    except Exception as e:
        print("Log error:", e)
except Exception as e:
    print("Assistant error:", e)

## 8) Example Questions & Testing
Sample queries for demo.

In [None]:
example_questions = [
    "What was GMV over the last 15 minutes?",
    "Show orders per minute for the last hour",
    "Which time window had the highest conversion rate today?", 
    "Break down active users by 5-minute windows in the last 30 minutes",
    "Show me 20 sample purchases from the last 10 minutes",
    "Compare GMV vs conversion rate trends over the last 2 hours"
]

# Quick test multiple questions
for i, q in enumerate(example_questions[:3]):
    print(f"\n[{i+1}] {q}")
    try:
        plan = plan_query(q)
        print("SQL:", plan.get("sql", "")[:120], "...")
        print("Explanation:", plan.get("explanation", "")[:80], "...")
    except Exception as e:
        print("Error:", e)

## 9) Logs & Quality Metrics
View recent interactions and performance stats.

In [None]:
# Check recent logs
try:
    logs_df = run_sql(f"""
        SELECT event_time, user_question, sql_text, row_count, latency_ms, error
        FROM {LOG_TABLE}
        ORDER BY event_time DESC
        LIMIT 20
    """)
    if not logs_df.empty:
        print("Recent interactions:")
        print(logs_df[['event_time', 'user_question', 'row_count', 'latency_ms', 'error']].head(10))
        
        # Quick metrics
        print("\nMetrics:")
        print(f"Total interactions: {len(logs_df)}")
        print(f"Avg latency: {logs_df['latency_ms'].mean():.0f}ms")
        print(f"P95 latency: {logs_df['latency_ms'].quantile(0.95):.0f}ms")
        print(f"Error rate: {(logs_df['error'].notna().sum() / len(logs_df) * 100):.1f}%")
    else:
        print("No logs found. Run some questions first.")
except Exception as e:
    print("Log table not ready:", e)