Skip to content

Commit

Permalink
Allow asset observations in IO Managers (#6653)
Browse files Browse the repository at this point in the history
  • Loading branch information
clairelin135 committed Mar 31, 2022
1 parent dd6f57e commit 0831538
Show file tree
Hide file tree
Showing 4 changed files with 253 additions and 5 deletions.
7 changes: 7 additions & 0 deletions python_modules/dagster/dagster/core/events/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1067,6 +1067,7 @@ def loaded_input(
upstream_output_name: Optional[str] = None,
upstream_step_key: Optional[str] = None,
message_override: Optional[str] = None,
metadata_entries: Optional[List[MetadataEntry]] = None,
) -> "DagsterEvent":

message = f'Loaded input "{input_name}" using input manager "{manager_key}"'
Expand All @@ -1081,6 +1082,7 @@ def loaded_input(
manager_key=manager_key,
upstream_output_name=upstream_output_name,
upstream_step_key=upstream_step_key,
metadata_entries=metadata_entries if metadata_entries else [],
),
message=message_override or message,
)
Expand Down Expand Up @@ -1450,6 +1452,7 @@ class LoadedInputData(
("manager_key", str),
("upstream_output_name", Optional[str]),
("upstream_step_key", Optional[str]),
("metadata_entries", Optional[List[MetadataEntry]]),
],
)
):
Expand All @@ -1459,13 +1462,17 @@ def __new__(
manager_key: str,
upstream_output_name: Optional[str] = None,
upstream_step_key: Optional[str] = None,
metadata_entries: Optional[List[MetadataEntry]] = None,
):
return super(LoadedInputData, cls).__new__(
cls,
input_name=check.str_param(input_name, "input_name"),
manager_key=check.str_param(manager_key, "manager_key"),
upstream_output_name=check.opt_str_param(upstream_output_name, "upstream_output_name"),
upstream_step_key=check.opt_str_param(upstream_step_key, "upstream_step_key"),
metadata_entries=check.opt_list_param(
metadata_entries, "metadata_entries", of_type=MetadataEntry
),
)


Expand Down
80 changes: 78 additions & 2 deletions python_modules/dagster/dagster/core/execution/context/input.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from typing import TYPE_CHECKING, Any, Dict, Optional, Union, cast
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Union, cast

from dagster import check
from dagster.core.definitions.events import AssetKey
from dagster.core.definitions.events import AssetKey, AssetObservation
from dagster.core.definitions.metadata import MetadataEntry, PartitionMetadataEntry
from dagster.core.definitions.op_definition import OpDefinition
from dagster.core.definitions.partition_key_range import PartitionKeyRange
from dagster.core.definitions.solid_definition import SolidDefinition
Expand All @@ -13,6 +14,7 @@

if TYPE_CHECKING:
from dagster.core.definitions.resource_definition import Resources
from dagster.core.events import DagsterEvent
from dagster.core.execution.context.system import StepExecutionContext
from dagster.core.log_manager import DagsterLogManager
from dagster.core.types.dagster_type import DagsterType
Expand Down Expand Up @@ -86,6 +88,10 @@ def __init__(
self._resources_contain_cm = isinstance(self._resources, IContainsGenerator)
self._cm_scope_entered = False

self._events: List["DagsterEvent"] = []
self._observations: List[AssetObservation] = []
self._metadata_entries: List[Union[MetadataEntry, PartitionMetadataEntry]] = []

def __enter__(self):
if self._resources_cm:
self._cm_scope_entered = True
Expand Down Expand Up @@ -287,6 +293,76 @@ def asset_partitions_time_window(self) -> TimeWindow:
partitions_def.time_window_for_partition_key(partition_key_range.end).end,
)

def consume_events(self) -> Iterator["DagsterEvent"]:
"""Pops and yields all user-generated events that have been recorded from this context.
If consume_events has not yet been called, this will yield all logged events since the call to `handle_input`. If consume_events has been called, it will yield all events since the last time consume_events was called. Designed for internal use. Users should never need to invoke this method.
"""

events = self._events
self._events = []
yield from events

