diff --git a/distributed/shuffle/_scheduler_plugin.py b/distributed/shuffle/_scheduler_plugin.py index 0a511145bc..048fdfb310 100644 --- a/distributed/shuffle/_scheduler_plugin.py +++ b/distributed/shuffle/_scheduler_plugin.py @@ -115,10 +115,14 @@ def get(self, id: ShuffleId, worker: str) -> ToPickle[ShuffleRunSpec]: def get_or_create( self, - spec: ShuffleSpec, + # FIXME: This should never be ToPickle[ShuffleSpec] + spec: ShuffleSpec | ToPickle[ShuffleSpec], key: str, worker: str, ) -> ToPickle[ShuffleRunSpec]: + # FIXME: Sometimes, this doesn't actually get pickled + if isinstance(spec, ToPickle): + spec = spec.data try: return self.get(spec.id, worker) except KeyError: diff --git a/distributed/shuffle/_worker_plugin.py b/distributed/shuffle/_worker_plugin.py index bc8b1f3d4d..f41c9b5dfd 100644 --- a/distributed/shuffle/_worker_plugin.py +++ b/distributed/shuffle/_worker_plugin.py @@ -218,7 +218,7 @@ async def _get_or_create_shuffle( if shuffle is None: shuffle = await self._refresh_shuffle( shuffle_id=spec.id, - spec=ToPickle(spec), + spec=spec, key=key, ) @@ -239,7 +239,7 @@ async def _refresh_shuffle( async def _refresh_shuffle( self, shuffle_id: ShuffleId, - spec: ToPickle, + spec: ShuffleSpec, key: str, ) -> ShuffleRun: ... @@ -247,10 +247,11 @@ async def _refresh_shuffle( async def _refresh_shuffle( self, shuffle_id: ShuffleId, - spec: ToPickle | None = None, + spec: ShuffleSpec | None = None, key: str | None = None, ) -> ShuffleRun: - result: ShuffleRunSpec + # FIXME: This should never be ToPickle[ShuffleRunSpec] + result: ShuffleRunSpec | ToPickle[ShuffleRunSpec] if spec is None: result = await self.worker.scheduler.shuffle_get( id=shuffle_id, @@ -258,14 +259,12 @@ async def _refresh_shuffle( ) else: result = await self.worker.scheduler.shuffle_get_or_create( - spec=spec, + spec=ToPickle(spec), key=key, worker=self.worker.address, ) - # if result["status"] == "error": - # raise RuntimeError(result["message"]) - # assert result["status"] == "OK" - + if isinstance(result, ToPickle): + result = result.data if self.closed: raise ShuffleClosedError(f"{self} has already been closed") if shuffle_id in self.shuffles: @@ -287,7 +286,6 @@ async def _( extension._runs_cleanup_condition.notify_all() self.worker._ongoing_background_tasks.call_soon(_, self, existing) - shuffle: ShuffleRun = result.spec.create_run_on_worker( result.run_id, result.worker_for, self ) diff --git a/distributed/shuffle/tests/test_shuffle.py b/distributed/shuffle/tests/test_shuffle.py index 0e73cd5efd..4e62952d21 100644 --- a/distributed/shuffle/tests/test_shuffle.py +++ b/distributed/shuffle/tests/test_shuffle.py @@ -23,7 +23,7 @@ dd = pytest.importorskip("dask.dataframe") import dask -from dask.distributed import Event, Nanny, Worker +from dask.distributed import Event, LocalCluster, Nanny, Worker from dask.utils import stringify from distributed.client import Client @@ -187,6 +187,28 @@ async def test_basic_integration(c, s, a, b, lose_annotations, npartitions): await check_scheduler_cleanup(s) +@pytest.mark.parametrize("processes", [True, False]) +@gen_test() +async def test_basic_integration_local_cluster(processes): + async with LocalCluster( + n_workers=2, + processes=processes, + asynchronous=True, + dashboard_address=":0", + ) as cluster: + df = dask.datasets.timeseries( + start="2000-01-01", + end="2000-01-10", + dtypes={"x": float, "y": float}, + freq="10 s", + ) + c = cluster.get_client() + out = dd.shuffle.shuffle(df, "x", shuffle="p2p") + x, y = c.compute([df, out]) + x, y = await c.gather([x, y]) + dd.assert_eq(x, y) + + @pytest.mark.parametrize("npartitions", [None, 1, 20]) @gen_cluster(client=True) async def test_shuffle_with_array_conversion(c, s, a, b, lose_annotations, npartitions):