In [2]:
from dotenv import load_dotenv, find_dotenv
from langchain_community.utilities import SQLDatabase

load_dotenv(find_dotenv())

True

In [3]:
db = SQLDatabase.from_uri("sqlite:///Chinook.db")
print(db.dialect)
print(db.get_usable_table_names())
db.run("SELECT * FROM Artist LIMIT 10;")

sqlite
['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']


"[(1, 'AC/DC'), (2, 'Accept'), (3, 'Aerosmith'), (4, 'Alanis Morissette'), (5, 'Alice In Chains'), (6, 'Antônio Carlos Jobim'), (7, 'Apocalyptica'), (8, 'Audioslave'), (9, 'BackBeat'), (10, 'Billy Cobham')]"

In [4]:

examples = [
    {"input": "List all artists.", "query": "SELECT * FROM Artist;"},
    {
        "input": "Find all albums for the artist 'AC/DC'.",
        "query": "SELECT * FROM Album WHERE ArtistId = (SELECT ArtistId FROM Artist WHERE Name = 'AC/DC');",
    },
    {
        "input": "how much of the total revenue was related to black sabbath track sales?",
        "query": "SELECT SUM(Total) FROM Invoice WHERE InvoiceId IN (SELECT InvoiceId FROM InvoiceLine WHERE TrackId IN (SELECT TrackId FROM Track WHERE AlbumId IN (SELECT AlbumId FROM Album WHERE ArtistId IN (SELECT ArtistId FROM Artist WHERE Name = 'Black Sabbath'))));",
    },
    {
        "input": "List all tracks in the 'Rock' genre.",
        "query": "SELECT * FROM Track WHERE GenreId = (SELECT GenreId FROM Genre WHERE Name = 'Rock');",
    },
    {
        "input": "Find the total duration of all tracks.",
        "query": "SELECT SUM(Milliseconds) FROM Track;",
    },
    {
        "input": "List all customers from Canada.",
        "query": "SELECT * FROM Customer WHERE Country = 'Canada';",
    },
    {
        "input": "How many tracks are there in the album with ID 5?",
        "query": "SELECT COUNT(*) FROM Track WHERE AlbumId = 5;",
    },
    {
        "input": "Find the total number of invoices.",
        "query": "SELECT COUNT(*) FROM Invoice;",
    },
    {
        "input": "List all tracks that are longer than 5 minutes.",
        "query": "SELECT * FROM Track WHERE Milliseconds > 300000;",
    },
    {
        "input": "Who are the top 5 customers by total purchase?",
        "query": "SELECT CustomerId, SUM(Total) AS TotalPurchase FROM Invoice GROUP BY CustomerId ORDER BY TotalPurchase DESC LIMIT 5;",
    },
    {
        "input": "Which albums are from the year 2000?",
        "query": "SELECT * FROM Album WHERE strftime('%Y', ReleaseDate) = '2000';",
    },
    {
        "input": "How many employees are there",
        "query": 'SELECT COUNT(*) FROM "Employee"',
    },
]

In [5]:
from langchain_community.vectorstores import Annoy
from langchain_core.example_selectors import SemanticSimilarityExampleSelector
from langchain_openai import OpenAIEmbeddings

#Consider changing to Milvus vector database

example_selector = SemanticSimilarityExampleSelector.from_examples(
    examples,
    OpenAIEmbeddings(),
    Annoy,
    k=5,
    input_keys=["input"],
)

In [6]:
from langchain_core.prompts import (
    ChatPromptTemplate,
    FewShotPromptTemplate,
    MessagesPlaceholder,
    PromptTemplate,
    SystemMessagePromptTemplate,
)

system_prefix = """You are an agent designed to interact with a SQL database.
Given an input question, create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.
Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most {top_k} results.
You can order the results by a relevant column to return the most interesting examples in the database.
Never query for all the columns from a specific table, only ask for the relevant columns given the question.
You have access to tools for interacting with the database.
Only use the given tools. Only use the information returned by the tools to construct your final answer.
You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.

DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.

If the question does not seem related to the database, just return "I don't know" as the answer.

Here are some examples of user inputs and their corresponding SQL queries:"""

few_shot_prompt = FewShotPromptTemplate(
    example_selector=example_selector,
    example_prompt=PromptTemplate.from_template(
        "User input: {input}\nSQL query: {query}"
    ),
    input_variables=["input", "dialect", "top_k"],
    prefix=system_prefix,
    suffix="",
)

In [7]:
full_prompt = ChatPromptTemplate.from_messages(
    [
        SystemMessagePromptTemplate(prompt=few_shot_prompt),
        ("human", "{input}"),
        MessagesPlaceholder("agent_scratchpad"),
    ]
)

