Skip to content

Commit

Permalink
[asset-resources 5/n] io manager defs on source assets (#8105)
Browse files Browse the repository at this point in the history
* Allow io manager definitions to be specified directly on SourceAssets

* Fix rebasing error
  • Loading branch information
dpeng817 committed Jun 2, 2022
1 parent 6fddda8 commit 6324cc4
Show file tree
Hide file tree
Showing 5 changed files with 136 additions and 20 deletions.
35 changes: 25 additions & 10 deletions python_modules/dagster/dagster/core/asset_defs/assets_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,18 @@ def asset2(asset1):
config=None,
)

# turn any AssetsDefinitions into SourceAssets
resolved_source_assets: List[SourceAsset] = []
for asset in source_assets or []:
if isinstance(asset, AssetsDefinition):
resolved_source_assets += asset.to_source_assets()
elif isinstance(asset, SourceAsset):
resolved_source_assets.append(asset)

asset_layer = AssetLayer.from_graph_and_assets_node_mapping(
graph, assets_defs_by_node_handle, resolved_source_assets
)

all_resource_defs = dict(resource_defs)
for asset_def in assets:
for resource_key, resource_def in asset_def.resource_defs.items():
Expand All @@ -110,22 +122,25 @@ def asset2(asset1):
)
all_resource_defs[resource_key] = resource_def

# turn any AssetsDefinitions into SourceAssets
resolved_source_assets: List[SourceAsset] = []
for asset in source_assets or []:
if isinstance(asset, AssetsDefinition):
resolved_source_assets += asset.to_source_assets()
elif isinstance(asset, SourceAsset):
resolved_source_assets.append(asset)
required_io_manager_keys = set()
for source_asset in resolved_source_assets:
if not source_asset.io_manager_def:
required_io_manager_keys.add(source_asset.get_io_manager_key())
else:
all_resource_defs[source_asset.get_io_manager_key()] = source_asset.io_manager_def

for required_key in sorted(list(required_io_manager_keys)):
if required_key not in all_resource_defs and required_key != "io_manager":
raise DagsterInvalidDefinitionError(
f"Error when attempting to build job '{name}': IO Manager required for key '{required_key}', but none was provided."
)

return graph.to_job(
resource_defs=all_resource_defs,
config=config or partitioned_config,
tags=tags,
executor_def=executor_def,
asset_layer=AssetLayer.from_graph_and_assets_node_mapping(
graph, assets_defs_by_node_handle, resolved_source_assets
),
asset_layer=asset_layer,
_asset_selection_data=_asset_selection_data,
)

Expand Down
25 changes: 20 additions & 5 deletions python_modules/dagster/dagster/core/asset_defs/source_asset.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
normalize_metadata,
)
from dagster.core.definitions.partition import PartitionsDefinition
from dagster.core.storage.io_manager import IOManagerDefinition


