### Understanding SQL Agents

In [1]:
import os

from dotenv import load_dotenv
from langchain.text_splitter import CharacterTextSplitter
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain_community.vectorstores import Chroma
from langchain_community.document_loaders import TextLoader
from langchain_community.utilities import SQLDatabase

In [8]:
load_dotenv(override=True)

openai_api_key = os.getenv("OPENAI_API_KEY")

if not openai_api_key:
    raise ValueError("OPENAI_API_KEY environment variable is not set.")

temperature = 0.7
max_tokens = 1500
model_name = "gpt-3.5-turbo"

llm = ChatOpenAI(
    model=model_name,
    temperature=temperature,
    max_tokens=max_tokens,
    openai_api_key=openai_api_key
)

In [9]:
db_path = "sqlite:///./chinook.db"
db = SQLDatabase.from_uri(db_path)

print(db.dialect)
print(db.get_usable_table_names())

sqlite
['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']


In [10]:
from langchain.chains import create_sql_query_chain

chain = create_sql_query_chain(
    llm=llm,
    db=db,
)

In [11]:
question = "How many employees are there in the database?"

response = chain.invoke({
    "question": question
})

print("Question:", question)
print("Response:", response)

Question: How many employees are there in the database?
Response: SELECT COUNT(EmployeeId) AS TotalEmployees
FROM Employee;


In [6]:
question = "which country's customers have spent the most?"

chain = create_sql_query_chain(
    llm=llm,
    db=db,
)

response = chain.invoke({
    "question": question
})

print("Question:", question)
print("Response:", response)

Question: which country's customers have spent the most?
Response: ```sql
SELECT "Customer"."Country", SUM("Invoice"."Total") AS "TotalSpent"
FROM "Customer"
JOIN "Invoice" ON "Customer"."CustomerId" = "Invoice"."CustomerId"
GROUP BY "Customer"."Country"
ORDER BY "TotalSpent" DESC
LIMIT 5;
```


In [None]:
question = "How many employees are there in the database?"

response = chain.invoke({
    "question": question
})

print("Question:", question)
print("Response:", response)

In [None]:
db.run(response)

In [7]:
chain.get_prompts()[0].pretty_print()

You are a SQLite expert. Given an input question, first create a syntactically correct SQLite query to run, then look at the results of the query and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most 5 results using the LIMIT clause as per SQLite. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use date('now') function to get the current date, if the question involves "today".

Use the following format:

Question: Question here
SQLQuery: SQL Query to run
SQLResult: Result

In [12]:
from langchain_community.tools.sql_database.tool import QuerySQLDatabaseTool
from langchain_core.tools import tool


@tool
def parse(query_string: str) -> str:
    """Parse the SQL query string and return a human-readable explanation."""

    splitted_string = query_string.split(":")

    if (len(splitted_string) >= 2):
        query = splitted_string[1].strip()
    else:
        query = query_string

    return query.strip()

In [13]:
parse("""
      SELECT COUNT("EmployeeId") AS EmployeeCount
        FROM Employee
      """)

  parse("""


'SELECT COUNT("EmployeeId") AS EmployeeCount\n        FROM Employee'

In [14]:
parse("""
      ```sql```: SELECT COUNT("EmployeeId") AS EmployeeCount
        FROM Employee
      """)

'SELECT COUNT("EmployeeId") AS EmployeeCount\n        FROM Employee'

In [15]:
execute_query = QuerySQLDatabaseTool(db=db)
write_query = create_sql_query_chain(
    llm=llm,
    db=db,
)
chain = write_query | parse | execute_query
question = "How many employees are there in the database?"
response = chain.invoke({
    "question": question
})

print("Question:", question)
print("Response:", response)

Question: How many employees are there in the database?
Response: [(8,)]


In [16]:
question = "which country's customers have spent the most?"

execute_query = QuerySQLDatabaseTool(db=db)
write_query = create_sql_query_chain(
    llm=llm,
    db=db,
)
chain = write_query | parse | execute_query

response = chain.invoke({
    "question": question
})

print("Question:", question)
print("Response:", response)

Question: which country's customers have spent the most?
Response: [('USA', 523.06)]


In [18]:
from operator import itemgetter
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough

In [19]:
answer_prompt = PromptTemplate.from_template(
    """
        Given the following user question, corresponding SQL Query, and SQL Result, answer the user question in a concise manner.
        
        User Question: {question}
        SQL Query: {query}
        SQL Result: {result}
        Answer: 
    """
)

chain = RunnablePassthrough \
    .assign(query=write_query) \
    .assign(result=itemgetter("query") | parse | execute_query) | \
    answer_prompt | llm | StrOutputParser()

In [20]:
question = "How many employees are there in the database?"

response = chain.invoke({
    "question": question
})

print("Question:", question)
print("Response:", response)

Question: How many employees are there in the database?
Response: There are 8 employees in the database.


In [22]:
question = "which country's customers have spent the most and give me their total spending?"

response = chain.invoke({
    "question": question
})

print("Question:", question)
print("Response:", response)

Question: which country's customers have spent the most and give me their total spending?
Response: The customers from the USA have spent the most, with a total spending of $523.06.


In [23]:
import ast
import re


def query_as_list(database, query):
    """
    Execute a SQL query and return the results as a list of dictionaries.
    """
    result = database.run(query)
    result = [el for sub in ast.literal_eval(result) for el in sub if el]
    result = [re.sub(r"\b\d+\b", "", string).strip() for string in result]

    return list(set(result))

artists = query_as_list(db, "SELECT Name FROM Artist")
albums = query_as_list(db, "SELECT Title FROM Album")

In [24]:
artists[5]

'Lenny Kravitz'

In [25]:
albums[:5]

['Retrospective I (-)',
 'Os Cães Ladram Mas A Caravana Não Pára',
 'Bach: Violin Concertos',
 'The Last Night of the Proms',
 'Body Count']

In [26]:
from langchain.agents.agent_toolkits import create_retriever_tool
from langchain_community.vectorstores import FAISS

In [27]:
vector_database = FAISS.from_texts(
    texts=artists + albums,
    embedding=OpenAIEmbeddings(openai_api_key=openai_api_key),
)

retriever = vector_database.as_retriever(
    search_type="similarity",
    search_kwargs={"k": 2}
)

description = """
    use to lookup values to filter on.
    input is an approximate spelling of the valid and proper nouns.
    Use the noun most similar to the input.
    If the input is not a valid noun, return an empty string.
"""

retriever_tool = create_retriever_tool(
    retriever=retriever,
    name="retriever",
    description=description,
)

In [28]:
response = retriever_tool.invoke("Alis Chains")

print("Input: Alis Chains")
print("Response:", response)

Input: Alis Chains
Response: Alice In Chains

Aisha Duo


In [29]:
response = retriever_tool.invoke("Do we have any artists by named Alis Chains")

print("Input: Alis Chains")
print("Response:", response)

Input: Alis Chains
Response: Alice In Chains

Alanis Morissette
