Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 58 additions & 25 deletions src/sentry/workflow_engine/processors/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from collections import defaultdict
from datetime import datetime, timedelta

from django.db import models
from django.db import connection, models
from django.db.models import Case, Value, When
from django.utils import timezone

Expand Down Expand Up @@ -36,6 +36,9 @@
logger = logging.getLogger(__name__)

EnqueuedAction = tuple[DataConditionGroup, list[DataCondition]]
UpdatedStatuses = int
CreatedStatuses = int
ConflictedStatuses = list[tuple[int, int]] # (workflow_id, action_id)


def get_workflow_action_group_statuses(
Expand Down Expand Up @@ -71,13 +74,13 @@ def process_workflow_action_group_statuses(
workflows: BaseQuerySet[Workflow],
group: Group,
now: datetime,
) -> tuple[dict[int, int], set[int], list[WorkflowActionGroupStatus]]:
) -> tuple[dict[int, set[int]], set[int], list[WorkflowActionGroupStatus]]:
"""
Determine which workflow actions should be fired based on their statuses.
Prepare the statuses to update and create.
"""

action_to_workflow_ids: dict[int, int] = {} # will dedupe because there can be only 1
updated_action_to_workflows_ids: dict[int, set[int]] = defaultdict(set)
workflow_frequencies: dict[int, timedelta] = {
workflow.id: workflow.config.get("frequency", 0) * timedelta(minutes=1)
for workflow in workflows
Expand All @@ -91,7 +94,7 @@ def process_workflow_action_group_statuses(
status.workflow_id, zero_timedelta
):
# we should fire the workflow for this action
action_to_workflow_ids[action_id] = status.workflow_id
updated_action_to_workflows_ids[action_id].add(status.workflow_id)
statuses_to_update.add(status.id)

missing_statuses: list[WorkflowActionGroupStatus] = []
Expand All @@ -107,31 +110,51 @@ def process_workflow_action_group_statuses(
workflow_id=workflow_id, action_id=action_id, group=group, date_updated=now
)
)
action_to_workflow_ids[action_id] = workflow_id
updated_action_to_workflows_ids[action_id].add(workflow_id)

return action_to_workflow_ids, statuses_to_update, missing_statuses
return updated_action_to_workflows_ids, statuses_to_update, missing_statuses


def update_workflow_action_group_statuses(
now: datetime, statuses_to_update: set[int], missing_statuses: list[WorkflowActionGroupStatus]
) -> None:
WorkflowActionGroupStatus.objects.filter(
) -> tuple[UpdatedStatuses, CreatedStatuses, ConflictedStatuses]:
updated_count = WorkflowActionGroupStatus.objects.filter(
id__in=statuses_to_update, date_updated__lt=now
).update(date_updated=now)

all_statuses = WorkflowActionGroupStatus.objects.bulk_create(
missing_statuses,
batch_size=1000,
ignore_conflicts=True,
)
Comment on lines -122 to -126
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately the best way to find the rows that were not created was to use a SQL query. Django's bulk_create doesn't have a nice way to only populate PKs in the return list for successful creates that didn't conflict

