Skip to content

Commit

Permalink
enable @asset-decorated functions to accept kwargs (#7871)
Browse files Browse the repository at this point in the history
  • Loading branch information
sryza committed May 13, 2022
1 parent a4d79e3 commit bfbfa62
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 9 deletions.
24 changes: 15 additions & 9 deletions python_modules/dagster/dagster/core/asset_defs/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from dagster.core.definitions.utils import NoValueSentinel
from dagster.core.errors import DagsterInvalidDefinitionError
from dagster.core.types.dagster_type import DagsterType
from dagster.seven import funcsigs
from dagster.utils.backcompat import ExperimentalWarning, experimental_decorator

from .asset_in import AssetIn
Expand Down Expand Up @@ -342,18 +343,23 @@ def build_asset_ins(
is_context_provided = len(params) > 0 and params[0].name in get_valid_name_permutations(
"context"
)
input_param_names = [
input_param.name for input_param in (params[1:] if is_context_provided else params)
input_params = params[1:] if is_context_provided else params
non_var_input_param_names = [
param.name
for param in input_params
if param.kind == funcsigs.Parameter.POSITIONAL_OR_KEYWORD
]
has_kwargs = any(param.kind == funcsigs.Parameter.VAR_KEYWORD for param in input_params)

all_input_names = set(input_param_names) | asset_ins.keys()
all_input_names = set(non_var_input_param_names) | asset_ins.keys()

for in_key in asset_ins.keys():
if in_key not in input_param_names:
raise DagsterInvalidDefinitionError(
f"Key '{in_key}' in provided ins dict does not correspond to any of the names "
"of the arguments to the decorated function"
)
if not has_kwargs:
for in_key in asset_ins.keys():
if in_key not in non_var_input_param_names:
raise DagsterInvalidDefinitionError(
f"Key '{in_key}' in provided ins dict does not correspond to any of the names "
"of the arguments to the decorated function"
)

ins_by_asset_key: Dict[AssetKey, Tuple[str, In]] = {}
for input_name in all_input_names:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -345,3 +345,14 @@ def my_asset():
...

assert my_asset.op.tags == tags_stringified


def test_kwargs():
@asset(ins={"upstream": AssetIn()})
def my_asset(**kwargs):
del kwargs

assert isinstance(my_asset, AssetsDefinition)
assert len(my_asset.op.output_defs) == 1
assert len(my_asset.op.input_defs) == 1
assert AssetKey("upstream") in my_asset.asset_keys_by_input_name.values()

0 comments on commit bfbfa62

Please sign in to comment.