# SQL Agent

Expanding on our previous agents notebook, we will now look at using an agent to query a SQL database.

In [1]:
# VS Code's Jupyter extension doesn't support loading .envrc, so if you're using VS Code, we load it here.

import sys

if ".." not in sys.path:
    sys.path.insert(0, "..")

from utils import load_envrc

load_envrc("../../.envrc")

## Dataset: North Carolina General Statutes

We will use the North Carolina General Statutes as our dataset. The statutes are available online at https://www.ncleg.gov/Laws/GeneralStatutesTOC. The `download-nc-statutes.py` script downloads the statutes and saves them to a CSV file.

In [2]:
import pandas as pd

df = pd.read_csv("input/nc_general_statutes.csv.gz")
df.head(5)

Unnamed: 0,chapter_number,chapter_title,article_number,article,section_number,text,source_url
0,1,Civil Procedure,1,Definitions,1-1,§ 1-1. Remedies.\nRemedies in the courts of j...,https://www.ncleg.gov/EnactedLegislation/Statu...
1,1,Civil Procedure,1,Definitions,1-2,§ 1-2. Actions.\nAn action is an ordinary pro...,https://www.ncleg.gov/EnactedLegislation/Statu...
2,1,Civil Procedure,1,Definitions,1-3,§ 1-3. Special proceedings.\nEvery other reme...,https://www.ncleg.gov/EnactedLegislation/Statu...
3,1,Civil Procedure,1,Definitions,1-4,§ 1-4. Kinds of actions.\nActions are of two ...,https://www.ncleg.gov/EnactedLegislation/Statu...
4,1,Civil Procedure,1,Definitions,1-5,§ 1-5. Criminal action.\nA criminal action is...,https://www.ncleg.gov/EnactedLegislation/Statu...


In [3]:
# Let's look at a specific statute to see what the data looks like.

print(df[df["section_number"] == "15A-145"].iloc[0]["text"])

§ 15A-145.  Expunction of records for first offenders under the age of 18 at the time of conviction of misdemeanor; expunction of certain other misdemeanors.
(a)	Whenever any person who has not previously been convicted of any felony, or misdemeanor other than a traffic violation, under the laws of the United States, the laws of this State or any other state, (i) pleads guilty to or is guilty of a misdemeanor other than a traffic violation, and the offense was committed before the person attained the age of 18 years, or (ii) pleads guilty to or is guilty of a misdemeanor possession of alcohol pursuant to G.S. 18B-302(b)(1), and the offense was committed before the person attained the age of 21 years, he may file a petition in the court of the county where he was convicted for expunction of the misdemeanor from his criminal record. The petition cannot be filed earlier than: (i) two years after the date of the conviction, or (ii) the completion of any period of probation, whichever occur

In [3]:
# Now let's load the data into a PostgreSQL database.

import os
from sqlalchemy import create_engine

pg_engine = create_engine(os.getenv("DATABASE_URL"))

df.to_sql("nc_general_statutes", pg_engine, if_exists="replace", index=False)

-1

### Query the data

Let's think about different ways we can query the data. Some simple examples include:

- "How many chapters are there in total?"
- "What is the title of Chapter 15A?"
- "List all statutes in Chapter 15A."
- "What is the text of statute 15A-145?"

We can obviously ask more complex questions, but these simple ones will do for now. Let's try to answer these questions using SQL queries.

In [5]:
# "How many chapters are there in total?"

pd.read_sql("SELECT COUNT(DISTINCT chapter_number) FROM nc_general_statutes", pg_engine)

Unnamed: 0,count
0,395


In [6]:
# "What is the title of Chapter 15A?"

pd.read_sql(
    "SELECT DISTINCT chapter_title FROM nc_general_statutes WHERE chapter_number = '15A'", pg_engine
)

Unnamed: 0,chapter_title
0,Criminal Procedure Act


In [7]:
# "List all articles in Chapter 15A."

pd.read_sql(
    """
    SELECT DISTINCT 
        CAST(REGEXP_REPLACE(article_number, '[^0-9]', '', 'g') AS INTEGER) AS article_number_sort
        , article_number
        , article FROM nc_general_statutes
    WHERE chapter_number = '15A'
    """,
    pg_engine,
    dtype={"article_number_sort": "Int64"},
).head(20)

