Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support positional arguments #2522

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 20 additions & 20 deletions flytekit/core/promise.py
Original file line number Diff line number Diff line change
Expand Up @@ -1202,19 +1202,22 @@ def flyte_entity_call_handler(
#. Start a local execution - This means that we're not already in a local workflow execution, which means that
we should expect inputs to be native Python values and that we should return Python native values.
"""
# Sanity checks
# Only keyword args allowed
if len(args) > 0:
raise _user_exceptions.FlyteAssertion(
f"When calling tasks, only keyword args are supported. "
f"Aborting execution as detected {len(args)} positional args {args}"
)
# Make sure arguments are part of interface
for k, v in kwargs.items():
if k not in cast(SupportsNodeCreation, entity).python_interface.inputs:
raise AssertionError(
f"Received unexpected keyword argument '{k}' in function '{cast(SupportsNodeCreation, entity).name}'"
)
if k not in entity.python_interface.inputs:
raise AssertionError(f"Received unexpected keyword argument '{k}' in function '{entity.name}'")

# Check if we have more arguments than expected
if len(args) > len(entity.python_interface.inputs):
raise AssertionError(
f"Received more arguments than expected in function '{entity.name}'. Expected {len(entity.python_interface.inputs)} but got {len(args)}"
)

# Convert args to kwargs
for arg, input_name in zip(args, entity.python_interface.inputs.keys()):
if input_name in kwargs:
raise AssertionError(f"Got multiple values for argument '{input_name}' in function '{entity.name}'")
kwargs[input_name] = arg

ctx = FlyteContextManager.current_context()
if ctx.execution_state and (
Expand All @@ -1234,15 +1237,12 @@ def flyte_entity_call_handler(
child_ctx.execution_state
and child_ctx.execution_state.branch_eval_mode == BranchEvalMode.BRANCH_SKIPPED
):
if (
len(cast(SupportsNodeCreation, entity).python_interface.inputs) > 0
or len(cast(SupportsNodeCreation, entity).python_interface.outputs) > 0
):
output_names = list(cast(SupportsNodeCreation, entity).python_interface.outputs.keys())
if len(entity.python_interface.inputs) > 0 or len(entity.python_interface.outputs) > 0:
output_names = list(entity.python_interface.outputs.keys())
if len(output_names) == 0:
return VoidPromise(entity.name)
vals = [Promise(var, None) for var in output_names]
return create_task_output(vals, cast(SupportsNodeCreation, entity).python_interface)
return create_task_output(vals, entity.python_interface)
else:
return None
return cast(LocallyExecutable, entity).local_execute(ctx, **kwargs)
Expand All @@ -1255,7 +1255,7 @@ def flyte_entity_call_handler(
cast(ExecutionParameters, child_ctx.user_space_params)._decks = []
result = cast(LocallyExecutable, entity).local_execute(child_ctx, **kwargs)

expected_outputs = len(cast(SupportsNodeCreation, entity).python_interface.outputs)
expected_outputs = len(entity.python_interface.outputs)
if expected_outputs == 0:
if result is None or isinstance(result, VoidPromise):
return None
Expand All @@ -1268,10 +1268,10 @@ def flyte_entity_call_handler(
if (1 < expected_outputs == len(cast(Tuple[Promise], result))) or (
result is not None and expected_outputs == 1
):
return create_native_named_tuple(ctx, result, cast(SupportsNodeCreation, entity).python_interface)
return create_native_named_tuple(ctx, result, entity.python_interface)

raise AssertionError(
f"Expected outputs and actual outputs do not match."
f"Result {result}. "
f"Python interface: {cast(SupportsNodeCreation, entity).python_interface}"
f"Python interface: {entity.python_interface}"
)
123 changes: 123 additions & 0 deletions tests/flytekit/unit/core/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -943,3 +943,126 @@ def wf_with_input() -> typing.Optional[typing.List[int]]:
)

assert wf_with_input() == input_val

def test_positional_args_task():
arg1 = 5
arg2 = 6
ret = 17

@task
def t1(x: int, y: int) -> int:
return x + y * 2

@workflow
def wf_pure_positional_args() -> int:
return t1(arg1, arg2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you also write a test (and run on serverless) where a downstream task takes both positional and named arguments from both 1) workflow inputs and 2) outputs of upstream tasks?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Please see if these tests are what you want. Thanks. Also, I'll run them on the serverless later.

https://github.com/flyteorg/flytekit/pull/2522/files#diff-fa72fab667c8393cca8d5afd3943db0edc55794423b7cfc3ab33d0f8c3146650R1022-R1044


@workflow
def wf_mixed_positional_and_keyword_args() -> int:
return t1(arg1, y=arg2)

wf_pure_positional_args_spec = get_serializable(OrderedDict(), serialization_settings, wf_pure_positional_args)
wf_mixed_positional_and_keyword_args_spec = get_serializable(OrderedDict(), serialization_settings, wf_mixed_positional_and_keyword_args)

arg1_binding = Scalar(primitive=Primitive(integer=arg1))
arg2_binding = Scalar(primitive=Primitive(integer=arg2))
output_type = LiteralType(simple=SimpleType.INTEGER)

assert wf_pure_positional_args_spec.template.nodes[0].inputs[0].binding.value == arg1_binding
assert wf_pure_positional_args_spec.template.nodes[0].inputs[1].binding.value == arg2_binding
assert wf_pure_positional_args_spec.template.interface.outputs["o0"].type == output_type


assert wf_mixed_positional_and_keyword_args_spec.template.nodes[0].inputs[0].binding.value == arg1_binding
assert wf_mixed_positional_and_keyword_args_spec.template.nodes[0].inputs[1].binding.value == arg2_binding
assert wf_mixed_positional_and_keyword_args_spec.template.interface.outputs["o0"].type == output_type

assert wf_pure_positional_args() == ret
assert wf_mixed_positional_and_keyword_args() == ret

def test_positional_args_workflow():
arg1 = 5
arg2 = 6
ret = 17

@task
def t1(x: int, y: int) -> int:
return x + y * 2

@workflow
def sub_wf(x: int, y: int) -> int:
return t1(x=x, y=y)

@workflow
def wf_pure_positional_args() -> int:
return sub_wf(arg1, arg2)

@workflow
def wf_mixed_positional_and_keyword_args() -> int:
return sub_wf(arg1, y=arg2)

wf_pure_positional_args_spec = get_serializable(OrderedDict(), serialization_settings, wf_pure_positional_args)
wf_mixed_positional_and_keyword_args_spec = get_serializable(OrderedDict(), serialization_settings, wf_mixed_positional_and_keyword_args)

arg1_binding = Scalar(primitive=Primitive(integer=arg1))
arg2_binding = Scalar(primitive=Primitive(integer=arg2))
output_type = LiteralType(simple=SimpleType.INTEGER)

assert wf_pure_positional_args_spec.template.nodes[0].inputs[0].binding.value == arg1_binding
assert wf_pure_positional_args_spec.template.nodes[0].inputs[1].binding.value == arg2_binding
assert wf_pure_positional_args_spec.template.interface.outputs["o0"].type == output_type

assert wf_mixed_positional_and_keyword_args_spec.template.nodes[0].inputs[0].binding.value == arg1_binding
assert wf_mixed_positional_and_keyword_args_spec.template.nodes[0].inputs[1].binding.value == arg2_binding
assert wf_mixed_positional_and_keyword_args_spec.template.interface.outputs["o0"].type == output_type

assert wf_pure_positional_args() == ret
assert wf_mixed_positional_and_keyword_args() == ret

def test_positional_args_chained_tasks():
@task
def t1(x: int, y: int) -> int:
return x + y * 2

@workflow
def wf() -> int:
x = t1(2, y = 3)
y = t1(3, 4)
return t1(x, y = y)

assert wf() == 30

def test_positional_args_task_inputs_from_workflow_args():
@task
def t1(x: int, y: int, z: int) -> int:
return x + y * 2 + z * 3

@workflow
def wf(x: int, y: int) -> int:
return t1(x, y=y, z=3)

assert wf(1, 2) == 14

def test_unexpected_kwargs_task_raises_error():
@task
def t1(a: int) -> int:
return a

with pytest.raises(AssertionError, match="Received unexpected keyword argument"):
t1(b=6)

def test_too_many_positional_args_task_raises_error():
@task
def t1(a: int) -> int:
return a

with pytest.raises(AssertionError, match="Received more arguments than expected"):
t1(1, 2)

def test_both_positional_and_keyword_args_task_raises_error():
@task
def t1(a: int) -> int:
return a

with pytest.raises(AssertionError, match="Got multiple values for argument"):
t1(1, a=2)
Loading