## SQLAgent

In [553]:
import os
LANGCHAIN_TRACING_V2 = os.getenv("LANGCHAIN_TRACING_V2")
LANGCHAIN_ENDPOINT = os.getenv("LANGCHAIN_ENDPOINT")
LANGCHAIN_API_KEY = os.getenv("LANGCHAIN_API_KEY")
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
TAVILY_API_KEY = os.getenv("TAVILY_API_KEY")
LANGSMITH_API_KEY = os.getenv("LANGSMITH_API_KEY")

In [554]:
from langchain_community.utilities import SQLDatabase

db = SQLDatabase.from_uri("postgresql://kunalkothavade:postgres@localhost:5432/chinook")
# print(db.dialect)
# print(db.get_usable_table_names())
# db.run("SELECT * FROM Artist LIMIT 10")


### Utility functions

In [555]:
from typing import Any
from langchain_core.messages import ToolMessage
from langchain_core.runnables import RunnableLambda, RunnableWithFallbacks
from langgraph.prebuilt import ToolNode


def create_tool_node_with_fallback(tools: list) -> RunnableWithFallbacks[Any, dict]:
    """
    Create a ToolNode with a fallback to handle errors and surface them to the agent.
    """
    return ToolNode(tools).with_fallbacks(
        [RunnableLambda(handle_tool_error)], exception_key="error"
    )


def handle_tool_error(state) -> dict:
    error = state.get("error")
    tool_calls = state["messages"][-1].tool_calls
    return {
        "messages": [
            ToolMessage(
                content=f"Error: {repr(error)}\n please fix your mistakes.",
                tool_call_id=tc["id"],
            )
            for tc in tool_calls
        ]
    }


## LLM

In [556]:
from langchain_groq import ChatGroq
llm =  ChatGroq(temperature=0, model_name="gemma2-9b-it", groq_api_key=GROQ_API_KEY)
# llm.invoke("What is sqlagent").content

### Tools

In [557]:
from langchain_community.agent_toolkits import SQLDatabaseToolkit

toolkit = SQLDatabaseToolkit(db=db,llm=llm)
tools = toolkit.get_tools()

list_tables_tool = next(tool for tool in tools if tool.name == 'sql_db_list_tables')
get_schema_tool = next(tool for tool in tools if tool.name == 'sql_db_schema')

# print(list_tables_tool.invoke(""))
# print("==========================")
# print(get_schema_tool.invoke("artist"))

In [558]:
# Manually created tool

from langchain_core.tools import tool

@tool
def db_query_tool(query :str) -> str:
    """
    Execute a SQL query against the database and get back the result.
    If the query is not correct, an error message will be returned.
    If an error is returned, rewrite the query, check the query, and try again.
    """

    result = db.run_no_throw(query)

    if not result:
        return "No result: Query failed. Please rewrite your query and try again."
    return result
# print(db_query_tool.invoke("select * from artist limit 10"))
# print("==============================")
# print(db_query_tool.invoke("select * from artist where artist_id = 0"))

### Nodes

In [559]:
from typing import Annotated,Literal
from typing_extensions import TypedDict
from langgraph.graph.message import AnyMessage,add_messages


In [560]:
class State(TypedDict):
    messages: Annotated[list[AnyMessage], add_messages]
    query: str
    tables: list[str]
    schema_info: str
    sqlquery: str
    result: str | list | None
    error: str | None
    retry_count: int
    max_retries: int

In [561]:
def identify_tables_node(state:State):
    question = state["query"]
    all_tables = list_tables_tool.run("")
    prompt = f"""
                You are given a list of tables in a Postgres database: {all_tables}

                Based on the following user question: "{question}", select which tables are most relevant to answer it.

                Only Return a Python list of table names, like ['customers', 'sales'] and nothing else
            """
    relevant_tables = llm.invoke([*state.get("messages",[]), prompt]).content
    try:
        state["tables"] = eval(relevant_tables)
    except:
        state["tables"] = []

    state["messages"] = state.get("messages",[]) + [{"role":"assistant","content":relevant_tables}]
    return state


def get_schema_node(state:State):
    schemas = []
    for table in state['tables']:
        schema = get_schema_tool.invoke(table)
        schemas.append(schema)
    state["schema_info"] = "\n".join(schemas)
    state["messages"] = state.get("messages",[]) + [{"role":"assistant","content":state["schema_info"]}]
    return state

def gen_sqlquery_node(state:State):
    prompt = f"""You are a SQL expert with a strong attention to detail.

                Given an input question, output a syntactically correct PostgreSQL query.

                When generating the query:

                Output the SQL query that answers the input question.

                Use the provided schema info to generate the SQL query.

                Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most 5 results.
                You can order the results by a relevant column to return the most interesting examples in the database.
                Never query for all the columns from a specific table, only ask for the relevant columns given the question.

                DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.
                Do not add additional string like ```sql before the sql query just output the executable sql query.
                
                Double check the Postgresql query for common mistakes, including:
                - Using NOT IN with NULL values
                - Using UNION when UNION ALL should have been used
                - Using BETWEEN for exclusive ranges
                - Data type mismatch in predicates
                - Properly quoting identifiers
                - Using the correct number of arguments for functions
                - Casting to the correct data type
                - Using the proper columns for joins

                Question : {state['query']} 
                Schema Info : {state['schema_info']}
                """
    msg = {"role":"user", "content":prompt}
    llm_response = llm.invoke([*state.get("messages",[]),msg])

    state["messages"] = state.get("messages",[]) + [{"role":"assistant","content":llm_response.content}]
    state["sqlquery"] = llm_response.content
    return state

