Skip to content

Commit

Permalink
AssetGroup.from_modules (#6884)
Browse files Browse the repository at this point in the history
  • Loading branch information
sryza committed Mar 3, 2022
1 parent 7601eab commit 8467f5c
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 13 deletions.
75 changes: 63 additions & 12 deletions python_modules/dagster/dagster/core/asset_defs/asset_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
Optional,
Sequence,
Set,
Tuple,
Union,
cast,
)
Expand Down Expand Up @@ -307,8 +308,8 @@ def from_package_module(
executor_def: Optional[ExecutorDefinition] = None,
) -> "AssetGroup":
"""
Constructs an AssetGroup that includes all asset definitions in all sub-modules of the
given package module.
Constructs an AssetGroup that includes all asset definitions and source assets in all
sub-modules of the given package module.
A package module is the result of importing a package.
Expand All @@ -322,15 +323,12 @@ def from_package_module(
Returns:
AssetGroup: An asset group with all the assets in the package.
"""
assets = set()
source_assets = set()
assets: Set[AssetsDefinition] = set()
source_assets: Set[SourceAsset] = set()
for module in _find_modules_in_package(package_module):
for attr in dir(module):
value = getattr(module, attr)
if isinstance(value, AssetsDefinition):
assets.add(value)
if isinstance(value, SourceAsset):
source_assets.add(value)
module_assets, module_source_assets = _find_assets_in_module(module)
assets.update(module_assets)
source_assets.update(module_source_assets)

return AssetGroup(
assets=list(assets),
Expand All @@ -346,8 +344,8 @@ def from_package_name(
executor_def: Optional[ExecutorDefinition] = None,
) -> "AssetGroup":
"""
Constructs an AssetGroup that includes all asset definitions in all sub-modules of the
given package.
Constructs an AssetGroup that includes all asset definitions and source assets in all
sub-modules of the given package.
Args:
package_name (str): The name of a Python package to look for assets inside.
Expand All @@ -364,6 +362,59 @@ def from_package_name(
package_module, resource_defs=resource_defs, executor_def=executor_def
)

@staticmethod
def from_modules(
modules: Sequence[ModuleType],
resource_defs: Optional[Mapping[str, ResourceDefinition]] = None,
executor_def: Optional[ExecutorDefinition] = None,
) -> "AssetGroup":
"""
Constructs an AssetGroup that includes all asset definitions and source assets in the given
module.
Args:
modules (Sequence[ModuleType]): The Python modules to look for assets inside.
resource_defs (Optional[Mapping[str, ResourceDefinition]]): A dictionary of resource
definitions to include on the returned asset group.
executor_def (Optional[ExecutorDefinition]): An executor to include on the returned
asset group.
Returns:
AssetGroup: An asset group with all the assets defined in the given modules.
"""
assets: Set[AssetsDefinition] = set()
source_assets: Set[SourceAsset] = set()
for module in modules:
module_assets, module_source_assets = _find_assets_in_module(module)
assets.update(module_assets)
source_assets.update(module_source_assets)

return AssetGroup(
assets=list(assets),
source_assets=list(source_assets),
resource_defs=resource_defs,
executor_def=executor_def,
)


def _find_assets_in_module(
module: ModuleType,
) -> Tuple[Sequence[AssetsDefinition], Sequence[SourceAsset]]:
"""
Finds assets in the given module and adds them to the given sets of assets and source assets.
"""
assets: List[AssetsDefinition] = []
source_assets: List[SourceAsset] = []

for attr in dir(module):
value = getattr(module, attr)
if isinstance(value, AssetsDefinition):
assets.append(value)
if isinstance(value, SourceAsset):
source_assets.append(value)

return assets, source_assets


def _find_modules_in_package(package_module: ModuleType) -> Iterable[ModuleType]:
yield package_module
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -349,10 +349,29 @@ def test_asset_group_from_package_module():
}


def test_asset_group_from_modules():
from . import asset_package
from .asset_package import module_with_assets

collection = AssetGroup.from_modules([asset_package, module_with_assets])
assert {asset.op.name for asset in collection.assets} == {
"little_richard",
"chuck_berry",
"miles_davis",
}
assert len(collection.assets) == 3
assert {source_asset.key for source_asset in collection.source_assets} == {
AssetKey("elvis_presley")
}


def test_default_io_manager():
@asset
def asset_foo():
return "foo"

group = AssetGroup(assets=[asset_foo])
assert group.resource_defs["io_manager"] == fs_asset_io_manager
assert (
group.resource_defs["io_manager"] # pylint: disable=comparison-with-callable
== fs_asset_io_manager
)

0 comments on commit 8467f5c

Please sign in to comment.