In [83]:
import time
import json
import openai
import re
import sqlite3
import sqlglot
from textwrap import dedent
import pandas as pd
import numpy as np

from os import path
from openai.api_requestor import error

In [476]:
def get_schema(db_id):
    PATH = 'spider/database/'
    
    path_to_file   = PATH + db_id + '/schema.sql'
    path_to_file_2 = PATH + db_id + '/' + db_id + '.sqlite'
    
    if path.exists(path_to_file):
        x = open(path_to_file, 'r').read()
        # Remove comment lines
        # x = re.sub(r"^/\*.*\n", '', x, flags=re.MULTILINE)
        x = re.sub(r'^--.*(\n|$)', '', x, flags=re.MULTILINE)
        x = re.sub(r',\s*--.*(\n|$)', ',\n', x, flags=re.MULTILINE)
        # x = re.sub(r'^/\*.*\*/', '', x)
        x = re.sub(r'/\*.*?\*/', '', x, flags=re.DOTALL)
        
        x = re.sub(r'CREATE TABLE \t', 'CREATE TABLE ', x)
        
        arr = x.split(';')
        create_tables = []
        for i, stmt in enumerate(arr):
            stmt = stmt.strip()
            if stmt.lower().startswith('create'):
                create_tables.append(stmt.replace('IF NOT EXISTS ', '').replace('\n', ''))
        return ';'.join(create_tables)
      
    elif path.exists(path_to_file_2):
        # Connect to the sqlite file
        conn = sqlite3.connect(path_to_file_2)

        # Read the schema of the twitter_1.sqlite file into a dataframe
        df_conn = pd.read_sql_query("SELECT sql FROM sqlite_master WHERE type='table';", conn)

        # Close the connection
        conn.close()
        
        create_tables = []
        for idx, row in df_conn.iterrows():
            row['sql'] = re.sub(r'--.*$', '', row['sql'], flags=re.MULTILINE)
            #row['sql'] = re.sub(r',\s*--.*(\n|$)', ',\n', row['sql'], flags=re.MULTILINE)
            #row['sql'] = re.sub(r'^\s*--.*(\n|$)', '\n', row['sql'], flags=re.MULTILINE)
            row['sql'] = row['sql'].strip()
            if row['sql'].lower().startswith('create'):
                create_tables.append(row['sql'].replace('IF NOT EXISTS ', '').replace('\n', ''))
        return ';'.join(create_tables)
    
    else: return None
    
def few_shot_prompt(x):
    return dedent(f'''
        Convert text to SQL.
        
        DDL:
        ```
        {x['schema-1']}
        ```
        Question:
        """{x['question-1']}"""
        Answer:
        {x['completion-1']}
        
        DDL:
        {x['schema-2']}
        Question:
        """{x['question-2']}"""
        Answer:
        {x['completion-2']}
        
        DDL:
        ```
        {x['schema']}
        ```
        Question:
        """{x['question']}"""
        Answer:
        ''')

# From OpenAI: The completion should start with a whitespace character (` `). 
# This tends to produce better results due to the tokenization we use.
def open_ai_completion(x):
    return f" {x['query']}\n" 

def call_raw_model(row, engine="text-davinci-003", stop=None, max_tokens=1000):
    while True:
        try:
            prompt = row["open_ai_prompt"]

            # Duber's parameters
            completions = openai.Completion.create(
                engine=engine,
                prompt=prompt,
                temperature=0.3,
                max_tokens=max_tokens,
                best_of=1,
                frequency_penalty=0,
                presence_penalty=0,
                stop=stop
            )

            print(completions.choices[0].text)
            # raise Exception()
            return completions.choices[0].text

        except error.RateLimitError:
            print('RateLimitError')
            time.sleep(15)
            continue
        
        except error.InvalidRequestError:
            print('InvalidRequestError: too many tokens')
            break
            
        except error.APIError:
            print('APIError')
            time.sleep(15)
            continue
            
        except error.Timeout:
            print('TimeoutError')
            time.sleep(1)
            continue
            
        except error.APIConnectionError:
            print('APIConnectionError')
            time.sleep(1)
            continue
            
def execution(row, query_type):
    PATH = 'spider/database/'
    db_id = row['db_id']
    
    query = row[query_type]

    # Connect to the SQLite database file
    conn = sqlite3.connect(PATH + db_id + '/' + db_id + '.sqlite')

    # Create a cursor object
    cursor = conn.cursor()

    # Execute a SELECT statement
    try:
        cursor.execute(query)
    except:
        return "invalid SQL"

    # Fetch the results
    results = cursor.fetchall()

    # Loop through the results and print them
    # for result in results:
    #     print(result)

    # Close the cursor and connection
    cursor.close()
    conn.close()
    
    return results

# Random Examples

In [None]:
def get_examples(x):
    examples = np.random.randint(0,len(df),2)
    for i in range(len(examples)):
        row = df.loc[examples[i]]
        x['question-' + str(i+1)]   = row['question']
        x['schema-' + str(i+1)]     = row['schema']
        x['completion-' + str(i+1)] = row['query']
        
    return x[['schema-1', 'question-1', 'completion-1', 'schema-2', 'question-2', 'completion-2']]

In [5]:
df = pd.read_json('spider/train_spider.json')

# Set schema
df['schema'] = df.apply(lambda x: get_schema(x['db_id']), axis=1)
df['schema'] = df['schema'].apply(lambda x: re.sub('(?i) REFERENCES.*?(;|$)', ';', x)) 
df['schema'] = df['schema'].apply(lambda x: re.sub('NOT NULL', '', x))

# Select the examples to show the model
df[['schema-1', 'question-1', 'completion-1', 'schema-2', 'question-2', 'completion-2']] = df.apply(lambda x: get_examples(x), axis=1)
df['open_ai_prompt'] = df.apply(lambda x: few_shot_prompt(x), axis=1)


# Set Open AI prompt and completion
df['open_ai_completion'] = df.apply(lambda x: open_ai_completion(x), axis=1)

In [6]:
df.columns

Index(['db_id', 'query', 'query_toks', 'query_toks_no_value', 'question',
       'question_toks', 'sql', 'schema', 'schema-1', 'question-1',
       'completion-1', 'schema-2', 'question-2', 'completion-2',
       'open_ai_prompt', 'open_ai_completion'],
      dtype='object')

In [7]:
print(df.loc[0]['open_ai_prompt'])


Convert text to SQL.

DDL:
```
CREATE TABLE "county" ("County_Id" int,"County_name" text,"Population" real,"Zip_code" text,PRIMARY KEY ("County_Id"));CREATE TABLE "party" ("Party_ID" int,"Year" real,"Party" text,"Governor" text,"Lieutenant_Governor" text,"Comptroller" text,"Attorney_General" text,"US_Senate" text,PRIMARY KEY ("Party_ID"));CREATE TABLE "election" ("Election_ID" int,"Counties_Represented" text,"District" int,"Delegate" text,"Party" int,"First_Elected" real,"Committee" text,PRIMARY KEY ("Election_ID"),FOREIGN KEY (`Party`);
```
Question:
"""Who are the lieutenant governor and comptroller from the democratic party?"""
Answer:
SELECT Lieutenant_Governor ,  Comptroller FROM party WHERE Party  =  "Democratic"

DDL:
CREATE TABLE "list" ( 	"LastName" TEXT, 	"FirstName" TEXT, 	"Grade" INTEGER, 	"Classroom" INTEGER,	PRIMARY KEY(LastName, FirstName));CREATE TABLE "teachers" ( 	"LastName" TEXT, 	"FirstName" TEXT, 	"Classroom" INTEGER,	PRIMARY KEY(LastName, FirstName))
Question:
""

In [8]:
tmp = df.sample(100)
tmp['model_response'] = ''

data = []
for idx, row in tmp.iterrows():
    print(idx)
    if row['model_response'] != '':
        print('already completed')
        data.append(row['model_response'])
        continue
    
    print(row['open_ai_completion'])
    new_response = call_raw_model(row)
    data.append(new_response)
    tmp.loc[idx,'model_response'] = new_response

4447
 SELECT T1.name FROM Person AS T1 JOIN PersonFriend AS T2 ON T1.name  =  T2.name WHERE T2.friend IN (SELECT name FROM Person WHERE age  >  40) EXCEPT SELECT T1.name FROM Person AS T1 JOIN PersonFriend AS T2 ON T1.name  =  T2.name WHERE T2.friend IN (SELECT name FROM Person WHERE age  <  30)

SELECT name FROM Person WHERE age > 40 AND NOT EXISTS (SELECT * FROM PersonFriend WHERE name = Person.name AND friend < 30)
6700
 SELECT T1.lesson_id FROM Lessons AS T1 JOIN Staff AS T2 ON T1.staff_id = T2.staff_id WHERE T2.first_name = "Janessa" AND T2.last_name = "Sawayn" AND nickname LIKE "%s%";

InvalidRequestError: too many tokens
1395
 SELECT max(capacity) ,  avg(capacity) ,  building FROM classroom GROUP BY building

SELECT building, MAX(capacity) as max_capacity, AVG(capacity) as avg_capacity FROM classroom GROUP BY building
1650
 SELECT count(*) FROM artist

SELECT COUNT(*) FROM artist
4922
 SELECT t3.headquartered_city ,  count(*) FROM store AS t1 JOIN store_district AS t2 ON t1.stor

SELECT o.organisation_id, o.organisation_details FROM organisations o INNER JOIN grants g ON o.organisation_id = g.organisation_id WHERE g.grant_amount > 6000;
5915
 SELECT T1.Museum_Details ,  T2.Opening_Hours FROM MUSEUMS AS T1 JOIN TOURIST_ATTRACTIONS AS T2 ON T1.Museum_ID  =  T2.Tourist_Attraction_ID

SELECT Museum_Details, Opening_Hours FROM Museums
2276
 SELECT T2.Name FROM entrepreneur AS T1 JOIN people AS T2 ON T1.People_ID  =  T2.People_ID WHERE T1.Investor != "Rachel Elnaugh"

SELECT Name FROM people p INNER JOIN entrepreneur e ON p.People_ID = e.People_ID WHERE e.Investor != "Rachel Elnaugh"
5717
 SELECT fname ,  lname FROM student WHERE city_code != 'HKG' ORDER BY age

SELECT Fname, Lname FROM Student WHERE city_code != 'HKG' ORDER BY Age ASC
4983
 SELECT DISTINCT cName FROM tryout ORDER BY cName

SELECT cName FROM Tryout ORDER BY cName ASC
5353
 SELECT problem_log_id FROM problem_log ORDER BY log_entry_date DESC LIMIT 1

SELECT problem_log_id FROM Problem_Log ORDER BY log_

SELECT T2.Type, AVG(T2.Tonnage) FROM mission AS T1 JOIN ship AS T2 ON T1.Ship_ID  =  T2.Ship_ID GROUP BY T2.Type
913
 SELECT DISTINCT t3.policy_type_code FROM customers AS t1 JOIN customers_policies AS t2 ON t1.customer_id  =  t2.customer_id JOIN available_policies AS t3 ON t2.policy_id  =  t3.policy_id WHERE t1.customer_name  =  (SELECT t1.customer_name FROM customers AS t1 JOIN customers_policies AS t2 ON t1.customer_id  =  t2.customer_id GROUP BY t1.customer_name ORDER BY count(*) DESC LIMIT 1)

