## SQLAgent

In [1]:
import os
LANGCHAIN_TRACING_V2 = os.getenv("LANGCHAIN_TRACING_V2")
LANGCHAIN_ENDPOINT = os.getenv("LANGCHAIN_ENDPOINT")
LANGCHAIN_API_KEY = os.getenv("LANGCHAIN_API_KEY")
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
TAVILY_API_KEY = os.getenv("TAVILY_API_KEY")
LANGSMITH_API_KEY = os.getenv("LANGSMITH_API_KEY")

In [2]:
from langchain_community.utilities import SQLDatabase

db = SQLDatabase.from_uri("postgresql://kunalkothavade:postgres@localhost:5432/chinook")
print(db.dialect)
print(db.get_usable_table_names())
db.run("SELECT * FROM Artist LIMIT 10")


postgresql
['album', 'artist', 'customer', 'employee', 'genre', 'invoice', 'invoice_line', 'media_type', 'playlist', 'playlist_track', 'track']


"[(1, 'AC/DC'), (2, 'Accept'), (3, 'Aerosmith'), (4, 'Alanis Morissette'), (5, 'Alice In Chains'), (6, 'Antônio Carlos Jobim'), (7, 'Apocalyptica'), (8, 'Audioslave'), (9, 'BackBeat'), (10, 'Billy Cobham')]"

### Utility functions

In [3]:
from typing import Any
from langchain_core.messages import ToolMessage
from langchain_core.runnables import RunnableLambda, RunnableWithFallbacks
from langgraph.prebuilt import ToolNode


def create_tool_node_with_fallback(tools: list) -> RunnableWithFallbacks[Any, dict]:
    """
    Create a ToolNode with a fallback to handle errors and surface them to the agent.
    """
    return ToolNode(tools).with_fallbacks(
        [RunnableLambda(handle_tool_error)], exception_key="error"
    )


def handle_tool_error(state) -> dict:
    error = state.get("error")
    tool_calls = state["messages"][-1].tool_calls
    return {
        "messages": [
            ToolMessage(
                content=f"Error: {repr(error)}\n please fix your mistakes.",
                tool_call_id=tc["id"],
            )
            for tc in tool_calls
        ]
    }


### Tools

In [4]:
from langchain_groq import ChatGroq
llm =  ChatGroq(temperature=0, model_name="gemma2-9b-it", groq_api_key=GROQ_API_KEY)
llm.invoke("What is sqlagent").content

"SQL Agent is a powerful tool within Microsoft SQL Server that allows you to automate a wide range of database management tasks. \n\nHere's a breakdown:\n\n**What it does:**\n\n* **Schedules Jobs:**  You can create jobs that run specific SQL Server tasks at predefined intervals (e.g., daily, weekly, on specific dates).\n* **Manages Tasks:**  Jobs can contain multiple tasks, such as:\n    * **SQL Server Agent Jobs:** Executing SQL scripts, backing up databases, restoring databases, monitoring performance, sending alerts, and more.\n    * **Other Tasks:**  Running external programs, sending emails, or integrating with other systems.\n* **Provides Logging and Reporting:** SQL Agent keeps detailed logs of job executions, including success/failure status, start/end times, and any errors encountered. You can generate reports to track job history and performance.\n\n**Why it's useful:**\n\n* **Automation:**  Eliminates the need for manual intervention in repetitive tasks, saving time and redu

In [5]:
from langchain_community.agent_toolkits import SQLDatabaseToolkit

toolkit = SQLDatabaseToolkit(db=db,llm=llm)
tools = toolkit.get_tools()

list_tables_tool = next(tool for tool in tools if tool.name == 'sql_db_list_tables')
get_schema_tool = next(tool for tool in tools if tool.name == 'sql_db_schema')

print(list_tables_tool.invoke(""))
print("==========================")
print(get_schema_tool.invoke("artist"))

album, artist, customer, employee, genre, invoice, invoice_line, media_type, playlist, playlist_track, track

CREATE TABLE artist (
	artist_id INTEGER NOT NULL, 
	name VARCHAR(120), 
	CONSTRAINT artist_pkey PRIMARY KEY (artist_id)
)

/*
3 rows from artist table:
artist_id	name
1	AC/DC
2	Accept
3	Aerosmith
*/


In [6]:
# Manually created tool

from langchain_core.tools import tool

@tool
def db_query_tool(query :str) -> str:
    """
    Execute a SQL query against the database and get back the result.
    If the query is not correct, an error message will be returned.
    If an error is returned, rewrite the query, check the query, and try again.
    """

    result = db.run_no_throw(query)

    if not result:
        return "Error: Query failed. Please rewrite your query and try again."
    return result
print(db_query_tool.invoke("select * from artist limit 10"))
print("==============================")
print(db_query_tool.invoke("select * from artist where artist_id = 0"))

