Skip to content

Commit

Permalink
[1/n] Interop Stack: Use AssetJobInfo at runtime instead of OutputDef…
Browse files Browse the repository at this point in the history
…inition (#7473)
  • Loading branch information
OwenKephart committed Apr 25, 2022
1 parent 922d786 commit fdb2274
Show file tree
Hide file tree
Showing 8 changed files with 292 additions and 14 deletions.
178 changes: 178 additions & 0 deletions python_modules/dagster/dagster/core/definitions/asset_layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
from typing import (
TYPE_CHECKING,
AbstractSet,
Callable,
Dict,
Mapping,
NamedTuple,
Optional,
Set,
Tuple,
)

from dagster import check
from dagster.core.definitions.events import AssetKey

from .dependency import NodeHandle, NodeInputHandle, NodeOutputHandle
from .graph_definition import GraphDefinition
from .node_definition import NodeDefinition

if TYPE_CHECKING:
from dagster.core.execution.context.output import OutputContext

from .partition import PartitionsDefinition


class AssetOutputInfo(
NamedTuple(
"_AssetOutputInfo",
[
("key", AssetKey),
("partitions_fn", Callable[["OutputContext"], Optional[AbstractSet[str]]]),
("partitions_def", Optional["PartitionsDefinition"]),
],
)
):
"""Defines all of the asset-related information for a given output.
Args:
key (AssetKey): The AssetKey
partitions_fn (OutputContext -> Optional[Set[str]], optional): A function which takes the
current OutputContext and generates a set of partition names that will be materialized
for this asset.
partitions_def (PartitionsDefinition, optional): Defines the set of valid partitions
for this asset.
"""

def __new__(
cls,
key: AssetKey,
partitions_fn: Optional[Callable[["OutputContext"], Optional[AbstractSet[str]]]] = None,
partitions_def: Optional["PartitionsDefinition"] = None,
):
return super().__new__(
cls,
key=check.inst_param(key, "key", AssetKey),
partitions_fn=check.opt_callable_param(partitions_fn, "partitions_fn", lambda _: None),
partitions_def=partitions_def,
)


def _assets_job_info_for_node(
node_def: NodeDefinition, node_handle: Optional[NodeHandle]
) -> Tuple[
Mapping[NodeInputHandle, AssetKey],
Mapping[NodeOutputHandle, AssetOutputInfo],
Mapping[AssetKey, AbstractSet[AssetKey]],
]:
"""
Recursively iterate through all the sub-nodes of a Node to find any ops with asset info
encoded on their inputs/outputs
"""
check.inst_param(node_def, "node_def", NodeDefinition)
check.opt_inst_param(node_handle, "node_handle", NodeHandle)

asset_key_by_input: Dict[NodeInputHandle, AssetKey] = {}
asset_info_by_output: Dict[NodeOutputHandle, AssetOutputInfo] = {}
asset_deps: Dict[AssetKey, AbstractSet[AssetKey]] = {}
if not isinstance(node_def, GraphDefinition):
# must be in an op (or solid)
if node_handle is None:
check.failed("Must have node_handle for non-graph NodeDefinition")
input_asset_keys: Set[AssetKey] = set()
for input_def in node_def.input_defs:
input_key = input_def.hardcoded_asset_key
if input_key:
input_asset_keys.add(input_key)
input_handle = NodeInputHandle(node_handle, input_def.name)
asset_key_by_input[input_handle] = input_key
for output_def in node_def.output_defs:
output_key = output_def.hardcoded_asset_key
if output_key:
output_handle = NodeOutputHandle(node_handle, output_def.name)
asset_info_by_output[output_handle] = AssetOutputInfo(
key=output_key,
partitions_fn=output_def.get_asset_partitions,
partitions_def=output_def.asset_partitions_def,
)
# assume output depends on all inputs
asset_deps[output_key] = input_asset_keys
else:
# keep recursing through structure
for sub_node_name, sub_node in node_def.node_dict.items():
n_asset_key_by_input, n_asset_info_by_output, n_asset_deps = _assets_job_info_for_node(
node_def=sub_node.definition,
node_handle=NodeHandle(sub_node_name, parent=node_handle),
)
asset_key_by_input.update(n_asset_key_by_input)
asset_info_by_output.update(n_asset_info_by_output)
asset_deps.update(n_asset_deps)
return asset_key_by_input, asset_info_by_output, asset_deps


class AssetLayer:
"""
Stores all of the asset-related information for a Dagster job / pipeline. Maps each
input / output in the underlying graph to the asset it represents (if any), and records the
dependencies between each asset.
Args:
asset_key_by_node_input_handle (Mapping[NodeInputHandle, AssetOutputInfo], optional): A mapping
from a unique input in the underlying graph to the associated AssetKey that it loads from.
asset_info_by_node_output_handle (Mapping[NodeOutputHandle, AssetOutputInfo], optional): A mapping
from a unique output in the underlying graph to the associated AssetOutputInfo.
asset_deps (Mapping[AssetKey, AbstractSet[AssetKey]], optional): Records the upstream asset
keys for each asset key produced by this job.
"""

def __init__(
self,
asset_key_by_node_input_handle: Optional[Mapping[NodeInputHandle, AssetKey]] = None,
asset_info_by_node_output_handle: Optional[
Mapping[NodeOutputHandle, AssetOutputInfo]
] = None,
asset_deps: Optional[Mapping[AssetKey, AbstractSet[AssetKey]]] = None,
):
self._asset_key_by_node_input_handle = check.opt_dict_param(
asset_key_by_node_input_handle,
"asset_key_by_node_input_handle",
key_type=NodeInputHandle,
value_type=AssetKey,
)
self._asset_info_by_node_output_handle = check.opt_dict_param(
asset_info_by_node_output_handle,
"asset_info_by_node_output_handle",
key_type=NodeOutputHandle,
value_type=AssetOutputInfo,
)
self._asset_deps = check.opt_dict_param(
asset_deps, "asset_deps", key_type=AssetKey, value_type=set
)

@staticmethod
def from_graph(graph_def: GraphDefinition) -> "AssetLayer":
"""Scrape asset info off of InputDefinition/OutputDefinition instances"""
check.inst_param(graph_def, "graph_def", GraphDefinition)
asset_by_input, asset_by_output, asset_deps = _assets_job_info_for_node(graph_def, None)
return AssetLayer(
asset_key_by_node_input_handle=asset_by_input,
asset_info_by_node_output_handle=asset_by_output,
asset_deps=asset_deps,
)

def upstream_assets(self, asset_key: AssetKey) -> AbstractSet[AssetKey]:
check.invariant(
asset_key in self._asset_deps,
"AssetKey '{asset_key}' is not produced by this JobDefinition.",
)
return self._asset_deps[asset_key]

def asset_key_for_input(self, node_handle: NodeHandle, input_name: str) -> Optional[AssetKey]:
return self._asset_key_by_node_input_handle.get(NodeInputHandle(node_handle, input_name))

def asset_info_for_output(
self, node_handle: NodeHandle, output_name: str
) -> Optional[AssetOutputInfo]:
return self._asset_info_by_node_output_handle.get(
NodeOutputHandle(node_handle, output_name)
)
16 changes: 16 additions & 0 deletions python_modules/dagster/dagster/core/definitions/dependency.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,22 @@ def from_dict(cls, dict_repr: Dict[str, Any]) -> Optional["NodeHandle"]:
return NodeHandle(**{k: dict_repr[k] for k in ["name", "parent"]})


class NodeInputHandle(
NamedTuple("_NodeInputHandle", [("node_handle", NodeHandle), ("input_name", str)])
):
"""
A structured object to uniquely identify inputs in the potentially recursive graph structure.
"""


class NodeOutputHandle(
NamedTuple("_NodeOutputHandle", [("node_handle", NodeHandle), ("output_name", str)])
):
"""
A structured object to uniquely identify outputs in the potentially recursive graph structure.
"""


# previous name for NodeHandle was SolidHandle
register_serdes_tuple_fallbacks({"SolidHandle": NodeHandle})

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from dagster.core.storage.fs_asset_io_manager import fs_asset_io_manager
from dagster.core.utils import str_format_set

from .asset_layer import AssetLayer
from .config import ConfigMapping
from .executor_definition import ExecutorDefinition
from .graph_definition import GraphDefinition, SubselectedGraphDefinition
Expand Down Expand Up @@ -72,6 +73,7 @@ def __init__(
op_retry_policy: Optional[RetryPolicy] = None,
version_strategy: Optional[VersionStrategy] = None,
_op_selection_data: Optional[OpSelectionData] = None,
asset_layer: Optional[AssetLayer] = None,
):

# Exists for backcompat - JobDefinition is implemented as a single-mode pipeline.
Expand All @@ -98,6 +100,7 @@ def __init__(
solid_retry_policy=op_retry_policy,
graph_def=graph_def,
version_strategy=version_strategy,
asset_layer=asset_layer,
)

@property
Expand Down Expand Up @@ -190,6 +193,7 @@ def execute_in_process(
tags=self.tags,
op_retry_policy=self._solid_retry_policy,
version_strategy=self.version_strategy,
asset_layer=self.asset_layer,
).get_job_def_for_op_selection(op_selection)

tags = None
Expand Down Expand Up @@ -261,6 +265,9 @@ def get_job_def_for_op_selection(
), # equivalent to solids_to_execute. currently only gets top level nodes.
parent_job_def=self, # used by pipeline snapshot lineage
),
# TODO: subset this structure.
# https://github.com/dagster-io/dagster/issues/7541
asset_layer=self.asset_layer,
)

