In [None]:
OPENAI_API_KEY = "<OPENAI API KEY>" # Replace this with your own OpenAI API key

In [1]:
import ast
import glob
import json
import os
import re
import sqlite3
import time
from typing import List, Optional

from func_timeout import func_timeout, FunctionTimedOut
import openai
from openai import APITimeoutError as OpenAITimeout
from openai import APIError
import pandas as pd
from requests.exceptions import Timeout as RequestsTimeout
import sqlglot
from sqlglot import expressions, Expression
import tiktoken

In [2]:
openai.api_key = OPENAI_API_KEY

REFERENCE_SET = "../bird-data-train"
REFERENCE_JSON = "train.json"
REFERENCE_DB = "train_databases"

REFERENCE_JSON = "./" + REFERENCE_SET + "/" + REFERENCE_JSON
REFERENCE_DB_DIR = "./" + REFERENCE_SET + "/" + REFERENCE_DB

SET = "../bird-data-dev"
JSON = "dev.json"
DB = "dev_databases"

JSON_FILE = "./" + SET + "/" + JSON
DB_DIR = "./" + SET + "/" + DB

dryrun = False

if not os.path.exists(DB_DIR):
    raise ValueError(f"DB_DIR {DB_DIR} does not exist.")
if not os.path.exists(JSON_FILE):
    raise ValueError(f"JSON_FILE {JSON_FILE} does not exist.")

# DEFAULT_MODEL = "gpt-4-turbo-preview"
DEFAULT_MODEL = 'gpt-4-0125-preview'

encoding = tiktoken.get_encoding("cl100k_base")

In [3]:
def num_tokens_from_messages(messages, tokens_per_message=3, tokens_per_name=1):
    num_tokens = 0
    for message in messages:
        num_tokens += tokens_per_message
        for key, value in message.items():
            num_tokens += len(encoding.encode(value))
            if key == "name":
                num_tokens += tokens_per_name
    num_tokens += 3
    return num_tokens


def get_create_table_and_data(db_path: str, num_rows: int = 5) -> List[str]:
    MAX_TOKENS = 5000
    conn = sqlite3.connect(db_path, timeout=30)
    cursor = conn.cursor()
    while num_rows >= 0:
        # Query the sqlite_master table to get the CREATE TABLE statements
        cursor.execute(
            "SELECT name, sql FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'"
        )
        tables = cursor.fetchall()

        output_statements = []

        for table_name, create_statement in tables:
            # "INTEGER" -> "INT"
            create_statement = create_statement.replace("INTEGER", "INT")

            # remove comments
            create_statement = re.sub(
                r"--.*$", "", create_statement, flags=re.MULTILINE
            )
            create_statement = "\n".join(
                [line for line in create_statement.split("\n") if line.strip()]
            )

            # Condense whitespace
            create_statement = " ".join(create_statement.split())

            # First, add the create statement
            output_statements.append(create_statement + ";")

            # Fetch sample data
            cursor.execute(f"SELECT * FROM `{table_name}` LIMIT ?", (num_rows,))
            sample_rows = cursor.fetchall()

            # For each row, create an INSERT INTO statement
            for row in sample_rows:
                formatted_values = []
                for idx, value in enumerate(row):
                    if isinstance(value, str):
                        formatted_value = value.replace("\n", " ")
                        formatted_value = formatted_value.replace("'", '"')
                        formatted_value = formatted_value[:100]
                        formatted_values.append(f"'{formatted_value}'")
                    elif value is None:
                        formatted_values.append("NULL")
                    else:
                        formatted_values.append(str(value))
                values_str = ",".join(formatted_values)

                # Check if table_name contains a space or dash and wrap it in double quotes if it does
                if " " in table_name or "-" in table_name:
                    formatted_table_name = f'"{table_name}"'
                else:
                    formatted_table_name = table_name

                insert_statement = (
                    f"INSERT INTO {formatted_table_name} VALUES ({values_str});"
                )
                output_statements.append(insert_statement)

        msgs = [{"role": "user", "content": "\n".join(output_statements)}]
        token_count = num_tokens_from_messages(msgs)

        if token_count < MAX_TOKENS:
            cursor.close()
            conn.close()
            return output_statements#, tables
        elif num_rows > 0:
            num_rows -= 1
            continue
        else:
            final_statements = []
            for statement in output_statements:
                final_statements.append(statement)
                msgs = [{"role": "user", "content": "\n".join(final_statements)}]
                token_count = num_tokens_from_messages(msgs)

                if token_count > MAX_TOKENS:
                    cursor.close()
                    conn.close()
                    final_statements.pop()
                    return final_statements#, tables
    cursor.close()
    conn.close()
    raise ValueError(f"Even with 0 rows, token count is too high!")


def clean_creates(sql_text: str) -> str:
    """While these fields might be useful for some purposes, I've honestly
    needed them so rarely as a data scientist that we are going to exclude them
    """

    def replace_(node: Expression) -> Optional[Expression]:
        if isinstance(
            node,
            (
                expressions.ColumnConstraint,
                expressions.PrimaryKey,
                expressions.Constraint,
            ),
        ):
            return None
        return node

    return str(sqlglot.parse_one(sql_text).transform(replace_))


