From d8c0cfea5ff679dc2de55220f8fc500fadef1093 Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Wed, 18 Aug 2021 12:46:49 +0100 Subject: [PATCH] Have the dag_maker fixture (optionally) give SerializedDAGs (#17577) All but one test in test_scheduler_job.py wants to operate on serialized dags, so it makes sense to have this be done in the dag_maker for us, to make each test "smaller". --- pytest.ini | 2 + tests/conftest.py | 113 +++++++-- tests/jobs/test_scheduler_job.py | 420 ++++++------------------------- 3 files changed, 175 insertions(+), 360 deletions(-) diff --git a/pytest.ini b/pytest.ini index a7fd3173138ba..7f3075312b0a8 100644 --- a/pytest.ini +++ b/pytest.ini @@ -32,3 +32,5 @@ faulthandler_timeout = 480 log_level = INFO filterwarnings = error::pytest.PytestCollectionWarning +markers = + need_serialized_dag diff --git a/tests/conftest.py b/tests/conftest.py index c7685d4287751..81d1d37122b3b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -429,7 +429,7 @@ def app(): @pytest.fixture def dag_maker(request): """ - The dag_maker helps us to create DAG & DagModel automatically. + The dag_maker helps us to create DAG, DagModel, and SerializedDAG automatically. You have to use the dag_maker as a context manager and it takes the same argument as DAG:: @@ -451,49 +451,89 @@ def dag_maker(request): The dag_maker.create_dagrun takes the same arguments as dag.create_dagrun + If you want to operate on serialized DAGs, then either pass ``serialized=True` to the ``dag_maker()`` + call, or you can mark your test/class/file with ``@pytest.mark.need_serialized_dag(True)``. In both of + these cases the ``dag`` returned by the context manager will be a lazily-evaluated proxy object to the + SerializedDAG. """ - from airflow.models import DAG, DagModel - from airflow.utils import timezone - from airflow.utils.session import provide_session - from airflow.utils.state import State + import lazy_object_proxy - DEFAULT_DATE = timezone.datetime(2016, 1, 1) + # IMPORTANT: Delay _all_ imports from `airflow.*` to _inside a method_. + # This fixture is "called" early on in the pytest collection process, and + # if we import airflow.* here the wrong (non-test) config will be loaded + # and "baked" in to various constants + + want_serialized = False + + # Allow changing default serialized behaviour with `@ptest.mark.need_serialized_dag` or + # `@ptest.mark.need_serialized_dag(False)` + serialized_marker = request.node.get_closest_marker("need_serialized_dag") + if serialized_marker: + (want_serialized,) = serialized_marker.args or (True,) class DagFactory: + def __init__(self): + from airflow.models import DagBag + + # Keep all the serialized dags we've created in this test + self.dagbag = DagBag(os.devnull, include_examples=False, read_dags_from_db=False) + def __enter__(self): self.dag.__enter__() + if self.want_serialized: + return lazy_object_proxy.Proxy(self._serialized_dag) return self.dag + def _serialized_dag(self): + return self.serialized_model.dag + def __exit__(self, type, value, traceback): + from airflow.models import DagModel + from airflow.models.serialized_dag import SerializedDagModel + dag = self.dag dag.__exit__(type, value, traceback) - if type is None: - dag.clear() - self.dag_model = DagModel( - dag_id=dag.dag_id, - next_dagrun=dag.start_date, - is_active=True, - is_paused=False, - max_active_tasks=dag.max_active_tasks, - has_task_concurrency_limits=False, - ) - self.session.add(self.dag_model) + if type is not None: + return + + dag.clear() + dag.sync_to_db(self.session) + self.dag_model = self.session.query(DagModel).get(dag.dag_id) + + if self.want_serialized: + self.serialized_model = SerializedDagModel(dag) + self.session.merge(self.serialized_model) + serialized_dag = self._serialized_dag() + self.dagbag.bag_dag(serialized_dag, root_dag=serialized_dag) self.session.flush() + else: + self.dagbag.bag_dag(self.dag, self.dag) def create_dagrun(self, **kwargs): + from airflow.utils.state import State + dag = self.dag defaults = dict( run_id='test', state=State.RUNNING, execution_date=self.start_date, start_date=self.start_date, + session=self.session, ) kwargs = {**defaults, **kwargs} self.dag_run = dag.create_dagrun(**kwargs) return self.dag_run - @provide_session - def __call__(self, dag_id='test_dag', session=None, **kwargs): + def __call__( + self, dag_id='test_dag', serialized=want_serialized, fileloc=None, session=None, **kwargs + ): + from airflow import settings + from airflow.models import DAG + from airflow.utils import timezone + + if session is None: + session = settings.Session() + self.kwargs = kwargs self.session = session self.start_date = self.kwargs.get('start_date', None) @@ -506,13 +546,44 @@ def __call__(self, dag_id='test_dag', session=None, **kwargs): if hasattr(request.module, 'DEFAULT_DATE'): self.start_date = getattr(request.module, 'DEFAULT_DATE') else: + DEFAULT_DATE = timezone.datetime(2016, 1, 1) self.start_date = DEFAULT_DATE self.kwargs['start_date'] = self.start_date self.dag = DAG(dag_id, **self.kwargs) - self.dag.fileloc = request.module.__file__ + self.dag.fileloc = fileloc or request.module.__file__ + self.want_serialized = serialized + return self - return DagFactory() + def cleanup(self): + from airflow.models import DagModel, DagRun, TaskInstance + from airflow.models.serialized_dag import SerializedDagModel + + dag_ids = list(self.dagbag.dag_ids) + if not dag_ids: + return + # To isolate problems here with problems from elsewhere on the session object + self.session.flush() + + self.session.query(SerializedDagModel).filter(SerializedDagModel.dag_id.in_(dag_ids)).delete( + synchronize_session=False + ) + self.session.query(DagRun).filter(DagRun.dag_id.in_(dag_ids)).delete(synchronize_session=False) + self.session.query(TaskInstance).filter(TaskInstance.dag_id.in_(dag_ids)).delete( + synchronize_session=False + ) + self.session.query(DagModel).filter(DagModel.dag_id.in_(dag_ids)).delete( + synchronize_session=False + ) + self.session.commit() + + factory = DagFactory() + + try: + yield factory + finally: + factory.cleanup() + del factory.session @pytest.fixture diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py index 9030005aad6e4..f8e8e53902528 100644 --- a/tests/jobs/test_scheduler_job.py +++ b/tests/jobs/test_scheduler_job.py @@ -96,6 +96,7 @@ def dagbag(): @pytest.mark.usefixtures("disable_load_example") +@pytest.mark.need_serialized_dag class TestSchedulerJob: @staticmethod def clean_db(): @@ -189,12 +190,11 @@ def test_process_executor_events(self, mock_stats_incr, mock_task_callback, dag_ dag_id2 = "test_process_executor_events_2" task_id_1 = 'dummy_task' - with dag_maker(dag_id=dag_id, full_filepath="/test_path1/") as dag: + with dag_maker(dag_id=dag_id, fileloc='/test_path1/'): task1 = DummyOperator(task_id=task_id_1) - with dag_maker(dag_id=dag_id2, full_filepath="/test_path1/") as dag2: + with dag_maker(dag_id=dag_id2, fileloc='/test_path1/'): DummyOperator(task_id=task_id_1) - dag.fileloc = "/test_path1/" - dag2.fileloc = "/test_path1/" + mock_stats_incr.reset_mock() executor = MockExecutor(do_update=False) task_callback = mock.MagicMock() @@ -203,8 +203,6 @@ def test_process_executor_events(self, mock_stats_incr, mock_task_callback, dag_ self.scheduler_job.processor_agent = mock.MagicMock() session = settings.Session() - dag.sync_to_db(session=session) - dag2.sync_to_db(session=session) ti1 = TaskInstance(task1, DEFAULT_DATE) ti1.state = State.QUEUED @@ -270,7 +268,7 @@ def test_execute_task_instances_is_paused_wont_execute(self, dag_maker): with dag_maker(dag_id=dag_id) as dag: task1 = DummyOperator(task_id=task_id_1) - dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) + assert isinstance(dag, SerializedDAG) self.scheduler_job = SchedulerJob(subdir=os.devnull) session = settings.Session() @@ -295,9 +293,8 @@ def test_execute_task_instances_no_dagrun_task_will_execute(self, dag_maker): dag_id = 'SchedulerJobTest.test_execute_task_instances_no_dagrun_task_will_execute' task_id_1 = 'dummy_task' - with dag_maker(dag_id=dag_id) as dag: + with dag_maker(dag_id=dag_id): task1 = DummyOperator(task_id=task_id_1) - dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) self.scheduler_job = SchedulerJob(subdir=os.devnull) session = settings.Session() @@ -321,9 +318,8 @@ def test_execute_task_instances_backfill_tasks_wont_execute(self, dag_maker): dag_id = 'SchedulerJobTest.test_execute_task_instances_backfill_tasks_wont_execute' task_id_1 = 'dummy_task' - with dag_maker(dag_id=dag_id) as dag: + with dag_maker(dag_id=dag_id): task1 = DummyOperator(task_id=task_id_1) - dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) self.scheduler_job = SchedulerJob(subdir=os.devnull) session = settings.Session() @@ -350,7 +346,6 @@ def test_find_executable_task_instances_backfill_nodagrun(self, dag_maker): task_id_1 = 'dummy' with dag_maker(dag_id=dag_id, max_active_tasks=16) as dag: task1 = DummyOperator(task_id=task_id_1) - dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) self.scheduler_job = SchedulerJob(subdir=os.devnull) session = settings.Session() @@ -390,7 +385,6 @@ def test_find_executable_task_instances_pool(self, dag_maker): with dag_maker(dag_id=dag_id, max_active_tasks=16) as dag: task1 = DummyOperator(task_id=task_id_1, pool='a') task2 = DummyOperator(task_id=task_id_2, pool='b') - dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) self.scheduler_job = SchedulerJob(subdir=os.devnull) session = settings.Session() @@ -437,17 +431,14 @@ def test_find_executable_task_instances_order_execution_date(self, dag_maker): dag_id_1 = 'SchedulerJobTest.test_find_executable_task_instances_order_execution_date-a' dag_id_2 = 'SchedulerJobTest.test_find_executable_task_instances_order_execution_date-b' task_id = 'task-a' - with dag_maker(dag_id=dag_id_1, max_active_tasks=16) as dag_1: + with dag_maker(dag_id=dag_id_1, max_active_tasks=16): dag1_task = DummyOperator(task_id=task_id) dr1 = dag_maker.create_dagrun(execution_date=DEFAULT_DATE + timedelta(hours=1)) - with dag_maker(dag_id=dag_id_2, max_active_tasks=16) as dag_2: + with dag_maker(dag_id=dag_id_2, max_active_tasks=16): dag2_task = DummyOperator(task_id=task_id) dr2 = dag_maker.create_dagrun() - dag_1 = SerializedDAG.from_dict(SerializedDAG.to_dict(dag_1)) - dag_2 = SerializedDAG.from_dict(SerializedDAG.to_dict(dag_2)) - self.scheduler_job = SchedulerJob(subdir=os.devnull) session = settings.Session() @@ -469,17 +460,14 @@ def test_find_executable_task_instances_order_priority(self, dag_maker): dag_id_1 = 'SchedulerJobTest.test_find_executable_task_instances_order_priority-a' dag_id_2 = 'SchedulerJobTest.test_find_executable_task_instances_order_priority-b' task_id = 'task-a' - with dag_maker(dag_id=dag_id_1, max_active_tasks=16) as dag_1: + with dag_maker(dag_id=dag_id_1, max_active_tasks=16): dag1_task = DummyOperator(task_id=task_id, priority_weight=1) dr1 = dag_maker.create_dagrun() - with dag_maker(dag_id=dag_id_2, max_active_tasks=16) as dag_2: + with dag_maker(dag_id=dag_id_2, max_active_tasks=16): dag2_task = DummyOperator(task_id=task_id, priority_weight=4) dr2 = dag_maker.create_dagrun() - dag_1 = SerializedDAG.from_dict(SerializedDAG.to_dict(dag_1)) - dag_2 = SerializedDAG.from_dict(SerializedDAG.to_dict(dag_2)) - self.scheduler_job = SchedulerJob(subdir=os.devnull) session = settings.Session() @@ -501,15 +489,13 @@ def test_find_executable_task_instances_order_execution_date_and_priority(self, dag_id_1 = 'SchedulerJobTest.test_find_executable_task_instances_order_execution_date_and_priority-a' dag_id_2 = 'SchedulerJobTest.test_find_executable_task_instances_order_execution_date_and_priority-b' task_id = 'task-a' - with dag_maker(dag_id=dag_id_1, max_active_tasks=16) as dag_1: + with dag_maker(dag_id=dag_id_1, max_active_tasks=16): dag1_task = DummyOperator(task_id=task_id, priority_weight=1) dr1 = dag_maker.create_dagrun() - with dag_maker(dag_id=dag_id_2, max_active_tasks=16) as dag_2: + with dag_maker(dag_id=dag_id_2, max_active_tasks=16): dag2_task = DummyOperator(task_id=task_id, priority_weight=4) dr2 = dag_maker.create_dagrun(execution_date=DEFAULT_DATE + timedelta(hours=1)) - dag_1 = SerializedDAG.from_dict(SerializedDAG.to_dict(dag_1)) - dag_2 = SerializedDAG.from_dict(SerializedDAG.to_dict(dag_2)) self.scheduler_job = SchedulerJob(subdir=os.devnull) session = settings.Session() @@ -535,7 +521,6 @@ def test_find_executable_task_instances_in_default_pool(self, dag_maker): with dag_maker(dag_id=dag_id) as dag: op1 = DummyOperator(task_id='dummy1') op2 = DummyOperator(task_id='dummy2') - dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) executor = MockExecutor(do_update=True) self.scheduler_job = SchedulerJob(executor=executor) @@ -575,9 +560,8 @@ def test_find_executable_task_instances_in_default_pool(self, dag_maker): def test_nonexistent_pool(self, dag_maker): dag_id = 'SchedulerJobTest.test_nonexistent_pool' task_id = 'dummy_wrong_pool' - with dag_maker(dag_id=dag_id, max_active_tasks=16) as dag: + with dag_maker(dag_id=dag_id, max_active_tasks=16): task = DummyOperator(task_id=task_id, pool="this_pool_doesnt_exist") - dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) self.scheduler_job = SchedulerJob(subdir=os.devnull) session = settings.Session() @@ -597,9 +581,8 @@ def test_nonexistent_pool(self, dag_maker): def test_infinite_pool(self, dag_maker): dag_id = 'SchedulerJobTest.test_infinite_pool' task_id = 'dummy' - with dag_maker(dag_id=dag_id, concurrency=16) as dag: + with dag_maker(dag_id=dag_id, concurrency=16): task = DummyOperator(task_id=task_id, pool="infinite_pool") - dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) self.scheduler_job = SchedulerJob(subdir=os.devnull) session = settings.Session() @@ -620,9 +603,8 @@ def test_infinite_pool(self, dag_maker): def test_find_executable_task_instances_none(self, dag_maker): dag_id = 'SchedulerJobTest.test_find_executable_task_instances_none' task_id_1 = 'dummy' - with dag_maker(dag_id=dag_id, max_active_tasks=16) as dag: + with dag_maker(dag_id=dag_id, max_active_tasks=16): DummyOperator(task_id=task_id_1) - dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) self.scheduler_job = SchedulerJob(subdir=os.devnull) session = settings.Session() @@ -635,7 +617,6 @@ def test_find_executable_task_instances_concurrency(self, dag_maker): task_id_1 = 'dummy' with dag_maker(dag_id=dag_id, max_active_tasks=2) as dag: task1 = DummyOperator(task_id=task_id_1) - dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) self.scheduler_job = SchedulerJob(subdir=os.devnull) session = settings.Session() @@ -681,11 +662,10 @@ def test_find_executable_task_instances_concurrency(self, dag_maker): def test_find_executable_task_instances_concurrency_queued(self, dag_maker): dag_id = 'SchedulerJobTest.test_find_executable_task_instances_concurrency_queued' - with dag_maker(dag_id=dag_id, max_active_tasks=3) as dag: + with dag_maker(dag_id=dag_id, max_active_tasks=3): task1 = DummyOperator(task_id='dummy1') task2 = DummyOperator(task_id='dummy2') task3 = DummyOperator(task_id='dummy3') - dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) self.scheduler_job = SchedulerJob(subdir=os.devnull) session = settings.Session() @@ -724,9 +704,6 @@ def test_find_executable_task_instances_task_concurrency(self, dag_maker): self.scheduler_job = SchedulerJob(executor=executor) session = settings.Session() - self.scheduler_job.dagbag.bag_dag(dag, root_dag=dag) - self.scheduler_job.dagbag.sync_to_db(session=session) - dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) dr1 = dag_maker.create_dagrun() dr2 = dag.create_dagrun( run_type=DagRunType.SCHEDULED, @@ -806,7 +783,6 @@ def test_change_state_for_executable_task_instances_no_tis_with_state(self, dag_ task_id_1 = 'dummy' with dag_maker(dag_id=dag_id, max_active_tasks=2) as dag: task1 = DummyOperator(task_id=task_id_1) - dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) self.scheduler_job = SchedulerJob(subdir=os.devnull) session = settings.Session() @@ -846,9 +822,8 @@ def test_change_state_for_executable_task_instances_no_tis_with_state(self, dag_ def test_enqueue_task_instances_with_queued_state(self, dag_maker): dag_id = 'SchedulerJobTest.test_enqueue_task_instances_with_queued_state' task_id_1 = 'dummy' - with dag_maker(dag_id=dag_id, start_date=DEFAULT_DATE) as dag: + with dag_maker(dag_id=dag_id, start_date=DEFAULT_DATE): task1 = DummyOperator(task_id=task_id_1) - dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) self.scheduler_job = SchedulerJob(subdir=os.devnull) session = settings.Session() @@ -877,7 +852,6 @@ def test_critical_section_execute_task_instances(self, dag_maker): with dag_maker(dag_id=dag_id, max_active_tasks=3) as dag: task1 = DummyOperator(task_id=task_id_1) task2 = DummyOperator(task_id=task_id_2) - dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) self.scheduler_job = SchedulerJob(subdir=os.devnull) session = settings.Session() @@ -944,7 +918,6 @@ def test_execute_task_instances_limit(self, dag_maker): with dag_maker(dag_id=dag_id, max_active_tasks=16) as dag: task1 = DummyOperator(task_id=task_id_1) task2 = DummyOperator(task_id=task_id_2) - dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) self.scheduler_job = SchedulerJob(subdir=os.devnull) session = settings.Session() @@ -998,7 +971,6 @@ def test_execute_task_instances_unlimited(self, dag_maker): with dag_maker(dag_id=dag_id, max_active_tasks=1024) as dag: task1 = DummyOperator(task_id=task_id_1) task2 = DummyOperator(task_id=task_id_2) - dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) self.scheduler_job = SchedulerJob(subdir=os.devnull) session = settings.Session() @@ -1032,12 +1004,12 @@ def test_execute_task_instances_unlimited(self, dag_maker): session.rollback() def test_change_state_for_tis_without_dagrun(self, dag_maker): - with dag_maker(dag_id='test_change_state_for_tis_without_dagrun') as dag1: + with dag_maker(dag_id='test_change_state_for_tis_without_dagrun'): DummyOperator(task_id='dummy') DummyOperator(task_id='dummy_b') dr1 = dag_maker.create_dagrun() - with dag_maker(dag_id='test_change_state_for_tis_without_dagrun_dont_change') as dag2: + with dag_maker(dag_id='test_change_state_for_tis_without_dagrun_dont_change'): DummyOperator(task_id='dummy') dr2 = dag_maker.create_dagrun() @@ -1062,13 +1034,6 @@ def test_change_state_for_tis_without_dagrun(self, dag_maker): session.merge(ti3) session.commit() - dagbag = DagBag("/dev/null", include_examples=False, read_dags_from_db=False) - dagbag.bag_dag(dag1, root_dag=dag1) - dagbag.bag_dag(dag2, root_dag=dag2) - dagbag.bag_dag(dag3, root_dag=dag3) - dagbag.sync_to_db(session) - session.commit() - self.scheduler_job = SchedulerJob(num_runs=0) self.scheduler_job.dagbag.collect_dags_from_db() @@ -1122,7 +1087,6 @@ def test_adopt_or_reset_orphaned_tasks(self, dag_maker): session = settings.Session() with dag_maker('test_execute_helper_reset_orphaned_tasks') as dag: op1 = DummyOperator(task_id='op1') - dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) dr = dag_maker.create_dagrun() dr2 = dag.create_dagrun( @@ -1168,15 +1132,9 @@ def test_scheduler_loop_should_change_state_for_tis_without_dagrun( with dag_maker( dag_id, start_date=DEFAULT_DATE + timedelta(days=1), - ) as dag: + ): op1 = DummyOperator(task_id='op1') - # Write Dag to DB - dagbag = DagBag(dag_folder="/dev/null", include_examples=False, read_dags_from_db=False) - dagbag.bag_dag(dag, root_dag=dag) - dagbag.sync_to_db() - - dag = DagBag(read_dags_from_db=True, include_examples=False).get_dag(dag_id) # Create DAG run with FAILED state dr = dag_maker.create_dagrun( state=State.FAILED, @@ -1190,7 +1148,7 @@ def test_scheduler_loop_should_change_state_for_tis_without_dagrun( # This poll interval is large, bug the scheduler doesn't sleep that # long, instead we hit the clean_tis_without_dagrun interval instead self.scheduler_job = SchedulerJob(num_runs=2, processor_poll_interval=30) - self.scheduler_job.dagbag = dagbag + self.scheduler_job.dagbag = dag_maker.dagbag executor = MockExecutor(do_update=False) executor.queued_tasks self.scheduler_job.executor = executor @@ -1253,22 +1211,18 @@ def test_dagrun_timeout_verify_max_active_runs(self, dag_maker): with dag_maker( dag_id='test_scheduler_verify_max_active_runs_and_dagrun_timeout', start_date=DEFAULT_DATE, + max_active_runs=1, + dagrun_timeout=datetime.timedelta(seconds=60), ) as dag: DummyOperator(task_id='dummy') - dag.max_active_runs = 1 - dag.dagrun_timeout = datetime.timedelta(seconds=60) self.scheduler_job = SchedulerJob(subdir=os.devnull) - self.scheduler_job.dagbag.bag_dag(dag, root_dag=dag) - - self.scheduler_job.dagbag.sync_to_db() + self.scheduler_job.dagbag = dag_maker.dagbag session = settings.Session() orm_dag = session.query(DagModel).get(dag.dag_id) assert orm_dag is not None - dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) - self.scheduler_job._create_dag_runs([orm_dag], session) self.scheduler_job._start_queued_dagruns(session) @@ -1321,30 +1275,18 @@ def test_dagrun_timeout_fails_run(self, dag_maker): """ Test if a a dagrun will be set failed if timeout, even without max_active_runs """ - with dag_maker(dag_id='test_scheduler_fail_dagrun_timeout') as dag: - DummyOperator(task_id='dummy') - dag.dagrun_timeout = datetime.timedelta(seconds=60) - - self.scheduler_job = SchedulerJob(subdir=os.devnull) - self.scheduler_job.dagbag.bag_dag(dag, root_dag=dag) - self.scheduler_job.dagbag.sync_to_db() - session = settings.Session() - orm_dag = session.query(DagModel).get(dag.dag_id) - assert orm_dag is not None + with dag_maker( + dag_id='test_scheduler_fail_dagrun_timeout', + dagrun_timeout=datetime.timedelta(seconds=60), + session=session, + ): + DummyOperator(task_id='dummy') - dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) + dr = dag_maker.create_dagrun(start_date=timezone.utcnow() - datetime.timedelta(days=1)) self.scheduler_job = SchedulerJob(subdir=os.devnull) - self.scheduler_job._create_dag_runs([orm_dag], session) - - drs = DagRun.find(dag_id=dag.dag_id, session=session) - assert len(drs) == 1 - dr = drs[0] - - # Should be scheduled as dagrun_timeout has passed - dr.start_date = timezone.utcnow() - datetime.timedelta(days=1) - session.flush() + self.scheduler_job.dagbag = dag_maker.dagbag # Mock that processor_agent is started self.scheduler_job.processor_agent = mock.Mock() @@ -1386,25 +1328,13 @@ def test_dagrun_callbacks_are_called(self, state, expected_callback_msg, dag_mak DummyOperator(task_id='dummy') self.scheduler_job = SchedulerJob(subdir=os.devnull) + self.scheduler_job.dagbag = dag_maker.dagbag self.scheduler_job.processor_agent = mock.Mock() self.scheduler_job.processor_agent.send_callback_to_execute = mock.Mock() self.scheduler_job._send_sla_callbacks_to_processor = mock.Mock() - # Sync DAG into DB - with mock.patch.object(settings, "STORE_DAG_CODE", False): - self.scheduler_job.dagbag.bag_dag(dag, root_dag=dag) - self.scheduler_job.dagbag.sync_to_db() - session = settings.Session() - orm_dag = session.query(DagModel).get(dag.dag_id) - assert orm_dag is not None - - # Create DagRun - self.scheduler_job._create_dag_runs([orm_dag], session) - - drs = DagRun.find(dag_id=dag.dag_id, session=session) - assert len(drs) == 1 - dr = drs[0] + dr = dag_maker.create_dagrun() ti = dr.get_task_instance('dummy') ti.set_state(state, session) @@ -1424,7 +1354,9 @@ def test_dagrun_callbacks_are_called(self, state, expected_callback_msg, dag_mak self.scheduler_job.processor_agent.send_callback_to_execute.assert_called_once_with(expected_callback) # This is already tested separately # In this test we just want to verify that this function is called - self.scheduler_job._send_sla_callbacks_to_processor.assert_called_once_with(dag) + # `dag` is a lazy-object-proxy -- we need to resolve it + real_dag = self.scheduler_job.dagbag.get_dag(dag.dag_id) + self.scheduler_job._send_sla_callbacks_to_processor.assert_called_once_with(real_dag) session.rollback() session.close() @@ -1434,7 +1366,7 @@ def test_dagrun_callbacks_commited_before_sent(self, dag_maker): Tests that before any callbacks are sent to the processor, the session is committed. This ensures that the dagrun details are up to date when the callbacks are run. """ - with dag_maker(dag_id='test_dagrun_callbacks_commited_before_sent') as dag: + with dag_maker(dag_id='test_dagrun_callbacks_commited_before_sent'): DummyOperator(task_id='dummy') self.scheduler_job = SchedulerJob(subdir=os.devnull) @@ -1442,21 +1374,8 @@ def test_dagrun_callbacks_commited_before_sent(self, dag_maker): self.scheduler_job._send_dag_callbacks_to_processor = mock.Mock() self.scheduler_job._schedule_dag_run = mock.Mock() - # Sync DAG into DB - with mock.patch.object(settings, "STORE_DAG_CODE", False): - self.scheduler_job.dagbag.bag_dag(dag, root_dag=dag) - self.scheduler_job.dagbag.sync_to_db() - + dr = dag_maker.create_dagrun() session = settings.Session() - orm_dag = session.query(DagModel).get(dag.dag_id) - assert orm_dag is not None - - # Create DagRun - self.scheduler_job._create_dag_runs([orm_dag], session) - - drs = DagRun.find(dag_id=dag.dag_id, session=session) - assert len(drs) == 1 - dr = drs[0] ti = dr.get_task_instance('dummy') ti.set_state(State.SUCCESS, session) @@ -1494,7 +1413,7 @@ def test_dagrun_callbacks_are_not_added_when_callbacks_are_not_defined(self, sta """ with dag_maker( dag_id='test_dagrun_callbacks_are_not_added_when_callbacks_are_not_defined', - ) as dag: + ): BashOperator(task_id='test_task', bash_command='echo hi') self.scheduler_job = SchedulerJob(subdir=os.devnull) @@ -1502,22 +1421,8 @@ def test_dagrun_callbacks_are_not_added_when_callbacks_are_not_defined(self, sta self.scheduler_job.processor_agent.send_callback_to_execute = mock.Mock() self.scheduler_job._send_dag_callbacks_to_processor = mock.Mock() - # Sync DAG into DB - with mock.patch.object(settings, "STORE_DAG_CODE", False): - self.scheduler_job.dagbag.bag_dag(dag, root_dag=dag) - self.scheduler_job.dagbag.sync_to_db() - session = settings.Session() - orm_dag = session.query(DagModel).get(dag.dag_id) - assert orm_dag is not None - - # Create DagRun - self.scheduler_job._create_dag_runs([orm_dag], session) - - drs = DagRun.find(dag_id=dag.dag_id, session=session) - assert len(drs) == 1 - dr = drs[0] - + dr = dag_maker.create_dagrun() ti = dr.get_task_instance('test_task') ti.set_state(state, session) @@ -1540,8 +1445,6 @@ def test_do_not_schedule_removed_task(self, dag_maker): session = settings.Session() - dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) - dr = dag_maker.create_dagrun() assert dr is not None @@ -1551,7 +1454,7 @@ def test_do_not_schedule_removed_task(self, dag_maker): with dag_maker( dag_id='test_scheduler_do_not_schedule_removed_task', start_date=dag.following_schedule(DEFAULT_DATE), - ) as dag: + ): pass self.scheduler_job = SchedulerJob(subdir=os.devnull) @@ -1879,22 +1782,11 @@ def test_scheduler_verify_pool_full(self, dag_maker): bash_command='echo hi', ) - dagbag = DagBag( - dag_folder=os.path.join(settings.DAGS_FOLDER, "no_dags.py"), - include_examples=False, - read_dags_from_db=True, - ) - dagbag.bag_dag(dag=dag, root_dag=dag) - dagbag.sync_to_db() - session = settings.Session() pool = Pool(pool='test_scheduler_verify_pool_full', slots=1) session.add(pool) session.flush() - dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) - SerializedDagModel.write_dag(dag) - self.scheduler_job = SchedulerJob(executor=self.null_exec) self.scheduler_job.processor_agent = mock.MagicMock() @@ -1929,21 +1821,10 @@ def test_scheduler_verify_pool_full_2_slots_per_task(self, dag_maker): bash_command='echo hi', ) - dagbag = DagBag( - dag_folder=os.path.join(settings.DAGS_FOLDER, "no_dags.py"), - include_examples=False, - read_dags_from_db=True, - ) - dagbag.bag_dag(dag=dag, root_dag=dag) - dagbag.sync_to_db() - session = settings.Session() pool = Pool(pool='test_scheduler_verify_pool_full_2_slots_per_task', slots=6) session.add(pool) - session.commit() - SerializedDagModel.write_dag(dag) - - dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) + session.flush() self.scheduler_job = SchedulerJob(executor=self.null_exec) self.scheduler_job.processor_agent = mock.MagicMock() @@ -1984,23 +1865,13 @@ def test_scheduler_keeps_scheduling_pool_full(self, dag_maker): pool='test_scheduler_keeps_scheduling_pool_full_p2', bash_command='echo hi', ) - dagbag = DagBag( - dag_folder=os.path.join(settings.DAGS_FOLDER, "no_dags.py"), - include_examples=False, - read_dags_from_db=True, - ) - dagbag.bag_dag(dag=dag_d1, root_dag=dag_d1) - dagbag.bag_dag(dag=dag_d2, root_dag=dag_d2) - dagbag.sync_to_db() session = settings.Session() pool_p1 = Pool(pool='test_scheduler_keeps_scheduling_pool_full_p1', slots=1) pool_p2 = Pool(pool='test_scheduler_keeps_scheduling_pool_full_p2', slots=10) session.add(pool_p1) session.add(pool_p2) - session.commit() - - dag_d1 = SerializedDAG.from_dict(SerializedDAG.to_dict(dag_d1)) + session.flush() scheduler = SchedulerJob(executor=self.null_exec) scheduler.processor_agent = mock.MagicMock() @@ -2045,7 +1916,7 @@ def test_scheduler_verify_priority_and_slots(self, dag_maker): Though tasks with lower priority might be executed. """ - with dag_maker(dag_id='test_scheduler_verify_priority_and_slots') as dag: + with dag_maker(dag_id='test_scheduler_verify_priority_and_slots'): # Medium priority, not enough slots BashOperator( task_id='test_scheduler_verify_priority_and_slots_t0', @@ -2071,21 +1942,10 @@ def test_scheduler_verify_priority_and_slots(self, dag_maker): bash_command='echo hi', ) - dagbag = DagBag( - dag_folder=os.path.join(settings.DAGS_FOLDER, "no_dags.py"), - include_examples=False, - read_dags_from_db=True, - ) - dagbag.bag_dag(dag=dag, root_dag=dag) - dagbag.sync_to_db() - session = settings.Session() pool = Pool(pool='test_scheduler_verify_priority_and_slots', slots=2) session.add(pool) - session.commit() - - dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) - SerializedDagModel.write_dag(dag) + session.flush() self.scheduler_job = SchedulerJob(executor=self.null_exec) self.scheduler_job.processor_agent = mock.MagicMock() @@ -2132,11 +1992,9 @@ def test_verify_integrity_if_dag_not_changed(self, dag_maker): BashOperator(task_id='dummy', bash_command='echo hi') self.scheduler_job = SchedulerJob(subdir=os.devnull) - self.scheduler_job.dagbag.bag_dag(dag, root_dag=dag) - self.scheduler_job.dagbag.sync_to_db() session = settings.Session() - orm_dag = session.query(DagModel).get(dag.dag_id) + orm_dag = dag_maker.dag_model assert orm_dag is not None self.scheduler_job = SchedulerJob(subdir=os.devnull) @@ -2184,11 +2042,9 @@ def test_verify_integrity_if_dag_changed(self, dag_maker): BashOperator(task_id='dummy', bash_command='echo hi') self.scheduler_job = SchedulerJob(subdir=os.devnull) - self.scheduler_job.dagbag.bag_dag(dag, root_dag=dag) - self.scheduler_job.dagbag.sync_to_db() session = settings.Session() - orm_dag = session.query(DagModel).get(dag.dag_id) + orm_dag = dag_maker.dag_model assert orm_dag is not None self.scheduler_job = SchedulerJob(subdir=os.devnull) @@ -2255,16 +2111,12 @@ def test_retry_still_in_executor(self, dag_maker): bash_command='exit 1', retries=1, ) - dag.is_subdag = False with create_session() as session: orm_dag = DagModel(dag_id=dag.dag_id) orm_dag.is_paused = False session.merge(orm_dag) - dagbag.bag_dag(dag=dag, root_dag=dag) - dagbag.sync_to_db() - @mock.patch('airflow.dag_processing.processor.DagBag', return_value=dagbag) def do_schedule(mock_dagbag): # Use a empty file since the above mock will return the @@ -2628,9 +2480,6 @@ def test_send_sla_callbacks_to_processor_sla_with_task_slas(self, dag_maker): with dag_maker(dag_id=dag_id, schedule_interval='@daily') as dag: DummyOperator(task_id='task1', sla=timedelta(seconds=60)) - # Used Serialized DAG as Serialized DAG is used in Scheduler - dag = SerializedDAG.from_json(SerializedDAG.to_json(dag)) - with patch.object(settings, "CHECK_SLAS", True): self.scheduler_job = SchedulerJob(subdir=os.devnull) mock_agent = mock.MagicMock() @@ -2653,14 +2502,7 @@ def test_create_dag_runs(self, dag_maker): with dag_maker(dag_id='test_create_dag_runs') as dag: DummyOperator(task_id='dummy') - dagbag = DagBag( - dag_folder=os.devnull, - include_examples=False, - read_dags_from_db=True, - ) - dagbag.bag_dag(dag=dag, root_dag=dag) - dagbag.sync_to_db() - dag_model = DagModel.get_dagmodel(dag.dag_id) + dag_model = dag_maker.dag_model self.scheduler_job = SchedulerJob(executor=self.null_exec) self.scheduler_job.processor_agent = mock.MagicMock() @@ -2689,14 +2531,7 @@ def test_start_dagruns(self, stats_timing, dag_maker): task_id='dummy', ) - dagbag = DagBag( - dag_folder=os.devnull, - include_examples=False, - read_dags_from_db=True, - ) - dagbag.bag_dag(dag=dag, root_dag=dag) - dagbag.sync_to_db() - dag_model = DagModel.get_dagmodel(dag.dag_id) + dag_model = dag_maker.dag_model self.scheduler_job = SchedulerJob(executor=self.null_exec) self.scheduler_job.processor_agent = mock.MagicMock() @@ -2726,17 +2561,7 @@ def test_extra_operator_links_not_loaded_in_scheduler_loop(self, dag_maker): # This CustomOperator has Extra Operator Links registered via plugins _ = CustomOperator(task_id='custom_task') - dagbag = DagBag( - dag_folder=os.path.join(settings.DAGS_FOLDER, "no_dags.py"), - include_examples=False, - read_dags_from_db=True, - ) - dagbag.bag_dag(dag=dag, root_dag=dag) - dagbag.sync_to_db() - - # Get serialized dag - s_dag_1 = dagbag.get_dag(dag.dag_id) - custom_task = s_dag_1.task_dict['custom_task'] + custom_task = dag.task_dict['custom_task'] # Test that custom_task has >= 1 Operator Links (after de-serialization) assert custom_task.operator_extra_links @@ -2755,21 +2580,11 @@ def test_scheduler_create_dag_runs_does_not_raise_error(self, caplog, dag_maker) Test that scheduler._create_dag_runs does not raise an error when the DAG does not exist in serialized_dag table """ - with dag_maker(dag_id='test_scheduler_create_dag_runs_does_not_raise_error') as dag: + with dag_maker(dag_id='test_scheduler_create_dag_runs_does_not_raise_error', serialized=False): DummyOperator( task_id='dummy', ) - dagbag = DagBag( - dag_folder=os.devnull, - include_examples=False, - read_dags_from_db=False, - ) - dagbag.bag_dag(dag=dag, root_dag=dag) - # Only write to dag table and not serialized_dag table - DAG.bulk_write_to_db(dagbag.dags.values()) - dag_model = DagModel.get_dagmodel(dag.dag_id) - self.scheduler_job = SchedulerJob(subdir=os.devnull, executor=self.null_exec) self.scheduler_job.processor_agent = mock.MagicMock() @@ -2779,7 +2594,7 @@ def test_scheduler_create_dag_runs_does_not_raise_error(self, caplog, dag_maker) 'ERROR', logger='airflow.jobs.scheduler_job', ): - self.scheduler_job._create_dag_runs([dag_model], session) + self.scheduler_job._create_dag_runs([dag_maker.dag_model], session) assert caplog.messages == [ "DAG 'test_scheduler_create_dag_runs_does_not_raise_error' not found in serialized_dag table", ] @@ -2798,18 +2613,9 @@ def test_bulk_write_to_db_external_trigger_dont_skip_scheduled_run(self, dag_mak DummyOperator(task_id='dummy') session = settings.Session() - dagbag = DagBag( - dag_folder=os.path.join(settings.DAGS_FOLDER, "no_dags.py"), - include_examples=False, - read_dags_from_db=True, - ) - dagbag.bag_dag(dag=dag, root_dag=dag) - # Write to dag and serialized_dag table - dagbag.sync_to_db(session) - dag = dagbag.get_dag(dag.dag_id) # Verify that dag_model.next_dagrun is equal to next execution_date - dag_model = session.query(DagModel).get(dag.dag_id) + dag_model = dag_maker.dag_model assert dag_model.next_dagrun == DEFAULT_DATE assert dag_model.next_dagrun_data_interval_start == DEFAULT_DATE assert dag_model.next_dagrun_data_interval_end == DEFAULT_DATE + timedelta(minutes=1) @@ -2870,22 +2676,11 @@ def test_scheduler_create_dag_runs_check_existing_run(self, dag_maker): session = settings.Session() assert dag.get_last_dagrun(session) is None - dagbag = DagBag( - dag_folder=os.devnull, - include_examples=False, - read_dags_from_db=False, - ) - dagbag.bag_dag(dag=dag, root_dag=dag) - - # Create DagModel - DAG.bulk_write_to_db(dagbag.dags.values()) - dag_model = DagModel.get_dagmodel(dag.dag_id) + dag_model = dag_maker.dag_model # Assert dag_model.next_dagrun is set correctly assert dag_model.next_dagrun == DEFAULT_DATE - dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) - dagrun = dag_maker.create_dagrun( run_type=DagRunType.SCHEDULED, execution_date=dag_model.next_dagrun, @@ -2900,7 +2695,6 @@ def test_scheduler_create_dag_runs_check_existing_run(self, dag_maker): assert dag.get_last_dagrun(session) == dagrun self.scheduler_job = SchedulerJob(subdir=os.devnull, executor=self.null_exec) - self.scheduler_job.dagbag = dagbag self.scheduler_job.processor_agent = mock.MagicMock() # Test that this does not raise any error @@ -2929,14 +2723,6 @@ def test_do_schedule_max_active_runs_dag_timed_out(self, dag_maker): ) session = settings.Session() - dagbag = DagBag( - dag_folder=os.devnull, - include_examples=False, - read_dags_from_db=True, - ) - - dagbag.bag_dag(dag=dag, root_dag=dag) - dagbag.sync_to_db(session=session) run1 = dag.create_dagrun( run_type=DagRunType.SCHEDULED, @@ -2988,15 +2774,6 @@ def test_do_schedule_max_active_runs_task_removed(self, dag_maker): task1 = BashOperator(task_id='dummy1', bash_command='true') session = settings.Session() - dagbag = DagBag( - dag_folder=os.devnull, - include_examples=False, - read_dags_from_db=True, - ) - - dagbag.bag_dag(dag=dag, root_dag=dag) - dagbag.sync_to_db(session=session) - session.add(TaskInstance(task1, DEFAULT_DATE, State.REMOVED)) session.flush() @@ -3039,15 +2816,6 @@ def test_do_schedule_max_active_runs_and_manual_trigger(self, dag_maker): BashOperator(task_id='dummy3', bash_command='true') session = settings.Session() - dagbag = DagBag( - dag_folder=os.path.join(settings.DAGS_FOLDER, "no_dags.py"), - include_examples=False, - read_dags_from_db=True, - ) - - dagbag.bag_dag(dag=dag, root_dag=dag) - dagbag.sync_to_db(session=session) - dag_run = dag_maker.create_dagrun( state=State.QUEUED, session=session, @@ -3106,14 +2874,11 @@ def test_dag_file_processor_process_task_instances(self, state, start_date, end_ Test if _process_task_instances puts the right task instances into the mock_list. """ - with dag_maker(dag_id='test_scheduler_process_execute_task') as dag: + with dag_maker(dag_id='test_scheduler_process_execute_task'): BashOperator(task_id='dummy', bash_command='echo hi') - dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) - self.scheduler_job = SchedulerJob(subdir=os.devnull) self.scheduler_job.processor_agent = mock.MagicMock() - self.scheduler_job.dagbag.bag_dag(dag, root_dag=dag) dr = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED) assert dr is not None @@ -3153,14 +2918,11 @@ def test_dag_file_processor_process_task_instances_with_task_concurrency( Test if _process_task_instances puts the right task instances into the mock_list. """ - with dag_maker(dag_id='test_scheduler_process_execute_task_with_task_concurrency') as dag: + with dag_maker(dag_id='test_scheduler_process_execute_task_with_task_concurrency'): BashOperator(task_id='dummy', task_concurrency=2, bash_command='echo Hi') - dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) - self.scheduler_job = SchedulerJob(subdir=os.devnull) self.scheduler_job.processor_agent = mock.MagicMock() - self.scheduler_job.dagbag.bag_dag(dag, root_dag=dag) dr = dag_maker.create_dagrun( run_type=DagRunType.SCHEDULED, @@ -3207,15 +2969,12 @@ def test_dag_file_processor_process_task_instances_depends_on_past( default_args={ 'depends_on_past': True, }, - ) as dag: + ): BashOperator(task_id='dummy1', bash_command='echo hi') BashOperator(task_id='dummy2', bash_command='echo hi') - dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) - self.scheduler_job = SchedulerJob(subdir=os.devnull) self.scheduler_job.processor_agent = mock.MagicMock() - self.scheduler_job.dagbag.bag_dag(dag, root_dag=dag) dr = dag_maker.create_dagrun( run_type=DagRunType.SCHEDULED, ) @@ -3244,14 +3003,10 @@ def test_scheduler_job_add_new_task(self, dag_maker): BashOperator(task_id='dummy', bash_command='echo test') self.scheduler_job = SchedulerJob(subdir=os.devnull) - self.scheduler_job.dagbag.bag_dag(dag, root_dag=dag) - - # Since we don't want to store the code for the DAG defined in this file - with mock.patch.object(settings, "STORE_DAG_CODE", False): - self.scheduler_job.dagbag.sync_to_db() + self.scheduler_job.dagbag = dag_maker.dagbag session = settings.Session() - orm_dag = session.query(DagModel).get(dag.dag_id) + orm_dag = dag_maker.dag_model assert orm_dag is not None if self.scheduler_job.processor_agent: @@ -3286,22 +3041,13 @@ def test_runs_respected_after_clear(self, dag_maker): """ Test dag after dag.clear, max_active_runs is respected """ - with dag_maker(dag_id='test_scheduler_max_active_runs_respected_after_clear') as dag: + with dag_maker( + dag_id='test_scheduler_max_active_runs_respected_after_clear', max_active_runs=1 + ) as dag: BashOperator(task_id='dummy', bash_command='echo Hi') - dag.max_active_runs = 1 - - dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) - - # Write Dag to DB - dagbag = DagBag(dag_folder="/dev/null", include_examples=False, read_dags_from_db=False) - dagbag.bag_dag(dag, root_dag=dag) - dagbag.sync_to_db() - - dag = DagBag(read_dags_from_db=True, include_examples=False).get_dag(dag.dag_id) self.scheduler_job = SchedulerJob(subdir=os.devnull) self.scheduler_job.processor_agent = mock.MagicMock() - self.scheduler_job.dagbag.bag_dag(dag, root_dag=dag) session = settings.Session() date = DEFAULT_DATE @@ -3327,47 +3073,43 @@ def test_runs_respected_after_clear(self, dag_maker): session = settings.Session() self.scheduler_job._start_queued_dagruns(session) - session.commit() + session.flush() # Assert that only 1 dagrun is active assert len(DagRun.find(dag_id=dag.dag_id, state=State.RUNNING, session=session)) == 1 # Assert that the other two are queued assert len(DagRun.find(dag_id=dag.dag_id, state=State.QUEUED, session=session)) == 2 - def test_timeout_triggers(self): + def test_timeout_triggers(self, dag_maker): """ Tests that tasks in the deferred state, but whose trigger timeout has expired, are correctly failed. """ + session = settings.Session() # Create the test DAG and task - with DAG( + with dag_maker( dag_id='test_timeout_triggers', start_date=DEFAULT_DATE, schedule_interval='@once', max_active_runs=1, - ) as dag: - task1 = DummyOperator(task_id='dummy1') - - # Load it into the DagBag - session = settings.Session() - dagbag = DagBag( - dag_folder=os.devnull, - include_examples=False, - read_dags_from_db=True, - ) - dagbag.bag_dag(dag=dag, root_dag=dag) - dagbag.sync_to_db(session=session) + session=session, + ): + DummyOperator(task_id='dummy1') # Create a Task Instance for the task that is allegedly deferred # but past its timeout, and one that is still good. # We don't actually need a linked trigger here; the code doesn't check. - ti1 = TaskInstance(task1, DEFAULT_DATE, State.DEFERRED) - ti2 = TaskInstance(task1, DEFAULT_DATE + datetime.timedelta(seconds=1), State.DEFERRED) + dr1 = dag_maker.create_dagrun() + dr2 = dag_maker.create_dagrun( + run_id="test2", execution_date=DEFAULT_DATE + datetime.timedelta(seconds=1) + ) + ti1 = dr1.get_task_instance('dummy1', session) + ti2 = dr2.get_task_instance('dummy1', session) + ti1.state = State.DEFERRED ti1.trigger_timeout = timezone.utcnow() - datetime.timedelta(seconds=60) + ti2.state = State.DEFERRED ti2.trigger_timeout = timezone.utcnow() + datetime.timedelta(seconds=60) - session.add(ti1) - session.add(ti2) session.flush() # Boot up the scheduler and make it check timeouts @@ -3375,8 +3117,8 @@ def test_timeout_triggers(self): self.scheduler_job.check_trigger_timeouts(session=session) # Make sure that TI1 is now scheduled to fail, and 2 wasn't touched - ti1.refresh_from_db() - ti2.refresh_from_db() + session.refresh(ti1) + session.refresh(ti2) assert ti1.state == State.SCHEDULED assert ti1.next_method == "__fail__" assert ti2.state == State.DEFERRED