In [1]:
import pandas as pd
import numpy as np

import openai
import sqlite3
import re

# Fix invalid SQL

In [14]:
def get_schema(x):
    return re.sub(r'Schema: (.+?);\nQuestion.*\n\n###\n\n', r'\1', x['open_ai_prompt'])

def get_question(x):
    return re.sub(r'Schema: .+?\nQuestion: (.+?)\n\n###\n\n', r'\1', x['open_ai_prompt'])

def open_ai_completion(x):
    return f" {x['query']}\n" 

def open_ai_correction_prompt(x):
    return f'''
        Your task is to correct a SQL query.
        
        You have the following DDLs:
        ```
        {x['schema']}
        ```
        And following question:
        """{x['question']}"""
        
        The draft SQL query is:
        ```
        {x['model_response']}
        ```
        
        The problem is:
        ```
        {x['error_prompt']}
        ```
        
        Respond with only one concise, corrected SQL statement.
        '''

def execution(row, query_type):
    PATH = 'spider/database/'
    db_id = row['db_id']
    
    query = row[query_type]

    # Connect to the SQLite database file
    conn = sqlite3.connect(PATH + db_id + '/' + db_id + '.sqlite')

    # Create a cursor object
    cursor = conn.cursor()

    # Execute a SELECT statement
    try:
        cursor.execute(query)
    except:
        return "invalid SQL"

    # Fetch the results
    results = cursor.fetchall()

    # Loop through the results and print them
    # for result in results:
    #     print(result)

    # Close the cursor and connection
    cursor.close()
    conn.close()
    
    return results

def correct_query(x):
    prompt = open_ai_correction_prompt(x)

    completions = openai.Completion.create(
        engine="text-davinci-003",
        prompt=prompt,
        temperature=0.3,
        max_tokens=2000,
        best_of=1,
        frequency_penalty=0,
        presence_penalty=0
    )

    print(completions.choices[0].text)
    return completions.choices[0].text

In [3]:
df = pd.read_csv('davinci-ft-mercator-2023-01-25-23-01-53-completions.csv',index_col=0)
df['schema']   = df.apply(get_schema, axis=1)
df['question'] = df.apply(get_question, axis=1)

df_spider = pd.read_json('spider/train_spider.json')
df_spider['open_ai_completion'] = df_spider.apply(lambda x: open_ai_completion(x), axis=1)


df = df.merge(df_spider[['db_id', 'open_ai_completion']], left_on='open_ai_completion', right_on='open_ai_completion')

df['open_ai_execution'] = df.apply(lambda x: execution(x, 'open_ai_completion'), axis=1)
df['model_response_execution'] = df.apply(lambda x: execution(x, 'model_response'), axis=1)
df['execution_accuracy'] = np.where(df['open_ai_execution'] == df['model_response_execution'], 1, 0)

df

