## Imports and setup

### Imports

In [5]:
from openai import OpenAI

from tkinter import * 
from tkinter import ttk
from tkinter import Tk 

import sqlalchemy
import os
import pandas as pd
from tabulate import tabulate

from dotenv import load_dotenv
import os 
import json 
import re 
import csv

from sqlalchemy import create_engine, MetaData, Table
from sqlalchemy import create_engine, text,Column, Integer, String, Float, Date, ForeignKey
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import relationship

DBTYPE = 'postgres'


### Load SQL

#### Load Oracle

In [2]:
# ! hi

load_dotenv()

oracle_connection_string = os.getenv('ORACLE_CONNECTION_STRING')

if oracle_connection_string:
    try:
        engine = create_engine(oracle_connection_string)
        connection = engine.connect()
        print("Connected to Oracle database successfully!")
        connection.close()
    except Exception as e:
        print("Error:", e)
else:
    print("Error: ORACLE_CONNECTION_STRING environment variable is not set.")
 
engine = create_engine(oracle_connection_string)
connection = engine.connect()


Connected to Oracle database successfully!


#### Load Postgres

In [33]:
load_dotenv()

pgsql_connection_string = os.getenv('PSQL_CONNECTION_STRING')
print(pgsql_connection_string)
if pgsql_connection_string:
    try:
        engine = create_engine(pgsql_connection_string)
        connection = engine.connect()
        print('Connected to Postgres Successfully!')
        connection.close()
    except Exception as e:
        print('Error:',e)
else:
    print('Error: psql_connection_string environment variable not set')


engine = create_engine(pgsql_connection_string)
connection = engine.connect()
    

postgresql://postgres:postgres@localhost/mydatabase
Connected to Postgres Successfully!


In [34]:
#Define Tables

metadata = MetaData() 

employees_table = Table('employees',metadata,autoload_with=engine)
discounts_table = Table('discounts',metadata,autoload_with=engine)
producers_table = Table('producers',metadata,autoload_with=engine)
produces_table = Table('produces',metadata,autoload_with=engine)
products_table = Table('products',metadata,autoload_with=engine)
purchases_table = Table('purchases',metadata,autoload_with=engine)
storehours_table = Table('storehours',metadata,autoload_with=engine)
stores_table = Table('stores',metadata,autoload_with=engine)
transactions_table = Table('transactions',metadata,autoload_with=engine)
members_table = Table('members',metadata,autoload_with=engine)


In [10]:
with engine.connect() as connection:
    result = connection.execute(text('select * from members;'))
    rows = result.fetchall()

    for row in rows:
        print(row)

(1, 'John Smith', '123 Main Street', '1')
(2, 'Alice Lee', '456 Elm Avenue', '1')


### LLM Helper Functions

In [35]:
if DBTYPE == 'oracle':
    with open('primer.txt', 'r') as file:
        primer = file.read()
elif DBTYPE == 'postgres':
    with open('postgresprimer.txt','r') as file:
        primer = file.read()

primobj = {'role':'system','content':primer}

In [36]:
client = OpenAI()
def convCompletion(completion,type='direct'):
    return {'role':'assistant','content':completion.choices[0].message.content if type=='direct' else completion}
def convQuery(query,role='user'):
    return {'role': role,'content':query}

# completion = client.chat.completions.create(
#   model="gpt-3.5-turbo",
#   messages=[
#     primobj,
#     {"role": "user", "content": "Compose a poem that explains the concept of recursion in programming."},
#   ]
# )

def doConv(t=3, model='ft:gpt-3.5-turbo-0125:personal::96kERypj'):
    msgs = [primobj]
    for i in range(t):
        a = input('Enter query')
        query = convQuery(a)
        msgs.append(query) 
        completion = client.chat.completions.create(
          model=model,
          messages=msgs
        )
        print('user:',a)
        print('bot:',completion.choices[0].message.content)
        msgs.append(convCompletion(completion))


def sendUserQuery(msg,model='ft:gpt-3.5-turbo-0125:personal::96kERypj'):
    query = convQuery(msg)
    msgs = [primobj, query]
    msgs.append(query) 
    completion = client.chat.completions.create(
        model=model,
        messages=msgs
    )
    return completion.choices[0].message.content
    
def executeQuery(msg):
    print('user:',msg)
    print('bot:',sendUserQuery(msg)) 
# doConv()
        
# executeQuery('find cheapest product at store name honey farms')
# print(completion.choices[0].message.content)

def printConv(msgs):
    for msg in msgs:
        print(msg['role']+":",msg['content']) 



In [5]:
print(executeQuery('find all members who have spent more than 30'))

