## Imports and setup

In [2]:
from openai import OpenAI

from tkinter import * 
from tkinter import ttk

import sqlalchemy
import os
from tabulate import tabulate

from dotenv import load_dotenv
import os 

from sqlalchemy import create_engine, text,Column, Integer, String, Float, Date, ForeignKey
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import relationship


In [3]:

load_dotenv()

oracle_connection_string = os.getenv('ORACLE_CONNECTION_STRING')

if oracle_connection_string:
    try:
        engine = create_engine(oracle_connection_string)
        connection = engine.connect()
        print("Connected to Oracle database successfully!")
        connection.close()
    except Exception as e:
        print("Error:", e)
else:
    print("Error: ORACLE_CONNECTION_STRING environment variable is not set.")
 
engine = create_engine(oracle_connection_string)
connection = engine.connect()


Connected to Oracle database successfully!


## LLM Helper Functions

In [4]:
with open('primer.txt', 'r') as file:
    primer = file.read()


primobj = {'role':'system','content':primer}

In [5]:
client = OpenAI()
def convCompletion(completion):
    return {'role':'assistant','content':completion.choices[0].message.content}
def convQuery(query):
    return {'role': 'user','content':query}

# completion = client.chat.completions.create(
#   model="gpt-3.5-turbo",
#   messages=[
#     primobj,
#     {"role": "user", "content": "Compose a poem that explains the concept of recursion in programming."},
#   ]
# )

def doConv(t=3):
    msgs = [primobj]
    for i in range(t):
        a = input('Enter query')
        query = convQuery(a)
        msgs.append(query) 
        completion = client.chat.completions.create(
          model="gpt-3.5-turbo",
          messages=msgs
        )
        print('user:',a)
        print('bot:',completion.choices[0].message.content)
        msgs.append(convCompletion(completion))


def sendUserQuery(msg):
    query = convQuery(msg)
    msgs = [primobj, query]
    msgs.append(query) 
    completion = client.chat.completions.create(
        model="gpt-3.5-turbo",
        messages=msgs
    )
    return completion.choices[0].message.content
    
def executeQuery(msg):
    print('user:',msg)
    print('bot:',sendUserQuery(msg)) 
# doConv()
        
# executeQuery('find cheapest product at store name honey farms')
# print(completion.choices[0].message.content)

## UI Code

In [6]:
def guiDoQuery(*args):
    res = sendUserQuery(userquery.get())
    print(res)
    botresponse.set(res)

def executeSQL(*args):
    try:
        runquery = connection.execute(text(sql.get()))
        result = runquery.fetchall()
        cols = runquery.keys()
        runquery.close()

        # print(type(cols))
        table = [list(cols)]+result
        temp = tabulate(table,headers='firstrow',tablefmt='fancy_grid')

        sqlresponse.set(temp)
    except Exception as e:
        sqlresponse.set(f'error: {e}')
        


In [7]:


root = Tk()
root.title("SQL LLM")



mainframe = ttk.Frame(root, padding="3 3 12 12")
mainframe.grid(column=0, row=0, sticky=(N, W, E, S))
root.columnconfigure(0, weight=1)
root.rowconfigure(0, weight=1)

userquery = StringVar()
userquery_entry = ttk.Entry(mainframe, width=10, textvariable=userquery)
ttk.Label(mainframe, text='Enter query:').grid(column=1,row=1,sticky=(W))
userquery_entry.grid(column=2, row=1, sticky=(W, E))
ttk.Button(mainframe, text='Execute Query',command=guiDoQuery).grid(row=1,column=3)

botresponse = StringVar()
botresponse.set('hi')

sql = StringVar() 
sql_entry = ttk.Entry(mainframe,width=20, textvariable=sql).grid(column=1,row=5)



sqlresponse = StringVar()
sqlresponse.set('Sample SqL response')
responselabel = ttk.Label(mainframe,padding=('5 5 5 5'),relief=SOLID,borderwidth=1,textvariable=sqlresponse).grid(column=2,row=5,rowspan=2,columnspan=10,sticky=W)

ttk.Label(mainframe,text='Bot Response').grid(column=10,row=1)
ttk.Label(mainframe, textvariable=botresponse).grid(column=11, row=1, sticky=(W, E))
# ttk.Label(mainframe, textvariable=botresponse).grid(column=3,row=2)
# ttk.Button(mainframe, text="Calculate", command=calculate).grid(column=3, row=3, sticky=W)


ttk.Label(mainframe,text='SQL To Execute:').grid(column=1,row=3,sticky=W)
ttk.Button(mainframe,text='Execute SQL',command=executeSQL).grid(column=2,row=3,sticky=W)
for child in mainframe.winfo_children(): 
    child.grid_configure(padx=5, pady=5)

userquery_entry.focus()



testquery = [ ['name','id','test'],['sdf',3,'ewr'],['sdf',3,'ewr'],['sdf',3,'ewr'],]
sqlresponse.set(tabulate(testquery,headers='firstrow',tablefmt='psql'))
root.mainloop()

SELECT * 
FROM Transactions 
WHERE mid = 1


## Testing SQL

In [6]:

load_dotenv()

oracle_connection_string = os.getenv('ORACLE_CONNECTION_STRING')

if oracle_connection_string:
    try:
        engine = create_engine(oracle_connection_string)
        connection = engine.connect()
        print("Connected to Oracle database successfully!")
        connection.close()
    except Exception as e:
        print("Error:", e)
