Skip to content

Commit

Permalink
bring back source asset metadata (#8195)
Browse files Browse the repository at this point in the history
  • Loading branch information
sryza authored and johannkm committed Jun 9, 2022
1 parent d421a1b commit 58fa8ca
Show file tree
Hide file tree
Showing 9 changed files with 164 additions and 37 deletions.
22 changes: 22 additions & 0 deletions python_modules/dagster/dagster/_check/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1121,6 +1121,28 @@ def opt_nullable_sequence_param(
return opt_sequence_param(obj, param_name, of_type, additional_message)


# ########################
# ##### Iterable
# ########################


def iterable_param(
obj: Iterable[T],
param_name: str,
of_type: Optional[TypeOrTupleOfTypes] = None,
additional_message: Optional[str] = None,
) -> Iterable[T]:
if not isinstance(obj, collections.abc.Iterable):
raise _param_type_mismatch_exception(
obj, (collections.abc.Iterable,), param_name, additional_message
)

if not of_type:
return obj

return _check_iterable_items(obj, of_type, "iterable")


# ########################
# ##### SET
# ########################
Expand Down
9 changes: 9 additions & 0 deletions python_modules/dagster/dagster/core/asset_defs/assets.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,11 @@ def __init__(
self._selected_asset_keys = all_asset_keys
self._can_subset = can_subset

self._metadata_by_asset_key = {
asset_key: node_def.resolve_output_to_origin(output_name, None)[0].metadata
for output_name, asset_key in asset_keys_by_output_name.items()
}

def __call__(self, *args, **kwargs):
return self._node_def(*args, **kwargs)

Expand Down Expand Up @@ -220,6 +225,10 @@ def asset_keys_by_input_name(self) -> Mapping[str, AssetKey]:
def partitions_def(self) -> Optional[PartitionsDefinition]:
return self._partitions_def

@property
def metadata_by_asset_key(self):
return self._metadata_by_asset_key

def get_partition_mapping(self, in_asset_key: AssetKey) -> PartitionMapping:
if self._partitions_def is None:
check.failed("Asset is not partitioned")
Expand Down
22 changes: 17 additions & 5 deletions python_modules/dagster/dagster/core/asset_defs/assets_job.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,17 @@
import itertools
from typing import AbstractSet, Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union, cast
from typing import (
AbstractSet,
Any,
Dict,
Iterable,
List,
Mapping,
Optional,
Sequence,
Tuple,
Union,
cast,
)

import dagster._check as check
from dagster.config import Shape
Expand Down Expand Up @@ -32,7 +44,7 @@
@experimental
def build_assets_job(
name: str,
assets: Sequence[AssetsDefinition],
assets: Iterable[AssetsDefinition],
source_assets: Optional[Sequence[Union[SourceAsset, AssetsDefinition]]] = None,
resource_defs: Optional[Mapping[str, ResourceDefinition]] = None,
description: Optional[str] = None,
Expand Down Expand Up @@ -76,7 +88,7 @@ def asset2(asset1):
"""

check.str_param(name, "name")
check.sequence_param(assets, "assets", of_type=AssetsDefinition)
check.iterable_param(assets, "assets", of_type=AssetsDefinition)
check.opt_sequence_param(
source_assets, "source_assets", of_type=(SourceAsset, AssetsDefinition)
)
Expand Down Expand Up @@ -131,7 +143,7 @@ def asset2(asset1):


def build_job_partitions_from_assets(
assets: Sequence[AssetsDefinition],
assets: Iterable[AssetsDefinition],
source_assets: Sequence[Union[SourceAsset, AssetsDefinition]],
) -> Optional[PartitionedConfig]:
assets_with_partitions_defs = [assets_def for assets_def in assets if assets_def.partitions_def]
Expand Down Expand Up @@ -236,7 +248,7 @@ def build_source_assets_by_key(


def build_deps(
assets_defs: Sequence[AssetsDefinition], source_paths: AbstractSet[AssetKey]
assets_defs: Iterable[AssetsDefinition], source_paths: AbstractSet[AssetKey]
) -> Tuple[
Dict[Union[str, NodeInvocation], Dict[str, IDependencyDefinition]],
Mapping[NodeHandle, AssetsDefinition],
Expand Down
66 changes: 46 additions & 20 deletions python_modules/dagster/dagster/core/definitions/asset_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,10 +369,10 @@ def __init__(
asset_deps: Optional[Mapping[AssetKey, AbstractSet[AssetKey]]] = None,
dependency_node_handles_by_asset_key: Optional[Mapping[AssetKey, Set[NodeHandle]]] = None,
assets_defs: Optional[List["AssetsDefinition"]] = None,
source_asset_defs: Optional[Sequence[Union["SourceAsset", "AssetsDefinition"]]] = None,
source_asset_defs: Optional[Sequence["SourceAsset"]] = None,
io_manager_keys_by_asset_key: Optional[Mapping[AssetKey, str]] = None,
):
from dagster.core.asset_defs import AssetsDefinition, SourceAsset
from dagster.core.asset_defs import SourceAsset

self._asset_keys_by_node_input_handle = check.opt_dict_param(
asset_keys_by_node_input_handle,
Expand All @@ -395,10 +395,17 @@ def __init__(
key_type=AssetKey,
value_type=Set,
)
self._assets_defs = check.opt_list_param(assets_defs, "assets_defs")
self._source_asset_defs = check.opt_list_param(
source_asset_defs, "source_assets", of_type=(SourceAsset, AssetsDefinition)
)
self._source_assets_by_key = {
source_asset.key: source_asset
for source_asset in check.opt_list_param(
source_asset_defs, "source_assets_defs", of_type=SourceAsset
)
}
self._assets_defs_by_key = {
key: assets_def
for assets_def in check.opt_list_param(assets_defs, "assets_defs")
for key in assets_def.asset_keys
}

# keep an index from node handle to all keys expected to be generated in that node
self._asset_keys_by_node_handle: Dict[NodeHandle, Set[AssetKey]] = defaultdict(set)
Expand Down Expand Up @@ -515,6 +522,14 @@ def dependency_node_handles_by_asset_key(self) -> Mapping[AssetKey, Sequence[Nod
def asset_keys(self) -> Iterable[AssetKey]:
return self._dependency_node_handles_by_asset_key.keys()

@property
def source_assets_by_key(self) -> Mapping[AssetKey, "SourceAsset"]:
return self._source_assets_by_key

@property
def assets_defs_by_key(self) -> Mapping[AssetKey, "AssetsDefinition"]:
return self._assets_defs_by_key

def asset_keys_for_node(self, node_handle: NodeHandle) -> AbstractSet[AssetKey]:
return self._asset_keys_by_node_handle[node_handle]

Expand All @@ -524,6 +539,15 @@ def asset_key_for_input(self, node_handle: NodeHandle, input_name: str) -> Optio
def io_manager_key_for_asset(self, asset_key: AssetKey) -> str:
return self._io_manager_keys_by_asset_key.get(asset_key, "io_manager")

def metadata_for_asset(self, asset_key: AssetKey) -> Optional[Dict[str, object]]:
if asset_key in self._source_assets_by_key:
metadata = self._source_assets_by_key[asset_key].metadata
return {key: value.value for key, value in metadata.items()} if metadata else None
elif asset_key in self._assets_defs_by_key:
return self._assets_defs_by_key[asset_key].metadata_by_asset_key[asset_key]
else:
check.failed(f"Couldn't find key {asset_key}")

def asset_info_for_output(
self, node_handle: NodeHandle, output_name: str
) -> Optional[AssetOutputInfo]:
Expand All @@ -532,16 +556,17 @@ def asset_info_for_output(
)

def group_names_by_assets(self) -> Mapping[AssetKey, str]:
group_names: Dict[AssetKey, str] = {}
for assets_def in self._assets_defs:
group_names.update(assets_def.group_names)
return group_names
return {
key: assets_def.group_names[key]
for key, assets_def in self._assets_defs_by_key.items()
if key in assets_def.group_names
}


def build_asset_selection_job(
name: str,
assets: Sequence["AssetsDefinition"],
source_assets: Sequence[Union["AssetsDefinition", "SourceAsset"]],
assets: Iterable["AssetsDefinition"],
source_assets: Iterable["SourceAsset"],
executor_def: ExecutorDefinition,
resource_defs: Mapping[str, ResourceDefinition],
description: str,
Expand All @@ -556,10 +581,8 @@ def build_asset_selection_job(
assets, source_assets, asset_selection
)
else:
included_assets = cast(List["AssetsDefinition"], assets)
# Slice [:] serves as a copy constructor, so that we don't
# accidentally add to the original list
excluded_assets = source_assets[:]
included_assets = cast(Iterable["AssetsDefinition"], assets)
excluded_assets = list(source_assets)

with warnings.catch_warnings():
warnings.simplefilter("ignore", category=ExperimentalWarning)
Expand All @@ -578,10 +601,10 @@ def build_asset_selection_job(


def _subset_assets_defs(
assets: Sequence["AssetsDefinition"],
source_assets: Sequence[Union["AssetsDefinition", "SourceAsset"]],
assets: Iterable["AssetsDefinition"],
source_assets: Iterable["SourceAsset"],
selected_asset_keys: AbstractSet[AssetKey],
) -> Tuple[Sequence["AssetsDefinition"], Sequence[Union["AssetsDefinition", "SourceAsset"]]]:
) -> Tuple[Iterable["AssetsDefinition"], Sequence[Union["AssetsDefinition", "SourceAsset"]]]:
"""Given a list of asset key selection queries, generate a set of AssetsDefinition objects
representing the included/excluded definitions.
"""
Expand Down Expand Up @@ -614,6 +637,9 @@ def _subset_assets_defs(
"asset keys produced by this asset."
)

all_excluded_assets = [*excluded_assets, *source_assets]
all_excluded_assets: Sequence[Union["AssetsDefinition", "SourceAsset"]] = [
*excluded_assets,
*source_assets,
]

return list(included_assets), all_excluded_assets
Original file line number Diff line number Diff line change
Expand Up @@ -389,10 +389,10 @@ def get_output_mapping(self, output_name: str) -> OutputMapping:
check.failed(f"Could not find output mapping {output_name}")

def resolve_output_to_origin(
self, output_name: str, handle: NodeHandle
self, output_name: str, handle: Optional[NodeHandle]
) -> Tuple[OutputDefinition, NodeHandle]:
check.str_param(output_name, "output_name")
check.inst_param(handle, "handle", NodeHandle)
check.opt_inst_param(handle, "handle", NodeHandle)

mapping = self.get_output_mapping(output_name)
check.invariant(mapping, "Can only resolve outputs for valid output names")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -330,14 +330,14 @@ def _get_job_def_for_asset_selection(
)

check.invariant(
self.asset_layer._assets_defs != None, # pylint:disable=protected-access
self.asset_layer.assets_defs_by_key is not None,
"Asset layer must have _asset_defs argument defined",
)

new_job = build_asset_selection_job(
name=self.name,
assets=self.asset_layer._assets_defs, # pylint:disable=protected-access
source_assets=self.asset_layer._source_asset_defs, # pylint:disable=protected-access
assets=self.asset_layer.assets_defs_by_key.values(),
source_assets=self.asset_layer.source_assets_by_key.values(),
executor_def=self.executor_def,
resource_defs=self.resource_defs,
description=self.description,
Expand Down

0 comments on commit 58fa8ca

Please sign in to comment.