## SQLAgent

In [600]:
import os
LANGCHAIN_TRACING_V2 = os.getenv("LANGCHAIN_TRACING_V2")
LANGCHAIN_ENDPOINT = os.getenv("LANGCHAIN_ENDPOINT")
LANGCHAIN_API_KEY = os.getenv("LANGCHAIN_API_KEY")
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
# GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
GOOGLE_API_KEY = os.getenv("GEMINI_API_KEY")
TAVILY_API_KEY = os.getenv("TAVILY_API_KEY")
LANGSMITH_API_KEY = os.getenv("LANGSMITH_API_KEY")

In [601]:
from langchain_community.utilities import SQLDatabase

db = SQLDatabase.from_uri("postgresql://kunalkothavade:postgres@localhost:5432/chinook")
# print(db.dialect)
# print(db.get_usable_table_names())
# db.run("SELECT * FROM Artist LIMIT 10")


### Utility functions

In [602]:
from typing import Any
from langchain_core.messages import ToolMessage
from langchain_core.runnables import RunnableLambda, RunnableWithFallbacks
from langgraph.prebuilt import ToolNode


def create_tool_node_with_fallback(tools: list) -> RunnableWithFallbacks[Any, dict]:
    """
    Create a ToolNode with a fallback to handle errors and surface them to the agent.
    """
    return ToolNode(tools).with_fallbacks(
        [RunnableLambda(handle_tool_error)], exception_key="error"
    )


def handle_tool_error(state) -> dict:
    error = state.get("error")
    tool_calls = state["messages"][-1].tool_calls
    return {
        "messages": [
            ToolMessage(
                content=f"Error: {repr(error)}\n please fix your mistakes.",
                tool_call_id=tc["id"],
            )
            for tc in tool_calls
        ]
    }


## LLM

In [603]:
from langchain_groq import ChatGroq
from langchain_google_genai import ChatGoogleGenerativeAI
llm =  ChatGroq(temperature=0, model_name="gemma2-9b-it", groq_api_key=GROQ_API_KEY)
# llm = ChatGoogleGenerativeAI(model = 'gemini-2.0-flash',google_api_key = GOOGLE_API_KEY)
# llm.invoke("What is sqlagent").content

### Tools

In [604]:
from langchain_community.agent_toolkits import SQLDatabaseToolkit

toolkit = SQLDatabaseToolkit(db=db,llm=llm)
tools = toolkit.get_tools()

list_tables_tool = next(tool for tool in tools if tool.name == 'sql_db_list_tables')
get_schema_tool = next(tool for tool in tools if tool.name == 'sql_db_schema')

# print(list_tables_tool.invoke(""))
# print("==========================")
# print(get_schema_tool.invoke("artist"))

In [605]:
# Manually created tool

from langchain_core.tools import tool

@tool
def db_query_tool(query :str) -> str:
    """
    Execute a SQL query against the database and get back the result.
    If the query is not correct, an error message will be returned.
    If an error is returned, rewrite the query, check the query, and try again.
    """

    result = db.run_no_throw(query)
    if not result:
        return "Error: Query failed. Please rewrite your query and try again."
    return result
# print(db_query_tool.invoke("select * from artist limit 10"))
# print("==============================")
# print(db_query_tool.invoke("select * from artist where artist_id = 0"))

### Nodes

In [606]:
from typing import Annotated,Literal
from typing_extensions import TypedDict
from langgraph.graph.message import AnyMessage,add_messages


In [607]:
class State(TypedDict):
    messages: Annotated[list[AnyMessage], add_messages]
    query: str
    tables: list[str]
    schema_info: str
    sqlquery: str
    result: str | list | None
    error: str | None
    retry_count: int
    max_retries: int

In [608]:
from langchain_core.messages import AIMessage
from langchain_core.prompts import ChatPromptTemplate

# Add node for first tool call
def first_tool_call(state: State) -> dict[str, list[AIMessage]]:
    return {"messages" : [
                AIMessage(
                    content = "",
                    tool_calls=[
                        {
                            "name" : "sql_db_list_tables",
                            "args": {},
                            "id":"tool_abcd123"
                        }
                        ]
                )
        ]
    }

# def identify_tables_node(state:State):
#     question = state["query"]
#     all_tables = list_tables_tool.run("")
#     try:
#         state["all_tables"] = eval(all_tables)
#     except:
#         state["all_tables"] = []

#     state["messages"] = state.get("messages",[]) + [{"role":"assistant","content":all_tables}]
#     return state


# def get_schema_node(state:State):
#     schemas = []
#     for table in state['tables']:
#         schema = get_schema_tool.invoke(table)
#         schemas.append(schema)
#     state["schema_info"] = "\n".join(schemas)
#     state["messages"] = state.get("messages",[]) + [{"role":"assistant","content":state["schema_info"]}]
#     return state

def gen_sqlquery_node(state:State):
    '''
    generate postgresql query
    '''
    system_prompt = f"""You are a SQL expert with a strong attention to detail.

                Given an input question, output a syntactically correct PostgreSQL query.

                When generating the query:

                Output the SQL query that answers the input question.

                Use the provided schema info to generate the SQL query.

                Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most 5 results.
                You can order the results by a relevant column to return the most interesting examples in the database.
                Never query for all the columns from a specific table, only ask for the relevant columns given the question.

                DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.
                
                Double check the Postgresql query for common mistakes, including:
                - Using NOT IN with NULL values
                - Using UNION when UNION ALL should have been used
                - Using BETWEEN for exclusive ranges
                - Data type mismatch in predicates
                - Properly quoting identifiers
                - Using the correct number of arguments for functions
                - Casting to the correct data type
                - Using the proper columns for joins

                Always call the db_query_tool with the generated query.

                """
    query_gen_prompt = ChatPromptTemplate.from_messages(
        [("system", system_prompt), ("placeholder", "{messages}")]
    )
    query_gen = query_gen_prompt | llm.bind_tools([db_query_tool])
    message = query_gen.invoke(state)
    return {"messages":[message]}

