In [5]:
import os
import faiss
import numpy as np
from dotenv import load_dotenv
from google import genai
import google.generativeai as gemini

from datasets import load_dataset

from sentence_transformers import SentenceTransformer

load_dotenv()
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
client = genai.Client(api_key=GEMINI_API_KEY)

In [None]:
dataset = load_dataset("gretelai/synthetic_text_to_sql", split="train")
dataset = dataset.select(range(100))  

In [40]:
import pandas as pd
print(type(dataset))
df = pd.DataFrame(dataset)
df.head()

<class 'datasets.arrow_dataset.Dataset'>


Unnamed: 0,id,domain,domain_description,sql_complexity,sql_complexity_description,sql_task_type,sql_task_type_description,sql_prompt,sql_context,sql,sql_explanation
0,5097,forestry,Comprehensive data on sustainable forest manag...,single join,"only one join (specify inner, outer, cross)",analytics and reporting,"generating reports, dashboards, and analytical...",What is the total volume of timber sold by eac...,"CREATE TABLE salesperson (salesperson_id INT, ...","SELECT salesperson_id, name, SUM(volume) as to...","Joins timber_sales and salesperson tables, gro..."
1,5098,defense industry,"Defense contract data, military equipment main...",aggregation,"aggregation functions (COUNT, SUM, AVG, MIN, M...",analytics and reporting,"generating reports, dashboards, and analytical...",List all the unique equipment types and their ...,CREATE TABLE equipment_maintenance (equipment_...,"SELECT equipment_type, SUM(maintenance_frequen...",This query groups the equipment_maintenance ta...
2,5099,marine biology,"Comprehensive data on marine species, oceanogr...",basic SQL,basic SQL with a simple select statement,analytics and reporting,"generating reports, dashboards, and analytical...",How many marine species are found in the South...,"CREATE TABLE marine_species (name VARCHAR(50),...",SELECT COUNT(*) FROM marine_species WHERE loca...,This query counts the number of marine species...
3,5100,financial services,Detailed financial data including investment s...,aggregation,"aggregation functions (COUNT, SUM, AVG, MIN, M...",analytics and reporting,"generating reports, dashboards, and analytical...",What is the total trade value and average pric...,"CREATE TABLE trade_history (id INT, trader_id ...","SELECT trader_id, stock, SUM(price * quantity)...",This query calculates the total trade value an...
4,5101,energy,Energy market data covering renewable energy s...,window functions,"window functions (e.g., ROW_NUMBER, LEAD, LAG,...",analytics and reporting,"generating reports, dashboards, and analytical...",Find the energy efficiency upgrades with the h...,"CREATE TABLE upgrades (id INT, cost FLOAT, typ...","SELECT type, cost FROM (SELECT type, cost, ROW...",The SQL query uses the ROW_NUMBER function to ...


In [48]:
df.to_csv("synthetic_dataset.csv", index=False)

In [None]:
embedder = SentenceTransformer('all-MiniLM-L6-v2')  


Loading weights: 100%|██████████| 103/103 [00:00<00:00, 887.19it/s, Materializing param=pooler.dense.weight]                             
[1mBertModel LOAD REPORT[0m from: sentence-transformers/all-MiniLM-L6-v2
Key                     | Status     |  | 
------------------------+------------+--+-
embeddings.position_ids | UNEXPECTED |  | 

[3mNotes:
- UNEXPECTED[3m	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.[0m


In [8]:
embeddings = embedder.encode(dataset["sql_prompt"], convert_to_numpy=True)
dim = embeddings.shape[1]
dim

384

In [9]:
index = faiss.IndexFlatL2(dim)
index.add(embeddings)


In [10]:
user_question = "Which customers placed more than 5 orders last year?"
user_emb = embedder.encode([user_question], convert_to_numpy=True)

In [None]:
k = 1
distances, indices = index.search(user_emb, k)


{'id': 5132,
 'domain': 'cybersecurity',
 'domain_description': 'Threat intelligence data, vulnerability assessments, security incident response metrics, and cybersecurity policy analysis.',
 'sql_complexity': 'aggregation',
 'sql_complexity_description': 'aggregation functions (COUNT, SUM, AVG, MIN, MAX, etc.), and HAVING clause',
 'sql_task_type': 'analytics and reporting',
 'sql_task_type_description': 'generating reports, dashboards, and analytical insights',
 'sql_prompt': 'What is the total number of security incidents that have occurred in each region in the past year?',
 'sql_context': 'CREATE TABLE incident_region(id INT, region VARCHAR(50), incidents INT, incident_date DATE);',
 'sql': 'SELECT region, SUM(incidents) as total_incidents FROM incident_region WHERE incident_date > DATE(NOW()) - INTERVAL 365 DATE GROUP BY region;',
 'sql_explanation': 'The SQL query calculates the total number of security incidents that have occurred in each region in the past year using the SUM()

In [22]:
dataset

Dataset({
    features: ['id', 'domain', 'domain_description', 'sql_complexity', 'sql_complexity_description', 'sql_task_type', 'sql_task_type_description', 'sql_prompt', 'sql_context', 'sql', 'sql_explanation'],
    num_rows: 100
})

In [21]:
example = dataset[indices[0][0]]
example

{'id': 5132,
 'domain': 'cybersecurity',
 'domain_description': 'Threat intelligence data, vulnerability assessments, security incident response metrics, and cybersecurity policy analysis.',
 'sql_complexity': 'aggregation',
 'sql_complexity_description': 'aggregation functions (COUNT, SUM, AVG, MIN, MAX, etc.), and HAVING clause',
 'sql_task_type': 'analytics and reporting',
 'sql_task_type_description': 'generating reports, dashboards, and analytical insights',
 'sql_prompt': 'What is the total number of security incidents that have occurred in each region in the past year?',
 'sql_context': 'CREATE TABLE incident_region(id INT, region VARCHAR(50), incidents INT, incident_date DATE);',
 'sql': 'SELECT region, SUM(incidents) as total_incidents FROM incident_region WHERE incident_date > DATE(NOW()) - INTERVAL 365 DATE GROUP BY region;',
 'sql_explanation': 'The SQL query calculates the total number of security incidents that have occurred in each region in the past year using the SUM()

In [47]:
user_question = "Show all customers from New York who made orders above 500."
user_emb = embedder.encode([user_question], convert_to_numpy=True)
k = 2
distances, indices = index.search(user_emb, k)
results = dataset[indices[0][1]]
results


{'id': 5171,
 'domain': 'civil engineering',
 'domain_description': 'Infrastructure development data, engineering design standards, public works project information, and resilience metrics.',
 'sql_complexity': 'basic SQL',
 'sql_complexity_description': 'basic SQL with a simple select statement',
 'sql_task_type': 'analytics and reporting',
 'sql_task_type_description': 'generating reports, dashboards, and analytical insights',
 'sql_prompt': 'What is the name, height, and number of stories for all buildings in the city of New York with more than 50 floors?',
 'sql_context': "CREATE TABLE Buildings (id INT, name VARCHAR(100), height FLOAT, num_stories INT, city VARCHAR(50)); INSERT INTO Buildings (id, name, height, num_stories, city) VALUES (1, 'Empire State Building', 381, 102, 'New York');",
 'sql': "SELECT name, height, num_stories FROM Buildings WHERE city = 'New York' AND num_stories > 50;",
 'sql_explanation': "This query selects the name, height, and num_stories columns from th

In [12]:
schema_info = """
Tables:
Customer(id, name, city)
Orders(id, customer_id, amount, order_date)
"""

prompt = f"""
You are an AI assistant that generates SQL queries.

Database schema:
{schema_info}

Example question: {example['sql_prompt']}
Example SQL: {example['sql']}

User question: {user_question}
Generate the corresponding SQL query:
"""

In [16]:
response = client.models.generate_content(
    model="gemini-3-flash-preview", contents=prompt, config={"temperature": 0.2}
)
print(response.text)


```sql
SELECT c.name
FROM Customer c
JOIN Orders o ON c.id = o.customer_id
WHERE o.order_date >= DATE(NOW()) - INTERVAL 1 YEAR
GROUP BY c.id, c.name
HAVING COUNT(o.id) > 5;
```
