Skip to content

Commit

Permalink
Make current working directory as templated field in BashOperator (ap…
Browse files Browse the repository at this point in the history
  • Loading branch information
Taragolis authored and howardyoo committed Mar 18, 2024
1 parent 99399e4 commit df7f52a
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 13 deletions.
6 changes: 4 additions & 2 deletions airflow/operators/bash.py
Expand Up @@ -59,8 +59,10 @@ class BashOperator(BaseOperator):
:param skip_on_exit_code: If task exits with this exit code, leave the task
in ``skipped`` state (default: 99). If set to ``None``, any non-zero
exit code will be treated as a failure.
:param cwd: Working directory to execute the command in.
:param cwd: Working directory to execute the command in (templated).
If None (default), the command is run in a temporary directory.
To use current DAG folder as the working directory,
you might set template ``{{ dag_run.dag.folder }}``.
Airflow will evaluate the exit code of the Bash command. In general, a non-zero exit code will result in
task failure and zero will result in task success.
Expand Down Expand Up @@ -130,7 +132,7 @@ class BashOperator(BaseOperator):
"""

template_fields: Sequence[str] = ("bash_command", "env")
template_fields: Sequence[str] = ("bash_command", "env", "cwd")
template_fields_renderers = {"bash_command": "bash", "env": "json"}
template_ext: Sequence[str] = (".sh", ".bash")
ui_color = "#f0ede4"
Expand Down
19 changes: 12 additions & 7 deletions tests/models/test_renderedtifields.py
Expand Up @@ -137,9 +137,11 @@ def test_get_templated_fields(self, templated_field, expected_rendered_field, da
session.add(rtif)
session.flush()

assert {"bash_command": expected_rendered_field, "env": None} == RTIF.get_templated_fields(
ti=ti, session=session
)
assert {
"bash_command": expected_rendered_field,
"env": None,
"cwd": None,
} == RTIF.get_templated_fields(ti=ti, session=session)
# Test the else part of get_templated_fields
# i.e. for the TIs that are not stored in RTIF table
# Fetching them will return None
Expand Down Expand Up @@ -261,7 +263,7 @@ def test_write(self, dag_maker):
)
.first()
)
assert ("test_write", "test", {"bash_command": "echo test_val", "env": None}) == result
assert ("test_write", "test", {"bash_command": "echo test_val", "env": None, "cwd": None}) == result

# Test that overwrite saves new values to the DB
Variable.delete("test_key")
Expand All @@ -287,7 +289,7 @@ def test_write(self, dag_maker):
assert (
"test_write",
"test",
{"bash_command": "echo test_val_updated", "env": None},
{"bash_command": "echo test_val_updated", "env": None, "cwd": None},
) == result_updated

@mock.patch.dict(os.environ, {"AIRFLOW_VAR_API_KEY": "secret"})
Expand All @@ -301,8 +303,10 @@ def test_redact(self, redact, dag_maker):
)
dr = dag_maker.create_dagrun()
redact.side_effect = [
"val 1",
"val 2",
# Order depends on order in Operator template_fields
"val 1", # bash_command
"val 2", # env
"val 3", # cwd
]

ti = dr.task_instances[0]
Expand All @@ -311,4 +315,5 @@ def test_redact(self, redact, dag_maker):
assert rtif.rendered_fields == {
"bash_command": "val 1",
"env": "val 2",
"cwd": "val 3",
}
6 changes: 5 additions & 1 deletion tests/models/test_taskinstance.py
Expand Up @@ -636,7 +636,11 @@ def test_retry_handling(self, dag_maker):
"""
Test that task retries are handled properly
"""
expected_rendered_ti_fields = {"env": None, "bash_command": "echo test_retry_handling; exit 1"}
expected_rendered_ti_fields = {
"env": None,
"bash_command": "echo test_retry_handling; exit 1",
"cwd": None,
}

with dag_maker(dag_id="test_retry_handling") as dag:
task = BashOperator(
Expand Down
20 changes: 20 additions & 0 deletions tests/operators/test_bash.py
Expand Up @@ -20,6 +20,7 @@
import os
import signal
from datetime import datetime, timedelta
from pathlib import Path
from time import sleep
from unittest import mock

Expand Down Expand Up @@ -244,3 +245,22 @@ def test_bash_operator_kill(self, dag_maker):
os.kill(proc.pid, signal.SIGTERM)
assert False, "BashOperator's subprocess still running after stopping on timeout!"
break

@pytest.mark.db_test
def test_templated_fields(self, create_task_instance_of_operator):
ti = create_task_instance_of_operator(
BashOperator,
# Templated fields
bash_command='echo "{{ dag_run.dag_id }}"',
env={"FOO": "{{ ds }}"},
cwd="{{ dag_run.dag.folder }}",
# Other parameters
dag_id="test_templated_fields_dag",
task_id="test_templated_fields_task",
execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc),
)
ti.render_templates()
task: BashOperator = ti.task
assert task.bash_command == 'echo "test_templated_fields_dag"'
assert task.env == {"FOO": "2024-02-01"}
assert task.cwd == Path(__file__).absolute().parent.as_posix()
6 changes: 3 additions & 3 deletions tests/serialization/test_dag_serialization.py
Expand Up @@ -169,7 +169,7 @@ def detect_task_dependencies(task: Operator) -> DagDependency | None: # type: i
"ui_color": "#f0ede4",
"ui_fgcolor": "#000",
"template_ext": [".sh", ".bash"],
"template_fields": ["bash_command", "env"],
"template_fields": ["bash_command", "env", "cwd"],
"template_fields_renderers": {"bash_command": "bash", "env": "json"},
"bash_command": "echo {{ task.task_id }}",
"_task_type": "BashOperator",
Expand Down Expand Up @@ -2150,7 +2150,7 @@ def test_operator_expand_serde():
},
"task_id": "a",
"operator_extra_links": [],
"template_fields": ["bash_command", "env"],
"template_fields": ["bash_command", "env", "cwd"],
"template_ext": [".sh", ".bash"],
"template_fields_renderers": {"bash_command": "bash", "env": "json"},
"ui_color": "#f0ede4",
Expand All @@ -2168,7 +2168,7 @@ def test_operator_expand_serde():
"downstream_task_ids": [],
"task_id": "a",
"template_ext": [".sh", ".bash"],
"template_fields": ["bash_command", "env"],
"template_fields": ["bash_command", "env", "cwd"],
"template_fields_renderers": {"bash_command": "bash", "env": "json"},
"ui_color": "#f0ede4",
"ui_fgcolor": "#000",
Expand Down

0 comments on commit df7f52a

Please sign in to comment.