Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 51 additions & 2 deletions sentry_sdk/ai/utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
import json

from collections import deque
from typing import TYPE_CHECKING
from sys import getsizeof

if TYPE_CHECKING:
from typing import Any, Callable
from typing import Any, Callable, Dict, List, Optional, Tuple

from sentry_sdk.tracing import Span

import sentry_sdk
from sentry_sdk.utils import logger

MAX_GEN_AI_MESSAGE_BYTES = 20_000 # 20KB


class GEN_AI_ALLOWED_MESSAGE_ROLES:
SYSTEM = "system"
Expand Down Expand Up @@ -95,3 +99,48 @@ def get_start_span_function():
current_span is not None and current_span.containing_transaction is not None
)
return sentry_sdk.start_span if transaction_exists else sentry_sdk.start_transaction


def _find_truncation_index(messages, max_bytes):
# type: (List[Dict[str, Any]], int) -> int
"""
Find the index of the first message that would exceed the max bytes limit.
Compute the individual message sizes, and return the index of the first message from the back
of the list that would exceed the max bytes limit.
"""
running_sum = 0
for idx in range(len(messages) - 1, -1, -1):
size = len(json.dumps(messages[idx], separators=(",", ":")).encode("utf-8"))
running_sum += size
if running_sum > max_bytes:
return idx + 1

return 0


def truncate_messages_by_size(messages, max_bytes=MAX_GEN_AI_MESSAGE_BYTES):
# type: (List[Dict[str, Any]], int) -> Tuple[List[Dict[str, Any]], int]
serialized_json = json.dumps(messages, separators=(",", ":"))
current_size = len(serialized_json.encode("utf-8"))

if current_size <= max_bytes:
return messages, 0

truncation_index = _find_truncation_index(messages, max_bytes)
return messages[truncation_index:], truncation_index


def truncate_and_annotate_messages(
messages, span, scope, max_bytes=MAX_GEN_AI_MESSAGE_BYTES
):
# type: (Optional[List[Dict[str, Any]]], Any, Any, int) -> Optional[List[Dict[str, Any]]]
if not messages:
return None

truncated_messages, removed_count = truncate_messages_by_size(messages, max_bytes)
if removed_count > 0:
scope._gen_ai_messages_truncated[span.span_id] = len(messages) - len(
truncated_messages
)

