# Agent C: SQL Generator
`Generates an SQL query tailored to natural language query and selected schema.`

### Table of Contents
1. [OpenAI Setup](#OpenAI-Setup)
2. [Testing Setup](#testing-setup) 
3. [Generate SQL](#generate-sql)
4. [Test Results](#Test-results)
5. [Putting it all together](#Putting-it-all-together)

In [32]:
# Imports
import sys
import os
import re
import json
import openai
import random
import pandas as pd
from pathlib import Path
from typing import Union

In [2]:
PROJECT_ROOT = Path(__file__).parent.parent.parent if '__file__' in globals() else Path.cwd().parent.parent.parent  
sys.path.append(PROJECT_ROOT)
os.chdir(PROJECT_ROOT)


Current working directory: /Users/lainemulvay/Desktop/Projects/UNI/Capstone/Explainable-Query-Interface-for-Relational-Databases
Notebook folder: /Users/lainemulvay/Desktop/Projects/UNI/Capstone/Explainable-Query-Interface-for-Relational-Databases


___

## OpenAI Setup

Import OpenAI API Key

In [3]:
from src import config

print("OpenAI API key used:", config.OPENAI_API_KEY[:5] + "****")

OpenAI API key used: sk-pr****


make a simple chat completion request

In [4]:
# set your API key from .env
openai.api_key = config.OPENAI_API_KEY

response = openai.chat.completions.create(
    model="gpt-4",
    messages=[{"role": "user", "content": "Hello!"}]
)

# print the model's reply
print(response.choices[0].message.content)

Hello! How can I assist you today?


---

# Testing Setup

### Setting up db_id, natural language query and sql answers for testing.

In [5]:
# Paths
SQL_DATA_PATH = PROJECT_ROOT / "data" / "spider_data" / "train_spider.json"  
OUTPUT_PATH = PROJECT_ROOT / "data" / "interim" / "query_sql_answers.json"


def load_sql_dataset(file_path: Union[str, Path]) -> pd.DataFrame:
    """
    Load the Spider dataset (JSON with a list of records) and return a DataFrame with 
    db_id, query, question.
    """
    file_path = Path(file_path)
    with open(file_path, "r", encoding="utf-8") as f:
        data = json.load(f)  # load entire JSON list

    records = [
        {
            "db_id": rec.get("db_id"),
            "query": rec.get("query"),
            "question": rec.get("question")
        }
        for rec in data
    ]

    df = pd.DataFrame(records)
    return df

df = load_sql_dataset(SQL_DATA_PATH)
print(df.head())

# Save to JSON
OUTPUT_PATH.parent.mkdir(parents=True, exist_ok=True)
df.to_json(OUTPUT_PATH, orient="records", indent=2)
print("Saved simplified queries to /data/interim")

                   db_id                                              query  \
0  department_management         SELECT count(*) FROM head WHERE age  >  56   
1  department_management  SELECT name ,  born_state ,  age FROM head ORD...   
2  department_management  SELECT creation ,  name ,  budget_in_billions ...   
3  department_management  SELECT max(budget_in_billions) ,  min(budget_i...   
4  department_management  SELECT avg(num_employees) FROM department WHER...   

                                            question  
0  How many heads of the departments are older th...  
1  List the name, born state and age of the heads...  
2  List the creation year, name and budget of eac...  
3  What are the maximum and minimum budget of the...  
4  What is the average number of employees of the...  
Saved simplified queries to /data/interim


___

## Generate SQL

In [14]:
import sys
from pathlib import Path
 
NOTEBOOK_FOLDER = Path.cwd() / "notebooks" / "agent-development" / "Laine" # CHANGE WHEN MOVED INTO A SCRIPT
sys.path.append(str(NOTEBOOK_FOLDER))
print("Notebook folder added to sys.path:", NOTEBOOK_FOLDER)

from agent_utils import get_schema_text, load_schemas, schema_text

Notebook folder added to sys.path: /Users/lainemulvay/Desktop/Projects/UNI/Capstone/Explainable-Query-Interface-for-Relational-Databases/notebooks/agent-development/Laine


#### Get the schema

In [11]:
SCHEMA_PATH = PROJECT_ROOT / "data" / "spider_data" / "tables.json"
schemas = load_schemas(SCHEMA_PATH)

db_id = "college_2"
print(get_schema_text(db_id, schemas))

classroom(building, room_number, capacity)
department(dept_name, building, budget)
course(course_id, title, dept_name, credits)
instructor(ID, name, dept_name, salary)
section(course_id, sec_id, semester, year, building, room_number, time_slot_id)
teaches(ID, course_id, sec_id, semester, year)
student(ID, name, dept_name, tot_cred)
takes(ID, course_id, sec_id, semester, year, grade)
advisor(s_ID, i_ID)
time_slot(time_slot_id, day, start_hr, start_min, end_hr, end_min)
prereq(course_id, prereq_id)


#### Feed the question and schema to LLM, get back query

In [None]:

def get_schema_text(db_id: str, schemas: dict) -> str:
    """
    Return the schema text for a given db_id.
    """
    if db_id not in schemas:
        raise ValueError(f"db_id '{db_id}' not found in schemas")
    return schema_text(schemas[db_id])

def generate_sql_from_llm(db_id: str, question: str, schemas: dict, model: str="gpt-4") -> str:
    """
    Generate a SQL query using GPT from a db_id and a natural language question.
    """
    schema_str = get_schema_text(db_id, schemas)
    prompt = f"""
You are an SQL expert. Given the following database schema:

{schema_str}

Write a SQL query that answers the following question:

{question}

Only return the SQL query, nothing else.
"""

    response = openai.chat.completions.create(
        model=model,
        messages=[{"role": "user", "content": prompt}]
    )

    sql_query = response.choices[0].message.content.strip()
    return sql_query

In [23]:
db_id = "department_management"
question = "How many heads of the departments are older than 56 ?"

sql_generated = generate_sql_from_llm(db_id, question, schemas)
print(sql_generated)

SELECT COUNT(*) 
FROM head 
WHERE age > 56;


---

## Test Results

In [19]:
QUERY_JSON_PATH = PROJECT_ROOT / "data" / "interim" / "query_sql_answers.json"

# Load the DataFrame once
df_queries = pd.read_json(QUERY_JSON_PATH)

def get_result(db_id: str, question: str) -> str:
    """
    Given a db_id and question, return the corresponding SQL query from the dataset.
    """
    match = df_queries[
        (df_queries["db_id"] == db_id) & 
        (df_queries["question"] == question)
    ]
    
    if match.empty:
        raise ValueError(f"No query found for db_id='{db_id}' and question='{question}'")
    
    # Return the first match (there should only be one)
    return match.iloc[0]["query"]

db_id = "department_management"
question = "How many heads of the departments are older than 56 ?"

true_sql_query = get_result(db_id, question)
print(true_sql_query)

SELECT count(*) FROM head WHERE age  >  56


This will not match the LLM's generated SQL beucase it is formatted differently, despide still having the correct content. 

In [20]:
print(sql_generated == true_sql_query)

False


Therefore we account for the different formatting

In [28]:
def normalise(sql: str) -> str:
    # Lowercase
    sql = sql.lower()
    # Remove extra whitespace
    sql = re.sub(r"\s+", " ", sql)
    # Strip leading/trailing spaces and optional trailing semicolon
    sql = sql.strip().rstrip(";")
    return sql

def test_llm_query(query1: str, query2: str) -> bool:
    return normalise(query1) == normalise(query2)

# Example usage
test_answer = "SELECT count(*) FROM head WHERE age  >  56"
llm_answer = "SELECT COUNT(*) \nFROM head \nWHERE age > 56;"

print(test_llm_query(true_sql_query, sql_generated))  # True

True


---

## Putting it all together

look at the json file ( run this for a random n quesitons - randomly picked from json file)
take the db_id and query
get the schema and give it to LLM to generate 
test the LLms answer to the acutal answer and see if its true or not
print results

In [34]:

# Load queries JSON (already done above, but included for clarity)
QUERY_JSON_PATH = PROJECT_ROOT / "data" / "interim" / "query_sql_answers.json"
df_queries = pd.read_json(QUERY_JSON_PATH)

def run_random_tests(n: int = 5, model: str = "gpt-4"):
    """
    Run n random tests comparing LLM-generated SQL against ground-truth SQL.
    """
    # Sample n random rows
    sample_rows = df_queries.sample(n, random_state=random.randint(0, 9999))
    
    results = []
    
    for _, row in sample_rows.iterrows():
        db_id = row["db_id"]
        question = row["question"]
        true_query = row["query"]

        print("=" * 80)
        print(f"DB: {db_id}")
        print(f"Question: {question}")
        print(f"True SQL: {true_query}")

        try:
            # Generate SQL from LLM
            llm_query = generate_sql_from_llm(db_id, question, schemas, model=model)
            print(f"LLM SQL:  {llm_query}")

            # Compare normalised queries
            is_match = test_llm_query(true_query, llm_query)
            print(f"Match? {is_match}")

            results.append({
                "db_id": db_id,
                "question": question,
                "true_query": true_query,
                "llm_query": llm_query,
                "match": is_match
            })

        except Exception as e:
            print(f"Error: {e}")
            results.append({
                "db_id": db_id,
                "question": question,
                "true_query": true_query,
                "llm_query": None,
                "match": False,
                "error": str(e)
            })
    
    return pd.DataFrame(results)


# Example usage
results_df = run_random_tests(n=10)
print("\nSummary:")
print(results_df[["db_id", "match"]])

DB: medicine_enzyme_interaction
Question: What are the ids and trade names of the medicine that can interact with at least 3 enzymes?
True SQL: SELECT T1.id ,  T1.trade_name FROM medicine AS T1 JOIN medicine_enzyme_interaction AS T2 ON T2.medicine_id  =  T1.id GROUP BY T1.id HAVING COUNT(*)  >=  3
LLM SQL:  SELECT m.id, m.Trade_Name
FROM medicine m
JOIN medicine_enzyme_interaction mei ON m.id = mei.medicine_id
GROUP BY m.id, m.Trade_Name
HAVING COUNT(mei.enzyme_id) >= 3
Match? False
DB: student_1
Question: Report the number of students in each classroom.
True SQL: SELECT classroom ,  count(*) FROM list GROUP BY classroom
LLM SQL:  SELECT Classroom, COUNT(*) as NumberOfStudents
FROM list
GROUP BY Classroom
Match? False
DB: chinook_1
Question: Count the number of customers that have an email containing "gmail.com".
True SQL: SELECT COUNT(*) FROM CUSTOMER WHERE Email LIKE "%gmail.com%"
LLM SQL:  SELECT COUNT(*) 
FROM Customer 
WHERE Email LIKE '%gmail.com%';
Match? False
DB: soccer_1
Ques