diff --git a/distributed/shuffle/_worker_plugin.py b/distributed/shuffle/_worker_plugin.py index f41c9b5dfd..953042fc81 100644 --- a/distributed/shuffle/_worker_plugin.py +++ b/distributed/shuffle/_worker_plugin.py @@ -104,10 +104,13 @@ async def shuffle_inputs_done(self, shuffle_id: ShuffleId, run_id: int) -> None: await shuffle.inputs_done() async def _close_shuffle_run(self, shuffle: ShuffleRun) -> None: - await shuffle.close() - async with self._runs_cleanup_condition: - self._runs.remove(shuffle) - self._runs_cleanup_condition.notify_all() + with log_errors(): + try: + await shuffle.close() + finally: + async with self._runs_cleanup_condition: + self._runs.remove(shuffle) + self._runs_cleanup_condition.notify_all() def shuffle_fail(self, shuffle_id: ShuffleId, run_id: int, message: str) -> None: """Fails the shuffle run with the message as exception and triggers cleanup. @@ -277,15 +280,9 @@ async def _refresh_shuffle( RuntimeError("{existing!r} stale, expected run_id=={run_id}") ) - async def _( - extension: ShuffleWorkerPlugin, shuffle: ShuffleRun - ) -> None: - await shuffle.close() - async with extension._runs_cleanup_condition: - extension._runs.remove(shuffle) - extension._runs_cleanup_condition.notify_all() - - self.worker._ongoing_background_tasks.call_soon(_, self, existing) + self.worker._ongoing_background_tasks.call_soon( + ShuffleWorkerPlugin._close_shuffle_run, self, existing + ) shuffle: ShuffleRun = result.spec.create_run_on_worker( result.run_id, result.worker_for, self ) diff --git a/distributed/shuffle/tests/test_shuffle.py b/distributed/shuffle/tests/test_shuffle.py index 68f97a2558..90a5e5def6 100644 --- a/distributed/shuffle/tests/test_shuffle.py +++ b/distributed/shuffle/tests/test_shuffle.py @@ -3,6 +3,7 @@ import asyncio import io import itertools +import logging import os import random import shutil @@ -610,6 +611,34 @@ async def test_closed_bystanding_worker_during_shuffle(c, s, w1, w2, w3): await check_scheduler_cleanup(s) +class RaiseOnCloseShuffleRun(DataFrameShuffleRun): + async def close(self, *args, **kwargs): + raise RuntimeError("test-exception-on-close") + + +@mock.patch( + "distributed.shuffle._shuffle.DataFrameShuffleRun", + RaiseOnCloseShuffleRun, +) +@gen_cluster(client=True, nthreads=[]) +async def test_exception_on_close_cleans_up(c, s, caplog): + # Ensure that everything is cleaned up and does not lock up if an exception + # is raised during shuffle close. + with caplog.at_level(logging.ERROR): + async with Worker(s.address) as w: + df = dask.datasets.timeseries( + start="2000-01-01", + end="2000-01-10", + dtypes={"x": float, "y": float}, + freq="10 s", + ) + shuffled = dd.shuffle.shuffle(df, "x", shuffle="p2p") + await c.compute([shuffled, df], sync=True) + + assert any("test-exception-on-close" in record.message for record in caplog.records) + await check_worker_cleanup(w, closed=True) + + class BlockedInputsDoneShuffle(DataFrameShuffleRun): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)