diff --git a/python_modules/dagster/dagster/_core/execution/execution_result.py b/python_modules/dagster/dagster/_core/execution/execution_result.py index 9cffc8268995..eac03e73121e 100644 --- a/python_modules/dagster/dagster/_core/execution/execution_result.py +++ b/python_modules/dagster/dagster/_core/execution/execution_result.py @@ -164,6 +164,9 @@ def asset_observations_for_node(self, node_name: str) -> Sequence[AssetObservati def get_step_success_events(self) -> Sequence[DagsterEvent]: return [event for event in self.all_events if event.is_step_success] + def get_step_skipped_events(self) -> Sequence[DagsterEvent]: + return [event for event in self.all_events if event.is_step_skipped] + def get_failed_step_keys(self) -> AbstractSet[str]: failure_events = self.filter_events( lambda event: event.is_step_failure or event.is_resource_init_failure diff --git a/python_modules/dagster/dagster/_core/execution/plan/active.py b/python_modules/dagster/dagster/_core/execution/plan/active.py index 996ce22f6413..871d6544a3ab 100644 --- a/python_modules/dagster/dagster/_core/execution/plan/active.py +++ b/python_modules/dagster/dagster/_core/execution/plan/active.py @@ -64,13 +64,16 @@ def __init__( # track mapping keys from DynamicOutputs, step_key, output_name -> list of keys # to _gathering while in flight - self._gathering_dynamic_outputs: Dict[str, Mapping[str, List[str]]] = {} - # then on success move to _successful - self._successful_dynamic_outputs: Dict[str, Mapping[str, Sequence[str]]] = ( + self._gathering_dynamic_outputs: Dict[str, Mapping[str, Optional[List[str]]]] = {} + # then on resolution move to _completed + self._completed_dynamic_outputs: Dict[str, Mapping[str, Optional[Sequence[str]]]] = ( dict(self._plan.known_state.dynamic_mappings) if self._plan.known_state else {} ) self._new_dynamic_mappings: bool = False + # track which upstream deps caused a step to skip + self._skipped_deps: Dict[str, Sequence[str]] = {} + # steps move in to these buckets as a result of _update calls self._executable: List[str] = [] self._pending_skip: List[str] = [] @@ -160,7 +163,7 @@ def _update(self) -> None: failed_or_abandoned_steps = self._failed | self._abandoned if self._new_dynamic_mappings: - new_step_deps = self._plan.resolve(self._successful_dynamic_outputs) + new_step_deps = self._plan.resolve(self._completed_dynamic_outputs) for step_key, deps in new_step_deps.items(): self._pending[step_key] = deps @@ -192,6 +195,9 @@ def _update(self) -> None: step_input.get_step_output_handle_dependencies() ): should_skip = True + self._skipped_deps[step_key] = [ + f"{h.step_key}.{h.output_name}" for h in missing_source_handles + ] break if should_skip: @@ -308,7 +314,8 @@ def get_steps_to_skip(self) -> Sequence[ExecutionStep]: steps.append(step) self._in_flight.add(key) self._pending_skip.remove(key) - self._prep_for_dynamic_outputs(step) + self._gathering_dynamic_outputs + self._skip_for_dynamic_outputs(step) return sorted(steps, key=self._sort_key_fn) @@ -331,14 +338,9 @@ def plan_events_iterator(self, pipeline_context) -> Iterator[DagsterEvent]: while steps_to_skip: for step in steps_to_skip: step_context = pipeline_context.for_step(step) - skipped_inputs: List[str] = [] - for step_input in step.step_inputs: - skipped_inputs.extend(self._skipped.intersection(step_input.dependency_keys)) - step_context.log.info( - "Skipping step {step} due to skipped dependencies: {skipped_inputs}.".format( - step=step.key, skipped_inputs=skipped_inputs - ) + f"Skipping step {step.key} due to skipped dependencies:" + f" {self._skipped_deps[step.key]}." ) yield DagsterEvent.step_skipped_event(step_context) @@ -461,9 +463,11 @@ def handle_event(self, dagster_event: DagsterEvent) -> None: event_specific_data = cast(StepOutputData, dagster_event.event_specific_data) self.mark_step_produced_output(event_specific_data.step_output_handle) if dagster_event.step_output_data.step_output_handle.mapping_key: - self._gathering_dynamic_outputs[step_key][ - dagster_event.step_output_data.step_output_handle.output_name - ].append(dagster_event.step_output_data.step_output_handle.mapping_key) + check.not_none( + self._gathering_dynamic_outputs[step_key][ + dagster_event.step_output_data.step_output_handle.output_name + ], + ).append(dagster_event.step_output_data.step_output_handle.mapping_key) def verify_complete(self, pipeline_context: PlanOrchestrationContext, step_key: str) -> None: """Ensure that a step has reached a terminal state, if it has not mark it as an unexpected failure @@ -505,7 +509,7 @@ def retry_state(self) -> RetryState: def get_known_state(self) -> KnownExecutionState: return KnownExecutionState( previous_retry_attempts=self._retry_state.snapshot_attempts(), - dynamic_mappings=dict(self._successful_dynamic_outputs), + dynamic_mappings=dict(self._completed_dynamic_outputs), ready_outputs=self._step_outputs, step_output_versions=self._plan.known_state.step_output_versions, parent_state=self._plan.known_state.parent_state, @@ -513,13 +517,28 @@ def get_known_state(self) -> KnownExecutionState: def _prep_for_dynamic_outputs(self, step: ExecutionStep): dyn_outputs = [step_out for step_out in step.step_outputs if step_out.is_dynamic] - if dyn_outputs: self._gathering_dynamic_outputs[step.key] = {out.name: [] for out in dyn_outputs} + def _skip_for_dynamic_outputs(self, step: ExecutionStep): + dyn_outputs = [step_out for step_out in step.step_outputs if step_out.is_dynamic] + if dyn_outputs: + # place None to indicate the dynamic output was skipped, different than having 0 entries + self._gathering_dynamic_outputs[step.key] = {out.name: None for out in dyn_outputs} + def _resolve_any_dynamic_outputs(self, step_key: str) -> None: if step_key in self._gathering_dynamic_outputs: - self._successful_dynamic_outputs[step_key] = self._gathering_dynamic_outputs[step_key] + step = self.get_step_by_key(step_key) + completed_mappings: Dict[str, Optional[Sequence[str]]] = {} + for output_name, mappings in self._gathering_dynamic_outputs[step_key].items(): + # if no dynamic outputs were returned and the output was marked is_required=False + # set to None to indicate a skip should occur + if not mappings and not step.step_output_dict[output_name].is_required: + completed_mappings[output_name] = None + else: + completed_mappings[output_name] = mappings + + self._completed_dynamic_outputs[step_key] = completed_mappings self._new_dynamic_mappings = True def rebuild_from_events( diff --git a/python_modules/dagster/dagster/_core/execution/plan/inputs.py b/python_modules/dagster/dagster/_core/execution/plan/inputs.py index 86fd943538c9..c1874a9887b5 100644 --- a/python_modules/dagster/dagster/_core/execution/plan/inputs.py +++ b/python_modules/dagster/dagster/_core/execution/plan/inputs.py @@ -1037,7 +1037,17 @@ def get_step_output_handle_dep_with_placeholder(self) -> StepOutputHandle: def required_resource_keys(self, _pipeline_def: PipelineDefinition) -> Set[str]: return set() - def resolve(self, mapping_keys): + def resolve(self, mapping_keys: Optional[Sequence[str]]): + if mapping_keys is None: + # None means that the dynamic output was skipped, so create + # a dependency on the dynamic output that will continue cascading the skip + return FromStepOutput( + step_output_handle=StepOutputHandle( + step_key=self.resolved_by_step_key, + output_name=self.resolved_by_output_name, + ), + fan_in=False, + ) return FromMultipleSources( sources=[self.source.resolve(map_key) for map_key in mapping_keys], ) @@ -1086,7 +1096,7 @@ def resolved_by_step_key(self) -> str: def resolved_by_output_name(self) -> str: return self.source.resolved_by_output_name - def resolve(self, mapping_keys: Sequence[str]) -> StepInput: + def resolve(self, mapping_keys: Optional[Sequence[str]]) -> StepInput: return StepInput( name=self.name, dagster_type_key=self.dagster_type_key, diff --git a/python_modules/dagster/dagster/_core/execution/plan/plan.py b/python_modules/dagster/dagster/_core/execution/plan/plan.py index 9b0e9e4dca8c..c77b8624275c 100644 --- a/python_modules/dagster/dagster/_core/execution/plan/plan.py +++ b/python_modules/dagster/dagster/_core/execution/plan/plan.py @@ -770,7 +770,7 @@ def get_executable_step_deps(self) -> Mapping[str, Set[str]]: def resolve( self, - mappings: Mapping[str, Mapping[str, Sequence[str]]], + mappings: Mapping[str, Mapping[str, Optional[Sequence[str]]]], ) -> Mapping[str, Set[str]]: """Resolve any dynamic map or collect steps with the resolved dynamic mappings""" @@ -1197,7 +1197,7 @@ def _update_from_resolved_dynamic_outputs( executable_map: Dict[str, Union[StepHandle, ResolvedFromDynamicStepHandle]], resolvable_map: Dict[FrozenSet[str], Sequence[Union[StepHandle, UnresolvedStepHandle]]], step_handles_to_execute: Sequence[StepHandleUnion], - dynamic_mappings: Mapping[str, Mapping[str, Sequence[str]]], + dynamic_mappings: Mapping[str, Mapping[str, Optional[Sequence[str]]]], ) -> None: resolved_steps: List[ExecutionStep] = [] key_sets_to_clear: List[FrozenSet[str]] = [] diff --git a/python_modules/dagster/dagster/_core/execution/plan/state.py b/python_modules/dagster/dagster/_core/execution/plan/state.py index 4695eb87862a..504412b0e099 100644 --- a/python_modules/dagster/dagster/_core/execution/plan/state.py +++ b/python_modules/dagster/dagster/_core/execution/plan/state.py @@ -74,7 +74,7 @@ class KnownExecutionState( # step_key -> count ("previous_retry_attempts", Mapping[str, int]), # step_key -> output_name -> mapping_keys - ("dynamic_mappings", Mapping[str, Mapping[str, Sequence[str]]]), + ("dynamic_mappings", Mapping[str, Mapping[str, Optional[Sequence[str]]]]), # step_output_handle -> version ("step_output_versions", Sequence[StepOutputVersionData]), ("ready_outputs", Set[StepOutputHandle]), @@ -91,7 +91,7 @@ class KnownExecutionState( def __new__( cls, previous_retry_attempts: Optional[Mapping[str, int]] = None, - dynamic_mappings: Optional[Mapping[str, Mapping[str, Sequence[str]]]] = None, + dynamic_mappings: Optional[Mapping[str, Mapping[str, Optional[Sequence[str]]]]] = None, step_output_versions: Optional[Sequence[StepOutputVersionData]] = None, ready_outputs: Optional[Set[StepOutputHandle]] = None, parent_state: Optional[PastExecutionState] = None, diff --git a/python_modules/dagster/dagster/_core/execution/plan/step.py b/python_modules/dagster/dagster/_core/execution/plan/step.py index 564e7637a32a..f861a4731674 100644 --- a/python_modules/dagster/dagster/_core/execution/plan/step.py +++ b/python_modules/dagster/dagster/_core/execution/plan/step.py @@ -3,6 +3,7 @@ from typing import ( TYPE_CHECKING, FrozenSet, + List, Mapping, NamedTuple, Optional, @@ -340,15 +341,21 @@ def resolved_by_step_keys(self) -> FrozenSet[str]: return frozenset(keys) def resolve( - self, mappings: Mapping[str, Mapping[str, Sequence[str]]] + self, mappings: Mapping[str, Mapping[str, Optional[Sequence[str]]]] ) -> Sequence[ExecutionStep]: check.invariant( all(key in mappings for key in self.resolved_by_step_keys), "resolving with mappings that do not contain all required step keys", ) - execution_steps = [] + execution_steps: List[ExecutionStep] = [] - for mapped_key in mappings[self.resolved_by_step_key][self.resolved_by_output_name]: + mapping_keys = mappings[self.resolved_by_step_key][self.resolved_by_output_name] + + # dynamic output skipped + if mapping_keys is None: + return execution_steps + + for mapped_key in mapping_keys: resolved_inputs = [_resolved_input(inp, mapped_key) for inp in self.step_inputs] execution_steps.append( @@ -472,7 +479,9 @@ def resolved_by_step_keys(self) -> FrozenSet[str]: return frozenset(keys) - def resolve(self, mappings: Mapping[str, Mapping[str, Sequence[str]]]) -> ExecutionStep: + def resolve( + self, mappings: Mapping[str, Mapping[str, Optional[Sequence[str]]]] + ) -> ExecutionStep: check.invariant( all(key in mappings for key in self.resolved_by_step_keys), "resolving with mappings that do not contain all required step keys", diff --git a/python_modules/dagster/dagster_tests/execution_tests/dynamic_tests/test_dynamic_execution.py b/python_modules/dagster/dagster_tests/execution_tests/dynamic_tests/test_dynamic_execution.py index 50aa850af17f..fda33a8391f0 100644 --- a/python_modules/dagster/dagster_tests/execution_tests/dynamic_tests/test_dynamic_execution.py +++ b/python_modules/dagster/dagster_tests/execution_tests/dynamic_tests/test_dynamic_execution.py @@ -571,17 +571,73 @@ def dyn_fork(): total.alias("grand_total")( [ total.alias("nums_total")(emit_dyn(nums).map(echo).collect()), - total.alias("empty_total")(emit_dyn(empty).map(echo).collect()), - total.alias("skip_total")(emit_dyn(skip).map(echo).collect()), + total.alias("empty_total")( + emit_dyn.alias("emit_dyn_empty")(empty).map(echo.alias("echo_empty")).collect() + ), + total.alias("skip_total")( + emit_dyn.alias("emit_dyn_skip")(skip).map(echo.alias("echo_skip")).collect() + ), ] ) result = dyn_fork.execute_in_process() assert result.success + skips = {ev.step_key for ev in result.get_step_skipped_events()} assert result.output_for_node("nums_total") assert result.output_for_node("empty_total") == 0 - assert result.output_for_node("skip_total") == 0 # arguably should skip + assert "skip_total" in skips assert result.output_for_node("grand_total") == 6 + + +def test_collect_optional(): + @op(out=Out(is_required=False)) + def optional_out_op(): + if False: # pylint: disable=using-constant-test + yield None + + @op(out=DynamicOut()) + def dynamic_out_op(_in): + yield DynamicOutput("a", "a") + + @op + def collect_op(_in): + # this assert gets hit + assert False + + @job + def job1(): + echo(collect_op(dynamic_out_op(optional_out_op()).collect())) + + result = job1.execute_in_process() + skips = {ev.step_key for ev in result.get_step_skipped_events()} + assert "dynamic_out_op" in skips + assert "collect_op" in skips + assert "echo" in skips + + +def test_non_required_dynamic_collect_skips(): + @op(out=DynamicOut(is_required=False)) + def producer(): + if False: # pylint: disable=using-constant-test + yield DynamicOutput("yay") + + @op + def consumer1(item): + pass + + @op + def consumer2(items): + pass + + @job() + def my_job(): + items = producer() + items.map(consumer1) + consumer2(items.collect()) + + result = my_job.execute_in_process() + skips = {ev.step_key for ev in result.get_step_skipped_events()} + assert "consumer2" in skips