diff --git a/distributed/shuffle/_scheduler_extension.py b/distributed/shuffle/_scheduler_extension.py index 429125addb..8c421255b5 100644 --- a/distributed/shuffle/_scheduler_extension.py +++ b/distributed/shuffle/_scheduler_extension.py @@ -45,6 +45,9 @@ class ShuffleState(abc.ABC): def to_msg(self) -> dict[str, Any]: """Transform the shuffle state into a JSON-serializable message""" + def __str__(self) -> str: + return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>" + @dataclass class DataFrameShuffleState(ShuffleState): @@ -119,6 +122,7 @@ def shuffle_ids(self) -> set[ShuffleId]: async def barrier(self, id: ShuffleId, run_id: int) -> None: shuffle = self.states[id] + assert shuffle.run_id == run_id, f"{run_id=} does not match {shuffle}" msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id} await self.scheduler.broadcast( msg=msg, workers=list(shuffle.participating_workers) @@ -126,8 +130,16 @@ async def barrier(self, id: ShuffleId, run_id: int) -> None: def restrict_task(self, id: ShuffleId, run_id: int, key: str, worker: str) -> dict: shuffle = self.states[id] - if shuffle.run_id != run_id: - return {"status": "error", "message": "Stale shuffle"} + if shuffle.run_id > run_id: + return { + "status": "error", + "message": f"Request stale, expected {run_id=} for {shuffle}", + } + elif shuffle.run_id < run_id: + return { + "status": "error", + "message": f"Request invalid, expected {run_id=} for {shuffle}", + } ts = self.scheduler.tasks[key] self._set_restriction(ts, worker) return {"status": "OK"} @@ -298,9 +310,7 @@ def remove_worker(self, scheduler: Scheduler, worker: str) -> None: for shuffle_id, shuffle in self.states.items(): if worker not in shuffle.participating_workers: continue - exception = RuntimeError( - f"Worker {worker} left during active shuffle {shuffle_id}" - ) + exception = RuntimeError(f"Worker {worker} left during active {shuffle}") self.erred_shuffles[shuffle_id] = exception self._fail_on_workers(shuffle, str(exception)) @@ -335,7 +345,7 @@ def transition( shuffle = self.states[shuffle_id] except KeyError: return - self._fail_on_workers(shuffle, message=f"Shuffle {shuffle_id} forgotten") + self._fail_on_workers(shuffle, message=f"{shuffle} forgotten") self._clean_on_scheduler(shuffle_id) def _fail_on_workers(self, shuffle: ShuffleState, message: str) -> None: diff --git a/distributed/shuffle/_worker_extension.py b/distributed/shuffle/_worker_extension.py index 16497bcf1a..d284ee3b86 100644 --- a/distributed/shuffle/_worker_extension.py +++ b/distributed/shuffle/_worker_extension.py @@ -99,7 +99,10 @@ def __init__( self._closed_event = asyncio.Event() def __repr__(self) -> str: - return f"<{self.__class__.__name__} {self.id}[{self.run_id}] on {self.local_address}>" + return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>" + + def __str__(self) -> str: + return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}" def __hash__(self) -> int: return self.run_id @@ -162,9 +165,7 @@ def raise_if_closed(self) -> None: if self.closed: if self._exception: raise self._exception - raise ShuffleClosedError( - f"Shuffle {self.id} has been closed on {self.local_address}" - ) + raise ShuffleClosedError(f"{self} has already been closed") async def inputs_done(self) -> None: self.raise_if_closed() @@ -346,7 +347,7 @@ async def _receive(self, data: list[tuple[ArrayRechunkShardID, bytes]]) -> None: async def add_partition(self, data: np.ndarray, partition_id: NDIndex) -> int: self.raise_if_closed() if self.transferred: - raise RuntimeError(f"Cannot add more partitions to shuffle {self}") + raise RuntimeError(f"Cannot add more partitions to {self}") def _() -> dict[str, list[tuple[ArrayRechunkShardID, bytes]]]: """Return a mapping of worker addresses to a list of tuples of shard IDs @@ -511,7 +512,7 @@ def _repartition_buffers(self, data: list[bytes]) -> dict[NDIndex, list[bytes]]: async def add_partition(self, data: pd.DataFrame, partition_id: int) -> int: self.raise_if_closed() if self.transferred: - raise RuntimeError(f"Cannot add more partitions to shuffle {self}") + raise RuntimeError(f"Cannot add more partitions to {self}") def _() -> dict[str, list[tuple[int, bytes]]]: out = split_by_worker( @@ -586,6 +587,12 @@ def __init__(self, worker: Worker) -> None: self.closed = False self._executor = ThreadPoolExecutor(self.worker.state.nthreads) + def __str__(self) -> str: + return f"ShuffleWorkerExtension on {self.worker.address}" + + def __repr__(self) -> str: + return f"" + # Handlers ########## # NOTE: handlers are not threadsafe, but they're called from async comms, so that's okay @@ -695,11 +702,11 @@ async def _get_shuffle_run( shuffle = await self._refresh_shuffle( shuffle_id=shuffle_id, ) - if run_id < shuffle.run_id: - raise RuntimeError("Stale shuffle") - elif run_id > shuffle.run_id: - # This should never happen - raise RuntimeError("Invalid shuffle state") + + if shuffle.run_id > run_id: + raise RuntimeError(f"{run_id=} stale, got {shuffle}") + elif shuffle.run_id < run_id: + raise RuntimeError(f"{run_id=} invalid, got {shuffle}") if shuffle._exception: raise shuffle._exception @@ -729,9 +736,7 @@ async def _get_or_create_shuffle( ) if self.closed: - raise ShuffleClosedError( - f"{self.__class__.__name__} already closed on {self.worker.address}" - ) + raise ShuffleClosedError(f"{self} has already been closed") if shuffle._exception: raise shuffle._exception return shuffle @@ -790,9 +795,7 @@ async def _refresh_shuffle( assert result["status"] == "OK" if self.closed: - raise ShuffleClosedError( - f"{self.__class__.__name__} already closed on {self.worker.address}" - ) + raise ShuffleClosedError(f"{self} has already been closed") if shuffle_id in self.shuffles: existing = self.shuffles[shuffle_id] if existing.run_id >= result["run_id"]: diff --git a/distributed/shuffle/tests/test_shuffle.py b/distributed/shuffle/tests/test_shuffle.py index 66382babdd..b11b678c79 100644 --- a/distributed/shuffle/tests/test_shuffle.py +++ b/distributed/shuffle/tests/test_shuffle.py @@ -1601,7 +1601,7 @@ async def test_shuffle_run_consistency(c, s, a): # This should never occur, but fetching an ID larger than the ID available on # the scheduler should result in an error. - with pytest.raises(RuntimeError, match="Invalid shuffle state"): + with pytest.raises(RuntimeError, match="invalid"): await worker_ext._get_shuffle_run(shuffle_id, shuffle_dict["run_id"] + 1) # Finish first execution @@ -1628,7 +1628,7 @@ async def test_shuffle_run_consistency(c, s, a): assert await worker_ext._get_shuffle_run(shuffle_id, new_shuffle_dict["run_id"]) # Fetching a stale run from a worker aware of the new run raises an error - with pytest.raises(RuntimeError, match="Stale shuffle"): + with pytest.raises(RuntimeError, match="stale"): await worker_ext._get_shuffle_run(shuffle_id, shuffle_dict["run_id"]) worker_ext.block_barrier.set()