# SQL Agent

Expanding on our previous agents notebook, we will now look at using an agent to query a SQL database.

In [1]:
# VS Code's Jupyter extension doesn't support loading .envrc, so if you're using VS Code, we load it here.

import sys

if ".." not in sys.path:
    sys.path.insert(0, "..")

from utils import load_envrc

load_envrc("../../.envrc")

## Dataset: North Carolina General Statutes

We will use the North Carolina General Statutes as our dataset. The statutes are available online at https://www.ncleg.gov/Laws/GeneralStatutesTOC. The `download-nc-statutes.py` script downloads the statutes and saves them to a CSV file.

In [2]:
import pandas as pd

df = pd.read_csv("input/nc_general_statutes.csv.gz")
df.head(5)

Unnamed: 0,chapter_number,chapter_title,article_number,article,section_number,text,source_url
0,1,Civil Procedure,1,Definitions,1-1,§ 1-1. Remedies.\nRemedies in the courts of j...,https://www.ncleg.gov/EnactedLegislation/Statu...
1,1,Civil Procedure,1,Definitions,1-2,§ 1-2. Actions.\nAn action is an ordinary pro...,https://www.ncleg.gov/EnactedLegislation/Statu...
2,1,Civil Procedure,1,Definitions,1-3,§ 1-3. Special proceedings.\nEvery other reme...,https://www.ncleg.gov/EnactedLegislation/Statu...
3,1,Civil Procedure,1,Definitions,1-4,§ 1-4. Kinds of actions.\nActions are of two ...,https://www.ncleg.gov/EnactedLegislation/Statu...
4,1,Civil Procedure,1,Definitions,1-5,§ 1-5. Criminal action.\nA criminal action is...,https://www.ncleg.gov/EnactedLegislation/Statu...


In [3]:
# Let's look at a specific statute to see what the data looks like.

print(df[df["section_number"] == "15A-145"].iloc[0]["text"])

§ 15A-145.  Expunction of records for first offenders under the age of 18 at the time of conviction of misdemeanor; expunction of certain other misdemeanors.
(a)	Whenever any person who has not previously been convicted of any felony, or misdemeanor other than a traffic violation, under the laws of the United States, the laws of this State or any other state, (i) pleads guilty to or is guilty of a misdemeanor other than a traffic violation, and the offense was committed before the person attained the age of 18 years, or (ii) pleads guilty to or is guilty of a misdemeanor possession of alcohol pursuant to G.S. 18B-302(b)(1), and the offense was committed before the person attained the age of 21 years, he may file a petition in the court of the county where he was convicted for expunction of the misdemeanor from his criminal record. The petition cannot be filed earlier than: (i) two years after the date of the conviction, or (ii) the completion of any period of probation, whichever occur

In [4]:
# Now let's load the data into a PostgreSQL database.

import os
from sqlalchemy import create_engine

pg_engine = create_engine(os.getenv("DATABASE_URL"))

df.to_sql("nc_general_statutes", pg_engine, if_exists="replace", index=False)

-1

### Query the data

Let's think about different ways we can query the data. Some simple examples include:

- "How many chapters are there in total?"
- "What is the title of Chapter 15A?"
- "List all statutes in Chapter 15A."
- "What is the text of statute 15A-145?"

We can obviously ask more complex questions, but these simple ones will do for now. Let's try to answer these questions using SQL queries.

In [5]:
# "How many chapters are there in total?"

pd.read_sql("SELECT COUNT(DISTINCT chapter_number) FROM nc_general_statutes", pg_engine)

Unnamed: 0,count
0,395


In [6]:
# "What is the title of Chapter 15A?"

pd.read_sql(
    "SELECT DISTINCT chapter_title FROM nc_general_statutes WHERE chapter_number = '15A'", pg_engine
)

Unnamed: 0,chapter_title
0,Criminal Procedure Act


In [7]:
# "List all articles in Chapter 15A."

pd.read_sql(
    """
    SELECT DISTINCT 
        CAST(REGEXP_REPLACE(article_number, '[^0-9]', '', 'g') AS INTEGER) AS article_number_sort
        , article_number
        , article FROM nc_general_statutes
    WHERE chapter_number = '15A'
    """,
    pg_engine,
    dtype={"article_number_sort": "Int64"},
).head(20)

