diff --git a/src/sentry/workflow_engine/processors/action.py b/src/sentry/workflow_engine/processors/action.py index 3580d600831948..c444f0df4e07c7 100644 --- a/src/sentry/workflow_engine/processors/action.py +++ b/src/sentry/workflow_engine/processors/action.py @@ -2,7 +2,7 @@ from collections import defaultdict from datetime import datetime, timedelta -from django.db import models +from django.db import connection, models from django.db.models import Case, Value, When from django.utils import timezone @@ -36,6 +36,9 @@ logger = logging.getLogger(__name__) EnqueuedAction = tuple[DataConditionGroup, list[DataCondition]] +UpdatedStatuses = int +CreatedStatuses = int +ConflictedStatuses = list[tuple[int, int]] # (workflow_id, action_id) def get_workflow_action_group_statuses( @@ -71,13 +74,13 @@ def process_workflow_action_group_statuses( workflows: BaseQuerySet[Workflow], group: Group, now: datetime, -) -> tuple[dict[int, int], set[int], list[WorkflowActionGroupStatus]]: +) -> tuple[dict[int, set[int]], set[int], list[WorkflowActionGroupStatus]]: """ Determine which workflow actions should be fired based on their statuses. Prepare the statuses to update and create. """ - action_to_workflow_ids: dict[int, int] = {} # will dedupe because there can be only 1 + updated_action_to_workflows_ids: dict[int, set[int]] = defaultdict(set) workflow_frequencies: dict[int, timedelta] = { workflow.id: workflow.config.get("frequency", 0) * timedelta(minutes=1) for workflow in workflows @@ -91,7 +94,7 @@ def process_workflow_action_group_statuses( status.workflow_id, zero_timedelta ): # we should fire the workflow for this action - action_to_workflow_ids[action_id] = status.workflow_id + updated_action_to_workflows_ids[action_id].add(status.workflow_id) statuses_to_update.add(status.id) missing_statuses: list[WorkflowActionGroupStatus] = [] @@ -107,31 +110,51 @@ def process_workflow_action_group_statuses( workflow_id=workflow_id, action_id=action_id, group=group, date_updated=now ) ) - action_to_workflow_ids[action_id] = workflow_id + updated_action_to_workflows_ids[action_id].add(workflow_id) - return action_to_workflow_ids, statuses_to_update, missing_statuses + return updated_action_to_workflows_ids, statuses_to_update, missing_statuses def update_workflow_action_group_statuses( now: datetime, statuses_to_update: set[int], missing_statuses: list[WorkflowActionGroupStatus] -) -> None: - WorkflowActionGroupStatus.objects.filter( +) -> tuple[UpdatedStatuses, CreatedStatuses, ConflictedStatuses]: + updated_count = WorkflowActionGroupStatus.objects.filter( id__in=statuses_to_update, date_updated__lt=now ).update(date_updated=now) - all_statuses = WorkflowActionGroupStatus.objects.bulk_create( - missing_statuses, - batch_size=1000, - ignore_conflicts=True, - ) - missing_status_pairs = [ - (status.workflow_id, status.action_id) for status in all_statuses if status.id is None + if not missing_statuses: + return updated_count, 0, [] + + # Use raw SQL: only returns successfully created rows + # XXX: the query does not currently include batch size limit like bulk_create does + with connection.cursor() as cursor: + # Build values for batch insert + values_placeholders = [] + values_data = [] + for s in missing_statuses: + values_placeholders.append("(%s, %s, %s, %s, %s)") + values_data.extend([s.workflow_id, s.action_id, s.group_id, now, now]) + + sql = f""" + INSERT INTO workflow_engine_workflowactiongroupstatus + (workflow_id, action_id, group_id, date_added, date_updated) + VALUES {', '.join(values_placeholders)} + ON CONFLICT (workflow_id, action_id, group_id) DO NOTHING + RETURNING workflow_id, action_id + """ + + cursor.execute(sql, values_data) + created_rows = set(cursor.fetchall()) # Only returns newly inserted rows + + # Figure out which ones conflicted (weren't returned) + conflicted_statuses = [ + (s.workflow_id, s.action_id) + for s in missing_statuses + if (s.workflow_id, s.action_id) not in created_rows ] - if missing_status_pairs: - logger.warning( - "Failed to create WorkflowActionGroupStatus objects", - extra={"missing_status_pairs": missing_status_pairs}, - ) + + created_count = len(created_rows) + return updated_count, created_count, conflicted_statuses def get_unique_active_actions( @@ -190,7 +213,7 @@ def filter_recently_fired_workflow_actions( workflow_ids=workflow_ids, ) now = timezone.now() - action_to_workflow_ids, statuses_to_update, missing_statuses = ( + action_to_workflows_ids, statuses_to_update, missing_statuses = ( process_workflow_action_group_statuses( action_to_workflows_ids=action_to_workflows_ids, action_to_statuses=action_to_statuses, @@ -199,14 +222,24 @@ def filter_recently_fired_workflow_actions( now=now, ) ) - update_workflow_action_group_statuses(now, statuses_to_update, missing_statuses) + _, _, conflicted_statuses = update_workflow_action_group_statuses( + now, statuses_to_update, missing_statuses + ) + + # if statuses were not created for some reason, we should not fire for them + for workflow_id, action_id in conflicted_statuses: + action_to_workflows_ids[action_id].remove(workflow_id) + if not action_to_workflows_ids[action_id]: + action_to_workflows_ids.pop(action_id) - actions_queryset = Action.objects.filter(id__in=list(action_to_workflow_ids.keys())) + actions_queryset = Action.objects.filter(id__in=list(action_to_workflows_ids.keys())) # annotate actions with workflow_id they are firing for (deduped) workflow_id_cases = [ - When(id=action_id, then=Value(workflow_id)) - for action_id, workflow_id in action_to_workflow_ids.items() + When( + id=action_id, then=Value(min(list(workflow_ids))) + ) # select 1 workflow to fire for, this is arbitrary but deterministic + for action_id, workflow_ids in action_to_workflows_ids.items() ] return actions_queryset.annotate( diff --git a/tests/sentry/workflow_engine/processors/test_action.py b/tests/sentry/workflow_engine/processors/test_action.py index acef8e820ed4c6..29891311d9e6e0 100644 --- a/tests/sentry/workflow_engine/processors/test_action.py +++ b/tests/sentry/workflow_engine/processors/test_action.py @@ -107,7 +107,7 @@ def test_multiple_workflows_single_action__first_fire(self) -> None: # dedupes action if both workflows will fire it assert set(triggered_actions) == {self.action} # Dedupes action so we have a single workflow_id -> environment to fire with - assert {getattr(action, "workflow_id") for action in triggered_actions} == {workflow.id} + assert getattr(triggered_actions[0], "workflow_id") == self.workflow.id assert WorkflowActionGroupStatus.objects.filter(action=self.action).count() == 2 @@ -191,8 +191,8 @@ def test_process_workflow_action_group_statuses(self) -> None: ) assert action_to_workflow_ids == { - self.action.id: self.workflow.id, - action.id: workflow.id, + self.action.id: {self.workflow.id}, + action.id: {workflow.id}, } assert statuses_to_update == {status_2.id} @@ -222,6 +222,54 @@ def test_update_workflow_action_group_statuses(self) -> None: for status in all_statuses: assert status.date_updated == timezone.now() + def test_returns_uncreated_statuses(self) -> None: + WorkflowActionGroupStatus.objects.create( + workflow=self.workflow, action=self.action, group=self.group + ) + + statuses_to_create = [ + WorkflowActionGroupStatus( + workflow=self.workflow, + action=self.action, + group=self.group, + date_updated=timezone.now(), + ) + ] + _, _, uncreated_statuses = update_workflow_action_group_statuses( + timezone.now(), set(), statuses_to_create + ) + + assert uncreated_statuses == [(self.workflow.id, self.action.id)] + + @patch("sentry.workflow_engine.processors.action.update_workflow_action_group_statuses") + def test_does_not_fire_for_uncreated_statuses(self, mock_update: MagicMock) -> None: + mock_update.return_value = (0, 0, [(self.workflow.id, self.action.id)]) + + triggered_actions = filter_recently_fired_workflow_actions( + set(DataConditionGroup.objects.all()), self.event_data + ) + + assert set(triggered_actions) == set() + + @patch("sentry.workflow_engine.processors.action.update_workflow_action_group_statuses") + def test_fires_for_non_conflicting_workflow(self, mock_update: MagicMock) -> None: + workflow = self.create_workflow(organization=self.organization, config={"frequency": 1440}) + action_group = self.create_data_condition_group(logic_type="any-short") + self.create_data_condition_group_action( + condition_group=action_group, + action=self.action, + ) # shared action + self.create_workflow_data_condition_group(workflow, action_group) + + mock_update.return_value = (0, 0, [(self.workflow.id, self.action.id)]) + + triggered_actions = filter_recently_fired_workflow_actions( + set(DataConditionGroup.objects.all()), self.event_data + ) + + assert set(triggered_actions) == {self.action} + assert getattr(triggered_actions[0], "workflow_id") == workflow.id + class TestIsActionPermitted(BaseWorkflowTest): @patch("sentry.workflow_engine.processors.action._get_integration_features")