diff --git a/py/src/braintrust/integrations/pydantic_ai/test_pydantic_ai_integration.py b/py/src/braintrust/integrations/pydantic_ai/test_pydantic_ai_integration.py index 6b553c5a..88bb97ac 100644 --- a/py/src/braintrust/integrations/pydantic_ai/test_pydantic_ai_integration.py +++ b/py/src/braintrust/integrations/pydantic_ai/test_pydantic_ai_integration.py @@ -125,7 +125,7 @@ async def test_agent_run_async(memory_logger): assert chat_span is not None, "chat span not found" # Check agent span - assert agent_span["span_attributes"]["type"] == SpanTypeAttribute.LLM + assert agent_span["span_attributes"]["type"] == SpanTypeAttribute.TASK assert agent_span["metadata"]["model"] == "gpt-4o-mini" assert agent_span["metadata"]["provider"] == "openai" assert TEST_PROMPT in str(agent_span["input"]) @@ -146,6 +146,18 @@ async def test_agent_run_async(memory_logger): assert agent_span["metrics"]["prompt_tokens"] > 0 assert agent_span["metrics"]["completion_tokens"] > 0 + # Regression: no double-counting of cost/tokens. Experiment-level aggregations + # sum metrics across type='llm' spans, so a single agent turn must contribute + # its tokens exactly once. The wrapper agent_run span logs the same usage as + # the leaf chat span; only the leaf should be type=LLM. + llm_spans = [s for s in spans if s["span_attributes"]["type"] == SpanTypeAttribute.LLM] + assert len(llm_spans) == 1, f"expected exactly one LLM-typed span, got {len(llm_spans)}" + assert llm_spans[0]["span_id"] == chat_span["span_id"] + llm_prompt_tokens_sum = sum(s["metrics"].get("prompt_tokens", 0) for s in llm_spans) + llm_completion_tokens_sum = sum(s["metrics"].get("completion_tokens", 0) for s in llm_spans) + assert llm_prompt_tokens_sum == chat_span["metrics"]["prompt_tokens"] + assert llm_completion_tokens_sum == chat_span["metrics"]["completion_tokens"] + @pytest.mark.vcr @pytest.mark.asyncio @@ -205,7 +217,7 @@ def test_agent_run_sync(memory_logger): assert chat_span is not None, "chat span not found" # Check agent span - assert agent_sync_span["span_attributes"]["type"] == SpanTypeAttribute.LLM + assert agent_sync_span["span_attributes"]["type"] == SpanTypeAttribute.TASK assert agent_sync_span["metadata"]["model"] == "gpt-4o-mini" assert agent_sync_span["metadata"]["provider"] == "openai" assert TEST_PROMPT in str(agent_sync_span["input"]) @@ -287,7 +299,7 @@ async def fake_run_chat( assert len(spans) == 1, f"Expected 1 CLI span, got {len(spans)}" cli_span = spans[0] - assert cli_span["span_attributes"]["type"] == SpanTypeAttribute.LLM + assert cli_span["span_attributes"]["type"] == SpanTypeAttribute.TASK assert cli_span["span_attributes"]["name"] == "agent_to_cli_sync [cli-agent]" assert cli_span["metadata"]["model"] == "gpt-4o-mini" assert cli_span["metadata"]["provider"] == "openai" @@ -497,7 +509,7 @@ async def test_agent_run_stream(memory_logger): assert chat_span is not None, "chat span not found" # Check agent span - assert agent_span["span_attributes"]["type"] == SpanTypeAttribute.LLM + assert agent_span["span_attributes"]["type"] == SpanTypeAttribute.TASK assert agent_span["metadata"]["model"] == "gpt-4o-mini" assert "Count from 1 to 5" in str(agent_span["input"]) _assert_metrics_are_valid(agent_span["metrics"], start, end) @@ -607,7 +619,7 @@ async def test_direct_model_request(memory_logger, direct): direct_span = next((s for s in spans if s["span_attributes"]["name"] == "model_request"), None) assert direct_span is not None - assert direct_span["span_attributes"]["type"] == SpanTypeAttribute.LLM + assert direct_span["span_attributes"]["type"] == SpanTypeAttribute.TASK assert direct_span["metadata"]["model"] == "gpt-4o-mini" assert direct_span["metadata"]["provider"] == "openai" assert TEST_PROMPT in str(direct_span["input"]) @@ -637,7 +649,7 @@ def test_direct_model_request_sync(memory_logger, direct): # Find the model_request_sync span span = next((s for s in spans if s["span_attributes"]["name"] == "model_request_sync"), None) assert span is not None, "model_request_sync span not found" - assert span["span_attributes"]["type"] == SpanTypeAttribute.LLM + assert span["span_attributes"]["type"] == SpanTypeAttribute.TASK assert span["metadata"]["model"] == "gpt-4o-mini" assert TEST_PROMPT in str(span["input"]) _assert_metrics_are_valid(span["metrics"], start, end) @@ -668,7 +680,7 @@ async def test_direct_model_request_with_settings(memory_logger, direct): direct_span = next((s for s in spans if s["span_attributes"]["name"] == "model_request"), None) assert direct_span is not None - assert direct_span["span_attributes"]["type"] == SpanTypeAttribute.LLM + assert direct_span["span_attributes"]["type"] == SpanTypeAttribute.TASK # Verify model_settings is in input (NOT metadata) assert "model_settings" in direct_span["input"], "model_settings should be in input" @@ -713,7 +725,7 @@ async def test_direct_model_request_stream(memory_logger, direct): direct_span = next((s for s in spans if s["span_attributes"]["name"] == "model_request_stream"), None) assert direct_span is not None - assert direct_span["span_attributes"]["type"] == SpanTypeAttribute.LLM + assert direct_span["span_attributes"]["type"] == SpanTypeAttribute.TASK assert direct_span["metadata"]["model"] == "gpt-4o-mini" _assert_metrics_are_valid(direct_span["metrics"], start, end) @@ -804,7 +816,7 @@ class MathAnswer(BaseModel): assert chat_span is not None, "chat span not found" # Check agent span - assert agent_span["span_attributes"]["type"] == SpanTypeAttribute.LLM + assert agent_span["span_attributes"]["type"] == SpanTypeAttribute.TASK assert agent_span["metadata"]["model"] == "gpt-4o-mini" assert agent_span["metadata"]["provider"] == "openai" assert "10 + 15" in str(agent_span["input"]) @@ -1092,7 +1104,7 @@ def test_agent_run_stream_sync(memory_logger): assert chat_span is not None, "chat span not found" # Check agent span - assert agent_span["span_attributes"]["type"] == SpanTypeAttribute.LLM + assert agent_span["span_attributes"]["type"] == SpanTypeAttribute.TASK assert agent_span["metadata"]["model"] == "gpt-4o-mini" assert "Count from 1 to 3" in str(agent_span["input"]) _assert_metrics_are_valid(agent_span["metrics"], start, end) @@ -1165,7 +1177,7 @@ async def test_agent_run_stream_events(memory_logger): assert agent_span is not None, "agent_run_stream_events span not found" # Check agent span has basic structure - assert agent_span["span_attributes"]["type"] == SpanTypeAttribute.LLM + assert agent_span["span_attributes"]["type"] == SpanTypeAttribute.TASK assert agent_span["metadata"]["model"] == "gpt-4o-mini" assert "5+5" in str(agent_span["input"]) or "What" in str(agent_span["input"]) assert agent_span["metrics"]["event_count"] == event_count @@ -1194,7 +1206,7 @@ def test_direct_model_request_stream_sync(memory_logger, direct): assert len(spans) == 1 span = spans[0] - assert span["span_attributes"]["type"] == SpanTypeAttribute.LLM + assert span["span_attributes"]["type"] == SpanTypeAttribute.TASK assert span["span_attributes"]["name"] == "model_request_stream_sync" assert span["metadata"]["model"] == "gpt-4o-mini" _assert_metrics_are_valid(span["metrics"], start, end) @@ -1258,7 +1270,7 @@ async def stream_wrapper(): assert len(spans) >= 1, "Should have at least one span even with early break" span = spans[0] - assert span["span_attributes"]["type"] == SpanTypeAttribute.LLM + assert span["span_attributes"]["type"] == SpanTypeAttribute.TASK assert span["span_attributes"]["name"] == "model_request_stream" @@ -1297,7 +1309,7 @@ async def test_agent_stream_early_break(memory_logger): # Verify at least agent_run_stream span exists and has basic structure if agent_span: - assert agent_span["span_attributes"]["type"] == SpanTypeAttribute.LLM + assert agent_span["span_attributes"]["type"] == SpanTypeAttribute.TASK # Metrics may be incomplete due to early break assert "start" in agent_span["metrics"] @@ -1368,7 +1380,7 @@ async def _buffer_stream() -> LLMStreamResponse: assert len(spans) >= 1, "Should have at least one span even with early return" span = spans[0] - assert span["span_attributes"]["type"] == SpanTypeAttribute.LLM + assert span["span_attributes"]["type"] == SpanTypeAttribute.TASK assert span["span_attributes"]["name"] == "model_request_stream" assert "start" in span["metrics"] assert span["metrics"]["start"] >= start @@ -1446,7 +1458,7 @@ async def _consume_until_final() -> StreamEvent: # Find agent_run_stream span agent_span = next((s for s in spans if "agent_run_stream" in s["span_attributes"]["name"]), None) assert agent_span is not None, "agent_run_stream span should exist" - assert agent_span["span_attributes"]["type"] == SpanTypeAttribute.LLM + assert agent_span["span_attributes"]["type"] == SpanTypeAttribute.TASK assert "start" in agent_span["metrics"] @@ -1500,7 +1512,7 @@ async def test_agent_with_binary_content(memory_logger): assert chat_span is not None, "chat span not found" # Verify basic span structure - assert agent_span["span_attributes"]["type"] == SpanTypeAttribute.LLM + assert agent_span["span_attributes"]["type"] == SpanTypeAttribute.TASK assert agent_span["metadata"]["model"] == "gpt-4o-mini" _assert_metrics_are_valid(agent_span["metrics"], start, end) @@ -2113,7 +2125,7 @@ class Product(BaseModel): assert chat_span is not None, "chat span not found" # Check agent span - assert agent_span["span_attributes"]["type"] == SpanTypeAttribute.LLM + assert agent_span["span_attributes"]["type"] == SpanTypeAttribute.TASK assert agent_span["metadata"]["model"] == "gpt-4o-mini" _assert_metrics_are_valid(agent_span["metrics"], start, end) @@ -2663,7 +2675,7 @@ async def test_no_model_agent_run(memory_logger): assert agent_span is not None, "agent_run span not found" assert chat_span is not None, "chat span not found" - assert agent_span["span_attributes"]["type"] == SpanTypeAttribute.LLM + assert agent_span["span_attributes"]["type"] == SpanTypeAttribute.TASK assert agent_span["metadata"]["model"] == "gpt-4o-mini" assert agent_span["metadata"]["provider"] == "openai" assert TEST_PROMPT in str(agent_span["input"]) diff --git a/py/src/braintrust/integrations/pydantic_ai/test_pydantic_ai_logfire.py b/py/src/braintrust/integrations/pydantic_ai/test_pydantic_ai_logfire.py index 661b7bf7..c8dde528 100644 --- a/py/src/braintrust/integrations/pydantic_ai/test_pydantic_ai_logfire.py +++ b/py/src/braintrust/integrations/pydantic_ai/test_pydantic_ai_logfire.py @@ -66,7 +66,7 @@ async def test_no_model_agent_run_with_logfire(memory_logger): assert agent_span is not None, "agent_run span not found" assert chat_span is not None, "chat span not found" - assert agent_span["span_attributes"]["type"] == SpanTypeAttribute.LLM + assert agent_span["span_attributes"]["type"] == SpanTypeAttribute.TASK assert agent_span["metadata"]["model"] == "gpt-4o-mini" assert agent_span["metadata"]["provider"] == "openai" assert TEST_PROMPT in str(agent_span["input"]) diff --git a/py/src/braintrust/integrations/pydantic_ai/tracing.py b/py/src/braintrust/integrations/pydantic_ai/tracing.py index 547cd857..1019441e 100644 --- a/py/src/braintrust/integrations/pydantic_ai/tracing.py +++ b/py/src/braintrust/integrations/pydantic_ai/tracing.py @@ -70,7 +70,7 @@ async def _agent_run_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any with start_span( name=f"agent_run [{instance.name}]" if hasattr(instance, "name") and instance.name else "agent_run", - type=SpanTypeAttribute.LLM, + type=SpanTypeAttribute.TASK, input=input_data if input_data else None, metadata=metadata, ) as agent_span: @@ -96,7 +96,7 @@ def _agent_run_sync_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any) with start_span( name=f"agent_run_sync [{instance.name}]" if hasattr(instance, "name") and instance.name else "agent_run_sync", - type=SpanTypeAttribute.LLM, + type=SpanTypeAttribute.TASK, input=input_data if input_data else None, metadata=metadata, ) as agent_span: @@ -124,7 +124,7 @@ def _agent_to_cli_sync_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: A name=f"agent_to_cli_sync [{instance.name}]" if hasattr(instance, "name") and instance.name else "agent_to_cli_sync", - type=SpanTypeAttribute.LLM, + type=SpanTypeAttribute.TASK, input=input_data if input_data else None, metadata=metadata, ) as agent_span: @@ -156,7 +156,7 @@ def _agent_run_stream_sync_wrapper(wrapped: Any, instance: Any, args: Any, kwarg # Create span context BEFORE calling wrapped function so internal spans nest under it span_cm = start_span( name=span_name, - type=SpanTypeAttribute.LLM, + type=SpanTypeAttribute.TASK, input=input_data if input_data else None, metadata=metadata, ) @@ -189,7 +189,7 @@ async def _agent_run_stream_events_wrapper(wrapped: Any, instance: Any, args: An with start_span( name=span_name, - type=SpanTypeAttribute.LLM, + type=SpanTypeAttribute.TASK, input=input_data if input_data else None, metadata=metadata, ) as agent_span: @@ -236,7 +236,7 @@ async def wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any): with start_span( name="model_request", - type=SpanTypeAttribute.LLM, + type=SpanTypeAttribute.TASK, input=input_data, metadata=metadata, ) as span: @@ -261,7 +261,7 @@ def wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any): with start_span( name="model_request_sync", - type=SpanTypeAttribute.LLM, + type=SpanTypeAttribute.TASK, input=input_data, metadata=metadata, ) as span: @@ -289,6 +289,7 @@ def wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any): "model_request_stream", input_data, metadata, + span_type=SpanTypeAttribute.TASK, ) return wrapper @@ -316,7 +317,7 @@ async def wrapper(*args, **kwargs): with start_span( name="model_request", - type=SpanTypeAttribute.LLM, + type=SpanTypeAttribute.TASK, input=input_data, metadata=metadata, ) as span: @@ -339,7 +340,7 @@ def wrapper(*args, **kwargs): with start_span( name="model_request_sync", - type=SpanTypeAttribute.LLM, + type=SpanTypeAttribute.TASK, input=input_data, metadata=metadata, ) as span: @@ -365,6 +366,7 @@ def wrapper(*args, **kwargs): "model_request_stream", input_data, metadata, + span_type=SpanTypeAttribute.TASK, ) return wrapper @@ -466,7 +468,7 @@ async def __aenter__(self): # DON'T pass start_time here - we'll set it via metrics in __aexit__ self.span_cm = start_span( name=self.span_name, - type=SpanTypeAttribute.LLM, + type=SpanTypeAttribute.TASK, input=self.input_data if self.input_data else None, metadata=self.metadata, ) @@ -535,13 +537,26 @@ async def wrapped_method(*args, **kwargs): class _DirectStreamWrapper(AbstractAsyncContextManager): - """Wrapper for model_request_stream() that adds tracing while passing through the stream.""" + """Wrapper for model_request_stream() that adds tracing while passing through the stream. - def __init__(self, stream_cm: Any, span_name: str, input_data: Any, metadata: Any): + Used both as the leaf `chat ` span (from `_wrap_concrete_model_class`, default + `span_type=LLM`) and as a non-leaf wrapper around a nested model call (from + `direct.model_request_stream`, which passes `span_type=TASK` to avoid double-counting). + """ + + def __init__( + self, + stream_cm: Any, + span_name: str, + input_data: Any, + metadata: Any, + span_type: str = SpanTypeAttribute.LLM, + ): self.stream_cm = stream_cm self.span_name = span_name self.input_data = input_data self.metadata = metadata + self.span_type = span_type self.span_cm = None self.start_time = None self.stream = None @@ -555,7 +570,7 @@ async def __aenter__(self): # DON'T pass start_time here - we'll set it via metrics in __aexit__ self.span_cm = start_span( name=self.span_name, - type=SpanTypeAttribute.LLM, + type=self.span_type, input=self.input_data if self.input_data else None, metadata=self.metadata, ) @@ -723,7 +738,7 @@ def __enter__(self): # DON'T pass start_time here - we'll set it via metrics in __exit__ self.span_cm = start_span( name=self.span_name, - type=SpanTypeAttribute.LLM, + type=SpanTypeAttribute.TASK, input=self.input_data if self.input_data else None, metadata=self.metadata, )