In [1]:
import time
import json
import openai
import re
import sqlite3
import pandas as pd
import numpy as np

from os import path
from openai.api_requestor import error

In [2]:
def get_schema(db_id):
    PATH = 'spider/database/'
    
    path_to_file   = PATH + db_id + '/schema.sql'
    path_to_file_2 = PATH + db_id + '/' + db_id + '.sqlite'
    
    if path.exists(path_to_file):
        x = open(path_to_file, 'r').read()
        # Remove comment lines
        x = re.sub(r"^/\*.*\n", "", x, flags=re.MULTILINE)
        x = re.sub(r'^--.*(\n|$)', '', x, flags=re.MULTILINE)
        x = re.sub(r"^/\*.*\*/", "", x)
        
        x = re.sub(r'CREATE TABLE \t', 'CREATE TABLE ', x)
        
        arr = x.split(';')
        create_tables = []
        for i, stmt in enumerate(arr):
            stmt = stmt.strip()
            if stmt.lower().startswith('create'):
                create_tables.append(stmt.replace('IF NOT EXISTS ', '').replace('\n', ''))
        return ';'.join(create_tables)
      
    elif path.exists(path_to_file_2):
        # Connect to the sqlite file
        conn = sqlite3.connect(path_to_file_2)

        # Read the schema of the twitter_1.sqlite file into a dataframe
        df_conn = pd.read_sql_query("SELECT sql FROM sqlite_master WHERE type='table';", conn)
        
        # Close the connection
        conn.close()
        
        create_tables = []
        for idx, row in df_conn.iterrows():
            row['sql'] = row['sql'].strip()
            if row['sql'].lower().startswith('create'):
                create_tables.append(row['sql'].replace('IF NOT EXISTS ', '').replace('\n', ''))
        return ';'.join(create_tables)
    
    else: return None

def short_open_ai_prompt(x):
    return f'''Schema: {x['schema']}\nQuestion: {x['question']}\n\n###\n\n'''

# From OpenAI: The completion should start with a whitespace character (` `). 
# This tends to produce better results due to the tokenization we use.
def open_ai_completion(x):
    return f" {x['query']}\n" 

def call_model(row, engine):
    prompt = row["open_ai_prompt"]
    while True:
        try:
            completions = openai.Completion.create(
                engine=engine,
                prompt=prompt,
                max_tokens=1024,
                n=1,
                stop=["\n"],
                temperature=0.5
            )

            print(completions.choices[0].text)
            return completions.choices[0].text
    
        except error.ServiceUnavailableError:
            print('ServiceUnavailableError')
            time.sleep(1)
            continue
            
        except error.InvalidRequestError:
            print('InvalidRequestError: too many tokens')
            time.sleep(1)
            break

# Data Preparation

In [4]:
df = pd.read_json('spider/train_spider.json')
# df = df[~df['query'].str.contains('JOIN')]
# df = df[df['query'].str.count('FROM') <= 1]

# Set schema
df['schema'] = df.apply(lambda x: get_schema(x['db_id']), axis=1)
# Remove REFERENCES sections and NOT NULLs
df['schema'] = df['schema'].apply(lambda x: re.sub('(?i) REFERENCES.*?(;|$)', ';', x)) 
df['schema'] = df['schema'].apply(lambda x: re.sub('NOT NULL', '', x))

# Set Open AI prompt and completion
df['open_ai_prompt'] = df.apply(lambda x: short_open_ai_prompt(x), axis=1)
df['open_ai_completion'] = df.apply(lambda x: open_ai_completion(x), axis=1)

# Randomize at the db level
df_db_id = pd.DataFrame(df['db_id'].unique(), columns=['db_id'])
np.random.seed(240956) #set seed
df_db_id['train_test'] = np.random.choice(['train','test'], df_db_id.shape[0], p=[0.8, 0.2])
df = df.merge(df_db_id, on='db_id')
df_train = df[df['train_test'] == 'train'][['open_ai_prompt', 'open_ai_completion']].copy()
df_test  = df[df['train_test'] == 'test'][['open_ai_prompt', 'open_ai_completion']].copy()

In [5]:
df[df['schema'] == '']['db_id'].unique()

array([], dtype=object)

In [6]:
print(df.iloc[110]['open_ai_prompt'])
print(df.iloc[110]['open_ai_completion'])

