## Imports and setup

### Imports

In [1]:
from openai import OpenAI

from tkinter import * 
from tkinter import ttk
from tkinter import Tk 

import sqlalchemy
import os
import pandas as pd
from tabulate import tabulate

from dotenv import load_dotenv
import os 
import json 
import re 
import csv

from sqlalchemy import create_engine, MetaData, Table
from sqlalchemy import create_engine, text,Column, Integer, String, Float, Date, ForeignKey
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import relationship

DBTYPE = 'postgres'


### Load SQL

#### Load Oracle

In [2]:
# ! hi

load_dotenv()

oracle_connection_string = os.getenv('ORACLE_CONNECTION_STRING')

if oracle_connection_string:
    try:
        engine = create_engine(oracle_connection_string)
        connection = engine.connect()
        print("Connected to Oracle database successfully!")
        connection.close()
    except Exception as e:
        print("Error:", e)
else:
    print("Error: ORACLE_CONNECTION_STRING environment variable is not set.")
 
engine = create_engine(oracle_connection_string)
connection = engine.connect()


Connected to Oracle database successfully!


#### Load Postgres

In [2]:
load_dotenv()

pgsql_connection_string = os.getenv('PSQL_CONNECTION_STRING')
if pgsql_connection_string:
    try:
        engine = create_engine(pgsql_connection_string)
        connection = engine.connect()
        print('Connected to Postgres Successfully!')
        connection.close()
    except Exception as e:
        print('Error:',e)
else:
    print('Error: psql_connection_string environment variable not set')


engine = create_engine(pgsql_connection_string)
connection = engine.connect()
    

Connected to Postgres Successfully!


#### Load SQLITE

In [28]:
load_dotenv()

sqlite_connection_string = os.getenv('SQLITE_CONNECTION_STRING')
print(sqlite_connection_string)
if sqlite_connection_string:
    try: 
        engine = create_engine('sqlite:///main.db')
        connection = engine.connect()
        print('Connected to Sqlite Successfully!')
        connection.close()
    except Exception as e:
        print('Error:',e)
else:
    print('Error: Connection String not found')

engine = create_engine('sqlite:///main.db')

sqlite://main.db
Connected to Sqlite Successfully!


#### Define Tables (unused)

In [3]:
#Define Tables

metadata = MetaData() 

employees_table = Table('employees',metadata,autoload_with=engine)
discounts_table = Table('discounts',metadata,autoload_with=engine)
producers_table = Table('producers',metadata,autoload_with=engine)
produces_table = Table('produces',metadata,autoload_with=engine)
products_table = Table('products',metadata,autoload_with=engine)
purchases_table = Table('purchases',metadata,autoload_with=engine)
storehours_table = Table('storehours',metadata,autoload_with=engine)
stores_table = Table('stores',metadata,autoload_with=engine)
transactions_table = Table('transactions',metadata,autoload_with=engine)
members_table = Table('members',metadata,autoload_with=engine)


In [10]:
with engine.connect() as connection:
    result = connection.execute(text('select * from members;'))
    rows = result.fetchall()

    for row in rows:
        print(row)

(1, 'John Smith', '123 Main Street', '1')
(2, 'Alice Lee', '456 Elm Avenue', '1')


### LLM Helper Functions

In [3]:
if DBTYPE == 'oracle':
    with open('primer.txt', 'r') as file:
        primer = file.read()
elif DBTYPE == 'postgres':
    with open('postgresprimer.txt','r') as file:
        primer = file.read()

primobj = {'role':'system','content':primer}

In [4]:
client = OpenAI()
def convCompletion(completion,type='direct'):
    return {'role':'assistant','content':completion.choices[0].message.content if type=='direct' else completion}
def convQuery(query,role='user'):
    return {'role': role,'content':query}

# completion = client.chat.completions.create(
#   model="gpt-3.5-turbo",
#   messages=[
#     primobj,
#     {"role": "user", "content": "Compose a poem that explains the concept of recursion in programming."},
#   ]
# )

