In [2]:
from langchain_core.messages import HumanMessage, SystemMessage
from linkedin.langchain.chat_models.proxied_gpt_chat import ProxiedGPTChat

In [3]:
%reload_ext linkedin.lisql
%config SqlMagic.autocommit=False
%manage_trino holdem
import datetime

In [4]:
!pip install langchain-community

In [5]:
from typing import TypedDict, List, Dict
from langchain_core.messages import HumanMessage, SystemMessage
from langgraph.graph import StateGraph, START, END
from linkedin.langchain.chat_models.proxied_gpt_chat import ProxiedGPTChat
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_core.chat_history import BaseChatMessageHistory
from langchain.memory import ConversationBufferMemory
#from langchain_community.stores.message.in_memory import InMemoryChatMessageHistory
from IPython.display import Markdown, display
from langchain.tools import tool
from IPython import get_ipython
import warnings
warnings.filterwarnings('ignore')


```
[User]
   |
   v
[Query Handler / Orchestrator]
   |
   +--> Intent Detection
   |
   +--> Schema Fetcher (DESCRIBE table)
   |
   +--> SQL Generator (LLM)
   |
   +--> SQL Executor (Trino)
   |
   +--> Result Interpreter (LLM)
   |
   +--> Metrics Extractor
   |
   +--> Predictive Agent 
   |
   +--> Memory System (SQLite)
   |
   v
[Final Answer to User]
```

```mermaid
flowchart LR
    %% =========================
    %% USER + EXECUTION LAYER
    %% =========================
    subgraph L1["Execution Layer - Darwin Notebook"]
      U["User / Analyst"] --> UI["Python Loop / CLI Input\nwhile True: user_query"]
      UI --> ORCH["LangGraph Orchestrator\n(StateGraph -> compiled.invoke)"]
    end

    %% =========================
    %% ORCHESTRATION + AGENTS
    %% =========================
    subgraph L2["Agent Layer - LangGraph Nodes"]
      ORCH --> ROUTER{"Router / Intent Detector\nneeds_sql? predictive?"}

      ROUTER -->|Follow-up or Memory| FOLLOW["Follow-up Resolver\n(check recent_turns + metrics cache)"]
      ROUTER -->|Needs SQL| SQLGEN["SQL Generator Agent\n(LLM -> SELECT + LIMIT)"]
      ROUTER -->|Recommendations| PRED["Predictive Reasoning Agent\n(LLM -> next-test suggestions)"]

      FOLLOW --> QA["QA / Insights Agent\n(LLM -> explanation + summary)"]

      SQLGEN --> SAFE{"SQL Safety Gate\nSELECT-only enforcement"}
      SAFE -->|Blocked| ERR["Error Response\n(unsafe SQL or invalid query)"]
      SAFE -->|Allowed| EXEC["SQL Executor Tool Node\nrun_trino_sql()"]

      EXEC --> METRICS["Metric Extraction Agent\n(LLM -> JSON metrics)"]
      METRICS --> QA
      QA --> OUT["Final Answer Returned"]
      ERR --> OUT
      PRED --> OUT
    end

    %% =========================
    %% MODEL LAYER
    %% =========================
    subgraph L3["Model Layer"]
      LLM["LinkedIn ProxiedGPTChat\n(GPT-4.1 deployment)"]
    end

    SQLGEN --> LLM
    QA --> LLM
    METRICS --> LLM
    PRED --> LLM
    FOLLOW --> LLM

    %% =========================
    %% DATA ACCESS LAYER
    %% =========================
    subgraph L4["Data Access Layer"]
      TRINO["Trino / Holdem\n(%manage_trino holdem)"]
      DB["Analytics Tables\n(u_mktgreporting.*)"]
      TRINO --> DB
    end

    EXEC --> TRINO

    %% =========================
    %% MEMORY LAYER
    %% =========================
    subgraph L5["Memory Layer"]
      STM["Short-Term Memory\nIn-state: recent_turns last 3\n+ older_summary"]
      LTM["Long-Term Memory\nSQLite store keyword + recency\nUser-scoped memories"]
      CKPT["LangGraph Checkpointer\nSQLite checkpoints"]
    end

    ROUTER <-->|reads| STM
    FOLLOW <-->|reads| STM
    QA <-->|reads| STM
    PRED <-->|reads| STM

    ROUTER <-->|retrieves notes| LTM
    OUT -->|write compact QA| LTM
    ORCH -->|persist state| CKPT
    CKPT --> ORCH

    %% =========================
    %% OUTPUT
    %% =========================
    OUT --> UI --> U
```

