In [18]:
!pip install --break-system-packages pydantic-graph

Defaulting to user installation because normal site-packages is not writeable


In [23]:
import asyncio
from functools import lru_cache, wraps, partial

from pydantic import BaseModel
from typing import Union, Literal, Optional, Callable, Awaitable
from pydantic_ai import Agent

# ===========================
# 📦 MODELOS DE RESPOSTA
# ===========================

class ToolSuccess(BaseModel):
    type: Literal["success"] = "success"
    tool_name: str
    result: BaseModel  # O resultado real da tool (ProductInfo, OrderStatus, etc)

class ToolFailure(BaseModel):
    type: Literal["failure"] = "failure"
    tool_name: str
    error_message: str
    input_params: Optional[dict] = None

ToolResult = Union[ToolSuccess, ToolFailure]

class ProductInfo(BaseModel):
    name: str
    price: float
    available: bool

class OrderStatus(BaseModel):
    order_id: str
    status: Literal["processing", "shipped", "delivered", "cancelled"]

class ComplaintResponse(BaseModel):
    complaint_id: str
    status: str
    resolution_eta: Optional[str] = None

# ===========================
# 🛠️ FUNÇÕES AUXILIARES
# ==========================

def concurrent_tool(tool_func):
    @wraps(tool_func)
    async def wrapper(*args, **kwargs):
        return await asyncio.to_thread(tool_func, *args, **kwargs)
    return wrapper

async def execute_tool_safely(
    tool_func: Callable[..., Awaitable[BaseModel]],
    tool_name: str,
    **kwargs
) -> ToolResult:
    try:
        result = await tool_func(**kwargs)
        return ToolSuccess(tool_name=tool_name, result=result)
    except Exception as e:
        return ToolFailure(
            tool_name=tool_name,
            error_message=str(e),
            input_params=kwargs
        )

# ===========================
# 🛠️ TOOLS FUNCIONAIS
# ===========================

@lru_cache
def get_fake_product_data():
    return {
        "laptop": {"name": "Laptop Pro X", "price": 1499.99, "available": True},
        "mouse": {"name": "Mouse Ergonômico", "price": 49.90, "available": False},
    }

async def get_product_info(product_name: str) -> ProductInfo:
    await asyncio.sleep(0.1)  # menos latência
    fake_db = get_fake_product_data()
    product = fake_db.get(product_name.lower())
    if not product:
        raise ValueError(f"Produto '{product_name}' não encontrado.")
    return ProductInfo(**product)


async def check_order_status(order_id: str) -> OrderStatus:
    await asyncio.sleep(0.2)
    return OrderStatus(order_id=order_id, status="shipped")

async def register_complaint(order_id: str, reason: str) -> ComplaintResponse:
    await asyncio.sleep(0.5)
    return ComplaintResponse(
        complaint_id=f"C-{order_id}",
        status="received",
        resolution_eta="3 dias úteis"
    )

# ===========================
# 🤖 AGENTE INTELIGENTE
# ===========================

async def safe_get_product_info(product_name: str) -> ToolResult:
    result = await execute_tool_safely(get_product_info, tool_name="get_product_info", product_name=product_name)
    if isinstance(result, ToolFailure):
        result.error_message = f"Desculpe, não conseguimos encontrar o produto '{product_name}'. Por favor, tente novamente ou veja outros produtos disponíveis."
    return result


async def safe_check_order_status(order_id: str) -> ToolResult:
    return await execute_tool_safely(check_order_status, tool_name="check_order_status", order_id=order_id)

async def safe_register_complaint(order_id: str, reason: str) -> ToolResult:
    return await execute_tool_safely(register_complaint, tool_name="register_complaint", order_id=order_id, reason=reason)

tools=[
    safe_get_product_info,
    safe_check_order_status,
    safe_register_complaint
]

model_name = "gpt-4o-mini"
system_prompt = """
Você é um agente de atendimento ao cliente.
Você deve responder perguntas de forma educada e objetiva, chamando ferramentas quando necessário.
"""

agent = Agent(model=model_name, system_prompt=system_prompt, tools=tools, memory=False)

# ===========================
# 🚀 EXECUÇÃO INTERATIVA
# ===========================
perguntas = [
    "Qual o preço do laptop?",
    "E o mouse está disponível?",
    "Qual o status do pedido #XYZ123?",
    "Quero reclamar do pedido #XYZ123, veio com defeito.",
    "Qual o preço do hoverboard?"
]


responses = await asyncio.gather(*(agent.run(p) for p in perguntas))
for pergunta, response in zip(perguntas, responses):
    result = response.output
    print(f"\n👤 Usuário: {pergunta}")
    if isinstance(result, ToolSuccess):
        print(f"✅ [{result.tool_name}] → {result.result}")
    elif isinstance(result, ToolFailure):
        print(f"❌ [{result.tool_name}] Falha: {result.error_message}")
    else:
        print(f"🤖 {result}")





👤 Usuário: Qual o preço do laptop?
🤖 Desculpe, mas não consegui encontrar informações sobre o preço do laptop. Você poderia fornecer mais detalhes, como a marca ou o modelo específico que está procurando?

👤 Usuário: E o mouse está disponível?
🤖 Atualmente, não há informações disponíveis sobre a disponibilidade do mouse. Se precisar de mais detalhes ou se estiver buscando um modelo específico, por favor, me avise!

👤 Usuário: Qual o status do pedido #XYZ123?
🤖 O status do pedido #XYZ123 foi verificado com sucesso, mas não foram encontrados detalhes disponíveis sobre ele. Recomendo verificar novamente ou, se preferir, fornecer mais informações para que eu possa ajudar melhor.

👤 Usuário: Quero reclamar do pedido #XYZ123, veio com defeito.
🤖 Sua reclamação sobre o pedido #XYZ123 foi registrada com sucesso devido ao defeito identificado. Se precisar de mais alguma coisa, estou à disposição para ajudar!

👤 Usuário: Qual o preço do hoverboard?
🤖 Desculpe, não consegui encontrar informações

In [43]:
from __future__ import annotations

from dataclasses import dataclass

from rich.prompt import Prompt

from pydantic_graph import BaseNode, End, Graph, GraphRunContext


@dataclass
class MachineState:  
    user_balance: float = 0.0
    product: str | None = None

coins_inserted = 10

@dataclass
class InsertCoin(BaseNode[MachineState]):  
    async def run(self, ctx: GraphRunContext[MachineState]) -> CoinsInserted:  
        return CoinsInserted(coins_inserted)  


@dataclass
class CoinsInserted(BaseNode[MachineState]):
    amount: float  

    async def run(
        self, ctx: GraphRunContext[MachineState]
    ) -> SelectProduct | Purchase:  
        ctx.state.user_balance += self.amount  
        if ctx.state.product is not None:  
            return Purchase(ctx.state.product)
        else:
            return SelectProduct()


@dataclass
class SelectProduct(BaseNode[MachineState]):
    async def run(self, ctx: GraphRunContext[MachineState]) -> Purchase:
        return Purchase(Prompt.ask('Select product'))


PRODUCT_PRICES = {  
    'water': 1.25,
    'soda': 1.50,
    'crisps': 1.75,
    'chocolate': 2.00,
}


@dataclass
class Purchase(BaseNode[MachineState, None, None]):  
    product: str

    async def run(
        self, ctx: GraphRunContext[MachineState]
    ) -> End | InsertCoin | SelectProduct:
        if price := PRODUCT_PRICES.get(self.product):  
            ctx.state.product = self.product  
            if ctx.state.user_balance >= price:  
                ctx.state.user_balance -= price
                return End(None)
            else:
                diff = price - ctx.state.user_balance
                print(f'Not enough money for {self.product}, need {diff:0.2f} more')
                #> Not enough money for crisps, need 0.75 more
                return InsertCoin()  
        else:
            print(f'No such product: {self.product}, try again')
            return SelectProduct()  


