Skip to content

Commit

Permalink
[graph] use the correct input definition for type checks (#7453)
Browse files Browse the repository at this point in the history
https://dagster.phacility.com/D6155 added things to the input source types to allow for removing all user code from the execution plan structure. 

`input_name` and `solid_handle` values were added to the source objects. The way these values are populated is that they correspond to where along the mapping chain the dependency was resolved. This is important for the `FromConfig` since the config may be supplied at different places in the config structure which map to different input mapping layers.
Unfortunately, for other sources such as `FromStepOutput` this information is more confusing than useful. 

This subtle confusion was fine with `@composite_solid` since it enforced that all types were equal along a mapping chain.

`@graph` dropped this requirement, which exposes this subtle "bug" and caused type checks to source the `dagster_type` from the mapping input def, which in the `@graph` case is likely `Any` instead of the `@op` input def. 


This diff fixes the underlying issue by passing `input_def` directly and deprecating the confusing properties from the input sources they are not needed on.


resolves #7452

## Test Plan

test added

still need to ensure that the serdes changes are safe and do not require more extensive back-compat / forward-compat protections .
  • Loading branch information
alangenfeld committed Apr 18, 2022
1 parent 2c54844 commit 39c853a
Show file tree
Hide file tree
Showing 12 changed files with 333 additions and 220 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -107,19 +107,19 @@ def output_defs(self) -> Sequence["OutputDefinition"]:
def output_dict(self) -> Mapping[str, "OutputDefinition"]:
return self._output_dict

def has_input(self, name):
def has_input(self, name) -> bool:
check.str_param(name, "name")
return name in self._input_dict

def input_def_named(self, name):
def input_def_named(self, name) -> "InputDefinition":
check.str_param(name, "name")
return self._input_dict[name]

def has_output(self, name):
def has_output(self, name) -> bool:
check.str_param(name, "name")
return name in self._output_dict

def output_def_named(self, name):
def output_def_named(self, name) -> "OutputDefinition":
check.str_param(name, "name")
return self._output_dict[name]

Expand Down
3 changes: 1 addition & 2 deletions python_modules/dagster/dagster/core/events/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -707,8 +707,7 @@ def step_retry_event(
def step_input_event(
step_context: StepExecutionContext, step_input_data: "StepInputData"
) -> "DagsterEvent":
step_input = step_context.step.step_input_named(step_input_data.input_name)
input_def = step_input.source.get_input_def(step_context.pipeline_def)
input_def = step_context.solid_def.input_def_named(step_input_data.input_name)

return DagsterEvent.from_step(
event_type=DagsterEventType.STEP_INPUT,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -292,8 +292,9 @@ def asset_partitions_time_window(self) -> TimeWindow:

partition_key_range = self.asset_partition_key_range
return TimeWindow(
partitions_def.time_window_for_partition_key(partition_key_range.start).start,
partitions_def.time_window_for_partition_key(partition_key_range.end).end,
# mypy thinks partitions_def is <nothing> here because ????
partitions_def.time_window_for_partition_key(partition_key_range.start).start, # type: ignore
partitions_def.time_window_for_partition_key(partition_key_range.end).end, # type: ignore
)

def consume_events(self) -> Iterator["DagsterEvent"]:
Expand Down
15 changes: 7 additions & 8 deletions python_modules/dagster/dagster/core/execution/context/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
from dagster.core.storage.pipeline_run import PipelineRun
from dagster.core.storage.tags import PARTITION_NAME_TAG
from dagster.core.system_config.objects import ResolvedRunConfig
from dagster.core.types.dagster_type import DagsterType, DagsterTypeKind
from dagster.core.types.dagster_type import DagsterType

from .input import InputContext
from .output import OutputContext, get_output_context
Expand Down Expand Up @@ -559,7 +559,6 @@ def add_output_metadata(
f"In {self.solid_def.node_type_str} '{self.solid.name}', attempted to log metadata for dynamic output '{output_def.name}' without providing a mapping key. When logging metadata for a dynamic output, it is necessary to provide a mapping key."
)

output_name = output_def.name
if output_name in self._output_metadata:
if not mapping_key or mapping_key in self._output_metadata[output_name]:
raise DagsterInvariantViolationError(
Expand Down Expand Up @@ -744,24 +743,24 @@ def asset_partitions_time_window_for_output(self, output_name: str) -> TimeWindo
"Tried to get asset partitions for an output that correponds to a partitioned "
"asset that is not partitioned with a TimeWindowPartitionsDefinition."
)

partition_key_range = self.asset_partition_key_range_for_output(output_name)
return TimeWindow(
partitions_def.time_window_for_partition_key(partition_key_range.start).start,
partitions_def.time_window_for_partition_key(partition_key_range.end).end,
# mypy thinks partitions_def is <nothing> here because ????
partitions_def.time_window_for_partition_key(partition_key_range.start).start, # type: ignore
partitions_def.time_window_for_partition_key(partition_key_range.end).end, # type: ignore
)

def get_input_lineage(self) -> List[AssetLineageInfo]:
if not self._input_lineage:

for step_input in self.step.step_inputs:
input_def = step_input.source.get_input_def(self.pipeline_def)
input_def = self.solid_def.input_def_named(step_input.name)
dagster_type = input_def.dagster_type

if dagster_type.kind == DagsterTypeKind.NOTHING:
if dagster_type.is_nothing:
continue

self._input_lineage.extend(step_input.source.get_asset_lineage(self))
self._input_lineage.extend(step_input.source.get_asset_lineage(self, input_def))

self._input_lineage = _dedup_asset_lineage(self._input_lineage)

Expand Down
32 changes: 24 additions & 8 deletions python_modules/dagster/dagster/core/execution/plan/execute_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from dagster.core.execution.resolve_versions import resolve_step_output_versions
from dagster.core.storage.io_manager import IOManager
from dagster.core.storage.tags import MEMOIZED_RUN_TAG
from dagster.core.types.dagster_type import DagsterType, DagsterTypeKind
from dagster.core.types.dagster_type import DagsterType
from dagster.utils import ensure_gen, iterate_with_context
from dagster.utils.backcompat import ExperimentalWarning, experimental_functionality_warning
from dagster.utils.timing import time_execution_scope
Expand Down Expand Up @@ -135,7 +135,7 @@ def _step_output_error_checked_user_event_sequence(
for step_output in step.step_outputs:
step_output_def = step_context.solid_def.output_def_named(step_output.name)
if not step_context.has_seen_output(step_output_def.name) and not step_output_def.optional:
if step_output_def.dagster_type.kind == DagsterTypeKind.NOTHING:
if step_output_def.dagster_type.is_nothing:
step_context.log.info(
f'Emitting implicit Nothing for output "{step_output_def.name}" on {op_label}'
)
Expand Down Expand Up @@ -186,13 +186,21 @@ def _create_step_input_event(


def _type_checked_event_sequence_for_input(
step_context: StepExecutionContext, input_name: str, input_value: Any
step_context: StepExecutionContext,
input_name: str,
input_value: Any,
) -> Iterator[DagsterEvent]:
check.inst_param(step_context, "step_context", StepExecutionContext)
check.str_param(input_name, "input_name")

step_input = step_context.step.step_input_named(input_name)
input_def = step_input.source.get_input_def(step_context.pipeline_def)
input_def = step_context.solid_def.input_def_named(step_input.name)

check.invariant(
input_def.name == input_name,
f"InputDefinition name does not match, expected {input_name} got {input_def.name}",
)

dagster_type = input_def.dagster_type
type_check_context = step_context.for_type(dagster_type)
input_type = type(input_value)
Expand Down Expand Up @@ -299,12 +307,15 @@ def core_dagster_event_sequence_for_step(
inputs = {}

for step_input in step_context.step.step_inputs:
input_def = step_input.source.get_input_def(step_context.pipeline_def)
input_def = step_context.solid_def.input_def_named(step_input.name)
dagster_type = input_def.dagster_type

if dagster_type.kind == DagsterTypeKind.NOTHING:
if dagster_type.is_nothing:
continue
for event_or_input_value in ensure_gen(step_input.source.load_input_object(step_context)):

for event_or_input_value in ensure_gen(
step_input.source.load_input_object(step_context, input_def)
):
if isinstance(event_or_input_value, DagsterEvent):
yield event_or_input_value
else:
Expand Down Expand Up @@ -652,7 +663,12 @@ def _create_type_materializations(
):
output_def = step_context.solid_def.output_def_named(step_output.name)
dagster_type = output_def.dagster_type
materializations = dagster_type.materializer.materialize_runtime_values(
materializer = dagster_type.materializer
if materializer is None:
check.failed(
"Unexpected attempt to materialize with no materializer available on dagster_type"
)
materializations = materializer.materialize_runtime_values(
step_context, output_spec, value
)

Expand Down

0 comments on commit 39c853a

Please sign in to comment.