In [1]:
import time
import json
import openai
import re
import sqlite3
import pandas as pd
import numpy as np

from os import path

In [3]:
def get_schema(db_id):
    PATH = 'spider/database/'
    
    path_to_file   = PATH + db_id + '/schema.sql'
    path_to_file_2 = PATH + db_id + '/' + db_id + '.sqlite'
    
    if path.exists(path_to_file):
        x = open(path_to_file, 'r').read()
        # Remove comment lines
        x = re.sub(r"^/\*.*\n", "", x, flags=re.MULTILINE)
        x = re.sub(r'^--.*(\n|$)', '', x, flags=re.MULTILINE)
        x = re.sub(r"^/\*.*\*/", "", x)
        
        x = re.sub(r'CREATE TABLE \t', 'CREATE TABLE ', x)
        
        arr = x.split(';')
        create_tables = []
        for i, stmt in enumerate(arr):
            stmt = stmt.strip()
            if stmt.lower().startswith('create'):
                create_tables.append(stmt.replace('IF NOT EXISTS ', '').replace('\n', ''))
        return ';'.join(create_tables)
      
    elif path.exists(path_to_file_2):
        # Connect to the sqlite file
        conn = sqlite3.connect(path_to_file_2)

        # Read the schema of the twitter_1.sqlite file into a dataframe
        df_conn = pd.read_sql_query("SELECT sql FROM sqlite_master WHERE type='table';", conn)
        
        # Close the connection
        conn.close()
        
        create_tables = []
        for idx, row in df_conn.iterrows():
            row['sql'] = row['sql'].strip()
            if row['sql'].lower().startswith('create'):
                create_tables.append(row['sql'].replace('IF NOT EXISTS ', '').replace('\n', ''))
        return ';'.join(create_tables)
    
    else: return None

def short_open_ai_prompt(x):
    return f'''Schema: {x['schema']}\nQuestion: {x['question']}\n\n###\n\n'''

# From OpenAI: The completion should start with a whitespace character (` `). 
# This tends to produce better results due to the tokenization we use.
def open_ai_completion(x):
    return f" {x['query']}\n" 

def call_model(row, engine):
    prompt = row["open_ai_prompt"]
    while True:
        try:
            completions = openai.Completion.create(
                engine=engine,
                prompt=prompt,
                max_tokens=1024,
                n=1,
                stop=["\n"],
                temperature=0.5
            )

            print(completions.choices[0].text)
            return completions.choices[0].text
    
        except error.ServiceUnavailableError:
            print('ServiceUnavailableError')
            time.sleep(1)
            continue
            
        except error.InvalidRequestError:
            print('InvalidRequestError: too many tokens')
            break

# Data Preparation

In [4]:
df = pd.read_json('spider/train_spider.json')
df = df[~df['query'].str.contains('JOIN')]
df = df[df['query'].str.count('FROM') <= 1]

# Set schema
df['schema']         = df.apply(lambda x: get_schema(x['db_id']), axis=1)
df['schema'] = df['schema'].apply(lambda x: re.sub('(?i) REFERENCES.*?(;|$)', ';', x))
# Extract the name of the table after FROM
df['table']  = df['query'].apply(lambda x: re.sub('.*FROM', 'FROM', x))\
                          .apply(lambda x: re.sub('^FROM\s+(\S+).*', r'\1', x).rstrip(';')).str.lower()
# Extract the 'create table' statement just for that table
df['schema'] = df.apply(lambda x: ''.join(re.findall(r"(create table `{}`.*?;)".format(x["table"]), 
                                                     x["schema"].lower())) or 
                                  ''.join(re.findall(r"(create table `{}`.*?$)".format(x["table"]), 
                                                     x["schema"].lower())) or 
                                  ''.join(re.findall(r"(create table {}.*?;)".format(x["table"]), 
                                                     x["schema"].lower())) or 
                                  ''.join(re.findall(r"(create table {}.*?$)".format(x["table"]), 
                                                     x["schema"].lower())) or 
                                  ''.join(re.findall(r"(create table \"{}\".*?;)".format(x["table"]), 
                                                     x["schema"].lower())) or
                                  ''.join(re.findall(r"(create table \"{}\".*?$)".format(x["table"]), 
                                                     x["schema"].lower())), axis=1)
df['schema'] = df['schema'].apply(lambda x: re.sub('NOT NULL', '', x))

# Set Open AI prompt and completion
df['open_ai_prompt'] = df.apply(lambda x: short_open_ai_prompt(x), axis=1)
df['open_ai_completion'] = df.apply(lambda x: open_ai_completion(x), axis=1)

