Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use pickle for graph submissions from client to scheduler #7564

Merged
merged 27 commits into from Mar 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
28 changes: 20 additions & 8 deletions distributed/client.py
Expand Up @@ -1930,7 +1930,7 @@ def submit(
[skey],
workers=workers,
allow_other_workers=allow_other_workers,
priority={skey: 0},
internal_priority={skey: 0},
user_priority=priority,
resources=resources,
retries=retries,
Expand Down Expand Up @@ -2134,7 +2134,7 @@ def map(
keys,
workers=workers,
allow_other_workers=allow_other_workers,
priority=internal_priority,
internal_priority=internal_priority,
resources=resources,
retries=retries,
user_priority=priority,
Expand Down Expand Up @@ -3035,7 +3035,7 @@ def _graph_to_futures(
keys,
workers=None,
allow_other_workers=None,
priority=None,
internal_priority=None,
user_priority=0,
resources=None,
retries=None,
Expand Down Expand Up @@ -3071,21 +3071,33 @@ def _graph_to_futures(

# Pack the high level graph before sending it to the scheduler
keyset = set(keys)
dsk = dsk.__dask_distributed_pack__(self, keyset, annotations)

# Create futures before sending graph (helps avoid contention)
futures = {key: Future(key, self, inform=False) for key in keyset}

# Circular import
from distributed.protocol import serialize
from distributed.protocol.serialize import ToPickle

header, frames = serialize(ToPickle(dsk), on_error="raise")
nbytes = len(header) + sum(map(len, frames))
if nbytes > 10_000_000:
warnings.warn(
f"Sending large graph of size {format_bytes(nbytes)}.\n"
"This may cause some slowdown.\n"
"Consider scattering data ahead of time and using futures."
)
self._send_to_scheduler(
{
"op": "update-graph-hlg",
"hlg": dsk,
"op": "update-graph",
"graph_header": header,
"graph_frames": frames,
"keys": list(map(stringify, keys)),
"priority": priority,
"internal_priority": internal_priority,
"submitting_task": getattr(thread_state, "key", None),
"fifo_timeout": fifo_timeout,
"actors": actors,
"code": self._get_computation_code(),
"annotations": ToPickle(annotations),
}
)
return futures
Expand Down
2 changes: 1 addition & 1 deletion distributed/diagnostics/graph_layout.py
Expand Up @@ -46,7 +46,7 @@ def __init__(self, scheduler):
)

def update_graph(
self, scheduler, dependencies=None, priority=None, tasks=None, **kwargs
self, scheduler, *, dependencies=None, priority=None, tasks=None, **kwargs
):
stack = sorted(tasks, key=lambda k: priority.get(k, 0), reverse=True)
while stack:
Expand Down
38 changes: 36 additions & 2 deletions distributed/diagnostics/plugin.py
Expand Up @@ -74,11 +74,45 @@ async def close(self) -> None:
def update_graph(
self,
scheduler: Scheduler,
*,
client: str,
keys: set[str],
restrictions: dict[str, float],
tasks: list[str],
annotations: dict[str, dict[str, Any]],
priority: dict[str, tuple[int | float, ...]],
dependencies: dict[str, set],
**kwargs: Any,
) -> None:
"""Run when a new graph / tasks enter the scheduler"""
"""Run when a new graph / tasks enter the scheduler

Parameters
----------
scheduler:
The `Scheduler` instance.
client:
The unique Client id.
keys:
The keys the Client is interested in when calling `update_graph`.
tasks:
The
annotations:
Fully resolved annotations as applied to the tasks in the format::

{
"annotation": {
"key": "value,
...
},
...
}
priority:
Task calculated priorities as assigned to the tasks.
dependencies:
A mapping that maps a key to its dependencies.
**kwargs:
It is recommended to allow plugins to accept more parameters to
ensure future compatibility.
"""

def restart(self, scheduler: Scheduler) -> None:
"""Run when the scheduler restarts itself"""
Expand Down
94 changes: 94 additions & 0 deletions distributed/diagnostics/tests/test_scheduler_plugin.py
Expand Up @@ -243,3 +243,97 @@ async def close(self):
assert "BEFORE_CLOSE" in text
text = logger.getvalue()
assert "AFTER_CLOSE" in text


@gen_cluster(client=True)
async def test_update_graph_hook_simple(c, s, a, b):
class UpdateGraph(SchedulerPlugin):
def __init__(self) -> None:
self.success = False

def update_graph( # type: ignore
self,
scheduler,
client,
keys,
tasks,
annotations,
priority,
dependencies,
**kwargs,
) -> None:
assert scheduler is s
assert client == c.id
# If new parameters are added we should add a test
assert not kwargs
assert keys == {"foo"}
assert tasks == ["foo"]
assert annotations == {}
assert len(priority) == 1
assert isinstance(priority["foo"], tuple)
assert dependencies == {"foo": set()}
self.success = True

plugin = UpdateGraph()
s.add_plugin(plugin, name="update-graph")

await c.submit(inc, 5, key="foo")
assert plugin.success


import dask
from dask import delayed


@gen_cluster(client=True)
async def test_update_graph_hook_complex(c, s, a, b):
class UpdateGraph(SchedulerPlugin):
def __init__(self) -> None:
self.success = False

def update_graph( # type: ignore
self,
scheduler,
client,
keys,
tasks,
annotations,
priority,
dependencies,
**kwargs,
) -> None:
assert scheduler is s
assert client == c.id
# If new parameters are added we should add a test
assert not kwargs
assert keys == {"sum"}
assert set(tasks) == {"sum", "f1", "f3", "f2"}
assert annotations == {
"global_annot": {k: 24 for k in tasks},
"layer": {"f2": "explicit"},
"len_key": {"f3": 2},
"priority": {"f2": 13},
}
assert len(priority) == len(tasks), priority
assert priority["f2"][0] == -13
for k in keys:
assert k in dependencies
assert dependencies["f1"] == set()
assert dependencies["sum"] == {"f1", "f3"}

self.success = True

plugin = UpdateGraph()
s.add_plugin(plugin, name="update-graph")
del_inc = delayed(inc)
f1 = del_inc(1, dask_key_name="f1")
with dask.annotate(layer="explicit", priority=13):
f2 = del_inc(2, dask_key_name="f2")
with dask.annotate(len_key=lambda x: len(x)):
f3 = del_inc(f2, dask_key_name="f3")

f4 = delayed(sum)([f1, f3], dask_key_name="sum")

with dask.annotate(global_annot=24):
await c.compute(f4)
assert plugin.success
71 changes: 11 additions & 60 deletions distributed/diagnostics/tests/test_widgets.py
Expand Up @@ -7,11 +7,9 @@

import pytest
from packaging.version import parse as parse_version
from tlz import valmap

from distributed.client import wait
from distributed.utils_test import dec, gen_cluster, gen_tls_cluster, inc, throws
from distributed.worker import dumps_task

ipywidgets = pytest.importorskip("ipywidgets")

Expand Down Expand Up @@ -141,38 +139,6 @@ async def test_multi_progressbar_widget(c, s, a, b):
assert sorted(capacities, reverse=True) == capacities


@mock_widget()
@gen_cluster()
async def test_multi_progressbar_widget_after_close(s, a, b):
s.update_graph(
tasks=valmap(
dumps_task,
{
"x-1": (inc, 1),
"x-2": (inc, "x-1"),
"x-3": (inc, "x-2"),
"y-1": (dec, "x-3"),
"y-2": (dec, "y-1"),
"e": (throws, "y-2"),
"other": (inc, 123),
},
),
keys=["e"],
dependencies={
"x-2": {"x-1"},
"x-3": {"x-2"},
"y-1": {"x-3"},
"y-2": {"y-1"},
"e": {"y-2"},
},
)

p = MultiProgressWidget(["x-1", "x-2", "x-3"], scheduler=s.address)
await p.listen()

assert "x" in p.bars


@mock_widget()
def test_values(client):
L = [client.submit(inc, i) for i in range(5)]
Expand Down Expand Up @@ -232,32 +198,17 @@ def test_progressbar_cancel(client):


@mock_widget()
@gen_cluster()
async def test_multibar_complete(s, a, b):
s.update_graph(
tasks=valmap(
dumps_task,
{
"x-1": (inc, 1),
"x-2": (inc, "x-1"),
"x-3": (inc, "x-2"),
"y-1": (dec, "x-3"),
"y-2": (dec, "y-1"),
"e": (throws, "y-2"),
"other": (inc, 123),
},
),
keys=["e"],
dependencies={
"x-2": {"x-1"},
"x-3": {"x-2"},
"y-1": {"x-3"},
"y-2": {"y-1"},
"e": {"y-2"},
},
)

p = MultiProgressWidget(["e"], scheduler=s.address, complete=True)
@gen_cluster(client=True)
async def test_multibar_complete(c, s, a, b):
x1 = c.submit(inc, 1, key="x-1")
x2 = c.submit(inc, x1, key="x-2")
x3 = c.submit(inc, x2, key="x-3")
y1 = c.submit(dec, x3, key="y-1")
y2 = c.submit(dec, y1, key="y-2")
e = c.submit(throws, y2, key="e")
other = c.submit(inc, 123, key="other")

p = MultiProgressWidget([e.key], scheduler=s.address, complete=True)
await p.listen()

assert p._last_response["all"] == {"x": 3, "y": 2, "e": 1}
Expand Down
12 changes: 8 additions & 4 deletions distributed/protocol/serialize.py
Expand Up @@ -342,6 +342,7 @@ def serialize( # type: ignore[no-untyped-def]
return headers, frames

tb = ""
exc = None

for name in serializers:
dumps, _, wants_context = families[name]
Expand All @@ -351,11 +352,14 @@ def serialize( # type: ignore[no-untyped-def]
return header, frames
except NotImplementedError:
continue
except Exception:
except Exception as e:
exc = e
tb = traceback.format_exc()
break

msg = f"Could not serialize object of type {type(x).__name__}"
type_x = type(x)
if isinstance(x, (ToPickle, Serialize)):
type_x = type(x.data)
msg = f"Could not serialize object of type {type_x.__name__}"
if on_error == "message":
txt_frames = [msg]
if tb:
Expand All @@ -365,7 +369,7 @@ def serialize( # type: ignore[no-untyped-def]

return {"serializer": "error"}, frames
elif on_error == "raise":
raise TypeError(msg, str(x)[:10000])
raise TypeError(msg, str(x)[:10000]) from exc
else: # pragma: nocover
raise ValueError(f"{on_error=}; expected 'message' or 'raise'")

Expand Down
1 change: 1 addition & 0 deletions distributed/protocol/tests/test_highlevelgraph.py
Expand Up @@ -26,6 +26,7 @@ def add(x, y, z, extra_arg):

y = c.submit(lambda x: x, 2)
z = c.submit(lambda x: x, 3)
xx = await c.submit(lambda x: x + 1, y)
x = da.blockwise(
add,
"x",
Expand Down