Skip to content

Commit

Permalink
AssetSelection (#8202)
Browse files Browse the repository at this point in the history
  • Loading branch information
smackesey committed Jun 8, 2022
1 parent 4b691da commit f7718fa
Show file tree
Hide file tree
Showing 10 changed files with 359 additions and 46 deletions.
2 changes: 2 additions & 0 deletions python_modules/dagster/dagster/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from dagster.core.asset_defs import (
AssetGroup,
AssetIn,
AssetSelection,
AssetsDefinition,
SourceAsset,
asset,
Expand Down Expand Up @@ -377,6 +378,7 @@ def __dir__() -> typing.List[str]:
"AssetIn",
"AssetMaterialization",
"AssetObservation",
"AssetSelection",
"AssetSensorDefinition",
"AssetsDefinition",
"DagsterAssetMetadataValue",
Expand Down
1 change: 1 addition & 0 deletions python_modules/dagster/dagster/core/asset_defs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .asset_group import AssetGroup
from .asset_in import AssetIn
from .asset_selection import AssetSelection
from .assets import AssetsDefinition
from .assets_from_modules import (
assets_from_current_module,
Expand Down
142 changes: 142 additions & 0 deletions python_modules/dagster/dagster/core/asset_defs/asset_selection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
import operator
from abc import ABC
from functools import reduce
from typing import AbstractSet, FrozenSet, Optional

from typing_extensions import TypeAlias

import dagster._check as check
from dagster.core.asset_defs.assets import AssetsDefinition
from dagster.core.selector.subset_selector import (
Direction,
fetch_connected_assets_definitions,
generate_asset_dep_graph,
generate_asset_name_to_definition_map,
)

AssetSet: TypeAlias = AbstractSet[AssetsDefinition] # makes sigs more readable


class AssetSelection(ABC):
@staticmethod
def all() -> "AllAssetSelection":
return AllAssetSelection()

@staticmethod
def keys(*key_strs: str) -> "KeysAssetSelection":
return KeysAssetSelection(*key_strs)

@staticmethod
def groups(*group_strs) -> "GroupsAssetSelection":
return GroupsAssetSelection(*group_strs)

def downstream(self, depth: Optional[int] = None) -> "DownstreamAssetSelection":
return DownstreamAssetSelection(self, depth=depth)

def upstream(self, depth: Optional[int] = None) -> "UpstreamAssetSelection":
return UpstreamAssetSelection(self, depth=depth)

def __or__(self, other: "AssetSelection") -> "OrAssetSelection":
return OrAssetSelection(self, other)

def __and__(self, other: "AssetSelection") -> "AndAssetSelection":
return AndAssetSelection(self, other)

def resolve(self, all_assets: AssetSet) -> AssetSet:
return Resolver(all_assets).resolve(self)


class AllAssetSelection(AssetSelection):
pass


class AndAssetSelection(AssetSelection):
def __init__(self, child_1: AssetSelection, child_2: AssetSelection):
self.children = (child_1, child_2)


class DownstreamAssetSelection(AssetSelection):
def __init__(self, child: AssetSelection, *, depth: Optional[int] = None):
self.children = (child,)
self.depth = depth


class GroupsAssetSelection(AssetSelection):
def __init__(self, *children: str):
self.children = children


class KeysAssetSelection(AssetSelection):
def __init__(self, *children: str):
self.children = children


class OrAssetSelection(AssetSelection):
def __init__(self, child_1: AssetSelection, child_2: AssetSelection):
self.children = (child_1, child_2)


class UpstreamAssetSelection(AssetSelection):
def __init__(self, child: AssetSelection, *, depth: Optional[int] = None):
self.children = (child,)
self.depth = depth


# ########################
# ##### RESOLUTION
# ########################


class Resolver:
def __init__(self, all_assets: AssetSet):
self.all_assets = all_assets
self.asset_dep_graph = generate_asset_dep_graph(list(all_assets))
self.all_assets_by_name = generate_asset_name_to_definition_map(all_assets)

def resolve(self, node: AssetSelection) -> AssetSet:
if isinstance(node, AllAssetSelection):
return self.all_assets
elif isinstance(node, AndAssetSelection):
child_1, child_2 = [self.resolve(child) for child in node.children]
return child_1 & child_2
elif isinstance(node, DownstreamAssetSelection):
child = self.resolve(node.children[0])
return reduce(
operator.or_,
[self._gather_connected_assets(asset, "downstream", node.depth) for asset in child],
)
elif isinstance(node, GroupsAssetSelection):
return {
a
for a in self.all_assets
if any(_match_group(a, pattern) for pattern in node.children)
}
elif isinstance(node, KeysAssetSelection):
return {a for a in self.all_assets if any(_match_key(a, key) for key in node.children)}
elif isinstance(node, OrAssetSelection):
child_1, child_2 = [self.resolve(child) for child in node.children]
return child_1 | child_2
elif isinstance(node, UpstreamAssetSelection):
child = self.resolve(node.children[0])
return reduce(
operator.or_,
[self._gather_connected_assets(asset, "upstream", node.depth) for asset in child],
)
else:
check.failed(f"Unknown node type: {type(node)}")

def _gather_connected_assets(
self, asset: AssetsDefinition, direction: Direction, depth: Optional[int]
) -> FrozenSet[AssetsDefinition]:
connected = fetch_connected_assets_definitions(
asset, self.asset_dep_graph, self.all_assets_by_name, direction=direction, depth=depth
)
return connected | {asset}


def _match_key(asset: AssetsDefinition, key_str: str) -> bool:
return any(key_str == key.to_user_string() for key in asset.asset_keys)


def _match_group(asset: AssetsDefinition, group_str: str) -> bool:
return any(group_str == group_name for group_name in asset.group_names.values())
6 changes: 3 additions & 3 deletions python_modules/dagster/dagster/core/definitions/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,13 +115,13 @@ def to_string(self, legacy: Optional[bool] = False) -> Optional[str]:

def to_user_string(self) -> str:
"""
E.g. "first_component>second_component"
E.g. "first_component/second_component"
"""
return ">".join(self.path)
return "/".join(self.path)

@staticmethod
def from_user_string(asset_key_string: str) -> "AssetKey":
return AssetKey(asset_key_string.split(">"))
return AssetKey(asset_key_string.split("/"))

@staticmethod
def from_db_string(asset_key_string: Optional[str]) -> Optional["AssetKey"]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,8 @@ def __init__(
version_strategy: Optional[VersionStrategy] = None,
asset_layer: Optional[AssetLayer] = None,
):
# If a graph is specificed directly use it
if check.opt_inst_param(graph_def, "graph_def", GraphDefinition):
# If a graph is specified directly use it
if isinstance(graph_def, GraphDefinition):
self._graph_def = graph_def
self._name = name or graph_def.name

Expand Down

0 comments on commit f7718fa

Please sign in to comment.