In [1]:
import sqlite3
import pandas as pd
import numpy as np
import sqlglot
dbList = {
    "california_schools": "D:/NL2SQLDev/backend/data/dbCollection/california_schools/california_schools.sqlite",
    "card_games": "D:/NL2SQLDev/backend/data/dbCollection/card_games/card_games.sqlite",
    "codebase_community": "D:/NL2SQLDev/backend/data/dbCollection/codebase_community/codebase_community.sqlite",
    "debit_card_specializing": "D:/NL2SQLDev/backend/data/dbCollection/debit_card_specializing/debit_card_specializing.sqlite",
    "european_football_2": "D:/NL2SQLDev/backend/data/dbCollection/european_football_2/european_football_2.sqlite",
    "financial": "D:/NL2SQLDev/backend/data/dbCollection/financial/financial.sqlite",
    "formula_1": "D:/NL2SQLDev/backend/data/dbCollection/formula_1/formula_1.sqlite",
    "student_club": "D:/NL2SQLDev/backend/data/dbCollection/student_club/student_club.sqlite",
    "superhero": "D:/NL2SQLDev/backend/data/dbCollection/superhero/superhero.sqlite",
    "thrombosis_prediction": "D:/NL2SQLDev/backend/data/dbCollection/thrombosis_prediction/thrombosis_prediction.sqlite",
    "toxicology": "D:/NL2SQLDev/backend/data/dbCollection/toxicology/toxicology.sqlite",
}

In [2]:
task = "california_schools"

In [3]:
con=sqlite3.connect(dbList[task])
cursor=con.cursor()

In [4]:
def get_table_names(cursor):
    cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
    tables = cursor.fetchall()
    return [table[0] for table in tables]


def executeSQL(cursor,sql):
    cursor.execute(sql)
    rows = cursor.fetchall()
    columns = [i[0] for i in cursor.description]
    df = pd.DataFrame(rows, columns=columns)
    return df
def generate_schema( db_path, num_rows=None):
    full_schema_prompt_list = []
    conn = sqlite3.connect(db_path)
    # Create a cursor object
    cursor = conn.cursor()
    cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
    tables = cursor.fetchall()
    schemas = {}
    for table in tables:
        if table == "sqlite_sequence":
            continue
        cursor.execute(f"PRAGMA table_info('{table[0]}')")
        columns_info = cursor.fetchall()
        column_df = pd.DataFrame(
            columns_info,
            columns=["cid", "name", "type", "notnull", "dflt_value", "pk"],
        )
        cursor.execute(f"PRAGMA foreign_key_list('{table[0]}')")
        fk_info = cursor.fetchall()
        fk_df = pd.DataFrame(
            fk_info,
            columns=[
                "id",
                "seq",
                "table",
                "from",
                "to",
                "on_update",
                "on_delete",
                "match",
            ],
        )
        schemas[table[0]] = {"columns": column_df, "foreign_keys": fk_df}
    return schemas

def generate_fk_pairs_list( schema):
    fk_pairs_list = []
    for k, v in schema.items():
        for iter, row in v["foreign_keys"].iterrows():
            pair_1 = (k, row["from"])
            pair_2 = (row["table"], row["to"])
            flag = 0
            for j in range(len(fk_pairs_list)):
                if pair_1 in fk_pairs_list[j] or pair_2 in fk_pairs_list[j]:
                    fk_pairs_list[j].add(pair_1)
                    fk_pairs_list[j].add(pair_2)
                    flag = 1
                    break
            if flag == 0:
                fk_pairs_list.append(set([pair_1, pair_2]))
    return fk_pairs_list
def jsonFKPairs( fk_pairs):
    fk_json_list = []
    for fk_pair in fk_pairs:
        fk_json_list.append([])
        for fk in fk_pair:
            fk_json_list[-1].append({"table": fk[0], "column": fk[1]})
    return fk_json_list
def generateFkpPairsList( schema):
    fk_pairs_list = []
    for k, v in schema.items():
        for iter, row in v["foreign_keys"].iterrows():
            pair_1 = (k, row["from"])
            pair_2 = (row["table"], row["to"])
            flag = 0
            for j in range(len(fk_pairs_list)):
                if pair_1 in fk_pairs_list[j] or pair_2 in fk_pairs_list[j]:
                    fk_pairs_list[j].add(pair_1)
                    fk_pairs_list[j].add(pair_2)
                    flag = 1
                    break
            if flag == 0:
                fk_pairs_list.append(set([pair_1, pair_2]))
    return fk_pairs_list


def schemaTrans( schema):
    result = {}
    result["table_names"] = []
    result["columns"] = {}
    result["fk_pairs"] = []
    for k, v in schema.items():
        result["table_names"].append(k)
        result["columns"][k] = v["columns"].to_dict()
    fk_pairs = generateFkpPairsList(schema)
    result["fk_pairs"] = jsonFKPairs(fk_pairs)
    return result

def dbSchemaPrompt( dbName=None, dbSchema=None):

    prompt = f"The database schema of {dbName} is as follows:\n"
    for i in dbSchema["columns"]:
        prompt += f"    Table {i}: ("
        for j in dbSchema["columns"][i]["name"]:
            prompt += f"'{dbSchema['columns'][i]['name'][j]}', "
        prompt = prompt[:-2] + ")\n"
    prompt += "The foreign key pairs are:\n"
    for i in dbSchema["fk_pairs"]:
        for j in i:
            prompt += f"{j['table']}.'{j['column']}' = "
        prompt = prompt[:-2] + "\n"
    return prompt

In [9]:
SQL = """
-- Step 1: Find the maximum number of students with SAT >= 1500
WITH MaxSAT AS (
    SELECT MAX("NumGE1500") AS MaxNumGE1500
    FROM satscores
)

-- Step 2 & 3: Identify the school(s) with the maximum NumGE1500 and retrieve admin names
SELECT 
    schools."AdmFName1" || ' ' || schools."AdmLName1" AS "Admin_Name"
FROM 
    schools
JOIN 
    satscores ON schools."CDSCode" = satscores."cds"
WHERE 
    satscores."NumGE1500" = (SELECT MaxNumGE1500 FROM MaxSAT)

UNION

SELECT 
    schools."AdmFName2" || ' ' || schools."AdmLName2" AS "Admin_Name"
FROM 
    schools
JOIN 
    satscores ON schools."CDSCode" = satscores."cds"
WHERE 
    satscores."NumGE1500" = (SELECT MaxNumGE1500 FROM MaxSAT)

UNION

SELECT 
    schools."AdmFName3" || ' ' || schools."AdmLName3" AS "Admin_Name"
FROM 
    schools
JOIN 
    satscores ON schools."CDSCode" = satscores."cds"
WHERE 
    satscores."NumGE1500" = (SELECT MaxNumGE1500 FROM MaxSAT);


"""
executeSQL(cursor,SQL)

Unnamed: 0,Admin_Name
0,
1,Michelle King