def hard_replace__clean_creates(sql_text: str):
    """The backticks and double-quotes are always equivalent in bird
    # but sqlglot cannot yet handle the backticks
    """
    try:
        return clean_creates(
            sql_text.replace("`", '"')
            .replace("WITHOUT ROWID", "")
            .replace("on update cascade", "")
            .replace("ON UPDATE CASCADE", "")
            .replace("on delete cascade", "")
            .replace("ON DELETE CASCADE", "")
            .replace("references staff", "")
        )
    except Exception:
        raise


def read_in_all_sqlite_dbs(directory):
    """Read in all the sqlite databases from the bird data"""
    dirs = glob.glob(directory + "/*")
    statements = []
    for d in dirs:
        if os.path.isfile(d):
            continue
        dbname = d.split("/")[-1]
        sqlite_db_path = os.path.join(d, dbname + ".sqlite")
        assert os.path.exists(d), f"DB {d} does not exist!"
        ddl_list = get_create_table_and_data(sqlite_db_path)
        for ddl in ddl_list:
            statements.append((dbname, hard_replace__clean_creates(ddl)))

    return statements


def make_x(tables, db_id, ideal_sql, question):
    """Make the x and y for the training data"""
    return tables, db_id, ideal_sql, question

In [4]:
df_reference_embeddings = pd.read_csv('BIRD_train_embeddings.csv')
df_reference_embeddings.head()

Unnamed: 0,question,embeddings
0,Name movie titles released in year 1945. Sort ...,"[-0.006165551487356424, 0.0049751438200473785,..."
1,State the most popular movie? When was it rele...,"[-0.02109631896018982, 0.03676282986998558, -0..."
2,What is the name of the longest movie title? W...,"[0.002490998711436987, 0.03317447006702423, -0..."
3,Name the movie with the most ratings.,"[-0.04145917296409607, 0.019352909177541733, -..."
4,What is the average number of Mubi users who l...,"[-0.0009301478858105838, -0.012216581963002682..."


In [5]:
import chardet

def read_database_descriptions(dir):
    consolidated_df = pd.DataFrame()
    
    dirs = glob.glob(dir + "/*")

    typo_corrections = {
        'original_column_name': ['original_column_name'],
        'column_description': ['column_description', 'column_desription']  # Add known typos here
    }
    
    for d in dirs:
        files = glob.glob(d + "/database_description/*")
        
        for f in files:
            # print(f)
            # Attempt to detect the file encoding
            with open(f, 'rb') as file:
                raw_data = file.read(10000)  # Read first 10000 bytes to guess encoding
                result = chardet.detect(raw_data)
                encoding = result['encoding']
                
            try:
                # Try reading with the detected encoding
                tmp = pd.read_csv(f, encoding=encoding)
            except UnicodeDecodeError:
                try:
                    # Fallback: Try reading with 'latin1'
                    tmp = pd.read_csv(f, encoding='latin1')
                except UnicodeDecodeError:
                    # Last resort: Ignore errors or replace invalid characters
                    tmp = pd.read_csv(f, encoding=encoding, errors='replace')

            # Attempt to correct column names based on known typos
            for correct, possible_typos in typo_corrections.items():
                found = False
                for typo in possible_typos:
                    if typo in tmp.columns:
                        tmp.rename(columns={typo: correct}, inplace=True)
                        found = True
                        break
                if not found:
                    # If no matching columns are found, create an empty column to maintain structure
                    tmp[correct] = None
            
            # Ensure tmp has the columns expected; if not, they are filled with None
            expected_columns = ['original_column_name', 'column_description']
            for col in expected_columns:
                if col not in tmp.columns:
                    tmp[col] = None
            
            tmp = tmp[expected_columns]
            tmp['table'] = f.split('/')[-1].rstrip('.csv')
            tmp['db_id'] = d.split('/')[-1]
            consolidated_df = pd.concat([consolidated_df, tmp], ignore_index=True)
    
    return consolidated_df