# Randomize at the db level
df_db_id = pd.DataFrame(df['db_id'].unique(), columns=['db_id'])
np.random.seed(240956) #set seed
df_db_id['train_test'] = np.random.choice(['train','test'], df_db_id.shape[0], p=[0.8, 0.2])
df = df.merge(df_db_id, on='db_id')
df_train = df[df['train_test'] == 'train'][['open_ai_prompt', 'open_ai_completion']].copy()
df_test  = df[df['train_test'] == 'test'][['open_ai_prompt', 'open_ai_completion']].copy()

In [5]:
df[df['schema'] == '']['db_id'].unique()

array([], dtype=object)

In [6]:
print(df.iloc[110]['open_ai_prompt'])
print(df.iloc[110]['open_ai_completion'])

Schema: create table weather (    date text,    max_temperature_f integer,    mean_temperature_f integer,    min_temperature_f integer,    max_dew_point_f integer,    mean_dew_point_f integer,    min_dew_point_f integer,    max_humidity integer,    mean_humidity integer,    min_humidity integer,    max_sea_level_pressure_inches numeric,    mean_sea_level_pressure_inches numeric,    min_sea_level_pressure_inches numeric,    max_visibility_miles integer,    mean_visibility_miles integer,    min_visibility_miles integer,    max_wind_speed_mph integer,    mean_wind_speed_mph integer,    max_gust_speed_mph integer,    precipitation_inches integer,    cloud_cover integer,    events text,    wind_dir_degrees integer,    zip_code integer)
Question: What are the dates in which the mean sea level pressure was between 30.3 and 31?

###


 SELECT date FROM weather WHERE mean_sea_level_pressure_inches BETWEEN 30.3 AND 31



In [7]:
df_train.shape

(3071, 2)

In [8]:
df_test.shape

(538, 2)

# Training

In [None]:
# Put the training data into jsonl format
data = []
for idx, row in df_train.iterrows():
    data.append({"prompt": row["open_ai_prompt"], "completion": row["open_ai_completion"]})

timestr = time.strftime("%Y%m%d-%H%M%S")
with open("spider_open_ai_fine_tuning_" + timestr + ".jsonl", "w") as outfile:
    for obj in data:
        json.dump(obj, outfile)
        outfile.write("\n")
        
print("spider_open_ai_fine_tuning_" + timestr + ".jsonl")

In [None]:
#!openai tools fine_tunes.prepare_data -f spider_open_ai_fine_tuning_20230125-135407.jsonl
# - There are 3 duplicated prompt-completion sets. These are rows: [1155, 1590, 1591]

In [None]:
df_train_dedup = df_train.copy()
df_train_dedup = df_train_dedup.drop(df_train.index[[1155, 1590, 1591]])
df_train_dedup[(df_train_dedup['open_ai_prompt'] == df_train_dedup.iloc[1155]['open_ai_prompt']) |
               (df_train_dedup['open_ai_prompt'] == df_train_dedup.iloc[1590]['open_ai_prompt']) |
               (df_train_dedup['open_ai_prompt'] == df_train_dedup.iloc[1591]['open_ai_prompt'])]

In [None]:
# Put the training data into jsonl format
data = []
for idx, row in df_train_dedup.iterrows():
    data.append({"prompt": row["open_ai_prompt"], "completion": row["open_ai_completion"]})

timestr = time.strftime("%Y%m%d-%H%M%S")
with open("spider_open_ai_fine_tuning_" + timestr + ".jsonl", "w") as outfile:
    for obj in data:
        json.dump(obj, outfile)
        outfile.write("\n")
        
print("spider_open_ai_fine_tuning_" + timestr + ".jsonl")

In [None]:
#!openai tools fine_tunes.prepare_data -f spider_open_ai_fine_tuning_20230125-135903.jsonl

In [None]:
#!openai api fine_tunes.create -t "spider_open_ai_fine_tuning_20230125-135903.jsonl" -m davinci

In [9]:
!openai api fine_tunes.list

