diff --git a/distributed/client.py b/distributed/client.py index 0f1db78a47..462a12e327 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -1930,7 +1930,7 @@ def submit( [skey], workers=workers, allow_other_workers=allow_other_workers, - priority={skey: 0}, + internal_priority={skey: 0}, user_priority=priority, resources=resources, retries=retries, @@ -2134,7 +2134,7 @@ def map( keys, workers=workers, allow_other_workers=allow_other_workers, - priority=internal_priority, + internal_priority=internal_priority, resources=resources, retries=retries, user_priority=priority, @@ -3035,7 +3035,7 @@ def _graph_to_futures( keys, workers=None, allow_other_workers=None, - priority=None, + internal_priority=None, user_priority=0, resources=None, retries=None, @@ -3071,21 +3071,33 @@ def _graph_to_futures( # Pack the high level graph before sending it to the scheduler keyset = set(keys) - dsk = dsk.__dask_distributed_pack__(self, keyset, annotations) # Create futures before sending graph (helps avoid contention) futures = {key: Future(key, self, inform=False) for key in keyset} - + # Circular import + from distributed.protocol import serialize + from distributed.protocol.serialize import ToPickle + + header, frames = serialize(ToPickle(dsk), on_error="raise") + nbytes = len(header) + sum(map(len, frames)) + if nbytes > 10_000_000: + warnings.warn( + f"Sending large graph of size {format_bytes(nbytes)}.\n" + "This may cause some slowdown.\n" + "Consider scattering data ahead of time and using futures." + ) self._send_to_scheduler( { - "op": "update-graph-hlg", - "hlg": dsk, + "op": "update-graph", + "graph_header": header, + "graph_frames": frames, "keys": list(map(stringify, keys)), - "priority": priority, + "internal_priority": internal_priority, "submitting_task": getattr(thread_state, "key", None), "fifo_timeout": fifo_timeout, "actors": actors, "code": self._get_computation_code(), + "annotations": ToPickle(annotations), } ) return futures diff --git a/distributed/diagnostics/graph_layout.py b/distributed/diagnostics/graph_layout.py index f0f438e48c..ba8003467a 100644 --- a/distributed/diagnostics/graph_layout.py +++ b/distributed/diagnostics/graph_layout.py @@ -46,7 +46,7 @@ def __init__(self, scheduler): ) def update_graph( - self, scheduler, dependencies=None, priority=None, tasks=None, **kwargs + self, scheduler, *, dependencies=None, priority=None, tasks=None, **kwargs ): stack = sorted(tasks, key=lambda k: priority.get(k, 0), reverse=True) while stack: diff --git a/distributed/diagnostics/plugin.py b/distributed/diagnostics/plugin.py index 0434f8c8f3..de8cc79bc7 100644 --- a/distributed/diagnostics/plugin.py +++ b/distributed/diagnostics/plugin.py @@ -74,11 +74,45 @@ async def close(self) -> None: def update_graph( self, scheduler: Scheduler, + *, + client: str, keys: set[str], - restrictions: dict[str, float], + tasks: list[str], + annotations: dict[str, dict[str, Any]], + priority: dict[str, tuple[int | float, ...]], + dependencies: dict[str, set], **kwargs: Any, ) -> None: - """Run when a new graph / tasks enter the scheduler""" + """Run when a new graph / tasks enter the scheduler + + Parameters + ---------- + scheduler: + The `Scheduler` instance. + client: + The unique Client id. + keys: + The keys the Client is interested in when calling `update_graph`. + tasks: + The + annotations: + Fully resolved annotations as applied to the tasks in the format:: + + { + "annotation": { + "key": "value, + ... + }, + ... + } + priority: + Task calculated priorities as assigned to the tasks. + dependencies: + A mapping that maps a key to its dependencies. + **kwargs: + It is recommended to allow plugins to accept more parameters to + ensure future compatibility. + """ def restart(self, scheduler: Scheduler) -> None: """Run when the scheduler restarts itself""" diff --git a/distributed/diagnostics/tests/test_scheduler_plugin.py b/distributed/diagnostics/tests/test_scheduler_plugin.py index 877882470a..a4f661b6b6 100644 --- a/distributed/diagnostics/tests/test_scheduler_plugin.py +++ b/distributed/diagnostics/tests/test_scheduler_plugin.py @@ -243,3 +243,97 @@ async def close(self): assert "BEFORE_CLOSE" in text text = logger.getvalue() assert "AFTER_CLOSE" in text + + +@gen_cluster(client=True) +async def test_update_graph_hook_simple(c, s, a, b): + class UpdateGraph(SchedulerPlugin): + def __init__(self) -> None: + self.success = False + + def update_graph( # type: ignore + self, + scheduler, + client, + keys, + tasks, + annotations, + priority, + dependencies, + **kwargs, + ) -> None: + assert scheduler is s + assert client == c.id + # If new parameters are added we should add a test + assert not kwargs + assert keys == {"foo"} + assert tasks == ["foo"] + assert annotations == {} + assert len(priority) == 1 + assert isinstance(priority["foo"], tuple) + assert dependencies == {"foo": set()} + self.success = True + + plugin = UpdateGraph() + s.add_plugin(plugin, name="update-graph") + + await c.submit(inc, 5, key="foo") + assert plugin.success + + +import dask +from dask import delayed + + +@gen_cluster(client=True) +async def test_update_graph_hook_complex(c, s, a, b): + class UpdateGraph(SchedulerPlugin): + def __init__(self) -> None: + self.success = False + + def update_graph( # type: ignore + self, + scheduler, + client, + keys, + tasks, + annotations, + priority, + dependencies, + **kwargs, + ) -> None: + assert scheduler is s + assert client == c.id + # If new parameters are added we should add a test + assert not kwargs + assert keys == {"sum"} + assert set(tasks) == {"sum", "f1", "f3", "f2"} + assert annotations == { + "global_annot": {k: 24 for k in tasks}, + "layer": {"f2": "explicit"}, + "len_key": {"f3": 2}, + "priority": {"f2": 13}, + } + assert len(priority) == len(tasks), priority + assert priority["f2"][0] == -13 + for k in keys: + assert k in dependencies + assert dependencies["f1"] == set() + assert dependencies["sum"] == {"f1", "f3"} + + self.success = True + + plugin = UpdateGraph() + s.add_plugin(plugin, name="update-graph") + del_inc = delayed(inc) + f1 = del_inc(1, dask_key_name="f1") + with dask.annotate(layer="explicit", priority=13): + f2 = del_inc(2, dask_key_name="f2") + with dask.annotate(len_key=lambda x: len(x)): + f3 = del_inc(f2, dask_key_name="f3") + + f4 = delayed(sum)([f1, f3], dask_key_name="sum") + + with dask.annotate(global_annot=24): + await c.compute(f4) + assert plugin.success diff --git a/distributed/diagnostics/tests/test_widgets.py b/distributed/diagnostics/tests/test_widgets.py index 34dff54c09..6821ff50be 100644 --- a/distributed/diagnostics/tests/test_widgets.py +++ b/distributed/diagnostics/tests/test_widgets.py @@ -7,11 +7,9 @@ import pytest from packaging.version import parse as parse_version -from tlz import valmap from distributed.client import wait from distributed.utils_test import dec, gen_cluster, gen_tls_cluster, inc, throws -from distributed.worker import dumps_task ipywidgets = pytest.importorskip("ipywidgets") @@ -141,38 +139,6 @@ async def test_multi_progressbar_widget(c, s, a, b): assert sorted(capacities, reverse=True) == capacities -@mock_widget() -@gen_cluster() -async def test_multi_progressbar_widget_after_close(s, a, b): - s.update_graph( - tasks=valmap( - dumps_task, - { - "x-1": (inc, 1), - "x-2": (inc, "x-1"), - "x-3": (inc, "x-2"), - "y-1": (dec, "x-3"), - "y-2": (dec, "y-1"), - "e": (throws, "y-2"), - "other": (inc, 123), - }, - ), - keys=["e"], - dependencies={ - "x-2": {"x-1"}, - "x-3": {"x-2"}, - "y-1": {"x-3"}, - "y-2": {"y-1"}, - "e": {"y-2"}, - }, - ) - - p = MultiProgressWidget(["x-1", "x-2", "x-3"], scheduler=s.address) - await p.listen() - - assert "x" in p.bars - - @mock_widget() def test_values(client): L = [client.submit(inc, i) for i in range(5)] @@ -232,32 +198,17 @@ def test_progressbar_cancel(client): @mock_widget() -@gen_cluster() -async def test_multibar_complete(s, a, b): - s.update_graph( - tasks=valmap( - dumps_task, - { - "x-1": (inc, 1), - "x-2": (inc, "x-1"), - "x-3": (inc, "x-2"), - "y-1": (dec, "x-3"), - "y-2": (dec, "y-1"), - "e": (throws, "y-2"), - "other": (inc, 123), - }, - ), - keys=["e"], - dependencies={ - "x-2": {"x-1"}, - "x-3": {"x-2"}, - "y-1": {"x-3"}, - "y-2": {"y-1"}, - "e": {"y-2"}, - }, - ) - - p = MultiProgressWidget(["e"], scheduler=s.address, complete=True) +@gen_cluster(client=True) +async def test_multibar_complete(c, s, a, b): + x1 = c.submit(inc, 1, key="x-1") + x2 = c.submit(inc, x1, key="x-2") + x3 = c.submit(inc, x2, key="x-3") + y1 = c.submit(dec, x3, key="y-1") + y2 = c.submit(dec, y1, key="y-2") + e = c.submit(throws, y2, key="e") + other = c.submit(inc, 123, key="other") + + p = MultiProgressWidget([e.key], scheduler=s.address, complete=True) await p.listen() assert p._last_response["all"] == {"x": 3, "y": 2, "e": 1} diff --git a/distributed/protocol/serialize.py b/distributed/protocol/serialize.py index 5882d6eb44..0f3d1998bf 100644 --- a/distributed/protocol/serialize.py +++ b/distributed/protocol/serialize.py @@ -342,6 +342,7 @@ def serialize( # type: ignore[no-untyped-def] return headers, frames tb = "" + exc = None for name in serializers: dumps, _, wants_context = families[name] @@ -351,11 +352,14 @@ def serialize( # type: ignore[no-untyped-def] return header, frames except NotImplementedError: continue - except Exception: + except Exception as e: + exc = e tb = traceback.format_exc() break - - msg = f"Could not serialize object of type {type(x).__name__}" + type_x = type(x) + if isinstance(x, (ToPickle, Serialize)): + type_x = type(x.data) + msg = f"Could not serialize object of type {type_x.__name__}" if on_error == "message": txt_frames = [msg] if tb: @@ -365,7 +369,7 @@ def serialize( # type: ignore[no-untyped-def] return {"serializer": "error"}, frames elif on_error == "raise": - raise TypeError(msg, str(x)[:10000]) + raise TypeError(msg, str(x)[:10000]) from exc else: # pragma: nocover raise ValueError(f"{on_error=}; expected 'message' or 'raise'") diff --git a/distributed/protocol/tests/test_highlevelgraph.py b/distributed/protocol/tests/test_highlevelgraph.py index adb150bfbc..540f5472d8 100644 --- a/distributed/protocol/tests/test_highlevelgraph.py +++ b/distributed/protocol/tests/test_highlevelgraph.py @@ -26,6 +26,7 @@ def add(x, y, z, extra_arg): y = c.submit(lambda x: x, 2) z = c.submit(lambda x: x, 3) + xx = await c.submit(lambda x: x + 1, y) x = da.blockwise( add, "x", diff --git a/distributed/protocol/tests/test_to_pickle.py b/distributed/protocol/tests/test_to_pickle.py index 5e2deb00bf..c526270dcd 100644 --- a/distributed/protocol/tests/test_to_pickle.py +++ b/distributed/protocol/tests/test_to_pickle.py @@ -1,11 +1,7 @@ from __future__ import annotations -import dask.config -from dask.highlevelgraph import HighLevelGraph, MaterializedLayer - from distributed.protocol import dumps, loads from distributed.protocol.serialize import ToPickle -from distributed.utils_test import gen_cluster def test_ToPickle(): @@ -17,30 +13,3 @@ def __init__(self, data): frames = dumps(msg) out = loads(frames) assert out["x"].data == 123 - - -class NonMsgPackSerializableLayer(MaterializedLayer): - """Layer that uses non-msgpack-serializable data""" - - def __dask_distributed_pack__(self, *args, **kwargs): - ret = super().__dask_distributed_pack__(*args, **kwargs) - # Some info that contains a `list`, which msgpack will convert to - # a tuple if getting the chance. - ret["myinfo"] = ["myinfo"] - return ToPickle(ret) - - @classmethod - def __dask_distributed_unpack__(cls, state, *args, **kwargs): - assert state["myinfo"] == ["myinfo"] - return super().__dask_distributed_unpack__(state, *args, **kwargs) - - -@gen_cluster(client=True) -async def test_non_msgpack_serializable_layer(c, s, a, b): - with dask.config.set({"distributed.scheduler.allowed-imports": "test_to_pickle"}): - a = NonMsgPackSerializableLayer({"x": 42}) - layers = {"a": a} - dependencies = {"a": set()} - hg = HighLevelGraph(layers, dependencies) - res = await c.get(hg, "x", sync=False) - assert res == 42 diff --git a/distributed/scheduler.py b/distributed/scheduler.py index eab4156146..c43f4d5415 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -14,6 +14,7 @@ import pickle import random import sys +import textwrap import uuid import warnings import weakref @@ -31,7 +32,6 @@ ) from contextlib import suppress from functools import partial -from numbers import Number from typing import TYPE_CHECKING, Any, ClassVar, Literal, NamedTuple, cast, overload import psutil @@ -50,13 +50,15 @@ from tornado.ioloop import IOLoop import dask -from dask.highlevelgraph import HighLevelGraph +import dask.utils +from dask.core import get_deps from dask.utils import ( format_bytes, format_time, key_split, parse_bytes, parse_timedelta, + stringify, tmpfile, ) from dask.widgets import get_template @@ -77,7 +79,7 @@ ) from distributed.comm.addressing import addresses_from_user_args from distributed.compatibility import PeriodicCallback -from distributed.core import Status, clean_exception, rpc, send_recv +from distributed.core import Status, clean_exception, error_message, rpc, send_recv from distributed.diagnostics.memory_sampler import MemorySamplerExtension from distributed.diagnostics.plugin import SchedulerPlugin, _get_plugin_name from distributed.event import EventExtension @@ -88,7 +90,7 @@ from distributed.node import ServerNode from distributed.proctitle import setproctitle from distributed.protocol.pickle import dumps, loads -from distributed.protocol.serialize import Serialized, serialize +from distributed.protocol.serialize import Serialized, ToPickle, serialize from distributed.publish import PublishExtension from distributed.pubsub import PubSubSchedulerExtension from distributed.queues import QueueExtension @@ -114,6 +116,7 @@ gather_from_workers, retry_operation, scatter_to_workers, + unpack_remotedata, ) from distributed.utils_perf import disable_gc_diagnosis, enable_gc_diagnosis from distributed.variable import VariableExtension @@ -846,6 +849,7 @@ class Computation: groups: set[TaskGroup] code: SortedSet id: uuid.UUID + annotations: dict __slots__ = tuple(__annotations__) @@ -854,6 +858,7 @@ def __init__(self): self.groups = set() self.code = SortedSet() self.id = uuid.uuid4() + self.annotations = {} @property def stop(self) -> float: @@ -1133,7 +1138,7 @@ class TaskState: #: within a large graph that may be important, such as if they are on the critical #: path, or good to run in order to release many dependencies. This is explained #: further in :doc:`Scheduling Policy `. - priority: tuple[int, ...] + priority: tuple[float, ...] | None # Attribute underlying the state property _state: TaskStateState @@ -1336,7 +1341,7 @@ def __init__( self.suspicious = 0 self.retries = 0 self.nbytes = -1 - self.priority = None # type: ignore + self.priority = None self.who_wants = set() self.dependencies = set() self.dependents = set() @@ -1345,8 +1350,8 @@ def __init__( self.who_has = set() self.processing_on = None self.has_lost_dependencies = False - self.host_restrictions = None # type: ignore - self.worker_restrictions = None # type: ignore + self.host_restrictions = set() + self.worker_restrictions = set() self.resource_restrictions = {} self.loose_restrictions = False self.actor = False @@ -3260,7 +3265,7 @@ def _task_to_msg(self, ts: TaskState, duration: float = -1) -> dict[str, Any]: if duration < 0: duration = self.get_task_duration(ts) ts.run_id = next(TaskState._run_id_iterator) - + assert ts.priority, ts msg: dict[str, Any] = { "op": "compute-task", "key": ts.key, @@ -3576,7 +3581,6 @@ def __init__( client_handlers = { "update-graph": self.update_graph, - "update-graph-hlg": self.update_graph_hlg, "client-desires-keys": self.client_desires_keys, "update-data": self.update_data, "report-key": self.report_on_key, @@ -4245,103 +4249,75 @@ async def add_nanny(self) -> dict[str, Any]: } return msg - def update_graph_hlg( + def update_graph( self, - client=None, - hlg=None, - keys=None, - dependencies=None, - restrictions=None, - priority=None, - loose_restrictions=None, - resources=None, - submitting_task=None, - retries=None, - user_priority=0, - actors=None, - fifo_timeout=0, - code=None, - ): - unpacked_graph = HighLevelGraph.__dask_distributed_unpack__(hlg) - dsk = unpacked_graph["dsk"] - dependencies = unpacked_graph["deps"] - annotations = unpacked_graph["annotations"] - - # Remove any self-dependencies (happens on test_publish_bag() and others) - for k, v in dependencies.items(): - deps = set(v) - if k in deps: - deps.remove(k) - dependencies[k] = deps - - if priority is None: - # Removing all non-local keys before calling order() - dsk_keys = set(dsk) # intersection() of sets is much faster than dict_keys - stripped_deps = { - k: v.intersection(dsk_keys) - for k, v in dependencies.items() - if k in dsk_keys - } - priority = dask.order.order(dsk, dependencies=stripped_deps) + client: str, + graph_header: dict, + graph_frames: list[bytes], + keys: list[str], + internal_priority: dict[str, int] | None, + submitting_task: str | None, + user_priority: int | dict[str, int] = 0, + actors: bool | list[str] | None = None, + fifo_timeout: float = 0.0, + code: str | None = None, + annotations: dict | None = None, + stimulus_id: str | None = None, + ) -> None: + start = time() + try: + # TODO: deserialization + materialization should be offloaded to a + # thread since this is non-trivial compute time that blocks the + # event loop. This likely requires us to use a lock since we need to + # guarantee ordering of update_graph calls (as long as there is just + # a single offload thread, this is not a problem) + from distributed.protocol import deserialize + + graph = deserialize(graph_header, graph_frames).data + except Exception as e: + msg = """\ + Error during deserialization of the task graph. This frequently occurs if the Scheduler and Client have different environments. For more information, see https://docs.dask.org/en/stable/deployment-considerations.html#consistent-software-environments + """ + try: + raise RuntimeError(textwrap.dedent(msg)) from e + except RuntimeError as e: + err = error_message(e) + for key in keys: + self.report( + { + "op": "task-erred", + "key": key, + "exception": err["exception"], + "traceback": err["traceback"], + } + ) - return self.update_graph( - client, + return + annotations = annotations or {} + if isinstance(annotations, ToPickle): # type: ignore + # FIXME: what the heck? + annotations = annotations.data # type: ignore + stimulus_id = stimulus_id or f"update-graph-{time()}" + ( dsk, - keys, dependencies, - restrictions, - priority, - loose_restrictions, - resources, - submitting_task, - retries, - user_priority, - actors, - fifo_timeout, - annotations, - code=code, - stimulus_id=f"update-graph-{time()}", - ) - - def update_graph( - self, - client=None, - tasks=None, - keys=None, - dependencies=None, - restrictions=None, - priority=None, - loose_restrictions=None, - resources=None, - submitting_task=None, - retries=None, - user_priority=0, - actors=None, - fifo_timeout=0, - annotations=None, - code=None, - stimulus_id=None, - ): - """ - Add new computations to the internal dask graph + layer_annotations, + pre_stringified_keys, + ) = self.materialize_graph(graph) - This happens whenever the Client calls submit, map, get, or compute. - """ - stimulus_id = stimulus_id or f"update-graph-{time()}" - start = time() - fifo_timeout = parse_timedelta(fifo_timeout) - keys = set(keys) - if len(tasks) > 1: + requested_keys = set(keys) + del keys + if len(dsk) > 1: self.log_event( - ["all", client], {"action": "update_graph", "count": len(tasks)} + ["all", client], {"action": "update_graph", "count": len(dsk)} ) + self._pop_known_tasks(dsk, dependencies) - # Remove aliases - for k in list(tasks): - if tasks[k] is k: - del tasks[k] - - dependencies = dependencies or {} + if lost_keys := self._pop_lost_tasks(dsk, dependencies, requested_keys): + self.report({"op": "cancelled-keys", "keys": lost_keys}, client=client) + self.client_releases_keys( + keys=lost_keys, client=client, stimulus_id=stimulus_id + ) if self.total_occupancy > 1e-9 and self.computations: # Still working on something. Assign new tasks to same computation @@ -4349,63 +4325,99 @@ def update_graph( else: computation = Computation() self.computations.append(computation) - if code and code not in computation.code: # add new code blocks computation.code.add(code) + if annotations: + computation.annotations.update(annotations) - n = 0 - while len(tasks) != n: # walk through new tasks, cancel any bad deps - n = len(tasks) - for k, deps in list(dependencies.items()): - if any( - dep not in self.tasks and dep not in tasks for dep in deps - ): # bad key - logger.info("User asked for computation on lost data, %s", k) - del tasks[k] - del dependencies[k] - if k in keys: - keys.remove(k) - self.report({"op": "cancelled-key", "key": k}, client=client) - self.client_releases_keys( - keys=[k], client=client, stimulus_id=stimulus_id - ) + runnable, touched_tasks, new_tasks = self._generate_taskstates( + keys=requested_keys, + dsk=dsk, + dependencies=dependencies, + computation=computation, + ) + # FIXME: These "resolved_annotations" are a big duplication and are only + # required to satisfy the current plugin API. This should be + # reconsidered. + resolved_annotations = self._parse_and_apply_annotations( + tasks=new_tasks, + annotations=annotations, + layer_annotations=layer_annotations, + pre_stringified_keys=pre_stringified_keys, + ) - # Avoid computation that is already finished - already_in_memory = set() # tasks that are already done - for k, v in dependencies.items(): - if v and k in self.tasks: - ts = self.tasks[k] - if ts.state in ("memory", "erred"): - already_in_memory.add(k) + self._set_priorities( + internal_priority=internal_priority, + submitting_task=submitting_task, + user_priority=user_priority, + fifo_timeout=fifo_timeout, + start=start, + dsk=dsk, + tasks=runnable, + dependencies=dependencies, + ) - if already_in_memory: - dependents = dask.core.reverse_dict(dependencies) - stack = list(already_in_memory) - done = set(already_in_memory) - while stack: # remove unnecessary dependencies - key = stack.pop() - try: - deps = dependencies[key] - except KeyError: - deps = self.tasks[key].dependencies - for dep in deps: - if dep in dependents: - child_deps = dependents[dep] - elif dep in self.tasks: - child_deps = self.tasks[dep].dependencies - else: - child_deps = set() - if all(d in done for d in child_deps): - if dep in self.tasks and dep not in done: - done.add(dep) - stack.append(dep) + self.client_desires_keys(keys=requested_keys, client=client) + + # Add actors + if actors is True: + actors = list(requested_keys) + for actor in actors or []: + ts = self.tasks[actor] + ts.actor = True + + # Compute recommendations + recommendations: Recs = {} + priority = dict() + for ts in sorted( + runnable, + key=operator.attrgetter("priority"), + reverse=True, + ): + assert ts.priority # mypy + priority[ts.key] = ts.priority + assert ts.run_spec + if ts.state == "released": + recommendations[ts.key] = "waiting" + + for ts in runnable: + for dts in ts.dependencies: + if dts.exception_blame: + ts.exception_blame = dts.exception_blame + recommendations[ts.key] = "erred" + break + + for plugin in list(self.plugins.values()): + try: + plugin.update_graph( + self, + client=client, + tasks=[ts.key for ts in touched_tasks], + keys=requested_keys, + dependencies=dependencies, + annotations=resolved_annotations, + priority=priority, + ) + except Exception as e: + logger.exception(e) - for d in done: - tasks.pop(d, None) - dependencies.pop(d, None) + self.transitions(recommendations, stimulus_id) + for ts in touched_tasks: + if ts.state in ("memory", "erred"): + self.report_on_key(ts=ts, client=client) + + end = time() + self.digest_metric("update-graph-duration", end - start) + + def _generate_taskstates( + self, keys: set[str], dsk: dict, dependencies: dict, computation + ): # Get or create task states + runnable = [] + new_tasks = [] stack = list(keys) + tasks = [] touched_keys = set() touched_tasks = [] while stack: @@ -4414,17 +4426,19 @@ def update_graph( continue # XXX Have a method get_task_state(self, k) ? ts = self.tasks.get(k) + tasks.append(ts) if ts is None: - ts = self.new_task(k, tasks.get(k), "released", computation=computation) + ts = self.new_task(k, dsk.get(k), "released", computation=computation) + new_tasks.append(ts) elif not ts.run_spec: - ts.run_spec = tasks.get(k) + ts.run_spec = dsk.get(k) + if ts.run_spec: + runnable.append(ts) touched_keys.add(k) touched_tasks.append(ts) stack.extend(dependencies.get(k, ())) - self.client_desires_keys(keys=keys, client=client) - # Add dependencies for key, deps in dependencies.items(): ts = self.tasks.get(key) @@ -4434,58 +4448,112 @@ def update_graph( dts = self.tasks[dep] ts.add_dependency(dts) - # Compute priorities - if isinstance(user_priority, Number): - user_priority = {k: user_priority for k in tasks} - - annotations = annotations or {} - restrictions = restrictions or {} - loose_restrictions = loose_restrictions or [] - resources = resources or {} - retries = retries or {} - - # Override existing taxonomy with per task annotations - if annotations: - if "priority" in annotations: - user_priority.update(annotations["priority"]) - - if "workers" in annotations: - restrictions.update(annotations["workers"]) + if len(touched_tasks) < len(keys): + logger.info( + "Submitted graph with length %s but requested graph only includes %s keys", + len(touched_tasks), + len(keys), + ) + return runnable, touched_tasks, new_tasks - if "allow_other_workers" in annotations: - loose_restrictions.extend( - k for k, v in annotations["allow_other_workers"].items() if v - ) + def _parse_and_apply_annotations( + self, + tasks: Iterable[TaskState], + annotations: dict, + layer_annotations: dict[str, dict], + pre_stringified_keys: dict[Hashable, str], + ) -> dict[str, dict[str, Any]]: + """Apply the provided annotations to the provided `TaskState` objects. - if "retries" in annotations: - retries.update(annotations["retries"]) + The raw annotations will be stored in the `annotations` attribute. - if "resources" in annotations: - resources.update(annotations["resources"]) + Layer / key specific annotations will take precedence over global / generic annotations. - for a, kv in annotations.items(): - for k, v in kv.items(): - # Tasks might have been culled, in which case - # we have nothing to annotate. - ts = self.tasks.get(k) - if ts is not None: - ts.annotations[a] = v + Parameters + ---------- + tasks : Iterable[TaskState] + _description_ + annotations : dict + _description_ + layer_annotations : dict[str, dict] + _description_ - # Add actors - if actors is True: - actors = list(keys) - for actor in actors or []: - ts = self.tasks[actor] - ts.actor = True + Returns + ------- + resolved_annotations: dict + A mapping of all resolved annotations in the format:: - priority = priority or dask.order.order( - tasks - ) # TODO: define order wrt old graph + { + "annotation": { + "key": value, + ... + }, + ... + } + """ + resolved_annotations: dict[str, dict[str, Any]] = defaultdict(dict) + for ts in tasks: + key = ts.key + # This could be a typed dict + if not annotations and key not in layer_annotations: + continue + out = annotations.copy() + out.update(layer_annotations.get(key, {})) + for annot, value in out.items(): + # Pop the key since names don't always match attributes + if callable(value): + value = value(pre_stringified_keys[key]) + out[annot] = value + resolved_annotations[annot][key] = value + + if annot in ("restrictions", "workers"): + if not isinstance(value, (list, tuple, set)): + value = [value] + host_restrictions = set() + worker_restrictions = set() + for w in value: + try: + w = self.coerce_address(w) + except ValueError: + # Not a valid address, but perhaps it's a hostname + host_restrictions.add(w) + else: + worker_restrictions.add(w) + if host_restrictions: + ts.host_restrictions = host_restrictions + if worker_restrictions: + ts.worker_restrictions = worker_restrictions + elif annot in ("loose_restrictions", "allow_other_workers"): + ts.loose_restrictions = value + elif annot == "resources": + assert isinstance(value, dict) + ts.resource_restrictions = value + elif annot == "priority": + # See Scheduler._set_priorities + continue + elif annot == "retries": + assert isinstance(value, int) + ts.retries = value + ts.annotations = out + return dict(resolved_annotations) + def _set_priorities( + self, + internal_priority: dict[str, int] | None, + submitting_task: str | None, + user_priority: int | dict[str, int], + fifo_timeout: int | float | str, + start: float, + dsk: dict, + tasks: list[TaskState], + dependencies: dict, + ): + fifo_timeout = parse_timedelta(fifo_timeout) if submitting_task: # sub-tasks get better priority than parent tasks - ts = self.tasks.get(submitting_task) - if ts is not None: - generation = ts.priority[0] - 0.01 + sts = self.tasks.get(submitting_task) + if sts is not None: + assert sts.priority + generation = sts.priority[0] - 0.01 else: # super-task already cleaned up generation = self.generation elif self._last_time + fifo_timeout < start: @@ -4495,104 +4563,148 @@ def update_graph( else: generation = self.generation - for key in set(priority) & touched_keys: - ts = self.tasks[key] - if ts.priority is None: - ts.priority = (-(user_priority.get(key, 0)), generation, priority[key]) - - # Ensure all runnables have a priority - runnables = [ts for ts in touched_tasks if ts.run_spec] - for ts in runnables: - if ts.priority is None and ts.run_spec: - ts.priority = (self.generation, 0) - - if restrictions: - # *restrictions* is a dict keying task ids to lists of - # restriction specifications (either worker names or addresses) - for k, v in restrictions.items(): - if v is None: - continue - ts = self.tasks.get(k) - if ts is None: - continue - ts.host_restrictions = set() - ts.worker_restrictions = set() - # Make sure `v` is a collection and not a single worker name / address - if not isinstance(v, (list, tuple, set)): - v = [v] - for w in v: - try: - w = self.coerce_address(w) - except ValueError: - # Not a valid address, but perhaps it's a hostname - ts.host_restrictions.add(w) - else: - ts.worker_restrictions.add(w) - - if loose_restrictions: - for k in loose_restrictions: - ts = self.tasks[k] - ts.loose_restrictions = True - - if resources: - for k, v in resources.items(): - if v is None: - continue - assert isinstance(v, dict) - ts = self.tasks.get(k) - if ts is None: - continue - ts.resource_restrictions = v - - if retries: - for k, v in retries.items(): - assert isinstance(v, int) - ts = self.tasks.get(k) - if ts is None: - continue - ts.retries = v - - # Compute recommendations - recommendations: Recs = {} + if internal_priority is None: + # Removing all non-local keys before calling order() + dsk_keys = set(dsk) # intersection() of sets is much faster than dict_keys + stripped_deps = { + k: v.intersection(dsk_keys) + for k, v in dependencies.items() + if k in dsk_keys + } + internal_priority = dask.order.order(dsk, dependencies=stripped_deps) + + for ts in tasks: + # Note: Under which circumstances would a task not have a + # prioritiy assigned by now? Are these scattered tasks + # exclusively or something else? + task_user_prio = user_priority + if isinstance(user_priority, dict): + task_user_prio = user_priority.get(ts.key, 0) + annotated_prio = ts.annotations.get("priority", {}) + if not annotated_prio: + annotated_prio = task_user_prio + + if not ts.priority and ts.key in internal_priority: + ts.priority = ( + -annotated_prio, + generation, + internal_priority[ts.key], + ) - for ts in sorted(runnables, key=operator.attrgetter("priority"), reverse=True): - if ts.state == "released" and ts.run_spec: - recommendations[ts.key] = "waiting" + if self.validate and ts.run_spec: + assert isinstance(ts.priority, tuple) and all( + isinstance(el, (int, float)) for el in ts.priority + ) - for ts in touched_tasks: - for dts in ts.dependencies: - if dts.exception_blame: - ts.exception_blame = dts.exception_blame - recommendations[ts.key] = "erred" - break + def _pop_lost_tasks( + self, dsk: dict, dependencies: dict, keys: set[str] + ) -> set[str]: + n = 0 + out = set() + while len(dsk) != n: # walk through new tasks, cancel any bad deps + n = len(dsk) + for k, deps in list(dependencies.items()): + if any( + dep not in self.tasks and dep not in dsk for dep in deps + ): # bad key + out.add(k) + logger.info("User asked for computation on lost data, %s", k) + del dsk[k] + del dependencies[k] + if k in keys: + keys.remove(k) + return out - for plugin in list(self.plugins.values()): - try: - plugin.update_graph( - self, - client=client, - tasks=tasks, - keys=keys, - restrictions=restrictions or {}, - dependencies=dependencies, - priority=priority, - loose_restrictions=loose_restrictions, - resources=resources, - annotations=annotations, - ) - except Exception as e: - logger.exception(e) + def _pop_known_tasks(self, dsk: dict, dependencies: dict) -> set[str]: + # Avoid computation that is already finished + already_in_memory = set() # tasks that are already done + for k, v in dependencies.items(): + if v and k in self.tasks: + ts = self.tasks[k] + if ts.state in ("memory", "erred"): + already_in_memory.add(k) - self.transitions(recommendations, stimulus_id) + done = set(already_in_memory) + if already_in_memory: + dependents = dask.core.reverse_dict(dependencies) + stack = list(already_in_memory) + while stack: # remove unnecessary dependencies + key = stack.pop() + try: + deps = dependencies[key] + except KeyError: + deps = self.tasks[key].dependencies + for dep in deps: + if dep in dependents: + child_deps = dependents[dep] + elif dep in self.tasks: + child_deps = self.tasks[dep].dependencies + else: + child_deps = set() + if all(d in done for d in child_deps): + if dep in self.tasks and dep not in done: + done.add(dep) + stack.append(dep) + for anc in done: + dsk.pop(anc, None) + dependencies.pop(anc, None) + return done - for ts in touched_tasks: - if ts.state in ("memory", "erred"): - self.report_on_key(ts=ts, client=client) + @staticmethod + def materialize_graph(hlg) -> tuple[dict, dict, dict, dict]: + from distributed.worker import dumps_task + + key_annotations = {} + dsk = dask.utils.ensure_dict(hlg) + + for layer in hlg.layers.values(): + if layer.annotations: + annot = layer.annotations + key_annotations.update({stringify(k): annot for k in layer}) + + dependencies, _ = get_deps(dsk) + + # Remove `Future` objects from graph and note any future dependencies + dsk2 = {} + fut_deps = {} + for k, v in dsk.items(): + dsk2[k], futs = unpack_remotedata(v, byte_keys=True) + if futs: + fut_deps[k] = futs + dsk = dsk2 + + # - Add in deps for any tasks that depend on futures + for k, futures in fut_deps.items(): + dependencies[k].update(f.key for f in futures) + new_dsk = {} + # Annotation callables are evaluated on the non-stringified version of + # the keys + pre_stringified_keys = {} + exclusive = set(hlg) + for k, v in dsk.items(): + new_k = stringify(k) + pre_stringified_keys[new_k] = k + new_dsk[new_k] = stringify(v, exclusive=exclusive) + assert len(new_dsk) == len(pre_stringified_keys) + dsk = new_dsk + dependencies = { + stringify(k): {stringify(dep) for dep in deps} + for k, deps in dependencies.items() + } - end = time() - self.digest_metric("update-graph-duration", end - start) + # Remove any self-dependencies (happens on test_publish_bag() and others) + for k, v in dependencies.items(): + deps = set(v) + if k in deps: + deps.remove(k) + dependencies[k] = deps - # TODO: balance workers + # Remove aliases + for k in list(dsk): + if dsk[k] is k: + del dsk[k] + dsk = valmap(dumps_task, dsk) + return dsk, dependencies, key_annotations, pre_stringified_keys def stimulus_queue_slots_maybe_opened(self, *, stimulus_id: str) -> None: """Respond to an event which may have opened spots on worker threadpools @@ -8176,8 +8288,8 @@ def __init__(self, scheduler: Scheduler, name: str): def update_graph( self, scheduler: Scheduler, + *, keys: set[str], - restrictions: dict[str, float], **kwargs: Any, ) -> None: self.keys.update(keys) diff --git a/distributed/shuffle/_merge.py b/distributed/shuffle/_merge.py index 9fcc81b9a7..a4eeb4a995 100644 --- a/distributed/shuffle/_merge.py +++ b/distributed/shuffle/_merge.py @@ -4,15 +4,10 @@ from collections import defaultdict from typing import TYPE_CHECKING, Any, Iterable, Sequence -import toolz - from dask.base import is_dask_collection, tokenize -from dask.core import keys_in_tasks from dask.highlevelgraph import HighLevelGraph from dask.layers import Layer -from dask.utils import stringify, stringify_collection_keys -from distributed.protocol.serialize import to_serialize from distributed.shuffle._shuffle import ( ShuffleId, _get_worker_extension, @@ -165,10 +160,6 @@ def merge_unpack( ): from dask.dataframe.multi import merge_chunk - from distributed.protocol import deserialize - - # FIXME: This is odd. - result_meta = deserialize(result_meta.header, result_meta.frames) ext = _get_worker_extension() left = ext.get_output_partition( shuffle_id_left, barrier_left, output_partition @@ -381,59 +372,3 @@ def _construct_graph(self) -> dict[tuple | str, tuple]: self.suffixes, ) return dsk - - @classmethod - def __dask_distributed_unpack__( - cls, state: dict, dsk: dict, dependecies: dict - ) -> dict: - from distributed.worker import dumps_task - - to_list = [ - "left_on", - "suffixes", - "right_on", - ] - # msgpack will convert lists into tuples, here - # we convert them back to lists - for attr in to_list: - if isinstance(state[attr], tuple): - state[attr] = list(state[attr]) - - # Materialize the layer - layer_dsk = cls(**state)._construct_graph() - - # Convert all keys to strings and dump tasks - layer_dsk = { - stringify(k): stringify_collection_keys(v) for k, v in layer_dsk.items() - } - keys = layer_dsk.keys() | dsk.keys() - deps = {} - for k, v in layer_dsk.items(): - deps[k] = d = keys_in_tasks(keys, [v]) - assert d - return {"dsk": toolz.valmap(dumps_task, layer_dsk), "deps": deps} - - def __dask_distributed_pack__( - self, - all_hlg_keys: Any, - known_key_dependencies: Any, - client: Any, - client_keys: Any, - ) -> dict: - return { - "name": self.name, - "name_input_left": self.name_input_left, - "left_on": self.left_on, - "name_input_right": self.name_input_right, - "right_on": self.right_on, - "meta_output": to_serialize(self.meta_output), - "how": self.how, - "npartitions": self.npartitions, - "suffixes": self.suffixes, - "indicator": self.indicator, - "parts_out": self.parts_out, - "n_partitions_left": self.n_partitions_left, - "n_partitions_right": self.n_partitions_right, - "left_index": self.left_index, - "right_index": self.right_index, - } diff --git a/distributed/shuffle/_shuffle.py b/distributed/shuffle/_shuffle.py index f14fe6c479..ba749fe1fc 100644 --- a/distributed/shuffle/_shuffle.py +++ b/distributed/shuffle/_shuffle.py @@ -2,12 +2,12 @@ import logging from enum import Enum -from typing import TYPE_CHECKING, Any, NewType +from typing import TYPE_CHECKING, Any, Iterable, Iterator, NewType import dask from dask.base import tokenize from dask.highlevelgraph import HighLevelGraph -from dask.layers import SimpleShuffleLayer +from dask.layers import Layer from distributed.shuffle._arrow import check_dtype_support, check_minimal_arrow_version @@ -15,6 +15,9 @@ if TYPE_CHECKING: import pandas as pd + # TODO import from typing (requires Python >=3.10) + from typing_extensions import TypeAlias + from dask.dataframe import DataFrame # circular dependency @@ -127,9 +130,7 @@ def rearrange_by_column_p2p( column, npartitions, npartitions_input=df.npartitions, - ignore_index=True, name_input=df._name, - meta_input=empty, ) return DataFrame( HighLevelGraph.from_collections(name, layer, [df]), @@ -139,58 +140,112 @@ def rearrange_by_column_p2p( ) -class P2PShuffleLayer(SimpleShuffleLayer): +# TODO remove quotes (requires Python >=3.9) +_T_Key: TypeAlias = "tuple[str, int] | str" +_T_LowLevelGraph: TypeAlias = "dict[_T_Key, tuple]" + + +class P2PShuffleLayer(Layer): def __init__( self, name: str, column: str, npartitions: int, npartitions_input: int, - ignore_index: bool, name_input: str, - meta_input: pd.DataFrame, - parts_out: list | None = None, + parts_out: Iterable | None = None, annotations: dict | None = None, ): check_minimal_arrow_version() annotations = annotations or {} annotations.update({"shuffle": lambda key: key[1]}) - super().__init__( - name, - column, - npartitions, - npartitions_input, - ignore_index, - name_input, - meta_input, - parts_out, - annotations=annotations, - ) - - def get_split_keys(self) -> list: - # TODO: This is doing some funky stuff to set priorities but we don't need this - return [] + self.name = name + self.column = column + self.npartitions = npartitions + self.name_input = name_input + if parts_out: + self.parts_out = set(parts_out) + else: + self.parts_out = set(range(self.npartitions)) + self.npartitions_input = npartitions_input + super().__init__(annotations=annotations) def __repr__(self) -> str: return ( f"{type(self).__name__}" ) - def _cull(self, parts_out: list) -> P2PShuffleLayer: + def get_output_keys(self) -> set[_T_Key]: + return {(self.name, part) for part in self.parts_out} + + def is_materialized(self) -> bool: + return hasattr(self, "_cached_dict") + + @property + def _dict(self) -> _T_LowLevelGraph: + """Materialize full dict representation""" + self._cached_dict: _T_LowLevelGraph + dsk: _T_LowLevelGraph + if hasattr(self, "_cached_dict"): + return self._cached_dict + else: + dsk = self._construct_graph() + self._cached_dict = dsk + return self._cached_dict + + def __getitem__(self, key: _T_Key) -> tuple: + return self._dict[key] + + def __iter__(self) -> Iterator[_T_Key]: + return iter(self._dict) + + def __len__(self) -> int: + return len(self._dict) + + def _cull(self, parts_out: Iterable[int]) -> P2PShuffleLayer: return P2PShuffleLayer( self.name, self.column, self.npartitions, self.npartitions_input, - self.ignore_index, self.name_input, - self.meta_input, parts_out=parts_out, ) - def _construct_graph(self, deserializing: Any = None) -> dict[tuple | str, tuple]: + def _keys_to_parts(self, keys: Iterable[_T_Key]) -> set[int]: + """Simple utility to convert keys to partition indices.""" + parts = set() + for key in keys: + if isinstance(key, tuple) and len(key) == 2: + _name, _part = key + if _name != self.name: + continue + parts.add(_part) + return parts + + def cull( + self, keys: Iterable[_T_Key], all_keys: Any + ) -> tuple[P2PShuffleLayer, dict]: + """Cull a P2PShuffleLayer HighLevelGraph layer. + + The underlying graph will only include the necessary + tasks to produce the keys (indices) included in `parts_out`. + Therefore, "culling" the layer only requires us to reset this + parameter. + """ + parts_out = self._keys_to_parts(keys) + input_parts = {(self.name_input, i) for i in range(self.npartitions_input)} + culled_deps = {(self.name, part): input_parts.copy() for part in parts_out} + + if parts_out != set(self.parts_out): + culled_layer = self._cull(parts_out) + return culled_layer, culled_deps + else: + return self, culled_deps + + def _construct_graph(self) -> _T_LowLevelGraph: token = tokenize(self.name_input, self.column, self.npartitions, self.parts_out) - dsk: dict[tuple | str, tuple] = {} + dsk: _T_LowLevelGraph = {} _barrier_key = barrier_key(ShuffleId(token)) name = "shuffle-transfer-" + token transfer_keys = list() diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index aa02962850..7d593df319 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -16,7 +16,6 @@ import threading import traceback import types -import warnings import weakref import zipfile from collections import deque, namedtuple @@ -113,6 +112,7 @@ nodebug, poll_for, popen, + raises_with_cause, randominc, save_sys_modules, slowadd, @@ -333,31 +333,6 @@ def test_retries_get(c): x.compute() -@gen_cluster(client=True) -async def test_compute_persisted_retries(c, s, a, b): - args = [ZeroDivisionError("one"), ZeroDivisionError("two"), 3] - - # Sanity check - x = c.persist(delayed(varying(args))()) - fut = c.compute(x) - with pytest.raises(ZeroDivisionError, match="one"): - await fut - - x = c.persist(delayed(varying(args))()) - fut = c.compute(x, retries=1) - with pytest.raises(ZeroDivisionError, match="two"): - await fut - - x = c.persist(delayed(varying(args))()) - fut = c.compute(x, retries=2) - assert await fut == 3 - - args.append(4) - x = c.persist(delayed(varying(args))()) - fut = c.compute(x, retries=3) - assert await fut == 3 - - @gen_cluster(client=True) async def test_persist_retries(c, s, a, b): # Same retries for all @@ -888,8 +863,8 @@ async def test_restrictions_submit(c, s, a, b): @gen_cluster(client=True, config=NO_AMM) async def test_restrictions_ip_port(c, s, a, b): - x = c.submit(inc, 1, workers={a.address}) - y = c.submit(inc, x, workers={b.address}) + x = c.submit(inc, 1, workers={a.address}, key="x") + y = c.submit(inc, x, workers={b.address}, key="y") await wait([x, y]) assert s.tasks[x.key].worker_restrictions == {a.address} @@ -1904,34 +1879,6 @@ def __setstate__(self, state): raise TypeError("hello!") -class FatallySerializedObject: - def __getstate__(self): - return 1 - - def __setstate__(self, state): - print("This should never have been deserialized, closing") - import sys - - sys.exit(0) - - -@gen_cluster(client=True) -async def test_badly_serialized_input(c, s, a, b): - o = BadlySerializedObject() - - future = c.submit(inc, o) - futures = c.map(inc, range(10)) - - L = await c.gather(futures) - assert list(L) == list(map(inc, range(10))) - assert future.status == "error" - - with pytest.raises(Exception) as info: - await future - - assert "hello!" in str(info.value) - - @pytest.mark.skip @gen_test() async def test_badly_serialized_input_stderr(capsys, c): @@ -2506,20 +2453,21 @@ async def test_futures_of_cancelled_raises(c, s, a, b): with pytest.raises(CancelledError): await x - + while x.key in s.tasks: + await asyncio.sleep(0.01) with pytest.raises(CancelledError): - await c.get({"x": (inc, x), "y": (inc, 2)}, ["x", "y"], sync=False) + get_obj = c.get({"x": (inc, x), "y": (inc, 2)}, ["x", "y"], sync=False) + gather_obj = c.gather(get_obj) + await gather_obj with pytest.raises(CancelledError): - c.submit(inc, x) + await c.submit(inc, x) with pytest.raises(CancelledError): - c.submit(add, 1, y=x) + await c.submit(add, 1, y=x) with pytest.raises(CancelledError): - c.map(add, [1], y=x) - - assert "y" not in s.tasks + await c.gather(c.map(add, [1], y=x)) @pytest.mark.skip @@ -2541,16 +2489,6 @@ async def test_dont_delete_recomputed_results(c, s, w): await asyncio.sleep(0.01) -@gen_cluster(nthreads=[], client=True) -async def test_fatally_serialized_input(c, s): - o = FatallySerializedObject() - - future = c.submit(inc, o) - - while not s.tasks: - await asyncio.sleep(0.01) - - @pytest.mark.skip(reason="Use fast random selection now") @gen_cluster(client=True) async def test_balance_tasks_by_stacks(c, s, a, b): @@ -2987,7 +2925,7 @@ async def test_submit_on_cancelled_future(c, s, a, b): await c.cancel(x) with pytest.raises(CancelledError): - c.submit(inc, x) + await c.submit(inc, x) @gen_cluster( @@ -4127,8 +4065,10 @@ async def test_persist_workers_annotate(e, s, a, b, c): @gen_cluster(nthreads=[("127.0.0.1", 1)] * 3, client=True) async def test_persist_workers_annotate2(e, s, a, b, c): + addr = a.address + def key_to_worker(key): - return a.address + return addr L1 = [delayed(inc)(i) for i in range(4)] for x in L1: @@ -4915,7 +4855,7 @@ class Foo: def __getstate__(self): raise MyException() - with pytest.raises(MyException): + with pytest.raises(TypeError, match="Could not serialize"): future = c.submit(identity, Foo()) futures = c.map(inc, range(10)) @@ -4935,7 +4875,9 @@ def __setstate__(self, state): raise MyException("hello") future = c.submit(identity, Foo()) - with pytest.raises(MyException): + await wait(future) + assert future.status == "error" + with raises_with_cause(RuntimeError, "deserialization", MyException, "hello"): await future futures = c.map(inc, range(10)) @@ -4958,7 +4900,9 @@ def __call__(self, *args): return 1 future = c.submit(Foo(), 1) - with pytest.raises(MyException): + await wait(future) + assert future.status == "error" + with raises_with_cause(RuntimeError, "deserialization", MyException, "hello"): await future futures = c.map(inc, range(10)) @@ -5797,22 +5741,14 @@ async def test_config_scheduler_address(s, a, b): assert sio.getvalue() == f"Config value `scheduler-address` found: {s.address}\n" +@pytest.mark.filterwarnings("ignore:Large object:UserWarning") @gen_cluster(client=True) async def test_warn_when_submitting_large_values(c, s, a, b): with pytest.warns( UserWarning, - match=r"Large object of size (2\.00 MB|1.91 MiB) detected in task graph:" - r" \n \(b'00000000000000000000000000000000000000000000000 \.\.\. 000000000000',\)" - r"\nConsider scattering large objects ahead of time.*", + match="Sending large graph of size", ): - future = c.submit(lambda x: x + 1, b"0" * 2000000) - - with warnings.catch_warnings(record=True) as record: - data = b"0" * 2000000 - for i in range(10): - future = c.submit(lambda x, y: x, data, i) - - assert not record + future = c.submit(lambda x: x + 1, b"0" * 10_000_000) @gen_cluster(client=True) @@ -6070,15 +6006,23 @@ def f(): @gen_cluster() -async def test_mixing_clients(s, a, b): +async def test_mixing_clients_same_scheduler(s, a, b): async with Client(s.address, asynchronous=True) as c1, Client( s.address, asynchronous=True ) as c2: future = c1.submit(inc, 1) - with pytest.raises(ValueError): - c2.submit(inc, future) + assert await c2.submit(inc, future) == 3 + assert not s.tasks - assert not c2.futures # Don't create Futures on second Client + +@gen_cluster() +async def test_mixing_clients_different_scheduler(s, a, b): + async with Scheduler(port=open_port()) as s2, Worker(s2.address) as w1, Client( + s.address, asynchronous=True + ) as c1, Client(s2.address, asynchronous=True) as c2: + future = c1.submit(inc, 1) + with pytest.raises(CancelledError): + await c2.submit(inc, future) @gen_cluster(client=True) @@ -6914,7 +6858,7 @@ async def test_annotations_workers(c, s, a, b): with dask.config.set(optimization__fuse__active=False): x = await x.persist() - assert all({"workers": (a.address,)} == ts.annotations for ts in s.tasks.values()) + assert all({"workers": [a.address]} == ts.annotations for ts in s.tasks.values()) assert all({a.address} == ts.worker_restrictions for ts in s.tasks.values()) assert a.data assert not b.data @@ -7021,7 +6965,7 @@ async def test_annotations_loose_restrictions(c, s, a, b): assert all({"fake"} == ts.host_restrictions for ts in s.tasks.values()) assert all( [ - {"workers": ("fake",), "allow_other_workers": True} == ts.annotations + {"workers": ["fake"], "allow_other_workers": True} == ts.annotations for ts in s.tasks.values() ] ) @@ -7115,6 +7059,19 @@ def nested_call(): assert nested_call() == upper_frame_code +@gen_cluster(client=True, nthreads=[("", 1)]) +async def test_computation_store_annotations(c, s, a): + # We do not want to store layer annotations + with dask.annotate(layer="foo"): + f = delayed(inc)(1) + + with dask.annotate(job="very-important"): + assert await c.compute(f) == 2 + + assert len(s.computations) == 1 + assert s.computations[0].annotations == {"job": "very-important"} + + def test_computation_object_code_dask_compute(client): da = pytest.importorskip("dask.array") x = da.ones((10, 10), chunks=(3, 3)) diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index c8221053af..3e70082ffc 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -18,11 +18,12 @@ import cloudpickle import psutil import pytest -from tlz import concat, first, merge, valmap +from tlz import concat, first, merge from tornado.ioloop import IOLoop import dask from dask import delayed +from dask.highlevelgraph import HighLevelGraph, MaterializedLayer from dask.utils import apply, parse_timedelta, stringify, tmpfile, typename from distributed import ( @@ -40,7 +41,9 @@ from distributed.compatibility import LINUX, MACOS, WINDOWS, PeriodicCallback from distributed.core import ConnectionPool, Status, clean_exception, connect, rpc from distributed.metrics import time +from distributed.protocol import serialize from distributed.protocol.pickle import dumps, loads +from distributed.protocol.serialize import ToPickle from distributed.scheduler import KilledWorker, MemoryState, Scheduler, WorkerState from distributed.utils import TimeoutError, wait_for from distributed.utils_test import ( @@ -718,19 +721,20 @@ async def test_server_listens_to_other_ops(s, a, b): assert ident["id"].lower().startswith("scheduler") -@gen_cluster() -async def test_remove_worker_from_scheduler(s, a, b): - dsk = {("x-%d" % i): (inc, i) for i in range(20)} - s.update_graph( - tasks=valmap(dumps_task, dsk), - keys=list(dsk), - dependencies={k: set() for k in dsk}, - ) +@gen_cluster(client=True) +async def test_remove_worker_from_scheduler(c, s, a, b): + """see also test_ready_remove_worker""" + ev = Event() + futs = c.map(lambda x, ev: ev.wait(), range(20), ev=ev) + while len(s.tasks) != len(futs): + await asyncio.sleep(0.01) assert a.address in s.stream_comms await s.remove_worker(address=a.address, stimulus_id="test") assert a.address not in s.workers - assert len(s.workers[b.address].processing) + len(s.queued) == len(dsk) + assert len(s.workers[b.address].processing) + len(s.queued) == len(futs) + await ev.set() + await c.gather(futs) @gen_cluster() @@ -926,47 +930,6 @@ async def test_delete(c, s, a): s.report_on_key(key=x.key) -@gen_cluster() -async def test_filtered_communication(s, a, b): - c = await connect(s.address) - f = await connect(s.address) - await c.write({"op": "register-client", "client": "c", "versions": {}}) - await f.write({"op": "register-client", "client": "f", "versions": {}}) - await c.read() - await f.read() - - assert set(s.client_comms) == {"c", "f"} - - await c.write( - { - "op": "update-graph", - "tasks": {"x": dumps_task((inc, 1)), "y": dumps_task((inc, "x"))}, - "dependencies": {"x": [], "y": ["x"]}, - "client": "c", - "keys": ["y"], - } - ) - - await f.write( - { - "op": "update-graph", - "tasks": { - "x": dumps_task((inc, 1)), - "z": dumps_task((operator.add, "x", 10)), - }, - "dependencies": {"x": [], "z": ["x"]}, - "client": "f", - "keys": ["z"], - } - ) - (msg,) = await c.read() - assert msg["op"] == "key-in-memory" - assert msg["key"] == "y" - (msg,) = await f.read() - assert msg["op"] == "key-in-memory" - assert msg["key"] == "z" - - def test_dumps_function(): a = dumps_function(inc) assert cloudpickle.loads(a)(10) == 11 @@ -997,23 +960,21 @@ def f(x, y=2): @pytest.mark.parametrize("worker_saturation", [1.0, float("inf")]) -@gen_cluster() -async def test_ready_remove_worker(s, a, b, worker_saturation): +@gen_cluster(client=True) +async def test_ready_remove_worker(c, s, a, b, worker_saturation): + """see also test_remove_worker_from_scheduler""" s.WORKER_SATURATION = worker_saturation - s.update_graph( - tasks={"x-%d" % i: dumps_task((inc, i)) for i in range(20)}, - keys=["x-%d" % i for i in range(20)], - client="client", - dependencies={"x-%d" % i: [] for i in range(20)}, - ) + ev = Event() + futs = c.map(lambda x, ev: ev.wait(), range(20), ev=ev) + while len(s.tasks) != len(futs): + await asyncio.sleep(0.01) if s.WORKER_SATURATION == 1: cmp = operator.eq elif math.isinf(s.WORKER_SATURATION): cmp = operator.gt else: pytest.fail(f"{s.WORKER_SATURATION=}, must be 1 or inf") - assert all(cmp(len(w.processing), w.nthreads) for w in s.workers.values()), ( list(s.workers.values()), s.WORKER_SATURATION, @@ -1021,9 +982,7 @@ async def test_ready_remove_worker(s, a, b, worker_saturation): assert sum(len(w.processing) for w in s.workers.values()) + len(s.queued) == len( s.tasks ) - await s.remove_worker(address=a.address, stimulus_id="test") - assert set(s.workers) == {b.address} assert all(cmp(len(w.processing), w.nthreads) for w in s.workers.values()), ( list(s.workers.values()), @@ -1032,6 +991,7 @@ async def test_ready_remove_worker(s, a, b, worker_saturation): assert sum(len(w.processing) for w in s.workers.values()) + len(s.queued) == len( s.tasks ) + await ev.set() @pytest.mark.slow @@ -1371,15 +1331,33 @@ async def test_file_descriptors_dont_leak(s): @gen_cluster() async def test_update_graph_culls(s, a, b): - s.update_graph( - tasks={ - "x": dumps_task((inc, 1)), - "y": dumps_task((inc, "x")), - "z": dumps_task((inc, 2)), + # This is a rather low level API but the fact that update_graph actually + # culls is worth testing and hard to do so with high level user API. Most + # but not all HLGs are implementing culling themselves already, i.e. a graph + # like the one written here will rarely exist in reality. It's worth to + # consider dropping this from the scheduler iff graph materialization + # actually ensure this + dsk = HighLevelGraph( + layers={ + "foo": MaterializedLayer( + { + "x": dumps_task((inc, 1)), + "y": dumps_task((inc, "x")), + "z": dumps_task((inc, 2)), + } + ) }, + dependencies={"foo": set()}, + ) + + header, frames = serialize(ToPickle(dsk), on_error="raise") + s.update_graph( + graph_header=header, + graph_frames=frames, keys=["y"], - dependencies={"y": "x", "x": [], "z": []}, client="client", + internal_priority={k: 0 for k in "xyz"}, + submitting_task=None, ) assert "z" not in s.tasks diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py index 55b004f68e..0cfd66ada3 100644 --- a/distributed/tests/test_steal.py +++ b/distributed/tests/test_steal.py @@ -262,7 +262,6 @@ def blocked_task(x, lock): await steal.start() # A is still blocked by executing task f-1 so this can only pass if # workstealing moves the tasks to B - await asyncio.sleep(5) await c.gather(more_tasks) assert len(b.data) == 10 await first diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index bb0d751d28..e85fd0b34e 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -2756,6 +2756,7 @@ async def test_forget_dependents_after_release(c, s, a): assert fut2.key not in {d.key for d in a.state.tasks[fut.key].dependents} +@pytest.mark.filterwarnings("ignore:Sending large graph of size") @pytest.mark.filterwarnings("ignore:Large object of size") @gen_cluster(client=True) async def test_steal_during_task_deserialization(c, s, a, b, monkeypatch): @@ -2788,24 +2789,6 @@ async def custom_worker_offload(func, *args): await fut3 -@gen_cluster(client=True) -async def test_run_spec_deserialize_fail(c, s, a, b): - class F: - def __call__(self): - pass - - def __reduce__(self): - return lambda: 1 / 0, () - - with captured_logger("distributed.worker") as logger: - fut = c.submit(F()) - assert isinstance(await fut.exception(), ZeroDivisionError) - - logvalue = logger.getvalue() - assert "Could not deserialize task" in logvalue - assert "return lambda: 1 / 0, ()" in logvalue - - @gen_cluster(client=True, config=NO_AMM) async def test_acquire_replicas(c, s, a, b): fut = c.submit(inc, 1, workers=[a.address])