New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make P2P shuffle extensible #8096
Changes from 5 commits
e0adc3a
82b677d
2cebf93
368f52e
7283a56
8251678
b76bd7a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,7 +10,8 @@ | |
from concurrent.futures import ThreadPoolExecutor | ||
from dataclasses import dataclass, field | ||
from enum import Enum | ||
from typing import TYPE_CHECKING, Any, ClassVar, Generic, NewType, TypeVar | ||
from functools import partial | ||
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar | ||
|
||
from distributed.core import PooledRPCCall | ||
from distributed.exceptions import Reschedule | ||
|
@@ -21,27 +22,28 @@ | |
from distributed.shuffle._limiter import ResourceLimiter | ||
|
||
if TYPE_CHECKING: | ||
import pandas as pd | ||
# TODO import from typing (requires Python >=3.10) | ||
from typing_extensions import TypeAlias | ||
|
||
# avoid circular dependencies | ||
from distributed.shuffle._scheduler_plugin import ShuffleSchedulerPlugin | ||
|
||
# circular dependencies | ||
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin | ||
|
||
ShuffleId = NewType("ShuffleId", str) | ||
NDIndex: TypeAlias = tuple[int, ...] | ||
|
||
|
||
_T_partition_id = TypeVar("_T_partition_id") | ||
_T_partition_type = TypeVar("_T_partition_type") | ||
_T = TypeVar("_T") | ||
|
||
NDIndex: TypeAlias = tuple[int, ...] | ||
|
||
ShuffleId = NewType("ShuffleId", str) | ||
|
||
|
||
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]): | ||
def __init__( | ||
self, | ||
id: ShuffleId, | ||
run_id: int, | ||
output_workers: set[str], | ||
local_address: str, | ||
directory: str, | ||
executor: ThreadPoolExecutor, | ||
|
@@ -52,7 +54,6 @@ | |
): | ||
self.id = id | ||
self.run_id = run_id | ||
self.output_workers = output_workers | ||
self.local_address = local_address | ||
self.executor = executor | ||
self.rpc = rpc | ||
|
@@ -215,7 +216,7 @@ | |
|
||
@abc.abstractmethod | ||
async def get_output_partition( | ||
self, partition_id: _T_partition_id, key: str, meta: pd.DataFrame | None = None | ||
self, partition_id: _T_partition_id, key: str, **kwargs: Any | ||
) -> _T_partition_type: | ||
"""Get an output partition to the shuffle run""" | ||
|
||
|
@@ -230,13 +231,12 @@ | |
"`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; " | ||
"please confirm that you've created a distributed Client and are submitting this computation through it." | ||
) from e | ||
plugin: ShuffleWorkerPlugin | None = worker.plugins.get("shuffle") # type: ignore | ||
if plugin is None: | ||
try: | ||
return worker.plugins["shuffle"] # type: ignore | ||
except KeyError as e: | ||
raise RuntimeError( | ||
f"The worker {worker.address} does not have a ShuffleExtension. " | ||
"Is pandas installed on the worker?" | ||
) | ||
return plugin | ||
f"The worker {worker.address} does not have a P2P shuffle plugin." | ||
) from e | ||
|
||
|
||
_BARRIER_PREFIX = "shuffle-barrier-" | ||
|
@@ -256,19 +256,60 @@ | |
ARRAY_RECHUNK = "ArrayRechunk" | ||
|
||
|
||
@dataclass(eq=False) | ||
class ShuffleState(abc.ABC): | ||
_run_id_iterator: ClassVar[itertools.count] = itertools.count(1) | ||
@dataclass(frozen=True) | ||
class ShuffleRunSpec(Generic[_T_partition_id]): | ||
run_id: int = field(init=False, default_factory=partial(next, itertools.count(1))) # type: ignore | ||
spec: ShuffleSpec | ||
worker_for: dict[_T_partition_id, str] | ||
|
||
@property | ||
def id(self) -> ShuffleId: | ||
return self.spec.id | ||
|
||
|
||
@dataclass(frozen=True) | ||
class ShuffleSpec(abc.ABC, Generic[_T_partition_id]): | ||
id: ShuffleId | ||
hendrikmakait marked this conversation as resolved.
Show resolved
Hide resolved
|
||
run_id: int | ||
output_workers: set[str] | ||
|
||
def create_new_run( | ||
self, | ||
plugin: ShuffleSchedulerPlugin, | ||
) -> SchedulerShuffleState: | ||
worker_for = self._pin_output_workers(plugin) | ||
return SchedulerShuffleState( | ||
run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for), | ||
participating_workers=set(worker_for.values()), | ||
) | ||
|
||
@abc.abstractmethod | ||
def _pin_output_workers( | ||
self, plugin: ShuffleSchedulerPlugin | ||
) -> dict[_T_partition_id, str]: | ||
"""TODO""" | ||
|
||
@abc.abstractmethod | ||
def initialize_run_on_worker( | ||
self, | ||
run_id: int, | ||
worker_for: dict[_T_partition_id, str], | ||
plugin: ShuffleWorkerPlugin, | ||
) -> ShuffleRun: | ||
"""TODO""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. All you need to do now is implement concrete subclasses of |
||
|
||
|
||
@dataclass(eq=False) | ||
class SchedulerShuffleState(Generic[_T_partition_id]): | ||
run_spec: ShuffleRunSpec | ||
participating_workers: set[str] | ||
_archived_by: str | None = field(default=None, init=False) | ||
|
||
@abc.abstractmethod | ||
def to_msg(self) -> dict[str, Any]: | ||
"""Transform the shuffle state into a JSON-serializable message""" | ||
@property | ||
def id(self) -> ShuffleId: | ||
return self.run_spec.id | ||
|
||
@property | ||
def run_id(self) -> int: | ||
return self.run_spec.run_id | ||
|
||
def __str__(self) -> str: | ||
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -96,13 +96,15 @@ | |
|
||
from __future__ import annotations | ||
|
||
import os | ||
import pickle | ||
from collections import defaultdict | ||
from collections.abc import Callable, Sequence | ||
from concurrent.futures import ThreadPoolExecutor | ||
from dataclasses import dataclass | ||
from io import BytesIO | ||
from typing import TYPE_CHECKING, Any, ClassVar, NamedTuple | ||
from itertools import product | ||
from typing import TYPE_CHECKING, Any, NamedTuple | ||
|
||
import dask | ||
from dask.base import tokenize | ||
|
@@ -114,23 +116,21 @@ | |
NDIndex, | ||
ShuffleId, | ||
ShuffleRun, | ||
ShuffleState, | ||
ShuffleType, | ||
barrier_key, | ||
ShuffleSpec, | ||
get_worker_plugin, | ||
) | ||
from distributed.shuffle._limiter import ResourceLimiter | ||
from distributed.shuffle._shuffle import shuffle_barrier | ||
from distributed.shuffle._scheduler_plugin import ShuffleSchedulerPlugin | ||
from distributed.shuffle._shuffle import barrier_key, shuffle_barrier | ||
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin | ||
from distributed.sizeof import sizeof | ||
|
||
if TYPE_CHECKING: | ||
import numpy as np | ||
import pandas as pd | ||
from typing_extensions import TypeAlias | ||
|
||
import dask.array as da | ||
|
||
|
||
ChunkedAxis: TypeAlias = tuple[float, ...] # chunks must either be an int or NaN | ||
ChunkedAxes: TypeAlias = tuple[ChunkedAxis, ...] | ||
NDSlice: TypeAlias = tuple[slice, ...] | ||
|
@@ -147,10 +147,7 @@ | |
return get_worker_plugin().add_partition( | ||
input, | ||
partition_id=input_chunk, | ||
shuffle_id=id, | ||
type=ShuffleType.ARRAY_RECHUNK, | ||
new=new, | ||
old=old, | ||
spec=ArrayRechunkSpec(id=id, new=new, old=old), | ||
) | ||
except Exception as e: | ||
raise RuntimeError(f"rechunk_transfer failed during shuffle {id}") from e | ||
|
@@ -304,10 +301,9 @@ | |
|
||
Parameters | ||
---------- | ||
# FIXME | ||
worker_for: | ||
A mapping partition_id -> worker_address. | ||
output_workers: | ||
A set of all participating worker (addresses). | ||
old: | ||
Existing chunking of the array per dimension. | ||
new: | ||
|
@@ -338,7 +334,6 @@ | |
def __init__( | ||
self, | ||
worker_for: dict[NDIndex, str], | ||
output_workers: set, | ||
old: ChunkedAxes, | ||
new: ChunkedAxes, | ||
id: ShuffleId, | ||
|
@@ -354,7 +349,6 @@ | |
super().__init__( | ||
id=id, | ||
run_id=run_id, | ||
output_workers=output_workers, | ||
local_address=local_address, | ||
directory=directory, | ||
executor=executor, | ||
|
@@ -403,7 +397,9 @@ | |
repartitioned[id].append(shard) | ||
return {k: pickle.dumps(v) for k, v in repartitioned.items()} | ||
|
||
async def add_partition(self, data: np.ndarray, partition_id: NDIndex) -> int: | ||
async def add_partition( | ||
self, data: np.ndarray, partition_id: NDIndex, **kwargs: Any | ||
) -> int: | ||
self.raise_if_closed() | ||
if self.transferred: | ||
raise RuntimeError(f"Cannot add more partitions to {self}") | ||
|
@@ -441,47 +437,58 @@ | |
return self.run_id | ||
|
||
async def get_output_partition( | ||
self, partition_id: NDIndex, key: str, meta: pd.DataFrame | None = None | ||
self, partition_id: NDIndex, key: str, **kwargs: Any | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this switch to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Future-proofing and making it completely transparent to the plugin. |
||
) -> np.ndarray: | ||
self.raise_if_closed() | ||
assert meta is None | ||
assert self.transferred, "`get_output_partition` called before barrier task" | ||
if not self.transferred: | ||
raise RuntimeError("`get_output_partition` called before barrier task") | ||
|
||
await self._ensure_output_worker(partition_id, key) | ||
|
||
await self.flush_receive() | ||
|
||
data = self._read_from_disk(partition_id) | ||
|
||
def _() -> np.ndarray: | ||
return convert_chunk(data) | ||
|
||
return await self.offload(_) | ||
return await self.offload(convert_chunk, data) | ||
|
||
def _get_assigned_worker(self, id: NDIndex) -> str: | ||
return self.worker_for[id] | ||
|
||
|
||
@dataclass(eq=False) | ||
class ArrayRechunkState(ShuffleState): | ||
type: ClassVar[ShuffleType] = ShuffleType.ARRAY_RECHUNK | ||
worker_for: dict[NDIndex, str] | ||
old: ChunkedAxes | ||
@dataclass(frozen=True) | ||
class ArrayRechunkSpec(ShuffleSpec[NDIndex]): | ||
new: ChunkedAxes | ||
old: ChunkedAxes | ||
|
||
def to_msg(self) -> dict[str, Any]: | ||
return { | ||
"status": "OK", | ||
"type": ArrayRechunkState.type, | ||
"run_id": self.run_id, | ||
"worker_for": self.worker_for, | ||
"old": self.old, | ||
"new": self.new, | ||
"output_workers": self.output_workers, | ||
} | ||
def _pin_output_workers(self, plugin: ShuffleSchedulerPlugin) -> dict[NDIndex, str]: | ||
parts_out = product(*(range(len(c)) for c in self.new)) | ||
return plugin._pin_output_workers( | ||
self.id, parts_out, _get_worker_for_hash_sharding | ||
) | ||
|
||
def initialize_run_on_worker( | ||
self, | ||
run_id: int, | ||
worker_for: dict[NDIndex, str], | ||
plugin: ShuffleWorkerPlugin, | ||
) -> ShuffleRun: | ||
return ArrayRechunkRun( | ||
worker_for=worker_for, | ||
old=self.old, | ||
new=self.new, | ||
id=self.id, | ||
run_id=run_id, | ||
directory=os.path.join( | ||
plugin.worker.local_directory, | ||
f"shuffle-{self.id}-{run_id}", | ||
), | ||
executor=plugin._executor, | ||
local_address=plugin.worker.address, | ||
rpc=plugin.worker.rpc, | ||
scheduler=plugin.worker.scheduler, | ||
memory_limiter_disk=plugin.memory_limiter_disk, | ||
memory_limiter_comms=plugin.memory_limiter_comms, | ||
) | ||
|
||
|
||
def get_worker_for_hash_sharding( | ||
def _get_worker_for_hash_sharding( | ||
output_partition: NDIndex, workers: Sequence[str] | ||
) -> str: | ||
"""Get address of target worker for this output partition using hash sharding""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FWIW, this is this mypy bug