Unnamed: 0,open_ai_prompt,open_ai_completion,model_response,schema,question,db_id,open_ai_execution,model_response_execution,execution_accuracy
0,Schema: create table `catalog_contents` (`cata...,SELECT distinct(catalog_entry_name) FROM cata...,SELECT DISTINCT catalog_entry_name FROM CATAL...,create table `catalog_contents` (`catalog_entr...,Find the names of all the catalog entries.,product_catalog,"[(Cola,), (Root beer,), (Cream Soda,), (Carbon...","[(Cola,), (Root beer,), (Cream Soda,), (Carbon...",1
1,Schema: create table `catalog_contents` (`cata...,SELECT distinct(catalog_entry_name) FROM cata...,SELECT DISTINCT catalog_entry_name FROM CATAL...,create table `catalog_contents` (`catalog_entr...,Find the names of all the catalog entries.,product_catalog,"[(Cola,), (Root beer,), (Cream Soda,), (Carbon...","[(Cola,), (Root beer,), (Cream Soda,), (Carbon...",1
2,Schema: create table `catalog_contents` (`cata...,SELECT distinct(catalog_entry_name) FROM cata...,SELECT DISTINCT catalog_entry_name FROM Catal...,create table `catalog_contents` (`catalog_entr...,What are all the catalog entry names?,product_catalog,"[(Cola,), (Root beer,), (Cream Soda,), (Carbon...","[(Cola,), (Root beer,), (Cream Soda,), (Carbon...",1
3,Schema: create table `catalog_contents` (`cata...,SELECT distinct(catalog_entry_name) FROM cata...,SELECT DISTINCT catalog_entry_name FROM Catal...,create table `catalog_contents` (`catalog_entr...,What are all the catalog entry names?,product_catalog,"[(Cola,), (Root beer,), (Cream Soda,), (Carbon...","[(Cola,), (Root beer,), (Cream Soda,), (Carbon...",1
4,Schema: create table `attribute_definitions` (...,SELECT attribute_data_type FROM Attribute_Def...,SELECT attribute_data_type FROM attribute_def...,create table `attribute_definitions` (`attribu...,Find the list of attribute data types possesse...,product_catalog,"[(Bool,)]","[(Bool,)]",1
...,...,...,...,...,...,...,...,...,...
985,Schema: create table airports ( apid integer...,"SELECT avg(elevation) , country FROM airport...","SELECT country , avg(elevation) FROM airport...",create table airports ( apid integer primary...,"For each country, what is the average elevatio...",flight_4,"[(3732.4, Afghanistan), (281.4, Albania), (144...","[(Afghanistan, 3732.4), (Albania, 281.4), (Alg...",0
986,Schema: create table airports ( apid integer...,SELECT city FROM airports GROUP BY city HAVIN...,SELECT city FROM airports GROUP BY city HAVIN...,create table airports ( apid integer primary...,Find the cities which have exactly two airports.,flight_4,"[(Abilene,), (Adana,), (Addis Ababa,), (Adelai...","[(Abilene,), (Adana,), (Addis Ababa,), (Adelai...",1
987,Schema: create table airports ( apid integer...,SELECT city FROM airports GROUP BY city HAVIN...,SELECT city FROM airports GROUP BY city HAVIN...,create table airports ( apid integer primary...,Find the cities which have exactly two airports.,flight_4,"[(Abilene,), (Adana,), (Addis Ababa,), (Adelai...","[(Abilene,), (Adana,), (Addis Ababa,), (Adelai...",1
988,Schema: create table airports ( apid integer...,SELECT city FROM airports GROUP BY city HAVIN...,SELECT city FROM airports GROUP BY city HAVIN...,create table airports ( apid integer primary...,What are the cities with exactly two airports?,flight_4,"[(Abilene,), (Adana,), (Addis Ababa,), (Adelai...","[(Abilene,), (Adana,), (Addis Ababa,), (Adelai...",1


In [21]:
df_wrong = df[df['execution_accuracy'] == 0].copy()
df_wrong['error_prompt'] = np.where(df_wrong['model_response_execution'] == 'invalid SQL',
                                   'invalid SQL', '')

df_wrong['corrected_query'] = ''

for idx, row in df_wrong.iterrows():
    if df_wrong.loc[idx]['error_prompt'] != 'invalid SQL':
        continue
    df_wrong.at[idx,'corrected_query'] = correct_query(row)

df_wrong.head()


        SELECT SUM(settlement_amount), AVG(settlement_amount) FROM Settlements;

        SELECT SUM(settlement_amount), AVG(settlement_amount) FROM Settlements;

        SELECT dependent_name FROM dependent WHERE relationship  =  'spouse'

        SELECT hispanic FROM city WHERE black > 10;

        SELECT hispanic FROM city WHERE black > 10;

        SELECT COUNT(DISTINCT region_id) FROM affected_region

        SELECT COUNT(DISTINCT region_id) FROM affected_region;

        SELECT format FROM files GROUP BY format ORDER BY COUNT(*) DESC LIMIT 1

        SELECT format FROM files GROUP BY format ORDER BY COUNT(*) DESC LIMIT 1

        SELECT language, AVG(rating) FROM song GROUP BY language;

        SELECT language, AVG(rating) FROM song GROUP BY language;

        SELECT language, AVG(rating) FROM song WHERE rating > 0 AND rating < 11 GROUP BY language;

        SELECT language, AVG(rating) FROM song WHERE rating > 0 AND rating < 11 GROUP BY language;

        SELECT formats, COUNT(

Unnamed: 0,open_ai_prompt,open_ai_completion,model_response,schema,question,db_id,open_ai_execution,model_response_execution,execution_accuracy,error_prompt,corrected_query
16,Schema: create table `catalogs` (`catalog_id` ...,SELECT distinct(catalog_publisher) FROM catal...,SELECT catalog_publisher FROM catalogs WHERE ...,create table `catalogs` (`catalog_id` integer ...,Find all the catalog publishers whose name con...,product_catalog,"[(Murray Coffee shop,)]",[],0,,
17,Schema: create table `catalogs` (`catalog_id` ...,SELECT distinct(catalog_publisher) FROM catal...,SELECT catalog_publisher FROM catalogs WHERE ...,create table `catalogs` (`catalog_id` integer ...,Find all the catalog publishers whose name con...,product_catalog,"[(Murray Coffee shop,)]",[],0,,
18,Schema: create table `catalogs` (`catalog_id` ...,SELECT distinct(catalog_publisher) FROM catal...,SELECT catalog_publisher FROM catalogs WHERE ...,create table `catalogs` (`catalog_id` integer ...,"Which catalog publishers have substring ""Murra...",product_catalog,"[(Murray Coffee shop,)]","[(Murray Coffee shop,), (Murray Coffee shop,)]",0,,
19,Schema: create table `catalogs` (`catalog_id` ...,SELECT distinct(catalog_publisher) FROM catal...,SELECT catalog_publisher FROM catalogs WHERE ...,create table `catalogs` (`catalog_id` integer ...,"Which catalog publishers have substring ""Murra...",product_catalog,"[(Murray Coffee shop,)]","[(Murray Coffee shop,), (Murray Coffee shop,)]",0,,
36,Schema: create table `catalog_contents` (`cata...,SELECT catalog_entry_name FROM catalog_conten...,SELECT product_stock_number FROM Catalog_Cont...,create table `catalog_contents` (`catalog_entr...,Find the name of the product that has the smal...,product_catalog,"[(Cola,)]","[(89 cp,)]",0,,


In [49]:
i = 5
print(df_wrong[df_wrong['error_prompt'] == 'invalid SQL']['open_ai_prompt'].iloc[i])
print(df_wrong[df_wrong['error_prompt'] == 'invalid SQL']['model_response'].iloc[i].lstrip())
print(df_wrong[df_wrong['error_prompt'] == 'invalid SQL']['corrected_query'].iloc[i].lstrip())
print(df_wrong[df_wrong['error_prompt'] == 'invalid SQL']['open_ai_completion'].iloc[i].lstrip())

Schema: create table `affected_region` (`region_id` int,`storm_id` int,`number_city_affected` real,primary key (`region_id`,`storm_id`),foreign key (`region_id`);
Question: Count the number of different affected regions.

###


SELECT count(DISTINCT affected_region) FROM affected_region
SELECT COUNT(DISTINCT region_id) FROM affected_region
SELECT count(DISTINCT region_id) FROM affected_region



In [5]:
df_wrong[df_wrong['error_prompt'] == 'invalid SQL']['model_response']

82      SELECT sum ,  avg(settlement_amount) FROM Set...
83      SELECT sum ,  avg(settlement_amount) FROM Set...
231     SELECT dependent_name FROM dependent WHERE re...
352     SELECT hispanic_percent FROM city WHERE black...
353     SELECT hispanic_percent FROM city WHERE black...
406     SELECT count(DISTINCT affected_region) FROM a...
407     SELECT count(DISTINCT affected_region) FROM a...
471     SELECT format FROM files GROUP BY format ORDE...
472     SELECT format FROM files GROUP BY format ORDE...
479     SELECT language ,  avg(rating) FROM song GROU...
480     SELECT language ,  avg(rating) FROM song GROU...
481     SELECT language ,  avg(rating) FROM song GROU...
482     SELECT language ,  avg(rating) FROM song GROU...
483     SELECT format ,  count(DISTINCT artist_name) ...
484     SELECT format ,  count(DISTINCT artist_name) ...
485     SELECT count(*) ,  format FROM Files GROUP BY...
486     SELECT count(*) ,  format FROM Files GROUP BY...
509     SELECT song_name FROM S

In [7]:
print(open_ai_correction_prompt(df_wrong.loc[82]))


        Your task is to correct a SQL query.
        
        You have the following DDLs:
        ```
        create table settlements (settlement_id integer not null,claim_id integer,effective_date date,settlement_amount real,primary key (settlement_id),unique (settlement_id),foreign key (claim_id)
        ```
        And following question:
        """Return the sum and average of all settlement amounts."""
        
        The draft SQL query is:
        ```
         SELECT sum ,  avg(settlement_amount) FROM Settlements
        ```
        
        The problem is:
        ```
        invalid SQL
        ```
        
        Respond with only one concise, corrected SQL statement.
        


In [50]:
# ignore: 
# - whitespace
# - capitalization
# - trailing semicolons
# - the difference between single and double quotes
# - difference between 'distinct(...)' and 'distinct ...'
# - explictly stating ASC (the ORDER BY default)
tmp = df_wrong[df_wrong['corrected_query'] != ''].copy()

tmp['correct'] = np.where(tmp['open_ai_completion'].str
                                                                         .strip()
                                                                         .str
                                                                         .lower()
                                                                         .apply(lambda x: re.sub("distinct\((.*)\)", "distinct\\1", 
                                                                                          re.sub('orderby(.*)asc', 'orderby\\1',
                                                                                          re.sub('"', "'",
                                                                                          re.sub('\s+', '', x))))
                                                                                .rstrip(';')) == 
                                     tmp['corrected_query'].str
                                                                     .strip()
                                                                     .str
                                                                     .lower()
                                                                         .apply(lambda x: re.sub("distinct\((.*)\)", "distinct\\1", 
                                                                                          re.sub('orderby(.*)asc', 'orderby\\1',
                                                                                          re.sub('"', "'",
                                                                                          re.sub('\s+', '', x))))
                                                                            .rstrip(';')), 1, 0)
tmp['correct'].mean()

0.6

In [51]:
tmp

Unnamed: 0,open_ai_prompt,open_ai_completion,model_response,schema,question,db_id,open_ai_execution,model_response_execution,execution_accuracy,error_prompt,corrected_query,correct
82,Schema: create table settlements (settlement_i...,"SELECT sum(settlement_amount) , avg(settleme...","SELECT sum , avg(settlement_amount) FROM Set...",create table settlements (settlement_id intege...,Return the sum and average of all settlement a...,insurance_fnol,"[(53543.69999999999, 5354.369999999999)]",invalid SQL,0,invalid SQL,"\n SELECT SUM(settlement_amount), AVG(s...",1
83,Schema: create table settlements (settlement_i...,"SELECT sum(settlement_amount) , avg(settleme...","SELECT sum , avg(settlement_amount) FROM Set...",create table settlements (settlement_id intege...,Return the sum and average of all settlement a...,insurance_fnol,"[(53543.69999999999, 5354.369999999999)]",invalid SQL,0,invalid SQL,"\n SELECT SUM(settlement_amount), AVG(s...",1
231,"Schema: create table dependent(essn integer,de...",SELECT Dependent_name FROM dependent WHERE re...,SELECT dependent_name FROM dependent WHERE re...,"create table dependent(essn integer,dependent_...",find all dependent names who have a spouse rel...,company_1,"[(Joy,), (Abner,), (Elizabeth,)]",invalid SQL,0,invalid SQL,\n SELECT dependent_name FROM dependent...,1
352,"Schema: create table ""city"" (""city_id"" int,""co...",SELECT Hispanic FROM city WHERE Black > 10\n,SELECT hispanic_percent FROM city WHERE black...,"create table ""city"" (""city_id"" int,""county_id""...",What are the percentage of hispanics in cities...,county_public_safety,"[(99.5,), (99.1,), (99.0,), (99.1,), (99.2,), ...",invalid SQL,0,invalid SQL,\n SELECT hispanic FROM city WHERE blac...,1
353,"Schema: create table ""city"" (""city_id"" int,""co...",SELECT Hispanic FROM city WHERE Black > 10\n,SELECT hispanic_percent FROM city WHERE black...,"create table ""city"" (""city_id"" int,""county_id""...",What are the percentage of hispanics in cities...,county_public_safety,"[(99.5,), (99.1,), (99.0,), (99.1,), (99.2,), ...",invalid SQL,0,invalid SQL,\n SELECT hispanic FROM city WHERE blac...,1
406,Schema: create table `affected_region` (`regio...,SELECT count(DISTINCT region_id) FROM affecte...,SELECT count(DISTINCT affected_region) FROM a...,create table `affected_region` (`region_id` in...,Count the number of different affected regions.,storm_record,"[(4,)]",invalid SQL,0,invalid SQL,\n SELECT COUNT(DISTINCT region_id) FRO...,1
407,Schema: create table `affected_region` (`regio...,SELECT count(DISTINCT region_id) FROM affecte...,SELECT count(DISTINCT affected_region) FROM a...,create table `affected_region` (`region_id` in...,Count the number of different affected regions.,storm_record,"[(4,)]",invalid SQL,0,invalid SQL,\n SELECT COUNT(DISTINCT region_id) FRO...,1
471,Schema: create table files(\tf_id number(10) n...,SELECT formats FROM files GROUP BY formats OR...,SELECT format FROM files GROUP BY format ORDE...,"create table files(\tf_id number(10) not null,...",What is the most popular file format?,music_1,"[(mp4,)]",invalid SQL,0,invalid SQL,\n SELECT format FROM files GROUP BY fo...,0
472,Schema: create table files(\tf_id number(10) n...,SELECT formats FROM files GROUP BY formats OR...,SELECT format FROM files GROUP BY format ORDE...,"create table files(\tf_id number(10) not null,...",What is the most popular file format?,music_1,"[(mp4,)]",invalid SQL,0,invalid SQL,\n SELECT format FROM files GROUP BY fo...,0
479,Schema: create table song(\tsong_name varchar2...,"SELECT avg(rating) , languages FROM song GRO...","SELECT language , avg(rating) FROM song GROU...","create table song(\tsong_name varchar2(50),\ta...",What is the average rating of songs for each l...,music_1,"[(7.5, bangla), (7.0, english)]",invalid SQL,0,invalid SQL,"\n SELECT language, AVG(rating) FROM so...",0


In [54]:
for i in range(len(tmp)):
    # print(tmp['open_ai_prompt'].iloc[i])
    print(tmp['model_response'].iloc[i].lstrip())
    print(tmp['corrected_query'].iloc[i].lstrip())
    print(tmp['open_ai_completion'].iloc[i].lstrip())

SELECT sum ,  avg(settlement_amount) FROM Settlements
SELECT SUM(settlement_amount), AVG(settlement_amount) FROM Settlements;
SELECT sum(settlement_amount) ,  avg(settlement_amount) FROM settlements

SELECT sum ,  avg(settlement_amount) FROM Settlements
SELECT SUM(settlement_amount), AVG(settlement_amount) FROM Settlements;
SELECT sum(settlement_amount) ,  avg(settlement_amount) FROM settlements

SELECT dependent_name FROM dependent WHERE relationship  =  'spouse' AND employee_id  !=  0
SELECT dependent_name FROM dependent WHERE relationship  =  'spouse'
SELECT Dependent_name FROM dependent WHERE relationship  =  'Spouse'

SELECT hispanic_percent FROM city WHERE black  >  10
SELECT hispanic FROM city WHERE black > 10;
SELECT Hispanic FROM city WHERE Black  >  10

SELECT hispanic_percent FROM city WHERE black  >  10
SELECT hispanic FROM city WHERE black > 10;
SELECT Hispanic FROM city WHERE Black  >  10

SELECT count(DISTINCT affected_region) FROM affected_region
SELECT COUNT(DISTINCT r