In [63]:
# import pyodbc
# import pandas as pd

# conn = pyodbc.connect(
#     "DRIVER={ODBC Driver 18 for SQL Server};"
#     "SERVER=YUEFANG;"
#     "DATABASE=adaptive rag;"
#     "UID=rag_user;"
#     "PWD=Haha100!;"
#     "Encrypt=no;"
#     "TrustServerCertificate=yes;"
# )

# cursor = conn.cursor()
# cursor.execute("SELECT @@VERSION")
# print(cursor.fetchone())

In [64]:
# query = "select * from [adaptive rag]..healthcare_dataset"

# cursor.execute(query)

In [65]:
# healthcare = pd.read_sql_query(query, conn)

In [66]:
# healthcare

In [67]:
import re
import pyodbc
from typing import List

from langchain_ollama import ChatOllama
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_chroma import Chroma
from langchain_core.documents import Document
from langchain_core.prompts import ChatPromptTemplate

In [68]:
# =========================
# CONFIG
# =========================

OLLAMA_MODEL = "llama3:8b"

SQL_CONN_STR = (
    "DRIVER={ODBC Driver 18 for SQL Server};"
    "SERVER=YUEFANG;"
    "DATABASE=adaptive rag;"
    "UID=rag_user;"
    "PWD=Haha100!;"
    "Encrypt=no;"
    "TrustServerCertificate=yes;"
)

CHROMA_DIR = "./schema_store"

TOP_K_SCHEMA = 5
MAX_ROWS = 100

In [69]:
# =========================
# 1. LOAD SCHEMA FROM SQL SERVER
# =========================

def load_schema_from_db() -> List[Document]:
    conn = pyodbc.connect(SQL_CONN_STR)
    cursor = conn.cursor()

    cursor.execute("""
        SELECT
            TABLE_SCHEMA,
            TABLE_NAME,
            COLUMN_NAME,
            DATA_TYPE
        FROM INFORMATION_SCHEMA.COLUMNS
        ORDER BY TABLE_SCHEMA, TABLE_NAME
    """)

    tables = {}
    for schema, table, col, dtype in cursor.fetchall():
        key = f"{schema}.{table}"
        tables.setdefault(key, []).append(f"- {col} ({dtype})")

    print(tables)

    docs = []
    for table, cols in tables.items():
        content = f"Table: {table}\nColumns:\n" + "\n".join(cols)
        docs.append(Document(page_content=content))

    conn.close()
    return docs

In [70]:
# =========================
# 2. BUILD / LOAD VECTOR STORE
# =========================

def build_schema_store():
    embeddings = HuggingFaceEmbeddings(
        model_name="sentence-transformers/all-MiniLM-L6-v2"
    )

    docs = load_schema_from_db()

    vectordb = Chroma.from_documents(
        docs,
        embedding=embeddings,
        persist_directory=CHROMA_DIR
    )
    vectordb.persist()
    return vectordb

def load_schema_store():
    embeddings = HuggingFaceEmbeddings(
        model_name="sentence-transformers/all-MiniLM-L6-v2"
    )
    return Chroma(
        persist_directory=CHROMA_DIR,
        embedding_function=embeddings
    )

In [71]:
# =========================
# 3. SQL SAFETY CHECKS
# =========================

def is_safe_sql(sql: str) -> bool:
    sql_l = sql.lower()

    forbidden = [
        "insert", "update", "delete", "drop",
        "alter", "truncate", "exec", "merge"
    ]
    if any(word in sql_l for word in forbidden):
        return False

    if not sql_l.strip().startswith("select"):
        return False

    if re.search(r"select\s+\*", sql_l):
        return False

    return True

In [72]:
# =========================
# 4. SQL GENERATION
# =========================

SQL_PROMPT = ChatPromptTemplate.from_template("""
You are a senior data engineer.

Generate ONE valid Microsoft SQL Server query.

Rules:
- SELECT statements only
- Use TOP if result is not aggregated
- No SELECT *
- Use only tables and columns provided
- Do not explain anything
- Do not wrap in markdown

Schema:
{schema}

Question:
{question}

Return SQL only.
""")

def generate_sql(llm, schema_context: str, question: str) -> str:
    prompt = SQL_PROMPT.format(
        schema=schema_context,
        question=question
    )
    response = llm.invoke(prompt)
    return response.content.strip()

In [73]:
# =========================
# 5. EXECUTE SQL
# =========================

def run_sql(sql: str):
    conn = pyodbc.connect(SQL_CONN_STR)
    cursor = conn.cursor()

    cursor.execute(sql)
    columns = [c[0] for c in cursor.description]
    rows = cursor.fetchmany(MAX_ROWS)

    conn.close()
    return columns, rows

In [74]:
# =========================
# 6. MAIN PIPELINE
# =========================

def ask(question: str):
    llm = ChatOllama(model=OLLAMA_MODEL, temperature=0)

    try:
        vectordb = load_schema_store()
    except:
        vectordb = build_schema_store()

    schema_docs = vectordb.similarity_search(question, k=TOP_K_SCHEMA)
    schema_context = "\n\n".join(d.page_content for d in schema_docs)

    sql = generate_sql(llm, schema_context, question)
    print("\nGenerated SQL:\n", sql)

    if not is_safe_sql(sql):
        raise ValueError("Unsafe SQL generated")

    cols, rows = run_sql(sql)

    print("\nResult:")
    print(cols)
    for r in rows:
        print(r)

In [None]:
# =========================
# 7. RUN
# =========================

if __name__ == "__main__":
    while True:
        q = input("\nAsk a question (or 'exit'): ")
        if q.lower() == "exit":
            break
        ask(q)


Generated SQL:
 SELECT TOP 5 patient_id, billing_amount
FROM visits
ORDER BY billing_amount DESC;

Result:
['patient_id', 'billing_amount']
(8789, 49995.90234375)
(9384, 49994.984375)
(3256, 49985.97265625)
(734, 49974.8046875)
(3526, 49974.30078125)

Generated SQL:
 SELECT TOP 1 patient_id, SUM(billing_amount) AS total_billing
FROM visits
GROUP BY patient_id
ORDER BY total_billing DESC;

Result:
['patient_id', 'total_billing']
(4096, 159668.2607421875)
