In [51]:
import time
import json
import openai
import re
import sqlite3
import pandas as pd
import numpy as np

from os import path
from openai.api_requestor import error

In [64]:
def get_schema(db_id):
    PATH = 'spider/database/'
    
    path_to_file   = PATH + db_id + '/schema.sql'
    path_to_file_2 = PATH + db_id + '/' + db_id + '.sqlite'
    
    if path.exists(path_to_file):
        x = open(path_to_file, 'r').read()
        # Remove comment lines
        x = re.sub(r"^/\*.*\n", "", x, flags=re.MULTILINE)
        x = re.sub(r'^--.*(\n|$)', '', x, flags=re.MULTILINE)
        x = re.sub(r"^/\*.*\*/", "", x)
        
        x = re.sub(r'CREATE TABLE \t', 'CREATE TABLE ', x)
        
        arr = x.split(';')
        create_tables = []
        for i, stmt in enumerate(arr):
            stmt = stmt.strip()
            if stmt.lower().startswith('create'):
                create_tables.append(stmt.replace('IF NOT EXISTS ', '').replace('\n', ''))
        return ';'.join(create_tables)
      
    elif path.exists(path_to_file_2):
        # Connect to the sqlite file
        conn = sqlite3.connect(path_to_file_2)

        # Read the schema of the twitter_1.sqlite file into a dataframe
        df_conn = pd.read_sql_query("SELECT sql FROM sqlite_master WHERE type='table';", conn)
        
        # Close the connection
        conn.close()
        
        create_tables = []
        for idx, row in df_conn.iterrows():
            row['sql'] = row['sql'].strip()
            if row['sql'].lower().startswith('create'):
                create_tables.append(row['sql'].replace('IF NOT EXISTS ', '').replace('\n', ''))
        return ';'.join(create_tables)
    
    else: return None

def short_open_ai_prompt(x):
    return f'''Schema: {x['schema']}\nQuestion: {x['question']}\n\n###\n\n'''

# From OpenAI: The completion should start with a whitespace character (` `). 
# This tends to produce better results due to the tokenization we use.
def open_ai_completion(x):
    return f" {x['query']}\n" 

def call_model(row, engine):
    prompt = row["open_ai_prompt"]
    while True:
        try:
            completions = openai.Completion.create(
                engine=engine,
                prompt=prompt,
                max_tokens=1024,
                n=1,
                stop=["\n"],
                temperature=0.5
            )

            print(completions.choices[0].text)
            return completions.choices[0].text
    
        except error.ServiceUnavailableError:
            print('ServiceUnavailableError')
            time.sleep(1)
            continue
            
        except error.InvalidRequestError:
            print('InvalidRequestError: too many tokens')
            time.sleep(1)
            break

# Data Preparation

In [3]:
df = pd.read_json('spider/train_spider.json')
# df = df[~df['query'].str.contains('JOIN')]
# df = df[df['query'].str.count('FROM') <= 1]

# Set schema
df['schema']         = df.apply(lambda x: get_schema(x['db_id']), axis=1)
df['schema'] = df['schema'].apply(lambda x: re.sub('(?i) REFERENCES.*?(;|$)', ';', x))
# # Extract the name of the table after FROM
# df['table']  = df['query'].apply(lambda x: re.sub('.*FROM', 'FROM', x))\
#                           .apply(lambda x: re.sub('^FROM\s+(\S+).*', r'\1', x).rstrip(';')).str.lower()
# # Extract the 'create table' statement just for that table
# df['schema'] = df.apply(lambda x: ''.join(re.findall(r"(create table `{}`.*?;)".format(x["table"]), 
#                                                      x["schema"].lower())) or 
#                                   ''.join(re.findall(r"(create table `{}`.*?$)".format(x["table"]), 
#                                                      x["schema"].lower())) or 
#                                   ''.join(re.findall(r"(create table {}.*?;)".format(x["table"]), 
#                                                      x["schema"].lower())) or 
#                                   ''.join(re.findall(r"(create table {}.*?$)".format(x["table"]), 
#                                                      x["schema"].lower())) or 
#                                   ''.join(re.findall(r"(create table \"{}\".*?;)".format(x["table"]), 
#                                                      x["schema"].lower())) or
#                                   ''.join(re.findall(r"(create table \"{}\".*?$)".format(x["table"]), 
#                                                      x["schema"].lower())), axis=1)
df['schema'] = df['schema'].apply(lambda x: re.sub('NOT NULL', '', x))

