In [1]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch

In [2]:
def load_model(model_path):
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
    model.eval()
    return tokenizer, model

In [3]:
def generate_sql(model, tokenizer, sql_prompt, sql_context, max_length=128):
    input_text = f"sql_prompt: {sql_prompt} sql_context: {sql_context}"
    inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=512)
    with torch.no_grad():
        output_tokens = model.generate(**inputs, max_new_tokens=max_length, min_new_tokens=5)
    return tokenizer.decode(output_tokens[0], skip_special_tokens=True)

In [4]:
# Load the fine-tuned model
tokenizer, model = load_model("nl2sql_epoch2")  

In [5]:
# Load test dataset
from datasets import load_dataset
dataset = load_dataset("gretelai/synthetic_text_to_sql")   

In [8]:
index = 0
test_sample = dataset["test"][index] 

sql_prompt = test_sample["sql_prompt"]
sql_context = test_sample["sql_context"]
correct_sql = test_sample["sql"]
predicted_sql = generate_sql(model, tokenizer, sql_prompt, sql_context)
print(f"SQL Prompt: {sql_prompt}\n")
print(f"SQL Context: {sql_context}\n")
print(f"Correct SQL: {correct_sql}\n")
print("Generated SQL:", predicted_sql)

SQL Prompt: What is the average explainability score of creative AI applications in 'Europe' and 'North America' in the 'creative_ai' table?

SQL Context: CREATE TABLE creative_ai (application_id INT, name TEXT, region TEXT, explainability_score FLOAT); INSERT INTO creative_ai (application_id, name, region, explainability_score) VALUES (1, 'ApplicationX', 'Europe', 0.87), (2, 'ApplicationY', 'North America', 0.91), (3, 'ApplicationZ', 'Europe', 0.84), (4, 'ApplicationAA', 'North America', 0.93), (5, 'ApplicationAB', 'Europe', 0.89);

Correct SQL: SELECT AVG(explainability_score) FROM creative_ai WHERE region IN ('Europe', 'North America');

Generated SQL: SELECT AVG(explainability_score) FROM creative_ai WHERE region IN ('Europe', 'North America');


In [9]:
index = 1
test_sample = dataset["test"][index] 

sql_prompt = test_sample["sql_prompt"]
sql_context = test_sample["sql_context"]
correct_sql = test_sample["sql"]
predicted_sql = generate_sql(model, tokenizer, sql_prompt, sql_context)
print(f"SQL Prompt: {sql_prompt}\n")
print(f"SQL Context: {sql_context}\n")
print(f"Correct SQL: {correct_sql}\n")
print("Generated SQL:", predicted_sql)

SQL Prompt: Delete all records of rural infrastructure projects in Indonesia that have a completion date before 2010.

SQL Context: CREATE TABLE rural_infrastructure (id INT, project_name TEXT, sector TEXT, country TEXT, completion_date DATE); INSERT INTO rural_infrastructure (id, project_name, sector, country, completion_date) VALUES (1, 'Water Supply Expansion', 'Infrastructure', 'Indonesia', '2008-05-15'), (2, 'Rural Electrification', 'Infrastructure', 'Indonesia', '2012-08-28'), (3, 'Transportation Improvement', 'Infrastructure', 'Indonesia', '2009-12-31');

Correct SQL: DELETE FROM rural_infrastructure WHERE country = 'Indonesia' AND completion_date < '2010-01-01';

Generated SQL: DELETE FROM rural_infrastructure WHERE country = 'Indonesia' AND completion_date < '2010-01-01';


In [11]:
index = 100
test_sample = dataset["test"][index] 

sql_prompt = test_sample["sql_prompt"]
sql_context = test_sample["sql_context"]
correct_sql = test_sample["sql"]
predicted_sql = generate_sql(model, tokenizer, sql_prompt, sql_context)
print(f"SQL Prompt: {sql_prompt}\n")
print(f"SQL Context: {sql_context}\n")
print(f"Correct SQL: {correct_sql}\n")
print("Generated SQL:", predicted_sql)

