Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ async def test_agent_run_async(memory_logger):
assert chat_span is not None, "chat span not found"

# Check agent span
assert agent_span["span_attributes"]["type"] == SpanTypeAttribute.LLM
assert agent_span["span_attributes"]["type"] == SpanTypeAttribute.TASK
assert agent_span["metadata"]["model"] == "gpt-4o-mini"
assert agent_span["metadata"]["provider"] == "openai"
assert TEST_PROMPT in str(agent_span["input"])
Expand All @@ -146,6 +146,18 @@ async def test_agent_run_async(memory_logger):
assert agent_span["metrics"]["prompt_tokens"] > 0
assert agent_span["metrics"]["completion_tokens"] > 0

# Regression: no double-counting of cost/tokens. Experiment-level aggregations
# sum metrics across type='llm' spans, so a single agent turn must contribute
# its tokens exactly once. The wrapper agent_run span logs the same usage as
# the leaf chat span; only the leaf should be type=LLM.
llm_spans = [s for s in spans if s["span_attributes"]["type"] == SpanTypeAttribute.LLM]
assert len(llm_spans) == 1, f"expected exactly one LLM-typed span, got {len(llm_spans)}"
assert llm_spans[0]["span_id"] == chat_span["span_id"]
llm_prompt_tokens_sum = sum(s["metrics"].get("prompt_tokens", 0) for s in llm_spans)
llm_completion_tokens_sum = sum(s["metrics"].get("completion_tokens", 0) for s in llm_spans)
assert llm_prompt_tokens_sum == chat_span["metrics"]["prompt_tokens"]
assert llm_completion_tokens_sum == chat_span["metrics"]["completion_tokens"]


@pytest.mark.vcr
@pytest.mark.asyncio
Expand Down Expand Up @@ -205,7 +217,7 @@ def test_agent_run_sync(memory_logger):
assert chat_span is not None, "chat span not found"

