In [1]:

import sqlite3
import pandas as pd
import numpy as np
import openai
from dotenv import load_dotenv, find_dotenv
import os
import ast
load_dotenv(find_dotenv())

openai.api_key = os.environ.get("OPENAI_API_KEY")


In [2]:
def call_gpt(messages, model="gpt-4"):
    """Generates the intro script for the episode"""
    response = openai.ChatCompletion.create(
        model=model,
        messages=messages,
    )
    r = response['choices'][0]['message']['content']
    return r


In [18]:
#load the sqllite database
conn = sqlite3.connect('chinook.db')

#view the tables in the database
tables = pd.read_sql("""SELECT *
                        FROM sqlite_master
                        WHERE type='table';""", conn)
tables

Unnamed: 0,type,name,tbl_name,rootpage,sql
0,table,albums,albums,2,"CREATE TABLE ""albums""\r\n(\r\n [AlbumId] IN..."
1,table,sqlite_sequence,sqlite_sequence,3,"CREATE TABLE sqlite_sequence(name,seq)"
2,table,artists,artists,4,"CREATE TABLE ""artists""\r\n(\r\n [ArtistId] ..."
3,table,customers,customers,5,"CREATE TABLE ""customers""\r\n(\r\n [Customer..."
4,table,employees,employees,8,"CREATE TABLE ""employees""\r\n(\r\n [Employee..."
5,table,genres,genres,10,"CREATE TABLE ""genres""\r\n(\r\n [GenreId] IN..."
6,table,invoices,invoices,11,"CREATE TABLE ""invoices""\r\n(\r\n [InvoiceId..."
7,table,invoice_items,invoice_items,13,"CREATE TABLE ""invoice_items""\r\n(\r\n [Invo..."
8,table,media_types,media_types,15,"CREATE TABLE ""media_types""\r\n(\r\n [MediaT..."
9,table,playlists,playlists,16,"CREATE TABLE ""playlists""\r\n(\r\n [Playlist..."


In [19]:
#manually created data dictionary
table_dict = {'albums': '1 row for each album and its title, FK to artists',
 'artists': '1 row for each artist and their name',
 'customers': '1 row for each customers and their demographic information',
 'employees': '1 row for each employee and their demographic and organizational information',
 'genres': '1 row for each genre and its name',
 'invoices': '1 row for each invoice and its associated customer and billing information',
 'invoice_items': '1 row for each invoice line item, containing the associated invoice, track and quantity of tracks purchased',
 'media_types': '1 row for each media type and its name',
 'playlists': '1 row for each playlist and its name',
 'playlist_track': 'join table between playlists and tracks',
 'tracks': '1 row for each track and its album, media type, genre, etc.',
}

In [20]:
# formatting for prompt
table_string = "".join([f"`{k}` : {v}\n" for k, v in table_dict.items()])
print(table_string)

`albums` : 1 row for each album and its title, FK to artists
`artists` : 1 row for each artist and their name
`customers` : 1 row for each customers and their demographic information
`employees` : 1 row for each employee and their demographic and organizational information
`genres` : 1 row for each genre and its name
`invoices` : 1 row for each invoice and its associated customer and billing information
`invoice_items` : 1 row for each invoice line item, containing the associated invoice, track and quantity of tracks purchased
`media_types` : 1 row for each media type and its name
`playlists` : 1 row for each playlist and its name
`playlist_track` : join table between playlists and tracks
`tracks` : 1 row for each track and its album, media type, genre, etc.



In [21]:
#create a dict where the key is the table name and the value is a concise table schema
#loop through the tables, and query the sqllite DB to get the schema, then format it concisely:

table_schema_dict = {}
for table in tables['name']:
    table_schema = pd.read_sql(f"PRAGMA table_info({table})", conn)
    table_schema_dict[table] = "".join([f"{row['name']} : {row['type']}\n" for _, row in table_schema.iterrows()])

In [24]:
#view an example
print(table_schema_dict['customers'])

CustomerId : INTEGER
FirstName : NVARCHAR(40)
LastName : NVARCHAR(20)
Company : NVARCHAR(80)
Address : NVARCHAR(70)
City : NVARCHAR(40)
State : NVARCHAR(40)
Country : NVARCHAR(40)
PostalCode : NVARCHAR(10)
Phone : NVARCHAR(24)
Fax : NVARCHAR(24)
Email : NVARCHAR(60)
SupportRepId : INTEGER



In [25]:
#this prompt is used to pick which tables are relevant to the question
first_prompt = """
Given the below question, which of the following tables would you need to query to answer the question? Return your answer as a valid Python list of table names. If the question is not answerable with the given tables, return an empty list.

Question: {question}

Tables:
{table_string}
"""

In [26]:
question = "What is the most popular genre of music in the USA?"

messages = [
    {"role": "system", "content": "You are an expert Database Admin, and you are helping a client with a question about their database."},
    {"role": "user", "content": first_prompt.format(question=question, table_string=table_string)},
]