In [6]:
def format_ddl(ddl_str, column_descriptions):
    formatted_ddls = []

    # Split the ddl_str by "CREATE TABLE"
    create_tables = re.split(r"(?i)CREATE TABLE", ddl_str)

    for ct in create_tables:
        if not ct.strip():
            continue

        # Extract table name from the current CREATE TABLE section
        table_name_match = re.search(r'^\s*("?[\w\s-]+"?|[\w\s-]+)', ct)
        table_name = (
            table_name_match.group(1).strip() if table_name_match else "Unknown Table"
        )

        # Find column descriptions for this table
        table_column_descriptions = column_descriptions[column_descriptions['table'].str.strip('"') == table_name.strip('"')]

        # Split the current section at "INSERT INTO"
        splits = ct.split("INSERT INTO")

        # Extract column names and remove the table name from it
        columns = re.sub(r"^\s*" + re.escape(table_name), "", splits[0]).strip()
        if columns.startswith("(") and columns.endswith(")"):
            columns = columns[1:-1]

        # Process compound foreign keys
        fk_pattern = r"(FOREIGN KEY \([^)]+\) REFERENCES [\w]+ \([^)]+\))"
        foreign_keys = re.findall(fk_pattern, columns, re.DOTALL)  # Use DOTALL to match across newlines
        # Replace newlines in foreign key definitions to ensure they stay on one line and use placeholders for commas
        placeholder_mapping = {}
        for i, fk in enumerate(foreign_keys):
            placeholder = f"__PLACEHOLDER_{i}__"
            corrected_fk_str = re.sub(r"\s*\n\s*", " ", fk)  # Replace newlines and surrounding spaces with a single space
            corrected_fk_str = corrected_fk_str.replace(",", placeholder)  # Temporarily replace commas
            placeholder_mapping[placeholder] = ","
            columns = columns.replace(fk, corrected_fk_str)

        # Process INSERT statements
        if ' ' in table_name or '-' in table_name:   
            cleaned_table_name = table_name
        else:
            cleaned_table_name = table_name.strip('"')
        insert_statements = [
            split.replace(f"{cleaned_table_name} VALUES", "").strip()
            for split in splits[1:]
        ]
        # cast the strings to tuples
        insert_tuples = []
        for statement in insert_statements:
            corrected_statement = statement.replace('NULL', 'None')
            result = ast.literal_eval(corrected_statement)
            if isinstance(result, tuple):
                insert_tuples.append(result)
            else:
                insert_tuples.append((result,))

        columns = " ".join(columns.split())
        pattern = r', UNIQUE \([^)]*\)'
        columns = re.sub(pattern, '', columns)

        # Append column descriptions and example values
        column_lines = columns.split(", ")
        # Restore commas in foreign keys
        for placeholder, comma in placeholder_mapping.items():
            column_lines = [line.replace(placeholder, comma) for line in column_lines]

        for i, line in enumerate(column_lines):
            column_name_match = re.match(r'\s*"?(\w+)"?\s+', line)
            if column_name_match:
                column_name = column_name_match.group(1)
                description = table_column_descriptions.loc[table_column_descriptions['original_column_name'] == column_name, 'column_description'].values
                if description.size > 0 and not pd.isna(description[0]):
                    column_lines[i] += f" -- {description[0]}"

                example_values = ""
                unique_values = set()
                for insert in insert_tuples:
                    if i < len(insert) and not pd.isna(insert[i]):
                        if ',' in str(insert[i]):
                            example_value = json.dumps(str(insert[i]))
                        else:
                            example_value = f"{insert[i]}"
                        if example_value not in unique_values and len(example_values) < 50:
                            unique_values.add(example_value)
                            example_values += example_value + ", "

                
                # Trim the trailing comma and space
                example_values = example_values[:-2]

                if example_values:
                    column_lines[i] += f" -- {example_values}"

        columns_with_descriptions = "\n\t".join(column_lines)

        # Combine the statements for the current table and append to formatted_ddls
        formatted_ddl = (
            table_name
            + " (\n\t"
            + columns_with_descriptions
            + "\n)"
        )
        # if insert_statements:
        #     formatted_ddl += "\n" + "\n".join(insert_statements)

        formatted_ddls.append(formatted_ddl)

    return "\n".join(formatted_ddls)


In [7]:
ddl_statements = read_in_all_sqlite_dbs(DB_DIR)

df_ddl = pd.DataFrame(ddl_statements)
df_ddl.columns = ["db_id", "ddl"]
df_ddl = df_ddl.groupby("db_id")["ddl"].agg("\n".join).reset_index(name="ddl")

column_descriptions = read_database_descriptions(DB_DIR)
column_descriptions.head()

df_ddl["ddl"] = df_ddl.apply(lambda row: format_ddl(row["ddl"], column_descriptions[column_descriptions['db_id'] == row['db_id']]), axis=1)

df_ddl.head()

