Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dynamic] fix how skips cascade #11561

Merged
merged 1 commit into from Jan 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -164,6 +164,9 @@ def asset_observations_for_node(self, node_name: str) -> Sequence[AssetObservati
def get_step_success_events(self) -> Sequence[DagsterEvent]:
return [event for event in self.all_events if event.is_step_success]

def get_step_skipped_events(self) -> Sequence[DagsterEvent]:
return [event for event in self.all_events if event.is_step_skipped]

def get_failed_step_keys(self) -> AbstractSet[str]:
failure_events = self.filter_events(
lambda event: event.is_step_failure or event.is_resource_init_failure
Expand Down
55 changes: 37 additions & 18 deletions python_modules/dagster/dagster/_core/execution/plan/active.py
Expand Up @@ -64,13 +64,16 @@ def __init__(

# track mapping keys from DynamicOutputs, step_key, output_name -> list of keys
# to _gathering while in flight
self._gathering_dynamic_outputs: Dict[str, Mapping[str, List[str]]] = {}
# then on success move to _successful
self._successful_dynamic_outputs: Dict[str, Mapping[str, Sequence[str]]] = (
self._gathering_dynamic_outputs: Dict[str, Mapping[str, Optional[List[str]]]] = {}
# then on resolution move to _completed
self._completed_dynamic_outputs: Dict[str, Mapping[str, Optional[Sequence[str]]]] = (
dict(self._plan.known_state.dynamic_mappings) if self._plan.known_state else {}
)
self._new_dynamic_mappings: bool = False

# track which upstream deps caused a step to skip
self._skipped_deps: Dict[str, Sequence[str]] = {}

# steps move in to these buckets as a result of _update calls
self._executable: List[str] = []
self._pending_skip: List[str] = []
Expand Down Expand Up @@ -160,7 +163,7 @@ def _update(self) -> None:
failed_or_abandoned_steps = self._failed | self._abandoned

if self._new_dynamic_mappings:
new_step_deps = self._plan.resolve(self._successful_dynamic_outputs)
new_step_deps = self._plan.resolve(self._completed_dynamic_outputs)
for step_key, deps in new_step_deps.items():
self._pending[step_key] = deps

Expand Down Expand Up @@ -192,6 +195,9 @@ def _update(self) -> None:
step_input.get_step_output_handle_dependencies()
):
should_skip = True
self._skipped_deps[step_key] = [
f"{h.step_key}.{h.output_name}" for h in missing_source_handles
]
break

if should_skip:
Expand Down Expand Up @@ -308,7 +314,8 @@ def get_steps_to_skip(self) -> Sequence[ExecutionStep]:
steps.append(step)
self._in_flight.add(key)
self._pending_skip.remove(key)
self._prep_for_dynamic_outputs(step)
self._gathering_dynamic_outputs
self._skip_for_dynamic_outputs(step)

return sorted(steps, key=self._sort_key_fn)

Expand All @@ -331,14 +338,9 @@ def plan_events_iterator(self, pipeline_context) -> Iterator[DagsterEvent]:
while steps_to_skip:
for step in steps_to_skip:
step_context = pipeline_context.for_step(step)
skipped_inputs: List[str] = []
for step_input in step.step_inputs:
skipped_inputs.extend(self._skipped.intersection(step_input.dependency_keys))

step_context.log.info(
"Skipping step {step} due to skipped dependencies: {skipped_inputs}.".format(
step=step.key, skipped_inputs=skipped_inputs
)
f"Skipping step {step.key} due to skipped dependencies:"
f" {self._skipped_deps[step.key]}."
)
yield DagsterEvent.step_skipped_event(step_context)

Expand Down Expand Up @@ -461,9 +463,11 @@ def handle_event(self, dagster_event: DagsterEvent) -> None:
event_specific_data = cast(StepOutputData, dagster_event.event_specific_data)
self.mark_step_produced_output(event_specific_data.step_output_handle)
if dagster_event.step_output_data.step_output_handle.mapping_key:
self._gathering_dynamic_outputs[step_key][
dagster_event.step_output_data.step_output_handle.output_name
].append(dagster_event.step_output_data.step_output_handle.mapping_key)
check.not_none(
self._gathering_dynamic_outputs[step_key][
dagster_event.step_output_data.step_output_handle.output_name
],
).append(dagster_event.step_output_data.step_output_handle.mapping_key)

def verify_complete(self, pipeline_context: PlanOrchestrationContext, step_key: str) -> None:
"""Ensure that a step has reached a terminal state, if it has not mark it as an unexpected failure
Expand Down Expand Up @@ -505,21 +509,36 @@ def retry_state(self) -> RetryState:
def get_known_state(self) -> KnownExecutionState:
return KnownExecutionState(
previous_retry_attempts=self._retry_state.snapshot_attempts(),
dynamic_mappings=dict(self._successful_dynamic_outputs),
dynamic_mappings=dict(self._completed_dynamic_outputs),
ready_outputs=self._step_outputs,
step_output_versions=self._plan.known_state.step_output_versions,
parent_state=self._plan.known_state.parent_state,
)

def _prep_for_dynamic_outputs(self, step: ExecutionStep):
dyn_outputs = [step_out for step_out in step.step_outputs if step_out.is_dynamic]

if dyn_outputs:
self._gathering_dynamic_outputs[step.key] = {out.name: [] for out in dyn_outputs}

def _skip_for_dynamic_outputs(self, step: ExecutionStep):
dyn_outputs = [step_out for step_out in step.step_outputs if step_out.is_dynamic]
if dyn_outputs:
# place None to indicate the dynamic output was skipped, different than having 0 entries
self._gathering_dynamic_outputs[step.key] = {out.name: None for out in dyn_outputs}

def _resolve_any_dynamic_outputs(self, step_key: str) -> None:
if step_key in self._gathering_dynamic_outputs:
self._successful_dynamic_outputs[step_key] = self._gathering_dynamic_outputs[step_key]
step = self.get_step_by_key(step_key)
completed_mappings: Dict[str, Optional[Sequence[str]]] = {}
for output_name, mappings in self._gathering_dynamic_outputs[step_key].items():
# if no dynamic outputs were returned and the output was marked is_required=False
# set to None to indicate a skip should occur
if not mappings and not step.step_output_dict[output_name].is_required:
completed_mappings[output_name] = None
else:
completed_mappings[output_name] = mappings

self._completed_dynamic_outputs[step_key] = completed_mappings
self._new_dynamic_mappings = True

def rebuild_from_events(
Expand Down
14 changes: 12 additions & 2 deletions python_modules/dagster/dagster/_core/execution/plan/inputs.py
Expand Up @@ -1037,7 +1037,17 @@ def get_step_output_handle_dep_with_placeholder(self) -> StepOutputHandle:
def required_resource_keys(self, _pipeline_def: PipelineDefinition) -> Set[str]:
return set()

def resolve(self, mapping_keys):
def resolve(self, mapping_keys: Optional[Sequence[str]]):
if mapping_keys is None:
# None means that the dynamic output was skipped, so create
# a dependency on the dynamic output that will continue cascading the skip
return FromStepOutput(
step_output_handle=StepOutputHandle(
step_key=self.resolved_by_step_key,
output_name=self.resolved_by_output_name,
),
fan_in=False,
)
return FromMultipleSources(
sources=[self.source.resolve(map_key) for map_key in mapping_keys],
)
Expand Down Expand Up @@ -1086,7 +1096,7 @@ def resolved_by_step_key(self) -> str:
def resolved_by_output_name(self) -> str:
return self.source.resolved_by_output_name

def resolve(self, mapping_keys: Sequence[str]) -> StepInput:
def resolve(self, mapping_keys: Optional[Sequence[str]]) -> StepInput:
return StepInput(
name=self.name,
dagster_type_key=self.dagster_type_key,
Expand Down
4 changes: 2 additions & 2 deletions python_modules/dagster/dagster/_core/execution/plan/plan.py
Expand Up @@ -770,7 +770,7 @@ def get_executable_step_deps(self) -> Mapping[str, Set[str]]:

def resolve(
self,
mappings: Mapping[str, Mapping[str, Sequence[str]]],
mappings: Mapping[str, Mapping[str, Optional[Sequence[str]]]],
) -> Mapping[str, Set[str]]:
"""Resolve any dynamic map or collect steps with the resolved dynamic mappings"""

Expand Down Expand Up @@ -1197,7 +1197,7 @@ def _update_from_resolved_dynamic_outputs(
executable_map: Dict[str, Union[StepHandle, ResolvedFromDynamicStepHandle]],
resolvable_map: Dict[FrozenSet[str], Sequence[Union[StepHandle, UnresolvedStepHandle]]],
step_handles_to_execute: Sequence[StepHandleUnion],
dynamic_mappings: Mapping[str, Mapping[str, Sequence[str]]],
dynamic_mappings: Mapping[str, Mapping[str, Optional[Sequence[str]]]],
) -> None:
resolved_steps: List[ExecutionStep] = []
key_sets_to_clear: List[FrozenSet[str]] = []
Expand Down
4 changes: 2 additions & 2 deletions python_modules/dagster/dagster/_core/execution/plan/state.py
Expand Up @@ -74,7 +74,7 @@ class KnownExecutionState(
# step_key -> count
("previous_retry_attempts", Mapping[str, int]),
# step_key -> output_name -> mapping_keys
("dynamic_mappings", Mapping[str, Mapping[str, Sequence[str]]]),
("dynamic_mappings", Mapping[str, Mapping[str, Optional[Sequence[str]]]]),
# step_output_handle -> version
("step_output_versions", Sequence[StepOutputVersionData]),
("ready_outputs", Set[StepOutputHandle]),
Expand All @@ -91,7 +91,7 @@ class KnownExecutionState(
def __new__(
cls,
previous_retry_attempts: Optional[Mapping[str, int]] = None,
dynamic_mappings: Optional[Mapping[str, Mapping[str, Sequence[str]]]] = None,
dynamic_mappings: Optional[Mapping[str, Mapping[str, Optional[Sequence[str]]]]] = None,
step_output_versions: Optional[Sequence[StepOutputVersionData]] = None,
ready_outputs: Optional[Set[StepOutputHandle]] = None,
parent_state: Optional[PastExecutionState] = None,
Expand Down
17 changes: 13 additions & 4 deletions python_modules/dagster/dagster/_core/execution/plan/step.py
Expand Up @@ -3,6 +3,7 @@
from typing import (
TYPE_CHECKING,
FrozenSet,
List,
Mapping,
NamedTuple,
Optional,
Expand Down Expand Up @@ -340,15 +341,21 @@ def resolved_by_step_keys(self) -> FrozenSet[str]:
return frozenset(keys)

def resolve(
self, mappings: Mapping[str, Mapping[str, Sequence[str]]]
self, mappings: Mapping[str, Mapping[str, Optional[Sequence[str]]]]
) -> Sequence[ExecutionStep]:
check.invariant(
all(key in mappings for key in self.resolved_by_step_keys),
"resolving with mappings that do not contain all required step keys",
)
execution_steps = []
execution_steps: List[ExecutionStep] = []

for mapped_key in mappings[self.resolved_by_step_key][self.resolved_by_output_name]:
mapping_keys = mappings[self.resolved_by_step_key][self.resolved_by_output_name]

# dynamic output skipped
if mapping_keys is None:
return execution_steps

for mapped_key in mapping_keys:
resolved_inputs = [_resolved_input(inp, mapped_key) for inp in self.step_inputs]

execution_steps.append(
Expand Down Expand Up @@ -472,7 +479,9 @@ def resolved_by_step_keys(self) -> FrozenSet[str]:

return frozenset(keys)

def resolve(self, mappings: Mapping[str, Mapping[str, Sequence[str]]]) -> ExecutionStep:
def resolve(
self, mappings: Mapping[str, Mapping[str, Optional[Sequence[str]]]]
) -> ExecutionStep:
check.invariant(
all(key in mappings for key in self.resolved_by_step_keys),
"resolving with mappings that do not contain all required step keys",
Expand Down
Expand Up @@ -571,17 +571,73 @@ def dyn_fork():
total.alias("grand_total")(
[
total.alias("nums_total")(emit_dyn(nums).map(echo).collect()),
total.alias("empty_total")(emit_dyn(empty).map(echo).collect()),
total.alias("skip_total")(emit_dyn(skip).map(echo).collect()),
total.alias("empty_total")(
emit_dyn.alias("emit_dyn_empty")(empty).map(echo.alias("echo_empty")).collect()
),
total.alias("skip_total")(
emit_dyn.alias("emit_dyn_skip")(skip).map(echo.alias("echo_skip")).collect()
),
]
)

result = dyn_fork.execute_in_process()
assert result.success
skips = {ev.step_key for ev in result.get_step_skipped_events()}

assert result.output_for_node("nums_total")
assert result.output_for_node("empty_total") == 0

assert result.output_for_node("skip_total") == 0 # arguably should skip
assert "skip_total" in skips

assert result.output_for_node("grand_total") == 6


def test_collect_optional():
@op(out=Out(is_required=False))
def optional_out_op():
if False: # pylint: disable=using-constant-test
yield None

@op(out=DynamicOut())
def dynamic_out_op(_in):
yield DynamicOutput("a", "a")

@op
def collect_op(_in):
# this assert gets hit
assert False

@job
def job1():
echo(collect_op(dynamic_out_op(optional_out_op()).collect()))

result = job1.execute_in_process()
skips = {ev.step_key for ev in result.get_step_skipped_events()}
assert "dynamic_out_op" in skips
assert "collect_op" in skips
assert "echo" in skips


def test_non_required_dynamic_collect_skips():
@op(out=DynamicOut(is_required=False))
def producer():
if False: # pylint: disable=using-constant-test
yield DynamicOutput("yay")

@op
def consumer1(item):
pass

@op
def consumer2(items):
pass

@job()
def my_job():
items = producer()
items.map(consumer1)
consumer2(items.collect())

result = my_job.execute_in_process()
skips = {ev.step_key for ev in result.get_step_skipped_events()}
assert "consumer2" in skips