diff --git a/.agents/skills/sdk-integrations/SKILL.md b/.agents/skills/sdk-integrations/SKILL.md index 92eed335..112b0f28 100644 --- a/.agents/skills/sdk-integrations/SKILL.md +++ b/.agents/skills/sdk-integrations/SKILL.md @@ -105,12 +105,14 @@ Do not start by wiring patchers and only later asking what the logged span shoul Keep provider-local code inside `py/src/braintrust/integrations//`. +If tracing or normalization logic is genuinely shared across multiple integrations, prefer adding it to `py/src/braintrust/integrations/utils.py` instead of copying it into each provider package. Avoid duplicating code between integrations unless there is a clear provider-specific reason the behavior must diverge. + Typical file ownership: - `__init__.py`: export the integration class, `setup_()`, and public `wrap_*()` helpers - `integration.py`: define the `BaseIntegration` subclass and register patchers - `patchers.py`: define patchers and manual `wrap_*()` helpers -- `tracing.py`: keep provider-specific tracing, stream handling, normalization, and metadata extraction +- `tracing.py`: keep provider-specific tracing, stream handling, normalization, and metadata extraction; move cross-integration helpers to `py/src/braintrust/integrations/utils.py` - `test_*.py`: keep provider behavior tests next to the integration - `cassettes/`: keep VCR recordings next to the integration tests when the provider uses HTTP diff --git a/AGENTS.md b/AGENTS.md index bf9d2ee6..b6b8aba5 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -75,6 +75,12 @@ Root `Makefile` exists as a convenience wrapper. The authoritative SDK workflow `py/noxfile.py` is the source of truth for compatibility coverage. +Testing preferences: + +- Prefer VCR-backed integration tests with checked-in cassettes whenever practical. +- Avoid mocks, fakes, and heavily synthetic tests unless there is no reasonable cassette-based alternative or the code under test is truly internal/purely local. +- When fixing a bug or issue, default to a red/green workflow: first add or update a test that reproduces the problem and fails, then implement the fix, unless the user explicitly asks for a different approach. + Key facts: - `test_core` runs without optional vendor packages. @@ -87,6 +93,8 @@ When changing behavior, run the narrowest affected session first, then expand on ## VCR +VCR/cassette coverage is the default and preferred testing strategy for provider and integration behavior in this repo. Reach for cassette-backed tests before introducing mocks or fakes, and keep new coverage aligned with the existing VCR patterns unless there is a strong reason not to. + VCR cassette directories: - `py/src/braintrust/cassettes/` @@ -162,6 +170,7 @@ Avoid editing `py/src/braintrust/version.py` while also running build commands. - Keep tests near the code they cover. - Reuse existing fixtures and cassette patterns. +- Prefer extending an existing cassette-backed test over adding a new mock-heavy test. - If a change affects examples or integrations, update the nearest example or focused test. - For CLI/devserver changes, consider whether wheel-mode behavior also needs coverage. - Do **not** add `from __future__ import annotations` unless it is absolutely required (e.g., a genuine forward-reference that cannot be resolved any other way). This import changes annotation evaluation semantics at runtime and can silently break `get_type_hints()`, Pydantic models, and other runtime introspection. Prefer quoted string literals (`"MyClass"`) or `TYPE_CHECKING` guards for forward references instead. diff --git a/py/noxfile.py b/py/noxfile.py index 89aa9082..09c044c1 100644 --- a/py/noxfile.py +++ b/py/noxfile.py @@ -10,6 +10,7 @@ nox -h Get help. """ +import functools import glob import os import pathlib @@ -476,6 +477,21 @@ def _get_braintrust_wheel(): return wheels[0] +@functools.cache +def _integration_subdirs_to_ignore() -> list[str]: + """Return integration subdirectories that require dedicated sessions. + + Top-level tests in ``src/braintrust/integrations/`` (e.g. shared utils and + versioning tests) should still run in ``test_core``. + """ + integrations_root = pathlib.Path("src") / INTEGRATION_DIR + return [ + f"{INTEGRATION_DIR}/{child.name}" + for child in integrations_root.iterdir() + if child.is_dir() and child.name != "__pycache__" + ] + + def _run_core_tests(session): """Run all tests which don't require optional dependencies.""" _run_tests( @@ -483,7 +499,7 @@ def _run_core_tests(session): SRC_DIR, ignore_paths=[ WRAPPER_DIR, - INTEGRATION_DIR, + *_integration_subdirs_to_ignore(), CONTRIB_DIR, DEVSERVER_DIR, ], diff --git a/py/src/braintrust/integrations/agentscope/tracing.py b/py/src/braintrust/integrations/agentscope/tracing.py index 37d9fa92..d2370ef5 100644 --- a/py/src/braintrust/integrations/agentscope/tracing.py +++ b/py/src/braintrust/integrations/agentscope/tracing.py @@ -6,14 +6,11 @@ from braintrust.logger import start_span from braintrust.span_types import SpanTypeAttribute - - -def _clean(mapping: dict[str, Any]) -> dict[str, Any]: - return {key: value for key, value in mapping.items() if value is not None} +from braintrust.util import clean_nones def _args_kwargs_input(args: Any, kwargs: dict[str, Any]) -> dict[str, Any]: - return _clean( + return clean_nones( { "args": list(args) if args else None, "kwargs": kwargs if kwargs else None, @@ -34,7 +31,7 @@ def _pipeline_metadata(args: Any, kwargs: dict[str, Any]) -> dict[str, Any]: if agents: agent_names = [getattr(agent, "name", agent.__class__.__name__) for agent in agents] - return _clean({"agent_names": agent_names}) + return clean_nones({"agent_names": agent_names}) def _extract_metrics(*candidates: Any) -> dict[str, float] | None: @@ -69,7 +66,7 @@ def _model_provider_name(instance: Any) -> str: def _model_metadata(instance: Any) -> dict[str, Any]: - return _clean( + return clean_nones( { "model": getattr(instance, "model_name", None), "provider": _model_provider_name(instance), @@ -95,7 +92,7 @@ def _model_call_input(args: Any, kwargs: dict[str, Any]) -> dict[str, Any]: if structured_model is None and len(args) > 3: structured_model = args[3] - return _clean( + return clean_nones( { "messages": messages, "tools": tools, @@ -125,7 +122,7 @@ def _model_call_output(result: Any) -> Any: else: return result - normalized = _clean( + normalized = clean_nones( { "role": "assistant" if data.get("content") is not None else None, "content": data.get("content"), @@ -178,7 +175,7 @@ async def _wrapper(wrapped: Any, instance: Any, args: Any, kwargs: dict[str, Any _agent_call_wrapper = _make_task_wrapper( name_fn=lambda instance, _a, _k: f"{_agent_name(instance)}.reply", - metadata_fn=lambda instance, _a, _k: _clean({"agent_class": instance.__class__.__name__}), + metadata_fn=lambda instance, _a, _k: clean_nones({"agent_class": instance.__class__.__name__}), ) _sequential_pipeline_wrapper = _make_task_wrapper( @@ -224,13 +221,13 @@ async def _toolkit_call_tool_function_wrapper(wrapped: Any, instance: Any, args: start_span( name=f"{tool_name}.execute", type=SpanTypeAttribute.TOOL, - input=_clean( + input=clean_nones( { "tool_name": tool_name, "tool_call": tool_call, } ), - metadata=_clean({"toolkit_class": instance.__class__.__name__}), + metadata=clean_nones({"toolkit_class": instance.__class__.__name__}), ) ) try: diff --git a/py/src/braintrust/integrations/agno/tracing.py b/py/src/braintrust/integrations/agno/tracing.py index 65c45dba..65a6ff7e 100644 --- a/py/src/braintrust/integrations/agno/tracing.py +++ b/py/src/braintrust/integrations/agno/tracing.py @@ -2,6 +2,7 @@ from inspect import isawaitable from typing import Any +from braintrust.integrations.utils import _try_to_dict from braintrust.logger import start_span from braintrust.span_types import SpanTypeAttribute from braintrust.util import is_numeric @@ -24,28 +25,6 @@ def get_args_kwargs(args: list[str], kwargs: dict[str, Any], keys: list[str]): return {k: args[i] if args else kwargs.get(k) for i, k in enumerate(keys)}, omit(kwargs, keys) -def _try_to_dict(obj: Any) -> Any: - """Convert object to dict, handling different object types like OpenAI wrapper.""" - if isinstance(obj, dict): - return obj - if hasattr(obj, "model_dump") and callable(obj.model_dump): - try: - return obj.model_dump() - except Exception: - pass - if hasattr(obj, "dict") and callable(obj.dict): - try: - return obj.dict() - except Exception: - pass - if hasattr(obj, "__dict__"): - try: - return obj.__dict__.copy() - except Exception: - pass - return obj - - def is_sync_iterator(result: Any) -> bool: return hasattr(result, "__iter__") and hasattr(result, "__next__") diff --git a/py/src/braintrust/integrations/anthropic/_utils.py b/py/src/braintrust/integrations/anthropic/_utils.py index cc0a0242..6e6af181 100644 --- a/py/src/braintrust/integrations/anthropic/_utils.py +++ b/py/src/braintrust/integrations/anthropic/_utils.py @@ -2,6 +2,7 @@ from typing import Any +from braintrust.integrations.utils import _try_to_dict as _shared_try_to_dict from braintrust.util import is_numeric @@ -36,27 +37,14 @@ def __getattr__(self, name: str) -> Any: def _try_to_dict(obj: Any) -> dict[str, Any] | None: - if isinstance(obj, dict): - return obj - - if hasattr(obj, "model_dump"): - try: - candidate = obj.model_dump(mode="python") - except TypeError: - candidate = obj.model_dump() - return candidate if isinstance(candidate, dict) else None - - if hasattr(obj, "to_dict"): - candidate = obj.to_dict() - return candidate if isinstance(candidate, dict) else None - - if hasattr(obj, "dict"): - candidate = obj.dict() - return candidate if isinstance(candidate, dict) else None - - if hasattr(obj, "__dict__"): - return vars(obj) + """Anthropic-flavoured object→dict conversion. + Delegates to the shared ``_try_to_dict`` first, then returns ``None`` + (instead of the original object) when conversion fails. + """ + result = _shared_try_to_dict(obj) + if isinstance(result, dict): + return result return None diff --git a/py/src/braintrust/integrations/anthropic/test_anthropic.py b/py/src/braintrust/integrations/anthropic/test_anthropic.py index 97ae08df..c3f60edd 100644 --- a/py/src/braintrust/integrations/anthropic/test_anthropic.py +++ b/py/src/braintrust/integrations/anthropic/test_anthropic.py @@ -166,6 +166,55 @@ def test_extract_anthropic_usage_includes_server_tool_use_metrics_from_objects() assert metadata == {} +def test_extract_anthropic_usage_supports_to_dict_only_objects(): + class ToDictOnly: + __slots__ = ("_payload",) + + def __init__(self, payload): + self._payload = payload + + def to_dict(self): + return self._payload + + usage = ToDictOnly( + { + "input_tokens": 11, + "output_tokens": 7, + "cache_read_input_tokens": 3, + "cache_creation": ToDictOnly( + { + "ephemeral_5m_input_tokens": 2, + "ephemeral_1h_input_tokens": 5, + } + ), + "server_tool_use": ToDictOnly( + { + "web_search_requests": 2, + "web_fetch_requests": 1, + } + ), + "service_tier": "standard", + } + ) + + metrics, metadata = extract_anthropic_usage(usage) + + assert metrics == { + "prompt_tokens": 21.0, + "completion_tokens": 7.0, + "prompt_cached_tokens": 3.0, + "prompt_cache_creation_tokens": 7.0, + "server_tool_use_web_search_requests": 2.0, + "server_tool_use_web_fetch_requests": 1.0, + "tokens": 28.0, + } + assert metadata == { + "cache_creation_ephemeral_5m_input_tokens": 2, + "cache_creation_ephemeral_1h_input_tokens": 5, + "usage_service_tier": "standard", + } + + @pytest.mark.vcr(match_on=["method", "scheme", "host", "port", "path"]) def test_anthropic_messages_create_with_image_attachment_input(memory_logger): assert not memory_logger.pop() diff --git a/py/src/braintrust/integrations/anthropic/tracing.py b/py/src/braintrust/integrations/anthropic/tracing.py index ac1e40c1..e172c8ad 100644 --- a/py/src/braintrust/integrations/anthropic/tracing.py +++ b/py/src/braintrust/integrations/anthropic/tracing.py @@ -2,11 +2,10 @@ import logging import time import warnings -from contextlib import contextmanager from braintrust.bt_json import bt_safe_deep_copy from braintrust.integrations.anthropic._utils import Wrapper, extract_anthropic_usage -from braintrust.logger import NOOP_SPAN, Attachment, log_exc_info_to_span, start_span +from braintrust.logger import Attachment, log_exc_info_to_span, start_span log = logging.getLogger(__name__) @@ -77,12 +76,10 @@ async def __create_with_stream_false(self, *args, **kwargs): try: result = await self.__messages.create(*args, **kwargs) ttft = time.time() - request_start_time - with _catch_exceptions(): - _log_message_to_span(result, span, time_to_first_token=ttft) + _log_message_to_span(result, span, time_to_first_token=ttft) return result except Exception as e: - with _catch_exceptions(): - span.log(error=e) + span.log(error=e) raise finally: span.end() @@ -93,9 +90,8 @@ async def __create_with_stream_true(self, *args, **kwargs): try: stream = await self.__messages.create(*args, **kwargs) except Exception as e: - with _catch_exceptions(): - span.log(error=e) - span.end() + span.log(error=e) + span.end() raise traced_stream = TracedMessageStream(stream, span, request_start_time) @@ -105,16 +101,14 @@ async def async_stream(): async for msg in traced_stream: yield msg except Exception as e: - with _catch_exceptions(): - span.log(error=e) + span.log(error=e) raise finally: - with _catch_exceptions(): - msg = traced_stream._get_final_traced_message() - if msg: - ttft = traced_stream._get_time_to_first_token() - _log_message_to_span(msg, span, time_to_first_token=ttft) - span.end() + msg = traced_stream._get_final_traced_message() + if msg: + ttft = traced_stream._get_time_to_first_token() + _log_message_to_span(msg, span, time_to_first_token=ttft) + span.end() return async_stream() @@ -206,12 +200,10 @@ def create(self, *args, **kwargs): span = _start_batch_create_span(kwargs) try: result = self.__batches.create(*args, **kwargs) - with _catch_exceptions(): - _log_batch_create_to_span(result, span) + _log_batch_create_to_span(result, span) return result except Exception as e: - with _catch_exceptions(): - span.log(error=e) + span.log(error=e) raise finally: span.end() @@ -220,12 +212,10 @@ def results(self, *args, **kwargs): span = _start_batch_results_span(args, kwargs) try: result = self.__batches.results(*args, **kwargs) - with _catch_exceptions(): - span.log(output={"type": "jsonl_stream"}) + span.log(output={"type": "jsonl_stream"}) return result except Exception as e: - with _catch_exceptions(): - span.log(error=e) + span.log(error=e) raise finally: span.end() @@ -242,12 +232,10 @@ async def create(self, *args, **kwargs): span = _start_batch_create_span(kwargs) try: result = await self.__batches.create(*args, **kwargs) - with _catch_exceptions(): - _log_batch_create_to_span(result, span) + _log_batch_create_to_span(result, span) return result except Exception as e: - with _catch_exceptions(): - span.log(error=e) + span.log(error=e) raise finally: span.end() @@ -256,12 +244,10 @@ async def results(self, *args, **kwargs): span = _start_batch_results_span(args, kwargs) try: result = await self.__batches.results(*args, **kwargs) - with _catch_exceptions(): - span.log(output={"type": "jsonl_stream"}) + span.log(output={"type": "jsonl_stream"}) return result except Exception as e: - with _catch_exceptions(): - span.log(error=e) + span.log(error=e) raise finally: span.end() @@ -289,26 +275,23 @@ def __aexit__(self, exc_type, exc_value, traceback): try: return self.__msg_stream_mgr.__aexit__(exc_type, exc_value, traceback) finally: - with _catch_exceptions(): - self.__close(exc_type, exc_value, traceback) + self.__close(exc_type, exc_value, traceback) def __exit__(self, exc_type, exc_value, traceback): try: return self.__msg_stream_mgr.__exit__(exc_type, exc_value, traceback) finally: - with _catch_exceptions(): - self.__close(exc_type, exc_value, traceback) + self.__close(exc_type, exc_value, traceback) def __close(self, exc_type, exc_value, traceback): - with _catch_exceptions(): - tms = self.__traced_message_stream - msg = tms._get_final_traced_message() - if msg: - ttft = tms._get_time_to_first_token() - _log_message_to_span(msg, self.__span, time_to_first_token=ttft) - if exc_type: - log_exc_info_to_span(self.__span, exc_type, exc_value, traceback) - self.__span.end() + tms = self.__traced_message_stream + msg = tms._get_final_traced_message() + if msg: + ttft = tms._get_time_to_first_token() + _log_message_to_span(msg, self.__span, time_to_first_token=ttft) + if exc_type: + log_exc_info_to_span(self.__span, exc_type, exc_value, traceback) + self.__span.end() class TracedMessageStream(Wrapper): @@ -340,80 +323,70 @@ def __iter__(self): async def __anext__(self): m = await self.__msg_stream.__anext__() - with _catch_exceptions(): - self.__process_message(m) + self.__process_message(m) return m def __next__(self): m = next(self.__msg_stream) - with _catch_exceptions(): - self.__process_message(m) + self.__process_message(m) return m def __process_message(self, m): if self.__time_to_first_token is None: self.__time_to_first_token = time.time() - self.__request_start_time - with _catch_exceptions(): - self.__snapshot = accumulate_event(event=m, current_snapshot=self.__snapshot) + self.__snapshot = accumulate_event(event=m, current_snapshot=self.__snapshot) def _start_batch_create_span(kwargs): - with _catch_exceptions(): - requests = list(kwargs.get("requests", [])) - # Extract models from the batch requests for metadata - models = set() - for req in requests: - params = req.get("params", {}) if isinstance(req, dict) else getattr(req, "params", {}) - model = params.get("model") if isinstance(params, dict) else getattr(params, "model", None) - if model: - models.add(model) + requests = list(kwargs.get("requests", [])) + # Extract models from the batch requests for metadata + models = set() + for req in requests: + params = req.get("params", {}) if isinstance(req, dict) else getattr(req, "params", {}) + model = params.get("model") if isinstance(params, dict) else getattr(params, "model", None) + if model: + models.add(model) - metadata = {"provider": "anthropic", "num_requests": len(requests)} - if len(models) == 1: - metadata["model"] = next(iter(models)) - elif models: - metadata["models"] = sorted(models) + metadata = {"provider": "anthropic", "num_requests": len(requests)} + if len(models) == 1: + metadata["model"] = next(iter(models)) + elif models: + metadata["models"] = sorted(models) - _input = [ - {"custom_id": req.get("custom_id") if isinstance(req, dict) else getattr(req, "custom_id", None)} - for req in requests - ] + _input = [ + {"custom_id": req.get("custom_id") if isinstance(req, dict) else getattr(req, "custom_id", None)} + for req in requests + ] - return start_span(name="anthropic.messages.batches.create", type="task", metadata=metadata, input=_input) - - return NOOP_SPAN + return start_span(name="anthropic.messages.batches.create", type="task", metadata=metadata, input=_input) def _log_batch_create_to_span(result, span): - with _catch_exceptions(): - output = {} - if hasattr(result, "id"): - output["id"] = result.id - if hasattr(result, "processing_status"): - output["processing_status"] = result.processing_status - if hasattr(result, "request_counts"): - rc = result.request_counts - output["request_counts"] = { - "processing": getattr(rc, "processing", 0), - "succeeded": getattr(rc, "succeeded", 0), - "errored": getattr(rc, "errored", 0), - "canceled": getattr(rc, "canceled", 0), - "expired": getattr(rc, "expired", 0), - } - - span.log(output=output) + output = {} + if hasattr(result, "id"): + output["id"] = result.id + if hasattr(result, "processing_status"): + output["processing_status"] = result.processing_status + if hasattr(result, "request_counts"): + rc = result.request_counts + output["request_counts"] = { + "processing": getattr(rc, "processing", 0), + "succeeded": getattr(rc, "succeeded", 0), + "errored": getattr(rc, "errored", 0), + "canceled": getattr(rc, "canceled", 0), + "expired": getattr(rc, "expired", 0), + } + + span.log(output=output) def _start_batch_results_span(args, kwargs): - with _catch_exceptions(): - # message_batch_id can be passed as first positional arg or as kwarg - batch_id = args[0] if args else kwargs.get("message_batch_id") - metadata = {"provider": "anthropic"} - _input = {"message_batch_id": batch_id} - return start_span(name="anthropic.messages.batches.results", type="task", metadata=metadata, input=_input) - - return NOOP_SPAN + # message_batch_id can be passed as first positional arg or as kwarg + batch_id = args[0] if args else kwargs.get("message_batch_id") + metadata = {"provider": "anthropic"} + _input = {"message_batch_id": batch_id} + return start_span(name="anthropic.messages.batches.results", type="task", metadata=metadata, input=_input) def _attachment_filename_for_media_type(media_type: str, block_type: str) -> str: @@ -487,43 +460,31 @@ def _get_metadata_from_kwargs(kwargs): def _start_span(name, kwargs): - with _catch_exceptions(): - _input = _get_input_from_kwargs(kwargs) - metadata = _get_metadata_from_kwargs(kwargs) - return start_span(name=name, type="llm", metadata=metadata, input=_input) - - return NOOP_SPAN + _input = _get_input_from_kwargs(kwargs) + metadata = _get_metadata_from_kwargs(kwargs) + return start_span(name=name, type="llm", metadata=metadata, input=_input) def _log_message_to_span(message, span, time_to_first_token: float | None = None): - with _catch_exceptions(): - usage = getattr(message, "usage", {}) - metrics, metadata = extract_anthropic_usage(usage) - - if time_to_first_token is not None: - metrics["time_to_first_token"] = time_to_first_token - - output = { - k: v - for k, v in { - "role": getattr(message, "role", None), - "content": getattr(message, "content", None), - "model": getattr(message, "model", None), - "stop_reason": getattr(message, "stop_reason", None), - "stop_sequence": getattr(message, "stop_sequence", None), - }.items() - if v is not None - } or None - - span.log(output=output, metrics=metrics, metadata=metadata) - - -@contextmanager -def _catch_exceptions(): - try: - yield - except Exception as e: - log.warning("swallowing exception in tracing code", exc_info=e) + usage = getattr(message, "usage", {}) + metrics, metadata = extract_anthropic_usage(usage) + + if time_to_first_token is not None: + metrics["time_to_first_token"] = time_to_first_token + + output = { + k: v + for k, v in { + "role": getattr(message, "role", None), + "content": getattr(message, "content", None), + "model": getattr(message, "model", None), + "stop_reason": getattr(message, "stop_reason", None), + "stop_sequence": getattr(message, "stop_sequence", None), + }.items() + if v is not None + } or None + + span.log(output=output, metrics=metrics, metadata=metadata) _BRAINTRUST_TRACED = "__braintrust_traced__" diff --git a/py/src/braintrust/integrations/google_genai/tracing.py b/py/src/braintrust/integrations/google_genai/tracing.py index f6d0fb57..00b22492 100644 --- a/py/src/braintrust/integrations/google_genai/tracing.py +++ b/py/src/braintrust/integrations/google_genai/tracing.py @@ -14,6 +14,7 @@ from braintrust.bt_json import bt_safe_deep_copy from braintrust.logger import Attachment, start_span from braintrust.span_types import SpanTypeAttribute +from braintrust.util import clean_nones if TYPE_CHECKING: @@ -218,10 +219,6 @@ def _get_args_kwargs( return {k: args[i] if args else kwargs.get(k) for i, k in enumerate(keys)}, _omit(kwargs, omit_keys or keys) -def _clean(obj: dict[str, Any]) -> dict[str, Any]: - return {k: v for k, v in obj.items() if v is not None} - - def _prepare_traced_call( api_client: Any, args: list[Any], kwargs: dict[str, Any] ) -> tuple[dict[str, Any], dict[str, Any]]: @@ -236,7 +233,7 @@ def _prepare_generate_images_traced_call( input, clean_kwargs = _get_args_kwargs(args, kwargs, ["model", "prompt", "config"], ["prompt", "config"]) if input.get("config") is not None: input["config"] = bt_safe_deep_copy(input["config"]) - return _clean(input), clean_kwargs + return clean_nones(input), clean_kwargs def _prepare_interaction_create_traced_call( @@ -244,7 +241,7 @@ def _prepare_interaction_create_traced_call( ) -> tuple[dict[str, Any], dict[str, Any]]: del api_client, args - input_data = _clean( + input_data = clean_nones( { "model": kwargs.get("model"), "agent": kwargs.get("agent"), @@ -262,7 +259,7 @@ def _prepare_interaction_create_traced_call( "agent_config": _serialize_interaction_value(kwargs.get("agent_config")), } ) - metadata = _clean( + metadata = clean_nones( { "api_version": kwargs.get("api_version"), "model": kwargs.get("model"), @@ -279,7 +276,7 @@ def _prepare_interaction_get_traced_call( del api_client interaction_id = args[0] if args else kwargs.get("id") - input_data = _clean( + input_data = clean_nones( { "id": interaction_id, "include_input": kwargs.get("include_input"), @@ -287,7 +284,7 @@ def _prepare_interaction_get_traced_call( "stream": kwargs.get("stream"), } ) - metadata = _clean({"api_version": kwargs.get("api_version")}) + metadata = clean_nones({"api_version": kwargs.get("api_version")}) return input_data, metadata @@ -297,8 +294,8 @@ def _prepare_interaction_id_traced_call( del api_client interaction_id = args[0] if args else kwargs.get("id") - input_data = _clean({"id": interaction_id}) - metadata = _clean({"api_version": kwargs.get("api_version")}) + input_data = clean_nones({"id": interaction_id}) + metadata = clean_nones({"api_version": kwargs.get("api_version")}) return input_data, metadata @@ -335,7 +332,7 @@ def _extract_generate_content_metrics(response: "GenerateContentResponse", start if hasattr(response, "usage_metadata") and response.usage_metadata: _extract_usage_metadata_metrics(response.usage_metadata, metrics) - return _clean(dict(metrics)) + return clean_nones(dict(metrics)) def _extract_embed_content_output(response: "EmbedContentResponse") -> dict[str, Any]: @@ -343,7 +340,7 @@ def _extract_embed_content_output(response: "EmbedContentResponse") -> dict[str, first_embedding = embeddings[0] if embeddings else None first_values = getattr(first_embedding, "values", None) or [] - return _clean( + return clean_nones( { "embedding_length": len(first_values) if first_values else None, "embeddings_count": len(embeddings) if embeddings else None, @@ -376,7 +373,7 @@ def _extract_embed_content_metrics(response: "EmbedContentResponse", start: floa if billable_character_count is not None: metrics["billable_characters"] = billable_character_count - return _clean(metrics) + return clean_nones(metrics) def _extract_generate_images_output(response: Any) -> dict[str, Any]: @@ -389,7 +386,7 @@ def _extract_generate_images_output(response: Any) -> dict[str, Any]: mime_type = getattr(image, "mime_type", None) safety_attributes = getattr(generated_image, "safety_attributes", None) - image_entry: dict[str, Any] = _clean( + image_entry: dict[str, Any] = clean_nones( { "mime_type": mime_type, "gcs_uri": getattr(image, "gcs_uri", None), @@ -415,7 +412,7 @@ def _extract_generate_images_output(response: Any) -> dict[str, Any]: positive_prompt_safety_attributes = getattr(response, "positive_prompt_safety_attributes", None) positive_prompt_summary = None if positive_prompt_safety_attributes is not None: - positive_prompt_summary = _clean( + positive_prompt_summary = clean_nones( { "categories": getattr(positive_prompt_safety_attributes, "categories", None), "scores": getattr(positive_prompt_safety_attributes, "scores", None), @@ -423,7 +420,7 @@ def _extract_generate_images_output(response: Any) -> dict[str, Any]: } ) - return _clean( + return clean_nones( { "generated_images_count": len(generated_images), "generated_images": serialized_images, @@ -435,7 +432,7 @@ def _extract_generate_images_output(response: Any) -> dict[str, Any]: def _extract_generic_timing_metrics(start: float) -> dict[str, Any]: end_time = time.time() - return _clean( + return clean_nones( { "start": start, "end": end_time, @@ -480,7 +477,7 @@ def _extract_interaction_output( ) -> dict[str, Any]: outputs_list = serialized_outputs if serialized_outputs is not None else _serialize_interaction_outputs(response) - return _clean( + return clean_nones( { "status": getattr(response, "status", None), "outputs": outputs_list, @@ -494,7 +491,7 @@ def _extract_interaction_metadata(response: "Interaction") -> dict[str, Any]: usage_serialized = _serialize_interaction_value(usage) usage_by_modality = None if isinstance(usage_serialized, dict): - usage_by_modality = _clean( + usage_by_modality = clean_nones( { "input_tokens_by_modality": usage_serialized.get("input_tokens_by_modality"), "output_tokens_by_modality": usage_serialized.get("output_tokens_by_modality"), @@ -503,7 +500,7 @@ def _extract_interaction_metadata(response: "Interaction") -> dict[str, Any]: } ) - return _clean( + return clean_nones( { "interaction_id": getattr(response, "id", None), "previous_interaction_id": getattr(response, "previous_interaction_id", None), @@ -553,7 +550,7 @@ def _tool_span_input(call_item: dict[str, Any] | None) -> Any: if call_item.get("arguments") is not None: return call_item["arguments"] return ( - _clean( + clean_nones( { key: value for key, value in call_item.items() @@ -570,7 +567,7 @@ def _tool_span_output(result_item: dict[str, Any] | None) -> Any: if result_item.get("result") is not None: return result_item["result"] return ( - _clean( + clean_nones( { key: value for key, value in result_item.items() @@ -694,7 +691,7 @@ def _aggregate_generate_content_chunks( if text: aggregated["text"] = text - clean_metrics = _clean(dict(metrics)) + clean_metrics = clean_nones(dict(metrics)) return aggregated, clean_metrics @@ -759,7 +756,7 @@ def _get_active_interaction_tool_spans() -> dict[str, _ActiveInteractionToolSpan def _tool_span_metadata(call_item: dict[str, Any] | None, result_item: dict[str, Any] | None) -> dict[str, Any] | None: return ( - _clean( + clean_nones( { "tool_type": (call_item or result_item or {}).get("type"), "call_id": (call_item or {}).get("id") or (result_item or {}).get("call_id"), @@ -953,7 +950,9 @@ def _aggregate_interaction_events( if first_token_time is not None: metrics["time_to_first_token"] = first_token_time - start - metadata = _clean({"stream_event_types": [et for event in events if (et := getattr(event, "event_type", None))]}) + metadata = clean_nones( + {"stream_event_types": [et for event in events if (et := getattr(event, "event_type", None))]} + ) reconstructed_outputs = _reconstruct_interaction_outputs_from_events(events) final_interaction = next( @@ -968,7 +967,7 @@ def _aggregate_interaction_events( if reconstructed_outputs: return ( {"outputs": reconstructed_outputs, "text": _extract_interaction_text(reconstructed_outputs)}, - _clean(metrics), + clean_nones(metrics), metadata, ) error_event = next( @@ -981,7 +980,7 @@ def _aggregate_interaction_events( ) if error_event is not None: metadata["stream_error"] = _serialize_interaction_value(error_event.error) - return {"events": _serialize_interaction_value(events)}, _clean(metrics), metadata + return {"events": _serialize_interaction_value(events)}, clean_nones(metrics), metadata final_outputs_list = _serialize_interaction_outputs(final_interaction) @@ -993,7 +992,7 @@ def _aggregate_interaction_events( output["outputs"] = reconstructed_outputs output["text"] = _extract_interaction_text(reconstructed_outputs) - return output, _clean(metrics), _clean(metadata) + return output, clean_nones(metrics), clean_nones(metadata) # --------------------------------------------------------------------------- diff --git a/py/src/braintrust/integrations/litellm/tracing.py b/py/src/braintrust/integrations/litellm/tracing.py index c880b490..98540855 100644 --- a/py/src/braintrust/integrations/litellm/tracing.py +++ b/py/src/braintrust/integrations/litellm/tracing.py @@ -5,9 +5,14 @@ from types import TracebackType from typing import Any +from braintrust.integrations.utils import ( + _parse_openai_usage_metrics, + _prettify_response_params, + _try_to_dict, +) from braintrust.logger import Span, start_span from braintrust.span_types import SpanTypeAttribute -from braintrust.util import is_numeric, merge_dicts +from braintrust.util import merge_dicts # LiteLLM's representation to Braintrust's representation @@ -461,7 +466,7 @@ def _update_span_payload_from_params(params: dict[str, Any], input_key: str = "i params = params.copy() span_info_d = params.pop("span_info", {}) - params = prettify_params(params) + params = _prettify_response_params(params) input_data = params.pop(input_key, None) model = params.pop("model", None) @@ -473,77 +478,8 @@ def _update_span_payload_from_params(params: dict[str, Any], input_key: str = "i def _parse_metrics_from_usage(usage: Any) -> dict[str, Any]: """Parse usage metrics from API response.""" - metrics: dict[str, Any] = {} - - if not usage: - return metrics - - usage = _try_to_dict(usage) - if not isinstance(usage, dict): - return metrics - - for oai_name, value in usage.items(): - if oai_name.endswith("_tokens_details"): - if not isinstance(value, dict): - continue - raw_prefix = oai_name[: -len("_tokens_details")] - prefix = TOKEN_PREFIX_MAP.get(raw_prefix, raw_prefix) - for k, v in value.items(): - if is_numeric(v): - metrics[f"{prefix}_{k}"] = v - elif is_numeric(value): - name = TOKEN_NAME_MAP.get(oai_name, oai_name) - metrics[name] = value - - return metrics - - -def prettify_params(params: dict[str, Any]) -> dict[str, Any]: - """Return a shallow copy of *params* with response_format serialized for logging.""" - - if "response_format" in params: - ret = params.copy() - ret["response_format"] = serialize_response_format(ret["response_format"]) - return ret - - return params - - -def _try_to_dict(obj: Any) -> dict[str, Any] | Any: - """Try to convert an object to a dictionary.""" - if isinstance(obj, dict): - return obj - if hasattr(obj, "model_dump") and callable(obj.model_dump): - try: - result = obj.model_dump() - if isinstance(result, dict): - return result - except Exception: - pass - if hasattr(obj, "dict") and callable(obj.dict): - try: - result = obj.dict() - if isinstance(result, dict): - return result - except Exception: - pass - return obj - - -def serialize_response_format(response_format: Any) -> Any: - """Serialize response format for logging.""" - try: - from pydantic import BaseModel - except ImportError: - return response_format - - if isinstance(response_format, type) and issubclass(response_format, BaseModel): - return dict( - type="json_schema", - json_schema=dict( - name=response_format.__name__, - schema=response_format.model_json_schema(), - ), - ) - else: - return response_format + return _parse_openai_usage_metrics( + usage, + token_name_map=TOKEN_NAME_MAP, + token_prefix_map=TOKEN_PREFIX_MAP, + ) diff --git a/py/src/braintrust/integrations/mistral/tracing.py b/py/src/braintrust/integrations/mistral/tracing.py index 08d8640f..789f6439 100644 --- a/py/src/braintrust/integrations/mistral/tracing.py +++ b/py/src/braintrust/integrations/mistral/tracing.py @@ -6,17 +6,23 @@ import re import time from collections.abc import AsyncIterator, Iterator -from numbers import Real from typing import Any from braintrust.bt_json import bt_safe_deep_copy +from braintrust.integrations.utils import ( + _camel_to_snake, + _convert_data_url_to_attachment, + _is_supported_metric_value, + _log_and_end_span, + _log_error_and_end_span, + _merge_timing_and_usage_metrics, +) from braintrust.logger import Attachment, start_span from braintrust.span_types import SpanTypeAttribute logger = logging.getLogger(__name__) -_DATA_URL_RE = re.compile(r"^data:([^;]+);base64,(.+)$") _BASE64_RE = re.compile(r"^[A-Za-z0-9+/]+={0,2}$") _TOKEN_NAME_MAP = { "total_tokens": "tokens", @@ -73,44 +79,10 @@ ) -def _camel_to_snake(value: str) -> str: - out = [] - for char in value: - if char.isupper(): - out.append("_") - out.append(char.lower()) - else: - out.append(char) - return "".join(out).lstrip("_") - - def _is_unset(value: Any) -> bool: return value.__class__.__name__ == "Unset" -def _is_supported_metric_value(value: Any) -> bool: - return isinstance(value, Real) and not isinstance(value, bool) - - -def _convert_data_url_to_attachment(data_url: str, filename: str | None = None) -> Attachment | str: - match = _DATA_URL_RE.match(data_url) - if not match: - return data_url - - mime_type, base64_data = match.groups() - try: - binary_data = base64.b64decode(base64_data, validate=True) - except (binascii.Error, ValueError): - return data_url - - if filename is None: - extension = mime_type.split("/")[1] if "/" in mime_type else "bin" - prefix = "image" if mime_type.startswith("image/") else "file" - filename = f"{prefix}.{extension}" - - return Attachment(data=binary_data, filename=filename, content_type=mime_type) - - def _convert_input_audio_to_attachment(value: str) -> Attachment | str: normalized = value.strip().replace("\n", "") if len(normalized) < 64 or len(normalized) % 4 != 0 or not _BASE64_RE.fullmatch(normalized): @@ -237,18 +209,6 @@ def _start_span(name: str, span_input: Any, metadata: dict[str, Any]): ) -def _timing_metrics(start_time: float, first_token_time: float | None = None) -> dict[str, float]: - end_time = time.time() - metrics = { - "start": start_time, - "end": end_time, - "duration": end_time - start_time, - } - if first_token_time is not None: - metrics["time_to_first_token"] = first_token_time - start_time - return metrics - - def _parse_usage_metrics(usage: Any) -> dict[str, float]: usage_data = sanitize_mistral_logged_value(usage) if not isinstance(usage_data, dict): @@ -266,11 +226,13 @@ def _parse_usage_metrics(usage: Any) -> dict[str, float]: return metrics -def _merge_metrics(start_time: float, usage: Any, first_token_time: float | None = None) -> dict[str, float]: - return { - **_timing_metrics(start_time, first_token_time), - **_parse_usage_metrics(usage), - } +def _merge_metrics(start_time: float, usage: Any, first_token_time: float | None = None) -> dict[str, Any]: + return _merge_timing_and_usage_metrics( + start_time, + usage, + _parse_usage_metrics, + first_token_time, + ) def _response_to_metadata(response: Any) -> dict[str, Any]: @@ -307,35 +269,11 @@ def _embeddings_output(response: Any) -> dict[str, Any]: return output -def _log_and_end( - span: Any, - *, - output: Any = None, - metrics: dict[str, Any] | None = None, - metadata: dict[str, Any] | None = None, -): - event = {} - if output is not None: - event["output"] = output - if metrics: - event["metrics"] = metrics - if metadata: - event["metadata"] = metadata - if event: - span.log(**event) - span.end() - - -def _log_error_and_end(span: Any, error: Exception): - span.log(error=error) - span.end() - - def _call_with_error_logging(span: Any, wrapped: Any, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: try: return wrapped(*args, **kwargs) except Exception as error: - _log_error_and_end(span, error) + _log_error_and_end_span(span, error) raise @@ -345,7 +283,7 @@ async def _call_async_with_error_logging( try: return await wrapped(*args, **kwargs) except Exception as error: - _log_error_and_end(span, error) + _log_error_and_end_span(span, error) raise @@ -502,7 +440,7 @@ def _aggregate_completion_events(items: list[Any]) -> dict[str, Any]: def _finalize_completion_response(span: Any, request_metadata: dict[str, Any], response: Any, start_time: float): response_metadata = _response_to_metadata(response) - _log_and_end( + _log_and_end_span( span, output=_completion_response_to_output(response), metrics=_merge_metrics(start_time, getattr(response, "usage", None)), @@ -512,7 +450,7 @@ def _finalize_completion_response(span: Any, request_metadata: dict[str, Any], r def _finalize_embeddings_response(span: Any, request_metadata: dict[str, Any], response: Any, start_time: float): response_metadata = _response_to_metadata(response) - _log_and_end( + _log_and_end_span( span, output=_embeddings_output(response), metrics=_merge_metrics(start_time, getattr(response, "usage", None)), @@ -570,11 +508,11 @@ def _finalize(self, *, error: Exception | None = None): self._closed = True if error is not None: - _log_error_and_end(self._span, error) + _log_error_and_end_span(self._span, error) return response = _aggregate_completion_events(self._items) - _log_and_end( + _log_and_end_span( self._span, output=response.get("choices"), metrics=_merge_metrics(self._start_time, response.get("usage"), self._first_token_time), @@ -632,11 +570,11 @@ def _finalize(self, *, error: Exception | None = None): self._closed = True if error is not None: - _log_error_and_end(self._span, error) + _log_error_and_end_span(self._span, error) return response = _aggregate_completion_events(self._items) - _log_and_end( + _log_and_end_span( self._span, output=response.get("choices"), metrics=_merge_metrics(self._start_time, response.get("usage"), self._first_token_time), diff --git a/py/src/braintrust/integrations/openai/tracing.py b/py/src/braintrust/integrations/openai/tracing.py index 3dde6d8a..4a49da67 100644 --- a/py/src/braintrust/integrations/openai/tracing.py +++ b/py/src/braintrust/integrations/openai/tracing.py @@ -1,17 +1,20 @@ """OpenAI-specific tracing wrappers, stream proxies, and serialization helpers.""" import abc -import base64 import inspect -import re import time -import warnings from collections.abc import Callable from typing import Any -from braintrust.logger import Attachment, Span, start_span +from braintrust.integrations.utils import ( + _convert_data_url_to_attachment, + _parse_openai_usage_metrics, + _prettify_response_params, + _try_to_dict, +) +from braintrust.logger import Span, start_span from braintrust.span_types import SpanTypeAttribute -from braintrust.util import is_numeric, merge_dicts +from braintrust.util import merge_dicts from wrapt import FunctionWrapper @@ -96,29 +99,6 @@ def _raw_response_requested(kwargs: dict[str, Any]) -> bool: return False -def _convert_data_url_to_attachment(data_url: str, filename: str | None = None) -> Attachment | str: - """Helper function to convert data URL to an Attachment.""" - data_url_match = re.match(r"^data:([^;]+);base64,(.+)$", data_url) - if not data_url_match: - return data_url - - mime_type, base64_data = data_url_match.groups() - - try: - binary_data = base64.b64decode(base64_data) - - if filename is None: - extension = mime_type.split("/")[1] if "/" in mime_type else "bin" - prefix = "image" if mime_type.startswith("image/") else "document" - filename = f"{prefix}.{extension}" - - attachment = Attachment(data=binary_data, filename=filename, content_type=mime_type) - - return attachment - except Exception: - return data_url - - def _process_attachments_in_input(input_data: Any) -> Any: """Process input to convert data URL images and base64 documents to Attachment objects.""" if isinstance(input_data, list): @@ -979,92 +959,14 @@ def process_output(self, response: Any, span: Span): def _parse_metrics_from_usage(usage: Any) -> dict[str, Any]: - # For simplicity, this function handles all the different APIs - metrics = {} - - if not usage: - return metrics - - # This might be a dict or a Usage object that can be cast to a dict - # to a dict - usage = _try_to_dict(usage) - if not isinstance(usage, dict): - return metrics # unexpected - - for oai_name, value in usage.items(): - if oai_name.endswith("_tokens_details"): - # handle `_tokens_detail` dicts - if not isinstance(value, dict): - continue # unexpected - raw_prefix = oai_name[: -len("_tokens_details")] - prefix = TOKEN_PREFIX_MAP.get(raw_prefix, raw_prefix) - for k, v in value.items(): - if is_numeric(v): - metrics[f"{prefix}_{k}"] = v - elif is_numeric(value): - name = TOKEN_NAME_MAP.get(oai_name, oai_name) - metrics[name] = value - - return metrics + return _parse_openai_usage_metrics( + usage, + token_name_map=TOKEN_NAME_MAP, + token_prefix_map=TOKEN_PREFIX_MAP, + ) def prettify_params(params: dict[str, Any]) -> dict[str, Any]: # Filter out NOT_GIVEN parameters # https://linear.app/braintrustdata/issue/BRA-2467 - ret = {k: v for k, v in params.items() if not _is_not_given(v)} - - if "response_format" in ret: - ret["response_format"] = serialize_response_format(ret["response_format"]) - return ret - - -def _try_to_dict(obj: Any) -> dict[str, Any]: - if isinstance(obj, dict): - return obj - # convert a pydantic object to a dict - # Suppress Pydantic serializer warnings from generic/discriminated-union models - # (e.g. OpenAI's ParsedResponse[T]). See - # https://github.com/braintrustdata/braintrust-sdk-python/issues/60 - if hasattr(obj, "model_dump") and callable(obj.model_dump): - try: - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", message="Pydantic serializer warnings", category=UserWarning) - return obj.model_dump() - except Exception: - pass - # deprecated pydantic method, try model_dump first. - if hasattr(obj, "dict") and callable(obj.dict): - try: - return obj.dict() - except Exception: - pass - return obj - - -def serialize_response_format(response_format: Any) -> Any: - try: - from pydantic import BaseModel - except ImportError: - return response_format - - if isinstance(response_format, type) and issubclass(response_format, BaseModel): - return dict( - type="json_schema", - json_schema=dict( - name=response_format.__name__, - schema=response_format.model_json_schema(), - ), - ) - else: - return response_format - - -def _is_not_given(value: Any) -> bool: - if value is None: - return False - try: - # Check by type name and repr to avoid import dependency - type_name = type(value).__name__ - return type_name == "NotGiven" - except Exception: - return False + return _prettify_response_params(params, drop_not_given=True) diff --git a/py/src/braintrust/integrations/openrouter/tracing.py b/py/src/braintrust/integrations/openrouter/tracing.py index 1c12a3b1..f9711bb2 100644 --- a/py/src/braintrust/integrations/openrouter/tracing.py +++ b/py/src/braintrust/integrations/openrouter/tracing.py @@ -3,10 +3,16 @@ import logging import time from collections.abc import AsyncIterator, Iterator -from numbers import Real from typing import TYPE_CHECKING, Any from braintrust.bt_json import bt_safe_deep_copy +from braintrust.integrations.utils import ( + _camel_to_snake, + _is_supported_metric_value, + _log_and_end_span, + _log_error_and_end_span, + _merge_timing_and_usage_metrics, +) from braintrust.logger import start_span from braintrust.span_types import SpanTypeAttribute @@ -50,21 +56,6 @@ } -def _camel_to_snake(value: str) -> str: - out = [] - for char in value: - if char.isupper(): - out.append("_") - out.append(char.lower()) - else: - out.append(char) - return "".join(out).lstrip("_") - - -def _is_supported_metric_value(value: Any) -> bool: - return isinstance(value, Real) and not isinstance(value, bool) - - def sanitize_openrouter_logged_value(value: Any) -> Any: safe = bt_safe_deep_copy(value) @@ -155,23 +146,13 @@ def _parse_openrouter_metrics_from_usage(usage: Any) -> dict[str, float]: return metrics -def _timing_metrics(start_time: float, first_token_time: float | None = None) -> dict[str, float]: - end_time = time.time() - metrics = { - "start": start_time, - "end": end_time, - "duration": end_time - start_time, - } - if first_token_time is not None: - metrics["time_to_first_token"] = first_token_time - start_time - return metrics - - -def _merge_metrics(start_time: float, usage: Any, first_token_time: float | None = None) -> dict[str, float]: - return { - **_timing_metrics(start_time, first_token_time), - **_parse_openrouter_metrics_from_usage(usage), - } +def _merge_metrics(start_time: float, usage: Any, first_token_time: float | None = None) -> dict[str, Any]: + return _merge_timing_and_usage_metrics( + start_time, + usage, + _parse_openrouter_metrics_from_usage, + first_token_time, + ) def _response_to_output(response: Any, *, fallback_output: Any | None = None) -> Any: @@ -220,26 +201,6 @@ def _start_span(name: str, span_input: Any, metadata: dict[str, Any]): ) -def _log_and_end( - span: Any, *, output: Any = None, metrics: dict[str, Any] | None = None, metadata: dict[str, Any] | None = None -): - event = {} - if output is not None: - event["output"] = output - if metrics: - event["metrics"] = metrics - if metadata: - event["metadata"] = metadata - if event: - span.log(**event) - span.end() - - -def _log_error_and_end(span: Any, error: Exception): - span.log(error=error) - span.end() - - class _TracedOpenRouterSyncStream: def __init__(self, stream: Any, span: Any, metadata: dict[str, Any], kind: str, start_time: float): self._stream = stream @@ -288,7 +249,7 @@ def _finalize(self, *, error: Exception | None = None): self._closed = True if error is not None: - _log_error_and_end(self._span, error) + _log_error_and_end_span(self._span, error) return if self._kind == "chat": @@ -298,7 +259,7 @@ def _finalize(self, *, error: Exception | None = None): output, usage, response_metadata = _aggregate_responses_stream(self._items) metadata = {**self._metadata, **response_metadata} - _log_and_end( + _log_and_end_span( self._span, output=output, metrics=_merge_metrics(self._start_time, usage, self._first_token_time), @@ -354,7 +315,7 @@ def _finalize(self, *, error: Exception | None = None): self._closed = True if error is not None: - _log_error_and_end(self._span, error) + _log_error_and_end_span(self._span, error) return if self._kind == "chat": @@ -364,7 +325,7 @@ def _finalize(self, *, error: Exception | None = None): output, usage, response_metadata = _aggregate_responses_stream(self._items) metadata = {**self._metadata, **response_metadata} - _log_and_end( + _log_and_end_span( self._span, output=output, metrics=_merge_metrics(self._start_time, usage, self._first_token_time), @@ -490,7 +451,7 @@ def _aggregate_responses_stream(chunks: list[Any]) -> tuple[Any, Any, dict[str, def _finalize_chat_response(span: Any, request_metadata: dict[str, Any], result: Any, start_time: float): - _log_and_end( + _log_and_end_span( span, output=_response_to_output(result), metrics=_merge_metrics(start_time, getattr(result, "usage", None)), @@ -499,7 +460,7 @@ def _finalize_chat_response(span: Any, request_metadata: dict[str, Any], result: def _finalize_embeddings_response(span: Any, request_metadata: dict[str, Any], result: Any, start_time: float): - _log_and_end( + _log_and_end_span( span, output=_embeddings_output(result), metrics=_merge_metrics(start_time, getattr(result, "usage", None)), @@ -508,7 +469,7 @@ def _finalize_embeddings_response(span: Any, request_metadata: dict[str, Any], r def _finalize_responses_response(span: Any, request_metadata: dict[str, Any], result: Any, start_time: float): - _log_and_end( + _log_and_end_span( span, output=_response_to_output(result, fallback_output=getattr(result, "output_text", None)), metrics=_merge_metrics(start_time, getattr(result, "usage", None)), @@ -524,7 +485,7 @@ def _chat_send_wrapper(wrapped, instance, args, kwargs): try: result = wrapped(*args, **kwargs) except Exception as error: - _log_error_and_end(span, error) + _log_error_and_end_span(span, error) raise if kwargs.get("stream"): @@ -542,7 +503,7 @@ async def _chat_send_async_wrapper(wrapped, instance, args, kwargs): try: result = await wrapped(*args, **kwargs) except Exception as error: - _log_error_and_end(span, error) + _log_error_and_end_span(span, error) raise if kwargs.get("stream"): @@ -560,7 +521,7 @@ def _embeddings_generate_wrapper(wrapped, instance, args, kwargs): try: result = wrapped(*args, **kwargs) except Exception as error: - _log_error_and_end(span, error) + _log_error_and_end_span(span, error) raise _finalize_embeddings_response(span, request_metadata, result, start_time) @@ -575,7 +536,7 @@ async def _embeddings_generate_async_wrapper(wrapped, instance, args, kwargs): try: result = await wrapped(*args, **kwargs) except Exception as error: - _log_error_and_end(span, error) + _log_error_and_end_span(span, error) raise _finalize_embeddings_response(span, request_metadata, result, start_time) @@ -590,7 +551,7 @@ def _responses_send_wrapper(wrapped, instance, args, kwargs): try: result = wrapped(*args, **kwargs) except Exception as error: - _log_error_and_end(span, error) + _log_error_and_end_span(span, error) raise if kwargs.get("stream"): @@ -608,7 +569,7 @@ async def _responses_send_async_wrapper(wrapped, instance, args, kwargs): try: result = await wrapped(*args, **kwargs) except Exception as error: - _log_error_and_end(span, error) + _log_error_and_end_span(span, error) raise if kwargs.get("stream"): diff --git a/py/src/braintrust/integrations/test_utils.py b/py/src/braintrust/integrations/test_utils.py new file mode 100644 index 00000000..fec97597 --- /dev/null +++ b/py/src/braintrust/integrations/test_utils.py @@ -0,0 +1,271 @@ +import unittest.mock + +import pytest +from braintrust import Attachment +from braintrust.integrations.utils import ( + _camel_to_snake, + _convert_data_url_to_attachment, + _is_supported_metric_value, + _log_and_end_span, + _log_error_and_end_span, + _merge_timing_and_usage_metrics, + _parse_openai_usage_metrics, + _prettify_response_params, + _serialize_response_format, + _timing_metrics, + _try_to_dict, +) + + +class NotGiven: + pass + + +def test_camel_to_snake(): + assert _camel_to_snake("promptTokens") == "prompt_tokens" + assert _camel_to_snake("TotalTokens") == "total_tokens" + assert _camel_to_snake("already_snake") == "already_snake" + + +def test_is_supported_metric_value_excludes_booleans(): + assert _is_supported_metric_value(1) + assert _is_supported_metric_value(1.5) + assert not _is_supported_metric_value(True) + assert not _is_supported_metric_value(False) + assert not _is_supported_metric_value("1") + + +def test_try_to_dict_uses_pydantic_model_dump_for_basemodel_instances(): + pydantic = pytest.importorskip("pydantic") + + class Usage(pydantic.BaseModel): + tokens: int + cached_tokens: int + + result = _try_to_dict(Usage(tokens=3, cached_tokens=1)) + + assert result == {"tokens": 3, "cached_tokens": 1} + + +def test_try_to_dict_uses_to_dict_when_available(): + class ToDictOnly: + __slots__ = ("_payload",) + + def __init__(self, payload): + self._payload = payload + + def to_dict(self): + return self._payload + + result = _try_to_dict(ToDictOnly({"tokens": 3})) + + assert result == {"tokens": 3} + + +def test_try_to_dict_falls_back_from_model_dump_python_to_bare_model_dump(): + class BareModelDumpOnly: + def model_dump(self, mode=None): + if mode == "python": + raise TypeError("mode not supported") + return {"tokens": 3} + + result = _try_to_dict(BareModelDumpOnly()) + + assert result == {"tokens": 3} + + +def test_try_to_dict_continues_past_non_dict_converter_results(): + class MixedConverters: + def model_dump(self, mode=None): + return [mode] + + def to_dict(self): + return {"tokens": 3} + + result = _try_to_dict(MixedConverters()) + + assert result == {"tokens": 3} + + +def test_try_to_dict_falls_back_to_vars_for_plain_objects(): + class PlainObject: + def __init__(self): + self.foo = "bar" + self.count = 2 + + result = _try_to_dict(PlainObject()) + + assert result == {"foo": "bar", "count": 2} + + +def test_try_to_dict_returns_original_when_no_conversion_is_possible(): + obj = object() + + result = _try_to_dict(obj) + + assert result is obj + + +def test_parse_openai_usage_metrics_handles_nested_token_details(): + usage = { + "prompt_tokens": 10, + "completion_tokens": 20, + "total_tokens": 30, + "input_tokens_details": {"cached_tokens": 4}, + "is_byok": True, + } + + metrics = _parse_openai_usage_metrics( + usage, + token_name_map={ + "prompt_tokens": "prompt_tokens", + "completion_tokens": "completion_tokens", + "total_tokens": "tokens", + }, + token_prefix_map={"input": "prompt"}, + ) + + assert metrics == { + "prompt_tokens": 10, + "completion_tokens": 20, + "tokens": 30, + "prompt_cached_tokens": 4, + } + + +def test_prettify_response_params_filters_not_given_without_mutating_input(): + original = { + "model": "gpt-5", + "response_format": object(), + "optional": NotGiven(), + } + + prettified = _prettify_response_params(original, drop_not_given=True) + + assert prettified == { + "model": "gpt-5", + "response_format": original["response_format"], + } + assert "optional" in original + + +def test_convert_data_url_to_attachment_converts_valid_base64(): + data_url = "data:image/png;base64,aGVsbG8=" + + attachment = _convert_data_url_to_attachment(data_url) + + assert isinstance(attachment, Attachment) + assert attachment.reference["content_type"] == "image/png" + assert attachment.reference["filename"] == "image.png" + + +def test_convert_data_url_to_attachment_preserves_invalid_base64(): + data_url = "data:image/png;base64,aGVsbG8=!" + + converted = _convert_data_url_to_attachment(data_url) + + assert converted == data_url + + +def test_convert_data_url_to_attachment_uses_file_prefix_for_non_image_mime_types(): + data_url = "data:application/pdf;base64,aGVsbG8=" + + attachment = _convert_data_url_to_attachment(data_url) + + assert isinstance(attachment, Attachment) + assert attachment.reference["content_type"] == "application/pdf" + assert attachment.reference["filename"] == "file.pdf" + + +def test_convert_data_url_to_attachment_preserves_non_data_urls(): + value = "https://example.com/image.png" + + converted = _convert_data_url_to_attachment(value) + + assert converted == value + + +def test_serialize_response_format_with_pydantic_basemodel_subclass(): + pydantic = pytest.importorskip("pydantic") + + class ResponseFormat(pydantic.BaseModel): + answer: str + + serialized = _serialize_response_format(ResponseFormat) + + assert serialized["type"] == "json_schema" + assert serialized["json_schema"]["name"] == "ResponseFormat" + assert serialized["json_schema"]["schema"]["properties"]["answer"]["title"] == "Answer" + + +def test_timing_metrics_includes_time_to_first_token_when_present(): + assert _timing_metrics(10.0, 15.0, 12.0) == { + "start": 10.0, + "end": 15.0, + "duration": 5.0, + "time_to_first_token": 2.0, + } + + +def test_timing_metrics_omits_time_to_first_token_when_absent(): + assert _timing_metrics(10.0, 15.0) == { + "start": 10.0, + "end": 15.0, + "duration": 5.0, + } + + +def test_log_and_end_span_logs_populated_event_then_ends(): + span = unittest.mock.Mock() + + _log_and_end_span( + span, + output={"answer": "4"}, + metrics={"tokens": 2}, + metadata={"provider": "test"}, + ) + + span.log.assert_called_once_with( + output={"answer": "4"}, + metrics={"tokens": 2}, + metadata={"provider": "test"}, + ) + span.end.assert_called_once_with() + + +def test_log_and_end_span_skips_log_for_empty_event(): + span = unittest.mock.Mock() + + _log_and_end_span(span) + + span.log.assert_not_called() + span.end.assert_called_once_with() + + +def test_log_error_and_end_span_logs_error_then_ends(): + span = unittest.mock.Mock() + error = RuntimeError("boom") + + _log_error_and_end_span(span, error) + + span.log.assert_called_once_with(error=error) + span.end.assert_called_once_with() + + +def test_merge_timing_and_usage_metrics(monkeypatch): + monkeypatch.setattr("braintrust.integrations.utils.time.time", lambda: 15.0) + + metrics = _merge_timing_and_usage_metrics( + 10.0, + {"usage": 1}, + lambda usage: {"tokens": usage["usage"]}, + 12.0, + ) + + assert metrics == { + "start": 10.0, + "end": 15.0, + "duration": 5.0, + "time_to_first_token": 2.0, + "tokens": 1, + } diff --git a/py/src/braintrust/integrations/utils.py b/py/src/braintrust/integrations/utils.py new file mode 100644 index 00000000..db8d710f --- /dev/null +++ b/py/src/braintrust/integrations/utils.py @@ -0,0 +1,256 @@ +"""Shared tracing utilities for Braintrust SDK integrations. + +These helpers are common building blocks used across multiple provider +integrations. Keeping them here avoids duplication and makes behavioral fixes +propagate to all providers at once. + +Names are prefixed with ``_`` so that consumer modules can import them +directly without aliasing (e.g. ``from braintrust.integrations.utils import +_try_to_dict``). +""" + +import base64 +import binascii +import re +import time +import warnings +from collections.abc import Callable, Mapping +from numbers import Real +from typing import Any + +from braintrust.logger import Attachment, Span +from braintrust.util import is_numeric + + +_DATA_URL_RE = re.compile(r"^data:([^;]+);base64,(.+)$") + + +def _try_to_dict(obj: Any) -> dict[str, Any] | Any: + """Best-effort conversion of an SDK response object to a plain dict. + + Tries, in order: + 1. ``model_dump(mode="python")`` (preferred for Pydantic v2 objects) + 2. ``model_dump()`` (fallback for SDKs with custom signatures) + 3. ``to_dict()`` (used by some provider SDK response objects) + 4. ``dict()`` (Pydantic v1 / legacy) + 5. ``vars(obj)`` (plain Python attribute bags) + 6. returns *obj* unchanged + + Only dict-like conversion results are accepted; non-dict results are + ignored so later fallbacks still run. + + Pydantic serializer warnings (common with generic/discriminated-union + models such as OpenAI's ``ParsedResponse[T]``) are suppressed. + """ + if isinstance(obj, dict): + return obj + + model_dump = getattr(obj, "model_dump", None) + + def _call_model_dump_python() -> Any: + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", message="Pydantic serializer warnings", category=UserWarning) + return model_dump(mode="python") + + def _call_model_dump() -> Any: + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", message="Pydantic serializer warnings", category=UserWarning) + return model_dump() + + to_dict = getattr(obj, "to_dict", None) + dict_method = getattr(obj, "dict", None) + + converters: list[Callable[[], Any]] = [] + if callable(model_dump): + converters.extend((_call_model_dump_python, _call_model_dump)) + if callable(to_dict): + converters.append(to_dict) + if callable(dict_method): + converters.append(dict_method) + converters.append(lambda: vars(obj)) + + for converter in converters: + try: + result = converter() + except Exception: + continue + if isinstance(result, dict): + return result + + return obj + + +def _camel_to_snake(value: str) -> str: + """Convert a camelCase or PascalCase string into snake_case.""" + out = [] + for char in value: + if char.isupper(): + out.append("_") + out.append(char.lower()) + else: + out.append(char) + return "".join(out).lstrip("_") + + +def _is_supported_metric_value(value: Any) -> bool: + """Return ``True`` for numeric metric values, excluding booleans.""" + return isinstance(value, Real) and not isinstance(value, bool) + + +def _convert_data_url_to_attachment(data_url: str, filename: str | None = None) -> Attachment | str: + """Convert a ``data:;base64,…`` URL into an :class:`Attachment`. + + Returns the original *data_url* string unchanged when it does not match + the expected format or cannot be decoded. + """ + match = _DATA_URL_RE.match(data_url) + if not match: + return data_url + + mime_type, base64_data = match.groups() + + try: + binary_data = base64.b64decode(base64_data, validate=True) + except (binascii.Error, ValueError): + return data_url + + if filename is None: + extension = mime_type.split("/")[1] if "/" in mime_type else "bin" + prefix = "image" if mime_type.startswith("image/") else "file" + filename = f"{prefix}.{extension}" + + return Attachment(data=binary_data, filename=filename, content_type=mime_type) + + +def _is_not_given(value: object) -> bool: + """Return ``True`` when *value* is a provider ``NOT_GIVEN`` sentinel. + + Works by type-name inspection so that Braintrust does not need a + direct import dependency on any provider SDK. + """ + if value is None: + return False + try: + return type(value).__name__ == "NotGiven" + except Exception: + return False + + +def _serialize_response_format(response_format: Any) -> Any: + """Serialize a Pydantic ``BaseModel`` subclass into a JSON-schema dict. + + Non-Pydantic values pass through unchanged. Used when logging + ``response_format`` parameters so the span metadata contains a + readable schema rather than a Python class reference. + """ + try: + from pydantic import BaseModel + except ImportError: + return response_format + + if isinstance(response_format, type) and issubclass(response_format, BaseModel): + return dict( + type="json_schema", + json_schema=dict( + name=response_format.__name__, + schema=response_format.model_json_schema(), + ), + ) + return response_format + + +def _prettify_response_params(params: dict[str, Any], *, drop_not_given: bool = False) -> dict[str, Any]: + """Return a shallow copy of traced request params with logging-friendly values.""" + ret = params.copy() + if drop_not_given: + ret = {key: value for key, value in ret.items() if not _is_not_given(value)} + + if "response_format" in ret: + ret["response_format"] = _serialize_response_format(ret["response_format"]) + return ret + + +def _parse_openai_usage_metrics( + usage: Any, + *, + token_name_map: Mapping[str, str], + token_prefix_map: Mapping[str, str], +) -> dict[str, Any]: + """Parse usage payloads that follow OpenAI's ``*_tokens`` conventions.""" + metrics: dict[str, Any] = {} + + if not usage: + return metrics + + usage = _try_to_dict(usage) + if not isinstance(usage, dict): + return metrics + + for name, value in usage.items(): + if name.endswith("_tokens_details"): + if not isinstance(value, dict): + continue + raw_prefix = name[: -len("_tokens_details")] + prefix = token_prefix_map.get(raw_prefix, raw_prefix) + for nested_name, nested_value in value.items(): + if is_numeric(nested_value): + metrics[f"{prefix}_{nested_name}"] = nested_value + elif is_numeric(value): + metrics[token_name_map.get(name, name)] = value + + return metrics + + +def _timing_metrics(start_time: float, end_time: float, first_token_time: float | None = None) -> dict[str, float]: + """Build a standard ``start / end / duration`` metrics dict. + + Optionally includes ``time_to_first_token`` when *first_token_time* + is provided. + """ + metrics: dict[str, float] = { + "start": start_time, + "end": end_time, + "duration": end_time - start_time, + } + if first_token_time is not None: + metrics["time_to_first_token"] = first_token_time - start_time + return metrics + + +def _merge_timing_and_usage_metrics( + start_time: float, + usage: Any, + usage_parser: Callable[[Any], dict[str, Any]], + first_token_time: float | None = None, +) -> dict[str, Any]: + """Combine standard timing metrics with provider-specific usage parsing.""" + return { + **_timing_metrics(start_time, time.time(), first_token_time), + **usage_parser(usage), + } + + +def _log_and_end_span( + span: Span, + *, + output: Any = None, + metrics: dict[str, Any] | None = None, + metadata: dict[str, Any] | None = None, +) -> None: + """Log *output*, *metrics* and *metadata* (when present) then end the span.""" + event: dict[str, Any] = {} + if output is not None: + event["output"] = output + if metrics: + event["metrics"] = metrics + if metadata: + event["metadata"] = metadata + if event: + span.log(**event) + span.end() + + +def _log_error_and_end_span(span: Span, error: BaseException) -> None: + """Log an error to *span* and immediately end it.""" + span.log(error=error) + span.end() diff --git a/py/src/braintrust/util.py b/py/src/braintrust/util.py index 516cb9b6..3541cb5f 100644 --- a/py/src/braintrust/util.py +++ b/py/src/braintrust/util.py @@ -40,6 +40,11 @@ def is_numeric(v): return isinstance(v, (int, float, complex)) and not isinstance(v, bool) +def clean_nones(obj: dict[str, Any]) -> dict[str, Any]: + """Return a shallow copy of *obj* with ``None``-valued keys removed.""" + return {k: v for k, v in obj.items() if v is not None} + + def eprint(*args, **kwargs) -> None: print(*args, file=sys.stderr, **kwargs)