Skip to content

Commit

Permalink
feat(airflow): allow data to be a callable (#1318)
Browse files Browse the repository at this point in the history
* feat(airflow): allow data to be a callable

* lint fix

* fix docstrings

* implement on_before_run argument
  • Loading branch information
IlyaFaer committed May 8, 2024
1 parent 80684d1 commit 21f90b4
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 5 deletions.
36 changes: 33 additions & 3 deletions dlt/helpers/airflow_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ def run(
loader_file_format: TLoaderFileFormat = None,
schema_contract: TSchemaContract = None,
pipeline_name: str = None,
on_before_run: Callable[[], None] = None,
**kwargs: Any,
) -> PythonOperator:
"""
Expand All @@ -179,7 +180,12 @@ def run(
Args:
pipeline (Pipeline): The pipeline to run
data (Any): The data to run the pipeline with
data (Any):
The data to run the pipeline with. If a non-resource
callable given, it's evaluated during the DAG execution,
right before the actual pipeline run.
NOTE: If `on_before_run` is provided, first `on_before_run`
is evaluated, and then callable `data`.
table_name (str, optional): The name of the table to
which the data should be loaded within the `dataset`.
write_disposition (TWriteDispositionConfig, optional): Same as
Expand All @@ -191,6 +197,8 @@ def run(
for the schema contract settings, this will replace
the schema contract settings for all tables in the schema.
pipeline_name (str, optional): The name of the derived pipeline.
on_before_run (Callable, optional): A callable to be
executed right before the actual pipeline run.
Returns:
PythonOperator: Airflow task instance.
Expand All @@ -204,6 +212,7 @@ def run(
loader_file_format=loader_file_format,
schema_contract=schema_contract,
pipeline_name=pipeline_name,
on_before_run=on_before_run,
)
return PythonOperator(task_id=self._task_name(pipeline, data), python_callable=f, **kwargs)

Expand All @@ -216,12 +225,18 @@ def _run(
loader_file_format: TLoaderFileFormat = None,
schema_contract: TSchemaContract = None,
pipeline_name: str = None,
on_before_run: Callable[[], None] = None,
) -> None:
"""Run the given pipeline with the given data.
Args:
pipeline (Pipeline): The pipeline to run
data (Any): The data to run the pipeline with
data (Any):
The data to run the pipeline with. If a non-resource
callable given, it's evaluated during the DAG execution,
right before the actual pipeline run.
NOTE: If `on_before_run` is provided, first `on_before_run`
is evaluated, and then callable `data`.
table_name (str, optional): The name of the
table to which the data should be loaded
within the `dataset`.
Expand All @@ -236,6 +251,8 @@ def _run(
for all tables in the schema.
pipeline_name (str, optional): The name of the
derived pipeline.
on_before_run (Callable, optional): A callable
to be executed right before the actual pipeline run.
"""
# activate pipeline
pipeline.activate()
Expand Down Expand Up @@ -271,6 +288,12 @@ def log_after_attempt(retry_state: RetryCallState) -> None:
)

try:
if on_before_run is not None:
on_before_run()

if callable(data):
data = data()

# retry with given policy on selected pipeline steps
for attempt in self.retry_policy.copy(
retry=retry_if_exception(
Expand Down Expand Up @@ -325,6 +348,7 @@ def add_run(
write_disposition: TWriteDispositionConfig = None,
loader_file_format: TLoaderFileFormat = None,
schema_contract: TSchemaContract = None,
on_before_run: Callable[[], None] = None,
**kwargs: Any,
) -> List[PythonOperator]:
"""Creates a task or a group of tasks to run `data` with `pipeline`
Expand All @@ -338,7 +362,10 @@ def add_run(
Args:
pipeline (Pipeline): An instance of pipeline used to run the source
data (Any): Any data supported by `run` method of the pipeline
data (Any):
Any data supported by `run` method of the pipeline.
If a non-resource callable given, it's called before
the load to get the data.
decompose (Literal["none", "serialize", "parallel"], optional):
A source decomposition strategy into Airflow tasks:
none - no decomposition, default value.
Expand All @@ -365,6 +392,8 @@ def add_run(
Not all file_formats are compatible with all destinations. Defaults to the preferred file format of the selected destination.
schema_contract (TSchemaContract, optional): On override for the schema contract settings,
this will replace the schema contract settings for all tables in the schema. Defaults to None.
on_before_run (Callable, optional):
A callable to be executed right before the actual pipeline run.
Returns:
Any: Airflow tasks created in order of creation.
Expand All @@ -391,6 +420,7 @@ def make_task(pipeline: Pipeline, data: Any, name: str = None) -> PythonOperator
loader_file_format=loader_file_format,
schema_contract=schema_contract,
pipeline_name=name,
on_before_run=on_before_run,
)
return PythonOperator(
task_id=self._task_name(pipeline, data), python_callable=f, **kwargs
Expand Down
82 changes: 80 additions & 2 deletions tests/helpers/airflow_tests/test_airflow_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
from typing import List
from airflow import DAG
from airflow.decorators import dag
from airflow.operators.python import PythonOperator
from airflow.operators.python import PythonOperator, get_current_context
from airflow.models import TaskInstance
from airflow.utils.state import DagRunState
from airflow.utils.types import DagRunType

import dlt
from dlt.common import pendulum
from dlt.common import logger, pendulum
from dlt.common.utils import uniq_id
from dlt.common.normalizers.naming.snake_case import NamingConvention as SnakeCaseNamingConvention

Expand Down Expand Up @@ -917,3 +917,81 @@ def dag_parallel():
dag_def = dag_parallel()
assert len(tasks_list) == 1
dag_def.test()


def callable_source():
@dlt.resource
def test_res():
context = get_current_context()
yield [
{"id": 1, "tomorrow": context["tomorrow_ds"]},
{"id": 2, "tomorrow": context["tomorrow_ds"]},
{"id": 3, "tomorrow": context["tomorrow_ds"]},
]

return test_res


def test_run_callable() -> None:
quackdb_path = os.path.join(TEST_STORAGE_ROOT, "callable_dag.duckdb")

@dag(schedule=None, start_date=DEFAULT_DATE, catchup=False, default_args=default_args)
def dag_regular():
tasks = PipelineTasksGroup(
"callable_dag_group", local_data_folder=TEST_STORAGE_ROOT, wipe_local_data=False
)

call_dag = dlt.pipeline(
pipeline_name="callable_dag",
dataset_name="mock_data_" + uniq_id(),
destination="duckdb",
credentials=quackdb_path,
)
tasks.run(call_dag, callable_source)

dag_def: DAG = dag_regular()
dag_def.test()

pipeline_dag = dlt.attach(pipeline_name="callable_dag")

with pipeline_dag.sql_client() as client:
with client.execute_query("SELECT * FROM test_res") as result:
results = result.fetchall()

assert len(results) == 3

for row in results:
assert row[1] == pendulum.tomorrow().format("YYYY-MM-DD")


def on_before_run():
context = get_current_context()
logger.info(f'on_before_run test: {context["tomorrow_ds"]}')


def test_on_before_run() -> None:
quackdb_path = os.path.join(TEST_STORAGE_ROOT, "callable_dag.duckdb")

@dag(schedule=None, start_date=DEFAULT_DATE, catchup=False, default_args=default_args)
def dag_regular():
tasks = PipelineTasksGroup(
"callable_dag_group", local_data_folder=TEST_STORAGE_ROOT, wipe_local_data=False
)

call_dag = dlt.pipeline(
pipeline_name="callable_dag",
dataset_name="mock_data_" + uniq_id(),
destination="duckdb",
credentials=quackdb_path,
)
tasks.run(call_dag, mock_data_source, on_before_run=on_before_run)

dag_def: DAG = dag_regular()

with mock.patch("dlt.helpers.airflow_helper.logger.info") as logger_mock:
dag_def.test()
logger_mock.assert_has_calls(
[
mock.call(f'on_before_run test: {pendulum.tomorrow().format("YYYY-MM-DD")}'),
]
)

0 comments on commit 21f90b4

Please sign in to comment.