Skip to content

Commit

Permalink
Tweaks to update_graph (backport from dask#8185)
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed Feb 6, 2024
1 parent 4425516 commit cbdb6de
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 23 deletions.
2 changes: 1 addition & 1 deletion distributed/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1576,7 +1576,7 @@ async def _handle_report(self):

breakout = False
for msg in msgs:
logger.debug("Client receives message %s", msg)
logger.debug("Client %s receives message %s", self.id, msg)

if "status" in msg and "error" in msg["status"]:
typ, exc, tb = clean_exception(**msg)
Expand Down
64 changes: 42 additions & 22 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@
from tornado.ioloop import IOLoop

import dask
from dask.core import get_deps, validate_key
import dask.utils
from dask.core import get_deps, iskey, validate_key
from dask.typing import Key, no_default
from dask.utils import (
ensure_dict,
Expand Down Expand Up @@ -4721,6 +4722,7 @@ async def update_graph(
stimulus_id=stimulus_id or f"update-graph-{start}",
)
except RuntimeError as e:
logger.error(str(e))
err = error_message(e)
for key in keys:
self.report(
Expand All @@ -4729,7 +4731,10 @@ async def update_graph(
"key": key,
"exception": err["exception"],
"traceback": err["traceback"],
}
},
# This informs all clients in who_wants plus the current client
# (which may not have been added to who_wants yet)
client=client,
)
end = time()
self.digest_metric("update-graph-duration", end - start)
Expand All @@ -4755,8 +4760,21 @@ def _generate_taskstates(
if ts is None:
ts = self.new_task(k, dsk.get(k), "released", computation=computation)
new_tasks.append(ts)
elif not ts.run_spec:
# It is possible to create the TaskState object before its runspec is known
# to the scheduler. For instance, this is possible when using a Variable:
# `f = c.submit(foo); await Variable().set(f)` since the Variable uses a
# different comm channel, so the `client_desires_key` message could arrive
# before `update_graph`.
# There are also anti-pattern processes possible;
# see for example test_scatter_creates_ts
elif ts.run_spec is None:
ts.run_spec = dsk.get(k)
# run_spec in the submitted graph may be None. This happens
# when an already persisted future is part of the graph
elif k in dsk:
# TODO run a health check to verify that run_spec and dependencies
# did not change. See https://github.com/dask/distributed/pull/8185
pass

if ts.run_spec:
runnable.append(ts)
Expand Down Expand Up @@ -5538,28 +5556,28 @@ def report(
tasks: dict = self.tasks
ts = tasks.get(msg_key)

client_comms: dict = self.client_comms
if ts is None:
if ts is None and client is None:
# Notify all clients
client_keys = list(client_comms)
elif client:
# Notify clients interested in key
client_keys = [cs.client_key for cs in ts.who_wants or ()]
client_keys = list(self.client_comms)
elif ts is None:
client_keys = [client]
else:
# Notify clients interested in key (including `client`)
# Note that, if report() was called by update_graph(), `client` won't be in
# ts.who_wants yet.
client_keys = [
cs.client_key for cs in ts.who_wants or () if cs.client_key != client
]
client_keys.append(client)
if client is not None:
client_keys.append(client)

k: str
for k in client_keys:
c = client_comms.get(k)
c = self.client_comms.get(k)
if c is None:
continue
try:
c.send(msg)
# logger.debug("Scheduler sends message to client %s", msg)
# logger.debug("Scheduler sends message to client %s: %s", k, msg)
except CommClosedError:
if self.status == Status.running:
logger.critical(
Expand Down Expand Up @@ -8724,26 +8742,28 @@ def _materialize_graph(
dsk2 = {}
fut_deps = {}
for k, v in dsk.items():
dsk2[k], futs = unpack_remotedata(v, byte_keys=True)
v, futs = unpack_remotedata(v, byte_keys=True)
if futs:
fut_deps[k] = futs

# Remove aliases {x: x}.
# FIXME: This is an artifact generated by unpack_remotedata when using persisted
# collections. There should be a better way to achieve that tasks are not self
# referencing themselves.
if not iskey(v) or v != k:
dsk2[k] = v

dsk = dsk2

# - Add in deps for any tasks that depend on futures
for k, futures in fut_deps.items():
dependencies[k].update(f.key for f in futures)
dependencies[k].update(f.key for f in futures if f.key != k)

# Remove any self-dependencies (happens on test_publish_bag() and others)
for k, v in dependencies.items():
deps = set(v)
if k in deps:
deps.remove(k)
deps.discard(k)
dependencies[k] = deps

# Remove aliases
for k in list(dsk):
if dsk[k] is k:
del dsk[k]
dsk = valmap(_normalize_task, dsk)

return dsk, dependencies, annotations_by_type

0 comments on commit cbdb6de

Please sign in to comment.