diff --git a/src/sentry/incidents/subscription_processor.py b/src/sentry/incidents/subscription_processor.py index 05accb63750bf7..01ad1fa5465862 100644 --- a/src/sentry/incidents/subscription_processor.py +++ b/src/sentry/incidents/subscription_processor.py @@ -5,7 +5,7 @@ from collections.abc import Sequence from copy import deepcopy from datetime import datetime, timedelta -from typing import TypeVar, cast +from typing import Literal, TypedDict, TypeVar, cast from django.conf import settings from django.db import router, transaction @@ -87,6 +87,15 @@ T = TypeVar("T") +class MetricIssueDetectorConfig(TypedDict): + """ + Schema for Metric Issue Detector.config. + """ + + comparison_delta: int | None + detection_type: Literal["static", "percent", "dynamic"] + + class SubscriptionProcessor: """ Class for processing subscription updates for an alert rule. Accepts a subscription @@ -107,19 +116,20 @@ class SubscriptionProcessor: def __init__(self, subscription: QuerySubscription) -> None: self.subscription = subscription + self._alert_rule: AlertRule | None = None try: - self.alert_rule = AlertRule.objects.get_for_subscription(subscription) + self._alert_rule = AlertRule.objects.get_for_subscription(subscription) except AlertRule.DoesNotExist: return - self.triggers = AlertRuleTrigger.objects.get_for_alert_rule(self.alert_rule) + self.triggers = AlertRuleTrigger.objects.get_for_alert_rule(self._alert_rule) self.triggers.sort(key=lambda trigger: trigger.alert_threshold) ( self.last_update, self.trigger_alert_counts, self.trigger_resolve_counts, - ) = get_alert_rule_stats(self.alert_rule, self.subscription, self.triggers) + ) = get_alert_rule_stats(self._alert_rule, self.subscription, self.triggers) self.orig_trigger_alert_counts = deepcopy(self.trigger_alert_counts) self.orig_trigger_resolve_counts = deepcopy(self.trigger_resolve_counts) @@ -135,6 +145,14 @@ def __init__(self, subscription: QuerySubscription) -> None: or self._has_workflow_engine_processing_only ) + @property + def alert_rule(self) -> AlertRule: + """ + Only use this in non-single processing contexts. + """ + assert self._alert_rule is not None + return self._alert_rule + @property def active_incident(self) -> Incident | None: """ @@ -188,7 +206,7 @@ def check_trigger_matches_status( incident_trigger = self.incident_trigger_map.get(trigger.id) return incident_trigger is not None and incident_trigger.status == status.value - def reset_trigger_counts(self) -> None: + def reset_trigger_counts(self, alert_rule: AlertRule) -> None: """ Helper method that clears both the trigger alert and the trigger resolve counts """ @@ -196,7 +214,7 @@ def reset_trigger_counts(self) -> None: self.trigger_alert_counts[trigger_id] = 0 for trigger_id in self.trigger_resolve_counts: self.trigger_resolve_counts[trigger_id] = 0 - self.update_alert_rule_stats() + self.update_alert_rule_stats(alert_rule) def calculate_resolve_threshold(self, trigger: AlertRuleTrigger) -> float: """ @@ -253,8 +271,8 @@ def get_crash_rate_alert_metrics_aggregation_value( aggregation_value = get_crash_rate_alert_metrics_aggregation_value_helper( subscription_update ) - if aggregation_value is None: - self.reset_trigger_counts() + if aggregation_value is None and self._alert_rule is not None: + self.reset_trigger_counts(self._alert_rule) return aggregation_value def get_aggregation_value( @@ -271,7 +289,7 @@ def get_aggregation_value( organization_id=self.subscription.project.organization.id, project_ids=[self.subscription.project_id], comparison_delta=comparison_delta, - alert_rule_id=self.alert_rule.id, + alert_rule_id=self._alert_rule.id if self._alert_rule else None, ) return aggregation_value @@ -300,7 +318,7 @@ def handle_trigger_anomalies( is_resolved=False, ) incremented = metrics_incremented or incremented - incident_trigger = self.trigger_alert_threshold(trigger, aggregation_value) + incident_trigger = self.trigger_alert_threshold(trigger) if incident_trigger is not None: fired_incident_triggers.append(incident_trigger) else: @@ -332,9 +350,12 @@ def get_comparison_delta(self, detector: Detector | None) -> int | None: comparison_delta = None if detector: - comparison_delta = detector.config.get("comparison_delta") + detector_cfg: MetricIssueDetectorConfig = detector.config + comparison_delta = detector_cfg.get("comparison_delta") else: - comparison_delta = self.alert_rule.comparison_delta + # If we don't have a Detector, we must have an AlertRule. + assert self._alert_rule is not None + comparison_delta = self._alert_rule.comparison_delta return comparison_delta @@ -421,7 +442,7 @@ def handle_trigger_alerts( ) incremented = metrics_incremented or incremented # triggering a threshold will create an incident and set the status to active - incident_trigger = self.trigger_alert_threshold(trigger, aggregation_value) + incident_trigger = self.trigger_alert_threshold(trigger) if incident_trigger is not None: fired_incident_triggers.append(incident_trigger) else: @@ -455,11 +476,13 @@ def handle_trigger_alerts( def process_results_workflow_engine( self, + detector: Detector, subscription_update: QuerySubscriptionUpdate, aggregation_value: float, organization: Organization, ) -> list[tuple[Detector, dict[DetectorGroupKey, DetectorEvaluationResult]]]: - if self.alert_rule.detection_type == AlertRuleDetectionType.DYNAMIC: + detector_cfg: MetricIssueDetectorConfig = detector.config + if detector_cfg["detection_type"] == AlertRuleDetectionType.DYNAMIC.value: anomaly_detection_packet = AnomalyDetectionUpdate( entity=subscription_update.get("entity", ""), subscription_id=subscription_update["subscription_id"], @@ -499,14 +522,13 @@ def process_results_workflow_engine( "results": results, "num_results": len(results), "value": aggregation_value, - "rule_id": self.alert_rule.id, + "rule_id": self._alert_rule.id if self._alert_rule else None, }, ) return results def process_legacy_metric_alerts( self, - subscription_update: QuerySubscriptionUpdate, aggregation_value: float, detector: Detector | None, results: list[tuple[Detector, dict[DetectorGroupKey, DetectorEvaluationResult]]] | None, @@ -632,7 +654,7 @@ def process_legacy_metric_alerts( # is killed here. The trade-off is that we might process an update twice. Mostly # this will have no effect, but if someone manages to close a triggered incident # before the next one then we might alert twice. - self.update_alert_rule_stats() + self.update_alert_rule_stats(self.alert_rule) return fired_incident_triggers def has_downgraded(self, dataset: str, organization: Organization) -> bool: @@ -677,7 +699,7 @@ def process_update(self, subscription_update: QuerySubscriptionUpdate) -> None: if self.has_downgraded(dataset, organization): return - if not hasattr(self, "alert_rule"): + if self._alert_rule is None: # QuerySubscriptions must _always_ have an associated AlertRule # If the alert rule has been removed then clean up associated tables and return metrics.incr("incidents.alert_rules.no_alert_rule_for_subscription", sample_rate=1.0) @@ -736,8 +758,9 @@ def process_update(self, subscription_update: QuerySubscriptionUpdate) -> None: legacy_results = None if self._has_workflow_engine_processing: + assert detector is not None workflow_engine_results = self.process_results_workflow_engine( - subscription_update, aggregation_value, organization + detector, subscription_update, aggregation_value, organization ) if self._has_workflow_engine_processing_only: @@ -756,7 +779,6 @@ def process_update(self, subscription_update: QuerySubscriptionUpdate) -> None: workflow engine "and" metric alerts. """ legacy_results = self.process_legacy_metric_alerts( - subscription_update, aggregation_value, detector, workflow_engine_results, @@ -775,7 +797,8 @@ def process_update(self, subscription_update: QuerySubscriptionUpdate) -> None: ) def trigger_alert_threshold( - self, trigger: AlertRuleTrigger, metric_value: float + self, + trigger: AlertRuleTrigger, ) -> IncidentTrigger | None: """ Called when a subscription update exceeds the value defined in the @@ -1019,7 +1042,7 @@ def handle_incident_severity_update(self) -> None: status_method=IncidentStatusMethod.RULE_TRIGGERED, ) - def update_alert_rule_stats(self) -> None: + def update_alert_rule_stats(self, alert_rule: AlertRule) -> None: """ Updates stats about the alert rule, if they're changed. :return: @@ -1036,7 +1059,7 @@ def update_alert_rule_stats(self) -> None: } update_alert_rule_stats( - self.alert_rule, + alert_rule, self.subscription, self.last_update, updated_trigger_alert_counts, diff --git a/tests/sentry/incidents/subscription_processor/test_subscription_processor.py b/tests/sentry/incidents/subscription_processor/test_subscription_processor.py index a2116b87526b27..d235a61cdcfebb 100644 --- a/tests/sentry/incidents/subscription_processor/test_subscription_processor.py +++ b/tests/sentry/incidents/subscription_processor/test_subscription_processor.py @@ -3481,7 +3481,7 @@ def test_seer_call_null_aggregation_value( mock_seer_request.return_value = HTTPResponse(orjson.dumps(seer_return_value), status=200) processor = SubscriptionProcessor(self.sub) - processor.alert_rule = self.dynamic_rule + processor._alert_rule = self.dynamic_rule result = get_anomaly_data_from_seer_legacy( alert_rule=processor.alert_rule, subscription=processor.subscription,