Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(airflow): allow data to be a callable #1318

Merged
merged 5 commits into from
May 8, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions dlt/helpers/airflow_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,10 @@ def run(

Args:
pipeline (Pipeline): The pipeline to run
data (Any): The data to run the pipeline with
data (Any):
The data to run the pipeline with. If a non-resource
callable given, it's called before the load to get
IlyaFaer marked this conversation as resolved.
Show resolved Hide resolved
the data.
table_name (str, optional): The name of the table to
which the data should be loaded within the `dataset`.
write_disposition (TWriteDispositionConfig, optional): Same as
Expand Down Expand Up @@ -221,7 +224,10 @@ def _run(

Args:
pipeline (Pipeline): The pipeline to run
data (Any): The data to run the pipeline with
data (Any):
The data to run the pipeline with. If a non-resource
callable given, it's called before the load to get
the data.
table_name (str, optional): The name of the
table to which the data should be loaded
within the `dataset`.
Expand Down Expand Up @@ -271,6 +277,9 @@ def log_after_attempt(retry_state: RetryCallState) -> None:
)

try:
if callable(data):
data = data()
rudolfix marked this conversation as resolved.
Show resolved Hide resolved

# retry with given policy on selected pipeline steps
for attempt in self.retry_policy.copy(
retry=retry_if_exception(
Expand Down Expand Up @@ -338,7 +347,10 @@ def add_run(

Args:
pipeline (Pipeline): An instance of pipeline used to run the source
data (Any): Any data supported by `run` method of the pipeline
data (Any):
Any data supported by `run` method of the pipeline.
If a non-resource callable given, it's called before
the load to get the data.
decompose (Literal["none", "serialize", "parallel"], optional):
A source decomposition strategy into Airflow tasks:
none - no decomposition, default value.
Expand Down
47 changes: 46 additions & 1 deletion tests/helpers/airflow_tests/test_airflow_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import List
from airflow import DAG
from airflow.decorators import dag
from airflow.operators.python import PythonOperator
from airflow.operators.python import PythonOperator, get_current_context
from airflow.models import TaskInstance
from airflow.utils.state import DagRunState
from airflow.utils.types import DagRunType
Expand Down Expand Up @@ -917,3 +917,48 @@ def dag_parallel():
dag_def = dag_parallel()
assert len(tasks_list) == 1
dag_def.test()


def callable_source():
@dlt.resource
def test_res():
context = get_current_context()
yield [
{"id": 1, "tomorrow": context["tomorrow_ds"]},
{"id": 2, "tomorrow": context["tomorrow_ds"]},
{"id": 3, "tomorrow": context["tomorrow_ds"]},
IlyaFaer marked this conversation as resolved.
Show resolved Hide resolved
]

return test_res


def test_run_callable() -> None:
quackdb_path = os.path.join(TEST_STORAGE_ROOT, "callable_dag.duckdb")

@dag(schedule=None, start_date=DEFAULT_DATE, catchup=False, default_args=default_args)
def dag_regular():
tasks = PipelineTasksGroup(
"callable_dag_group", local_data_folder=TEST_STORAGE_ROOT, wipe_local_data=False
)

call_dag = dlt.pipeline(
pipeline_name="callable_dag",
dataset_name="mock_data_" + uniq_id(),
destination="duckdb",
credentials=quackdb_path,
)
tasks.run(call_dag, callable_source)

dag_def: DAG = dag_regular()
dag_def.test()

pipeline_dag = dlt.attach(pipeline_name="callable_dag")

with pipeline_dag.sql_client() as client:
with client.execute_query("SELECT * FROM test_res") as result:
results = result.fetchall()

assert len(results) == 3

for row in results:
assert row[1] == pendulum.tomorrow().format("YYYY-MM-DD")
Loading