In [25]:
import os
import numpy as np
import pandas as pd
from dotenv import load_dotenv
import faiss
from google import genai

load_dotenv()
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
client = genai.Client(api_key=GEMINI_API_KEY)


In [4]:
import pandas as pd

# Load dataset
df = pd.read_csv("synthetic_dataset.csv")

# Keep only useful columns
dataset = df[["sql_prompt", "sql_context", "sql"]]

print(dataset.head())


                                          sql_prompt  \
0  What is the total volume of timber sold by eac...   
1  List all the unique equipment types and their ...   
2  How many marine species are found in the South...   
3  What is the total trade value and average pric...   
4  Find the energy efficiency upgrades with the h...   

                                         sql_context  \
0  CREATE TABLE salesperson (salesperson_id INT, ...   
1  CREATE TABLE equipment_maintenance (equipment_...   
2  CREATE TABLE marine_species (name VARCHAR(50),...   
3  CREATE TABLE trade_history (id INT, trader_id ...   
4  CREATE TABLE upgrades (id INT, cost FLOAT, typ...   

                                                 sql  
0  SELECT salesperson_id, name, SUM(volume) as to...  
1  SELECT equipment_type, SUM(maintenance_frequen...  
2  SELECT COUNT(*) FROM marine_species WHERE loca...  
3  SELECT trader_id, stock, SUM(price * quantity)...  
4  SELECT type, cost FROM (SELECT type, cost, ROW..

In [5]:
from sentence_transformers import SentenceTransformer
import numpy as np
import faiss

# Load embedding model
embedder = SentenceTransformer("all-MiniLM-L6-v2")

# Encode all prompts
question_embeddings = embedder.encode(
    dataset["sql_prompt"].tolist(),
    convert_to_numpy=True
)

# Build FAISS index
dimension = question_embeddings.shape[1]
index = faiss.IndexFlatL2(dimension)
index.add(question_embeddings)


  from .autonotebook import tqdm as notebook_tqdm
Loading weights: 100%|██████████| 103/103 [00:00<00:00, 1002.69it/s, Materializing param=pooler.dense.weight]                             
[1mBertModel LOAD REPORT[0m from: sentence-transformers/all-MiniLM-L6-v2
Key                     | Status     |  | 
------------------------+------------+--+-
embeddings.position_ids | UNEXPECTED |  | 

[3mNotes:
- UNEXPECTED[3m	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.[0m


In [8]:
def retrieve_few_shot_examples(user_question, k=2):
    user_emb = embedder.encode([user_question], convert_to_numpy=True)
    distances, indices = index.search(user_emb, k)

    examples = []
    for i in indices[0]:
        row = dataset.iloc[i]
        examples.append({
            "question": row["sql_prompt"],
            "context": row["sql_context"],
            "sql": row["sql"]
        })

    return examples


In [9]:
# Simple DANKE-style keyword matcher

def danke_keyword_match(user_question):
    KM = {}

    if "salesperson" in user_question.lower():
        KM["salesperson"] = {"table": "salesperson", "column": "salesperson_id"}

    if "equipment" in user_question.lower():
        KM["equipment"] = {"table": "equipment_maintenance", "column": "equipment_type"}

    if "marine" in user_question.lower():
        KM["marine"] = {"table": "marine_species", "column": "location"}

    return KM


In [None]:


def generate_sql(user_question, use_danke=False, use_dfe=False, k=2):

    # 1. DANKE keyword matching
    KM = danke_keyword_match(user_question) if use_danke else None

    # 2. Few-shot retrieval
    examples = retrieve_few_shot_examples(user_question, k) if use_dfe else []


    # Build prompt

    prompt = "You are an AI assistant that writes correct SQL.\n\n"

    # Add few-shot schema from retrieved examples
    if examples:
        prompt += "Database context examples:\n"
        for ex in examples:
            prompt += f"{ex['context']}\n"

    # Add DANKE hints
    if KM:
        prompt += f"\nRelevant schema hints:\n{KM}\n"

    # Add few-shot Q&A
    if examples:
        prompt += "\nFew-shot examples:\n"
        for ex in examples:
            prompt += f"""
Question: {ex['question']}
SQL: {ex['sql']}
"""

    # Final user question
    prompt += f"\nUser question: {user_question}\nGenerate only the SQL query."


    # 4. call Gemini

    response = client.models.generate_content(
        model="gemini-3-flash-preview",
        contents=prompt,
        config={"temperature": 0}
    )

    return response.text.strip()

In [12]:
q = "How many marine species are found in the Southern Ocean?"

print("LLM only:\n", generate_sql(q))

print("\nLLM + DFE:\n", generate_sql(q, use_dfe=True))

print("\nLLM + DANKE:\n", generate_sql(q, use_danke=True))

print("\nComplete pipeline:\n", generate_sql(q, use_danke=True, use_dfe=True))


LLM only:
 ```sql
SELECT COUNT(*) 
FROM marine_species 
WHERE region = 'Southern Ocean';
```

LLM + DFE:
 SELECT COUNT(*) FROM marine_species WHERE location = 'Southern Ocean';

LLM + DANKE:
 ```sql
SELECT count(*) FROM marine_species WHERE location = 'Southern Ocean';
```

Complete pipeline:
 SELECT COUNT(*) FROM marine_species WHERE location = 'Southern Ocean';


In [None]:
import pandas as pd

df = pd.read_csv("synthetic_dataset.csv")

actual_sql = df["sql"].tolist()


pred_sql = []

for prompt in df["sql_prompt"]:
    response = client.models.generate_content(
        model="gemini-3-flash-preview",
        contents=f"You are an AI that writes SQL queries.\n\nUser question: {prompt}\nSQL:"
    )

    prediction = response.text
    pred_sql.append(prediction)

df["predicted_sql"] = pred_sql



In [27]:
response = client.models.generate_content(
        model="gemini-3-flash-preview",
        contents=f"You are an AI that writes SQL queries.\n\nUser question: i want to see whole table named after 'Bob ' table name customer\nSQL:"
        
    )
print(response.text)

```sql
SELECT * FROM "Bob customer";
```


In [23]:
df.head()

Unnamed: 0,id,domain,domain_description,sql_complexity,sql_complexity_description,sql_task_type,sql_task_type_description,sql_prompt,sql_context,sql,sql_explanation
0,5097,forestry,Comprehensive data on sustainable forest manag...,single join,"only one join (specify inner, outer, cross)",analytics and reporting,"generating reports, dashboards, and analytical...",What is the total volume of timber sold by eac...,"CREATE TABLE salesperson (salesperson_id INT, ...","SELECT salesperson_id, name, SUM(volume) as to...","Joins timber_sales and salesperson tables, gro..."
1,5098,defense industry,"Defense contract data, military equipment main...",aggregation,"aggregation functions (COUNT, SUM, AVG, MIN, M...",analytics and reporting,"generating reports, dashboards, and analytical...",List all the unique equipment types and their ...,CREATE TABLE equipment_maintenance (equipment_...,"SELECT equipment_type, SUM(maintenance_frequen...",This query groups the equipment_maintenance ta...
2,5099,marine biology,"Comprehensive data on marine species, oceanogr...",basic SQL,basic SQL with a simple select statement,analytics and reporting,"generating reports, dashboards, and analytical...",How many marine species are found in the South...,"CREATE TABLE marine_species (name VARCHAR(50),...",SELECT COUNT(*) FROM marine_species WHERE loca...,This query counts the number of marine species...
3,5100,financial services,Detailed financial data including investment s...,aggregation,"aggregation functions (COUNT, SUM, AVG, MIN, M...",analytics and reporting,"generating reports, dashboards, and analytical...",What is the total trade value and average pric...,"CREATE TABLE trade_history (id INT, trader_id ...","SELECT trader_id, stock, SUM(price * quantity)...",This query calculates the total trade value an...
4,5101,energy,Energy market data covering renewable energy s...,window functions,"window functions (e.g., ROW_NUMBER, LEAD, LAG,...",analytics and reporting,"generating reports, dashboards, and analytical...",Find the energy efficiency upgrades with the h...,"CREATE TABLE upgrades (id INT, cost FLOAT, typ...","SELECT type, cost FROM (SELECT type, cost, ROW...",The SQL query uses the ROW_NUMBER function to ...
