Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure inproc properly emulates serialization protocol #8622

Merged
merged 1 commit into from
Apr 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions distributed/comm/inproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from distributed.comm.core import BaseListener, Comm, CommClosedError, Connector
from distributed.comm.registry import Backend, backends
from distributed.protocol import nested_deserialize
from distributed.protocol.serialize import _nested_deserialize
from distributed.utils import get_ip, is_python_shutting_down

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -218,8 +218,7 @@ async def read(self, deserializers="ignored"):
self._finalizer.detach()
raise CommClosedError()

if self.deserialize:
msg = nested_deserialize(msg)
msg = _nested_deserialize(msg, self.deserialize)
return msg

async def write(self, msg, serializers=None, on_error=None):
Expand Down
2 changes: 2 additions & 0 deletions distributed/protocol/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
from distributed.protocol.core import decompress, dumps, loads, maybe_compress, msgpack
from distributed.protocol.cuda import cuda_deserialize, cuda_serialize
from distributed.protocol.serialize import (
Pickled,
Serialize,
Serialized,
ToPickle,
dask_deserialize,
dask_serialize,
deserialize,
Expand Down
27 changes: 21 additions & 6 deletions distributed/protocol/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import codecs
import importlib
import traceback
import warnings
from array import array
from enum import Enum
from functools import partial
Expand Down Expand Up @@ -621,6 +622,14 @@ def __ne__(self, other):


def nested_deserialize(x):
warnings.warn(
"nested_deserialize is deprecated and will be removed in a future release.",
DeprecationWarning,
)
return _nested_deserialize(x, emulate_deserialize=True)


def _nested_deserialize(x, emulate_deserialize=True):
"""
Replace all Serialize and Serialized values nested in *x*
with the original values. Returns a copy of *x*.
Expand All @@ -637,21 +646,27 @@ def replace_inner(x):
typ = type(v)
if typ is dict or typ is list:
x[k] = replace_inner(v)
elif typ is Serialize:
if emulate_deserialize:
if typ is Serialize:
x[k] = v.data
elif typ is Serialized:
x[k] = deserialize(v.header, v.frames)
if typ is ToPickle:
x[k] = v.data
elif typ is Serialized:
x[k] = deserialize(v.header, v.frames)

elif type(x) is list:
x = list(x)
for k, v in enumerate(x):
typ = type(v)
if typ is dict or typ is list:
x[k] = replace_inner(v)
elif typ is Serialize:
if emulate_deserialize:
if typ is Serialize:
x[k] = v.data
elif typ is Serialized:
x[k] = deserialize(v.header, v.frames)
if typ is ToPickle:
x[k] = v.data
elif typ is Serialized:
x[k] = deserialize(v.header, v.frames)

return x

Expand Down
17 changes: 15 additions & 2 deletions distributed/protocol/tests/test_serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@
from distributed.protocol import (
Serialize,
Serialized,
ToPickle,
dask_serialize,
deserialize,
deserialize_bytes,
dumps,
loads,
nested_deserialize,
register_serialization,
register_serialization_family,
serialize,
Expand All @@ -35,6 +35,7 @@
)
from distributed.protocol.serialize import (
_is_msgpack_serializable,
_nested_deserialize,
check_dask_serializable,
)
from distributed.utils import ensure_memoryview, nbytes
Expand Down Expand Up @@ -166,12 +167,24 @@ def test_nested_deserialize():
"x": [to_serialize(123), to_serialize(456), 789],
"y": {"a": ["abc", Serialized(*serialize("def"))], "b": b"ghi"},
}

x_orig = copy.deepcopy(x)
assert _nested_deserialize(x, emulate_deserialize=False) == x_orig

assert x == x_orig # x wasn't mutated
x["topickle"] = ToPickle(1)
x["topickle_nested"] = [1, ToPickle(2)]
x_orig = copy.deepcopy(x)
assert (out := _nested_deserialize(x, emulate_deserialize=False)) != x_orig
assert out["topickle"] == 1
assert out["topickle_nested"] == [1, 2]

assert nested_deserialize(x) == {
assert _nested_deserialize(x) == {
"op": "update",
"x": [123, 456, 789],
"y": {"a": ["abc", "def"], "b": b"ghi"},
"topickle": 1,
"topickle_nested": [1, 2],
}
assert x == x_orig # x wasn't mutated

Expand Down
3 changes: 0 additions & 3 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4676,9 +4676,6 @@ async def update_graph(
annotations: dict | None = None,
stimulus_id: str | None = None,
) -> None:
# FIXME: Apparently empty dicts arrive as a ToPickle object
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My initial diagnose was wrong. This has nothing to do with empty dicts

if isinstance(annotations, ToPickle):
annotations = annotations.data # type: ignore[unreachable]
start = time()
try:
try:
Expand Down
4 changes: 2 additions & 2 deletions distributed/shuffle/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,15 @@ def __init__(self, shuffle: ShuffleRun):

def __getattr__(self, key):
async def _(**kwargs):
from distributed.protocol.serialize import nested_deserialize
from distributed.protocol.serialize import _nested_deserialize

method_name = key.replace("shuffle_", "")
kwargs.pop("shuffle_id", None)
kwargs.pop("run_id", None)
# TODO: This is a bit awkward. At some point the arguments are
# already getting wrapped with a `Serialize`. We only want to unwrap
# here.
kwargs = nested_deserialize(kwargs)
kwargs = _nested_deserialize(kwargs)
meth = getattr(self.shuffle, method_name)
return await meth(**kwargs)

Expand Down
Loading