def doConv(t=3, model='ft:gpt-3.5-turbo-0125:personal::96kERypj'):
    msgs = [primobj]
    for i in range(t):
        a = input('Enter query')
        query = convQuery(a)
        msgs.append(query) 
        completion = client.chat.completions.create(
          model=model,
          messages=msgs
        )
        print('user:',a)
        print('bot:',completion.choices[0].message.content)
        msgs.append(convCompletion(completion))


def sendUserQuery(msg,model='ft:gpt-3.5-turbo-0125:personal::96kERypj'):
    query = convQuery(msg)
    msgs = [primobj, query]
    msgs.append(query) 
    completion = client.chat.completions.create(
        model=model,
        messages=msgs
    )
    return completion.choices[0].message.content
    
def executeQuery(msg):
    print('user:',msg)
    print('bot:',sendUserQuery(msg)) 
# doConv()
        
# executeQuery('find cheapest product at store name honey farms')
# print(completion.choices[0].message.content)

def printConv(msgs):
    for msg in msgs:
        print(msg['role']+":",msg['content']) 



In [4]:
print(executeQuery('find all members who have spent more than 30'))

user: find all members who have spent more than 30
bot: SELECT T1.mid FROM MEMBERS AS T1 JOIN TRANSACTIONS AS T2 ON T1.mid  =  T2.mid JOIN PURCHASES AS T3 ON T2.tid  =  T3.tid JOIN PRODUCTS AS T4 ON T3.pid  =  T4.pid GROUP BY T1.mid HAVING sum(T3.quantity * T4.price)  >  30
None


## UI Code

#### Helper Functions

In [6]:
def guiDoQuery(*args):
    res = sendUserQuery(userquery.get())
    print(res)
    botresponse.set(res)

def executeSQL(sqlresponse,*args):
    for widget in mainframe.grid_slaves(row=5,column=2):
        widget.grid_forget()

    with engine.connect() as connection:        
        try:
            runquery = connection.execute(text(sql.get()))
            result = runquery.fetchall()
            cols = runquery.keys()
            runquery.close()

            # print(type(cols))
            table = [list(cols)]+result
            temp = tabulate(table,headers='firstrow',tablefmt='fancy_grid')

            sqlresponse = makeTable(mainframe,data=table)
            sqlresponse.grid(row=5,column=2,rowspan=3,columnspan=10,sticky=(N,W))

            
        except Exception as e:
            # sqlresponse.forget()
            sqlresponse = Text(mainframe)
            sqlresponse.insert(END,e)
            sqlresponse.config(state=DISABLED)
            sqlresponse.grid(row=5,column=2,rowspan=3, columnspan=10,sticky=(N,W))
        
def makeTable(rt,data):
    frame = ttk.Frame(rt)

    for i in range(len(data)):
        for j in range(len(data[0])):
            e = Text(frame,width=20,height=3)
            e.grid(row=i,column=j)

            e.insert(END,data[i][j])
            e.config(state=DISABLED)

            if i==0:
                e.tag_configure('bold', font=('TkDefaultFont',12,'bold'))
                e.tag_add('bold','1.0','end')

    return frame

def clearInp():
    sql.set('')
    userquery.set('')

#### Run Full UI

In [7]:
root = Tk()
root.title("SQL LLM")

#Initialize Main Frame
mainframe = ttk.Frame(root, padding="3 3 12 12")
mainframe.grid(column=0, row=0, sticky=(N, W, E, S))
root.columnconfigure(0, weight=1)
root.rowconfigure(0, weight=1)

