Skip to content

Commit

Permalink
Improved errors and reduced logging for P2P RPC calls (#8666)
Browse files Browse the repository at this point in the history
  • Loading branch information
hendrikmakait committed Jun 4, 2024
1 parent 366286e commit 7aea988
Show file tree
Hide file tree
Showing 8 changed files with 164 additions and 95 deletions.
15 changes: 11 additions & 4 deletions distributed/shuffle/_comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@

from dask.utils import parse_bytes

from distributed.core import ErrorMessage, OKMessage, clean_exception
from distributed.metrics import context_meter
from distributed.shuffle._disk import ShardsBuffer
from distributed.shuffle._limiter import ResourceLimiter
from distributed.utils import log_errors


class CommShardsBuffer(ShardsBuffer):
Expand Down Expand Up @@ -53,7 +53,9 @@ class CommShardsBuffer(ShardsBuffer):

def __init__(
self,
send: Callable[[str, list[tuple[Any, Any]]], Awaitable[None]],
send: Callable[
[str, list[tuple[Any, Any]]], Awaitable[OKMessage | ErrorMessage]
],
memory_limiter: ResourceLimiter,
concurrency_limit: int = 10,
):
Expand All @@ -64,9 +66,14 @@ def __init__(
)
self.send = send

@log_errors
async def _process(self, address: str, shards: list[tuple[Any, Any]]) -> None:
"""Send one message off to a neighboring worker"""
# Consider boosting total_size a bit here to account for duplication
with context_meter.meter("send"):
await self.send(address, shards)
response = await self.send(address, shards)
status = response["status"]
if status == "error":
_, exc, tb = clean_exception(**response)
assert exc
raise exc.with_traceback(tb)
assert status == "OK"
31 changes: 21 additions & 10 deletions distributed/shuffle/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,14 @@
from dask.typing import Key
from dask.utils import parse_timedelta

from distributed.core import PooledRPCCall
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import ShuffleClosedError
from distributed.shuffle._exceptions import P2PConsistencyError, ShuffleClosedError
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
Expand All @@ -59,6 +60,10 @@
_T = TypeVar("_T")


class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]


class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
Expand Down Expand Up @@ -199,7 +204,7 @@ async def barrier(self, run_ids: Sequence[int]) -> int:

async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> None:
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
Expand All @@ -209,7 +214,7 @@ async def _send(

async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> None:
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
Expand All @@ -220,7 +225,7 @@ async def send(
else:
shards_or_bytes = shards

def _send() -> Coroutine[Any, Any, None]:
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)

return await retry(
Expand Down Expand Up @@ -302,11 +307,17 @@ def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))

async def receive(self, data: list[tuple[_T_partition_id, Any]] | bytes) -> None:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)

async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
Expand Down
10 changes: 9 additions & 1 deletion distributed/shuffle/_exceptions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
from __future__ import annotations


class ShuffleClosedError(RuntimeError):
class P2PIllegalStateError(RuntimeError):
pass


class P2PConsistencyError(RuntimeError):
pass


class ShuffleClosedError(P2PConsistencyError):
pass


Expand Down
129 changes: 78 additions & 51 deletions distributed/shuffle/_scheduler_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,21 @@

from dask.typing import Key

from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors

Expand Down Expand Up @@ -98,77 +101,97 @@ async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
workers=list(shuffle.participating_workers),
)

def restrict_task(self, id: ShuffleId, run_id: int, key: Key, worker: str) -> dict:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
return {
"status": "error",
"message": f"Request stale, expected {run_id=} for {shuffle}",
}
elif shuffle.run_id < run_id:
return {
"status": "error",
"message": f"Request invalid, expected {run_id=} for {shuffle}",
}
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)

def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)

def get(self, id: ShuffleId, worker: str) -> ToPickle[ShuffleRunSpec]:
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)