In [None]:
%reload_ext linkedin.lisql
%config SqlMagic.autocommit=True
%manage_trino holdem

from typing import TypedDict, Dict, List, Optional, Tuple
from langchain_core.messages import HumanMessage, SystemMessage
from langgraph.graph import StateGraph, START, END
from linkedin.langchain.chat_models.proxied_gpt_chat import ProxiedGPTChat
from langchain.tools import tool
from IPython import get_ipython
from IPython.display import Markdown, display
import re
import json
import os
import time
import sqlite3
import base64
import hashlib

# -----------------------------
# 1) SQL Tool (rollback + retry)
# -----------------------------
@tool
def run_trino_sql(query: str) -> str:
    """
    Execute a SQL query via Darwin %sql magic.
    Auto-ROLLBACK + retry once if Trino transaction gets stuck.
    """
    ip = get_ipython()
    try:
        result = ip.run_line_magic("sql", query)
        try:
            pdf = result.DataFrame()
            return pdf.head(200).to_markdown(index=False)
        except Exception:
            return str(result)
    except Exception as e:
        err = str(e)
        if "invalid transaction" in err.lower() or "rolled back" in err.lower():
            try:
                ip.run_line_magic("sql", "ROLLBACK")
                result = ip.run_line_magic("sql", query)
                try:
                    pdf = result.DataFrame()
                    return pdf.head(200).to_markdown(index=False)
                except Exception:
                    return str(result)
            except Exception as e2:
                return f"SQL_ERROR: {str(e2)}"
        return f"SQL_ERROR: {err}"

# -----------------------------
# 2) LLM
# -----------------------------
chat = ProxiedGPTChat(
    resource_id="yourorg-resource-id",
    deployment_id="yourorg-deployment-id",
    temperature=0.0,
    max_tokens=2500
)

# -----------------------------
# 3) Prompts
# -----------------------------
SQL_GEN_SYSTEM_PROMPT = """
You are a LinkedIn Trino SQL expert.

Rules:
- ONLY write SELECT statements.
- You will be given the exact table schema (columns). Use only those columns.
- If the user asks for "insights" or "show me this table", return:
  SELECT * FROM <table> LIMIT 50
- Always include LIMIT <= 200 unless aggregation is requested.
Return ONLY SQL.
"""

QA_SYSTEM_PROMPT = """
You are a LinkedIn analytics explanation expert.
Explain results clearly, highlight patterns and useful takeaways.
If the SQL errored, explain why and what to do next.
"""

METRIC_EXTRACTION_PROMPT = """
Extract key numeric values from the SQL results and return JSON only.
Example: {"total_rows": 50, "pct_change": -15.2}
"""

PREDICTIVE_SYSTEM_PROMPT = """
You are a senior experimentation strategist.
Use metrics + results context to recommend next tests, guardrails, or next actions.
Stay practical and grounded in the observed table output.
"""

# -----------------------------
# 4) Long-term memory (SQLite keyword + recency)
# -----------------------------
def _pbkdf2_hash_password(password: str, salt: bytes, rounds: int = 200_000) -> bytes:
    return hashlib.pbkdf2_hmac("sha256", password.encode("utf-8"), salt, rounds)

def _tokenize(text: str) -> List[str]:
    text = (text or "").lower()
    return re.findall(r"[a-z0-9_]+", text)

def _keyword_score(query: str, doc: str) -> float:
    q_toks = set(_tokenize(query))
    d_toks = set(_tokenize(doc))
    if not q_toks:
        return 0.0
    overlap = len(q_toks.intersection(d_toks))
    return overlap / max(len(q_toks), 1)