[(1, 'AC/DC'), (2, 'Accept'), (3, 'Aerosmith'), (4, 'Alanis Morissette'), (5, 'Alice In Chains'), (6, 'Antônio Carlos Jobim'), (7, 'Apocalyptica'), (8, 'Audioslave'), (9, 'BackBeat'), (10, 'Billy Cobham')]
Error: Query failed. Please rewrite your query and try again.


In [7]:
from langchain_core.prompts import ChatPromptTemplate
query_check_system_message = """You are a SQL expert with a strong attention to detail.
Double check the Postgresql query for common mistakes, including:
- Using NOT IN with NULL values
- Using UNION when UNION ALL should have been used
- Using BETWEEN for exclusive ranges
- Data type mismatch in predicates
- Properly quoting identifiers
- Using the correct number of arguments for functions
- Casting to the correct data type
- Using the proper columns for joins

If there are any of the above mistakes, rewrite the query. If there are no mistakes, just reproduce the original query.

You will call the appropriate tool to execute the query after running this check."""

query_check_prompt = ChatPromptTemplate.from_messages([("system",query_check_system_message),("placeholder","{messages}")])
print(query_check_prompt)

query_check = query_check_prompt | llm.bind_tools([db_query_tool],tool_choice='required')

query_check.invoke({"messages":[("user","select * from artist limit 10")]})