{
  "data": [
    {
      "created_at": 1672357089,
      "fine_tuned_model": "davinci:ft-mercator-2022-12-29-23-47-24",
      "hyperparams": {
        "batch_size": 1,
        "learning_rate_multiplier": 0.1,
        "n_epochs": 4,
        "prompt_loss_weight": 0.01
      },
      "id": "ft-ZMShnMrnhxzay9r0mXXEcxyI",
      "model": "davinci",
      "object": "fine-tune",
      "organization_id": "org-ePmgB4qVo14GgUKdUQci6IGz",
      "result_files": [
        {
          "bytes": 7017,
          "created_at": 1672357645,
          "filename": "compiled_results.csv",
          "id": "file-FdDaGVVXGWC08jN4NwodgYnY",
          "object": "file",
          "purpose": "fine-tune-results",
          "status": "processed",
          "status_details": null
        }
      ],
      "status": "succeeded",
      "training_files": [
        {
          "bytes": 3380,
          "created_at": 1672357089,
          "filename": "openai_classification_fine_tuning.txt",
 

## In Sample Testing

## Out of Sample Testing

In [12]:
tmp = df_test.copy()

tmp['model_response'] = ''
data = []
for idx, row in tmp.iterrows():
    print(idx)
    if row['model_response'] != '':
        print('already completed')
        data.append(row['model_response'])
        continue
    
    new_response = call_model(row, engine="davinci:ft-mercator-2023-01-25-23-01-53")
    data.append(new_response)
    tmp.loc[idx,'model_response'] = new_response

173
 SELECT DISTINCT catalog_entry_name FROM CATALOG_CONTENTS
174
 SELECT DISTINCT catalog_entry_name FROM Catalog_Contents
175
 SELECT attribute_data_type FROM attribute_definitions GROUP BY attribute_data_type HAVING count(*)  >  3
176
 SELECT attribute_data_type FROM attribute_definitions GROUP BY attribute_data_type HAVING count(*)  >  3
177
 SELECT attribute_data_type FROM Attribute_Definitions WHERE attribute_name  =  "Green"
178
 SELECT attribute_data_type FROM Attribute_Definitions WHERE attribute_name  =  "Green"
179
 SELECT catalog_level_name ,  catalog_level_number FROM catalog_structure WHERE catalog_level_number BETWEEN 5 AND 10
180
 SELECT catalog_level_name ,  catalog_level_number FROM Catalog_structure WHERE catalog_level_number BETWEEN 5 AND 10
181
 SELECT catalog_publisher FROM catalogs WHERE catalog_name LIKE "%Murray%"
182
 SELECT catalog_publisher FROM catalogs WHERE catalog_publisher LIKE "%Murray%"
183
 SELECT catalog_publisher FROM catalogs GROUP BY catalog_publ

 SELECT dependent_name FROM dependent WHERE relationship  =  'spouse' AND employee_id  !=  0
1178
 SELECT count(*) FROM dependent WHERE sex  =  'F'
1179
 SELECT fname ,  lname FROM Employee WHERE salary  >  30000
1180
 SELECT count(*) ,  sex FROM employee WHERE salary  <  50000 GROUP BY sex
1181
 SELECT fname ,  lname ,  address FROM Employee ORDER BY bdate ASC
1189
 SELECT name FROM races ORDER BY date DESC LIMIT 1
1190
 SELECT name FROM races ORDER BY YEAR DESC LIMIT 1
1191
 SELECT name ,  date FROM races ORDER BY date DESC LIMIT 1
1192
 SELECT name ,  date FROM races ORDER BY date DESC LIMIT 1
1193
 SELECT name FROM races WHERE YEAR  =  2017
1194
 SELECT name FROM races WHERE YEAR  =  2017
1195
 SELECT DISTINCT name FROM races WHERE YEAR BETWEEN 2014 AND 2017
1196
 SELECT DISTINCT name FROM races WHERE YEAR BETWEEN 2014 AND 2017
1197
 SELECT forename ,  surname FROM drivers WHERE nationality  =  'German'
1198
 SELECT Forename ,  Surname FROM drivers WHERE Nationality  =  'German'
11

 SELECT city FROM addresses ORDER BY city
1654
 SELECT first_name ,  last_name FROM Teachers ORDER BY last_name
1655
 SELECT * FROM Student_Addresses ORDER BY monthly_rental DESC
1772
 SELECT sum(Num_of_component) FROM furniture
1773
 SELECT Name ,  Furniture_ID FROM furniture ORDER BY Market_Rate DESC LIMIT 1
1774
 SELECT sum(Market_Rate) FROM furniture GROUP BY Name ORDER BY Market_Rate DESC LIMIT 2
1775
 SELECT Name ,  Num_of_component FROM furniture WHERE Num_of_component  >  10
1776
 SELECT Name ,  Num_of_component FROM furniture ORDER BY Num_of_component ASC LIMIT 1
1777
 SELECT Name ,  Open_Year FROM manufacturer ORDER BY Num_of_Shops DESC LIMIT 1
1778
 SELECT avg(num_of_factories) FROM manufacturer WHERE num_of_shops  >  20
1779
 SELECT Name ,  Manufacturer_ID FROM manufacturer ORDER BY Open_Year
1780
 SELECT Name ,  Open_Year FROM manufacturer WHERE Num_of_Factories  <  10 OR Num_of_Shops  >  10
1781
 SELECT avg(num_of_factories) ,  max(num_of_shops) FROM manufacturer WHERE op

 SELECT avg(order_quantity) FROM invoices WHERE payment_method_code  =  "MasterCard"
2662
 SELECT product_id FROM invoices GROUP BY product_id ORDER BY count(*) DESC LIMIT 1
2663
 SELECT product_id FROM invoices GROUP BY product_id ORDER BY count(*) DESC LIMIT 1
2760
 SELECT name ,  address_road ,  city FROM branch ORDER BY open_year
2761
 SELECT name ,  address_road ,  city FROM branch ORDER BY open_year
2762
 SELECT name FROM branch ORDER BY Membership_amount DESC LIMIT 3
2763
 SELECT name FROM branch ORDER BY Membership_amount DESC LIMIT 3
2764
 SELECT DISTINCT city FROM branch WHERE membership_amount  >=  100
2765
 SELECT DISTINCT city FROM branch WHERE membership_amount  >  100
2766
 SELECT open_year FROM branch GROUP BY open_year HAVING count(*)  >=  2
2767
 SELECT Open_Year FROM branch GROUP BY Open_Year HAVING COUNT(*)  >=  2
2768
 SELECT min(Membership_amount) ,  max(Membership_amount) FROM branch WHERE open_year  =  2011 OR city  =  'London'
2769
 SELECT min(Membership_amount

 SELECT first_name ,  last_name FROM customers
3446
 SELECT first_name ,  last_name FROM customers
3447
 SELECT email_address ,  date_of_birth FROM Customers WHERE first_name  =  "Carole"
3448
 SELECT email_address ,  date_of_birth FROM Customers WHERE first_name  =  "Carole"
3449
 SELECT phone_number ,  email_address FROM customers WHERE amount_outstanding  >  2000
3450
 SELECT phone_number ,  email_address FROM customers WHERE amount_outstanding  >  2000
3451
 SELECT customer_status_code ,  cell_mobile_phone_number ,  email_address FROM Customers WHERE last_name  =  "Kohler" OR first_name  =  "Marina"
3452
 SELECT customer_status_code ,  phone_number ,  email_address FROM Customers WHERE first_name  =  "Marina" OR last_name  =  "Kohler"
3453
 SELECT date_of_birth FROM customers WHERE customer_status_code  =  'Good Customer'
3454
 SELECT date_of_birth FROM customers WHERE customer_status_code  =  'Good Customer'
3455
 SELECT date_became_customer FROM customers WHERE first_name  =  "Ca

In [14]:
tmp.to_csv('davinci-ft-mercator-2023-01-25-23-01-53-completions.csv')

In [15]:
tmp.head()

Unnamed: 0,open_ai_prompt,open_ai_completion,model_response
173,Schema: create table `catalog_contents` (`cata...,SELECT distinct(catalog_entry_name) FROM cata...,SELECT DISTINCT catalog_entry_name FROM CATAL...
174,Schema: create table `catalog_contents` (`cata...,SELECT distinct(catalog_entry_name) FROM cata...,SELECT DISTINCT catalog_entry_name FROM Catal...
175,Schema: create table `attribute_definitions` (...,SELECT attribute_data_type FROM Attribute_Def...,SELECT attribute_data_type FROM attribute_def...
176,Schema: create table `attribute_definitions` (...,SELECT attribute_data_type FROM Attribute_Def...,SELECT attribute_data_type FROM attribute_def...
177,Schema: create table `attribute_definitions` (...,SELECT attribute_data_type FROM Attribute_Def...,SELECT attribute_data_type FROM Attribute_Def...


In [16]:
tmp.shape

(538, 3)

In [17]:
# ignore white space, case, and trailing semicolons
tmp['correct'] = np.where(tmp['open_ai_completion'].str
                                                   .strip()
                                                   .str
                                                   .lower()
                                                   .apply(lambda x: re.sub('\s+', '',x).rstrip(';')) == 
                              tmp['model_response'].str
                                                   .strip()
                                                   .str
                                                   .lower()
                                                   .apply(lambda x: re.sub('\s+', '',x).rstrip(';')), 1, 0)
tmp['correct'].mean()

0.7397769516728625