diff --git a/distributed/shuffle/_pickle.py b/distributed/shuffle/_pickle.py index 4db706565be..c22da2170be 100644 --- a/distributed/shuffle/_pickle.py +++ b/distributed/shuffle/_pickle.py @@ -1,11 +1,16 @@ from __future__ import annotations import pickle -from collections.abc import Iterator -from typing import Any +from collections.abc import Iterable, Iterator +from typing import TYPE_CHECKING, Any + +from toolz import first from distributed.protocol.utils import pack_frames_prelude, unpack_frames +if TYPE_CHECKING: + import pandas as pd + def pickle_bytelist(obj: object, prelude: bool = True) -> list[pickle.PickleBuffer]: """Variant of :func:`serialize_bytelist`, that doesn't support compression, locally @@ -39,3 +44,72 @@ def unpickle_bytestream(b: bytes | bytearray | memoryview) -> Iterator[Any]: if remainder.nbytes == 0: break b = remainder + + +def pickle_dataframe_shard( + input_part_id: int, + shard: pd.DataFrame, +) -> list[pickle.PickleBuffer]: + """Optimized pickler for pandas Dataframes. DIscard all unnecessary metadata + (like the columns header). + + Parameters: + obj: pandas + """ + return pickle_bytelist( + (input_part_id, shard.index, *shard._mgr.blocks), prelude=False + ) + + +def unpickle_and_concat_dataframe_shards( + parts: Iterable[Any], meta: pd.DataFrame +) -> pd.DataFrame: + """Optimized unpickler for pandas Dataframes. + + Parameters + ---------- + parts: + output of ``unpickle_bytestream(b)``, where b is the memory-mapped blob of + pickled data which is the concatenation of the outputs of + :func:`pickle_dataframe_shard` in arbitrary order + meta: + DataFrame header + + Returns + ------- + Reconstructed output shard, sorted by input partition ID + + **Roundtrip example** + + .. code-block:: python + + import random + import pandas as pd + + df = pd.DataFrame(...) # Input partition + meta = df.iloc[:0] + shards = df.iloc[0:10], df.iloc[10:20], ... + frames = [pickle_dataframe_shard(i, shard) for i, shard in enumerate(shards)] + random.shuffle(frames) # Simulate the frames arriving in arbitrary order + frames = [f for fs in frames for f in fs] # Flatten + blob = bytearray(b"".join(frames)) # Simulate disk roundtrip + parts = unpickle_bytestream(blob) + df2 = unpickle_and_concat_dataframe_shards(parts, meta) + + """ + import pandas as pd + from pandas.core.internals import BlockManager + + # [(input_part_id, index, *blocks), ...] + parts = sorted(parts, key=first) + shards = [] + for _, idx, *blocks in parts: + axes = [meta.columns, idx] + df = pd.DataFrame._from_mgr( # type: ignore[attr-defined] + BlockManager(blocks, axes, verify_integrity=False), axes + ) + shards.append(df) + + # Actually load memory-mapped buffers into memory and close the file + # descriptors + return pd.concat(shards, copy=True) diff --git a/distributed/shuffle/_shuffle.py b/distributed/shuffle/_shuffle.py index 8f2fd02fbd4..c9820046d95 100644 --- a/distributed/shuffle/_shuffle.py +++ b/distributed/shuffle/_shuffle.py @@ -17,7 +17,6 @@ from pickle import PickleBuffer from typing import TYPE_CHECKING, Any -from toolz import first from tornado.ioloop import IOLoop import dask @@ -42,7 +41,10 @@ ) from distributed.shuffle._exceptions import DataUnavailable from distributed.shuffle._limiter import ResourceLimiter -from distributed.shuffle._pickle import pickle_bytelist +from distributed.shuffle._pickle import ( + pickle_dataframe_shard, + unpickle_and_concat_dataframe_shards, +) from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin from distributed.utils import nbytes @@ -335,9 +337,7 @@ def split_by_worker( assert isinstance(output_part_id, int) if drop_column: del part[column] - frames = pickle_bytelist( - (input_part_id, part.index, *part._mgr.blocks), prelude=False - ) + frames = pickle_dataframe_shard(input_part_id, part) out[worker_for[output_part_id]].append((output_part_id, frames)) return {k: (input_part_id, v) for k, v in out.items()} @@ -516,9 +516,6 @@ def _get_output_partition( key: Key, **kwargs: Any, ) -> pd.DataFrame: - import pandas as pd - from pandas.core.internals import BlockManager - meta = self.meta.copy() if self.drop_column: meta = self.meta.drop(columns=self.column) @@ -528,19 +525,7 @@ def _get_output_partition( except DataUnavailable: return meta - # [(input_part_id, index, *blocks), ...] - parts = sorted(parts, key=first) - shards = [] - for _, idx, *blocks in parts: - axes = [meta.columns, idx] - df = pd.DataFrame._from_mgr( # type: ignore[attr-defined] - BlockManager(blocks, axes, verify_integrity=False), axes - ) - shards.append(df) - - # Actually load memory-mapped buffers into memory and close the file - # descriptors - return pd.concat(shards, copy=True) + return unpickle_and_concat_dataframe_shards(parts, meta) def _get_assigned_worker(self, id: int) -> str: return self.worker_for[id]