Skip to content

Commit

Permalink
Fix P2P worker cleanup (#7981)
Browse files Browse the repository at this point in the history
  • Loading branch information
hendrikmakait committed Jul 13, 2023
1 parent 7b21399 commit 8e3e0f6
Show file tree
Hide file tree
Showing 13 changed files with 262 additions and 242 deletions.
19 changes: 12 additions & 7 deletions distributed/diagnostics/tests/test_worker_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ async def test_create_with_client(c, s):

@gen_cluster(client=True, nthreads=[])
async def test_remove_with_client(c, s):
existing_plugins = s.worker_plugins.copy()
n_existing_plugins = len(existing_plugins)
await c.register_worker_plugin(MyPlugin(123), name="foo")
await c.register_worker_plugin(MyPlugin(546), name="bar")

Expand All @@ -62,18 +64,18 @@ async def test_remove_with_client(c, s):
assert worker._my_plugin_status == "teardown"

# check that on the scheduler registered worker plugins we only have 'bar'
assert len(s.worker_plugins) == 1
assert len(s.worker_plugins) == n_existing_plugins + 1
assert "bar" in s.worker_plugins

# check on the worker plugins that we only have 'bar'
assert len(worker.plugins) == 1
assert len(worker.plugins) == n_existing_plugins + 1
assert "bar" in worker.plugins

# let's remove 'bar' and we should have none worker plugins
await c.unregister_worker_plugin("bar")
assert worker._my_plugin_status == "teardown"
assert not s.worker_plugins
assert not worker.plugins
assert s.worker_plugins == existing_plugins
assert len(worker.plugins) == n_existing_plugins


@gen_cluster(client=True, nthreads=[])
Expand All @@ -87,7 +89,9 @@ async def test_remove_with_client_raises(c, s):

@gen_cluster(client=True, worker_kwargs={"plugins": [MyPlugin(5)]})
async def test_create_on_construction(c, s, a, b):
assert len(a.plugins) == len(b.plugins) == 1
assert len(a.plugins) == len(b.plugins)
assert any(isinstance(plugin, MyPlugin) for plugin in a.plugins.values())
assert any(isinstance(plugin, MyPlugin) for plugin in b.plugins.values())
assert a._my_plugin_status == "setup"
assert a._my_plugin_data == 5

Expand Down Expand Up @@ -195,9 +199,10 @@ async def test_default_name(c, s, w):
class MyCustomPlugin(WorkerPlugin):
pass

n_existing_plugins = len(w.plugins)
await c.register_worker_plugin(MyCustomPlugin())
assert len(w.plugins) == 1
assert next(iter(w.plugins)).startswith("MyCustomPlugin-")
assert len(w.plugins) == n_existing_plugins + 1
assert any(name.startswith("MyCustomPlugin-") for name in w.plugins)


@gen_cluster(client=True, nthreads=[("", 1)])
Expand Down
4 changes: 2 additions & 2 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@
from distributed.recreate_tasks import ReplayTaskScheduler
from distributed.security import Security
from distributed.semaphore import SemaphoreExtension
from distributed.shuffle import ShuffleSchedulerExtension
from distributed.shuffle import ShuffleSchedulerPlugin
from distributed.spans import SpansSchedulerExtension
from distributed.stealing import WorkStealing
from distributed.utils import (
Expand Down Expand Up @@ -170,7 +170,7 @@
"events": EventExtension,
"amm": ActiveMemoryManagerExtension,
"memory_sampler": MemorySamplerExtension,
"shuffle": ShuffleSchedulerExtension,
"shuffle": ShuffleSchedulerPlugin,
"spans": SpansSchedulerExtension,
"stealing": WorkStealing,
}
Expand Down
8 changes: 4 additions & 4 deletions distributed/shuffle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from distributed.shuffle._arrow import check_minimal_arrow_version
from distributed.shuffle._merge import HashJoinP2PLayer, hash_join_p2p
from distributed.shuffle._rechunk import rechunk_p2p
from distributed.shuffle._scheduler_extension import ShuffleSchedulerExtension
from distributed.shuffle._scheduler_plugin import ShuffleSchedulerPlugin
from distributed.shuffle._shuffle import P2PShuffleLayer, rearrange_by_column_p2p
from distributed.shuffle._worker_extension import ShuffleWorkerExtension
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin

__all__ = [
"check_minimal_arrow_version",
Expand All @@ -14,6 +14,6 @@
"P2PShuffleLayer",
"rearrange_by_column_p2p",
"rechunk_p2p",
"ShuffleSchedulerExtension",
"ShuffleWorkerExtension",
"ShuffleSchedulerPlugin",
"ShuffleWorkerPlugin",
]
4 changes: 2 additions & 2 deletions distributed/shuffle/_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from distributed.shuffle._shuffle import (
ShuffleId,
_get_worker_extension,
_get_worker_plugin,
barrier_key,
shuffle_barrier,
shuffle_transfer,
Expand Down Expand Up @@ -167,7 +167,7 @@ def merge_unpack(
):
from dask.dataframe.multi import merge_chunk

ext = _get_worker_extension()
ext = _get_worker_plugin()
# If the partition is empty, it doesn't contain the hash column name
left = ext.get_output_partition(
shuffle_id_left, barrier_left, output_partition, meta=meta_left
Expand Down
6 changes: 3 additions & 3 deletions distributed/shuffle/_rechunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from distributed.shuffle._shuffle import (
ShuffleId,
ShuffleType,
_get_worker_extension,
_get_worker_plugin,
barrier_key,
shuffle_barrier,
)
Expand All @@ -36,7 +36,7 @@ def rechunk_transfer(
old: ChunkedAxes,
) -> int:
try:
return _get_worker_extension().add_partition(
return _get_worker_plugin().add_partition(
input,
partition_id=input_chunk,
shuffle_id=id,
Expand All @@ -52,7 +52,7 @@ def rechunk_unpack(
id: ShuffleId, output_chunk: NDIndex, barrier_run_id: int
) -> np.ndarray:
try:
return _get_worker_extension().get_output_partition(
return _get_worker_plugin().get_output_partition(
id, barrier_run_id, output_chunk
)
except Reschedule as e:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,15 @@
from typing import TYPE_CHECKING, Any, ClassVar

from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.protocol.pickle import dumps
from distributed.shuffle._rechunk import ChunkedAxes, NDIndex
from distributed.shuffle._shuffle import (
ShuffleId,
ShuffleType,
barrier_key,
id_from_key,
)
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin

if TYPE_CHECKING:
from distributed.scheduler import (
Expand Down Expand Up @@ -85,16 +87,16 @@ def to_msg(self) -> dict[str, Any]:
}


class ShuffleSchedulerExtension(SchedulerPlugin):
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle extension for the scheduler
Shuffle plugin for the scheduler
Today this mostly just collects heartbeat messages for the dashboard,
but in the future it may be responsible for more
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerExtension
ShuffleWorkerPlugin
"""

scheduler: Scheduler
Expand All @@ -115,7 +117,13 @@ def __init__(self, scheduler: Scheduler):
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.states = {}
self.erred_shuffles = {}
self.scheduler.add_plugin(self)
self.scheduler.add_plugin(self, name="shuffle")

async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle"
)

def shuffle_ids(self) -> set[ShuffleId]:
return set(self.states)
Expand Down
16 changes: 8 additions & 8 deletions distributed/shuffle/_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from dask.dataframe import DataFrame

# circular dependency
from distributed.shuffle._worker_extension import ShuffleWorkerExtension
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin

ShuffleId = NewType("ShuffleId", str)

Expand All @@ -32,7 +32,7 @@ class ShuffleType(Enum):
ARRAY_RECHUNK = "ArrayRechunk"


def _get_worker_extension() -> ShuffleWorkerExtension:
def _get_worker_plugin() -> ShuffleWorkerPlugin:
from distributed import get_worker

try:
Expand All @@ -42,13 +42,13 @@ def _get_worker_extension() -> ShuffleWorkerExtension:
"`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
"please confirm that you've created a distributed Client and are submitting this computation through it."
) from e
extension: ShuffleWorkerExtension | None = worker.extensions.get("shuffle")
if extension is None:
plugin: ShuffleWorkerPlugin | None = worker.plugins.get("shuffle") # type: ignore
if plugin is None:
raise RuntimeError(
f"The worker {worker.address} does not have a ShuffleExtension. "
"Is pandas installed on the worker?"
)
return extension
return plugin


def shuffle_transfer(
Expand All @@ -60,7 +60,7 @@ def shuffle_transfer(
parts_out: set[int],
) -> int:
try:
return _get_worker_extension().add_partition(
return _get_worker_plugin().add_partition(
input,
shuffle_id=id,
type=ShuffleType.DATAFRAME,
Expand All @@ -77,7 +77,7 @@ def shuffle_unpack(
id: ShuffleId, output_partition: int, barrier_run_id: int, meta: pd.DataFrame
) -> pd.DataFrame:
try:
return _get_worker_extension().get_output_partition(
return _get_worker_plugin().get_output_partition(
id, barrier_run_id, output_partition, meta=meta
)
except Reschedule as e:
Expand All @@ -88,7 +88,7 @@ def shuffle_unpack(

def shuffle_barrier(id: ShuffleId, run_ids: list[int]) -> int:
try:
return _get_worker_extension().barrier(id, run_ids)
return _get_worker_plugin().barrier(id, run_ids)
except Exception as e:
raise RuntimeError(f"shuffle_barrier failed during shuffle {id}") from e

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from dask.utils import parse_bytes

from distributed.core import PooledRPCCall
from distributed.diagnostics.plugin import WorkerPlugin
from distributed.exceptions import Reschedule
from distributed.protocol import to_serialize
from distributed.shuffle._arrow import (
Expand Down Expand Up @@ -552,7 +553,7 @@ def _get_assigned_worker(self, id: int) -> str:
return self.worker_for[id]


class ShuffleWorkerExtension:
class ShuffleWorkerPlugin(WorkerPlugin):
"""Interface between a Worker and a Shuffle.
This extension is responsible for
Expand All @@ -571,7 +572,7 @@ class ShuffleWorkerExtension:
memory_limiter_disk: ResourceLimiter
closed: bool

def __init__(self, worker: Worker) -> None:
def setup(self, worker: Worker) -> None:
# Attach to worker
worker.handlers["shuffle_receive"] = self.shuffle_receive
worker.handlers["shuffle_inputs_done"] = self.shuffle_inputs_done
Expand All @@ -588,10 +589,10 @@ def __init__(self, worker: Worker) -> None:
self._executor = ThreadPoolExecutor(self.worker.state.nthreads)

def __str__(self) -> str:
return f"ShuffleWorkerExtension on {self.worker.address}"
return f"ShuffleWorkerPlugin on {self.worker.address}"

def __repr__(self) -> str:
return f"<ShuffleWorkerExtension, worker={self.worker.address_safe!r}, closed={self.closed}>"
return f"<ShuffleWorkerPlugin, worker={self.worker.address_safe!r}, closed={self.closed}>"

# Handlers
##########
Expand Down Expand Up @@ -638,7 +639,7 @@ def shuffle_fail(self, shuffle_id: ShuffleId, run_id: int, message: str) -> None
exception = RuntimeError(message)
shuffle.fail(exception)

async def _(extension: ShuffleWorkerExtension, shuffle: ShuffleRun) -> None:
async def _(extension: ShuffleWorkerPlugin, shuffle: ShuffleRun) -> None:
await shuffle.close()
extension._runs.remove(shuffle)

Expand Down Expand Up @@ -805,7 +806,7 @@ async def _refresh_shuffle(
existing.fail(RuntimeError("Stale Shuffle"))

async def _(
extension: ShuffleWorkerExtension, shuffle: ShuffleRun
extension: ShuffleWorkerPlugin, shuffle: ShuffleRun
) -> None:
await shuffle.close()
extension._runs.remove(shuffle)
Expand Down Expand Up @@ -855,7 +856,7 @@ async def _(
self._runs.add(shuffle)
return shuffle

async def close(self) -> None:
async def teardown(self, worker: Worker) -> None:
assert not self.closed

self.closed = True
Expand Down
4 changes: 2 additions & 2 deletions distributed/shuffle/tests/test_rechunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@

from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._rechunk import Split, split_axes
from distributed.shuffle._scheduler_extension import get_worker_for_hash_sharding
from distributed.shuffle._scheduler_plugin import get_worker_for_hash_sharding
from distributed.shuffle._shuffle import ShuffleId
from distributed.shuffle._worker_extension import ArrayRechunkRun
from distributed.shuffle._worker_plugin import ArrayRechunkRun
from distributed.shuffle.tests.utils import AbstractShuffleTestPool
from distributed.utils_test import gen_cluster, gen_test, raises_with_cause

Expand Down

0 comments on commit 8e3e0f6

Please sign in to comment.