user: find all members who have spent more than 30
bot: SELECT DISTINCT T1.mid FROM TRANSACTIONS AS T1 JOIN PRODUCTS AS T2 ON T1.mid  =  T2.pid JOIN PRODUCTS AS T3 ON T2.category  =  T3.name HAVING sum(T3.price * T2.quantity)  >  30
None


## UI Code

#### Helper Functions

In [31]:
def guiDoQuery(*args):
    res = sendUserQuery(userquery.get())
    print(res)
    botresponse.set(res)

def executeSQL(sqlresponse,*args):
    with engine.connect() as connection:        
        try:
            runquery = connection.execute(text(sql.get()))
            result = runquery.fetchall()
            cols = runquery.keys()
            runquery.close()

            # print(type(cols))
            table = [list(cols)]+result
            temp = tabulate(table,headers='firstrow',tablefmt='fancy_grid')

            sqlresponse.destroy()
            sqlresponse = makeTable(mainframe,data=table)
            sqlresponse.grid(row=5,column=2,rowspan=2,columnspan=10)

            
        except Exception as e:
            sqlresponse.destroy()
            sqlresponse = Text(mainframe)
            sqlresponse.insert(END,e)
            sqlresponse.config(state=DISABLED)
            sqlresponse.grid(row=5,column=2)
        
def makeTable(rt,data):
    frame = ttk.Frame(rt)

    for i in range(len(data)):
        for j in range(len(data[0])):
            e = Text(frame,width=20,height=3)
            e.grid(row=i,column=j)

            e.insert(END,data[i][j])
            e.config(state=DISABLED)

            if i==0:
                e.tag_configure('bold', font=('TkDefaultFont',12,'bold'))
                e.tag_add('bold','1.0','end')

    return frame


#### Run Full UI

In [32]:
root = Tk()
root.title("SQL LLM")

mainframe = ttk.Frame(root, padding="3 3 12 12")
mainframe.grid(column=0, row=0, sticky=(N, W, E, S))
root.columnconfigure(0, weight=1)
root.rowconfigure(0, weight=1)

userquery = StringVar()
userquery_entry = ttk.Entry(mainframe, width=10, textvariable=userquery)
ttk.Label(mainframe, text='Enter query:').grid(column=1,row=1,sticky=(W))
userquery_entry.grid(column=2, row=1, sticky=(W, E))
ttk.Button(mainframe, text='Execute Query',command=guiDoQuery).grid(row=1,column=3)

botresponse = StringVar()
botresponse.set('hi')

sql = StringVar() 
sql_entry = ttk.Entry(mainframe,width=20, textvariable=sql).grid(column=1,row=5)

#SQL RESPONSE SPOT: Col 2, row 5, rowspan 2, colspan 10
sqlresponse = makeTable(mainframe, [ ['ID','Name','Salary',],[3,'a b',333333],[3,'a b',333333],[3,'a b',333333]])
sqlresponse.grid(row=5,column=2,rowspan=2,columnspan=10)
#render labels
ttk.Label(mainframe,text='Bot Response').grid(column=10,row=1)
ttk.Label(mainframe, textvariable=botresponse).grid(column=11, row=1, sticky=(W, E))
# ttk.Label(mainframe, textvariable=botresponse).grid(column=3,row=2)
# ttk.Button(mainframe, text="Calculate", command=calculate).grid(column=3, row=3, sticky=W)


ttk.Label(mainframe,text='SQL To Execute:').grid(column=1,row=3,sticky=W)
ttk.Button(mainframe,text='Execute SQL',command=lambda:executeSQL(sqlresponse)).grid(column=2,row=3,sticky=W)
for child in mainframe.winfo_children(): 
    child.grid_configure(padx=5, pady=5)

userquery_entry.focus()



testquery = [ ['name','id','test'],['sdf',3,'ewr'],['sdf',3,'ewr'],['sdf',3,'ewr'],]
root.mainloop()

## Testing Suite

### Run Tests

#### Helper Functions

In [37]:
def extract_queries_from_file(file_path,includesemi=False):
    with open(file_path, 'r') as file:
        sql_content = file.read()

    # Split SQL content into individual queries
    # queries = re.split(r';\s*\n(?![\s\w]*SELECT)', sql_content)
    queries = sql_content.split(';')
    queries = [query.strip() for query in queries if query.strip()]
    queries = [q.replace('\n',' ') for q in queries]

    if includesemi:
        queries = [query + ';' for query in queries]

    # Remove empty or whitespace-only queries
    # queries = [query.strip() for query in queries if query.strip()]

    return queries

def shortenString(text, l=20):
    return text[:l] + '...' if len(text) > l else ''

