


In [20]:
import sqlite3
import json
import pandas as pd
import os
import textwrap
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain_community.vectorstores import FAISS
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain
from openai import OpenAI

In [21]:
tables_json_path_2 = r"C:\Users\coffe\OneDrive\Desktop\CITS5553 Capstone Project\spider_data\spider_data\train_spider.json"
if os.path.exists(tables_json_path_2):
    with open(tables_json_path_2, 'r', encoding='utf-8') as f:
        train_data = json.load(f)
    print(f"Number of tables/schemas: {len(train_data)}")
    if train_data :
        first_entry = train_data[5] 
        print(f"Keys: {list(first_entry.keys()) if isinstance(first_entry, dict) else 'Not a dictionary'}")
else:
    print("error.")

Number of tables/schemas: 7000
Keys: ['db_id', 'query', 'query_toks', 'query_toks_no_value', 'question', 'question_toks', 'sql']


In [22]:
dbid_question_list = [
    {'db_id': entry.get('db_id'), 'question': entry.get('question'), 'query': entry.get('query')}
    for entry in train_data if isinstance(entry, dict)
]
print(dbid_question_list[:5]) 

[{'db_id': 'department_management', 'question': 'How many heads of the departments are older than 56 ?', 'query': 'SELECT count(*) FROM head WHERE age  >  56'}, {'db_id': 'department_management', 'question': 'List the name, born state and age of the heads of departments ordered by age.', 'query': 'SELECT name ,  born_state ,  age FROM head ORDER BY age'}, {'db_id': 'department_management', 'question': 'List the creation year, name and budget of each department.', 'query': 'SELECT creation ,  name ,  budget_in_billions FROM department'}, {'db_id': 'department_management', 'question': 'What are the maximum and minimum budget of the departments?', 'query': 'SELECT max(budget_in_billions) ,  min(budget_in_billions) FROM department'}, {'db_id': 'department_management', 'question': 'What is the average number of employees of the departments whose rank is between 10 and 15?', 'query': 'SELECT avg(num_employees) FROM department WHERE ranking BETWEEN 10 AND 15'}]


In [38]:
every_x_th = dbid_question_list[::150]
# checking how many in total to be tested
len(every_x_th) 

47

In [39]:
for item in every_x_th:
    print(item['question'])
    print(item['db_id'])
    print(item['query'])

How many heads of the departments are older than 56 ?
department_management
SELECT count(*) FROM head WHERE age  >  56
What is the average bike availablility for stations not in Palo Alto?
bike_1
SELECT avg(bikes_available) FROM status WHERE station_id NOT IN (SELECT id FROM station WHERE city  =  "Palo Alto")
Find the maximum and total number of followers of all users.
twitter_1
SELECT max(followers) ,  sum(followers) FROM user_profiles
What is allergy type of a cat allergy?
allergy_1
SELECT allergytype FROM Allergy_type WHERE allergy  =  "Cat"
What si the youngest employee's first and last name?
store_1
SELECT first_name , last_name FROM employees ORDER BY birth_date DESC LIMIT 1;
What are the names and locations of all tracks?
race_track
SELECT name ,  LOCATION FROM track
Find the policy types more than 4 customers use. Show their type code.
insurance_fnol
SELECT policy_type_code FROM available_policies GROUP BY policy_type_code HAVING count(*)  >  4
Find the names of the chip model

In [40]:
from pathlib import Path
import subprocess, sys

# 1) Use the same interpreter as the notebook kernel
PY = sys.executable

# 2) Point cwd to your project root (adjust if your notebook isn't in project_root/notebooks/)
project_root = Path.cwd().parent  # e.g., notebooks/ -> project root
print("CWD:", project_root)

answers = []

for item in every_x_th:
    result = subprocess.run(
        [PY, "-m", "src.agents.agent_a", "--query", item['question'], "--mode", "light", "--quiet"],
        cwd=project_root,                 # critical so -m finds src.agents.agent_a
        text=True,
    stdout=subprocess.PIPE,
    stderr=subprocess.STDOUT          # merge stderr so you see errors/logs
    )
    answers.append(result.stdout.strip())
    



CWD: c:\Users\coffe\OneDrive\Desktop\CITS5553 Capstone Project\Project_Git\explainable-nl-query-db-agents


In [41]:
db_ground_truth = [item['db_id'] for item in every_x_th]
db_prediction = [ans.strip().strip('"') for ans in answers]

In [42]:
comparison_results = [gt == pred for gt, pred in zip(db_ground_truth, db_prediction)]
print(comparison_results)

[True, True, True, True, False, True, False, True, True, False, True, False, True, True, False, True, True, False, True, False, True, True, False, False, True, True, False, True, False, True, True, True, True, True, True, True, True, False, False, True, True, True, False, True, True, False, True]


In [43]:
sum(comparison_results) / len (comparison_results)

0.6808510638297872

In [44]:
false_indices = [i for i, val in enumerate(comparison_results) if not val]
for i in false_indices:
    question = every_x_th[i]['question']
    ground_truth = db_ground_truth[i]
    prediction = db_prediction[i]
    print(f"Q: {question}\nTruth: {ground_truth} | Pred: {prediction}\n{'-'*60}")

Q: What si the youngest employee's first and last name?
Truth: store_1 | Pred: college_1
------------------------------------------------------------
Q: Find the policy types more than 4 customers use. Show their type code.
Truth: insurance_fnol | Pred: insurance_and_eClaims
------------------------------------------------------------
Q: How many different instructors have taught some course?
Truth: college_2 | Pred: course_teach
------------------------------------------------------------
Q: Count the number of artists.
Truth: theme_gallery | Pred: store_1
------------------------------------------------------------
Q: What document type codes do we have?
Truth: cre_Doc_Control_Systems | Pred: cre_Doc_Tracking_DB
------------------------------------------------------------
Q: Show names of cities and names of counties they are in.
Truth: county_public_safety | Pred: e_government
------------------------------------------------------------
Q: Find the names and phone numbers of custome

