Skip to content

Commit

Permalink
[dagster-airflow] add airflow 2 support to `make_dagster_job_from_air…
Browse files Browse the repository at this point in the history
…flow_dag` + xcom mock option (#10337)
  • Loading branch information
Ramshackle-Jamathon authored and alangenfeld committed Nov 3, 2022
1 parent 9157ebb commit 251b666
Show file tree
Hide file tree
Showing 7 changed files with 91 additions and 27 deletions.
6 changes: 4 additions & 2 deletions docs/content/integrations/airflow.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ Dagster can convert Airflow operators to Dagster ops for any Airflow operators t

```python file=../../with_airflow/with_airflow/airflow_operator_to_op.py startafter=start_operator_to_op_1 endbefore=end_operator_to_op_1
http_task = SimpleHttpOperator(task_id="http_task", method="GET", endpoint="images")
connections = [Connection(conn_id="http_default", host="https://google.com")]
connections = [Connection(conn_id="http_default", conn_type="uri", host="https://google.com")]
dagster_op = airflow_operator_to_op(http_task, connections=connections)


Expand Down Expand Up @@ -190,17 +190,19 @@ There are two jobs in the repo:
```python file=../../with_airflow/with_airflow/repository.py startafter=start_repo_marker_0 endbefore=end_repo_marker_0
from dagster_airflow.dagster_job_factory import make_dagster_job_from_airflow_dag
from with_airflow.airflow_complex_dag import complex_dag
from with_airflow.airflow_kubernetes_dag import kubernetes_dag
from with_airflow.airflow_simple_dag import simple_dag

from dagster import repository

airflow_simple_dag = make_dagster_job_from_airflow_dag(simple_dag)
airflow_complex_dag = make_dagster_job_from_airflow_dag(complex_dag)
airflow_kubernetes_dag = make_dagster_job_from_airflow_dag(kubernetes_dag, mock_xcom=True)


@repository
def with_airflow():
return [airflow_complex_dag, airflow_simple_dag]
return [airflow_complex_dag, airflow_simple_dag, airflow_kubernetes_dag]
```

Note that the "execution_date" for the Airflow DAG is specified through the job tags. To specify tags, call to:
Expand Down
10 changes: 6 additions & 4 deletions examples/with_airflow/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@
install_requires=[
"dagster",
"dagster_airflow",
# See https://github.com/dagster-io/dagster/issues/2701
"apache-airflow==1.10.10",
# Conflicts with `Jinja2` which is used in dagster cli that dagster_airflow depends on
"markupsafe<=2.0.1",
"apache-airflow==2.3.0",
# pin jinja2 to version compatible with dagit and airflow
"jinja2==3.0.3",
# for the kubernetes operator
"apache-airflow-providers-cncf-kubernetes>=4.4.0",
"apache-airflow-providers-docker>=3.1.0",
],
extras_require={"dev": ["dagit", "pytest"]},
)
31 changes: 31 additions & 0 deletions examples/with_airflow/with_airflow/airflow_kubernetes_dag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# pylint: disable=pointless-statement

from airflow import models
from airflow.operators.dummy_operator import DummyOperator
from airflow.providers.cncf.kubernetes.operators.kubernetes_pod import KubernetesPodOperator
from airflow.utils.dates import days_ago

default_args = {"start_date": days_ago(1)}

with models.DAG(
dag_id="kubernetes_dag", default_args=default_args, schedule_interval=None
) as kubernetes_dag:

run_this_last = DummyOperator(
task_id="sink_task_1",
dag=kubernetes_dag,
)

k = KubernetesPodOperator(
name="hello-dry-run",
# will need to modified to match your k8s context
cluster_context="hooli-user-cluster",
namespace="default",
image="debian",
cmds=["bash", "-cx"],
arguments=["echo", "10"],
labels={"foo": "bar"},
task_id="dry_run_demo",
do_xcom_push=True,
)
k >> run_this_last
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

# start_operator_to_op_1
http_task = SimpleHttpOperator(task_id="http_task", method="GET", endpoint="images")
connections = [Connection(conn_id="http_default", host="https://google.com")]
connections = [Connection(conn_id="http_default", conn_type="uri", host="https://google.com")]
dagster_op = airflow_operator_to_op(http_task, connections=connections)


Expand Down
4 changes: 3 additions & 1 deletion examples/with_airflow/with_airflow/repository.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
# start_repo_marker_0
from dagster_airflow.dagster_job_factory import make_dagster_job_from_airflow_dag
from with_airflow.airflow_complex_dag import complex_dag
from with_airflow.airflow_kubernetes_dag import kubernetes_dag
from with_airflow.airflow_simple_dag import simple_dag

from dagster import repository

airflow_simple_dag = make_dagster_job_from_airflow_dag(simple_dag)
airflow_complex_dag = make_dagster_job_from_airflow_dag(complex_dag)
airflow_kubernetes_dag = make_dagster_job_from_airflow_dag(kubernetes_dag, mock_xcom=True)


@repository
def with_airflow():
return [airflow_complex_dag, airflow_simple_dag]
return [airflow_complex_dag, airflow_simple_dag, airflow_kubernetes_dag]


# end_repo_marker_0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@


def make_dagster_job_from_airflow_dag(
dag, tags=None, use_airflow_template_context=False, unique_id=None
dag, tags=None, use_airflow_template_context=False, unique_id=None, mock_xcom=False
):
"""Construct a Dagster job corresponding to a given Airflow DAG.
Expand Down Expand Up @@ -45,13 +45,15 @@ def make_dagster_job_from_airflow_dag(
(default: False)
unique_id (int): If not None, this id will be postpended to generated op names. Used by
framework authors to enforce unique op names within a repo.
mock_xcom (bool): If not None, dagster will mock out all calls made to xcom, features that
depend on xcom may not work as expected.
Returns:
JobDefinition: The generated Dagster job
"""
pipeline_def = make_dagster_pipeline_from_airflow_dag(
dag, tags, use_airflow_template_context, unique_id
dag, tags, use_airflow_template_context, unique_id, mock_xcom
)
# pass in tags manually because pipeline_def.graph doesn't have it threaded
return pipeline_def.graph.to_job(tags={**pipeline_def.tags})
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import datetime
import logging
import sys
from contextlib import contextmanager
from contextlib import contextmanager, nullcontext
from unittest.mock import patch

import dateutil
import lazy_object_proxy
import pendulum
from airflow import __version__ as airflow_version
from airflow.models import TaskInstance
from airflow.models.baseoperator import BaseOperator
from airflow.models.dag import DAG
Expand Down Expand Up @@ -224,7 +226,7 @@ def make_repo_from_dir():


def make_dagster_pipeline_from_airflow_dag(
dag, tags=None, use_airflow_template_context=False, unique_id=None
dag, tags=None, use_airflow_template_context=False, unique_id=None, mock_xcom=False
):
"""Construct a Dagster pipeline corresponding to a given Airflow DAG.
Expand Down Expand Up @@ -274,6 +276,8 @@ def make_dagster_pipeline_from_airflow_dag(
(default: False)
unique_id (int): If not None, this id will be postpended to generated solid names. Used by
framework authors to enforce unique solid names within a repo.
mock_xcom (bool): If not None, dagster will mock out all calls made to xcom, features that
depend on xcom may not work as expected.
Returns:
pipeline_def (PipelineDefinition): The generated Dagster pipeline
Expand All @@ -283,14 +287,15 @@ def make_dagster_pipeline_from_airflow_dag(
tags = check.opt_dict_param(tags, "tags")
check.bool_param(use_airflow_template_context, "use_airflow_template_context")
unique_id = check.opt_int_param(unique_id, "unique_id")
mock_xcom = check.opt_bool_param(mock_xcom, "mock_xcom")

if IS_AIRFLOW_INGEST_PIPELINE_STR not in tags:
tags[IS_AIRFLOW_INGEST_PIPELINE_STR] = "true"

tags = validate_tags(tags)

pipeline_dependencies, solid_defs = _get_pipeline_definition_args(
dag, use_airflow_template_context, unique_id
dag, use_airflow_template_context, unique_id, mock_xcom
)
pipeline_def = PipelineDefinition(
name=normalized_name(dag.dag_id, None),
Expand All @@ -312,7 +317,9 @@ def normalized_name(name, unique_id):
return base_name + "_" + str(unique_id)


def _get_pipeline_definition_args(dag, use_airflow_template_context, unique_id=None):
def _get_pipeline_definition_args(
dag, use_airflow_template_context, unique_id=None, mock_xcom=False
):
check.inst_param(dag, "dag", DAG)
check.bool_param(use_airflow_template_context, "use_airflow_template_context")
unique_id = check.opt_int_param(unique_id, "unique_id")
Expand All @@ -331,6 +338,7 @@ def _get_pipeline_definition_args(dag, use_airflow_template_context, unique_id=N
solid_defs,
use_airflow_template_context,
unique_id,
mock_xcom,
)
return (pipeline_dependencies, solid_defs)

Expand All @@ -342,16 +350,18 @@ def _traverse_airflow_dag(
solid_defs,
use_airflow_template_context,
unique_id,
mock_xcom,
):
check.inst_param(task, "task", BaseOperator)
check.list_param(seen_tasks, "seen_tasks", BaseOperator)
check.list_param(solid_defs, "solid_defs", SolidDefinition)
check.bool_param(use_airflow_template_context, "use_airflow_template_context")
unique_id = check.opt_int_param(unique_id, "unique_id")
mock_xcom = check.opt_bool_param(mock_xcom, "mock_xcom")

seen_tasks.append(task)
current_solid = make_dagster_solid_from_airflow_task(
task, use_airflow_template_context, unique_id
task, use_airflow_template_context, unique_id, mock_xcom
)
solid_defs.append(current_solid)

Expand Down Expand Up @@ -382,6 +392,7 @@ def _traverse_airflow_dag(
solid_defs,
use_airflow_template_context,
unique_id,
mock_xcom,
)


Expand All @@ -400,9 +411,18 @@ def replace_airflow_logger_handlers():
logging.getLogger("airflow.task").handlers = prev_airflow_handlers


@contextmanager
def _mock_xcom():
with patch("airflow.models.TaskInstance.xcom_push"):
with patch("airflow.models.TaskInstance.xcom_pull"):
yield


# If unique_id is not None, this id will be postpended to generated solid names, generally used
# to enforce unique solid names within a repo.
def make_dagster_solid_from_airflow_task(task, use_airflow_template_context, unique_id=None):
def make_dagster_solid_from_airflow_task(
task, use_airflow_template_context, unique_id=None, mock_xcom=False
):
check.inst_param(task, "task", BaseOperator)
check.bool_param(use_airflow_template_context, "use_airflow_template_context")
unique_id = check.opt_int_param(unique_id, "unique_id")
Expand Down Expand Up @@ -443,19 +463,24 @@ def _solid(context): # pylint: disable=unused-argument

check.inst_param(execution_date, "execution_date", datetime.datetime)

with replace_airflow_logger_handlers():
task_instance = TaskInstance(task=task, execution_date=execution_date)

ti_context = (
dagster_get_template_context(task_instance, task, execution_date)
if not use_airflow_template_context
else task_instance.get_template_context()
)
task.render_template_fields(ti_context)
with _mock_xcom() if mock_xcom else nullcontext():
with replace_airflow_logger_handlers():
if airflow_version >= "2.0.0":
task_instance = TaskInstance(
task=task, execution_date=execution_date, run_id="dagster_airflow_run"
)
else:
task_instance = TaskInstance(task=task, execution_date=execution_date)
ti_context = (
dagster_get_template_context(task_instance, task, execution_date)
if not use_airflow_template_context
else task_instance.get_template_context()
)
task.render_template_fields(ti_context)

task.execute(ti_context)
task.execute(ti_context)

return None
return None

return _solid

Expand Down

0 comments on commit 251b666

Please sign in to comment.