Skip to content

Commit

Permalink
Fix kwarg case with invocation (#7095)
Browse files Browse the repository at this point in the history
* Fix kwarg case with invocation

* Fix nits
  • Loading branch information
dpeng817 committed Mar 16, 2022
1 parent 5765c38 commit aa1dc4a
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 1 deletion.
4 changes: 4 additions & 0 deletions python_modules/dagster/dagster/core/decorator_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ def positional_arg_name_list(params: List[funcsigs.Parameter]) -> List[str]:
return [p.name for p in params if p.kind in accepted_param_types]


def param_is_var_keyword(param: funcsigs.Parameter) -> bool:
return param.kind == funcsigs.Parameter.VAR_KEYWORD


def format_docstring_for_description(fn: Callable) -> Optional[str]:
if fn.__doc__ is not None:
docstring = fn.__doc__
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from ...decorator_utils import (
get_function_params,
get_valid_name_permutations,
param_is_var_keyword,
positional_arg_name_list,
)
from ..inference import infer_input_props, infer_output_props
Expand All @@ -41,11 +42,19 @@ def has_context_arg(self) -> bool:
return is_context_provided(get_function_params(self.decorated_fn))

@lru_cache(maxsize=1)
def _get_function_params(self) -> List[funcsigs.Parameter]:
return get_function_params(self.decorated_fn)

def positional_inputs(self) -> List[str]:
params = get_function_params(self.decorated_fn)
params = self._get_function_params()
input_args = params[1:] if self.has_context_arg() else params
return positional_arg_name_list(input_args)

def has_var_kwargs(self) -> bool:
params = self._get_function_params()
# var keyword arg has to be the last argument
return len(params) > 0 and param_is_var_keyword(params[-1])


class NoContextDecoratedSolidFunction(DecoratedSolidFunction):
"""Wrapper around a decorated solid function, when the decorator does not permit a context
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,12 @@ def _resolve_inputs(
f"Too many input arguments were provided for {node_label} '{context.alias}'. {suggestion}"
)

# If more args were provided than the function has positional args, then fail early.
positional_inputs = cast("DecoratedSolidFunction", solid_def.compute_fn).positional_inputs()
if len(args) > len(positional_inputs):
raise DagsterInvalidInvocationError(
f"{solid_def.node_type_str} '{solid_def.name}' has {len(positional_inputs)} positional inputs, but {len(args)} positional inputs were provided."
)

input_dict = {}

Expand All @@ -147,6 +152,27 @@ def _resolve_inputs(
kwargs[positional_input] if positional_input in kwargs else input_def.default_value
)

unassigned_kwargs = {k: v for k, v in kwargs.items() if k not in input_dict}
# If there are unassigned inputs, then they may be intended for use with a variadic keyword argument.
if unassigned_kwargs and cast("DecoratedSolidFunction", solid_def.compute_fn).has_var_kwargs():
for k, v in unassigned_kwargs.items():
input_dict[k] = v

# Error if any inputs are not represented in input_dict
input_def_names = set(input_defs_by_name.keys())
provided_input_names = set(input_dict.keys())

missing_inputs = input_def_names - provided_input_names
extra_inputs = provided_input_names - input_def_names

if missing_inputs or extra_inputs:
error_msg = ""
if extra_inputs:
error_msg += f"Invocation had extra inputs {list(extra_inputs)}."
if missing_inputs:
error_msg += f"Invocation had missing inputs {list(missing_inputs)}."
raise DagsterInvalidInvocationError(error_msg)

# Type check inputs
op_label = context.describe_op()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
AssetMaterialization,
AssetObservation,
DagsterInvalidConfigError,
DagsterInvalidDefinitionError,
DagsterInvariantViolationError,
DagsterType,
DagsterTypeCheckDidNotPass,
Expand Down Expand Up @@ -748,3 +749,32 @@ def my_constant_asset_op():
result = execute_op_in_graph(my_constant_asset_op)
assert result.success
assert len(result.asset_materializations_for_node(my_constant_asset_op.name)) == 1


def test_args_kwargs_op():
with pytest.raises(
DagsterInvalidDefinitionError,
match=r"@op 'the_op' decorated function has positional vararg parameter "
r"'\*args'. @op decorated functions should only have keyword arguments "
r"that match input names and, if system information is required, a "
r"first positional parameter named 'context'.",
):

@op(ins={"the_in": In()})
def the_op(*args):
pass

@op(ins={"the_in": In()})
def the_op(**kwargs):
return kwargs["the_in"]

@op
def emit_op():
return 1

@graph
def the_graph_provides_inputs():
the_op(emit_op())

result = the_graph_provides_inputs.execute_in_process()
assert result.success
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
ExpectationResult,
Failure,
Field,
In,
InputDefinition,
Materialization,
Noneable,
Expand Down Expand Up @@ -1019,3 +1020,29 @@ def the_op(context):
match="In op 'the_op', attempted to log output metadata for output 'result' with mapping_key 'one' which has already been yielded. Metadata must be logged before the output is yielded.",
):
list(the_op(build_op_context()))


def test_kwarg_inputs():
@op(ins={"the_in": In(str)})
def the_op(**kwargs) -> str:
return kwargs["the_in"] + "foo"

with pytest.raises(
DagsterInvalidInvocationError,
match="op 'the_op' has 0 positional inputs, but 1 positional inputs were provided.",
):
the_op("bar")

assert the_op(the_in="bar") == "barfoo"

with pytest.raises(
DagsterInvalidInvocationError,
match="Invocation had extra inputs \['bad_val'\].Invocation had missing inputs \['the_in'\].",
):
the_op(bad_val="bar")

@op(ins={"the_in": In(), "kwarg_in": In(), "kwarg_in_two": In()})
def the_op(the_in, **kwargs):
return the_in + kwargs["kwarg_in"] + kwargs["kwarg_in_two"]

assert the_op("foo", kwarg_in="bar", kwarg_in_two="baz") == "foobarbaz"

0 comments on commit aa1dc4a

Please sign in to comment.