#Begin user entry spot
#Grid:  (1,1),(1,2),(1,3),(1,4)
userquery = StringVar()
userquery_entry = ttk.Entry(mainframe, width=10, textvariable=userquery)
ttk.Label(mainframe, text='Enter query:').grid(column=1,row=1,sticky=(W))
userquery_entry.grid(column=2, row=1, sticky=(W, E))
ttk.Button(mainframe, text='Execute Query',command=guiDoQuery).grid(row=1,column=3)
ttk.Button(mainframe,text='Send to sql',command=lambda:sql.set(botresponse.get())).grid(row=1,column=4)
#Bot Response
#Grid: (1,10), (1,11)
botresponse = StringVar()
botresponse.set('hi')
ttk.Label(mainframe,text='Bot Response').grid(column=10,row=1)
ttk.Label(mainframe, textvariable=botresponse).grid(column=11, row=1, sticky=(W, E))

#SQL Entry
#Grid: (5,1),(5,2)-(7,12),(3,1),(3,2)
sql = StringVar() 
sql_entry = ttk.Entry(mainframe,width=20, textvariable=sql).grid(column=1,row=5,sticky=(N))
sqlresponse = makeTable(mainframe, [ ['ID','Name','Salary',],[3,'a b',333333],[3,'a b',333333],[3,'a b',333333]])
sqlresponse.grid(row=5,column=2,rowspan=3,columnspan=10,sticky=(N,W))
ttk.Label(mainframe,text='SQL To Execute:').grid(column=1,row=3,sticky=W)
ttk.Button(mainframe,text='Execute SQL',command=lambda:executeSQL(sqlresponse)).grid(column=2,row=3,sticky=W)
ttk.Button(mainframe,text='Clear', command=clearInp).grid(row=3,column=3)

for child in mainframe.winfo_children(): 
    child.grid_configure(padx=5, pady=5)

userquery_entry.focus()



testquery = [ ['name','id','test'],['sdf',3,'ewr'],['sdf',3,'ewr'],['sdf',3,'ewr'],]
root.mainloop()

SELECT * FROM members


## Testing Suite

### Run Tests

#### Helper Functions

In [6]:
def extract_queries_from_file(file_path,includesemi=False):
    with open(file_path, 'r') as file:
        sql_content = file.read()

    # Split SQL content into individual queries
    # queries = re.split(r';\s*\n(?![\s\w]*SELECT)', sql_content)
    queries = sql_content.split(';')
    queries = [query.strip() for query in queries if query.strip()]
    queries = [q.replace('\n',' ') for q in queries]

    if includesemi:
        queries = [query + ';' for query in queries]

    # Remove empty or whitespace-only queries
    # queries = [query.strip() for query in queries if query.strip()]

    return queries

def shortenString(text, l=20):
    return text[:l] + '...' if len(text) > l else ''

def debugResult(userquery,sql2, initresult,sqlres, querylim=5,debug=False):
    msgs = [primobj,convQuery(userquery),convQuery(initresult,role='assistant')]

    fixed = False
    i = 0 
    curerr = sqlres

    while fixed is not True and i < querylim:
        if curerr is False:
            msgs.append(convQuery('That gave the wrong result. Fix.'))
        else:
            msgs.append(convQuery('That gave me '+curerr+'. Fix to work for SQLOracle'))

        completion = client.chat.completions.create(
          model='ft:gpt-3.5-turbo-0125:personal::96kERypj',
          messages=msgs
        )

        newsql=completion.choices[0].message.content
        nmsg = convCompletion(completion)
        msgs.append(nmsg)

        curerr = compSQL(newsql,sql2,debug=False,uq=userquery)
        if curerr is True:
            fixed = True


            if debug:
                print('Succesfully Fixed!')
                print(newsql)
                printConv(msgs[1:])
                print('\n')

            return newsql

        i += 1

    if debug:
        print('Failed to Fix')
        printConv(msgs[1:])
        print('\n')

    return False
                 
