# Agent Development 
##### *Final Agent Development Notebook*

## Table of Contents

- [0. Set Up Testing](#0-set-up-testing)
  - [0.1 Load Testing JSON data](#0.1-load-testing-json-data)
  - [0.2 Save Testing Data](#0.2-save-testing-data)
- [1. Agent A: Database Selector](#1-agent-a-database-selector)
  - [1.1 OpenAI Setup](#11-openai-setup)
  - [1.2 Load Schemas from tables.json](#12-load-schemas-from-tablesjson)
  - [1.3 Extract Schemas](#13-extract-schemas)
  - [1.4 Redefine Schemas for LLM](#14-redefine-schemas-for-llm)
  - [1.5 Save Redefined Schemas](#15-save-redefined-schemas)
  - [1.6 Create and Save Schema Embeddings](#16-create-and-save-schema-embeddings)
  - [1.7 Set up Agent A](#17-set-up-agent-a)
  - [1.8 Apply Agent A](#18-apply-agent-a)
  - [1.9 Test Agent A](#19-test-agent-a)
- [2. Agent B: Table & Column Selector](#2-agent-b-table--column-selector)
  - [2.1 Setup for Testing](#21-setup-for-testing)
  - [2.2 Setup Agent B](#22-setup-agent-b)
  - [2.3 Apply Agent B](#23-apply-agent-b)
  - [2.4 Test Agent B](#24-test-agent-b)
- [3. Agent C: SQL Generator](#3-agent-c-sql-generator)
  - [3.1 Setup for Testing](#31-setup-for-testing)
  - [3.2 Setup Agent C](#32-setup-agent-c)
  - [3.3 Apply Agent C](#33-apply-agent-c)
- [4. Archive](#4-archive)

In [86]:

import os
import sys
import sqlite3
import json
import pandas as pd
from pathlib import Path
import pprint
import textwrap
from typing import Dict, Any, List, Tuple, Union
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain_community.vectorstores import FAISS
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain
from openai import OpenAI

In [3]:
PROJECT_ROOT = Path(__file__).parent.parent.parent if '__file__' in globals() else Path.cwd().parent.parent.parent  
sys.path.append(PROJECT_ROOT)

## 0 Set Up Testing

### 0.1 Load Testing JSON data

In [223]:
SQL_DATA_PATH = PROJECT_ROOT / "data" / "spider_data" / "train_spider.json"  
SQL_TESTING = PROJECT_ROOT / "data" / "test" / "spider_query_answers.json"

In [139]:
def load_sql_dataset(file_path: Union[str, Path]) -> pd.DataFrame:
    """
    Load the Spider dataset (JSON with a list of records) and return a DataFrame with 
    db_id, query, question.
    """
    file_path = Path(file_path)
    with open(file_path, "r", encoding="utf-8") as f:
        data = json.load(f)  # load entire JSON list

    records = [
        {
            "db_id": rec.get("db_id"),
            "question": rec.get("question"),
            "query": rec.get("query"),
            "guery_toks": rec.get("query_toks") # list
        }
        for rec in data
    ]

    df = pd.DataFrame(records)
    return df

sql_answers_df = load_sql_dataset(SQL_DATA_PATH)
sql_answers_df.head()



Unnamed: 0,db_id,question,query,guery_toks
0,department_management,How many heads of the departments are older th...,SELECT count(*) FROM head WHERE age > 56,"[SELECT, count, (, *, ), FROM, head, WHERE, ag..."
1,department_management,"List the name, born state and age of the heads...","SELECT name , born_state , age FROM head ORD...","[SELECT, name, ,, born_state, ,, age, FROM, he..."
2,department_management,"List the creation year, name and budget of eac...","SELECT creation , name , budget_in_billions ...","[SELECT, creation, ,, name, ,, budget_in_billi..."
3,department_management,What are the maximum and minimum budget of the...,"SELECT max(budget_in_billions) , min(budget_i...","[SELECT, max, (, budget_in_billions, ), ,, min..."
4,department_management,What is the average number of employees of the...,SELECT avg(num_employees) FROM department WHER...,"[SELECT, avg, (, num_employees, ), FROM, depar..."


### 0.2 Save Testing Data

In [49]:
SQL_TESTING.parent.mkdir(parents=True, exist_ok=True)
df.to_json(SQL_TESTING, orient="records", indent=2)
print(f"Saved simplified queries to {SQL_TESTING}")

Saved simplified queries to /Users/lainemulvay/Desktop/Projects/UNI/Capstone/Explainable-Query-Interface-for-Relational-Databases/data/test/spider_query_answers.json


## 1 Agent  A:  Database  Selector:
`Chooses  the  most  relevant  database  based  on  the  user’s  natural 
language query.`

### 1.1 OpenAI Setup

Import OpenAI API key

In [4]:
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "not-set")
print("OpenAI API key used:", OPENAI_API_KEY[:5] + "****")

OpenAI API key used: sk-pr****


Test OpenAI API

In [5]:
client = OpenAI()
response = client.responses.create (
    model = "gpt-5-mini",
    input = "how much gold would it take to coat the statue of liberty in a 1mm layer? Answer concisely",
    reasoning = {
        "effort": "minimal"
    }
)
print(response.output[1].content[0].text)

Rough estimate:

- Surface area of Statue of Liberty (figure + torch + pedestal excluded) ≈ 11,000 ft² ≈ 1,022 m² (common cited range 1,000–1,200 m²). Use 1,000 m² for simplicity.
- Volume of 1 mm (0.001 m) layer = area × thickness = 1,000 m² × 0.001 m = 1 m³.
- Gold density ≈ 19,320 kg/m³ → mass ≈ 19,320 kg ≈ 19.3 metric tonnes.
- Current gold price (approx) ≈ $65,000 per kg (price fluctuates; using ≈ $65k/kg → $1.25 billion). More commonly quoted spot ~$60–70k/kg; at $65k/kg 19,320 kg ≈ $1.26 billion.

So about 19 tonnes of gold (≈1 m³), costing on the order of $1–1.5 billion depending on spot price.


### 1.2 Load Schemas from tables.json

In [6]:
SCHEMA_PATH = PROJECT_ROOT / "data" / "spider_data" / "tables.json"

# read schema file with debug prints and file not found handling
def load_schemas(path: Path) -> Dict[str, Dict[str, Any]]:
    """
    Load the schema file from the given path and return a dictionary of schema objects.
    Includes debug output and file not found handling.
    """
    try:
        with open(path, "r", encoding="utf-8") as f:
            schemas_json = json.load(f)
        print(f"Loaded {len(schemas_json)} schema entries from file '{path}'.")
        return schemas_json
    except FileNotFoundError:
        print(f"Error: Schema file '{path}' not found.")
        return {}

schemas_json = load_schemas(SCHEMA_PATH)

# get db_ids from the schemas_json
db_ids = [entry.get('db_id', 'undefined') for entry in schemas_json if isinstance(entry, dict)]

if db_ids:
    print("First 10 Database ids:", db_ids[:10])
else:
    print("Schema loading failed. No database ids available.")

Loaded 166 schema entries from file '/Users/lainemulvay/Desktop/Projects/UNI/Capstone/Explainable-Query-Interface-for-Relational-Databases/data/spider_data/tables.json'.
First 10 Database ids: ['perpetrator', 'college_2', 'flight_company', 'icfp_1', 'body_builder', 'storm_record', 'pilot_record', 'race_track', 'academic', 'department_store']


### 1.3 Extract Schemas

In [7]:
# Here I've extract only the db_id, table_names, and column_names from tables.json
def extract_essential_schema(tables_data):
    essential_data = []
    for entry in tables_data:
        simplified_entry = {
            'database_name': entry.get('db_id', 'undefined'),
            'table_names': entry.get('table_names', []),
            'column_names': entry.get('column_names', [])
        }
        essential_data.append(simplified_entry)
    return essential_data

# Apply the extraction if tables_data is loaded
if 'schemas_json' in locals():
    essential_schemas = extract_essential_schema(schemas_json)
    print(f" Extracted data for {len(essential_schemas)} database schemas")

print("\nFirst datbase schema:")
essential_schemas[0]

#     # Show example of the simplified structure
#     if essential_schemas:
#         print(f"\n Example of simplified entry:")
#         example = essential_schemas[0]
#         print(f"  database_name: {example['database_name']}")
#         print(f"  table_names: {example['table_names']}")
#         print(f"  column_names (first 3): {example['column_names'][:3]}...")
#         print(f"  Total columns: {len(example['column_names'])}")
    
# else:
#     print("tables_data not found.")

 Extracted data for 166 database schemas

First datbase schema:


{'database_name': 'perpetrator',
 'table_names': ['perpetrator', 'people'],
 'column_names': [[-1, '*'],
  [0, 'perpetrator id'],
  [0, 'people id'],
  [0, 'date'],
  [0, 'year'],
  [0, 'location'],
  [0, 'country'],
  [0, 'killed'],
  [0, 'injured'],
  [1, 'people id'],
  [1, 'name'],
  [1, 'height'],
  [1, 'weight'],
  [1, 'home town']]}

### 1.4 Redefine Schemas for LLM

In [8]:
def reshape_with_headings(essential_schemas):
    """Add descriptive headings to make schema more LLM-friendly"""
    out = {}
    for db in essential_schemas:
        db_name = db.get("database_name", "unknown")
        table_names = list(db.get("table_names", []))
        col_specs = list(db.get("column_names", []))
        tables = []
        
        for idx, table_name in enumerate(table_names):
            cols = []
            for pair in col_specs:
                if not isinstance(pair, (list, tuple)) or len(pair) != 2:
                    continue
                t_idx, col = pair
                try:
                    t_idx = int(t_idx)
                except (ValueError, TypeError):
                    continue
                if t_idx != idx:
                    continue
                if col is None or str(col).strip() == "*" or t_idx < 0:
                    continue
                cols.append(str(col))

            tables.append({
                "table_name": table_name,
                "columns": cols,
            })

        out[db_name] = {
           "database_name": db_name,
            "tables": tables
        }
    return out


# Reshape with headings
reshaped_schemas = reshape_with_headings(essential_schemas)
print(f"Reshaped {len(reshaped_schemas)} databases with descriptive headings")
reshaped_schemas


Reshaped 166 databases with descriptive headings


{'perpetrator': {'database_name': 'perpetrator',
  'tables': [{'table_name': 'perpetrator',
    'columns': ['perpetrator id',
     'people id',
     'date',
     'year',
     'location',
     'country',
     'killed',
     'injured']},
   {'table_name': 'people',
    'columns': ['people id', 'name', 'height', 'weight', 'home town']}]},
 'college_2': {'database_name': 'college_2',
  'tables': [{'table_name': 'classroom',
    'columns': ['building', 'room number', 'capacity']},
   {'table_name': 'department',
    'columns': ['department name', 'building', 'budget']},
   {'table_name': 'course',
    'columns': ['course id', 'title', 'department name', 'credits']},
   {'table_name': 'instructor',
    'columns': ['id', 'name', 'department name', 'salary']},
   {'table_name': 'section',
    'columns': ['course id',
     'section id',
     'semester',
     'year',
     'building',
     'room number',
     'time slot id']},
   {'table_name': 'teaches',
    'columns': ['id', 'course id', 'secti

In [9]:
#Below function to finalise the format of the schema, into DB-TABLE-Column names format
def format_schema_jsonish(reshaped_essential_schemas):
    lines = []
    for _, db in reshaped_essential_schemas.items():
        db_name = db.get("database_name", "unknown")
        for t in db.get("tables", []):
            obj = {
                "database": db_name,
                "table": t.get("table_name", "unknown"),
                "columns": t.get("columns", [])
            }
            lines.append(json.dumps(obj, ensure_ascii=False, separators=(",", ":")))
    return lines

format_schema_jsonish(reshaped_schemas)

['{"database":"perpetrator","table":"perpetrator","columns":["perpetrator id","people id","date","year","location","country","killed","injured"]}',
 '{"database":"perpetrator","table":"people","columns":["people id","name","height","weight","home town"]}',
 '{"database":"college_2","table":"classroom","columns":["building","room number","capacity"]}',
 '{"database":"college_2","table":"department","columns":["department name","building","budget"]}',
 '{"database":"college_2","table":"course","columns":["course id","title","department name","credits"]}',
 '{"database":"college_2","table":"instructor","columns":["id","name","department name","salary"]}',
 '{"database":"college_2","table":"section","columns":["course id","section id","semester","year","building","room number","time slot id"]}',
 '{"database":"college_2","table":"teaches","columns":["id","course id","section id","semester","year"]}',
 '{"database":"college_2","table":"student","columns":["id","name","department name","tota

### 1.5 Save Redefined Schemas

In [None]:
SCHEMA_OUTPUT_DIR = PROJECT_ROOT / "data" / "processed"
SCHEMA_OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
SCHEMA_PROCESSED_FILE = SCHEMA_OUTPUT_DIR / "spider_schemas_processed.jsonl"  # output

In [11]:
def save_processed_schema(reshaped_essential_schemas, output_file):
    """Save formatted schema into a JSONL file (one table per line)."""
    lines = format_schema_jsonish(reshaped_essential_schemas)
    with open(output_file, "w", encoding="utf-8") as f:
        for line in lines:
            f.write(line + "\n")
    print(f"Saved {len(lines)} schema entries to {output_file}")

save_processed_schema(reshaped_schemas, SCHEMA_PROCESSED_FILE)

Saved 876 schema entries to /Users/lainemulvay/Desktop/Projects/UNI/Capstone/Explainable-Query-Interface-for-Relational-Databases/data/processed/spider_schemas_processed.jsonl


### 1.6 Create and Save Schema Embeddings from Redefined Schema

In [12]:
EMBEDDINGS_FOLDER = SCHEMA_OUTPUT_DIR / "spider_schemas_embeddings"

In [13]:
# Load processed schema from JSONL
def load_processed_schema(input_file):
    """Load processed schema JSONL back into a list of strings."""
    with open(input_file, "r", encoding="utf-8") as f:
        lines = [line.strip() for line in f if line.strip()]
    return lines

# Use it
final_schema_result = load_processed_schema(SCHEMA_PROCESSED_FILE)
print(f"Loaded {len(final_schema_result)} entries from {SCHEMA_PROCESSED_FILE}")

# Step 1: create embeddings + vectorstore
embeddings = OpenAIEmbeddings()
vectorstore = FAISS.from_texts(final_schema_result, embeddings)

# Step 2: save embeddings
vectorstore.save_local(str(EMBEDDINGS_FOLDER))
print(f"Saved schema embeddings to {EMBEDDINGS_FOLDER}")

Loaded 876 entries from /Users/lainemulvay/Desktop/Projects/UNI/Capstone/Explainable-Query-Interface-for-Relational-Databases/data/processed/spider_schemas_processed.jsonl
Saved schema embeddings to /Users/lainemulvay/Desktop/Projects/UNI/Capstone/Explainable-Query-Interface-for-Relational-Databases/data/processed/spider_schemas_embeddings


### 1.7 Set up Agent A

In [319]:
def create_database_selection_agent(top_k):
    """Create the database selection agent using a prebuilt FAISS vectorstore"""
    
    prompt_db = PromptTemplate(
        input_variables=["query", "retrieved_schema"],
        template="""
Please select the single most relevant database and table to answer the user's query.

User query: {query}
Schema info: {retrieved_schema}

Respond **only** with a valid JSON object (no backticks, no extra text). 
The JSON must include the following keys: "db_name", "tables", "columns", and "reasons". 
Each key should appear on its own line for readability.

Example format:

{{
  "db_name": "...",
  "tables": ["..."],
  "columns": ["..."],
  "reasons": "..."
}}
"""
    )

    db_chain = prompt_db | llm

    def database_selection_agent(user_query, top_k, mode="medium"):
        # Retrieve relevant schemas
        relevant_docs = vectorstore.similarity_search_with_score(user_query, k=top_k)
        retrieved_schema = ""
        for doc, score in relevant_docs:
            retrieved_schema += f"score: {score:.4f}, content: {doc.page_content}\n"

        # Invoke LLM
        response = db_chain.invoke({
            "query": user_query,
            "retrieved_schema": retrieved_schema
        })

        # Parse JSON output safely
        llm_content = response.content if hasattr(response, "content") else str(response)
        try:
            parsed = json.loads(llm_content)
        except json.JSONDecodeError:
            parsed = {}

        # Transform retrieved_schema string into structured list
        structured_schema = []
        for doc, score in relevant_docs:
            try:
                schema_json = json.loads(doc.page_content)
                structured_schema.append({
                    "score": round(score, 4),
                    "database": schema_json.get("database"),
                    "table": schema_json.get("table"),
                    "columns": schema_json.get("columns", [])
                })
            except json.JSONDecodeError:
                structured_schema.append({
                    "score": round(score, 4),
                    "raw_content": doc.page_content
                })

        # Then in your return block
        if mode == "light":
            return parsed.get("db_name")
        elif mode == "medium":
            return {
                "retrieved_schema": structured_schema,
                "db_name": parsed.get("db_name"),
                "tables": parsed.get("tables", []),
                "columns": parsed.get("columns", []),
                "reasons": parsed.get("reasons", ""),
            }
        else:  # heavy
            return {
                "db_name": parsed.get("db_name"),
                "retrieved_schema": structured_schema,
                "tables": parsed.get("tables", []),
                "columns": parsed.get("columns", []),
                "reasons": parsed.get("reasons", ""),
                "llm_raw": response,
            }

    return database_selection_agent, vectorstore

### 1.8 Apply Agent A

In [66]:
test_queries = [
    "Find the name of all students who were in the tryout sorted in alphabetic order",
    "Find the average price of all product clothes.",
    "Show the names of artworks in ascending order of the year they are nominated in.",
    "What is the name of the department with the student that has the lowest GPA?",
    "What are the names and years of the movies that has the top 3 highest rating star?",
    "How many students does each advisor have?",
    "Count flights departing from Dallas in 2017",
    "What are the distinct creation years of the departments managed by a secretary born in state 'Alabama'?",
    "List courses worth more than 3 credits and their departments",
    "For each customer, compute total order value and sort desc.",
    "select all the deaths caused by ship",
    "Show me information about singers and their concerts",
    "I want to see student enrollment data",
    "Find information about car manufacturers and models",
    "What data do you have about movies and actors?",
    "Show me employee salary information",
    "Which produce has the most complaints where the status are still open"
]

In [67]:
# Load embeddings/vectorstore from folder using the new function
vectorstore = load_schema_embeddings(EMBEDDINGS_FOLDER)

# Set up LLM
llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)


Embeddings loaded with 876 entries


In [320]:
def apply_database_selector(query_numbers=None, mode="medium"):
    """
    Apply the database selection agent to one or more queries.

    Parameters:
        query_numbers (list or int, optional): 
            - int: apply to that single test query (1-based index)
            - list of ints: apply to multiple test queries
            - None: apply to all test queries
        mode (str): "light", "medium", or "heavy" for the agent output
    """
    # Create the agent
    db_agent, vectorstore = create_database_selection_agent(top_k=5)
    
    print("="*60)
    print("APPLYING DATABASE SELECTION AGENT")
    print("="*60)
    
    # Determine which queries to run
    if query_numbers is None:
        query_indices = range(len(test_queries))
    elif isinstance(query_numbers, int):
        query_indices = [query_numbers - 1]
    else:
        query_indices = [i-1 for i in query_numbers]
    
    # Apply the agent to the selected queries
    for i in query_indices:
        query = test_queries[i]
        print(f"\nApplying to Query {i+1}: {query}")
        print("-" * 50)
        
        # Call the agent with the selected mode
        result = db_agent(query, top_k=5, mode=mode)
        
        print("Agent A Selection:")
        pprint.pprint(result)
        
        print("\n" + "="*60)

apply_database_selector(query_numbers=2, mode="medium")

APPLYING DATABASE SELECTION AGENT

Applying to Query 2: Find the average price of all product clothes.
--------------------------------------------------
Agent A Selection:
{'columns': ['product id',
             'product type code',
             'product name',
             'product price'],
 'db_name': 'department_store',
 'reasons': 'This table contains the product price information specifically '
            'for clothing products, which is needed to calculate the average '
            'price.',
 'retrieved_schema': [{'columns': ['product id',
                                   'color code',
                                   'product category code',
                                   'product name',
                                   'typical buying price',
                                   'typical selling price',
                                   'product description',
                                   'other product details'],
                       'database': 'products_gen

### 1.9 Test Agent A

In [140]:
# Pick every X questions
every_x_th = 250
questions_subset = sql_answers_df.iloc[::every_x_th]["question"]
print(f"Number of questions we are testing: {len(questions_subset)}")
print(f"Total number of questions: {len(sql_answers_df)}")

Number of questions we are testing: 28
Total number of questions: 7000


In [148]:
def apply_agent_a_on_test_set(every_x_th):
    """
    Picks every Xth question from sql_answers_df, applies the database selection agent,
    and returns a list of results with question, db_name, and reason.
    Also prints a progress tracker showing how many questions have been processed out of the total.
    """
    questions_subset = sql_answers_df.iloc[::every_x_th].to_dict(orient="records")
    total = len(questions_subset)
    db_name_results = []

    for idx, item in enumerate(questions_subset, 1):
        question = item['question']
        #print(f"Question: {question}")
        # Database selection
        db_agent, _ = create_database_selection_agent(top_k=5)
        result_1 = db_agent(question, top_k=5)
        llm_output_1 = (
            result_1['llm_selection'].content
            if hasattr(result_1['llm_selection'], 'content')
            else str(result_1['llm_selection'])
        )
        #print(f"LLM Output: {llm_output_1}")
        llm_output_json = json.loads(llm_output_1)
        db_name = str(llm_output_json["db_name"])
        reason = llm_output_json.get("reasons", "")
        #print(f"DB Name: {db_name}")
        # Store results
        db_name_results.append({
            'question': question,
            'db_name': db_name,
            'reason': reason
        })
        # Progress tracker
        print(f"Applied Agent A (selected databases) on {idx}/{total} test questions")

    return db_name_results

# Call the function and print results
db_name_results = apply_agent_a_on_test_set(every_x_th)
print("-" * 50)
pprint.pprint(db_name_results[:5])

Applied Agent A (selected databases) on 1/28 test questions
Applied Agent A (selected databases) on 2/28 test questions
Applied Agent A (selected databases) on 3/28 test questions
Applied Agent A (selected databases) on 4/28 test questions
Applied Agent A (selected databases) on 5/28 test questions
Applied Agent A (selected databases) on 6/28 test questions
Applied Agent A (selected databases) on 7/28 test questions
Applied Agent A (selected databases) on 8/28 test questions
Applied Agent A (selected databases) on 9/28 test questions
Applied Agent A (selected databases) on 10/28 test questions
Applied Agent A (selected databases) on 11/28 test questions
Applied Agent A (selected databases) on 12/28 test questions
Applied Agent A (selected databases) on 13/28 test questions
Applied Agent A (selected databases) on 14/28 test questions
Applied Agent A (selected databases) on 15/28 test questions
Applied Agent A (selected databases) on 16/28 test questions
Applied Agent A (selected databas

In [152]:
# Extract ground truth answers for Agent A - ground truth database names
ground_truth = sql_answers_df.iloc[::every_x_th][["question", "db_id"]].reset_index(drop=True)
# print(ground_truth.head())

# compare results from agent a test set with ground truth
agent_a_test_results_df = pd.DataFrame({
    'question': [row['question'] for row in db_name_results],
    'predicted_db': [row['db_name'] for row in db_name_results],
    'ground_truth_db': ground_truth['db_id'],
})

agent_a_test_results_df['is_correct'] = agent_a_test_results_df['predicted_db'] == agent_a_test_results_df['ground_truth_db']
agent_a_test_results_df = agent_a_test_results_df[['is_correct', 'question', 'predicted_db', 'ground_truth_db']]

agent_a_test_results_df


Unnamed: 0,is_correct,question,predicted_db,ground_truth_db
0,True,How many heads of the departments are older th...,department_management,department_management
1,True,Show names of actors and names of musicals the...,musical,musical
2,False,How many students does each advisor have?,dorm_1,allergy_1
3,True,What are the names and locations of all tracks?,race_track,race_track
4,False,Return the total and minimum enrollments acros...,school_player,university_basketball
5,True,Return all the apartment numbers sorted by the...,apartment_rentals,apartment_rentals
6,True,Show the distinct venues of debates,debate,debate
7,True,What is the age of the tallest person?,gymnast,gymnast
8,False,What are the names and headquarters of all com...,company_office,gas_company
9,True,What is the team with at least 2 technicians?,machine_repair,machine_repair


In [150]:
# print the percentage of correct predictions
print(f"Percentage correct: {agent_a_test_results_df['is_correct'].mean():.2%}")

Percentage correct: 71.43%


In [153]:
# Print all incorrect predictions for database selection using the new results DataFrame
incorrect = agent_a_test_results_df[~agent_a_test_results_df['is_correct']]
for idx, row in incorrect.iterrows():
    question = row['question']
    ground_truth = row['ground_truth_db']
    prediction = row['predicted_db']
    print(f"Q: {question}\nTruth: {ground_truth} | Pred: {prediction}\n{'-'*60}")

Q: How many students does each advisor have?
Truth: allergy_1 | Pred: dorm_1
------------------------------------------------------------
Q: Return the total and minimum enrollments across all schools.
Truth: university_basketball | Pred: school_player
------------------------------------------------------------
Q: What are the names and headquarters of all companies ordered by descending market value?
Truth: gas_company | Pred: company_office
------------------------------------------------------------
Q: What are the names and years of the movies that has the top 3 highest rating star?
Truth: movie_1 | Pred: imdb
------------------------------------------------------------
Q: How many clubs are there?
Truth: club_1 | Pred: riding_club
------------------------------------------------------------
Q: For each advisor, report the total number of students advised by him or her.
Truth: voter_2 | Pred: college_2
------------------------------------------------------------
Q: How many differ

## 2 Agent B: Table & Column Selector
`Identifies the appropriate tables and columns within the selected 
schema that are needed to answer the query.`

### 2.1 Setup for Testing

Not possible yet, no gorund truth for tables and column selection

### 2.2 Setup Agent B

In [202]:
list_tables_prompt = PromptTemplate(
    input_variables=["user_query", "db_schema_json"],
    template = """
Given the relevant database schema, return the tables that are most relevant to the user's query.

User query: {user_query}
DB schema JSON: {db_schema_json}

Respond ONLY with a valid JSON object (no extra text, no backticks). 
The JSON must include the following keys: "db_name", "tables", "columns", and "reasons". 
Each key should appear on its own line for readability. 

Example format (output must match this structure exactly): 

{{
  "relevant_tables": ["..."],
  "reasons": "..."
}}


Do not include any text outside the JSON object.
"""
)
#db_chain_2 = LLMChain(llm=llm, prompt=list_tables_prompt)
db_chain_2 = list_tables_prompt | llm

In [None]:
SCHEMA_PROCESSED_FILE = SCHEMA_OUTPUT_DIR / "spider_schemas_processed.jsonl"

In [287]:
def agent_b(user_query, db_name, mode="medium"):
    """
    mode options:
      - "light": only tables and columns
      - "medium": user query, db_name, tables, columns, reasons
      - "heavy": full schema + raw LLM response
    """

    # Load schema lines
    with open(SCHEMA_PROCESSED_FILE, "r", encoding="utf-8") as f:
        schema_lines = f.readlines()

    # Parse JSON lines and filter for the selected database
    full_schema = [
        json.loads(line)
        for line in schema_lines
        if json.loads(line)["database"] == db_name
    ]

    # Run LLM
    response_2 = db_chain_2.invoke({
        "user_query": user_query,
        "db_schema_json": full_schema
    })

    # Parse LLM output into dict
    llm_selection_content = (
        response_2.content if hasattr(response_2, "content") else str(response_2)
    )
    parsed = json.loads(llm_selection_content)

    if mode == "light":
        return {
            "Tables": parsed.get("tables", []),
            "Columns": parsed.get("columns", []),
        }

    elif mode == "medium":  # fixed spelling
        return {
            "User Query": user_query,
            "Database Name": parsed.get("db_name", db_name),
            "Tables": parsed.get("tables", []),
            "Columns": parsed.get("columns", []),
            "Reasons": parsed.get("reasons", "")
        }

    elif mode == "heavy":
        return {
            "User Query": user_query,
            "Database Name": db_name,
            "schema": full_schema,
            "llm_result_raw": response_2,
        }

    else:
        raise ValueError(f"Unknown mode: {mode}")

### 2.3 Apply Agent B

In [288]:
agent_b("How many heads of the departments are older than 56 ?", 'department_management', mode="heavy")

{'User Query': 'How many heads of the departments are older than 56 ?',
 'Database Name': 'department_management',
 'schema': [{'database': 'department_management',
   'table': 'department',
   'columns': ['department id',
    'name',
    'creation',
    'ranking',
    'budget in billions',
    'num employees']},
  {'database': 'department_management',
   'table': 'head',
   'columns': ['head id', 'name', 'born state', 'age']},
  {'database': 'department_management',
   'table': 'management',
   'columns': ['department id', 'head id', 'temporary acting']}],
 'llm_result_raw': AIMessage(content='{\n  "db_name": "department_management",\n  "tables": ["head"],\n  "columns": ["head id", "name", "age"],\n  "reasons": "The query specifically asks about the heads of departments and their ages, making the \'head\' table the most relevant."\n}', additional_kwargs={'refusal': None}, response_metadata={'token_usage': {'completion_tokens': 61, 'prompt_tokens': 236, 'total_tokens': 297, 'completion

In [237]:
# Pick every X questions
every_y_th = 2500
questions_subset = sql_answers_df.iloc[::every_y_th]
print(f"Number of questions we are testing: {len(questions_subset)}")
print(f"Total number of questions: {len(sql_answers_df)}")

Number of questions we are testing: 3
Total number of questions: 7000


In [289]:
for idx, row in questions_subset.iterrows():
    question = row['question']
    db_id = row['db_id']
    result = agent_b(question, db_id, mode="medium")
    print("\nAgent B Table Selection:")
    pprint.pprint(result, sort_dicts=False)  # keeps insertion order
    print("\n" + "="*60)


Agent B Table Selection:
{'User Query': 'How many heads of the departments are older than 56 ?',
 'Database Name': 'department_management',
 'Tables': ['head'],
 'Columns': ['head id', 'name', 'age'],
 'Reasons': 'The query specifically asks about the heads of departments and '
            "their ages, making the 'head' table the most relevant."}


Agent B Table Selection:
{'User Query': 'What are the names and years of the movies that has the top 3 '
               'highest rating star?',
 'Database Name': 'movie_1',
 'Tables': ['movie', 'rating'],
 'Columns': ['title', 'year', 'rating stars'],
 'Reasons': "The 'movie' table contains the movie titles and years, while the "
            "'rating' table contains the rating stars needed to identify the "
            'top rated movies.'}


Agent B Table Selection:
{'User Query': 'Find the name of all students who were in the tryout sorted in '
               'alphabetic order.',
 'Database Name': 'soccer_2',
 'Tables': ['player', 'tryout'

### 2.4 Test Agent B

Not possible yet, no gorund truth for tables and column selection

## 3 Agent C: SQL Generator
`Generates an SQL query tailored to natural language query and selected schema.`

### 3.1 Setup for Testing

In [234]:
SQL_TESTING = PROJECT_ROOT / "data" / "test" / "spider_query_answers.json"

In [229]:
with open(SQL_TESTING, "r") as f:
    data = json.load(f)
df = pd.DataFrame(data)
df.head()

Unnamed: 0,db_id,question,query,guery_toks
0,department_management,How many heads of the departments are older th...,SELECT count(*) FROM head WHERE age > 56,"[SELECT, count, (, *, ), FROM, head, WHERE, ag..."
1,department_management,"List the name, born state and age of the heads...","SELECT name , born_state , age FROM head ORD...","[SELECT, name, ,, born_state, ,, age, FROM, he..."
2,department_management,"List the creation year, name and budget of eac...","SELECT creation , name , budget_in_billions ...","[SELECT, creation, ,, name, ,, budget_in_billi..."
3,department_management,What are the maximum and minimum budget of the...,"SELECT max(budget_in_billions) , min(budget_i...","[SELECT, max, (, budget_in_billions, ), ,, min..."
4,department_management,What is the average number of employees of the...,SELECT avg(num_employees) FROM department WHER...,"[SELECT, avg, (, num_employees, ), FROM, depar..."


In [None]:
def get_true_query_toks(db_id, question, data=SQL_TESTING):
    """
    Retrieve the 'guery_toks' for a given db_id and question from the test data file.

    Args:
        db_id (str): The database ID to match.
        question (str): The question string to match.
        data_file (str or Path): Path to the JSON test data file.

    Returns:
        list: The list of SQL tokens ('guery_toks') if found, else None.
    """
    import json

    with open(data, "r") as f:
        data = json.load(f)

    for item in data:
        if item.get("db_id") == db_id and item.get("question") == question:
            return item.get("guery_toks")
    return None

In [231]:
def compare_sql_to_ground_truth(sql_query, db_id, question, data=SQL_TESTING):
    """
    Tokenizes the given SQL query and compares its tokens to the ground truth tokens
    for the specified db_id and question.

    Args:
        sql_query (str): The SQL query to compare.
        db_id (str): The database ID.
        question (str): The natural language question.
        data (str or Path): Path to the JSON test data file.

    Returns:
        bool: True if the tokenized SQL matches the ground truth tokens, False otherwise.
    """
    import re

    def tokenize_sql(sql):
        # Simple SQL tokenizer: split on whitespace and punctuation
        return [tok for tok in re.split(r"(\W)", sql) if tok.strip()]

    gt_tokens = get_true_query_toks(db_id, question, data)
    if gt_tokens is None:
        raise ValueError(f"No ground truth found for db_id={db_id}, question={question}")

    sql_tokens = tokenize_sql(sql_query)
    return sql_tokens == gt_tokens

### 3.2 Setup Agent C

In [None]:
generate_sql_prompt = PromptTemplate(
    input_variables=["user_query", "db_schema_json", "reccomended_tables"],
    template = """
You are an SQL expert. Given the following database schema:

DB schema JSON: {db_schema_json}

You are also provided with a list of recommended tables from the schema that are likely to be relevant for answering the question.

Recommended tables: {reccomended_tables}

Write a SQL query that answers the following question:

User query: {user_query}

Only return the SQL query, nothing else.

Make sure your SQL is formatted so that it is easily tokenizable: 
- Use clear spacing around punctuation (e.g., SELECT name , age FROM table WHERE id = 5).
- Avoid unnecessary line breaks or indentation.
- Do not include comments or explanations.
"""
)

db_chain_3 = generate_sql_prompt | llm

In [None]:
SCHEMA_PROCESSED_FILE = SCHEMA_OUTPUT_DIR / "spider_schemas_processed.jsonl"

In [296]:
def clean_sql(sql_string):
    # Remove code fences if present
    sql_string = sql_string.strip()
    if sql_string.startswith("```") and sql_string.endswith("```"):
        sql_string = "\n".join(sql_string.splitlines()[1:-1])  # remove ```sql and ```
    # Remove surrounding parentheses if it's a tuple string
    if sql_string.startswith("(") and sql_string.endswith(")"):
        sql_string = sql_string[1:-1].strip().strip("'").strip('"')
    return sql_string

def agent_c(user_query, db_name, reccomended_tables, mode="medium"):
    """
    mode options:
      - "light": only the generated SQL
      - "medium": user query, db_name, recommended tables, SQL
      - "heavy": everything + schema + raw LLM result
    """

    # Read all lines from the processed schema file
    with open(SCHEMA_PROCESSED_FILE, "r", encoding="utf-8") as f:
        schema_lines = f.readlines()

    # Parse JSON lines and filter for the selected database
    full_schema = [
        json.loads(line)
        for line in schema_lines
        if json.loads(line)["database"] == db_name
    ]

    # Run LLM
    response_3 = db_chain_3.invoke({
        "user_query": user_query,
        "db_schema_json": full_schema,
        "reccomended_tables": reccomended_tables
    })

    # Safely parse LLM output
    llm_selection_content = (
        response_3.content if hasattr(response_3, "content") else str(response_3)
    )
    try:
        parsed = json.loads(llm_selection_content)
        sql = parsed.get("sql", "")
    except json.JSONDecodeError:
        sql = llm_selection_content.strip()

    # Clean SQL
    sql = clean_sql(sql)

    if mode == "light":
        return sql  # just the SQL string

    elif mode == "medium":
        return {
            "User Query": user_query,
            "Database Name": db_name,
            "Reccomended Tables": reccomended_tables,
            "SQL": sql
        }

    elif mode == "heavy":
        return {
            "User Query": user_query,
            "Database Name": db_name,
            "Reccomended Tables": reccomended_tables,
            "SQL": sql,
            "schema": full_schema,
            "llm_result_raw": response_3
        }

    else:
        raise ValueError(f"Unknown mode: {mode}")

### 3.3 Apply Agent C

In [243]:
# Pick every X questions
every_z_th = 2500
questions_subset = sql_answers_df.iloc[::every_z_th]
print(f"Number of questions we are testing: {len(questions_subset)}")
print(f"Total number of questions: {len(sql_answers_df)}")

Number of questions we are testing: 3
Total number of questions: 7000


Note we have to run Agent B to feed the reccomended tables to Agent C 

In [297]:
for idx, row in questions_subset.iterrows():
    question = row['question']
    db_id = row['db_id']
    
    # Step 1: Use Agent B to get recommended tables (light mode for clean output)
    agent_b_result = agent_b(question, db_id, mode="light")
    print("\nAgent B Table Selection:")
    pprint.pprint(agent_b_result)

    # Step 2: Extract recommended tables from Agent B's output
    recommended_tables = agent_b_result.get("Tables", [])

    # Step 3: Use Agent C to generate SQL using recommended tables
    agent_c_result = agent_c(question, db_id, recommended_tables, mode="light")
    print("\nAgent C SQL Generation:")
    pprint.pprint(agent_c_result)

    print("\n" + "="*60)


Agent B Table Selection:
{'Columns': ['head id', 'name', 'age'], 'Tables': ['head']}

Agent C SQL Generation:
'SELECT COUNT(*) FROM head WHERE age > 56;'


Agent B Table Selection:
{'Columns': ['title', 'year', 'rating stars'], 'Tables': ['movie', 'rating']}

Agent C SQL Generation:
('SELECT m.title , m.year FROM movie m JOIN rating r ON m."movie id" = '
 'r."movie id" GROUP BY m."movie id" ORDER BY AVG(r."rating stars") DESC LIMIT '
 '3;')


Agent B Table Selection:
{'Columns': ['player name'], 'Tables': ['player', 'tryout']}

Agent C SQL Generation:
('SELECT DISTINCT player_name FROM player INNER JOIN tryout ON '
 'player.player_id = tryout.player_id ORDER BY player_name;')



# 4. Archive

In [197]:
# def fetch_schema_for_db(db_id, schema_file=SCHEMA_PROCESSED_FILE):
#     """
#     Fetch the schema (list of tables and columns) for a given db_id from the processed schema file.
#     Returns a list of dicts: [{ "table": ..., "columns": [...] }, ...]
#     """
#     schema = []
#     with open(schema_file, "r") as f:
#         for line in f:
#             entry = json.loads(line)
#             if entry["database"] == db_id:
#                 schema.append({
#                     "table": entry["table"],
#                     "columns": entry["columns"]
#                 })
#     return schema

# # Example usage:
# db_schema = fetch_schema_for_db("perpetrator")
# print(db_schema)

[{'table': 'perpetrator', 'columns': ['perpetrator id', 'people id', 'date', 'year', 'location', 'country', 'killed', 'injured']}, {'table': 'people', 'columns': ['people id', 'name', 'height', 'weight', 'home town']}]


In [196]:
# Read the schema file and build a mapping from db_id to list of tables
db_to_tables = {}

# Get all unique db_ids from sql_answers_df
db_ids_needed = set(sql_answers_df["db_id"].tolist())


db_to_tables = {}

with open(SCHEMA_PROCESSED_FILE, "r") as f:
    for line in f:
        entry = json.loads(line)
        db_id = entry["database"]
        table = entry["table"]
        columns = entry["columns"]
        # Only process db_ids that are in our needed set
        if db_id in db_ids_needed:
            if db_id not in db_to_tables:
                db_to_tables[db_id] = []
            db_to_tables[db_id].append({
                "table": table,
                "columns": columns
            })

print(db_to_tables)

{'perpetrator': [{'table': 'perpetrator', 'columns': ['perpetrator id', 'people id', 'date', 'year', 'location', 'country', 'killed', 'injured']}, {'table': 'people', 'columns': ['people id', 'name', 'height', 'weight', 'home town']}], 'college_2': [{'table': 'classroom', 'columns': ['building', 'room number', 'capacity']}, {'table': 'department', 'columns': ['department name', 'building', 'budget']}, {'table': 'course', 'columns': ['course id', 'title', 'department name', 'credits']}, {'table': 'instructor', 'columns': ['id', 'name', 'department name', 'salary']}, {'table': 'section', 'columns': ['course id', 'section id', 'semester', 'year', 'building', 'room number', 'time slot id']}, {'table': 'teaches', 'columns': ['id', 'course id', 'section id', 'semester', 'year']}, {'table': 'student', 'columns': ['id', 'name', 'department name', 'total credits']}, {'table': 'takes classes', 'columns': ['id', 'course id', 'section id', 'semester', 'year', 'grade']}, {'table': 'advisor', 'colum