diff --git a/src/sentry/options/defaults.py b/src/sentry/options/defaults.py index a8ddd368033de0..d159cbf34905ac 100644 --- a/src/sentry/options/defaults.py +++ b/src/sentry/options/defaults.py @@ -3151,6 +3151,20 @@ flags=FLAG_AUTOMATOR_MODIFIABLE, ) +register( + "workflow_engine.num_cohorts", + type=Int, + default=1, + flags=FLAG_AUTOMATOR_MODIFIABLE, +) + +register( + "workflow_engine.use_cohort_selection", + type=Bool, + default=True, + flags=FLAG_AUTOMATOR_MODIFIABLE, +) + # Restrict uptime issue creation for specific host provider identifiers. Items # in this list map to the `host_provider_id` column in the UptimeSubscription # table. diff --git a/src/sentry/workflow_engine/buffer/batch_client.py b/src/sentry/workflow_engine/buffer/batch_client.py index 360184ae0bd3a4..5f5f02f2ead2a7 100644 --- a/src/sentry/workflow_engine/buffer/batch_client.py +++ b/src/sentry/workflow_engine/buffer/batch_client.py @@ -4,6 +4,8 @@ from collections.abc import Mapping from typing import TYPE_CHECKING +import pydantic + import sentry.workflow_engine.buffer as buffer from sentry.workflow_engine.models import Workflow @@ -11,6 +13,10 @@ from sentry.workflow_engine.buffer.redis_hash_sorted_set_buffer import RedisHashSortedSetBuffer +class CohortUpdates(pydantic.BaseModel): + values: dict[int, float] + + class DelayedWorkflowClient: """ Client for interacting with batch processing of delayed workflows. @@ -69,6 +75,16 @@ def _get_buffer_keys(cls) -> list[str]: for shard in range(cls._BUFFER_SHARDS) ] + _COHORT_UPDATES_KEY = "WORKFLOW_ENGINE_COHORT_UPDATES" + + def fetch_updates(self) -> CohortUpdates: + return self._buffer.get_parsed_key( + self._COHORT_UPDATES_KEY, CohortUpdates + ) or CohortUpdates(values={}) + + def persist_updates(self, cohort_updates: CohortUpdates) -> None: + self._buffer.put_parsed_key(self._COHORT_UPDATES_KEY, cohort_updates) + def for_project(self, project_id: int) -> ProjectDelayedWorkflowClient: """Create a project-specific client for workflow operations.""" return ProjectDelayedWorkflowClient(project_id, self._buffer) diff --git a/src/sentry/workflow_engine/buffer/redis_hash_sorted_set_buffer.py b/src/sentry/workflow_engine/buffer/redis_hash_sorted_set_buffer.py index 4cd3802a8f64bd..33f986f2494fd7 100644 --- a/src/sentry/workflow_engine/buffer/redis_hash_sorted_set_buffer.py +++ b/src/sentry/workflow_engine/buffer/redis_hash_sorted_set_buffer.py @@ -6,6 +6,7 @@ from collections.abc import Callable, Iterable, Mapping, Sequence from typing import Any, TypeAlias, TypeVar +import pydantic import rb from redis.client import Pipeline @@ -50,6 +51,8 @@ def _by_pairs(seq: list[T]) -> Iterable[tuple[T, T]]: "zrangebyscore": False, "zrem": True, "zremrangebyscore": True, + "set": True, + "get": False, } @@ -418,3 +421,12 @@ def _conditional_delete_rb_fallback( converted_results.update(host_parsed) return converted_results + + def get_parsed_key[T: pydantic.BaseModel](self, key: str, model: type[T]) -> T | None: + value = self._execute_redis_operation(key, "get") + if value is None: + return None + return model.parse_raw(value) + + def put_parsed_key[T: pydantic.BaseModel](self, key: str, value: T) -> None: + self._execute_redis_operation(key, "set", value.json()) diff --git a/src/sentry/workflow_engine/processors/schedule.py b/src/sentry/workflow_engine/processors/schedule.py index 8713ce62488421..3d1b0192c65b21 100644 --- a/src/sentry/workflow_engine/processors/schedule.py +++ b/src/sentry/workflow_engine/processors/schedule.py @@ -1,13 +1,17 @@ +import hashlib import logging import math import uuid -from datetime import datetime, timezone +from collections.abc import Generator +from contextlib import contextmanager +from datetime import datetime, timedelta, timezone from itertools import islice from sentry import options from sentry.utils import metrics from sentry.utils.iterators import chunked from sentry.workflow_engine.buffer.batch_client import ( + CohortUpdates, DelayedWorkflowClient, ProjectDelayedWorkflowClient, ) @@ -79,6 +83,73 @@ def process_in_batches(client: ProjectDelayedWorkflowClient) -> None: ) +class ProjectChooser: + """ + ProjectChooser assists in determining which projects to process based on the cohort updates. + """ + + def __init__(self, buffer_client: DelayedWorkflowClient, num_cohorts: int): + self.client = buffer_client + assert num_cohorts > 0 and num_cohorts <= 255 + self.num_cohorts = num_cohorts + + def _project_id_to_cohort(self, project_id: int) -> int: + return hashlib.sha256(project_id.to_bytes(8)).digest()[0] % self.num_cohorts + + def project_ids_to_process( + self, fetch_time: float, cohort_updates: CohortUpdates, all_project_ids: list[int] + ) -> list[int]: + """ + Given the time, the cohort update history, and the list of project ids in need of processing, + determine which project ids should be processed. + """ + must_process = set[int]() + may_process = set[int]() + now = fetch_time + long_ago = now - 1000 + for co in range(self.num_cohorts): + last_run = cohort_updates.values.get(co, long_ago) + elapsed = timedelta(seconds=now - last_run) + if elapsed >= timedelta(minutes=1): + must_process.add(co) + elif elapsed >= timedelta(seconds=60 / self.num_cohorts): + may_process.add(co) + if may_process and not must_process: + choice = min(may_process, key=lambda c: (cohort_updates.values.get(c, long_ago), c)) + must_process.add(choice) + cohort_updates.values.update({cohort_id: fetch_time for cohort_id in must_process}) + return [ + project_id + for project_id in all_project_ids + if self._project_id_to_cohort(project_id) in must_process + ] + + +@contextmanager +def chosen_projects( + project_chooser: ProjectChooser | None, + fetch_time: float, + all_project_ids: list[int], +) -> Generator[list[int]]: + """ + Context manager that yields the project ids to be processed, and manages the + cohort state after the processing is complete. + + If project_chooser is None, all projects are yielded without cohort-based selection. + """ + if project_chooser is None: + # No cohort selection - process all projects + yield all_project_ids + return + + cohort_updates = project_chooser.client.fetch_updates() + project_ids_to_process = project_chooser.project_ids_to_process( + fetch_time, cohort_updates, all_project_ids + ) + yield project_ids_to_process + project_chooser.client.persist_updates(cohort_updates) + + def process_buffered_workflows(buffer_client: DelayedWorkflowClient) -> None: option_name = buffer_client.option if option_name and not options.get(option_name): @@ -92,45 +163,58 @@ def process_buffered_workflows(buffer_client: DelayedWorkflowClient) -> None: max=fetch_time, ) - metrics.distribution( - "workflow_engine.schedule.projects", len(all_project_ids_and_timestamps) - ) - logger.info( - "delayed_workflow.project_id_list", - extra={"project_ids": sorted(all_project_ids_and_timestamps.keys())}, + # Check if cohort-based selection is enabled (defaults to True for safety) + use_cohort_selection = options.get("workflow_engine.use_cohort_selection", True) + project_chooser = ( + ProjectChooser(buffer_client, num_cohorts=options.get("workflow_engine.num_cohorts", 1)) + if use_cohort_selection + else None ) - project_ids = list(all_project_ids_and_timestamps.keys()) - for project_id in project_ids: - process_in_batches(buffer_client.for_project(project_id)) + with chosen_projects( + project_chooser, fetch_time, list(all_project_ids_and_timestamps.keys()) + ) as project_ids_to_process: + metrics.distribution("workflow_engine.schedule.projects", len(project_ids_to_process)) + logger.info( + "delayed_workflow.project_id_list", + extra={"project_ids": sorted(project_ids_to_process)}, + ) + + for project_id in project_ids_to_process: + process_in_batches(buffer_client.for_project(project_id)) - mark_projects_processed(buffer_client, all_project_ids_and_timestamps) + mark_projects_processed( + buffer_client, project_ids_to_process, all_project_ids_and_timestamps + ) def mark_projects_processed( buffer_client: DelayedWorkflowClient, + processed_project_ids: list[int], all_project_ids_and_timestamps: dict[int, list[float]], ) -> None: if not all_project_ids_and_timestamps: return with metrics.timer("workflow_engine.scheduler.mark_projects_processed"): - member_maxes = [ + processed_member_maxes = [ (project_id, max(timestamps)) for project_id, timestamps in all_project_ids_and_timestamps.items() + if project_id in processed_project_ids ] deleted_project_ids = set[int]() # The conditional delete can be slow, so we break it into chunks that probably # aren't big enough to hold onto the main redis thread for too long. - for chunk in chunked(member_maxes, 500): + for chunk in chunked(processed_member_maxes, 500): with metrics.timer( "workflow_engine.conditional_delete_from_sorted_sets.chunk_duration" ): deleted = buffer_client.mark_project_ids_as_processed(dict(chunk)) deleted_project_ids.update(deleted) - logger.info( - "process_buffered_workflows.project_ids_deleted", - extra={ - "deleted_project_ids": sorted(deleted_project_ids), - }, - ) + logger.info( + "process_buffered_workflows.project_ids_deleted", + extra={ + "deleted_project_ids": sorted(deleted_project_ids), + "processed_project_ids": sorted(processed_project_ids), + }, + ) diff --git a/tests/sentry/workflow_engine/buffer/test_batch_client.py b/tests/sentry/workflow_engine/buffer/test_batch_client.py index d63c554bb08100..62ee1e25e531c9 100644 --- a/tests/sentry/workflow_engine/buffer/test_batch_client.py +++ b/tests/sentry/workflow_engine/buffer/test_batch_client.py @@ -1,18 +1,35 @@ from unittest.mock import Mock -from sentry.testutils.cases import TestCase -from sentry.workflow_engine.buffer.batch_client import DelayedWorkflowClient +import pytest +from sentry.workflow_engine.buffer.batch_client import CohortUpdates, DelayedWorkflowClient +from sentry.workflow_engine.buffer.redis_hash_sorted_set_buffer import RedisHashSortedSetBuffer -class TestDelayedWorkflowClient(TestCase): - def setUp(self) -> None: - self.mock_buffer = Mock() - self.buffer_keys = ["test_key_1", "test_key_2"] - self.workflow_client = DelayedWorkflowClient( - buf=self.mock_buffer, buffer_keys=self.buffer_keys - ) - def test_mark_project_ids_as_processed(self) -> None: +class TestDelayedWorkflowClient: + @pytest.fixture + def mock_buffer(self): + """Create a mock buffer for testing.""" + return Mock(spec=RedisHashSortedSetBuffer) + + @pytest.fixture + def buffer_keys(self): + """Create test buffer keys.""" + return ["test_key_1", "test_key_2"] + + @pytest.fixture + def delayed_workflow_client(self, mock_buffer): + """Create a DelayedWorkflowClient with mocked buffer.""" + return DelayedWorkflowClient(buf=mock_buffer) + + @pytest.fixture + def workflow_client_with_keys(self, mock_buffer, buffer_keys): + """Create a DelayedWorkflowClient with mocked buffer and specific keys.""" + return DelayedWorkflowClient(buf=mock_buffer, buffer_keys=buffer_keys) + + def test_mark_project_ids_as_processed( + self, workflow_client_with_keys, mock_buffer, buffer_keys + ): """Test mark_project_ids_as_processed with mocked RedisHashSortedSetBuffer.""" # Mock the conditional_delete_from_sorted_sets return value # Return value is dict[str, list[int]] where keys are buffer keys and values are deleted project IDs @@ -20,7 +37,7 @@ def test_mark_project_ids_as_processed(self) -> None: "test_key_1": [123, 456], "test_key_2": [789], } - self.mock_buffer.conditional_delete_from_sorted_sets.return_value = mock_return_value + mock_buffer.conditional_delete_from_sorted_sets.return_value = mock_return_value # Input data: project_id -> max_timestamp mapping project_id_max_timestamps = { @@ -30,11 +47,11 @@ def test_mark_project_ids_as_processed(self) -> None: } # Call the method - result = self.workflow_client.mark_project_ids_as_processed(project_id_max_timestamps) + result = workflow_client_with_keys.mark_project_ids_as_processed(project_id_max_timestamps) # Verify the mock was called with the correct arguments - self.mock_buffer.conditional_delete_from_sorted_sets.assert_called_once_with( - tuple(self.buffer_keys), # DelayedWorkflowClient stores keys as tuple + mock_buffer.conditional_delete_from_sorted_sets.assert_called_once_with( + tuple(buffer_keys), # DelayedWorkflowClient stores keys as tuple [(123, 1000.5), (456, 2000.0), (789, 1500.75)], ) @@ -42,10 +59,12 @@ def test_mark_project_ids_as_processed(self) -> None: expected_result = [123, 456, 789] assert sorted(result) == sorted(expected_result) - def test_mark_project_ids_as_processed_empty_input(self) -> None: + def test_mark_project_ids_as_processed_empty_input( + self, workflow_client_with_keys, mock_buffer, buffer_keys + ): """Test mark_project_ids_as_processed with empty input.""" # Mock return value for empty input - self.mock_buffer.conditional_delete_from_sorted_sets.return_value = { + mock_buffer.conditional_delete_from_sorted_sets.return_value = { "test_key_1": [], "test_key_2": [], } @@ -54,25 +73,27 @@ def test_mark_project_ids_as_processed_empty_input(self) -> None: project_id_max_timestamps: dict[int, float] = {} # Call the method - result = self.workflow_client.mark_project_ids_as_processed(project_id_max_timestamps) + result = workflow_client_with_keys.mark_project_ids_as_processed(project_id_max_timestamps) # Verify the mock was called with empty member list - self.mock_buffer.conditional_delete_from_sorted_sets.assert_called_once_with( - tuple(self.buffer_keys), + mock_buffer.conditional_delete_from_sorted_sets.assert_called_once_with( + tuple(buffer_keys), [], ) # Result should be empty assert result == [] - def test_mark_project_ids_as_processed_partial_deletion(self) -> None: + def test_mark_project_ids_as_processed_partial_deletion( + self, workflow_client_with_keys, mock_buffer, buffer_keys + ): """Test mark_project_ids_as_processed when only some project IDs are deleted.""" # Mock return value where only some project IDs are actually deleted mock_return_value = { "test_key_1": [123], # Only project 123 was deleted from this key "test_key_2": [], # No projects deleted from this key } - self.mock_buffer.conditional_delete_from_sorted_sets.return_value = mock_return_value + mock_buffer.conditional_delete_from_sorted_sets.return_value = mock_return_value # Input with multiple project IDs project_id_max_timestamps = { @@ -81,25 +102,27 @@ def test_mark_project_ids_as_processed_partial_deletion(self) -> None: } # Call the method - result = self.workflow_client.mark_project_ids_as_processed(project_id_max_timestamps) + result = workflow_client_with_keys.mark_project_ids_as_processed(project_id_max_timestamps) # Verify the mock was called with all input project IDs - self.mock_buffer.conditional_delete_from_sorted_sets.assert_called_once_with( - tuple(self.buffer_keys), + mock_buffer.conditional_delete_from_sorted_sets.assert_called_once_with( + tuple(buffer_keys), [(123, 1000.5), (456, 2000.0)], ) # Result should only contain the actually deleted project IDs assert result == [123] - def test_mark_project_ids_as_processed_deduplicates_results(self) -> None: + def test_mark_project_ids_as_processed_deduplicates_results( + self, workflow_client_with_keys, mock_buffer, buffer_keys + ): """Test that mark_project_ids_as_processed deduplicates project IDs from multiple keys.""" # Mock return value where the same project ID appears in multiple keys mock_return_value = { "test_key_1": [123, 456], "test_key_2": [456, 789], # 456 appears in both keys } - self.mock_buffer.conditional_delete_from_sorted_sets.return_value = mock_return_value + mock_buffer.conditional_delete_from_sorted_sets.return_value = mock_return_value # Input data project_id_max_timestamps = { @@ -109,9 +132,173 @@ def test_mark_project_ids_as_processed_deduplicates_results(self) -> None: } # Call the method - result = self.workflow_client.mark_project_ids_as_processed(project_id_max_timestamps) + result = workflow_client_with_keys.mark_project_ids_as_processed(project_id_max_timestamps) # Verify the result deduplicates project ID 456 expected_result = [123, 456, 789] assert sorted(result) == sorted(expected_result) assert len(result) == 3 # Should have exactly 3 unique project IDs + + def test_fetch_updates(self, delayed_workflow_client, mock_buffer): + """Test fetching cohort updates from buffer.""" + expected_updates = CohortUpdates(values={1: 100.0}) + mock_buffer.get_parsed_key.return_value = expected_updates + + result = delayed_workflow_client.fetch_updates() + + mock_buffer.get_parsed_key.assert_called_once_with( + "WORKFLOW_ENGINE_COHORT_UPDATES", CohortUpdates + ) + assert result == expected_updates + + def test_persist_updates(self, delayed_workflow_client, mock_buffer): + """Test persisting cohort updates to buffer.""" + updates = CohortUpdates(values={1: 100.0, 2: 200.0}) + + delayed_workflow_client.persist_updates(updates) + + mock_buffer.put_parsed_key.assert_called_once_with( + "WORKFLOW_ENGINE_COHORT_UPDATES", updates + ) + + def test_fetch_updates_missing_key(self, delayed_workflow_client, mock_buffer): + """Test fetching cohort updates when key doesn't exist (returns None).""" + mock_buffer.get_parsed_key.return_value = None + + result = delayed_workflow_client.fetch_updates() + + mock_buffer.get_parsed_key.assert_called_once_with( + "WORKFLOW_ENGINE_COHORT_UPDATES", CohortUpdates + ) + assert isinstance(result, CohortUpdates) + assert result.values == {} # Should be default empty dict + + def test_add_project_ids(self, delayed_workflow_client, mock_buffer): + """Test adding project IDs to a random shard.""" + project_ids = [1, 2, 3] + + delayed_workflow_client.add_project_ids(project_ids) + + # Should call push_to_sorted_set with one of the buffer keys + assert mock_buffer.push_to_sorted_set.call_count == 1 + call_args = mock_buffer.push_to_sorted_set.call_args + assert call_args[1]["value"] == project_ids + # Key should be one of the expected buffer keys + called_key = call_args[1]["key"] + expected_keys = DelayedWorkflowClient._get_buffer_keys() + assert called_key in expected_keys + + def test_get_project_ids(self, delayed_workflow_client, mock_buffer): + """Test getting project IDs within score range.""" + expected_result = {1: [100.0], 2: [200.0]} + mock_buffer.bulk_get_sorted_set.return_value = expected_result + + result = delayed_workflow_client.get_project_ids(min=0.0, max=300.0) + + mock_buffer.bulk_get_sorted_set.assert_called_once_with( + tuple(DelayedWorkflowClient._get_buffer_keys()), + min=0.0, + max=300.0, + ) + assert result == expected_result + + def test_clear_project_ids(self, delayed_workflow_client, mock_buffer): + """Test clearing project IDs within score range.""" + delayed_workflow_client.clear_project_ids(min=0.0, max=300.0) + + mock_buffer.delete_keys.assert_called_once_with( + tuple(DelayedWorkflowClient._get_buffer_keys()), + min=0.0, + max=300.0, + ) + + def test_get_buffer_keys(self): + """Test that buffer keys are generated correctly.""" + keys = DelayedWorkflowClient._get_buffer_keys() + + assert len(keys) == 8 # _BUFFER_SHARDS + assert keys[0] == "workflow_engine_delayed_processing_buffer" # shard 0 + assert keys[1] == "workflow_engine_delayed_processing_buffer:1" # shard 1 + assert keys[7] == "workflow_engine_delayed_processing_buffer:7" # shard 7 + + def test_for_project(self, delayed_workflow_client, mock_buffer): + """Test creating a project-specific client.""" + project_id = 123 + + project_client = delayed_workflow_client.for_project(project_id) + + assert project_client.project_id == project_id + assert project_client._buffer == mock_buffer + + +class TestProjectDelayedWorkflowClient: + @pytest.fixture + def mock_buffer(self): + """Create a mock buffer for testing.""" + return Mock(spec=RedisHashSortedSetBuffer) + + @pytest.fixture + def project_client(self, mock_buffer): + """Create a ProjectDelayedWorkflowClient with mocked buffer.""" + return DelayedWorkflowClient(buf=mock_buffer).for_project(123) + + def test_filters_without_batch_key(self, project_client): + """Test filters generation without batch key.""" + filters = project_client._filters(batch_key=None) + assert filters == {"project_id": 123} + + def test_filters_with_batch_key(self, project_client): + """Test filters generation with batch key.""" + filters = project_client._filters(batch_key="test-batch") + assert filters == {"project_id": 123, "batch_key": "test-batch"} + + def test_delete_hash_fields(self, project_client, mock_buffer): + """Test deleting specific fields from workflow hash.""" + fields = ["field1", "field2"] + + project_client.delete_hash_fields(batch_key=None, fields=fields) + + from sentry.workflow_engine.models import Workflow + + mock_buffer.delete_hash.assert_called_once_with( + model=Workflow, filters={"project_id": 123}, fields=fields + ) + + def test_get_hash_length(self, project_client, mock_buffer): + """Test getting hash length.""" + mock_buffer.get_hash_length.return_value = 5 + + result = project_client.get_hash_length(batch_key=None) + + from sentry.workflow_engine.models import Workflow + + mock_buffer.get_hash_length.assert_called_once_with( + model=Workflow, filters={"project_id": 123} + ) + assert result == 5 + + def test_get_hash_data(self, project_client, mock_buffer): + """Test fetching hash data.""" + expected_data = {"key1": "value1", "key2": "value2"} + mock_buffer.get_hash.return_value = expected_data + + result = project_client.get_hash_data(batch_key="test-batch") + + from sentry.workflow_engine.models import Workflow + + mock_buffer.get_hash.assert_called_once_with( + model=Workflow, filters={"project_id": 123, "batch_key": "test-batch"} + ) + assert result == expected_data + + def test_push_to_hash(self, project_client, mock_buffer): + """Test pushing data to hash in bulk.""" + data = {"key1": "value1", "key2": "value2"} + + project_client.push_to_hash(batch_key="test-batch", data=data) + + from sentry.workflow_engine.models import Workflow + + mock_buffer.push_to_hash_bulk.assert_called_once_with( + model=Workflow, filters={"project_id": 123, "batch_key": "test-batch"}, data=data + ) diff --git a/tests/sentry/workflow_engine/buffer/test_redis_hash_sorted_set_buffer.py b/tests/sentry/workflow_engine/buffer/test_redis_hash_sorted_set_buffer.py index 07afd19fc0d934..5d336b336daefe 100644 --- a/tests/sentry/workflow_engine/buffer/test_redis_hash_sorted_set_buffer.py +++ b/tests/sentry/workflow_engine/buffer/test_redis_hash_sorted_set_buffer.py @@ -65,7 +65,7 @@ def buffer(self, set_sentry_option, request, mock_time_provider): @pytest.fixture(autouse=True) def setup_buffer(self, buffer, mock_time_provider): - self.buf = buffer + self.buf: RedisHashSortedSetBuffer = buffer self.mock_time = mock_time_provider def test_push_to_hash(self): @@ -474,3 +474,55 @@ def test_conditional_delete_rb_host_batching(self): for key in keys: remaining = self.buf.get_sorted_set(key, 0, time.time() + 10) assert len(remaining) == 0 + + def test_get_parsed_key_put_parsed_key(self): + """Test storing and retrieving pydantic models using get_parsed_key/put_parsed_key.""" + from pydantic import BaseModel + + class TestModel(BaseModel): + name: str + value: int + enabled: bool + + # Test putting and getting a parsed model + test_data = TestModel(name="test", value=42, enabled=True) + self.buf.put_parsed_key("test_key", test_data) + + retrieved_data = self.buf.get_parsed_key("test_key", TestModel) + + assert retrieved_data is not None + assert retrieved_data.name == "test" + assert retrieved_data.value == 42 + assert retrieved_data.enabled is True + assert isinstance(retrieved_data, TestModel) + + def test_get_parsed_key_put_parsed_key_complex_model(self): + """Test with more complex pydantic model containing nested data.""" + from pydantic import BaseModel + + class NestedModel(BaseModel): + items: list[int] + metadata: dict[str, str] + + test_data = NestedModel( + items=[1, 2, 3, 4, 5], metadata={"source": "test", "version": "1.0"} + ) + + self.buf.put_parsed_key("complex_key", test_data) + retrieved_data = self.buf.get_parsed_key("complex_key", NestedModel) + + assert retrieved_data is not None + assert retrieved_data.items == [1, 2, 3, 4, 5] + assert retrieved_data.metadata == {"source": "test", "version": "1.0"} + assert isinstance(retrieved_data, NestedModel) + + def test_get_parsed_key_missing_key(self): + """Test get_parsed_key returns None for missing key.""" + from pydantic import BaseModel + + class TestModel(BaseModel): + name: str + + # Try to get a key that doesn't exist + retrieved_data = self.buf.get_parsed_key("nonexistent_key", TestModel) + assert retrieved_data is None diff --git a/tests/sentry/workflow_engine/processors/test_schedule.py b/tests/sentry/workflow_engine/processors/test_schedule.py index a9adf48ef17db2..9d98e53f66ac85 100644 --- a/tests/sentry/workflow_engine/processors/test_schedule.py +++ b/tests/sentry/workflow_engine/processors/test_schedule.py @@ -1,14 +1,20 @@ from datetime import datetime -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, Mock, patch from uuid import uuid4 +import pytest + from sentry.testutils.cases import TestCase from sentry.testutils.helpers.datetime import before_now, freeze_time from sentry.testutils.helpers.options import override_options from sentry.utils import json -from sentry.workflow_engine.buffer.batch_client import DelayedWorkflowClient +from sentry.workflow_engine.buffer.batch_client import CohortUpdates, DelayedWorkflowClient +from sentry.workflow_engine.buffer.redis_hash_sorted_set_buffer import RedisHashSortedSetBuffer from sentry.workflow_engine.processors.schedule import ( + ProjectChooser, bucket_num_groups, + chosen_projects, + mark_projects_processed, process_buffered_workflows, process_in_batches, ) @@ -88,6 +94,42 @@ def test_skips_processing_with_option(self, mock_process_in_batches) -> None: # Should still contain our project assert project.id in all_project_ids + @override_options( + {"delayed_workflow.rollout": True, "workflow_engine.use_cohort_selection": False} + ) + @patch("sentry.workflow_engine.processors.schedule.process_in_batches") + def test_processes_all_projects_without_cohort_selection( + self, mock_process_in_batches: MagicMock + ) -> None: + """Test that all projects are processed when cohort selection is disabled.""" + project = self.create_project() + project_two = self.create_project() + group = self.create_group(project) + group_two = self.create_group(project_two) + + # Push data to buffer + self.batch_client.for_project(project.id).push_to_hash( + batch_key=None, + data={f"345:{group.id}": json.dumps({"event_id": "event-1"})}, + ) + self.batch_client.for_project(project_two.id).push_to_hash( + batch_key=None, + data={f"345:{group_two.id}": json.dumps({"event_id": "event-2"})}, + ) + + # Add projects to sorted set + self.batch_client.add_project_ids([project.id, project_two.id]) + + process_buffered_workflows(self.batch_client) + + # All projects should be processed (no cohort filtering) + assert mock_process_in_batches.call_count == 2 + + # Verify that the buffer keys are cleaned up + fetch_time = datetime.now().timestamp() + all_project_ids = self.batch_client.get_project_ids(min=0, max=fetch_time) + assert all_project_ids == {} + class ProcessInBatchesTest(CreateEventTestCase): def setUp(self) -> None: @@ -174,3 +216,333 @@ def test_get_hash_data_with_batch_key(self) -> None: assert f"{self.rule.id}:{self.group.id}" in result data = json.loads(result[f"{self.rule.id}:{self.group.id}"]) assert data["event_id"] == "event-456" + + +class TestProjectChooser: + @pytest.fixture + def mock_buffer(self): + mock_buffer = Mock(spec=DelayedWorkflowClient) + return mock_buffer + + @pytest.fixture + def project_chooser(self, mock_buffer): + return ProjectChooser(mock_buffer, num_cohorts=6) + + def _find_projects_for_cohorts(self, chooser: ProjectChooser, num_cohorts: int) -> list[int]: + """Helper method to find project IDs that map to each cohort to ensure even distribution.""" + all_project_ids = [] + used_cohorts: set[int] = set() + project_id = 1 + while len(used_cohorts) < num_cohorts: + cohort = chooser._project_id_to_cohort(project_id) + if cohort not in used_cohorts: + all_project_ids.append(project_id) + used_cohorts.add(cohort) + project_id += 1 + return all_project_ids + + def test_project_id_to_cohort_distribution(self, project_chooser): + project_ids = list(range(1, 1001)) # 1000 project IDs + cohorts = [project_chooser._project_id_to_cohort(pid) for pid in project_ids] + + # Check all cohorts are used + assert set(cohorts) == set(range(6)) + + # Check distribution is reasonably even (each cohort gets some projects) + cohort_counts = [cohorts.count(i) for i in range(6)] + assert all(count > 0 for count in cohort_counts) + assert all(count < 1000 for count in cohort_counts) + + def test_project_id_to_cohort_consistent(self, project_chooser): + for project_id in [123, 999, 4, 193848493]: + cohort1 = project_chooser._project_id_to_cohort(project_id) + cohort2 = project_chooser._project_id_to_cohort(project_id) + cohort3 = project_chooser._project_id_to_cohort(project_id) + + assert cohort1 == cohort2 == cohort3 + assert 0 <= cohort1 < 6 + + def test_project_ids_to_process_must_process_over_minute(self, project_chooser): + fetch_time = 1000.0 + cohort_updates = CohortUpdates( + values={ + 0: 900.0, # 100 seconds ago - must process + 1: 950.0, # 50 seconds ago - may process + 2: 990.0, # 10 seconds ago - no process + } + ) + all_project_ids = [10, 11, 12, 13, 14, 15] # Projects mapping to cohorts 0-5 + + result = project_chooser.project_ids_to_process(fetch_time, cohort_updates, all_project_ids) + + # Should include projects from cohort 0 (over 1 minute old) + expected_cohort = project_chooser._project_id_to_cohort(10) + if expected_cohort == 0: + assert 10 in result + + # cohort_updates should be updated with fetch_time for processed cohorts + assert 0 in cohort_updates.values + + def test_project_ids_to_process_may_process_fallback(self, project_chooser): + fetch_time = 1000.0 + cohort_updates = CohortUpdates( + values={ + 0: 995.0, # 5 seconds ago - may process (older) + 1: 998.0, # 2 seconds ago - may process (newer) + 2: 999.0, # 1 second ago - no process + } + ) + all_project_ids = [10, 11, 12] + + result = project_chooser.project_ids_to_process(fetch_time, cohort_updates, all_project_ids) + + # Should choose the oldest from may_process cohorts (cohort 0) + # and update cohort_updates accordingly + assert len(result) > 0 # Should process something + processed_cohorts = {project_chooser._project_id_to_cohort(pid) for pid in result} + + # The processed cohorts should be updated in cohort_updates + for cohort in processed_cohorts: + assert cohort_updates.values[cohort] == fetch_time + + def test_project_ids_to_process_no_processing_needed(self, project_chooser): + fetch_time = 1000.0 + cohort_updates = CohortUpdates( + values={ + 0: 999.0, # 1 second ago + 1: 998.0, # 2 seconds ago + 2: 997.0, # 3 seconds ago + 3: 996.0, # 4 seconds ago + 4: 995.0, # 5 seconds ago + 5: 994.0, # 6 seconds ago + } + ) + all_project_ids = [10, 11, 12, 13, 14, 15] + + result = project_chooser.project_ids_to_process(fetch_time, cohort_updates, all_project_ids) + + # No cohorts are old enough for must_process or may_process + assert len(result) == 0 + + def test_scenario_once_per_minute_6_cohorts(self, project_chooser: ProjectChooser) -> None: + """ + Scenario test: Running once per minute with 6 cohorts. + Since run interval (60s) equals must_process threshold (60s), + all cohorts should be processed on every single run. + """ + all_project_ids = self._find_projects_for_cohorts(project_chooser, 6) + + cohort_updates = CohortUpdates(values={}) + + # Simulate 5 minutes of processing (5 runs, once per minute) + for minute in range(5): + fetch_time = float(minute * 60) # Every 60 seconds + + processed_projects = project_chooser.project_ids_to_process( + fetch_time, cohort_updates, all_project_ids + ) + processed_cohorts = { + project_chooser._project_id_to_cohort(pid) for pid in processed_projects + } + + # Every run should process all 6 cohorts. + assert processed_cohorts == { + 0, + 1, + 2, + 3, + 4, + 5, + }, f"Run {minute} didn't process all cohorts: {processed_cohorts}" + + def test_scenario_six_times_per_minute(self, project_chooser: ProjectChooser) -> None: + """ + Scenario test: Running 6 times per minute (every 10 seconds). + Should process exactly one cohort per run in stable cycle, cycling through all. + """ + all_project_ids = self._find_projects_for_cohorts(project_chooser, 6) + + cohort_updates = CohortUpdates(values={}) + + all_cohorts = set(range(6)) + processed_cohorts_over_time = [] + + # Simulate 2 minutes of processing (12 runs, every 10 seconds) + previous_cohorts = [] + for run in range(12): + fetch_time = float(run * 10) # Every 10 seconds + + processed_projects = project_chooser.project_ids_to_process( + fetch_time, cohort_updates, all_project_ids + ) + processed_cohorts = { + project_chooser._project_id_to_cohort(pid) for pid in processed_projects + } + if run == 0: + assert ( + processed_cohorts == all_cohorts + ), f"First run should process all cohorts, got {processed_cohorts}" + previous_cohorts.append(processed_cohorts) + if len(previous_cohorts) > 6: + previous_cohorts.pop(0) + elif len(previous_cohorts) == 6: + processed_in_last_cycle = set().union(*previous_cohorts) + assert ( + processed_in_last_cycle == all_cohorts + ), f"Run {run} should process all cohorts, got {processed_in_last_cycle}" + processed_cohorts_over_time.append(processed_cohorts) + + def test_scenario_once_per_minute_cohort_count_1(self, mock_buffer) -> None: + """ + Scenario test: Running once per minute with cohort count of 1 (production default). + This demonstrates that all projects are processed together every minute. + """ + # Create ProjectChooser with cohort count = 1 (production default) + chooser = ProjectChooser(mock_buffer, num_cohorts=1) + all_project_ids = self._find_projects_for_cohorts(chooser, 1) + + # Add more projects to demonstrate they all map to cohort 0 + additional_projects = [10, 25, 50, 100, 999, 1001, 5000] + all_project_ids.extend(additional_projects) + + # Verify all projects map to cohort 0 + for project_id in all_project_ids: + cohort = chooser._project_id_to_cohort(project_id) + assert cohort == 0, f"Project {project_id} should map to cohort 0, got {cohort}" + + cohort_updates = CohortUpdates(values={}) + + # Simulate 5 minutes of processing (5 runs, once per minute) + for minute in range(5): + fetch_time = float(minute * 60) # Every 60 seconds + + processed_projects = chooser.project_ids_to_process( + fetch_time, cohort_updates, all_project_ids + ) + processed_cohorts = {chooser._project_id_to_cohort(pid) for pid in processed_projects} + + # With cohort count = 1, should always process cohort 0 + assert processed_cohorts == { + 0 + }, f"Run {minute} should process cohort 0, got {processed_cohorts}" + + # Since all projects are in cohort 0, processing cohort 0 means ALL projects + assert set(processed_projects) == set(all_project_ids), ( + f"Run {minute}: Expected all {len(all_project_ids)} projects to be processed, " + f"but got {len(processed_projects)}: {sorted(processed_projects)}" + ) + + +class TestChosenProjects: + @pytest.fixture + def mock_project_chooser(self): + """Create a mock ProjectChooser.""" + return Mock(spec=ProjectChooser) + + def test_chosen_projects_context_manager(self, mock_project_chooser): + """Test chosen_projects as a context manager.""" + # Setup mocks + mock_cohort_updates = Mock(spec=CohortUpdates) + mock_buffer_client = Mock(spec=DelayedWorkflowClient) + mock_project_chooser.client = mock_buffer_client + mock_buffer_client.fetch_updates.return_value = mock_cohort_updates + + fetch_time = 1000.0 + all_project_ids = [1, 2, 3, 4, 5] + expected_result = [1, 2, 3] + + mock_project_chooser.project_ids_to_process.return_value = expected_result + + # Use context manager + with chosen_projects(mock_project_chooser, fetch_time, all_project_ids) as result: + project_ids_to_process = result + + # Verify fetch_updates was called on project_chooser.client + mock_buffer_client.fetch_updates.assert_called_once() + + # Verify project_ids_to_process was called with correct args + mock_project_chooser.project_ids_to_process.assert_called_once_with( + fetch_time, mock_cohort_updates, all_project_ids + ) + + # Verify the result + assert project_ids_to_process == expected_result + + # Verify persist_updates was called after context exit + mock_buffer_client.persist_updates.assert_called_once_with(mock_cohort_updates) + + def test_chosen_projects_fetch_updates_exception(self, mock_project_chooser): + """Test that exception during fetch_updates is properly handled.""" + # Setup mocks + mock_buffer_client = Mock(spec=DelayedWorkflowClient) + mock_project_chooser.client = mock_buffer_client + # Make fetch_updates raise an exception (e.g. key doesn't exist) + mock_buffer_client.fetch_updates.side_effect = Exception("Key not found") + + fetch_time = 1000.0 + all_project_ids = [1, 2, 3, 4, 5] + + # Should raise the exception from fetch_updates + with pytest.raises(Exception, match="Key not found"): + with chosen_projects(mock_project_chooser, fetch_time, all_project_ids): + pass + + # persist_updates should not be called if fetch_updates fails + mock_buffer_client.persist_updates.assert_not_called() + + def test_chosen_projects_exception_during_processing(self, mock_project_chooser): + mock_buffer_client = Mock(spec=DelayedWorkflowClient) + mock_project_chooser.client = mock_buffer_client + mock_buffer_client.fetch_updates.return_value = Mock(spec=CohortUpdates) + mock_project_chooser.project_ids_to_process.return_value = [1, 2, 3] + + with pytest.raises(RuntimeError, match="Processing failed"): + with chosen_projects(mock_project_chooser, 1000.0, [1, 2, 3, 4, 5]): + raise RuntimeError("Processing failed") + + mock_buffer_client.persist_updates.assert_not_called() + + def test_chosen_projects_without_cohort_selection(self): + """Test chosen_projects when project_chooser is None (cohort selection disabled).""" + fetch_time = 1000.0 + all_project_ids = [1, 2, 3, 4, 5] + + # When project_chooser is None, all projects should be yielded without redis interaction + with chosen_projects(None, fetch_time, all_project_ids) as result: + assert result == all_project_ids + + +@override_options({"workflow_engine.scheduler.use_conditional_delete": True}) +def test_mark_projects_processed_only_cleans_up_processed_projects() -> None: + """Test that mark_projects_processed only cleans up processed projects, not all projects.""" + processed_project_id = 5000 + unprocessed_project_id = 5001 + + current_time = 1000.0 + + def get_fake_time() -> float: + return current_time + + all_project_ids_and_timestamps = { + processed_project_id: [1000.0], + unprocessed_project_id: [2000.0], + } + + client = DelayedWorkflowClient(RedisHashSortedSetBuffer(now_fn=get_fake_time)) + + # Add both projects to buffer + for project_id, [timestamp] in all_project_ids_and_timestamps.items(): + current_time = timestamp + client.add_project_ids([project_id]) + + # Only mark one project as processed + mark_projects_processed( + client, + [processed_project_id], # Only this one was processed + all_project_ids_and_timestamps, + ) + + # The unprocessed project should still be in buffer + remaining_project_ids = client.get_project_ids(min=0, max=3000.0) + assert unprocessed_project_id in remaining_project_ids + assert processed_project_id not in remaining_project_ids