Schema: CREATE TABLE station (    id INTEGER PRIMARY KEY,    name TEXT,    lat NUMERIC,    long NUMERIC,    dock_count INTEGER,    city TEXT,    installation_date TEXT);CREATE TABLE status (    station_id INTEGER,    bikes_available INTEGER,    docks_available INTEGER,    time TEXT,    FOREIGN KEY (station_id);CREATE TABLE trip (    id INTEGER PRIMARY KEY,    duration INTEGER,    start_date TEXT,    start_station_name TEXT, -- this should be removed    start_station_id INTEGER,    end_date TEXT,    end_station_name TEXT, -- this should be removed    end_station_id INTEGER,    bike_id INTEGER,    subscription_type TEXT,    zip_code INTEGER);CREATE TABLE weather (    date TEXT,    max_temperature_f INTEGER,    mean_temperature_f INTEGER,    min_temperature_f INTEGER,    max_dew_point_f INTEGER,    mean_dew_point_f INTEGER,    min_dew_point_f INTEGER,    max_humidity INTEGER,    mean_humidity INTEGER,    min_humidity INTEGER,    max_sea_level_pressure_inches NUMERIC,    mean_sea_level_pre

In [7]:
df_train.shape

(5944, 2)

In [8]:
df_test.shape

(1056, 2)

# Training

In [9]:
# # Put the training data into jsonl format
# data = []
# for idx, row in df_train.iterrows():
#     data.append({"prompt": row["open_ai_prompt"], "completion": row["open_ai_completion"]})

# timestr = time.strftime("%Y%m%d-%H%M%S")
# with open("spider_open_ai_fine_tuning_" + timestr + ".jsonl", "w") as outfile:
#     for obj in data:
#         json.dump(obj, outfile)
#         outfile.write("\n")
        
# print("spider_open_ai_fine_tuning_" + timestr + ".jsonl")

spider_open_ai_fine_tuning_20230126-235556.jsonl

In [10]:
#!openai tools fine_tunes.prepare_data -f spider_open_ai_fine_tuning_20230126-235556.jsonl
# - There are 6 duplicated prompt-completion sets. These are rows: [1296, 2000, 2097, 2984, 2985, 3799]

In [11]:
df_train_dedup = df_train.copy()
df_train_dedup = df_train_dedup.drop(df_train.index[[1296, 2000, 2097, 2984, 2985, 3799]])
df_train_dedup[(df_train_dedup['open_ai_prompt'] == df_train_dedup.iloc[1296]['open_ai_prompt']) |
               (df_train_dedup['open_ai_prompt'] == df_train_dedup.iloc[2000]['open_ai_prompt']) |
               (df_train_dedup['open_ai_prompt'] == df_train_dedup.iloc[2097]['open_ai_prompt']) |
               (df_train_dedup['open_ai_prompt'] == df_train_dedup.iloc[2984]['open_ai_prompt']) |
               (df_train_dedup['open_ai_prompt'] == df_train_dedup.iloc[2985]['open_ai_prompt']) |
               (df_train_dedup['open_ai_prompt'] == df_train_dedup.iloc[3799]['open_ai_prompt'])]

Unnamed: 0,open_ai_prompt,open_ai_completion
1421,Schema: CREATE TABLE classroom\t(building\t\tv...,SELECT title FROM course WHERE course_id NOT ...
2344,"Schema: CREATE TABLE ""Campuses"" (\t""Id"" INTEGE...",SELECT campus FROM campuses WHERE county = ...
2442,Schema: create table Movie(\tmID int primary k...,SELECT count(*) FROM Reviewer\n
3489,Schema: CREATE TABLE `regions` ( `REGION_ID` ...,SELECT job_id FROM employees GROUP BY job_id ...
3490,Schema: CREATE TABLE `regions` ( `REGION_ID` ...,SELECT job_id FROM employees GROUP BY job_id ...
4405,Schema: CREATE TABLE Person ( name varchar(20...,SELECT count(DISTINCT city) FROM Person\n


In [12]:
# # Put the training data into jsonl format
# data = []
# for idx, row in df_train_dedup.iterrows():
#     data.append({"prompt": row["open_ai_prompt"], "completion": row["open_ai_completion"]})

# timestr = time.strftime("%Y%m%d-%H%M%S")
# with open("spider_open_ai_fine_tuning_" + timestr + ".jsonl", "w") as outfile:
#     for obj in data:
#         json.dump(obj, outfile)
#         outfile.write("\n")
        
# print("spider_open_ai_fine_tuning_" + timestr + ".jsonl")

spider_open_ai_fine_tuning_20230126-235844.jsonl

In [13]:
#!openai tools fine_tunes.prepare_data -f spider_open_ai_fine_tuning_20230126-235844.jsonl

In [14]:
# !openai api fine_tunes.create -t "spider_open_ai_fine_tuning_20230126-235844.jsonl" -m davinci
# [2023-01-27 00:01:35] Created fine-tune: ft-xv8wuk75QTAbYGVPgAC2vHYh
# [2023-01-27 00:05:10] Fine-tune costs $376.19
# [2023-01-27 00:05:17] Fine-tune enqueued. Queue number: 0
# [2023-01-27 00:05:27] Fine-tune started

In [15]:
# !openai api fine_tunes.list

# Out of Sample Testing

In [16]:
# tmp = df_test.copy()

# # tmp['model_response'] = ''
# data = []
# for idx, row in tmp.iterrows():
#     print(idx)
#     if row['model_response'] != '':
#         print('already completed')
#         data.append(row['model_response'])
#         continue
    
#     new_response = call_model(row, engine="davinci:ft-mercator-2023-01-27-10-13-07")
#     data.append(new_response)
#     tmp.loc[idx,'model_response'] = new_response

In [22]:
df_test = pd.read_csv('df_test_davinci_ft-mercator-2023-01-27-10-13-07.csv',index_col=0).fillna('')
df_test

Unnamed: 0,open_ai_prompt,open_ai_completion,model_response
301,Schema: CREATE TABLE `Attribute_Definitions` (...,SELECT distinct(catalog_entry_name) FROM cata...,SELECT DISTINCT catalog_entry_name FROM catal...
302,Schema: CREATE TABLE `Attribute_Definitions` (...,SELECT distinct(catalog_entry_name) FROM cata...,SELECT DISTINCT catalog_entry_name FROM catal...
303,Schema: CREATE TABLE `Attribute_Definitions` (...,SELECT attribute_data_type FROM Attribute_Def...,SELECT attribute_data_type FROM attribute_def...
304,Schema: CREATE TABLE `Attribute_Definitions` (...,SELECT attribute_data_type FROM Attribute_Def...,SELECT attribute_data_type FROM attribute_def...
305,Schema: CREATE TABLE `Attribute_Definitions` (...,SELECT attribute_data_type FROM Attribute_Def...,SELECT attribute_data_type FROM attribute_def...
...,...,...,...
6880,Schema: CREATE TABLE routes ( rid integer PR...,SELECT T1.name FROM airlines AS T1 JOIN route...,SELECT T1.name FROM airlines AS T1 JOIN route...
6881,Schema: CREATE TABLE routes ( rid integer PR...,SELECT T1.name FROM airports AS T1 JOIN route...,SELECT src_ap FROM routes WHERE rid IN (SELEC...
6882,Schema: CREATE TABLE routes ( rid integer PR...,SELECT T1.name FROM airports AS T1 JOIN route...,SELECT T1.name FROM airports AS T1 JOIN route...
6883,Schema: CREATE TABLE routes ( rid integer PR...,SELECT T1.name FROM airports AS T1 JOIN route...,SELECT T1.dst_ap FROM routes AS T1 JOIN airpo...


In [23]:
# ignore: 
# - whitespace
# - capitalization
# - trailing semicolons
# - the difference between single and double quotes
# - difference between 'distinct(...)' and 'distinct ...'
# - explictly stating ASC (the ORDER BY default)
tmp = df_test[df_test['model_response'] != ''].copy()

tmp['correct'] = np.where(tmp['open_ai_completion'].str
                                                                         .strip()
                                                                         .str
                                                                         .lower()
                                                                         .apply(lambda x: re.sub("distinct\((.*)\)", "distinct\\1", 
                                                                                          re.sub('orderby(.*)asc', 'orderby\\1',
                                                                                          re.sub('"', "'",
                                                                                          re.sub('\s+', '', x))))
                                                                                .rstrip(';')) == 
                                     tmp['model_response'].str
                                                                     .strip()
                                                                     .str
                                                                     .lower()
                                                                         .apply(lambda x: re.sub("distinct\((.*)\)", "distinct\\1", 
                                                                                          re.sub('orderby(.*)asc', 'orderby\\1',
                                                                                          re.sub('"', "'",
                                                                                          re.sub('\s+', '', x))))
                                                                            .rstrip(';')), 1, 0)
tmp['correct'].mean()

0.4394250513347023

In [24]:
# 0 leaves off "distinct"
# 10 adds a bad GROUP BY
# 20 - provided answer is bad
# 30 completely wrong
# 40 wrong table
# 50 - right, just reverses join order
# 60 leaves off a join
# 70 leaves off a join
# 80 leaves off a column, adds two joins it doesn't need
# 90 - right, just reverses column and join order
i = 90

print(tmp[tmp['correct'] == 0]['open_ai_prompt'].iloc[i])
print(tmp[tmp['correct'] == 0]['open_ai_completion'].iloc[i])
print(tmp[tmp['correct'] == 0]['model_response'].iloc[i])

Schema: CREATE TABLE "region" ("Region_ID" int,"Region_name" text,"Date" text,"Label" text,"Format" text,"Catalogue" text,PRIMARY KEY ("Region_ID"));CREATE TABLE "party" ("Party_ID" int,"Minister" text,"Took_office" text,"Left_office" text,"Region_ID" int,"Party_name" text,PRIMARY KEY ("Party_ID"),FOREIGN KEY (`Region_ID`);CREATE TABLE "member" ("Member_ID" int,"Member_Name" text,"Party_ID" text,"In_office" text,PRIMARY KEY ("Member_ID"),FOREIGN KEY (`Party_ID`);CREATE TABLE "party_events" ("Event_ID" int,"Event_Name" text,"Party_ID" int,"Member_in_charge_ID" int,PRIMARY KEY ("Event_ID"),FOREIGN KEY (`Party_ID`);
Question: How many members are in each party?

###


 SELECT T2.party_name ,  count(*) FROM Member AS T1 JOIN party AS T2 ON T1.party_id  =  T2.party_id GROUP BY T1.party_id

 SELECT count(*) ,  T2.Party_name FROM party AS T1 JOIN member AS T2 ON T1.Party_ID  =  T2.Party_ID GROUP BY T2.Party_name


In [155]:
df_test

Unnamed: 0,open_ai_prompt,open_ai_completion,model_response
301,Schema: CREATE TABLE `Attribute_Definitions` (...,SELECT distinct(catalog_entry_name) FROM cata...,SELECT DISTINCT catalog_entry_name FROM catal...
302,Schema: CREATE TABLE `Attribute_Definitions` (...,SELECT distinct(catalog_entry_name) FROM cata...,SELECT DISTINCT catalog_entry_name FROM catal...
303,Schema: CREATE TABLE `Attribute_Definitions` (...,SELECT attribute_data_type FROM Attribute_Def...,SELECT attribute_data_type FROM attribute_def...
304,Schema: CREATE TABLE `Attribute_Definitions` (...,SELECT attribute_data_type FROM Attribute_Def...,SELECT attribute_data_type FROM attribute_def...
305,Schema: CREATE TABLE `Attribute_Definitions` (...,SELECT attribute_data_type FROM Attribute_Def...,SELECT attribute_data_type FROM attribute_def...
...,...,...,...
6880,Schema: CREATE TABLE routes ( rid integer PR...,SELECT T1.name FROM airlines AS T1 JOIN route...,SELECT T1.name FROM airlines AS T1 JOIN route...
6881,Schema: CREATE TABLE routes ( rid integer PR...,SELECT T1.name FROM airports AS T1 JOIN route...,SELECT src_ap FROM routes WHERE rid IN (SELEC...
6882,Schema: CREATE TABLE routes ( rid integer PR...,SELECT T1.name FROM airports AS T1 JOIN route...,SELECT T1.name FROM airports AS T1 JOIN route...
6883,Schema: CREATE TABLE routes ( rid integer PR...,SELECT T1.name FROM airports AS T1 JOIN route...,SELECT T1.dst_ap FROM routes AS T1 JOIN airpo...


## How does the multi-table query model perform on single table queries?

In [26]:
# ignore: 
# - whitespace
# - capitalization
# - trailing semicolons
# - the difference between single and double quotes
# - coerce 'distinct(...)' to 'distinct ...'
# - explictly stating ASC, the ORDER BY default
df_test_single = df_test[~df_test['open_ai_completion'].str.contains('JOIN')].copy()
df_test_single = df_test_single[df_test_single['open_ai_completion'].str.count('FROM') <= 1]
df_test_single = df_test_single[df_test_single['model_response'] != '']

df_test_single['correct'] = np.where(df_test_single['open_ai_completion'].str
                                                                         .strip()
                                                                         .str
                                                                         .lower()
                                                                         .apply(lambda x: re.sub("distinct\((.*)\)", "distinct\\1", 
                                                                                          re.sub('orderby(.*)asc', 'orderby\\1',
                                                                                          re.sub('"', "'",
                                                                                          re.sub('\s+', '', x))))
                                                                                .rstrip(';')) == 
                                     df_test_single['model_response'].str
                                                                     .strip()
                                                                     .str
                                                                     .lower()
                                                                         .apply(lambda x: re.sub("distinct\((.*)\)", "distinct\\1", 
                                                                                          re.sub('orderby(.*)asc', 'orderby\\1',
                                                                                          re.sub('"', "'",
                                                                                          re.sub('\s+', '', x))))
                                                                            .rstrip(';')), 1, 0)
df_test_single['correct'].mean()

0.6713709677419355

In [27]:
# 0 leaves off "distinct"
# 10 confuses which column goes in SELECT vs WHERE clause
# 20 wrong table
# 30 wrong table
# 40 sorts by "year" instead of "date" to find "most recent"
# 50 adds a join it doesn't need and sorts by the wrong column
# 60 adds a join it doesn't need
# 70 wrong table, but I'd argue the model gets it right and the provided answer is wrong
# 80 leave off a column, but I'd argue the model gets it right and the provided answer is wrong
# 90 is right, just reverses the column order
i = 90

print(df_test_single[df_test_single['correct'] == 0]['open_ai_prompt'].iloc[i])
print(df_test_single[df_test_single['correct'] == 0]['open_ai_completion'].iloc[i])
print(df_test_single[df_test_single['correct'] == 0]['model_response'].iloc[i])

Schema: create table genre(	g_name varchar2(20) not null,	rating varchar2(10),	most_popular_in varchar2(50),	primary key(g_name));create table artist(	artist_name varchar2(50) not null,	country varchar2(20),	gender varchar2(20),	preferred_genre varchar2(50),	constraint a_name primary key(artist_name),	foreign key(preferred_genre);create table files(	f_id number(10) not null,	artist_name varchar2(50),	file_size varchar2(20),	duration varchar2(20),	formats varchar2(20),	primary key(f_id),	foreign key(artist_name);create table song(	song_name varchar2(50),	artist_name varchar2(50),	country varchar2(20),	f_id number(10),    	genre_is varchar2(20),	rating number(10) check(rating>0 and rating<11),	languages varchar2(20),	releasedate Date, 	resolution number(10) not null,	constraint s_name primary key(song_name),	foreign key(artist_name);
Question: List the names of all genres in alphabetical oder, together with its ratings.

###


 SELECT g_name ,  rating FROM genre ORDER BY g_name

 SELECT 

# How does raw da vinci compare?

In [133]:
def open_ai_prompt(x):
    return f'''
        Convert text to SQL.
        You have the following DDLs:
        ```
        {x['schema']}
        ```
        Write the SQL that answers the following question:
        """{x['question']}"""
        Respond with only one concise SQL statement.
        '''

def call_raw_model(row):
    prompt = row["open_ai_prompt"]

    # parameters chosen to match OpenAI Playground
#     completions = openai.Completion.create(
#         engine="text-davinci-003",
#         prompt=prompt,
#         max_tokens=256,
#         stop=None,
#         temperature=0.7,
#         top_p=1,
#         frequency_penalty=0,
#         presence_penalty=0
#     )
    
    # Duber's parameters
    completions = openai.Completion.create(
        engine="text-davinci-003",
        prompt=prompt,
        temperature=0.3,
        max_tokens=2000,
        best_of=3,
        frequency_penalty=0,
        presence_penalty=0
    )

    print(completions.choices[0].text)
    # raise Exception()
    time.sleep(15)
    return completions.choices[0].text




In [134]:
# df_long = df.loc[df_test.index.to_list()].copy()
# df_long['model_response'] = ''

#Get the long prompt
# df_long['open_ai_prompt'] = df_long.apply(lambda x: open_ai_prompt(x), axis=1)

data = []
for idx, row in df_long.iterrows():
    print(idx)
    if row['model_response'] != '':
        print('already completed')
        data.append(row['model_response'])
        continue
    
    new_response = call_raw_model(row)
    print(row['open_ai_completion'])
    data.append(new_response)
    tmp.loc[idx,'model_response'] = new_response

301
already completed
302
already completed
303
already completed
304
already completed
305
already completed
306
already completed
307
already completed
308
already completed
309
already completed
310
already completed
311
already completed
312
already completed
313
already completed
314
already completed
315
already completed
316
already completed
317
already completed
318
already completed
319
already completed
320
already completed
321
already completed
322
already completed
323
already completed
324
already completed
325
already completed
326
already completed
327
already completed
328
already completed
329
already completed
330
already completed
331
already completed
332
already completed
333
already completed
334
already completed
335
already completed
336
already completed
337
already completed
338
already completed
339
already completed
340
already completed
341
already completed
342
already completed
891
already completed
892
already completed
893
already completed
894
alread

 SELECT Name FROM mountain WHERE Mountain_ID NOT IN (SELECT Mountain_ID FROM climber)

1143

        SELECT Name FROM mountain WHERE Mountain_ID NOT IN (SELECT Mountain_ID FROM climber);
 SELECT Name FROM mountain WHERE Mountain_ID NOT IN (SELECT Mountain_ID FROM climber)

1144

        SELECT DISTINCT Country 
        FROM mountain 
        WHERE Height > 5600 OR Height < 5200;
 SELECT Country FROM mountain WHERE Height  >  5600 INTERSECT SELECT Country FROM mountain WHERE Height  <  5200

1145

        SELECT DISTINCT Country 
        FROM mountain 
        WHERE Height > 5600 OR Height < 5200;
 SELECT Country FROM mountain WHERE Height  >  5600 INTERSECT SELECT Country FROM mountain WHERE Height  <  5200

1146

        SELECT Range, COUNT(*) AS Num_Mountains
        FROM mountain
        GROUP BY Range
        ORDER BY Num_Mountains DESC
        LIMIT 1;
 SELECT Range FROM mountain GROUP BY Range ORDER BY COUNT(*) DESC LIMIT 1

1147

        SELECT Range, COUNT(*) AS Mountain_Count


 SELECT Sponsor_name FROM player WHERE Residence  =  "Brandon" OR Residence  =  "Birtle"

1724

        SELECT Player_name 
        FROM player 
        ORDER BY Votes DESC 
        LIMIT 1;
 SELECT Player_name FROM player ORDER BY Votes DESC LIMIT 1

1725

        SELECT Occupation, COUNT(*) AS 'Number of Players' FROM player GROUP BY Occupation;
 SELECT Occupation ,  COUNT(*) FROM player GROUP BY Occupation

1726

        SELECT Occupation, COUNT(Occupation) AS Occurrence 
        FROM player 
        GROUP BY Occupation 
        ORDER BY Occurrence DESC 
        LIMIT 1;
 SELECT Occupation FROM player GROUP BY Occupation ORDER BY COUNT(*) DESC LIMIT 1

1727

        SELECT Residence
        FROM player
        GROUP BY Residence
        HAVING COUNT(*) >= 2;
 SELECT Residence FROM player GROUP BY Residence HAVING COUNT(*)  >=  2

1728

        SELECT player.Player_name, coach.Coach_name 
        FROM player 
        INNER JOIN player_coach ON player.Player_ID = player_coach.Player_I


        SELECT Party_name FROM party LEFT JOIN member ON party.Party_ID = member.Party_ID WHERE Member_Name IS NULL;
 SELECT party_name FROM party WHERE party_id NOT IN (SELECT party_id FROM Member)

2064

        SELECT Party_name FROM party LEFT JOIN member ON party.Party_ID = member.Party_ID WHERE Member_Name IS NULL;
 SELECT party_name FROM party WHERE party_id NOT IN (SELECT party_id FROM Member)

2065

        SELECT Member_Name 
        FROM member 
        WHERE Party_ID IN (1, 3);
 SELECT member_name FROM member WHERE party_id  =  3 INTERSECT SELECT member_name FROM member WHERE party_id  =  1

2066

        SELECT Member_Name
        FROM member
        WHERE Member_ID IN (SELECT Member_ID FROM member WHERE Party_ID = 3)
        AND Member_ID IN (SELECT Member_ID FROM member WHERE Party_ID = 1);
 SELECT member_name FROM member WHERE party_id  =  3 INTERSECT SELECT member_name FROM member WHERE party_id  =  1

2067

        SELECT Member_Name FROM member WHERE Party_ID NOT IN

 SELECT DISTINCT T1.forename ,  T1.surname FROM drivers AS T1 JOIN laptimes AS T2 ON T1.driverid = T2.driverid WHERE T2.milliseconds < 93000

2163

        SELECT DISTINCT driverId, nationality FROM drivers INNER JOIN lapTimes ON drivers.driverId = lapTimes.driverId WHERE milliseconds > 100000;
 SELECT DISTINCT T1.driverid ,  T1.nationality FROM drivers AS T1 JOIN laptimes AS T2 ON T1.driverid = T2.driverid WHERE T2.milliseconds >  100000

2164

        SELECT driverId, nationality FROM drivers INNER JOIN lapTimes ON drivers.driverId = lapTimes.driverId WHERE lapTimes.milliseconds > 100000;
 SELECT DISTINCT T1.driverid ,  T1.nationality FROM drivers AS T1 JOIN laptimes AS T2 ON T1.driverid = T2.driverid WHERE T2.milliseconds >  100000

2165

        SELECT forename, surname FROM drivers INNER JOIN lapTimes ON drivers.driverId = lapTimes.driverId WHERE lapTimes.milliseconds = (SELECT MIN(milliseconds) FROM lapTimes);
 SELECT T1.forename ,  T1.surname FROM drivers AS T1 JOIN laptimes AS 

 SELECT DISTINCT T1.forename FROM drivers AS T1 JOIN driverstandings AS T2 ON T1.driverid = T2.driverid WHERE T2.position = 1 AND T2.wins = 1

2185

        SELECT DISTINCT forename 
        FROM drivers 
        INNER JOIN driverStandings 
        ON drivers.driverId = driverStandings.driverId 
        WHERE driverStandings.position = 1 
        AND driverStandings.points > 20;
 SELECT DISTINCT T1.forename FROM drivers AS T1 JOIN driverstandings AS T2 ON T1.driverid = T2.driverid WHERE T2.position = 1 AND T2.wins = 1 AND T2.points > 20

2186

        SELECT forename FROM drivers 
        INNER JOIN driverStandings ON drivers.driverId = driverStandings.driverId 
        WHERE driverStandings.position = 1 AND driverStandings.points > 20;
 SELECT DISTINCT T1.forename FROM drivers AS T1 JOIN driverstandings AS T2 ON T1.driverid = T2.driverid WHERE T2.position = 1 AND T2.wins = 1 AND T2.points > 20

2187

        SELECT nationality, COUNT(*) AS 'Number of Constructors' FROM constructors GR

 SELECT T1.driverid ,  T1.surname FROM drivers AS T1 JOIN results AS T2 ON T1.driverid  =  T2.driverid JOIN races AS T3 ON T2.raceid = T3.raceid WHERE T3.year > 2010 GROUP BY T1.driverid ORDER BY count(*) DESC LIMIT 1

2213

        SELECT name FROM circuits WHERE country IN ('UK', 'Malaysia');
 SELECT name FROM circuits WHERE country = "UK" OR country = "Malaysia"

2214

        SELECT name FROM circuits WHERE country IN ('UK', 'Malaysia');
 SELECT name FROM circuits WHERE country = "UK" OR country = "Malaysia"

2215

        SELECT circuitId, location FROM circuits WHERE country IN ('France', 'Belgium');
 SELECT circuitid ,  LOCATION FROM circuits WHERE country = "France" OR country = "Belgium"

2216

        SELECT circuitId, location FROM circuits WHERE country IN ('France', 'Belgium');
 SELECT circuitid ,  LOCATION FROM circuits WHERE country = "France" OR country = "Belgium"

2217

        SELECT name 
        FROM constructors 
        INNER JOIN constructorResults 
        ON c

 SELECT T2.Name FROM entrepreneur AS T1 JOIN people AS T2 ON T1.People_ID  =  T2.People_ID

2274

        SELECT Name FROM people INNER JOIN entrepreneur ON people.People_ID = entrepreneur.People_ID;
 SELECT T2.Name FROM entrepreneur AS T1 JOIN people AS T2 ON T1.People_ID  =  T2.People_ID

2275

        SELECT Name 
        FROM people 
        INNER JOIN entrepreneur 
        ON people.People_ID = entrepreneur.People_ID 
        WHERE Investor != 'Rachel Elnaugh';
 SELECT T2.Name FROM entrepreneur AS T1 JOIN people AS T2 ON T1.People_ID  =  T2.People_ID WHERE T1.Investor != "Rachel Elnaugh"

2276

        SELECT Name 
        FROM people 
        JOIN entrepreneur 
        ON people.People_ID = entrepreneur.People_ID 
        WHERE Investor != 'Rachel Elnaugh';
 SELECT T2.Name FROM entrepreneur AS T1 JOIN people AS T2 ON T1.People_ID  =  T2.People_ID WHERE T1.Investor != "Rachel Elnaugh"

2277

        SELECT MIN(Weight) FROM people;
 SELECT Weight FROM people ORDER BY Height ASC LIM


        SELECT MIN(Crime_rate), MAX(Crime_rate) FROM county_public_safety;
 SELECT min(Crime_rate) ,  max(Crime_rate) FROM county_public_safety

2539

        SELECT MIN(Crime_rate), MAX(Crime_rate) FROM county_public_safety;
 SELECT min(Crime_rate) ,  max(Crime_rate) FROM county_public_safety

2540

        SELECT County_ID, Name, Crime_rate, Police_officers 
        FROM county_public_safety 
        ORDER BY Police_officers ASC;
 SELECT Crime_rate FROM county_public_safety ORDER BY Police_officers ASC

2541

        SELECT County_ID, Name, Crime_rate FROM county_public_safety ORDER BY Police_officers ASC;
 SELECT Crime_rate FROM county_public_safety ORDER BY Police_officers ASC

2542

        SELECT Name FROM city ORDER BY Name ASC;
 SELECT Name FROM city ORDER BY Name ASC

2543

        SELECT Name FROM city ORDER BY Name ASC;
 SELECT Name FROM city ORDER BY Name ASC

2544

        SELECT City.Name, Hispanic * 100.0 / (White + Black + Amerindian + Asian + Multiracial + Hispanic) A


        SELECT Region_name FROM region ORDER BY Region_name ASC;
 SELECT region_name FROM region ORDER BY region_name

2696

        SELECT Region_name FROM region WHERE Region_name != 'Denmark';
 SELECT region_name FROM region WHERE region_name != 'Denmark'

2697

        SELECT Region_name FROM region WHERE Region_name != 'Denmark';
 SELECT region_name FROM region WHERE region_name != 'Denmark'

2698

        SELECT COUNT(Number_Deaths) FROM storm WHERE Number_Deaths > 0;
 SELECT count(*) FROM storm WHERE Number_Deaths  >  0

2699

        SELECT COUNT(*) FROM storm WHERE Number_Deaths > 0;
 SELECT count(*) FROM storm WHERE Number_Deaths  >  0

2700

        SELECT Name, Dates_active, Number_Deaths FROM storm WHERE Number_Deaths > 0;
 SELECT name ,  dates_active ,  number_deaths FROM storm WHERE number_deaths  >=  1

2701

        SELECT Name, Dates_active, Number_Deaths FROM storm WHERE Number_Deaths > 0;
 SELECT name ,  dates_active ,  number_deaths FROM storm WHERE number_deaths 

 SELECT T3.name FROM affected_region AS T1 JOIN region AS T2 ON T1.region_id  =  T2.region_id JOIN storm AS T3 ON T1.storm_id  =  T3.storm_id WHERE T2.region_name  =  'Denmark'

2727

        SELECT Name FROM storm
        INNER JOIN affected_region ON storm.Storm_ID = affected_region.Storm_ID
        INNER JOIN region ON affected_region.Region_id = region.Region_id
        WHERE region.Region_name = 'Denmark';
 SELECT T3.name FROM affected_region AS T1 JOIN region AS T2 ON T1.region_id  =  T2.region_id JOIN storm AS T3 ON T1.storm_id  =  T3.storm_id WHERE T2.region_name  =  'Denmark'

2728

        SELECT Region_name 
        FROM region 
        INNER JOIN affected_region 
        ON region.Region_id = affected_region.Region_id 
        INNER JOIN storm 
        ON affected_region.Storm_ID = storm.Storm_ID 
        GROUP BY Region_name 
        HAVING COUNT(*) >= 2;
 SELECT T1.region_name FROM region AS T1 JOIN affected_region AS T2 ON T1.region_id = T2.region_id GROUP BY T1.region_i

 SELECT venue ,  name FROM event ORDER BY Event_Attendance DESC LIMIT 2

3086

        SELECT COUNT(*) FROM Assessment_Notes;
 SELECT count(*) FROM ASSESSMENT_NOTES

3087

        SELECT date_of_notes FROM Assessment_Notes;
 SELECT date_of_notes FROM Assessment_Notes

3088

        SELECT COUNT(*) FROM Addresses WHERE zip_postcode = '197';
 SELECT count(*) FROM ADDRESSES WHERE zip_postcode  =  "197"

3089

        SELECT COUNT(DISTINCT incident_type_code) FROM Ref_Incident_Type;
 SELECT count(DISTINCT incident_type_code) FROM Behavior_Incident

3090

        SELECT DISTINCT detention_type_code FROM Ref_Detention_Type;
 SELECT DISTINCT detention_type_code FROM Detention

3091

        SELECT date_incident_start, date_incident_end FROM Behavior_Incident WHERE incident_type_code = 'NOISE';
 SELECT date_incident_start ,  date_incident_end FROM Behavior_Incident WHERE incident_type_code  =  "NOISE"

3092

        SELECT detention_summary FROM Detention;
 SELECT detention_summary FROM Detent

 SELECT T2.address_id ,  T1.zip_postcode FROM Addresses AS T1 JOIN Student_Addresses AS T2 ON T1.address_id  =  T2.address_id ORDER BY monthly_rental DESC LIMIT 1

3117

        SELECT cell_mobile_number 
        FROM Students 
        INNER JOIN Student_Addresses 
        ON Students.student_id = Student_Addresses.student_id 
        INNER JOIN Addresses 
        ON Student_Addresses.address_id = Addresses.address_id 
        ORDER BY monthly_rental ASC 
        LIMIT 1;
 SELECT T2.cell_mobile_number FROM Student_Addresses AS T1 JOIN Students AS T2 ON T1.student_id  =  T2.student_id ORDER BY T1.monthly_rental ASC LIMIT 1

3118

        SELECT sa.monthly_rental 
        FROM Student_Addresses sa 
        JOIN Addresses a ON sa.address_id = a.address_id 
        WHERE a.state_province_county = 'Texas';
 SELECT T2.monthly_rental FROM Addresses AS T1 JOIN Student_Addresses AS T2 ON T1.address_id  =  T2.address_id WHERE T1.state_province_county  =  "Texas"

3119

        SELECT first_name,


        SELECT f_id 
        FROM files 
        ORDER BY duration DESC 
        LIMIT 1;
 SELECT f_id FROM files ORDER BY duration DESC LIMIT 1

3533

        SELECT song_name FROM song WHERE languages = 'English';
 SELECT song_name FROM song WHERE languages  =  "english"

3534

        SELECT song_name FROM song WHERE languages = 'English';
 SELECT song_name FROM song WHERE languages  =  "english"

3535

        SELECT f_id FROM files WHERE formats = 'mp3';
 SELECT f_id FROM files WHERE formats  =  "mp3"

3536

        SELECT f_id FROM files WHERE formats = 'mp3';
 SELECT f_id FROM files WHERE formats  =  "mp3"

3537

        SELECT artist_name, country FROM artist JOIN song ON artist.artist_name = song.artist_name WHERE rating > 9;
 SELECT DISTINCT T1.artist_name ,  T1.country FROM artist AS T1 JOIN song AS T2 ON T1.artist_name  =  T2.artist_name WHERE T2.rating  >  9

3538

        SELECT artist_name, country FROM artist INNER JOIN song ON artist.artist_name = song.artist_name WHE

 SELECT DISTINCT song_name FROM song WHERE resolution  >  (SELECT min(resolution) FROM song WHERE languages  =  "english")

3570

        SELECT song_name FROM song WHERE resolution > (SELECT resolution FROM song WHERE languages = 'English');
 SELECT DISTINCT song_name FROM song WHERE resolution  >  (SELECT min(resolution) FROM song WHERE languages  =  "english")

3571

        SELECT song_name
        FROM song
        WHERE rating < (SELECT rating FROM song WHERE genre_is = 'blues');
 SELECT song_name FROM song WHERE rating  <  (SELECT max(rating) FROM song WHERE genre_is  =  "blues")

3572

        SELECT song_name
        FROM song
        WHERE rating < (SELECT rating FROM song WHERE genre_is = 'blues')
 SELECT song_name FROM song WHERE rating  <  (SELECT max(rating) FROM song WHERE genre_is  =  "blues")

3573

        SELECT artist_name, country FROM artist JOIN song ON artist.artist_name = song.artist_name WHERE song_name LIKE '%love%';
 SELECT T1.artist_name ,  T1.country FROM 

 SELECT min(T1.duration) ,  min(T2.rating) ,  T2.genre_is FROM files AS T1 JOIN song AS T2 ON T1.f_id  =  T2.f_id GROUP BY T2.genre_is ORDER BY T2.genre_is

3605

        SELECT artist_name, COUNT(*) 
        FROM song 
        WHERE languages = 'English' 
        GROUP BY artist_name;
 SELECT T1.artist_name ,  count(*) FROM artist AS T1 JOIN song AS T2 ON T1.artist_name  =  T2.artist_name WHERE T2.languages  =  "english" GROUP BY T2.artist_name HAVING count(*)  >=  1

3606

        SELECT artist_name, COUNT(*) AS num_works
        FROM song
        WHERE languages = 'English'
        GROUP BY artist_name;
 SELECT T1.artist_name ,  count(*) FROM artist AS T1 JOIN song AS T2 ON T1.artist_name  =  T2.artist_name WHERE T2.languages  =  "english" GROUP BY T2.artist_name HAVING count(*)  >=  1

3607

        SELECT artist_name, country FROM artist a INNER JOIN song s ON a.artist_name = s.artist_name WHERE resolution > 900;
 SELECT T1.artist_name ,  T1.country FROM artist AS T1 JOIN song AS 

 SELECT Product_Name FROM Products ORDER BY Product_Price DESC LIMIT 1

4589

        SELECT Product_Type_Code, COUNT(Product_Type_Code) FROM Products GROUP BY Product_Type_Code;
 SELECT Product_Type_Code ,  COUNT(*) FROM Products GROUP BY Product_Type_Code

4590

        SELECT Product_Type_Code, COUNT(*) AS Frequency 
        FROM Products 
        GROUP BY Product_Type_Code 
        ORDER BY Frequency DESC 
        LIMIT 1;
 SELECT Product_Type_Code FROM Products GROUP BY Product_Type_Code ORDER BY COUNT(*) DESC LIMIT 1

4591

        SELECT Product_Type_Code
        FROM Products
        GROUP BY Product_Type_Code
        HAVING COUNT(Product_ID) >= 2;
 SELECT Product_Type_Code FROM Products GROUP BY Product_Type_Code HAVING COUNT(*)  >=  2

4592

        SELECT Product_Type_Code 
        FROM Products 
        WHERE Product_Price > 4500 OR Product_Price < 3000;
 SELECT Product_Type_Code FROM Products WHERE Product_Price  >  4500 INTERSECT SELECT Product_Type_Code FROM Products WHE

 SELECT count(*) FROM BOOKINGS

5091

        SELECT COUNT(*) FROM Bookings;
 SELECT count(*) FROM BOOKINGS

5092

        SELECT Order_Date FROM Bookings;
 SELECT Order_Date FROM BOOKINGS

5093

        SELECT Order_Date FROM Bookings;
 SELECT Order_Date FROM BOOKINGS

5094

        SELECT Planned_Delivery_Date, Actual_Delivery_Date FROM Bookings;
 SELECT Planned_Delivery_Date ,  Actual_Delivery_Date FROM BOOKINGS

5095

        SELECT Bookings.Planned_Delivery_Date, Bookings.Actual_Delivery_Date FROM Bookings;
 SELECT Planned_Delivery_Date ,  Actual_Delivery_Date FROM BOOKINGS

5096

        SELECT COUNT(*) FROM Customers;
 SELECT count(*) FROM CUSTOMERS

5097

        SELECT COUNT(*) FROM Customers;
 SELECT count(*) FROM CUSTOMERS

5098

        SELECT Customer_Phone, Customer_Email_Address 
        FROM Customers 
        WHERE Customer_Name = 'Harold';
 SELECT Customer_Phone ,  Customer_Email_Address FROM CUSTOMERS WHERE Customer_Name  =  "Harold"

5099

        SELECT Customer_Ph


        SELECT Service_Type_Code, Service_Type_Description, COUNT(*) FROM Ref_Service_Types GROUP BY Service_Type_Code, Service_Type_Description;
 SELECT T1.Service_Type_Description ,  T2.Service_Type_Code ,  COUNT(*) FROM Ref_Service_Types AS T1 JOIN Services AS T2 ON T1.Service_Type_Code  =  T2.Service_Type_Code GROUP BY T2.Service_Type_Code

5129

        SELECT Service_Type_Description, Service_Type_Code, COUNT(Service_ID) AS Number_of_Services FROM Ref_Service_Types INNER JOIN Services ON Ref_Service_Types.Service_Type_Code = Services.Service_Type_Code GROUP BY Service_Type_Description, Service_Type_Code;
 SELECT T1.Service_Type_Description ,  T2.Service_Type_Code ,  COUNT(*) FROM Ref_Service_Types AS T1 JOIN Services AS T2 ON T1.Service_Type_Code  =  T2.Service_Type_Code GROUP BY T2.Service_Type_Code

5130

        SELECT Service_Type_Description, Service_Type_Code 
        FROM Ref_Service_Types 
        INNER JOIN Services 
        ON Ref_Service_Types.Service_Type_Code = Serv

 SELECT T2.Store_Name FROM Addresses AS T1 JOIN Drama_Workshop_Groups AS T2 ON T1.Address_ID  =  T2.Address_ID WHERE T1.City_Town  =  "Feliciaberg"

5152

        SELECT Store_Email_Address 
        FROM Drama_Workshop_Groups 
        INNER JOIN Addresses 
        ON Drama_Workshop_Groups.Address_ID = Addresses.Address_ID 
        WHERE State_County = 'Alaska';
 SELECT T2.Store_Email_Address FROM Addresses AS T1 JOIN Drama_Workshop_Groups AS T2 ON T1.Address_ID  =  T2.Address_ID WHERE T1.State_County  =  "Alaska"

5153

        SELECT Store_Email_Address 
        FROM Drama_Workshop_Groups 
        INNER JOIN Addresses 
        ON Drama_Workshop_Groups.Address_ID = Addresses.Address_ID 
        WHERE State_County = 'Alaska';
 SELECT T2.Store_Email_Address FROM Addresses AS T1 JOIN Drama_Workshop_Groups AS T2 ON T1.Address_ID  =  T2.Address_ID WHERE T1.State_County  =  "Alaska"

5154

        SELECT City_Town, COUNT(*) AS 'Number of Drama Workshop Groups' 
        FROM Addresses 
      

 SELECT name FROM branch ORDER BY membership_amount DESC LIMIT 3

5406

        SELECT DISTINCT City FROM branch WHERE membership_amount >= 100;
 SELECT DISTINCT city FROM branch WHERE membership_amount  >=  100

5407

        SELECT City
        FROM branch
        WHERE membership_amount > 100;
 SELECT DISTINCT city FROM branch WHERE membership_amount  >=  100

5408

        SELECT Open_year 
        FROM branch 
        GROUP BY Open_year 
        HAVING COUNT(*) >= 2;
 SELECT open_year FROM branch GROUP BY open_year HAVING count(*)  >=  2

5409

        SELECT Open_year
        FROM branch
        GROUP BY Open_year
        HAVING COUNT(*) >= 2;
 SELECT open_year FROM branch GROUP BY open_year HAVING count(*)  >=  2

5410

        SELECT MIN(membership_amount), MAX(membership_amount) FROM branch WHERE Open_year = '2011' OR City = 'London';
 SELECT min(membership_amount) ,  max(membership_amount) FROM branch WHERE open_year  =  2011 OR city  =  'London'

5411

        SELECT MIN(mem

APIConnectionError: Error communicating with OpenAI: ('Connection aborted.', RemoteDisconnected('Remote end closed connection without response'))

In [135]:
len(data)

748

In [136]:
data[0]

'\n        SELECT catalog_entry_name FROM Catalog_Contents;'

In [143]:
df_long[['open_ai_prompt', 'open_ai_completion', 'model_response']].to_csv('davinci-completions.csv')

In [137]:
for i in range(len(data)):
    df_long.iloc[i]['model_response'] = data[i]

In [71]:
df_long = df_long.dropna()

In [138]:
df_long.shape

(1056, 12)

In [73]:
df_long[df_long['model_response'] != ''].shape

(42, 12)

In [74]:
len(data)

42

In [87]:
i = 50
print(df_long[df_long['model_response'] != ''].iloc[i]['open_ai_prompt'] + '\n')
print(df_long[df_long['model_response'] != ''].iloc[i]['open_ai_completion'])
print(df_long[df_long['model_response'] != ''].iloc[i]['model_response'])

IndexError: single positional indexer is out-of-bounds

In [144]:
# ignore: 
# - whitespace
# - capitalization
# - trailing semicolons
# - the difference between single and double quotes
# - difference between 'distinct(...)' and 'distinct ...'
# - explictly stating ASC (the ORDER BY default)
tmp = df_long[df_long['model_response'] != ''].copy()

tmp['correct'] = np.where(tmp['open_ai_completion'].str
                                                                         .strip()
                                                                         .str
                                                                         .lower()
                                                                         .apply(lambda x: re.sub("distinct\((.*)\)", "distinct\\1", 
                                                                                          re.sub('orderby(.*)asc', 'orderby\\1',
                                                                                          re.sub('"', "'",
                                                                                          re.sub('\s+', '', x))))
                                                                                .rstrip(';')) == 
                                     tmp['model_response'].str
                                                                     .strip()
                                                                     .str
                                                                     .lower()
                                                                         .apply(lambda x: re.sub("distinct\((.*)\)", "distinct\\1", 
                                                                                          re.sub('orderby(.*)asc', 'orderby\\1',
                                                                                          re.sub('"', "'",
                                                                                          re.sub('\s+', '', x))))
                                                                            .rstrip(';')), 1, 0)
tmp['correct'].mean()

0.32219251336898397

In [154]:
i = 50
print(tmp[tmp['correct'] == 0]['open_ai_prompt'].iloc[i])
print(tmp[tmp['correct'] == 0]['open_ai_completion'].iloc[i])
print(tmp[tmp['correct'] == 0]['model_response'].iloc[i])


        Convert text to SQL.
        You have the following DDLs:
        ```
        CREATE TABLE Customers (Customer_ID INTEGER ,Customer_name VARCHAR(40),PRIMARY KEY (Customer_ID));CREATE TABLE Services (Service_ID INTEGER ,Service_name VARCHAR(40),PRIMARY KEY (Service_ID));CREATE TABLE Available_Policies (Policy_ID INTEGER ,policy_type_code CHAR(15),Customer_Phone VARCHAR(255),PRIMARY KEY (Policy_ID),UNIQUE (Policy_ID));CREATE TABLE Customers_Policies (Customer_ID INTEGER ,Policy_ID INTEGER ,Date_Opened DATE,Date_Closed DATE,PRIMARY KEY (Customer_ID, Policy_ID),FOREIGN KEY (Customer_ID);CREATE TABLE First_Notification_of_Loss (FNOL_ID INTEGER ,Customer_ID INTEGER ,Policy_ID INTEGER ,Service_ID INTEGER ,PRIMARY KEY (FNOL_ID),UNIQUE (FNOL_ID),FOREIGN KEY (Service_ID);CREATE TABLE Claims (Claim_ID INTEGER ,FNOL_ID INTEGER ,Effective_Date DATE,PRIMARY KEY (Claim_ID),UNIQUE (Claim_ID),FOREIGN KEY (FNOL_ID);CREATE TABLE Settlements (Settlement_ID INTEGER ,Claim_ID INTEGER,Effective_Date