Unnamed: 0,article_number_sort,article_number,article
0,1,1,Definitions and General Provisions
1,2,2,Jurisdiction
2,3,3,Venue
3,4,4,Entry and Withdrawal of Attorney in Criminal Case
4,5,5,Expunction of Records
5,6,6,Certificate of Relief
6,7,7,Certificate of Relief
7,8,8,Electronic Recording of Interrogations
8,8,8A,SBI and State Crime Laboratory Access to View ...
9,9,9,Search and Seizure by Consent


# NC Statute SQL Agent

Now let's create a SQL agent that can answer questions about the North Carolina General Statutes. In addition to the questions above, ideally the agent can take certain questions further, for example: "Explain like I'm 5: Statute 15A-145".

Pydantic AI is another framework we can use to create our SQL agent. It should already be installed in your Python environment if you followed the installation instructions in the README file. We'll create one agent for generating the SQL queries and another one for executing the queries. We'll use it with an OpenAI model first.

In [8]:
from dataclasses import dataclass

import pandas as pd
import psycopg
from pydantic import BaseModel, Field
from pydantic_ai import Agent, ModelRetry, RunContext
from sqlalchemy import Connection
from typing_extensions import TypeAlias

In [9]:
model = "openai:gpt-4o"

We'll use LangChain's `SQLDatabase` tool to generate the database schema that we will use in the system instructions, like we did in the LangChain notebook. It's however possible to copy the code for the tool out of LangChain so that we don't have to install LangChain just for that.

In [10]:
from langchain_community.utilities import SQLDatabase

db = SQLDatabase(engine=pg_engine)
SYSTEM_PROMPT = f"""
You are a PostgreSQL expert with read-only access to the database. Given an input question, create a syntactically correct PostgreSQL SELECT query to run to help find the answer. You should never make any DML queries (INSERT, UPDATE, DELETE, DROP etc.).

Pay attention to use only the column names that you can see in the schema description. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.

Only use the following tables:
{db.get_table_info()}
"""

In [11]:
@dataclass
class Deps:
    conn: Connection


class Success(BaseModel):
    """Response when SQL could be successfully generated."""

    sql_query: str = Field(description="Generated SQL query")
    explanation: str = Field("", description="Explanation of the SQL query, as markdown")


class InvalidRequest(BaseModel):
    """Response the user input didn't include enough information to generate SQL."""

    error_message: str


Response: TypeAlias = Success | InvalidRequest


sql_gen_agent: Agent[Deps, Response] = Agent(
    model,
    system_prompt=SYSTEM_PROMPT,
    # Type ignore while we wait for PEP-0747, nonetheless unions will work fine everywhere else
    output_type=Response,  # type: ignore
    deps_type=Deps,
    retries=2,
)


@sql_gen_agent.output_validator
async def validate_result(ctx: RunContext[Deps], result: Response) -> Response:
    if isinstance(result, InvalidRequest):
        return result

    # gemini often adds extraneous backslashes to SQL
    result.sql_query = result.sql_query.replace("\\", "")
    print(f"Running SQL: {result.sql_query}")
    if not result.sql_query.upper().startswith("SELECT"):
        raise ModelRetry("Please create a SELECT query")

    try:
        await ctx.deps.conn.execute(f"EXPLAIN {result.sql_query}")
    except psycopg.Error as e:
        raise ModelRetry(f"Invalid query: {e}") from e
    else:
        return result


class ExecSuccess(BaseModel):
    """Response when SQL could be successfully generated."""

    query_result: str = Field(description="Result of the SQL query")
    analysis: str = Field("", description="Analysis of the SQL query result, as markdown")


class ExecFailure(BaseModel):
    """Response the user that the SQL query failed to execute."""

    error_message: str


ExecResponse: TypeAlias = ExecSuccess | ExecFailure


sql_exec_agent: Agent[Deps, ExecResponse] = Agent(
    model,
    # Type ignore while we wait for PEP-0747, nonetheless unions will work fine everywhere else
    output_type=ExecResponse,  # type: ignore
    deps_type=Deps,
    retries=2,
)


@sql_exec_agent.tool
async def run_sql(
    ctx: RunContext[Deps],
) -> ExecResponse:
    """Run and analyze a SQL query result."""
    result = await sql_gen_agent.run(user_prompt=ctx.prompt, usage=ctx.usage, deps=ctx.deps)
    print(result)

    if isinstance(result.output, InvalidRequest):
        return ExecFailure(error_message=result.output.error_message)

    print(f"got sql statement {result}")

    async with ctx.deps.conn.cursor() as acur:
        await acur.execute(result.output.sql_query)
        rows = await acur.fetchall()
        print(f"got sql result {rows}")
        return ExecSuccess(query_result=pd.DataFrame(rows).to_markdown())

