Skip to content

Commit

Permalink
Add some tests (#819)
Browse files Browse the repository at this point in the history
Signed-off-by: Yee Hing Tong <wild-endeavor@users.noreply.github.com>
  • Loading branch information
wild-endeavor committed Jan 19, 2022
1 parent d9f5106 commit b5bf089
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 0 deletions.
23 changes: 23 additions & 0 deletions tests/flytekit/unit/core/test_imperative.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,29 @@ def t2(a: typing.List[int]) -> int:
assert wb() == [3, 6]


def test_imperative_tuples():
@task
def t1() -> (int, str):
return 3, "three"

@task
def t3(a: int, b: str) -> typing.Tuple[int, str]:
return a + 2, "world" + b

wb = ImperativeWorkflow(name="my.workflow.a")
t1_node = wb.add_entity(t1)
t3_node = wb.add_entity(t3, a=t1_node.outputs["o0"], b=t1_node.outputs["o1"])
wb.add_workflow_output("wf0", t3_node.outputs["o0"], python_type=int)
wb.add_workflow_output("wf1", t3_node.outputs["o1"], python_type=str)
res = wb()
assert res == (5, "worldthree")

with pytest.raises(KeyError):
wb = ImperativeWorkflow(name="my.workflow.b")
t1_node = wb.add_entity(t1)
wb.add_entity(t3, a=t1_node.outputs["bad"], b=t1_node.outputs["o2"])


def test_call_normal():
@task
def t1(a: int) -> (int, str):
Expand Down
63 changes: 63 additions & 0 deletions tests/flytekit/unit/core/test_shim_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import tempfile
from collections import OrderedDict

import mock

from flytekit import ContainerTask, kwtypes
from flytekit.core import context_manager
from flytekit.core.context_manager import Image, ImageConfig
from flytekit.core.python_customized_container_task import PythonCustomizedContainerTask, TaskTemplateResolver
from flytekit.core.utils import write_proto_to_file
from flytekit.tools.translator import get_serializable

default_img = Image(name="default", fqn="test", tag="tag")
serialization_settings = context_manager.SerializationSettings(
project="project",
domain="domain",
version="version",
env=None,
image_config=ImageConfig(default_image=default_img, images=[default_img]),
)


class Placeholder(object):
...


def test_resolver_load_task():
# any task is fine, just copied one
square = ContainerTask(
name="square",
input_data_dir="/var/inputs",
output_data_dir="/var/outputs",
inputs=kwtypes(val=int),
outputs=kwtypes(out=int),
image="alpine",
command=["sh", "-c", "echo $(( {{.Inputs.val}} * {{.Inputs.val}} )) | tee /var/outputs/out"],
)

resolver = TaskTemplateResolver()
ts = get_serializable(OrderedDict(), serialization_settings, square)
with tempfile.NamedTemporaryFile() as f:
write_proto_to_file(ts.template.to_flyte_idl(), f.name)
# load_task should create an instance of the path to the object given, doesn't need to be a real executor
shim_task = resolver.load_task([f.name, f"{Placeholder.__module__}.Placeholder"])
assert isinstance(shim_task.executor, Placeholder)
assert shim_task.task_template.id.name == "square"
assert shim_task.task_template.interface.inputs["val"] is not None
assert shim_task.task_template.interface.outputs["out"] is not None


@mock.patch("flytekit.core.python_customized_container_task.PythonCustomizedContainerTask.get_config")
@mock.patch("flytekit.core.python_customized_container_task.PythonCustomizedContainerTask.get_custom")
def test_serialize_to_model(mock_custom, mock_config):
mock_custom.return_value = {"a": "custom"}
mock_config.return_value = {"a": "config"}
ct = PythonCustomizedContainerTask(
name="mytest", task_config=None, container_image="someimage", executor_type=Placeholder
)
tt = ct.serialize_to_model(serialization_settings)
assert tt.container.image == "someimage"
assert len(tt.config) == 1
assert tt.id.name == "mytest"
assert len(tt.custom) == 1

0 comments on commit b5bf089

Please sign in to comment.