# An agent for interacting with a SQL database

In this tutorial, we will walk through how to build an agent that can answer questions about a SQL database. 

At a high level, the agent will:
1. Fetch the available tables from the database
2. Decide which tables are relevant to the question
3. Fetch the DDL for the relevant tables
4. Generate a query based on the question and information from the DDL
5. Double-check the query for common mistakes using an LLM
6. Execute the query and return the results
7. Correct mistakes surfaced by the database engine until the query is successful
8. Formulate a response based on the results

## Setup

First let's install our required packages and set our API keys

In [None]:
%%capture --no-stderr
%pip install -U langgraph langchain_community

In [None]:
!pip install -U langchain
!pip install -qU langchain-groq



In [None]:
import os
from dotenv import load_dotenv
load_dotenv()

# os.environ["OPENAI_API_KEY"]
# os.environ["LANGCHAIN_API_KEY"]
# os.environ["GROQ_API_KEY"]
# os.environ["PGPASSWORD"]
# os.environ["DATABASE_URI"] 

os.environ["LANGCHAIN_TRACING_V2"] = "true"
os.environ["LANGCHAIN_ENDPOINT"] = "https://api.smith.langchain.com"
os.environ["LANGCHAIN_PROJECT"] = "llama-agent"

<div class="admonition tip">
    <p class="admonition-title">Set up <a href="https://smith.langchain.com">LangSmith</a> for LangGraph development</p>
    <p style="padding-top: 5px;">
        Sign up for LangSmith to quickly spot issues and improve the performance of your LangGraph projects. LangSmith lets you use trace data to debug, test, and monitor your LLM apps built with LangGraph — read more about how to get started <a href="https://docs.smith.langchain.com">here</a>. 
    </p>
</div>    

## Configure the database