In [12]:
async def prompt_pydantic_ai(question):
    async with await psycopg.AsyncConnection.connect(
        os.getenv("DATABASE_URL").replace("+psycopg", "")
    ) as aconn:
        deps = Deps(aconn)
        result = await sql_exec_agent.run(question, deps=deps)

    print(result.output)

In [13]:
await prompt_pydantic_ai("What is the title of Chapter 15A?")

Running SQL: SELECT DISTINCT chapter_title FROM nc_general_statutes WHERE chapter_number = '15A';
AgentRunResult(output=Success(sql_query="SELECT DISTINCT chapter_title FROM nc_general_statutes WHERE chapter_number = '15A';", explanation='This query retrieves the distinct chapter title for Chapter 15A from the `nc_general_statutes` table.'))
got sql statement AgentRunResult(output=Success(sql_query="SELECT DISTINCT chapter_title FROM nc_general_statutes WHERE chapter_number = '15A';", explanation='This query retrieves the distinct chapter title for Chapter 15A from the `nc_general_statutes` table.'))
got sql result [('Criminal Procedure Act',)]
Running SQL: SELECT chapter_title FROM nc_general_statutes WHERE chapter_number = '15A' LIMIT 1;
AgentRunResult(output=Success(sql_query="SELECT chapter_title FROM nc_general_statutes WHERE chapter_number = '15A' LIMIT 1;", explanation='This query retrieves the title of Chapter 15A from the `nc_general_statutes` table.'))
got sql statement Agent

In [14]:
await prompt_pydantic_ai("What is the text of section 15A-145?")

