Skip to content

Commit

Permalink
ctx Context can be used within shell tasks - to access context vars…
Browse files Browse the repository at this point in the history
… and secrets (flyteorg#832)

* Adding context to a substitutable parameter in shell task

Signed-off-by: Ketan Umare <ketan.umare@gmail.com>

* Support for secrets in context

Signed-off-by: Ketan Umare <ketan.umare@gmail.com>

* addressed comments

Signed-off-by: Ketan Umare <ketan.umare@gmail.com>
Signed-off-by: Kenny Workman <kennyworkman@sbcglobal.net>
  • Loading branch information
kumare3 committed Jan 29, 2022
1 parent fcfd533 commit 7b0c12e
Show file tree
Hide file tree
Showing 5 changed files with 177 additions and 34 deletions.
22 changes: 22 additions & 0 deletions flytekit/core/context_manager.py
Expand Up @@ -337,11 +337,33 @@ class SecretsManager(object):
All configuration values can always be overridden by injecting an environment variable
"""

class _GroupSecrets(object):
"""
This is a dummy class whose sole purpose is to support "attribute" style lookup for secrets
"""

def __init__(self, group: str, sm: typing.Any):
self._group = group
self._sm = sm

def __getattr__(self, item: str) -> str:
"""
Returns the secret that matches "group"."key"
the key, here is the item
"""
return self._sm.get(self._group, item)

def __init__(self):
self._base_dir = str(secrets.SECRETS_DEFAULT_DIR.get()).strip()
self._file_prefix = str(secrets.SECRETS_FILE_PREFIX.get()).strip()
self._env_prefix = str(secrets.SECRETS_ENV_PREFIX.get()).strip()

def __getattr__(self, item: str) -> _GroupSecrets:
"""
returns a new _GroupSecrets objects, that allows all keys within this group to be looked up like attributes
"""
return self._GroupSecrets(item, self)

def get(self, group: str, key: str) -> str:
"""
Retrieves a secret using the resolution order -> Env followed by file. If not found raises a ValueError
Expand Down
26 changes: 20 additions & 6 deletions flytekit/extras/tasks/shell.py
@@ -1,4 +1,3 @@
import collections
import datetime
import logging
import os
Expand All @@ -7,6 +6,7 @@
import typing
from dataclasses import dataclass

import flytekit
from flytekit.core.context_manager import ExecutionParameters
from flytekit.core.interface import Interface
from flytekit.core.python_function_task import PythonInstanceTask
Expand Down Expand Up @@ -38,7 +38,15 @@ def _dummy_task_func():
return None


T = typing.TypeVar("T")
class AttrDict(dict):
"""
Convert a dictionary to an attribute style lookup. Do not use this in regular places, this is used for
namespacing inputs and outputs
"""

def __init__(self, *args, **kwargs):
super(AttrDict, self).__init__(*args, **kwargs)
self.__dict__ = self


class _PythonFStringInterpolizer:
Expand Down Expand Up @@ -73,16 +81,22 @@ def interpolate(
"""
inputs = inputs or {}
outputs = outputs or {}
reused_vars = inputs.keys() & outputs.keys()
if reused_vars:
raise ValueError(f"Variables {reused_vars} in Query cannot be shared between inputs and outputs.")
consolidated_args = collections.ChainMap(inputs, outputs)
inputs = AttrDict(inputs)
outputs = AttrDict(outputs)
consolidated_args = {
"inputs": inputs,
"outputs": outputs,
"ctx": flytekit.current_context(),
}
try:
return self._Formatter().format(tmpl, **consolidated_args)
except KeyError as e:
raise ValueError(f"Variable {e} in Query not found in inputs {consolidated_args.keys()}")


T = typing.TypeVar("T")


class ShellTask(PythonInstanceTask[T]):
""" """

Expand Down
87 changes: 86 additions & 1 deletion tests/flytekit/unit/core/test_context_manager.py
@@ -1,4 +1,16 @@
from flytekit.core.context_manager import ExecutionState, FlyteContext, FlyteContextManager, look_up_image_info
import os

import py
import pytest

from flytekit.configuration import secrets
from flytekit.core.context_manager import (
ExecutionState,
FlyteContext,
FlyteContextManager,
SecretsManager,
look_up_image_info,
)


class SampleTestClass(object):
Expand Down Expand Up @@ -65,3 +77,76 @@ def test_additional_context():
)
) as exec_ctx_inner:
assert exec_ctx_inner.execution_state.additional_context == {1: "inner", 2: "foo", 3: "baz"}


def test_secrets_manager_default():
with pytest.raises(ValueError):
sec = SecretsManager()
sec.get("group", "key")

with pytest.raises(ValueError):
_ = sec.group.key


def test_secrets_manager_get_envvar():
sec = SecretsManager()
with pytest.raises(ValueError):
sec.get_secrets_env_var("test", "")
with pytest.raises(ValueError):
sec.get_secrets_env_var("", "x")
assert sec.get_secrets_env_var("group", "test") == f"{secrets.SECRETS_ENV_PREFIX.get()}GROUP_TEST"


def test_secrets_manager_get_file():
sec = SecretsManager()
with pytest.raises(ValueError):
sec.get_secrets_file("test", "")
with pytest.raises(ValueError):
sec.get_secrets_file("", "x")
assert sec.get_secrets_file("group", "test") == os.path.join(
secrets.SECRETS_DEFAULT_DIR.get(),
"group",
f"{secrets.SECRETS_FILE_PREFIX.get()}test",
)


def test_secrets_manager_file(tmpdir: py.path.local):
tmp = tmpdir.mkdir("file_test").dirname
os.environ["FLYTE_SECRETS_DEFAULT_DIR"] = tmp
sec = SecretsManager()
f = os.path.join(tmp, "test")
with open(f, "w+") as w:
w.write("my-password")

with pytest.raises(ValueError):
sec.get("test", "")
with pytest.raises(ValueError):
sec.get("", "x")
# Group dir not exists
with pytest.raises(ValueError):
sec.get("group", "test")

g = os.path.join(tmp, "group")
os.makedirs(g)
f = os.path.join(g, "test")
with open(f, "w+") as w:
w.write("my-password")
assert sec.get("group", "test") == "my-password"
assert sec.group.test == "my-password"
del os.environ["FLYTE_SECRETS_DEFAULT_DIR"]


def test_secrets_manager_bad_env():
with pytest.raises(ValueError):
os.environ["TEST"] = "value"
sec = SecretsManager()
sec.get("group", "test")


def test_secrets_manager_env():
sec = SecretsManager()
os.environ[sec.get_secrets_env_var("group", "test")] = "value"
assert sec.get("group", "test") == "value"

os.environ[sec.get_secrets_env_var(group="group", key="key")] = "value"
assert sec.get(group="group", key="key") == "value"
72 changes: 47 additions & 25 deletions tests/flytekit/unit/extras/tasks/test_shell.py
Expand Up @@ -7,6 +7,7 @@
import pytest
from dataclasses_json import dataclass_json

import flytekit
from flytekit import kwtypes
from flytekit.extras.tasks.shell import OutputLocation, ShellTask
from flytekit.types.directory import FlyteDirectory
Expand Down Expand Up @@ -46,8 +47,8 @@ def test_input_substitution_primitive():
name="test",
script="""
set -ex
cat {f}
echo "Hello World {y} on {j}"
cat {inputs.f}
echo "Hello World {inputs.y} on {inputs.j}"
""",
inputs=kwtypes(f=str, y=int, j=datetime.datetime),
)
Expand All @@ -62,24 +63,46 @@ def test_input_substitution_files():
t = ShellTask(
name="test",
script="""
cat {f}
echo "Hello World {y} on {j}"
cat {inputs.f}
echo "Hello World {inputs.y} on {inputs.j}"
""",
inputs=kwtypes(f=CSVFile, y=FlyteDirectory, j=datetime.datetime),
)

assert t(f=test_csv, y=testdata, j=datetime.datetime(2021, 11, 10, 12, 15, 0)) is None


def test_input_substitution_files_ctx():
sec = flytekit.current_context().secrets
envvar = sec.get_secrets_env_var("group", "key")
os.environ[envvar] = "value"
assert sec.get("group", "key") == "value"

t = ShellTask(
name="test",
script="""
export EXEC={ctx.execution_id}
export SECRET={ctx.secrets.group.key}
cat {inputs.f}
echo "Hello World {inputs.y} on {inputs.j}"
""",
inputs=kwtypes(f=CSVFile, y=FlyteDirectory, j=datetime.datetime),
debug=True,
)

assert t(f=test_csv, y=testdata, j=datetime.datetime(2021, 11, 10, 12, 15, 0)) is None
del os.environ[envvar]


def test_input_output_substitution_files():
script = "cat {f} > {y}"
script = "cat {inputs.f} > {outputs.y}"
t = ShellTask(
name="test",
debug=True,
script=script,
inputs=kwtypes(f=CSVFile),
output_locs=[
OutputLocation(var="y", var_type=FlyteFile, location="{f}.mod"),
OutputLocation(var="y", var_type=FlyteFile, location="{inputs.f}.mod"),
],
)

Expand All @@ -101,15 +124,15 @@ def test_input_output_substitution_files():

def test_input_single_output_substitution_files():
script = """
cat {f} >> {z}
echo "Hello World {y} on {j}"
cat {inputs.f} >> {outputs.z}
echo "Hello World {inputs.y} on {inputs.j}"
"""
t = ShellTask(
name="test",
debug=True,
script=script,
inputs=kwtypes(f=CSVFile, y=FlyteDirectory, j=datetime.datetime),
output_locs=[OutputLocation(var="z", var_type=FlyteFile, location="{f}.pyc")],
output_locs=[OutputLocation(var="z", var_type=FlyteFile, location="{inputs.f}.pyc")],
)

assert t.script == script
Expand All @@ -122,14 +145,14 @@ def test_input_single_output_substitution_files():
[
(
"""
cat {missing} >> {z}
echo "Hello World {y} on {j} - output {x}"
cat {missing} >> {outputs.z}
echo "Hello World {inputs.y} on {inputs.j} - output {outputs.x}"
"""
),
(
"""
cat {f} {missing} >> {z}
echo "Hello World {y} on {j} - output {x}"
cat {inputs.f} {missing} >> {outputs.z}
echo "Hello World {inputs.y} on {inputs.j} - output {outputs.x}"
"""
),
],
Expand All @@ -141,31 +164,30 @@ def test_input_output_extra_and_missing_variables(script):
script=script,
inputs=kwtypes(f=CSVFile, y=FlyteDirectory, j=datetime.datetime),
output_locs=[
OutputLocation(var="x", var_type=FlyteDirectory, location="{y}"),
OutputLocation(var="z", var_type=FlyteFile, location="{f}.pyc"),
OutputLocation(var="x", var_type=FlyteDirectory, location="{inputs.y}"),
OutputLocation(var="z", var_type=FlyteFile, location="{inputs.f}.pyc"),
],
)

with pytest.raises(ValueError, match="missing"):
t(f=test_csv, y=testdata, j=datetime.datetime(2021, 11, 10, 12, 15, 0))


def test_cannot_reuse_variables_for_both_inputs_and_outputs():
def test_reuse_variables_for_both_inputs_and_outputs():
t = ShellTask(
name="test",
debug=True,
script="""
cat {f} >> {y}
echo "Hello World {y} on {j}"
cat {inputs.f} >> {outputs.y}
echo "Hello World {inputs.y} on {inputs.j}"
""",
inputs=kwtypes(f=CSVFile, y=FlyteDirectory, j=datetime.datetime),
output_locs=[
OutputLocation(var="y", var_type=FlyteFile, location="{f}.pyc"),
OutputLocation(var="y", var_type=FlyteFile, location="{inputs.f}.pyc"),
],
)

with pytest.raises(ValueError, match="Variables {'y'} in Query"):
t(f=test_csv, y=testdata, j=datetime.datetime(2021, 11, 10, 12, 15, 0))
t(f=test_csv, y=testdata, j=datetime.datetime(2021, 11, 10, 12, 15, 0))


def test_can_use_complex_types_for_inputs_to_f_string_template():
Expand All @@ -177,10 +199,10 @@ class InputArgs:
t = ShellTask(
name="test",
debug=True,
script="""cat {input_args.in_file} >> {input_args.in_file}.tmp""",
script="""cat {inputs.input_args.in_file} >> {inputs.input_args.in_file}.tmp""",
inputs=kwtypes(input_args=InputArgs),
output_locs=[
OutputLocation(var="x", var_type=FlyteFile, location="{input_args.in_file}.tmp"),
OutputLocation(var="x", var_type=FlyteFile, location="{inputs.input_args.in_file}.tmp"),
],
)

Expand All @@ -196,8 +218,8 @@ def test_shell_script():
script_file=script_sh,
inputs=kwtypes(f=CSVFile, y=FlyteDirectory, j=datetime.datetime),
output_locs=[
OutputLocation(var="x", var_type=FlyteDirectory, location="{y}"),
OutputLocation(var="z", var_type=FlyteFile, location="{f}.pyc"),
OutputLocation(var="x", var_type=FlyteDirectory, location="{inputs.y}"),
OutputLocation(var="z", var_type=FlyteFile, location="{inputs.f}.pyc"),
],
)

Expand Down
4 changes: 2 additions & 2 deletions tests/flytekit/unit/extras/tasks/testdata/script.sh
Expand Up @@ -2,5 +2,5 @@

set -ex

cat "{f}" >> "{z}"
echo "Hello World {y} on {j} - output {x}"
cat "{inputs.f}" >> "{outputs.z}"
echo "Hello World {inputs.y} on {inputs.j} - output {outputs.x}"

0 comments on commit 7b0c12e

Please sign in to comment.