Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make P2P shuffle extensible #8096

Merged
merged 7 commits into from Aug 17, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
89 changes: 65 additions & 24 deletions distributed/shuffle/_core.py
Expand Up @@ -10,7 +10,8 @@
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from typing import TYPE_CHECKING, Any, ClassVar, Generic, NewType, TypeVar
from functools import partial
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar

from distributed.core import PooledRPCCall
from distributed.exceptions import Reschedule
Expand All @@ -21,27 +22,28 @@
from distributed.shuffle._limiter import ResourceLimiter

if TYPE_CHECKING:
import pandas as pd
# TODO import from typing (requires Python >=3.10)
from typing_extensions import TypeAlias

# avoid circular dependencies
from distributed.shuffle._scheduler_plugin import ShuffleSchedulerPlugin

# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin

ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]


_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")

NDIndex: TypeAlias = tuple[int, ...]

ShuffleId = NewType("ShuffleId", str)


class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
def __init__(
self,
id: ShuffleId,
run_id: int,
output_workers: set[str],
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
Expand All @@ -52,7 +54,6 @@
):
self.id = id
self.run_id = run_id
self.output_workers = output_workers
self.local_address = local_address
self.executor = executor
self.rpc = rpc
Expand Down Expand Up @@ -215,7 +216,7 @@

@abc.abstractmethod
async def get_output_partition(
self, partition_id: _T_partition_id, key: str, meta: pd.DataFrame | None = None
self, partition_id: _T_partition_id, key: str, **kwargs: Any
) -> _T_partition_type:
"""Get an output partition to the shuffle run"""

Expand All @@ -230,13 +231,12 @@
"`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
"please confirm that you've created a distributed Client and are submitting this computation through it."
) from e
plugin: ShuffleWorkerPlugin | None = worker.plugins.get("shuffle") # type: ignore
if plugin is None:
try:
return worker.plugins["shuffle"] # type: ignore
except KeyError as e:

Check warning on line 236 in distributed/shuffle/_core.py

View check run for this annotation

Codecov / codecov/patch

distributed/shuffle/_core.py#L236

Added line #L236 was not covered by tests
raise RuntimeError(
f"The worker {worker.address} does not have a ShuffleExtension. "
"Is pandas installed on the worker?"
)
return plugin
f"The worker {worker.address} does not have a P2P shuffle plugin."
) from e


_BARRIER_PREFIX = "shuffle-barrier-"
Expand All @@ -256,19 +256,60 @@
ARRAY_RECHUNK = "ArrayRechunk"


@dataclass(eq=False)
class ShuffleState(abc.ABC):
_run_id_iterator: ClassVar[itertools.count] = itertools.count(1)
@dataclass(frozen=True)
class ShuffleRunSpec(Generic[_T_partition_id]):
run_id: int = field(init=False, default_factory=partial(next, itertools.count(1))) # type: ignore
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW, this is this mypy bug

spec: ShuffleSpec
worker_for: dict[_T_partition_id, str]

@property
def id(self) -> ShuffleId:
return self.spec.id


@dataclass(frozen=True)
class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
id: ShuffleId
hendrikmakait marked this conversation as resolved.
Show resolved Hide resolved
run_id: int
output_workers: set[str]

def create_new_run(
self,
plugin: ShuffleSchedulerPlugin,
) -> SchedulerShuffleState:
worker_for = self._pin_output_workers(plugin)
return SchedulerShuffleState(
run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for),
participating_workers=set(worker_for.values()),
)

@abc.abstractmethod
def _pin_output_workers(
self, plugin: ShuffleSchedulerPlugin
) -> dict[_T_partition_id, str]:
"""TODO"""

@abc.abstractmethod
def initialize_run_on_worker(
self,
run_id: int,
worker_for: dict[_T_partition_id, str],
plugin: ShuffleWorkerPlugin,
) -> ShuffleRun:
"""TODO"""
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All you need to do now is implement concrete subclasses of ShuffleSpec and ShuffleRun.



@dataclass(eq=False)
class SchedulerShuffleState(Generic[_T_partition_id]):
run_spec: ShuffleRunSpec
participating_workers: set[str]
_archived_by: str | None = field(default=None, init=False)

@abc.abstractmethod
def to_msg(self) -> dict[str, Any]:
"""Transform the shuffle state into a JSON-serializable message"""
@property
def id(self) -> ShuffleId:
return self.run_spec.id

@property
def run_id(self) -> int:
return self.run_spec.run_id

def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
Expand Down
91 changes: 49 additions & 42 deletions distributed/shuffle/_rechunk.py
Expand Up @@ -96,13 +96,15 @@

from __future__ import annotations

import os
import pickle
from collections import defaultdict
from collections.abc import Callable, Sequence
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from io import BytesIO
from typing import TYPE_CHECKING, Any, ClassVar, NamedTuple
from itertools import product
from typing import TYPE_CHECKING, Any, NamedTuple

