Skip to content

Commit

Permalink
Have the dag_maker fixture (optionally) give SerializedDAGs (apache#1…
Browse files Browse the repository at this point in the history
…7577)

All but one test in test_scheduler_job.py wants to operate on serialized
dags, so it makes sense to have this be done in the dag_maker for us, to
make each test "smaller".
  • Loading branch information
ashb committed Aug 18, 2021
1 parent 1cd3d8f commit d8c0cfe
Show file tree
Hide file tree
Showing 3 changed files with 175 additions and 360 deletions.
2 changes: 2 additions & 0 deletions pytest.ini
Expand Up @@ -32,3 +32,5 @@ faulthandler_timeout = 480
log_level = INFO
filterwarnings =
error::pytest.PytestCollectionWarning
markers =
need_serialized_dag
113 changes: 92 additions & 21 deletions tests/conftest.py
Expand Up @@ -429,7 +429,7 @@ def app():
@pytest.fixture
def dag_maker(request):
"""
The dag_maker helps us to create DAG & DagModel automatically.
The dag_maker helps us to create DAG, DagModel, and SerializedDAG automatically.
You have to use the dag_maker as a context manager and it takes
the same argument as DAG::
Expand All @@ -451,49 +451,89 @@ def dag_maker(request):
The dag_maker.create_dagrun takes the same arguments as dag.create_dagrun
If you want to operate on serialized DAGs, then either pass ``serialized=True` to the ``dag_maker()``
call, or you can mark your test/class/file with ``@pytest.mark.need_serialized_dag(True)``. In both of
these cases the ``dag`` returned by the context manager will be a lazily-evaluated proxy object to the
SerializedDAG.
"""
from airflow.models import DAG, DagModel
from airflow.utils import timezone
from airflow.utils.session import provide_session
from airflow.utils.state import State
import lazy_object_proxy

DEFAULT_DATE = timezone.datetime(2016, 1, 1)
# IMPORTANT: Delay _all_ imports from `airflow.*` to _inside a method_.
# This fixture is "called" early on in the pytest collection process, and
# if we import airflow.* here the wrong (non-test) config will be loaded
# and "baked" in to various constants

want_serialized = False

# Allow changing default serialized behaviour with `@ptest.mark.need_serialized_dag` or
# `@ptest.mark.need_serialized_dag(False)`
serialized_marker = request.node.get_closest_marker("need_serialized_dag")
if serialized_marker:
(want_serialized,) = serialized_marker.args or (True,)

class DagFactory:
def __init__(self):
from airflow.models import DagBag

# Keep all the serialized dags we've created in this test
self.dagbag = DagBag(os.devnull, include_examples=False, read_dags_from_db=False)

def __enter__(self):
self.dag.__enter__()
if self.want_serialized:
return lazy_object_proxy.Proxy(self._serialized_dag)
return self.dag

def _serialized_dag(self):
return self.serialized_model.dag

def __exit__(self, type, value, traceback):
from airflow.models import DagModel
from airflow.models.serialized_dag import SerializedDagModel

dag = self.dag
dag.__exit__(type, value, traceback)
if type is None:
dag.clear()
self.dag_model = DagModel(
dag_id=dag.dag_id,
next_dagrun=dag.start_date,
is_active=True,
is_paused=False,
max_active_tasks=dag.max_active_tasks,
has_task_concurrency_limits=False,
)
self.session.add(self.dag_model)
if type is not None:
return

dag.clear()
dag.sync_to_db(self.session)
self.dag_model = self.session.query(DagModel).get(dag.dag_id)

if self.want_serialized:
self.serialized_model = SerializedDagModel(dag)
self.session.merge(self.serialized_model)
serialized_dag = self._serialized_dag()
self.dagbag.bag_dag(serialized_dag, root_dag=serialized_dag)
self.session.flush()
else:
self.dagbag.bag_dag(self.dag, self.dag)

def create_dagrun(self, **kwargs):
from airflow.utils.state import State

dag = self.dag
defaults = dict(
run_id='test',
state=State.RUNNING,
execution_date=self.start_date,
start_date=self.start_date,
session=self.session,
)
kwargs = {**defaults, **kwargs}
self.dag_run = dag.create_dagrun(**kwargs)
return self.dag_run

@provide_session
def __call__(self, dag_id='test_dag', session=None, **kwargs):
def __call__(
self, dag_id='test_dag', serialized=want_serialized, fileloc=None, session=None, **kwargs
):
from airflow import settings
from airflow.models import DAG
from airflow.utils import timezone

if session is None:
session = settings.Session()

self.kwargs = kwargs
self.session = session
self.start_date = self.kwargs.get('start_date', None)
Expand All @@ -506,13 +546,44 @@ def __call__(self, dag_id='test_dag', session=None, **kwargs):
if hasattr(request.module, 'DEFAULT_DATE'):
self.start_date = getattr(request.module, 'DEFAULT_DATE')
else:
DEFAULT_DATE = timezone.datetime(2016, 1, 1)
self.start_date = DEFAULT_DATE
self.kwargs['start_date'] = self.start_date
self.dag = DAG(dag_id, **self.kwargs)
self.dag.fileloc = request.module.__file__
self.dag.fileloc = fileloc or request.module.__file__
self.want_serialized = serialized

return self

return DagFactory()
def cleanup(self):
from airflow.models import DagModel, DagRun, TaskInstance
from airflow.models.serialized_dag import SerializedDagModel

dag_ids = list(self.dagbag.dag_ids)
if not dag_ids:
return
# To isolate problems here with problems from elsewhere on the session object
self.session.flush()

self.session.query(SerializedDagModel).filter(SerializedDagModel.dag_id.in_(dag_ids)).delete(
synchronize_session=False
)
self.session.query(DagRun).filter(DagRun.dag_id.in_(dag_ids)).delete(synchronize_session=False)
self.session.query(TaskInstance).filter(TaskInstance.dag_id.in_(dag_ids)).delete(
synchronize_session=False
)
self.session.query(DagModel).filter(DagModel.dag_id.in_(dag_ids)).delete(
synchronize_session=False
)
self.session.commit()

factory = DagFactory()

try:
yield factory
finally:
factory.cleanup()
del factory.session


@pytest.fixture
Expand Down

0 comments on commit d8c0cfe

Please sign in to comment.