|
| 1 | +import dataclasses |
| 2 | +from collections import defaultdict |
| 3 | +from datetime import datetime |
| 4 | + |
| 5 | +from rest_framework import serializers |
| 6 | +from rest_framework.request import Request |
| 7 | +from rest_framework.response import Response |
| 8 | + |
| 9 | +from sentry import features |
| 10 | +from sentry.api.api_owners import ApiOwner |
| 11 | +from sentry.api.api_publish_status import ApiPublishStatus |
| 12 | +from sentry.api.base import region_silo_endpoint |
| 13 | +from sentry.api.bases import NoProjects, OrganizationEventsV2EndpointBase |
| 14 | +from sentry.api.paginator import GenericOffsetPaginator |
| 15 | +from sentry.api.utils import handle_query_errors |
| 16 | +from sentry.models.organization import Organization |
| 17 | +from sentry.search.eap.types import SearchResolverConfig |
| 18 | +from sentry.snuba.referrer import Referrer |
| 19 | +from sentry.snuba.spans_rpc import Spans |
| 20 | + |
| 21 | + |
| 22 | +class OrganizationAIConversationsSerializer(serializers.Serializer): |
| 23 | + """Serializer for validating query parameters.""" |
| 24 | + |
| 25 | + sort = serializers.CharField(required=False, default="-timestamp") |
| 26 | + query = serializers.CharField(required=False, allow_blank=True) |
| 27 | + |
| 28 | + def validate_sort(self, value): |
| 29 | + allowed_sorts = { |
| 30 | + "timestamp", |
| 31 | + "-timestamp", |
| 32 | + "duration", |
| 33 | + "-duration", |
| 34 | + "errors", |
| 35 | + "-errors", |
| 36 | + "llmCalls", |
| 37 | + "-llmCalls", |
| 38 | + "toolCalls", |
| 39 | + "-toolCalls", |
| 40 | + "totalTokens", |
| 41 | + "-totalTokens", |
| 42 | + "totalCost", |
| 43 | + "-totalCost", |
| 44 | + } |
| 45 | + if value not in allowed_sorts: |
| 46 | + raise serializers.ValidationError(f"Invalid sort option: {value}") |
| 47 | + return value |
| 48 | + |
| 49 | + |
| 50 | +@region_silo_endpoint |
| 51 | +class OrganizationAIConversationsEndpoint(OrganizationEventsV2EndpointBase): |
| 52 | + """Endpoint for fetching AI agent conversation traces.""" |
| 53 | + |
| 54 | + publish_status = { |
| 55 | + "GET": ApiPublishStatus.PRIVATE, |
| 56 | + } |
| 57 | + owner = ApiOwner.VISIBILITY |
| 58 | + |
| 59 | + def get(self, request: Request, organization: Organization) -> Response: |
| 60 | + """ |
| 61 | + Retrieve AI conversation traces for an organization. |
| 62 | + """ |
| 63 | + if not features.has("organizations:gen-ai-conversations", organization, actor=request.user): |
| 64 | + return Response(status=404) |
| 65 | + |
| 66 | + try: |
| 67 | + snuba_params = self.get_snuba_params(request, organization) |
| 68 | + except NoProjects: |
| 69 | + return Response(status=404) |
| 70 | + |
| 71 | + serializer = OrganizationAIConversationsSerializer(data=request.GET) |
| 72 | + if not serializer.is_valid(): |
| 73 | + return Response(serializer.errors, status=400) |
| 74 | + |
| 75 | + validated_data = serializer.validated_data |
| 76 | + |
| 77 | + # Create paginator with data function |
| 78 | + def data_fn(offset: int, limit: int): |
| 79 | + return self._get_conversations( |
| 80 | + snuba_params=snuba_params, |
| 81 | + offset=offset, |
| 82 | + limit=limit, |
| 83 | + _sort=validated_data.get("sort", "-timestamp"), |
| 84 | + _query=validated_data.get("query", ""), |
| 85 | + ) |
| 86 | + |
| 87 | + with handle_query_errors(): |
| 88 | + return self.paginate( |
| 89 | + request=request, |
| 90 | + paginator=GenericOffsetPaginator(data_fn=data_fn), |
| 91 | + on_results=lambda results: results, |
| 92 | + ) |
| 93 | + |
| 94 | + def _get_conversations( |
| 95 | + self, snuba_params, offset: int, limit: int, _sort: str, _query: str |
| 96 | + ) -> list[dict]: |
| 97 | + """ |
| 98 | + Fetch conversation data by querying spans grouped by gen_ai.conversation.id. |
| 99 | +
|
| 100 | + This is a two-step process: |
| 101 | + 1. Find conversation IDs that have spans in the time range (with pagination/sorting) |
| 102 | + 2. Get complete aggregations for those conversations (all spans, ignoring time filter) |
| 103 | +
|
| 104 | + Args: |
| 105 | + snuba_params: Snuba parameters including projects, time range, etc. |
| 106 | + offset: Starting index for pagination |
| 107 | + limit: Number of results to return |
| 108 | + _sort: Sort field and direction (currently only supports timestamp sorting, unused for now) |
| 109 | + _query: Search query (not yet implemented) |
| 110 | + """ |
| 111 | + # Step 1: Find conversation IDs with spans in the time range |
| 112 | + conversation_ids_results = Spans.run_table_query( |
| 113 | + params=snuba_params, |
| 114 | + query_string="has:gen_ai.conversation.id", |
| 115 | + selected_columns=[ |
| 116 | + "gen_ai.conversation.id", |
| 117 | + "max(precise.finish_ts)", |
| 118 | + ], |
| 119 | + orderby=["-max(precise.finish_ts)"], |
| 120 | + offset=offset, |
| 121 | + limit=limit, |
| 122 | + referrer=Referrer.API_AI_CONVERSATIONS.value, |
| 123 | + config=SearchResolverConfig(auto_fields=True), |
| 124 | + sampling_mode=None, |
| 125 | + ) |
| 126 | + |
| 127 | + conversation_ids: list[str] = [ |
| 128 | + conv_id |
| 129 | + for row in conversation_ids_results.get("data", []) |
| 130 | + if (conv_id := row.get("gen_ai.conversation.id")) |
| 131 | + ] |
| 132 | + |
| 133 | + if not conversation_ids: |
| 134 | + return [] |
| 135 | + |
| 136 | + # Step 2: Get complete aggregations for these conversations (all time) |
| 137 | + all_time_params = dataclasses.replace( |
| 138 | + snuba_params, |
| 139 | + start=datetime(2020, 1, 1), |
| 140 | + end=datetime(2100, 1, 1), |
| 141 | + ) |
| 142 | + |
| 143 | + results = Spans.run_table_query( |
| 144 | + params=all_time_params, |
| 145 | + query_string=f"gen_ai.conversation.id:[{','.join(conversation_ids)}]", |
| 146 | + selected_columns=[ |
| 147 | + "gen_ai.conversation.id", |
| 148 | + "failure_count()", |
| 149 | + "count_if(gen_ai.operation.type,equals,ai_client)", |
| 150 | + "count_if(span.op,equals,gen_ai.execute_tool)", |
| 151 | + "sum(gen_ai.usage.total_tokens)", |
| 152 | + "sum(gen_ai.usage.total_cost)", |
| 153 | + "min(precise.start_ts)", |
| 154 | + "max(precise.finish_ts)", |
| 155 | + "count_unique(trace)", |
| 156 | + ], |
| 157 | + orderby=None, |
| 158 | + offset=0, |
| 159 | + limit=len(conversation_ids), |
| 160 | + referrer=Referrer.API_AI_CONVERSATIONS_COMPLETE.value, |
| 161 | + config=SearchResolverConfig(auto_fields=True), |
| 162 | + sampling_mode=None, |
| 163 | + ) |
| 164 | + |
| 165 | + # Create a map of conversation data by ID |
| 166 | + conversations_map = {} |
| 167 | + for row in results.get("data", []): |
| 168 | + start_ts = row.get("min(precise.start_ts)", 0) |
| 169 | + finish_ts = row.get("max(precise.finish_ts)", 0) |
| 170 | + duration_ms = int((finish_ts - start_ts) * 1000) if finish_ts and start_ts else 0 |
| 171 | + timestamp_ms = int(finish_ts * 1000) if finish_ts else 0 |
| 172 | + |
| 173 | + conv_id = row.get("gen_ai.conversation.id", "") |
| 174 | + conversations_map[conv_id] = { |
| 175 | + "conversationId": conv_id, |
| 176 | + "flow": [], |
| 177 | + "duration": duration_ms, |
| 178 | + "errors": int(row.get("failure_count()") or 0), |
| 179 | + "llmCalls": int(row.get("count_if(gen_ai.operation.type,equals,ai_client)") or 0), |
| 180 | + "toolCalls": int(row.get("count_if(span.op,equals,gen_ai.execute_tool)") or 0), |
| 181 | + "totalTokens": int(row.get("sum(gen_ai.usage.total_tokens)") or 0), |
| 182 | + "totalCost": float(row.get("sum(gen_ai.usage.total_cost)") or 0), |
| 183 | + "timestamp": timestamp_ms, |
| 184 | + "traceCount": int(row.get("count_unique(trace)") or 0), |
| 185 | + "traceIds": [], |
| 186 | + } |
| 187 | + |
| 188 | + # Preserve the order from step 1 |
| 189 | + conversations = [ |
| 190 | + conversations_map[conv_id] |
| 191 | + for conv_id in conversation_ids |
| 192 | + if conv_id in conversations_map |
| 193 | + ] |
| 194 | + |
| 195 | + if conversations: |
| 196 | + self._enrich_conversations(all_time_params, conversations) |
| 197 | + |
| 198 | + return conversations |
| 199 | + |
| 200 | + def _enrich_conversations(self, snuba_params, conversations: list[dict]) -> None: |
| 201 | + """ |
| 202 | + Enrich conversations with flow and trace IDs by querying all spans. |
| 203 | + """ |
| 204 | + conversation_ids = [conv["conversationId"] for conv in conversations] |
| 205 | + |
| 206 | + # Query all spans for these conversations to get both agent flows and trace IDs |
| 207 | + all_spans_results = Spans.run_table_query( |
| 208 | + params=snuba_params, |
| 209 | + query_string=f"gen_ai.conversation.id:[{','.join(conversation_ids)}]", |
| 210 | + selected_columns=[ |
| 211 | + "gen_ai.conversation.id", |
| 212 | + "span.op", |
| 213 | + "span.description", |
| 214 | + "trace", |
| 215 | + "precise.start_ts", |
| 216 | + ], |
| 217 | + orderby=["gen_ai.conversation.id", "precise.start_ts"], |
| 218 | + offset=0, |
| 219 | + limit=10000, |
| 220 | + referrer=Referrer.API_AI_CONVERSATIONS_ENRICHMENT.value, |
| 221 | + config=SearchResolverConfig(auto_fields=True), |
| 222 | + sampling_mode=None, |
| 223 | + ) |
| 224 | + |
| 225 | + flows_by_conversation = defaultdict(list) |
| 226 | + traces_by_conversation = defaultdict(set) |
| 227 | + |
| 228 | + for row in all_spans_results.get("data", []): |
| 229 | + conv_id = row.get("gen_ai.conversation.id", "") |
| 230 | + if not conv_id: |
| 231 | + continue |
| 232 | + |
| 233 | + # Collect trace IDs |
| 234 | + trace_id = row.get("trace", "") |
| 235 | + if trace_id: |
| 236 | + traces_by_conversation[conv_id].add(trace_id) |
| 237 | + |
| 238 | + # Collect agent flow (only from invoke_agent spans) |
| 239 | + if row.get("span.op") == "gen_ai.invoke_agent": |
| 240 | + agent_name = row.get("span.description", "") |
| 241 | + if agent_name: |
| 242 | + flows_by_conversation[conv_id].append(agent_name) |
| 243 | + |
| 244 | + for conversation in conversations: |
| 245 | + conv_id = conversation["conversationId"] |
| 246 | + conversation["flow"] = flows_by_conversation.get(conv_id, []) |
| 247 | + conversation["traceIds"] = list(traces_by_conversation.get(conv_id, set())) |
0 commit comments