str_tables = call_gpt(messages)

In [27]:
str_tables

"['genres', 'tracks', 'invoice_items', 'invoices', 'customers']"

In [29]:
#safely evaluate the string as a python object
try:
    query_tables = ast.literal_eval(str_tables)
    assert type(query_tables) == list
    print(query_tables)
except:
    print(f"Model returned invalid python object, try again: {str_tables}")

['genres', 'tracks', 'invoice_items', 'invoices', 'customers']


In [35]:
#next, we need to generate the SQL query to answer the question given the specific columns in the preselected tables

query_prompt = """
Given the below question, and the columns in each of the tables you selected, write a SQL query to answer the question. 
Return your answer as a valid SQL query for SQLite. 
Do not include any formatting or other characters, just the SQL Code ending in a semicolon.

Question: {question}

Tables and their columns:

{table_schema}

SQL Query:
"""

In [36]:
table_schema = ''
for table in query_tables:
    table_schema += f"TABLE: {table}:\n{table_schema_dict[table]}\n"
print(table_schema)

TABLE: genres:
GenreId : INTEGER
Name : NVARCHAR(120)

TABLE: tracks:
TrackId : INTEGER
Name : NVARCHAR(200)
AlbumId : INTEGER
MediaTypeId : INTEGER
GenreId : INTEGER
Composer : NVARCHAR(220)
Milliseconds : INTEGER
Bytes : INTEGER
UnitPrice : NUMERIC(10,2)

TABLE: invoice_items:
InvoiceLineId : INTEGER
InvoiceId : INTEGER
TrackId : INTEGER
UnitPrice : NUMERIC(10,2)
Quantity : INTEGER

TABLE: invoices:
InvoiceId : INTEGER
CustomerId : INTEGER
InvoiceDate : DATETIME
BillingAddress : NVARCHAR(70)
BillingCity : NVARCHAR(40)
BillingState : NVARCHAR(40)
BillingCountry : NVARCHAR(40)
BillingPostalCode : NVARCHAR(10)
Total : NUMERIC(10,2)

TABLE: customers:
CustomerId : INTEGER
FirstName : NVARCHAR(40)
LastName : NVARCHAR(20)
Company : NVARCHAR(80)
Address : NVARCHAR(70)
City : NVARCHAR(40)
State : NVARCHAR(40)
Country : NVARCHAR(40)
PostalCode : NVARCHAR(10)
Phone : NVARCHAR(24)
Fax : NVARCHAR(24)
Email : NVARCHAR(60)
SupportRepId : INTEGER




In [37]:
messages = [
    {'role': 'system', 'content': 'You are an expert Data Analyst, and you are helping a client write a SQL query to answer a question about their database.'},
    {'role': 'user', 'content': query_prompt.format(question=question, table_schema=table_schema)}
]

str_query = call_gpt(messages)
print(str_query)

SELECT g.Name AS Genre, SUM(ii.Quantity) AS Total_Sold
FROM genres g
JOIN tracks t ON g.GenreId = t.GenreId
JOIN invoice_items ii ON t.TrackId = ii.TrackId
JOIN invoices i ON ii.InvoiceId = i.InvoiceId
JOIN customers c ON i.CustomerId = c.CustomerId
WHERE c.Country = 'USA'
GROUP BY g.Name
ORDER BY Total_Sold DESC
LIMIT 1;


In [40]:
# run the query: 
try: 
    tabular_answer = pd.read_sql(str_query, conn)
    print(tabular_answer)
except:
    print(f"Model returned invalid SQL query, try again: {str_query}")

  Genre  Total_Sold
0  Rock         157


In [41]:
final_prompt = """
Given the below question, and the tabular data returned by your SQL query, write a short answer to the question, including an explanation of the query used to generate the answer.

Question: {question}

SQL Query: {str_query}

Answer: {tabular_answer}
"""

In [42]:
messages = [
    {'role': 'system', 'content': 'You are an expert Data Analyst, and you are helping a client write a short answer to a question about their database.'},
    {'role': 'user', 'content': final_prompt.format(question=question, str_query=str_query, tabular_answer=tabular_answer)}
]

str_answer = call_gpt(messages)
print(str_answer)

The most popular genre of music in the USA, according to the database, is Rock. The SQL query used to determine this joins five different tables: genres, tracks, invoice_items, invoices, and customers. The query first links each genre of music to its tracks, which are then linked to items on customer invoices. It filters by customers from the USA. Each sold item quantity is summed up by genre and then sorted in descending order. The top result, with a total sold of 157, is Rock.


# Other Future Considerations:
 - What if the question isn't answerable by the data? Table query will return an empty string - then it should prompt the user for clarification, or let them know we don't have data to answer their question.
 - What if the question isn't specific enough (ex: the above sums quantity, but that's not the only valid measure of popularity)? LLM can list out reasonable possibilities and prompt the user to choose.
 - Program should suggest alternative spellings (USA vs. U.S.A in the above query) - a solution here would depend on the specifics of the problem and dataset