class LongTermMemorySQLite:
    def __init__(self, db_path: str):
        self.db_path = db_path
        self._init_db()

    def _conn(self):
        return sqlite3.connect(self.db_path)

    def _init_db(self):
        with self._conn() as con:
            cur = con.cursor()
            cur.execute("""
                CREATE TABLE IF NOT EXISTS users (
                    id INTEGER PRIMARY KEY AUTOINCREMENT,
                    username TEXT UNIQUE NOT NULL,
                    salt_b64 TEXT NOT NULL,
                    pw_hash_b64 TEXT NOT NULL,
                    created_ts REAL NOT NULL
                )
            """)
            cur.execute("""
                CREATE TABLE IF NOT EXISTS memories (
                    id INTEGER PRIMARY KEY AUTOINCREMENT,
                    user_id INTEGER NOT NULL,
                    created_ts REAL NOT NULL,
                    text TEXT NOT NULL,
                    FOREIGN KEY(user_id) REFERENCES users(id)
                )
            """)
            cur.execute("CREATE INDEX IF NOT EXISTS idx_memories_user_ts ON memories(user_id, created_ts)")
            con.commit()

    def create_user(self, username: str, password: str) -> int:
        salt = os.urandom(16)
        pw_hash = _pbkdf2_hash_password(password, salt)
        with self._conn() as con:
            cur = con.cursor()
            cur.execute(
                "INSERT INTO users(username, salt_b64, pw_hash_b64, created_ts) VALUES (?, ?, ?, ?)",
                (username, base64.b64encode(salt).decode(), base64.b64encode(pw_hash).decode(), time.time())
            )
            con.commit()
            return int(cur.lastrowid)

    def authenticate(self, username: str, password: str) -> Optional[int]:
        with self._conn() as con:
            cur = con.cursor()
            cur.execute("SELECT id, salt_b64, pw_hash_b64 FROM users WHERE username=?", (username,))
            row = cur.fetchone()
            if not row:
                return None
            uid, salt_b64, pw_hash_b64 = row
            salt = base64.b64decode(salt_b64)
            expected = base64.b64decode(pw_hash_b64)
            got = _pbkdf2_hash_password(password, salt)
            return int(uid) if got == expected else None

    def create_or_auth(self, username: str, password: str) -> int:
        uid = self.authenticate(username, password)
        return uid if uid is not None else self.create_user(username, password)

    def add_memory(self, user_id: int, text: str):
        if not text.strip():
            return
        with self._conn() as con:
            cur = con.cursor()
            cur.execute("INSERT INTO memories(user_id, created_ts, text) VALUES (?, ?, ?)", (user_id, time.time(), text))
            con.commit()

    def search(self, user_id: int, query: str, k: int = 5, candidate_limit: int = 300) -> List[str]:
        with self._conn() as con:
            cur = con.cursor()
            cur.execute("SELECT created_ts, text FROM memories WHERE user_id=? ORDER BY created_ts DESC LIMIT ?",
                        (user_id, candidate_limit))
            rows = cur.fetchall()

        now = time.time()
        scored = []
        for ts, text in rows:
            s = _keyword_score(query, text)
            age_days = max((now - float(ts)) / 86400, 0.0)
            recency = 0.08 * (1.0 / (1.0 + age_days / 30.0))
            scored.append((s + recency, text))
        scored.sort(reverse=True, key=lambda x: x[0])
        return [t for _, t in scored[:k]]

lt_store = LongTermMemorySQLite("long_term_memory_noembed.sqlite")

# -----------------------------
# 5) Helpers
# -----------------------------
def is_predictive_question(q: str) -> bool:
    triggers = ["suggest", "recommend", "next test", "predict", "improve", "strategy", "what should we do next"]
    return any(t in (q or "").lower() for t in triggers)

def extract_table_name(user_q: str) -> Optional[str]:
    m = re.search(r"([a-zA-Z0-9_]+\.[a-zA-Z0-9_]+)", user_q or "")
    return m.group(1) if m else None

# -----------------------------
# 6) State
# -----------------------------
class AppState(TypedDict, total=False):
    user_id: int
    user_query: str
    memory: Dict[str, object]
    metrics: Dict[str, float]
    answer: str

# -----------------------------
# 7) Predictive Agent
# -----------------------------
def run_predictive_agent(state: AppState) -> AppState:
    user_id = state.get("user_id")
    user_q = state.get("user_query", "")
    memory = state.get("memory", {"older_summary": "", "recent_turns": []})
    metrics = state.get("metrics", {})

    lt_hits = lt_store.search(user_id, user_q) if user_id is not None else []

    prompt = [
        SystemMessage(content=PREDICTIVE_SYSTEM_PROMPT),
        HumanMessage(content=f"""
Long-term memories:
{lt_hits}

Short summary:
{memory.get('older_summary','')}

Recent turns:
{memory.get('recent_turns',[])}

Metrics:
{metrics}

User question:
{user_q}
        """)
    ]
    resp = chat.invoke(prompt).content.strip()

    memory["recent_turns"].append({"role": "user", "content": user_q})
    memory["recent_turns"].append({"role": "assistant", "content": resp})
    memory["recent_turns"] = memory["recent_turns"][-6:]

    if user_id is not None:
        lt_store.add_memory(user_id, f"Q: {user_q}\nA: {resp}")

    return {"answer": resp, "memory": memory, "metrics": metrics}

