Skip to content

Commit

Permalink
Pass triggered or existing DAG Run logical date to DagStateTrigger
Browse files Browse the repository at this point in the history
Closes: apache#38353

When using the TriggerDagRunOperator in `deferrable=True` mode, the DagStateTrigger is being passed the incorrect logical date to poll for. The trigger is using a logical date that is calculated on every execution rather than the logical from either the triggered DAG run or an existing DAG run (if the task is configured to not fail for existing DAG runs).

This change corrects the logical date being used by the DagStateTrigger to poll for the triggered (or reset) DAG run.
  • Loading branch information
josh-fell committed Jun 2, 2024
1 parent e3e450e commit c22b3db
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 7 deletions.
6 changes: 3 additions & 3 deletions airflow/operators/trigger_dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,8 @@ def execute(self, context: Context):

except DagRunAlreadyExists as e:
if self.reset_dag_run:
self.log.info("Clearing %s on %s", self.trigger_dag_id, parsed_logical_date)
dag_run = e.dag_run
self.log.info("Clearing %s on %s", self.trigger_dag_id, dag_run.logical_date)

# Get target dag object and call clear()
dag_model = DagModel.get_current(self.trigger_dag_id)
Expand All @@ -208,7 +209,6 @@ def execute(self, context: Context):

dag_bag = DagBag(dag_folder=dag_model.fileloc, read_dags_from_db=True)
dag = dag_bag.get_dag(self.trigger_dag_id)
dag_run = e.dag_run
dag.clear(start_date=dag_run.logical_date, end_date=dag_run.logical_date)
else:
if self.skip_when_already_exists:
Expand All @@ -231,7 +231,7 @@ def execute(self, context: Context):
trigger=DagStateTrigger(
dag_id=self.trigger_dag_id,
states=self.allowed_states + self.failed_states,
execution_dates=[parsed_logical_date],
execution_dates=[dag_run.logical_date],
poll_interval=self.poke_interval,
),
method_name="execute_complete",
Expand Down
86 changes: 82 additions & 4 deletions tests/operators/test_trigger_dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@
from datetime import datetime
from unittest import mock

import pendulum
import pytest

from airflow.exceptions import AirflowException, DagRunAlreadyExists
from airflow.exceptions import AirflowException, DagRunAlreadyExists, TaskDeferred
from airflow.models.dag import DAG, DagModel
from airflow.models.dagbag import DagBag
from airflow.models.dagrun import DagRun
Expand All @@ -35,7 +36,7 @@
from airflow.triggers.external_task import DagStateTrigger
from airflow.utils import timezone
from airflow.utils.session import create_session
from airflow.utils.state import State, TaskInstanceState
from airflow.utils.state import DagRunState, State, TaskInstanceState
from airflow.utils.types import DagRunType

pytestmark = pytest.mark.db_test
Expand All @@ -50,8 +51,8 @@
dag = DAG(
dag_id='{TRIGGERED_DAG_ID}',
default_args={{'start_date': datetime(2019, 1, 1)}},
schedule_interval=None
schedule=None,
start_date=datetime(2019, 1, 1),
)
task = EmptyOperator(task_id='test', dag=dag)
Expand Down Expand Up @@ -564,3 +565,80 @@ def test_trigger_dagrun_with_execution_date(self):
assert dagrun.logical_date == custom_execution_date
assert dagrun.run_id == DagRun.generate_run_id(DagRunType.MANUAL, custom_execution_date)
self.assert_extra_link(dagrun, task, session)

@pytest.mark.parametrize(
argnames=["trigger_logical_date"],
argvalues=[
pytest.param(DEFAULT_DATE, id=f"logical_date={DEFAULT_DATE}"),
pytest.param(None, id="logical_date=None"),
],
)
def test_dagstatetrigger_execution_dates(self, trigger_logical_date):
"""Ensure that the DagStateTrigger is called with the triggered DAG's logical date."""
task = TriggerDagRunOperator(
task_id="test_task",
trigger_dag_id=TRIGGERED_DAG_ID,
logical_date=trigger_logical_date,
wait_for_completion=True,
poke_interval=5,
allowed_states=[DagRunState.QUEUED],
deferrable=True,
dag=self.dag,
)

mock_task_defer = mock.MagicMock(side_effect=task.defer)
with mock.patch.object(TriggerDagRunOperator, "defer", mock_task_defer), pytest.raises(TaskDeferred):
task.execute({"task_instance": mock.MagicMock()})

with create_session() as session:
dagruns = session.query(DagRun).filter(DagRun.dag_id == TRIGGERED_DAG_ID).all()
assert len(dagruns) == 1

assert mock_task_defer.call_args_list[0].kwargs["trigger"].execution_dates == [
pendulum.instance(dagruns[0].logical_date)
]

def test_dagstatetrigger_execution_dates_with_clear_and_reset(self):
"""Check DagStateTrigger is called with the triggered DAG's logical date on subsequent defers."""
task = TriggerDagRunOperator(
task_id="test_task",
trigger_dag_id=TRIGGERED_DAG_ID,
trigger_run_id="custom_run_id",
wait_for_completion=True,
poke_interval=5,
allowed_states=[DagRunState.QUEUED],
deferrable=True,
reset_dag_run=True,
dag=self.dag,
)

mock_task_defer = mock.MagicMock(side_effect=task.defer)
with mock.patch.object(TriggerDagRunOperator, "defer", mock_task_defer), pytest.raises(TaskDeferred):
task.execute({"task_instance": mock.MagicMock()})

with create_session() as session:
dagruns = session.query(DagRun).filter(DagRun.dag_id == TRIGGERED_DAG_ID).all()
triggered_logical_date = dagruns[0].logical_date
assert len(dagruns) == 1

assert mock_task_defer.call_args_list[0].kwargs["trigger"].execution_dates == [
pendulum.instance(triggered_logical_date)
]

# Simulate the TriggerDagRunOperator task being cleared (aka executed again). A DagRunAlreadyExists
# exception should be raised because of the previous DAG run.
with mock.patch.object(TriggerDagRunOperator, "defer", mock_task_defer), pytest.raises(
(DagRunAlreadyExists, TaskDeferred)
):
task.execute({"task_instance": mock.MagicMock()})

# Still only one DAG run should exist for the triggered DAG since the DAG will be cleared since the
# TriggerDagRunOperator task is configured with `reset_dag_run=True`.
with create_session() as session:
dagruns = session.query(DagRun).filter(DagRun.dag_id == TRIGGERED_DAG_ID).all()
assert len(dagruns) == 1

# The second DagStateTrigger call should still use the original `logical_date` value.
assert mock_task_defer.call_args_list[1].kwargs["trigger"].execution_dates == [
pendulum.instance(triggered_logical_date)
]

0 comments on commit c22b3db

Please sign in to comment.