diff --git a/airflow/example_dags/example_xcomargs.py b/airflow/example_dags/example_xcomargs.py new file mode 100644 index 0000000000000..870030371cacc --- /dev/null +++ b/airflow/example_dags/example_xcomargs.py @@ -0,0 +1,51 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Example DAG demonstrating the usage of the XComArgs.""" + +from airflow import DAG +from airflow.operators.python import PythonOperator +from airflow.utils.dates import days_ago + +args = { + 'owner': 'airflow', + 'start_date': days_ago(2), +} + + +def dummy(*args, **kwargs): + """Dummy function""" + return "pass" + + +with DAG( + dag_id='example_xcom_args', + default_args=args, + schedule_interval=None, + tags=['example'] +) as dag: + task1 = PythonOperator( + task_id='task1', + python_callable=dummy, + ) + + task2 = PythonOperator( + task_id='task2', + python_callable=dummy, + op_kwargs={"dummy": task1.output}, + ) diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 85f01dccb8949..7101f858ab6fc 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -18,6 +18,7 @@ """ Base operator for all operators. """ +import abc import copy import functools import logging @@ -60,9 +61,29 @@ ScheduleInterval = Union[str, timedelta, relativedelta] +class BaseOperatorMeta(abc.ABCMeta): + """ + Base metaclass of BaseOperator. + """ + + def __call__(cls, *args, **kwargs): + """ + Called when you call BaseOperator(). In this way we are able to perform an action + after initializing an operator no matter where the ``super().__init__`` is called + (before or after assign of new attributes in a custom operator). + """ + obj: BaseOperator = type.__call__(cls, *args, **kwargs) + # Here we set upstream task defined by XComArgs passed to template fields of the operator + obj.set_xcomargs_dependencies() + + # Mark instance as instantiated https://docs.python.org/3/tutorial/classes.html#private-variables + obj._BaseOperator__instantiated = True + return obj + + # pylint: disable=too-many-instance-attributes,too-many-public-methods @functools.total_ordering -class BaseOperator(Operator, LoggingMixin): +class BaseOperator(Operator, LoggingMixin, metaclass=BaseOperatorMeta): """ Abstract base class for all operators. Since operators create objects that become nodes in the dag, BaseOperator contains many recursive methods for @@ -292,6 +313,12 @@ class derived from this one results in the creation of a task object, # Defines if the operator supports lineage without manual definitions supports_lineage = False + # If True then the class constructor was called + __instantiated = False + + # Set to True before calling execute method + _lock_for_execution = False + # noinspection PyUnusedLocal # pylint: disable=too-many-arguments,too-many-locals, too-many-statements @apply_defaults @@ -547,6 +574,18 @@ def __lt__(self, other): return self + def __setattr__(self, key, value): + super().__setattr__(key, value) + if self._lock_for_execution: + # Skip any custom behaviour during execute + return + if self.__instantiated and key in self.template_fields: + # Resolve upstreams set by assigning an XComArg after initializing + # an operator, example: + # op = BashOperator() + # op.bash_command = "sleep 1" + self.set_xcomargs_dependencies() + def add_inlets(self, inlets: Iterable[Any]): """ Sets inlets to this operator @@ -633,6 +672,56 @@ def deps(self) -> Set[BaseTIDep]: NotPreviouslySkippedDep(), } + def prepare_for_execution(self) -> "BaseOperator": + """ + Lock task for execution to disable custom action in __setattr__ and + returns a copy of the task + """ + other = copy.copy(self) + other._lock_for_execution = True # pylint: disable=protected-access + return other + + def set_xcomargs_dependencies(self) -> None: + """ + Resolves upstream dependencies of a task. In this way passing an ``XComArg`` + as value for a template field will result in creating upstream relation between + two tasks. + + **Example**: :: + + with DAG(...): + generate_content = GenerateContentOperator(task_id="generate_content") + send_email = EmailOperator(..., html_content=generate_content.output) + + # This is equivalent to + with DAG(...): + generate_content = GenerateContentOperator(task_id="generate_content") + send_email = EmailOperator( + ..., html_content="{{ task_instance.xcom_pull('generate_content') }}" + ) + generate_content >> send_email + + """ + from airflow.models.xcom_arg import XComArg + + def apply_set_upstream(arg: Any): + if isinstance(arg, XComArg): + self.set_upstream(arg.operator) + elif isinstance(arg, (tuple, set, list)): + for elem in arg: + apply_set_upstream(elem) + elif isinstance(arg, dict): + for elem in arg.values(): + apply_set_upstream(elem) + elif hasattr(arg, "template_fields"): + for elem in arg.template_fields: + apply_set_upstream(elem) + + for field in self.template_fields: + if hasattr(self, field): + arg = getattr(self, field) + apply_set_upstream(arg) + @property def priority_weight_total(self) -> int: """ @@ -1140,7 +1229,7 @@ def set_upstream(self, task_or_task_list: Union['BaseOperator', List['BaseOperat @property def output(self): - """Returns default XComArg for the operator""" + """Returns reference to XCom pushed by current operator""" from airflow.models.xcom_arg import XComArg return XComArg(operator=self) @@ -1205,7 +1294,8 @@ def get_serialized_fields(cls): if not cls.__serialized_fields: cls.__serialized_fields = frozenset( vars(BaseOperator(task_id='test')).keys() - { - 'inlets', 'outlets', '_upstream_task_ids', 'default_args', 'dag', '_dag' + 'inlets', 'outlets', '_upstream_task_ids', 'default_args', 'dag', '_dag', + '_BaseOperator__instantiated', } | {'_task_type', 'subdag', 'ui_color', 'ui_fgcolor', 'template_fields'}) return cls.__serialized_fields diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 514e54dbbd882..6a4d625814054 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -16,7 +16,6 @@ # specific language governing permissions and limitations # under the License. -import copy import getpass import hashlib import logging @@ -970,7 +969,7 @@ def _run_raw_task( if not mark_success: context = self.get_template_context() - task_copy = copy.copy(task) + task_copy = task.prepare_for_execution() # Sensors in `poke` mode can block execution of DAGs when running # with single process executor, thus we change the mode to`reschedule` @@ -1154,7 +1153,7 @@ def run( def dry_run(self): task = self.task - task_copy = copy.copy(task) + task_copy = task.prepare_for_execution() self.task = task_copy self.render_templates() diff --git a/airflow/providers/google/cloud/operators/sql_to_gcs.py b/airflow/providers/google/cloud/operators/sql_to_gcs.py index fc71852670c35..a7bc5297c0e98 100644 --- a/airflow/providers/google/cloud/operators/sql_to_gcs.py +++ b/airflow/providers/google/cloud/operators/sql_to_gcs.py @@ -31,7 +31,7 @@ from airflow.utils.decorators import apply_defaults -class BaseSQLToGCSOperator(BaseOperator, metaclass=abc.ABCMeta): +class BaseSQLToGCSOperator(BaseOperator): """ :param sql: The SQL to execute. :type sql: str diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index b04cb7757e388..305635d0ca501 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -299,7 +299,7 @@ class SerializedBaseOperator(BaseOperator, BaseSerialization): _decorated_fields = {'executor_config'} _CONSTRUCTOR_PARAMS = { - k: v.default for k, v in signature(BaseOperator).parameters.items() + k: v.default for k, v in signature(BaseOperator.__init__).parameters.items() if v.default is not v.empty } @@ -537,7 +537,7 @@ def __get_constructor_defaults(): # pylint: disable=no-method-argument 'access_control': '_access_control', } return { - param_to_attr.get(k, k): v.default for k, v in signature(DAG).parameters.items() + param_to_attr.get(k, k): v.default for k, v in signature(DAG.__init__).parameters.items() if v.default is not v.empty } diff --git a/tests/models/test_baseoperator.py b/tests/models/test_baseoperator.py index c9a07f94542b8..9c003a2a5e4f8 100644 --- a/tests/models/test_baseoperator.py +++ b/tests/models/test_baseoperator.py @@ -15,13 +15,13 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - import unittest import uuid from datetime import date, datetime from unittest import mock import jinja2 +import pytest from parameterized import parameterized from airflow.exceptions import AirflowException @@ -29,6 +29,7 @@ from airflow.models import DAG from airflow.models.baseoperator import chain, cross_downstream from airflow.operators.dummy_operator import DummyOperator +from airflow.utils.decorators import apply_defaults from tests.models import DEFAULT_DATE from tests.test_utils.mock_operators import MockNamedTuple, MockOperator @@ -347,3 +348,61 @@ def test_lineage_composition(self): task4 = DummyOperator(task_id="op4", dag=dag) task4 > [inlet, outlet, extra] self.assertEqual(task4.get_outlet_defs(), [inlet, outlet, extra]) + + +class CustomOp(DummyOperator): + template_fields = ("field", "field2") + + @apply_defaults + def __init__(self, field=None, field2=None, *args, **kwargs): + super().__init__(*args, **kwargs) + self.field = field + self.field2 = field2 + + def execute(self, context): + self.field = None + + +class TestXComArgsRelationsAreResolved: + def test_setattr_performs_no_custom_action_at_execute_time(self): + op = CustomOp(task_id="test_task") + op_copy = op.prepare_for_execution() + + with mock.patch( + "airflow.models.baseoperator.BaseOperator.set_xcomargs_dependencies" + ) as method_mock: + op_copy.execute({}) + assert method_mock.call_count == 0 + + def test_upstream_is_set_when_template_field_is_xcomarg(self): + with DAG("xcomargs_test", default_args={"start_date": datetime.today()}): + op1 = DummyOperator(task_id="op1") + op2 = CustomOp(task_id="op2", field=op1.output) + + assert op1 in op2.upstream_list + assert op2 in op1.downstream_list + + def test_set_xcomargs_dependencies_works_recursively(self): + with DAG("xcomargs_test", default_args={"start_date": datetime.today()}): + op1 = DummyOperator(task_id="op1") + op2 = DummyOperator(task_id="op2") + op3 = CustomOp(task_id="op3", field=[op1.output, op2.output]) + op4 = CustomOp(task_id="op4", field={"op1": op1.output, "op2": op2.output}) + + assert op1 in op3.upstream_list + assert op2 in op3.upstream_list + assert op1 in op4.upstream_list + assert op2 in op4.upstream_list + + def test_set_xcomargs_dependencies_works_when_set_after_init(self): + with DAG(dag_id='xcomargs_test', default_args={"start_date": datetime.today()}): + op1 = DummyOperator(task_id="op1") + op2 = CustomOp(task_id="op2") + op2.field = op1.output # value is set after init + + assert op1 in op2.upstream_list + + def test_set_xcomargs_dependencies_error_when_outside_dag(self): + with pytest.raises(AirflowException): + op1 = DummyOperator(task_id="op1") + CustomOp(task_id="op2", field=op1.output) diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py index 7af8eb90dbe15..152a3c210c4bf 100644 --- a/tests/serialization/test_dag_serialization.py +++ b/tests/serialization/test_dag_serialization.py @@ -726,7 +726,8 @@ def test_no_new_fields_added_to_base_operator(self): """ base_operator = BaseOperator(task_id="10") fields = base_operator.__dict__ - self.assertEqual({'_dag': None, + self.assertEqual({'_BaseOperator__instantiated': True, + '_dag': None, '_downstream_task_ids': set(), '_inlets': [], '_log': base_operator.log,