# Check agent span
assert agent_sync_span["span_attributes"]["type"] == SpanTypeAttribute.LLM
assert agent_sync_span["span_attributes"]["type"] == SpanTypeAttribute.TASK
assert agent_sync_span["metadata"]["model"] == "gpt-4o-mini"
assert agent_sync_span["metadata"]["provider"] == "openai"
assert TEST_PROMPT in str(agent_sync_span["input"])
Expand Down Expand Up @@ -287,7 +299,7 @@ async def fake_run_chat(
assert len(spans) == 1, f"Expected 1 CLI span, got {len(spans)}"

cli_span = spans[0]
assert cli_span["span_attributes"]["type"] == SpanTypeAttribute.LLM
assert cli_span["span_attributes"]["type"] == SpanTypeAttribute.TASK
assert cli_span["span_attributes"]["name"] == "agent_to_cli_sync [cli-agent]"
assert cli_span["metadata"]["model"] == "gpt-4o-mini"
assert cli_span["metadata"]["provider"] == "openai"
Expand Down Expand Up @@ -497,7 +509,7 @@ async def test_agent_run_stream(memory_logger):
assert chat_span is not None, "chat span not found"

# Check agent span
assert agent_span["span_attributes"]["type"] == SpanTypeAttribute.LLM
assert agent_span["span_attributes"]["type"] == SpanTypeAttribute.TASK
assert agent_span["metadata"]["model"] == "gpt-4o-mini"
assert "Count from 1 to 5" in str(agent_span["input"])
_assert_metrics_are_valid(agent_span["metrics"], start, end)
Expand Down Expand Up @@ -607,7 +619,7 @@ async def test_direct_model_request(memory_logger, direct):
direct_span = next((s for s in spans if s["span_attributes"]["name"] == "model_request"), None)
assert direct_span is not None

assert direct_span["span_attributes"]["type"] == SpanTypeAttribute.LLM
assert direct_span["span_attributes"]["type"] == SpanTypeAttribute.TASK
assert direct_span["metadata"]["model"] == "gpt-4o-mini"
assert direct_span["metadata"]["provider"] == "openai"
assert TEST_PROMPT in str(direct_span["input"])
Expand Down Expand Up @@ -637,7 +649,7 @@ def test_direct_model_request_sync(memory_logger, direct):
# Find the model_request_sync span
span = next((s for s in spans if s["span_attributes"]["name"] == "model_request_sync"), None)
assert span is not None, "model_request_sync span not found"
assert span["span_attributes"]["type"] == SpanTypeAttribute.LLM
assert span["span_attributes"]["type"] == SpanTypeAttribute.TASK
assert span["metadata"]["model"] == "gpt-4o-mini"
assert TEST_PROMPT in str(span["input"])
_assert_metrics_are_valid(span["metrics"], start, end)
Expand Down Expand Up @@ -668,7 +680,7 @@ async def test_direct_model_request_with_settings(memory_logger, direct):
direct_span = next((s for s in spans if s["span_attributes"]["name"] == "model_request"), None)
assert direct_span is not None

assert direct_span["span_attributes"]["type"] == SpanTypeAttribute.LLM
assert direct_span["span_attributes"]["type"] == SpanTypeAttribute.TASK

# Verify model_settings is in input (NOT metadata)
assert "model_settings" in direct_span["input"], "model_settings should be in input"
Expand Down Expand Up @@ -713,7 +725,7 @@ async def test_direct_model_request_stream(memory_logger, direct):
direct_span = next((s for s in spans if s["span_attributes"]["name"] == "model_request_stream"), None)
assert direct_span is not None

assert direct_span["span_attributes"]["type"] == SpanTypeAttribute.LLM
assert direct_span["span_attributes"]["type"] == SpanTypeAttribute.TASK
assert direct_span["metadata"]["model"] == "gpt-4o-mini"
_assert_metrics_are_valid(direct_span["metrics"], start, end)

Expand Down Expand Up @@ -804,7 +816,7 @@ class MathAnswer(BaseModel):
assert chat_span is not None, "chat span not found"

# Check agent span
assert agent_span["span_attributes"]["type"] == SpanTypeAttribute.LLM
assert agent_span["span_attributes"]["type"] == SpanTypeAttribute.TASK
assert agent_span["metadata"]["model"] == "gpt-4o-mini"
assert agent_span["metadata"]["provider"] == "openai"
assert "10 + 15" in str(agent_span["input"])
Expand Down Expand Up @@ -1092,7 +1104,7 @@ def test_agent_run_stream_sync(memory_logger):
assert chat_span is not None, "chat span not found"

# Check agent span
assert agent_span["span_attributes"]["type"] == SpanTypeAttribute.LLM
assert agent_span["span_attributes"]["type"] == SpanTypeAttribute.TASK
assert agent_span["metadata"]["model"] == "gpt-4o-mini"
assert "Count from 1 to 3" in str(agent_span["input"])
_assert_metrics_are_valid(agent_span["metrics"], start, end)
Expand Down Expand Up @@ -1165,7 +1177,7 @@ async def test_agent_run_stream_events(memory_logger):
assert agent_span is not None, "agent_run_stream_events span not found"

# Check agent span has basic structure
assert agent_span["span_attributes"]["type"] == SpanTypeAttribute.LLM
assert agent_span["span_attributes"]["type"] == SpanTypeAttribute.TASK
assert agent_span["metadata"]["model"] == "gpt-4o-mini"
assert "5+5" in str(agent_span["input"]) or "What" in str(agent_span["input"])
assert agent_span["metrics"]["event_count"] == event_count
Expand Down Expand Up @@ -1194,7 +1206,7 @@ def test_direct_model_request_stream_sync(memory_logger, direct):
assert len(spans) == 1

span = spans[0]
assert span["span_attributes"]["type"] == SpanTypeAttribute.LLM
assert span["span_attributes"]["type"] == SpanTypeAttribute.TASK
assert span["span_attributes"]["name"] == "model_request_stream_sync"
assert span["metadata"]["model"] == "gpt-4o-mini"
_assert_metrics_are_valid(span["metrics"], start, end)
Expand Down Expand Up @@ -1258,7 +1270,7 @@ async def stream_wrapper():
assert len(spans) >= 1, "Should have at least one span even with early break"

span = spans[0]
assert span["span_attributes"]["type"] == SpanTypeAttribute.LLM
assert span["span_attributes"]["type"] == SpanTypeAttribute.TASK
assert span["span_attributes"]["name"] == "model_request_stream"


Expand Down Expand Up @@ -1297,7 +1309,7 @@ async def test_agent_stream_early_break(memory_logger):

# Verify at least agent_run_stream span exists and has basic structure
if agent_span:
assert agent_span["span_attributes"]["type"] == SpanTypeAttribute.LLM
assert agent_span["span_attributes"]["type"] == SpanTypeAttribute.TASK
# Metrics may be incomplete due to early break
assert "start" in agent_span["metrics"]

Expand Down Expand Up @@ -1368,7 +1380,7 @@ async def _buffer_stream() -> LLMStreamResponse:
assert len(spans) >= 1, "Should have at least one span even with early return"

span = spans[0]
assert span["span_attributes"]["type"] == SpanTypeAttribute.LLM
assert span["span_attributes"]["type"] == SpanTypeAttribute.TASK
assert span["span_attributes"]["name"] == "model_request_stream"
assert "start" in span["metrics"]
assert span["metrics"]["start"] >= start
Expand Down Expand Up @@ -1446,7 +1458,7 @@ async def _consume_until_final() -> StreamEvent:
# Find agent_run_stream span
agent_span = next((s for s in spans if "agent_run_stream" in s["span_attributes"]["name"]), None)
assert agent_span is not None, "agent_run_stream span should exist"
assert agent_span["span_attributes"]["type"] == SpanTypeAttribute.LLM
assert agent_span["span_attributes"]["type"] == SpanTypeAttribute.TASK
assert "start" in agent_span["metrics"]


Expand Down Expand Up @@ -1500,7 +1512,7 @@ async def test_agent_with_binary_content(memory_logger):
assert chat_span is not None, "chat span not found"

# Verify basic span structure
assert agent_span["span_attributes"]["type"] == SpanTypeAttribute.LLM
assert agent_span["span_attributes"]["type"] == SpanTypeAttribute.TASK
assert agent_span["metadata"]["model"] == "gpt-4o-mini"
_assert_metrics_are_valid(agent_span["metrics"], start, end)

Expand Down Expand Up @@ -2113,7 +2125,7 @@ class Product(BaseModel):
assert chat_span is not None, "chat span not found"

# Check agent span
assert agent_span["span_attributes"]["type"] == SpanTypeAttribute.LLM
assert agent_span["span_attributes"]["type"] == SpanTypeAttribute.TASK
assert agent_span["metadata"]["model"] == "gpt-4o-mini"
_assert_metrics_are_valid(agent_span["metrics"], start, end)

Expand Down Expand Up @@ -2663,7 +2675,7 @@ async def test_no_model_agent_run(memory_logger):
assert agent_span is not None, "agent_run span not found"
assert chat_span is not None, "chat span not found"

assert agent_span["span_attributes"]["type"] == SpanTypeAttribute.LLM
assert agent_span["span_attributes"]["type"] == SpanTypeAttribute.TASK
assert agent_span["metadata"]["model"] == "gpt-4o-mini"
assert agent_span["metadata"]["provider"] == "openai"
assert TEST_PROMPT in str(agent_span["input"])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ async def test_no_model_agent_run_with_logfire(memory_logger):
assert agent_span is not None, "agent_run span not found"
assert chat_span is not None, "chat span not found"

assert agent_span["span_attributes"]["type"] == SpanTypeAttribute.LLM
assert agent_span["span_attributes"]["type"] == SpanTypeAttribute.TASK
assert agent_span["metadata"]["model"] == "gpt-4o-mini"
assert agent_span["metadata"]["provider"] == "openai"
assert TEST_PROMPT in str(agent_span["input"])
Expand Down
43 changes: 29 additions & 14 deletions py/src/braintrust/integrations/pydantic_ai/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ async def _agent_run_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any

with start_span(
name=f"agent_run [{instance.name}]" if hasattr(instance, "name") and instance.name else "agent_run",
type=SpanTypeAttribute.LLM,
type=SpanTypeAttribute.TASK,
input=input_data if input_data else None,
metadata=metadata,
) as agent_span:
Expand All @@ -96,7 +96,7 @@ def _agent_run_sync_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any)