def compSQL(predicted,sql2,debug=False,uq=None):
    with engine.connect() as connection:
        real = connection.execute(text(sql2))
        res2 = real.fetchall()
        cols2 = real.keys()
        real.close()

        try:
            predicted = connection.execute(text(predicted))
            res1 = predicted.fetchall()
            cols1 = predicted.keys()
            predicted.close()
        except Exception as e:
            # if debug:
            #     if uq is not None:
            #         print(uq)
                # print(e,'\n')
                
                # print(predicted)
                # print(fixedquery,'\n')
            fixedquerystr = 'Your attempt to code user query '+uq + ' in oracleSQL gave me error ' +str(e)+'. Fix'
            fixedquery = sendUserQuery(fixedquerystr)

            return 'Error '+fixedquery


        
        return cols1 == cols2 and res1 == res2
        


# print(tabulate([ [r[1],r[-1]] for r in restable],headers='firstrow',tablefmt='fancy_grid'))

#### Run Tests

In [7]:
# connection = engine.connect()
file = open('userqueries.txt','r')
restable=[['Index','User Query', 'Our SQL','Generated SQL','Correct?']]
userqueries = [l for l in file]

realqueries = extract_queries_from_file('postgresqueries.sql',includesemi=True)

errors = 0
wrong = 0
fixed = 0 
correct = 0 
attemptedfixes = {True:0,False:0,'Error':0}
for i in range(len(userqueries)):
    r = sendUserQuery(userqueries[i])   
    # r = r.replace('"',"'")
    # r = r.replace('\n',' ')

    # print('Debug Messages:',i)
    # print(f'\t{userqueries[i]}\n')
    # print(f'\t{r}\n')
    # print(f'\t{realqueries[i]}\n')
    
    comp = compSQL(r,realqueries[i],debug=True,uq = userqueries[i])

    if comp is True:
        correct += 1
    elif comp is False:
        # wrong += 1
        res = debugResult(userqueries[i],realqueries[i],r,comp,querylim=10)
        if res == False:
            wrong += 1
        elif res == True:
            r = res
            comp = 'Fixed'
            fixed += 1
        else:
            errors += 1

    else:
        # errors += 1
        res = debugResult(userqueries[i],realqueries[i],r,comp,querylim=10)
        if res == True:
            r = res
            comp = 'Fixed'
            fixed += 1
        elif res == False:
            wrong += 1
        else:
            errors += 1

    restable.append([i,userqueries[i],realqueries[i],r,comp])

print(f'Correct:{correct}, Wrong:{wrong},Errors: {errors}, Fixed:{fixed}')
# df = pd.DataFrame(restable[1:],columns=restable[0])
connection.close()

Correct:8, Wrong:11,Errors: 1, Fixed:0


In [8]:
temp = [ [r[0],r[4] if type(r[4]) is bool else 'Error'] for r in restable]
print(tabulate(temp, headers='firstrow',tablefmt='fancy_grid'))

╒═════════╤═════════╕
│   Index │ Error   │
╞═════════╪═════════╡
│       0 │ True    │
├─────────┼─────────┤
│       1 │ True    │
├─────────┼─────────┤
│       2 │ True    │
├─────────┼─────────┤
│       3 │ False   │
├─────────┼─────────┤
│       4 │ False   │
├─────────┼─────────┤
│       5 │ True    │
├─────────┼─────────┤
│       6 │ True    │
├─────────┼─────────┤
│       7 │ False   │
├─────────┼─────────┤
│       8 │ True    │
├─────────┼─────────┤
│       9 │ True    │
├─────────┼─────────┤
│      10 │ False   │
├─────────┼─────────┤
│      11 │ False   │
├─────────┼─────────┤
│      12 │ True    │
├─────────┼─────────┤
│      13 │ False   │
├─────────┼─────────┤
│      14 │ Error   │
├─────────┼─────────┤
│      15 │ Error   │
├─────────┼─────────┤
│      16 │ False   │
├─────────┼─────────┤
│      17 │ Error   │
├─────────┼─────────┤
│      18 │ Error   │
├─────────┼─────────┤
│      19 │ False   │
╘═════════╧═════════╛


