Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions py/src/braintrust/wrappers/pydantic_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,26 @@ def agent_run_sync_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any):

wrap_function_wrapper(Agent, "run_sync", agent_run_sync_wrapper)

def agent_to_cli_sync_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any):
_ensure_model_wrapped(instance)
input_data, metadata = _build_agent_input_and_metadata(args, kwargs, instance)

with start_span(
name=f"agent_to_cli_sync [{instance.name}]"
if hasattr(instance, "name") and instance.name
else "agent_to_cli_sync",
type=SpanTypeAttribute.LLM,
input=input_data if input_data else None,
metadata=metadata,
) as agent_span:
start_time = time.time()
result = wrapped(*args, **kwargs)
end_time = time.time()
agent_span.log(metrics={"start": start_time, "end": end_time, "duration": end_time - start_time})
return result

wrap_function_wrapper(Agent, "to_cli_sync", agent_to_cli_sync_wrapper)

def agent_run_stream_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any):
_ensure_model_wrapped(instance)
input_data, metadata = _build_agent_input_and_metadata(args, kwargs, instance)
Expand Down
67 changes: 67 additions & 0 deletions py/src/braintrust/wrappers/test_pydantic_ai_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# pyright: reportUnknownParameterType=false
# pyright: reportPrivateUsage=false
import asyncio
import inspect
import time

import pytest
Expand All @@ -13,6 +14,7 @@
from pydantic import BaseModel
from pydantic_ai import Agent, ModelSettings
from pydantic_ai.messages import ModelRequest, UserPromptPart
from pydantic_ai.usage import UsageLimits

PROJECT_NAME = "test-pydantic-ai-integration"
MODEL = "openai:gpt-4o-mini" # Use cheaper model for tests
Expand Down Expand Up @@ -168,6 +170,71 @@ def is_descendant(child_span, ancestor_id):
assert "completion_tokens" in agent_sync_span["metrics"]


def test_agent_to_cli_sync(memory_logger, monkeypatch):
"""Test Agent.to_cli_sync() records a CLI session span."""
assert not memory_logger.pop()

cli_signature = inspect.signature(Agent.to_cli_sync)
message_history = [ModelRequest(parts=[UserPromptPart(content="Previous question")])]
agent = Agent(MODEL, name="cli-agent", model_settings=ModelSettings(max_tokens=50))
captured = {}

async def fake_run_chat(
*,
stream,
agent,
deps,
console,
code_theme,
prog_name,
message_history,
model_settings=None,
usage_limits=None,
):
assert stream is True
assert prog_name == "braintrust-cli"
assert message_history is not None
captured["model_settings"] = model_settings
captured["usage_limits"] = usage_limits
return 0

monkeypatch.setattr("pydantic_ai._cli.run_chat", fake_run_chat)

cli_kwargs = {
"prog_name": "braintrust-cli",
"message_history": message_history,
}
# pydantic_ai 1.10.0 exposes a smaller to_cli_sync API; newer versions add
# model_settings and usage_limits, so assert those fields only when present.
if "model_settings" in cli_signature.parameters:
cli_kwargs["model_settings"] = ModelSettings(max_tokens=20, temperature=0.2)
if "usage_limits" in cli_signature.parameters:
cli_kwargs["usage_limits"] = UsageLimits(request_limit=3)

start = time.time()
agent.to_cli_sync(**cli_kwargs)
end = time.time()

spans = memory_logger.pop()
assert len(spans) == 1, f"Expected 1 CLI span, got {len(spans)}"

cli_span = spans[0]
assert cli_span["span_attributes"]["type"] == SpanTypeAttribute.LLM
assert cli_span["span_attributes"]["name"] == "agent_to_cli_sync [cli-agent]"
assert cli_span["metadata"]["model"] == "gpt-4o-mini"
assert cli_span["metadata"]["provider"] == "openai"
assert cli_span["input"]["prog_name"] == "braintrust-cli"
assert "message_history" in cli_span["input"]
if "model_settings" in cli_signature.parameters:
assert captured["model_settings"] is not None
assert cli_span["input"]["model_settings"]["max_tokens"] == 20
assert cli_span["input"]["model_settings"]["temperature"] == 0.2
if "usage_limits" in cli_signature.parameters:
assert captured["usage_limits"] is not None
assert cli_span["input"]["usage_limits"]["request_limit"] == 3
_assert_metrics_are_valid(cli_span["metrics"], start, end)


@pytest.mark.vcr
@pytest.mark.asyncio
async def test_multiple_identical_sequential_streams(memory_logger):
Expand Down