In [1]:

from llama_index.core.retrievers import NLSQLRetriever

import sqlite3
import os
import llama_index
import sqlalchemy
import openai

from sqlalchemy import (
    create_engine,
    MetaData,
    Table,
    Column,
    String,
    Integer,
    select,
)

from llama_index.core import SQLDatabase
from llama_index.llms.openai import OpenAI


from sqlalchemy import text


In [2]:
import os
from sqlalchemy.orm import sessionmaker

In [2]:
api_key = os.getenv("OPENAI_API_KEY")
openai.api_key = api_key





In [9]:
#Establishing engines and connections



engine = create_engine("sqlite:///databases/courses_temp.db")

llm = OpenAI(temperature=0.1, model="gpt-3.5-turbo")
sql_database = SQLDatabase(engine, include_tables=["class_data"])


In [5]:
#sqlite query
cursor.execute("PRAGMA table_info(class_data)")
tables = cursor.fetchall()
for col in tables:
    print(col[1], col[2])

course_campus TEXT
course_credits TEXT
course_description TEXT
course_id TEXT
quarters_offered TEXT
course_title TEXT
department_abbrev TEXT
mean_gpa REAL


In [46]:

#sqlalchemy query

with engine.connect() as con:
    rows = con.execute(text("SELECT course_title, course_id\nFROM class_data\nWHERE department_abbrev = 'CSE'\nAND course_credits = '4.0'\nORDER BY mean_gpa DESC\nLIMIT 4"))
    for row in rows:
        print(row)

('Computer Vision', 'CSE 455')
('Computer Security', 'CSE 484')
('Datacenter Systems', 'CSE 453')
('Autonomous Robotics', 'CSE 478')


In [None]:
course_ids = ['CSE 455', 'CSE 484']
formatted_ids = ", ".join(f"'{course_id}'" for course_id in course_ids)
print(f'{course_ids}')

['CSE 455', 'CSE 484']


In [56]:
returns = []
if returns:
    print("yes")

In [54]:

#sqlalchemy query

with engine.connect() as con:
    rows = con.execute(text(f'SELECT course_title, course_id\nFROM class_data\nWHERE course_id IN ({formatted_ids})'))
    for row in rows:
        print(row)

('Computer Vision', 'CSE 455')
('Computer Security', 'CSE 484')


In [10]:
#llama index retriever

nl_sql_retriever = NLSQLRetriever(
    sql_database, tables=["class_data"], return_raw=True
)


In [44]:
#llama index query retriever


query = """
Schema Context:
Department_abbrev column represents the department and is abbreviated in all caps. Example: CS for Computer Science.
Course_credits contain floating points. Example: 4.0.

Do not order by unless the question suggests that you should.

Always return course_id

Question:
What are some CSE courses with the highest average gpa?"""
 

response = nl_sql_retriever.retrieve(query)
print(response)


[NodeWithScore(node=TextNode(id_='0a273585-842e-4eaa-b9fa-aed137be4471', embedding=None, metadata={'sql_query': "SELECT course_id, mean_gpa\nFROM class_data\nWHERE department_abbrev = 'CSE'\nORDER BY mean_gpa DESC", 'result': [('CSE 496', 3.9976190476190476), ('CSE 460', 3.9870588235294115), ('CSE 498', 3.982476635514019), ('CSE 428', 3.940625), ('CSE 455', 3.9210769230769227), ('CSE 475', 3.9111111111111114), ('CSE 482', 3.907228915662651), ('CSE 481', 3.872076023391813), ('CSE 484', 3.871465461588121), ('CSE 453', 3.852100840336134), ('CSE 478', 3.851572327044025), ('CSE 458', 3.8224719101123594), ('CSE 464', 3.80919540229885), ('CSE 495', 3.801298701298701), ('CSE 427', 3.78663967611336), ('CSE 444', 3.768900804289544), ('CSE 180', 3.7555819477434675), ('CSE 493', 3.7503124999999997), ('CSE 469', 3.742528735632184), ('CSE 451', 3.7359999999999998), ('CSE 452', 3.7126271186440674), ('CSE 490', 3.6962088698140203), ('CSE 369', 3.694385026737968), ('CSE 442', 3.691517323775388), ('CSE 

In [26]:
print(response[0].get_content())
print(response[0].get_score())
print(response[0])
print(response[0].metadata)
print(type(response[0]))

[]
0.0
Node ID: 390cf0b6-de06-4ca8-a972-46cd42be5f65
Text: []
Score: None

{'sql_query': "SELECT course_title, course_description, course_credits \nFROM class_data \nWHERE department_abbrev = 'UDP' OR department_abbrev = 'URBAN' OR department_abbrev = 'DP' OR department_abbrev = 'URBDP' OR department_abbrev = 'URB'", 'result': [], 'col_keys': ['course_title', 'course_description', 'course_credits']}
<class 'llama_index.core.schema.NodeWithScore'>


In [30]:
if not response[0].metadata['result']:
    print("Condition met: result is falsy")


Condition met: result is falsy


In [32]:
response[0].metadata['sql_query'].splitlines()[0]

'SELECT course_title, course_description, course_credits '

In [40]:
sql = "SELECT course_title, course_description, course_credits /nFROM class_data /nWHERE department_abbrev = 'UDP' OR department_abbrev = 'URBAN' OR department_abbrev = 'DP' OR department_abbrev = 'URBDP' OR department_abbrev = 'URB AND test'"

for line_num, line in enumerate(sql.splitlines()):
    if 'WHERE' in line:
        where_clause = line
        line_num = line_num

        new_clause = where_clause.split('nothing')[:-1]
        split_sql = sql.splitlines()
        split_sql[line_num] = new_clause
        sql = split_sql

print(sql)
print(response[0].metadata['sql_query'])

[[]]
SELECT course_title, course_description, course_credits 
FROM class_data 
WHERE department_abbrev = 'UDP' OR department_abbrev = 'URBAN' OR department_abbrev = 'DP' OR department_abbrev = 'URBDP' OR department_abbrev = 'URB'