# -----------------------------
# 8) Main Handler (FIXED: schema-aware SQL generation)
# -----------------------------
def handle_query(state: AppState) -> AppState:
    user_id = state.get("user_id")
    user_q = state.get("user_query", "")
    memory = state.get("memory") or {"older_summary": "", "recent_turns": []}
    metrics = state.get("metrics") or {}

    if is_predictive_question(user_q):
        return run_predictive_agent(state)

    table = extract_table_name(user_q)

    # Pull long-term context
    lt_hits = lt_store.search(user_id, user_q) if user_id is not None else []

    context_text = ""
    if lt_hits:
        context_text += "Long-term context:\n" + "\n".join(f"- {x}" for x in lt_hits) + "\n\n"
    if memory.get("older_summary"):
        context_text += f"Short history summary:\n{memory['older_summary']}\n\n"

    # If table exists, fetch schema first (this is the key fix)
    schema_md = ""
    if table:
        schema_md = run_trino_sql(f"DESCRIBE {table}")

    needs_sql = True  # always execute SQL if table is provided

    if needs_sql and table:
        sql_msgs = [
            SystemMessage(content=SQL_GEN_SYSTEM_PROMPT),
            HumanMessage(content=f"""
{context_text}

User question:
{user_q}

Target table:
{table}

Table schema:
{schema_md}
            """)
        ]
        sql_query = chat.invoke(sql_msgs).content.strip()

        sql_results = run_trino_sql(sql_query) if sql_query.strip().upper().startswith("SELECT") else f"SQL_ERROR: Unsafe SQL: {sql_query}"

        metric_msgs = [
            SystemMessage(content=METRIC_EXTRACTION_PROMPT),
            HumanMessage(content=f"SQL Results:\n{sql_results}")
        ]
        metric_resp = chat.invoke(metric_msgs).content.strip()
        try:
            parsed = json.loads(metric_resp)
            for k, v in parsed.items():
                try:
                    metrics[k] = float(v)
                except:
                    pass
        except:
            pass

        interpret_msgs = [
            SystemMessage(content=QA_SYSTEM_PROMPT),
            HumanMessage(content=f"""
{context_text}

User question:
{user_q}

SQL query:
{sql_query}

SQL results:
{sql_results}

Extracted metrics:
{metrics}
            """)
        ]
        final_answer = chat.invoke(interpret_msgs).content.strip()
    else:
        final_answer = chat.invoke([
            SystemMessage(content=QA_SYSTEM_PROMPT),
            HumanMessage(content=f"{context_text}\nUser question: {user_q}")
        ]).content.strip()

    # Update short-term memory
    memory["recent_turns"].append({"role": "user", "content": user_q})
    memory["recent_turns"].append({"role": "assistant", "content": final_answer})
    memory["recent_turns"] = memory["recent_turns"][-6:]

    if user_id is not None:
        lt_store.add_memory(user_id, f"Q: {user_q}\nA: {final_answer}")

    return {"answer": final_answer, "memory": memory, "metrics": metrics}

# -----------------------------
# 9) Build graph
# -----------------------------
graph = StateGraph(AppState)
graph.add_node("handle_query", handle_query)
graph.add_edge(START, "handle_query")
graph.add_edge("handle_query", END)
compiled = graph.compile()

# -----------------------------
# 10) Interactive loop
# -----------------------------
USER = "demo_user"
PASS = "demo_pass_123"
user_id = lt_store.create_or_auth(USER, PASS)

state = {"memory": {"older_summary": "", "recent_turns": []}, "metrics": {}}

while True:
    user_input = input("\nAsk your query (exit/quit/bye/q to end): ").strip()
    if user_input.lower() in ["quit", "exit", "bye", "q"]:
        print("\nSession ended.\n" + "=" * 60)
        break

    out = compiled.invoke({
        "user_id": user_id,
        "user_query": user_input,
        "memory": state["memory"],
        "metrics": state["metrics"]
    })

    state["memory"] = out.get("memory", state["memory"])
    state["metrics"] = out.get("metrics", state["metrics"])
    display(Markdown(out["answer"]))
