In [26]:
import os
import numpy as np
import pandas as pd
from dotenv import load_dotenv
import faiss
from google import genai
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score



load_dotenv()
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
client = genai.Client(api_key=GEMINI_API_KEY)


In [27]:
df = pd.read_csv("synthetic_dataset.csv")

questions = df["sql_prompt"].tolist()
gold_sql = df["sql"].tolist()


In [31]:
#  Helper => LLM SQL generator

def generate_sql(prompt):
    response = client.models.generate_content(
        model="gemini-3-flash-preview",
        contents=f"You generate only SQL.\nQuestion: {prompt}\nSQL:",
        
    )
    return response.text.strip()


In [None]:
result =generate_sql("What is the total sales for each product category in the last quarter?")
print(result)


In [6]:
# DFE retrieval (simple example)

def retrieve_few_shot(question, df, k=2):
    matches = df[df["sql_prompt"].str.contains(question.split()[0], case=False, na=False)]
    return matches.head(k)


In [7]:
# DANKE schema + join hint

def danke_hint():
    return """
Tables:
salesperson(salesperson_id, name)
timber_sales(salesperson_id, volume)

Join:
timber_sales.salesperson_id = salesperson.salesperson_id
"""


In [None]:
################################## Four generation modes ##########################

#Just LLM
def run_llm():
    return [generate_sql(q) for q in questions]


# LLM  + DFE(dynamic few-shot example retrieval)
def run_llm_dfe():
    preds = []
    for q in questions:
        examples = retrieve_few_shot(q, df)

        prompt = f"""
Examples:
{examples[['sql_prompt','sql']].to_string(index=False)}

Question: {q}
SQL:
"""
        preds.append(generate_sql(prompt))
    return preds




# LLM + DANKE

def run_llm_danke():
    preds = []
    hint = danke_hint()

    for q in questions:
        prompt = f"""
Database info:
{hint}

Question: {q}
SQL:
"""
        preds.append(generate_sql(prompt))
    return preds




# LLM + DFE + DANKE

def run_full():
    preds = []
    hint = danke_hint()

    for q in questions:
        examples = retrieve_few_shot(q, df)

        prompt = f"""
Database info:
{hint}

Examples:
{examples[['sql_prompt','sql']].to_string(index=False)}

Question: {q}
SQL:
"""
        preds.append(generate_sql(prompt))
    return preds


In [20]:
def evaluate(preds, gold):
    y_true = [1] * len(gold)
    y_pred = [1 if p.strip().lower() == g.strip().lower() else 0 for p, g in zip(preds, gold)]

    return {
        "accuracy": accuracy_score(y_true, y_pred),
        "precision": precision_score(y_true, y_pred),
        "recall": recall_score(y_true, y_pred),
        "f1": f1_score(y_true, y_pred),
    }


In [None]:
results = {}

results["LLM"] = evaluate(run_llm(), gold_sql)
results["LLM + DFE"] = evaluate(run_llm_dfe(), gold_sql)
results["LLM + DANKE"] = evaluate(run_llm_danke(), gold_sql)
results["LLM + DFE + DANKE"] = evaluate(run_full(), gold_sql)

print(pd.DataFrame(results).T)
