Skip to content

Commit

Permalink
Fix upstream context handling in fs_asset_io_manager (#8007)
Browse files Browse the repository at this point in the history
* Add test case that is currently failing

* Add helper function to get asset identifier from InputContext

* Rework asset path determination logic in fs_asset_io_manager

By passing either InputContext or OutputContext to _get_path, we can
implement it without using upstream_output for assets.

Using upstream_output for assets was problematic, because fetching the
partition_key required upstream step config, which is not available in
upstream_output.

* Refactor to enable change of arguments in get_output_context.

* Emmit warning on use of step_context through upstream_output

This change should be reverted once upstream_output.step_context is
initialized as None in PlanExecutionContext.for_input_manager

* Add uniform get_asset_identifier for both InputContext and OutputContext

* Add uniform get_identifier method for InputContext and OutputContext

* Adapt to new API

* Fix a couple of typing issues related to optionals

* Fix another typing issue

Black formatting unsilenced a mypy error that was explicitly ignored. We
move things around to fix the original mypy complain.

* Change all uses of get_output_identifier to get_identifier

This fixes an error on dagster-tests.
  • Loading branch information
aroig committed May 26, 2022
1 parent 28f9930 commit 8e8ee85
Show file tree
Hide file tree
Showing 16 changed files with 258 additions and 65 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def metadata_for_actions(df):

class MyDatabaseIOManager(PickledObjectFilesystemIOManager):
def _get_path(self, context):
keys = context.get_output_identifier()
keys = context.get_identifier()

return os.path.join("/tmp", *keys)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,14 @@ def handle_output(self, context, obj):
if self._throw_output:
raise ExampleException("throwing up trying to handle output")

keys = tuple(context.get_output_identifier())
keys = tuple(context.get_identifier())
self._values[keys] = obj

def load_input(self, context):
if self._throw_input:
raise ExampleException("throwing up trying to load input")

keys = tuple(context.upstream_output.get_output_identifier())
keys = tuple(context.upstream_output.get_identifier())
return self._values[keys]


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def metadata_for_actions(df):

class MyDatabaseIOManager(PickledObjectFilesystemIOManager):
def _get_path(self, context):
keys = context.get_output_identifier()
keys = context.get_identifier()

return os.path.join("/tmp", *keys)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,14 @@ def handle_output(self, context, obj):
if self._throw_output:
raise ExampleException("throwing up trying to handle output")

keys = tuple(context.get_output_identifier())
keys = tuple(context.get_identifier())
self._values[keys] = obj

def load_input(self, context):
if self._throw_input:
raise ExampleException("throwing up trying to load input")

keys = tuple(context.upstream_output.get_output_identifier())
keys = tuple(context.upstream_output.get_identifier())
return self._values[keys]


Expand Down
53 changes: 44 additions & 9 deletions python_modules/dagster/dagster/core/execution/context/input.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Union, cast
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Sequence, Union, cast

import dagster._check as check
from dagster.core.definitions.events import AssetKey, AssetObservation
Expand Down Expand Up @@ -284,28 +284,63 @@ def asset_partitions_time_window(self) -> TimeWindow:
if self.upstream_output is None:
check.failed("InputContext needs upstream_output to get asset_partitions_time_window")

asset_info = self.upstream_output.asset_info
partitions_def = asset_info.partitions_def if asset_info else None

if not partitions_def:
if self.upstream_output.asset_info is None:
raise ValueError(
"Tried to get asset partitions for an output that does not correspond to a "
"partitioned asset."
)

if not isinstance(partitions_def, TimeWindowPartitionsDefinition):
asset_info = self.upstream_output.asset_info

if not isinstance(asset_info.partitions_def, TimeWindowPartitionsDefinition):
raise ValueError(
"Tried to get asset partitions for an input that correponds to a partitioned "
"asset that is not partitioned with a TimeWindowPartitionsDefinition."
)

partitions_def: TimeWindowPartitionsDefinition = asset_info.partitions_def

partition_key_range = self.asset_partition_key_range
return TimeWindow(
# mypy thinks partitions_def is <nothing> here because ????
partitions_def.time_window_for_partition_key(partition_key_range.start).start, # type: ignore
partitions_def.time_window_for_partition_key(partition_key_range.end).end, # type: ignore
partitions_def.time_window_for_partition_key(partition_key_range.start).start,
partitions_def.time_window_for_partition_key(partition_key_range.end).end,
)

def get_identifier(self) -> List[str]:
"""Utility method to get a collection of identifiers that as a whole represent a unique
step input.
If not using memoization, the unique identifier collection consists of
- ``run_id``: the id of the run which generates the input.
Note: This method also handles the re-execution memoization logic. If the step that
generates the input is skipped in the re-execution, the ``run_id`` will be the id
of its parent run.
- ``step_key``: the key for a compute step.
- ``name``: the name of the output. (default: 'result').
If using memoization, the ``version`` corresponding to the step output is used in place of
the ``run_id``.
Returns:
List[str, ...]: A list of identifiers, i.e. (run_id or version), step_key, and output_name
"""
if self.upstream_output is None:
raise DagsterInvariantViolationError(
"InputContext.upstream_output not defined. " "Cannot compute an identifier"
)

return self.upstream_output.get_identifier()

def get_asset_identifier(self) -> Sequence[str]:
if self.asset_key is not None:
if self.has_asset_partitions:
return self.asset_key.path + [self.asset_partition_key]
else:
return self.asset_key.path
else:
check.failed("Can't get asset identifier for an input with no asset key")

def consume_events(self) -> Iterator["DagsterEvent"]:
"""Pops and yields all user-generated events that have been recorded from this context.
Expand Down
82 changes: 79 additions & 3 deletions python_modules/dagster/dagster/core/execution/context/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def __init__(
step_context: Optional["StepExecutionContext"] = None,
op_def: Optional["OpDefinition"] = None,
asset_info: Optional[AssetOutputInfo] = None,
warn_on_step_context_use: bool = False,
):
from dagster.core.definitions.resource_definition import IContainsGenerator, Resources
from dagster.core.execution.build_resources import build_resources
Expand All @@ -97,6 +98,7 @@ def __init__(
self._resource_config = resource_config
self._step_context = step_context
self._asset_info = asset_info
self._warn_on_step_context_use = warn_on_step_context_use

if isinstance(resources, Resources):
self._resources_cm = None
Expand Down Expand Up @@ -262,6 +264,14 @@ def asset_key(self) -> AssetKey:

@property
def step_context(self) -> "StepExecutionContext":
if self._warn_on_step_context_use:
warnings.warn(
"You are using InputContext.upstream_output.step_context"
"This use on upstream_output is deprecated and will fail in the future"
"Try to obtain what you need directly from InputContext"
"For more details: https://github.com/dagster-io/dagster/issues/7900"
)

if self._step_context is None:
raise DagsterInvariantViolationError(
"Attempting to access step_context, "
Expand All @@ -273,6 +283,14 @@ def step_context(self) -> "StepExecutionContext":
@property
def has_partition_key(self) -> bool:
"""Whether the current run is a partitioned run"""
if self._warn_on_step_context_use:
warnings.warn(
"You are using InputContext.upstream_output.has_partition_key"
"This use on upstream_output is deprecated and will fail in the future"
"Try to obtain what you need directly from InputContext"
"For more details: https://github.com/dagster-io/dagster/issues/7900"
)

return self.step_context.has_partition_key

@property
Expand All @@ -281,10 +299,26 @@ def partition_key(self) -> str:
Raises an error if the current run is not a partitioned run.
"""
if self._warn_on_step_context_use:
warnings.warn(
"You are using InputContext.upstream_output.partition_key"
"This use on upstream_output is deprecated and will fail in the future"
"Try to obtain what you need directly from InputContext"
"For more details: https://github.com/dagster-io/dagster/issues/7900"
)

return self.step_context.partition_key

@property
def has_asset_partitions(self) -> bool:
if self._warn_on_step_context_use:
warnings.warn(
"You are using InputContext.upstream_output.has_asset_partitions"
"This use on upstream_output is deprecated and will fail in the future"
"Try to obtain what you need directly from InputContext"
"For more details: https://github.com/dagster-io/dagster/issues/7900"
)

if self._step_context is not None:
return self._step_context.has_asset_partitions_for_output(self.name)
else:
Expand All @@ -297,6 +331,14 @@ def asset_partition_key(self) -> str:
Raises an error if the output asset has no partitioning, or if the run covers a partition
range for the output asset.
"""
if self._warn_on_step_context_use:
warnings.warn(
"You are using InputContext.upstream_output.asset_partition_key"
"This use on upstream_output is deprecated and will fail in the future"
"Try to obtain what you need directly from InputContext"
"For more details: https://github.com/dagster-io/dagster/issues/7900"
)

return self.step_context.asset_partition_key_for_output(self.name)

@property
Expand All @@ -305,6 +347,14 @@ def asset_partition_key_range(self) -> PartitionKeyRange:
Raises an error if the output asset has no partitioning.
"""
if self._warn_on_step_context_use:
warnings.warn(
"You are using InputContext.upstream_output.asset_partition_key_range"
"This use on upstream_output is deprecated and will fail in the future"
"Try to obtain what you need directly from InputContext"
"For more details: https://github.com/dagster-io/dagster/issues/7900"
)

return self.step_context.asset_partition_key_range_for_output(self.name)

@property
Expand All @@ -315,6 +365,14 @@ def asset_partitions_time_window(self) -> TimeWindow:
- The output asset has no partitioning.
- The output asset is not partitioned with a TimeWindowPartitionsDefinition.
"""
if self._warn_on_step_context_use:
warnings.warn(
"You are using InputContext.upstream_output.asset_partitions_time_window"
"This use on upstream_output is deprecated and will fail in the future"
"Try to obtain what you need directly from InputContext"
"For more details: https://github.com/dagster-io/dagster/issues/7900"
)

return self.step_context.asset_partitions_time_window_for_output(self.name)

def get_run_scoped_output_identifier(self) -> List[str]:
Expand All @@ -336,7 +394,7 @@ def get_run_scoped_output_identifier(self) -> List[str]:

warnings.warn(
"`OutputContext.get_run_scoped_output_identifier` is deprecated. Use "
"`OutputContext.get_output_identifier` instead."
"`OutputContext.get_identifier` instead."
)
# if run_id is None and this is a re-execution, it means we failed to find its source run id
check.invariant(
Expand All @@ -360,7 +418,7 @@ def get_run_scoped_output_identifier(self) -> List[str]:

return [run_id, step_key, name]

def get_output_identifier(self) -> List[str]:
def get_identifier(self) -> List[str]:
"""Utility method to get a collection of identifiers that as a whole represent a unique
step output.
Expand Down Expand Up @@ -397,7 +455,15 @@ def get_output_identifier(self) -> List[str]:

return identifier

def get_asset_output_identifier(self) -> Sequence[str]:
def get_output_identifier(self) -> List[str]:
warnings.warn(
"`OutputContext.get_output_identifier` is deprecated. Use "
"`OutputContext.get_identifier` instead."
)

return self.get_identifier()

def get_asset_identifier(self) -> Sequence[str]:
if self.asset_key is not None:
if self.has_asset_partitions:
return self.asset_key.path + [self.asset_partition_key]
Expand All @@ -406,6 +472,14 @@ def get_asset_output_identifier(self) -> Sequence[str]:
else:
check.failed("Can't get asset output identifier for an output with no asset key")

def get_asset_output_identifier(self) -> Sequence[str]:
warnings.warn(
"`OutputContext.get_asset_output_identifier` is deprecated. Use "
"`OutputContext.get_asset_identifier` instead."
)

return self.get_asset_identifier()

def log_event(
self, event: Union[AssetObservation, AssetMaterialization, Materialization]
) -> None:
Expand Down Expand Up @@ -534,6 +608,7 @@ def get_output_context(
step_context: Optional["StepExecutionContext"],
resources: Optional["Resources"],
version: Optional[str],
warn_on_step_context_use: bool = False,
) -> "OutputContext":
"""
Args:
Expand Down Expand Up @@ -587,6 +662,7 @@ def get_output_context(
resource_config=resource_config,
resources=resources,
asset_info=asset_info,
warn_on_step_context_use=warn_on_step_context_use,
)


Expand Down
27 changes: 24 additions & 3 deletions python_modules/dagster/dagster/core/execution/context/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,15 +506,36 @@ def for_input_manager(
) -> InputContext:
if source_handle and artificial_output_context:
check.failed("Cannot specify both source_handle and artificial_output_context.")

upstream_output: Optional[OutputContext] = None

if source_handle is not None:
version = self.execution_plan.get_version_for_step_output_handle(source_handle)

# NOTE: this is using downstream step_context for upstream OutputContext. step_context
# will be set to None for 0.15 release.
upstream_output = get_output_context(
self.execution_plan,
self.pipeline_def,
self.resolved_run_config,
source_handle,
self._get_source_run_id(source_handle),
log_manager=self.log,
step_context=self,
resources=None,
version=version,
warn_on_step_context_use=True,
)
else:
upstream_output = artificial_output_context

return InputContext(
pipeline_name=self.pipeline_def.name,
name=name,
solid_def=self.solid_def,
config=config,
metadata=metadata,
upstream_output=self.get_output_context(source_handle)
if source_handle is not None
else artificial_output_context,
upstream_output=upstream_output,
dagster_type=dagster_type,
log_manager=self.log,
step_context=self,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import os
from typing import Union

from dagster.config import Field
from dagster.config.source import StringSource
from dagster.core.execution.context.input import InputContext
from dagster.core.execution.context.output import OutputContext
from dagster.core.storage.io_manager import io_manager

from .fs_io_manager import PickledObjectFilesystemIOManager
Expand Down Expand Up @@ -80,5 +83,6 @@ def asset2(asset1):


class AssetPickledObjectFilesystemIOManager(PickledObjectFilesystemIOManager):
def _get_path(self, context):
return os.path.join(self.base_dir, *context.get_asset_output_identifier())
def _get_path(self, context: Union[InputContext, OutputContext]) -> str:
identifier = context.get_asset_identifier()
return os.path.join(self.base_dir, *identifier)

0 comments on commit 8e8ee85

Please sign in to comment.