In [25]:
%%capture --no-stderr
%pip install --upgrade --quiet langchain langchain-community langchain-openai faiss-cpu

In [26]:
import getpass
import os
from dotenv import load_dotenv
from langchain_openai import ChatOpenAI

os.environ["LANGCHAIN_TRACING_V2"] = "true" 

load_dotenv()
if "OPENAI_API_KEY" not in os.environ:
    os.environ["OPENAI_API_KEY"] = getpass.getpass("Enter your OpenAI API key: ")
if "LANGCHAIN_API_KEY" not in os.environ:
    os.environ["LANGCHAIN_API_KEY"] = getpass.getpass("Enter your LangChain API key: ")

llm = ChatOpenAI(model="gpt-4o-mini")

In [27]:
from langchain_community.utilities import SQLDatabase

db = SQLDatabase.from_uri("sqlite:///Chinook.db") # select relative database file

# print(db.dialect)
# print(db.get_usable_table_names())
# db.run("SELECT * FROM Artist LIMIT 10;")

In [28]:
from typing_extensions import TypedDict

#schema for the final dictionary and pipeline
class State(TypedDict):
    question: str
    query: str
    result: str
    answer: str

In [29]:
from langchain_core.prompts import ChatPromptTemplate

system_message = """
Given an input question, create a syntactically correct {dialect} query to
run to help find the answer. Unless the user specifies in his question a
specific number of examples they wish to obtain, always limit your query to
at most {top_k} results. You can order the results by a relevant column to
return the most interesting examples in the database.

Never query for all the columns from a specific table, only ask for a the
few relevant columns given the question.

Pay attention to use only the column names that you can see in the schema
description. Be careful to not query for columns that do not exist. Also,
pay attention to which column is in which table.

Only use the following tables:
{table_info}
"""

user_prompt = "Question: {input}" 

query_prompt_template = ChatPromptTemplate(
    [("system", system_message),
    ("human", user_prompt)],
)

for message in query_prompt_template.messages:
    message.pretty_print()



Given an input question, create a syntactically correct [33;1m[1;3m{dialect}[0m query to
run to help find the answer. Unless the user specifies in his question a
specific number of examples they wish to obtain, always limit your query to
at most [33;1m[1;3m{top_k}[0m results. You can order the results by a relevant column to
return the most interesting examples in the database.

Never query for all the columns from a specific table, only ask for a the
few relevant columns given the question.

Pay attention to use only the column names that you can see in the schema
description. Be careful to not query for columns that do not exist. Also,
pay attention to which column is in which table.

Only use the following tables:
[33;1m[1;3m{table_info}[0m


Question: [33;1m[1;3m{input}[0m


In [30]:
# convert prompt to query using llm itself
from typing_extensions import Annotated

class QueryState(TypedDict):
    query: Annotated[str, ..., "A syntactically correct SQL query to run."] 
    # define the structure of the llm output to be a dictionary with a query string (so we dont get extra text like "The query is: ...")

# populate parameters for the query prompt
def write_query(state: State):
    # Fill in the template with a dict, not kwargs
    prompt = query_prompt_template.invoke({
        "dialect": db.dialect,
        "top_k": 10,
        "table_info": db.get_table_info(),
        "input": state["question"],
    })
    structured_llm = llm.with_structured_output(QueryState)  # use the llm with structured output
    result = structured_llm.invoke(prompt)
    return result

write_query({"question": "How many tracks are there in the database?"})

{'query': 'SELECT COUNT(*) AS TrackCount FROM Track;'}

In [None]:
# 2. execute the query: excexute the generated sql query
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool

def execute_query(state: State):
    execute_query_tool = QuerySQLDataBaseTool(db=db)            
    return {"result": execute_query_tool.invoke(state["query"])} # use the QuerySQLDataBaseTool to execute the query

execute_query({'query': 'SELECT COUNT(*) as TotalTracks FROM Track;'})

{'result': '[(3503,)]'}

In [38]:
# 3. generate answer from the result
def generate_answer(state: State):
    prompt = (
        "Given the following user question, corresponding SQL query, "
        "and SQL result, answer the user question.\n\n"
        f'Question: {state["question"]}\n'
        f'SQL Query: {state["query"]}\n'
        f'SQL Result: {state["result"]}'
    )
    response = llm.invoke(prompt)
    return {"answer": response.content}

In [None]:
from langgraph.graph import START, StateGraph
# build the pipeline using LangGraph
graph_builder = StateGraph(State).add_sequence(
    [write_query, execute_query, generate_answer]
)
# add entry point to tell graph where to start
graph_builder.add_edge(START, "write_query")
graph = graph_builder.compile()

In [None]:
for step in graph.stream(
    {"question": "How many employees are there?"}, stream_mode="updates"
):
    print(step)

{'write_query': {'query': 'SELECT COUNT(*) AS EmployeeCount FROM Employee;'}}
{'execute_query': {'result': '[(8,)]'}}
{'generate_answer': {'answer': 'There are 8 employees.'}}


In [None]:
from IPython.display import Image, display
from langchain_core.runnables.graph import CurveStyle, MermaidDrawMethod, NodeStyles

try:
    display(Image(graph.get_graph().draw_mermaid_png()))
except Exception:
    pass


KeyboardInterrupt: 