with start_span(
name=f"agent_run_sync [{instance.name}]" if hasattr(instance, "name") and instance.name else "agent_run_sync",
type=SpanTypeAttribute.LLM,
type=SpanTypeAttribute.TASK,
input=input_data if input_data else None,
metadata=metadata,
) as agent_span:
Expand Down Expand Up @@ -124,7 +124,7 @@ def _agent_to_cli_sync_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: A
name=f"agent_to_cli_sync [{instance.name}]"
if hasattr(instance, "name") and instance.name
else "agent_to_cli_sync",
type=SpanTypeAttribute.LLM,
type=SpanTypeAttribute.TASK,
input=input_data if input_data else None,
metadata=metadata,
) as agent_span:
Expand Down Expand Up @@ -156,7 +156,7 @@ def _agent_run_stream_sync_wrapper(wrapped: Any, instance: Any, args: Any, kwarg
# Create span context BEFORE calling wrapped function so internal spans nest under it
span_cm = start_span(
name=span_name,
type=SpanTypeAttribute.LLM,
type=SpanTypeAttribute.TASK,
input=input_data if input_data else None,
metadata=metadata,
)
Expand Down Expand Up @@ -189,7 +189,7 @@ async def _agent_run_stream_events_wrapper(wrapped: Any, instance: Any, args: An

with start_span(
name=span_name,
type=SpanTypeAttribute.LLM,
type=SpanTypeAttribute.TASK,
input=input_data if input_data else None,
metadata=metadata,
) as agent_span:
Expand Down Expand Up @@ -236,7 +236,7 @@ async def wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any):

with start_span(
name="model_request",
type=SpanTypeAttribute.LLM,
type=SpanTypeAttribute.TASK,
input=input_data,
metadata=metadata,
) as span:
Expand All @@ -261,7 +261,7 @@ def wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any):

