diff --git a/distributed/comm/inproc.py b/distributed/comm/inproc.py index 52d6ea1ba7..dad7350568 100644 --- a/distributed/comm/inproc.py +++ b/distributed/comm/inproc.py @@ -13,7 +13,7 @@ from distributed.comm.core import BaseListener, Comm, CommClosedError, Connector from distributed.comm.registry import Backend, backends -from distributed.protocol import nested_deserialize +from distributed.protocol.serialize import _nested_deserialize from distributed.utils import get_ip, is_python_shutting_down logger = logging.getLogger(__name__) @@ -218,8 +218,7 @@ async def read(self, deserializers="ignored"): self._finalizer.detach() raise CommClosedError() - if self.deserialize: - msg = nested_deserialize(msg) + msg = _nested_deserialize(msg, self.deserialize) return msg async def write(self, msg, serializers=None, on_error=None): diff --git a/distributed/protocol/__init__.py b/distributed/protocol/__init__.py index ad43feb493..4aa87deadd 100644 --- a/distributed/protocol/__init__.py +++ b/distributed/protocol/__init__.py @@ -6,8 +6,10 @@ from distributed.protocol.core import decompress, dumps, loads, maybe_compress, msgpack from distributed.protocol.cuda import cuda_deserialize, cuda_serialize from distributed.protocol.serialize import ( + Pickled, Serialize, Serialized, + ToPickle, dask_deserialize, dask_serialize, deserialize, diff --git a/distributed/protocol/serialize.py b/distributed/protocol/serialize.py index eb74a61407..fe2d6879a4 100644 --- a/distributed/protocol/serialize.py +++ b/distributed/protocol/serialize.py @@ -3,6 +3,7 @@ import codecs import importlib import traceback +import warnings from array import array from enum import Enum from functools import partial @@ -621,6 +622,14 @@ def __ne__(self, other): def nested_deserialize(x): + warnings.warn( + "nested_deserialize is deprecated and will be removed in a future release.", + DeprecationWarning, + ) + return _nested_deserialize(x, emulate_deserialize=True) + + +def _nested_deserialize(x, emulate_deserialize=True): """ Replace all Serialize and Serialized values nested in *x* with the original values. Returns a copy of *x*. @@ -637,10 +646,13 @@ def replace_inner(x): typ = type(v) if typ is dict or typ is list: x[k] = replace_inner(v) - elif typ is Serialize: + if emulate_deserialize: + if typ is Serialize: + x[k] = v.data + elif typ is Serialized: + x[k] = deserialize(v.header, v.frames) + if typ is ToPickle: x[k] = v.data - elif typ is Serialized: - x[k] = deserialize(v.header, v.frames) elif type(x) is list: x = list(x) @@ -648,10 +660,13 @@ def replace_inner(x): typ = type(v) if typ is dict or typ is list: x[k] = replace_inner(v) - elif typ is Serialize: + if emulate_deserialize: + if typ is Serialize: + x[k] = v.data + elif typ is Serialized: + x[k] = deserialize(v.header, v.frames) + if typ is ToPickle: x[k] = v.data - elif typ is Serialized: - x[k] = deserialize(v.header, v.frames) return x diff --git a/distributed/protocol/tests/test_serialize.py b/distributed/protocol/tests/test_serialize.py index 8cc85b5db4..d91d160d02 100644 --- a/distributed/protocol/tests/test_serialize.py +++ b/distributed/protocol/tests/test_serialize.py @@ -20,12 +20,12 @@ from distributed.protocol import ( Serialize, Serialized, + ToPickle, dask_serialize, deserialize, deserialize_bytes, dumps, loads, - nested_deserialize, register_serialization, register_serialization_family, serialize, @@ -35,6 +35,7 @@ ) from distributed.protocol.serialize import ( _is_msgpack_serializable, + _nested_deserialize, check_dask_serializable, ) from distributed.utils import ensure_memoryview, nbytes @@ -166,12 +167,24 @@ def test_nested_deserialize(): "x": [to_serialize(123), to_serialize(456), 789], "y": {"a": ["abc", Serialized(*serialize("def"))], "b": b"ghi"}, } + + x_orig = copy.deepcopy(x) + assert _nested_deserialize(x, emulate_deserialize=False) == x_orig + + assert x == x_orig # x wasn't mutated + x["topickle"] = ToPickle(1) + x["topickle_nested"] = [1, ToPickle(2)] x_orig = copy.deepcopy(x) + assert (out := _nested_deserialize(x, emulate_deserialize=False)) != x_orig + assert out["topickle"] == 1 + assert out["topickle_nested"] == [1, 2] - assert nested_deserialize(x) == { + assert _nested_deserialize(x) == { "op": "update", "x": [123, 456, 789], "y": {"a": ["abc", "def"], "b": b"ghi"}, + "topickle": 1, + "topickle_nested": [1, 2], } assert x == x_orig # x wasn't mutated diff --git a/distributed/scheduler.py b/distributed/scheduler.py index dc87640677..ec63548009 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -4676,9 +4676,6 @@ async def update_graph( annotations: dict | None = None, stimulus_id: str | None = None, ) -> None: - # FIXME: Apparently empty dicts arrive as a ToPickle object - if isinstance(annotations, ToPickle): - annotations = annotations.data # type: ignore[unreachable] start = time() try: try: diff --git a/distributed/shuffle/tests/utils.py b/distributed/shuffle/tests/utils.py index 7191be5d05..1d4fb7319a 100644 --- a/distributed/shuffle/tests/utils.py +++ b/distributed/shuffle/tests/utils.py @@ -22,7 +22,7 @@ def __init__(self, shuffle: ShuffleRun): def __getattr__(self, key): async def _(**kwargs): - from distributed.protocol.serialize import nested_deserialize + from distributed.protocol.serialize import _nested_deserialize method_name = key.replace("shuffle_", "") kwargs.pop("shuffle_id", None) @@ -30,7 +30,7 @@ async def _(**kwargs): # TODO: This is a bit awkward. At some point the arguments are # already getting wrapped with a `Serialize`. We only want to unwrap # here. - kwargs = nested_deserialize(kwargs) + kwargs = _nested_deserialize(kwargs) meth = getattr(self.shuffle, method_name) return await meth(**kwargs)