In [1]:

from llama_index.core.retrievers import NLSQLRetriever

import sqlite3
import os
import llama_index
import sqlalchemy
import openai

from sqlalchemy import (
    create_engine,
    MetaData,
    Table,
    Column,
    String,
    Integer,
    select,
)

from llama_index.core import SQLDatabase
from llama_index.llms.openai import OpenAI


from sqlalchemy import text


In [2]:
import os
from sqlalchemy.orm import sessionmaker

In [2]:
api_key = os.getenv("OPENAI_API_KEY")
openai.api_key = api_key





In [9]:
#Establishing engines and connections



engine = create_engine("sqlite:///databases/courses_temp.db")

llm = OpenAI(temperature=0.1, model="gpt-3.5-turbo")
sql_database = SQLDatabase(engine, include_tables=["class_data"])


In [5]:
#sqlite query
cursor.execute("PRAGMA table_info(class_data)")
tables = cursor.fetchall()
for col in tables:
    print(col[1], col[2])

course_campus TEXT
course_credits TEXT
course_description TEXT
course_id TEXT
quarters_offered TEXT
course_title TEXT
department_abbrev TEXT
mean_gpa REAL


In [9]:

#sqlalchemy query

with engine.connect() as con:
    rows = con.execute(text("SELECT course_title\nFROM class_data\nWHERE department_abbrev = 'CSE'\nAND course_credits = '4.0'\nORDER BY mean_gpa DESC\nLIMIT 4"))
    for row in rows:
        print(row)

('Computer Vision',)
('Computer Security',)
('Datacenter Systems',)
('Autonomous Robotics',)


In [10]:
#llama index retriever

nl_sql_retriever = NLSQLRetriever(
    sql_database, tables=["class_data"], return_raw=True
)


In [11]:
#llama index query retriever


query = """
Schema Context:
Department_abbrev column represents the department and is abbreviated in all caps. Example: CS for Computer Science.
Course_credits contain floating points. Example: 4.0.

Choose one of these for urban: UDP, URBAN, DP, URBDP, URB
Question:
What are some urban design and planning courses?"""
 

response = nl_sql_retriever.retrieve(query)
print(response)


[NodeWithScore(node=TextNode(id_='390cf0b6-de06-4ca8-a972-46cd42be5f65', embedding=None, metadata={'sql_query': "SELECT course_title, course_description, course_credits \nFROM class_data \nWHERE department_abbrev = 'UDP' OR department_abbrev = 'URBAN' OR department_abbrev = 'DP' OR department_abbrev = 'URBDP' OR department_abbrev = 'URB'", 'result': [], 'col_keys': ['course_title', 'course_description', 'course_credits']}, excluded_embed_metadata_keys=['sql_query', 'result', 'col_keys'], excluded_llm_metadata_keys=['sql_query', 'result', 'col_keys'], relationships={}, metadata_template='{key}: {value}', metadata_separator='\n', text='[]', mimetype='text/plain', start_char_idx=None, end_char_idx=None, metadata_seperator='\n', text_template='{metadata_str}\n\n{content}'), score=None)]


In [26]:
print(response[0].get_content())
print(response[0].get_score())
print(response[0])
print(response[0].metadata)
print(type(response[0]))

[]
0.0
Node ID: 390cf0b6-de06-4ca8-a972-46cd42be5f65
Text: []
Score: None

{'sql_query': "SELECT course_title, course_description, course_credits \nFROM class_data \nWHERE department_abbrev = 'UDP' OR department_abbrev = 'URBAN' OR department_abbrev = 'DP' OR department_abbrev = 'URBDP' OR department_abbrev = 'URB'", 'result': [], 'col_keys': ['course_title', 'course_description', 'course_credits']}
<class 'llama_index.core.schema.NodeWithScore'>


In [30]:
if not response[0].metadata['result']:
    print("Condition met: result is falsy")


Condition met: result is falsy


In [32]:
response[0].metadata['sql_query'].splitlines()[0]

'SELECT course_title, course_description, course_credits '

In [40]:
sql = "SELECT course_title, course_description, course_credits /nFROM class_data /nWHERE department_abbrev = 'UDP' OR department_abbrev = 'URBAN' OR department_abbrev = 'DP' OR department_abbrev = 'URBDP' OR department_abbrev = 'URB AND test'"

for line_num, line in enumerate(sql.splitlines()):
    if 'WHERE' in line:
        where_clause = line
        line_num = line_num

        new_clause = where_clause.split('nothing')[:-1]
        split_sql = sql.splitlines()
        split_sql[line_num] = new_clause
        sql = split_sql

print(sql)
print(response[0].metadata['sql_query'])

[[]]
SELECT course_title, course_description, course_credits 
FROM class_data 
WHERE department_abbrev = 'UDP' OR department_abbrev = 'URBAN' OR department_abbrev = 'DP' OR department_abbrev = 'URBDP' OR department_abbrev = 'URB'