SQL Prompt: Show the sum of investments by year and industry

SQL Context: CREATE TABLE investments (id INT, investment_year INT, industry VARCHAR(255), investment_amount DECIMAL(10,2)); INSERT INTO investments (id, investment_year, industry, investment_amount) VALUES (1, 2020, 'Tech', 50000.00), (2, 2019, 'Biotech', 20000.00), (3, 2020, 'Tech', 75000.00);

Correct SQL: SELECT investment_year, industry, SUM(investment_amount) as total_investments FROM investments GROUP BY investment_year, industry;

Generated SQL: SELECT investment_year, industry, SUM(investment_amount) as total_investment FROM investments GROUP BY investment_ year, industry;


In [13]:
index = 150
test_sample = dataset["test"][index] 

sql_prompt = test_sample["sql_prompt"]
sql_context = test_sample["sql_context"]
correct_sql = test_sample["sql"]
predicted_sql = generate_sql(model, tokenizer, sql_prompt, sql_context)
print(f"SQL Prompt: {sql_prompt}\n")
print(f"SQL Context: {sql_context}\n")
print(f"Correct SQL: {correct_sql}\n")
print("Generated SQL:", predicted_sql)

SQL Prompt: What is the total number of customer complaints regarding mobile and broadband services by region?

SQL Context: CREATE TABLE complaints (complaint_id INT, complaint_type VARCHAR(255), region VARCHAR(255)); INSERT INTO complaints (complaint_id, complaint_type, region) VALUES (1, 'Mobile', 'North'), (2, 'Broadband', 'South'), (3, 'Mobile', 'East'), (4, 'Broadband', 'West'), (5, 'Mobile', 'North'), (6, 'Broadband', 'South'), (7, 'Mobile', 'East'), (8, 'Broadband', 'West');

Correct SQL: SELECT region, COUNT(*) AS total_complaints FROM complaints WHERE complaint_type IN ('Mobile', 'Broadband') GROUP BY region;

Generated SQL: SELECT region, COUNT(*) as total_complaints FROM complaints WHERE complaint_type IN ('Mobile', 'Broadband') GROUP BY region;


In [14]:
index = 171
test_sample = dataset["test"][index] 

sql_prompt = test_sample["sql_prompt"]
sql_context = test_sample["sql_context"]
correct_sql = test_sample["sql"]
predicted_sql = generate_sql(model, tokenizer, sql_prompt, sql_context)
print(f"SQL Prompt: {sql_prompt}\n")
print(f"SQL Context: {sql_context}\n")
print(f"Correct SQL: {correct_sql}\n")
print("Generated SQL:", predicted_sql)

SQL Prompt: What is the average severity of vulnerabilities in the 'Malware' category?

SQL Context: CREATE TABLE vulnerabilities (id INT, name TEXT, category TEXT, severity TEXT, date_discovered DATE); INSERT INTO vulnerabilities (id, name, category, severity, date_discovered) VALUES (1, 'Remote Code Execution', 'Malware', 'Critical', '2022-01-01');

Correct SQL: SELECT AVG(severity = 'Critical') + AVG(severity = 'High') * 0.75 + AVG(severity = 'Medium') * 0.5 + AVG(severity = 'Low') * 0.25 as average FROM vulnerabilities WHERE category = 'Malware';

Generated SQL: SELECT AVG(severity) FROM vulnerabilities WHERE category = 'Malware';


In [15]:
index = 212
test_sample = dataset["test"][index] 

sql_prompt = test_sample["sql_prompt"]
sql_context = test_sample["sql_context"]
correct_sql = test_sample["sql"]
predicted_sql = generate_sql(model, tokenizer, sql_prompt, sql_context)
print(f"SQL Prompt: {sql_prompt}\n")
print(f"SQL Context: {sql_context}\n")
print(f"Correct SQL: {correct_sql}\n")
print("Generated SQL:", predicted_sql)

SQL Prompt: How many research grants were awarded to the Computer Science department in the year 2020?

