## Import model from huggingface

In [1]:
import sentencepiece as spm
print("✅ SentencePiece loaded, version:", spm.__version__)

import torch
print("✅ PyTorch loaded, version:", torch.__version__) 
print("CUDA available:", torch.cuda.is_available()) 



✅ SentencePiece loaded, version: 0.2.1
✅ PyTorch loaded, version: 2.8.0+cpu
CUDA available: False


In [2]:
import torch
from transformers import T5Tokenizer, T5ForConditionalGeneration

# Initialize the tokenizer from Hugging Face Transformers library
tokenizer = T5Tokenizer.from_pretrained('t5-small')

# Load the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = T5ForConditionalGeneration.from_pretrained('cssupport/t5-small-awesome-text-to-sql')
model = model.to(device)
model.eval()

def generate_sql(input_prompt):
    # Tokenize the input prompt
    inputs = tokenizer(input_prompt, padding=True, truncation=True, return_tensors="pt").to(device)
    
    # Forward pass
    with torch.no_grad():
        outputs = model.generate(**inputs, max_length=512)
    
    # Decode the output IDs to a string (SQL query in this case)
    generated_sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    return generated_sql


input_prompt = "tables:\n" + "CREATE TABLE student_course_attendance (student_id VARCHAR); CREATE TABLE students (student_id VARCHAR)" + "\n" + "query for:" + "List the id of students who never attends courses?"

generated_sql = generate_sql(input_prompt)

print(f"The generated SQL query is: {generated_sql}")
#OUTPUT: The generated SQL query is: SELECT student_id FROM students WHERE NOT student_id IN (SELECT student_id FROM student_course_attendance)


