diff --git a/src/sentry/seer/endpoints/seer_rpc.py b/src/sentry/seer/endpoints/seer_rpc.py index 7a865707418c6d..b546b591adb2ff 100644 --- a/src/sentry/seer/endpoints/seer_rpc.py +++ b/src/sentry/seer/endpoints/seer_rpc.py @@ -72,6 +72,7 @@ rpc_get_trace_for_transaction, rpc_get_transactions_for_project, ) +from sentry.seer.explorer.tools import execute_trace_query_chart, execute_trace_query_table from sentry.seer.fetch_issues import by_error_type, by_function_name, by_text_query, utils from sentry.seer.seer_setup import get_seer_org_acknowledgement from sentry.sentry_apps.tasks.sentry_apps import broadcast_webhooks_for_organization @@ -234,6 +235,21 @@ def get_organization_slug(*, org_id: int) -> dict: return {"slug": org.slug} +def get_organization_project_ids(*, org_id: int) -> dict: + """Get all project IDs for an organization""" + from sentry.models.project import Project + + try: + organization = Organization.objects.get(id=org_id) + except Organization.DoesNotExist: + return {"project_ids": []} + + project_ids = list( + Project.objects.filter(organization=organization).values_list("id", flat=True) + ) + return {"project_ids": project_ids} + + def _can_use_prevent_ai_features(org: Organization) -> bool: hide_ai_features = org.get_option("sentry:hide_ai_features", HIDE_AI_FEATURES_DEFAULT) pr_review_test_generation_enabled = bool( @@ -972,6 +988,7 @@ def send_seer_webhook(*, event_name: str, organization_id: int, payload: dict) - # Common to Seer features "get_organization_seer_consent_by_org_name": get_organization_seer_consent_by_org_name, "get_github_enterprise_integration_config": get_github_enterprise_integration_config, + "get_organization_project_ids": get_organization_project_ids, # # Autofix "get_organization_slug": get_organization_slug, @@ -999,6 +1016,8 @@ def send_seer_webhook(*, event_name: str, organization_id: int, payload: dict) - "get_trace_for_transaction": rpc_get_trace_for_transaction, "get_profiles_for_trace": rpc_get_profiles_for_trace, "get_issues_for_transaction": rpc_get_issues_for_transaction, + "execute_trace_query_chart": execute_trace_query_chart, + "execute_trace_query_table": execute_trace_query_table, # # Replays "get_replay_summary_logs": rpc_get_replay_summary_logs, diff --git a/src/sentry/seer/explorer/tools.py b/src/sentry/seer/explorer/tools.py new file mode 100644 index 00000000000000..72b7ef87ed0add --- /dev/null +++ b/src/sentry/seer/explorer/tools.py @@ -0,0 +1,135 @@ +import logging +from typing import Any + +from sentry.api import client +from sentry.models.apikey import ApiKey +from sentry.models.organization import Organization +from sentry.snuba.referrer import Referrer + +logger = logging.getLogger(__name__) + + +def execute_trace_query_chart( + *, + org_id: int, + query: str, + stats_period: str, + y_axes: list[str], + group_by: list[str] | None = None, +) -> dict[str, Any] | None: + """ + Execute a trace query to get chart/timeseries data by calling the events-stats endpoint. + """ + try: + organization = Organization.objects.get(id=org_id) + except Organization.DoesNotExist: + logger.warning("Organization not found", extra={"org_id": org_id}) + return None + + # Get all project IDs for the organization + project_ids = list(organization.project_set.values_list("id", flat=True)) + if not project_ids: + logger.warning("No projects found for organization", extra={"org_id": org_id}) + return None + + params: dict[str, Any] = { + "query": query, + "statsPeriod": stats_period, + "yAxis": y_axes, + "project": project_ids, + "dataset": "spans", + "referrer": Referrer.SEER_RPC, + "transformAliasToInputFormat": "1", # Required for RPC datasets + } + + # Add group_by if provided (for top events) + if group_by and len(group_by) > 0: + params["topEvents"] = 5 + params["field"] = group_by + params["excludeOther"] = "0" # Include "Other" series + + resp = client.get( + auth=ApiKey(organization_id=organization.id, scope_list=["org:read", "project:read"]), + user=None, + path=f"/organizations/{organization.slug}/events-stats/", + params=params, + ) + data = resp.data + + # Normalize response format: single-axis returns flat format, multi-axis returns nested + # We always want the nested format {"metric": {"data": [...]}} + if isinstance(data, dict) and "data" in data and len(y_axes) == 1: + # Single axis response - wrap it + metric_name = y_axes[0] + return {metric_name: data} + + return data + + +def execute_trace_query_table( + *, + org_id: int, + query: str, + stats_period: str, + sort: str, + group_by: list[str] | None = None, + y_axes: list[str] | None = None, + per_page: int = 50, +) -> dict[str, Any] | None: + """ + Execute a trace query to get table data by calling the events endpoint. + """ + try: + organization = Organization.objects.get(id=org_id) + except Organization.DoesNotExist: + logger.warning("Organization not found", extra={"org_id": org_id}) + return None + + # Get all project IDs for the organization + project_ids = list(organization.project_set.values_list("id", flat=True)) + if not project_ids: + logger.warning("No projects found for organization", extra={"org_id": org_id}) + return None + + # Determine fields based on mode + if group_by and len(group_by) > 0: + # Aggregates mode: group_by fields + aggregate functions + fields = list(group_by) + if y_axes: + fields.extend(y_axes) + else: + # Samples mode: default span fields + fields = [ + "id", + "span.op", + "span.description", + "span.duration", + "transaction", + "timestamp", + "project", + "project.name", + "trace", + ] + + params: dict[str, Any] = { + "query": query, + "statsPeriod": stats_period, + "field": fields, + "sort": sort if sort else ("-timestamp" if not group_by else None), + "per_page": per_page, + "project": project_ids, + "dataset": "spans", + "referrer": Referrer.SEER_RPC, + "transformAliasToInputFormat": "1", # Required for RPC datasets + } + + # Remove None values + params = {k: v for k, v in params.items() if v is not None} + + resp = client.get( + auth=ApiKey(organization_id=organization.id, scope_list=["org:read", "project:read"]), + user=None, + path=f"/organizations/{organization.slug}/events/", + params=params, + ) + return resp.data diff --git a/tests/sentry/seer/explorer/test_tools.py b/tests/sentry/seer/explorer/test_tools.py new file mode 100644 index 00000000000000..cae0c8530fe4da --- /dev/null +++ b/tests/sentry/seer/explorer/test_tools.py @@ -0,0 +1,303 @@ +import pytest + +from sentry.seer.explorer.tools import execute_trace_query_chart, execute_trace_query_table +from sentry.testutils.cases import APITransactionTestCase, SnubaTestCase, SpanTestCase +from sentry.testutils.helpers.datetime import before_now + + +@pytest.mark.django_db(databases=["default", "control"]) +class TestExplorerTools(APITransactionTestCase, SnubaTestCase, SpanTestCase): + databases = {"default", "control"} + + def setUp(self): + super().setUp() + self.ten_mins_ago = before_now(minutes=10) + + # Create spans using the exact pattern from working tests + spans = [ + self.create_span( + { + "description": "SELECT * FROM users WHERE id = ?", + "sentry_tags": {"op": "db", "transaction": "api/user/profile"}, + }, + start_ts=self.ten_mins_ago, + duration=150, + ), + self.create_span( + { + "description": "SELECT * FROM posts WHERE user_id = ?", + "sentry_tags": {"op": "db", "transaction": "api/user/posts"}, + }, + start_ts=self.ten_mins_ago, + duration=200, + ), + self.create_span( + { + "description": "GET https://api.external.com/data", + "sentry_tags": {"op": "http.client", "transaction": "api/external/fetch"}, + }, + start_ts=self.ten_mins_ago, + duration=500, + ), + self.create_span( + { + "description": "Redis GET user:123", + "sentry_tags": {"op": "cache.get", "transaction": "api/user/profile"}, + }, + start_ts=self.ten_mins_ago, + duration=25, + ), + ] + + self.store_spans(spans, is_eap=True) + + def test_execute_trace_query_chart_count_metric(self): + """Test chart query with count() metric using real data""" + result = execute_trace_query_chart( + org_id=self.organization.id, + query="", + stats_period="1h", + y_axes=["count()"], + ) + + assert result is not None + # Result is now dict from events-stats endpoint + assert "count()" in result + assert "data" in result["count()"] + + data_points = result["count()"]["data"] + assert len(data_points) > 0 + + # Each data point is [timestamp, [{"count": value}]] + total_count = sum(point[1][0]["count"] for point in data_points if point[1]) + assert total_count == 4 + + def test_execute_trace_query_chart_multiple_metrics(self): + """Test chart query with multiple metrics""" + result = execute_trace_query_chart( + org_id=self.organization.id, + query="", + stats_period="1h", + y_axes=["count()", "avg(span.duration)"], + ) + + assert result is not None + # Should have both metrics in result + assert "count()" in result + assert "avg(span.duration)" in result + + # Check count metric + count_data = result["count()"]["data"] + assert len(count_data) > 0 + total_count = sum(point[1][0]["count"] for point in count_data if point[1]) + assert total_count == 4 + + # Check avg duration metric + avg_duration_data = result["avg(span.duration)"]["data"] + assert len(avg_duration_data) > 0 + # Should have duration values where count > 0 + duration_values = [ + point[1][0]["count"] + for point in avg_duration_data + if point[1] and point[1][0]["count"] > 0 + ] + assert len(duration_values) > 0 + + def test_execute_trace_query_table_basic_query(self): + """Test table query returns actual span data""" + result = execute_trace_query_table( + org_id=self.organization.id, + query="", + stats_period="1h", + sort="-timestamp", + per_page=10, + ) + + assert result is not None + assert "data" in result + assert "meta" in result + + rows = result["data"] + assert len(rows) == 4 # Should find all 4 spans we created + + # Verify span data + db_rows = [row for row in rows if row.get("span.op") == "db"] + assert len(db_rows) == 2 # Two database spans + + http_rows = [row for row in rows if row.get("span.op") == "http.client"] + assert len(http_rows) == 1 # One HTTP span + + cache_rows = [row for row in rows if row.get("span.op") == "cache.get"] + assert len(cache_rows) == 1 # One cache span + + def test_execute_trace_query_table_specific_operation(self): + """Test table query filtering by specific operation""" + result = execute_trace_query_table( + org_id=self.organization.id, + query="span.op:http.client", + stats_period="1h", + sort="-timestamp", + ) + + assert result is not None + rows = result["data"] + + # Should find our http.client span + http_rows = [row for row in rows if row.get("span.op") == "http.client"] + assert len(http_rows) == 1 + + # Check description contains our external API call + descriptions = [row.get("span.description", "") for row in http_rows] + assert any("api.external.com" in desc for desc in descriptions) + + def test_execute_trace_query_chart_empty_results(self): + """Test chart query with query that returns no results""" + result = execute_trace_query_chart( + org_id=self.organization.id, + query="span.op:nonexistent", + stats_period="1h", + y_axes=["count()"], + ) + + assert result is not None + assert "count()" in result + assert "data" in result["count()"] + + # Should have time buckets but with zero counts + data_points = result["count()"]["data"] + if data_points: + total_count = sum(point[1][0]["count"] for point in data_points if point[1]) + assert total_count == 0 + + def test_execute_trace_query_table_empty_results(self): + """Test table query with query that returns no results""" + result = execute_trace_query_table( + org_id=self.organization.id, + query="span.op:nonexistent", + stats_period="1h", + sort="-timestamp", + ) + + assert result is not None + assert "data" in result + assert len(result["data"]) == 0 + + def test_execute_trace_query_chart_duration_filtering(self): + """Test chart query with duration filter""" + result = execute_trace_query_chart( + org_id=self.organization.id, + query="span.duration:>100ms", # Should match spans > 100ms + stats_period="1h", + y_axes=["count()"], + ) + + assert result is not None + assert "count()" in result + assert "data" in result["count()"] + + data_points = result["count()"]["data"] + + # Should find our longer spans (150ms, 200ms, 500ms) + total_count = sum(point[1][0]["count"] for point in data_points if point[1]) + assert total_count == 3 + + def test_execute_trace_query_table_duration_stats(self): + """Test table query with duration statistics""" + result = execute_trace_query_table( + org_id=self.organization.id, + query="", + stats_period="1h", + sort="-span.duration", + per_page=20, + ) + + assert result is not None + rows = result["data"] + assert len(rows) == 4 # All our spans + + # Check that durations are present and reasonable + durations = [row.get("span.duration") for row in rows if row.get("span.duration")] + assert len(durations) == 4 + + # Should include our test durations (converted from ms to ms in storage) + expected_durations = [150, 200, 500, 25] + for expected in expected_durations: + # Allow for some tolerance in duration matching + assert any(abs(d - expected) < 10 for d in durations) + + def test_execute_trace_query_nonexistent_organization(self): + """Test queries handle nonexistent organization gracefully""" + chart_result = execute_trace_query_chart( + org_id=99999, + query="", + stats_period="1h", + y_axes=["count()"], + ) + assert chart_result is None + + table_result = execute_trace_query_table( + org_id=99999, + query="", + stats_period="1h", + sort="-count", + ) + assert table_result is None + + def test_execute_trace_query_chart_with_groupby(self): + """Test chart query with group_by parameter for aggregates""" + result = execute_trace_query_chart( + org_id=self.organization.id, + query="", + stats_period="1h", + y_axes=["count()"], + group_by=["span.op"], + ) + + assert result is not None + # Grouped results have group values as top-level keys + # Should have different span.op values like "db", "http.client", etc. + assert len(result) > 0 + + # Each group should have the metric + for group_value, metrics in result.items(): + if isinstance(metrics, dict) and "count()" in metrics: + assert "data" in metrics["count()"] + + def test_execute_trace_query_table_with_groupby(self): + """Test table query with group_by for aggregates mode""" + result = execute_trace_query_table( + org_id=self.organization.id, + query="", + stats_period="1h", + sort="-count()", + group_by=["span.op"], + y_axes=["count()"], + per_page=10, + ) + + assert result is not None + assert "data" in result + assert "meta" in result + + rows = result["data"] + # Should have one row per unique span.op value + assert len(rows) > 0 + + # Each row should have span.op and count() + for row in rows: + assert "span.op" in row + assert "count()" in row + + def test_get_organization_project_ids(self): + """Test the get_organization_project_ids RPC method""" + from sentry.seer.endpoints.seer_rpc import get_organization_project_ids + + # Test with valid organization + result = get_organization_project_ids(org_id=self.organization.id) + assert "project_ids" in result + assert isinstance(result["project_ids"], list) + assert self.project.id in result["project_ids"] + + # Test with nonexistent organization + result = get_organization_project_ids(org_id=99999) + assert result == {"project_ids": []}