def debugResult(userquery,sql2, initresult,sqlres, querylim=5,debug=False):
    msgs = [primobj,convQuery(userquery),convQuery(initresult,role='assistant')]

    fixed = False
    i = 0 
    curerr = sqlres

    while fixed is not True and i < querylim:
        if curerr is False:
            msgs.append(convQuery('That gave the wrong result. Fix.'))
        else:
            msgs.append(convQuery('That gave me '+curerr+'. Fix to work for SQLOracle'))

        completion = client.chat.completions.create(
          model='ft:gpt-3.5-turbo-0125:personal::96kERypj',
          messages=msgs
        )

        newsql=completion.choices[0].message.content
        nmsg = convCompletion(completion)
        msgs.append(nmsg)

        curerr = compSQL(newsql,sql2,debug=False,uq=userquery)
        if curerr is True:
            fixed = True


            if debug:
                print('Succesfully Fixed!')
                print(newsql)
                printConv(msgs[1:])
                print('\n')

            return newsql

        i += 1

    if debug:
        print('Failed to Fix')
        printConv(msgs[1:])
        print('\n')

    return False
                 
def compSQL(predicted,sql2,debug=False,uq=None):
    with engine.connect() as connection:
        real = connection.execute(text(sql2))
        res2 = real.fetchall()
        cols2 = real.keys()
        real.close()

        try:
            predicted = connection.execute(text(predicted))
            res1 = predicted.fetchall()
            cols1 = predicted.keys()
            predicted.close()
        except Exception as e:
            if debug:
                if uq is not None:
                    print(uq)
                # print(e,'\n')
                
                # print(predicted)
                # print(fixedquery,'\n')
            fixedquerystr = 'Your attempt to code user query '+uq + ' in oracleSQL gave me error ' +str(e)+'. Fix'
            fixedquery = sendUserQuery(fixedquerystr)

            return 'Error'+fixedquery


        
        return cols1 == cols2 and res1 == res2
        


# print(tabulate([ [r[1],r[-1]] for r in restable],headers='firstrow',tablefmt='fancy_grid'))

#### Run Tests

In [39]:
# connection = engine.connect()
file = open('userqueries.txt','r')
restable=[['Index','User Query', 'Our SQL','Generated SQL','Correct?']]
userqueries = [l for l in file]

realqueries = extract_queries_from_file('postgresqueries.sql',includesemi=True)

errors = 0
wrong = 0
fixed = 0 
correct = 0 
attemptedfixes = {True:0,False:0,'Error':0}
for i in range(len(userqueries)):
    r = sendUserQuery(userqueries[i])   
    # r = r.replace('"',"'")
    # r = r.replace('\n',' ')

    # print('Debug Messages:',i)
    # print(f'\t{userqueries[i]}\n')
    # print(f'\t{r}\n')
    # print(f'\t{realqueries[i]}\n')
    
    comp = compSQL(r,realqueries[i],debug=True,uq = userqueries[i])

    if comp is True:
        correct += 1
    elif comp is False:
        # wrong += 1
        res = debugResult(userqueries[i],realqueries[i],r,comp)
        if res == False:
            wrong += 1
        elif res == True:
            r = res
            comp = 'Fixed'
            fixed += 1
        else:
            errors += 1

    else:
        # errors += 1
        res = debugResult(userqueries[i],realqueries[i],r,comp)
        if res == True:
            r = res
            comp = 'Fixed'
            fixed += 1
        elif res == False:
            wrong += 1
        else:
            errors += 1

    restable.append([i,userqueries[i],realqueries[i],r,comp])

print(f'Correct:{correct}, Wrong:{wrong},Errors: {errors}, Fixed:{fixed}')
# df = pd.DataFrame(restable[1:],columns=restable[0])
connection.close()

Provide the address and employee id of the store which has the longest opening time on Monday.

Provide the information about the products which are produced by Tech Innovate with the price below 600.

Display the total sales revenue generated by each store in the last quarter.

Correct:7, Wrong:11,Errors: 2, Fixed:0


In [40]:
temp = [ [r[0],r[4]] for r in restable]
print(tabulate(temp, headers='firstrow'))

  Index  Correct?
-------  -----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
      0  True
      1  True
      2  True
      3  False
      4  False
      5  True
      6  True
      7  False
      8  True
      9  True
     10  False
     11  False
     12  False
     13  False
     14  ErrorSELECT T3.location ,  T2.eid FROM storehours AS T1 JOIN employees AS T2 ON T1.sid  =  T2.storeID JOIN stores AS T3 ON T1.sid  =  T3.sid WHERE T1.day  =  1 ORDER BY T1.closetime - T1.opentime DESC LIMIT 1
     15  ErrorSELECT T1.pname FROM products AS T1 JOIN produces AS T2 ON T1.pid  =  T2.pid JOIN producers AS T3 ON T2.prid  =  T3.prid WHERE T3.name  =  'Tech Innovate' AND T1.price  <  600
     16  False
     17  False
     18  ErrorSELECT sum(T1.price * T2.quantity) ,  T3.name FROM TRANSACTIONS AS T3 JOIN purchases AS T2