Running SQL: SELECT text FROM nc_general_statutes WHERE section_number = '15A-145'
AgentRunResult(output=Success(sql_query="SELECT text FROM nc_general_statutes WHERE section_number = '15A-145'", explanation="This query retrieves the text of the statute section with the section number '15A-145' from the nc_general_statutes table."))
got sql statement AgentRunResult(output=Success(sql_query="SELECT text FROM nc_general_statutes WHERE section_number = '15A-145'", explanation="This query retrieves the text of the statute section with the section number '15A-145' from the nc_general_statutes table."))
got sql result [("§ 15A-145. \xa0Expunction of records for first offenders under the age of 18 at the time of conviction of misdemeanor; expunction of certain other misdemeanors.\n(a)\tWhenever any person who has not previously been convicted of any felony, or misdemeanor other than a traffic violation, under the laws of the United States, the laws of this State or any other state, (i) pleads

In [15]:
await prompt_pydantic_ai("How many chapters are there in total?")

Running SQL: SELECT COUNT(DISTINCT chapter_number) FROM nc_general_statutes;
AgentRunResult(output=Success(sql_query='SELECT COUNT(DISTINCT chapter_number) FROM nc_general_statutes;', explanation='This query counts the number of distinct chapter numbers in the nc_general_statutes table, effectively giving the total number of chapters.'))
got sql statement AgentRunResult(output=Success(sql_query='SELECT COUNT(DISTINCT chapter_number) FROM nc_general_statutes;', explanation='This query counts the number of distinct chapter numbers in the nc_general_statutes table, effectively giving the total number of chapters.'))
got sql result [(395,)]
query_result='395' analysis='The total number of chapters is 395.'


In [16]:
await prompt_pydantic_ai("Explain like I'm 5: section 15A-145")

AgentRunResult(output=InvalidRequest(error_message="The requested section 15A-145 cannot be found in the given table schema. The available data only includes chapters and sections beginning with '1' from the 'nc_general_statutes' table. Please provide a different question or verify the section number."))
error_message="The requested section 15A-145 cannot be found in the given table schema. It seems the data only includes chapters and sections beginning with '1' from the 'nc_general_statutes' table. Please provide a valid section number or verify the section details."


## Using Pydantic AI with an Ollama model

I found that the above code does not work with some Ollama models. With `llama3.2` I was getting an error: `Plain text responses are not permitted, please include your response in a tool call`. It seems to be an issue with the `output_type` argument in the `Agent` constructor. I tried wrapping the output type with a `PromptedOutput` as suggested [here](https://ai.pydantic.dev/output/#prompted-output) (`NativeOutput` was failing with `Native structured output is not supported by the model`), but that did not work either.

Using an [output function](https://ai.pydantic.dev/output/#output-functions) instead of a validator for the first agent seems to fix the error. At least we can get the agent to generate some SQL queries.

In [17]:
from pydantic_ai.models.openai import OpenAIChatModel
from pydantic_ai.providers.ollama import OllamaProvider

model = OpenAIChatModel(
    model_name="llama3.2", provider=OllamaProvider(base_url="http://localhost:11434/v1")
)


@dataclass
class Deps:
    conn: Connection


class InvalidRequest(BaseModel):
    """Response the user input didn't include enough information to generate SQL."""

    error_message: str


async def validate_sql_query(ctx: RunContext[Deps], sql_query: str) -> str:
    """Validate a SQL query on the database."""
    # gemini often adds extraneous backslashes to SQL
    sql_query = sql_query.replace("\\", "")
    print(f"Validating SQL: {sql_query}")
    if not sql_query.upper().startswith("SELECT"):
        raise ModelRetry("Please create a SELECT query")

    try:
        await ctx.deps.conn.execute(f"EXPLAIN {sql_query}")
    except psycopg.Error as e:
        raise ModelRetry(f"Invalid query: {e}") from e
    else:
        return sql_query


sql_gen_agent = Agent(
    model,
    system_prompt=SYSTEM_PROMPT,
    # Type ignore while we wait for PEP-0747, nonetheless unions will work fine everywhere else
    output_type=[validate_sql_query, InvalidRequest],  # type: ignore
    deps_type=Deps,
    retries=2,
)


async def prompt_pydantic_llama(question):
    async with await psycopg.AsyncConnection.connect(
        os.getenv("DATABASE_URL").replace("+psycopg", "")
    ) as aconn:
        deps = Deps(aconn)
        result = await sql_gen_agent.run(question, deps=deps)

    print(result.output)

In [18]:
await prompt_pydantic_llama("How many chapters are there in total?")

Validating SQL: SELECT COUNT(*) FROM nc_general_statutes
SELECT COUNT(*) FROM nc_general_statutes


## Multi-agent system

Let's create a multi-agent system where one agent handles drafting SQL queries (OpenAI) and another agent executes them (Llama).

In [19]:
import os
from dataclasses import dataclass

import pandas as pd
from langchain_community.utilities import SQLDatabase
from pydantic import BaseModel, Field
from pydantic_ai import Agent, RunContext
from pydantic_ai.models.openai import OpenAIChatModel
from pydantic_ai.providers.ollama import OllamaProvider
from sqlalchemy import Connection
from typing_extensions import TypeAlias


exec_model = OpenAIChatModel(
    model_name="llama3.2", provider=OllamaProvider(base_url="http://localhost:11434/v1")
)
gen_model = "openai:gpt-4o"
# gen_model = exec_model


@dataclass
class Deps:
    conn: Connection
    db: SQLDatabase


class GenFailure(BaseModel):
    """Response if we could not generate a valid SELECT SQL query."""

    error_message: str


async def validate_sql_query(ctx: RunContext[Deps], sql_query: str) -> str:
    """Validate a SQL query on the database."""
    # gemini often adds extraneous backslashes to SQL
    sql_query = sql_query.replace("\\", "")
    print(f"Validating SQL: {sql_query}")
    if not sql_query.upper().startswith("SELECT"):
        raise ModelRetry("Please create a SELECT query")

    try:
        await ctx.deps.conn.execute(f"EXPLAIN {sql_query}")
    except psycopg.Error as e:
        raise ModelRetry(f"Invalid query: {e}") from e
    else:
        return sql_query


sql_gen_agent = Agent(
    gen_model,
    # Type ignore while we wait for PEP-0747, nonetheless unions will work fine everywhere else
    output_type=[validate_sql_query, GenFailure],  # type: ignore
    deps_type=Deps,
    retries=2,
)


@sql_gen_agent.system_prompt
def gen_agent_system_prompt(ctx: RunContext):
    return f"""
    You are a PostgreSQL expert with read-only access to the database. Given an input question, create a syntactically correct PostgreSQL SELECT query to run to help find the answer. You should never make any DML queries (INSERT, UPDATE, DELETE, DROP etc.).

    Pay attention to use only the column names that you can see in the schema description. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.

    Only use the following tables:
    {ctx.deps.db.get_table_info()}
    """


@dataclass
class ExecDeps:
    conn: Connection
    db: SQLDatabase
    sql_query: str


class ExecFailure(BaseModel):
    """Response the user that the SQL query failed to execute."""

    error_message: str


sql_exec_agent = Agent(
    exec_model,
    # Type ignore while we wait for PEP-0747, nonetheless unions will work fine everywhere else
    output_type=[str, ExecFailure],  # type: ignore
    deps_type=ExecDeps,
    # retries=2,
)


@sql_exec_agent.system_prompt
def exec_agent_system_prompt(ctx: RunContext):
    return f"""
    You are a PostgreSQL expert with read-only access to a database. Answer the user's question by running this SQL query exactly as it is:
    {ctx.deps.sql_query}
    
    This is the schema of the database:
    {ctx.deps.db.get_table_info()}
    """


@sql_exec_agent.tool
async def run_sql_query(ctx: RunContext[ExecDeps], sql_query: str) -> list:
    """Run a SQL query."""
    async with ctx.deps.conn.cursor() as acur:
        await acur.execute(sql_query)
        rows = await acur.fetchall()
        print(f"got sql result {rows}")
        return rows


async def prompt_pydantic_ai(question):
    db = SQLDatabase.from_uri(os.getenv("DATABASE_URL"))

    async with await psycopg.AsyncConnection.connect(
        os.getenv("DATABASE_URL").replace("+psycopg", "")
    ) as aconn:
        deps = Deps(conn=aconn, db=db)
        result = await sql_gen_agent.run(question, deps=deps)

        print(result.output)
        if isinstance(result.output, GenFailure):
            return ExecFailure(error_message=result.output.error_message)

        deps = ExecDeps(conn=aconn, db=db, sql_query=result.output)
        result = await sql_exec_agent.run(question, deps=deps)
        print(result.output)

In [20]:
await prompt_pydantic_ai("What is the title of Chapter 15A?")

error_message='The available schema does not include details about Chapter 15A or its title. The provided data only contains entries for Chapter 1 of the NC General Statutes.'


ExecFailure(error_message='The available schema does not include details about Chapter 15A or its title. The provided data only contains entries for Chapter 1 of the NC General Statutes.')

I found that iterating over an agent's graph, using the iterator returned by `Agent.iter()`, helps me see what the agent is actually doing. 

In [21]:
async def prompt_pydantic_ai_iter(question):
    db = SQLDatabase.from_uri(os.getenv("DATABASE_URL"), sample_rows_in_table_info=0)

    async with await psycopg.AsyncConnection.connect(
        os.getenv("DATABASE_URL").replace("+psycopg", "")
    ) as aconn:
        print("GEN AGENT RUN")
        print("-" * 20)
        deps = Deps(conn=aconn, db=db)
        async with sql_gen_agent.iter(question, deps=deps) as gen_agent_run:
            async for node in gen_agent_run:
                print(node)
                print("=" * 50)

        if isinstance(gen_agent_run.result.output, GenFailure):
            print(f"Error generating the SQL query: {gen_agent_run.result.output.error_message}")
            return

        print("EXEC AGENT RUN")
        print("-" * 20)
        deps = ExecDeps(conn=aconn, db=db, sql_query=gen_agent_run.result.output)
        async with sql_exec_agent.iter(question, deps=deps) as exec_agent_run:
            async for node in exec_agent_run:
                print(node)
                print("=" * 50)

In [22]:
await prompt_pydantic_ai_iter("What is the title of Chapter 15A?")

GEN AGENT RUN
--------------------
UserPromptNode(user_prompt='What is the title of Chapter 15A?', instructions_functions=[], system_prompts=(), system_prompt_functions=[SystemPromptRunner(function=<function gen_agent_system_prompt at 0x7eb9900bf6a0>, dynamic=False, _takes_ctx=True, _is_async=False)], system_prompt_dynamic_functions={})
ModelRequestNode(request=ModelRequest(parts=[SystemPromptPart(content='\n    You are a PostgreSQL expert with read-only access to the database. Given an input question, create a syntactically correct PostgreSQL SELECT query to run to help find the answer. You should never make any DML queries (INSERT, UPDATE, DELETE, DROP etc.).\n\n    Pay attention to use only the column names that you can see in the schema description. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.\n\n    Only use the following tables:\n    \nCREATE TABLE nc_general_statutes (\n\tchapter_number TEXT, \n\tchapter_title TEXT

SyntaxError: syntax error at end of input
LINE 1: ...apter_title FROM nc_general_statutes WHERE chapter_number = 
                                                                       ^