In [None]:
from __future__ import annotations

from pathlib import Path

import requests
from langchain.agents import create_agent
from langchain.chat_models import init_chat_model
from langchain.messages import HumanMessage
from langchain_community.agent_toolkits import SQLDatabaseToolkit
from langchain_community.utilities import SQLDatabase
from pydantic import BaseModel, Field

from chain_reaction.config import APIKeys, ModelBehavior, ModelName
from chain_reaction.utils import get_structured_response

# Configure database

In [None]:
# Define local data directory
path_parts = Path.cwd().parts
root_dir_index = path_parts.index("chain-reaction")
root_dir = Path(*path_parts[: root_dir_index + 1])
data_dir = root_dir / "data"
data_dir.mkdir(exist_ok=True)

# Define database path
db_path = data_dir / "Chinook.db"

# Download the database if it doesn't exist
if db_path.exists():
    print(f"{db_path.name} already exists, skipping download.")
else:
    url = "https://storage.googleapis.com/benchmarks-artifacts/chinook/Chinook.db"
    response = requests.get(url, timeout=10)
    if response.status_code == 200:
        db_path.write_bytes(response.content)
        print(f"File downloaded and saved as {db_path}")
    else:
        print(f"Failed to download the file. Status code: {response.status_code}")

# Initialize db engine with a read-only connection
db = SQLDatabase.from_uri(f"sqlite:///{db_path}?mode=ro")

print(f"Dialect: {db.dialect}")
print(f"Available tables: {db.get_usable_table_names()}")
print(f"Sample output: {db.run('SELECT * FROM Artist LIMIT 5;')}")  # noqa: S608

# Tools for database interactions

In [None]:
# Initialize a chat model
chat_model = init_chat_model(
    model=ModelName.CLAUDE_HAIKU,
    timeout=60,
    max_retries=2,
    api_key=APIKeys().anthropic,
    **ModelBehavior.factual().model_dump(),
)

In [None]:
# Initialize the SQL Database Toolkit with the database and chat model
toolkit = SQLDatabaseToolkit(db=db, llm=chat_model)

# Extract the tools from the toolkit instance
# NOTE: these tools are linked to the chat model and database defined above
sql_tools = toolkit.get_tools()

# Print out the available SQL tools
for tool_i in sql_tools:
    print(f"{tool_i.name}: {tool_i.description}\n")

# System prompt

In [None]:
system_prompt_template = """
You are an agent designed to interact with a SQL database.
Given an input question, create a syntactically correct {dialect} query to run,
then look at the results of the query and return the answer. Unless the user
specifies a specific number of examples they wish to obtain, always limit your
query to at most {top_k} results.

You can order the results by a relevant column to return the most interesting
examples in the database. Never query for all the columns from a specific table,
only ask for the relevant columns given the question.

You MUST double check your query before executing it. If you get an error while
executing a query, rewrite the query and try again.

DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the
database.

To start you should ALWAYS look at the tables in the database to see what you
can query. Do NOT skip this step.

Then you should query the schema of the most relevant tables.
"""

# Response format

In [None]:
class SQLAnswer(BaseModel):
    """Response schema for a question asked that requires a SQL query."""

    query: str = Field(description="The exact SQL query used to generate the answer.")
    answer: str | None = Field(
        description="""The natural language final answer based on the query results.
        It is ok to return None if the answer is best represented by the SQL query itself
        (i.e. an intermediate step in a larger analysis).
        """
    )
    num_rows: int = Field(description="The number of rows returned by the SQL query.", ge=0)
    num_cols: int = Field(description="The number of columns returned by the SQL query.", ge=0)

    @classmethod
    def get_result(cls, response: dict) -> SQLAnswer:
        """Parse result from agent response."""
        return get_structured_response(response, model=cls)

# Define agent

In [None]:
sql_agent = create_agent(
    chat_model,
    tools=sql_tools,
    system_prompt=system_prompt_template.format(dialect=db.dialect, top_k=5),
    response_format=SQLAnswer,
)

# Invoke agent

In [None]:
# Scalar response
response = sql_agent.invoke({"messages": [HumanMessage(content="Which genre on average has the longest tracks?")]})
SQLAnswer.get_result(response)

In [None]:
# Tabular response
response = sql_agent.invoke({
    "messages": [HumanMessage(content="What are the top 10 artists by user playlist appearances?")]
})
sql_answer = SQLAnswer.get_result(response)
print(sql_answer.answer)

In [None]:
# Intermediate analysis
response = sql_agent.invoke({
    "messages": [
        HumanMessage(
            content="Please collate the number of unique playlist per artist and album."
            " I'm going to use this result in downstream analysis."
        )
    ]
})
sql_answer = SQLAnswer.get_result(response)
sql_answer