import dask
from dask.base import tokenize
Expand All @@ -114,23 +116,21 @@
NDIndex,
ShuffleId,
ShuffleRun,
ShuffleState,
ShuffleType,
barrier_key,
ShuffleSpec,
get_worker_plugin,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._shuffle import shuffle_barrier
from distributed.shuffle._scheduler_plugin import ShuffleSchedulerPlugin
from distributed.shuffle._shuffle import barrier_key, shuffle_barrier
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.sizeof import sizeof

if TYPE_CHECKING:
import numpy as np
import pandas as pd
from typing_extensions import TypeAlias

import dask.array as da


ChunkedAxis: TypeAlias = tuple[float, ...] # chunks must either be an int or NaN
ChunkedAxes: TypeAlias = tuple[ChunkedAxis, ...]
NDSlice: TypeAlias = tuple[slice, ...]
Expand All @@ -147,10 +147,7 @@
return get_worker_plugin().add_partition(
input,
partition_id=input_chunk,
shuffle_id=id,
type=ShuffleType.ARRAY_RECHUNK,
new=new,
old=old,
spec=ArrayRechunkSpec(id=id, new=new, old=old),
)
except Exception as e:
raise RuntimeError(f"rechunk_transfer failed during shuffle {id}") from e
Expand Down Expand Up @@ -304,10 +301,9 @@

Parameters
----------
# FIXME
worker_for:
A mapping partition_id -> worker_address.
output_workers:
A set of all participating worker (addresses).
old:
Existing chunking of the array per dimension.
new:
Expand Down Expand Up @@ -338,7 +334,6 @@
def __init__(
self,
worker_for: dict[NDIndex, str],
output_workers: set,
old: ChunkedAxes,
new: ChunkedAxes,
id: ShuffleId,
Expand All @@ -354,7 +349,6 @@
super().__init__(
id=id,
run_id=run_id,
output_workers=output_workers,
local_address=local_address,
directory=directory,
executor=executor,
Expand Down Expand Up @@ -403,7 +397,9 @@
repartitioned[id].append(shard)
return {k: pickle.dumps(v) for k, v in repartitioned.items()}

async def add_partition(self, data: np.ndarray, partition_id: NDIndex) -> int:
async def add_partition(
self, data: np.ndarray, partition_id: NDIndex, **kwargs: Any
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
Expand Down Expand Up @@ -441,47 +437,58 @@
return self.run_id

async def get_output_partition(
self, partition_id: NDIndex, key: str, meta: pd.DataFrame | None = None
self, partition_id: NDIndex, key: str, **kwargs: Any
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this switch to kwargs future-proofing?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Future-proofing and making it completely transparent to the plugin.

) -> np.ndarray:
self.raise_if_closed()
assert meta is None
assert self.transferred, "`get_output_partition` called before barrier task"
if not self.transferred:
raise RuntimeError("`get_output_partition` called before barrier task")

Check warning on line 444 in distributed/shuffle/_rechunk.py

View check run for this annotation

Codecov / codecov/patch

distributed/shuffle/_rechunk.py#L444

Added line #L444 was not covered by tests

await self._ensure_output_worker(partition_id, key)

await self.flush_receive()

data = self._read_from_disk(partition_id)

def _() -> np.ndarray:
return convert_chunk(data)

return await self.offload(_)
return await self.offload(convert_chunk, data)

def _get_assigned_worker(self, id: NDIndex) -> str:
return self.worker_for[id]


@dataclass(eq=False)
class ArrayRechunkState(ShuffleState):
type: ClassVar[ShuffleType] = ShuffleType.ARRAY_RECHUNK
worker_for: dict[NDIndex, str]
old: ChunkedAxes
@dataclass(frozen=True)
class ArrayRechunkSpec(ShuffleSpec[NDIndex]):
new: ChunkedAxes
old: ChunkedAxes

def to_msg(self) -> dict[str, Any]:
return {
"status": "OK",
"type": ArrayRechunkState.type,
"run_id": self.run_id,
"worker_for": self.worker_for,
"old": self.old,
"new": self.new,
"output_workers": self.output_workers,
}
def _pin_output_workers(self, plugin: ShuffleSchedulerPlugin) -> dict[NDIndex, str]:
parts_out = product(*(range(len(c)) for c in self.new))
return plugin._pin_output_workers(
self.id, parts_out, _get_worker_for_hash_sharding
)

def initialize_run_on_worker(
self,
run_id: int,
worker_for: dict[NDIndex, str],
plugin: ShuffleWorkerPlugin,
) -> ShuffleRun:
return ArrayRechunkRun(
worker_for=worker_for,
old=self.old,
new=self.new,
id=self.id,
run_id=run_id,
directory=os.path.join(
plugin.worker.local_directory,
f"shuffle-{self.id}-{run_id}",
),
executor=plugin._executor,
local_address=plugin.worker.address,
rpc=plugin.worker.rpc,
scheduler=plugin.worker.scheduler,
memory_limiter_disk=plugin.memory_limiter_disk,
memory_limiter_comms=plugin.memory_limiter_comms,
)


def get_worker_for_hash_sharding(
def _get_worker_for_hash_sharding(
output_partition: NDIndex, workers: Sequence[str]
) -> str:
"""Get address of target worker for this output partition using hash sharding"""
Expand Down