In [6]:
from langchain_community.utilities import SQLDatabase
from langchain_community.agent_toolkits import SQLDatabaseToolkit
from langchain_openai import ChatOpenAI
from langchain_core.tools import tool
from typing import Annotated, Literal
from langchain_core.messages import AIMessage
from pydantic import BaseModel, Field
from typing_extensions import TypedDict
from langgraph.graph import END, StateGraph, START
from langgraph.graph.message import AnyMessage, add_messages
from dotenv import load_dotenv
import requests
from typing import Any

# Load environment variables (if any)
load_dotenv()

# Download and save the Chinook database locally
url = "https://storage.googleapis.com/benchmarks-artifacts/chinook/Chinook.db"
response = requests.get(url)
if response.status_code == 200:
    with open("Chinook.db", "wb") as file:
        file.write(response.content)
    print("File downloaded and saved as Chinook.db")
else:
    print(f"Failed to download the file. Status code: {response.status_code}")

# Initialize the SQLDatabase instance
db = SQLDatabase.from_uri("sqlite:///Chinook.db")
print(db.dialect)
print(db.get_usable_table_names())

# Define a utility function to print SQL results and check output
def print_and_return_result(query):
    result = db.run(query)
    print(f"Query: {query}")
    print(f"Result: {result}")
    return result

# Validate SQL query results for common cases and errors
def validate_sql_result(query, expected_output):
    result = db.run(query)
    if result == expected_output:
        return True
    else:
        print(f"Validation failed for query: {query}")
        print(f"Expected: {expected_output}, Got: {result}")
        return False

# Error handler for SQL tool
def handle_tool_error(state) -> dict:
    error = state.get("error")
    tool_calls = state["messages"][-1].tool_calls
    return {
        "messages": [
            ToolMessage(
                content=f"Error: {repr(error)}\n please fix your mistakes.",
                tool_call_id=tc["id"],
            )
            for tc in tool_calls
        ]
    }

# Create a tool node with a fallback mechanism
def create_tool_node_with_fallback(tools: list) -> RunnableWithFallbacks[Any, dict]:
    return ToolNode(tools).with_fallbacks(
        [RunnableLambda(handle_tool_error)], exception_key="error"
    )

# Create SQLDatabaseToolkit and bind the tools
toolkit = SQLDatabaseToolkit(db=db, llm=ChatOpenAI(model="gpt-4o"))
tools = toolkit.get_tools()
list_tables_tool = next(tool for tool in tools if tool.name == "sql_db_list_tables")
get_schema_tool = next(tool for tool in tools if tool.name == "sql_db_schema")

# Query execution tool
@tool
def db_query_tool(query: str) -> str:
    """
    Execute a SQL query against the database and get back the result.
    If the query is not correct, an error message will be returned.
    If an error is returned, rewrite the query, check the query, and try again.
    """
    result = db.run_no_throw(query)
    if not result:
        return "Error: Query failed. Please rewrite your query and try again."
    return result

# Sample query to validate functionality
print(db_query_tool.invoke("SELECT * FROM Artist LIMIT 10;"))

from langchain_core.prompts import ChatPromptTemplate

# Query check system for SQL validation
query_check_system = """You are a SQL expert with a strong attention to detail.
Double-check the SQLite query for common mistakes, including:
- Using NOT IN with NULL values
- Using UNION when UNION ALL should have been used
- Using BETWEEN for exclusive ranges
- Data type mismatch in predicates
- Properly quoting identifiers
- Using the correct number of arguments for functions
- Casting to the correct data type
- Using the proper columns for joins

If there are mistakes, rewrite the query. If no mistakes, reproduce the original query."""

# Generate prompt and bind to the query tool
query_check_prompt = ChatPromptTemplate.from_messages(
    [("system", query_check_system), ("placeholder", "{messages}")]
)
query_check = query_check_prompt | ChatOpenAI(model="gpt-4o", temperature=0).bind_tools(
    [db_query_tool], tool_choice="required"
)

# Check and execute query
query_check.invoke({"messages": [("user", "SELECT * FROM Artist LIMIT 10;")]})

# Define agent state
class State(TypedDict):
    messages: Annotated[list[AnyMessage], add_messages]

# Create new StateGraph
workflow = StateGraph(State)

# Initial tool call for the agent
def first_tool_call(state: State) -> dict[str, list[AIMessage]]:
    return {
        "messages": [
            AIMessage(
                content="",
                tool_calls=[{"name": "sql_db_list_tables", "args": {}, "id": "tool_abcd123"}],
            )
        ]
    }

# Function to check query correctness
def model_check_query(state: State) -> dict[str, list[AIMessage]]:
    return {"messages": [query_check.invoke({"messages": [state["messages"][-1]]})]}

# Add nodes to workflow
workflow.add_node("first_tool_call", first_tool_call)
workflow.add_node("list_tables_tool", create_tool_node_with_fallback([list_tables_tool]))
workflow.add_node("get_schema_tool", create_tool_node_with_fallback([get_schema_tool]))

# Model for schema selection
model_get_schema = ChatOpenAI(model="gpt-4o", temperature=0).bind_tools([get_schema_tool])
workflow.add_node("model_get_schema", lambda state: {"messages": [model_get_schema.invoke(state["messages"])]})