def add_input_metadata(
self,
metadata: Dict[str, Any],
description: Optional[str] = None,
) -> None:
"""Accepts a dictionary of metadata. Metadata entries will appear on the LOADED_INPUT event.
If the input is an asset, metadata will be attached to an asset observation.
The asset observation will be yielded from the run and appear in the event log.
Only valid if the context has an asset key.
"""
from dagster.core.definitions.metadata import normalize_metadata
from dagster.core.events import DagsterEvent

metadata = check.dict_param(metadata, "metadata", key_type=str)
self._metadata_entries.extend(normalize_metadata(metadata, []))
if self.asset_key:
check.opt_str_param(description, "description")

observation = AssetObservation(
asset_key=self.asset_key,
description=description,
partition=self.asset_partition_key if self.has_asset_partitions else None,
metadata=metadata,
)
self._observations.append(observation)
if self._step_context:
self._events.append(DagsterEvent.asset_observation(self._step_context, observation))

def get_observations(
self,
) -> List[AssetObservation]:
"""Retrieve the list of user-generated asset observations that were observed via the context.
User-generated events that were yielded will not appear in this list.
**Examples:**
.. code-block:: python
from dagster import IOManager, build_input_context, AssetObservation
class MyIOManager(IOManager):
def load_input(self, context, obj):
...
def test_load_input():
mgr = MyIOManager()
context = build_input_context()
mgr.load_input(context)
observations = context.get_observations()
...
"""
return self._observations

def consume_metadata_entries(self) -> List[Union[MetadataEntry, PartitionMetadataEntry]]:
result = self._metadata_entries
self._metadata_entries = []
return result


