Skip to content

Commit

Permalink
enable adding asset groups together (#7634)
Browse files Browse the repository at this point in the history
  • Loading branch information
sryza committed Apr 29, 2022
1 parent 6003fdb commit a953214
Show file tree
Hide file tree
Showing 2 changed files with 145 additions and 47 deletions.
118 changes: 73 additions & 45 deletions python_modules/dagster/dagster/core/asset_defs/asset_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
Iterable,
List,
Mapping,
NamedTuple,
Optional,
Sequence,
Set,
Expand Down Expand Up @@ -44,17 +43,7 @@
ASSET_GROUP_BASE_JOB_PREFIX = "__ASSET_GROUP"


class AssetGroup(
NamedTuple(
"_AssetGroup",
[
("assets", Sequence[AssetsDefinition]),
("source_assets", Sequence[SourceAsset]),
("resource_defs", Mapping[str, ResourceDefinition]),
("executor_def", Optional[ExecutorDefinition]),
],
)
):
class AssetGroup:
"""Defines a group of assets, along with environment information in the
form of resources and an executor.
Expand Down Expand Up @@ -109,8 +98,8 @@ def foo_resource():
"""

def __new__(
cls,
def __init__(
self,
assets: Sequence[AssetsDefinition],
source_assets: Optional[Sequence[SourceAsset]] = None,
resource_defs: Optional[Mapping[str, ResourceDefinition]] = None,
Expand All @@ -126,9 +115,6 @@ def __new__(
)
executor_def = check.opt_inst_param(executor_def, "executor_def", ExecutorDefinition)

source_assets_by_key = build_source_assets_by_key(source_assets)
root_manager = build_root_manager(source_assets_by_key)

if "root_manager" in resource_defs:
raise DagsterInvalidDefinitionError(
"Resource dictionary included resource with key 'root_manager', "
Expand All @@ -139,22 +125,32 @@ def __new__(
# In the case of collisions, merge_dicts takes values from the
# dictionary latest in the list, so we place the user provided resource
# defs after the defaults.
resource_defs = merge_dicts(
{"root_manager": root_manager, "io_manager": fs_asset_io_manager},
resource_defs,
)
resource_defs = merge_dicts({"io_manager": fs_asset_io_manager}, resource_defs)

_validate_resource_reqs_for_asset_group(
asset_list=assets, source_assets=source_assets, resource_defs=resource_defs
)

return super(AssetGroup, cls).__new__(
cls,
assets=assets,
source_assets=source_assets,
resource_defs=resource_defs,
executor_def=executor_def,
)
self._assets = assets
self._source_assets = source_assets
self._resource_defs = resource_defs
self._executor_def = executor_def

@property
def assets(self):
return self._assets

@property
def source_assets(self):
return self._source_assets

@property
def resource_defs(self):
return self._resource_defs

@property
def executor_def(self):
return self._executor_def

@staticmethod
def is_base_job_name(name) -> bool:
Expand Down Expand Up @@ -211,17 +207,23 @@ def build_job(

if not isinstance(selection, str):
selection = check.opt_list_param(selection, "selection", of_type=str)
executor_def = check.opt_inst_param(executor_def, "executor_def", ExecutorDefinition)
executor_def = check.opt_inst_param(
executor_def, "executor_def", ExecutorDefinition, self.executor_def
)
description = check.opt_str_param(description, "description")
resource_defs = {
**self.resource_defs,
**{"root_manager": build_root_manager(build_source_assets_by_key(self.source_assets))},
}

with warnings.catch_warnings():
warnings.simplefilter("ignore", category=ExperimentalWarning)
mega_job_def = build_assets_job(
name=name,
assets=self.assets,
source_assets=self.source_assets,
resource_defs=self.resource_defs,
executor_def=self.executor_def,
resource_defs=resource_defs,
executor_def=executor_def,
)

if selection:
Expand Down Expand Up @@ -254,8 +256,8 @@ def build_job(
name=name,
assets=included_assets,
source_assets=excluded_assets,
resource_defs=self.resource_defs,
executor_def=self.executor_def,
resource_defs=resource_defs,
executor_def=executor_def,
description=description,
tags=tags,
)
Expand Down Expand Up @@ -529,15 +531,7 @@ def get_base_jobs(self) -> Sequence[JobDefinition]:
if len(assets_by_partitions_def.keys()) == 0 or assets_by_partitions_def.keys() == {
None
}:
return [
build_assets_job(
ASSET_GROUP_BASE_JOB_PREFIX,
assets=self.assets,
source_assets=self.source_assets,
resource_defs=self.resource_defs,
executor_def=self.executor_def,
)
]
return [self.build_job(ASSET_GROUP_BASE_JOB_PREFIX)]
else:
unpartitioned_assets = assets_by_partitions_def.get(None, [])
jobs = []
Expand All @@ -552,7 +546,12 @@ def get_base_jobs(self) -> Sequence[JobDefinition]:
f"{ASSET_GROUP_BASE_JOB_PREFIX}_{i}",
assets=assets_with_partitions + unpartitioned_assets,
source_assets=[*self.source_assets, *self.assets],
resource_defs=self.resource_defs,
resource_defs={
**self.resource_defs,
"root_manager": build_root_manager(
build_source_assets_by_key(self.source_assets)
),
},
executor_def=self.executor_def,
)
)
Expand Down Expand Up @@ -639,10 +638,39 @@ def asset2(asset1):
return AssetGroup(
assets=result_assets,
source_assets=self.source_assets,
resource_defs={k: r for k, r in self.resource_defs.items() if k != "root_manager"},
resource_defs=self.resource_defs,
executor_def=self.executor_def,
)

def __add__(self, other: "AssetGroup") -> "AssetGroup":
check.inst_param(other, "other", AssetGroup)

if self.resource_defs != other.resource_defs:
raise DagsterInvalidDefinitionError(
"Can't add asset groups together with different resource definition dictionarys"
)

if self.executor_def != other.executor_def:
raise DagsterInvalidDefinitionError(
"Can't add asset groups together with different executor definitions"
)

return AssetGroup(
assets=self.assets + other.assets,
source_assets=self.source_assets + other.source_assets,
resource_defs=self.resource_defs,
executor_def=self.executor_def,
)

def __eq__(self, other: object) -> bool:
return (
isinstance(other, AssetGroup)
and self.assets == other.assets
and self.source_assets == other.source_assets
and self.resource_defs == other.resource_defs
and self.executor_def == other.executor_def
)


def _find_assets_in_module(
module: ModuleType,
Expand Down Expand Up @@ -672,7 +700,7 @@ def _find_modules_in_package(package_module: ModuleType) -> Iterable[ModuleType]
yield submodule
else:
raise ValueError(
f"Tried find modules in package {package_module}, but its __file__ is None"
f"Tried to find modules in package {package_module}, but its __file__ is None"
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
HourlyPartitionsDefinition,
IOManager,
Out,
ResourceDefinition,
fs_asset_io_manager,
graph,
in_process_executor,
Expand Down Expand Up @@ -137,7 +138,7 @@ def asset_foo(context):

with pytest.raises(
DagsterInvalidDefinitionError,
match=r"SourceAsset with key AssetKey\(\['foo'\]\) requires io manager with key 'foo', which was not provided on AssetGroup. Provided keys: \['io_manager', 'root_manager'\]",
match=r"SourceAsset with key AssetKey\(\['foo'\]\) requires io manager with key 'foo', which was not provided on AssetGroup. Provided keys: \['io_manager'\]",
):
AssetGroup([], source_assets=[source_asset_io_req])

Expand Down Expand Up @@ -167,7 +168,7 @@ def asset_foo():
DagsterInvalidDefinitionError,
match=r"Output 'result' with AssetKey 'AssetKey\(\['asset_foo'\]\)' "
r"requires io manager 'blah' but was not provided on asset group. "
r"Provided resources: \['io_manager', 'root_manager'\]",
r"Provided resources: \['io_manager'\]",
):
AssetGroup([asset_foo])

Expand Down Expand Up @@ -643,3 +644,72 @@ def orange(apple):
result = AssetGroup([orange]).prefixed("my_prefix").assets
assert result[0].asset_key == AssetKey(["my_prefix", "orange"])
assert set(result[0].dependency_asset_keys) == {AssetKey("apple")}


def test_add_asset_groups():
@asset
def asset1():
...

@asset
def asset2():
...

source1 = SourceAsset(AssetKey(["source1"]))
source2 = SourceAsset(AssetKey(["source2"]))

group1 = AssetGroup(assets=[asset1], source_assets=[source1])
group2 = AssetGroup(assets=[asset2], source_assets=[source2])

assert (group1 + group2) == AssetGroup(
assets=[asset1, asset2], source_assets=[source1, source2]
)


def test_add_asset_groups_different_resources():
@asset
def asset1():
...

@asset
def asset2():
...

source1 = SourceAsset(AssetKey(["source1"]))
source2 = SourceAsset(AssetKey(["source2"]))

group1 = AssetGroup(
assets=[asset1],
source_assets=[source1],
resource_defs={"apple": ResourceDefinition.none_resource()},
)
group2 = AssetGroup(
assets=[asset2],
source_assets=[source2],
resource_defs={"banana": ResourceDefinition.none_resource()},
)

with pytest.raises(DagsterInvalidDefinitionError):
group1 + group2 # pylint: disable=pointless-statement


def test_add_asset_groups_different_executors():
@asset
def asset1():
...

@asset
def asset2():
...

source1 = SourceAsset(AssetKey(["source1"]))
source2 = SourceAsset(AssetKey(["source2"]))

group1 = AssetGroup(assets=[asset1], source_assets=[source1], executor_def=in_process_executor)
group2 = AssetGroup(
assets=[asset2],
source_assets=[source2],
)

with pytest.raises(DagsterInvalidDefinitionError):
group1 + group2 # pylint: disable=pointless-statement

0 comments on commit a953214

Please sign in to comment.