Skip to content

Commit

Permalink
[AIRFLOW-4232] Add none_skipped trigger rule (apache#5032)
Browse files Browse the repository at this point in the history
Downstream tasks should run as long as their parents are in
`success`, `failed`, or `upstream_failed` states.
  • Loading branch information
cmdoptesc authored and Chad Henderson committed Apr 16, 2019
1 parent 968623e commit 33ca49e
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 5 deletions.
2 changes: 1 addition & 1 deletion airflow/models/__init__.py
Expand Up @@ -2008,7 +2008,7 @@ class derived from this one results in the creation of a task object,
:param trigger_rule: defines the rule by which dependencies are applied
for the task to get triggered. Options are:
``{ all_success | all_failed | all_done | one_success |
one_failed | none_failed | dummy}``
one_failed | none_failed | none_skipped | dummy}``
default is ``all_success``. Options can be set as string or
using the constants defined in the static class
``airflow.utils.TriggerRule``
Expand Down
11 changes: 11 additions & 0 deletions airflow/ti_deps/deps/trigger_rule_dep.py
Expand Up @@ -157,6 +157,9 @@ def _evaluate_trigger_rule(
ti.set_state(State.UPSTREAM_FAILED, session)
elif skipped == upstream:
ti.set_state(State.SKIPPED, session)
elif tr == TR.NONE_SKIPPED:
if skipped:
ti.set_state(State.SKIPPED, session)

if tr == TR.ONE_SUCCESS:
if successes <= 0:
Expand Down Expand Up @@ -208,6 +211,14 @@ def _evaluate_trigger_rule(
"upstream_tasks_state={2}, upstream_task_ids={3}"
.format(tr, num_failures, upstream_tasks_state,
task.upstream_task_ids))
elif tr == TR.NONE_SKIPPED:
if skipped > 0:
yield self._failing_status(
reason="Task's trigger rule '{0}' requires all upstream "
"tasks to not have been skipped, but found {1} task(s) skipped. "
"upstream_tasks_state={2}, upstream_task_ids={3}"
.format(tr, skipped, upstream_tasks_state,
task.upstream_task_ids))
else:
yield self._failing_status(
reason="No strategy to evaluate trigger rule '{0}'.".format(tr))
3 changes: 2 additions & 1 deletion airflow/utils/trigger_rule.py
Expand Up @@ -29,8 +29,9 @@ class TriggerRule(object):
ALL_DONE = 'all_done'
ONE_SUCCESS = 'one_success'
ONE_FAILED = 'one_failed'
DUMMY = 'dummy'
NONE_FAILED = 'none_failed'
NONE_SKIPPED = 'none_skipped'
DUMMY = 'dummy'

_ALL_TRIGGER_RULES = set() # type: Set[str]

Expand Down
3 changes: 2 additions & 1 deletion docs/concepts.rst
Expand Up @@ -743,6 +743,7 @@ while creating tasks:
* ``one_failed``: fires as soon as at least one parent has failed, it does not wait for all parents to be done
* ``one_success``: fires as soon as at least one parent succeeds, it does not wait for all parents to be done
* ``none_failed``: all parents have not failed (``failed`` or ``upstream_failed``) i.e. all parents have succeeded or been skipped
* ``none_skipped``: no parent is in a ``skipped`` state, i.e. all parents are in a ``success``, ``failed``, or ``upstream_failed`` state
* ``dummy``: dependencies are just for show, trigger at will

Note that these can be used in conjunction with ``depends_on_past`` (boolean)
Expand All @@ -752,7 +753,7 @@ previous schedule for the task hasn't succeeded.
One must be aware of the interaction between trigger rules and skipped tasks
in schedule level. Skipped tasks will cascade through trigger rules
``all_success`` and ``all_failed`` but not ``all_done``, ``one_failed``, ``one_success``,
``none_failed`` and ``dummy``.
``none_failed``, ``none_skipped`` and ``dummy``.

For example, consider the following DAG:

Expand Down
70 changes: 69 additions & 1 deletion tests/ti_deps/deps/test_trigger_rule_dep.py
Expand Up @@ -23,6 +23,7 @@
from airflow.models import BaseOperator, TaskInstance
from airflow.utils.trigger_rule import TriggerRule
from airflow.ti_deps.deps.trigger_rule_dep import TriggerRuleDep
from airflow.utils.db import create_session
from airflow.utils.state import State


Expand All @@ -34,7 +35,7 @@ def _get_task_instance(self, trigger_rule=TriggerRule.ALL_SUCCESS,
start_date=datetime(2015, 1, 1))
if upstream_task_ids:
task._upstream_task_ids.update(upstream_task_ids)
return TaskInstance(task=task, state=state, execution_date=None)
return TaskInstance(task=task, state=state, execution_date=task.start_date)

def test_no_upstream_tasks(self):
"""
Expand Down Expand Up @@ -275,6 +276,73 @@ def test_all_done_tr_failure(self):
self.assertEqual(len(dep_statuses), 1)
self.assertFalse(dep_statuses[0].passed)

def test_none_skipped_tr_success(self):
"""
None-skipped trigger rule success
"""

ti = self._get_task_instance(TriggerRule.NONE_SKIPPED,
upstream_task_ids=["FakeTaskID",
"OtherFakeTaskID",
"FailedFakeTaskID"])
with create_session() as session:
dep_statuses = tuple(TriggerRuleDep()._evaluate_trigger_rule(
ti=ti,
successes=2,
skipped=0,
failed=1,
upstream_failed=0,
done=3,
flag_upstream_failed=False,
session=session))
self.assertEqual(len(dep_statuses), 0)

# with `flag_upstream_failed` set to True
dep_statuses = tuple(TriggerRuleDep()._evaluate_trigger_rule(
ti=ti,
successes=0,
skipped=0,
failed=3,
upstream_failed=0,
done=3,
flag_upstream_failed=True,
session=session))
self.assertEqual(len(dep_statuses), 0)

def test_none_skipped_tr_failure(self):
"""
None-skipped trigger rule failure
"""
ti = self._get_task_instance(TriggerRule.NONE_SKIPPED,
upstream_task_ids=["FakeTaskID",
"SkippedTaskID"])

with create_session() as session:
dep_statuses = tuple(TriggerRuleDep()._evaluate_trigger_rule(
ti=ti,
successes=1,
skipped=1,
failed=0,
upstream_failed=0,
done=2,
flag_upstream_failed=False,
session=session))
self.assertEqual(len(dep_statuses), 1)
self.assertFalse(dep_statuses[0].passed)

# with `flag_upstream_failed` set to True
dep_statuses = tuple(TriggerRuleDep()._evaluate_trigger_rule(
ti=ti,
successes=1,
skipped=1,
failed=0,
upstream_failed=0,
done=2,
flag_upstream_failed=True,
session=session))
self.assertEqual(len(dep_statuses), 1)
self.assertFalse(dep_statuses[0].passed)

def test_unknown_tr(self):
"""
Unknown trigger rules should cause this dep to fail
Expand Down
3 changes: 2 additions & 1 deletion tests/utils/test_trigger_rule.py
Expand Up @@ -30,5 +30,6 @@ def test_valid_trigger_rules(self):
self.assertTrue(TriggerRule.is_valid(TriggerRule.ONE_SUCCESS))
self.assertTrue(TriggerRule.is_valid(TriggerRule.ONE_FAILED))
self.assertTrue(TriggerRule.is_valid(TriggerRule.NONE_FAILED))
self.assertTrue(TriggerRule.is_valid(TriggerRule.NONE_SKIPPED))
self.assertTrue(TriggerRule.is_valid(TriggerRule.DUMMY))
self.assertEqual(len(TriggerRule.all_triggers()), 7)
self.assertEqual(len(TriggerRule.all_triggers()), 8)

0 comments on commit 33ca49e

Please sign in to comment.