Skip to content

Commit

Permalink
[4/n] Interop Stack: node_def on AssetsDefinition can be graph (#7573)
Browse files Browse the repository at this point in the history
  • Loading branch information
OwenKephart committed Apr 27, 2022
1 parent cc2060a commit 6d27f54
Show file tree
Hide file tree
Showing 6 changed files with 457 additions and 18 deletions.
30 changes: 22 additions & 8 deletions python_modules/dagster/dagster/core/asset_defs/assets.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import AbstractSet, Mapping, Optional
from typing import AbstractSet, Mapping, Optional, cast

from dagster import check
from dagster.core.definitions import NodeDefinition, OpDefinition
Expand All @@ -13,14 +13,24 @@ def __init__(
self,
asset_keys_by_input_name: Mapping[str, AssetKey],
asset_keys_by_output_name: Mapping[str, AssetKey],
op: OpDefinition,
node_def: NodeDefinition,
partitions_def: Optional[PartitionsDefinition] = None,
partition_mappings: Optional[Mapping[AssetKey, PartitionMapping]] = None,
asset_deps: Optional[Mapping[AssetKey, AbstractSet[AssetKey]]] = None,
):
self._op = op
self._asset_keys_by_input_name = asset_keys_by_input_name
self._asset_keys_by_output_name = asset_keys_by_output_name
self._node_def = node_def
self._asset_keys_by_input_name = check.dict_param(
asset_keys_by_input_name,
"asset_keys_by_input_name",
key_type=str,
value_type=AssetKey,
)
self._asset_keys_by_output_name = check.dict_param(
asset_keys_by_output_name,
"asset_keys_by_output_name",
key_type=str,
value_type=AssetKey,
)

self._partitions_def = partitions_def
self._partition_mappings = partition_mappings or {}
Expand All @@ -31,15 +41,19 @@ def __init__(
}

def __call__(self, *args, **kwargs):
return self._op(*args, **kwargs)
return self._node_def(*args, **kwargs)

@property
def op(self) -> OpDefinition:
return self._op
check.invariant(
isinstance(self._node_def, OpDefinition),
"The NodeDefinition for this AssetsDefinition is not of type OpDefinition.",
)
return cast(OpDefinition, self._node_def)

@property
def node_def(self) -> NodeDefinition:
return self._op
return self._node_def

@property
def asset_deps(self) -> Mapping[AssetKey, AbstractSet[AssetKey]]:
Expand Down
3 changes: 1 addition & 2 deletions python_modules/dagster/dagster/core/asset_defs/assets_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def asset2(asset1):

graph = GraphDefinition(
name=name,
node_defs=[asset.op for asset in assets],
node_defs=[asset.node_def for asset in assets],
dependencies=deps,
description=description,
input_mappings=None,
Expand Down Expand Up @@ -242,7 +242,6 @@ def build_deps(
node_outputs_by_asset[asset_key] = (assets_def.node_def, output_name)

deps: Dict[Union[str, NodeInvocation], Dict[str, IDependencyDefinition]] = {}

# if the same graph/op is used in multiple assets_definitions, their invocations much have
# different names. we keep track of definitions that share a name and add a suffix to their
# invocations to solve this issue
Expand Down
5 changes: 2 additions & 3 deletions python_modules/dagster/dagster/core/asset_defs/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def __call__(self, fn: Callable) -> AssetsDefinition:
return AssetsDefinition(
asset_keys_by_input_name=asset_keys_by_input_name,
asset_keys_by_output_name={"result": out_asset_key},
op=op,
node_def=op,
partitions_def=self.partitions_def,
partition_mappings={
asset_keys_by_input_name[input_name]: partition_mapping
Expand Down Expand Up @@ -299,7 +299,6 @@ def inner(fn: Callable[..., Any]) -> AssetsDefinition:
"must be associated with an input to the asset or produced by this asset. Valid "
f"keys: {valid_asset_deps}",
)

with warnings.catch_warnings():
warnings.simplefilter("ignore", category=ExperimentalWarning)
op = _Op(
Expand All @@ -319,7 +318,7 @@ def inner(fn: Callable[..., Any]) -> AssetsDefinition:
input_name: asset_key for asset_key, (input_name, _) in asset_ins.items()
},
asset_keys_by_output_name=asset_keys_by_output_name,
op=op,
node_def=op,
asset_deps={asset_keys_by_output_name[name]: asset_deps[name] for name in asset_deps},
)

Expand Down
41 changes: 38 additions & 3 deletions python_modules/dagster/dagster/core/definitions/asset_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
AbstractSet,
Callable,
Dict,
List,
Mapping,
NamedTuple,
Optional,
Sequence,
Set,
Tuple,
)
Expand Down Expand Up @@ -59,6 +61,29 @@ def __new__(
)


def _resolve_input_to_destinations(
name: str, node_def: NodeDefinition, handle: NodeHandle
) -> Sequence[NodeInputHandle]:
"""
Recursively follow input mappings to find all op inputs for a graph input.
"""
if not isinstance(node_def, GraphDefinition):
# must be in the op definition
return [NodeInputHandle(node_handle=handle, input_name=name)]
all_destinations: List[NodeInputHandle] = []
for mapping in node_def.input_mappings:
if mapping.definition.name != name:
continue
# recurse into graph structure
all_destinations += _resolve_input_to_destinations(
# update name to be the mapped input name
name=mapping.maps_to.input_name,
node_def=node_def.solid_named(mapping.maps_to.solid_name).definition,
handle=NodeHandle(mapping.maps_to.solid_name, parent=handle),
)
return all_destinations


def _asset_mappings_for_node(
node_def: NodeDefinition, node_handle: Optional[NodeHandle]
) -> Tuple[
Expand Down Expand Up @@ -190,11 +215,21 @@ def from_graph_and_assets_node_mapping(
asset_deps: Dict[AssetKey, AbstractSet[AssetKey]] = {}
for node_handle, assets_def in assets_defs_by_node_handle.items():
asset_deps.update(assets_def.asset_deps)

for input_name, asset_key in assets_def.asset_keys_by_input_name.items():
node_input_handle = NodeInputHandle(node_handle, input_name)
asset_key_by_input[node_input_handle] = asset_key
# resolve graph input to list of op inputs that consume it
node_input_handles = _resolve_input_to_destinations(
input_name, assets_def.node_def, node_handle
)
for node_input_handle in node_input_handles:
asset_key_by_input[node_input_handle] = asset_key

for output_name, asset_key in assets_def.asset_keys_by_output_name.items():
node_output_handle = NodeOutputHandle(node_handle, output_name)
# resolve graph output to the op output it comes from
inner_output_def, inner_node_handle = assets_def.node_def.resolve_output_to_origin(
output_name, handle=node_handle
)
node_output_handle = NodeOutputHandle(inner_node_handle, inner_output_def.name)
partition_fn = lambda context: {context.partition_key}
asset_info_by_output[node_output_handle] = AssetOutputInfo(
asset_key,
Expand Down

0 comments on commit 6d27f54

Please sign in to comment.