## Import model from huggingface

In [1]:
import sentencepiece as spm
print("✅ SentencePiece loaded, version:", spm.__version__)

import torch
print("✅ PyTorch loaded, version:", torch.__version__) 
print("CUDA available:", torch.cuda.is_available()) 



✅ SentencePiece loaded, version: 0.2.1
✅ PyTorch loaded, version: 2.8.0+cpu
CUDA available: False


In [2]:
import torch
from transformers import T5Tokenizer, T5ForConditionalGeneration

# Initialize the tokenizer from Hugging Face Transformers library
tokenizer = T5Tokenizer.from_pretrained('t5-small')

# Load the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = T5ForConditionalGeneration.from_pretrained('cssupport/t5-small-awesome-text-to-sql')
model = model.to(device)
model.eval()

def generate_sql(input_prompt):
    # Tokenize the input prompt
    inputs = tokenizer(input_prompt, padding=True, truncation=True, return_tensors="pt").to(device)
    
    # Forward pass
    with torch.no_grad():
        outputs = model.generate(**inputs, max_length=512)
    
    # Decode the output IDs to a string (SQL query in this case)
    generated_sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    return generated_sql


input_prompt = "tables:\n" + "CREATE TABLE student_course_attendance (student_id VARCHAR); CREATE TABLE students (student_id VARCHAR)" + "\n" + "query for:" + "List the id of students who never attends courses?"

generated_sql = generate_sql(input_prompt)

print(f"The generated SQL query is: {generated_sql}")
#OUTPUT: The generated SQL query is: SELECT student_id FROM students WHERE NOT student_id IN (SELECT student_id FROM student_course_attendance)