SELECT policy_type_code FROM Available_Policies WHERE Customer_Phone IN (SELECT Customer_Phone FROM Available_Policies GROUP BY Customer_Phone ORDER BY COUNT(*) DESC LIMIT 1)
6159
 SELECT Song FROM volume WHERE Weeks_on_Top  >  1

SELECT Song FROM volume WHERE Weeks_on_Top > 1
1103
 SELECT College FROM match_season GROUP BY College HAVING count(*)  >=  2 ORDER BY College DESC

SELECT college FROM match_season GROUP BY college HAVING COUNT(*)  >=  2 ORDER BY college DESC
4381
 SELECT T1.outco

In [9]:
tmp['model_response']

4447    SELECT name FROM Person WHERE age > 40 AND NOT...
6700                                                 None
1395    SELECT building, MAX(capacity) as max_capacity...
1650                          SELECT COUNT(*) FROM artist
4922    SELECT Headquartered_City, COUNT(Store_ID) AS ...
                              ...                        
1493    SELECT Venue FROM debate ORDER BY Num_of_Audie...
849                                                  None
1553    SELECT account_id, account_name, other_account...
103     SELECT T1.student_details FROM Students AS T1 ...
4952       SELECT COUNT(*) FROM College WHERE enr > 15000
Name: model_response, Length: 100, dtype: object

In [12]:
len(tmp[tmp['model_response'].isna()])

20

In [62]:
df_rand['open_ai_execution'] = df_rand.apply(lambda x: execution(x, 'open_ai_completion'), axis=1)
df_rand['model_response_execution'] = df_rand.apply(lambda x: execution(x, 'model_response'), axis=1)
df_rand['execution_accuracy'] = np.where(df_rand['open_ai_execution'] == df_rand['model_response_execution'], 1, 0)
df_rand[~df_rand['model_response'].isna()]['execution_accuracy'].mean()

0.5625

In [11]:
i = 0
print(tmp[~tmp['model_response'].isna()].iloc[i]['open_ai_prompt'])
print(tmp[~tmp['model_response'].isna()].iloc[i]['open_ai_completion'])
print(tmp[~tmp['model_response'].isna()].iloc[i]['model_response'])



Convert text to SQL.

DDL:
```
create table Student (        StuID        INTEGER PRIMARY KEY,        LName        VARCHAR(12),        Fname        VARCHAR(12),        Age      INTEGER,        Sex      VARCHAR(1),        Major        INTEGER,        Advisor      INTEGER,        city_code    VARCHAR(3) );create table Video_Games (       GameID           INTEGER PRIMARY KEY,       GName            VARCHAR(40),       GType            VARCHAR(40));create table Plays_Games (       StuID                INTEGER,       GameID            INTEGER,       Hours_Played      INTEGER,       FOREIGN KEY(GameID);create table SportsInfo (  StuID INTEGER,  SportName VARCHAR(32),  HoursPerWeek INTEGER,  GamesPlayed INTEGER,  OnScholarship VARCHAR(1),  FOREIGN KEY(StuID);
```
Question:
"""What are the names of all video games that are collectible cards?"""
Answer:
SELECT gname FROM Video_games WHERE gtype  =  "Collectible card game"