# def execute_sql_node(state:State):
#     try:
#         state["result"] = db_query_tool.invoke(state["sqlquery"])
#         state["error"] = None
#     except Exception as e:
#         state["result"] = None
#         state["error"] = str(e)
    
#     msg = {"role":"assistant", "content": f"Execution result: {state['result'] or state['error']}"}
#     state["messages"] = state.get("messages",[]) + [msg]
#     return state


def fix_sqlquery_node(state:State):
    system_prompt = f"""Look at the error message and regenerate the correct PostGreSQL query. 
    Refer to the schema details again, if needed. 
    Ensure this time the SQL query is correct to answer the user question.
    Execute the db_query_tool once the sql query is generated.""
            """
    query_fix_prompt = ChatPromptTemplate.from_messages(
        [("system", system_prompt), ("placeholder", "{messages}")]
    )
    query_fix = query_fix_prompt | llm.bind_tools([db_query_tool],tool_choice='required')
    message = query_fix.invoke(state)
    
    return {"messages":[message]}


In [609]:
get_relevant_tables = llm.bind_tools([get_schema_tool])

### Create Workflow

In [610]:
from langgraph.graph import END,StateGraph,START
from langchain_core.messages import AIMessage
from pydantic import BaseModel,Field

In [611]:
workflow = StateGraph(State)

# nodes
workflow.add_node("first_tool_call", first_tool_call)
workflow.add_node("ListTables", create_tool_node_with_fallback([list_tables_tool]))
workflow.add_node(
    "IdentifyRelevantTables",
    lambda state: {
        "messages": [get_relevant_tables.invoke(state["messages"])],
    },
)
workflow.add_node("GetSchema", create_tool_node_with_fallback([get_schema_tool]))
workflow.add_node("GenerateSQL", gen_sqlquery_node)
workflow.add_node("ExecuteSQL", create_tool_node_with_fallback([db_query_tool]))
workflow.add_node("FixSQL", fix_sqlquery_node)

# edges

workflow.add_edge(START,"first_tool_call")
workflow.add_edge("first_tool_call", "ListTables")
workflow.add_edge("ListTables", "IdentifyRelevantTables")
workflow.add_edge("IdentifyRelevantTables", "GetSchema")
workflow.add_edge("GetSchema", "GenerateSQL")
workflow.add_edge("GenerateSQL","ExecuteSQL")

def handle_execution(state):
    if state.get("messages")[-1].content.startswith("Error:"):
        return "FixSQL"
    return END

workflow.add_conditional_edges("ExecuteSQL", handle_execution)
workflow.add_edge("FixSQL", "ExecuteSQL")

<langgraph.graph.state.StateGraph at 0x31b3e47d0>

In [612]:
initial_state = {
    "messages" : [
        {"role":"user", 
        "content":"Which sales agent made most sales in 2009?"}
        ]
}

In [614]:
app = workflow.compile()
for event in app.stream(initial_state):
    print(event)

{'first_tool_call': {'messages': [AIMessage(content='', additional_kwargs={}, response_metadata={}, id='72476b38-7a6e-4474-af6b-255c339d5637', tool_calls=[{'name': 'sql_db_list_tables', 'args': {}, 'id': 'tool_abcd123', 'type': 'tool_call'}])]}}
{'ListTables': {'messages': [ToolMessage(content='album, artist, customer, employee, genre, invoice, invoice_line, media_type, playlist, playlist_track, track', name='sql_db_list_tables', id='4bce17d6-d0fa-4776-afb0-75a69cb0d44b', tool_call_id='tool_abcd123')]}}
{'IdentifyRelevantTables': {'messages': [AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_6ecq', 'function': {'arguments': '{"table_names":"customer, employee, invoice, invoice_line"}', 'name': 'sql_db_schema'}, 'type': 'function'}]}, response_metadata={'token_usage': {'completion_tokens': 53, 'prompt_tokens': 1140, 'total_tokens': 1193, 'completion_time': 0.096363636, 'prompt_time': 0.048758205, 'queue_time': 0.237722504, 'total_time': 0.145121841}, 'model_name': 'g

In [615]:

final_state = app.invoke(initial_state)

Error: (psycopg2.errors.UndefinedColumn) column i.support_rep_id does not exist
LINE 1: ...FROM employee e JOIN invoice i ON e.employee_id = i.support_...
                                                             ^

[SQL: SELECT e.last_name, e.first_name, COUNT(i.invoice_id) AS num_invoices FROM employee e JOIN invoice i ON e.employee_id = i.support_rep_id GROUP BY e.last_name, e.first_name ORDER BY num_invoices DESC LIMIT 5; ]
(Background on this error at: https://sqlalche.me/e/20/f405)
Error: (psycopg2.errors.UndefinedColumn) column i.support_rep_id does not exist
LINE 1: ...FROM employee e JOIN invoice i ON e.employee_id = i.support_...
                                                             ^

[SQL: SELECT e.last_name, e.first_name, COUNT(i.invoice_id) AS num_invoices FROM employee e JOIN invoice i ON e.employee_id = i.support_rep_id GROUP BY e.last_name, e.first_name ORDER BY num_invoices DESC LIMIT 1; ]
(Background on this error at: https://sqlalche.me/e/20/f405)
[('Peaco