We will be creating a SQLite database for this tutorial. SQLite is a lightweight database that is easy to set up and use. We will be loading the `chinook` database, which is a sample database that represents a digital media store.
Find more information about the database [here](https://www.sqlitetutorial.net/sqlite-sample-database/).

For convenience, we have hosted the database (`Chinook.db`) on a public GCS bucket.

In [None]:
from langchain_community.utilities import SQLDatabase

database_uri = os.getenv("DATABASE_URI")

# Initialize the SQLDatabase instance
db = SQLDatabase.from_uri(database_uri)

print(db.dialect)
print(db.get_usable_table_names())
db.run("SELECT * FROM tx LIMIT 10;")
db.run("SELECT MAX(tx_in_id) FROM tx_in;")

postgresql
['ada_pots', 'block', 'collateral_tx_in', 'collateral_tx_out', 'committee', 'committee_de_registration', 'committee_hash', 'committee_member', 'committee_registration', 'constitution', 'cost_model', 'datum', 'delegation', 'delegation_vote', 'delisted_pool', 'drep_distr', 'drep_hash', 'drep_registration', 'epoch', 'epoch_param', 'epoch_stake', 'epoch_stake_progress', 'epoch_state', 'epoch_sync_time', 'event_info', 'extra_key_witness', 'extra_migrations', 'gov_action_proposal', 'ma_tx_mint', 'ma_tx_out', 'meta', 'multi_asset', 'new_committee', 'off_chain_pool_data', 'off_chain_pool_fetch_error', 'off_chain_vote_author', 'off_chain_vote_data', 'off_chain_vote_drep_data', 'off_chain_vote_external_update', 'off_chain_vote_fetch_error', 'off_chain_vote_gov_action_data', 'off_chain_vote_reference', 'param_proposal', 'pool_hash', 'pool_metadata_ref', 'pool_owner', 'pool_relay', 'pool_retire', 'pool_stat', 'pool_update', 'pot_transfer', 'redeemer', 'redeemer_data', 'reference_tx_in',

'[(96591643,)]'

## Utility functions

We will define a few utility functions to help us with the agent implementation. Specifically, we will wrap a `ToolNode` with a fallback to handle errors and surface them to the agent.

In [None]:
from typing import Any

from langchain_core.messages import ToolMessage
from langchain_core.runnables import RunnableLambda, RunnableWithFallbacks
from langgraph.prebuilt import ToolNode


def create_tool_node_with_fallback(tools: list) -> RunnableWithFallbacks[Any, dict]:
    """
    Create a ToolNode with a fallback to handle errors and surface them to the agent.
    """
    return ToolNode(tools).with_fallbacks(
        [RunnableLambda(handle_tool_error)], exception_key="error"
    )


def handle_tool_error(state) -> dict:
    error = state.get("error")
    tool_calls = state["messages"][-1].tool_calls
    return {
        "messages": [
            ToolMessage(
                content=f"Error: {repr(error)}\n please fix your mistakes.",
                tool_call_id=tc["id"],
            )
            for tc in tool_calls
        ]
    }

## Define tools for the agent

We will define a few tools that the agent will use to interact with the database.

1. `list_tables_tool`: Fetch the available tables from the database
2. `get_schema_tool`: Fetch the DDL for a table
3. `db_query_tool`: Execute the query and fetch the results OR return an error message if the query fails

For the first two tools, we will grab them from the `SQLDatabaseToolkit`, also available in the `langchain_community` package.

In [None]:
from langchain_community.agent_toolkits import SQLDatabaseToolkit
from langchain_groq import ChatGroq

toolkit = SQLDatabaseToolkit(db=db, llm=ChatGroq(model="llama3-70b-8192"))
tools = toolkit.get_tools()

list_tables_tool = next(tool for tool in tools if tool.name == "sql_db_list_tables")
get_schema_tool = next(tool for tool in tools if tool.name == "sql_db_schema")

print(list_tables_tool.invoke(""))
print(get_schema_tool.invoke("tx_in"))

ada_pots, block, collateral_tx_in, collateral_tx_out, committee, committee_de_registration, committee_hash, committee_member, committee_registration, constitution, cost_model, datum, delegation, delegation_vote, delisted_pool, drep_distr, drep_hash, drep_registration, epoch, epoch_param, epoch_stake, epoch_stake_progress, epoch_state, epoch_sync_time, event_info, extra_key_witness, extra_migrations, gov_action_proposal, ma_tx_mint, ma_tx_out, meta, multi_asset, new_committee, off_chain_pool_data, off_chain_pool_fetch_error, off_chain_vote_author, off_chain_vote_data, off_chain_vote_drep_data, off_chain_vote_external_update, off_chain_vote_fetch_error, off_chain_vote_gov_action_data, off_chain_vote_reference, param_proposal, pool_hash, pool_metadata_ref, pool_owner, pool_relay, pool_retire, pool_stat, pool_update, pot_transfer, redeemer, redeemer_data, reference_tx_in, reserve, reserved_pool_ticker, reverse_index, reward, reward_rest, schema_version, script, slot_leader, stake_address, 

The third will be defined manually. For the `db_query_tool`, we will execute the query against the database and return the results.

In [None]:
from langchain_core.tools import tool


@tool
def db_query_tool(query: str) -> str:
    """
    Execute a SQL query against the database and get back the result.
    If the query is not correct, an error message will be returned.
    If an error is returned, rewrite the query, check the query, and try again.
    """
    result = db.run_no_throw(query)
    if not result:
        return "Error: Query failed. Please rewrite your query and try again."
    return f"Query results: {result}"


print(db_query_tool.invoke("SELECT * FROM tx_in LIMIT 10;"))

Query results: [(1, 14506, 7853, 0, None), (2, 14507, 14506, 0, None), (3, 14508, 14507, 0, None), (4, 14509, 14508, 1, None), (5, 14509, 14507, 1, None), (6, 14510, 14508, 0, None), (7, 14510, 14509, 1, None), (8, 14511, 1326, 0, None), (9, 14512, 14511, 0, None), (10, 14513, 4360, 0, None)]


While not strictly a tool, we will prompt an LLM to check for common mistakes in the query and later add this as a node in the workflow.

In [None]:
from langchain_core.prompts import ChatPromptTemplate

query_check_system = """You are a SQL expert with a strong attention to detail.
Generate a SQL query to retrieve data from the live PostgreSQL database.
Do not use example data. Ensure your query targets the database directly.
Double check the SQLite query for common mistakes, including:
- Using NOT IN with NULL values
- Using UNION when UNION ALL should have been used
- Using BETWEEN for exclusive ranges
- Data type mismatch in predicates
- Properly quoting identifiers
- Using the correct number of arguments for functions
- Casting to the correct data type
- Using the proper columns for joins

If there are any of the above mistakes, rewrite the query. If there are no mistakes, just reproduce the original query.

You will call the appropriate tool to execute the query after running this check."""

query_check_prompt = ChatPromptTemplate.from_messages(
    [("system", query_check_system), ("placeholder", "{messages}")]
)
query_check = query_check_prompt | ChatGroq(model="llama3-70b-8192", temperature=0).bind_tools(
    [db_query_tool], tool_choice="any"
)


#query_check.invoke({"messages": [("user", "SELECT * FROM tx_in LIMIT 10;")]})



query_check.invoke({"messages": [("user", 
"""
The relevant table here is "block" and the relevant columns are "time" and "id". We can count the number of blocks per day by grouping the data by date. Here is the SQL query:

```sql
SELECT DATE(time) as date, COUNT(id) as number_of_blocks
FROM block
WHERE time BETWEEN '2023-05-01' AND '2023-06-30'
GROUP BY date
ORDER BY date;
```

This query will return the number of blocks for each day from May 2023 to June 2023.
"""
)]})

AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_8pzw', 'function': {'arguments': '{"query":"SELECT DATE(time) as date, COUNT(id) as number_of_blocks FROM block WHERE time BETWEEN \'2023-05-01\' AND \'2023-06-30\' GROUP BY date ORDER BY date;"}', 'name': 'db_query_tool'}, 'type': 'function'}]}, response_metadata={'token_usage': {'completion_tokens': 86, 'prompt_tokens': 1114, 'total_tokens': 1200, 'completion_time': 0.245714286, 'prompt_time': 0.054149074, 'queue_time': 0.01862576700000001, 'total_time': 0.29986336}, 'model_name': 'llama3-70b-8192', 'system_fingerprint': 'fp_2e0feca3c9', 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-c8b5c790-a1a4-4472-8a08-c9ac2735a55c-0', tool_calls=[{'name': 'db_query_tool', 'args': {'query': "SELECT DATE(time) as date, COUNT(id) as number_of_blocks FROM block WHERE time BETWEEN '2023-05-01' AND '2023-06-30' GROUP BY date ORDER BY date;"}, 'id': 'call_8pzw', 'type': 'tool_call'}], usage_metadata={'input_tokens': 1114, 

## Define the workflow

We will then define the workflow for the agent. The agent will first force-call the `list_tables_tool` to fetch the available tables from the database, then follow the steps mentioned at the beginning of the tutorial.

<div class="admonition note">
    <p class="admonition-title">Using Pydantic with LangChain</p>
    <p>
        This notebook uses Pydantic v2 <code>BaseModel</code>, which requires <code>langchain-core >= 0.3</code>. Using <code>langchain-core < 0.3</code> will result in errors due to mixing of Pydantic v1 and v2 <code>BaseModels</code>.
    </p>
</div>

In [None]:
file_path = './schema.md'
with open(file_path, "r") as file:
    schema_content = file.read()

In [None]:
import json
from typing import Annotated, Literal

from langchain_core.messages import AIMessage, ToolMessage
from langchain_core.prompts import ChatPromptTemplate
from langchain_groq import ChatGroq
from langchain_openai import ChatOpenAI
from pydantic import BaseModel, Field
from typing_extensions import TypedDict

from langgraph.graph import END, START, StateGraph
from langgraph.graph.message import AnyMessage, add_messages


# Define the state for the agent
class State(TypedDict):
    messages: Annotated[list[AnyMessage], add_messages]


# Define a new graph
workflow = StateGraph(State)


# Add a node for the first tool call
def first_tool_call(state: State) -> dict[str, list[AIMessage]]:
    return {
        "messages": [
            AIMessage(
                content="",
                tool_calls=[
                    {
                        "name": "sql_db_list_tables",
                        "args": {},
                        "id": "tool_abcd123",
                    }
                ],
            )
        ]
    }


def model_check_query(state: State) -> dict[str, list[AIMessage]]:
    """
    Use this tool to double-check if your query is correct before executing it.
    """
    return {"messages": [query_check.invoke({"messages": [state["messages"][-1]]})]}


workflow.add_node("first_tool_call", first_tool_call)

# Add nodes for the first two tools
workflow.add_node("list_tables_tool", create_tool_node_with_fallback([list_tables_tool]))

workflow.add_node("get_schema_tool", create_tool_node_with_fallback([get_schema_tool]))

# Add a node for a model to choose the relevant tables based on the question and available tables
model_get_schema = ChatGroq(model="llama3-70b-8192", temperature=0).bind_tools([get_schema_tool])
workflow.add_node(
    "model_get_schema",
    lambda state: {
        "messages": [model_get_schema.invoke(state["messages"])],
    },
)


# Describe a tool to represent the end state
class SubmitFinalAnswer(BaseModel):
    """Submit the final answer to the user based on the query results."""

    final_answer: str = Field(..., description="The final answer to the user")


# Add a node for a model to generate a query based on the question and schema
query_gen_system = """You are a SQL expert with a strong attention to detail.

Given an input question, output a syntactically correct PostgreSQL query to run, then look at the results of the query and return the answer.

You must generate a SQL query that retrieves live data from the PostgreSQL database. Do not use any example rows that may be provided by the schema tool as part of the final answer.
The final answer should only be based on the actual query results from the database.

When generating the query:

Output the SQL query that answers the input question without a tool call. The output must be only the SQL query without any text or explanation.

You can order the results by a relevant column to return the most interesting examples in the database.
Never query for all the columns from a specific table, only ask for the relevant columns given the question.

If you get an error while executing a query, rewrite the query and try again.

If you get an empty result set, you should try to rewrite the query to get a non-empty result set.
NEVER make stuff up if you don't have enough information to answer the query... just say you don't have enough information.

Generate only the query. IMPORTANT: DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.

"""

# Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most 5 results.
query_gen_prompt = ChatPromptTemplate.from_messages([("system", query_gen_system), ("placeholder", "{messages}")])
query_gen = query_gen_prompt | ChatGroq(model="llama3-70b-8192", temperature=0)


def query_gen_node(state: State):
    message = query_gen.invoke(state)

    # If the message contains a SQL query (not a tool call), return it
    if not message.tool_calls:
        return {"messages": [message]}

    # If there's a tool call, it's an error
    tool_messages = [
        ToolMessage(
            content=f"Error: Direct tool calls are not allowed. Please generate a SQL query instead.",
            tool_call_id=tc["id"],
        )
        for tc in message.tool_calls
    ]
    return {"messages": tool_messages}


workflow.add_node("query_gen", query_gen_node)



import psycopg2
import csv


def new_execute_query_node(state: State):
    # Write your SQL query
    #query = "SELECT * FROM your_table_name;"  # Replace with your actual query
    messages = state["messages"]
    query = messages[-1].content
    
    # Define your PostgreSQL connection details
    conn = psycopg2.connect(os.environ["DATABASE_URI"])

    # Create a cursor object
    cur = conn.cursor()
    
    # Execute the SQL query
    cur.execute(query)
    
    # Fetch all rows from the result
    rows = cur.fetchall()
    
    # Get column names from the cursor
    columns = [desc[0] for desc in cur.description]
    
    # Define the CSV file path
    csv_file_path = 'query_result.csv'
    # Define the txt file path
    #txt_file_path = 'query_result.txt'
    
    # Write the result to a CSV file
    with open(csv_file_path, mode='w') as file:
    # Write the column headers
        file.write(",".join(columns) + "\n")  # Use tab as delimiter between columns
    
    # Write the data rows
        for row in rows:
            file.write(",".join(map(str, row)) + "\n")  # Convert each row to a string and join with tab
    
    # Close the cursor and the connection
    cur.close()
    conn.close()
    
    print(f"Data has been saved to {csv_file_path}")



workflow.add_node("new_execute_query", new_execute_query_node)

# Define a conditional edge to decide whether to continue or end the workflow
def should_continue(state: State) -> Literal["new_execute_query", "query_gen"]:
    messages = state["messages"]
    last_message = messages[-1]
    if last_message.content.startswith("Error:"):
        return "query_gen"
    else:
        return "new_execute_query"

# Specify the edges between the nodes
workflow.add_edge(START, "first_tool_call")
workflow.add_edge("first_tool_call", "list_tables_tool")
workflow.add_edge("list_tables_tool", "model_get_schema")
workflow.add_edge("model_get_schema", "get_schema_tool")
workflow.add_edge("get_schema_tool", "query_gen")
workflow.add_conditional_edges("query_gen", should_continue)
#workflow.add_edge("correct_query", "execute_query")
#workflow.add_edge("execute_query", "submit_final_answer")
#workflow.add_edge("submit_final_answer", END)
workflow.add_edge("new_execute_query", END)

# Compile the workflow into a runnable
app = workflow.compile()

##Visualize the Graph

In [None]:
from IPython.display import Image, display
from langchain_core.runnables.graph import MermaidDrawMethod

display(
    Image(
        app.get_graph().draw_mermaid_png(
            draw_method=MermaidDrawMethod.API,
        )
    )
)

## Run the agent

In [None]:
#    {"messages": [("user", "Write a PostgreSQL query that retrieves schema information for the tables tx_in, tx_out, reward, in the database. The query should include: Table Name (table_name), Column Name (column_name). Format the query to retrieve these details in a way that captures column metadata across the first ten tables in the database. Add one more column named Description that describes each row in the tables, for example, for the row block_id, the description should be - The Block table index of the block for which this snapshot was taken. IMPORTANT: Avoid generating a description of No description available. The final output should only be the SQL query, with no additional explanation or comments.")]}
#    {"messages": [("user", "Generate a SQL query that extracts the total stakes and total rewards each epoch")]}

messages = app.invoke(
    {"messages": [("user", "Generate a SQL query that returns pool id, the total stakes delegated to each pool, and total rewards earned by each pool in epoch 280 from tables reward and epoch_stake")]}
)
print(messages)
#json_str = messages["messages"][-1].tool_calls[0]["args"]["final_answer"]
#json_str
#print(messages["messages"][-1].content)

Data has been saved to query_result.csv
{'messages': [HumanMessage(content='Generate a SQL query that returns pool id, the total stakes delegated to each pool, and total rewards earned by each pool in epoch 280 from tables reward and epoch_stake', additional_kwargs={}, response_metadata={}, id='9cc26cca-4c03-447f-8eda-7fa453e33702'), AIMessage(content='', additional_kwargs={}, response_metadata={}, id='310b5a48-f67c-4392-ae4b-ac7afeaa5be7', tool_calls=[{'name': 'sql_db_list_tables', 'args': {}, 'id': 'tool_abcd123', 'type': 'tool_call'}]), ToolMessage(content='ada_pots, block, collateral_tx_in, collateral_tx_out, committee, committee_de_registration, committee_hash, committee_member, committee_registration, constitution, cost_model, datum, delegation, delegation_vote, delisted_pool, drep_distr, drep_hash, drep_registration, epoch, epoch_param, epoch_stake, epoch_stake_progress, epoch_state, epoch_sync_time, event_info, extra_key_witness, extra_migrations, gov_action_proposal, ma_tx_min

In [None]:
#Write a PostgreSQL query that retrieves schema information for the tables tx_in, tx_out, reward, in the database. The query should include: Table Name (table_name), Column Name (column_name). Format the query to retrieve these details in a way that captures column metadata across the first ten tables in the database. Add one more column named Description that describes each row in the tables, for example, for the row block_id, the description should be - The Block table index of the block for which this snapshot was taken. IMPORTANT: Avoid generating a description of No description available. The final output should only be the SQL query, with no additional explanation or comments.

In [None]:
print(messages["messages"][-1])