In [42]:
print('Test Results'.center(150))

print(tabulate([ [r[1],r[4] if (r[4] == True or r[4] == False  or r[4] == 'Correct?' or r[4] == 'Fixed') else 'SQL Failed'] for r in restable],headers='firstrow',tablefmt='fancy_grid'))

csv_file = "output.csv"

# Write the list of lists to a CSV file
with open(csv_file, mode="w", newline="") as file:
    writer = csv.writer(file)
    writer.writerows([ [r[1],r[4] if (r[4] == True or r[4] == False  or r[4] == 'Correct?' or r[4] == 'Fixed') else 'SQL Failed'] for r in restable])
    



                                                                     Test Results                                                                     
╒═════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════╤════════════╕
│ User Query                                                                                                                              │ Correct?   │
╞═════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════╪════════════╡
│ Give me the addresses of all registered stores.                                                                                         │ True       │
├─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┼────────────┤
│ Give me the names along with locations of registered producers.                   

#### Spider Dataset Chek

In [49]:
# dirs = os.listdir('finetuningstuff/spider/database')
# print(dirs)

# table = 'academic'

# prim = 'You Convert user queries to sqlite sql statements. Use the following Schema:\n'
# with open('finetuningstuff/spider/database/academic/schema.sql') as f:
#     text = f.read()

# prim += text 

# print(prim)

output = {}
with open('finetuningstuff/spider/train_gold.sql') as f:
    text = f.read()
    lines = text.split('\n')
    
    for line in lines:
        temp = line.split('\t')
        table = temp[-1]
        query = ''.join(temp[:-1])
        print(temp)
        # print(query)
        print(query,table)
        print('\n')

        if table not in output:
            output[table] = []
        output[table].append(query)



['SELECT count(*) FROM head WHERE age  >  56', 'department_management']
SELECT count(*) FROM head WHERE age  >  56 department_management


['SELECT name ,  born_state ,  age FROM head ORDER BY age', 'department_management']
SELECT name ,  born_state ,  age FROM head ORDER BY age department_management


['SELECT creation ,  name ,  budget_in_billions FROM department', 'department_management']
SELECT creation ,  name ,  budget_in_billions FROM department department_management


['SELECT max(budget_in_billions) ,  min(budget_in_billions) FROM department', 'department_management']
SELECT max(budget_in_billions) ,  min(budget_in_billions) FROM department department_management


['SELECT avg(num_employees) FROM department WHERE ranking BETWEEN 10 AND 15', 'department_management']
SELECT avg(num_employees) FROM department WHERE ranking BETWEEN 10 AND 15 department_management


["SELECT name FROM head WHERE born_state != 'California'", 'department_management']
SELECT name FROM head WHERE born_

In [50]:
print(output)

{'department_management': ['SELECT count(*) FROM head WHERE age  >  56', 'SELECT name ,  born_state ,  age FROM head ORDER BY age', 'SELECT creation ,  name ,  budget_in_billions FROM department', 'SELECT max(budget_in_billions) ,  min(budget_in_billions) FROM department', 'SELECT avg(num_employees) FROM department WHERE ranking BETWEEN 10 AND 15', "SELECT name FROM head WHERE born_state != 'California'", "SELECT DISTINCT T1.creation FROM department AS T1 JOIN management AS T2 ON T1.department_id  =  T2.department_id JOIN head AS T3 ON T2.head_id  =  T3.head_id WHERE T3.born_state  =  'Alabama'", 'SELECT born_state FROM head GROUP BY born_state HAVING count(*)  >=  3', 'SELECT creation FROM department GROUP BY creation ORDER BY count(*) DESC LIMIT 1', "SELECT T1.name ,  T1.num_employees FROM department AS T1 JOIN management AS T2 ON T1.department_id  =  T2.department_id WHERE T2.temporary_acting  =  'Yes'", 'SELECT count(DISTINCT temporary_acting) FROM management', 'SELECT count(*) FRO