Skip to content

Commit

Permalink
asset partitioned io manager (#7413)
Browse files Browse the repository at this point in the history
  • Loading branch information
sryza committed Apr 14, 2022
1 parent be2469b commit bb27b2d
Show file tree
Hide file tree
Showing 9 changed files with 96 additions and 28 deletions.
11 changes: 10 additions & 1 deletion python_modules/dagster/dagster/core/execution/context/output.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import warnings
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Union, cast
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Sequence, Union, cast

from dagster import check
from dagster.core.definitions.events import (
Expand Down Expand Up @@ -384,6 +384,15 @@ def get_output_identifier(self) -> List[str]:

return identifier

def get_asset_output_identifier(self) -> Sequence[str]:
if self.asset_key is not None:
if self.has_asset_partitions:
return self.asset_key.path + [self.asset_partition_key]
else:
return self.asset_key.path
else:
check.failed("Can't get asset output identifier for an output with no asset key")

def log_event(
self, event: Union[AssetObservation, AssetMaterialization, Materialization]
) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,4 +81,4 @@ def asset2(asset1):

class AssetPickledObjectFilesystemIOManager(PickledObjectFilesystemIOManager):
def _get_path(self, context):
return os.path.join(self.base_dir, *context.asset_key.path)
return os.path.join(self.base_dir, *context.get_asset_output_identifier())
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,23 @@
import pickle
import tempfile

from dagster import DailyPartitionsDefinition
from dagster.core.asset_defs import AssetIn, asset, build_assets_job
from dagster.core.storage.fs_asset_io_manager import fs_asset_io_manager


def get_assets_job(io_manager_def):
def get_assets_job(io_manager_def, partitions_def=None):
asset1_namespace = ["one", "two", "three"]

@asset(namespace=["one", "two", "three"])
@asset(namespace=["one", "two", "three"], partitions_def=partitions_def)
def asset1():
return [1, 2, 3]

@asset(namespace=["four", "five"], ins={"asset1": AssetIn(namespace=asset1_namespace)})
@asset(
namespace=["four", "five"],
ins={"asset1": AssetIn(namespace=asset1_namespace)},
partitions_def=partitions_def,
)
def asset2(asset1):
return asset1 + [4]

Expand Down Expand Up @@ -48,3 +53,33 @@ def test_fs_asset_io_manager():
assert os.path.isfile(filepath_b)
with open(filepath_b, "rb") as read_obj:
assert pickle.load(read_obj) == [1, 2, 3, 4]


def test_fs_asset_io_manager_partitioned():
with tempfile.TemporaryDirectory() as tmpdir_path:
io_manager_def = fs_asset_io_manager.configured({"base_dir": tmpdir_path})
job_def = get_assets_job(
io_manager_def, partitions_def=DailyPartitionsDefinition(start_date="2020-02-01")
)

result = job_def.execute_in_process(partition_key="2020-05-03")
assert result.success

handled_output_events = list(
filter(lambda evt: evt.is_handled_output, result.all_node_events)
)
assert len(handled_output_events) == 2

filepath_a = os.path.join(tmpdir_path, "one", "two", "three", "asset1", "2020-05-03")
assert os.path.isfile(filepath_a)
with open(filepath_a, "rb") as read_obj:
assert pickle.load(read_obj) == [1, 2, 3]

loaded_input_events = list(filter(lambda evt: evt.is_loaded_input, result.all_node_events))
assert len(loaded_input_events) == 1
assert loaded_input_events[0].event_specific_data.upstream_step_key.endswith("asset1")

filepath_b = os.path.join(tmpdir_path, "four", "five", "asset2", "2020-05-03")
assert os.path.isfile(filepath_b)
with open(filepath_b, "rb") as read_obj:
assert pickle.load(read_obj) == [1, 2, 3, 4]
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def my_job():

class PickledObjectS3AssetIOManager(PickledObjectS3IOManager):
def _get_path(self, context):
return "/".join([self.s3_prefix, *context.asset_key.path])
return "/".join([self.s3_prefix, *context.get_asset_output_identifier()])


@io_manager(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
Int,
Out,
Output,
StaticPartitionsDefinition,
VersionStrategy,
asset,
build_assets_job,
Expand Down Expand Up @@ -119,7 +120,7 @@ def memoized():
assert len(result.all_node_events) == 0


def define_assets_job():
def define_assets_job(bucket):
@asset
def asset1():
return 1
Expand All @@ -128,30 +129,33 @@ def asset1():
def asset2(asset1):
return asset1 + 1

@asset(partitions_def=StaticPartitionsDefinition(["apple", "orange"]))
def partitioned():
return 8

return build_assets_job(
name="assets",
assets=[asset1, asset2],
assets=[asset1, asset2, partitioned],
resource_defs={
"io_manager": s3_pickle_asset_io_manager,
"io_manager": s3_pickle_asset_io_manager.configured({"s3_bucket": bucket}),
"s3": s3_test_resource,
},
)


def test_s3_pickle_asset_io_manager_execution(mock_s3_bucket):
assert not len(list(mock_s3_bucket.objects.all()))
inty_job = define_assets_job()

run_config = {"resources": {"io_manager": {"config": {"s3_bucket": mock_s3_bucket.name}}}}
inty_job = define_assets_job(mock_s3_bucket.name)

result = inty_job.execute_in_process(run_config)
result = inty_job.execute_in_process(partition_key="apple")

assert result.output_for_node("asset1") == 1
assert result.output_for_node("asset2") == 2

objects = list(mock_s3_bucket.objects.all())
assert len(objects) == 2
assert objects[0].bucket_name == "test-bucket"
assert objects[0].key == "dagster/asset1"
assert objects[1].bucket_name == "test-bucket"
assert objects[1].key == "dagster/asset2"
assert len(objects) == 3
assert {(o.bucket_name, o.key) for o in objects} == {
("test-bucket", "dagster/asset1"),
("test-bucket", "dagster/asset2"),
("test-bucket", "dagster/partitioned/apple"),
}
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def my_job():

class PickledObjectADLS2AssetIOManager(PickledObjectADLS2IOManager):
def _get_path(self, context):
return "/".join([self.prefix, *context.asset_key.path])
return "/".join([self.prefix, *context.get_asset_output_identifier()])


@io_manager(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, Optional, Union
from typing import AbstractSet, Dict, Optional, Union


class FakeGCSBlob:
Expand Down Expand Up @@ -90,3 +90,10 @@ def list_blobs(
yield blob
elif prefix in blob.name:
yield blob

def get_all_blob_paths(self) -> AbstractSet[str]:
return {
f"{bucket.name}/{blob.name}"
for bucket in self.buckets.values()
for blob in bucket.blobs.values()
}
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def my_job():

class PickledObjectGCSAssetIOManager(PickledObjectGCSIOManager):
def _get_path(self, context):
return "/".join([self.prefix, *context.asset_key.path])
return "/".join([self.prefix, *context.get_asset_output_identifier()])


@io_manager(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
Int,
Out,
PipelineRun,
ResourceDefinition,
StaticPartitionsDefinition,
asset,
build_input_context,
build_output_context,
Expand Down Expand Up @@ -166,15 +168,26 @@ def upstream():
def downstream(upstream):
return 1 + upstream

@asset(partitions_def=StaticPartitionsDefinition(["apple", "orange"]))
def partitioned():
return 8

fake_gcs_client = FakeGCSClient()
asset_group = AssetGroup(
[upstream, downstream],
resource_defs={"io_manager": gcs_pickle_asset_io_manager, "gcs": mock_gcs_resource},
[upstream, downstream, partitioned],
resource_defs={
"io_manager": gcs_pickle_asset_io_manager.configured(
{"gcs_bucket": gcs_bucket, "gcs_prefix": "assets"}
),
"gcs": ResourceDefinition.hardcoded_resource(fake_gcs_client),
},
)
asset_job = asset_group.build_job(name="my_asset_job")

run_config = {
"resources": {"io_manager": {"config": {"gcs_bucket": gcs_bucket, "gcs_prefix": "assets"}}}
}

result = asset_job.execute_in_process(run_config=run_config)
result = asset_job.execute_in_process(partition_key="apple")
assert result.success
assert fake_gcs_client.get_all_blob_paths() == {
f"{gcs_bucket}/assets/upstream",
f"{gcs_bucket}/assets/downstream",
f"{gcs_bucket}/assets/partitioned/apple",
}

0 comments on commit bb27b2d

Please sign in to comment.