# Final answer submission
class SubmitFinalAnswer(BaseModel):
    final_answer: str = Field(..., description="The final answer to the user")

# Generate a query based on question and schema
query_gen_system = """You are a SQL expert with strong attention to detail.
Generate a syntactically correct SQLite query, execute it, and return the results.

If an error occurs during query execution, rewrite and retry the query.
Never make up data, and avoid DML statements."""

query_gen_prompt = ChatPromptTemplate.from_messages(
    [("system", query_gen_system), ("placeholder", "{messages}")]
)
query_gen = query_gen_prompt | ChatOpenAI(model="gpt-4o", temperature=0).bind_tools(
    [SubmitFinalAnswer]
)

# Query generation node
def query_gen_node(state: State):
    message = query_gen.invoke(state)
    tool_messages = []
    if message.tool_calls:
        for tc in message.tool_calls:
            if tc["name"] != "SubmitFinalAnswer":
                tool_messages.append(
                    ToolMessage(
                        content=f"Error: Wrong tool called: {tc['name']}. Please fix the query and remember to call only SubmitFinalAnswer.",
                        tool_call_id=tc["id"],
                    )
                )
    return {"messages": [message] + tool_messages}

# Add nodes for query generation, correction, and execution
workflow.add_node("query_gen", query_gen_node)
workflow.add_node("correct_query", model_check_query)
workflow.add_node("execute_query", create_tool_node_with_fallback([db_query_tool]))

# Conditional flow to determine if the query needs correction or is complete
def should_continue(state: State) -> Literal[END, "correct_query", "query_gen"]:
    last_message = state["messages"][-1]
    if getattr(last_message, "tool_calls", None):
        return END
    if last_message.content.startswith("Error:"):
        return "query_gen"
    return "correct_query"

# Define edges between nodes
workflow.add_edge(START, "first_tool_call")
workflow.add_edge("first_tool_call", "list_tables_tool")
workflow.add_edge("list_tables_tool", "model_get_schema")
workflow.add_edge("model_get_schema", "get_schema_tool")
workflow.add_edge("get_schema_tool", "query_gen")
workflow.add_conditional_edges("query_gen", should_continue)
workflow.add_edge("correct_query", "execute_query")
workflow.add_edge("execute_query", "query_gen")

# Compile the workflow into a runnable application
app = workflow.compile()

# Run the app with sample user input
messages = app.invoke(
    {"messages": [("user", "Which sales agent made the most in sales in 2009?")]}
)
json_str = messages["messages"][-1].tool_calls[0]["args"]["final_answer"]
print(json_str)

# Stream the events as they occur
for event in app.stream(
    {"messages": [("user", "Which sales agent made the most in sales in 2009?")]}
):
    print(event)

import json

# Predict SQL agent answer for evaluation
def predict_sql_agent_answer(example: dict):
    msg = {"messages": ("user", example["input"])}
    messages = app.invoke(msg)
    json_str = messages["messages"][-1].tool_calls[0]["args"]
    response = json_str["final_answer"]
    return {"response": response}

# RAG-based evaluation for SQL agent
from langchain import hub
grade_prompt_answer_accuracy = hub.pull("langchain-ai/rag-answer-vs-reference")

# Answer evaluator using RAG
def answer_evaluator(run, example) -> dict:
    input_question = example.inputs["input"]
    reference = example.outputs["output"]
    prediction = run.outputs["response"]
    llm = ChatOpenAI(model="gpt-4-turbo", temperature=0)
    answer_grader = grade_prompt_answer_accuracy | llm
    score = answer_grader.invoke({
        "question": input_question,
        "correct_answer": reference,
        "student_answer": prediction,
    })
    return {"key": "answer_v_reference_score", "score": score["Score"]}

# Experiment evaluation with LangSmith
from langsmith.evaluation import evaluate

dataset_name = "SQL Agent Response"
try:
    experiment_results = evaluate(
        predict_sql_agent_answer,
        data=dataset_name,
        evaluators=[answer_evaluator],
        num_repetitions=3,
        experiment_prefix="sql-agent-multi-step-response-v-reference",
        metadata={"version": "Chinook, gpt-4o multi-step-agent"},
    )
except Exception as e:
    print(e)


File downloaded and saved as Chinook.db
sqlite
['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']
[(1, 'AC/DC'), (2, 'Accept'), (3, 'Aerosmith'), (4, 'Alanis Morissette'), (5, 'Alice In Chains'), (6, 'Antônio Carlos Jobim'), (7, 'Apocalyptica'), (8, 'Audioslave'), (9, 'BackBeat'), (10, 'Billy Cobham')]
The tables provided do not contain a direct link between sales agents and invoices. The "Employee" table lists employees, including sales agents, but does not directly associate them with sales data in the "Invoice" table. Therefore, it's not possible to determine which sales agent made the most in sales in 2009 with the given database schema.
{'first_tool_call': {'messages': [AIMessage(content='', additional_kwargs={}, response_metadata={}, id='bea66952-6c3c-48c8-acc1-3a999be638a3', tool_calls=[{'name': 'sql_db_list_tables', 'args': {}, 'id': 'tool_abcd123', 'type': 'tool_call'}])]}}
{'list_tables_tool': {'me

0it [00:00, ?it/s]