From d10fae4f37795ed7a373a96c1dcabc73a35b13a0 Mon Sep 17 00:00:00 2001 From: Abhijeet Prasad Date: Mon, 27 Apr 2026 19:23:12 -0400 Subject: [PATCH 1/2] perf(adk): reduce tracing serialization overhead Capture ADK request, config, and event fields directly instead of deep-copying entire objects on tracing paths. Preserve extra dict config fields while still serializing Pydantic schemas for logged config readability. Verified with targeted ADK capture_config and call-type tests. --- .../braintrust/integrations/adk/test_adk.py | 32 +-- py/src/braintrust/integrations/adk/tracing.py | 268 ++++++++---------- 2 files changed, 141 insertions(+), 159 deletions(-) diff --git a/py/src/braintrust/integrations/adk/test_adk.py b/py/src/braintrust/integrations/adk/test_adk.py index 8adb7ead..cda77281 100644 --- a/py/src/braintrust/integrations/adk/test_adk.py +++ b/py/src/braintrust/integrations/adk/test_adk.py @@ -1454,9 +1454,9 @@ class Person(BaseModel): @pytest.mark.asyncio -async def test_serialize_config_handles_all_schema_fields(): - """Test that _serialize_config handles all 4 schema fields.""" - from braintrust.integrations.adk.tracing import _serialize_config +async def test_capture_config_handles_all_schema_fields(): + """Test that _capture_config handles all 4 schema fields.""" + from braintrust.integrations.adk.tracing import _capture_config class TestSchema(BaseModel): value: str = Field(description="Test value") @@ -1470,7 +1470,7 @@ class TestSchema(BaseModel): "other_field": "keep me", } - serialized = _serialize_config(config) + serialized = _capture_config(config) assert isinstance(serialized, dict) @@ -1488,14 +1488,14 @@ class TestSchema(BaseModel): @pytest.mark.asyncio -async def test_serialize_config_handles_non_pydantic(): - """Test that _serialize_config handles non-Pydantic values gracefully.""" - from braintrust.integrations.adk.tracing import _serialize_config +async def test_capture_config_handles_non_pydantic(): + """Test that _capture_config handles non-Pydantic values gracefully.""" + from braintrust.integrations.adk.tracing import _capture_config # Test with non-Pydantic values config = {"response_schema": "not a pydantic model", "other_field": {"key": "value"}} - serialized = _serialize_config(config) + serialized = _capture_config(config) assert isinstance(serialized, dict) # Non-Pydantic schema should remain as-is @@ -1695,28 +1695,28 @@ async def test_adk_response_json_schema_dict(memory_logger): @pytest.mark.asyncio -async def test_serialize_config_preserves_none(): - """Test that _serialize_config returns None when config is None (not empty dict).""" - from braintrust.integrations.adk.tracing import _serialize_config +async def test_capture_config_preserves_none(): + """Test that _capture_config returns None when config is None (not empty dict).""" + from braintrust.integrations.adk.tracing import _capture_config # None should be preserved as None, not converted to {} - result = _serialize_config(None) + result = _capture_config(None) assert result is None, f"Expected None, got {result}" # Empty dict should remain empty dict - result = _serialize_config({}) + result = _capture_config({}) assert result == {} # False should be preserved as False - result = _serialize_config(False) + result = _capture_config(False) assert result is False # 0 should be preserved as 0 - result = _serialize_config(0) + result = _capture_config(0) assert result == 0 # Empty string should be preserved - result = _serialize_config("") + result = _capture_config("") assert result == "" diff --git a/py/src/braintrust/integrations/adk/tracing.py b/py/src/braintrust/integrations/adk/tracing.py index 73c3ea6f..c84b51b2 100644 --- a/py/src/braintrust/integrations/adk/tracing.py +++ b/py/src/braintrust/integrations/adk/tracing.py @@ -4,9 +4,11 @@ import inspect import logging import time -from collections.abc import Iterable +from collections.abc import Iterable, Mapping from contextlib import aclosing -from typing import Any, cast +from functools import lru_cache +from itertools import chain +from typing import Any from braintrust.bt_json import bt_safe_deep_copy from braintrust.integrations.utils import _materialize_attachment @@ -84,6 +86,7 @@ def _serialize_part(part: Any) -> Any: return bt_safe_deep_copy(part) +@lru_cache(maxsize=128) def _serialize_pydantic_schema(schema_class: Any) -> dict[str, Any]: """ Serialize a Pydantic model class to its full JSON schema. @@ -103,51 +106,49 @@ def _serialize_pydantic_schema(schema_class: Any) -> dict[str, Any]: return {"__class__": schema_class.__name__ if inspect.isclass(schema_class) else str(type(schema_class).__name__)} -def _serialize_config(config: Any) -> dict[str, Any] | Any: +def _capture_config(config: Any) -> dict[str, Any] | Any: """ - Serialize a config object, specifically handling schema fields that may contain Pydantic classes. + Capture the ADK config fields that make LLM spans readable. Google ADK uses these fields for schemas: - response_schema, response_json_schema (in GenerateContentConfig for LLM requests) - input_schema, output_schema (in agent config) """ - if config is None: - return None - if not config: + if config is None or not config: return config - # Extract schema fields BEFORE calling bt_safe_deep_copy (which converts Pydantic classes to dicts) - schema_fields = ["response_schema", "response_json_schema", "input_schema", "output_schema"] - serialized_schemas: dict[str, Any] = {} - - for field in schema_fields: - schema_value = None - - # Try to get the field value - if hasattr(config, field): - schema_value = getattr(config, field) - elif isinstance(config, dict) and field in config: - schema_value = config[field] - - # If it's a Pydantic class, serialize it - if schema_value is not None and inspect.isclass(schema_value): + config_fields = [ + "system_instruction", + "response_mime_type", + "response_schema", + "response_json_schema", + "input_schema", + "output_schema", + "max_output_tokens", + "temperature", + "top_p", + "top_k", + "stop_sequences", + "candidate_count", + ] + captured: dict[str, Any] = dict(config) if isinstance(config, dict) else {} + + for field in config_fields: + value = _get_field(config, field) + if value is None: + continue + if inspect.isclass(value): try: from pydantic import BaseModel - if issubclass(schema_value, BaseModel): - serialized_schemas[field] = _serialize_pydantic_schema(schema_value) + if issubclass(value, BaseModel): + captured[field] = _serialize_pydantic_schema(value) + continue except (TypeError, ImportError): pass + captured[field] = value - # Serialize the config - config_dict = bt_safe_deep_copy(config) - if not isinstance(config_dict, dict): - return config_dict # type: ignore - - # Replace schema fields with serialized versions - config_dict.update(serialized_schemas) - - return config_dict + return captured or config def _omit(obj: Any, keys: Iterable[str]): @@ -213,6 +214,55 @@ def _extract_model_name(response: Any, llm_request: Any, instance: Any) -> str | return None +def _get_field(value: Any, field: str, default: Any = None) -> Any: + return value.get(field, default) if isinstance(value, Mapping) else getattr(value, field, default) + + +def _part_has_field(part: Any, *field_names: str) -> bool: + return any(_get_field(part, field_name) is not None for field_name in field_names) + + +def _capture_llm_request_input(llm_request: Any) -> Any: + """Capture the ADK request fields that make LLM spans readable.""" + if llm_request is None: + return None + + contents = _get_field(llm_request, "contents") + config = _get_field(llm_request, "config") + model = _get_field(llm_request, "model") + live_connect_config = _get_field(llm_request, "live_connect_config") + + captured: dict[str, Any] = {} + if model: + captured["model"] = model + if contents: + captured["contents"] = ( + [_serialize_content(c) for c in contents] if isinstance(contents, list) else _serialize_content(contents) + ) + if config: + captured["config"] = _capture_config(config) + if live_connect_config is not None or hasattr(llm_request, "live_connect_config") or isinstance(llm_request, dict): + captured["live_connect_config"] = live_connect_config + + return captured or llm_request + + +def _event_output_with_content(last_event: Any, event_with_content: Any | None) -> Any: + if event_with_content is None or _get_field(last_event, "content") is not None: + return last_event + + content = _get_field(event_with_content, "content") + if content is None: + return last_event + + if isinstance(last_event, dict): + return {**last_event, "content": content} + + # Keep the original event instead of recursively serializing it; add the + # captured content alongside it so Braintrust can serialize both values. + return {"event": last_event, "content": content} + + def _determine_llm_call_type(llm_request: Any, model_response: Any = None) -> str: """ Determine the type of LLM call based on the request and response content. @@ -223,64 +273,30 @@ def _determine_llm_call_type(llm_request: Any, model_response: Any = None) -> st - "direct_response" if there are no tools involved or tools available but not used """ try: - # Convert to dict if it's a model object - request_dict = cast(dict[str, Any], bt_safe_deep_copy(llm_request)) - - # Check the conversation history for function responses - contents = request_dict.get("contents", []) - has_function_response = False - - for content in contents: - if isinstance(content, dict): - parts = content.get("parts", []) - for part in parts: - if isinstance(part, dict): - if "function_response" in part and part["function_response"] is not None: - has_function_response = True - - # Check if the response contains function calls + has_function_response = any( + _part_has_field(part, "function_response", "functionResponse") + for content in (_get_field(llm_request, "contents", []) or []) + for part in (_get_field(content, "parts", []) or []) + ) + response_has_function_call = False if model_response: - # Check if it's an Event object with get_function_calls method (ADK Event) if hasattr(model_response, "get_function_calls"): try: function_calls = model_response.get_function_calls() - if function_calls and len(function_calls) > 0: - response_has_function_call = True + response_has_function_call = bool(function_calls) except Exception: pass - # Fallback: Check the response dict structure if not response_has_function_call: - response_dict = bt_safe_deep_copy(model_response) - if isinstance(response_dict, dict): - # Try multiple possible response structures - # 1. Standard: response.content.parts - content = response_dict.get("content", {}) - if isinstance(content, dict): - parts = content.get("parts", []) - if isinstance(parts, list): - for part in parts: - if isinstance(part, dict): - if ("function_call" in part and part["function_call"] is not None) or ( - "functionCall" in part and part["functionCall"] is not None - ): - response_has_function_call = True - break - - # 2. Alternative: response has parts directly (for some event types) - if not response_has_function_call and "parts" in response_dict: - parts = response_dict.get("parts", []) - if isinstance(parts, list): - for part in parts: - if isinstance(part, dict): - if ("function_call" in part and part["function_call"] is not None) or ( - "functionCall" in part and part["functionCall"] is not None - ): - response_has_function_call = True - break - - # Determine the call type + content = _get_field(model_response, "content") + response_has_function_call = any( + _part_has_field(part, "function_call", "functionCall") + for part in chain( + _get_field(content, "parts", []) or [], _get_field(model_response, "parts", []) or [] + ) + ) + if has_function_response: return "response_generation" elif response_has_function_call: @@ -327,7 +343,7 @@ async def _trace(): with start_span( name=f"agent_run [{instance.name}]", type=SpanTypeAttribute.TASK, - metadata=bt_safe_deep_copy({"parent_context": parent_context, **_omit(kwargs, ["parent_context"])}), + metadata={"parent_context": parent_context, **_omit(kwargs, ["parent_context"])}, ) as agent_span: last_event = None async with aclosing(wrapped(*args, **kwargs)) as agen: @@ -350,12 +366,10 @@ async def _trace(): with start_span( name="call_llm", type=SpanTypeAttribute.TASK, - metadata=bt_safe_deep_copy( - { - "invocation_context": invocation_context, - **_omit(kwargs, ["invocation_context"]), - } - ), + metadata={ + "invocation_context": invocation_context, + **_omit(kwargs, ["invocation_context"]), + }, ) as llm_span: last_event = None async with aclosing(wrapped(*args, **kwargs)) as agen: @@ -376,28 +390,10 @@ async def _flow_call_llm_async_wrapper(wrapped: Any, instance: Any, args: Any, k model_response_event = args[2] if len(args) > 2 else kwargs.get("model_response_event") async def _trace(): - # Extract and serialize contents BEFORE converting to dict - # This is critical because bt_safe_deep_copy converts bytes to string representations - serialized_contents = None - if llm_request and hasattr(llm_request, "contents"): - contents = llm_request.contents - if contents: - serialized_contents = ( - [_serialize_content(c) for c in contents] - if isinstance(contents, list) - else _serialize_content(contents) - ) - - # Now convert the whole request to dict - serialized_request = bt_safe_deep_copy(llm_request) - - # Replace contents with our serialized version that has Attachments - if serialized_contents is not None and isinstance(serialized_request, dict): - serialized_request["contents"] = serialized_contents - - # Handle config specifically to serialize Pydantic schema classes - if isinstance(serialized_request, dict) and "config" in serialized_request: - serialized_request["config"] = _serialize_config(serialized_request["config"]) + # Capture only the fields we need to alter: contents may contain binary + # data that should become Attachments, and config may contain Pydantic + # schema classes that are clearer as JSON schema. + captured_request = _capture_llm_request_input(llm_request) # Extract model name from request or instance model_name = _extract_model_name(None, llm_request, instance) @@ -407,16 +403,14 @@ async def _trace(): with start_span( name="llm_call", type=SpanTypeAttribute.LLM, - input=serialized_request, - metadata=bt_safe_deep_copy( - { - "invocation_context": invocation_context, - "model_response_event": model_response_event, - "flow_class": instance.__class__.__name__, - "model": model_name, - **_omit(kwargs, ["invocation_context", "model_response_event", "flow_class", "llm_call_type"]), - } - ), + input=captured_request, + metadata={ + "invocation_context": invocation_context, + "model_response_event": model_response_event, + "flow_class": instance.__class__.__name__, + "model": model_name, + **_omit(kwargs, ["invocation_context", "model_response_event", "flow_class", "llm_call_type"]), + }, ) as llm_span: # Execute the LLM call and yield events while span is active last_event = None @@ -437,18 +431,8 @@ async def _trace(): # After execution, update span with correct call type and output if last_event: - # We need to check if we should merge content from an earlier event - # Convert to dict to inspect/modify, but let span.log() handle final serialization - output_dict = bt_safe_deep_copy(last_event) - if event_with_content and isinstance(output_dict, dict): - if "content" not in output_dict or output_dict.get("content") is None: - content = ( - bt_safe_deep_copy(event_with_content.content) - if hasattr(event_with_content, "content") - else None - ) - if content: - output_dict["content"] = content + # We need to check if we should merge content from an earlier event. + output = _event_output_with_content(last_event, event_with_content) # Extract metrics from response metrics = _extract_metrics(last_event) @@ -469,7 +453,7 @@ async def _trace(): ) # Log output and metrics (span.log will handle serialization) - llm_span.log(output=output_dict, metrics=metrics) + llm_span.log(output=output, metrics=metrics) async with aclosing(_trace()) as agen: async for event in agen: @@ -490,14 +474,12 @@ async def _trace(): name=f"invocation [{instance.app_name}]", type=SpanTypeAttribute.TASK, input={"new_message": serialized_message}, - metadata=bt_safe_deep_copy( - { - "user_id": user_id, - "session_id": session_id, - "state_delta": state_delta, - **_omit(kwargs, ["user_id", "session_id", "new_message", "state_delta"]), - } - ), + metadata={ + "user_id": user_id, + "session_id": session_id, + "state_delta": state_delta, + **_omit(kwargs, ["user_id", "session_id", "new_message", "state_delta"]), + }, ) as runner_span: last_event = None async with aclosing(wrapped(*args, **kwargs)) as agen: @@ -526,7 +508,7 @@ async def _tool_call_async_wrapper(wrapped: Any, instance: Any, args: Any, kwarg with start_span( name=f"tool [{tool_name}]", type=SpanTypeAttribute.TOOL, - input={"tool_name": tool_name, "arguments": bt_safe_deep_copy(tool_args)}, + input={"tool_name": tool_name, "arguments": tool_args}, metadata={"tool_class": tool.__class__.__name__ if tool is not None else None}, ) as tool_span: try: From 8d7b4d7add43a419c35871e9b902a2f057e6c78f Mon Sep 17 00:00:00 2001 From: Abhijeet Prasad Date: Tue, 28 Apr 2026 10:13:31 -0400 Subject: [PATCH 2/2] fix pylint --- py/src/braintrust/integrations/adk/test_adk.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/py/src/braintrust/integrations/adk/test_adk.py b/py/src/braintrust/integrations/adk/test_adk.py index cda77281..3473360d 100644 --- a/py/src/braintrust/integrations/adk/test_adk.py +++ b/py/src/braintrust/integrations/adk/test_adk.py @@ -1479,9 +1479,10 @@ class TestSchema(BaseModel): assert field in serialized, f"Missing {field}" schema = serialized[field] assert isinstance(schema, dict) - assert "properties" in schema - assert "value" in schema["properties"] - assert schema["properties"]["value"]["description"] == "Test value" + properties = schema.get("properties") + assert isinstance(properties, dict) + assert "value" in properties + assert properties["value"]["description"] == "Test value" # Other fields should be preserved assert "other_field" in serialized