Skip to content

Commit

Permalink
Fix scheduler transition error on memory->erred (#8549)
Browse files Browse the repository at this point in the history
Co-authored-by: crusaderky <crusaderky@gmail.com>
  • Loading branch information
hendrikmakait and crusaderky committed Mar 8, 2024
1 parent e16a7af commit 91350ab
Show file tree
Hide file tree
Showing 2 changed files with 164 additions and 6 deletions.
34 changes: 28 additions & 6 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1965,6 +1965,7 @@ def _transition(
)

v = a_recs.get(key, finish)
# The inner rec has higher priority? Is that always desired?
func = self._TRANSITIONS_TABLE["released", v]
b_recs, b_cmsgs, b_wmsgs = func(self, key, stimulus_id)

Expand Down Expand Up @@ -2083,7 +2084,11 @@ def _transition_released_waiting(self, key: Key, stimulus_id: str) -> RecsMsgs:
assert not ts.who_has
assert not ts.processing_on
for dts in ts.dependencies:
assert dts.state not in {"forgotten", "erred"}
assert dts.state not in {"forgotten", "erred"}, (
str(ts),
str(dts),
self.transition_log,
)

if ts.has_lost_dependencies:
return {key: "forgotten"}, {}, {}
Expand Down Expand Up @@ -2481,7 +2486,9 @@ def _transition_memory_released(
recommendations[key] = "forgotten"
elif ts.has_lost_dependencies:
recommendations[key] = "forgotten"
elif ts.who_wants or ts.waiters:
elif (ts.who_wants or ts.waiters) and not any(
dts.state == "erred" for dts in ts.dependencies
):
recommendations[key] = "waiting"

for dts in ts.waiters or ():
Expand All @@ -2506,14 +2513,13 @@ def _transition_released_erred(self, key: Key, stimulus_id: str) -> RecsMsgs:
assert ts.exception_blame
assert not ts.who_has
assert not ts.waiting_on
assert not ts.waiters

failing_ts = ts.exception_blame
assert failing_ts

for dts in ts.dependents:
dts.exception_blame = failing_ts
if not dts.who_has:
dts.exception_blame = failing_ts
recommendations[dts.key] = "erred"

report_msg = {
Expand Down Expand Up @@ -2548,6 +2554,9 @@ def _transition_erred_released(self, key: Key, stimulus_id: str) -> RecsMsgs:

for dts in ts.dependents:
if dts.state == "erred":
# Does this make sense?
# This goes via released
# dts -> released -> waiting
recommendations[dts.key] = "waiting"

w_msg = {
Expand Down Expand Up @@ -2622,8 +2631,8 @@ def _transition_processing_erred(
self,
key: Key,
stimulus_id: str,
*,
worker: str,
*,
cause: Key | None = None,
exception: Serialized | None = None,
traceback: Serialized | None = None,
Expand Down Expand Up @@ -2699,7 +2708,7 @@ def _transition_processing_erred(
)
)

for dts in ts.dependents:
for dts in ts.waiters or set():
dts.exception_blame = failing_ts
recommendations[dts.key] = "erred"

Expand Down Expand Up @@ -5040,6 +5049,19 @@ def stimulus_task_finished(
"stimulus_id": stimulus_id,
}
]
elif ts.state == "erred":
logger.debug(
"Received already erred task, worker: %s" ", key: %s",
worker,
key,
)
worker_msgs[worker] = [
{
"op": "free-keys",
"keys": [key],
"stimulus_id": stimulus_id,
}
]
elif ts.run_id != run_id:
if not ts.processing_on or ts.processing_on.address != worker:
logger.debug(
Expand Down
136 changes: 136 additions & 0 deletions distributed/tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4890,3 +4890,139 @@ async def test_resubmit_different_task_same_key_warns_only_once(

async with Worker(s.address):
assert await c.gather(zs) == [2, 3, 4] # Kept old ys


def block(x, in_event, block_event):
in_event.set()
block_event.wait()
return x


@gen_cluster(
client=True,
nthreads=[("", 1, {"resources": {"a": 1}})],
config={"distributed.scheduler.allowed-failures": 0},
)
async def test_fan_out_pattern_deadlock(c, s, a):
"""Regression test for https://github.com/dask/distributed/issues/8548
This test heavily uses resources to force scheduling decisions.
"""
in_f, block_f = Event(), Event()
in_ha, block_ha = Event(), Event()
in_hb, block_hb = Event(), Event()

# Input task to 'g' that we can fail
with dask.annotate(resources={"b": 1}):
f = delayed(block)(1, in_f, block_f, dask_key_name="f")
g = delayed(inc)(f, dask_key_name="g")

# Fan-out from 'g' and run h1 and h2 on different workers
hb = delayed(block)(g, in_hb, block_hb, dask_key_name="hb")
with dask.annotate(resources={"a": 1}):
ha = delayed(block)(g, in_ha, block_ha, dask_key_name="ha")

f, ha, hb = c.compute([f, ha, hb])
with captured_logger("distributed.scheduler", level=logging.ERROR) as logger:
async with Worker(s.address, nthreads=1, resources={"b": 1}) as b:
await block_f.set()
await in_ha.wait()
await in_hb.wait()
await in_f.clear()

# Make sure that the scheduler knows that both workers hold 'g' in memory
await async_poll_for(lambda: len(s.tasks["g"].who_has) == 2, timeout=5)
# Remove worker 'b' while it's processing h1
await s.remove_worker(b.address, stimulus_id="remove_b1")
await block_hb.set()
await block_f.clear()

# Remove the new instance of the 'b' worker while it processes 'f'
# to trigger an transition for 'f' to 'erred'
async with Worker(s.address, nthreads=1, resources={"b": 1}) as b:
await in_f.wait()
await in_f.clear()
await s.remove_worker(b.address, stimulus_id="remove_b2")
await block_f.set()
await block_f.clear()

await block_ha.set()
await ha

with pytest.raises(KilledWorker, match="Attempted to run task 'hb'"):
await hb

del ha, hb
# Make sure that h2 gets forgotten on worker 'a'
await async_poll_for(lambda: not a.state.tasks, timeout=5)
# Ensure that no other errors including transition failures were logged
assert (
logger.getvalue()
== "Task hb marked as failed because 1 workers died while trying to run it\nTask f marked as failed because 1 workers died while trying to run it\n"
)


@gen_cluster(
client=True,
nthreads=[("", 1, {"resources": {"a": 1}})],
config={"distributed.scheduler.allowed-failures": 0},
)
async def test_stimulus_from_erred_task(c, s, a):
"""This test heavily uses resources to force scheduling decisions."""
in_f, block_f = Event(), Event()
in_g, block_g = Event(), Event()

with dask.annotate(resources={"b": 1}):
f = delayed(block)(1, in_f, block_f, dask_key_name="f")

with dask.annotate(resources={"a": 1}):
g = delayed(block)(f, in_g, block_g, dask_key_name="g")

f, g = c.compute([f, g])
with captured_logger("distributed.scheduler", level=logging.ERROR) as logger:
frozen_stream_from_a_ctx = freeze_batched_send(a.batched_stream)
frozen_stream_from_a_ctx.__enter__()

async with Worker(s.address, nthreads=1, resources={"b": 1}) as b1:
await block_f.set()
await in_g.wait()
await in_f.clear()
frozen_stream_to_a_ctx = freeze_batched_send(s.stream_comms[a.address])
frozen_stream_to_a_ctx.__enter__()
await s.remove_worker(b1.address, stimulus_id="remove_b1")
await block_f.clear()

# Remove the new instance of the 'b' worker while it processes 'f'
# to trigger a transition for 'f' to 'erred'
async with Worker(s.address, nthreads=1, resources={"b": 1}) as b2:
await in_f.wait()
await in_f.clear()
await s.remove_worker(b2.address, stimulus_id="remove_b2")
await block_f.set()

with pytest.raises(KilledWorker, match="Attempted to run task 'f'"):
await f

# g has already been transitioned to 'erred' because 'f' failed
with pytest.raises(KilledWorker, match="Attempted to run task 'f'"):
await g

# Finish 'g' and let the scheduler know so it can trigger cleanup
await block_g.set()
with mock.patch.object(
s, "stimulus_task_finished", wraps=s.stimulus_task_finished
) as wrapped_stimulus:
frozen_stream_from_a_ctx.__exit__(None, None, None)
# Make sure the `stimulus_task_finished` gets processed
await async_poll_for(lambda: wrapped_stimulus.call_count == 1, timeout=5)

# Allow the scheduler to talk to the worker again
frozen_stream_to_a_ctx.__exit__(None, None, None)
# Make sure all data gets forgotten on worker 'a'
await async_poll_for(lambda: not a.state.tasks, timeout=5)

# Ensure that no other errors including transition failures were logged
assert (
logger.getvalue()
== "Task f marked as failed because 1 workers died while trying to run it\n"
)

0 comments on commit 91350ab

Please sign in to comment.