Skip to content

Commit

Permalink
Asset key to node handle mapping (#7599)
Browse files Browse the repository at this point in the history
  • Loading branch information
clairelin135 committed Apr 27, 2022
1 parent 6d27f54 commit 3bcb757
Show file tree
Hide file tree
Showing 3 changed files with 511 additions and 2 deletions.
13 changes: 11 additions & 2 deletions python_modules/dagster/dagster/core/asset_defs/assets.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,19 @@ def __init__(
self._partitions_def = partitions_def
self._partition_mappings = partition_mappings or {}

all_input_asset_keys = set(asset_keys_by_input_name.values())
# if not specified assume all output assets depend on all input assets
all_asset_keys = set(asset_keys_by_output_name.values())
self._asset_deps = asset_deps or {
output_key: all_input_asset_keys for output_key in asset_keys_by_output_name.values()
out_asset_key: set(asset_keys_by_input_name.values())
for out_asset_key in all_asset_keys
}
check.invariant(
set(self._asset_deps.keys()) == all_asset_keys,
"The set of asset keys with dependencies specified in the asset_deps argument must "
"equal the set of asset keys produced by this AssetsDefinition. \n"
f"asset_deps keys: {set(self._asset_deps.keys())} \n"
f"expected keys: {all_asset_keys}",
)

def __call__(self, *args, **kwargs):
return self._node_def(*args, **kwargs)
Expand Down
193 changes: 193 additions & 0 deletions python_modules/dagster/dagster/core/definitions/asset_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
Sequence,
Set,
Tuple,
Union,
)

from dagster import check
Expand Down Expand Up @@ -84,6 +85,184 @@ def _resolve_input_to_destinations(
return all_destinations


def _build_graph_dependencies(
graph_def: GraphDefinition,
parent_handle: Union[NodeHandle, None],
outputs_by_graph_handle: Dict[NodeHandle, Dict[str, NodeOutputHandle]],
non_asset_inputs_by_node_handle: Dict[NodeHandle, Sequence[NodeOutputHandle]],
assets_defs_by_node_handle: Mapping[NodeHandle, "AssetsDefinition"],
):
"""
Scans through every node in the graph, making a recursive call when a node is a graph.
Builds two dictionaries:
outputs_by_graph_handle: A mapping of every graph node handle to a dictionary with each out
name as a key and a NodeOutputHandle containing the op output name and op node handle
non_asset_inputs_by_node_handle: A mapping of all node handles to all upstream node handles
that are not assets. Each key is a node output handle.
"""
dep_struct = graph_def.dependency_structure
for sub_node_name, sub_node in graph_def.node_dict.items():
curr_node_handle = NodeHandle(sub_node_name, parent=parent_handle)
if isinstance(sub_node.definition, GraphDefinition):
_build_graph_dependencies(
sub_node.definition,
curr_node_handle,
outputs_by_graph_handle,
non_asset_inputs_by_node_handle,
assets_defs_by_node_handle,
)
outputs_by_graph_handle[curr_node_handle] = {
mapping.definition.name: NodeOutputHandle(
NodeHandle(mapping.maps_from.solid_name, parent=curr_node_handle),
mapping.maps_from.output_name,
)
for mapping in sub_node.definition.output_mappings
}
non_asset_inputs_by_node_handle[curr_node_handle] = [
NodeOutputHandle(
NodeHandle(output_handle.solid_name, parent=parent_handle),
output_handle.output_def.name,
)
for output_handle in dep_struct.all_upstream_outputs_from_solid(sub_node_name)
if NodeHandle(output_handle.solid.name, parent=parent_handle)
not in assets_defs_by_node_handle
]


def _get_dependency_node_handles(
non_asset_inputs_by_node_handle: Mapping[NodeHandle, Sequence[NodeOutputHandle]],
outputs_by_graph_handle: Mapping[NodeHandle, Mapping[str, NodeOutputHandle]],
dep_node_handles_by_node: Dict[NodeHandle, List[NodeHandle]],
node_output_handle: NodeOutputHandle,
) -> Sequence[NodeHandle]:
"""
Given a node handle and an optional output name of the node (if the node is a graph), return
all upstream op node handles (leaf nodes).
Arguments:
outputs_by_graph_handle: A mapping of every graph node handle to a dictionary with each out
name as a key and a NodeOutputHandle containing the op output name and op node handle
non_asset_inputs_by_node_handle: A mapping of all node handles to all upstream node handles
that are not assets. Each key is a node output handle.
dep_node_handles_by_node: A mapping of each non-graph node to all non-graph node dependencies.
Used for memoization to avoid scanning already visited nodes.
curr_node_handle: The current node handle being traversed.
graph_output_name: Name of the node output being traversed. Only used if the current node is a
graph to trace the op that generates this output.
"""
curr_node_handle = node_output_handle.node_handle

if curr_node_handle in dep_node_handles_by_node:
return dep_node_handles_by_node[curr_node_handle]

dependency_node_handles: List[
NodeHandle
] = [] # first node in list is node that outputs the asset
if curr_node_handle not in outputs_by_graph_handle:
dependency_node_handles.append(curr_node_handle)
else: # is graph
node_output_handle = outputs_by_graph_handle[curr_node_handle][
node_output_handle.output_name
]
dependency_node_handles.extend(
_get_dependency_node_handles(
non_asset_inputs_by_node_handle,
outputs_by_graph_handle,
dep_node_handles_by_node,
node_output_handle,
)
)
for node_output_handle in non_asset_inputs_by_node_handle[curr_node_handle]:
dependency_node_handles.extend(
_get_dependency_node_handles(
non_asset_inputs_by_node_handle,
outputs_by_graph_handle,
dep_node_handles_by_node,
node_output_handle,
)
)

if curr_node_handle not in outputs_by_graph_handle:
dep_node_handles_by_node[curr_node_handle] = dependency_node_handles

return dependency_node_handles


def _asset_key_to_dep_node_handles(
graph_def: GraphDefinition, assets_defs_by_node_handle: Mapping[NodeHandle, "AssetsDefinition"]
) -> Mapping[AssetKey, Set[NodeHandle]]:
"""
For each asset in assets_defs_by_node_handle, returns all the op handles within the asset's node
that are upstream dependencies of the asset.
"""
# A mapping of all node handles to all upstream node handles
# that are not assets. Each key is a node handle with node output handle value
non_asset_inputs_by_node_handle: Dict[NodeHandle, Sequence[NodeOutputHandle]] = {}

# A mapping of every graph node handle to a dictionary with each out
# name as a key and node output handle value
outputs_by_graph_handle: Dict[NodeHandle, Dict[str, NodeOutputHandle]] = {}
_build_graph_dependencies(
graph_def=graph_def,
parent_handle=None,
outputs_by_graph_handle=outputs_by_graph_handle,
non_asset_inputs_by_node_handle=non_asset_inputs_by_node_handle,
assets_defs_by_node_handle=assets_defs_by_node_handle,
)

dep_nodes_by_asset_key: Dict[AssetKey, List[NodeHandle]] = {}

for node_handle, assets_defs in assets_defs_by_node_handle.items():
dep_node_handles_by_node: Dict[
NodeHandle, List[NodeHandle]
] = {} # memoized map of nodehandle to all node handle dependencies that are ops
for output_name, asset_key in assets_defs.asset_keys_by_output_name.items():
output_def = assets_defs.node_def.output_def_named(output_name)
output_name = output_def.name

dep_nodes_by_asset_key[
asset_key
] = [] # first element in list is node that outputs asset
if node_handle not in outputs_by_graph_handle:
dep_nodes_by_asset_key[asset_key].extend([node_handle])
else: # is graph
node_output_handle = outputs_by_graph_handle[node_handle][output_name]
dep_nodes_by_asset_key[asset_key].extend(
_get_dependency_node_handles(
non_asset_inputs_by_node_handle,
outputs_by_graph_handle,
dep_node_handles_by_node,
node_output_handle,
)
)

# handle internal_asset_deps
for node_handle, assets_defs in assets_defs_by_node_handle.items():
all_output_asset_keys = assets_defs.asset_keys
for asset_key, dep_asset_keys in assets_defs.asset_deps.items():
for dep_asset_key in [key for key in dep_asset_keys if key in all_output_asset_keys]:
output_node = dep_nodes_by_asset_key[asset_key][
0
] # first item in list is the original node that outputted the asset
dep_asset_key_node_handles = [
node for node in dep_nodes_by_asset_key[dep_asset_key] if node != output_node
]
dep_nodes_by_asset_key[asset_key] = [
node
for node in dep_nodes_by_asset_key[asset_key]
if node not in dep_asset_key_node_handles
]

dep_node_set_by_asset_key: Dict[AssetKey, Set[NodeHandle]] = {}
for asset_key, dep_node_handles in dep_nodes_by_asset_key.items():
dep_node_set_by_asset_key[asset_key] = set(dep_node_handles)
return dep_node_set_by_asset_key


def _asset_mappings_for_node(
node_def: NodeDefinition, node_handle: Optional[NodeHandle]
) -> Tuple[
Expand Down Expand Up @@ -162,6 +341,7 @@ def __init__(
Mapping[NodeOutputHandle, AssetOutputInfo]
] = None,
asset_deps: Optional[Mapping[AssetKey, AbstractSet[AssetKey]]] = None,
dependency_node_handles_by_asset_key: Optional[Mapping[AssetKey, Set[NodeHandle]]] = None,
):
self._asset_keys_by_node_input_handle = check.opt_dict_param(
asset_keys_by_node_input_handle,
Expand All @@ -178,6 +358,12 @@ def __init__(
self._asset_deps = check.opt_dict_param(
asset_deps, "asset_deps", key_type=AssetKey, value_type=set
)
self._dependency_node_handles_by_asset_key = check.opt_dict_param(
dependency_node_handles_by_asset_key,
"dependency_node_handles_by_asset_key",
key_type=AssetKey,
value_type=Set,
)

@staticmethod
def from_graph(graph_def: GraphDefinition) -> "AssetLayer":
Expand Down Expand Up @@ -240,6 +426,9 @@ def from_graph_and_assets_node_mapping(
asset_keys_by_node_input_handle=asset_key_by_input,
asset_info_by_node_output_handle=asset_info_by_output,
asset_deps=asset_deps,
dependency_node_handles_by_asset_key=_asset_key_to_dep_node_handles(
graph_def, assets_defs_by_node_handle
),
)

@property
Expand All @@ -253,6 +442,10 @@ def upstream_assets_for_asset(self, asset_key: AssetKey) -> AbstractSet[AssetKey
)
return self._asset_deps[asset_key]

@property
def dependency_node_handles_by_asset_key(self) -> Mapping[AssetKey, Sequence[NodeOutputHandle]]:
return self._dependency_node_handles_by_asset_key

def asset_key_for_input(self, node_handle: NodeHandle, input_name: str) -> Optional[AssetKey]:
return self._asset_keys_by_node_input_handle.get(NodeInputHandle(node_handle, input_name))

Expand Down

0 comments on commit 3bcb757

Please sign in to comment.