In [42]:
print('Test Results'.center(150))

print(tabulate([ [r[1],r[4] if (r[4] == True or r[4] == False  or r[4] == 'Correct?' or r[4] == 'Fixed') else 'SQL Failed'] for r in restable],headers='firstrow',tablefmt='fancy_grid'))

csv_file = "output.csv"

# Write the list of lists to a CSV file
with open(csv_file, mode="w", newline="") as file:
    writer = csv.writer(file)
    writer.writerows([ [r[1],r[4] if (r[4] == True or r[4] == False  or r[4] == 'Correct?' or r[4] == 'Fixed') else 'SQL Failed'] for r in restable])
    



                                                                     Test Results                                                                     
╒═════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════╤════════════╕
│ User Query                                                                                                                              │ Correct?   │
╞═════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════╪════════════╡
│ Give me the addresses of all registered stores.                                                                                         │ True       │
├─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┼────────────┤
│ Give me the names along with locations of registered producers.                   

## Spider Dataset processing

In [51]:
directory_path = 'spider/database'

# Get all directories within the specified directory
folders = [folder for folder in os.listdir(directory_path) if os.path.isdir(os.path.join(directory_path, folder))]

# Print the list of folders
# print(folders)


dbquerytable = {}
for db in folders[:50]:
    dbpath = os.path.join(os.path.join(directory_path,db),'schema.sql')
    # print(db)
    try:
        with open(dbpath, 'r') as f:
            dbquerytable[db] = {}
            dbquerytable[db]['schema'] = f.read()
    except Exception as e:
        print(f'{db} Error: {e}')



# print(dbquerytable)

trainingdatapath = 'spider/train_spider.json'

with open(trainingdatapath, 'r') as file:
    trainingdata = json.load(file)


for entry in trainingdata:
    db = entry['db_id']
    if db in dbquerytable:
        sql = entry['query']
        uquery = entry['question']
        if 'queries' not in dbquerytable[db]:
            dbquerytable[db]['queries'] = []
        dbquerytable[db]['queries'].append( (uquery,sql))



# print(dbquerytable['medicine_enzyme_interaction']['queries'])
# print(dbquerytable)

openaijson = []

out = 'You convert user queries to sql. use these examples:\n'
for key, val in dbquerytable.items():
    out += f"\nDatabase: {key}\nschema:\n{val['schema']}\n"
    if 'queries' in val:
        
        for p in val['queries']:
            out += f'\tUser query: {p[0]}\n\tSQL:{p[1]}\n'

            pt ={'role':'system','content':f'You convert user queries to SQL statements. use the following schema: {val["schema"]}'}
            pt2 = {'role':'user','content':p[0]}
            pt3 = {'role':'assistant','content':p[1]}
            openaijson.append({'messages':[pt,pt2,pt3]})
            with open('finetuningfile.jsonl','a') as file:
                # print('writing')
                jstr = json.dumps({'messages': [pt,pt2,pt3]})
                file.write(jstr+'\n')


epinions_1 Error: [Errno 2] No such file or directory: 'spider/database/epinions_1/schema.sql'
formula_1 Error: [Errno 2] No such file or directory: 'spider/database/formula_1/schema.sql'
college_1 Error: [Errno 2] No such file or directory: 'spider/database/college_1/schema.sql'
voter_1 Error: [Errno 2] No such file or directory: 'spider/database/voter_1/schema.sql'


In [52]:
#Create finetuning job

file = client.files.create(
  file=open("finetuningfile.jsonl", "rb"),
  purpose="fine-tune"
)

client.fine_tuning.jobs.create(
  training_file=file.id,
  model="gpt-3.5-turbo"
)

FineTuningJob(id='ftjob-nL36YDN3x3OLLo9xvwLnHj1O', created_at=1711388152, error=Error(code=None, message=None, param=None, error=None), fine_tuned_model=None, finished_at=None, hyperparameters=Hyperparameters(n_epochs='auto', batch_size='auto', learning_rate_multiplier='auto'), model='gpt-3.5-turbo-0125', object='fine_tuning.job', organization_id='org-U9LCC1iwJFC3QqGaYQw6frNf', result_files=[], status='validating_files', trained_tokens=None, training_file='file-LYUqnOn7AFijH5zhr0iHCztz', validation_file=None, user_provided_suffix=None)