def build_input_context(
name: Optional[str] = None,
Expand Down
23 changes: 20 additions & 3 deletions python_modules/dagster/dagster/core/execution/plan/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from dagster import check
from dagster.core.definitions import InputDefinition, NodeHandle, PipelineDefinition
from dagster.core.definitions.events import AssetLineageInfo
from dagster.core.definitions.metadata import MetadataEntry
from dagster.core.errors import (
DagsterExecutionLoadInputError,
DagsterInvariantViolationError,
Expand Down Expand Up @@ -160,11 +161,17 @@ def load_input_object(self, step_context: "StepExecutionContext") -> Iterator["D
].config,
resources=build_resources_for_manager(input_def.root_manager_key, step_context),
)
yield _load_input_with_input_manager(loader, load_input_context)
yield from _load_input_with_input_manager(loader, load_input_context)

metadata_entries = load_input_context.consume_metadata_entries()

yield DagsterEvent.loaded_input(
step_context,
input_name=input_def.name,
manager_key=input_def.root_manager_key,
metadata_entries=[
entry for entry in metadata_entries if isinstance(entry, MetadataEntry)
],
)

def compute_version(self, step_versions, pipeline_def, resolved_run_config) -> Optional[str]:
Expand Down Expand Up @@ -278,13 +285,20 @@ def load_input_object(self, step_context: "StepExecutionContext") -> Iterator["D
f"Please ensure that the resource returned for resource key "
f'"{manager_key}" is an IOManager.',
)
yield _load_input_with_input_manager(input_manager, self.get_load_context(step_context))
load_input_context = self.get_load_context(step_context)
yield from _load_input_with_input_manager(input_manager, load_input_context)

metadata_entries = load_input_context.consume_metadata_entries()

yield DagsterEvent.loaded_input(
step_context,
input_name=self.input_name,
manager_key=manager_key,
upstream_output_name=source_handle.output_name,
upstream_step_key=source_handle.step_key,
metadata_entries=[
entry for entry in metadata_entries if isinstance(entry, MetadataEntry)
],
)

def compute_version(
Expand Down Expand Up @@ -586,7 +600,10 @@ def _load_input_with_input_manager(input_manager: "InputManager", context: "Inpu
):
value = input_manager.load_input(context)
# close user code boundary before returning value
return value
for event in context.consume_events():
yield event

yield value


@whitelist_for_serdes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,23 @@

from dagster import (
AssetKey,
AssetObservation,
In,
InputDefinition,
ModeDefinition,
Out,
Output,
OutputDefinition,
StaticPartitionsDefinition,
asset,
build_assets_job,
build_input_context,
execute_pipeline,
io_manager,
job,
op,
pipeline,
root_input_manager,
solid,
)
from dagster.check import CheckError
Expand Down Expand Up @@ -132,6 +142,144 @@ def my_pipeline():
)


def test_io_manager_add_input_metadata():
class MyIOManager(IOManager):
def handle_output(self, context, obj):
pass

def load_input(self, context):
context.add_input_metadata(metadata={"foo": "bar"})
context.add_input_metadata(metadata={"baz": "qux"})

observations = context.get_observations()
assert observations[0].asset_key == context.asset_key
assert observations[0].metadata_entries[0].label == "foo"
assert observations[1].metadata_entries[0].label == "baz"
return 1

@io_manager
def my_io_manager(_):
return MyIOManager()

in_asset_key = AssetKey(["a", "b"])
out_asset_key = AssetKey(["c", "d"])

@op(out=Out(asset_key=out_asset_key))
def before():
pass

@op(ins={"a": In(asset_key=in_asset_key)}, out={})
def after(a):
pass

@job(resource_defs={"io_manager": my_io_manager})
def my_job():
after(before())

get_observation = lambda event: event.event_specific_data.asset_observation

result = my_job.execute_in_process()
observations = [
event for event in result.all_node_events if event.event_type_value == "ASSET_OBSERVATION"
]

# first observation
assert observations[0].step_key == "after"
assert get_observation(observations[0]) == AssetObservation(
asset_key=in_asset_key, metadata={"foo": "bar"}
)
# second observation
assert observations[1].step_key == "after"
assert get_observation(observations[1]) == AssetObservation(
asset_key=in_asset_key, metadata={"baz": "qux"}
)

# confirm loaded_input event contains metadata
loaded_input_event = [
event for event in result.all_events if event.event_type_value == "LOADED_INPUT"
][0]
assert loaded_input_event
loaded_input_event_metadata = loaded_input_event.event_specific_data.metadata_entries
assert len(loaded_input_event_metadata) == 2
assert loaded_input_event_metadata[0].label == "foo"
assert loaded_input_event_metadata[1].label == "baz"


def test_root_input_manager_add_input_metadata():
@root_input_manager
def my_root_input_manager(context):
context.add_input_metadata(metadata={"foo": "bar"})
context.add_input_metadata(metadata={"baz": "qux"})
return []

@op(ins={"input1": In(root_manager_key="my_root_input_manager")})
def my_op(_, input1):
return input1

@job(resource_defs={"my_root_input_manager": my_root_input_manager})
def my_job():
my_op()

result = my_job.execute_in_process()
loaded_input_event = [
event for event in result.all_events if event.event_type_value == "LOADED_INPUT"
][0]
metadata_entries = loaded_input_event.event_specific_data.metadata_entries
assert len(metadata_entries) == 2
assert metadata_entries[0].label == "foo"
assert metadata_entries[1].label == "baz"


def test_io_manager_single_partition_add_input_metadata():
partitions_def = StaticPartitionsDefinition(["a", "b", "c"])

@asset(partitions_def=partitions_def)
def asset_1():
return 1

@asset(partitions_def=partitions_def)
def asset_2(asset_1):
return 2

class MyIOManager(IOManager):
def handle_output(self, context, obj):
pass

def load_input(self, context):
context.add_input_metadata(metadata={"foo": "bar"}, description="hello world")
return 1

@io_manager
def my_io_manager(_):
return MyIOManager()

assets_job = build_assets_job(
"assets_job", [asset_1, asset_2], resource_defs={"io_manager": my_io_manager}
)
result = assets_job.execute_in_process(partition_key="a")

get_observation = lambda event: event.event_specific_data.asset_observation

observations = [
event for event in result.all_node_events if event.event_type_value == "ASSET_OBSERVATION"
]

assert observations[0].step_key == "asset_2"
assert get_observation(observations[0]) == AssetObservation(
asset_key="asset_1", metadata={"foo": "bar"}, description="hello world", partition="a"
)


def test_context_error_add_input_metadata():
@op
def my_op():
pass

context = build_input_context(op_def=my_op)
with pytest.raises(CheckError):
context.add_input_metadata({"foo": "bar"})


def test_io_manager_single_partition_materialization():

entry1 = MetadataEntry("nrows", value=123)
Expand Down

0 comments on commit 0831538

Please sign in to comment.