DDL:
CREATE TABLE "film" ("Film_ID" int,"Rank_in_series" int,"Number_in_

In [13]:
df_rand = tmp.copy()

In [69]:
df_rand[~df_rand['model_response'].isna()]['model_response']

4447    SELECT name FROM Person WHERE age > 40 AND NOT...
1395    SELECT building, MAX(capacity) as max_capacity...
1650                          SELECT COUNT(*) FROM artist
4922    SELECT Headquartered_City, COUNT(Store_ID) AS ...
1156    SELECT Name FROM people INNER JOIN body_builde...
                              ...                        
2485    SELECT title FROM Movie WHERE mID NOT IN (SELE...
1493    SELECT Venue FROM debate ORDER BY Num_of_Audie...
1553    SELECT account_id, account_name, other_account...
103     SELECT T1.student_details FROM Students AS T1 ...
4952       SELECT COUNT(*) FROM College WHERE enr > 15000
Name: model_response, Length: 80, dtype: object

# Examples from the same db
70% accuracy (69/98 - 2 had too many tokens) (max_tokens = 2000 and best_of=3)  
70% accuracy (70/100) (max_tokens = 1000 and best_of=1)  
With 2 examples, 2 out of 100 prompts have too many tokens

In [351]:
def get_db_examples(x, df):
    db_id = x['db_id']
    tmp = df[df['db_id'] == db_id].copy()
    # Remove the row x from the dataframe tmp
    tmp = tmp[~tmp.index.isin([x.index])]
    
    if len(tmp) < 2:
        print('not enough examples in db')
        return None
    
    # x = {}
    n_examples = 2
    examples = np.random.randint(0,len(tmp),n_examples)
    for i in range(n_examples):
        row_ex = tmp.iloc[examples[i]]
        x['question-' + str(i+1)]   = row_ex['question']
        x['completion-' + str(i+1)] = row_ex['query']

    return x[['question-1', 'completion-1', 'question-2', 'completion-2']]

def same_db_prompt(x):
    return dedent(f'''
        Convert text to SQL.
        
        DDL:
        ```
        {x['schema']}
        ```
        Question:
        """{x['question-1']}"""
        Answer:
        {x['completion-1']}

        Question:
        """{x['question-2']}"""
        Answer:
        {x['completion-2']}
        
        Question:
        """{x['question']}"""
        Answer:
        ''')

In [354]:
df = pd.read_json('spider/train_spider.json')

# Set schema
df['schema'] = df.apply(lambda x: get_schema(x['db_id']), axis=1)
# df['schema'] = df['schema'].apply(lambda x: re.sub('(?i) REFERENCES.*?(;|$)', ';', x)) 
df['schema'] = df['schema'].apply(lambda x: re.sub(r'\s+references\s+.*?[,;)]', '', x)) 
df['schema'] = df['schema'].apply(lambda x: re.sub('NOT NULL', '', x))
df['schema'] = df['schema'].apply(lambda x: re.sub(r',\s*(PRIMARY KEY|FOREIGN KEY).*?(;|$)', ');', x, 
                                                   flags=re.IGNORECASE))
df['schema'] = df['schema'].apply(lambda x: x.replace('`', '"'))
df['schema'] = df['schema'].apply(lambda x: re.sub(r'"(\w+)"', r'\1', x))
df['schema'] = df['schema'].apply(lambda x: x.replace('AUTOINCREMENT', ''))
df['schema'] = df['schema'].apply(lambda x: x.replace('UNSIGNED', ''))
df['schema'] = df['schema'].apply(lambda x: x.replace(' BIT', ' BOOLEAN'))
df['schema'] = df['schema'].apply(lambda x: x.replace('YEAR DEFAULT NULL', 'SMALLINT'))
df['schema'] = df['schema'].apply(lambda x: x.replace('MEDIUMINT', 'SMALLINT'))
df['schema'] = df['schema'].apply(lambda x: x.replace('character varchar', 'varchar'))
df['schema'] = df['schema'].apply(lambda x: re.sub(r'PRIMARY KEY,', ',', x))

# Select the examples to show the model
df[['question-1', 'completion-1', 'question-2', 'completion-2']] =\
    df.apply(lambda x: get_db_examples(x, df), axis=1)
df['open_ai_prompt'] = df.apply(lambda x: same_db_prompt(x), axis=1)
df['open_ai_completion'] = df.apply(lambda x: open_ai_completion(x), axis=1)

In [355]:
# check all the DDLs are valid SQL
for i in range(df.shape[0]):
    print(i)
    sqlglot.transpile(df.iloc[i]['schema'], read='sqlite')

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
27

2039
2040
2041
2042
2043
2044
2045
2046
2047
2048
2049
2050
2051
2052
2053
2054
2055
2056
2057
2058
2059
2060
2061
2062
2063
2064
2065
2066
2067
2068
2069
2070
2071
2072
2073
2074
2075
2076
2077
2078
2079
2080
2081
2082
2083
2084
2085
2086
2087
2088
2089
2090
2091
2092
2093
2094
2095
2096
2097
2098
2099
2100
2101
2102
2103
2104
2105
2106
2107
2108
2109
2110
2111
2112
2113
2114
2115
2116
2117
2118
2119
2120
2121
2122
2123
2124
2125
2126
2127
2128
2129
2130
2131
2132
2133
2134
2135
2136
2137
2138
2139
2140
2141
2142
2143
2144
2145
2146
2147
2148
2149
2150
2151
2152
2153
2154
2155
2156
2157
2158
2159
2160
2161
2162
2163
2164
2165
2166
2167
2168
2169
2170
2171
2172
2173
2174
2175
2176
2177
2178
2179
2180
2181
2182
2183
2184
2185
2186
2187
2188
2189
2190
2191
2192
2193
2194
2195
2196
2197
2198
2199
2200
2201
2202
2203
2204
2205
2206
2207
2208
2209
2210
2211
2212
2213
2214
2215
2216
2217
2218
2219
2220
2221
2222
2223
2224
2225
2226
2227
2228
2229
2230
2231
2232
2233
2234
2235
2236
2237
2238


3679
3680
3681
3682
3683
3684
3685
3686
3687
3688
3689
3690
3691
3692
3693
3694
3695
3696
3697
3698
3699
3700
3701
3702
3703
3704
3705
3706
3707
3708
3709
3710
3711
3712
3713
3714
3715
3716
3717
3718
3719
3720
3721
3722
3723
3724
3725
3726
3727
3728
3729
3730
3731
3732
3733
3734
3735
3736
3737
3738
3739
3740
3741
3742
3743
3744
3745
3746
3747
3748
3749
3750
3751
3752
3753
3754
3755
3756
3757
3758
3759
3760
3761
3762
3763
3764
3765
3766
3767
3768
3769
3770
3771
3772
3773
3774
3775
3776
3777
3778
3779
3780
3781
3782
3783
3784
3785
3786
3787
3788
3789
3790
3791
3792
3793
3794
3795
3796
3797
3798
3799
3800
3801
3802
3803
3804
3805
3806
3807
3808
3809
3810
3811
3812
3813
3814
3815
3816
3817
3818
3819
3820
3821
3822
3823
3824
3825
3826
3827
3828
3829
3830
3831
3832
3833
3834
3835
3836
3837
3838
3839
3840
3841
3842
3843
3844
3845
3846
3847
3848
3849
3850
3851
3852
3853
3854
3855
3856
3857
3858
3859
3860
3861
3862
3863
3864
3865
3866
3867
3868
3869
3870
3871
3872
3873
3874
3875
3876
3877
3878


5517
5518
5519
5520
5521
5522
5523
5524
5525
5526
5527
5528
5529
5530
5531
5532
5533
5534
5535
5536
5537
5538
5539
5540
5541
5542
5543
5544
5545
5546
5547
5548
5549
5550
5551
5552
5553
5554
5555
5556
5557
5558
5559
5560
5561
5562
5563
5564
5565
5566
5567
5568
5569
5570
5571
5572
5573
5574
5575
5576
5577
5578
5579
5580
5581
5582
5583
5584
5585
5586
5587
5588
5589
5590
5591
5592
5593
5594
5595
5596
5597
5598
5599
5600
5601
5602
5603
5604
5605
5606
5607
5608
5609
5610
5611
5612
5613
5614
5615
5616
5617
5618
5619
5620
5621
5622
5623
5624
5625
5626
5627
5628
5629
5630
5631
5632
5633
5634
5635
5636
5637
5638
5639
5640
5641
5642
5643
5644
5645
5646
5647
5648
5649
5650
5651
5652
5653
5654
5655
5656
5657
5658
5659
5660
5661
5662
5663
5664
5665
5666
5667
5668
5669
5670
5671
5672
5673
5674
5675
5676
5677
5678
5679
5680
5681
5682
5683
5684
5685
5686
5687
5688
5689
5690
5691
5692
5693
5694
5695
5696
5697
5698
5699
5700
5701
5702
5703
5704
5705
5706
5707
5708
5709
5710
5711
5712
5713
5714
5715
5716


In [356]:
print(df.iloc[0]['open_ai_prompt'])


Convert text to SQL.

DDL:
```
CREATE TABLE department (Department_ID int,Name text,Creation text,Ranking int,Budget_in_Billions real,Num_Employees real);CREATE TABLE head (head_ID int,name text,born_state text,age real);CREATE TABLE management (department_ID int,head_ID int,temporary_acting text);
```
Question:
"""What are the names of the states where at least 3 heads were born?"""
Answer:
SELECT born_state FROM head GROUP BY born_state HAVING count(*)  >=  3

Question:
"""What are the names of the states where at least 3 heads were born?"""
Answer:
SELECT born_state FROM head GROUP BY born_state HAVING count(*)  >=  3

Question:
"""How many heads of the departments are older than 56 ?"""
Answer:



In [362]:
tmp = df.sample(100)
tmp['model_response'] = ''

data = []
for idx, row in tmp.iterrows():
    print(idx)
    if row['model_response'] != '':
        print('already completed')
        data.append(row['model_response'])
        continue
    
    print(row['open_ai_completion'])
    new_response = call_raw_model(row)
    data.append(new_response)
    tmp.loc[idx,'model_response'] = new_response

3073
 SELECT DISTINCT T1.cust_name ,  T1.credit_score FROM customer AS T1 JOIN loan AS T2 ON T1.cust_id  =  T2.cust_id

SELECT cust_name, credit_score FROM customer JOIN loan ON customer.cust_ID  =  loan.cust_ID
2310
 SELECT T1.Name FROM people AS T1 JOIN perpetrator AS T2 ON T1.People_ID  =  T2.People_ID WHERE T2.Country != "China"

SELECT Name FROM people JOIN perpetrator ON people.People_ID = perpetrator.People_ID WHERE Country != 'China'
1543
 SELECT count(DISTINCT claim_outcome_code) FROM claims_processing

SELECT count(DISTINCT claim_outcome_code) FROM claims_processing
479
 SELECT Advisor FROM Student WHERE StuID  =  1004

SELECT Advisor FROM Student WHERE StuID = 1004
4672
 SELECT T1.DName FROM DEPARTMENT AS T1 JOIN MINOR_IN AS T2 ON T1.DNO  =  T2.DNO GROUP BY T2.DNO ORDER BY count(*) DESC LIMIT 1

SELECT T2.DName FROM MINOR_IN AS T1 JOIN DEPARTMENT AS T2 ON T1.DNO  =  T2.DNO GROUP BY T2.DName ORDER BY COUNT(*) DESC LIMIT 1
6880
 SELECT T1.name FROM airlines AS T1 JOIN routes A

SELECT COUNT(*) FROM county_public_safety;
1984
 SELECT Carrier ,  COUNT(*) FROM phone GROUP BY Carrier

SELECT Carrier, COUNT(*) FROM phone GROUP BY Carrier
1280
 SELECT T1.season FROM game AS T1 JOIN injury_accident AS T2 ON T1.id  =  T2.game_id WHERE T2.player  =  'Walter Samuel'

SELECT T2.Season FROM game AS T2 JOIN injury_accident AS T1 ON T2.id = T1.game_id WHERE T1.Player = 'Walter Samuel'
2724
 SELECT T2.region_name FROM affected_region AS T1 JOIN region AS T2 ON T1.region_id  =  T2.region_id JOIN storm AS T3 ON T1.storm_id  =  T3.storm_id WHERE T3.number_deaths  >=  10

SELECT region_name FROM region JOIN affected_region ON region.region_id = affected_region.region_id JOIN storm ON affected_region.storm_id = storm.storm_id WHERE number_deaths >= 10;
5477
 SELECT DISTINCT T1.Age FROM STUDENT AS T1 JOIN VOTING_RECORD AS T2 ON T1.StuID  =  T2.Secretary_Vote WHERE T2.Election_Cycle  =  "Fall"

SELECT DISTINCT Age FROM STUDENT AS T1 JOIN VOTING_RECORD AS T2 ON T1.StuID  =  T2.Secr

SELECT outcome_description FROM Research_Outcomes
5633
 SELECT Builder FROM railway ORDER BY Builder ASC

SELECT Builder FROM railway ORDER BY Builder ASC
3149
 SELECT DISTINCT asset_model FROM Assets

SELECT DISTINCT asset_model FROM Assets;
2480
 SELECT DISTINCT T3.name ,  T2.title ,  T1.stars FROM Rating AS T1 JOIN Movie AS T2 ON T1.mID  =  T2.mID JOIN Reviewer AS T3 ON T1.rID  =  T3.rID WHERE T2.director  =  T3.name

SELECT T2.name, T3.title, T1.stars FROM Rating AS T1 JOIN Reviewer AS T2 ON T1.rID  =  T2.rID JOIN Movie AS T3 ON T1.mID  =  T3.mID WHERE T2.name  =  T3.director
463
 SELECT count(DISTINCT advisor) FROM Student

SELECT COUNT(DISTINCT Advisor) FROM Student
3306
 SELECT T2.dept_name ,  T2.dept_address ,  count(*) FROM student AS T1 JOIN department AS T2 ON T1.dept_code  =  T2.dept_code GROUP BY T1.dept_code ORDER BY count(*) DESC LIMIT 3

SELECT T1.dept_name, T1.dept_address, COUNT(T2.stu_num) AS num_students FROM department AS T1 JOIN student AS T2 ON T1.dept_code  =  T

In [363]:
tmp[~(tmp['model_response'] == '')].shape

(100, 15)

In [364]:
tmp['open_ai_execution'] = tmp.apply(lambda x: execution(x, 'open_ai_completion'), axis=1)
tmp['model_response_execution'] = tmp.apply(lambda x: execution(x, 'model_response'), axis=1)
tmp['execution_accuracy'] = np.where(tmp['open_ai_execution'] == tmp['model_response_execution'], 1, 0)
tmp[~(tmp['model_response'] == '') & ~(tmp['model_response'].isna())]['execution_accuracy'].mean()

0.7

In [367]:
df_db_rand = tmp.copy()

In [368]:
i = 0
print(df_db_rand[~(df_db_rand['model_response'] == '') & 
                 (df_db_rand['execution_accuracy'] == 0)].iloc[i]['open_ai_prompt'])
print(df_db_rand[~(df_db_rand['model_response'] == '') & 
                 (df_db_rand['execution_accuracy'] == 0)].iloc[i]['open_ai_completion'])
print(df_db_rand[~(df_db_rand['model_response'] == '') & 
                 (df_db_rand['execution_accuracy'] == 0)].iloc[i]['model_response'])



Convert text to SQL.

DDL:
```
CREATE TABLE bank (branch_ID int ,bname varchar(20),no_of_customers int,city varchar(10),state varchar(20));CREATE TABLE customer (cust_ID varchar(3) ,cust_name varchar(20),acc_type char(1),acc_bal int,no_of_loans int,credit_score int,branch_ID int,state varchar(20));CREATE TABLE loan (loan_ID varchar(3) ,loan_type varchar(15),cust_ID varchar(3),branch_ID varchar(3),amount int);
```
Question:
"""Find the names of bank branches that have provided a loan to any customer whose credit score is below 100."""
Answer:
SELECT T2.bname FROM loan AS T1 JOIN bank AS T2 ON T1.branch_id  =  T2.branch_id JOIN customer AS T3 ON T1.cust_id  =  T3.cust_id WHERE T3.credit_score  <  100

Question:
"""What are the total account balances for each customer from Utah or Texas?"""
Answer:
SELECT sum(acc_bal) FROM customer WHERE state  =  'Utah' OR state  =  'Texas'

Question:
"""What are the different names and credit scores of customers who have taken a loan?"""
Answer:

 SELE

# More examples from the same db

In [381]:
def get_more_db_examples(x, df, num_examples):
    db_id = x['db_id']
    tmp = df[df['db_id'] == db_id].copy()
    # Remove the row x from the dataframe tmp
    tmp = tmp[~tmp.index.isin([x.index])]
    
    if len(tmp) < num_examples:
        print('not enough examples in db')
        return None
    
    examples = np.random.randint(0,len(tmp),num_examples)
    for i in range(num_examples):
        row_ex = tmp.iloc[examples[i]]
        x['question-' + str(i+1)]   = row_ex['question']
        x['completion-' + str(i+1)] = row_ex['query']
        
    return x[['question-' + str(i+1) for i in range(num_examples)] + 
             ['completion-' + str(i+1) for i in range(num_examples)]]

def more_db_examples_prompt(x, num_examples):
    prompt = dedent(f'''
        Convert text to SQL.
        
        DDL:
        ```
        {x['schema']}
        ```
    ''')
    for i in range(1, num_examples + 1):
        question_key = f'question-{i}'
        completion_key = f'completion-{i}'
        if question_key in x and completion_key in x:
            prompt += dedent(f'''
            Question:
            """{x[question_key]}"""
            Answer:
            {x[completion_key]}
            ''')

    prompt += f'''\nQuestion:\n"""{x['question']}"""\nAnswer:\n'''

    return prompt


In [390]:
# We should be able to go up to 6 examples without causing a problem
df.groupby('db_id')['query'].count().min()

7

In [437]:
df = pd.read_json('spider/train_spider.json')

# Set schema
df['schema'] = df.apply(lambda x: get_schema(x['db_id']), axis=1)
df['schema'] = df['schema'].apply(lambda x: re.sub(r'\s+references\s+.*?[,;)]', '', x)) 
df['schema'] = df['schema'].apply(lambda x: re.sub('NOT NULL', '', x))
df['schema'] = df['schema'].apply(lambda x: re.sub(r',\s*(PRIMARY KEY|FOREIGN KEY).*?(;|$)', ');', x, 
                                                   flags=re.IGNORECASE))
df['schema'] = df['schema'].apply(lambda x: x.replace('`', '"'))
df['schema'] = df['schema'].apply(lambda x: re.sub(r'"(\w+)"', r'\1', x))
df['schema'] = df['schema'].apply(lambda x: x.replace('AUTOINCREMENT', ''))
df['schema'] = df['schema'].apply(lambda x: x.replace('UNSIGNED', ''))
df['schema'] = df['schema'].apply(lambda x: x.replace(' BIT', ' BOOLEAN'))
df['schema'] = df['schema'].apply(lambda x: x.replace('YEAR DEFAULT NULL', 'SMALLINT'))
df['schema'] = df['schema'].apply(lambda x: x.replace('MEDIUMINT', 'SMALLINT'))
df['schema'] = df['schema'].apply(lambda x: x.replace('character varchar', 'varchar'))
df['schema'] = df['schema'].apply(lambda x: re.sub(r'PRIMARY KEY,', ',', x))

# Select the examples to show the model
num_examples = 3
df[['question-' + str(i+1) for i in range(num_examples)] + ['completion-' + str(i+1) for i in range(num_examples)]]=\
    df.apply(lambda x: get_more_db_examples(x, df, num_examples), axis=1)
df['open_ai_prompt'] = df.apply(lambda x: more_db_examples_prompt(x, num_examples), axis=1)
df['open_ai_completion'] = df.apply(lambda x: open_ai_completion(x), axis=1)

In [438]:
# check all the DDLs are valid SQL
for i in range(df.shape[0]):
    print(i)
    sqlglot.transpile(df.iloc[i]['schema'], read='sqlite')

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
27

1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
2002
2003
2004
2005
2006
2007
2008
2009
2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
2021
2022
2023
2024
2025
2026
2027
2028
2029
2030
2031
2032
2033
2034
2035
2036
2037
2038
2039
2040
2041
2042
2043
2044
2045
2046
2047
2048
2049
2050
2051
2052
2053
2054
2055
2056
2057
2058
2059
2060
2061
2062
2063
2064
2065
2066
2067
2068
2069
2070
2071
2072
2073
2074
2075
2076
2077
2078
2079
2080
2081
2082
2083
2084
2085
2086
2087
2088
2089
2090
2091
2092
2093
2094
2095
2096
2097
2098
2099
2100
2101
2102
2103
2104
2105
2106
2107
2108
2109
2110
2111
2112
2113
2114
2115
2116
2117
2118
2119
2120
2121
2122
2123
2124
2125
2126


3632
3633
3634
3635
3636
3637
3638
3639
3640
3641
3642
3643
3644
3645
3646
3647
3648
3649
3650
3651
3652
3653
3654
3655
3656
3657
3658
3659
3660
3661
3662
3663
3664
3665
3666
3667
3668
3669
3670
3671
3672
3673
3674
3675
3676
3677
3678
3679
3680
3681
3682
3683
3684
3685
3686
3687
3688
3689
3690
3691
3692
3693
3694
3695
3696
3697
3698
3699
3700
3701
3702
3703
3704
3705
3706
3707
3708
3709
3710
3711
3712
3713
3714
3715
3716
3717
3718
3719
3720
3721
3722
3723
3724
3725
3726
3727
3728
3729
3730
3731
3732
3733
3734
3735
3736
3737
3738
3739
3740
3741
3742
3743
3744
3745
3746
3747
3748
3749
3750
3751
3752
3753
3754
3755
3756
3757
3758
3759
3760
3761
3762
3763
3764
3765
3766
3767
3768
3769
3770
3771
3772
3773
3774
3775
3776
3777
3778
3779
3780
3781
3782
3783
3784
3785
3786
3787
3788
3789
3790
3791
3792
3793
3794
3795
3796
3797
3798
3799
3800
3801
3802
3803
3804
3805
3806
3807
3808
3809
3810
3811
3812
3813
3814
3815
3816
3817
3818
3819
3820
3821
3822
3823
3824
3825
3826
3827
3828
3829
3830
3831


5450
5451
5452
5453
5454
5455
5456
5457
5458
5459
5460
5461
5462
5463
5464
5465
5466
5467
5468
5469
5470
5471
5472
5473
5474
5475
5476
5477
5478
5479
5480
5481
5482
5483
5484
5485
5486
5487
5488
5489
5490
5491
5492
5493
5494
5495
5496
5497
5498
5499
5500
5501
5502
5503
5504
5505
5506
5507
5508
5509
5510
5511
5512
5513
5514
5515
5516
5517
5518
5519
5520
5521
5522
5523
5524
5525
5526
5527
5528
5529
5530
5531
5532
5533
5534
5535
5536
5537
5538
5539
5540
5541
5542
5543
5544
5545
5546
5547
5548
5549
5550
5551
5552
5553
5554
5555
5556
5557
5558
5559
5560
5561
5562
5563
5564
5565
5566
5567
5568
5569
5570
5571
5572
5573
5574
5575
5576
5577
5578
5579
5580
5581
5582
5583
5584
5585
5586
5587
5588
5589
5590
5591
5592
5593
5594
5595
5596
5597
5598
5599
5600
5601
5602
5603
5604
5605
5606
5607
5608
5609
5610
5611
5612
5613
5614
5615
5616
5617
5618
5619
5620
5621
5622
5623
5624
5625
5626
5627
5628
5629
5630
5631
5632
5633
5634
5635
5636
5637
5638
5639
5640
5641
5642
5643
5644
5645
5646
5647
5648
5649


In [439]:
print(df.iloc[0]['open_ai_prompt'])


Convert text to SQL.

DDL:
```
CREATE TABLE department (Department_ID int,Name text,Creation text,Ranking int,Budget_in_Billions real,Num_Employees real);CREATE TABLE head (head_ID int,name text,born_state text,age real);CREATE TABLE management (department_ID int,head_ID int,temporary_acting text);
```

Question:
"""How many acting statuses are there?"""
Answer:
SELECT count(DISTINCT temporary_acting) FROM management

Question:
"""List the name, born state and age of the heads of departments ordered by age."""
Answer:
SELECT name ,  born_state ,  age FROM head ORDER BY age

Question:
"""List the states where both the secretary of 'Treasury' department and the secretary of 'Homeland Security' were born."""
Answer:
SELECT T3.born_state FROM department AS T1 JOIN management AS T2 ON T1.department_id  =  T2.department_id JOIN head AS T3 ON T2.head_id  =  T3.head_id WHERE T1.name  =  'Treasury' INTERSECT SELECT T3.born_state FROM department AS T1 JOIN management AS T2 ON T1.department_id  

In [440]:
tmp = df.sample(100)
tmp['model_response'] = ''

data = []
for idx, row in tmp.iterrows():
    print(idx)
    if row['model_response'] != '':
        print('already completed')
        data.append(row['model_response'])
        continue
    
    print(row['open_ai_completion'])
    new_response = call_raw_model(row)
    data.append(new_response)
    tmp.loc[idx,'model_response'] = new_response

265
 SELECT RESULT FROM musical GROUP BY RESULT ORDER BY COUNT(*) DESC LIMIT 1

SELECT RESULT FROM musical GROUP BY RESULT ORDER BY COUNT(*) DESC LIMIT 1
6957
 SELECT name FROM mill WHERE name LIKE '%Moulin%'

SELECT name FROM mill WHERE name LIKE '%Moulin%'
5797
 SELECT complaint_status_code FROM complaints GROUP BY complaint_status_code HAVING count(*)  >  3

SELECT complaint_status_code, COUNT(*) FROM complaints GROUP BY complaint_status_code HAVING COUNT(*) > 3
5082
 SELECT Shop_Name FROM shop WHERE Shop_ID NOT IN (SELECT Shop_ID FROM stock)

SELECT Shop_Name FROM shop WHERE Shop_ID NOT IN (SELECT Shop_ID FROM stock)
1337
 SELECT count(DISTINCT s_id) FROM advisor

SELECT count(*) FROM advisor
6731
 SELECT count(*) FROM Faculty WHERE Rank  =  "Professor" AND building  =  "NEB"

SELECT count(*) FROM Faculty WHERE building  =  "NEB" AND rank  =  "Professor"
6672
 SELECT customer_status_code , count(*) FROM Customers GROUP BY customer_status_code;

SELECT customer_status_code ,  count(

SELECT primary_conference FROM university u JOIN basketball_match bm ON u.school_id = bm.school_id ORDER BY bm.acc_percent ASC LIMIT 1
5579
 SELECT unit_of_measure FROM ref_product_categories WHERE product_category_code  =  "Herbs"

SELECT unit_of_measure FROM ref_product_categories WHERE product_category_description  =  "Herb"
1270
 SELECT T2.apt_number FROM Apartment_Bookings AS T1 JOIN Apartments AS T2 ON T1.apt_id  =  T2.apt_id WHERE T1.booking_status_code  =  "Confirmed" INTERSECT SELECT T2.apt_number FROM Apartment_Bookings AS T1 JOIN Apartments AS T2 ON T1.apt_id  =  T2.apt_id WHERE T1.booking_status_code  =  "Provisional"

SELECT Apartments.apt_number 
FROM Apartments 
INNER JOIN Apartment_Bookings 
ON Apartments.apt_id = Apartment_Bookings.apt_id 
WHERE Apartment_Bookings.booking_status_code = "Provisional" 
OR Apartment_Bookings.booking_status_code = "Confirmed"
1542
 SELECT claim_status_description FROM claims_processing_stages WHERE claim_status_name  =  "Open"

SELECT clai

SELECT c.name 
FROM channel c 
INNER JOIN broadcast b 
ON c.Channel_ID = b.Channel_ID 
WHERE b.Time_of_day = 'morning'
5107
 SELECT Marketing_Region_Descriptrion FROM Marketing_Regions WHERE Marketing_Region_Name  =  "China"

SELECT Marketing_Region_Descriptrion FROM Marketing_Regions WHERE Marketing_Region_Name  =  "China"
4252
 SELECT clubname FROM club

SELECT clubname FROM club
1580
 SELECT count(*) ,  account_id FROM Financial_transactions

SELECT T1.account_id ,  count(*) FROM Financial_transactions AS T1 GROUP BY T1.account_id
3144
 SELECT T1.part_name FROM Parts AS T1 JOIN Part_Faults AS T2 ON T1.part_id  =  T2.part_id GROUP BY T1.part_name ORDER BY count(*) ASC LIMIT 1

SELECT T1.part_name FROM Parts AS T1 JOIN Asset_Parts AS T2 ON T1.part_id  =  T2.part_id JOIN Part_Faults AS T3 ON T2.part_id  =  T3.part_id GROUP BY T1.part_name ORDER BY count(*) ASC LIMIT 1
657
 SELECT DISTINCT Theme FROM journal

SELECT DISTINCT Theme FROM journal
866
 SELECT COUNT(*) FROM MEDIATYPE AS T1 J

In [441]:
tmp[~(tmp['model_response'] == '')].shape

(100, 17)

In [442]:
tmp[~(tmp['model_response'].isna())].shape

(100, 17)

In [443]:
tmp['open_ai_execution'] = tmp.apply(lambda x: execution(x, 'open_ai_completion'), axis=1)
tmp['model_response_execution'] = tmp.apply(lambda x: execution(x, 'model_response'), axis=1)
tmp['execution_accuracy'] = np.where(tmp['open_ai_execution'] == tmp['model_response_execution'], 1, 0)
tmp[~(tmp['model_response'] == '') & ~(tmp['model_response'].isna())]['execution_accuracy'].mean()

0.72

In [444]:
i = 1
print(tmp[~(tmp['model_response'] == '') & 
                 (tmp['execution_accuracy'] == 0)].iloc[i]['open_ai_prompt'])
print(tmp[~(tmp['model_response'] == '') & 
                 (tmp['execution_accuracy'] == 0)].iloc[i]['open_ai_completion'])
print(tmp[~(tmp['model_response'] == '') & 
                 (tmp['execution_accuracy'] == 0)].iloc[i]['model_response'])



Convert text to SQL.

DDL:
```
CREATE TABLE CLASS (CLASS_CODE varchar(5) ,CRS_CODE varchar(10),CLASS_SECTION varchar(2),CLASS_TIME varchar(20),CLASS_ROOM varchar(8),PROF_NUM int);CREATE TABLE COURSE (CRS_CODE varchar(10) ,DEPT_CODE varchar(10),CRS_DESCRIPTION varchar(35),CRS_CREDIT float(8));CREATE TABLE DEPARTMENT (DEPT_CODE varchar(10) ,DEPT_NAME varchar(30),SCHOOL_CODE varchar(8),EMP_NUM int,DEPT_ADDRESS varchar(20),DEPT_EXTENSION varchar(4));CREATE TABLE EMPLOYEE (EMP_NUM int ,EMP_LNAME varchar(15),EMP_FNAME varchar(12),EMP_INITIAL varchar(1),EMP_JOBCODE varchar(5),EMP_HIREDATE datetime,EMP_DOB datetime);CREATE TABLE ENROLL (CLASS_CODE varchar(5),STU_NUM int,ENROLL_GRADE varchar(50));CREATE TABLE PROFESSOR (EMP_NUM int,DEPT_CODE varchar(10),PROF_OFFICE varchar(50),PROF_EXTENSION varchar(4),PROF_HIGH_DEGREE varchar(5));CREATE TABLE STUDENT (STU_NUM int ,STU_LNAME varchar(15),STU_FNAME varchar(15),STU_INIT varchar(1),STU_DOB datetime,STU_HRS int,STU_CLASS varchar(2),STU_GPA float(8)

In [445]:
df_3_examples = tmp.copy()

In [416]:
df_0_examples[~(df_0_examples['model_response'] == '') & 
              ~(df_0_examples['model_response'].isna())]['execution_accuracy'].mean()

0.16

In [426]:
df_1_examples[~(df_1_examples['model_response'] == '') & 
              ~(df_1_examples['model_response'].isna())]['execution_accuracy'].mean()

0.64

In [436]:
df_2_examples[~(df_2_examples['model_response'] == '') & 
              ~(df_2_examples['model_response'].isna())]['execution_accuracy'].mean()

0.67

In [446]:
df_3_examples[~(df_3_examples['model_response'] == '') & 
              ~(df_3_examples['model_response'].isna())]['execution_accuracy'].mean()

0.72

In [414]:
df_6_examples[~(df_6_examples['model_response'] == '') & 
              ~(df_6_examples['model_response'].isna())]['execution_accuracy'].mean()

0.7

In [402]:
df_db_rand.head()

Unnamed: 0,db_id,query,query_toks,query_toks_no_value,question,question_toks,sql,schema,question-1,completion-1,question-2,completion-2,open_ai_prompt,open_ai_completion,model_response,open_ai_execution,model_response_execution,execution_accuracy
3073,loan_1,"SELECT DISTINCT T1.cust_name , T1.credit_scor...","[SELECT, DISTINCT, T1.cust_name, ,, T1.credit_...","[select, distinct, t1, ., cust_name, ,, t1, .,...",What are the different names and credit scores...,"[What, are, the, different, names, and, credit...","{'from': {'table_units': [['table_unit', 1], [...","CREATE TABLE bank (branch_ID int ,bname varcha...",Find the names of bank branches that have prov...,SELECT T2.bname FROM loan AS T1 JOIN bank AS T...,What are the total account balances for each c...,SELECT sum(acc_bal) FROM customer WHERE state ...,\nConvert text to SQL.\n\nDDL:\n```\nCREATE TA...,"SELECT DISTINCT T1.cust_name , T1.credit_sco...","SELECT cust_name, credit_score FROM customer J...","[(Mary, 30), (Owen, 210)]","[(Mary, 30), (Mary, 30), (Owen, 210)]",0
2310,perpetrator,SELECT T1.Name FROM people AS T1 JOIN perpetra...,"[SELECT, T1.Name, FROM, people, AS, T1, JOIN, ...","[select, t1, ., name, from, people, as, t1, jo...",What are the names of perpetrators whose count...,"[What, are, the, names, of, perpetrators, whos...","{'from': {'table_units': [['table_unit', 1], [...","CREATE TABLE perpetrator (Perpetrator_ID int,P...",Show the countries that have both perpetrators...,SELECT Country FROM perpetrator WHERE Injured ...,What are the countries of perpetrators? Show e...,"SELECT Country , COUNT(*) FROM perpetrator GR...",\nConvert text to SQL.\n\nDDL:\n```\nCREATE TA...,SELECT T1.Name FROM people AS T1 JOIN perpetr...,SELECT Name FROM people JOIN perpetrator ON pe...,"[(Ron Baxter,), (Rob Cunningham,), (Henry John...","[(Ron Baxter,), (Rob Cunningham,), (Henry John...",1
1543,insurance_and_eClaims,SELECT count(DISTINCT claim_outcome_code) FROM...,"[SELECT, count, (, DISTINCT, claim_outcome_cod...","[select, count, (, distinct, claim_outcome_cod...",How many distinct claim outcome codes are there?,"[How, many, distinct, claim, outcome, codes, a...","{'from': {'table_units': [['table_unit', 6]], ...","CREATE TABLE Customers (Customer_ID INTEGER ,C...",Which customers have an insurance policy with ...,SELECT DISTINCT t2.customer_details FROM polic...,Find the number of distinct stages in claim pr...,SELECT count(*) FROM claims_processing_stages,\nConvert text to SQL.\n\nDDL:\n```\nCREATE TA...,SELECT count(DISTINCT claim_outcome_code) FRO...,SELECT count(DISTINCT claim_outcome_code) FROM...,"[(3,)]","[(3,)]",1
479,allergy_1,SELECT Advisor FROM Student WHERE StuID = 1004,"[SELECT, Advisor, FROM, Student, WHERE, StuID,...","[select, advisor, from, student, where, stuid,...",Who is the advisor of student with ID 1004?,"[Who, is, the, advisor, of, student, with, ID,...","{'from': {'table_units': [['table_unit', 2]], ...",create table Allergy_Type ( Allergy \t\t...,What are the student ids of students who don't...,SELECT StuID FROM Student EXCEPT SELECT StuID ...,How many distinct allergies are there?,SELECT count(DISTINCT allergytype) FROM Allerg...,\nConvert text to SQL.\n\nDDL:\n```\ncreate ta...,SELECT Advisor FROM Student WHERE StuID = 1...,SELECT Advisor FROM Student WHERE StuID = 1004,"[(8423,)]","[(8423,)]",1
4672,college_3,SELECT T1.DName FROM DEPARTMENT AS T1 JOIN MIN...,"[SELECT, T1.DName, FROM, DEPARTMENT, AS, T1, J...","[select, t1, ., dname, from, department, as, t...",What is the name of the department with the mo...,"[What, is, the, name, of, the, department, wit...","{'from': {'table_units': [['table_unit', 2], [...",create table Student ( StuID INT...,Find the first name and last name of the instr...,"SELECT T2.Fname , T2.Lname FROM COURSE AS T1 ...",Which courses are taught on days MTW?,"SELECT CName FROM COURSE WHERE Days = ""MTW""",\nConvert text to SQL.\n\nDDL:\n```\ncreate ta...,SELECT T1.DName FROM DEPARTMENT AS T1 JOIN MI...,SELECT T2.DName FROM MINOR_IN AS T1 JOIN DEPAR...,"[(Mathematical Sciences,)]","[(Mathematical Sciences,)]",1


In [403]:
df_6_examples.head()

Unnamed: 0,db_id,query,query_toks,query_toks_no_value,question,question_toks,sql,schema,question-1,question-2,...,completion-3,completion-4,completion-5,completion-6,open_ai_prompt,open_ai_completion,model_response,open_ai_execution,model_response_execution,execution_accuracy
5632,railway,SELECT count(*) FROM railway,"[SELECT, count, (, *, ), FROM, railway]","[select, count, (, *, ), from, railway]",How many railways are there?,"[How, many, railways, are, there, ?]","{'from': {'table_units': [['table_unit', 0]], ...","CREATE TABLE railway (Railway_ID int,Railway t...",Show the most common builder of railways.,Show the distinct countries of managers.,...,SELECT Builder FROM railway ORDER BY Builder ASC,SELECT Builder FROM railway ORDER BY Builder ASC,SELECT Name FROM manager ORDER BY LEVEL ASC,"SELECT LOCATION , COUNT(*) FROM railway GROUP...",\nConvert text to SQL.\n\nDDL:\n```\nCREATE TA...,SELECT count(*) FROM railway\n,SELECT COUNT(*) FROM railway,"[(10,)]","[(10,)]",1
4443,network_2,"SELECT DISTINCT T1.name , T1.age FROM Person ...","[SELECT, DISTINCT, T1.name, ,, T1.age, FROM, P...","[select, distinct, t1, ., name, ,, t1, ., age,...",What are the different names and ages of every...,"[What, are, the, different, names, and, ages, ...","{'from': {'table_units': [['table_unit', 0], [...","CREATE TABLE Person ( name varchar(20) , age...",What are the names of every person who has a f...,What are the names of all friends who are from...,...,"SELECT T1.name , T1.age , T1.job FROM Person...",SELECT T2.name FROM Person AS T1 JOIN PersonFr...,SELECT age FROM Person WHERE job = 'doctor' ...,SELECT count(DISTINCT job) FROM Person,\nConvert text to SQL.\n\nDDL:\n```\nCREATE TA...,"SELECT DISTINCT T1.name , T1.age FROM Person...","SELECT T2.name , T1.age FROM Person AS T1 JOI...","[(Zach, 45)]","[(Zach, 45), (Zach, 45)]",0
4713,department_store,SELECT count(DISTINCT product_type_code) FROM ...,"[SELECT, count, (, DISTINCT, product_type_code...","[select, count, (, distinct, product_type_code...",Find the number of different product types.,"[Find, the, number, of, different, product, ty...","{'from': {'table_units': [['table_unit', 5]], ...","CREATE TABLE Addresses (address_id INTEGER ,ad...",Return ids of all the products that are suppli...,What are the names and ids of customers whose ...,...,SELECT DISTINCT T1.customer_name FROM customer...,SELECT avg(product_price) FROM products WHERE ...,"SELECT T1.supplier_name , T1.supplier_phone F...",SELECT T1.staff_name FROM staff AS T1 JOIN sta...,\nConvert text to SQL.\n\nDDL:\n```\nCREATE TA...,SELECT count(DISTINCT product_type_code) FROM...,SELECT COUNT(DISTINCT product_type_code) FROM ...,"[(2,)]","[(2,)]",1
5365,tracking_software_problems,SELECT count(*) FROM product AS T1 JOIN proble...,"[SELECT, count, (, *, ), FROM, product, AS, T1...","[select, count, (, *, ), from, product, as, t1...","How many problems did the product called ""volu...","[How, many, problems, did, the, product, calle...","{'from': {'table_units': [['table_unit', 3], [...",CREATE TABLE Problem_Category_Codes (problem_c...,Find the first and last name of the staff memb...,What are the ids of the problems that are from...,...,SELECT T1.problem_id FROM problems AS T1 JOIN ...,SELECT T2.product_name FROM problems AS T1 JOI...,SELECT problem_id FROM problems WHERE date_pro...,SELECT problem_id FROM problems WHERE date_pro...,\nConvert text to SQL.\n\nDDL:\n```\nCREATE TA...,SELECT count(*) FROM product AS T1 JOIN probl...,SELECT COUNT(*) FROM problems AS T1 JOIN produ...,"[(0,)]","[(0,)]",1
5345,manufactory_1,"SELECT avg(T1.Price) , T2.name FROM products ...","[SELECT, avg, (, T1.Price, ), ,, T2.name, FROM...","[select, avg, (, t1, ., price, ), ,, t2, ., na...",What are the names and average prices of produ...,"[What, are, the, names, and, average, prices, ...","{'from': {'table_units': [['table_unit', 1], [...","CREATE TABLE Manufacturers ( Code INTEGER, N...",Select the average price of each manufacturer'...,Who is the founders of companies whose first l...,...,SELECT DISTINCT headquarter FROM manufacturers,"SELECT T1.Name , max(T1.Price) , T2.name FRO...","SELECT sum(revenue) , founder FROM manufactur...",SELECT sum(revenue) FROM manufacturers WHERE h...,\nConvert text to SQL.\n\nDDL:\n```\nCREATE TA...,"SELECT avg(T1.Price) , T2.name FROM products...","SELECT T1.name , AVG(T1.price) FROM products ...","[(150.0, Creative Labs), (240.0, Fujitsu), (16...","[(DVD burner, 180.0), (DVD drive, 165.0), (Har...",0


In [473]:
print(df_6_examples[df_6_examples['model_response'].str.len() == 250]['model_response'])#.str.len().max()

2732    SELECT T3.name FROM affected_region AS T1 JOIN...
Name: model_response, dtype: object


In [475]:
df_6_examples.loc[2732]['model_response']

"SELECT T3.name FROM affected_region AS T1 JOIN region AS T2 ON T1.region_id  =  T2.region_id JOIN storm AS T3 ON T1.storm_id  =  T3.storm_id WHERE T2.region_name  =  'Afghanistan' OR T2.region_name  =  'Albania' GROUP BY T3.name HAVING count(*)  =  2"

# Testing other models
Curie (90% chepaer) - 9% accuracy  
Codex models (free in beta, more expensive later?)  
Davinci 1 - 57% accuracy  
Davinci 2 - 76% accuracy  
Cushman - 59% accuracy

In [448]:
df = pd.read_json('spider/train_spider.json')

# Set schema
df['schema'] = df.apply(lambda x: get_schema(x['db_id']), axis=1)
df['schema'] = df['schema'].apply(lambda x: re.sub(r'\s+references\s+.*?[,;)]', '', x)) 
df['schema'] = df['schema'].apply(lambda x: re.sub('NOT NULL', '', x))
df['schema'] = df['schema'].apply(lambda x: re.sub(r',\s*(PRIMARY KEY|FOREIGN KEY).*?(;|$)', ');', x, 
                                                   flags=re.IGNORECASE))
df['schema'] = df['schema'].apply(lambda x: x.replace('`', '"'))
df['schema'] = df['schema'].apply(lambda x: re.sub(r'"(\w+)"', r'\1', x))
df['schema'] = df['schema'].apply(lambda x: x.replace('AUTOINCREMENT', ''))
df['schema'] = df['schema'].apply(lambda x: x.replace('UNSIGNED', ''))
df['schema'] = df['schema'].apply(lambda x: x.replace(' BIT', ' BOOLEAN'))
df['schema'] = df['schema'].apply(lambda x: x.replace('YEAR DEFAULT NULL', 'SMALLINT'))
df['schema'] = df['schema'].apply(lambda x: x.replace('MEDIUMINT', 'SMALLINT'))
df['schema'] = df['schema'].apply(lambda x: x.replace('character varchar', 'varchar'))
df['schema'] = df['schema'].apply(lambda x: re.sub(r'PRIMARY KEY,', ',', x))

# Select the examples to show the model
num_examples = 3
df[['question-' + str(i+1) for i in range(num_examples)] + ['completion-' + str(i+1) for i in range(num_examples)]]=\
    df.apply(lambda x: get_more_db_examples(x, df, num_examples), axis=1)
df['open_ai_prompt'] = df.apply(lambda x: more_db_examples_prompt(x, num_examples), axis=1)
df['open_ai_completion'] = df.apply(lambda x: open_ai_completion(x), axis=1)

In [477]:
df_curie = df.sample(100).copy()
df_curie['model_response'] = ''

data = []
for idx, row in df_curie.iterrows():
    print(idx)
    if row['model_response'] != '':
        print('already completed')
        data.append(row['model_response'])
        continue
    
    print(row['open_ai_completion'])
    new_response = call_raw_model(row, engine="curie", stop='\n', max_tokens=128)
    data.append(new_response)
    df_curie.loc[idx,'model_response'] = new_response

1013
 SELECT count(*) ,  affiliation FROM university WHERE enrollment  >  20000 GROUP BY affiliation

SELECT affiliation , SUM(enrollment) OVER ( ORDER BY founded ) FROM university GROUP BY affiliation
4440
 SELECT T1.name ,  T1.age FROM Person AS T1 JOIN PersonFriend AS T2 ON T1.name  =  T2.name WHERE T2.friend  =  'Dan' INTERSECT SELECT T1.name ,   T1.age FROM Person AS T1 JOIN PersonFriend AS T2 ON T1.name  =  T2.name WHERE T2.friend  =  'Alice'

SELECT name, age FROM PersonFriend WHERE name  =  'Dan' AND age  =  '23'
5876
 SELECT transaction_type_code FROM TRANSACTIONS GROUP BY transaction_type_code ORDER BY COUNT(*) DESC LIMIT 1

SELECT T2.transaction_type_code FROM TRANSACTIONS
4866
 SELECT Enrollment FROM school WHERE Denomination != "Catholic"

SELECT Enrollment FROM school WHERE Denomination != 'Catholic'
136
 SELECT bike_id FROM trip WHERE zip_code  =  94002 GROUP BY bike_id ORDER BY COUNT(*) DESC LIMIT 1

SELECT id FROM trip WHERE bike_id  =  636 AND duration  >  (SELECT max

SELECT station_name FROM trip WHERE trip_duration  < 100
2672
 SELECT Nationality ,  COUNT(*) FROM HOST GROUP BY Nationality

SELECT Nationality , COUNT(*) FROM HOST GROUP BY Nationality ORDER BY COUNT(*) DESC LIMIT 1
4139
 SELECT T1.Year FROM film_market_estimation AS T1 JOIN market AS T2 ON T1.Market_ID  =  T2.Market_ID WHERE T2.Country  =  "Japan" ORDER BY T1.Year DESC

SELECT Year,Estimation_ID FROM film_market_estimation AS T1 JOIN market AS T2 ON T1.Market_ID = T2.Market_ID WHERE T1.Year =  '2014' AND T1.Estimation_ID =  '1'
6368
 SELECT id ,  country ,  city ,  name FROM airport ORDER BY name

SELECT id, country, city, name FROM airport ORDER BY name
2271
 SELECT Name FROM People ORDER BY Weight ASC

SELECT Name FROM people ORDER BY Weight DESC LIMIT 10
132
 SELECT DISTINCT T1.name FROM station AS T1 JOIN status AS T2 ON T1.id  =  T2.station_id WHERE T2.bikes_available  =  7

SELECT name FROM station WHERE bikes_available = 7
1433
 SELECT T2.name ,  T2.salary FROM advisor AS T1 

SELECT * FROM employees WHERE hire_date >= '1987-09-07'
380
 SELECT eid ,  salary FROM Employee WHERE name  =  'Mark Young'

SELECT eid , salary FROM Employee WHERE name = 'Mark Young'
5910
 SELECT Name FROM TOURIST_ATTRACTIONS WHERE How_to_Get_There  =  "bus"

SELECT T1.Name FROM Tourist_Attractions AS T1 JOIN VISITS AS T2 ON T1.Tourist_Attraction_ID  =  T2.Tourist_Attraction_ID JOIN BUSES AS T3 ON T2.Tourist_ID  = T3.Tourist_ID WHERE T3.Tourist_Details  =  "Bus"
4923
 SELECT t3.headquartered_city ,  count(*) FROM store AS t1 JOIN store_district AS t2 ON t1.store_id  =  t2.store_id JOIN district AS t3 ON t2.district_id  =  t3.district_id GROUP BY t3.headquartered_city

SELECT city_name FROM city WHERE city_population = 
4769
 SELECT product_id ,  product_name FROM products WHERE product_price  <  600 OR product_price  >  900

SELECT product_id , product_name FROM products WHERE product_price  <  600 OR product_price  >  900
5464
 SELECT DISTINCT Secretary_Vote FROM VOTING_RECORD WHERE

In [478]:
df_curie[~(df_curie['model_response'] == '')].shape

(100, 17)

In [479]:
df_curie[~(df_curie['model_response'].isna())].shape

(98, 17)

In [480]:
df_curie['open_ai_execution'] = df_curie.apply(lambda x: execution(x, 'open_ai_completion'), axis=1)
df_curie['model_response_execution'] = df_curie.apply(lambda x: execution(x, 'model_response'), axis=1)
df_curie['execution_accuracy'] = np.where(df_curie['open_ai_execution'] == df_curie['model_response_execution'], 1, 0)
df_curie[~(df_curie['model_response'] == '') & ~(df_curie['model_response'].isna())]['execution_accuracy'].mean()

0.09183673469387756

In [481]:
print(df_curie.iloc[0]['open_ai_prompt'])
print(df_curie.iloc[0]['open_ai_completion'])
print(df_curie.iloc[0]['model_response'])


Convert text to SQL.

DDL:
```
CREATE TABLE basketball_match (Team_ID int,School_ID int,Team_Name text,ACC_Regular_Season text,ACC_Percent text,ACC_Home text,ACC_Road text,All_Games text,All_Games_Percent int,All_Home text,All_Road text,All_Neutral text);CREATE TABLE university (School_ID int,School text,Location text,Founded real,Affiliation text,Enrollment real,Nickname text,Primary_conference text);
```

Question:
"""Show the enrollment and primary_conference of the oldest college."""
Answer:
SELECT enrollment ,  primary_conference FROM university ORDER BY founded LIMIT 1

Question:
"""Return the total and minimum enrollments across all schools."""
Answer:
SELECT sum(enrollment) ,  min(enrollment) FROM university

Question:
"""What are the different schools and their nicknames, ordered by their founding years?"""
Answer:
SELECT school ,  nickname FROM university ORDER BY founded

Question:
"""Find the number of universities that have over a 20000 enrollment size for each affiliatio

In [482]:
df_code_davinci = df.sample(100).copy()
df_code_davinci['model_response'] = ''

data = []
for idx, row in df_code_davinci.iterrows():
    print(idx)
    if row['model_response'] != '':
        print('already completed')
        data.append(row['model_response'])
        continue
    
    print(row['open_ai_completion'])
    new_response = call_raw_model(row, engine="code-davinci-001", stop='\n', max_tokens=128)
    data.append(new_response)
    df_code_davinci.loc[idx,'model_response'] = new_response

6701
 SELECT T1.lesson_id FROM Lessons AS T1 JOIN Staff AS T2 ON T1.staff_id = T2.staff_id WHERE T2.first_name = "Janessa" AND T2.last_name = "Sawayn" AND nickname LIKE "%s%";

SELECT lesson_id FROM Lessons AS T1 JOIN Staff AS T2 ON T1.staff_id = T2.staff_id WHERE T2.nickname LIKE "%s%" AND T2.first_name = "Janessa" AND T2.last_name = "Sawayn";
1789
 SELECT T2.balance FROM accounts AS T1 JOIN checking AS T2 ON T1.custid  =  T2.custid WHERE T1.name LIKE '%ee%'

SELECT T2.balance FROM accounts AS T1 JOIN checking AS T2 ON T1.custid  =  T2.custid WHERE T1.name  LIKE  '%ee%'
255
 SELECT T1.Name FROM actor AS T1 JOIN musical AS T2 ON T1.Musical_ID  =  T2.Musical_ID ORDER BY T2.Year DESC

SELECT Name FROM actor WHERE Musical_ID IN (SELECT Musical_ID FROM musical WHERE Award  =  "Tony Award") ORDER BY Year DESC
6869
 SELECT count(*) FROM routes AS T1 JOIN airports AS T2 ON T1.dst_apid  =  T2.apid WHERE T2.country  =  'Italy'

SELECT count(*) FROM routes AS T1 JOIN airports AS T2 ON T1.dst_api

SELECT DISTINCT T1.customer_name FROM customers AS T1 JOIN Customer_Orders AS T2 ON T1.customer_id  =  T2.customer_id JOIN Order_Items AS T3 ON T2.order_id  =  T3.order_id GROUP BY T1.customer_name HAVING COUNT(DISTINCT T3.product_id)  >=  3
5360
 SELECT DISTINCT staff_first_name ,  staff_last_name FROM staff AS T1 JOIN problem_log AS T2 ON T1.staff_id = T2.assigned_to_staff_id WHERE T2.problem_id = 1

SELECT DISTINCT T1.staff_first_name , T1.staff_last_name FROM staff AS T1 JOIN problem_log AS T2 ON T1.staff_id = T2.assigned_to_staff_id WHERE T2.problem_id = 1
3872
 SELECT sum(Amount_Settled) FROM Settlements

SELECT sum(Amount_Settled) FROM Settlements
2616
 SELECT roomName ,  bedType FROM Rooms WHERE decor = "traditional";

SELECT bedType, roomName FROM Rooms WHERE decor  =  'traditional';
6851
 SELECT city FROM airports WHERE country  =  'United States' GROUP BY city HAVING count(*)  >  3

SELECT city FROM airports WHERE country  =  'United States' GROUP BY city HAVING count(*) > 3

SELECT avg(Points) FROM player WHERE Club_ID  =  (SELECT Club_ID FROM club WHERE name  =  'AIB')
6750
 SELECT T1.fname ,  T1.lname FROM Faculty AS T1 JOIN Student AS T2 ON T1.FacID  =  T2.advisor WHERE T2.fname  =  "Linda" AND T2.lname  =  "Smith"

SELECT T1.Fname ,  T1.Lname FROM Student AS T1 JOIN Faculty AS T2 ON T1.advisor  =  T2.FacID WHERE T1.Lname  =  'Smith' AND T1.Fname  =  'Linda'
5254
 SELECT TYPE FROM vocals AS T1 JOIN band AS T2 ON T1.bandmate  =  T2.id WHERE firstname  =  "Solveig" GROUP BY TYPE ORDER BY count(*) DESC LIMIT 1

SELECT T1.Type FROM Vocals AS T1 JOIN Band AS T2 ON T1.Bandmate  =  T2.id JOIN Songs AS T3 ON T3.SongId  =  T1.SongId WHERE T2.Firstname  =  "Solveig" GROUP BY T1.Type
4577
 SELECT Industry FROM Companies WHERE Headquarters  =  "USA" INTERSECT SELECT Industry FROM Companies WHERE Headquarters  =  "China"

SELECT T1.Industry FROM Companies AS T1 JOIN Companies AS T2 ON T1.Headquarters  =  "USA" AND T2.Headquarters  =  "China" WHERE T1.id  =  T2.id
48

In [483]:
df_code_davinci[~(df_code_davinci['model_response'] == '')].shape

(100, 17)

In [484]:
df_code_davinci[~(df_code_davinci['model_response'].isna())].shape

(100, 17)

In [485]:
df_code_davinci['open_ai_execution'] = df_code_davinci.apply(lambda x: execution(x, 'open_ai_completion'), axis=1)
df_code_davinci['model_response_execution'] = df_code_davinci.apply(lambda x: execution(x, 'model_response'), axis=1)
df_code_davinci['execution_accuracy'] = np.where(df_code_davinci['open_ai_execution'] == df_code_davinci['model_response_execution'], 1, 0)
df_code_davinci[~(df_code_davinci['model_response'] == '') & ~(df_code_davinci['model_response'].isna())]['execution_accuracy'].mean()

0.57

In [494]:
i = 40
print(df_code_davinci[df_code_davinci['execution_accuracy'] == 0].iloc[i]['open_ai_prompt'])
print(df_code_davinci[df_code_davinci['execution_accuracy'] == 0].iloc[i]['open_ai_completion'])
print(df_code_davinci[df_code_davinci['execution_accuracy'] == 0].iloc[i]['model_response'])


Convert text to SQL.

DDL:
```
CREATE TABLE Apartment_Buildings (building_id INTEGER ,building_short_name CHAR(15),building_full_name VARCHAR(80),building_description VARCHAR(255),building_address VARCHAR(255),building_manager VARCHAR(50),building_phone VARCHAR(80));CREATE TABLE Apartments (apt_id INTEGER  ,building_id INTEGER ,apt_type_code CHAR(15),apt_number CHAR(10),bathroom_count INTEGER,bedroom_count INTEGER,room_count CHAR(5));CREATE TABLE Apartment_Facilities (apt_id INTEGER ,facility_code CHAR(15) );CREATE TABLE Guests (guest_id INTEGER  ,gender_code CHAR(1),guest_first_name VARCHAR(80),guest_last_name VARCHAR(80),date_of_birth DATETIME);CREATE TABLE Apartment_Bookings (apt_booking_id INTEGER ,apt_id INTEGER,guest_id INTEGER ,booking_status_code CHAR(15) ,booking_start_date DATETIME,booking_end_date DATETIME);CREATE TABLE View_Unit_Status (apt_id INTEGER,apt_booking_id INTEGER,status_date DATETIME ,available_yn BOOLEAN);
```

Question:
"""Which apartments have bookings with s

In [495]:
df_code_davinci2 = df.sample(100).copy()
df_code_davinci2['model_response'] = ''

data = []
for idx, row in df_code_davinci2.iterrows():
    print(idx)
    if row['model_response'] != '':
        print('already completed')
        data.append(row['model_response'])
        continue
    
    print(row['open_ai_completion'])
    new_response = call_raw_model(row, engine="code-davinci-002", stop='\n', max_tokens=128)
    data.append(new_response)
    df_code_davinci2.loc[idx,'model_response'] = new_response

6973
 SELECT publisher FROM book_club GROUP BY publisher ORDER BY count(*) DESC LIMIT 1

SELECT publisher FROM book_club GROUP BY publisher ORDER BY COUNT(*) DESC LIMIT 1
6436
 SELECT T1.project_id ,  T1.project_details FROM Projects AS T1 JOIN Documents AS T2 ON T1.project_id  =  T2.project_id GROUP BY T1.project_id HAVING count(*)  >  2

SELECT T1.project_id ,  T1.project_details FROM Projects AS T1 JOIN Documents AS T2 ON T1.project_id  =  T2.project_id GROUP BY T1.project_id ,  T1.project_details HAVING COUNT(T2.document_id)  >  2
5046
 SELECT count(DISTINCT state) FROM college WHERE enr  >  (SELECT avg(enr) FROM college)

SELECT count(DISTINCT state) FROM college WHERE enr > (SELECT avg(enr) FROM college)
3325
 SELECT T1.stu_lname FROM student AS T1 JOIN enroll AS T2 ON T1.stu_num  =  T2.stu_num WHERE T2.enroll_grade  =  'A' AND T2.class_code  =  10018

SELECT stu_lname FROM student AS T1 JOIN enroll AS T2 ON T1.stu_num  =  T2.stu_num WHERE T2.class_code  =  '10018' AND T2.enroll_

SELECT fname FROM authors ORDER BY fname ASC
4225
 SELECT T1.date_in_location_from ,  T1.date_in_locaton_to FROM Document_locations AS T1 JOIN All_documents AS T2 ON T1.document_id  =  T2.document_id WHERE T2.document_name  =  "Robin CV"

SELECT Date_in_Location_From ,  Date_in_Locaton_To FROM Document_locations AS T1 JOIN All_documents AS T2 ON T1.document_id  =  T2.document_id WHERE T2.document_name  =  "Robin CV"
761
 SELECT count(*) FROM race

SELECT count(*) FROM race
6146
 SELECT Age FROM artist

SELECT Age FROM artist
6660
 SELECT email_address ,  date_of_birth FROM Customers WHERE first_name = "Carole"

RateLimitError
SELECT email_address, date_of_birth FROM Customers WHERE first_name = "Carole";
2221
 SELECT max(T2.fastestlapspeed) FROM races AS T1 JOIN results AS T2 ON T1.raceid = T2.raceid WHERE T1.year = 2008 AND T1.name = "Monaco Grand Prix"

SELECT MAX(T2.fastestlapspeed) FROM races AS T1 JOIN results AS T2 ON T1.raceid = T2.raceid WHERE T1.name = "Monaco Grand Prix" AND 

SELECT count(*) FROM exhibition WHERE exhibition_id IN (SELECT exhibition_id FROM exhibition_record WHERE attendance  >  100) OR exhibition_id IN (SELECT exhibition_id FROM exhibition WHERE ticket_price  <  10)
5771
 SELECT email_address ,  phone_number FROM customers ORDER BY email_address ,  phone_number

RateLimitError
SELECT email_address ,  phone_number FROM customers ORDER BY email_address ,  phone_number
6013
 SELECT StuID ,  count(*) ,  sum(gamesplayed) FROM Sportsinfo GROUP BY StuID

SELECT StuID, count(SportName), sum(GamesPlayed) FROM SportsInfo GROUP BY StuID
5984
 SELECT gtype ,  count(*) FROM Video_games GROUP BY gtype

SELECT gtype, count(*) FROM Video_games GROUP BY gtype
4650
 SELECT max(gradepoint) ,  min(gradepoint) FROM GRADECONVERSION

SELECT max(gradepoint) ,  min(gradepoint) FROM GRADECONVERSION
260
 SELECT Nominee ,  COUNT(*) FROM musical GROUP BY Nominee

RateLimitError
SELECT Nominee, COUNT(Nominee) FROM musical GROUP BY Nominee
4752
 SELECT product_name ,  pr

SELECT T1.employee_id ,  T1.first_name ,  T1.last_name ,  T1.salary FROM employees AS T1 JOIN departments AS T2 ON T1.department_id  =  T2.department_id JOIN employees AS T3 ON T2.department_id  =  T3.department_id WHERE T1.salary  >  (SELECT AVG(salary) FROM employees) AND T3.first_name LIKE '%J%'
981
 SELECT founded FROM university ORDER BY enrollment DESC LIMIT 1

SELECT founded FROM university WHERE enrollment  =  ( SELECT MAX ( enrollment ) FROM university )


In [496]:
df_code_davinci2[~(df_code_davinci2['model_response'] == '')].shape

(100, 17)

In [497]:
df_code_davinci2[~(df_code_davinci2['model_response'].isna())].shape

(100, 17)

In [499]:
df_code_davinci2['open_ai_execution'] = df_code_davinci2.apply(lambda x: execution(x, 'open_ai_completion'), axis=1)
df_code_davinci2['model_response_execution'] = df_code_davinci2.apply(lambda x: execution(x, 'model_response'), axis=1)
df_code_davinci2['execution_accuracy'] = np.where(df_code_davinci2['open_ai_execution'] == 
                                                  df_code_davinci2['model_response_execution'], 1, 0)
df_code_davinci2[~(df_code_davinci2['model_response'] == '') & 
                 ~(df_code_davinci2['model_response'].isna())]['execution_accuracy'].mean()

0.76

In [502]:
i = 1
print(df_code_davinci2[df_code_davinci2['execution_accuracy'] == 0].iloc[i]['open_ai_prompt'])
print(df_code_davinci2[df_code_davinci2['execution_accuracy'] == 0].iloc[i]['open_ai_completion'])
print(df_code_davinci2[df_code_davinci2['execution_accuracy'] == 0].iloc[i]['model_response'])


Convert text to SQL.

DDL:
```
CREATE TABLE Addresses (address_id INTEGER ,address_details VARCHAR(255));CREATE TABLE Staff (staff_id INTEGER ,staff_gender VARCHAR(1),staff_name VARCHAR(80));CREATE TABLE Suppliers (supplier_id INTEGER ,supplier_name VARCHAR(80),supplier_phone VARCHAR(80));CREATE TABLE Department_Store_Chain (dept_store_chain_id INTEGER ,dept_store_chain_name VARCHAR(80));CREATE TABLE Customers (customer_id INTEGER ,payment_method_code VARCHAR(10) ,customer_code VARCHAR(20),customer_name VARCHAR(80),customer_address VARCHAR(255),customer_phone VARCHAR(80),customer_email VARCHAR(80));CREATE TABLE Products (product_id INTEGER ,product_type_code VARCHAR(10) ,product_name VARCHAR(80),product_price DECIMAL(19,4));CREATE TABLE Supplier_Addresses (supplier_id INTEGER ,address_id INTEGER ,date_from DATETIME ,date_to DATETIME);CREATE TABLE Customer_Addresses (customer_id INTEGER ,address_id INTEGER ,date_from DATETIME ,date_to DATETIME);CREATE TABLE Customer_Orders (order_id IN

In [503]:
df_code_cushman = df.sample(100).copy()
df_code_cushman['model_response'] = ''

data = []
for idx, row in df_code_cushman.iterrows():
    print(idx)
    if row['model_response'] != '':
        print('already completed')
        data.append(row['model_response'])
        continue
    
    print(row['open_ai_completion'])
    new_response = call_raw_model(row, engine="code-cushman-001", stop='\n', max_tokens=128)
    data.append(new_response)
    df_code_cushman.loc[idx,'model_response'] = new_response

6174
 SELECT avg(T2.Weeks_on_Top) FROM artist AS T1 JOIN volume AS T2 ON T1.Artist_ID  =  T2.Artist_ID WHERE T1.age  <=  25

SELECT avg(T2.Weeks_on_Top) FROM artist AS T1 JOIN volume AS T2 ON T1.Artist_ID  =  T2.Artist_ID WHERE T1.age  <=  25
5694
 SELECT count(DISTINCT gender) FROM dorm

SELECT count(DISTINCT gender) FROM dorm
2760
 SELECT T2.Delegate FROM county AS T1 JOIN election AS T2 ON T1.County_id  =  T2.District WHERE T1.Population  <  100000

SELECT T2.Delegate FROM county AS T1 JOIN election AS T2 ON T1.County_id  =  T2.District WHERE T1.Population  <  100000
3799
 SELECT personal_name ,  family_name FROM Students ORDER BY family_name

SELECT T1.personal_name ,  T1.family_name FROM Students AS T1 ORDER BY T1.family_name
1261
 SELECT apt_type_code FROM Apartments GROUP BY apt_type_code ORDER BY count(*) DESC LIMIT 1

SELECT apt_type_code FROM Apartments GROUP BY apt_type_code ORDER BY count(apt_type_code) DESC
6452
 SELECT T1.document_id FROM Documents_with_expenses AS T1 JOI

SELECT T2.school_name FROM school AS T2 LEFT JOIN endowment AS T1 ON T2.school_id  =  T1.school_id WHERE T1.endowment_id  IS  NULL
4832
 SELECT Aircraft FROM aircraft WHERE Aircraft_ID NOT IN (SELECT Winning_Aircraft FROM MATCH)

SELECT Aircraft FROM aircraft WHERE Aircraft NOT IN (SELECT Winning_Aircraft FROM match)
5956
 SELECT T1.Name ,  T1.Tourist_Attraction_ID FROM Tourist_Attractions AS T1 JOIN VISITS AS T2 ON T1.Tourist_Attraction_ID  =  T2.Tourist_Attraction_ID GROUP BY T2.Tourist_Attraction_ID HAVING count(*)  <=  1

SELECT T1.Name ,  T1.Tourist_Attraction_ID FROM TOURIST_ATTRACTIONS AS T1 JOIN ( SELECT Tourist_Attraction_ID ,  count(*) AS cnt FROM Visits GROUP BY Tourist_Attraction_ID HAVING cnt = 1 ) AS T2 ON T1.Tourist_Attraction_ID  =  T2.Tourist_Attraction_ID
6601
 SELECT DISTINCT LOCATION FROM station

SELECT LOCATION FROM station
3838
 SELECT T1.student_id ,  T2.personal_name FROM Student_Course_Enrolment AS T1 JOIN Students AS T2 ON T1.student_id  =  T2.student_id GROU

SELECT class FROM captain GROUP BY class HAVING count(*) > 2
419
 SELECT count(DISTINCT eid) FROM Certificate

SELECT count(*) FROM Employee AS T1 JOIN Certificate AS T2 ON T1.eid  =  T2.eid
2566
 SELECT Police_force FROM county_public_safety WHERE LOCATION  =  "East" INTERSECT SELECT Police_force FROM county_public_safety WHERE LOCATION  =  "West"

RateLimitError
SELECT Police_force FROM county_public_safety WHERE LOCATION = 'EAST' OR LOCATION = 'WEST' GROUP BY Police_force ORDER BY COUNT(*) DESC LIMIT 1
1218
 SELECT date_of_birth FROM Guests WHERE gender_code  =  "Male"

SELECT date_of_birth FROM Guests WHERE gender_code = "M"
6144
 SELECT count(*) FROM artist

SELECT COUNT(*) FROM artist
4981
 SELECT pName FROM Player WHERE yCard  =  'yes' ORDER BY HS DESC

SELECT pName FROM Player WHERE yCard  =  'yes' ORDER BY HS DESC
1228
 SELECT DISTINCT T2.apt_number FROM Apartment_Bookings AS T1 JOIN Apartments AS T2 ON T1.apt_id  =  T2.apt_id WHERE T1.booking_status_code  =  "Confirmed"

SELE

In [505]:
df_code_cushman['open_ai_execution'] = df_code_cushman.apply(lambda x: execution(x, 'open_ai_completion'), axis=1)
df_code_cushman['model_response_execution'] = df_code_cushman.apply(lambda x: execution(x, 'model_response'), axis=1)
df_code_cushman['execution_accuracy'] = np.where(df_code_cushman['open_ai_execution'] == 
                                                  df_code_cushman['model_response_execution'], 1, 0)
df_code_cushman[~(df_code_cushman['model_response'] == '') & 
                 ~(df_code_cushman['model_response'].isna())]['execution_accuracy'].mean()

0.59

# SQLglot cleanup

In [176]:
a ="""CREATE TABLE trip (
    id INTEGER PRIMARY KEY,
    duration INTEGER,
    start_date TEXT,
    start_station_name TEXT, -- this should be removed
    start_station_id INTEGER,
    end_date TEXT,
    end_station_name TEXT, -- this should be removed
    end_station_id INTEGER,
    bike_id INTEGER,
    subscription_type TEXT,
    zip_code INTEGER);
"""

def remove_quotes_and_lower_case(node):
    if isinstance(node, sqlglot.exp.Identifier):
        return sqlglot.exp.Identifier(**{**node.args, 'this': node.this.lower(), "quoted": False})
        if ' ' in str(node):
            raise Exception("Whitespace found in column name, cannot strip quotes without causing syntax errors")
    return node

def remove_reference_nodes(node):
    if isinstance(node, sqlglot.exp.Anonymous):
        return
    return node

parsed = sqlglot.parse_one(a)
parsed.transform(remove_quotes_and_lower_case).transform(remove_reference_nodes).sql()
# Output: CREATE TABLE IF NOT EXISTS department (department_id INT, name TEXT, creation TEXT, ranking INT, budget_in_billions FLOAT, num_employees FLOAT)

'CREATE TABLE trip (id INT PRIMARY KEY, duration INT, start_date TEXT, start_station_name TEXT /* this should be removed */, start_station_id INT, end_date TEXT, end_station_name TEXT /* this should be removed */, end_station_id INT, bike_id INT, subscription_type TEXT, zip_code INT)'