diff --git a/src/sentry/seer/endpoints/organization_seer_explorer_chat.py b/src/sentry/seer/endpoints/organization_seer_explorer_chat.py index 9794b557d35e19..ebdbaf91c34141 100644 --- a/src/sentry/seer/endpoints/organization_seer_explorer_chat.py +++ b/src/sentry/seer/endpoints/organization_seer_explorer_chat.py @@ -3,6 +3,7 @@ import logging from rest_framework import serializers +from rest_framework.exceptions import PermissionDenied from rest_framework.request import Request from rest_framework.response import Response @@ -12,7 +13,8 @@ from sentry.api.bases.organization import OrganizationEndpoint, OrganizationPermission from sentry.models.organization import Organization from sentry.ratelimits.config import RateLimitConfig -from sentry.seer.explorer.client import continue_seer_run, get_seer_run, start_seer_run +from sentry.seer.explorer.client import SeerExplorerClient +from sentry.seer.models import SeerPermissionError from sentry.types.ratelimit import RateLimit, RateLimitCategory logger = logging.getLogger(__name__) @@ -77,8 +79,11 @@ def get( return Response({"session": None}, status=404) try: - state = get_seer_run(run_id=int(run_id), organization=organization, user=request.user) + client = SeerExplorerClient(organization, request.user) + state = client.get_run(run_id=int(run_id)) return Response({"session": state.dict()}) + except SeerPermissionError as e: + raise PermissionDenied(e.message) from e except ValueError: return Response({"session": None}, status=404) @@ -106,23 +111,22 @@ def post( insert_index = validated_data.get("insert_index") on_page_context = validated_data.get("on_page_context") - # Use client to start or continue run - if run_id: - # Continue existing conversation - result_run_id = continue_seer_run( - run_id=int(run_id), - organization=organization, - prompt=query, - user=request.user, - insert_index=insert_index, - on_page_context=on_page_context, - ) - else: - # Start new conversation - result_run_id = start_seer_run( - organization=organization, - prompt=query, - user=request.user, - on_page_context=on_page_context, - ) - return Response({"run_id": result_run_id}) + try: + client = SeerExplorerClient(organization, request.user) + if run_id: + # Continue existing conversation + result_run_id = client.continue_run( + run_id=int(run_id), + prompt=query, + insert_index=insert_index, + on_page_context=on_page_context, + ) + else: + # Start new conversation + result_run_id = client.start_run( + prompt=query, + on_page_context=on_page_context, + ) + return Response({"run_id": result_run_id}) + except SeerPermissionError as e: + raise PermissionDenied(e.message) from e diff --git a/src/sentry/seer/endpoints/organization_seer_explorer_runs.py b/src/sentry/seer/endpoints/organization_seer_explorer_runs.py index 7ece231bd1d29f..fdb4a0ad44dba2 100644 --- a/src/sentry/seer/endpoints/organization_seer_explorer_runs.py +++ b/src/sentry/seer/endpoints/organization_seer_explorer_runs.py @@ -13,7 +13,7 @@ from sentry.api.bases.organization import OrganizationEndpoint, OrganizationPermission from sentry.api.paginator import GenericOffsetPaginator from sentry.models.organization import Organization -from sentry.seer.explorer.client import get_seer_runs +from sentry.seer.explorer.client import SeerExplorerClient from sentry.seer.models import SeerPermissionError logger = logging.getLogger(__name__) @@ -48,9 +48,8 @@ def get(self, request: Request, organization: Organization) -> Response: def _make_seer_runs_request(offset: int, limit: int) -> dict[str, Any]: try: - runs = get_seer_runs( - organization=organization, - user=request.user, + client = SeerExplorerClient(organization, request.user) + runs = client.get_runs( category_key=category_key, category_value=category_value, offset=offset, diff --git a/src/sentry/seer/explorer/client.py b/src/sentry/seer/explorer/client.py index 7023646e939c65..dd75dfe2c7bfe9 100644 --- a/src/sentry/seer/explorer/client.py +++ b/src/sentry/seer/explorer/client.py @@ -1,42 +1,13 @@ -""" -Seer Explorer Client - Simple interface for running AI debugging agents. - -This module provides a minimal interface for Sentry developers to build agentic features -with full Sentry context, all without directly touching Seer code. - -Example usage: - from sentry.seer.explorer.client import start_seer_run, continue_seer_run, get_seer_run - - # Start a new conversation (client automatically collects user/org context) - run_id = start_seer_run( - organization=organization, - prompt="Analyze trace XYZ and find performance issues", - user=request.user, - ) - - # Continue the conversation - continue_seer_run( - run_id=run_id, - organization=organization, - prompt="What about memory leaks?", - ) - - # Get current status (non-blocking) - state = get_seer_run(run_id=run_id, organization=organization) - print(state.status, state.blocks) - - # Or wait for completion (blocking with polling) - state = get_seer_run(run_id=run_id, organization=organization, blocking=True) -""" - from __future__ import annotations +import logging from typing import Any import orjson import requests from django.conf import settings from django.contrib.auth.models import AnonymousUser +from pydantic import BaseModel, ValidationError from sentry.models.organization import Organization from sentry.seer.explorer.client_models import ExplorerRun, SeerRunState @@ -50,237 +21,280 @@ from sentry.seer.signed_seer_api import sign_with_seer_secret from sentry.users.models.user import User +logger = logging.getLogger(__name__) -def start_seer_run( - organization: Organization, - prompt: str, - user: User | AnonymousUser | None = None, - on_page_context: str | None = None, - category_key: str | None = None, - category_value: str | None = None, -) -> int: - """ - Start a new Seer Explorer session. - - The client automatically collects user/org context (teams, projects, etc.) - and sends it to Seer for the agent to use. - - Args: - organization: Sentry organization - prompt: The initial task/query for the agent - user: User (from request.user, can be User or AnonymousUser or None) - on_page_context: Optional context from the user's screen - category_key: Optional category key for filtering/grouping runs (e.g., "bug-fixer", "researcher"). Should identify the purpose/use case of the run. - category_value: Optional category value for filtering/grouping runs (e.g., "issue-123", "a5b32"). Should identify individual runs within the category. - - Returns: - int: The run ID that can be used to fetch results or continue the conversation - Raises: - SeerPermissionError: If the user/org doesn't have access to Seer Explorer - requests.HTTPError: If the Seer API request fails +class SeerExplorerClient: """ - # Check access - has_access, error = has_seer_explorer_access_with_detail(organization, user) - if not has_access: - raise SeerPermissionError(error or "Access denied") - - path = "/v1/automation/explorer/chat" - - payload: dict[str, Any] = { - "organization_id": organization.id, - "query": prompt, - "run_id": None, - "insert_index": None, - "on_page_context": on_page_context, - "user_org_context": collect_user_org_context(user, organization), - } - - if category_key or category_value: - if not category_key or not category_value: - raise ValueError("category_key and category_value must be provided together") - payload["category_key"] = category_key - payload["category_value"] = category_value - - body = orjson.dumps(payload, option=orjson.OPT_NON_STR_KEYS) - - response = requests.post( - f"{settings.SEER_AUTOFIX_URL}{path}", - data=body, - headers={ - "content-type": "application/json;charset=utf-8", - **sign_with_seer_secret(body), - }, - ) - - response.raise_for_status() - result = response.json() - return result["run_id"] - - -def continue_seer_run( - run_id: int, - organization: Organization, - prompt: str, - user: User | AnonymousUser | None = None, - insert_index: int | None = None, - on_page_context: str | None = None, -) -> int: + A simple client for Seer Explorer, our general debugging agent. + + This provides a class-based interface for Sentry developers to build agentic features + with full Sentry context. + + Example usage: + from sentry.seer.explorer.client import SeerExplorerClient + from pydantic import BaseModel + + # Simple usage + client = SeerExplorerClient(organization, user) + run_id = client.start_run("Analyze trace XYZ and find performance issues") + state = client.get_run(run_id) + + # With artifacts + class BugAnalysis(BaseModel): + issue_count: int + severity: str + recommendations: list[str] + + client = SeerExplorerClient(organization, user, artifact_schema=BugAnalysis) + run_id = client.start_run("Analyze recent 500 errors") + state = client.get_run(run_id, blocking=True) + + # Artifact is automatically reconstructed as BugAnalysis instance at runtime + if state.artifact: + artifact = cast(BugAnalysis, state.artifact) + print(f"Found {artifact.issue_count} issues") + + Args: + organization: Sentry organization + user: User for permission checks and user-specific context (can be User, AnonymousUser, or None) + artifact_schema: Optional Pydantic model to generate a structured artifact at the end of the run """ - Continue an existing Seer Explorer session. - - This allows you to add follow-up queries to an ongoing conversation. - User context is NOT collected again (it was already captured at start). - - Args: - run_id: The run ID from start_seer_run() - organization: Sentry organization - prompt: The follow-up task/query for the agent - user: User (for permission check) - insert_index: Optional index to insert the message at - on_page_context: Optional context from the user's screen - - Returns: - int: The run ID (same as input) - Raises: - SeerPermissionError: If the user/org doesn't have access to Seer Explorer - requests.HTTPError: If the Seer API request fails - """ - # Check access - has_access, error = has_seer_explorer_access_with_detail(organization, user) - if not has_access: - raise SeerPermissionError(error or "Access denied") - - path = "/v1/automation/explorer/chat" - - payload: dict[str, Any] = { - "organization_id": organization.id, - "query": prompt, - "run_id": run_id, - "insert_index": insert_index, - "on_page_context": on_page_context, - } - - body = orjson.dumps(payload, option=orjson.OPT_NON_STR_KEYS) - - response = requests.post( - f"{settings.SEER_AUTOFIX_URL}{path}", - data=body, - headers={ - "content-type": "application/json;charset=utf-8", - **sign_with_seer_secret(body), - }, - ) - - response.raise_for_status() - result = response.json() - return result["run_id"] - - -def get_seer_run( - run_id: int, - organization: Organization, - user: User | AnonymousUser | None = None, - blocking: bool = False, - poll_interval: float = 2.0, - poll_timeout: float = 600.0, -) -> SeerRunState: - """ - Get the status/result of a Seer Explorer session. - - Args: - run_id: The run ID returned from start_seer_run() - organization: Sentry organization - user: User (for permission check) - blocking: If True, blocks until the run completes (with polling) - If False, returns current state immediately - poll_interval: Seconds between polls when blocking=True - poll_timeout: Maximum seconds to wait when blocking=True - - Returns: - SeerRunState: State object with blocks, status, etc. - - Raises: - SeerPermissionError: If the user/org doesn't have access to Seer Explorer - requests.HTTPError: If the Seer API request fails - TimeoutError: If polling exceeds poll_timeout when blocking=True - """ - # Check access - has_access, error = has_seer_explorer_access_with_detail(organization, user) - if not has_access: - raise SeerPermissionError(error or "Access denied") - - if blocking: - return poll_until_done(run_id, organization, poll_interval, poll_timeout) - - return fetch_run_status(run_id, organization) - - -def get_seer_runs( - organization: Organization, - user: User | AnonymousUser | None = None, - category_key: str | None = None, - category_value: str | None = None, - offset: int | None = None, - limit: int | None = None, -) -> list[ExplorerRun]: - """ - Get a list of Seer Explorer runs for the given organization with optional filters. - - This function supports flexible filtering by user_id, category_key, or category_value. - At least one filter should be provided to avoid returning all runs for the org. - - Args: - organization: Sentry organization - user: Optional user to filter runs by (if provided, only returns runs for this user) - category_key: Optional category key to filter by (e.g., "bug-fixer", "researcher") - category_value: Optional category value to filter by (e.g., "issue-123", "a5b32") - offset: Optional offset for pagination - limit: Optional limit for pagination - - Returns: - list[ExplorerRun]: List of runs matching the filters, sorted by most recent first - - Raises: - SeerPermissionError: If the user/org doesn't have access to Seer Explorer - requests.HTTPError: If the Seer API request fails - """ - has_access, error = has_seer_explorer_access_with_detail(organization, user) - if not has_access: - raise SeerPermissionError(error or "Access denied") - - path = "/v1/automation/explorer/runs" - - payload: dict[str, Any] = { - "organization_id": organization.id, - } - - # Add optional filters - if user and hasattr(user, "id"): - payload["user_id"] = user.id - if category_key is not None: - payload["category_key"] = category_key - if category_value is not None: - payload["category_value"] = category_value - if offset is not None: - payload["offset"] = offset - if limit is not None: - payload["limit"] = limit - - body = orjson.dumps(payload, option=orjson.OPT_NON_STR_KEYS) - - response = requests.post( - f"{settings.SEER_AUTOFIX_URL}{path}", - data=body, - headers={ - "content-type": "application/json;charset=utf-8", - **sign_with_seer_secret(body), - }, - ) - - response.raise_for_status() - result = response.json() - - runs = [ExplorerRun(**run) for run in result.get("data", [])] - return runs + def __init__( + self, + organization: Organization, + user: User | AnonymousUser | None = None, + artifact_schema: type[BaseModel] | None = None, + ): + self.organization = organization + self.user = user + self.artifact_schema = artifact_schema + + # Validate access on init + has_access, error = has_seer_explorer_access_with_detail(organization, user) + if not has_access: + raise SeerPermissionError(error or "Access denied") + + def start_run( + self, + prompt: str, + on_page_context: str | None = None, + category_key: str | None = None, + category_value: str | None = None, + ) -> int: + """ + Start a new Seer Explorer session. + + The client automatically collects user/org context (teams, projects, etc.) + and sends it to Seer for the agent to use. If artifact_schema was provided + in the constructor, it will be automatically included. + + Args: + prompt: The initial task/query for the agent + on_page_context: Optional context from the user's screen + category_key: Optional category key for filtering/grouping runs + category_value: Optional category value for filtering/grouping runs + + Returns: + int: The run ID that can be used to fetch results or continue the conversation + + Raises: + requests.HTTPError: If the Seer API request fails + """ + path = "/v1/automation/explorer/chat" + + payload: dict[str, Any] = { + "organization_id": self.organization.id, + "query": prompt, + "run_id": None, + "insert_index": None, + "on_page_context": on_page_context, + "user_org_context": collect_user_org_context(self.user, self.organization), + } + + # Add artifact schema if provided + if self.artifact_schema: + payload["artifact_schema"] = self.artifact_schema.schema() + + if category_key or category_value: + if not category_key or not category_value: + raise ValueError("category_key and category_value must be provided together") + payload["category_key"] = category_key + payload["category_value"] = category_value + + body = orjson.dumps(payload, option=orjson.OPT_NON_STR_KEYS) + + response = requests.post( + f"{settings.SEER_AUTOFIX_URL}{path}", + data=body, + headers={ + "content-type": "application/json;charset=utf-8", + **sign_with_seer_secret(body), + }, + ) + + response.raise_for_status() + result = response.json() + return result["run_id"] + + def continue_run( + self, + run_id: int, + prompt: str, + insert_index: int | None = None, + on_page_context: str | None = None, + ) -> int: + """ + Continue an existing Seer Explorer session. + + This allows you to add follow-up queries to an ongoing conversation. + User context is NOT collected again (it was already captured at start). + + Args: + run_id: The run ID from start_run() + prompt: The follow-up task/query for the agent + insert_index: Optional index to insert the message at + on_page_context: Optional context from the user's screen + + Returns: + int: The run ID (same as input) + + Raises: + requests.HTTPError: If the Seer API request fails + """ + path = "/v1/automation/explorer/chat" + + payload: dict[str, Any] = { + "organization_id": self.organization.id, + "query": prompt, + "run_id": run_id, + "insert_index": insert_index, + "on_page_context": on_page_context, + } + + body = orjson.dumps(payload, option=orjson.OPT_NON_STR_KEYS) + + response = requests.post( + f"{settings.SEER_AUTOFIX_URL}{path}", + data=body, + headers={ + "content-type": "application/json;charset=utf-8", + **sign_with_seer_secret(body), + }, + ) + + response.raise_for_status() + result = response.json() + return result["run_id"] + + def get_run( + self, + run_id: int, + blocking: bool = False, + poll_interval: float = 2.0, + poll_timeout: float = 600.0, + ) -> SeerRunState: + """ + Get the status/result of a Seer Explorer session. + + If artifact_schema was provided in the constructor and an artifact was generated, + it will be automatically reconstructed as a typed Pydantic instance. + + Args: + run_id: The run ID returned from start_run() + blocking: If True, blocks until the run completes (with polling) + poll_interval: Seconds between polls when blocking=True + poll_timeout: Maximum seconds to wait when blocking=True + + Returns: + SeerRunState: State object with blocks, status, and optionally reconstructed artifact + + Raises: + requests.HTTPError: If the Seer API request fails + TimeoutError: If polling exceeds poll_timeout when blocking=True + """ + if blocking: + state = poll_until_done(run_id, self.organization, poll_interval, poll_timeout) + else: + state = fetch_run_status(run_id, self.organization) + + # Automatically parse raw_artifact into typed artifact if schema was provided + if state.raw_artifact and self.artifact_schema: + try: + state.artifact = self.artifact_schema.parse_obj(state.raw_artifact) + state.raw_artifact = None # clear now that it's not needed + except ValidationError as e: + # Log but don't fail - keep artifact as None + state.artifact = None + logger.warning( + "Failed to parse artifact", + extra={ + "run_id": run_id, + "error": str(e), + "artifact_schema": self.artifact_schema.__name__, + "raw_artifact": state.raw_artifact, + }, + ) + + return state + + def get_runs( + self, + category_key: str | None = None, + category_value: str | None = None, + offset: int | None = None, + limit: int | None = None, + ) -> list[ExplorerRun]: + """ + Get a list of Seer Explorer runs for the organization with optional filters. + + This function supports flexible filtering by user_id (from client), category_key, + or category_value. At least one filter should be provided to avoid returning all runs. + + Args: + category_key: Optional category key to filter by (e.g., "bug-fixer") + category_value: Optional category value to filter by (e.g., "issue-123") + offset: Optional offset for pagination + limit: Optional limit for pagination + + Returns: + list[ExplorerRun]: List of runs matching the filters, sorted by most recent first + + Raises: + requests.HTTPError: If the Seer API request fails + """ + path = "/v1/automation/explorer/runs" + + payload: dict[str, Any] = { + "organization_id": self.organization.id, + } + + # Add optional filters + if self.user and hasattr(self.user, "id"): + payload["user_id"] = self.user.id + if category_key is not None: + payload["category_key"] = category_key + if category_value is not None: + payload["category_value"] = category_value + if offset is not None: + payload["offset"] = offset + if limit is not None: + payload["limit"] = limit + + body = orjson.dumps(payload, option=orjson.OPT_NON_STR_KEYS) + + response = requests.post( + f"{settings.SEER_AUTOFIX_URL}{path}", + data=body, + headers={ + "content-type": "application/json;charset=utf-8", + **sign_with_seer_secret(body), + }, + ) + + response.raise_for_status() + result = response.json() + + runs = [ExplorerRun(**run) for run in result.get("data", [])] + return runs diff --git a/src/sentry/seer/explorer/client_models.py b/src/sentry/seer/explorer/client_models.py index e17498c6302a5f..9650816a123dfc 100644 --- a/src/sentry/seer/explorer/client_models.py +++ b/src/sentry/seer/explorer/client_models.py @@ -5,7 +5,7 @@ from __future__ import annotations from datetime import datetime -from typing import Literal +from typing import Any, Literal from pydantic import BaseModel @@ -50,6 +50,9 @@ class SeerRunState(BaseModel): blocks: list[MemoryBlock] status: Literal["processing", "completed", "error"] updated_at: str + raw_artifact: dict[str, Any] | None = None + artifact: BaseModel | None = None + artifact_reason: str | None = None class Config: extra = "allow" diff --git a/tests/sentry/seer/endpoints/test_organization_seer_explorer_chat.py b/tests/sentry/seer/endpoints/test_organization_seer_explorer_chat.py index 757eed86529c36..e2bea4d9aceae2 100644 --- a/tests/sentry/seer/endpoints/test_organization_seer_explorer_chat.py +++ b/tests/sentry/seer/endpoints/test_organization_seer_explorer_chat.py @@ -1,5 +1,5 @@ from typing import Any -from unittest.mock import MagicMock, patch +from unittest.mock import ANY, MagicMock, patch from sentry.models.organizationmember import OrganizationMember from sentry.seer.explorer.client_utils import collect_user_org_context @@ -22,8 +22,8 @@ def test_get_without_run_id_returns_null_session(self) -> None: assert response.status_code == 404 assert response.data == {"session": None} - @patch("sentry.seer.endpoints.organization_seer_explorer_chat.get_seer_run") - def test_get_with_run_id_calls_client(self, mock_get_seer_run: MagicMock) -> None: + @patch("sentry.seer.endpoints.organization_seer_explorer_chat.SeerExplorerClient") + def test_get_with_run_id_calls_client(self, mock_client_class: MagicMock) -> None: from sentry.seer.explorer.client_models import SeerRunState # Mock client response @@ -33,14 +33,16 @@ def test_get_with_run_id_calls_client(self, mock_get_seer_run: MagicMock) -> Non status="completed", updated_at="2024-01-01T00:00:00Z", ) - mock_get_seer_run.return_value = mock_state + mock_client = MagicMock() + mock_client.get_run.return_value = mock_state + mock_client_class.return_value = mock_client response = self.client.get(f"{self.url}123/") assert response.status_code == 200 assert response.data["session"]["run_id"] == 123 assert response.data["session"]["status"] == "completed" - assert mock_get_seer_run.call_count == 1 + mock_client.get_run.assert_called_once_with(run_id=123) def test_post_without_query_returns_400(self) -> None: data: dict[str, Any] = {} @@ -54,9 +56,11 @@ def test_post_with_empty_query_returns_400(self) -> None: assert response.status_code == 400 - @patch("sentry.seer.endpoints.organization_seer_explorer_chat.start_seer_run") - def test_post_new_conversation_calls_client(self, mock_start_seer_run: MagicMock): - mock_start_seer_run.return_value = 456 + @patch("sentry.seer.endpoints.organization_seer_explorer_chat.SeerExplorerClient") + def test_post_new_conversation_calls_client(self, mock_client_class: MagicMock): + mock_client = MagicMock() + mock_client.start_run.return_value = 456 + mock_client_class.return_value = mock_client data = {"query": "What is this error about?"} response = self.client.post(self.url, data, format="json") @@ -64,18 +68,17 @@ def test_post_new_conversation_calls_client(self, mock_start_seer_run: MagicMock assert response.status_code == 200 assert response.data == {"run_id": 456} - # Verify client was called - assert mock_start_seer_run.call_count == 1 - call_kwargs = mock_start_seer_run.call_args[1] - assert call_kwargs["organization"] == self.organization - assert call_kwargs["prompt"] == "What is this error about?" - assert call_kwargs["on_page_context"] is None + # Verify client was called correctly + mock_client_class.assert_called_once_with(self.organization, ANY) + mock_client.start_run.assert_called_once_with( + prompt="What is this error about?", on_page_context=None + ) - @patch("sentry.seer.endpoints.organization_seer_explorer_chat.continue_seer_run") - def test_post_continue_conversation_calls_client( - self, mock_continue_seer_run: MagicMock - ) -> None: - mock_continue_seer_run.return_value = 789 + @patch("sentry.seer.endpoints.organization_seer_explorer_chat.SeerExplorerClient") + def test_post_continue_conversation_calls_client(self, mock_client_class: MagicMock) -> None: + mock_client = MagicMock() + mock_client.continue_run.return_value = 789 + mock_client_class.return_value = mock_client data = { "query": "Follow up question", @@ -86,14 +89,11 @@ def test_post_continue_conversation_calls_client( assert response.status_code == 200 assert response.data == {"run_id": 789} - # Verify client was called - assert mock_continue_seer_run.call_count == 1 - call_kwargs = mock_continue_seer_run.call_args[1] - assert call_kwargs["organization"] == self.organization - assert call_kwargs["prompt"] == "Follow up question" - assert call_kwargs["run_id"] == 789 - assert call_kwargs["insert_index"] == 2 - assert call_kwargs["on_page_context"] is None + # Verify client was called correctly + mock_client_class.assert_called_once_with(self.organization, ANY) + mock_client.continue_run.assert_called_once_with( + run_id=789, prompt="Follow up question", insert_index=2, on_page_context=None + ) class CollectUserOrgContextTest(APITestCase): diff --git a/tests/sentry/seer/endpoints/test_organization_seer_explorer_runs.py b/tests/sentry/seer/endpoints/test_organization_seer_explorer_runs.py index 0fb61ae8eb141e..40bd02175b818f 100644 --- a/tests/sentry/seer/endpoints/test_organization_seer_explorer_runs.py +++ b/tests/sentry/seer/endpoints/test_organization_seer_explorer_runs.py @@ -1,5 +1,5 @@ from datetime import datetime -from unittest.mock import patch +from unittest.mock import ANY, MagicMock, patch import requests from django.urls import reverse @@ -26,18 +26,20 @@ def setUp(self) -> None: return_value=(True, None), ) self.seer_access_patcher.start() - self.get_seer_runs_patcher = patch( - "sentry.seer.endpoints.organization_seer_explorer_runs.get_seer_runs" + self.client_patcher = patch( + "sentry.seer.endpoints.organization_seer_explorer_runs.SeerExplorerClient" ) - self.get_seer_runs = self.get_seer_runs_patcher.start() + self.mock_client_class = self.client_patcher.start() + self.mock_client = MagicMock() + self.mock_client_class.return_value = self.mock_client def tearDown(self) -> None: self.seer_access_patcher.stop() - self.get_seer_runs_patcher.stop() + self.client_patcher.stop() super().tearDown() def test_get_simple(self) -> None: - self.get_seer_runs.return_value = [ + self.mock_client.get_runs.return_value = [ ExplorerRun( run_id=1, title="Run 1", @@ -58,18 +60,17 @@ def test_get_simple(self) -> None: assert data[0]["run_id"] == 1 assert data[1]["run_id"] == 2 - self.get_seer_runs.assert_called_once() - call_args = self.get_seer_runs.call_args - assert call_args.kwargs["organization"] == self.organization - assert call_args.kwargs["user"].id == self.user.id - assert call_args.kwargs["limit"] == 101 # Default per_page of 100 + 1 for has_more - assert call_args.kwargs["offset"] == 0 - assert call_args.kwargs["category_key"] is None - assert call_args.kwargs["category_value"] is None + self.mock_client_class.assert_called_once_with(self.organization, ANY) + self.mock_client.get_runs.assert_called_once_with( + category_key=None, + category_value=None, + offset=0, + limit=101, # Default per_page of 100 + 1 for has_more + ) def test_get_cursor_pagination(self) -> None: # Mock seer response for offset 0, limit 3. - self.get_seer_runs.return_value = [ + self.mock_client.get_runs.return_value = [ ExplorerRun( run_id=1, title="Run 1", @@ -98,17 +99,12 @@ def test_get_cursor_pagination(self) -> None: assert data[1]["run_id"] == 2 assert 'rel="next"; results="true"' in response.headers["Link"] - self.get_seer_runs.assert_called_once() - call_args = self.get_seer_runs.call_args - assert call_args.kwargs["organization"] == self.organization - assert call_args.kwargs["user"].id == self.user.id - assert call_args.kwargs["limit"] == 3 # +1 for has_more - assert call_args.kwargs["offset"] == 0 - assert call_args.kwargs["category_key"] is None - assert call_args.kwargs["category_value"] is None + self.mock_client.get_runs.assert_called_once_with( + category_key=None, category_value=None, offset=0, limit=3 + ) # Second page - mock seer response for offset 2, limit 3. - self.get_seer_runs.return_value = [ + self.mock_client.get_runs.return_value = [ ExplorerRun( run_id=3, title="Run 3", @@ -131,21 +127,19 @@ def test_get_cursor_pagination(self) -> None: assert data[1]["run_id"] == 4 assert 'rel="next"; results="false"' in response.headers["Link"] - call_args = self.get_seer_runs.call_args - assert call_args.kwargs["organization"] == self.organization - assert call_args.kwargs["user"].id == self.user.id - assert call_args.kwargs["limit"] == 3 # +1 for has_more + # Verify second call + assert self.mock_client.get_runs.call_count == 2 + call_args = self.mock_client.get_runs.call_args assert call_args.kwargs["offset"] == 2 - assert call_args.kwargs["category_key"] is None - assert call_args.kwargs["category_value"] is None + assert call_args.kwargs["limit"] == 3 def test_get_with_seer_error(self) -> None: - self.get_seer_runs.side_effect = requests.HTTPError("API Error") + self.mock_client.get_runs.side_effect = requests.HTTPError("API Error") response = self.client.get(self.url) assert response.status_code == 500 def test_get_with_category_key_filter(self) -> None: - self.get_seer_runs.return_value = [ + self.mock_client.get_runs.return_value = [ ExplorerRun( run_id=1, title="Run 1", @@ -161,12 +155,12 @@ def test_get_with_category_key_filter(self) -> None: assert len(data) == 1 assert data[0]["run_id"] == 1 - call_args = self.get_seer_runs.call_args + call_args = self.mock_client.get_runs.call_args assert call_args.kwargs["category_key"] == "bug-fixer" assert call_args.kwargs["category_value"] is None def test_get_with_category_value_filter(self) -> None: - self.get_seer_runs.return_value = [ + self.mock_client.get_runs.return_value = [ ExplorerRun( run_id=2, title="Run 2", @@ -182,12 +176,12 @@ def test_get_with_category_value_filter(self) -> None: assert len(data) == 1 assert data[0]["run_id"] == 2 - call_args = self.get_seer_runs.call_args + call_args = self.mock_client.get_runs.call_args assert call_args.kwargs["category_key"] is None assert call_args.kwargs["category_value"] == "issue-123" def test_get_with_both_category_filters(self) -> None: - self.get_seer_runs.return_value = [ + self.mock_client.get_runs.return_value = [ ExplorerRun( run_id=3, title="Run 3", @@ -203,12 +197,12 @@ def test_get_with_both_category_filters(self) -> None: assert len(data) == 1 assert data[0]["run_id"] == 3 - call_args = self.get_seer_runs.call_args + call_args = self.mock_client.get_runs.call_args assert call_args.kwargs["category_key"] == "bug-fixer" assert call_args.kwargs["category_value"] == "issue-123" def test_get_with_category_filters_and_pagination(self) -> None: - self.get_seer_runs.return_value = [ + self.mock_client.get_runs.return_value = [ ExplorerRun( run_id=1, title="Run 1", @@ -236,7 +230,7 @@ def test_get_with_category_filters_and_pagination(self) -> None: data = response.json()["data"] assert len(data) == 2 - call_args = self.get_seer_runs.call_args + call_args = self.mock_client.get_runs.call_args assert call_args.kwargs["category_key"] == "bug-fixer" assert call_args.kwargs["category_value"] == "issue-123" assert call_args.kwargs["limit"] == 3 # +1 for has_more @@ -255,7 +249,7 @@ def setUp(self) -> None: def test_missing_gen_ai_features_flag(self) -> None: with self.feature({"organizations:seer-explorer": True}): with patch( - "sentry.seer.endpoints.organization_seer_explorer_runs.get_seer_runs", + "sentry.seer.endpoints.organization_seer_explorer_runs.SeerExplorerClient", side_effect=SeerPermissionError("Feature flag not enabled"), ): response = self.client.get(self.url) @@ -265,7 +259,7 @@ def test_missing_gen_ai_features_flag(self) -> None: def test_missing_seer_explorer_flag(self) -> None: with self.feature({"organizations:gen-ai-features": True}): with patch( - "sentry.seer.endpoints.organization_seer_explorer_runs.get_seer_runs", + "sentry.seer.endpoints.organization_seer_explorer_runs.SeerExplorerClient", side_effect=SeerPermissionError("Feature flag not enabled"), ): response = self.client.get(self.url) @@ -277,7 +271,7 @@ def test_missing_seer_acknowledgement(self) -> None: {"organizations:gen-ai-features": True, "organizations:seer-explorer": True} ): with patch( - "sentry.seer.endpoints.organization_seer_explorer_runs.get_seer_runs", + "sentry.seer.endpoints.organization_seer_explorer_runs.SeerExplorerClient", side_effect=SeerPermissionError( "Seer has not been acknowledged by the organization." ), @@ -293,7 +287,7 @@ def test_missing_allow_joinleave_org_flag(self) -> None: {"organizations:gen-ai-features": True, "organizations:seer-explorer": True} ): with patch( - "sentry.seer.endpoints.organization_seer_explorer_runs.get_seer_runs", + "sentry.seer.endpoints.organization_seer_explorer_runs.SeerExplorerClient", side_effect=SeerPermissionError( "Organization does not have open team membership enabled. Seer requires this to aggregate context across all projects and allow members to ask questions freely." ), diff --git a/tests/sentry/seer/explorer/test_explorer_client.py b/tests/sentry/seer/explorer/test_explorer_client.py index 3cb98bd2184a79..0e0f25c39f1257 100644 --- a/tests/sentry/seer/explorer/test_explorer_client.py +++ b/tests/sentry/seer/explorer/test_explorer_client.py @@ -1,40 +1,65 @@ from unittest.mock import MagicMock, patch +import orjson import pytest import requests +from pydantic import BaseModel -from sentry.seer.explorer.client import ( - continue_seer_run, - get_seer_run, - get_seer_runs, - start_seer_run, -) +from sentry.seer.explorer.client import SeerExplorerClient from sentry.seer.explorer.client_models import SeerRunState +from sentry.seer.models import SeerPermissionError from sentry.testutils.cases import TestCase -class TestStartSeerRun(TestCase): +class TestSeerExplorerClient(TestCase): def setUp(self): super().setUp() self.user = self.create_user() self.organization = self.create_organization(owner=self.user) + @patch("sentry.seer.explorer.client.has_seer_explorer_access_with_detail") + def test_client_init_checks_access(self, mock_access): + """Test that client initialization checks access and raises on denial""" + mock_access.return_value = (False, "Feature flag not enabled") + + with pytest.raises(SeerPermissionError) as exc_info: + SeerExplorerClient(self.organization, self.user) + assert "Feature flag not enabled" in str(exc_info.value) + + @patch("sentry.seer.explorer.client.has_seer_explorer_access_with_detail") + def test_client_init_succeeds_with_access(self, mock_access): + """Test that client initialization succeeds with proper access""" + mock_access.return_value = (True, None) + + client = SeerExplorerClient(self.organization, self.user) + assert client.organization == self.organization + assert client.user == self.user + assert client.artifact_schema is None + + @patch("sentry.seer.explorer.client.has_seer_explorer_access_with_detail") + def test_client_init_with_artifact_schema(self, mock_access): + """Test that client stores artifact schema""" + mock_access.return_value = (True, None) + + class TestSchema(BaseModel): + count: int + + client = SeerExplorerClient(self.organization, self.user, artifact_schema=TestSchema) + assert client.artifact_schema == TestSchema + @patch("sentry.seer.explorer.client.has_seer_explorer_access_with_detail") @patch("sentry.seer.explorer.client.requests.post") @patch("sentry.seer.explorer.client.collect_user_org_context") - def test_start_seer_run_new_session(self, mock_collect_context, mock_post, mock_access): - """Test starting a new Seer run collects user context""" + def test_start_run_basic(self, mock_collect_context, mock_post, mock_access): + """Test starting a new run collects user context""" mock_access.return_value = (True, None) mock_collect_context.return_value = {"user_id": self.user.id} mock_response = MagicMock() mock_response.json.return_value = {"run_id": 123} mock_post.return_value = mock_response - run_id = start_seer_run( - organization=self.organization, - prompt="Test query", - user=self.user, - ) + client = SeerExplorerClient(self.organization, self.user) + run_id = client.start_run("Test query") assert run_id == 123 mock_collect_context.assert_called_once_with(self.user, self.organization) @@ -42,58 +67,35 @@ def test_start_seer_run_new_session(self, mock_collect_context, mock_post, mock_ @patch("sentry.seer.explorer.client.has_seer_explorer_access_with_detail") @patch("sentry.seer.explorer.client.requests.post") - def test_start_seer_run_with_optional_params(self, mock_post, mock_access): - """Test starting a run with optional on_page_context""" + def test_start_run_with_optional_params(self, mock_post, mock_access): + """Test starting a run with optional parameters""" mock_access.return_value = (True, None) mock_response = MagicMock() mock_response.json.return_value = {"run_id": 789} mock_post.return_value = mock_response - run_id = start_seer_run( - organization=self.organization, - prompt="Query", - user=self.user, - on_page_context="some context", - ) + client = SeerExplorerClient(self.organization, self.user) + run_id = client.start_run("Query", on_page_context="some context") assert run_id == 789 - # Verify the payload includes optional params call_args = mock_post.call_args assert call_args is not None @patch("sentry.seer.explorer.client.has_seer_explorer_access_with_detail") @patch("sentry.seer.explorer.client.requests.post") - def test_start_seer_run_http_error(self, mock_post, mock_access): + def test_start_run_http_error(self, mock_post, mock_access): """Test that HTTP errors are propagated""" mock_access.return_value = (True, None) mock_post.return_value.raise_for_status.side_effect = requests.HTTPError("API Error") + client = SeerExplorerClient(self.organization, self.user) with pytest.raises(requests.HTTPError): - start_seer_run( - organization=self.organization, - prompt="Test query", - user=self.user, - ) - - @patch("sentry.seer.explorer.client.has_seer_explorer_access_with_detail") - def test_start_seer_run_permission_denied(self, mock_access): - """Test that SeerPermissionError is raised when access is denied""" - from sentry.seer.models import SeerPermissionError - - mock_access.return_value = (False, "Feature flag not enabled") - - with pytest.raises(SeerPermissionError) as exc_info: - start_seer_run( - organization=self.organization, - prompt="Test query", - user=self.user, - ) - assert "Feature flag not enabled" in str(exc_info.value) + client.start_run("Test query") @patch("sentry.seer.explorer.client.has_seer_explorer_access_with_detail") @patch("sentry.seer.explorer.client.requests.post") @patch("sentry.seer.explorer.client.collect_user_org_context") - def test_start_seer_run_with_categories(self, mock_collect_context, mock_post, mock_access): + def test_start_run_with_categories(self, mock_collect_context, mock_post, mock_access): """Test starting a run with category fields""" mock_access.return_value = (True, None) mock_collect_context.return_value = {"user_id": self.user.id} @@ -101,92 +103,62 @@ def test_start_seer_run_with_categories(self, mock_collect_context, mock_post, m mock_response.json.return_value = {"run_id": 999} mock_post.return_value = mock_response - run_id = start_seer_run( - organization=self.organization, - prompt="Fix bug", - user=self.user, - category_key="bug-fixer", - category_value="issue-123", - ) + client = SeerExplorerClient(self.organization, self.user) + run_id = client.start_run("Fix bug", category_key="bug-fixer", category_value="issue-123") assert run_id == 999 - import orjson - body = orjson.loads(mock_post.call_args[1]["data"]) assert body["category_key"] == "bug-fixer" assert body["category_value"] == "issue-123" @patch("sentry.seer.explorer.client.has_seer_explorer_access_with_detail") - def test_start_seer_run_category_key_only_raises_error(self, mock_access): + def test_start_run_category_key_only_raises_error(self, mock_access): """Test that ValueError is raised when only category_key is provided""" mock_access.return_value = (True, None) + client = SeerExplorerClient(self.organization, self.user) with pytest.raises( ValueError, match="category_key and category_value must be provided together" ): - start_seer_run( - organization=self.organization, - prompt="Test query", - user=self.user, - category_key="bug-fixer", - ) + client.start_run("Test query", category_key="bug-fixer") @patch("sentry.seer.explorer.client.has_seer_explorer_access_with_detail") - def test_start_seer_run_category_value_only_raises_error(self, mock_access): + def test_start_run_category_value_only_raises_error(self, mock_access): """Test that ValueError is raised when only category_value is provided""" mock_access.return_value = (True, None) + client = SeerExplorerClient(self.organization, self.user) with pytest.raises( ValueError, match="category_key and category_value must be provided together" ): - start_seer_run( - organization=self.organization, - prompt="Test query", - user=self.user, - category_value="issue-123", - ) - - -class TestContinueSeerRun(TestCase): - def setUp(self): - super().setUp() - self.user = self.create_user() - self.organization = self.create_organization(owner=self.user) + client.start_run("Test query", category_value="issue-123") @patch("sentry.seer.explorer.client.has_seer_explorer_access_with_detail") @patch("sentry.seer.explorer.client.requests.post") - def test_continue_seer_run_basic(self, mock_post, mock_access): - """Test continuing an existing Seer run""" + def test_continue_run_basic(self, mock_post, mock_access): + """Test continuing an existing run""" mock_access.return_value = (True, None) mock_response = MagicMock() mock_response.json.return_value = {"run_id": 456} mock_post.return_value = mock_response - run_id = continue_seer_run( - run_id=456, - organization=self.organization, - prompt="Follow up query", - ) + client = SeerExplorerClient(self.organization, self.user) + run_id = client.continue_run(456, "Follow up query") assert run_id == 456 assert mock_post.called @patch("sentry.seer.explorer.client.has_seer_explorer_access_with_detail") @patch("sentry.seer.explorer.client.requests.post") - def test_continue_seer_run_with_all_params(self, mock_post, mock_access): + def test_continue_run_with_all_params(self, mock_post, mock_access): """Test continuing a run with all optional parameters""" mock_access.return_value = (True, None) mock_response = MagicMock() mock_response.json.return_value = {"run_id": 789} mock_post.return_value = mock_response - run_id = continue_seer_run( - run_id=789, - organization=self.organization, - prompt="Follow up", - insert_index=2, - on_page_context="context", - ) + client = SeerExplorerClient(self.organization, self.user) + run_id = client.continue_run(789, "Follow up", insert_index=2, on_page_context="context") assert run_id == 789 call_args = mock_post.call_args @@ -194,27 +166,18 @@ def test_continue_seer_run_with_all_params(self, mock_post, mock_access): @patch("sentry.seer.explorer.client.has_seer_explorer_access_with_detail") @patch("sentry.seer.explorer.client.requests.post") - def test_continue_seer_run_http_error(self, mock_post, mock_access): + def test_continue_run_http_error(self, mock_post, mock_access): """Test that HTTP errors are propagated""" mock_access.return_value = (True, None) mock_post.return_value.raise_for_status.side_effect = requests.HTTPError("API Error") + client = SeerExplorerClient(self.organization, self.user) with pytest.raises(requests.HTTPError): - continue_seer_run( - run_id=123, - organization=self.organization, - prompt="Test query", - ) - - -class TestGetSeerRun(TestCase): - def setUp(self): - super().setUp() - self.organization = self.create_organization() + client.continue_run(123, "Test query") @patch("sentry.seer.explorer.client.has_seer_explorer_access_with_detail") @patch("sentry.seer.explorer.client.fetch_run_status") - def test_get_seer_run_immediate(self, mock_fetch, mock_access): + def test_get_run_immediate(self, mock_fetch, mock_access): """Test getting run status without waiting""" mock_access.return_value = (True, None) mock_state = SeerRunState( @@ -225,7 +188,8 @@ def test_get_seer_run_immediate(self, mock_fetch, mock_access): ) mock_fetch.return_value = mock_state - result = get_seer_run(run_id=123, organization=self.organization) + client = SeerExplorerClient(self.organization, self.user) + result = client.get_run(123) assert result.run_id == 123 assert result.status == "processing" @@ -233,7 +197,7 @@ def test_get_seer_run_immediate(self, mock_fetch, mock_access): @patch("sentry.seer.explorer.client.has_seer_explorer_access_with_detail") @patch("sentry.seer.explorer.client.poll_until_done") - def test_get_seer_run_with_blocking(self, mock_poll, mock_access): + def test_get_run_with_blocking(self, mock_poll, mock_access): """Test getting run status with polling""" mock_access.return_value = (True, None) mock_state = SeerRunState( @@ -244,13 +208,8 @@ def test_get_seer_run_with_blocking(self, mock_poll, mock_access): ) mock_poll.return_value = mock_state - result = get_seer_run( - run_id=123, - organization=self.organization, - blocking=True, - poll_interval=1.0, - poll_timeout=30.0, - ) + client = SeerExplorerClient(self.organization, self.user) + result = client.get_run(123, blocking=True, poll_interval=1.0, poll_timeout=30.0) assert result.run_id == 123 assert result.status == "completed" @@ -258,24 +217,18 @@ def test_get_seer_run_with_blocking(self, mock_poll, mock_access): @patch("sentry.seer.explorer.client.has_seer_explorer_access_with_detail") @patch("sentry.seer.explorer.client.fetch_run_status") - def test_get_seer_run_http_error(self, mock_fetch, mock_access): + def test_get_run_http_error(self, mock_fetch, mock_access): """Test that HTTP errors are propagated""" mock_access.return_value = (True, None) mock_fetch.side_effect = requests.HTTPError("API Error") + client = SeerExplorerClient(self.organization, self.user) with pytest.raises(requests.HTTPError): - get_seer_run(run_id=123, organization=self.organization) - - -class TestGetSeerRuns(TestCase): - def setUp(self): - super().setUp() - self.user = self.create_user() - self.organization = self.create_organization(owner=self.user) + client.get_run(123) @patch("sentry.seer.explorer.client.has_seer_explorer_access_with_detail") @patch("sentry.seer.explorer.client.requests.post") - def test_get_seer_runs_basic(self, mock_post, mock_access): + def test_get_runs_basic(self, mock_post, mock_access): """Test getting runs with filters""" mock_access.return_value = (True, None) mock_response = MagicMock() @@ -293,16 +246,103 @@ def test_get_seer_runs_basic(self, mock_post, mock_access): } mock_post.return_value = mock_response - runs = get_seer_runs( - organization=self.organization, - category_key="bug-fixer", - category_value="issue-123", - ) + client = SeerExplorerClient(self.organization, self.user) + runs = client.get_runs(category_key="bug-fixer", category_value="issue-123") assert len(runs) == 1 assert runs[0].category_key == "bug-fixer" - import orjson - body = orjson.loads(mock_post.call_args[1]["data"]) assert body["category_key"] == "bug-fixer" assert body["category_value"] == "issue-123" + + +class TestSeerExplorerClientArtifacts(TestCase): + """Test artifact schema passing and reconstruction""" + + def setUp(self): + super().setUp() + self.user = self.create_user() + self.organization = self.create_organization(owner=self.user) + + @patch("sentry.seer.explorer.client.has_seer_explorer_access_with_detail") + @patch("sentry.seer.explorer.client.requests.post") + @patch("sentry.seer.explorer.client.collect_user_org_context") + def test_start_run_with_artifact_schema(self, mock_collect_context, mock_post, mock_access): + """Test that artifact schema is serialized and sent to API""" + mock_access.return_value = (True, None) + mock_collect_context.return_value = {"user_id": self.user.id} + mock_response = MagicMock() + mock_response.json.return_value = {"run_id": 123} + mock_post.return_value = mock_response + + class IssueAnalysis(BaseModel): + issue_count: int + severity: str + + client = SeerExplorerClient(self.organization, self.user, artifact_schema=IssueAnalysis) + run_id = client.start_run("Analyze errors") + + assert run_id == 123 + + # Verify artifact_schema was included in payload + body = orjson.loads(mock_post.call_args[1]["data"]) + assert "artifact_schema" in body + assert body["artifact_schema"]["type"] == "object" + assert "issue_count" in body["artifact_schema"]["properties"] + assert "severity" in body["artifact_schema"]["properties"] + + @patch("sentry.seer.explorer.client.has_seer_explorer_access_with_detail") + @patch("sentry.seer.explorer.client.fetch_run_status") + def test_get_run_reconstructs_artifact(self, mock_fetch, mock_access): + """Test that artifact is automatically reconstructed from dict""" + mock_access.return_value = (True, None) + + class BugReport(BaseModel): + bug_count: int + severity: str + + # Mock API returns dict artifact + mock_state = SeerRunState( + run_id=123, + blocks=[], + status="completed", + updated_at="2024-01-01T00:00:00Z", + raw_artifact={"bug_count": 5, "severity": "high"}, # Raw dict from API + artifact_reason="Successfully generated", + ) + mock_fetch.return_value = mock_state + + client = SeerExplorerClient(self.organization, self.user, artifact_schema=BugReport) + result = client.get_run(123) + + # Verify artifact was reconstructed as Pydantic model + assert isinstance(result.artifact, BugReport) + assert result.artifact.bug_count == 5 + assert result.artifact.severity == "high" + assert result.artifact_reason == "Successfully generated" + + @patch("sentry.seer.explorer.client.has_seer_explorer_access_with_detail") + @patch("sentry.seer.explorer.client.fetch_run_status") + def test_get_run_with_none_artifact(self, mock_fetch, mock_access): + """Test that None artifact is handled gracefully""" + mock_access.return_value = (True, None) + + class MySchema(BaseModel): + field: str + + mock_state = SeerRunState( + run_id=123, + blocks=[], + status="completed", + updated_at="2024-01-01T00:00:00Z", + raw_artifact=None, + artifact_reason="Generation failed", + ) + mock_fetch.return_value = mock_state + + client = SeerExplorerClient(self.organization, self.user, artifact_schema=MySchema) + result = client.get_run(123) + + # Verify None artifact is preserved + assert result.artifact is None + assert result.artifact_reason == "Generation failed"