TINY SQL EXPERT-SLM

In [1]:
!pip install transformers accelerate bitsandbytes sentencepiece
!pip install sqlparse


Collecting bitsandbytes
  Downloading bitsandbytes-0.48.2-py3-none-manylinux_2_24_x86_64.whl.metadata (10 kB)
Downloading bitsandbytes-0.48.2-py3-none-manylinux_2_24_x86_64.whl (59.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m59.4/59.4 MB[0m [31m11.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: bitsandbytes
Successfully installed bitsandbytes-0.48.2


In [2]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

MODEL_NAME = "microsoft/Phi-3-mini-4k-instruct"

print("Loading model... (this may take 20–40 seconds)")

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.float16,
    device_map="auto",   # chooses GPU or CPU automatically
    low_cpu_mem_usage=True
)

def run_model(prompt: str) -> str:
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

    output = model.generate(
        **inputs,
        max_new_tokens=256,
        temperature=0.3,
        do_sample=True
    )

    return tokenizer.decode(output[0], skip_special_tokens=True)


Loading model... (this may take 20–40 seconds)


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

added_tokens.json:   0%|          | 0.00/306 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/599 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/967 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.97G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/2.67G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/181 [00:00<?, ?B/s]

In [3]:
SCHEMA = """
Users (
    user_id INT PRIMARY KEY,
    name TEXT,
    email TEXT
)

Orders (
    order_id INT PRIMARY KEY,
    user_id INT,
    order_date DATE,
    amount DECIMAL,
    FOREIGN KEY(user_id) REFERENCES Users(user_id)
)

Products (
    product_id INT PRIMARY KEY,
    name TEXT,
    price DECIMAL
)

OrderItems (
    item_id INT PRIMARY KEY,
    order_id INT,
    product_id INT,
    quantity INT,
    FOREIGN KEY(order_id) REFERENCES Orders(order_id),
    FOREIGN KEY(product_id) REFERENCES Products(product_id)
)
"""


In [4]:
BASE_PROMPT = f"""
Convert the user question into a SQL query.

Use ONLY the following schema:
{SCHEMA}

RULES:
- Output ONLY SQL.
- No explanation.
- No markdown.
- Use JOINs when needed.
- Do NOT invent columns or tables.
"""


In [30]:
ALLOWED_TABLES = {
    "users": ["user_id", "name", "email"],
    "orders": ["order_id", "user_id", "order_date", "amount"],
    "products": ["product_id", "name", "price"],
    "orderitems": ["item_id", "order_id", "product_id", "quantity"]
}


In [37]:
import re
import sqlparse
#from schema import ALLOWED_TABLES

FORBIDDEN = ["DROP", "DELETE", "ALTER"]

def has_forbidden_words(sql):
    return any(word in sql.upper() for word in FORBIDDEN)

def extract_identifiers(sql):
    tokens = sqlparse.parse(sql)[0].tokens
    identifiers = []
    for token in tokens:
        if token.ttype is None and hasattr(token, "tokens"):
            for sub in token.tokens:
                if sub.ttype is None:
                    identifiers.append(sub.value.lower())
    return identifiers

def validate_schema(sql):
    sql_lower = sql.lower()

    for table, cols in ALLOWED_TABLES.items():
        if table in sql_lower:
            # get all columns used in the SQL for this table
            for col in re.findall(rf"{table}\.\s*(\w+)", sql_lower):
                if col not in cols:
                    return False, f"Column '{col}' does not exist in table '{table}'."

    # Detect any tables used that aren’t in the schema
    for match in re.findall(r"from\s+(\w+)", sql_lower) + re.findall(r"join\s+(\w+)", sql_lower):
        if match not in ALLOWED_TABLES:
            return False, f"Table '{match}' does not exist in schema."

    return True, None

def validate(sql):
    if has_forbidden_words(sql):
        return False, "Forbidden keyword detected."

    try:
        parsed = sqlparse.parse(sql)
        if len(parsed) == 0:
            return False, "Unparsable SQL."
    except Exception:
        return False, "SQL parsing failed."

    # Schema validation
    ok, err = validate_schema(sql)
    if not ok:
        return False, err

    return True, None

In [69]:
#from model_interface import run_model
#from sql_validator import validate
#from schema import BASE_PROMPT

def extract_sql(text):
    idx = text.lower().find("select")
    if idx == -1:
        idx = text.lower().find("with")
    if idx == -1:
        return text
    return text[idx:].strip()

def generate_sql(question):
    prompt = BASE_PROMPT + f"\nQuestion: {question}"

    for attempt in range(3):
        raw_output = run_model(prompt).strip()


        sql = extract_sql(raw_output)

        valid, error = validate(sql)

        if valid:
            print("\nGenerated SQL:\n")
            return sql
        print(f"[Attempt {attempt+1}] SQL INVALID. \nRetrying")


        # retry
        prompt = (
            BASE_PROMPT
            + f"\nThe previous SQL was invalid: {sql}"
            + f"\nError: {error}"
            + f"\nFix the SQL. Output ONLY SQL.\n"
            + f"Question: {question}"
        )

    return "ERROR: Could not generate valid SQL in 3 attempt ."


In [72]:
if __name__ == "__main__":
    q = input("Enter your question: ")
    #print("\nGenerated SQL:\n")
    print(generate_sql(q))

Enter your question: List users who ordered products named 'Laptop' or 'Tablet' in the past 90 days.

Generated SQL:

SELECT DISTINCT u.user_id, u.name, u.email
FROM Users u
JOIN Orders o ON u.user_id = o.user_id
JOIN OrderItems oi ON o.order_id = oi.order_id
JOIN Products p ON oi.product_id = p.product_id
WHERE p.name IN ('Laptop', 'Tablet')
AND o.order_date >= CURRENT_DATE - INTERVAL '90 days';
