Skip to content

Commit

Permalink
Warn if tasks are submitted with identical keys but different `run_sp…
Browse files Browse the repository at this point in the history
…ec` (#8185)

Co-authored-by: crusaderky <crusaderky@gmail.com>
Co-authored-by: Hendrik Makait <hendrik.makait@gmail.com>
  • Loading branch information
3 people committed Feb 16, 2024
1 parent 045dc64 commit d4380a7
Show file tree
Hide file tree
Showing 4 changed files with 195 additions and 5 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ jobs:
# Increase this value to reset cache if
# continuous_integration/environment-${{ matrix.environment }}.yaml has not
# changed. See also same variable in .pre-commit-config.yaml
CACHE_NUMBER: 0
CACHE_NUMBER: 1
id: cache

- name: Update environment
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -69,4 +69,4 @@ repos:

# Increase this value to clear the cache on GitHub actions if nothing else in this file
# has changed. See also same variable in .github/workflows/test.yaml
# CACHE_NUMBER: 0
# CACHE_NUMBER: 1
54 changes: 51 additions & 3 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@

import dask
import dask.utils
from dask.base import TokenizationError, normalize_token, tokenize
from dask.core import get_deps, iskey, validate_key
from dask.typing import Key, no_default
from dask.utils import (
Expand Down Expand Up @@ -4752,6 +4753,7 @@ def _generate_taskstates(
stack = list(keys)
touched_keys = set()
touched_tasks = []
tgs_with_bad_run_spec = set()
while stack:
k = stack.pop()
if k in touched_keys:
Expand All @@ -4772,9 +4774,55 @@ def _generate_taskstates(
# run_spec in the submitted graph may be None. This happens
# when an already persisted future is part of the graph
elif k in dsk:
# TODO run a health check to verify that run_spec and dependencies
# did not change. See https://github.com/dask/distributed/pull/8185
pass
# If both tokens are non-deterministic, skip comparison
try:
tok_lhs = tokenize(ts.run_spec, ensure_deterministic=True)
except TokenizationError:
tok_lhs = ""
try:
tok_rhs = tokenize(dsk[k], ensure_deterministic=True)
except TokenizationError:
tok_rhs = ""

# Additionally check dependency names. This should only be necessary
# if run_specs can't be tokenized deterministically.
deps_lhs = {dts.key for dts in ts.dependencies}
deps_rhs = dependencies.get(k, set())

# FIXME It would be a really healthy idea to change this to a hard
# failure. However, this is not possible at the moment because of
# https://github.com/dask/dask/issues/9888
if tok_lhs != tok_rhs or deps_lhs != deps_rhs:
if ts.group not in tgs_with_bad_run_spec:
tgs_with_bad_run_spec.add(ts.group)
logger.warning(
f"Detected different `run_spec` for key {ts.key!r} between "
"two consecutive calls to `update_graph`. "
"This can cause failures and deadlocks down the line. "
"Please ensure unique key names. "
"If you are using a standard dask collections, consider "
"releasing all the data before resubmitting another "
"computation. More details and help can be found at "
"https://github.com/dask/dask/issues/9888. "
+ textwrap.dedent(
f"""
Debugging information
---------------------
old task state: {ts.state}
old run_spec: {ts.run_spec!r}
new run_spec: {dsk[k]!r}
old token: {normalize_token(ts.run_spec)!r}
new token: {normalize_token(dsk[k])!r}
old dependencies: {deps_lhs}
new dependencies: {deps_rhs}
"""
)
)
else:
logger.debug(
f"Detected different `run_spec` for key {ts.key!r} between "
"two consecutive calls to `update_graph`."
)

if ts.run_spec:
runnable.append(ts)
Expand Down
142 changes: 142 additions & 0 deletions distributed/tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4702,3 +4702,145 @@ async def test_html_repr(c, s, a, b):
await asyncio.sleep(0.01)

await f


@pytest.mark.parametrize("add_deps", [False, True])
@gen_cluster(client=True, nthreads=[])
async def test_resubmit_different_task_same_key(c, s, add_deps):
"""If an intermediate key has a different run_spec (either the callable function or
the dependencies / arguments) that will conflict with what was previously defined,
it should raise an error since this can otherwise break in many different places and
cause either spurious exceptions or even deadlocks.
For a real world example where this can trigger, see
https://github.com/dask/dask/issues/9888
"""
y1 = c.submit(inc, 1, key="y")

x = delayed(inc)(1, dask_key_name="x") if add_deps else 2
y2 = delayed(inc)(x, dask_key_name="y")
z = delayed(inc)(y2, dask_key_name="z")

if add_deps: # add_deps=True corrupts the state machine
s.validate = False

with captured_logger("distributed.scheduler", level=logging.WARNING) as log:
fut = c.compute(z)
await wait_for_state("z", "waiting", s)

assert "Detected different `run_spec` for key 'y'" in log.getvalue()

if not add_deps: # add_deps=True hangs
async with Worker(s.address):
assert await y1 == 2
assert await fut == 3


@gen_cluster(client=True, nthreads=[])
async def test_resubmit_different_task_same_key_many_clients(c, s):
"""Two different clients submit a task with the same key but different run_spec's."""
async with Client(s.address, asynchronous=True) as c2:
with captured_logger("distributed.scheduler", level=logging.WARNING) as log:
x1 = c.submit(inc, 1, key="x")
x2 = c2.submit(inc, 2, key="x")

await wait_for_state("x", ("no-worker", "queued"), s)
who_wants = s.tasks["x"].who_wants
await async_poll_for(
lambda: {cs.client_key for cs in who_wants} == {c.id, c2.id}, timeout=5
)

assert "Detected different `run_spec` for key 'x'" in log.getvalue()

async with Worker(s.address):
assert await x1 == 2
assert await x2 == 2 # kept old run_spec


@pytest.mark.parametrize(
"before,after,expect_msg",
[
(object(), 123, True),
(123, object(), True),
(o := object(), o, False),
],
)
@gen_cluster(client=True, nthreads=[])
async def test_resubmit_nondeterministic_task_same_deps(
c, s, before, after, expect_msg
):
"""Some run_specs can't be tokenized deterministically. Silently skip comparison on
the run_spec when both lhs and rhs are nondeterministic.
Dependencies must be the same.
"""
x1 = c.submit(lambda x: x, before, key="x")
x2 = delayed(lambda x: x)(after, dask_key_name="x")
y = delayed(lambda x: x)(x2, dask_key_name="y")

with captured_logger("distributed.scheduler", level=logging.WARNING) as log:
fut = c.compute(y)
await async_poll_for(lambda: "y" in s.tasks, timeout=5)

has_msg = "Detected different `run_spec` for key 'x'" in log.getvalue()
assert has_msg == expect_msg

async with Worker(s.address):
assert type(await fut) is type(before)


@pytest.mark.parametrize("add_deps", [False, True])
@gen_cluster(client=True, nthreads=[])
async def test_resubmit_nondeterministic_task_different_deps(c, s, add_deps):
"""Some run_specs can't be tokenized deterministically. Silently skip comparison on
the run_spec in those cases. However, fail anyway if dependencies have changed.
"""
o = object()
x1 = c.submit(inc, 1, key="x1") if not add_deps else 2
x2 = c.submit(inc, 2, key="x2")
y1 = delayed(lambda i, j: i)(x1, o, dask_key_name="y").persist()
y2 = delayed(lambda i, j: i)(x2, o, dask_key_name="y")
z = delayed(inc)(y2, dask_key_name="z")

if add_deps: # add_deps=True corrupts the state machine and hangs
s.validate = False

with captured_logger("distributed.scheduler", level=logging.WARNING) as log:
fut = c.compute(z)
await wait_for_state("z", "waiting", s)
assert "Detected different `run_spec` for key 'y'" in log.getvalue()

if not add_deps: # add_deps=True corrupts the state machine and hangs
async with Worker(s.address):
assert await fut == 3


@pytest.mark.parametrize(
"loglevel,expect_loglines", [(logging.DEBUG, 3), (logging.WARNING, 1)]
)
@gen_cluster(client=True, nthreads=[])
async def test_resubmit_different_task_same_key_warns_only_once(
c, s, loglevel, expect_loglines
):
"""If all tasks of a layer are affected by the same run_spec collision, warn
only once.
"""
x1s = c.map(inc, [0, 1, 2], key=[("x", 0), ("x", 1), ("x", 2)])
dsk = {
("x", 0): 3,
("x", 1): 4,
("x", 2): 5,
("y", 0): (inc, ("x", 0)),
("y", 1): (inc, ("x", 1)),
("y", 2): (inc, ("x", 2)),
}
with captured_logger("distributed.scheduler", level=loglevel) as log:
ys = c.get(dsk, [("y", 0), ("y", 1), ("y", 2)], sync=False)
await wait_for_state(("y", 2), "waiting", s)

actual_loglines = len(
re.findall("Detected different `run_spec` for key ", log.getvalue())
)
assert actual_loglines == expect_loglines

async with Worker(s.address):
assert await c.gather(ys) == [2, 3, 4]

0 comments on commit d4380a7

Please sign in to comment.