return truncated_messages
19 changes: 19 additions & 0 deletions sentry_sdk/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,6 +598,24 @@ def _prepare_event(
if event_scrubber:
event_scrubber.scrub_event(event)

if scope is not None and scope._gen_ai_messages_truncated:
spans = event.get("spans", []) # type: List[Dict[str, Any]] | AnnotatedValue
if isinstance(spans, list):
for span in spans:
span_id = span.get("span_id", None)
span_data = span.get("data", {})
if (
span_id
and span_id in scope._gen_ai_messages_truncated
and SPANDATA.GEN_AI_REQUEST_MESSAGES in span_data
):
span_data[SPANDATA.GEN_AI_REQUEST_MESSAGES] = AnnotatedValue(
span_data[SPANDATA.GEN_AI_REQUEST_MESSAGES],
{
"len": scope._gen_ai_messages_truncated[span_id]
+ len(span_data[SPANDATA.GEN_AI_REQUEST_MESSAGES])
},
)
if previous_total_spans is not None:
event["spans"] = AnnotatedValue(
event.get("spans", []), {"len": previous_total_spans}
Expand All @@ -606,6 +624,7 @@ def _prepare_event(
event["breadcrumbs"] = AnnotatedValue(
event.get("breadcrumbs", []), {"len": previous_total_breadcrumbs}
)

# Postprocess the event here so that annotated types do
# generally not surface in before_send
if event is not None:
Expand Down
18 changes: 12 additions & 6 deletions sentry_sdk/integrations/openai.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from functools import wraps
from collections.abc import Iterable

import sentry_sdk
from sentry_sdk import consts
from sentry_sdk.ai.monitoring import record_token_usage
from sentry_sdk.ai.utils import set_data_normalized, normalize_message_roles
from sentry_sdk.ai.utils import (
set_data_normalized,
normalize_message_roles,
truncate_and_annotate_messages,
)
from sentry_sdk.consts import SPANDATA
from sentry_sdk.integrations import DidNotEnable, Integration
from sentry_sdk.scope import should_send_default_pii
Expand All @@ -18,7 +21,7 @@
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from typing import Any, List, Optional, Callable, AsyncIterator, Iterator
from typing import Any, Iterable, List, Optional, Callable, AsyncIterator, Iterator
from sentry_sdk.tracing import Span

try:
Expand Down Expand Up @@ -189,9 +192,12 @@ def _set_input_data(span, kwargs, operation, integration):
and integration.include_prompts
):
normalized_messages = normalize_message_roles(messages)
set_data_normalized(
span, SPANDATA.GEN_AI_REQUEST_MESSAGES, normalized_messages, unpack=False
)
scope = sentry_sdk.get_current_scope()
messages_data = truncate_and_annotate_messages(normalized_messages, span, scope)
if messages_data is not None:
set_data_normalized(
span, SPANDATA.GEN_AI_REQUEST_MESSAGES, messages_data, unpack=False
)

# Input attributes: Common
set_data_normalized(span, SPANDATA.GEN_AI_SYSTEM, "openai")
Expand Down
5 changes: 5 additions & 0 deletions sentry_sdk/scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ class Scope:
"_extras",
"_breadcrumbs",
"_n_breadcrumbs_truncated",
"_gen_ai_messages_truncated",
"_event_processors",
"_error_processors",
"_should_capture",
Expand All @@ -213,6 +214,7 @@ def __init__(self, ty=None, client=None):
self._name = None # type: Optional[str]
self._propagation_context = None # type: Optional[PropagationContext]
self._n_breadcrumbs_truncated = 0 # type: int
self._gen_ai_messages_truncated = {} # type: Dict[str, int]

self.client = NonRecordingClient() # type: sentry_sdk.client.BaseClient

Expand Down Expand Up @@ -247,6 +249,7 @@ def __copy__(self):

rv._breadcrumbs = copy(self._breadcrumbs)
rv._n_breadcrumbs_truncated = self._n_breadcrumbs_truncated
rv._gen_ai_messages_truncated = self._gen_ai_messages_truncated.copy()
rv._event_processors = self._event_processors.copy()
rv._error_processors = self._error_processors.copy()
rv._propagation_context = self._propagation_context
Expand Down Expand Up @@ -1583,6 +1586,8 @@ def update_from_scope(self, scope):
self._n_breadcrumbs_truncated = (
self._n_breadcrumbs_truncated + scope._n_breadcrumbs_truncated
)
if scope._gen_ai_messages_truncated:
self._gen_ai_messages_truncated.update(scope._gen_ai_messages_truncated)
if scope._span:
self._span = scope._span
if scope._attachments:
Expand Down
63 changes: 58 additions & 5 deletions tests/integrations/openai/test_openai.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import pytest

from sentry_sdk.utils import package_version
Expand All @@ -6,7 +7,6 @@
from openai import NOT_GIVEN
except ImportError:
NOT_GIVEN = None

try:
from openai import omit
except ImportError:
Expand Down Expand Up @@ -44,6 +44,9 @@
OpenAIIntegration,
_calculate_token_usage,
)
from sentry_sdk.ai.utils import MAX_GEN_AI_MESSAGE_BYTES
from sentry_sdk._types import AnnotatedValue
from sentry_sdk.serializer import serialize

from unittest import mock # python 3.3 and above

Expand Down Expand Up @@ -1456,6 +1459,7 @@ def test_empty_tools_in_chat_completion(sentry_init, capture_events, tools):

def test_openai_message_role_mapping(sentry_init, capture_events):
"""Test that OpenAI integration properly maps message roles like 'ai' to 'assistant'"""

sentry_init(
integrations=[OpenAIIntegration(include_prompts=True)],
traces_sample_rate=1.0,
Expand All @@ -1465,7 +1469,6 @@ def test_openai_message_role_mapping(sentry_init, capture_events):

client = OpenAI(api_key="z")
client.chat.completions._post = mock.Mock(return_value=EXAMPLE_CHAT_COMPLETION)

# Test messages with mixed roles including "ai" that should be mapped to "assistant"
test_messages = [
{"role": "system", "content": "You are helpful."},
Expand All @@ -1476,11 +1479,9 @@ def test_openai_message_role_mapping(sentry_init, capture_events):

with start_transaction(name="openai tx"):
client.chat.completions.create(model="test-model", messages=test_messages)

# Verify that the span was created correctly
(event,) = events
span = event["spans"][0]

# Verify that the span was created correctly
assert span["op"] == "gen_ai.chat"
assert SPANDATA.GEN_AI_REQUEST_MESSAGES in span["data"]

Expand All @@ -1505,3 +1506,55 @@ def test_openai_message_role_mapping(sentry_init, capture_events):
# Verify no "ai" roles remain
roles = [msg["role"] for msg in stored_messages]
assert "ai" not in roles


def test_openai_message_truncation(sentry_init, capture_events):
"""Test that large messages are truncated properly in OpenAI integration."""
sentry_init(
integrations=[OpenAIIntegration(include_prompts=True)],
traces_sample_rate=1.0,
send_default_pii=True,
)
events = capture_events()

client = OpenAI(api_key="z")
client.chat.completions._post = mock.Mock(return_value=EXAMPLE_CHAT_COMPLETION)

large_content = (
"This is a very long message that will exceed our size limits. " * 1000
)
large_messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": large_content},
{"role": "assistant", "content": large_content},
{"role": "user", "content": large_content},
]

with start_transaction(name="openai tx"):
client.chat.completions.create(
model="some-model",
messages=large_messages,
)

(event,) = events
span = event["spans"][0]
assert SPANDATA.GEN_AI_REQUEST_MESSAGES in span["data"]

messages_data = span["data"][SPANDATA.GEN_AI_REQUEST_MESSAGES]
assert isinstance(messages_data, str)

parsed_messages = json.loads(messages_data)
assert isinstance(parsed_messages, list)
assert len(parsed_messages) <= len(large_messages)

if "_meta" in event and len(parsed_messages) < len(large_messages):
meta_path = event["_meta"]
if (
"spans" in meta_path
and "0" in meta_path["spans"]
and "data" in meta_path["spans"]["0"]
):
span_meta = meta_path["spans"]["0"]["data"]
if SPANDATA.GEN_AI_REQUEST_MESSAGES in span_meta:
messages_meta = span_meta[SPANDATA.GEN_AI_REQUEST_MESSAGES]
assert "len" in messages_meta.get("", {})
Loading