def get_partition_set_def(self) -> Optional["PartitionSetDefinition"]:
Expand Down Expand Up @@ -313,6 +320,7 @@ def with_hooks(self, hook_defs: AbstractSet[HookDefinition]) -> "JobDefinition":
hook_defs=hook_defs | self.hook_defs,
description=self._description,
op_retry_policy=self._solid_retry_policy,
asset_layer=self.asset_layer,
_op_selection_data=self._op_selection_data,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from dagster.utils import frozentags, merge_dicts
from dagster.utils.backcompat import experimental_class_warning

from .asset_layer import AssetLayer
from .dependency import (
DependencyDefinition,
DependencyStructure,
Expand Down Expand Up @@ -106,6 +107,8 @@ class PipelineDefinition:
solid_retry_policy (Optional[RetryPolicy]): The default retry policy for all solids in
this pipeline. Only used if retry policy is not defined on the solid definition or
solid invocation.
asset_layer (Optional[AssetLayer]): Structured object containing all definition-time asset
information for this pipeline.
_parent_pipeline_def (INTERNAL ONLY): Used for tracking pipelines created using solid subsets.
Expand Down Expand Up @@ -167,6 +170,7 @@ def __init__(
graph_def=None,
_parent_pipeline_def=None, # https://github.com/dagster-io/dagster/issues/2115
version_strategy: Optional[VersionStrategy] = None,
asset_layer: Optional[AssetLayer] = None,
):
# If a graph is specificed directly use it
if check.opt_inst_param(graph_def, "graph_def", GraphDefinition):
Expand Down Expand Up @@ -268,6 +272,10 @@ def __init__(
if self.version_strategy is not None:
experimental_class_warning("VersionStrategy")

self._asset_layer = check.opt_inst_param(
asset_layer, "asset_layer", AssetLayer, default=AssetLayer.from_graph(self.graph)
)

@property
def name(self):
return self._name
Expand Down Expand Up @@ -495,6 +503,10 @@ def solids_to_execute(self) -> Optional[FrozenSet[str]]:
def hook_defs(self) -> AbstractSet[HookDefinition]:
return self._hook_defs

@property
def asset_layer(self) -> AssetLayer:
return self._asset_layer

def get_all_hooks_for_handle(self, handle: NodeHandle) -> FrozenSet[HookDefinition]:
"""Gather all the hooks for the given solid from all places possibly attached with a hook.
Expand Down
26 changes: 16 additions & 10 deletions python_modules/dagster/dagster/core/execution/plan/execute_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
Materialization,
Output,
OutputDefinition,
SolidDefinition,
TypeCheck,
)
from dagster.core.definitions.decorators.solid_decorator import DecoratedSolidFunction
Expand Down Expand Up @@ -429,20 +428,27 @@ def _asset_key_and_partitions_for_output(

manager_asset_key = output_manager.get_output_asset_key(output_context)

if output_def.is_asset:
pipeline_def = output_context.step_context.pipeline_def
node_handle = output_context.step_context.solid_handle
output_asset_info = pipeline_def.asset_layer.asset_info_for_output(
node_handle=output_context.step_context.solid_handle, output_name=output_def.name
)
if output_asset_info:
if manager_asset_key is not None:
solid_def = cast(SolidDefinition, output_context.solid_def)
raise DagsterInvariantViolationError(
f'Both the OutputDefinition and the IOManager of output "{output_def.name}" on '
f'solid "{solid_def.name}" associate it with an asset. Either remove '
"the asset_key parameter on the OutputDefinition or use an IOManager that does not "
"specify an AssetKey in its get_output_asset_key() function."
f'The IOManager of output "{output_def.name}" on node "{node_handle}" associates it '
f'with asset key "{manager_asset_key}", but this output has already been defined to '
f'produce asset "{output_asset_info.key}", either via a Software Defined Asset, '
"or by setting the asset_key parameter on the OutputDefinition. In most cases, this "
"means that you should use an IOManager that does not specify an AssetKey in its "
"get_output_asset_key() function for this output."
)
return (
output_def.get_asset_key(output_context),
output_def.get_asset_partitions(output_context) or set(),
output_asset_info.key,
output_asset_info.partitions_fn(output_context) or set(),
)
elif manager_asset_key:

if manager_asset_key:
return manager_asset_key, output_manager.get_output_asset_partitions(output_context)

return None, set()
Expand Down

0 comments on commit fdb2274

Please sign in to comment.