Skip to content

Commit

Permalink
[dynamic] fix how skips cascade (#11561)
Browse files Browse the repository at this point in the history
resolves #10292
resolves #5948

### How I Tested These Changes

added tests cases
  • Loading branch information
alangenfeld committed Jan 9, 2023
1 parent 260e3e3 commit 29b9d48
Show file tree
Hide file tree
Showing 7 changed files with 128 additions and 31 deletions.
Expand Up @@ -164,6 +164,9 @@ def asset_observations_for_node(self, node_name: str) -> Sequence[AssetObservati
def get_step_success_events(self) -> Sequence[DagsterEvent]:
return [event for event in self.all_events if event.is_step_success]

def get_step_skipped_events(self) -> Sequence[DagsterEvent]:
return [event for event in self.all_events if event.is_step_skipped]

def get_failed_step_keys(self) -> AbstractSet[str]:
failure_events = self.filter_events(
lambda event: event.is_step_failure or event.is_resource_init_failure
Expand Down
55 changes: 37 additions & 18 deletions python_modules/dagster/dagster/_core/execution/plan/active.py
Expand Up @@ -64,13 +64,16 @@ def __init__(

# track mapping keys from DynamicOutputs, step_key, output_name -> list of keys
# to _gathering while in flight
self._gathering_dynamic_outputs: Dict[str, Mapping[str, List[str]]] = {}
# then on success move to _successful
self._successful_dynamic_outputs: Dict[str, Mapping[str, Sequence[str]]] = (
self._gathering_dynamic_outputs: Dict[str, Mapping[str, Optional[List[str]]]] = {}
# then on resolution move to _completed
self._completed_dynamic_outputs: Dict[str, Mapping[str, Optional[Sequence[str]]]] = (
dict(self._plan.known_state.dynamic_mappings) if self._plan.known_state else {}
)
self._new_dynamic_mappings: bool = False

# track which upstream deps caused a step to skip
self._skipped_deps: Dict[str, Sequence[str]] = {}

# steps move in to these buckets as a result of _update calls
self._executable: List[str] = []
self._pending_skip: List[str] = []
Expand Down Expand Up @@ -160,7 +163,7 @@ def _update(self) -> None:
failed_or_abandoned_steps = self._failed | self._abandoned

if self._new_dynamic_mappings:
new_step_deps = self._plan.resolve(self._successful_dynamic_outputs)
new_step_deps = self._plan.resolve(self._completed_dynamic_outputs)
for step_key, deps in new_step_deps.items():
self._pending[step_key] = deps

Expand Down Expand Up @@ -192,6 +195,9 @@ def _update(self) -> None:
step_input.get_step_output_handle_dependencies()
):
should_skip = True
self._skipped_deps[step_key] = [
f"{h.step_key}.{h.output_name}" for h in missing_source_handles
]
break

if should_skip:
Expand Down Expand Up @@ -308,7 +314,8 @@ def get_steps_to_skip(self) -> Sequence[ExecutionStep]:
steps.append(step)
self._in_flight.add(key)
self._pending_skip.remove(key)
self._prep_for_dynamic_outputs(step)
self._gathering_dynamic_outputs
self._skip_for_dynamic_outputs(step)

return sorted(steps, key=self._sort_key_fn)

Expand All @@ -331,14 +338,9 @@ def plan_events_iterator(self, pipeline_context) -> Iterator[DagsterEvent]:
while steps_to_skip:
for step in steps_to_skip:
step_context = pipeline_context.for_step(step)
skipped_inputs: List[str] = []
for step_input in step.step_inputs:
skipped_inputs.extend(self._skipped.intersection(step_input.dependency_keys))

step_context.log.info(
"Skipping step {step} due to skipped dependencies: {skipped_inputs}.".format(
step=step.key, skipped_inputs=skipped_inputs
)
f"Skipping step {step.key} due to skipped dependencies:"
f" {self._skipped_deps[step.key]}."
)
yield DagsterEvent.step_skipped_event(step_context)

Expand Down Expand Up @@ -461,9 +463,11 @@ def handle_event(self, dagster_event: DagsterEvent) -> None:
event_specific_data = cast(StepOutputData, dagster_event.event_specific_data)
self.mark_step_produced_output(event_specific_data.step_output_handle)
if dagster_event.step_output_data.step_output_handle.mapping_key:
self._gathering_dynamic_outputs[step_key][
dagster_event.step_output_data.step_output_handle.output_name
].append(dagster_event.step_output_data.step_output_handle.mapping_key)
check.not_none(
self._gathering_dynamic_outputs[step_key][
dagster_event.step_output_data.step_output_handle.output_name
],
).append(dagster_event.step_output_data.step_output_handle.mapping_key)

def verify_complete(self, pipeline_context: PlanOrchestrationContext, step_key: str) -> None:
"""Ensure that a step has reached a terminal state, if it has not mark it as an unexpected failure
Expand Down Expand Up @@ -505,21 +509,36 @@ def retry_state(self) -> RetryState:
def get_known_state(self) -> KnownExecutionState:
return KnownExecutionState(
previous_retry_attempts=self._retry_state.snapshot_attempts(),
dynamic_mappings=dict(self._successful_dynamic_outputs),
dynamic_mappings=dict(self._completed_dynamic_outputs),
ready_outputs=self._step_outputs,
step_output_versions=self._plan.known_state.step_output_versions,
parent_state=self._plan.known_state.parent_state,
)

def _prep_for_dynamic_outputs(self, step: ExecutionStep):
dyn_outputs = [step_out for step_out in step.step_outputs if step_out.is_dynamic]

if dyn_outputs:
self._gathering_dynamic_outputs[step.key] = {out.name: [] for out in dyn_outputs}

def _skip_for_dynamic_outputs(self, step: ExecutionStep):
dyn_outputs = [step_out for step_out in step.step_outputs if step_out.is_dynamic]
if dyn_outputs:
# place None to indicate the dynamic output was skipped, different than having 0 entries
self._gathering_dynamic_outputs[step.key] = {out.name: None for out in dyn_outputs}

def _resolve_any_dynamic_outputs(self, step_key: str) -> None:
if step_key in self._gathering_dynamic_outputs:
self._successful_dynamic_outputs[step_key] = self._gathering_dynamic_outputs[step_key]
step = self.get_step_by_key(step_key)
completed_mappings: Dict[str, Optional[Sequence[str]]] = {}
for output_name, mappings in self._gathering_dynamic_outputs[step_key].items():
# if no dynamic outputs were returned and the output was marked is_required=False
# set to None to indicate a skip should occur
if not mappings and not step.step_output_dict[output_name].is_required:
completed_mappings[output_name] = None
else:
completed_mappings[output_name] = mappings

self._completed_dynamic_outputs[step_key] = completed_mappings
self._new_dynamic_mappings = True

def rebuild_from_events(
Expand Down
14 changes: 12 additions & 2 deletions python_modules/dagster/dagster/_core/execution/plan/inputs.py
Expand Up @@ -1037,7 +1037,17 @@ def get_step_output_handle_dep_with_placeholder(self) -> StepOutputHandle:
def required_resource_keys(self, _pipeline_def: PipelineDefinition) -> Set[str]:
return set()

def resolve(self, mapping_keys):
def resolve(self, mapping_keys: Optional[Sequence[str]]):
if mapping_keys is None:
# None means that the dynamic output was skipped, so create
# a dependency on the dynamic output that will continue cascading the skip
return FromStepOutput(
step_output_handle=StepOutputHandle(
step_key=self.resolved_by_step_key,
output_name=self.resolved_by_output_name,
),
fan_in=False,
)
return FromMultipleSources(
sources=[self.source.resolve(map_key) for map_key in mapping_keys],
)
Expand Down Expand Up @@ -1086,7 +1096,7 @@ def resolved_by_step_key(self) -> str:
def resolved_by_output_name(self) -> str:
return self.source.resolved_by_output_name

def resolve(self, mapping_keys: Sequence[str]) -> StepInput:
def resolve(self, mapping_keys: Optional[Sequence[str]]) -> StepInput:
return StepInput(
name=self.name,
dagster_type_key=self.dagster_type_key,
Expand Down
4 changes: 2 additions & 2 deletions python_modules/dagster/dagster/_core/execution/plan/plan.py
Expand Up @@ -770,7 +770,7 @@ def get_executable_step_deps(self) -> Mapping[str, Set[str]]:

def resolve(
self,
mappings: Mapping[str, Mapping[str, Sequence[str]]],
mappings: Mapping[str, Mapping[str, Optional[Sequence[str]]]],
) -> Mapping[str, Set[str]]:
"""Resolve any dynamic map or collect steps with the resolved dynamic mappings"""

Expand Down Expand Up @@ -1197,7 +1197,7 @@ def _update_from_resolved_dynamic_outputs(
executable_map: Dict[str, Union[StepHandle, ResolvedFromDynamicStepHandle]],
resolvable_map: Dict[FrozenSet[str], Sequence[Union[StepHandle, UnresolvedStepHandle]]],
step_handles_to_execute: Sequence[StepHandleUnion],
dynamic_mappings: Mapping[str, Mapping[str, Sequence[str]]],
dynamic_mappings: Mapping[str, Mapping[str, Optional[Sequence[str]]]],
) -> None:
resolved_steps: List[ExecutionStep] = []
key_sets_to_clear: List[FrozenSet[str]] = []
Expand Down
4 changes: 2 additions & 2 deletions python_modules/dagster/dagster/_core/execution/plan/state.py
Expand Up @@ -74,7 +74,7 @@ class KnownExecutionState(
# step_key -> count
("previous_retry_attempts", Mapping[str, int]),
# step_key -> output_name -> mapping_keys
("dynamic_mappings", Mapping[str, Mapping[str, Sequence[str]]]),
("dynamic_mappings", Mapping[str, Mapping[str, Optional[Sequence[str]]]]),
# step_output_handle -> version
("step_output_versions", Sequence[StepOutputVersionData]),
("ready_outputs", Set[StepOutputHandle]),
Expand All @@ -91,7 +91,7 @@ class KnownExecutionState(
def __new__(
cls,
previous_retry_attempts: Optional[Mapping[str, int]] = None,
dynamic_mappings: Optional[Mapping[str, Mapping[str, Sequence[str]]]] = None,
dynamic_mappings: Optional[Mapping[str, Mapping[str, Optional[Sequence[str]]]]] = None,
step_output_versions: Optional[Sequence[StepOutputVersionData]] = None,
ready_outputs: Optional[Set[StepOutputHandle]] = None,
parent_state: Optional[PastExecutionState] = None,
Expand Down
17 changes: 13 additions & 4 deletions python_modules/dagster/dagster/_core/execution/plan/step.py
Expand Up @@ -3,6 +3,7 @@
from typing import (
TYPE_CHECKING,
FrozenSet,
List,
Mapping,
NamedTuple,
Optional,
Expand Down Expand Up @@ -340,15 +341,21 @@ def resolved_by_step_keys(self) -> FrozenSet[str]:
return frozenset(keys)

def resolve(
self, mappings: Mapping[str, Mapping[str, Sequence[str]]]
self, mappings: Mapping[str, Mapping[str, Optional[Sequence[str]]]]
) -> Sequence[ExecutionStep]:
check.invariant(
all(key in mappings for key in self.resolved_by_step_keys),
"resolving with mappings that do not contain all required step keys",
)
execution_steps = []
execution_steps: List[ExecutionStep] = []

for mapped_key in mappings[self.resolved_by_step_key][self.resolved_by_output_name]:
mapping_keys = mappings[self.resolved_by_step_key][self.resolved_by_output_name]

# dynamic output skipped
if mapping_keys is None:
return execution_steps

for mapped_key in mapping_keys:
resolved_inputs = [_resolved_input(inp, mapped_key) for inp in self.step_inputs]

execution_steps.append(
Expand Down Expand Up @@ -472,7 +479,9 @@ def resolved_by_step_keys(self) -> FrozenSet[str]:

return frozenset(keys)

def resolve(self, mappings: Mapping[str, Mapping[str, Sequence[str]]]) -> ExecutionStep:
def resolve(
self, mappings: Mapping[str, Mapping[str, Optional[Sequence[str]]]]
) -> ExecutionStep:
check.invariant(
all(key in mappings for key in self.resolved_by_step_keys),
"resolving with mappings that do not contain all required step keys",
Expand Down
Expand Up @@ -571,17 +571,73 @@ def dyn_fork():
total.alias("grand_total")(
[
total.alias("nums_total")(emit_dyn(nums).map(echo).collect()),
total.alias("empty_total")(emit_dyn(empty).map(echo).collect()),
total.alias("skip_total")(emit_dyn(skip).map(echo).collect()),
total.alias("empty_total")(
emit_dyn.alias("emit_dyn_empty")(empty).map(echo.alias("echo_empty")).collect()
),
total.alias("skip_total")(
emit_dyn.alias("emit_dyn_skip")(skip).map(echo.alias("echo_skip")).collect()
),
]
)

result = dyn_fork.execute_in_process()
assert result.success
skips = {ev.step_key for ev in result.get_step_skipped_events()}

assert result.output_for_node("nums_total")
assert result.output_for_node("empty_total") == 0

assert result.output_for_node("skip_total") == 0 # arguably should skip
assert "skip_total" in skips

assert result.output_for_node("grand_total") == 6


def test_collect_optional():
@op(out=Out(is_required=False))
def optional_out_op():
if False: # pylint: disable=using-constant-test
yield None

@op(out=DynamicOut())
def dynamic_out_op(_in):
yield DynamicOutput("a", "a")

@op
def collect_op(_in):
# this assert gets hit
assert False

@job
def job1():
echo(collect_op(dynamic_out_op(optional_out_op()).collect()))

result = job1.execute_in_process()
skips = {ev.step_key for ev in result.get_step_skipped_events()}
assert "dynamic_out_op" in skips
assert "collect_op" in skips
assert "echo" in skips


def test_non_required_dynamic_collect_skips():
@op(out=DynamicOut(is_required=False))
def producer():
if False: # pylint: disable=using-constant-test
yield DynamicOutput("yay")

@op
def consumer1(item):
pass

@op
def consumer2(items):
pass

@job()
def my_job():
items = producer()
items.map(consumer1)
consumer2(items.collect())

result = my_job.execute_in_process()
skips = {ev.step_key for ev in result.get_step_skipped_events()}
assert "consumer2" in skips

0 comments on commit 29b9d48

Please sign in to comment.