def execute_sql_node(state:State):
    try:
        state["result"] = db_query_tool.invoke(state["sqlquery"])
        state["error"] = None
    except Exception as e:
        state["result"] = None
        state["error"] = str(e)
    
    msg = {"role":"assistant", "content": f"Execution result: {state['result'] or state['error']}"}
    state["messages"] = state.get("messages",[]) + [msg]
    return state


def fix_sqlquery_node(state:State):
    prompt = f"""Fix the PostgreSQL query below and return the corrected query only and nothing else:
                Error : {state['result']}
                Corrected query: ""
            """
    msg = {"role":"user", "content": prompt}

    llm_response = llm.invoke([*state.get("messages",[]), msg])
    state["messages"] = state.get("messages",[]) + [{"role":"assistant","content":llm_response.content}]
    state["sqlquery"] = llm_response.content
    return state


### Create Workflow

In [562]:
from langgraph.graph import END,StateGraph,START
from langchain_core.messages import AIMessage
from pydantic import BaseModel,Field

In [563]:
workflow = StateGraph(State)

workflow.add_node("IdentifyTables", identify_tables_node)
workflow.add_node("GetSchema", get_schema_node)
workflow.add_node("GenerateSQL", gen_sqlquery_node)
workflow.add_node("ExecuteSQL", execute_sql_node)
workflow.add_node("FixSQL", fix_sqlquery_node)

workflow.set_entry_point("IdentifyTables")

workflow.add_edge("IdentifyTables", "GetSchema")
workflow.add_edge("GetSchema", "GenerateSQL")
workflow.add_edge("GenerateSQL","ExecuteSQL")

def handle_execution(state):
    if state.get("result","").startswith("Error:") and state.get("retry_count", 0) < state.get("max_retries", 2):
        state["retry_count"] += 1
        return "FixSQL"
    return END

workflow.add_conditional_edges("ExecuteSQL", handle_execution)
workflow.add_edge("FixSQL", "ExecuteSQL")

<langgraph.graph.state.StateGraph at 0x324cf8f90>

In [564]:

app = workflow.compile()
initial_state = {
    "query": "Which sales agent made most sales in 2009?",
    "retry_count" : 0,
    "max_retries": 2,
    "messages" : [
        {"role":"user", 
        "content":"Which sales agent made most sales in 2009?"}
        ]
}
final_state = app.invoke(initial_state)



Wrong query: SELECT e.first_name, e.last_name 
FROM employee e
JOIN invoice_line il ON e.employee_id = il.invoice_id
WHERE EXTRACT(YEAR FROM il.invoice_date) = 2009
GROUP BY e.first_name, e.last_name
ORDER BY COUNT(il.invoice_id) DESC
LIMIT 5; 

Fix the PostgreSQL query below and return the corrected query only and nothing else:
                Error : Error: (psycopg2.errors.UndefinedColumn) column il.invoice_date does not exist
LINE 4: WHERE EXTRACT(YEAR FROM il.invoice_date) = 2009
                                ^

[SQL: SELECT e.first_name, e.last_name 
FROM employee e
JOIN invoice_line il ON e.employee_id = il.invoice_id
WHERE EXTRACT(YEAR FROM il.invoice_date) = 2009
GROUP BY e.first_name, e.last_name
ORDER BY COUNT(il.invoice_id) DESC
LIMIT 5; 
]
(Background on this error at: https://sqlalche.me/e/20/f405)
                Corrected query: ""
            
Updated query: SELECT e.first_name, e.last_name 
FROM employee e
JOIN invoice_line il ON e.employee_id = il.invoice_id
WHERE 

GraphRecursionError: Recursion limit of 25 reached without hitting a stop condition. You can increase the limit by setting the `recursion_limit` config key.
For troubleshooting, visit: https://python.langchain.com/docs/troubleshooting/errors/GRAPH_RECURSION_LIMIT

In [565]:
for event in app.stream(initial_state):
    print(event)

{'IdentifyTables': {'messages': [HumanMessage(content='Which sales agent made most sales in 2009?', additional_kwargs={}, response_metadata={}, id='06c6ab04-a925-4cdc-b434-229d40e1608c'), {'role': 'assistant', 'content': "['employee', 'invoice', 'invoice_line'] \n"}], 'query': 'Which sales agent made most sales in 2009?', 'retry_count': 0, 'max_retries': 2, 'tables': ['employee', 'invoice', 'invoice_line']}}
{'GetSchema': {'messages': [HumanMessage(content='Which sales agent made most sales in 2009?', additional_kwargs={}, response_metadata={}, id='06c6ab04-a925-4cdc-b434-229d40e1608c'), AIMessage(content="['employee', 'invoice', 'invoice_line'] \n", additional_kwargs={}, response_metadata={}, id='f608b3fa-6d5d-45f0-af73-b3bc5201512f'), {'role': 'assistant', 'content': '\nCREATE TABLE employee (\n\temployee_id INTEGER NOT NULL, \n\tlast_name VARCHAR(20) NOT NULL, \n\tfirst_name VARCHAR(20) NOT NULL, \n\ttitle VARCHAR(30), \n\treports_to INTEGER, \n\tbirth_date TIMESTAMP WITHOUT TIME ZO