In [8]:
import ast
import re


def query_as_list(db, query):
    res = db.run(query)
    res = [el for sub in ast.literal_eval(res) for el in sub if el]
    res = [re.sub(r"\b\d+\b", "", string).strip() for string in res]
    return list(set(res))


artists = query_as_list(db, "SELECT Name FROM Artist")
albums = query_as_list(db, "SELECT Title FROM Album")
artists[:50]

['Academy of St. Martin in the Fields & Sir Neville Marriner',
 'Santana Feat. Eric Clapton',
 'Black Eyed Peas',
 'Habib Koité and Bamada',
 'Cássia Eller',
 'Baby Consuelo',
 'Zeca Pagodinho',
 'Buddy Guy',
 'Pedro Luís E A Parede',
 'Lenny Kravitz',
 'The  Cellists of The Berlin Philharmonic',
 'Azymuth',
 'Kid Abelha',
 'Jaguares',
 'Sandra De Sá',
 'Wilhelm Kempff',
 'Academy of St. Martin in the Fields Chamber Ensemble & Sir Neville Marriner',
 "The King's Singers",
 'Dread Zeppelin',
 'The Office',
 'Otto Klemperer & Philharmonia Orchestra',
 "Charles Dutoit & L'Orchestre Symphonique de Montréal",
 'Eric Clapton',
 'Adrian Leaper & Doreen de Feis',
 'The Flaming Lips',
 'Van Halen',
 'Pink Floyd',
 'Instituto',
 'Funk Como Le Gusta',
 'Mundo Livre S/A',
 'Roger Norrington, London Classical Players',
 'Hermeto Pascoal',
 'Milton Nascimento',
 'Iron Maiden',
 "Guns N' Roses",
 'Nando Reis',
 'Velvet Revolver',
 'Aaron Goldberg',
 'Page & Plant',
 'Heroes',
 'Berliner Philharmonike

In [9]:
from langchain.agents.agent_toolkits import create_retriever_tool

vector_db = Annoy.from_texts(artists + albums, OpenAIEmbeddings())
retriever = vector_db.as_retriever(search_kwargs={"k": 5})
description = """Use to look up values to filter on. Input is an approximate spelling of the proper noun, output is \
valid proper nouns. Use the noun most similar to the search."""
retriever_tool = create_retriever_tool(
    retriever,
    name="search_proper_nouns",
    description=description,
)

In [10]:
from langchain.agents import AgentType
from langchain_community.agent_toolkits import create_sql_agent, SQLDatabaseToolkit
from langchain_openai import ChatOpenAI

llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)

agent = create_sql_agent(
    llm=llm,
    db=db,
    prompt=full_prompt,
    verbose=True,
    agent_type="openai-tools",
)

agent.invoke(
    "What was the total revenue from all sales?"
)

agent.invoke(
    "which album contributed the most to the total revenue?"
)

agent.invoke(
    "Which country is most likely to listen to rock music?"
)

agent.invoke(
    "How many albums does blacksabbath have?"
)

# agent.invoke(
#     "How many albums does alis in chain have?"
# )

agent.invoke("How many albums where sold in the usa?")

# agent.invoke("How many albums where sold in USA?")

In [12]:
agent.invoke("how much of the total revenue was related ac/dc track sales?")
agent.invoke("how much of the total revenue was related led zeppelin track sales?")



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3m
Invoking: `sql_db_query` with `SELECT SUM(UnitPrice * Quantity) AS Revenue FROM InvoiceLine WHERE TrackId IN (SELECT TrackId FROM Track WHERE AlbumId IN (SELECT AlbumId FROM Album WHERE ArtistId IN (SELECT ArtistId FROM Artist WHERE Name = 'AC/DC')))`


[0m[36;1m[1;3m[(15.84,)][0m[32;1m[1;3mThe total revenue related to AC/DC track sales is $15.84.[0m

[1m> Finished chain.[0m


[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3m
Invoking: `sql_db_query` with `SELECT SUM(Total) FROM Invoice WHERE InvoiceId IN (SELECT InvoiceId FROM InvoiceLine WHERE TrackId IN (SELECT TrackId FROM Track WHERE AlbumId IN (SELECT AlbumId FROM Album WHERE ArtistId IN (SELECT ArtistId FROM Artist WHERE Name = 'Led Zeppelin'))))`


[0m[36;1m[1;3m[(204.93,)][0m[32;1m[1;3mThe total revenue related to Led Zeppelin track sales is $204.93.[0m

[1m> Finished chain.[0m


{'input': 'how much of the total revenue was related led zeppelin track sales?',
 'output': 'The total revenue related to Led Zeppelin track sales is $204.93.'}