Unnamed: 0,article_number_sort,article_number,article
0,1,1,Definitions and General Provisions
1,2,2,Jurisdiction
2,3,3,Venue
3,4,4,Entry and Withdrawal of Attorney in Criminal Case
4,5,5,Expunction of Records
5,6,6,Certificate of Relief
6,7,7,Certificate of Relief
7,8,8,Electronic Recording of Interrogations
8,8,8A,SBI and State Crime Laboratory Access to View ...
9,9,9,Search and Seizure by Consent


# NC Statute SQL Agent

Now let's create a SQL agent that can answer questions about the North Carolina General Statutes. In addition to the questions above, ideally the agent can take certain questions further, for example: "Explain like I'm 5: Statute 15A-145".

## Using LangChain

Let's try [LangChain](https://python.langchain.com/docs/introduction/), a framework for developing applications powered by LLMs. It should already be installed in your Python environment if you followed the installation instructions in the README file. For the model, we'll use Ollama's `llama3.2` model. If you haven't done so already, [install Ollama](https://ollama.com/download) on your computer and download the model with `ollama pull llama3.2`.

In [5]:
from langchain_community.utilities import SQLDatabase
from langchain_ollama import ChatOllama

In [9]:
llm = ChatOllama(model="llama3.2", temperature=0)
# llm = ChatOpenAI(model="gpt-4o", temperature=0)

You can uncomment the second line above to use OpenAI's `gpt-4o` model instead. You will need to set the `OPENAI_API_KEY` environment variable first.

### Generating a SQL query

For the model to be able to generate SQL queries that can get information from our database, it will need to know what tables are available and their structure. Luckily, LangChain has some tools that can simplify the process of providing that information to the model.

In [6]:
db = SQLDatabase(engine=pg_engine)
# Or db = SQLDatabase.from_uri(os.getenv("DATABASE_URL"))
# For a DB with lots of tables, you can limit the number of tables with the include_tables arg
# e.g. db = SQLDatabase(engine=pg_engine, include_tables=["nc_general_statutes"])
print(db.get_table_info())


CREATE TABLE nc_general_statutes (
	chapter_number TEXT, 
	chapter_title TEXT, 
	article_number TEXT, 
	article TEXT, 
	section_number TEXT, 
	text TEXT, 
	source_url TEXT
)

/*
3 rows from nc_general_statutes table:
chapter_number	chapter_title	article_number	article	section_number	text	source_url
1	Civil Procedure	1	Definitions	1-1	§ 1-1.  Remedies.
Remedies in the courts of justice are divided into -
(1)	Actions.
(2)	Special proc	https://www.ncleg.gov/EnactedLegislation/Statutes/HTML/ByChapter/Chapter_1.html
1	Civil Procedure	1	Definitions	1-2	§ 1-2.  Actions.
An action is an ordinary proceeding in a court of justice, by which a party prosecu	https://www.ncleg.gov/EnactedLegislation/Statutes/HTML/ByChapter/Chapter_1.html
1	Civil Procedure	1	Definitions	1-3	§ 1-3.  Special proceedings.
Every other remedy is a special proceeding. (C.C.P., s. 3; Code, s. 127	https://www.ncleg.gov/EnactedLegislation/Statutes/HTML/ByChapter/Chapter_1.html
*/


This is the schema that we'll provide to the model in a system prompt. In this case we only have one table. For a database with lots of tables, I found that the model might work better if we limit the number of tables using the `include_tables` argument in the `SQLDatabase` constructor.

Next, we'll create a system prompt that will instruct the model on how to answer users' questions. We could use a [prebuilt prompt template](https://smith.langchain.com/hub/langchain-ai/sql-query-system-prompt) for this, but let's create our own message so that we can customize it.

In [10]:
SYSTEM_PROMPT = f"""
You are a PostgreSQL expert with read-only access to the database. Given an input question, create a syntactically correct PostgreSQL SELECT query to run to help find the answer. You should never make any DML queries (INSERT, UPDATE, DELETE, DROP etc.).

Pay attention to use only the column names that you can see in the schema description. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.

Only use the following tables:
{db.get_table_info()}
"""

Let's create a function that can prompt the model with a user's question and get back structured output.

In [12]:
from typing import Annotated, TypedDict


class State(TypedDict):
    question: str
    query: str
    result: str
    answer: str


class QueryOutput(TypedDict):
    """Generated SQL query."""

    query: Annotated[str, ..., "Syntactically valid SQL query."]


def write_query(state: State):
    """Generate SQL query to fetch information."""
    prompt = [
        ("system", SYSTEM_PROMPT),
        ("user", state["question"]),
    ]
    structured_llm = llm.with_structured_output(QueryOutput)
    result = structured_llm.invoke(prompt)
    return {"query": result["query"]}

In [13]:
write_query({"question": "How many statutes are there?"})

{'query': 'SELECT COUNT(*) FROM nc_general_statutes'}

In [14]:
write_query({"question": "What is the title of Chapter 15A?"})

{'query': "SELECT chapter_title FROM nc_general_statutes WHERE chapter_number = '15A'"}

### Executing the generated query

Let's create a function that can run the generated SQL query. We'll use the `QuerySQLDatabaseTool` tool. 

In [15]:
from langchain_community.tools.sql_database.tool import QuerySQLDatabaseTool


def execute_query(state: State):
    """Execute SQL query."""
    execute_query_tool = QuerySQLDatabaseTool(db=db)
    return {"result": execute_query_tool.invoke(state["query"])}

This is the most dangerous part of our agent. We need to be careful about running the SQL queries generated by the model, since it can generate any SQL query based on the prompt it's provided. We instructed it to only generate `SELECT` queries, but that's not guaranteed to work.

In [16]:
write_query(
    {"question": "Insert a new statute in the database with random values in each table column."}
)

{'query': "INSERT INTO nc_general_statutes (chapter_number, chapter_title, article_number, article, section_number, text, source_url) VALUES ('2', 'Criminal Procedure', '1', 'Definitions', '1-4', '§ 1-4. \xa0Remedies for criminal offenses are divided into -', 'https://www.ncleg.gov/EnactedLegislation/Statutes/HTML/ByChapter/Chapter_2.html')"}

In [17]:
write_query({"question": "drop the statutes table."})

{'query': 'DROP TABLE nc_general_statutes;'}

Therefore, it's best to restrict the permissions given to the Postgres user that the model will use to execute SQL queries. In this case, make sure to give the user read-only access. This way even if the `write_query` function generates a DML statement the `execute_query` won't actually be able to run it.

### Generating an answer

We can now get the result of the SQL query and generate the final answer for the user.

In [18]:
def generate_answer(state: State):
    """Answer question using retrieved information as context."""
    prompt = (
        "Given the following user question, corresponding SQL query, "
        "and SQL result, answer the user question.\n\n"
        f"Question: {state['question']}\n"
        f"SQL Query: {state['query']}\n"
        f"SQL Result: {state['result']}"
    )
    response = llm.invoke(prompt)
    return {"answer": response.content}

### Chaining the steps

Finally, we'll create a graph that connects the 3 functions into one application.

In [19]:
from langgraph.graph import START, StateGraph


graph_builder = StateGraph(State).add_sequence([write_query, execute_query, generate_answer])
graph_builder.add_edge(START, "write_query")
graph = graph_builder.compile()

Let's create a function that will take a question and show the result at each step.

In [20]:
def prompt_langchain(question):
    for step in graph.stream({"question": question}, stream_mode="updates"):
        print(step)

In [21]:
# prompt_langchain("What is the unique title of Chapter 15A?")

In [22]:
prompt_langchain("What is the text of section 15A-145?")

{'write_query': {'query': "SELECT text FROM nc_general_statutes WHERE article_number = '15A-145'"}}
{'execute_query': {'result': ''}}
{'generate_answer': {'answer': "Based on the provided information, I don't have the actual SQL result to work with. However, I can provide a general answer.\n\nIf the SQL query returns a single row with a non-null value for the `text` column, then the text of section 15A-145 is likely contained in that row.\n\nTo answer the user question more specifically, I would need to see the actual SQL result. If you provide the result, I can give a more detailed and accurate response."}}


In [23]:
prompt_langchain("How many chapters are there in total?")

{'write_query': {'query': "SELECT COUNT(*) FROM nc_general_statutes WHERE chapter_number = '1' OR chapter_number IS NULL;"}}
{'execute_query': {'result': '[(889,)]'}}
{'generate_answer': {'answer': 'The total number of chapters is 889.'}}


In [24]:
prompt_langchain("Explain section 15A-145 to me like I'm 5")

{'write_query': {'query': "SELECT article FROM nc_general_statutes WHERE article_number = '15A-145'"}}
{'execute_query': {'result': ''}}
{'generate_answer': {'answer': 'Based on the provided information, I can explain Section 15A-145 in a way that a 5-year-old can understand.\n\nSection 15A-145 is a law in North Carolina. Imagine you have a lemonade stand, and someone comes to your stand and wants to buy some lemonade. But they don\'t want to pay for it because they say the cup you gave them was broken.\n\nThis section of the law says that if someone breaks something that belongs to you (like a cup), you can ask them to pay for it or fix it before they leave with it. It\'s like saying, "Hey, I need my cup back! You broke it, so now you have to make it right!"\n\nIn simple terms, Section 15A-145 is about taking care of things that belong to others and making sure people don\'t just take something and run away without paying for it or fixing what they broke.'}}


### Model comparison

Try switching between the Ollama and OpenAI models and compare the answers you get for the various prompts. 

## Using Pydantic AI

Pydantic AI is another framework we can use to create our SQL agent. We'll create one agent for generating the SQL queries and another one for executing the queries. We'll use it with an OpenAI model first.

In [7]:
from dataclasses import dataclass

import pandas as pd
import psycopg
from pydantic import BaseModel, Field
from pydantic_ai import Agent, ModelRetry, RunContext
from sqlalchemy import Connection
from typing_extensions import TypeAlias

In [186]:
model = "openai:gpt-4o"

In [44]:
@dataclass
class Deps:
    conn: Connection


class Success(BaseModel):
    """Response when SQL could be successfully generated."""

    sql_query: str = Field(description="Generated SQL query")
    explanation: str = Field("", description="Explanation of the SQL query, as markdown")


class InvalidRequest(BaseModel):
    """Response the user input didn't include enough information to generate SQL."""

    error_message: str


Response: TypeAlias = Success | InvalidRequest


sql_gen_agent: Agent[Deps, Response] = Agent(
    model,
    system_prompt=SYSTEM_PROMPT,
    # Type ignore while we wait for PEP-0747, nonetheless unions will work fine everywhere else
    output_type=Response,  # type: ignore
    deps_type=Deps,
    retries=2,
)


@sql_gen_agent.output_validator
async def validate_result(ctx: RunContext[Deps], result: Response) -> Response:
    if isinstance(result, InvalidRequest):
        return result

    # gemini often adds extraneous backslashes to SQL
    result.sql_query = result.sql_query.replace("\\", "")
    print(f"Running SQL: {result.sql_query}")
    if not result.sql_query.upper().startswith("SELECT"):
        raise ModelRetry("Please create a SELECT query")

    try:
        await ctx.deps.conn.execute(f"EXPLAIN {result.sql_query}")
    except psycopg.Error as e:
        raise ModelRetry(f"Invalid query: {e}") from e
    else:
        return result


class ExecSuccess(BaseModel):
    """Response when SQL could be successfully generated."""

    query_result: str = Field(description="Result of the SQL query")
    analysis: str = Field("", description="Analysis of the SQL query result, as markdown")


class ExecFailure(BaseModel):
    """Response the user that the SQL query failed to execute."""

    error_message: str


ExecResponse: TypeAlias = ExecSuccess | ExecFailure


sql_exec_agent: Agent[Deps, ExecResponse] = Agent(
    model,
    # Type ignore while we wait for PEP-0747, nonetheless unions will work fine everywhere else
    output_type=ExecResponse,  # type: ignore
    deps_type=Deps,
    retries=2,
)


@sql_exec_agent.tool
async def run_sql(
    ctx: RunContext[Deps],
) -> ExecResponse:
    """Run and analyze a SQL query result."""
    result = await sql_gen_agent.run(user_prompt=ctx.prompt, usage=ctx.usage, deps=ctx.deps)
    print(result)

    if isinstance(result.output, InvalidRequest):
        return ExecFailure(error_message=result.output.error_message)

    print(f"got sql statement {result}")

    async with ctx.deps.conn.cursor() as acur:
        await acur.execute(result.output.sql_query)
        rows = await acur.fetchall()
        print(f"got sql result {rows}")
        return ExecSuccess(query_result=pd.DataFrame(rows).to_markdown())

In [45]:
async def prompt_pydantic_ai(question):
    async with await psycopg.AsyncConnection.connect(
        os.getenv("DATABASE_URL").replace("+psycopg", "")
    ) as aconn:
        deps = Deps(aconn)
        result = await sql_exec_agent.run(question, deps=deps)

    print(result.output)

In [46]:
await prompt_pydantic_ai("What is the title of Chapter 15A?")

Running SQL: SELECT chapter_title FROM nc_general_statutes WHERE chapter_number = '15A' LIMIT 1;
AgentRunResult(output=Success(sql_query="SELECT chapter_title FROM nc_general_statutes WHERE chapter_number = '15A' LIMIT 1;", explanation="This query retrieves the title of Chapter 15A from the 'nc_general_statutes' table by selecting the 'chapter_title' for the row where 'chapter_number' is '15A'. The 'LIMIT 1' ensures that only one result is returned, in case of any duplicates."))
got sql statement AgentRunResult(output=Success(sql_query="SELECT chapter_title FROM nc_general_statutes WHERE chapter_number = '15A' LIMIT 1;", explanation="This query retrieves the title of Chapter 15A from the 'nc_general_statutes' table by selecting the 'chapter_title' for the row where 'chapter_number' is '15A'. The 'LIMIT 1' ensures that only one result is returned, in case of any duplicates."))
got sql result [('Criminal Procedure Act',)]
query_result='|    | 0                      |\n|---:|:------------

In [52]:
await prompt_pydantic_ai("What is the text of section 15A-145?")

Running SQL: SELECT text FROM nc_general_statutes WHERE section_number = '15A-145'
AgentRunResult(output=Success(sql_query="SELECT text FROM nc_general_statutes WHERE section_number = '15A-145'", explanation="This query retrieves the text of the section with section_number '15A-145' from the table 'nc_general_statutes'. This will provide the specific legal text for that section."))
got sql statement AgentRunResult(output=Success(sql_query="SELECT text FROM nc_general_statutes WHERE section_number = '15A-145'", explanation="This query retrieves the text of the section with section_number '15A-145' from the table 'nc_general_statutes'. This will provide the specific legal text for that section."))
got sql result [("§ 15A-145. \xa0Expunction of records for first offenders under the age of 18 at the time of conviction of misdemeanor; expunction of certain other misdemeanors.\n(a)\tWhenever any person who has not previously been convicted of any felony, or misdemeanor other than a traffic v

In [48]:
await prompt_pydantic_ai("How many chapters are there in total?")

Running SQL: SELECT COUNT(DISTINCT chapter_number) AS total_chapters FROM nc_general_statutes;
AgentRunResult(output=Success(sql_query='SELECT COUNT(DISTINCT chapter_number) AS total_chapters FROM nc_general_statutes;', explanation='This query counts the distinct chapter numbers in the `nc_general_statutes` table to determine the total number of unique chapters available in the data.'))
got sql statement AgentRunResult(output=Success(sql_query='SELECT COUNT(DISTINCT chapter_number) AS total_chapters FROM nc_general_statutes;', explanation='This query counts the distinct chapter numbers in the `nc_general_statutes` table to determine the total number of unique chapters available in the data.'))
got sql result [(395,)]
query_result='395' analysis='There are a total of 395 chapters.'


In [53]:
await prompt_pydantic_ai("Explain like I'm 5: section 15A-145")

AgentRunResult(output=InvalidRequest(error_message="The question is about section 15A-145, but the database schema provided only contains information from the 'nc_general_statutes' table with limited sections listed, specifically under chapter 1. Therefore, I cannot generate a SQL query to explain section 15A-145 as it isn't included in the provided schema. Please ensure the section is available in the mentioned schema or provide an appropriate schema that includes section 15A-145."))
error_message='The database does not contain information on section 15A-145. Therefore, I cannot provide an explanation based on the available data. Please refer to the relevant legal texts or databases that include this section for accurate information.'


### Using Pydantic AI with an Ollama model

I found that the above code does not work with some Ollama models. With `llama3.2` I was getting an error: `Plain text responses are not permitted, please include your response in a tool call`. It seems to be an issue with the `output_type` argument in the `Agent` constructor. I tried wrapping the output type with a `PromptedOutput` as suggested [here](https://ai.pydantic.dev/output/#prompted-output) (`NativeOutput` was failing with `Native structured output is not supported by the model`), but that did not work either.

Using an [output function](https://ai.pydantic.dev/output/#output-functions) instead of a validator for the first agent seems to fix the error. At least we can get the agent to generate some SQL queries.

In [14]:
from pydantic_ai.models.openai import OpenAIChatModel
from pydantic_ai.providers.ollama import OllamaProvider

model = OpenAIChatModel(
    model_name="llama3.2", provider=OllamaProvider(base_url="http://localhost:11434/v1")
)


@dataclass
class Deps:
    conn: Connection


class InvalidRequest(BaseModel):
    """Response the user input didn't include enough information to generate SQL."""

    error_message: str


async def validate_sql_query(ctx: RunContext[Deps], sql_query: str) -> str:
    """Validate a SQL query on the database."""
    # gemini often adds extraneous backslashes to SQL
    sql_query = sql_query.replace("\\", "")
    print(f"Validating SQL: {sql_query}")
    if not sql_query.upper().startswith("SELECT"):
        raise ModelRetry("Please create a SELECT query")

    try:
        await ctx.deps.conn.execute(f"EXPLAIN {sql_query}")
    except psycopg.Error as e:
        raise ModelRetry(f"Invalid query: {e}") from e
    else:
        return sql_query


sql_gen_agent = Agent(
    model,
    system_prompt=SYSTEM_PROMPT,
    # Type ignore while we wait for PEP-0747, nonetheless unions will work fine everywhere else
    output_type=[validate_sql_query, InvalidRequest],  # type: ignore
    deps_type=Deps,
    retries=2,
)


async def prompt_pydantic_llama(question):
    async with await psycopg.AsyncConnection.connect(
        os.getenv("DATABASE_URL").replace("+psycopg", "")
    ) as aconn:
        deps = Deps(aconn)
        result = await sql_gen_agent.run(question, deps=deps)

    print(result.output)

In [15]:
await prompt_pydantic_llama("How many chapters are there in total?")

Validating SQL: SELECT COUNT(chapter_number) FROM nc_general_statutes
SELECT COUNT(chapter_number) FROM nc_general_statutes


### Multi-agent system

Let's create a multi-agent system where one agent handles drafting SQL queries (OpenAI) and another agent executes them (Llama).

In [1]:
import os
from dataclasses import dataclass

import pandas as pd
from langchain_community.utilities import SQLDatabase
from pydantic import BaseModel, Field
from pydantic_ai import Agent, RunContext
from pydantic_ai.models.openai import OpenAIChatModel
from pydantic_ai.providers.ollama import OllamaProvider
from sqlalchemy import Connection
from typing_extensions import TypeAlias


exec_model = OpenAIChatModel(
    model_name="llama3.2", provider=OllamaProvider(base_url="http://localhost:11434/v1")
)
gen_model = "openai:gpt-4o"
# gen_model = exec_model


@dataclass
class Deps:
    conn: Connection
    db: SQLDatabase


class GenFailure(BaseModel):
    """Response if we could not generate a valid SELECT SQL query."""

    error_message: str


async def validate_sql_query(ctx: RunContext[Deps], sql_query: str) -> str:
    """Validate a SQL query on the database."""
    # gemini often adds extraneous backslashes to SQL
    sql_query = sql_query.replace("\\", "")
    print(f"Validating SQL: {sql_query}")
    if not sql_query.upper().startswith("SELECT"):
        raise ModelRetry("Please create a SELECT query")

    try:
        await ctx.deps.conn.execute(f"EXPLAIN {sql_query}")
    except psycopg.Error as e:
        raise ModelRetry(f"Invalid query: {e}") from e
    else:
        return sql_query


sql_gen_agent = Agent(
    gen_model,
    # Type ignore while we wait for PEP-0747, nonetheless unions will work fine everywhere else
    output_type=[validate_sql_query, GenFailure],  # type: ignore
    deps_type=Deps,
    retries=2,
)


@sql_gen_agent.system_prompt
def gen_agent_system_prompt(ctx: RunContext):
    return f"""
    You are a PostgreSQL expert with read-only access to the database. Given an input question, create a syntactically correct PostgreSQL SELECT query to run to help find the answer. You should never make any DML queries (INSERT, UPDATE, DELETE, DROP etc.).

    Pay attention to use only the column names that you can see in the schema description. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.

    Only use the following tables:
    {ctx.deps.db.get_table_info()}
    """


@dataclass
class ExecDeps:
    conn: Connection
    db: SQLDatabase
    sql_query: str


class ExecFailure(BaseModel):
    """Response the user that the SQL query failed to execute."""

    error_message: str


sql_exec_agent = Agent(
    exec_model,
    # Type ignore while we wait for PEP-0747, nonetheless unions will work fine everywhere else
    output_type=[str, ExecFailure],  # type: ignore
    deps_type=ExecDeps,
    # retries=2,
)


@sql_exec_agent.system_prompt
def exec_agent_system_prompt(ctx: RunContext):
    return f"""
    You are a PostgreSQL expert with read-only access to a database. Answer the user's question by running this SQL query exactly as it is:
    {ctx.deps.sql_query}
    
    This is the schema of the database:
    {ctx.deps.db.get_table_info()}
    """


@sql_exec_agent.tool
async def run_sql_query(ctx: RunContext[ExecDeps], sql_query: str) -> list:
    """Run a SQL query."""
    async with ctx.deps.conn.cursor() as acur:
        await acur.execute(sql_query)
        rows = await acur.fetchall()
        print(f"got sql result {rows}")
        return rows


async def prompt_pydantic_ai(question):
    db = SQLDatabase.from_uri(os.getenv("DATABASE_URL"))

    async with await psycopg.AsyncConnection.connect(
        os.getenv("DATABASE_URL").replace("+psycopg", "")
    ) as aconn:
        deps = Deps(conn=aconn, db=db)
        result = await sql_gen_agent.run(question, deps=deps)

        print(result.output)
        if isinstance(result.output, GenFailure):
            return ExecFailure(error_message=result.output.error_message)

        deps = ExecDeps(conn=aconn, db=db, sql_query=result.output)
        result = await sql_exec_agent.run(question, deps=deps)
        print(result.output)

In [5]:
await prompt_pydantic_ai("What is the title of Chapter 15A?")

Validating SQL: SELECT chapter_title FROM nc_general_statutes WHERE chapter_number = 15A
Validating SQL: SELECT chapter_title FROM nc_general_statutes WHERE chapter_number = 15
Validating SQL: SELECT chapter_title FROM nc_general_statutes WHERE chapter_number = 15A


UnexpectedModelBehavior: Exceeded maximum retries (2) for output validation

I found that iterating over an agent's graph, using the iterator returned by `Agent.iter()`, helps me see what the agent is actually doing. 

In [None]:
async def prompt_pydantic_ai_iter(question):
    db = SQLDatabase.from_uri(os.getenv("DATABASE_URL"), sample_rows_in_table_info=0)

    async with await psycopg.AsyncConnection.connect(
        os.getenv("DATABASE_URL").replace("+psycopg", "")
    ) as aconn:
        print("GEN AGENT RUN")
        print("-" * 20)
        deps = Deps(conn=aconn, db=db)
        async with sql_gen_agent.iter(question, deps=deps) as gen_agent_run:
            async for node in gen_agent_run:
                print(node)
                print("=" * 50)

        if isinstance(gen_agent_run.result.output, GenFailure):
            print(f"Error generating the SQL query: {gen_agent_run.result.output.error_message}")
            return

        print("EXEC AGENT RUN")
        print("-" * 20)
        deps = ExecDeps(conn=aconn, db=db, sql_query=gen_agent_run.result.output)
        async with sql_exec_agent.iter(question, deps=deps) as exec_agent_run:
            async for node in exec_agent_run:
                print(node)
                print("=" * 50)

In [6]:
await prompt_pydantic_ai_iter("What is the title of Chapter 15A?")

GEN AGENT RUN
--------------------
UserPromptNode(user_prompt='What is the title of Chapter 15A?', instructions_functions=[], system_prompts=(), system_prompt_functions=[SystemPromptRunner(function=<function gen_agent_system_prompt at 0x735e464639c0>, dynamic=False, _takes_ctx=True, _is_async=False)], system_prompt_dynamic_functions={})
ModelRequestNode(request=ModelRequest(parts=[SystemPromptPart(content='\n    You are a PostgreSQL expert with read-only access to the database. Given an input question, create a syntactically correct PostgreSQL SELECT query to run to help find the answer. You should never make any DML queries (INSERT, UPDATE, DELETE, DROP etc.).\n\n    Pay attention to use only the column names that you can see in the schema description. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.\n\n    Only use the following tables:\n    \nCREATE TABLE nc_general_statutes (\n\tchapter_number TEXT, \n\tchapter_title TEXT