def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise RuntimeError(
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return ToPickle(state.run_spec)
return state.run_spec

def _create(self, spec: ShuffleSpec, key: Key, worker: str) -> ShuffleRunSpec:
# FIXME: The current implementation relies on the barrier task to be
# known by its name. If the name has been mangled, we cannot guarantee
# that the shuffle works as intended and should fail instead.
self._raise_if_barrier_unknown(spec.id)
self._raise_if_task_not_processing(key)
worker_for = self._calculate_worker_for(spec)
self._ensure_output_tasks_are_non_rootish(spec)
state = spec.create_new_run(
worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
)
self.active_shuffles[spec.id] = state
self._shuffles[spec.id].add(state)
state.participating_workers.add(worker)
logger.warning(
"Shuffle %s initialized by task %r executed on worker %s",
spec.id,
key,
worker,
)
return state.run_spec

def get_or_create(
self,
# FIXME: This should never be ToPickle[ShuffleSpec]
spec: ShuffleSpec | ToPickle[ShuffleSpec],
spec: ShuffleSpec,
key: Key,
worker: str,
) -> ToPickle[ShuffleRunSpec]:
# FIXME: Sometimes, this doesn't actually get pickled
if isinstance(spec, ToPickle):
spec = spec.data
) -> RunSpecMessage | ErrorMessage:
try:
return self.get(spec.id, worker)
run_spec = self._get(spec.id, worker)
except P2PConsistencyError as e:
return error_message(e)
except KeyError:
# FIXME: The current implementation relies on the barrier task to be
# known by its name. If the name has been mangled, we cannot guarantee
# that the shuffle works as intended and should fail instead.
self._raise_if_barrier_unknown(spec.id)
self._raise_if_task_not_processing(key)
worker_for = self._calculate_worker_for(spec)
self._ensure_output_tasks_are_non_rootish(spec)
state = spec.create_new_run(
worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
)
self.active_shuffles[spec.id] = state
self._shuffles[spec.id].add(state)
state.participating_workers.add(worker)
logger.warning(
"Shuffle %s initialized by task %r executed on worker %s",
spec.id,
key,
worker,
)
return ToPickle(state.run_spec)
try:
run_spec = self._create(spec, key, worker)
except P2PConsistencyError as e:
return error_message(e)
return {"status": "OK", "run_spec": ToPickle(run_spec)}

def _raise_if_barrier_unknown(self, id: ShuffleId) -> None:
key = barrier_key(id)
try:
self.scheduler.tasks[key]
except KeyError:
raise RuntimeError(
raise P2PConsistencyError(
f"Barrier task with key {key!r} does not exist. This may be caused by "
"task fusion during graph generation. Please let us know that you ran "
"into this by leaving a comment at distributed#7816."
Expand All @@ -177,7 +200,9 @@ def _raise_if_barrier_unknown(self, id: ShuffleId) -> None:
def _raise_if_task_not_processing(self, key: Key) -> None:
task = self.scheduler.tasks[key]
if task.state != "processing":
raise RuntimeError(f"Expected {task} to be processing, is {task.state}.")
raise P2PConsistencyError(
f"Expected {task} to be processing, is {task.state}."
)

def _calculate_worker_for(self, spec: ShuffleSpec) -> dict[Any, str]:
"""Pin the outputs of a P2P shuffle to specific workers.
Expand Down Expand Up @@ -235,7 +260,7 @@ def _calculate_worker_for(self, spec: ShuffleSpec) -> dict[Any, str]:
if existing: # pragma: nocover
for shared_key in existing.keys() & current_worker_for.keys():
if existing[shared_key] != current_worker_for[shared_key]:
raise RuntimeError(
raise P2PIllegalStateError(
f"Failed to initialize shuffle {spec.id} because "
"it cannot align output partition mappings between "
f"existing shuffles {seen}. "
Expand Down Expand Up @@ -316,7 +341,7 @@ def _restart_recommendations(self, id: ShuffleId) -> Recs:
if barrier_task.state == "erred":
# This should never happen, a dependent of the barrier should already
# be `erred`
raise RuntimeError(
raise P2PIllegalStateError(
f"Expected dependents of {barrier_task=} to be 'erred' if "
"the barrier is."
) # pragma: no cover
Expand All @@ -326,7 +351,7 @@ def _restart_recommendations(self, id: ShuffleId) -> Recs:
if dt.state == "erred":
# This should never happen, a dependent of the barrier should already
# be `erred`
raise RuntimeError(
raise P2PIllegalStateError(
f"Expected barrier and its dependents to be "
f"'erred' if the barrier's dependency {dt} is."
) # pragma: no cover
Expand Down Expand Up @@ -366,7 +391,9 @@ def remove_worker(
shuffle_id,
stimulus_id,
)
exception = RuntimeError(f"Worker {worker} left during active {shuffle}")
exception = P2PConsistencyError(
f"Worker {worker} left during active {shuffle}"
)
self._fail_on_workers(shuffle, str(exception))
self._clean_on_scheduler(shuffle_id, stimulus_id)

Expand Down
Loading

0 comments on commit 7aea988

Please sign in to comment.