Unnamed: 0,db_id,ddl
0,california_schools,frpm (\n\tCDSCode TEXT -- CDSCode -- 011001701...
1,card_games,"""cards"" (\n\tid INT -- 1, 2, 3, 4, 5\n\tartist..."
2,codebase_community,"badges (\n\tId INT -- 1, 2, 3, 4, 5\n\tUserId ..."
3,debit_card_specializing,"customers (\n\tCustomerID INT -- 3, 5, 6, 7, 9..."
4,european_football_2,"""Player_Attributes"" (\n\t""id"" INT -- 1, 2, 3, ..."


In [8]:
reference_ddl_statements = read_in_all_sqlite_dbs(REFERENCE_DB_DIR)

df_reference_ddl = pd.DataFrame(reference_ddl_statements)
df_reference_ddl.columns = ["db_id", "ddl"]
df_reference_ddl = df_reference_ddl.groupby("db_id")["ddl"].agg("\n".join).reset_index(name="ddl")

reference_column_descriptions = read_database_descriptions(REFERENCE_DB_DIR)
reference_column_descriptions.head()

df_reference_ddl["ddl"] = df_reference_ddl.apply(lambda row: format_ddl(row["ddl"], reference_column_descriptions[reference_column_descriptions['db_id'] == row['db_id']]), axis=1)

In [9]:
for ddl in df_reference_ddl["ddl"]:
    if 'SpecialOfferID' in ddl:
        print(ddl)

CountryRegion (
	CountryRegionCode TEXT -- The unique id number identifying Country Region ISO standard code for countries and regions.
	Name TEXT -- Country or region name.
	ModifiedDate DATETIME -- Date and time the record was last updated. 
)
Culture (
	CultureID TEXT -- The unique id string identifying the language in which AdventrueWorks data is stored.
	Name TEXT -- Name of the language.
	ModifiedDate DATETIME -- Date and time the record was last updated. 
)
Currency (
	CurrencyCode TEXT -- The unique id string identifying the currency.
	Name TEXT -- Currency name.
	ModifiedDate DATETIME -- Date and time the record was last updated. 
)
CountryRegionCurrency (
	CountryRegionCode TEXT -- The id number identifying Country Region ISO standard code for countries and regions.
	CurrencyCode TEXT -- ISO standard currency code.
	ModifiedDate DATETIME -- Date and time the record was last updated. 
	FOREIGN KEY (CountryRegionCode) REFERENCES CountryRegion (CountryRegionCode)
	FOREIGN KEY (C

In [10]:
df_reference_questions = pd.read_json(REFERENCE_JSON)

df_reference_questions = pd.merge(df_reference_questions, df_reference_embeddings, on='question')
df_reference_questions = pd.merge(df_reference_questions, df_reference_ddl, on='db_id')

df_reference_questions['user_prompt'] = df_reference_questions.apply(lambda x: x["ddl"] 
                                                                             + "\n\n" 
                                                                             + x["question"] 
                                                                             + ("\n" + x["evidence"] if x["evidence"] else ""), axis=1)

df_reference_questions.head()

Unnamed: 0,db_id,question,evidence,SQL,embeddings,ddl,user_prompt
0,movie_platform,Name movie titles released in year 1945. Sort ...,released in the year 1945 refers to movie_rele...,SELECT movie_title FROM movies WHERE movie_rel...,"[-0.006165551487356424, 0.0049751438200473785,...","""lists"" (\n\tuser_id INT -- 88260493, 45204418...","""lists"" (\n\tuser_id INT -- 88260493, 45204418..."
1,movie_platform,State the most popular movie? When was it rele...,most popular movie refers to MAX(movie_popular...,"SELECT movie_title, movie_release_year, direct...","[-0.02109631896018982, 0.03676282986998558, -0...","""lists"" (\n\tuser_id INT -- 88260493, 45204418...","""lists"" (\n\tuser_id INT -- 88260493, 45204418..."
2,movie_platform,What is the name of the longest movie title? W...,longest movie title refers to MAX(LENGTH(movie...,"SELECT movie_title, movie_release_year FROM mo...","[0.002490998711436987, 0.03317447006702423, -0...","""lists"" (\n\tuser_id INT -- 88260493, 45204418...","""lists"" (\n\tuser_id INT -- 88260493, 45204418..."
3,movie_platform,Name the movie with the most ratings.,movie with the most rating refers to MAX(SUM(r...,SELECT movie_title FROM movies GROUP BY movie_...,"[-0.04145917296409607, 0.019352909177541733, -...","""lists"" (\n\tuser_id INT -- 88260493, 45204418...","""lists"" (\n\tuser_id INT -- 88260493, 45204418..."
4,movie_platform,What is the average number of Mubi users who l...,average = AVG(movie_popularity); number of Mub...,SELECT AVG(movie_popularity) FROM movies WHERE...,"[-0.0009301478858105838, -0.012216581963002682...","""lists"" (\n\tuser_id INT -- 88260493, 45204418...","""lists"" (\n\tuser_id INT -- 88260493, 45204418..."


In [11]:
df_question = pd.read_json(JSON_FILE)
joined_df = pd.merge(df_question, df_ddl, on=["db_id"])

joined_df.head()

Unnamed: 0,question_id,db_id,question,evidence,SQL,difficulty,ddl
0,0,california_schools,What is the highest eligible free rate for K-1...,Eligible free rate for K-12 = `Free Meal Count...,SELECT `Free Meal Count (K-12)` / `Enrollment ...,simple,frpm (\n\tCDSCode TEXT -- CDSCode -- 011001701...
1,1,california_schools,Please list the lowest three eligible free rat...,Eligible free rates for students aged 5-17 = `...,SELECT `Free Meal Count (Ages 5-17)` / `Enroll...,moderate,frpm (\n\tCDSCode TEXT -- CDSCode -- 011001701...
2,2,california_schools,Please list the zip code of all the charter sc...,Charter schools refers to `Charter School (Y/N...,SELECT T2.Zip FROM frpm AS T1 INNER JOIN schoo...,simple,frpm (\n\tCDSCode TEXT -- CDSCode -- 011001701...
3,3,california_schools,What is the unabbreviated mailing address of t...,,SELECT T2.MailStreet FROM frpm AS T1 INNER JOI...,simple,frpm (\n\tCDSCode TEXT -- CDSCode -- 011001701...
4,4,california_schools,Please list the phone numbers of the direct ch...,Charter schools refers to `Charter School (Y/N...,SELECT T2.Phone FROM frpm AS T1 INNER JOIN sch...,moderate,frpm (\n\tCDSCode TEXT -- CDSCode -- 011001701...


In [12]:
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity

def find_similar_questions(df_questions, df_reference_questions, num_questions = 3):
    # Convert string representations of embeddings back to numpy arrays
    embeddings_reference = np.array([np.array(ast.literal_eval(emb)) for emb in df_reference_questions['embeddings']])
    embeddings_questions = np.array(df_questions['embeddings'].tolist())

    similarities = cosine_similarity(embeddings_questions, embeddings_reference)
    similar_questions = {}

    for index, row in df_questions.iterrows():
        current_db_id = row['db_id']
        # similarities[index, index] = -np.inf  # Exclude the current question

        # Sort the indexes of other questions based on similarity scores
        sorted_indices = np.argsort(similarities[index])[::-1]
        filtered_indices = []
        seen_db_ids = {current_db_id}

        for i in sorted_indices:
            if len(filtered_indices) >= num_questions:
                break  # Stop when we have found enough questions

            candidate_db_id = df_reference_questions.iloc[i]['db_id']
            if candidate_db_id not in seen_db_ids:
                seen_db_ids.add(candidate_db_id)
                filtered_indices.append(i)

        similar_questions[index] = filtered_indices

    return similar_questions

In [13]:
def get_embeddings(df, column_to_embed):
    embeddings = []
    for text in df[column_to_embed]:
        retries = 0
        while retries < 3:
            try:
                response = openai.embeddings.create(input=text, model="text-embedding-3-large")
                embedding = response.data[0].embedding
                embeddings.append(embedding)
                break
            except Exception as e:
                print(e)
                retries += 1
                time.sleep(5)
                if retries == 3:
                    print(f"Failed to get embedding for {text}")
                    # kill the function
                    raise   

    return embeddings

In [14]:
df = pd.DataFrame(
    joined_df.apply(
        lambda x: make_x(
            x["ddl"]
            + "\n\n"
            + x["question"]            
            + ("\n" + x["evidence"] if x["evidence"] else ""),
            x["db_id"],
            x["SQL"],
            x['question']
        ),
        axis=1,
    ).tolist()
)
df.columns = ["user_prompt", "db_id", "ideal_assistant_response", 'question']

df = df.sample(n=500, random_state=1234).reset_index(drop=True) 
# df = df.tail(10).reset_index(drop=True)

df['embeddings'] = get_embeddings(df, 'question')
start = time.time()
df['similar_questions'] = find_similar_questions(df, df_reference_questions, num_questions=0)
end = time.time()
print(f"Time to find similar questions: {end - start} seconds")

df.head()

Time to find similar questions: 81.02810716629028 seconds


Unnamed: 0,user_prompt,db_id,ideal_assistant_response,question,embeddings,similar_questions
0,Examination (\n\tID INT -- identification of t...,thrombosis_prediction,"SELECT DISTINCT T1.ID, T1.SEX, T1.Birthday FRO...","List ID, sex and date of birth of patient whos...","[0.015778781846165657, 0.005091146100312471, -...",[]
1,frpm (\n\tCDSCode TEXT -- CDSCode -- 011001701...,california_schools,SELECT GSserved FROM schools WHERE City = 'Ade...,What is the most common type of grade span ser...,"[0.014656159095466137, 0.02713850326836109, -0...",[]
2,"""atom"" (\n\t""atom_id"" TEXT -- the unique id of...",toxicology,SELECT COUNT(T.molecule_id) FROM molecule AS T...,Calculate the total carcinogenic molecules for...,"[-0.026426946744322777, -0.053564295172691345,...",[]
3,Examination (\n\tID INT -- identification of t...,thrombosis_prediction,SELECT COUNT(DISTINCT T1.ID) FROM Patient AS T...,"Excluding all P only ANA Pattern patients, how...","[-0.024667911231517792, 0.005138586275279522, ...",[]
4,"badges (\n\tId INT -- 1, 2, 3, 4, 5\n\tUserId ...",codebase_community,SELECT COUNT(id) FROM users WHERE Reputation >...,How many users whose reputations are higher th...,"[-0.015890464186668396, -0.012033837847411633,...",[]


In [15]:
SYSTEM_PROMPT = """You are a principal data engineer. Help users write SQL queries for a SQLite database to answer their questions.

For every user question, you'll be provided with context about their database in the following format:
<table> (
    <column> <data_type> -- <column_description> -- <example_value>, <example_value>, ...
    ...
    <column> <data_type> -- <column_description> -- <example_value>, <example_value>, ...
    FOREIGN KEY (<column>) REFERENCES <table> (<column>)
)

Not all columns have descriptions or example values. Not all tables have foreign keys.

You should only respond in this JSON format:
{
    "sql": "the sql that answers the user's question"
}"""

In [16]:
MAX_RETRIES = 3
RETRY_DELAY = 10  # seconds


def create_response(model, msgs, response_max_tokens=1000):
    for attempt in range(MAX_RETRIES):
        if dryrun:
            return {"choices": [{"message": {"content": "SELECT 1"}}]}
        try:
            return openai.chat.completions.create(
                model=model,
                messages=msgs,
                temperature=0.0,
                n=1,
                max_tokens=response_max_tokens,
                seed=9385
            )
        except (OpenAITimeout, RequestsTimeout, APIError):
            if attempt < MAX_RETRIES - 1:  # i.e. if not the last attempt
                print(f"Timeout occurred, retrying in {RETRY_DELAY} seconds...")
                time.sleep(RETRY_DELAY)
            else:
                raise


def generate_similar_examples(_row, _df_reference_questions):
    similar_questions_indices = _row['similar_questions']
    msgs = []
    for idx in similar_questions_indices:
        ref_row = _df_reference_questions.iloc[idx]
        user_msg = {"role": "user", "content": ref_row['user_prompt']}
        assistant_msg = {
            "role": "assistant",
            "content": json.dumps({
                "sql": ref_row['SQL']
            }, indent=4)
        }
        msgs.append(user_msg)
        msgs.append(assistant_msg)

    return msgs

            
empty_sql_correction_msg = [{"role": "user", "content": "You must reply with SQL."}]

def grab_response_from_chatgpt(model, user_msg, similar_questions):
    user_msg = user_msg
    intro_messages = [{"role": "system", "content": SYSTEM_PROMPT}] + similar_questions
    msgs = intro_messages + [{"role": "user", "content": user_msg}]
    for msg in msgs:
        print(msg['content'])
    response = create_response(model, msgs).choices[0].message.content
    print(response)
    try:
        sql = json.loads(response)["sql"]
    except Exception as e:
        print(e)
        if "Invalid \escape" in str(e):
            print("Invalid escape character detected in response. Attempting to correct...")
            response = response.replace("\\", "\\\\")
            sql = json.loads(response)["sql"]

    if sql == "":
        msgs += [{"role": "assistant", "content": response}]
        msgs += empty_sql_correction_msg
        response = create_response(model, msgs).choices[0].message.content
    return response


def error_correction_from_chatgpt(model, user_msg, predicted_sql, error_message):
    msgs = [
        {"role": "system", "content": SYSTEM_PROMPT},
        {"role": "user", "content": user_msg},
        {"role": "assistant", "content": predicted_sql},
        {
            "role": "user",
            "content": f"That SQL produced this error message: \"{error_message}\". Write a new query. If you receieved a 'no such column' error, consider whether you pulled the column from the correct table. Don't apologize. Respond only with SQL.",
        },
    ]
    return create_response(model, msgs).choices[0].message.content


def execute_query_safe(db_name, query):
    """Executes a SQL query safely and returns the fetched results."""
    path = f"{DB_DIR}/{db_name}/{db_name}.sqlite"
    conn = sqlite3.connect(path, timeout=30)
    cursor = conn.cursor()
    try:
        cursor.execute(query)
        return set(cursor.fetchall()), None
    except sqlite3.OperationalError as e:
        return None, str(e)
    except Exception as e:
        return None, f"Unexpected error: {e}"
    finally:
        conn.close()


def correct_sql_and_execute(db_name, model, user_msg, predicted_query, error_msg, timeout=30):
    """Attempts to correct a SQL query and executes the corrected version."""
    corrected_sql_response = error_correction_from_chatgpt(
        model, user_msg, predicted_query, error_msg
    )
    try:
        corrected_sql = json.loads(corrected_sql_response)["sql"]
        print(f"Corrected SQL: {corrected_sql}")
        result_set, error = func_timeout(timeout, execute_query_safe, args=(db_name, corrected_sql))
        if error:
            print(error)
            return None, corrected_sql
        return result_set, corrected_sql
    except FunctionTimedOut:
        print("SQL correction timed out")
        return None, corrected_sql
    except Exception as e:
        print(str(e))
        return (None, "Failed to correct SQL"), None
    

def check_execution_accuracy(predicted_query, ideal_query, db_name, model, user_msg):
    is_equal = False
    _predicted_set = set()
    predicted_query_final = predicted_query
    error = None

    # Execute the ideal query
    try:
        ideal_set, error = func_timeout(30, execute_query_safe, args=(db_name, ideal_query))
    except FunctionTimedOut:
        error = "Ideal SQL execution timed out"
        ideal_set = set()
    except Exception as e:
        error = f"Unexpected error during ideal SQL execution: {str(e)}"
        ideal_set = set()
    if error:
        print(f"Error with ideal SQL: {error}")
        return is_equal, _predicted_set, predicted_query
    
    error = None
    
    # Execute the predicted query
    try:
        _predicted_set, error = func_timeout(30, execute_query_safe, args=(db_name, predicted_query))
    except FunctionTimedOut:
        error = "Predicted SQL execution timed out"
        _predicted_set = set()
    except Exception as e:
        error = f"Unexpected error during predicted SQL execution: {str(e)}"
        _predicted_set = set()
    if error:
        print(f"Error with predicted SQL: {error}")
        result_set, corrected_sql = correct_sql_and_execute(
            db_name, model, user_msg, predicted_query, error
        )
        _predicted_set = result_set  # Update _predicted_set with result from corrected execution
        if corrected_sql:
            predicted_query_final = corrected_sql

    # Check if correction is needed due to empty result set
    if _predicted_set is not None and len(_predicted_set) == 0:
        print("Empty result set")
        result_set, corrected_sql = correct_sql_and_execute(
            db_name, model, user_msg, predicted_query, "Empty result set"
        )
        _predicted_set = result_set  # Update _predicted_set with result from corrected execution
        if corrected_sql:
            predicted_query_final = corrected_sql

    # Final comparison
    if _predicted_set is not None and ideal_set is not None and _predicted_set == ideal_set:
        is_equal = True
    elif _predicted_set is None and ideal_set is None:
        is_equal = True

    return is_equal, _predicted_set, predicted_query_final


In [17]:
predicted_res = []
first_predicted_queries = []
final_predicted_queries = []
execution_accuracies = []
query_changes = []

df_pred = df.copy()

for index, row in df_pred.iterrows():
    completion = ""
    print(f"processing row {index} of {len(df_pred)}")
    similar_questions = generate_similar_examples(row, df_reference_questions)
    try:
        completion = grab_response_from_chatgpt(
            model=DEFAULT_MODEL, user_msg=row["user_prompt"], similar_questions=similar_questions
        )
    except Exception as e:
        if "Read timed out" in str(e) or "You requested a model that is not compatible with this engine" in str(e):
            print("The request to OpenAI API timed out or the model was incompatible. Retrying...")
            completion = grab_response_from_chatgpt(
                model=DEFAULT_MODEL, user_msg=row["user_prompt"], similar_questions=similar_questions
            )
        else:
            print(f"An unexpected error occurred: {e}")
    print('here 1')
    # first_predicted_query = ast.literal_eval(completion)["sql"]  # type: ignore
    try:
        first_predicted_query = json.loads(completion)["sql"]
    except Exception as e:
        print(e)
        first_predicted_query = ""
    db_name = row["db_id"]
    print('here 2')
    execution_accuracy, predicted_set, final_predicted_query = check_execution_accuracy(
        first_predicted_query,
        row["ideal_assistant_response"],
        db_name,
        DEFAULT_MODEL,
        row["user_prompt"],
    )
    print('here 3')

    execution_accuracies.append(1 if execution_accuracy else 0)
    first_predicted_queries.append(first_predicted_query)
    final_predicted_queries.append(final_predicted_query)
    predicted_res.append(predicted_set)
    print('here 5')


df_pred["first_predicted_query"] = first_predicted_queries
df_pred["final_predicted_query"] = final_predicted_queries
df_pred["predicted_res"] = predicted_res
df_pred["execution_accuracy"] = execution_accuracies

df_pred["formatted_string"] = (
    df_pred["final_predicted_query"] + "\t----- bird -----\t" + df_pred["db_id"]
)

df_pred.head()

processing row 0 of 500
You are a principal data engineer. Help users write SQL queries for a SQLite database to answer their questions.

For every user question, you'll be provided with context about their database in the following format:
<table> (
    <column> <data_type> -- <column_description> -- <example_value>, <example_value>, ...
    ...
    <column> <data_type> -- <column_description> -- <example_value>, <example_value>, ...
    FOREIGN KEY (<column>) REFERENCES <table> (<column>)
)

Not all columns have descriptions or example values. Not all tables have foreign keys.

You should only respond in this JSON format:
{
    "sql": "the sql that answers the user's question"
}
Examination (
	ID INT -- identification of the patient -- 14872, 48473, 102490, 108788, 122405
	"Examination Date" DATE -- 1997-05-27, 1992-12-21, 1995-04-20, 1997-05-06, 1998-04-02
	"aCL IgG" FLOAT -- 1.3, 4.3, 2.3, 0.0
	"aCL IgM" FLOAT -- 1.6, 4.6, 2.5, 0.0, 4.0
	ANA INT -- anti-nucleus antibody concentrati

Unnamed: 0,user_prompt,db_id,ideal_assistant_response,question,embeddings,similar_questions,first_predicted_query,final_predicted_query,predicted_res,execution_accuracy,formatted_string
0,Examination (\n\tID INT -- identification of t...,thrombosis_prediction,"SELECT DISTINCT T1.ID, T1.SEX, T1.Birthday FRO...","List ID, sex and date of birth of patient whos...","[0.015778781846165657, 0.005091146100312471, -...",[],"SELECT Patient.ID, Patient.SEX, Patient.Birthd...","SELECT Patient.ID, Patient.SEX, Patient.Birthd...","{(4862013, F, 1964-01-29), (2355809, F, 1938-0...",1,"SELECT Patient.ID, Patient.SEX, Patient.Birthd..."
1,frpm (\n\tCDSCode TEXT -- CDSCode -- 011001701...,california_schools,SELECT GSserved FROM schools WHERE City = 'Ade...,What is the most common type of grade span ser...,"[0.014656159095466137, 0.02713850326836109, -0...",[],"SELECT GSserved, COUNT(*) AS count FROM school...","SELECT GSserved, COUNT(*) AS count FROM school...","{(K-6, 5)}",0,"SELECT GSserved, COUNT(*) AS count FROM school..."
2,"""atom"" (\n\t""atom_id"" TEXT -- the unique id of...",toxicology,SELECT COUNT(T.molecule_id) FROM molecule AS T...,Calculate the total carcinogenic molecules for...,"[-0.026426946744322777, -0.053564295172691345,...",[],SELECT COUNT(*) AS total_carcinogenic_molecule...,SELECT COUNT(*) AS total_carcinogenic_molecule...,"{(7,)}",1,SELECT COUNT(*) AS total_carcinogenic_molecule...
3,Examination (\n\tID INT -- identification of t...,thrombosis_prediction,SELECT COUNT(DISTINCT T1.ID) FROM Patient AS T...,"Excluding all P only ANA Pattern patients, how...","[-0.024667911231517792, 0.005138586275279522, ...",[],SELECT COUNT(DISTINCT Patient.ID) FROM Patient...,SELECT COUNT(DISTINCT Patient.ID) FROM Patient...,"{(3,)}",1,SELECT COUNT(DISTINCT Patient.ID) FROM Patient...
4,"badges (\n\tId INT -- 1, 2, 3, 4, 5\n\tUserId ...",codebase_community,SELECT COUNT(id) FROM users WHERE Reputation >...,How many users whose reputations are higher th...,"[-0.015890464186668396, -0.012033837847411633,...",[],SELECT COUNT(*) FROM users WHERE Reputation > ...,SELECT COUNT(*) FROM users WHERE Reputation > ...,"{(44,)}",1,SELECT COUNT(*) FROM users WHERE Reputation > ...


In [18]:
df_reference_questions[df_reference_questions['question'].str.contains('How many word that has')].iloc[0]#['question']

db_id                                            language_corpus
question       How many word that has number of different wor...
evidence                                            This is not;
SQL            SELECT COUNT(T2.wid) FROM pages AS T1 INNER JO...
embeddings     [-0.009390588849782944, 0.0025102547369897366,...
ddl            langs (\n\tlid INT -- 1\n\tlang TEXT -- ca\n\t...
user_prompt    langs (\n\tlid INT -- 1\n\tlang TEXT -- ca\n\t...
Name: 5779, dtype: object

In [19]:
row = df_pred[df_pred['execution_accuracy'] == 0].iloc[-2]
# print row's index
print(row.name)
print(row['user_prompt'])
print(row['ideal_assistant_response'])
print(row['first_predicted_query'])

495
frpm (
	CDSCode TEXT -- CDSCode -- 01100170109835, 01100170112607, 01100170118489, 01100170123968
	"Academic Year" TEXT -- 2014-2015
	"County Code" TEXT -- 01
	"District Code" INT -- 10017
	"School Code" TEXT -- 0109835, 0112607, 0118489, 0123968, 0124172
	"County Name" TEXT -- Alameda
	"District Name" TEXT -- Alameda County Office of Education
	"School Name" TEXT -- FAME Public Charter, Envision Academy for Arts & Technology
	"District Type" TEXT -- County Office of Education (COE)
	"School Type" TEXT -- K-12 Schools (Public), High Schools (Public), Elementary Schools (Public)
	"Educational Option Type" TEXT -- Traditional
	"NSLP Provision Status" TEXT -- Breakfast Provision 2
	"Charter School (Y/N)" INT -- 1
	"Charter School Number" TEXT -- 0728, 0811, 1049, 1284, 1296
	"Charter Funding Type" TEXT -- Directly funded
	IRC INT -- 1
	"Low Grade" TEXT -- K, 9
	"High Grade" TEXT -- 12, 8
	"Enrollment (K-12)" FLOAT -- 1087.0, 395.0, 244.0, 191.0, 257.0
	"Free Meal Count (K-12)" FLOAT -

In [20]:
df_pred['execution_accuracy'].mean()

0.582