input_variables=[] optional_variables=['messages'] input_types={'messages': list[typing.Annotated[typing.Union[typing.Annotated[langchain_core.messages.ai.AIMessage, Tag(tag='ai')], typing.Annotated[langchain_core.messages.human.HumanMessage, Tag(tag='human')], typing.Annotated[langchain_core.messages.chat.ChatMessage, Tag(tag='chat')], typing.Annotated[langchain_core.messages.system.SystemMessage, Tag(tag='system')], typing.Annotated[langchain_core.messages.function.FunctionMessage, Tag(tag='function')], typing.Annotated[langchain_core.messages.tool.ToolMessage, Tag(tag='tool')], typing.Annotated[langchain_core.messages.ai.AIMessageChunk, Tag(tag='AIMessageChunk')], typing.Annotated[langchain_core.messages.human.HumanMessageChunk, Tag(tag='HumanMessageChunk')], typing.Annotated[langchain_core.messages.chat.ChatMessageChunk, Tag(tag='ChatMessageChunk')], typing.Annotated[langchain_core.messages.system.SystemMessageChunk, Tag(tag='SystemMessageChunk')], typing.Annotated[langchain_core.m

AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_ztc2', 'function': {'arguments': '{"query":"select * from artist limit 10"}', 'name': 'db_query_tool'}, 'type': 'function'}]}, response_metadata={'token_usage': {'completion_tokens': 89, 'prompt_tokens': 1018, 'total_tokens': 1107, 'completion_time': 0.161818182, 'prompt_time': 0.035382068, 'queue_time': 0.242447857, 'total_time': 0.19720025}, 'model_name': 'gemma2-9b-it', 'system_fingerprint': 'fp_10c08bf97d', 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-06ce4423-90f5-4b1c-b43c-704f22e0ad79-0', tool_calls=[{'name': 'db_query_tool', 'args': {'query': 'select * from artist limit 10'}, 'id': 'call_ztc2', 'type': 'tool_call'}], usage_metadata={'input_tokens': 1018, 'output_tokens': 89, 'total_tokens': 1107})

In [8]:
from pydantic import BaseModel, Field
# Describe a tool to represent the end state
class SubmitFinalAnswer(BaseModel):
    """Submit the final answer to the user based on the query results."""

    final_answer: str = Field(..., description="The final answer to the user")


# Add a node for a model to generate a query based on the question and schema
query_gen_system = """You are a SQL expert with a strong attention to detail.

Given an input question, output a syntactically correct SQLite query to run, then look at the results of the query and return the answer.

DO NOT call any tool besides SubmitFinalAnswer to submit the final answer.

When generating the query:

Output the SQL query that answers the input question without a tool call.

Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most 5 results.
You can order the results by a relevant column to return the most interesting examples in the database.
Never query for all the columns from a specific table, only ask for the relevant columns given the question.

If you get an error while executing a query, rewrite the query and try again.

If you get an empty result set, you should try to rewrite the query to get a non-empty result set. 
NEVER make stuff up if you don't have enough information to answer the query... just say you don't have enough information.

If you have enough information to answer the input question, simply invoke the appropriate tool to submit the final answer to the user.

DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database."""
query_gen_prompt = ChatPromptTemplate.from_messages(
    [("system", query_gen_system), ("placeholder", "{messages}")]
)
query_gen = query_gen_prompt | llm.bind_tools(
    [SubmitFinalAnswer]
)

### Create nodes

In [9]:
from typing import Annotated,Literal

from langchain_core.messages import AIMessage
from pydantic import BaseModel,Field
from typing_extensions import TypedDict

from langgraph.graph import END,StateGraph,START
from langgraph.graph.message import AnyMessage,add_messages


In [10]:
class State(TypedDict):
    messages: Annotated[list[AnyMessage], add_messages]


workflow = StateGraph(State)

In [11]:
# Add node for first tool call
def first_tool_call(state: State) -> dict[str, list[AIMessage]]:
    return {"messages" : [
                AIMessage(
                    content = "",
                    tool_calls=[
                        {
                            "name" : "sql_db_list_tables",
                            "args": {},
                            "id":"tool_abcd123"
                        }
                        ]
                )
        ]
    }

In [12]:
def model_check_query(state: State) -> dict[str, list[AIMessage]]:
    """
    Use this tool to double-check if your query is correct before executing it.
    """
    return {"messages": [query_check.invoke({"messages": [state["messages"][-1]]})]}

In [13]:
# node to choose the relevant tables based on the question and available tables
model_get_schema = llm.bind_tools([get_schema_tool])

In [14]:
def query_gen_node(state: State):
    message = query_gen.invoke(state)

    # Sometimes, the LLM will hallucinate and call the wrong tool. We need to catch this and return an error message.
    tool_messages = []
    if message.tool_calls:
        for tc in message.tool_calls:
            if tc["name"] != "SubmitFinalAnswer":
                tool_messages.append(
                    ToolMessage(
                        content=f"Error: The wrong tool was called: {tc['name']}. Please fix your mistakes. Remember to only call SubmitFinalAnswer to submit the final answer. Generated queries should be outputted WITHOUT a tool call.",
                        tool_call_id=tc["id"],
                    )
                )
    else:
        tool_messages = []
    return {"messages": [message] + tool_messages}

### Create Workflow

In [15]:
workflow.add_node("first_tool_call", first_tool_call)
workflow.add_node(
    "list_tables_tool", create_tool_node_with_fallback([list_tables_tool])
)
workflow.add_node("get_schema_tool", create_tool_node_with_fallback([get_schema_tool]))

workflow.add_node(
    "model_get_schema",
    lambda state: {
        "messages": [model_get_schema.invoke(state["messages"])],
    },
)

workflow.add_node("query_gen", query_gen_node)

# Add a node for the model to check the query before executing it
workflow.add_node("correct_query", model_check_query)

# Add node for executing the query
workflow.add_node("execute_query", create_tool_node_with_fallback([db_query_tool]))


<langgraph.graph.state.StateGraph at 0x11109f510>

In [16]:
# Define a conditional edge to decide whether to continue or end the workflow
def should_continue(state: State) -> Literal[END, "correct_query", "query_gen"]:
    messages = state["messages"]
    last_message = messages[-1]
    # If there is a tool call, then we finish
    if getattr(last_message, "tool_calls", None):
        return END
    if last_message.content.startswith("Error:"):
        return "query_gen"
    else:
        return "correct_query"

In [17]:
# Specify the edges between the nodes
workflow.add_edge(START, "first_tool_call")
workflow.add_edge("first_tool_call", "list_tables_tool")
workflow.add_edge("list_tables_tool", "model_get_schema")
workflow.add_edge("model_get_schema", "get_schema_tool")
workflow.add_edge("get_schema_tool", "query_gen")
workflow.add_conditional_edges(
    "query_gen",
    should_continue,
)
workflow.add_edge("correct_query", "execute_query")
workflow.add_edge("execute_query", "query_gen")
# workflow.add_edge("query_gen", "execute_query")

# Compile the workflow into a runnable
app = workflow.compile()

In [18]:
messages = app.invoke(
    {"messages": [("user", "Which sales agent made most sales in 2009?")]}
)
json_str = messages["messages"][-1].tool_calls[0]["args"]["final_answer"]
json_str

IndexError: list index out of range

In [19]:
for event in app.stream(
    {"messages": [("user", "Which sales agent made most sales in 2009?")]}
):
    print(event)

{'first_tool_call': {'messages': [AIMessage(content='', additional_kwargs={}, response_metadata={}, id='3dedb0fc-e742-4941-b17a-144878d54578', tool_calls=[{'name': 'sql_db_list_tables', 'args': {}, 'id': 'tool_abcd123', 'type': 'tool_call'}])]}}
{'list_tables_tool': {'messages': [ToolMessage(content='album, artist, customer, employee, genre, invoice, invoice_line, media_type, playlist, playlist_track, track', name='sql_db_list_tables', id='26da9bc7-81ef-49bf-a589-06bf46782040', tool_call_id='tool_abcd123')]}}
{'model_get_schema': {'messages': [AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_kmzk', 'function': {'arguments': '{"table_names":"customer, employee, invoice, invoice_line"}', 'name': 'sql_db_schema'}, 'type': 'function'}]}, response_metadata={'token_usage': {'completion_tokens': 53, 'prompt_tokens': 1140, 'total_tokens': 1193, 'completion_time': 0.096363636, 'prompt_time': 0.052325623, 'queue_time': 0.23731869100000003, 'total_time': 0.148689259}, 'model_n