
1. **User Question** → input

2. **Schema Linking**:
    * Extract keywords (tables/columns/values)
    * Retrieve most similar example from synthetic dataset
    * Identify minimal tables required

3. **SQL Compilation**:
    * Create a “view” of these tables
    * Provide example SQL
    * Ask Gemini to generate final SQL

4. **Output** → SQL ready to run


In [None]:
"""
Pipeline for Text-to-SQL using Gemini:

1. User Question → input

2. Schema Linking:
    * Extract keywords (tables/columns/values)
    * Retrieve most similar example from synthetic dataset
    * Identify minimal tables required

3. SQL Compilation:
    * Create a “view” of these tables
    * Provide example SQL
    * Ask Gemini to generate final SQL

4. Output → SQL ready to run
"""



In [3]:
import os
import numpy as np
import pandas as pd
from dotenv import load_dotenv
import faiss
from google import genai

load_dotenv()
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
client = genai.Client(api_key=GEMINI_API_KEY)

In [4]:
dataset = pd.read_csv("synthetic_dataset.csv")   

In [5]:
dataset

Unnamed: 0,id,domain,domain_description,sql_complexity,sql_complexity_description,sql_task_type,sql_task_type_description,sql_prompt,sql_context,sql,sql_explanation
0,5097,forestry,Comprehensive data on sustainable forest manag...,single join,"only one join (specify inner, outer, cross)",analytics and reporting,"generating reports, dashboards, and analytical...",What is the total volume of timber sold by eac...,"CREATE TABLE salesperson (salesperson_id INT, ...","SELECT salesperson_id, name, SUM(volume) as to...","Joins timber_sales and salesperson tables, gro..."
1,5098,defense industry,"Defense contract data, military equipment main...",aggregation,"aggregation functions (COUNT, SUM, AVG, MIN, M...",analytics and reporting,"generating reports, dashboards, and analytical...",List all the unique equipment types and their ...,CREATE TABLE equipment_maintenance (equipment_...,"SELECT equipment_type, SUM(maintenance_frequen...",This query groups the equipment_maintenance ta...
2,5099,marine biology,"Comprehensive data on marine species, oceanogr...",basic SQL,basic SQL with a simple select statement,analytics and reporting,"generating reports, dashboards, and analytical...",How many marine species are found in the South...,"CREATE TABLE marine_species (name VARCHAR(50),...",SELECT COUNT(*) FROM marine_species WHERE loca...,This query counts the number of marine species...
3,5100,financial services,Detailed financial data including investment s...,aggregation,"aggregation functions (COUNT, SUM, AVG, MIN, M...",analytics and reporting,"generating reports, dashboards, and analytical...",What is the total trade value and average pric...,"CREATE TABLE trade_history (id INT, trader_id ...","SELECT trader_id, stock, SUM(price * quantity)...",This query calculates the total trade value an...
4,5101,energy,Energy market data covering renewable energy s...,window functions,"window functions (e.g., ROW_NUMBER, LEAD, LAG,...",analytics and reporting,"generating reports, dashboards, and analytical...",Find the energy efficiency upgrades with the h...,"CREATE TABLE upgrades (id INT, cost FLOAT, typ...","SELECT type, cost FROM (SELECT type, cost, ROW...",The SQL query uses the ROW_NUMBER function to ...
...,...,...,...,...,...,...,...,...,...,...,...
95,5192,disability services,Comprehensive data on disability accommodation...,basic SQL,basic SQL with a simple select statement,analytics and reporting,"generating reports, dashboards, and analytical...",List the programs and their budgets for mobili...,"CREATE TABLE Programs (Program VARCHAR(20), Bu...","SELECT Program, Budget FROM Programs WHERE Typ...",This SQL query lists the programs and their bu...
96,5193,space exploration,"Spacecraft manufacturing data, space mission r...",aggregation,"aggregation functions (COUNT, SUM, AVG, MIN, M...",analytics and reporting,"generating reports, dashboards, and analytical...",List the space missions that have had astronau...,CREATE TABLE SpaceMissions (mission_name VARCH...,SELECT mission_name FROM SpaceMissions WHERE a...,This query lists the space missions that have ...
97,5194,oceans,"Ocean data on marine conservation, ocean acidi...",basic SQL,basic SQL with a simple select statement,analytics and reporting,"generating reports, dashboards, and analytical...",What is the average depth of all marine protec...,CREATE TABLE marine_protected_areas (name VARC...,SELECT AVG(avg_depth) FROM marine_protected_ar...,This query calculates the average depth of all...
98,5195,oil and gas,"Exploration data, production figures, infrastr...",aggregation,"aggregation functions (COUNT, SUM, AVG, MIN, M...",analytics and reporting,"generating reports, dashboards, and analytical...",Calculate the total production in the Southern...,"CREATE TABLE production (well_id INT, type VAR...","SELECT type, SUM(quantity) as total_production...",This SQL query groups the production data by o...


In [6]:
from sentence_transformers import SentenceTransformer

# Load any embedding model
embedder = SentenceTransformer('all-MiniLM-L6-v2')

# Convert dataset questions to embeddings
corpus_embeddings = embedder.encode(dataset['sql_prompt'].tolist(), convert_to_numpy=True)

# Build FAISS index
dimension = corpus_embeddings.shape[1]
index = faiss.IndexFlatL2(dimension)
index.add(corpus_embeddings)


  from .autonotebook import tqdm as notebook_tqdm
Loading weights: 100%|██████████| 103/103 [00:00<00:00, 954.32it/s, Materializing param=pooler.dense.weight]                             
[1mBertModel LOAD REPORT[0m from: sentence-transformers/all-MiniLM-L6-v2
Key                     | Status     |  | 
------------------------+------------+--+-
embeddings.position_ids | UNEXPECTED |  | 

[3mNotes:
- UNEXPECTED[3m	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.[0m


In [None]:
def get_similar_example(user_question, k=1):
    user_emb = embedder.encode([user_question], convert_to_numpy=True)
    distances, indices = index.search(user_emb, k)
    return dataset.iloc[indices[0][0]]   


In [9]:
def schema_linking(user_question, database_schema, k=1):
    """
    - Retrieves keywords using LLM (Gemini)
    - Retrieves top-k similar examples from synthetic dataset
    """
    #  Extract keywords using Gemini
    keyword_prompt = f"""
    Extract important keywords (tables, columns, values) from this question:
    Database schema: {database_schema}
    Question: {user_question}
    Return as list: ["keyword1", "keyword2", ...]
    """
    response = client.models.generate_content(
        model="gemini-3-flash-preview",
        contents=keyword_prompt,
        config={"temperature": 0}
    )
    KM = response.text.strip()  # keywords

    # Retrieve most similar example
    example = get_similar_example(user_question, k=k)
    
    # Extract tables from example SQL (simple parsing for FROM clause)
    sql_from = example['sql'].split('FROM')[1].split()[0].replace(',', '')
    T = [sql_from]

    # Step 4: Return schema linking info
    return KM, T, example



In [10]:
def sql_query_compilation(user_question, KM, tables, example, database_schema):
    """
    Compiles final SQL using:
    - View synthesis (tables)
    - Few-shot example (example['sql'])
    """
    # Step 1: Create a view prompt for Gemini
    view_prompt = f"""
    You are an AI assistant that generates SQL.
    Database schema: {database_schema}
    Tables selected: {tables}
    Example question: {example['sql_prompt']}
    Example SQL: {example['sql']}
    User question: {user_question}
    Generate the SQL query using only the selected tables.
    Return ONLY SQL.
    """
    
    response = client.models.generate_content(
        model="gemini-3-flash-preview",
        contents=view_prompt,
        config={"temperature": 0.2}
    )
    return response.text.strip()


In [11]:
# Example database schema
database_schema = """
Tables:
Customer(id, name, city)
Orders(id, customer_id, amount, order_date)
"""

user_question = "Show all customers from New York who made orders above 500."

# Schema Linking
KM, tables, example = schema_linking(user_question, database_schema)

# SQL Compilation
final_sql = sql_query_compilation(user_question, KM, tables, example, database_schema)
print("Generated SQL:\n", final_sql)


Generated SQL:
 SELECT DISTINCT Customer.id, Customer.name, Customer.city 
FROM Customer 
JOIN Orders ON Customer.id = Orders.customer_id 
WHERE Customer.city = 'New York' AND Orders.amount > 500;
