In [1]:
import os

os.environ["LANGSMITH_TRACING"] = "true"
os.environ["LANGSMITH_TRACING_V2"] = "true"

In [2]:
from langchain_community.utilities import SQLDatabase

db = SQLDatabase.from_uri("sqlite:///../tmp/Chinook.db")
print(db.dialect)
print(db.get_usable_table_names())
db.run("SELECT * FROM Artist LIMIT 10;")

sqlite
['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']


"[(1, 'AC/DC'), (2, 'Accept'), (3, 'Aerosmith'), (4, 'Alanis Morissette'), (5, 'Alice In Chains'), (6, 'Antônio Carlos Jobim'), (7, 'Apocalyptica'), (8, 'Audioslave'), (9, 'BackBeat'), (10, 'Billy Cobham')]"

In [None]:
from IPython.display import Markdown, display

display(Markdown(db.get_table_info()))

In [4]:
from langchain_google_vertexai import ChatVertexAI

llm = ChatVertexAI(model="gemini-2.0-flash-001")

## A basic query session

1. We need to construct a query first from natural langugage. To construt a query, we need to give the model the sql dialect, and table information.
2.  We execute the query using provided tool.
3.  We give the question, table information, and  query result and ask the LLM to output a nautural langaueg  output.

First implementation will using chain so we have fine control the execution flow.




In [5]:
# Step one.

from typing import TypedDict, Annotated
from langchain_core.prompts import PromptTemplate

q2q_template = PromptTemplate.from_template(
    """system

Given an input question, create a syntactically correct {dialect} query to run to help find the answer.

Unless the user specifies in his question a specific number of examples they wish to obtain, 
always limit your query to at most {top_k} results. 

You can order the results by a relevant column to return the most interesting examples in the database.


Never query for all the columns from a specific table, 
only ask for a the few relevant columns given the question.


Pay attention to use only the column names that you can see in the schema description.
Be careful to not query for columns that do not exist. 
Also, pay attention to which column is in which table.


Only use the following tables:

{table_info}


Question: {input}

Query is: 
"""
)


#since we don't expose the state to LLM for reasoning, we don't need detailed description.
class State(TypedDict):
    question: str
    query: str
    result: str
    answer: str

def question_to_query(state:State):
    """Generate SQL query to fetch information based  on the question and database schema."""
    question = state['question']
    prompt = q2q_template.invoke({
        "top_k":10,
        "dialect": db.dialect,
        "table_info": db.get_table_info(),
        "input": question
    })

    ## The output for this turn
    class QueryResult(TypedDict):
        query: Annotated[str, ..., "thq SQL query for the question"]

    llm1 = llm.with_structured_output(QueryResult)
    resp = llm1.invoke(prompt)
    return resp

from langchain_community.tools.sql_database.tool import QuerySQLDatabaseTool

def execute_query(state:State):
    """Execute the provided query."""
    sql_executor = QuerySQLDatabaseTool(db=db)
    return {'result': sql_executor.invoke(state['query'])}

def generate_answer(state:State):
    """Generate the answer for original question from the query result."""
    prompt =(
        "The user asks the question:\n "
        f"{state['question']}\n"
        "We first convert the question into a query and execute the query against a database."
        " Here is the result from this query:\n"
        f"{state['result']}\n"
        "Could you please formulate an answer based on the question and query result?"
        "Please only show the short answer. Don't give context and where the result comes from."
    )
    response = llm.invoke(prompt)
    return {'answer': response.content}


In [6]:
# Finally, we build a chain to execute the code

from langgraph.graph import START, END, StateGraph

builder = StateGraph(State).add_sequence([question_to_query, execute_query, generate_answer])
builder.set_entry_point("question_to_query")
graph=builder.compile()



In [7]:
for step in graph.stream({'question': "How many employees do we have?"}, stream_mode="updates"):
    print(step)


{'question_to_query': {'query': 'SELECT count(*) FROM Employee'}}
{'execute_query': {'result': '[(8,)]'}}
{'generate_answer': {'answer': '8\n'}}


In [8]:
builder1 = StateGraph(State).add_sequence([question_to_query, execute_query, generate_answer])
builder1.set_entry_point("question_to_query")
from langgraph.checkpoint.memory import MemorySaver
memory = MemorySaver()
graph1 = builder1.compile(checkpointer=memory, interrupt_before=["execute_query"])
config = {'configurable': {'thread_id': "sql_jj1"}}
for step in graph1.stream({'question': 'How many employees do we have?'}, 
                          config=config,
                          stream_mode="updates"):
    print(step)

print("==========Continue the invocation after interrupt")
for step in graph1.stream(None, 
                          config=config,
                          stream_mode="updates"):
    print(step)



{'question_to_query': {'query': 'SELECT count(*) FROM Employee'}}
{'__interrupt__': ()}
{'execute_query': {'result': '[(8,)]'}}
{'generate_answer': {'answer': '8\n'}}


In [9]:

sql_agent_prompt_text = """system

You are an agent designed to interact with a SQL database.

Given an input question, create a syntactically correct {dialect} query to run, 
then look at the results of the query and return the answer.

Unless the user specifies a specific number of examples they wish to obtain, 
always limit your query to at most {top_k} results.

You can order the results by a relevant column to return the most interesting examples in the database.

Never query for all the columns from a specific table, only ask for the relevant columns given 
the question.

You have access to tools for interacting with the database.

Only use the below tools. Only use the information returned by the below tools to 
construct your final answer.

You MUST double check your query before executing it. If you 
get an error while executing a query, rewrite the query and try again.


DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.


To start you should ALWAYS look at the tables in the database to see what you can query.

Do NOT skip this step.

Then you should query the schema of the most relevant tables.
"""

sql_agent_prompt_template = PromptTemplate.from_template(sql_agent_prompt_text)
sql_agent_prompt = sql_agent_prompt_template.invoke({
    'dialect': db.dialect,
    'top_k':10
})



In [None]:
from langgraph.prebuilt import create_react_agent
from langchain_community.agent_toolkits import SQLDatabaseToolkit

toolkit = SQLDatabaseToolkit(db=db, llm=llm)
tools  = toolkit.get_tools();

agent = create_react_agent(llm, tools)

question="Which country's customers spent the most?"
for step in agent.stream(
    {'messages': [{'role': 'system', 'content': sql_agent_prompt.text}, 
                  {'role': 'user', 'content': question}]},
    stream_mode="values"
):
    step["messages"][-1].pretty_print()