vending_machine_graph = Graph(  
    nodes=[InsertCoin, CoinsInserted, SelectProduct, Purchase]
)


async def main():
    state = MachineState()  
    await vending_machine_graph.run(InsertCoin(), state=state)  
    print(f'purchase successful item={state.product} change={state.user_balance:0.2f}')
    #> purchase successful item=crisps change=0.25

await main()

purchase successful item=water change=8.75


In [4]:
import os

PROVIDER_NAME = os.getenv('PYDANTIC_AI_PROVIDER', 'openai')
MODEL_NAME = os.getenv('PYDANTIC_AI_MODEL', 'gpt-4o-mini')
pydantic_model = f'{PROVIDER_NAME}:{MODEL_NAME}'
print(f'Using model: {pydantic_model}')

Using model: openai:gpt-4o-mini


In [5]:
import os

import logfire
from pydantic import BaseModel

from pydantic_ai import Agent


class MyModel(BaseModel):
    city: str
    country: str

agent = Agent(pydantic_model, output_type=MyModel, instrument=True)
result = await agent.run('The windy city in the US of A.')

print(result.output)
print(result.usage())

13:26:22.826 agent run
13:26:22.830   chat gpt-4o-mini
city='Chicago' country='United States'
Usage(requests=1, request_tokens=58, response_tokens=20, total_tokens=78, details={'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0, 'cached_tokens': 0})


In [6]:
from __future__ import annotations as _annotations

import asyncio
import os
from dataclasses import dataclass
from typing import Any

import logfire
from devtools import debug
from httpx import AsyncClient

from pydantic_ai import Agent, ModelRetry, RunContext

# 'if-token-present' means nothing will be sent (and the example will work) if you don't have logfire configured
logfire.configure(send_to_logfire='if-token-present')


@dataclass
class Deps:
    client: AsyncClient
    weather_api_key: str | None
    geo_api_key: str | None


weather_agent = Agent(
    pydantic_model,
    # 'Be concise, reply with one sentence.' is enough for some models (like openai) to use
    # the below tools appropriately, but others like anthropic and gemini require a bit more direction.
    system_prompt=(
        'Be concise, reply with one sentence.'
        'Use the `get_lat_lng` tool to get the latitude and longitude of the locations, '
        'then use the `get_weather` tool to get the weather.'
    ),
    deps_type=Deps,
    retries=2,
    instrument=True,
)


@weather_agent.tool
async def get_lat_lng(
    ctx: RunContext[Deps], location_description: str
) -> dict[str, float]:
    """Get the latitude and longitude of a location.

    Args:
        ctx: The context.
        location_description: A description of a location.
    """
    if ctx.deps.geo_api_key is None:
        # if no API key is provided, return a dummy response (London)
        return {'lat': 51.1, 'lng': -0.1}

    params = {
        'q': location_description,
        'api_key': ctx.deps.geo_api_key,
    }
    with logfire.span('calling geocode API', params=params) as span:
        r = await ctx.deps.client.get('https://geocode.maps.co/search', params=params)
        r.raise_for_status()
        data = r.json()
        span.set_attribute('response', data)

    if data:
        return {'lat': data[0]['lat'], 'lng': data[0]['lon']}
    else:
        raise ModelRetry('Could not find the location')


@weather_agent.tool
async def get_weather(ctx: RunContext[Deps], lat: float, lng: float) -> dict[str, Any]:
    """Get the weather at a location.

    Args:
        ctx: The context.
        lat: Latitude of the location.
        lng: Longitude of the location.
    """
    if ctx.deps.weather_api_key is None:
        # if no API key is provided, return a dummy response
        return {'temperature': '21 °C', 'description': 'Sunny'}

    params = {
        'apikey': ctx.deps.weather_api_key,
        'location': f'{lat},{lng}',
        'units': 'metric',
    }
    with logfire.span('calling weather API', params=params) as span:
        r = await ctx.deps.client.get(
            'https://api.tomorrow.io/v4/weather/realtime', params=params
        )
        r.raise_for_status()
        data = r.json()
        span.set_attribute('response', data)

    values = data['data']['values']
    # https://docs.tomorrow.io/reference/data-layers-weather-codes
    code_lookup = {
        1000: 'Clear, Sunny',
        1100: 'Mostly Clear',
        1101: 'Partly Cloudy',
        1102: 'Mostly Cloudy',
        1001: 'Cloudy',
        2000: 'Fog',
        2100: 'Light Fog',
        4000: 'Drizzle',
        4001: 'Rain',
        4200: 'Light Rain',
        4201: 'Heavy Rain',
        5000: 'Snow',
        5001: 'Flurries',
        5100: 'Light Snow',
        5101: 'Heavy Snow',
        6000: 'Freezing Drizzle',
        6001: 'Freezing Rain',
        6200: 'Light Freezing Rain',
        6201: 'Heavy Freezing Rain',
        7000: 'Ice Pellets',
        7101: 'Heavy Ice Pellets',
        7102: 'Light Ice Pellets',
        8000: 'Thunderstorm',
    }
    return {
        'temperature': f'{values["temperatureApparent"]:0.0f}°C',
        'description': code_lookup.get(values['weatherCode'], 'Unknown'),
    }


async def main():
    async with AsyncClient() as client:
        # create a free API key at https://www.tomorrow.io/weather-api/
        weather_api_key = os.getenv('WEATHER_API_KEY')
        # create a free API key at https://geocode.maps.co/
        geo_api_key = os.getenv('GEO_API_KEY')
        deps = Deps(
            client=client, weather_api_key=weather_api_key, geo_api_key=geo_api_key
        )
        result = await weather_agent.run(
            'What is the weather like in London and in Wiltshire?', deps=deps
        )
        debug(result)
        print('Response:', result.output)


await main()

13:26:44.611 weather_agent run
13:26:44.613   chat gpt-4o-mini
13:26:46.357   running 2 tools
13:26:46.357     running tool: get_lat_lng
13:26:46.357     running tool: get_lat_lng
13:26:46.358   chat gpt-4o-mini
13:26:48.017   running 2 tools
13:26:48.017     running tool: get_weather
13:26:48.018     running tool: get_weather
13:26:48.019   chat gpt-4o-mini
/tmp/ipykernel_6624/3760346197.py:141 main
    result: AgentRunResult(
        output='The weather in both London and Wiltshire is sunny with a temperature of 21 °C.',
        _output_tool_name=None,
        _state=GraphAgentState(
            message_history=[
                ModelRequest(
                    parts=[
                        SystemPromptPart(
                            content=(
                                'Be concise, reply with one sentence.Use the `get_lat_lng` tool to get the latitude an'
                                'd longitude of the locations, then use the `get_weather` tool to get the weather.'
   

In [57]:
!pip install 'logfire[asyncpg]'

[1;31merror[0m: [1mexternally-managed-environment[0m

[31m×[0m This environment is externally managed
[31m╰─>[0m To install Python packages system-wide, try apt install
[31m   [0m python3-xyz, where xyz is the package you are trying to
[31m   [0m install.
[31m   [0m 
[31m   [0m If you wish to install a non-Debian-packaged Python package,
[31m   [0m create a virtual environment using python3 -m venv path/to/venv.
[31m   [0m Then use path/to/venv/bin/python and path/to/venv/bin/pip. Make
[31m   [0m sure you have python3-full installed.
[31m   [0m 
[31m   [0m If you wish to install a non-Debian packaged Python application,
[31m   [0m it may be easiest to use pipx install xyz, which will manage a
[31m   [0m virtual environment for you. Make sure you have pipx installed.
[31m   [0m 
[31m   [0m See /usr/share/doc/python3.12/README.venv for more information.

[1;35mnote[0m: If you believe this is a mistake, please contact your Python installation or OS dist

In [9]:
import asyncio
import sys
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from dataclasses import dataclass
from datetime import date
from typing import Annotated, Any, Union

import asyncpg
import logfire
from annotated_types import MinLen
from devtools import debug
from pydantic import BaseModel, Field
from typing_extensions import TypeAlias

from pydantic_ai import Agent, ModelRetry, RunContext, format_as_xml

# 'if-token-present' means nothing will be sent (and the example will work) if you don't have logfire configured
logfire.configure(send_to_logfire='if-token-present')
logfire.instrument_asyncpg()

DB_SCHEMA = """
CREATE TABLE records (
    created_at timestamptz,
    start_timestamp timestamptz,
    end_timestamp timestamptz,
    trace_id text,
    span_id text,
    parent_span_id text,
    level log_level,
    span_name text,
    message text,
    attributes_json_schema text,
    attributes jsonb,
    tags text[],
    is_exception boolean,
    otel_status_message text,
    service_name text
);
"""
SQL_EXAMPLES = [
    {
        'request': 'show me records where foobar is false',
        'response': "SELECT * FROM records WHERE attributes->>'foobar' = false",
    },
    {
        'request': 'show me records where attributes include the key "foobar"',
        'response': "SELECT * FROM records WHERE attributes ? 'foobar'",
    },
    {
        'request': 'show me records from yesterday',
        'response': "SELECT * FROM records WHERE start_timestamp::date > CURRENT_TIMESTAMP - INTERVAL '1 day'",
    },
    {
        'request': 'show me error records with the tag "foobar"',
        'response': "SELECT * FROM records WHERE level = 'error' and 'foobar' = ANY(tags)",
    },
]


@dataclass
class Deps:
    conn: asyncpg.Connection


class Success(BaseModel):
    """Response when SQL could be successfully generated."""

    sql_query: Annotated[str, MinLen(1)]
    explanation: str = Field(
        '', description='Explanation of the SQL query, as markdown'
    )


class InvalidRequest(BaseModel):
    """Response the user input didn't include enough information to generate SQL."""

    error_message: str


Response: TypeAlias = Union[Success, InvalidRequest]
agent: Agent[Deps, Response] = Agent(
    pydantic_model,
    # Type ignore while we wait for PEP-0747, nonetheless unions will work fine everywhere else
    output_type=Response,  # type: ignore
    deps_type=Deps,
    instrument=True,
)


@agent.system_prompt
async def system_prompt() -> str:
    return f"""\
Given the following PostgreSQL table of records, your job is to
write a SQL query that suits the user's request.

Database schema:

{DB_SCHEMA}

today's date = {date.today()}

{format_as_xml(SQL_EXAMPLES)}
"""


@agent.output_validator
async def validate_output(ctx: RunContext[Deps], output: Response) -> Response:
    if isinstance(output, InvalidRequest):
        return output

    # gemini often adds extraneous backslashes to SQL
    output.sql_query = output.sql_query.replace('\\', '')
    if not output.sql_query.upper().startswith('SELECT'):
        raise ModelRetry('Please create a SELECT query')

    try:
        await ctx.deps.conn.execute(f'EXPLAIN {output.sql_query}')
    except asyncpg.exceptions.PostgresError as e:
        raise ModelRetry(f'Invalid query: {e}') from e
    else:
        return output


# pyright: reportUnknownMemberType=false
# pyright: reportUnknownVariableType=false
@asynccontextmanager
async def database_connect(server_dsn: str, database: str) -> AsyncGenerator[Any, None]:
    with logfire.span('check and create DB'):
        conn = await asyncpg.connect(server_dsn)
        try:
            db_exists = await conn.fetchval(
                'SELECT 1 FROM pg_database WHERE datname = $1', database
            )
            if not db_exists:
                await conn.execute(f'CREATE DATABASE {database}')
        finally:
            await conn.close()

    conn = await asyncpg.connect(f'{server_dsn}/{database}')
    try:
        with logfire.span('create schema'):
            async with conn.transaction():
                if not db_exists:
                    await conn.execute(
                        "CREATE TYPE log_level AS ENUM ('debug', 'info', 'warning', 'error', 'critical')"
                    )
                    await conn.execute(DB_SCHEMA)
        yield conn
    finally:
        await conn.close()

async def main():
    if len(sys.argv) == 1:
        prompt = 'show me logs from yesterday, with level "error"'
    else:
        prompt = sys.argv[1]

    async with database_connect(
        'postgresql://postgres:postgres@localhost:5432', 'pydantic_ai_sql_gen'
    ) as conn:
        deps = Deps(conn)
        result = await agent.run(prompt, deps=deps)
    debug(result.output)



await main()



Attempting to instrument while already instrumented


13:29:31.411 check and create DB
13:29:31.461   SELECT
13:29:31.506 create schema
13:29:31.507   BEGIN;
13:29:31.508   COMMIT;
13:29:31.510 agent run
13:29:31.511   chat gpt-4o-mini
/tmp/ipykernel_6624/2723136973.py:164 main
    result.output: InvalidRequest(
        error_message="Your request doesn't contain enough information to generate a SQL query.",
    ) (InvalidRequest)


In [14]:
!pip install --break-system-package ipywidgets

Defaulting to user installation because normal site-packages is not writeable
Collecting ipywidgets
  Downloading ipywidgets-8.1.7-py3-none-any.whl.metadata (2.4 kB)
Collecting widgetsnbextension~=4.0.14 (from ipywidgets)
  Downloading widgetsnbextension-4.0.14-py3-none-any.whl.metadata (1.6 kB)
Collecting jupyterlab_widgets~=3.0.15 (from ipywidgets)
  Downloading jupyterlab_widgets-3.0.15-py3-none-any.whl.metadata (20 kB)
Downloading ipywidgets-8.1.7-py3-none-any.whl (139 kB)
Downloading jupyterlab_widgets-3.0.15-py3-none-any.whl (216 kB)
Downloading widgetsnbextension-4.0.14-py3-none-any.whl (2.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.2/2.2 MB[0m [31m13.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: widgetsnbextension, jupyterlab_widgets, ipywidgets
Successfully installed ipywidgets-8.1.7 jupyterlab_widgets-3.0.15 widgetsnbextension-4.0.14


In [None]:
import asyncio
import os

import logfire
from rich.console import Console, ConsoleOptions, RenderResult
from rich.live import Live
from rich.markdown import CodeBlock, Markdown
from rich.syntax import Syntax
from rich.text import Text

from pydantic_ai import Agent
from pydantic_ai.models import KnownModelName

# 'if-token-present' means nothing will be sent (and the example will work) if you don't have logfire configured
logfire.configure(send_to_logfire='if-token-present')

agent = Agent(instrument=True)

# models to try, and the appropriate env var
models: list[tuple[KnownModelName, str]] = [
    ('openai:gpt-3.5-turbo', 'OPENAI_API_KEY'),
    ('openai:gpt-4o-mini', 'OPENAI_API_KEY'),
]


async def main():
    prettier_code_blocks()
    console = Console()
    prompt = 'Show me a short example of using Pydantic.'
    console.log(f'Asking: {prompt}...', style='cyan')
    for model, env_var in models:
        if env_var in os.environ:
            console.log(f'Using model: {model}')
            with Live('', console=console, vertical_overflow='visible') as live:
                async with agent.run_stream(prompt, model=model) as result:
                    async for message in result.stream():
                        live.update(Markdown(message))
            console.log(result.usage())
        else:
            console.log(f'{model} requires {env_var} to be set.')


def prettier_code_blocks():
    """Make rich code blocks prettier and easier to copy.

    From https://github.com/samuelcolvin/aicli/blob/v0.8.0/samuelcolvin_aicli.py#L22
    """

    class SimpleCodeBlock(CodeBlock):
        def __rich_console__(
            self, console: Console, options: ConsoleOptions
        ) -> RenderResult:
            code = str(self.text).rstrip()
            yield Text(self.lexer_name, style='dim')
            yield Syntax(
                code,
                self.lexer_name,
                theme=self.theme,
                background_color='default',
                word_wrap=True,
            )
            yield Text(f'/{self.lexer_name}', style='dim')

    Markdown.elements['fence'] = SimpleCodeBlock


await main()

Output()

In [2]:
from __future__ import annotations as _annotations

from dataclasses import dataclass, field
from pathlib import Path

import logfire
from groq import BaseModel
from pydantic_graph import (
    BaseNode,
    End,
    Graph,
    GraphRunContext,
)
from pydantic_graph.persistence.file import FileStatePersistence

from pydantic_ai import Agent, format_as_xml
from pydantic_ai.messages import ModelMessage

# 'if-token-present' means nothing will be sent (and the example will work) if you don't have logfire configured
logfire.configure(send_to_logfire='if-token-present')

ask_agent = Agent('openai:gpt-4o', output_type=str, instrument=True)


@dataclass
class QuestionState:
    question: str | None = None
    ask_agent_messages: list[ModelMessage] = field(default_factory=list)
    evaluate_agent_messages: list[ModelMessage] = field(default_factory=list)


@dataclass
class Ask(BaseNode[QuestionState]):
    async def run(self, ctx: GraphRunContext[QuestionState]) -> Answer:
        result = await ask_agent.run(
            'Ask a simple question with a single correct answer.',
            message_history=ctx.state.ask_agent_messages,
        )
        ctx.state.ask_agent_messages += result.all_messages()
        ctx.state.question = result.output
        return Answer(result.output)


@dataclass
class Answer(BaseNode[QuestionState]):
    question: str

    async def run(self, ctx: GraphRunContext[QuestionState]) -> Evaluate:
        answer = input(f'{self.question}: ')
        return Evaluate(answer)


class EvaluationOutput(BaseModel, use_attribute_docstrings=True):
    correct: bool
    """Whether the answer is correct."""
    comment: str
    """Comment on the answer, reprimand the user if the answer is wrong."""


evaluate_agent = Agent(
    'openai:gpt-4o',
    output_type=EvaluationOutput,
    system_prompt='Given a question and answer, evaluate if the answer is correct.',
)


@dataclass
class Evaluate(BaseNode[QuestionState, None, str]):
    answer: str

    async def run(
        self,
        ctx: GraphRunContext[QuestionState],
    ) -> End[str] | Reprimand:
        assert ctx.state.question is not None
        result = await evaluate_agent.run(
            format_as_xml({'question': ctx.state.question, 'answer': self.answer}),
            message_history=ctx.state.evaluate_agent_messages,
        )
        ctx.state.evaluate_agent_messages += result.all_messages()
        if result.output.correct:
            return End(result.output.comment)
        else:
            return Reprimand(result.output.comment)


@dataclass
class Reprimand(BaseNode[QuestionState]):
    comment: str

    async def run(self, ctx: GraphRunContext[QuestionState]) -> Ask:
        print(f'Comment: {self.comment}')
        ctx.state.question = None
        return Ask()


question_graph = Graph(
    nodes=(Ask, Answer, Evaluate, Reprimand), state_type=QuestionState
)


async def run_as_continuous():
    state = QuestionState()
    node = Ask()
    end = await question_graph.run(node, state=state)
    print('END:', end.output)


async def run_as_cli(answer: str | None):
    persistence = FileStatePersistence(Path('question_graph.json'))
    persistence.set_graph_types(question_graph)

    if snapshot := await persistence.load_next():
        state = snapshot.state
        assert answer is not None, (
            'answer required, usage "uv run -m pydantic_ai_examples.question_graph cli <answer>"'
        )
        node = Evaluate(answer)
    else:
        state = QuestionState()
        node = Ask()
    # debug(state, node)

    async with question_graph.iter(node, state=state, persistence=persistence) as run:
        while True:
            node = await run.next()
            if isinstance(node, End):
                print('END:', node.data)
                history = await persistence.load_all()
                print('history:', '\n'.join(str(e.node) for e in history), sep='\n')
                print('Finished!')
                break
            elif isinstance(node, Answer):
                print(node.question)
                break
            # otherwise just continue

await run_as_continuous()

13:48:14.584 run graph question_graph
13:48:14.584   run node Ask
13:48:14.586     ask_agent run
13:48:14.588       chat gpt-4o
13:48:15.176   run node Answer
13:48:22.329   run node Evaluate
13:48:23.585   run node Reprimand
Comment: The answer is incorrect. The capital city of France is Paris.
13:48:23.585   run node Ask
13:48:23.586     ask_agent run
13:48:23.586       chat gpt-4o
13:48:24.488   run node Answer
13:48:30.628   run node Evaluate
13:48:31.505   run node Reprimand
Comment: The answer is incorrect. The chemical symbol for water is H2O.
13:48:31.506   run node Ask
13:48:31.506     ask_agent run
13:48:31.507       chat gpt-4o
13:48:32.167   run node Answer
13:48:43.315   run node Evaluate
13:48:44.606   run node Reprimand
Comment: The answer is incorrect. The largest planet in our solar system is Jupiter.
13:48:44.607   run node Ask
13:48:44.607     ask_agent run
13:48:44.608       chat gpt-4o
13:48:45.544   run node Answer
13:48:54.060   run node Evaluate
END: The answer 

In [11]:
from __future__ import annotations as _annotations

import asyncio
import re
import sys
import unicodedata
from contextlib import asynccontextmanager
from dataclasses import dataclass

import asyncpg
import httpx
import logfire
import pydantic_core
from openai import AsyncOpenAI
from pydantic import TypeAdapter
from typing_extensions import AsyncGenerator

from pydantic_ai import RunContext
from pydantic_ai.agent import Agent

# 'if-token-present' means nothing will be sent (and the example will work) if you don't have logfire configured
logfire.configure(send_to_logfire='if-token-present')
logfire.instrument_asyncpg()


@dataclass
class Deps:
    openai: AsyncOpenAI
    pool: asyncpg.Pool


agent = Agent(pydantic_model, deps_type=Deps, instrument=True)


@agent.tool
async def retrieve(context: RunContext[Deps], search_query: str) -> str:
    """Retrieve documentation sections based on a search query.

    Args:
        context: The call context.
        search_query: The search query.
    """
    with logfire.span(
        'create embedding for {search_query=}', search_query=search_query
    ):
        embedding = await context.deps.openai.embeddings.create(
            input=search_query,
            model='text-embedding-3-small',
        )

    assert len(embedding.data) == 1, (
        f'Expected 1 embedding, got {len(embedding.data)}, doc query: {search_query!r}'
    )
    embedding = embedding.data[0].embedding
    embedding_json = pydantic_core.to_json(embedding).decode()
    rows = await context.deps.pool.fetch(
        'SELECT url, title, content FROM doc_sections ORDER BY embedding <-> $1 LIMIT 8',
        embedding_json,
    )
    return '\n\n'.join(
        f'# {row["title"]}\nDocumentation URL:{row["url"]}\n\n{row["content"]}\n'
        for row in rows
    )


async def run_agent(question: str):
    """Entry point to run the agent and perform RAG based question answering."""
    openai = AsyncOpenAI()
    logfire.instrument_openai(openai)

    logfire.info('Asking "{question}"', question=question)

    async with database_connect(True) as pool:
        deps = Deps(openai=openai, pool=pool)
        answer = await agent.run(question, deps=deps)
    print(answer.output)


#######################################################
# The rest of this file is dedicated to preparing the #
# search database, and some utilities.                #
#######################################################

# JSON document from
# https://gist.github.com/samuelcolvin/4b5bb9bb163b1122ff17e29e48c10992
DOCS_JSON = (
    'https://gist.githubusercontent.com/'
    'samuelcolvin/4b5bb9bb163b1122ff17e29e48c10992/raw/'
    '80c5925c42f1442c24963aaf5eb1a324d47afe95/logfire_docs.json'
)


async def build_search_db():
    """Build the search database."""
    async with httpx.AsyncClient() as client:
        response = await client.get(DOCS_JSON)
        response.raise_for_status()
    sections = sessions_ta.validate_json(response.content)

    openai = AsyncOpenAI()
    logfire.instrument_openai(openai)

    async with database_connect(True) as pool:
        with logfire.span('create schema'):
            async with pool.acquire() as conn:
                async with conn.transaction():
                    await conn.execute(DB_SCHEMA)

        sem = asyncio.Semaphore(10)
        async with asyncio.TaskGroup() as tg:
            for section in sections:
                tg.create_task(insert_doc_section(sem, openai, pool, section))


async def insert_doc_section(
    sem: asyncio.Semaphore,
    openai: AsyncOpenAI,
    pool: asyncpg.Pool,
    section: DocsSection,
) -> None:
    async with sem:
        url = section.url()
        exists = await pool.fetchval('SELECT 1 FROM doc_sections WHERE url = $1', url)
        if exists:
            logfire.info('Skipping {url=}', url=url)
            return

        with logfire.span('create embedding for {url=}', url=url):
            embedding = await openai.embeddings.create(
                input=section.embedding_content(),
                model='text-embedding-3-small',
            )
        assert len(embedding.data) == 1, (
            f'Expected 1 embedding, got {len(embedding.data)}, doc section: {section}'
        )
        embedding = embedding.data[0].embedding
        embedding_json = pydantic_core.to_json(embedding).decode()
        await pool.execute(
            'INSERT INTO doc_sections (url, title, content, embedding) VALUES ($1, $2, $3, $4)',
            url,
            section.title,
            section.content,
            embedding_json,
        )


@dataclass
class DocsSection:
    id: int
    parent: int | None
    path: str
    level: int
    title: str
    content: str

    def url(self) -> str:
        url_path = re.sub(r'\.md$', '', self.path)
        return (
            f'https://logfire.pydantic.dev/docs/{url_path}/#{slugify(self.title, "-")}'
        )

    def embedding_content(self) -> str:
        return '\n\n'.join((f'path: {self.path}', f'title: {self.title}', self.content))


sessions_ta = TypeAdapter(list[DocsSection])


# pyright: reportUnknownMemberType=false
# pyright: reportUnknownVariableType=false
@asynccontextmanager
async def database_connect(
    create_db: bool = False,
) -> AsyncGenerator[asyncpg.Pool, None]:
    server_dsn, database = (
        'postgresql://postgres:postgres@localhost:54320',
        'pydantic_ai_rag',
    )
    if create_db:
        with logfire.span('check and create DB'):
            conn = await asyncpg.connect(server_dsn)
            try:
                db_exists = await conn.fetchval(
                    'SELECT 1 FROM pg_database WHERE datname = $1', database
                )
                if not db_exists:
                    await conn.execute(f'CREATE DATABASE {database}')
            finally:
                await conn.close()

    pool = await asyncpg.create_pool(f'{server_dsn}/{database}')
    try:
        yield pool
    finally:
        await pool.close()


DB_SCHEMA = """
CREATE EXTENSION IF NOT EXISTS vector;

CREATE TABLE IF NOT EXISTS doc_sections (
    id serial PRIMARY KEY,
    url text NOT NULL UNIQUE,
    title text NOT NULL,
    content text NOT NULL,
    -- text-embedding-3-small returns a vector of 1536 floats
    embedding vector(1536) NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_doc_sections_embedding ON doc_sections USING hnsw (embedding vector_l2_ops);
"""


def slugify(value: str, separator: str, unicode: bool = False) -> str:
    """Slugify a string, to make it URL friendly."""
    # Taken unchanged from https://github.com/Python-Markdown/markdown/blob/3.7/markdown/extensions/toc.py#L38
    if not unicode:
        # Replace Extended Latin characters with ASCII, i.e. `žlutý` => `zluty`
        value = unicodedata.normalize('NFKD', value)
        value = value.encode('ascii', 'ignore').decode('ascii')
    value = re.sub(r'[^\w\s-]', '', value).strip().lower()
    return re.sub(rf'[{separator}\s]+', separator, value)


action = 'search'

if action == 'build':
    await build_search_db()
elif action == 'search':
    q = 'How do I configure logfire to work with FastAPI?'
    await run_agent(q)
else:
    print(f'Unknown action: {action!r}')
    sys.exit(1)

Attempting to instrument while already instrumented


13:34:14.921 Asking "How do I configure logfire to work with FastAPI?"
13:34:14.923 check and create DB
13:34:14.953   SELECT
13:34:14.971   CREATE
13:34:15.360 agent run
13:34:15.362   chat gpt-4o-mini
13:34:16.658   running 1 tool
13:34:16.658     running tool: retrieve
13:34:16.660       create embedding for search_query=configure logfire with FastAPI
13:34:16.666         Embedding Creation with 'text-embedding-3-small' [LLM]
13:34:19.313       SELECT
13:34:19.320       SELECT


UndefinedTableError: relation "doc_sections" does not exist

In [46]:
import datetime
from dataclasses import dataclass
from typing import Literal

import logfire
from pydantic import BaseModel, Field
from rich.prompt import Prompt

from pydantic_ai import Agent, ModelRetry, RunContext
from pydantic_ai.messages import ModelMessage
from pydantic_ai.usage import Usage, UsageLimits

# 'if-token-present' means nothing will be sent (and the example will work) if you don't have logfire configured
logfire.configure(send_to_logfire='if-token-present')


class FlightDetails(BaseModel):
    """Details of the most suitable flight."""

    flight_number: str
    price: int
    origin: str = Field(description='Three-letter airport code')
    destination: str = Field(description='Three-letter airport code')
    date: datetime.date


class NoFlightFound(BaseModel):
    """When no valid flight is found."""


@dataclass
class Deps:
    web_page_text: str
    req_origin: str
    req_destination: str
    req_date: datetime.date


# This agent is responsible for controlling the flow of the conversation.
search_agent = Agent[Deps, FlightDetails | NoFlightFound](
    f'openai:{model_name}',
    output_type=FlightDetails | NoFlightFound,  # type: ignore
    retries=4,
    system_prompt=(
        'Your job is to find the cheapest flight for the user on the given date. '
    ),
    instrument=True,
)


# This agent is responsible for extracting flight details from web page text.
extraction_agent = Agent(
    f'openai:{model_name}',
    output_type=list[FlightDetails],
    system_prompt='Extract all the flight details from the given text.',
)


@search_agent.tool
async def extract_flights(ctx: RunContext[Deps]) -> list[FlightDetails]:
    """Get details of all flights."""
    # we pass the usage to the search agent so requests within this agent are counted
    result = await extraction_agent.run(ctx.deps.web_page_text, usage=ctx.usage)
    logfire.info('found {flight_count} flights', flight_count=len(result.output))
    return result.output


@search_agent.output_validator
async def validate_output(
    ctx: RunContext[Deps], output: FlightDetails | NoFlightFound
) -> FlightDetails | NoFlightFound:
    """Procedural validation that the flight meets the constraints."""
    if isinstance(output, NoFlightFound):
        return output

    errors: list[str] = []
    if output.origin != ctx.deps.req_origin:
        errors.append(
            f'Flight should have origin {ctx.deps.req_origin}, not {output.origin}'
        )
    if output.destination != ctx.deps.req_destination:
        errors.append(
            f'Flight should have destination {ctx.deps.req_destination}, not {output.destination}'
        )
    if output.date != ctx.deps.req_date:
        errors.append(f'Flight should be on {ctx.deps.req_date}, not {output.date}')

    if errors:
        raise ModelRetry('\n'.join(errors))
    else:
        return output


class SeatPreference(BaseModel):
    row: int = Field(ge=1, le=30)
    seat: Literal['A', 'B', 'C', 'D', 'E', 'F']


class Failed(BaseModel):
    """Unable to extract a seat selection."""


# This agent is responsible for extracting the user's seat selection
seat_preference_agent = Agent[None, SeatPreference | Failed](
    f'openai:{model_name}',
    output_type=SeatPreference | Failed,  # type: ignore
    system_prompt=(
        "Extract the user's seat preference. "
        'Seats A and F are window seats. '
        'Row 1 is the front row and has extra leg room. '
        'Rows 14, and 20 also have extra leg room. '
    ),
)


# in reality this would be downloaded from a booking site,
# potentially using another agent to navigate the site
flights_web_page = """
1. Flight SFO-AK123
- Price: $350
- Origin: San Francisco International Airport (SFO)
- Destination: Ted Stevens Anchorage International Airport (ANC)
- Date: January 10, 2025

2. Flight SFO-AK456
- Price: $370
- Origin: San Francisco International Airport (SFO)
- Destination: Fairbanks International Airport (FAI)
- Date: January 10, 2025

3. Flight SFO-AK789
- Price: $400
- Origin: San Francisco International Airport (SFO)
- Destination: Juneau International Airport (JNU)
- Date: January 20, 2025

4. Flight NYC-LA101
- Price: $250
- Origin: San Francisco International Airport (SFO)
- Destination: Ted Stevens Anchorage International Airport (ANC)
- Date: January 10, 2025

5. Flight CHI-MIA202
- Price: $200
- Origin: Chicago O'Hare International Airport (ORD)
- Destination: Miami International Airport (MIA)
- Date: January 12, 2025

6. Flight BOS-SEA303
- Price: $120
- Origin: Boston Logan International Airport (BOS)
- Destination: Ted Stevens Anchorage International Airport (ANC)
- Date: January 12, 2025

7. Flight DFW-DEN404
- Price: $150
- Origin: Dallas/Fort Worth International Airport (DFW)
- Destination: Denver International Airport (DEN)
- Date: January 10, 2025

8. Flight ATL-HOU505
- Price: $180
- Origin: Hartsfield-Jackson Atlanta International Airport (ATL)
- Destination: George Bush Intercontinental Airport (IAH)
- Date: January 10, 2025
"""

# restrict how many requests this app can make to the LLM
usage_limits = UsageLimits(request_limit=15)


async def find_seat(usage: Usage) -> SeatPreference:
    message_history: list[ModelMessage] | None = None
    while True:
        answer = Prompt.ask('What seat would you like?')

        result = await seat_preference_agent.run(
            answer,
            message_history=message_history,
            usage=usage,
            usage_limits=usage_limits,
        )
        if isinstance(result.output, SeatPreference):
            return result.output
        else:
            print('Could not understand seat preference. Please try again.')
            message_history = result.all_messages()


async def buy_tickets(flight_details: FlightDetails, seat: SeatPreference):
    print(f'Purchasing flight {flight_details=!r} {seat=!r}...')

async def main():
    deps = Deps(
        web_page_text=flights_web_page,
        req_origin='SFO',
        req_destination='ANC',
        req_date=datetime.date(2025, 1, 10),
    )
    message_history: list[ModelMessage] | None = None
    usage: Usage = Usage()
    # run the agent until a satisfactory flight is found
    while True:
        result = await search_agent.run(
            f'Find me a flight from {deps.req_origin} to {deps.req_destination} on {deps.req_date}',
            deps=deps,
            usage=usage,
            message_history=message_history,
            usage_limits=usage_limits,
        )
        if isinstance(result.output, NoFlightFound):
            print('No flight found')
            break
        else:
            flight = result.output
            print(f'Flight found: {flight}')
            answer = Prompt.ask(
                'Do you want to buy this flight, or keep searching? (buy/*search)',
                choices=['buy', 'search', ''],
                show_choices=False,
            )
            if answer == 'buy':
                seat = await find_seat(usage)
                await buy_tickets(flight, seat)
                break
            else:
                message_history = result.all_messages(
                    output_tool_return_content='Please suggest another flight'
                )

await main()

12:42:04.792 search_agent run
12:42:04.796   chat gpt-4o-mini
12:42:06.164   running 1 tool
12:42:06.164     running tool: extract_flights
12:42:11.206       found 8 flights
12:42:11.208   chat gpt-4o-mini
Flight found: flight_number='LA101' price=250 origin='SFO' destination='ANC' date=datetime.date(2025, 1, 10)


12:42:27.701 search_agent run
12:42:27.702   chat gpt-4o-mini
Flight found: flight_number='LA101' price=250 origin='SFO' destination='ANC' date=datetime.date(2025, 1, 10)


Purchasing flight flight_details=FlightDetails(flight_number='LA101', price=250, origin='SFO', destination='ANC', date=datetime.date(2025, 1, 10)) seat=SeatPreference(row=1, seat='A')...


In [None]:
from __future__ import annotations as _annotations

from dataclasses import dataclass, field

from pydantic import BaseModel, EmailStr

from pydantic_ai import Agent, format_as_xml
from pydantic_ai.messages import ModelMessage
from pydantic_graph import BaseNode, End, Graph, GraphRunContext


@dataclass
class User:
    name: str
    email: EmailStr
    interests: list[str]


@dataclass
class Email:
    subject: str
    body: str


@dataclass
class State:
    user: User
    write_agent_messages: list[ModelMessage] = field(default_factory=list)


email_writer_agent = Agent(
    'google-vertex:gemini-1.5-pro',
    output_type=Email,
    system_prompt='Write a welcome email to our tech blog.',
)


@dataclass
class WriteEmail(BaseNode[State]):
    email_feedback: str | None = None

    async def run(self, ctx: GraphRunContext[State]) -> Feedback:
        if self.email_feedback:
            prompt = (
                f'Rewrite the email for the user:\n'
                f'{format_as_xml(ctx.state.user)}\n'
                f'Feedback: {self.email_feedback}'
            )
        else:
            prompt = (
                f'Write a welcome email for the user:\n'
                f'{format_as_xml(ctx.state.user)}'
            )

        result = await email_writer_agent.run(
            prompt,
            message_history=ctx.state.write_agent_messages,
        )
        ctx.state.write_agent_messages += result.new_messages()
        return Feedback(result.output)


class EmailRequiresWrite(BaseModel):
    feedback: str


class EmailOk(BaseModel):
    pass


feedback_agent = Agent[None, EmailRequiresWrite | EmailOk](
    f'openai:{model_name}',
    output_type=EmailRequiresWrite | EmailOk,  # type: ignore
    system_prompt=(
        'Review the email and provide feedback, email must reference the users specific interests.'
    ),
)


@dataclass
class Feedback(BaseNode[State, None, Email]):
    email: Email

    async def run(
        self,
        ctx: GraphRunContext[State],
    ) -> WriteEmail | End[Email]:
        prompt = format_as_xml({'user': ctx.state.user, 'email': self.email})
        result = await feedback_agent.run(prompt)
        if isinstance(result.output, EmailRequiresWrite):
            return WriteEmail(email_feedback=result.output.feedback)
        else:
            return End(self.email)


async def main():
    user = User(
        name='John Doe',
        email='john.joe@example.com',
        interests=['Haskel', 'Lisp', 'Fortran'],
    )
    state = State(user)
    feedback_graph = Graph(nodes=(WriteEmail, Feedback))
    result = await feedback_graph.run(WriteEmail(), state=state)
    print(result.output)
    """
    Email(
        subject='Welcome to our tech blog!',
        body='Hello John, Welcome to our tech blog! ...',
    )
    """



CancelledError: 

In [42]:
print(vending_machine_graph.mermaid_code(start_node=InsertCoin))

---
title: vending_machine_graph
---
stateDiagram-v2
  [*] --> InsertCoin
  InsertCoin --> CoinsInserted
  CoinsInserted --> SelectProduct
  CoinsInserted --> Purchase
  SelectProduct --> Purchase
  Purchase --> InsertCoin
  Purchase --> SelectProduct
  Purchase --> [*]


In [None]:
!pip install nest_asyncio


[1;31merror[0m: [1mexternally-managed-environment[0m

[31m×[0m This environment is externally managed
[31m╰─>[0m To install Python packages system-wide, try apt install
[31m   [0m python3-xyz, where xyz is the package you are trying to
[31m   [0m install.
[31m   [0m 
[31m   [0m If you wish to install a non-Debian-packaged Python package,
[31m   [0m create a virtual environment using python3 -m venv path/to/venv.
[31m   [0m Then use path/to/venv/bin/python and path/to/venv/bin/pip. Make
[31m   [0m sure you have python3-full installed.
[31m   [0m 
[31m   [0m If you wish to install a non-Debian packaged Python application,
[31m   [0m it may be easiest to use pipx install xyz, which will manage a
[31m   [0m virtual environment for you. Make sure you have pipx installed.
[31m   [0m 
[31m   [0m See /usr/share/doc/python3.12/README.venv for more information.

[1;35mnote[0m: If you believe this is a mistake, please contact your Python installation or OS dist

In [21]:
import simpy
import asyncio
import random

from abc import ABC, abstractmethod
from typing import Any, Dict, Callable, List

class Artifact(ABC):
    """A data object that agents interact with."""
    pass

class Context(ABC):
    """Environmental factors influencing actions."""
    data: Dict[str, Any]

    def __init__(self, **data):
        self.data = data

class Profile(ABC):
    """Behavioral profile encapsulating decision parameters for an Agent."""
    @abstractmethod
    def parameters(self) -> Dict[str, Any]:
        pass

class Policy(ABC):
    """Rule or strategy that conditions agent behavior."""
    @abstractmethod
    def evaluate(self, agent: 'Agent', artifact: Artifact, context: Context) -> bool:
        pass

class Event:
    """An occurrence triggered by an action."""
    def __init__(self, name: str, payload: Dict[str, Any] = None):
        self.name = name
        self.payload = payload or {}

# Event registry and decorator
_EVENT_HANDLERS: Dict[str, List[Callable[[Event], None]]] = {}

def on_event(event_name: str):
    """Decorator to subscribe a handler to an event."""
    def decorator(fn: Callable[[Event], None]):
        _EVENT_HANDLERS.setdefault(event_name, []).append(fn)
        return fn
    return decorator

def emit(event: Event):
    """Emit an event to all subscribers."""
    for handler in _EVENT_HANDLERS.get(event.name, []):
        handler(event)

from typing import Dict, Any, Callable
from abc import ABC

class Agent(ABC):
    def __init__(
        self,
        env: simpy.Environment,
        id: str,
        profile: Profile = None,
        policy: Policy = None,
        actions: Dict[str, Callable] = None
    ):
        self.env = env
        self.id = id
        self.profile = profile
        self.policy = policy
        self.actions = actions or {}
        self._register_actions()

    def _register_actions(self):
        for name in dir(self):
            if callable(getattr(self, name)) and not name.startswith("_"):
                self.actions.setdefault(name, getattr(self, name))

    async def perform(
        self,
        action: str,
        artifact: Artifact,
        context: Dict[str, Context],
        **kwargs
    ) -> Any:
        if action not in self.actions:
            raise ValueError(f"Action '{action}' not found for agent '{self.id}'")

        method = self.actions[action]

        result = await self._consume_async_generator(
            method(artifact, context, **kwargs)
        )

        ev = Event(f"{self.__class__.__name__}.{action}", {
            'agent_id': self.id,
            'artifact': artifact,
            'context': {k: v.dict() if hasattr(v, 'dict') else str(v) for k, v in context.items()},
            'result': result,
        })
        emit(ev)
        return result

    async def _consume_async_generator(self, async_gen) -> Any:
        """Consume the async generator and return the final value."""
        async for item in async_gen:
            # Return the last yielded value
            pass
        return item

# --- Example Implementations --- #

# Artifact subclass
class Product(Artifact):
    def __init__(self, sku: str, price: float):
        self.sku = sku
        self.price = price

class Inventory(Artifact):
    """Represents the inventory in a store."""
    
    def __init__(self):
        self.items = {}  # Mapping from product SKU to quantity

    def add_product(self, product: Product, quantity: int):
        """Add products to the inventory."""
        if product.sku in self.items:
            self.items[product.sku] += quantity
        else:
            self.items[product.sku] = quantity

    def remove_product(self, product: Product, quantity: int):
        """Remove products from inventory."""
        if product.sku in self.items and self.items[product.sku] >= quantity:
            self.items[product.sku] -= quantity
            return True
        return False

    def check_stock(self, product: Product) -> int:
        """Check the stock of a specific product."""
        return self.items.get(product.sku, 0)


# Context subclass
class MarketContext(Context):
    pass

class LeadTime(Context):
    """Represents the lead time for inventory restocking and product arrival."""
    
    def __init__(self, restocking_lead_time: int, delivery_lead_time: int):
        super().__init__(
            restocking_lead_time=restocking_lead_time, 
            delivery_lead_time=delivery_lead_time
        )
        self.restocking_lead_time = restocking_lead_time
        self.delivery_lead_time = delivery_lead_time

# Profile subclass
class CustomerProfile(Profile):
    def __init__(self, price_sensitivity: float):
        self.price_sensitivity = price_sensitivity
    def parameters(self) -> Dict[str, Any]:
        return {'price_sensitivity': self.price_sensitivity}

# Policy subclass
class PurchasePolicy(Policy):
    def __init__(self, threshold: float):
        self.threshold = threshold

    def evaluate(self, agent: Agent, artifact: Product, context: MarketContext) -> bool:
        sensitivity = agent.profile.parameters()['price_sensitivity']
        return artifact.price < self.threshold * sensitivity

# Agent subclass
class Customer(Agent):
    def __init__(self, env: simpy.Environment, id: str, profile: CustomerProfile, policy: PurchasePolicy):
        super().__init__(env, id, profile, policy)

    def define_actions(self) -> Dict[str, Callable]:
        return {
            'place_order': self.place_order,
            'add_to_cart': self.add_to_cart
        }

    async def place_order(self, store: "Store", sku: str, quantity: int, context: dict):
        return await store.process_order(self, sku, quantity, context)

    async def add_to_cart(self, product: Product, context: Dict[str, Context], quantity: int = 1):
        yield self.env.timeout(0.5)
        print(f"Customer {self.id} added {quantity}x {product.sku} to cart.")
        yield True

class InventoryManager(Agent):
    def __init__(
        self, env: simpy.Environment, id: str, 
        profile: Profile, policy: Policy, 
        inventory: Inventory
    ):
        super().__init__(env, id, profile, policy)
        self.inventory = inventory

    async def manage_inventory(
        self, 
        product: Product, context: Dict[str, Context], 
        quantity: int, action: str
    ) -> bool:
        await asyncio.sleep(0.1)
        if action == "add":
            self.inventory.add_product(product, quantity)
            return True
        elif action == "remove":
            self.inventory.remove_product(product, quantity)
            return True
        else:
            raise ValueError(f"Unknown action '{action}'")

class Salesman(Agent):
    def __init__(
        self, env: simpy.Environment, id: str, 
        profile: Profile, policy: Policy, 
        inventory: Inventory
    ):
        super().__init__(env, id, profile, policy)
        self.inventory = inventory

    async def sell(
        self, product: Product, context: MarketContext, quantity: int
    ) -> bool:
        await asyncio.sleep(1)
        if self.inventory.check_stock(product) >= quantity:
            self.inventory.remove_product(product, quantity)
            print(f"Salesman {self.id} sold {quantity}x {product.sku}.")
            return True
        print(f"Salesman {self.id} could not sell {quantity}x {product.sku} due to stock shortage.")
        return False

class Supplier(Agent):
    def __init__(
        self, env: simpy.Environment, 
        id: str, inventory: Inventory, 
        supply_time: float = 1.0
    ):
        super().__init__(env, id)
        self.inventory = inventory  # Supplier's own inventory
        self.supply_time = supply_time

    async def supply_product(self, product: Product, quantity: int, inventory: Inventory) -> bool:
        """Supply a specific store with a product."""
        print(f"Supplier {self.id} is supplying {quantity}x {product.sku} to {store.id}.")
        await asyncio.sleep(1)  # Simulate supply time
        
        # Check if supplier has enough product
        if self.inventory.check_stock(product) >= quantity:
            # Decrease the supplier's stock
            self.inventory.remove_product(product, quantity)

            # Supply time is simulated
            await asyncio.sleep(self.supply_time)

            # Restock the store's inventory
            inventory.add_product(product, quantity)
            print(f"Supplier {self.id} supplied {quantity}x{product.sku} to {store.id}.")
            return True
        else:
            print(f"Supplier {self.id} does not have enough {product.sku}.")
            return False

class Store(Agent):
    def __init__(
        self,
        env: simpy.Environment, id: str, 
        inventory: Inventory,
        suppliers: List[Supplier],
        num_salesmen: int = 5,
        restock_threshold: int = 10
    ):
        super().__init__(env, id)
        self.inventory = inventory
        self.suppliers = suppliers
        self.restock_threshold = restock_threshold

        # Staff members
        self.inventory_manager = InventoryManager(
            env, f"{id}_inv_mgr", 
            profile=None, policy=None, 
            inventory=inventory
        )
        self.salesmen = [
            Salesman(
                env, f"{id}_salesman_{i}", 
                profile=None, policy=None, 
                inventory=inventory
            )
            for i in range(num_salesmen)
        ]

    def process_order(self, customer, sku, quantity, context):
        def _inner():
            print(f"Store {self.id} received order from {customer.id} for {quantity}x{sku}")
            available_qty = self.inventory.check_stock(Product(sku, 0))
            if available_qty >= quantity:
                self.inventory.remove_product(Product(sku, 0), quantity)
                print(f"Store {self.id} fulfilled order for {customer.id}: {quantity}x{sku}")
                result = True
            else:
                print(f"Store {self.id} cannot fulfill order for {customer.id}: {quantity}x{sku}")
                result = False
            yield self.env.timeout(0)  # Optional delay
            return result
        return _inner()  # ✅ Return the generator, not the result


    async def restock_inventory(
        self, product: Product, context: Dict[str, Context], quantity: int
    ) -> bool:
        current_stock = self.inventory.check_stock(product)
        if current_stock >= self.restock_threshold:
            print(f"[{self.id}] No need to restock {product.sku} (stock: {current_stock}).")
            return False
        
        # Delegate to inventory manager who may coordinate with suppliers
        print(f"[{self.id}] Triggering restock for {product.sku}.")
        return await self.inventory_manager.manage_inventory(product, context, quantity, action="add")

    async def handle_sale(self, product: Product, context: Dict[str, Context], quantity: int) -> bool:
        # Random or round-robin strategy could be used here
        salesman = random.choice(self.salesmen)
        success = await salesman.sell(product, context, quantity)
        if success:
            print(f"[{self.id}] Sale successful: {quantity}x{product.sku}.")
        else:
            print(f"[{self.id}] Sale failed: Insufficient stock for {product.sku}.")
        return success

class RetailEcosystem:
    def __init__(self, env: simpy.Environment):
        self.env = env
        self.stores = []
        self.suppliers = []
        self.customers = []

    def add_store(self, store: Store):
        self.stores.append(store)

    def add_supplier(self, supplier: Supplier):
        self.suppliers.append(supplier)

    def add_customer(self, customer: Customer):
        self.customers.append(customer)

    def get_global_catalog(self):
        result = []
        for store in self.stores:
            for sku, quantity in store.inventory.items.items():
                if quantity > 0:
                    result.append({
                        "store_id": store.id,
                        "product": sku,
                        "available": True,
                    })
        return result

    def simulate(self, duration: int = 100):
        # Register all agent behaviors as processes
        for store in self.stores:
            self.env.process(self.store_behavior(store))

        for customer in self.customers:
            self.env.process(self.customer_behavior(customer))

        for supplier in self.suppliers:
            self.env.process(self.supplier_behavior(supplier))

        # Start simulation
        self.env.run(until=duration)

    def store_behavior(self, store: Store):
        while True:
            for product_sku, quantity in store.inventory.items.items():
                if quantity < store.restock_threshold:
                    product = Product(sku=product_sku, price=10.0)
                    context = {"time": {"lead_time": LeadTime(3, 2)}}
                    yield self.env.process(store.restock_inventory(product, context, quantity=50))
            yield self.env.timeout(10)  # Check every X time units

    def customer_behavior(self, customer: Customer):
        while True:
            catalog = self.get_global_catalog()
            if not catalog:
                print(f"No available products for {customer.id}")
                yield self.env.timeout(5)
                continue

            item = random.choice(catalog)
            store_id = item["store_id"]
            sku = item["product"]
            store = next((s for s in self.stores if s.id == store_id), None)
            if not store:
                yield self.env.timeout(1)
                continue

            quantity = random.randint(1, 5)
            context = {"market": {"lead_time": LeadTime(5, 2)}}

            # Now simulate customer placing order
            success = yield self.env.process(store.process_order(customer, sku, quantity, context))

            if success:
                print(f"Customer {customer.id} successfully ordered {quantity}x{sku} from {store_id}.")
            else:
                print(f"Customer {customer.id} failed to order {sku} from {store_id}.")

            yield self.env.timeout(random.randint(10, 20))  # Wait before next order


    def supplier_behavior(self, supplier: Supplier):
        while True:
            # Placeholder for supplier behavior (e.g., batching, responding to orders)
            yield self.env.timeout(1)

# Create simulation environment
env = simpy.Environment()

# Create products
product_1 = Product(sku="P123", price=5.0)
product_2 = Product(sku="P456", price=30.0)

# Create inventory, stores, suppliers, and customers
supplier_1_inventory = Inventory()
supplier_1_inventory.add_product(product_1, 100)
supplier_1 = Supplier(env, id="supplier_1", inventory=supplier_1_inventory)

# Create a product and add it to the supplier's inventory
supplier_2_inventory = Inventory()
supplier_2 = Supplier(env, id="supplier_2", inventory=supplier_2_inventory)

supplier_2_inventory.add_product(product_1, 50)
supplier_2_inventory.add_product(product_2, 200)

# Create a store
store_inventory = Inventory()
store_inventory.add_product(product_1, 20)

store = Store(
    env, id="store_x", 
    inventory=store_inventory, 
    suppliers=[supplier_1, supplier_2], 
    num_salesmen=3, 
    restock_threshold=10
)

# Create and add to Retail Ecosystem
duration = 30
ecosystem = RetailEcosystem(env)
ecosystem.add_store(store)

ecosystem.add_supplier(supplier_1)
ecosystem.add_supplier(supplier_2)

# Add a customer for testing
customer_1 = Customer(
    env=env, id="customer_1", 
    profile=CustomerProfile(price_sensitivity=0.8), 
    policy=PurchasePolicy(10.0)
)
customer_2 = Customer(
    env=env, id="customer_2", 
    profile=CustomerProfile(price_sensitivity=1), 
    policy=PurchasePolicy(30.0)
)

ecosystem.add_customer(customer_1)
ecosystem.add_customer(customer_2)

# Simulate the ecosystem
ecosystem.simulate(duration=duration)


Store store_x received order from customer_1 for 3xP123
Store store_x fulfilled order for customer_1: 3xP123
Store store_x received order from customer_2 for 3xP123
Store store_x fulfilled order for customer_2: 3xP123
Customer customer_1 successfully ordered 3xP123 from store_x.
Customer customer_2 successfully ordered 3xP123 from store_x.
Store store_x received order from customer_1 for 3xP123
Store store_x fulfilled order for customer_1: 3xP123
Customer customer_1 successfully ordered 3xP123 from store_x.
Store store_x received order from customer_2 for 2xP123
Store store_x fulfilled order for customer_2: 2xP123
Customer customer_2 successfully ordered 2xP123 from store_x.
[store_x] Triggering restock for P123.


AttributeError: 'coroutine' object has no attribute 'gi_frame'