SQL Context: CREATE TABLE grant (id INT, department VARCHAR(50), amount INT, grant_date DATE); INSERT INTO grant (id, department, amount, grant_date) VALUES (1, 'Computer Science', 50000, '2020-01-01'), (2, 'Computer Science', 75000, '2020-04-15'), (3, 'Mechanical Engineering', 60000, '2019-12-31');

Correct SQL: SELECT COUNT(*) FROM grant WHERE department = 'Computer Science' AND YEAR(grant_date) = 2020;

Generated SQL: SELECT COUNT(*) FROM grant WHERE department = 'Computer Science' AND YEAR(grant_date) = 2020;


In [16]:
index = 243
test_sample = dataset["test"][index] 

sql_prompt = test_sample["sql_prompt"]
sql_context = test_sample["sql_context"]
correct_sql = test_sample["sql"]
predicted_sql = generate_sql(model, tokenizer, sql_prompt, sql_context)
print(f"SQL Prompt: {sql_prompt}\n")
print(f"SQL Context: {sql_context}\n")
print(f"Correct SQL: {correct_sql}\n")
print("Generated SQL:", predicted_sql)

SQL Prompt: Which cybersecurity policies in the 'cybersecurity_policies' table were last updated on a specific date?

SQL Context: CREATE TABLE cybersecurity_policies (id INT PRIMARY KEY, policy_name TEXT, policy_text TEXT, last_updated DATE);

Correct SQL: SELECT policy_name, last_updated FROM cybersecurity_policies WHERE last_updated = '2022-01-01';

Generated SQL: SELECT policy_name FROM cybersecurity_policies WHERE last_updated = '2022-01-01';


In [17]:
index = 301
test_sample = dataset["test"][index] 

sql_prompt = test_sample["sql_prompt"]
sql_context = test_sample["sql_context"]
correct_sql = test_sample["sql"]
predicted_sql = generate_sql(model, tokenizer, sql_prompt, sql_context)
print(f"SQL Prompt: {sql_prompt}\n")
print(f"SQL Context: {sql_context}\n")
print(f"Correct SQL: {correct_sql}\n")
print("Generated SQL:", predicted_sql)

SQL Prompt: What was the average funding for 'Climate Change' initiatives provided by the US in 2021?

SQL Context: CREATE TABLE USFunding (Funder VARCHAR(50), Sector VARCHAR(50), FundingAmount NUMERIC(15,2), Year INT); INSERT INTO USFunding (Funder, Sector, FundingAmount, Year) VALUES ('US', 'Climate Change', 450000, 2021), ('US', 'Climate Change', 500000, 2021), ('US', 'Climate Change', 350000, 2021);

Correct SQL: SELECT AVG(FundingAmount) FROM USFunding WHERE Sector = 'Climate Change' AND Year = 2021 AND Funder = 'US';

Generated SQL: SELECT AVG(FundingAmount) FROM USFunding WHERE Funder = 'US' AND Sector = 'Climate Change' AND Year = 2021;


In [19]:
index = 334
test_sample = dataset["test"][index] 

sql_prompt = test_sample["sql_prompt"]
sql_context = test_sample["sql_context"]
correct_sql = test_sample["sql"]
predicted_sql = generate_sql(model, tokenizer, sql_prompt, sql_context)
print(f"SQL Prompt: {sql_prompt}\n")
print(f"SQL Context: {sql_context}\n")
print(f"Correct SQL: {correct_sql}\n")
print("Generated SQL:", predicted_sql)

SQL Prompt: List the policy numbers, claim amounts, and claim dates for policies that have more than two claims and the total claim amount exceeds $5000

SQL Context: CREATE TABLE policies (policy_number INT);CREATE TABLE claims (claim_id INT, policy_number INT, claim_amount DECIMAL(10,2), claim_date DATE);

Correct SQL: SELECT p.policy_number, c.claim_amount, c.claim_date FROM policies p INNER JOIN claims c ON p.policy_number = c.policy_number GROUP BY p.policy_number, c.claim_amount, c.claim_date HAVING COUNT(c.claim_id) > 2 AND SUM(c.claim_amount) > 5000;

