In [1]:
import duckdb
from dotenv import load_dotenv
import pandas as pd
from langchain_community.utilities import SQLDatabase
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
import sqlite3

In [2]:
load_dotenv(dotenv_path="../.env")

True

In [3]:
llm = ChatOpenAI()

In [4]:
sqlite_alchemy_uri = "sqlite:///../data/Chinook.db"
sqlite_uri = "../data/Chinook.db"

db = SQLDatabase.from_uri(sqlite_alchemy_uri)
sqlite_con = sqlite3.connect(sqlite_uri)

In [5]:
con = duckdb.connect(":memory:")

## Ask for a query based on user question


In [6]:
template = """Based on the table schema below, write a SQL query that would answer the user's question:
{schema}

Question: {question}
SQL Query:"""

prompt = ChatPromptTemplate.from_template(template)

In [7]:
def get_schema(_):
    return db.get_table_info()

In [8]:
sql_chain = (
    RunnablePassthrough.assign(schema=get_schema)
    | prompt
    | llm.bind(stop=["\nSQLResult:"])
    | StrOutputParser()
)

In [18]:
user_question = "what are the top 5 artists by sales?"
sql_chain_result = sql_chain.invoke({"question": user_question})
sql_chain_result

'SELECT ar.Name AS Artist, SUM(il.UnitPrice * il.Quantity) AS TotalSales\nFROM Artist ar\nJOIN Album al ON ar.ArtistId = al.ArtistId\nJOIN Track t ON al.AlbumId = t.AlbumId\nJOIN InvoiceLine il ON t.TrackId = il.TrackId\nGROUP BY ar.ArtistId\nORDER BY TotalSales DESC\nLIMIT 5;'

In [19]:
df = pd.read_sql_query(sql_chain_result, sqlite_con)
df.head()

Unnamed: 0,Artist,TotalSales
0,Iron Maiden,138.6
1,U2,105.93
2,Metallica,90.09
3,Led Zeppelin,86.13
4,Lost,81.59


## Ask for a Natural language response based on user question


In [20]:
template = """Based on the table schema below, question, sql query write a natural language response:
{schema}

Question: {question}
SQL Query: {query}
SQL Response: {response}"""

prompt_response = ChatPromptTemplate.from_template(template=template)

In [21]:
def run_query(query):
    return db.run(query)

In [22]:
full_chain = (
    RunnablePassthrough.assign(query=sql_chain).assign(
        schema=get_schema,
        response=lambda vars: run_query(vars["query"]),
    )
    | prompt_response
    | llm.bind(stop=["\nNatural Language Response:"])
    | StrOutputParser()
)

In [25]:
user_question = "how many albums are there in the database?"
full_chain.invoke({"question": user_question})

'There are 347 albums in the database.'