In [58]:
correct_pairs = [v for v, f in zip(every_x_th, comparison_results) if f]

In [59]:
len(correct_pairs)

32

In [60]:
for item in correct_pairs:
    print(item['question'])
    print(item['db_id'])
    print(item['query'])

How many heads of the departments are older than 56 ?
department_management
SELECT count(*) FROM head WHERE age  >  56
What is the average bike availablility for stations not in Palo Alto?
bike_1
SELECT avg(bikes_available) FROM status WHERE station_id NOT IN (SELECT id FROM station WHERE city  =  "Palo Alto")
Find the maximum and total number of followers of all users.
twitter_1
SELECT max(followers) ,  sum(followers) FROM user_profiles
What is allergy type of a cat allergy?
allergy_1
SELECT allergytype FROM Allergy_type WHERE allergy  =  "Cat"
What are the names and locations of all tracks?
race_track
SELECT name ,  LOCATION FROM track
Find the names of the chip models that are not used by any phone with full accreditation type.
phone_1
SELECT model_name FROM chip_model EXCEPT SELECT chip_model FROM phone WHERE Accreditation_type  =  'Full'
Show the short names of the buildings managed by "Emma".
apartment_rentals
SELECT building_short_name FROM Apartment_Buildings WHERE building_man

In [61]:
correct_db_prediction = [v for v, f in zip(db_prediction, comparison_results) if f]

In [62]:
correct_db_prediction

['department_management',
 'bike_1',
 'twitter_1',
 'allergy_1',
 'race_track',
 'phone_1',
 'apartment_rentals',
 'debate',
 'small_bank_1',
 'cinema',
 'machine_repair',
 'candidate_poll',
 'storm_record',
 'sakila_1',
 'assets_maintenance',
 'music_1',
 'program_share',
 'student_1',
 'tracking_grants_for_research',
 'document_management',
 'college_3',
 'aircraft',
 'soccer_2',
 'cre_Drama_Workshop_Groups',
 'music_2',
 'shop_membership',
 'tracking_share_transactions',
 'game_1',
 'music_4',
 'cre_Docs_and_Epenses',
 'train_station',
 'tracking_orders']

In [66]:
PY = sys.executable

# 2) Point cwd to your project root (adjust if your notebook isn't in project_root/notebooks/)
project_root = Path.cwd().parent  # e.g., notebooks/ -> project root
print("CWD:", project_root)

answers_agent_b = []

for item in correct_pairs:
    result = subprocess.run(
        [PY, "-m", "src.agents.agent_b", "--query", item['question'], "--database", item['db_id'], "--mode", "medium", "--quiet"],
        cwd=project_root,                 # critical so -m finds src.agents.agent_b
        text=True,
    stdout=subprocess.PIPE,
    stderr=subprocess.STDOUT          # merge stderr so you see errors/logs
    )
    answers_agent_b.append(result.stdout.strip())
    

CWD: c:\Users\coffe\OneDrive\Desktop\CITS5553 Capstone Project\Project_Git\explainable-nl-query-db-agents


In [80]:
parsed_answers = []

for i in answers_agent_b:
    try:
        data = json.loads(i)   # convert string â†’ dict
        parsed_answers.append(data)
    except json.JSONDecodeError as e:
        print(" Could not parse:", i)
        print("Error:", e)


for ans in parsed_answers:
    print("Query:", ans["User Query"])
    print("Database name:", ans["Database Name"])
    print("Tables:", ans["Tables"])
    print()

Query: How many heads of the departments are older than 56 ?
Database name: department_management
Tables: ['head', 'management']

Query: What is the average bike availablility for stations not in Palo Alto?
Database name: bike_1
Tables: ['station', 'status']

Query: Find the maximum and total number of followers of all users.
Database name: twitter_1
Tables: ['follows', 'user profiles']

Query: What is allergy type of a cat allergy?
Database name: allergy_1
Tables: ['allergy type']

Query: What are the names and locations of all tracks?
Database name: race_track
Tables: ['track']

Query: Find the names of the chip models that are not used by any phone with full accreditation type.
Database name: phone_1
Tables: ['chip model', 'phone']

Query: Show the short names of the buildings managed by "Emma".
Database name: apartment_rentals
Tables: ['apartment buildings']

Query: Show the distinct venues of debates
Database name: debate
Tables: ['debate']

Query: Find the names and total checkin

In [None]:
PY = sys.executable

# 2) Point cwd to your project root (adjust if your notebook isn't in project_root/notebooks/)
project_root = Path.cwd().parent  # e.g., notebooks/ -> project root
print("CWD:", project_root)

answers_agent_c = []

for item in parsed_answers:
    result = subprocess.run(
        [PY, "-m", "src.agents.agent_c", "--query", item["User Query"], "--database", item['Database Name'], "--tables", json.dumps(item['Tables']), "--quiet"],
        cwd=project_root,                 # critical so -m finds src.agents.agent_c
        text=True,
    stdout=subprocess.PIPE,
    stderr=subprocess.STDOUT          # merge stderr so you see errors/logs
    )
    answers_agent_c.append(result.stdout.strip())

CWD: c:\Users\coffe\OneDrive\Desktop\CITS5553 Capstone Project\Project_Git\explainable-nl-query-db-agents
