diff --git a/distributed/shuffle/_core.py b/distributed/shuffle/_core.py new file mode 100644 index 0000000000..fe8a38f5f0 --- /dev/null +++ b/distributed/shuffle/_core.py @@ -0,0 +1,277 @@ +from __future__ import annotations + +import abc +import asyncio +import contextlib +import itertools +import time +from collections import defaultdict +from collections.abc import Callable, Iterator +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass, field +from enum import Enum +from typing import TYPE_CHECKING, Any, ClassVar, Generic, NewType, TypeVar + +from distributed.core import PooledRPCCall +from distributed.exceptions import Reschedule +from distributed.protocol import to_serialize +from distributed.shuffle._comms import CommShardsBuffer +from distributed.shuffle._disk import DiskShardsBuffer +from distributed.shuffle._exceptions import ShuffleClosedError +from distributed.shuffle._limiter import ResourceLimiter + +if TYPE_CHECKING: + import pandas as pd + from typing_extensions import TypeAlias + + # avoid circular dependencies + from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin + +_T_partition_id = TypeVar("_T_partition_id") +_T_partition_type = TypeVar("_T_partition_type") +_T = TypeVar("_T") + +NDIndex: TypeAlias = tuple[int, ...] + +ShuffleId = NewType("ShuffleId", str) + + +class ShuffleRun(Generic[_T_partition_id, _T_partition_type]): + def __init__( + self, + id: ShuffleId, + run_id: int, + output_workers: set[str], + local_address: str, + directory: str, + executor: ThreadPoolExecutor, + rpc: Callable[[str], PooledRPCCall], + scheduler: PooledRPCCall, + memory_limiter_disk: ResourceLimiter, + memory_limiter_comms: ResourceLimiter, + ): + self.id = id + self.run_id = run_id + self.output_workers = output_workers + self.local_address = local_address + self.executor = executor + self.rpc = rpc + self.scheduler = scheduler + self.closed = False + + self._disk_buffer = DiskShardsBuffer( + directory=directory, + memory_limiter=memory_limiter_disk, + ) + + self._comm_buffer = CommShardsBuffer( + send=self.send, memory_limiter=memory_limiter_comms + ) + # TODO: reduce number of connections to number of workers + # MultiComm.max_connections = min(10, n_workers) + + self.diagnostics: dict[str, float] = defaultdict(float) + self.transferred = False + self.received: set[_T_partition_id] = set() + self.total_recvd = 0 + self.start_time = time.time() + self._exception: Exception | None = None + self._closed_event = asyncio.Event() + + def __repr__(self) -> str: + return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>" + + def __str__(self) -> str: + return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}" + + def __hash__(self) -> int: + return self.run_id + + @contextlib.contextmanager + def time(self, name: str) -> Iterator[None]: + start = time.time() + yield + stop = time.time() + self.diagnostics[name] += stop - start + + async def barrier(self) -> None: + self.raise_if_closed() + # TODO: Consider broadcast pinging once when the shuffle starts to warm + # up the comm pool on scheduler side + await self.scheduler.shuffle_barrier(id=self.id, run_id=self.run_id) + + async def send( + self, address: str, shards: list[tuple[_T_partition_id, bytes]] + ) -> None: + self.raise_if_closed() + return await self.rpc(address).shuffle_receive( + data=to_serialize(shards), + shuffle_id=self.id, + run_id=self.run_id, + ) + + async def offload(self, func: Callable[..., _T], *args: Any) -> _T: + self.raise_if_closed() + with self.time("cpu"): + return await asyncio.get_running_loop().run_in_executor( + self.executor, + func, + *args, + ) + + def heartbeat(self) -> dict[str, Any]: + comm_heartbeat = self._comm_buffer.heartbeat() + comm_heartbeat["read"] = self.total_recvd + return { + "disk": self._disk_buffer.heartbeat(), + "comm": comm_heartbeat, + "diagnostics": self.diagnostics, + "start": self.start_time, + } + + async def _write_to_comm( + self, data: dict[str, tuple[_T_partition_id, bytes]] + ) -> None: + self.raise_if_closed() + await self._comm_buffer.write(data) + + async def _write_to_disk(self, data: dict[NDIndex, bytes]) -> None: + self.raise_if_closed() + await self._disk_buffer.write( + {"_".join(str(i) for i in k): v for k, v in data.items()} + ) + + def raise_if_closed(self) -> None: + if self.closed: + if self._exception: + raise self._exception + raise ShuffleClosedError(f"{self} has already been closed") + + async def inputs_done(self) -> None: + self.raise_if_closed() + self.transferred = True + await self._flush_comm() + try: + self._comm_buffer.raise_on_exception() + except Exception as e: + self._exception = e + raise + + async def _flush_comm(self) -> None: + self.raise_if_closed() + await self._comm_buffer.flush() + + async def flush_receive(self) -> None: + self.raise_if_closed() + await self._disk_buffer.flush() + + async def close(self) -> None: + if self.closed: # pragma: no cover + await self._closed_event.wait() + return + + self.closed = True + await self._comm_buffer.close() + await self._disk_buffer.close() + self._closed_event.set() + + def fail(self, exception: Exception) -> None: + if not self.closed: + self._exception = exception + + def _read_from_disk(self, id: NDIndex) -> bytes: + self.raise_if_closed() + data: bytes = self._disk_buffer.read("_".join(str(i) for i in id)) + return data + + async def receive(self, data: list[tuple[_T_partition_id, bytes]]) -> None: + await self._receive(data) + + async def _ensure_output_worker(self, i: _T_partition_id, key: str) -> None: + assigned_worker = self._get_assigned_worker(i) + + if assigned_worker != self.local_address: + result = await self.scheduler.shuffle_restrict_task( + id=self.id, run_id=self.run_id, key=key, worker=assigned_worker + ) + if result["status"] == "error": + raise RuntimeError(result["message"]) + assert result["status"] == "OK" + raise Reschedule() + + @abc.abstractmethod + def _get_assigned_worker(self, i: _T_partition_id) -> str: + """Get the address of the worker assigned to the output partition""" + + @abc.abstractmethod + async def _receive(self, data: list[tuple[_T_partition_id, bytes]]) -> None: + """Receive shards belonging to output partitions of this shuffle run""" + + @abc.abstractmethod + async def add_partition( + self, data: _T_partition_type, partition_id: _T_partition_id + ) -> int: + """Add an input partition to the shuffle run""" + + @abc.abstractmethod + async def get_output_partition( + self, partition_id: _T_partition_id, key: str, meta: pd.DataFrame | None = None + ) -> _T_partition_type: + """Get an output partition to the shuffle run""" + + +def get_worker_plugin() -> ShuffleWorkerPlugin: + from distributed import get_worker + + try: + worker = get_worker() + except ValueError as e: + raise RuntimeError( + "`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; " + "please confirm that you've created a distributed Client and are submitting this computation through it." + ) from e + plugin: ShuffleWorkerPlugin | None = worker.plugins.get("shuffle") # type: ignore + if plugin is None: + raise RuntimeError( + f"The worker {worker.address} does not have a ShuffleExtension. " + "Is pandas installed on the worker?" + ) + return plugin + + +_BARRIER_PREFIX = "shuffle-barrier-" + + +def barrier_key(shuffle_id: ShuffleId) -> str: + return _BARRIER_PREFIX + shuffle_id + + +def id_from_key(key: str) -> ShuffleId: + assert key.startswith(_BARRIER_PREFIX) + return ShuffleId(key.replace(_BARRIER_PREFIX, "")) + + +class ShuffleType(Enum): + DATAFRAME = "DataFrameShuffle" + ARRAY_RECHUNK = "ArrayRechunk" + + +@dataclass(eq=False) +class ShuffleState(abc.ABC): + _run_id_iterator: ClassVar[itertools.count] = itertools.count(1) + + id: ShuffleId + run_id: int + output_workers: set[str] + participating_workers: set[str] + _archived_by: str | None = field(default=None, init=False) + + @abc.abstractmethod + def to_msg(self) -> dict[str, Any]: + """Transform the shuffle state into a JSON-serializable message""" + + def __str__(self) -> str: + return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>" + + def __hash__(self) -> int: + return hash(self.run_id) diff --git a/distributed/shuffle/_merge.py b/distributed/shuffle/_merge.py index 34c7ef3a37..42ce1ce081 100644 --- a/distributed/shuffle/_merge.py +++ b/distributed/shuffle/_merge.py @@ -9,13 +9,8 @@ from dask.highlevelgraph import HighLevelGraph from dask.layers import Layer -from distributed.shuffle._shuffle import ( - ShuffleId, - _get_worker_plugin, - barrier_key, - shuffle_barrier, - shuffle_transfer, -) +from distributed.shuffle._core import ShuffleId, barrier_key, get_worker_plugin +from distributed.shuffle._shuffle import shuffle_barrier, shuffle_transfer if TYPE_CHECKING: import pandas as pd @@ -167,7 +162,7 @@ def merge_unpack( ): from dask.dataframe.multi import merge_chunk - ext = _get_worker_plugin() + ext = get_worker_plugin() # If the partition is empty, it doesn't contain the hash column name left = ext.get_output_partition( shuffle_id_left, barrier_left, output_partition, meta=meta_left diff --git a/distributed/shuffle/_rechunk.py b/distributed/shuffle/_rechunk.py index a35b75d10a..0c6116646e 100644 --- a/distributed/shuffle/_rechunk.py +++ b/distributed/shuffle/_rechunk.py @@ -96,23 +96,36 @@ from __future__ import annotations -from typing import TYPE_CHECKING, NamedTuple +import pickle +from collections import defaultdict +from collections.abc import Callable, Sequence +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass +from io import BytesIO +from typing import TYPE_CHECKING, Any, ClassVar, NamedTuple import dask from dask.base import tokenize from dask.highlevelgraph import HighLevelGraph, MaterializedLayer +from distributed.core import PooledRPCCall from distributed.exceptions import Reschedule -from distributed.shuffle._shuffle import ( +from distributed.shuffle._core import ( + NDIndex, ShuffleId, + ShuffleRun, + ShuffleState, ShuffleType, - _get_worker_plugin, barrier_key, - shuffle_barrier, + get_worker_plugin, ) +from distributed.shuffle._limiter import ResourceLimiter +from distributed.shuffle._shuffle import shuffle_barrier +from distributed.sizeof import sizeof if TYPE_CHECKING: import numpy as np + import pandas as pd from typing_extensions import TypeAlias import dask.array as da @@ -120,7 +133,6 @@ ChunkedAxis: TypeAlias = tuple[float, ...] # chunks must either be an int or NaN ChunkedAxes: TypeAlias = tuple[ChunkedAxis, ...] -NDIndex: TypeAlias = tuple[int, ...] NDSlice: TypeAlias = tuple[slice, ...] @@ -132,7 +144,7 @@ def rechunk_transfer( old: ChunkedAxes, ) -> int: try: - return _get_worker_plugin().add_partition( + return get_worker_plugin().add_partition( input, partition_id=input_chunk, shuffle_id=id, @@ -148,7 +160,7 @@ def rechunk_unpack( id: ShuffleId, output_chunk: NDIndex, barrier_run_id: int ) -> np.ndarray: try: - return _get_worker_plugin().get_output_partition( + return get_worker_plugin().get_output_partition( id, barrier_run_id, output_chunk ) except Reschedule as e: @@ -252,3 +264,226 @@ def split_axes(old: ChunkedAxes, new: ChunkedAxes) -> SplitAxes: old_chunk.sort(key=lambda split: split.slice.start) axes.append(old_axis) return axes + + +def convert_chunk(data: bytes) -> np.ndarray: + import numpy as np + + from dask.array.core import concatenate3 + + file = BytesIO(data) + shards: dict[NDIndex, np.ndarray] = {} + + while file.tell() < len(data): + for index, shard in pickle.load(file): + shards[index] = shard + + subshape = [max(dim) + 1 for dim in zip(*shards.keys())] + assert len(shards) == np.prod(subshape) + + rec_cat_arg = np.empty(subshape, dtype="O") + for index, shard in shards.items(): + rec_cat_arg[tuple(index)] = shard + del data + del file + arrs = rec_cat_arg.tolist() + return concatenate3(arrs) + + +class ArrayRechunkRun(ShuffleRun[NDIndex, "np.ndarray"]): + """State for a single active rechunk execution + + This object is responsible for splitting, sending, receiving and combining + data shards. + + It is entirely agnostic to the distributed system and can perform a shuffle + with other `Shuffle` instances using `rpc` and `broadcast`. + + The user of this needs to guarantee that only `Shuffle`s of the same unique + `ShuffleID` interact. + + Parameters + ---------- + worker_for: + A mapping partition_id -> worker_address. + output_workers: + A set of all participating worker (addresses). + old: + Existing chunking of the array per dimension. + new: + Desired chunking of the array per dimension. + id: + A unique `ShuffleID` this belongs to. + run_id: + A unique identifier of the specific execution of the shuffle this belongs to. + local_address: + The local address this Shuffle can be contacted by using `rpc`. + directory: + The scratch directory to buffer data in. + executor: + Thread pool to use for offloading compute. + loop: + The event loop. + rpc: + A callable returning a PooledRPCCall to contact other Shuffle instances. + Typically a ConnectionPool. + scheduler: + A PooledRPCCall to to contact the scheduler. + memory_limiter_disk: + memory_limiter_comm: + A ``ResourceLimiter`` limiting the total amount of memory used in either + buffer. + """ + + def __init__( + self, + worker_for: dict[NDIndex, str], + output_workers: set, + old: ChunkedAxes, + new: ChunkedAxes, + id: ShuffleId, + run_id: int, + local_address: str, + directory: str, + executor: ThreadPoolExecutor, + rpc: Callable[[str], PooledRPCCall], + scheduler: PooledRPCCall, + memory_limiter_disk: ResourceLimiter, + memory_limiter_comms: ResourceLimiter, + ): + super().__init__( + id=id, + run_id=run_id, + output_workers=output_workers, + local_address=local_address, + directory=directory, + executor=executor, + rpc=rpc, + scheduler=scheduler, + memory_limiter_comms=memory_limiter_comms, + memory_limiter_disk=memory_limiter_disk, + ) + self.old = old + self.new = new + partitions_of = defaultdict(list) + for part, addr in worker_for.items(): + partitions_of[addr].append(part) + self.partitions_of = dict(partitions_of) + self.worker_for = worker_for + self.split_axes = split_axes(old, new) + + async def _receive(self, data: list[tuple[NDIndex, bytes]]) -> None: + self.raise_if_closed() + + filtered = [] + for d in data: + id, payload = d + if id in self.received: + continue + filtered.append(payload) + self.received.add(id) + self.total_recvd += sizeof(d) + del data + if not filtered: + return + try: + shards = await self.offload(self._repartition_shards, filtered) + del filtered + await self._write_to_disk(shards) + except Exception as e: + self._exception = e + raise + + def _repartition_shards(self, data: list[bytes]) -> dict[NDIndex, bytes]: + repartitioned: defaultdict[ + NDIndex, list[tuple[NDIndex, np.ndarray]] + ] = defaultdict(list) + for buffer in data: + for id, shard in pickle.loads(buffer): + repartitioned[id].append(shard) + return {k: pickle.dumps(v) for k, v in repartitioned.items()} + + async def add_partition(self, data: np.ndarray, partition_id: NDIndex) -> int: + self.raise_if_closed() + if self.transferred: + raise RuntimeError(f"Cannot add more partitions to {self}") + + def _() -> dict[str, tuple[NDIndex, bytes]]: + """Return a mapping of worker addresses to a tuple of input partition + IDs and shard data. + + + TODO: Overhaul! + As shard data, we serialize the payload together with the sub-index of the + slice within the new chunk. To assemble the new chunk from its shards, it + needs the sub-index to know where each shard belongs within the chunk. + Adding the sub-index into the serialized payload on the sender allows us to + write the serialized payload directly to disk on the receiver. + """ + out: dict[ + str, list[tuple[NDIndex, tuple[NDIndex, np.ndarray]]] + ] = defaultdict(list) + from itertools import product + + ndsplits = product( + *(axis[i] for axis, i in zip(self.split_axes, partition_id)) + ) + + for ndsplit in ndsplits: + chunk_index, shard_index, ndslice = zip(*ndsplit) + out[self.worker_for[chunk_index]].append( + (chunk_index, (shard_index, data[ndslice])) + ) + return {k: (partition_id, pickle.dumps(v)) for k, v in out.items()} + + out = await self.offload(_) + await self._write_to_comm(out) + return self.run_id + + async def get_output_partition( + self, partition_id: NDIndex, key: str, meta: pd.DataFrame | None = None + ) -> np.ndarray: + self.raise_if_closed() + assert meta is None + assert self.transferred, "`get_output_partition` called before barrier task" + + await self._ensure_output_worker(partition_id, key) + + await self.flush_receive() + + data = self._read_from_disk(partition_id) + + def _() -> np.ndarray: + return convert_chunk(data) + + return await self.offload(_) + + def _get_assigned_worker(self, id: NDIndex) -> str: + return self.worker_for[id] + + +@dataclass(eq=False) +class ArrayRechunkState(ShuffleState): + type: ClassVar[ShuffleType] = ShuffleType.ARRAY_RECHUNK + worker_for: dict[NDIndex, str] + old: ChunkedAxes + new: ChunkedAxes + + def to_msg(self) -> dict[str, Any]: + return { + "status": "OK", + "type": ArrayRechunkState.type, + "run_id": self.run_id, + "worker_for": self.worker_for, + "old": self.old, + "new": self.new, + "output_workers": self.output_workers, + } + + +def get_worker_for_hash_sharding( + output_partition: NDIndex, workers: Sequence[str] +) -> str: + """Get address of target worker for this output partition using hash sharding""" + i = hash(output_partition) % len(workers) + return workers[i] diff --git a/distributed/shuffle/_scheduler_plugin.py b/distributed/shuffle/_scheduler_plugin.py index ec670c0b07..9fdf8f10c2 100644 --- a/distributed/shuffle/_scheduler_plugin.py +++ b/distributed/shuffle/_scheduler_plugin.py @@ -1,25 +1,27 @@ from __future__ import annotations -import abc import contextlib -import itertools import logging from collections import defaultdict from collections.abc import Callable, Iterable, Sequence -from dataclasses import dataclass, field from functools import partial from itertools import product -from typing import TYPE_CHECKING, Any, ClassVar +from typing import TYPE_CHECKING, Any from distributed.diagnostics.plugin import SchedulerPlugin from distributed.protocol.pickle import dumps -from distributed.shuffle._rechunk import ChunkedAxes, NDIndex -from distributed.shuffle._shuffle import ( +from distributed.shuffle._core import ( ShuffleId, + ShuffleState, ShuffleType, barrier_key, id_from_key, ) +from distributed.shuffle._rechunk import ArrayRechunkState, get_worker_for_hash_sharding +from distributed.shuffle._shuffle import ( + DataFrameShuffleState, + get_worker_for_range_sharding, +) from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin if TYPE_CHECKING: @@ -34,63 +36,6 @@ logger = logging.getLogger(__name__) -@dataclass(eq=False) -class ShuffleState(abc.ABC): - _run_id_iterator: ClassVar[itertools.count] = itertools.count(1) - - id: ShuffleId - run_id: int - output_workers: set[str] - participating_workers: set[str] - _archived_by: str | None = field(default=None, init=False) - - @abc.abstractmethod - def to_msg(self) -> dict[str, Any]: - """Transform the shuffle state into a JSON-serializable message""" - - def __str__(self) -> str: - return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>" - - def __hash__(self) -> int: - return hash(self.run_id) - - -@dataclass(eq=False) -class DataFrameShuffleState(ShuffleState): - type: ClassVar[ShuffleType] = ShuffleType.DATAFRAME - worker_for: dict[int, str] - column: str - - def to_msg(self) -> dict[str, Any]: - return { - "status": "OK", - "type": DataFrameShuffleState.type, - "run_id": self.run_id, - "worker_for": self.worker_for, - "column": self.column, - "output_workers": self.output_workers, - } - - -@dataclass(eq=False) -class ArrayRechunkState(ShuffleState): - type: ClassVar[ShuffleType] = ShuffleType.ARRAY_RECHUNK - worker_for: dict[NDIndex, str] - old: ChunkedAxes - new: ChunkedAxes - - def to_msg(self) -> dict[str, Any]: - return { - "status": "OK", - "type": ArrayRechunkState.type, - "run_id": self.run_id, - "worker_for": self.worker_for, - "old": self.old, - "new": self.new, - "output_workers": self.output_workers, - } - - class ShuffleSchedulerPlugin(SchedulerPlugin): """ Shuffle plugin for the scheduler @@ -448,19 +393,3 @@ def restart(self, scheduler: Scheduler) -> None: self.heartbeats.clear() self._shuffles.clear() self._archived_by_stimulus.clear() - - -def get_worker_for_range_sharding( - npartitions: int, output_partition: int, workers: Sequence[str] -) -> str: - """Get address of target worker for this output partition using range sharding""" - i = len(workers) * output_partition // npartitions - return workers[i] - - -def get_worker_for_hash_sharding( - output_partition: NDIndex, workers: Sequence[str] -) -> str: - """Get address of target worker for this output partition using hash sharding""" - i = hash(output_partition) % len(workers) - return workers[i] diff --git a/distributed/shuffle/_shuffle.py b/distributed/shuffle/_shuffle.py index 62bb13c90b..4a4f4601ea 100644 --- a/distributed/shuffle/_shuffle.py +++ b/distributed/shuffle/_shuffle.py @@ -1,56 +1,50 @@ from __future__ import annotations import logging -from collections.abc import Iterable, Iterator -from enum import Enum -from typing import TYPE_CHECKING, Any, NewType, Union +from collections import defaultdict +from collections.abc import Callable, Iterable, Iterator, Sequence +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, ClassVar, Union + +import toolz from dask.base import tokenize from dask.highlevelgraph import HighLevelGraph from dask.layers import Layer +from distributed.core import PooledRPCCall from distributed.exceptions import Reschedule -from distributed.shuffle._arrow import check_dtype_support, check_minimal_arrow_version +from distributed.shuffle._arrow import ( + check_dtype_support, + check_minimal_arrow_version, + convert_partition, + list_of_buffers_to_table, + serialize_table, +) +from distributed.shuffle._core import ( + NDIndex, + ShuffleId, + ShuffleRun, + ShuffleState, + ShuffleType, + barrier_key, + get_worker_plugin, +) from distributed.shuffle._exceptions import ShuffleClosedError +from distributed.shuffle._limiter import ResourceLimiter +from distributed.sizeof import sizeof logger = logging.getLogger("distributed.shuffle") if TYPE_CHECKING: import pandas as pd + import pyarrow as pa # TODO import from typing (requires Python >=3.10) from typing_extensions import TypeAlias from dask.dataframe import DataFrame - # circular dependency - from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin - -ShuffleId = NewType("ShuffleId", str) - - -class ShuffleType(Enum): - DATAFRAME = "DataFrameShuffle" - ARRAY_RECHUNK = "ArrayRechunk" - - -def _get_worker_plugin() -> ShuffleWorkerPlugin: - from distributed import get_worker - - try: - worker = get_worker() - except ValueError as e: - raise RuntimeError( - "`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; " - "please confirm that you've created a distributed Client and are submitting this computation through it." - ) from e - plugin: ShuffleWorkerPlugin | None = worker.plugins.get("shuffle") # type: ignore - if plugin is None: - raise RuntimeError( - f"The worker {worker.address} does not have a ShuffleExtension. " - "Is pandas installed on the worker?" - ) - return plugin - def shuffle_transfer( input: pd.DataFrame, @@ -61,7 +55,7 @@ def shuffle_transfer( parts_out: set[int], ) -> int: try: - return _get_worker_plugin().add_partition( + return get_worker_plugin().add_partition( input, shuffle_id=id, type=ShuffleType.DATAFRAME, @@ -80,7 +74,7 @@ def shuffle_unpack( id: ShuffleId, output_partition: int, barrier_run_id: int, meta: pd.DataFrame ) -> pd.DataFrame: try: - return _get_worker_plugin().get_output_partition( + return get_worker_plugin().get_output_partition( id, barrier_run_id, output_partition, meta=meta ) except Reschedule as e: @@ -93,7 +87,7 @@ def shuffle_unpack( def shuffle_barrier(id: ShuffleId, run_ids: list[int]) -> int: try: - return _get_worker_plugin().barrier(id, run_ids) + return get_worker_plugin().barrier(id, run_ids) except Exception as e: raise RuntimeError(f"shuffle_barrier failed during shuffle {id}") from e @@ -270,13 +264,248 @@ def _construct_graph(self) -> _T_LowLevelGraph: return dsk -_BARRIER_PREFIX = "shuffle-barrier-" - - -def barrier_key(shuffle_id: ShuffleId) -> str: - return _BARRIER_PREFIX + shuffle_id - +def split_by_worker( + df: pd.DataFrame, + column: str, + worker_for: pd.Series, +) -> dict[Any, pa.Table]: + """ + Split data into many arrow batches, partitioned by destination worker + """ + import numpy as np + import pyarrow as pa + + df = df.merge( + right=worker_for.cat.codes.rename("_worker"), + left_on=column, + right_index=True, + how="inner", + ) + nrows = len(df) + if not nrows: + return {} + # assert len(df) == nrows # Not true if some outputs aren't wanted + # FIXME: If we do not preserve the index something is corrupting the + # bytestream such that it cannot be deserialized anymore + t = pa.Table.from_pandas(df, preserve_index=True) + t = t.sort_by("_worker") + codes = np.asarray(t["_worker"]) + t = t.drop(["_worker"]) + del df + + splits = np.where(codes[1:] != codes[:-1])[0] + 1 + splits = np.concatenate([[0], splits]) + + shards = [ + t.slice(offset=a, length=b - a) for a, b in toolz.sliding_window(2, splits) + ] + shards.append(t.slice(offset=splits[-1], length=None)) + + unique_codes = codes[splits] + out = { + # FIXME https://github.com/pandas-dev/pandas-stubs/issues/43 + worker_for.cat.categories[code]: shard + for code, shard in zip(unique_codes, shards) + } + assert sum(map(len, out.values())) == nrows + return out + + +def split_by_partition(t: pa.Table, column: str) -> dict[Any, pa.Table]: + """ + Split data into many arrow batches, partitioned by final partition + """ + import numpy as np + + partitions = t.select([column]).to_pandas()[column].unique() + partitions.sort() + t = t.sort_by(column) + + partition = np.asarray(t[column]) + splits = np.where(partition[1:] != partition[:-1])[0] + 1 + splits = np.concatenate([[0], splits]) + + shards = [ + t.slice(offset=a, length=b - a) for a, b in toolz.sliding_window(2, splits) + ] + shards.append(t.slice(offset=splits[-1], length=None)) + assert len(t) == sum(map(len, shards)) + assert len(partitions) == len(shards) + return dict(zip(partitions, shards)) + + +class DataFrameShuffleRun(ShuffleRun[int, "pd.DataFrame"]): + """State for a single active shuffle execution + + This object is responsible for splitting, sending, receiving and combining + data shards. + + It is entirely agnostic to the distributed system and can perform a shuffle + with other `Shuffle` instances using `rpc` and `broadcast`. + + The user of this needs to guarantee that only `Shuffle`s of the same unique + `ShuffleID` interact. + + Parameters + ---------- + worker_for: + A mapping partition_id -> worker_address. + output_workers: + A set of all participating worker (addresses). + column: + The data column we split the input partition by. + id: + A unique `ShuffleID` this belongs to. + run_id: + A unique identifier of the specific execution of the shuffle this belongs to. + local_address: + The local address this Shuffle can be contacted by using `rpc`. + directory: + The scratch directory to buffer data in. + executor: + Thread pool to use for offloading compute. + loop: + The event loop. + rpc: + A callable returning a PooledRPCCall to contact other Shuffle instances. + Typically a ConnectionPool. + scheduler: + A PooledRPCCall to to contact the scheduler. + memory_limiter_disk: + memory_limiter_comm: + A ``ResourceLimiter`` limiting the total amount of memory used in either + buffer. + """ -def id_from_key(key: str) -> ShuffleId: - assert key.startswith(_BARRIER_PREFIX) - return ShuffleId(key.replace(_BARRIER_PREFIX, "")) + def __init__( + self, + worker_for: dict[int, str], + output_workers: set, + column: str, + id: ShuffleId, + run_id: int, + local_address: str, + directory: str, + executor: ThreadPoolExecutor, + rpc: Callable[[str], PooledRPCCall], + scheduler: PooledRPCCall, + memory_limiter_disk: ResourceLimiter, + memory_limiter_comms: ResourceLimiter, + ): + import pandas as pd + + super().__init__( + id=id, + run_id=run_id, + output_workers=output_workers, + local_address=local_address, + directory=directory, + executor=executor, + rpc=rpc, + scheduler=scheduler, + memory_limiter_comms=memory_limiter_comms, + memory_limiter_disk=memory_limiter_disk, + ) + self.column = column + partitions_of = defaultdict(list) + for part, addr in worker_for.items(): + partitions_of[addr].append(part) + self.partitions_of = dict(partitions_of) + self.worker_for = pd.Series(worker_for, name="_workers").astype("category") + + async def receive(self, data: list[tuple[int, bytes]]) -> None: + await self._receive(data) + + async def _receive(self, data: list[tuple[int, bytes]]) -> None: + self.raise_if_closed() + + filtered = [] + for d in data: + if d[0] not in self.received: + filtered.append(d[1]) + self.received.add(d[0]) + self.total_recvd += sizeof(d) + del data + if not filtered: + return + try: + groups = await self.offload(self._repartition_buffers, filtered) + del filtered + await self._write_to_disk(groups) + except Exception as e: + self._exception = e + raise + + def _repartition_buffers(self, data: list[bytes]) -> dict[NDIndex, bytes]: + table = list_of_buffers_to_table(data) + groups = split_by_partition(table, self.column) + assert len(table) == sum(map(len, groups.values())) + del data + return {(k,): serialize_table(v) for k, v in groups.items()} + + async def add_partition(self, data: pd.DataFrame, partition_id: int) -> int: + self.raise_if_closed() + if self.transferred: + raise RuntimeError(f"Cannot add more partitions to {self}") + + def _() -> dict[str, tuple[int, bytes]]: + out = split_by_worker( + data, + self.column, + self.worker_for, + ) + out = {k: (partition_id, serialize_table(t)) for k, t in out.items()} + return out + + out = await self.offload(_) + await self._write_to_comm(out) + return self.run_id + + async def get_output_partition( + self, partition_id: int, key: str, meta: pd.DataFrame | None = None + ) -> pd.DataFrame: + self.raise_if_closed() + assert meta is not None + assert self.transferred, "`get_output_partition` called before barrier task" + + await self._ensure_output_worker(partition_id, key) + + await self.flush_receive() + try: + data = self._read_from_disk((partition_id,)) + + def _() -> pd.DataFrame: + return convert_partition(data, meta) # type: ignore + + out = await self.offload(_) + except KeyError: + out = meta.copy() + return out + + def _get_assigned_worker(self, id: int) -> str: + return self.worker_for[id] + + +@dataclass(eq=False) +class DataFrameShuffleState(ShuffleState): + type: ClassVar[ShuffleType] = ShuffleType.DATAFRAME + worker_for: dict[int, str] + column: str + + def to_msg(self) -> dict[str, Any]: + return { + "status": "OK", + "type": DataFrameShuffleState.type, + "run_id": self.run_id, + "worker_for": self.worker_for, + "column": self.column, + "output_workers": self.output_workers, + } + + +def get_worker_for_range_sharding( + npartitions: int, output_partition: int, workers: Sequence[str] +) -> str: + """Get address of target worker for this output partition using range sharding""" + i = len(workers) * output_partition // npartitions + return workers[i] diff --git a/distributed/shuffle/_worker_plugin.py b/distributed/shuffle/_worker_plugin.py index 0f2fcf415e..7340b8fe6e 100644 --- a/distributed/shuffle/_worker_plugin.py +++ b/distributed/shuffle/_worker_plugin.py @@ -1,564 +1,31 @@ from __future__ import annotations -import abc import asyncio -import contextlib import logging import os -import pickle -import time -from collections import defaultdict -from collections.abc import Callable, Iterator from concurrent.futures import ThreadPoolExecutor -from io import BytesIO -from typing import TYPE_CHECKING, Any, Generic, TypeVar, overload - -import toolz +from typing import TYPE_CHECKING, Any, overload from dask.context import thread_state from dask.utils import parse_bytes -from distributed.core import PooledRPCCall from distributed.diagnostics.plugin import WorkerPlugin -from distributed.exceptions import Reschedule -from distributed.protocol import to_serialize -from distributed.shuffle._arrow import ( - convert_partition, - list_of_buffers_to_table, - serialize_table, -) -from distributed.shuffle._comms import CommShardsBuffer -from distributed.shuffle._disk import DiskShardsBuffer +from distributed.shuffle._core import NDIndex, ShuffleId, ShuffleRun, ShuffleType from distributed.shuffle._exceptions import ShuffleClosedError from distributed.shuffle._limiter import ResourceLimiter -from distributed.shuffle._rechunk import ChunkedAxes, NDIndex, split_axes -from distributed.shuffle._shuffle import ShuffleId, ShuffleType -from distributed.sizeof import sizeof +from distributed.shuffle._rechunk import ArrayRechunkRun +from distributed.shuffle._shuffle import DataFrameShuffleRun from distributed.utils import log_errors, sync if TYPE_CHECKING: # TODO import from typing (requires Python >=3.10) - import numpy as np import pandas as pd - import pyarrow as pa from distributed.worker import Worker -T_partition_id = TypeVar("T_partition_id") -T_partition_type = TypeVar("T_partition_type") -T = TypeVar("T") - logger = logging.getLogger(__name__) -class ShuffleRun(Generic[T_partition_id, T_partition_type]): - def __init__( - self, - id: ShuffleId, - run_id: int, - output_workers: set[str], - local_address: str, - directory: str, - executor: ThreadPoolExecutor, - rpc: Callable[[str], PooledRPCCall], - scheduler: PooledRPCCall, - memory_limiter_disk: ResourceLimiter, - memory_limiter_comms: ResourceLimiter, - ): - self.id = id - self.run_id = run_id - self.output_workers = output_workers - self.local_address = local_address - self.executor = executor - self.rpc = rpc - self.scheduler = scheduler - self.closed = False - - self._disk_buffer = DiskShardsBuffer( - directory=directory, - memory_limiter=memory_limiter_disk, - ) - - self._comm_buffer = CommShardsBuffer( - send=self.send, memory_limiter=memory_limiter_comms - ) - # TODO: reduce number of connections to number of workers - # MultiComm.max_connections = min(10, n_workers) - - self.diagnostics: dict[str, float] = defaultdict(float) - self.transferred = False - self.received: set[T_partition_id] = set() - self.total_recvd = 0 - self.start_time = time.time() - self._exception: Exception | None = None - self._closed_event = asyncio.Event() - - def __repr__(self) -> str: - return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>" - - def __str__(self) -> str: - return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}" - - def __hash__(self) -> int: - return self.run_id - - @contextlib.contextmanager - def time(self, name: str) -> Iterator[None]: - start = time.time() - yield - stop = time.time() - self.diagnostics[name] += stop - start - - async def barrier(self) -> None: - self.raise_if_closed() - # TODO: Consider broadcast pinging once when the shuffle starts to warm - # up the comm pool on scheduler side - await self.scheduler.shuffle_barrier(id=self.id, run_id=self.run_id) - - async def send( - self, address: str, shards: list[tuple[T_partition_id, bytes]] - ) -> None: - self.raise_if_closed() - return await self.rpc(address).shuffle_receive( - data=to_serialize(shards), - shuffle_id=self.id, - run_id=self.run_id, - ) - - async def offload(self, func: Callable[..., T], *args: Any) -> T: - self.raise_if_closed() - with self.time("cpu"): - return await asyncio.get_running_loop().run_in_executor( - self.executor, - func, - *args, - ) - - def heartbeat(self) -> dict[str, Any]: - comm_heartbeat = self._comm_buffer.heartbeat() - comm_heartbeat["read"] = self.total_recvd - return { - "disk": self._disk_buffer.heartbeat(), - "comm": comm_heartbeat, - "diagnostics": self.diagnostics, - "start": self.start_time, - } - - async def _write_to_comm( - self, data: dict[str, tuple[T_partition_id, bytes]] - ) -> None: - self.raise_if_closed() - await self._comm_buffer.write(data) - - async def _write_to_disk(self, data: dict[NDIndex, bytes]) -> None: - self.raise_if_closed() - await self._disk_buffer.write( - {"_".join(str(i) for i in k): v for k, v in data.items()} - ) - - def raise_if_closed(self) -> None: - if self.closed: - if self._exception: - raise self._exception - raise ShuffleClosedError(f"{self} has already been closed") - - async def inputs_done(self) -> None: - self.raise_if_closed() - self.transferred = True - await self._flush_comm() - try: - self._comm_buffer.raise_on_exception() - except Exception as e: - self._exception = e - raise - - async def _flush_comm(self) -> None: - self.raise_if_closed() - await self._comm_buffer.flush() - - async def flush_receive(self) -> None: - self.raise_if_closed() - await self._disk_buffer.flush() - - async def close(self) -> None: - if self.closed: # pragma: no cover - await self._closed_event.wait() - return - - self.closed = True - await self._comm_buffer.close() - await self._disk_buffer.close() - self._closed_event.set() - - def fail(self, exception: Exception) -> None: - if not self.closed: - self._exception = exception - - def _read_from_disk(self, id: NDIndex) -> bytes: - self.raise_if_closed() - data: bytes = self._disk_buffer.read("_".join(str(i) for i in id)) - return data - - async def receive(self, data: list[tuple[T_partition_id, bytes]]) -> None: - await self._receive(data) - - async def _ensure_output_worker(self, i: T_partition_id, key: str) -> None: - assigned_worker = self._get_assigned_worker(i) - - if assigned_worker != self.local_address: - result = await self.scheduler.shuffle_restrict_task( - id=self.id, run_id=self.run_id, key=key, worker=assigned_worker - ) - if result["status"] == "error": - raise RuntimeError(result["message"]) - assert result["status"] == "OK" - raise Reschedule() - - @abc.abstractmethod - def _get_assigned_worker(self, i: T_partition_id) -> str: - """Get the address of the worker assigned to the output partition""" - - @abc.abstractmethod - async def _receive(self, data: list[tuple[T_partition_id, bytes]]) -> None: - """Receive shards belonging to output partitions of this shuffle run""" - - @abc.abstractmethod - async def add_partition( - self, data: T_partition_type, partition_id: T_partition_id - ) -> int: - """Add an input partition to the shuffle run""" - - @abc.abstractmethod - async def get_output_partition( - self, partition_id: T_partition_id, key: str, meta: pd.DataFrame | None = None - ) -> T_partition_type: - """Get an output partition to the shuffle run""" - - -class ArrayRechunkRun(ShuffleRun[NDIndex, "np.ndarray"]): - """State for a single active rechunk execution - - This object is responsible for splitting, sending, receiving and combining - data shards. - - It is entirely agnostic to the distributed system and can perform a shuffle - with other `Shuffle` instances using `rpc` and `broadcast`. - - The user of this needs to guarantee that only `Shuffle`s of the same unique - `ShuffleID` interact. - - Parameters - ---------- - worker_for: - A mapping partition_id -> worker_address. - output_workers: - A set of all participating worker (addresses). - old: - Existing chunking of the array per dimension. - new: - Desired chunking of the array per dimension. - id: - A unique `ShuffleID` this belongs to. - run_id: - A unique identifier of the specific execution of the shuffle this belongs to. - local_address: - The local address this Shuffle can be contacted by using `rpc`. - directory: - The scratch directory to buffer data in. - executor: - Thread pool to use for offloading compute. - loop: - The event loop. - rpc: - A callable returning a PooledRPCCall to contact other Shuffle instances. - Typically a ConnectionPool. - scheduler: - A PooledRPCCall to to contact the scheduler. - memory_limiter_disk: - memory_limiter_comm: - A ``ResourceLimiter`` limiting the total amount of memory used in either - buffer. - """ - - def __init__( - self, - worker_for: dict[NDIndex, str], - output_workers: set, - old: ChunkedAxes, - new: ChunkedAxes, - id: ShuffleId, - run_id: int, - local_address: str, - directory: str, - executor: ThreadPoolExecutor, - rpc: Callable[[str], PooledRPCCall], - scheduler: PooledRPCCall, - memory_limiter_disk: ResourceLimiter, - memory_limiter_comms: ResourceLimiter, - ): - super().__init__( - id=id, - run_id=run_id, - output_workers=output_workers, - local_address=local_address, - directory=directory, - executor=executor, - rpc=rpc, - scheduler=scheduler, - memory_limiter_comms=memory_limiter_comms, - memory_limiter_disk=memory_limiter_disk, - ) - self.old = old - self.new = new - partitions_of = defaultdict(list) - for part, addr in worker_for.items(): - partitions_of[addr].append(part) - self.partitions_of = dict(partitions_of) - self.worker_for = worker_for - self.split_axes = split_axes(old, new) - - async def _receive(self, data: list[tuple[NDIndex, bytes]]) -> None: - self.raise_if_closed() - - filtered = [] - for d in data: - id, payload = d - if id in self.received: - continue - filtered.append(payload) - self.received.add(id) - self.total_recvd += sizeof(d) - del data - if not filtered: - return - try: - shards = await self.offload(self._repartition_shards, filtered) - del filtered - await self._write_to_disk(shards) - except Exception as e: - self._exception = e - raise - - def _repartition_shards(self, data: list[bytes]) -> dict[NDIndex, bytes]: - repartitioned: defaultdict[ - NDIndex, list[tuple[NDIndex, np.ndarray]] - ] = defaultdict(list) - for buffer in data: - for id, shard in pickle.loads(buffer): - repartitioned[id].append(shard) - return {k: pickle.dumps(v) for k, v in repartitioned.items()} - - async def add_partition(self, data: np.ndarray, partition_id: NDIndex) -> int: - self.raise_if_closed() - if self.transferred: - raise RuntimeError(f"Cannot add more partitions to {self}") - - def _() -> dict[str, tuple[NDIndex, bytes]]: - """Return a mapping of worker addresses to a tuple of input partition - IDs and shard data. - - - TODO: Overhaul! - As shard data, we serialize the payload together with the sub-index of the - slice within the new chunk. To assemble the new chunk from its shards, it - needs the sub-index to know where each shard belongs within the chunk. - Adding the sub-index into the serialized payload on the sender allows us to - write the serialized payload directly to disk on the receiver. - """ - out: dict[ - str, list[tuple[NDIndex, tuple[NDIndex, np.ndarray]]] - ] = defaultdict(list) - from itertools import product - - ndsplits = product( - *(axis[i] for axis, i in zip(self.split_axes, partition_id)) - ) - - for ndsplit in ndsplits: - chunk_index, shard_index, ndslice = zip(*ndsplit) - out[self.worker_for[chunk_index]].append( - (chunk_index, (shard_index, data[ndslice])) - ) - return {k: (partition_id, pickle.dumps(v)) for k, v in out.items()} - - out = await self.offload(_) - await self._write_to_comm(out) - return self.run_id - - async def get_output_partition( - self, partition_id: NDIndex, key: str, meta: pd.DataFrame | None = None - ) -> np.ndarray: - self.raise_if_closed() - assert meta is None - assert self.transferred, "`get_output_partition` called before barrier task" - - await self._ensure_output_worker(partition_id, key) - - await self.flush_receive() - - data = self._read_from_disk(partition_id) - - def _() -> np.ndarray: - return convert_chunk(data) - - return await self.offload(_) - - def _get_assigned_worker(self, id: NDIndex) -> str: - return self.worker_for[id] - - -class DataFrameShuffleRun(ShuffleRun[int, "pd.DataFrame"]): - """State for a single active shuffle execution - - This object is responsible for splitting, sending, receiving and combining - data shards. - - It is entirely agnostic to the distributed system and can perform a shuffle - with other `Shuffle` instances using `rpc` and `broadcast`. - - The user of this needs to guarantee that only `Shuffle`s of the same unique - `ShuffleID` interact. - - Parameters - ---------- - worker_for: - A mapping partition_id -> worker_address. - output_workers: - A set of all participating worker (addresses). - column: - The data column we split the input partition by. - id: - A unique `ShuffleID` this belongs to. - run_id: - A unique identifier of the specific execution of the shuffle this belongs to. - local_address: - The local address this Shuffle can be contacted by using `rpc`. - directory: - The scratch directory to buffer data in. - executor: - Thread pool to use for offloading compute. - loop: - The event loop. - rpc: - A callable returning a PooledRPCCall to contact other Shuffle instances. - Typically a ConnectionPool. - scheduler: - A PooledRPCCall to to contact the scheduler. - memory_limiter_disk: - memory_limiter_comm: - A ``ResourceLimiter`` limiting the total amount of memory used in either - buffer. - """ - - def __init__( - self, - worker_for: dict[int, str], - output_workers: set, - column: str, - id: ShuffleId, - run_id: int, - local_address: str, - directory: str, - executor: ThreadPoolExecutor, - rpc: Callable[[str], PooledRPCCall], - scheduler: PooledRPCCall, - memory_limiter_disk: ResourceLimiter, - memory_limiter_comms: ResourceLimiter, - ): - import pandas as pd - - super().__init__( - id=id, - run_id=run_id, - output_workers=output_workers, - local_address=local_address, - directory=directory, - executor=executor, - rpc=rpc, - scheduler=scheduler, - memory_limiter_comms=memory_limiter_comms, - memory_limiter_disk=memory_limiter_disk, - ) - self.column = column - partitions_of = defaultdict(list) - for part, addr in worker_for.items(): - partitions_of[addr].append(part) - self.partitions_of = dict(partitions_of) - self.worker_for = pd.Series(worker_for, name="_workers").astype("category") - - async def receive(self, data: list[tuple[int, bytes]]) -> None: - await self._receive(data) - - async def _receive(self, data: list[tuple[int, bytes]]) -> None: - self.raise_if_closed() - - filtered = [] - for d in data: - if d[0] not in self.received: - filtered.append(d[1]) - self.received.add(d[0]) - self.total_recvd += sizeof(d) - del data - if not filtered: - return - try: - groups = await self.offload(self._repartition_buffers, filtered) - del filtered - await self._write_to_disk(groups) - except Exception as e: - self._exception = e - raise - - def _repartition_buffers(self, data: list[bytes]) -> dict[NDIndex, bytes]: - table = list_of_buffers_to_table(data) - groups = split_by_partition(table, self.column) - assert len(table) == sum(map(len, groups.values())) - del data - return {(k,): serialize_table(v) for k, v in groups.items()} - - async def add_partition(self, data: pd.DataFrame, partition_id: int) -> int: - self.raise_if_closed() - if self.transferred: - raise RuntimeError(f"Cannot add more partitions to {self}") - - def _() -> dict[str, tuple[int, bytes]]: - out = split_by_worker( - data, - self.column, - self.worker_for, - ) - out = {k: (partition_id, serialize_table(t)) for k, t in out.items()} - return out - - out = await self.offload(_) - await self._write_to_comm(out) - return self.run_id - - async def get_output_partition( - self, partition_id: int, key: str, meta: pd.DataFrame | None = None - ) -> pd.DataFrame: - self.raise_if_closed() - assert meta is not None - assert self.transferred, "`get_output_partition` called before barrier task" - - await self._ensure_output_worker(partition_id, key) - - await self.flush_receive() - try: - data = self._read_from_disk((partition_id,)) - - def _() -> pd.DataFrame: - return convert_partition(data, meta) # type: ignore - - out = await self.offload(_) - except KeyError: - out = meta.copy() - return out - - def _get_assigned_worker(self, id: int) -> str: - return self.worker_for[id] - - class ShuffleWorkerPlugin(WorkerPlugin): """Interface between a Worker and a Shuffle. @@ -974,97 +441,3 @@ def get_output_partition( key=key, meta=meta, ) - - -def split_by_worker( - df: pd.DataFrame, - column: str, - worker_for: pd.Series, -) -> dict[Any, pa.Table]: - """ - Split data into many arrow batches, partitioned by destination worker - """ - import numpy as np - import pyarrow as pa - - df = df.merge( - right=worker_for.cat.codes.rename("_worker"), - left_on=column, - right_index=True, - how="inner", - ) - nrows = len(df) - if not nrows: - return {} - # assert len(df) == nrows # Not true if some outputs aren't wanted - # FIXME: If we do not preserve the index something is corrupting the - # bytestream such that it cannot be deserialized anymore - t = pa.Table.from_pandas(df, preserve_index=True) - t = t.sort_by("_worker") - codes = np.asarray(t["_worker"]) - t = t.drop(["_worker"]) - del df - - splits = np.where(codes[1:] != codes[:-1])[0] + 1 - splits = np.concatenate([[0], splits]) - - shards = [ - t.slice(offset=a, length=b - a) for a, b in toolz.sliding_window(2, splits) - ] - shards.append(t.slice(offset=splits[-1], length=None)) - - unique_codes = codes[splits] - out = { - # FIXME https://github.com/pandas-dev/pandas-stubs/issues/43 - worker_for.cat.categories[code]: shard - for code, shard in zip(unique_codes, shards) - } - assert sum(map(len, out.values())) == nrows - return out - - -def split_by_partition(t: pa.Table, column: str) -> dict[Any, pa.Table]: - """ - Split data into many arrow batches, partitioned by final partition - """ - import numpy as np - - partitions = t.select([column]).to_pandas()[column].unique() - partitions.sort() - t = t.sort_by(column) - - partition = np.asarray(t[column]) - splits = np.where(partition[1:] != partition[:-1])[0] + 1 - splits = np.concatenate([[0], splits]) - - shards = [ - t.slice(offset=a, length=b - a) for a, b in toolz.sliding_window(2, splits) - ] - shards.append(t.slice(offset=splits[-1], length=None)) - assert len(t) == sum(map(len, shards)) - assert len(partitions) == len(shards) - return dict(zip(partitions, shards)) - - -def convert_chunk(data: bytes) -> np.ndarray: - import numpy as np - - from dask.array.core import concatenate3 - - file = BytesIO(data) - shards: dict[NDIndex, np.ndarray] = {} - - while file.tell() < len(data): - for index, shard in pickle.load(file): - shards[index] = shard - - subshape = [max(dim) + 1 for dim in zip(*shards.keys())] - assert len(shards) == np.prod(subshape) - - rec_cat_arg = np.empty(subshape, dtype="O") - for index, shard in shards.items(): - rec_cat_arg[tuple(index)] = shard - del data - del file - arrs = rec_cat_arg.tolist() - return concatenate3(arrs) diff --git a/distributed/shuffle/tests/test_rechunk.py b/distributed/shuffle/tests/test_rechunk.py index 483ef7fc6d..62f3b1b9c9 100644 --- a/distributed/shuffle/tests/test_rechunk.py +++ b/distributed/shuffle/tests/test_rechunk.py @@ -17,11 +17,14 @@ from dask.array.rechunk import normalize_chunks, rechunk from dask.array.utils import assert_eq +from distributed.shuffle._core import ShuffleId from distributed.shuffle._limiter import ResourceLimiter -from distributed.shuffle._rechunk import Split, split_axes -from distributed.shuffle._scheduler_plugin import get_worker_for_hash_sharding -from distributed.shuffle._shuffle import ShuffleId -from distributed.shuffle._worker_plugin import ArrayRechunkRun +from distributed.shuffle._rechunk import ( + ArrayRechunkRun, + Split, + get_worker_for_hash_sharding, + split_axes, +) from distributed.shuffle.tests.utils import AbstractShuffleTestPool from distributed.utils_test import gen_cluster, gen_test, raises_with_cause diff --git a/distributed/shuffle/tests/test_shuffle.py b/distributed/shuffle/tests/test_shuffle.py index b5540892a0..8b6c00a211 100644 --- a/distributed/shuffle/tests/test_shuffle.py +++ b/distributed/shuffle/tests/test_shuffle.py @@ -15,6 +15,7 @@ import pytest +from distributed.shuffle._core import ShuffleId, ShuffleRun, barrier_key from distributed.worker import Status pd = pytest.importorskip("pandas") @@ -27,22 +28,20 @@ from distributed.client import Client from distributed.scheduler import KilledWorker, Scheduler from distributed.scheduler import TaskState as SchedulerTaskState -from distributed.shuffle._arrow import serialize_table -from distributed.shuffle._limiter import ResourceLimiter -from distributed.shuffle._scheduler_plugin import ( - ShuffleSchedulerPlugin, - get_worker_for_range_sharding, -) -from distributed.shuffle._shuffle import ShuffleId, barrier_key -from distributed.shuffle._worker_plugin import ( - DataFrameShuffleRun, - ShuffleRun, - ShuffleWorkerPlugin, +from distributed.shuffle._arrow import ( convert_partition, list_of_buffers_to_table, + serialize_table, +) +from distributed.shuffle._limiter import ResourceLimiter +from distributed.shuffle._scheduler_plugin import ShuffleSchedulerPlugin +from distributed.shuffle._shuffle import ( + DataFrameShuffleRun, + get_worker_for_range_sharding, split_by_partition, split_by_worker, ) +from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin from distributed.shuffle.tests.utils import ( AbstractShuffleTestPool, invoke_annotation_chaos, diff --git a/distributed/shuffle/tests/test_shuffle_plugins.py b/distributed/shuffle/tests/test_shuffle_plugins.py index aad806cb93..d8fec10834 100644 --- a/distributed/shuffle/tests/test_shuffle_plugins.py +++ b/distributed/shuffle/tests/test_shuffle_plugins.py @@ -4,18 +4,17 @@ import pytest -pd = pytest.importorskip("pandas") -dd = pytest.importorskip("dask.dataframe") - -from distributed.shuffle._scheduler_plugin import ( - ShuffleSchedulerPlugin, +from distributed.shuffle._shuffle import ( get_worker_for_range_sharding, -) -from distributed.shuffle._worker_plugin import ( - ShuffleWorkerPlugin, split_by_partition, split_by_worker, ) + +pd = pytest.importorskip("pandas") +dd = pytest.importorskip("dask.dataframe") + +from distributed.shuffle._scheduler_plugin import ShuffleSchedulerPlugin +from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin from distributed.utils_test import gen_cluster diff --git a/distributed/shuffle/tests/utils.py b/distributed/shuffle/tests/utils.py index c5f8e30f56..0fd25a0bed 100644 --- a/distributed/shuffle/tests/utils.py +++ b/distributed/shuffle/tests/utils.py @@ -8,8 +8,7 @@ from distributed.core import PooledRPCCall from distributed.diagnostics.plugin import SchedulerPlugin from distributed.scheduler import Scheduler, TaskStateState -from distributed.shuffle._shuffle import ShuffleId -from distributed.shuffle._worker_plugin import ShuffleRun +from distributed.shuffle._core import ShuffleId, ShuffleRun class PooledRPCShuffle(PooledRPCCall):