diff --git a/sentry_sdk/ai/utils.py b/sentry_sdk/ai/utils.py index 0c0b937006..1fb291bdac 100644 --- a/sentry_sdk/ai/utils.py +++ b/sentry_sdk/ai/utils.py @@ -1,14 +1,18 @@ import json - +from collections import deque from typing import TYPE_CHECKING +from sys import getsizeof if TYPE_CHECKING: - from typing import Any, Callable + from typing import Any, Callable, Dict, List, Optional, Tuple + from sentry_sdk.tracing import Span import sentry_sdk from sentry_sdk.utils import logger +MAX_GEN_AI_MESSAGE_BYTES = 20_000 # 20KB + class GEN_AI_ALLOWED_MESSAGE_ROLES: SYSTEM = "system" @@ -95,3 +99,48 @@ def get_start_span_function(): current_span is not None and current_span.containing_transaction is not None ) return sentry_sdk.start_span if transaction_exists else sentry_sdk.start_transaction + + +def _find_truncation_index(messages, max_bytes): + # type: (List[Dict[str, Any]], int) -> int + """ + Find the index of the first message that would exceed the max bytes limit. + Compute the individual message sizes, and return the index of the first message from the back + of the list that would exceed the max bytes limit. + """ + running_sum = 0 + for idx in range(len(messages) - 1, -1, -1): + size = len(json.dumps(messages[idx], separators=(",", ":")).encode("utf-8")) + running_sum += size + if running_sum > max_bytes: + return idx + 1 + + return 0 + + +def truncate_messages_by_size(messages, max_bytes=MAX_GEN_AI_MESSAGE_BYTES): + # type: (List[Dict[str, Any]], int) -> Tuple[List[Dict[str, Any]], int] + serialized_json = json.dumps(messages, separators=(",", ":")) + current_size = len(serialized_json.encode("utf-8")) + + if current_size <= max_bytes: + return messages, 0 + + truncation_index = _find_truncation_index(messages, max_bytes) + return messages[truncation_index:], truncation_index + + +def truncate_and_annotate_messages( + messages, span, scope, max_bytes=MAX_GEN_AI_MESSAGE_BYTES +): + # type: (Optional[List[Dict[str, Any]]], Any, Any, int) -> Optional[List[Dict[str, Any]]] + if not messages: + return None + + truncated_messages, removed_count = truncate_messages_by_size(messages, max_bytes) + if removed_count > 0: + scope._gen_ai_messages_truncated[span.span_id] = len(messages) - len( + truncated_messages + ) + + return truncated_messages diff --git a/sentry_sdk/client.py b/sentry_sdk/client.py index d17f922642..ffd899b545 100644 --- a/sentry_sdk/client.py +++ b/sentry_sdk/client.py @@ -598,6 +598,24 @@ def _prepare_event( if event_scrubber: event_scrubber.scrub_event(event) + if scope is not None and scope._gen_ai_messages_truncated: + spans = event.get("spans", []) # type: List[Dict[str, Any]] | AnnotatedValue + if isinstance(spans, list): + for span in spans: + span_id = span.get("span_id", None) + span_data = span.get("data", {}) + if ( + span_id + and span_id in scope._gen_ai_messages_truncated + and SPANDATA.GEN_AI_REQUEST_MESSAGES in span_data + ): + span_data[SPANDATA.GEN_AI_REQUEST_MESSAGES] = AnnotatedValue( + span_data[SPANDATA.GEN_AI_REQUEST_MESSAGES], + { + "len": scope._gen_ai_messages_truncated[span_id] + + len(span_data[SPANDATA.GEN_AI_REQUEST_MESSAGES]) + }, + ) if previous_total_spans is not None: event["spans"] = AnnotatedValue( event.get("spans", []), {"len": previous_total_spans} @@ -606,6 +624,7 @@ def _prepare_event( event["breadcrumbs"] = AnnotatedValue( event.get("breadcrumbs", []), {"len": previous_total_breadcrumbs} ) + # Postprocess the event here so that annotated types do # generally not surface in before_send if event is not None: diff --git a/sentry_sdk/integrations/openai.py b/sentry_sdk/integrations/openai.py index 19d7717b3c..9ff2489765 100644 --- a/sentry_sdk/integrations/openai.py +++ b/sentry_sdk/integrations/openai.py @@ -1,10 +1,13 @@ from functools import wraps -from collections.abc import Iterable import sentry_sdk from sentry_sdk import consts from sentry_sdk.ai.monitoring import record_token_usage -from sentry_sdk.ai.utils import set_data_normalized, normalize_message_roles +from sentry_sdk.ai.utils import ( + set_data_normalized, + normalize_message_roles, + truncate_and_annotate_messages, +) from sentry_sdk.consts import SPANDATA from sentry_sdk.integrations import DidNotEnable, Integration from sentry_sdk.scope import should_send_default_pii @@ -18,7 +21,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from typing import Any, List, Optional, Callable, AsyncIterator, Iterator + from typing import Any, Iterable, List, Optional, Callable, AsyncIterator, Iterator from sentry_sdk.tracing import Span try: @@ -189,9 +192,12 @@ def _set_input_data(span, kwargs, operation, integration): and integration.include_prompts ): normalized_messages = normalize_message_roles(messages) - set_data_normalized( - span, SPANDATA.GEN_AI_REQUEST_MESSAGES, normalized_messages, unpack=False - ) + scope = sentry_sdk.get_current_scope() + messages_data = truncate_and_annotate_messages(normalized_messages, span, scope) + if messages_data is not None: + set_data_normalized( + span, SPANDATA.GEN_AI_REQUEST_MESSAGES, messages_data, unpack=False + ) # Input attributes: Common set_data_normalized(span, SPANDATA.GEN_AI_SYSTEM, "openai") diff --git a/sentry_sdk/scope.py b/sentry_sdk/scope.py index f9caf7e1d6..5815a65440 100644 --- a/sentry_sdk/scope.py +++ b/sentry_sdk/scope.py @@ -188,6 +188,7 @@ class Scope: "_extras", "_breadcrumbs", "_n_breadcrumbs_truncated", + "_gen_ai_messages_truncated", "_event_processors", "_error_processors", "_should_capture", @@ -213,6 +214,7 @@ def __init__(self, ty=None, client=None): self._name = None # type: Optional[str] self._propagation_context = None # type: Optional[PropagationContext] self._n_breadcrumbs_truncated = 0 # type: int + self._gen_ai_messages_truncated = {} # type: Dict[str, int] self.client = NonRecordingClient() # type: sentry_sdk.client.BaseClient @@ -247,6 +249,7 @@ def __copy__(self): rv._breadcrumbs = copy(self._breadcrumbs) rv._n_breadcrumbs_truncated = self._n_breadcrumbs_truncated + rv._gen_ai_messages_truncated = self._gen_ai_messages_truncated.copy() rv._event_processors = self._event_processors.copy() rv._error_processors = self._error_processors.copy() rv._propagation_context = self._propagation_context @@ -1583,6 +1586,8 @@ def update_from_scope(self, scope): self._n_breadcrumbs_truncated = ( self._n_breadcrumbs_truncated + scope._n_breadcrumbs_truncated ) + if scope._gen_ai_messages_truncated: + self._gen_ai_messages_truncated.update(scope._gen_ai_messages_truncated) if scope._span: self._span = scope._span if scope._attachments: diff --git a/tests/integrations/openai/test_openai.py b/tests/integrations/openai/test_openai.py index 276a1b4886..ccef4f336e 100644 --- a/tests/integrations/openai/test_openai.py +++ b/tests/integrations/openai/test_openai.py @@ -1,3 +1,4 @@ +import json import pytest from sentry_sdk.utils import package_version @@ -6,7 +7,6 @@ from openai import NOT_GIVEN except ImportError: NOT_GIVEN = None - try: from openai import omit except ImportError: @@ -44,6 +44,9 @@ OpenAIIntegration, _calculate_token_usage, ) +from sentry_sdk.ai.utils import MAX_GEN_AI_MESSAGE_BYTES +from sentry_sdk._types import AnnotatedValue +from sentry_sdk.serializer import serialize from unittest import mock # python 3.3 and above @@ -1456,6 +1459,7 @@ def test_empty_tools_in_chat_completion(sentry_init, capture_events, tools): def test_openai_message_role_mapping(sentry_init, capture_events): """Test that OpenAI integration properly maps message roles like 'ai' to 'assistant'""" + sentry_init( integrations=[OpenAIIntegration(include_prompts=True)], traces_sample_rate=1.0, @@ -1465,7 +1469,6 @@ def test_openai_message_role_mapping(sentry_init, capture_events): client = OpenAI(api_key="z") client.chat.completions._post = mock.Mock(return_value=EXAMPLE_CHAT_COMPLETION) - # Test messages with mixed roles including "ai" that should be mapped to "assistant" test_messages = [ {"role": "system", "content": "You are helpful."}, @@ -1476,11 +1479,9 @@ def test_openai_message_role_mapping(sentry_init, capture_events): with start_transaction(name="openai tx"): client.chat.completions.create(model="test-model", messages=test_messages) - + # Verify that the span was created correctly (event,) = events span = event["spans"][0] - - # Verify that the span was created correctly assert span["op"] == "gen_ai.chat" assert SPANDATA.GEN_AI_REQUEST_MESSAGES in span["data"] @@ -1505,3 +1506,55 @@ def test_openai_message_role_mapping(sentry_init, capture_events): # Verify no "ai" roles remain roles = [msg["role"] for msg in stored_messages] assert "ai" not in roles + + +def test_openai_message_truncation(sentry_init, capture_events): + """Test that large messages are truncated properly in OpenAI integration.""" + sentry_init( + integrations=[OpenAIIntegration(include_prompts=True)], + traces_sample_rate=1.0, + send_default_pii=True, + ) + events = capture_events() + + client = OpenAI(api_key="z") + client.chat.completions._post = mock.Mock(return_value=EXAMPLE_CHAT_COMPLETION) + + large_content = ( + "This is a very long message that will exceed our size limits. " * 1000 + ) + large_messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": large_content}, + {"role": "assistant", "content": large_content}, + {"role": "user", "content": large_content}, + ] + + with start_transaction(name="openai tx"): + client.chat.completions.create( + model="some-model", + messages=large_messages, + ) + + (event,) = events + span = event["spans"][0] + assert SPANDATA.GEN_AI_REQUEST_MESSAGES in span["data"] + + messages_data = span["data"][SPANDATA.GEN_AI_REQUEST_MESSAGES] + assert isinstance(messages_data, str) + + parsed_messages = json.loads(messages_data) + assert isinstance(parsed_messages, list) + assert len(parsed_messages) <= len(large_messages) + + if "_meta" in event and len(parsed_messages) < len(large_messages): + meta_path = event["_meta"] + if ( + "spans" in meta_path + and "0" in meta_path["spans"] + and "data" in meta_path["spans"]["0"] + ): + span_meta = meta_path["spans"]["0"]["data"] + if SPANDATA.GEN_AI_REQUEST_MESSAGES in span_meta: + messages_meta = span_meta[SPANDATA.GEN_AI_REQUEST_MESSAGES] + assert "len" in messages_meta.get("", {}) diff --git a/tests/test_ai_monitoring.py b/tests/test_ai_monitoring.py index ee757f82cd..be66860384 100644 --- a/tests/test_ai_monitoring.py +++ b/tests/test_ai_monitoring.py @@ -1,7 +1,19 @@ +import json + import pytest import sentry_sdk +from sentry_sdk._types import AnnotatedValue from sentry_sdk.ai.monitoring import ai_track +from sentry_sdk.ai.utils import ( + MAX_GEN_AI_MESSAGE_BYTES, + set_data_normalized, + truncate_and_annotate_messages, + truncate_messages_by_size, + _find_truncation_index, +) +from sentry_sdk.serializer import serialize +from sentry_sdk.utils import safe_serialize def test_ai_track(sentry_init, capture_events): @@ -160,3 +172,291 @@ async def async_tool(**kwargs): assert span["description"] == "my async tool" assert span["op"] == "custom.async.operation" + + +@pytest.fixture +def sample_messages(): + """Sample messages similar to what gen_ai integrations would use""" + return [ + {"role": "system", "content": "You are a helpful assistant."}, + { + "role": "user", + "content": "What is the difference between a list and a tuple in Python?", + }, + { + "role": "assistant", + "content": "Lists are mutable and use [], tuples are immutable and use ().", + }, + {"role": "user", "content": "Can you give me some examples?"}, + { + "role": "assistant", + "content": "Sure! Here are examples:\n\n```python\n# List\nmy_list = [1, 2, 3]\nmy_list.append(4)\n\n# Tuple\nmy_tuple = (1, 2, 3)\n# my_tuple.append(4) would error\n```", + }, + ] + + +@pytest.fixture +def large_messages(): + """Messages that will definitely exceed size limits""" + large_content = "This is a very long message. " * 100 + return [ + {"role": "system", "content": large_content}, + {"role": "user", "content": large_content}, + {"role": "assistant", "content": large_content}, + {"role": "user", "content": large_content}, + ] + + +class TestTruncateMessagesBySize: + def test_no_truncation_needed(self, sample_messages): + """Test that messages under the limit are not truncated""" + result, removed_count = truncate_messages_by_size( + sample_messages, max_bytes=MAX_GEN_AI_MESSAGE_BYTES + ) + assert len(result) == len(sample_messages) + assert result == sample_messages + assert removed_count == 0 + + def test_truncation_removes_oldest_first(self, large_messages): + """Test that oldest messages are removed first during truncation""" + small_limit = 3000 + result, removed_count = truncate_messages_by_size( + large_messages, max_bytes=small_limit + ) + assert len(result) < len(large_messages) + + if result: + assert result[-1] == large_messages[-1] + assert removed_count == len(large_messages) - len(result) + + def test_empty_messages_list(self): + """Test handling of empty messages list""" + result, removed_count = truncate_messages_by_size( + [], max_bytes=MAX_GEN_AI_MESSAGE_BYTES // 500 + ) + assert result == [] + assert removed_count == 0 + + def test_find_truncation_index( + self, + ): + """Test that the truncation index is found correctly""" + # when represented in JSON, these are each 7 bytes long + messages = ["A" * 5, "B" * 5, "C" * 5, "D" * 5, "E" * 5] + truncation_index = _find_truncation_index(messages, 20) + assert truncation_index == 3 + assert messages[truncation_index:] == ["D" * 5, "E" * 5] + + messages = ["A" * 5, "B" * 5, "C" * 5, "D" * 5, "E" * 5] + truncation_index = _find_truncation_index(messages, 40) + assert truncation_index == 0 + assert messages[truncation_index:] == [ + "A" * 5, + "B" * 5, + "C" * 5, + "D" * 5, + "E" * 5, + ] + + def test_progressive_truncation(self, large_messages): + """Test that truncation works progressively with different limits""" + limits = [ + MAX_GEN_AI_MESSAGE_BYTES // 5, + MAX_GEN_AI_MESSAGE_BYTES // 10, + MAX_GEN_AI_MESSAGE_BYTES // 25, + MAX_GEN_AI_MESSAGE_BYTES // 100, + MAX_GEN_AI_MESSAGE_BYTES // 500, + ] + prev_count = len(large_messages) + + for limit in limits: + result = truncate_messages_by_size(large_messages, max_bytes=limit) + current_count = len(result) + + assert current_count <= prev_count + assert current_count >= 1 + prev_count = current_count + + +class TestTruncateAndAnnotateMessages: + def test_no_truncation_returns_list(self, sample_messages): + class MockSpan: + def __init__(self): + self.span_id = "test_span_id" + self.data = {} + + def set_data(self, key, value): + self.data[key] = value + + class MockScope: + def __init__(self): + self._gen_ai_messages_truncated = {} + + span = MockSpan() + scope = MockScope() + result = truncate_and_annotate_messages(sample_messages, span, scope) + + assert isinstance(result, list) + assert not isinstance(result, AnnotatedValue) + assert len(result) == len(sample_messages) + assert result == sample_messages + assert span.span_id not in scope._gen_ai_messages_truncated + + def test_truncation_sets_metadata_on_scope(self, large_messages): + class MockSpan: + def __init__(self): + self.span_id = "test_span_id" + self.data = {} + + def set_data(self, key, value): + self.data[key] = value + + class MockScope: + def __init__(self): + self._gen_ai_messages_truncated = {} + + small_limit = 1000 + span = MockSpan() + scope = MockScope() + original_count = len(large_messages) + result = truncate_and_annotate_messages( + large_messages, span, scope, max_bytes=small_limit + ) + + assert isinstance(result, list) + assert not isinstance(result, AnnotatedValue) + assert len(result) < len(large_messages) + n_removed = original_count - len(result) + assert scope._gen_ai_messages_truncated[span.span_id] == n_removed + + def test_scope_tracks_removed_messages(self, large_messages): + class MockSpan: + def __init__(self): + self.span_id = "test_span_id" + self.data = {} + + def set_data(self, key, value): + self.data[key] = value + + class MockScope: + def __init__(self): + self._gen_ai_messages_truncated = {} + + small_limit = 1000 + original_count = len(large_messages) + span = MockSpan() + scope = MockScope() + + result = truncate_and_annotate_messages( + large_messages, span, scope, max_bytes=small_limit + ) + + n_removed = original_count - len(result) + assert scope._gen_ai_messages_truncated[span.span_id] == n_removed + assert len(result) + n_removed == original_count + + def test_empty_messages_returns_none(self): + class MockSpan: + def __init__(self): + self.span_id = "test_span_id" + self.data = {} + + def set_data(self, key, value): + self.data[key] = value + + class MockScope: + def __init__(self): + self._gen_ai_messages_truncated = {} + + span = MockSpan() + scope = MockScope() + result = truncate_and_annotate_messages([], span, scope) + assert result is None + + result = truncate_and_annotate_messages(None, span, scope) + assert result is None + + def test_truncated_messages_newest_first(self, large_messages): + class MockSpan: + def __init__(self): + self.span_id = "test_span_id" + self.data = {} + + def set_data(self, key, value): + self.data[key] = value + + class MockScope: + def __init__(self): + self._gen_ai_messages_truncated = {} + + small_limit = 3000 + span = MockSpan() + scope = MockScope() + result = truncate_and_annotate_messages( + large_messages, span, scope, max_bytes=small_limit + ) + + assert isinstance(result, list) + assert result[0] == large_messages[-len(result)] + + +class TestClientAnnotation: + def test_client_wraps_truncated_messages_in_annotated_value(self, large_messages): + """Test that client.py properly wraps truncated messages in AnnotatedValue using scope data""" + from sentry_sdk._types import AnnotatedValue + from sentry_sdk.consts import SPANDATA + + class MockSpan: + def __init__(self): + self.span_id = "test_span_123" + self.data = {} + + def set_data(self, key, value): + self.data[key] = value + + class MockScope: + def __init__(self): + self._gen_ai_messages_truncated = {} + + small_limit = 3000 + span = MockSpan() + scope = MockScope() + original_count = len(large_messages) + + # Simulate what integrations do + truncated_messages = truncate_and_annotate_messages( + large_messages, span, scope, max_bytes=small_limit + ) + span.set_data(SPANDATA.GEN_AI_REQUEST_MESSAGES, truncated_messages) + + # Verify metadata was set on scope + assert span.span_id in scope._gen_ai_messages_truncated + assert scope._gen_ai_messages_truncated[span.span_id] > 0 + + # Simulate what client.py does + event = {"spans": [{"span_id": span.span_id, "data": span.data.copy()}]} + + # Mimic client.py logic - using scope to get the removed count + for event_span in event["spans"]: + span_id = event_span.get("span_id") + span_data = event_span.get("data", {}) + if ( + span_id + and span_id in scope._gen_ai_messages_truncated + and SPANDATA.GEN_AI_REQUEST_MESSAGES in span_data + ): + messages = span_data[SPANDATA.GEN_AI_REQUEST_MESSAGES] + n_removed = scope._gen_ai_messages_truncated[span_id] + n_remaining = len(messages) if isinstance(messages, list) else 0 + original_count_calculated = n_removed + n_remaining + + span_data[SPANDATA.GEN_AI_REQUEST_MESSAGES] = AnnotatedValue( + safe_serialize(messages), + {"len": original_count_calculated}, + ) + + # Verify the annotation happened + messages_value = event["spans"][0]["data"][SPANDATA.GEN_AI_REQUEST_MESSAGES] + assert isinstance(messages_value, AnnotatedValue) + assert messages_value.metadata["len"] == original_count + assert isinstance(messages_value.value, str)