In [5]:
%%capture --no-stderr
%pip install --upgrade --quiet langchain langchain-community langchain-openai faiss-cpu

In [6]:
import getpass
import os
from dotenv import load_dotenv
from langchain_openai import ChatOpenAI

os.environ["LANGCHAIN_TRACING_V2"] = "true" 

load_dotenv()
if "OPENAI_API_KEY" not in os.environ:
    os.environ["OPENAI_API_KEY"] = getpass.getpass("Enter your OpenAI API key: ")
if "LANGCHAIN_API_KEY" not in os.environ:
    os.environ["LANGCHAIN_API_KEY"] = getpass.getpass("Enter your LangChain API key: ")

llm = ChatOpenAI(model="gpt-4o-mini")

In [7]:
from langchain_community.utilities import SQLDatabase

db = SQLDatabase.from_uri("sqlite:///Chinook.db") # select relative database file

print(db.dialect)
print(db.get_usable_table_names())
db.run("SELECT * FROM Artist LIMIT 10;")

sqlite
['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']


"[(1, 'AC/DC'), (2, 'Accept'), (3, 'Aerosmith'), (4, 'Alanis Morissette'), (5, 'Alice In Chains'), (6, 'Antônio Carlos Jobim'), (7, 'Apocalyptica'), (8, 'Audioslave'), (9, 'BackBeat'), (10, 'Billy Cobham')]"

In [8]:
from typing_extensions import TypedDict

class State(TypedDict):
    question: str
    query: str
    result: str
    answer: str

In [9]:
from langchain_core.prompts import ChatPromptTemplate

system_message = """
Given an input question, create a syntactically correct {dialect} query to
run to help find the answer. Unless the user specifies in his question a
specific number of examples they wish to obtain, always limit your query to
at most {top_k} results. You can order the results by a relevant column to
return the most interesting examples in the database.

Never query for all the columns from a specific table, only ask for a the
few relevant columns given the question.

Pay attention to use only the column names that you can see in the schema
description. Be careful to not query for columns that do not exist. Also,
pay attention to which column is in which table.

Only use the following tables:
{table_info}
"""

user_prompt = "Question: {input}" 
query_prompt_template = ChatPromptTemplate(
    [("system", system_message),
    ("human", user_prompt)],
)

for message in query_prompt_template.messages:
    message.pretty_print()



Given an input question, create a syntactically correct [33;1m[1;3m{dialect}[0m query to
run to help find the answer. Unless the user specifies in his question a
specific number of examples they wish to obtain, always limit your query to
at most [33;1m[1;3m{top_k}[0m results. You can order the results by a relevant column to
return the most interesting examples in the database.

Never query for all the columns from a specific table, only ask for a the
few relevant columns given the question.

Pay attention to use only the column names that you can see in the schema
description. Be careful to not query for columns that do not exist. Also,
pay attention to which column is in which table.

Only use the following tables:
[33;1m[1;3m{table_info}[0m


Question: [33;1m[1;3m{input}[0m


***Need to understand below***

In [None]:
from typing_extensions import Annotated

class QueryState(TypedDict): # specify the state for the query, define the type of query as a string, providing a description for readiblity
    query: Annotated[str, ..., "A syntactically correct SQL query to run."] 
    # define the structure of the llm output as a TypedDict, which is a dictionary with specific keys (our only key being query) and types (type string)

# populate paraeters for the query prompt
def write_query(state: State):
    prompt = query_prompt_template.format( # fills in parameters in the template
        dialect=db.dialect,
        top_k=10,
        table_info=db.get_table_info(),
        input=state["question"], # get the question from the argument (state is a State type-dictionary defined above)
    )
    structured_llm = llm.with_structured_output(QueryState)  # use the llm with structured output
    result = structured_llm.invoke(prompt)
    return result

write_query({"question": "How many tracks are there in the database?"})

{'query': 'SELECT COUNT(*) as TotalTracks FROM Track;'}

In [None]:
#execute the query

from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool

def execute_query(state: State):
    execute_query_tool = QuerySQLDataBaseTool(db=db)
    return {"result": execute_query_tool.invoke(state["query"])}

execute_query({'query': 'SELECT COUNT(*) as TotalTracks FROM Track;'})

{'result': '[(3503,)]'}