class SourceAsset(
Expand All @@ -18,7 +19,8 @@ class SourceAsset(
[
("key", AssetKey),
("metadata_entries", Sequence[Union[MetadataEntry, PartitionMetadataEntry]]),
("io_manager_key", str),
("io_manager_key", Optional[str]),
("io_manager_def", Optional[IOManagerDefinition]),
("description", Optional[str]),
("partitions_def", Optional[PartitionsDefinition]),
],
Expand All @@ -29,7 +31,9 @@ class SourceAsset(
Attributes:
key (Union[AssetKey, Sequence[str], str]): The key of the asset.
metadata_entries (List[MetadataEntry]): Metadata associated with the asset.
io_manager_key (str): The key for the IOManager that will be used to load the contents of
io_manager_key (Optional[str]): The key for the IOManager that will be used to load the contents of
the asset when it's used as an input to other assets inside a job.
io_manager_def (Optional[IOManagerDefinition]): The definition of the IOManager that will be used to load the contents of
the asset when it's used as an input to other assets inside a job.
description (Optional[str]): The description of the asset.
partitions_def (Optional[PartitionsDefinition]): Defines the set of partition keys that
Expand All @@ -40,18 +44,23 @@ def __new__(
cls,
key: CoerceableToAssetKey,
metadata: Optional[MetadataUserInput] = None,
io_manager_key: str = "io_manager",
io_manager_key: Optional[str] = None,
io_manager_def: Optional[IOManagerDefinition] = None,
description: Optional[str] = None,
partitions_def: Optional[PartitionsDefinition] = None,
):

key = AssetKey.from_coerceable(key)
metadata = check.opt_dict_param(metadata, "metadata", key_type=str)
metadata_entries = normalize_metadata(metadata, [], allow_invalid=True)
return super().__new__(
cls,
key=AssetKey.from_coerceable(key),
key=key,
metadata_entries=metadata_entries,
io_manager_key=check.str_param(io_manager_key, "io_manager_key"),
io_manager_key=check.opt_str_param(io_manager_key, "io_manager_key"),
io_manager_def=check.opt_inst_param(
io_manager_def, "io_manager_def", IOManagerDefinition
),
description=check.opt_str_param(description, "description"),
partitions_def=check.opt_inst_param(
partitions_def, "partitions_def", PartitionsDefinition
Expand All @@ -62,3 +71,9 @@ def __new__(
def metadata(self) -> MetadataMapping:
# PartitionMetadataEntry (unstable API) case is unhandled
return {entry.label: entry.entry_data for entry in self.metadata_entries} # type: ignore

def get_io_manager_key(self) -> str:
if not self.io_manager_key and not self.io_manager_def:
return "io_manager"
source_asset_path = "__".join(self.key.path)
return self.io_manager_key or f"{source_asset_path}__io_manager"
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,7 @@ def from_graph_and_assets_node_mapping(
asset_info_by_output: Dict[NodeOutputHandle, AssetOutputInfo] = {}
asset_deps: Dict[AssetKey, AbstractSet[AssetKey]] = {}
io_manager_by_asset: Dict[AssetKey, str] = {
source_asset.key: source_asset.io_manager_key for source_asset in source_assets
source_asset.key: source_asset.get_io_manager_key() for source_asset in source_assets
}
for node_handle, assets_def in assets_defs_by_node_handle.items():
asset_deps.update(assets_def.asset_deps)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1042,9 +1042,9 @@ def my_multi_asset():
yield Output(2, "my_other_out_name")

assert AssetGroup([my_asset, my_multi_asset]).to_source_assets() == [
SourceAsset(AssetKey(["my_asset"])),
SourceAsset(AssetKey(["my_asset_name"])),
SourceAsset(AssetKey(["my_other_asset"])),
SourceAsset(AssetKey(["my_asset"]), io_manager_key="io_manager"),
SourceAsset(AssetKey(["my_asset_name"]), io_manager_key="io_manager"),
SourceAsset(AssetKey(["my_other_asset"]), io_manager_key="io_manager"),
]


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ def asset1(source1):

with pytest.raises(
DagsterInvalidDefinitionError,
match="input manager with key 'special_io_manager' required by input 'source1' of op 'asset1' was not provided",
match="Error when attempting to build job 'a': IO Manager required for key 'special_io_manager', but none was provided.",
):
build_assets_job(
"a",
Expand Down Expand Up @@ -1303,3 +1303,89 @@ def my_asset(asset_1):

result = execute_pipeline(my_job)
assert result.success


def test_source_asset_io_manager_def():
class MyIOManager(IOManager):
def handle_output(self, context, obj):
pass

def load_input(self, context):
return 5

@io_manager
def the_manager():
return MyIOManager()

my_source_asset = SourceAsset(key=AssetKey("my_source_asset"), io_manager_def=the_manager)

@asset
def my_derived_asset(my_source_asset):
return my_source_asset + 4

source_asset_job = AssetGroup(
assets=[my_derived_asset],
source_assets=[my_source_asset],
).build_job("source_asset_job")

result = source_asset_job.execute_in_process(asset_selection=[AssetKey("my_derived_asset")])
assert result.success
assert result.output_for_node("my_derived_asset") == 9


def test_source_asset_io_manager_not_provided():
class MyIOManager(IOManager):
def handle_output(self, context, obj):
pass

def load_input(self, context):
return 5

@io_manager
def the_manager():
return MyIOManager()

my_source_asset = SourceAsset(key=AssetKey("my_source_asset"))

@asset
def my_derived_asset(my_source_asset):
return my_source_asset + 4

source_asset_job = AssetGroup(
assets=[my_derived_asset],
source_assets=[my_source_asset],
resource_defs={"io_manager": the_manager},
).build_job("source_asset_job")

result = source_asset_job.execute_in_process(asset_selection=[AssetKey("my_derived_asset")])
assert result.success
assert result.output_for_node("my_derived_asset") == 9


def test_source_asset_io_manager_key_provided():
class MyIOManager(IOManager):
def handle_output(self, context, obj):
pass

def load_input(self, context):
return 5

@io_manager
def the_manager():
return MyIOManager()

my_source_asset = SourceAsset(key=AssetKey("my_source_asset"), io_manager_key="some_key")

@asset
def my_derived_asset(my_source_asset):
return my_source_asset + 4

source_asset_job = AssetGroup(
assets=[my_derived_asset],
source_assets=[my_source_asset],
resource_defs={"some_key": the_manager},
).build_job("source_asset_job")

result = source_asset_job.execute_in_process(asset_selection=[AssetKey("my_derived_asset")])
assert result.success
assert result.output_for_node("my_derived_asset") == 9

0 comments on commit 6324cc4

Please sign in to comment.