Generated SQL: SELECT policies.policy_number, claims.claim_amount, claim_date FROM policies INNER JOIN claims ON policy_number IN (SELECT policies.* FROM claims GROUP BY policies.* HAVING COUNT(claim_id) > 2) AND claims.total_claims > 5000;


In [21]:
index = 410
test_sample = dataset["test"][index] 

sql_prompt = test_sample["sql_prompt"]
sql_context = test_sample["sql_context"]
correct_sql = test_sample["sql"]
predicted_sql = generate_sql(model, tokenizer, sql_prompt, sql_context)
print(f"SQL Prompt: {sql_prompt}\n")
print(f"SQL Context: {sql_context}\n")
print(f"Correct SQL: {correct_sql}\n")
print("Generated SQL:", predicted_sql)

SQL Prompt: What is the total number of vehicles sold in 'sales_data' view that have a speed greater than or equal to 80 mph?

SQL Context: CREATE VIEW sales_data AS SELECT id, vehicle_type, avg_speed, sales FROM vehicle_sales WHERE sales > 20000;

Correct SQL: SELECT SUM(sales) FROM sales_data WHERE avg_speed >= 80;

Generated SQL: SELECT SUM(sales) FROM sales_data WHERE avg_speed >= 80;


In [22]:
index = 420
test_sample = dataset["test"][index] 

sql_prompt = test_sample["sql_prompt"]
sql_context = test_sample["sql_context"]
correct_sql = test_sample["sql"]
predicted_sql = generate_sql(model, tokenizer, sql_prompt, sql_context)
print(f"SQL Prompt: {sql_prompt}\n")
print(f"SQL Context: {sql_context}\n")
print(f"Correct SQL: {correct_sql}\n")
print("Generated SQL:", predicted_sql)

SQL Prompt: What is the combined landfill capacity for 'City A' and 'City B'?

SQL Context: CREATE TABLE landfill_capacity (city VARCHAR(255), capacity INT); INSERT INTO landfill_capacity (city, capacity) VALUES ('City A', 500000), ('City B', 600000);

Correct SQL: SELECT SUM(capacity) FROM (SELECT capacity FROM landfill_capacity WHERE city = 'City A' UNION ALL SELECT capacity FROM landfill_capacity WHERE city = 'City B') AS combined_capacity;

Generated SQL: SELECT SUM(capacity) FROM landfill_capacity WHERE city IN ('City A', 'City B');


In [24]:
index = 520
test_sample = dataset["test"][index] 

sql_prompt = test_sample["sql_prompt"]
sql_context = test_sample["sql_context"]
correct_sql = test_sample["sql"]
predicted_sql = generate_sql(model, tokenizer, sql_prompt, sql_context)
print(f"SQL Prompt: {sql_prompt}\n")
print(f"SQL Context: {sql_context}\n")
print(f"Correct SQL: {correct_sql}\n")
print("Generated SQL:", predicted_sql)

SQL Prompt: How many vessels have not had an inspection in the past year?

SQL Context: CREATE TABLE safety_records(id INT, vessel_name VARCHAR(50), inspection_date DATE); CREATE TABLE vessels(id INT, name VARCHAR(50), country VARCHAR(50)); INSERT INTO vessels(id, name, country) VALUES (1, 'Vessel A', 'Philippines'), (2, 'Vessel B', 'Philippines'); INSERT INTO safety_records(id, vessel_name, inspection_date) VALUES (1, 'Vessel A', '2022-01-01');

Correct SQL: SELECT COUNT(*) FROM vessels WHERE name NOT IN (SELECT vessel_name FROM safety_records WHERE inspection_date BETWEEN DATE_SUB(NOW(), INTERVAL 1 YEAR) AND NOW());

Generated SQL: SELECT COUNT(*) FROM vessels WHERE id NOT IN (SELECT vessel_id FROM safety_records WHERE inspection_date >= DATE_SUB(CURRENT_DATE, INTERVAL 1 YEAR));