tokenizer_config.json:   0%|          | 0.00/2.32k [00:00<?, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.39M [00:00<?, ?B/s]

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


config.json: 0.00B [00:00, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


pytorch_model.bin:   0%|          | 0.00/242M [00:00<?, ?B/s]

The generated SQL query is: SELECT student_id FROM students WHERE NOT student_id IN (SELECT student_id FROM student_course_attendance)


model.safetensors:   0%|          | 0.00/242M [00:00<?, ?B/s]

In [None]:
# Test the function
#input_prompt = "tables:\n" + "CREATE TABLE Catalogs (date_of_latest_revision VARCHAR)" + "\n" +"query for: Find the dates on which more than one revisions were made."
#input_prompt = "tables:\n" + "CREATE TABLE table_22767 ( \"Year\" real, \"World\" real, \"Asia\" text, \"Africa\" text, \"Europe\" text, \"Latin America/Caribbean\" text, \"Northern America\" text, \"Oceania\" text )" + "\n" +"query for:what will the population of Asia be when Latin America/Caribbean is 783 (7.5%)?."
#input_prompt = "tables:\n" + "CREATE TABLE procedures ( subject_id text, hadm_id text, icd9_code text, short_title text, long_title text ) CREATE TABLE diagnoses ( subject_id text, hadm_id text, icd9_code text, short_title text, long_title text ) CREATE TABLE lab ( subject_id text, hadm_id text, itemid text, charttime text, flag text, value_unit text, label text, fluid text ) CREATE TABLE demographic ( subject_id text, hadm_id text, name text, marital_status text, age text, dob text, gender text, language text, religion text, admission_type text, days_stay text, insurance text, ethnicity text, expire_flag text, admission_location text, discharge_location text, diagnosis text, dod text, dob_year text, dod_year text, admittime text, dischtime text, admityear text ) CREATE TABLE prescriptions ( subject_id text, hadm_id text, icustay_id text, drug_type text, drug text, formulary_drug_cd text, route text, drug_dose text )" + "\n" +"query for:" + "what is the total number of patients who were diagnosed with icd9 code 2254?"

In [4]:
input_prompt = """
tables:
CREATE TABLE housing (
    longitude REAL,
    latitude REAL,
    housing_median_age REAL,
    total_rooms REAL,
    total_bedrooms REAL,
    population REAL,
    households REAL,
    median_income REAL,
    median_house_value REAL
)
query for: What is the average median_house_value for houses with median_income > 5?
"""

generated_sql = generate_sql(input_prompt)
print(generated_sql)


SELECT AVG(midd_house_value) FROM housing WHERE median_income > 5


### Extract schema directly from SQLite database and feed it into the model prompt automatically.

In [118]:
import re, difflib, sqlite3, pandas as pd

SQL_KEYWORDS = {
    "select","from","where","and","or","not","in","between","like","is","null",
    "group","by","order","limit","offset","asc","desc","avg","sum","min","max",
    "count","distinct","as","on","join","left","right","inner","outer","having"
}
SQL_FUNCS = {"avg","sum","min","max","count","abs","round","upper","lower","coalesce"}


In [119]:
import sqlite3, pandas as pd, difflib, re, textwrap

db_path = '../data/housing.db'
# 1) read schema from DB
def get_schema_and_columns(db_path: str, table_name: str):
    conn = sqlite3.connect(db_path)
    cur = conn.cursor()
    cur.execute(f"PRAGMA table_info({table_name});")
    info = cur.fetchall()
    conn.close()
    if not info:
        raise ValueError(f"Table '{table_name}' not found in {db_path}")
    ddl = "CREATE TABLE {t} (\n{cols}\n)".format(
        t=table_name,
        cols=",\n".join([f"    {row[1]} {row[2]}" for row in info])
    )
    columns = [row[1] for row in info]
    return ddl, columns

In [120]:
# 2) normalize user synonyms before prompting 
ALIASES = {
    # map synonyms (left) to real column names (right) in your DB
    "median_income": "MedInc",
    "median house value": "MedHouseVal",
    "house age": "HouseAge",
    "avg rooms": "AveRooms",
    "average rooms": "AveRooms",
    "avg bedrooms": "AveBedrms",
    "average bedrooms": "AveBedrms",
    "lat": "Latitude",
    "lng": "Longitude",
}

def normalize_question(question: str) -> str:
    for alias, column in ALIASES.items():
        question = question.replace(alias, column)
    return question



In [121]:
# --- 3) prompt builder that constrains the model ---
def build_prompt(db_path: str, table: str, question: str) -> str:
    ddl, cols = get_schema_and_columns(db_path, table)
    col_list = ", ".join(cols)
    question = normalize_question(question)
    rules = textwrap.dedent(f"""
    Rules:
    - Use ONLY these columns: {col_list}
    - The table name is exactly `{table}` (no aliases).
    - Return a SINGLE SQL statement, no commentary.
    - Use reasonable numeric ranges (Latitude in [-90, 90], Longitude in [-180, 180]).
    """).strip()
    return f"tables:\n{ddl}\n{rules}\nquery for: {question}", cols



In [122]:
# # Example: housing.db schema
# db_path = '../data/housing.db'
# table_name = 'housing'
# schema = get_schema(db_path, table_name)
# # 
# print(schema)


In [123]:
# --- plug-in: strict aliases to override hallucinations before fuzzy match ---
STRICT_ALIASES = {
    # hard typos seen in practice
    "HorseAge": "HouseAge",
    "HypertAge": "HouseAge",
    "House_Age": "HouseAge",
    "Med_Inc": "MedInc",
    "MedianIncome": "MedInc",
    "Median_House_Value": "MedHouseVal",
    "Lat": "Latitude",
    "Long": "Longitude",
}

def apply_strict_aliases(sql: str):
    fixed = sql
    for bad, good in STRICT_ALIASES.items():
        # replace case-insensitively on word boundaries
        fixed = re.sub(rf"\b{re.escape(bad)}\b", good, fixed, flags=re.I)
    return fixed


In [124]:
import re, difflib, sqlite3, pandas as pd

SQL_KEYWORDS = {
    "select","from","where","and","or","not","in","between","like","is","null",
    "group","by","order","limit","offset","asc","desc","avg","sum","min","max",
    "count","distinct","as","on","join","left","right","inner","outer","having"
}
SQL_FUNCS = {"avg","sum","min","max","count","abs","round","upper","lower","coalesce"}

def tokenize_words(sql: str):
    # words made of letters/underscore/digits; ignore quoted identifiers for simplicity
    return re.findall(r"[A-Za-z_][A-Za-z_0-9]*", sql)

def enforce_columns(sql: str, cols: list, table_name: str):
    """Replace any unknown identifiers with closest valid column before running EXPLAIN."""
    col_set = {c.lower(): c for c in cols}
    table_lc = table_name.lower()

    tokens = tokenize_words(sql)
    replacements = {}
    for tok in tokens:
        tl = tok.lower()
        if (
            tl in SQL_KEYWORDS
            or tl in SQL_FUNCS
            or tl == table_lc
            or tl.isdigit()
        ):
            continue
        if tl not in col_set:  # unknown identifier -> try to map
            # fuzzy match against columns
            cand = difflib.get_close_matches(tl, list(col_set.keys()), n=1, cutoff=0.6)
            if cand:
                replacements[tok] = col_set[cand[0]]

    fixed = sql
    for bad, good in replacements.items():
        fixed = re.sub(rf"\b{re.escape(bad)}\b", good, fixed)

    note = None
    if replacements:
        pairs = ", ".join([f"`{b}`→`{g}`" for b, g in replacements.items()])
        note = f"🩹 Pre-fixed unknown identifiers: {pairs}"
    return fixed, note

def try_explain_sql(conn, sql: str):
    return pd.read_sql_query("EXPLAIN " + sql, conn)

def validate_and_fix_sql(sql: str, db_path: str, table: str, cols: list):
    # 1) proactive pass: fix unknown identifiers before EXPLAIN
    sql, pre_note = enforce_columns(sql, cols, table)

    conn = sqlite3.connect(db_path)
    try:
        try_explain_sql(conn, sql)  # if OK, we’re done
        return sql, pre_note
    except Exception as e:
        msg = str(e)
        # 2) fallback: parse SQLite error for a specific bad column and try a single replacement
        m = re.search(r"no such column: ([\w_]+)", msg, re.I)
        if m:
            bad = m.group(1)
            match = difflib.get_close_matches(bad.lower(), [c.lower() for c in cols], n=1, cutoff=0.6)
            if match:
                good = next(c for c in cols if c.lower() == match[0])
                fixed = re.sub(rf"\b{re.escape(bad)}\b", good, sql, flags=re.I)
                try:
                    try_explain_sql(conn, fixed)
                    note2 = f"🩹 Replaced `{bad}` → `{good}`"
                    note = f"{pre_note}\n{note2}" if pre_note else note2
                    return fixed, note
                except Exception:
                    pass
        # 3) give up with message
        return None, f"Validation failed: {msg}"
    finally:
        conn.close()

def validate_and_fix_sql(sql: str, db_path: str, table: str, cols: list):
    # 0) hard replacements first
    sql0 = apply_strict_aliases(sql)

    # 1) proactive pass: fix unknown identifiers before EXPLAIN
    sql1, pre_note = enforce_columns(sql0, cols, table)

    conn = sqlite3.connect(db_path)
    try:
        try_explain_sql(conn, sql1)  # if OK, done
        return sql1, pre_note
    except Exception as e:
        msg = str(e)
        # fallback: parse specific bad column & fuzzy replace
        m = re.search(r"no such column: ([\w_]+)", msg, re.I)
        if m:
            bad = m.group(1)
            match = difflib.get_close_matches(bad.lower(), [c.lower() for c in cols], n=1, cutoff=0.6)
            if match:
                good = next(c for c in cols if c.lower() == match[0])
                fixed = re.sub(rf"\b{re.escape(bad)}\b", good, sql1, flags=re.I)
                try:
                    try_explain_sql(conn, fixed)
                    note2 = f"🩹 Replaced `{bad}` → `{good}`"
                    note = f"{pre_note}\n{note2}" if pre_note else note2
                    return fixed, note
                except Exception:
                    pass
        return None, f"Validation failed: {msg}"
    finally:
        conn.close()



In [125]:
# --- 5) end-to-end guarded generation ---
def generate_sql_guarded(question: str, db_path="../data/housing.db", table="housing", max_attempts=2):
    prompt, cols = build_prompt(db_path, table, question)
    sql = generate_sql(prompt)  # uses your working HF model function

    for _ in range(max_attempts):
        fixed_sql, note = validate_and_fix_sql(sql, db_path, table, cols)
        if fixed_sql:
            if note:
                print(note)
            return fixed_sql
        # If invalid, re-prompt the model with the error and allowed columns
        error_msg = note or "Invalid SQL."
        reprompt = (
            f"{prompt}\n\nPrevious SQL:\n{sql}\n\n"
            f"Error: {error_msg}\n"
            f"Regenerate a valid SQL using ONLY these columns: {', '.join(cols)}.\n"
            f"Return just the SQL."
        )
        sql = generate_sql(reprompt)

    return sql  # best effort

In [126]:
# 1) See actual columns the validator will enforce:
ddl, cols = get_schema_and_columns("../data/housing.db", "housing")
print(cols)  # must include HouseAge, MedInc, Latitude, etc.

# 2) Generate with guard
q = "Average house age for blocks with MedInc > 5 and Latitude > 35"
sql = generate_sql_guarded(q)
print(sql)

# 3) Execute to confirm no errors:
import pandas as pd, sqlite3
with sqlite3.connect("../data/housing.db") as conn:
    print(pd.read_sql_query("EXPLAIN " + sql, conn).head())


['MedInc', 'HouseAge', 'AveRooms', 'AveBedrms', 'Population', 'AveOccup', 'Latitude', 'Longitude', 'MedHouseVal']
SELECT AVG(HouseAge) FROM housing WHERE MedInc > 5 AND Latitude > 35
   addr    opcode  p1  p2  p3    p4  p5 comment
0     0      Init   0  18   0  None   0    None
1     1      Null   0   1   2  None   0    None
2     2  OpenRead   0   2   0     7   0    None
3     3    Rewind   0  14   0  None   0    None
4     4    Column   0   0   3  None   0    None


In [127]:
question = "Average house age for blocks with MedInc > 5 and Latitude > 35"
safe_sql = generate_sql_guarded(question)
print("✅ Final SQL:", safe_sql)


✅ Final SQL: SELECT AVG(HouseAge) FROM housing WHERE MedInc > 5 AND Latitude > 35


### Execute the suggested SQL query from the housing.db

In [130]:
# execute the generated SQL query against the database
def run_query(db_path, query):
    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()
    cursor.execute(query)
    results = cursor.fetchall()
    conn.close()
    return results
results = run_query(db_path, safe_sql)
print(results)

[(26.02671755725191,)]


### Store query history table

In [135]:
import sqlite3, time
with sqlite3.connect("../data/housing.db") as conn:
    conn.execute("""
      CREATE TABLE IF NOT EXISTS query_history(
        ts INTEGER, question TEXT, sql TEXT
      )
    """)
    conn.execute("INSERT INTO query_history VALUES(?,?,?)",
                 (int(time.time()), question, safe_sql))
print("✅ Logged to query_history")


✅ Logged to query_history


### Save outputs\

In [134]:
import pandas as pd
from pathlib import Path

results_file = Path("../data/generated_queries.csv")

# Example: log prompt + SQL
log = pd.DataFrame([{
    "input_prompt": input_prompt,
    "generated_sql": safe_sql
}])

if results_file.exists():
    log.to_csv(results_file, mode="a", header=False, index=False)
else:
    log.to_csv(results_file, index=False)

print("✅ Query saved to", results_file)


✅ Query saved to ..\data\generated_queries.csv