missing_status_pairs = [
(status.workflow_id, status.action_id) for status in all_statuses if status.id is None
if not missing_statuses:
return updated_count, 0, []

# Use raw SQL: only returns successfully created rows
# XXX: the query does not currently include batch size limit like bulk_create does
with connection.cursor() as cursor:
# Build values for batch insert
values_placeholders = []
values_data = []
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bulk_create has a batch size because there are statement size limits. I don't know that we expect to hit them, I assume our expected upper end is hundreds here, but seems worth noting.

for s in missing_statuses:
values_placeholders.append("(%s, %s, %s, %s, %s)")
values_data.extend([s.workflow_id, s.action_id, s.group_id, now, now])

sql = f"""
INSERT INTO workflow_engine_workflowactiongroupstatus
(workflow_id, action_id, group_id, date_added, date_updated)
VALUES {', '.join(values_placeholders)}
ON CONFLICT (workflow_id, action_id, group_id) DO NOTHING
RETURNING workflow_id, action_id
"""

cursor.execute(sql, values_data)
created_rows = set(cursor.fetchall()) # Only returns newly inserted rows

# Figure out which ones conflicted (weren't returned)
conflicted_statuses = [
(s.workflow_id, s.action_id)
for s in missing_statuses
if (s.workflow_id, s.action_id) not in created_rows
]
if missing_status_pairs:
logger.warning(
"Failed to create WorkflowActionGroupStatus objects",
extra={"missing_status_pairs": missing_status_pairs},
)

created_count = len(created_rows)
return updated_count, created_count, conflicted_statuses


def get_unique_active_actions(
Expand Down Expand Up @@ -190,7 +213,7 @@ def filter_recently_fired_workflow_actions(
workflow_ids=workflow_ids,
)
now = timezone.now()
action_to_workflow_ids, statuses_to_update, missing_statuses = (
action_to_workflows_ids, statuses_to_update, missing_statuses = (
process_workflow_action_group_statuses(
action_to_workflows_ids=action_to_workflows_ids,
action_to_statuses=action_to_statuses,
Expand All @@ -199,14 +222,24 @@ def filter_recently_fired_workflow_actions(
now=now,
)
)
update_workflow_action_group_statuses(now, statuses_to_update, missing_statuses)
_, _, conflicted_statuses = update_workflow_action_group_statuses(
now, statuses_to_update, missing_statuses
)

# if statuses were not created for some reason, we should not fire for them
for workflow_id, action_id in conflicted_statuses:
action_to_workflows_ids[action_id].remove(workflow_id)
if not action_to_workflows_ids[action_id]:
action_to_workflows_ids.pop(action_id)

actions_queryset = Action.objects.filter(id__in=list(action_to_workflow_ids.keys()))
actions_queryset = Action.objects.filter(id__in=list(action_to_workflows_ids.keys()))

# annotate actions with workflow_id they are firing for (deduped)
workflow_id_cases = [
When(id=action_id, then=Value(workflow_id))
for action_id, workflow_id in action_to_workflow_ids.items()
When(
id=action_id, then=Value(min(list(workflow_ids)))
) # select 1 workflow to fire for, this is arbitrary but deterministic
for action_id, workflow_ids in action_to_workflows_ids.items()
]

return actions_queryset.annotate(
Expand Down
54 changes: 51 additions & 3 deletions tests/sentry/workflow_engine/processors/test_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def test_multiple_workflows_single_action__first_fire(self) -> None:
# dedupes action if both workflows will fire it
assert set(triggered_actions) == {self.action}
# Dedupes action so we have a single workflow_id -> environment to fire with
assert {getattr(action, "workflow_id") for action in triggered_actions} == {workflow.id}
assert getattr(triggered_actions[0], "workflow_id") == self.workflow.id

assert WorkflowActionGroupStatus.objects.filter(action=self.action).count() == 2

Expand Down Expand Up @@ -191,8 +191,8 @@ def test_process_workflow_action_group_statuses(self) -> None:
)

assert action_to_workflow_ids == {
self.action.id: self.workflow.id,
action.id: workflow.id,
self.action.id: {self.workflow.id},
action.id: {workflow.id},
}
assert statuses_to_update == {status_2.id}

Expand Down Expand Up @@ -222,6 +222,54 @@ def test_update_workflow_action_group_statuses(self) -> None:
for status in all_statuses:
assert status.date_updated == timezone.now()

def test_returns_uncreated_statuses(self) -> None:
WorkflowActionGroupStatus.objects.create(
workflow=self.workflow, action=self.action, group=self.group
)

statuses_to_create = [
WorkflowActionGroupStatus(
workflow=self.workflow,
action=self.action,
group=self.group,
date_updated=timezone.now(),
)
]
_, _, uncreated_statuses = update_workflow_action_group_statuses(
timezone.now(), set(), statuses_to_create
)

assert uncreated_statuses == [(self.workflow.id, self.action.id)]

@patch("sentry.workflow_engine.processors.action.update_workflow_action_group_statuses")
def test_does_not_fire_for_uncreated_statuses(self, mock_update: MagicMock) -> None:
mock_update.return_value = (0, 0, [(self.workflow.id, self.action.id)])

triggered_actions = filter_recently_fired_workflow_actions(
set(DataConditionGroup.objects.all()), self.event_data
)

assert set(triggered_actions) == set()

@patch("sentry.workflow_engine.processors.action.update_workflow_action_group_statuses")
def test_fires_for_non_conflicting_workflow(self, mock_update: MagicMock) -> None:
workflow = self.create_workflow(organization=self.organization, config={"frequency": 1440})
action_group = self.create_data_condition_group(logic_type="any-short")
self.create_data_condition_group_action(
condition_group=action_group,
action=self.action,
) # shared action
self.create_workflow_data_condition_group(workflow, action_group)

mock_update.return_value = (0, 0, [(self.workflow.id, self.action.id)])

triggered_actions = filter_recently_fired_workflow_actions(
set(DataConditionGroup.objects.all()), self.event_data
)

assert set(triggered_actions) == {self.action}
assert getattr(triggered_actions[0], "workflow_id") == workflow.id


class TestIsActionPermitted(BaseWorkflowTest):
@patch("sentry.workflow_engine.processors.action._get_integration_features")
Expand Down
Loading