else:
    print("Error: ORACLE_CONNECTION_STRING environment variable is not set.")
 
engine = create_engine(oracle_connection_string)
connection = engine.connect()

# statement = text("SELECT * FROM members" )


Connected to Oracle database successfully!


In [28]:
runquery = connection.execute(text('SELECT * FROM MEMBERS'))
result = runquery.fetchall()
cols = runquery.keys()
runquery.close()

# print(type(cols))
table = [list(cols)]+result

print(tabulate(table,headers='firstrow',tablefmt='fancy_grid'))

╒═══════╤════════════╤═════════════════╤═══════════════╕
│   mid │ name       │ address         │   phonenumber │
╞═══════╪════════════╪═════════════════╪═══════════════╡
│     1 │ John Smith │ 123 Main Street │             1 │
├───────┼────────────┼─────────────────┼───────────────┤
│     2 │ Alice Lee  │ 456 Elm Avenue  │             1 │
╘═══════╧════════════╧═════════════════╧═══════════════╛


In [13]:
print(list(cols))


['mid', 'name', 'address', 'phonenumber']


## Testing Suite

In [43]:
def compSQL(predicted,sql2):
    real = connection.execute(text(sql2))
    res2 = real.fetchall()
    cols2 = real.keys()
    real.close()
    

    try:
        predicted = connection.execute(text(predicted))
        res1 = predicted.fetchall()
        cols1 = predicted.keys()
        predicted.close()
    except Exception as e:
        print(e)
        return 'Error'
    
    return cols1 == cols2 and res1 == res2
    


file = open('userqueries.txt','r')
restable=[['User Query','Generated SQL','Is Correct']]
userqueries = [l for l in file]

realqueries = extract_queries_from_file('test_queries.sql')


for i in range(len(userqueries)):
    r = sendUserQuery(userqueries[i])
    r = r.replace(';','')
    r = r.replace('\n',' ')
    restable.append([userqueries[i][:20] +'...' if len(userqueries[i])>20 else '',r,compSQL(r,realqueries[i])])


print(tabulate(restable,headers='firstrow',tablefmt='fancy_grid'))
    



(cx_Oracle.DatabaseError) ORA-00907: missing right parenthesis
[SQL: SELECT p.name FROM Products p JOIN (SELECT pid 	FROM Discounts 	ORDER BY newPrice / price DESC 	LIMIT 1) d ON p.pid = d.pid JOIN Produces pr ON p.pid = pr.pid JOIN Producers prd ON pr.prID = prd.prID]
(Background on this error at: https://sqlalche.me/e/20/4xp6)
(cx_Oracle.DatabaseError) ORA-00907: missing right parenthesis
[SQL: SELECT AVG(E.salary) FROM Employees E WHERE E.storeID =      (SELECT T1.sid     FROM          (SELECT T.sid, COUNT(T.tid) AS num_transactions         FROM Transactions T         WHERE T.pdate BETWEEN '2024-03-01' AND '2024-03-31'         GROUP BY T.sid         ORDER BY num_transactions DESC         LIMIT 1) AS T1     )]
(Background on this error at: https://sqlalche.me/e/20/4xp6)
(cx_Oracle.DatabaseError) ORA-30089: missing or invalid <datetime field>
[SQL: SELECT Members.phonenumber FROM Members JOIN Transactions ON Members.mid = Transactions.mid WHERE Transactions.pdate < CURRENT_DATE - INTE

In [31]:
import re

queries = ''
def extract_queries_from_file(file_path):
    with open(file_path, 'r') as file:
        sql_content = file.read()

    # Split SQL content into individual queries
    # queries = re.split(r';\s*\n(?![\s\w]*SELECT)', sql_content)
    queries = sql_content.split(';')
    queries = [query.strip() for query in queries if query.strip()]

    # Remove empty or whitespace-only queries
    # queries = [query.strip() for query in queries if query.strip()]

    return queries

queries = extract_queries_from_file('test_queries.sql')

print(queries)
print(len(queries))
print(queries[10])
connection = engine.connect()

for i in range(len(queries)):
    real = connection.execute(text(queries[i]))
    print(i,real.fetchall())
    real.close()

['SELECT location\nFROM stores', 'SELECT name, location\nFROM producers', "SELECT pname\nFROM products\nWHERE category = 'Groceries'", 'SELECT COUNT(DISTINCT mid) AS num_customers\nFROM transactions\nWHERE sid = 1', 'SELECT COUNT(*) AS num_completed\nFROM transactions', 'SELECT name\nFROM employees\nORDER BY salary DESC\nFETCH FIRST 1 ROW ONLY', 'SELECT pid, pname\nFROM products\nWHERE quantity < 50', 'SELECT COUNT(DISTINCT category) AS num_categories\nFROM products', "SELECT address\nFROM members\nWHERE address LIKE '%Elm%'", 'SELECT openTime, closeTime\nFROM storehours\nWHERE sid = 1', 'SELECT pr.name AS manufacturer_name\nFROM producers pr\nJOIN produces p ON pr.prID = p.prID\nJOIN products pro ON pro.pid = p.pid\nJOIN discounts d ON pro.pid = d.pid\nORDER BY d.newprice / pro.price DESC\nFETCH FIRST 1 ROW ONLY', 'SELECT m.name AS customer_name, SUM(p.quantity * pr.price) AS total_price\nFROM members m\nJOIN transactions t ON m.mid = t.mid\nJOIN purchases p ON t.tid = p.tid\nJOIN pro