with start_span(
name="model_request_sync",
type=SpanTypeAttribute.LLM,
type=SpanTypeAttribute.TASK,
input=input_data,
metadata=metadata,
) as span:
Expand Down Expand Up @@ -289,6 +289,7 @@ def wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any):
"model_request_stream",
input_data,
metadata,
span_type=SpanTypeAttribute.TASK,
)

return wrapper
Expand Down Expand Up @@ -316,7 +317,7 @@ async def wrapper(*args, **kwargs):

with start_span(
name="model_request",
type=SpanTypeAttribute.LLM,
type=SpanTypeAttribute.TASK,
input=input_data,
metadata=metadata,
) as span:
Expand All @@ -339,7 +340,7 @@ def wrapper(*args, **kwargs):

with start_span(
name="model_request_sync",
type=SpanTypeAttribute.LLM,
type=SpanTypeAttribute.TASK,
input=input_data,
metadata=metadata,
) as span:
Expand All @@ -365,6 +366,7 @@ def wrapper(*args, **kwargs):
"model_request_stream",
input_data,
metadata,
span_type=SpanTypeAttribute.TASK,
)

return wrapper
Expand Down Expand Up @@ -466,7 +468,7 @@ async def __aenter__(self):
# DON'T pass start_time here - we'll set it via metrics in __aexit__
self.span_cm = start_span(
name=self.span_name,
type=SpanTypeAttribute.LLM,
type=SpanTypeAttribute.TASK,
input=self.input_data if self.input_data else None,
metadata=self.metadata,
)
Expand Down Expand Up @@ -535,13 +537,26 @@ async def wrapped_method(*args, **kwargs):


class _DirectStreamWrapper(AbstractAsyncContextManager):
"""Wrapper for model_request_stream() that adds tracing while passing through the stream."""
"""Wrapper for model_request_stream() that adds tracing while passing through the stream.

def __init__(self, stream_cm: Any, span_name: str, input_data: Any, metadata: Any):
Used both as the leaf `chat <model>` span (from `_wrap_concrete_model_class`, default
`span_type=LLM`) and as a non-leaf wrapper around a nested model call (from
`direct.model_request_stream`, which passes `span_type=TASK` to avoid double-counting).
"""

def __init__(
self,
stream_cm: Any,
span_name: str,
input_data: Any,
metadata: Any,
span_type: str = SpanTypeAttribute.LLM,
):
self.stream_cm = stream_cm
self.span_name = span_name
self.input_data = input_data
self.metadata = metadata
self.span_type = span_type
self.span_cm = None
self.start_time = None
self.stream = None
Expand All @@ -555,7 +570,7 @@ async def __aenter__(self):
# DON'T pass start_time here - we'll set it via metrics in __aexit__
self.span_cm = start_span(
name=self.span_name,
type=SpanTypeAttribute.LLM,
type=self.span_type,
input=self.input_data if self.input_data else None,
metadata=self.metadata,
)
Expand Down Expand Up @@ -723,7 +738,7 @@ def __enter__(self):
# DON'T pass start_time here - we'll set it via metrics in __exit__
self.span_cm = start_span(
name=self.span_name,
type=SpanTypeAttribute.LLM,
type=SpanTypeAttribute.TASK,
input=self.input_data if self.input_data else None,
metadata=self.metadata,
)
Expand Down