diff --git a/sentry_sdk/ai_analytics.py b/sentry_sdk/ai_analytics.py new file mode 100644 index 0000000000..ebdbc56c54 --- /dev/null +++ b/sentry_sdk/ai_analytics.py @@ -0,0 +1,78 @@ +from functools import wraps + +from sentry_sdk import start_span +from sentry_sdk.tracing import Span +from sentry_sdk.utils import ContextVar +from sentry_sdk._types import TYPE_CHECKING + +if TYPE_CHECKING: + from typing import Optional, Callable, Any + +_ai_pipeline_name = ContextVar("ai_pipeline_name", default=None) + + +def set_ai_pipeline_name(name): + # type: (Optional[str]) -> None + _ai_pipeline_name.set(name) + + +def get_ai_pipeline_name(): + # type: () -> Optional[str] + return _ai_pipeline_name.get() + + +def ai_pipeline(description, op="ai.pipeline", **span_kwargs): + # type: (str, str, Any) -> Callable[..., Any] + def decorator(f): + # type: (Callable[..., Any]) -> Callable[..., Any] + @wraps(f) + def wrapped(*args, **kwargs): + # type: (Any, Any) -> Any + with start_span(description=description, op=op, **span_kwargs): + _ai_pipeline_name.set(description) + res = f(*args, **kwargs) + _ai_pipeline_name.set(None) + return res + + return wrapped + + return decorator + + +def ai_run(description, op="ai.run", **span_kwargs): + # type: (str, str, Any) -> Callable[..., Any] + def decorator(f): + # type: (Callable[..., Any]) -> Callable[..., Any] + @wraps(f) + def wrapped(*args, **kwargs): + # type: (Any, Any) -> Any + with start_span(description=description, op=op, **span_kwargs) as span: + curr_pipeline = _ai_pipeline_name.get() + if curr_pipeline: + span.set_data("ai.pipeline.name", curr_pipeline) + return f(*args, **kwargs) + + return wrapped + + return decorator + + +def record_token_usage( + span, prompt_tokens=None, completion_tokens=None, total_tokens=None +): + # type: (Span, Optional[int], Optional[int], Optional[int]) -> None + ai_pipeline_name = get_ai_pipeline_name() + if ai_pipeline_name: + span.set_data("ai.pipeline.name", ai_pipeline_name) + if prompt_tokens is not None: + span.set_measurement("ai_prompt_tokens_used", value=prompt_tokens) + if completion_tokens is not None: + span.set_measurement("ai_completion_tokens_used", value=completion_tokens) + if ( + total_tokens is None + and prompt_tokens is not None + and completion_tokens is not None + ): + total_tokens = prompt_tokens + completion_tokens + if total_tokens is not None: + span.set_measurement("ai_total_tokens_used", total_tokens) diff --git a/sentry_sdk/integrations/_ai_common.py b/sentry_sdk/integrations/_ai_common.py index 5b25d1fc69..42d46304e4 100644 --- a/sentry_sdk/integrations/_ai_common.py +++ b/sentry_sdk/integrations/_ai_common.py @@ -1,7 +1,7 @@ from sentry_sdk._types import TYPE_CHECKING if TYPE_CHECKING: - from typing import Any, Optional + from typing import Any from sentry_sdk.tracing import Span from sentry_sdk.utils import logger @@ -30,21 +30,3 @@ def set_data_normalized(span, key, value): # type: (Span, str, Any) -> None normalized = _normalize_data(value) span.set_data(key, normalized) - - -def record_token_usage( - span, prompt_tokens=None, completion_tokens=None, total_tokens=None -): - # type: (Span, Optional[int], Optional[int], Optional[int]) -> None - if prompt_tokens is not None: - span.set_measurement("ai_prompt_tokens_used", value=prompt_tokens) - if completion_tokens is not None: - span.set_measurement("ai_completion_tokens_used", value=completion_tokens) - if ( - total_tokens is None - and prompt_tokens is not None - and completion_tokens is not None - ): - total_tokens = prompt_tokens + completion_tokens - if total_tokens is not None: - span.set_measurement("ai_total_tokens_used", total_tokens) diff --git a/sentry_sdk/integrations/langchain.py b/sentry_sdk/integrations/langchain.py index 0cebe1ec17..f3058fe087 100644 --- a/sentry_sdk/integrations/langchain.py +++ b/sentry_sdk/integrations/langchain.py @@ -3,8 +3,9 @@ import sentry_sdk from sentry_sdk._types import TYPE_CHECKING +from sentry_sdk.ai_analytics import set_ai_pipeline_name, record_token_usage from sentry_sdk.consts import OP, SPANDATA -from sentry_sdk.integrations._ai_common import set_data_normalized, record_token_usage +from sentry_sdk.integrations._ai_common import set_data_normalized from sentry_sdk.scope import should_send_default_pii from sentry_sdk.tracing import Span @@ -88,6 +89,7 @@ class WatchedSpan: num_prompt_tokens = 0 # type: int no_collect_tokens = False # type: bool children = [] # type: List[WatchedSpan] + is_pipeline = False # type: bool def __init__(self, span): # type: (Span) -> None @@ -134,9 +136,6 @@ def _normalize_langchain_message(self, message): def _create_span(self, run_id, parent_id, **kwargs): # type: (SentryLangchainCallback, UUID, Optional[Any], Any) -> WatchedSpan - if "origin" not in kwargs: - kwargs["origin"] = "auto.ai.langchain" - watched_span = None # type: Optional[WatchedSpan] if parent_id: parent_span = self.span_map[parent_id] # type: Optional[WatchedSpan] @@ -146,6 +145,11 @@ def _create_span(self, run_id, parent_id, **kwargs): if watched_span is None: watched_span = WatchedSpan(sentry_sdk.start_span(**kwargs)) + if kwargs.get("op", "").startswith("ai.pipeline."): + if kwargs.get("description"): + set_ai_pipeline_name(kwargs.get("description")) + watched_span.is_pipeline = True + watched_span.span.__enter__() self.span_map[run_id] = watched_span self.gc_span_map() @@ -154,6 +158,9 @@ def _create_span(self, run_id, parent_id, **kwargs): def _exit_span(self, span_data, run_id): # type: (SentryLangchainCallback, WatchedSpan, UUID) -> None + if span_data.is_pipeline: + set_ai_pipeline_name(None) + span_data.span.__exit__(None, None, None) del self.span_map[run_id] diff --git a/sentry_sdk/integrations/openai.py b/sentry_sdk/integrations/openai.py index ffb8a391fa..ae5c9e70ac 100644 --- a/sentry_sdk/integrations/openai.py +++ b/sentry_sdk/integrations/openai.py @@ -2,8 +2,9 @@ from sentry_sdk import consts from sentry_sdk._types import TYPE_CHECKING +from sentry_sdk.ai_analytics import record_token_usage from sentry_sdk.consts import SPANDATA -from sentry_sdk.integrations._ai_common import set_data_normalized, record_token_usage +from sentry_sdk.integrations._ai_common import set_data_normalized if TYPE_CHECKING: from typing import Any, Iterable, List, Optional, Callable, Iterator @@ -141,7 +142,6 @@ def new_chat_completion(*args, **kwargs): span = sentry_sdk.start_span( op=consts.OP.OPENAI_CHAT_COMPLETIONS_CREATE, - origin="auto.ai.openai", description="Chat Completion", ) span.__enter__() @@ -225,7 +225,6 @@ def new_embeddings_create(*args, **kwargs): # type: (*Any, **Any) -> Any with sentry_sdk.start_span( op=consts.OP.OPENAI_EMBEDDINGS_CREATE, - origin="auto.ai.openai", description="OpenAI Embedding Creation", ) as span: integration = sentry_sdk.get_client().get_integration(OpenAIIntegration)