tokenizer_config.json:   0%|          | 0.00/2.32k [00:00<?, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.39M [00:00<?, ?B/s]

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


config.json: 0.00B [00:00, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


pytorch_model.bin:   0%|          | 0.00/242M [00:00<?, ?B/s]

The generated SQL query is: SELECT student_id FROM students WHERE NOT student_id IN (SELECT student_id FROM student_course_attendance)


model.safetensors:   0%|          | 0.00/242M [00:00<?, ?B/s]

In [None]:
# Test the function
#input_prompt = "tables:\n" + "CREATE TABLE Catalogs (date_of_latest_revision VARCHAR)" + "\n" +"query for: Find the dates on which more than one revisions were made."
#input_prompt = "tables:\n" + "CREATE TABLE table_22767 ( \"Year\" real, \"World\" real, \"Asia\" text, \"Africa\" text, \"Europe\" text, \"Latin America/Caribbean\" text, \"Northern America\" text, \"Oceania\" text )" + "\n" +"query for:what will the population of Asia be when Latin America/Caribbean is 783 (7.5%)?."
#input_prompt = "tables:\n" + "CREATE TABLE procedures ( subject_id text, hadm_id text, icd9_code text, short_title text, long_title text ) CREATE TABLE diagnoses ( subject_id text, hadm_id text, icd9_code text, short_title text, long_title text ) CREATE TABLE lab ( subject_id text, hadm_id text, itemid text, charttime text, flag text, value_unit text, label text, fluid text ) CREATE TABLE demographic ( subject_id text, hadm_id text, name text, marital_status text, age text, dob text, gender text, language text, religion text, admission_type text, days_stay text, insurance text, ethnicity text, expire_flag text, admission_location text, discharge_location text, diagnosis text, dod text, dob_year text, dod_year text, admittime text, dischtime text, admityear text ) CREATE TABLE prescriptions ( subject_id text, hadm_id text, icustay_id text, drug_type text, drug text, formulary_drug_cd text, route text, drug_dose text )" + "\n" +"query for:" + "what is the total number of patients who were diagnosed with icd9 code 2254?"

In [4]:
input_prompt = """
tables:
CREATE TABLE housing (
    longitude REAL,
    latitude REAL,
    housing_median_age REAL,
    total_rooms REAL,
    total_bedrooms REAL,
    population REAL,
    households REAL,
    median_income REAL,
    median_house_value REAL
)
query for: What is the average median_house_value for houses with median_income > 5?
"""

generated_sql = generate_sql(input_prompt)
print(generated_sql)


SELECT AVG(midd_house_value) FROM housing WHERE median_income > 5


### Extract schema directly from SQLite database and feed it into the model prompt automatically.

In [46]:
import sqlite3, pandas as pd, difflib, re, textwrap

db_path = '../data/housing.db'
# 1) read schema from DB
def get_schema_and_columns(db_path: str, table_name: str):
    conn = sqlite3.connect(db_path)
    cur = conn.cursor()
    cur.execute(f"PRAGMA table_info({table_name});")
    info = cur.fetchall()
    conn.close()
    if not info:
        raise ValueError(f"Table '{table_name}' not found in {db_path}")
    ddl = "CREATE TABLE {t} (\n{cols}\n)".format(
        t=table_name,
        cols=",\n".join([f"    {row[1]} {row[2]}" for row in info])
    )
    columns = [row[1] for row in info]
    return ddl, columns

In [35]:
# 2) normalize user synonyms before prompting 
ALIASES = {
    # map synonyms (left) to real column names (right) in your DB
    "median_income": "MedInc",
    "median house value": "MedHouseVal",
    "house age": "HouseAge",
    "avg rooms": "AveRooms",
    "average rooms": "AveRooms",
    "avg bedrooms": "AveBedrms",
    "average bedrooms": "AveBedrms",
    "lat": "Latitude",
    "lng": "Longitude",
}

def normalize_question(question: str) -> str:
    for alias, column in ALIASES.items():
        question = question.replace(alias, column)
    return question



In [36]:
# --- 3) prompt builder that constrains the model ---
def build_prompt(db_path: str, table: str, question: str) -> str:
    ddl, cols = get_schema_and_columns(db_path, table)
    col_list = ", ".join(cols)
    question = normalize_question(question)
    rules = textwrap.dedent(f"""
    Rules:
    - Use ONLY these columns: {col_list}
    - The table name is exactly `{table}` (no aliases).
    - Return a SINGLE SQL statement, no commentary.
    - Use reasonable numeric ranges (Latitude in [-90, 90], Longitude in [-180, 180]).
    """).strip()
    return f"tables:\n{ddl}\n{rules}\nquery for: {question}", cols



In [37]:
# # Example: housing.db schema
# db_path = '../data/housing.db'
# table_name = 'housing'
# schema = get_schema(db_path, table_name)
# # 
# print(schema)


In [None]:

# --- 4) validate + auto-repair if a column doesn’t exist ---
def try_explain_sql(conn, sql: str):
    # We use EXPLAIN to validate structure without running the full query
    return pd.read_sql_query("EXPLAIN " + sql, conn)

def validate_and_fix_sql(sql: str, db_path: str, table: str, cols: list):
    conn = sqlite3.connect(db_path)
    try:
        try_explain_sql(conn, sql)
        return sql, None  # valid
    except Exception as e:
        msg = str(e)
        m = re.search(r"no such column: (\w+)", msg, re.I)
        if m:
            bad = m.group(1)
            match = difflib.get_close_matches(bad, cols, n=1, cutoff=0.6)
            if match:
                fixed = re.sub(rf"\b{bad}\b", match[0], sql)
                try:
                    try_explain_sql(conn, fixed)
                    return fixed, f"🩹 Replaced `{bad}` → `{match[0]}`"
                except Exception:
                    pass
        return None, f"Validation failed: {msg}"
    finally:
        conn.close()



In [44]:
# --- 5) end-to-end guarded generation ---
def generate_sql_guarded(question: str, db_path="../data/housing.db", table="housing", max_attempts=2):
    prompt, cols = build_prompt(db_path, table, question)
    sql = generate_sql(prompt)  # uses your working HF model function

    for _ in range(max_attempts):
        fixed_sql, note = validate_and_fix_sql(sql, db_path, table, cols)
        if fixed_sql:
            if note:
                print(note)
            return fixed_sql
        # If invalid, re-prompt the model with the error and allowed columns
        error_msg = note or "Invalid SQL."
        reprompt = (
            f"{prompt}\n\nPrevious SQL:\n{sql}\n\n"
            f"Error: {error_msg}\n"
            f"Regenerate a valid SQL using ONLY these columns: {', '.join(cols)}.\n"
            f"Return just the SQL."
        )
        sql = generate_sql(reprompt)

    return sql  # best effort

In [45]:
question = "Average house age for blocks with MedInc > 5 and Latitude > 35"
safe_sql = generate_sql_guarded(question)
print("✅ Final SQL:", safe_sql)


NameError: name 'difflib' is not defined

In [None]:
# step 2: add the schema to the prompt and generate SQL query
def build_prompt(schema, question):
    prompt = f"tables:\n{schema}\nquery for: {question}"
    return prompt 

In [None]:
# usage
question = "What is the average median_house_value for houses with median_income > 5?"
input_prompt = build_prompt(schema, question)
generated_sql = generate_sql(input_prompt)
print(generated_sql)


SELECT AVG(HorseAge) FROM housing WHERE Latitude > 5


### Execute the suggested SQL query from the housing.db

In [None]:
# execute the generated SQL query against the database
def run_query(db_path, query):
    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()
    cursor.execute(query)
    results = cursor.fetchall()
    conn.close()
    return results
results = run_query(db_path, generated_sql)
print(results)

OperationalError: no such column: HorseAge

### Save outputs\

In [None]:
import pandas as pd
from pathlib import Path

results_file = Path("../data/generated_queries.csv")

# Example: log prompt + SQL
log = pd.DataFrame([{
    "input_prompt": input_prompt,
    "generated_sql": generated_sql
}])

if results_file.exists():
    log.to_csv(results_file, mode="a", header=False, index=False)
else:
    log.to_csv(results_file, index=False)

print("✅ Query saved to", results_file)


✅ Query saved to ..\data\generated_queries.csv