# Set Open AI prompt and completion
df['open_ai_prompt'] = df.apply(lambda x: short_open_ai_prompt(x), axis=1)
df['open_ai_completion'] = df.apply(lambda x: open_ai_completion(x), axis=1)

# Randomize at the db level
df_db_id = pd.DataFrame(df['db_id'].unique(), columns=['db_id'])
np.random.seed(240956) #set seed
df_db_id['train_test'] = np.random.choice(['train','test'], df_db_id.shape[0], p=[0.8, 0.2])
df = df.merge(df_db_id, on='db_id')
df_train = df[df['train_test'] == 'train'][['open_ai_prompt', 'open_ai_completion']].copy()
df_test  = df[df['train_test'] == 'test'][['open_ai_prompt', 'open_ai_completion']].copy()

In [4]:
df[df['schema'] == '']['db_id'].unique()

array([], dtype=object)

In [5]:
print(df.iloc[110]['open_ai_prompt'])
print(df.iloc[110]['open_ai_completion'])

Schema: CREATE TABLE station (    id INTEGER PRIMARY KEY,    name TEXT,    lat NUMERIC,    long NUMERIC,    dock_count INTEGER,    city TEXT,    installation_date TEXT);CREATE TABLE status (    station_id INTEGER,    bikes_available INTEGER,    docks_available INTEGER,    time TEXT,    FOREIGN KEY (station_id);CREATE TABLE trip (    id INTEGER PRIMARY KEY,    duration INTEGER,    start_date TEXT,    start_station_name TEXT, -- this should be removed    start_station_id INTEGER,    end_date TEXT,    end_station_name TEXT, -- this should be removed    end_station_id INTEGER,    bike_id INTEGER,    subscription_type TEXT,    zip_code INTEGER);CREATE TABLE weather (    date TEXT,    max_temperature_f INTEGER,    mean_temperature_f INTEGER,    min_temperature_f INTEGER,    max_dew_point_f INTEGER,    mean_dew_point_f INTEGER,    min_dew_point_f INTEGER,    max_humidity INTEGER,    mean_humidity INTEGER,    min_humidity INTEGER,    max_sea_level_pressure_inches NUMERIC,    mean_sea_level_pre

In [6]:
df_train.shape

(5944, 2)

In [7]:
df_test.shape

(1056, 2)

# Training

In [8]:
# Put the training data into jsonl format
data = []
for idx, row in df_train.iterrows():
    data.append({"prompt": row["open_ai_prompt"], "completion": row["open_ai_completion"]})

timestr = time.strftime("%Y%m%d-%H%M%S")
with open("spider_open_ai_fine_tuning_" + timestr + ".jsonl", "w") as outfile:
    for obj in data:
        json.dump(obj, outfile)
        outfile.write("\n")
        
print("spider_open_ai_fine_tuning_" + timestr + ".jsonl")

spider_open_ai_fine_tuning_20230126-235556.jsonl


In [9]:
#!openai tools fine_tunes.prepare_data -f spider_open_ai_fine_tuning_20230126-235556.jsonl
# - There are 6 duplicated prompt-completion sets. These are rows: [1296, 2000, 2097, 2984, 2985, 3799]

In [10]:
df_train_dedup = df_train.copy()
df_train_dedup = df_train_dedup.drop(df_train.index[[1296, 2000, 2097, 2984, 2985, 3799]])
df_train_dedup[(df_train_dedup['open_ai_prompt'] == df_train_dedup.iloc[1296]['open_ai_prompt']) |
               (df_train_dedup['open_ai_prompt'] == df_train_dedup.iloc[2000]['open_ai_prompt']) |
               (df_train_dedup['open_ai_prompt'] == df_train_dedup.iloc[2097]['open_ai_prompt']) |
               (df_train_dedup['open_ai_prompt'] == df_train_dedup.iloc[2984]['open_ai_prompt']) |
               (df_train_dedup['open_ai_prompt'] == df_train_dedup.iloc[2985]['open_ai_prompt']) |
               (df_train_dedup['open_ai_prompt'] == df_train_dedup.iloc[3799]['open_ai_prompt'])]

Unnamed: 0,open_ai_prompt,open_ai_completion
1421,Schema: CREATE TABLE classroom\t(building\t\tv...,SELECT title FROM course WHERE course_id NOT ...
2344,"Schema: CREATE TABLE ""Campuses"" (\t""Id"" INTEGE...",SELECT campus FROM campuses WHERE county = ...
2442,Schema: create table Movie(\tmID int primary k...,SELECT count(*) FROM Reviewer\n
3489,Schema: CREATE TABLE `regions` ( `REGION_ID` ...,SELECT job_id FROM employees GROUP BY job_id ...
3490,Schema: CREATE TABLE `regions` ( `REGION_ID` ...,SELECT job_id FROM employees GROUP BY job_id ...
4405,Schema: CREATE TABLE Person ( name varchar(20...,SELECT count(DISTINCT city) FROM Person\n


In [11]:
# Put the training data into jsonl format
data = []
for idx, row in df_train_dedup.iterrows():
    data.append({"prompt": row["open_ai_prompt"], "completion": row["open_ai_completion"]})

timestr = time.strftime("%Y%m%d-%H%M%S")
with open("spider_open_ai_fine_tuning_" + timestr + ".jsonl", "w") as outfile:
    for obj in data:
        json.dump(obj, outfile)
        outfile.write("\n")
        
print("spider_open_ai_fine_tuning_" + timestr + ".jsonl")

spider_open_ai_fine_tuning_20230126-235844.jsonl


In [12]:
#!openai tools fine_tunes.prepare_data -f spider_open_ai_fine_tuning_20230126-235844.jsonl

In [None]:
# !openai api fine_tunes.create -t "spider_open_ai_fine_tuning_20230126-235844.jsonl" -m davinci
# [2023-01-27 00:01:35] Created fine-tune: ft-xv8wuk75QTAbYGVPgAC2vHYh
# [2023-01-27 00:05:10] Fine-tune costs $376.19
# [2023-01-27 00:05:17] Fine-tune enqueued. Queue number: 0
# [2023-01-27 00:05:27] Fine-tune started

In [16]:
!openai api fine_tunes.list

{
  "data": [
    {
      "created_at": 1672357089,
      "fine_tuned_model": "davinci:ft-mercator-2022-12-29-23-47-24",
      "hyperparams": {
        "batch_size": 1,
        "learning_rate_multiplier": 0.1,
        "n_epochs": 4,
        "prompt_loss_weight": 0.01
      },
      "id": "ft-ZMShnMrnhxzay9r0mXXEcxyI",
      "model": "davinci",
      "object": "fine-tune",
      "organization_id": "org-ePmgB4qVo14GgUKdUQci6IGz",
      "result_files": [
        {
          "bytes": 7017,
          "created_at": 1672357645,
          "filename": "compiled_results.csv",
          "id": "file-FdDaGVVXGWC08jN4NwodgYnY",
          "object": "file",
          "purpose": "fine-tune-results",
          "status": "processed",
          "status_details": null
        }
      ],
      "status": "succeeded",
      "training_files": [
        {
          "bytes": 3380,
          "created_at": 1672357089,
          "filename": "openai_classification_fine_tuning.txt",
 

# Out of Sample Testing

In [69]:
tmp = df_test.copy()

# tmp['model_response'] = ''
data = []
for idx, row in tmp.iterrows():
    print(idx)
    if row['model_response'] != '':
        print('already completed')
        data.append(row['model_response'])
        continue
    
    new_response = call_model(row, engine="davinci:ft-mercator-2023-01-27-10-13-07")
    data.append(new_response)
    tmp.loc[idx,'model_response'] = new_response

301
already completed
302
already completed
303
already completed
304
already completed
305
already completed
306
already completed
307
already completed
308
already completed
309
already completed
310
already completed
311
already completed
312
already completed
313
already completed
314
already completed
315
already completed
316
already completed
317
already completed
318
already completed
319
already completed
320
already completed
321
already completed
322
already completed
323
already completed
324
already completed
325
already completed
326
already completed
327
already completed
328
already completed
329
already completed
330
already completed
331
already completed
332
already completed
333
already completed
334
already completed
335
already completed
336
already completed
337
already completed
338
already completed
339
already completed
340
already completed
341
already completed
342
already completed
891
already completed
892
already completed
893
already completed
894
alread

 SELECT T1.zip_postcode FROM addresses AS T1 JOIN customers AS T2 ON T1.address_id  =  T2.customer_address_id WHERE T2.first_name  =  "Carole" AND T2.last_name  =  "Bernhard"
6687
 SELECT t1.zip_postcode FROM addresses AS t1 JOIN customers AS t2 ON t1.address_id  =  t2.customer_address_id WHERE t2.first_name  =  "Carole" AND t2.last_name  =  "Bernhard"
6688
 SELECT t3.city FROM customers AS t1 JOIN customer_addresses AS t2 ON t1.customer_address_id  =  t2.address_id JOIN addresses AS t3 ON t2.address_id  =  t3.address_id GROUP BY t3.city ORDER BY count(*) DESC LIMIT 1
6689
 SELECT t3.city FROM customers AS t1 JOIN addresses AS t2 ON t1.customer_address_id  =  t2.address_id JOIN addresses AS t3 ON t2.address_id  =  t3.address_id GROUP BY t3.city ORDER BY count(*) DESC LIMIT 1
6690
 SELECT sum(t1.amount_payment) FROM customer_payments AS t1 JOIN customers AS t2 ON t1.customer_id  =  t2.customer_id WHERE t2.first_name  =  "Carole" AND t2.last_name  =  "Bernhard"
6691
 SELECT sum(t1.amount

 SELECT country ,  count(*) FROM airports GROUP BY country ORDER BY count(*) DESC
6849
 SELECT city ,  count(*) FROM airports WHERE country  =  "United States" GROUP BY city ORDER BY count(*) DESC
6850
 SELECT count(*) ,  city FROM airports WHERE country  =  "United States" GROUP BY city ORDER BY count(*) DESC
6851
 SELECT city FROM airports WHERE country  =  "United States" GROUP BY city HAVING count(*)  >  3
6852
 SELECT count(*) FROM airports WHERE country  =  "United States" AND count(*)  >  3
6853
 SELECT count(*) FROM airports WHERE name LIKE "%%" GROUP BY name HAVING count(*)  >  3
6854
 SELECT count(*) FROM airports WHERE city LIKE "%%" GROUP BY city
6855
 SELECT count(*) ,  name FROM airports GROUP BY name HAVING count(*)  >  1
6856
 SELECT count(*) ,  name FROM airports GROUP BY name HAVING count(*)  >  1
6857
 SELECT city FROM airports GROUP BY city HAVING count(*)  >  2 ORDER BY count(*) DESC
6858
 SELECT city FROM airports GROUP BY city HAVING count(*)  >  2 ORDER BY count

In [71]:
df_test['model_response'] = tmp['model_response']

In [262]:
df_test.to_csv('df_test_davinci_ft-mercator-2023-01-27-10-13-07.csv')

In [72]:
print(tmp.loc[5090]['open_ai_prompt'])

Schema: CREATE TABLE Ref_Payment_Methods (payment_method_code CHAR(10) ,payment_method_description VARCHAR(80),PRIMARY KEY (payment_method_code),UNIQUE (payment_method_code));CREATE TABLE Ref_Service_Types (Service_Type_Code CHAR(15) ,Parent_Service_Type_Code CHAR(15),Service_Type_Description VARCHAR(255),PRIMARY KEY (Service_Type_Code),UNIQUE (Service_Type_Code));CREATE TABLE Addresses (Address_ID VARCHAR(100) ,Line_1 VARCHAR(255),Line_2 VARCHAR(255),City_Town VARCHAR(255),State_County VARCHAR(255),Other_Details VARCHAR(255),PRIMARY KEY (Address_ID),UNIQUE (Address_ID));CREATE TABLE Products (Product_ID VARCHAR(100) ,Product_Name VARCHAR(255),Product_Price DECIMAL(20,4),Product_Description VARCHAR(255),Other_Product_Service_Details VARCHAR(255),PRIMARY KEY (Product_ID),UNIQUE (Product_ID));CREATE TABLE Marketing_Regions (Marketing_Region_Code CHAR(15) ,Marketing_Region_Name VARCHAR(255) ,Marketing_Region_Descriptrion VARCHAR(255) ,Other_Details VARCHAR(255),PRIMARY KEY (Marketing_Regi

In [74]:
df_test['model_response'] = df_test['model_response'].fillna('')

In [242]:
# ignore: 
# - whitespace
# - capitalization
# - trailing semicolons
# - the difference between single and double quotes
# - difference between 'distinct(...)' and 'distinct ...'
# - explictly stating ASC (the ORDER BY default)
tmp = df_test[df_test['model_response'] != ''].copy()

tmp['correct'] = np.where(tmp['open_ai_completion'].str
                                                                         .strip()
                                                                         .str
                                                                         .lower()
                                                                         .apply(lambda x: re.sub("distinct\((.*)\)", "distinct\\1", 
                                                                                          re.sub('orderby(.*)asc', 'orderby\\1',
                                                                                          re.sub('"', "'",
                                                                                          re.sub('\s+', '', x))))
                                                                                .rstrip(';')) == 
                                     tmp['model_response'].str
                                                                     .strip()
                                                                     .str
                                                                     .lower()
                                                                         .apply(lambda x: re.sub("distinct\((.*)\)", "distinct\\1", 
                                                                                          re.sub('orderby(.*)asc', 'orderby\\1',
                                                                                          re.sub('"', "'",
                                                                                          re.sub('\s+', '', x))))
                                                                            .rstrip(';')), 1, 0)
tmp['correct'].mean()

0.4394250513347023

In [253]:
# 0 leaves off "distinct"
# 10 adds a bad GROUP BY
# 20 - provided answer is bad
# 30 completely wrong
# 40 wrong table
# 50 - right, just reverses join order
# 60 leaves off a join
# 70 leaves off a join
# 80 leaves off a column, adds two joins it doesn't need
# 90 - right, just reverses column and join order
i = 90

print(tmp[tmp['correct'] == 0]['open_ai_prompt'].iloc[i])
print(tmp[tmp['correct'] == 0]['open_ai_completion'].iloc[i])
print(tmp[tmp['correct'] == 0]['model_response'].iloc[i])

Schema: CREATE TABLE "region" ("Region_ID" int,"Region_name" text,"Date" text,"Label" text,"Format" text,"Catalogue" text,PRIMARY KEY ("Region_ID"));CREATE TABLE "party" ("Party_ID" int,"Minister" text,"Took_office" text,"Left_office" text,"Region_ID" int,"Party_name" text,PRIMARY KEY ("Party_ID"),FOREIGN KEY (`Region_ID`);CREATE TABLE "member" ("Member_ID" int,"Member_Name" text,"Party_ID" text,"In_office" text,PRIMARY KEY ("Member_ID"),FOREIGN KEY (`Party_ID`);CREATE TABLE "party_events" ("Event_ID" int,"Event_Name" text,"Party_ID" int,"Member_in_charge_ID" int,PRIMARY KEY ("Event_ID"),FOREIGN KEY (`Party_ID`);
Question: How many members are in each party?

###


 SELECT T2.party_name ,  count(*) FROM Member AS T1 JOIN party AS T2 ON T1.party_id  =  T2.party_id GROUP BY T1.party_id

 SELECT count(*) ,  T2.Party_name FROM party AS T1 JOIN member AS T2 ON T1.Party_ID  =  T2.Party_ID GROUP BY T2.Party_name


## How does the multi-table query model perform on single table queries?

In [223]:
# ignore: 
# - whitespace
# - capitalization
# - trailing semicolons
# - the difference between single and double quotes
# - coerce 'distinct(...)' to 'distinct ...'
# - explictly stating ASC, the ORDER BY default
df_test_single = df_test[~df_test['open_ai_completion'].str.contains('JOIN')].copy()
df_test_single = df_test_single[df_test_single['open_ai_completion'].str.count('FROM') <= 1]
df_test_single = df_test_single[df_test_single['model_response'] != '']

df_test_single['correct'] = np.where(df_test_single['open_ai_completion'].str
                                                                         .strip()
                                                                         .str
                                                                         .lower()
                                                                         .apply(lambda x: re.sub("distinct\((.*)\)", "distinct\\1", 
                                                                                          re.sub('orderby(.*)asc', 'orderby\\1',
                                                                                          re.sub('"', "'",
                                                                                          re.sub('\s+', '', x))))
                                                                                .rstrip(';')) == 
                                     df_test_single['model_response'].str
                                                                     .strip()
                                                                     .str
                                                                     .lower()
                                                                         .apply(lambda x: re.sub("distinct\((.*)\)", "distinct\\1", 
                                                                                          re.sub('orderby(.*)asc', 'orderby\\1',
                                                                                          re.sub('"', "'",
                                                                                          re.sub('\s+', '', x))))
                                                                            .rstrip(';')), 1, 0)
df_test_single['correct'].mean()

0.6713709677419355

In [241]:
# 0 leaves off "distinct"
# 10 confuses which column goes in SELECT vs WHERE clause
# 20 wrong table
# 30 wrong table
# 40 sorts by "year" instead of "date" to find "most recent"
# 50 adds a join it doesn't need and sorts by the wrong column
# 60 adds a join it doesn't need
# 70 wrong table, but I'd argue the model gets it right and the provided answer is wrong
# 80 leave off a column, but I'd argue the model gets it right and the provided answer is wrong
# 90 is right, just reverses the column order
i = 90

print(df_test_single[df_test_single['correct'] == 0]['open_ai_prompt'].iloc[i])
print(df_test_single[df_test_single['correct'] == 0]['open_ai_completion'].iloc[i])
print(df_test_single[df_test_single['correct'] == 0]['model_response'].iloc[i])

Schema: create table genre(	g_name varchar2(20) not null,	rating varchar2(10),	most_popular_in varchar2(50),	primary key(g_name));create table artist(	artist_name varchar2(50) not null,	country varchar2(20),	gender varchar2(20),	preferred_genre varchar2(50),	constraint a_name primary key(artist_name),	foreign key(preferred_genre);create table files(	f_id number(10) not null,	artist_name varchar2(50),	file_size varchar2(20),	duration varchar2(20),	formats varchar2(20),	primary key(f_id),	foreign key(artist_name);create table song(	song_name varchar2(50),	artist_name varchar2(50),	country varchar2(20),	f_id number(10),    	genre_is varchar2(20),	rating number(10) check(rating>0 and rating<11),	languages varchar2(20),	releasedate Date, 	resolution number(10) not null,	constraint s_name primary key(song_name),	foreign key(artist_name);
Question: List the names of all genres in alphabetical oder, together with its ratings.

###


 SELECT g_name ,  rating FROM genre ORDER BY g_name

 SELECT 

In [221]:
df_test_single[df_test_single['correct'] == 0].iloc[50]

open_ai_prompt        Schema: CREATE TABLE "circuits" ("circuitId" I...
open_ai_completion     SELECT DISTINCT forename FROM drivers ORDER B...
model_response         SELECT DISTINCT forename FROM drivers ORDER B...
correct                                                               0
Name: 2199, dtype: object

In [219]:
df_test_single[df_test_single['correct'] == 0].iloc[50]['open_ai_completion']

' SELECT DISTINCT forename FROM drivers ORDER BY forename ASC\n'

In [220]:
df_test_single[df_test_single['correct'] == 0].iloc[50]['model_response']

' SELECT DISTINCT forename FROM drivers ORDER BY forename'

In [180]:
tmp = df_test_single[(df_test_single['open_ai_completion'].str.contains("ASC")) & (df_test_single['correct'] == 0)]
# tmp = df_test_single[(df_test_single['open_ai_completion'].str.contains("DISTINCT")) & (df_test_single['correct'] == 0)]
tmp
i = 4
print(tmp.iloc[i]['open_ai_completion'])
print(tmp.iloc[i]['model_response'])

 SELECT DISTINCT forename FROM drivers ORDER BY forename ASC

 SELECT DISTINCT forename FROM drivers ORDER BY forename


In [131]:
df_test_single[df_test_single['model_response'] != ''].shape

(496, 4)

# How does raw da vinci compare?

In [260]:
tmp = df_test.copy()

tmp['model_response'] = ''
data = []
for idx, row in tmp.iterrows():
    print(idx)
    if row['model_response'] != '':
        print('already completed')
        data.append(row['model_response'])
        continue
    
    new_response = call_model(row, engine="davinci")
    data.append(new_response)
    tmp.loc[idx,'model_response'] = new_response

301


AuthenticationError: Incorrect API key provided: sk-Xwp8P***************************************98o5. You can find your API key at https://beta.openai.com.