### Understanding SQL Agents

In [None]:
import os

from dotenv import load_dotenv
from langchain.text_splitter import CharacterTextSplitter
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain_community.vectorstores import Chroma
from langchain_community.document_loaders import TextLoader
from langchain_community.utilities import SQLDatabase

In [None]:
load_dotenv(override=True)

openai_api_key = os.getenv("OPENAI_API_KEY")

if not openai_api_key:
    raise ValueError("OPENAI_API_KEY environment variable is not set.")

temperature = 0.7
max_tokens = 1500
model_name = "gpt-3.5-turbo-0125"

llm = ChatOpenAI(
    model=model_name,
    temperature=temperature,
    max_tokens=max_tokens,
    openai_api_key=openai_api_key
)

In [None]:
db_path = "sqlite:///./chinook.db"
db = SQLDatabase.from_uri(db_path)

print(db.dialect)
print(db.get_usable_table_names())

In [None]:
from langchain.chains import create_sql_query_chain

chain = create_sql_query_chain(
    llm=llm,
    db=db,
)

In [None]:
question = "How many employees are there in the database?"

response = chain.invoke({
    "question": question
})

print("Question:", question)
print("Response:", response)

In [None]:
question = "which country's customers have spent the most?"

chain = create_sql_query_chain(
    llm=llm,
    db=db,
)

response = chain.invoke({
    "question": question
})

print("Question:", question)
print("Response:", response)

In [None]:
question = "How many employees are there in the database?"

response = chain.invoke({
    "question": question
})

print("Question:", question)
print("Response:", response)

In [None]:
db.run(response)

In [None]:
chain.get_prompts()[0].pretty_print()

In [None]:
from langchain_community.tools.sql_database.tool import QuerySQLDatabaseTool
from langchain_core.tools import tool


@tool
def parse(query_string: str) -> str:
    """Parse the SQL query string and return a human-readable explanation."""

    splitted_string = query_string.split(":")

    if (len(splitted_string) >= 2):
        query = splitted_string[1].strip()
    else:
        query = query_string

    return query.strip()

In [None]:
parse("""
      SELECT COUNT("EmployeeId") AS EmployeeCount
        FROM Employee
      """)

In [None]:
parse("""
      QUERY: SELECT COUNT("EmployeeId") AS EmployeeCount
        FROM Employee
      """)

In [None]:
execute_query = QuerySQLDatabaseTool(db=db)
write_query = create_sql_query_chain(
    llm=llm,
    db=db,
)
chain = write_query | parse | execute_query
question = "How many employees are there in the database?"
response = chain.invoke({
    "question": question
})

print("Question:", question)
print("Response:", response)

In [None]:
question = "which country's customers have spent the most?"

execute_query = QuerySQLDatabaseTool(db=db)
write_query = create_sql_query_chain(
    llm=llm,
    db=db,
)
chain = write_query | parse | execute_query

response = chain.invoke({
    "question": question
})

print("Question:", question)
print("Response:", response)

In [None]:
from operator import itemgetter
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough

In [None]:
answer_prompt = PromptTemplate.from_template(
    """
        Given the following user question, corresponding SQL Query, and SQL Result, answer the user question in a concise manner.
        
        User Question: {question}
        SQL Query: {query}
        SQL Result: {result}
        Answer: 
    """
)

chain = RunnablePassthrough \
    .assign(query=write_query) \
    .assign(result=itemgetter("query") | parse | execute_query) | \
    answer_prompt | llm | StrOutputParser()

In [None]:
question = "How many employees are there in the database?"

response = chain.invoke({
    "question": question
})

print("Question:", question)
print("Response:", response)

In [None]:
question = "which country's customers have spent the most?"

response = chain.invoke({
    "question": question
})

print("Question:", question)
print("Response:", response)

In [None]:
import ast
import re


def query_as_list(database, query):
    """
    Execute a SQL query and return the results as a list of dictionaries.
    """
    result = database.run(query)
    result = [el for sub in ast.literal_eval(result) for el in sub if el]
    result = [re.sub(r"\b\d+\b", "", string).strip() for string in result]

    return list(set(result))

artists = query_as_list(db, "SELECT Name FROM Artist")
albums = query_as_list(db, "SELECT Title FROM Album")

In [None]:
artists[5]

In [None]:
albums[:5]

In [None]:
from langchain.agents.agent_toolkits import create_retriever_tool
from langchain_community.vectorstores import FAISS

In [None]:
vector_database = FAISS.from_texts(
    texts=artists + albums,
    embedding=OpenAIEmbeddings(openai_api_key=openai_api_key),
)

retriever = vector_database.as_retriever(
    search_type="similarity",
    search_kwargs={"k": 2}
)

description = """
    use to lookup values to filter on.
    input is an approximate spelling of the valid and proper nouns.
    Use the noun most similar to the input.
    If the input is not a valid noun, return an empty string.
"""

retriever_tool = create_retriever_tool(
    retriever=retriever,
    name="retriever",
    description=description,
)

In [None]:
response = retriever_tool.invoke("Alis Chains")

print("Input: Alis Chains")
print("Response:", response)

In [None]:
response = retriever_tool.invoke("Do we have any artists by named Alis Chains")

print("Input: Alis Chains")
print("Response:", response)