Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix P2P worker cleanup #7981

Merged
merged 8 commits into from
Jul 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
19 changes: 12 additions & 7 deletions distributed/diagnostics/tests/test_worker_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ async def test_create_with_client(c, s):

@gen_cluster(client=True, nthreads=[])
async def test_remove_with_client(c, s):
existing_plugins = s.worker_plugins.copy()
n_existing_plugins = len(existing_plugins)
await c.register_worker_plugin(MyPlugin(123), name="foo")
await c.register_worker_plugin(MyPlugin(546), name="bar")

Expand All @@ -62,18 +64,18 @@ async def test_remove_with_client(c, s):
assert worker._my_plugin_status == "teardown"

# check that on the scheduler registered worker plugins we only have 'bar'
assert len(s.worker_plugins) == 1
assert len(s.worker_plugins) == n_existing_plugins + 1
assert "bar" in s.worker_plugins

# check on the worker plugins that we only have 'bar'
assert len(worker.plugins) == 1
assert len(worker.plugins) == n_existing_plugins + 1
assert "bar" in worker.plugins

# let's remove 'bar' and we should have none worker plugins
await c.unregister_worker_plugin("bar")
assert worker._my_plugin_status == "teardown"
assert not s.worker_plugins
assert not worker.plugins
assert s.worker_plugins == existing_plugins
assert len(worker.plugins) == n_existing_plugins


@gen_cluster(client=True, nthreads=[])
Expand All @@ -87,7 +89,9 @@ async def test_remove_with_client_raises(c, s):

@gen_cluster(client=True, worker_kwargs={"plugins": [MyPlugin(5)]})
async def test_create_on_construction(c, s, a, b):
assert len(a.plugins) == len(b.plugins) == 1
assert len(a.plugins) == len(b.plugins)
assert any(isinstance(plugin, MyPlugin) for plugin in a.plugins.values())
assert any(isinstance(plugin, MyPlugin) for plugin in b.plugins.values())
assert a._my_plugin_status == "setup"
assert a._my_plugin_data == 5

Expand Down Expand Up @@ -195,9 +199,10 @@ async def test_default_name(c, s, w):
class MyCustomPlugin(WorkerPlugin):
pass

n_existing_plugins = len(w.plugins)
await c.register_worker_plugin(MyCustomPlugin())
assert len(w.plugins) == 1
assert next(iter(w.plugins)).startswith("MyCustomPlugin-")
assert len(w.plugins) == n_existing_plugins + 1
assert any(name.startswith("MyCustomPlugin-") for name in w.plugins)


@gen_cluster(client=True, nthreads=[("", 1)])
Expand Down
4 changes: 2 additions & 2 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@
from distributed.recreate_tasks import ReplayTaskScheduler
from distributed.security import Security
from distributed.semaphore import SemaphoreExtension
from distributed.shuffle import ShuffleSchedulerExtension
from distributed.shuffle import ShuffleSchedulerPlugin
from distributed.spans import SpansSchedulerExtension
from distributed.stealing import WorkStealing
from distributed.utils import (
Expand Down Expand Up @@ -170,7 +170,7 @@
"events": EventExtension,
"amm": ActiveMemoryManagerExtension,
"memory_sampler": MemorySamplerExtension,
"shuffle": ShuffleSchedulerExtension,
"shuffle": ShuffleSchedulerPlugin,
"spans": SpansSchedulerExtension,
"stealing": WorkStealing,
}
Expand Down
8 changes: 4 additions & 4 deletions distributed/shuffle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from distributed.shuffle._arrow import check_minimal_arrow_version
from distributed.shuffle._merge import HashJoinP2PLayer, hash_join_p2p
from distributed.shuffle._rechunk import rechunk_p2p
from distributed.shuffle._scheduler_extension import ShuffleSchedulerExtension
from distributed.shuffle._scheduler_plugin import ShuffleSchedulerPlugin
from distributed.shuffle._shuffle import P2PShuffleLayer, rearrange_by_column_p2p
from distributed.shuffle._worker_extension import ShuffleWorkerExtension
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin

__all__ = [
"check_minimal_arrow_version",
Expand All @@ -14,6 +14,6 @@
"P2PShuffleLayer",
"rearrange_by_column_p2p",
"rechunk_p2p",
"ShuffleSchedulerExtension",
"ShuffleWorkerExtension",
"ShuffleSchedulerPlugin",
"ShuffleWorkerPlugin",
]
4 changes: 2 additions & 2 deletions distributed/shuffle/_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from distributed.shuffle._shuffle import (
ShuffleId,
_get_worker_extension,
_get_worker_plugin,
barrier_key,
shuffle_barrier,
shuffle_transfer,
Expand Down Expand Up @@ -167,7 +167,7 @@ def merge_unpack(
):
from dask.dataframe.multi import merge_chunk

ext = _get_worker_extension()
ext = _get_worker_plugin()
# If the partition is empty, it doesn't contain the hash column name
left = ext.get_output_partition(
shuffle_id_left, barrier_left, output_partition, meta=meta_left
Expand Down
6 changes: 3 additions & 3 deletions distributed/shuffle/_rechunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from distributed.shuffle._shuffle import (
ShuffleId,
ShuffleType,
_get_worker_extension,
_get_worker_plugin,
barrier_key,
shuffle_barrier,
)
Expand All @@ -36,7 +36,7 @@ def rechunk_transfer(
old: ChunkedAxes,
) -> int:
try:
return _get_worker_extension().add_partition(
return _get_worker_plugin().add_partition(
input,
partition_id=input_chunk,
shuffle_id=id,
Expand All @@ -52,7 +52,7 @@ def rechunk_unpack(
id: ShuffleId, output_chunk: NDIndex, barrier_run_id: int
) -> np.ndarray:
try:
return _get_worker_extension().get_output_partition(
return _get_worker_plugin().get_output_partition(
id, barrier_run_id, output_chunk
)
except Reschedule as e:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,15 @@
from typing import TYPE_CHECKING, Any, ClassVar

from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.protocol.pickle import dumps
from distributed.shuffle._rechunk import ChunkedAxes, NDIndex
from distributed.shuffle._shuffle import (
ShuffleId,
ShuffleType,
barrier_key,
id_from_key,
)
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin

if TYPE_CHECKING:
from distributed.scheduler import (
Expand Down Expand Up @@ -85,16 +87,16 @@ def to_msg(self) -> dict[str, Any]:
}


class ShuffleSchedulerExtension(SchedulerPlugin):
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle extension for the scheduler
Shuffle plugin for the scheduler

Today this mostly just collects heartbeat messages for the dashboard,
but in the future it may be responsible for more
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.

See Also
--------
ShuffleWorkerExtension
ShuffleWorkerPlugin
"""

scheduler: Scheduler
Expand All @@ -115,7 +117,13 @@ def __init__(self, scheduler: Scheduler):
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.states = {}
self.erred_shuffles = {}
self.scheduler.add_plugin(self)
self.scheduler.add_plugin(self, name="shuffle")

async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle"
)
Comment on lines +124 to +126
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this is the magic that ensures things are cleaned up, because the client will teardown the scheduler which tears down all the worker plugins?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And pretty much everything else is downstream renaming changes....

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pretty much, yes. IIUC, extensions don't get cleaned up at all. Worker plugins will get torn down on worker close, but for that the method has to be called teardown (as seen below). Similarly, we need to move some of the initialization to setup.


def shuffle_ids(self) -> set[ShuffleId]:
return set(self.states)
Expand Down
16 changes: 8 additions & 8 deletions distributed/shuffle/_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from dask.dataframe import DataFrame

# circular dependency
from distributed.shuffle._worker_extension import ShuffleWorkerExtension
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin

ShuffleId = NewType("ShuffleId", str)

Expand All @@ -32,7 +32,7 @@ class ShuffleType(Enum):
ARRAY_RECHUNK = "ArrayRechunk"


def _get_worker_extension() -> ShuffleWorkerExtension:
def _get_worker_plugin() -> ShuffleWorkerPlugin:
from distributed import get_worker

try:
Expand All @@ -42,13 +42,13 @@ def _get_worker_extension() -> ShuffleWorkerExtension:
"`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
"please confirm that you've created a distributed Client and are submitting this computation through it."
) from e
extension: ShuffleWorkerExtension | None = worker.extensions.get("shuffle")
if extension is None:
plugin: ShuffleWorkerPlugin | None = worker.plugins.get("shuffle") # type: ignore
if plugin is None:
raise RuntimeError(
f"The worker {worker.address} does not have a ShuffleExtension. "
"Is pandas installed on the worker?"
)
return extension
return plugin


def shuffle_transfer(
Expand All @@ -60,7 +60,7 @@ def shuffle_transfer(
parts_out: set[int],
) -> int:
try:
return _get_worker_extension().add_partition(
return _get_worker_plugin().add_partition(
input,
shuffle_id=id,
type=ShuffleType.DATAFRAME,
Expand All @@ -77,7 +77,7 @@ def shuffle_unpack(
id: ShuffleId, output_partition: int, barrier_run_id: int, meta: pd.DataFrame
) -> pd.DataFrame:
try:
return _get_worker_extension().get_output_partition(
return _get_worker_plugin().get_output_partition(
id, barrier_run_id, output_partition, meta=meta
)
except Reschedule as e:
Expand All @@ -88,7 +88,7 @@ def shuffle_unpack(

def shuffle_barrier(id: ShuffleId, run_ids: list[int]) -> int:
try:
return _get_worker_extension().barrier(id, run_ids)
return _get_worker_plugin().barrier(id, run_ids)
except Exception as e:
raise RuntimeError(f"shuffle_barrier failed during shuffle {id}") from e

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from dask.utils import parse_bytes

from distributed.core import PooledRPCCall
from distributed.diagnostics.plugin import WorkerPlugin
from distributed.exceptions import Reschedule
from distributed.protocol import to_serialize
from distributed.shuffle._arrow import (
Expand Down Expand Up @@ -552,7 +553,7 @@
return self.worker_for[id]


class ShuffleWorkerExtension:
class ShuffleWorkerPlugin(WorkerPlugin):
"""Interface between a Worker and a Shuffle.

This extension is responsible for
Expand All @@ -571,7 +572,7 @@
memory_limiter_disk: ResourceLimiter
closed: bool

def __init__(self, worker: Worker) -> None:
def setup(self, worker: Worker) -> None:
# Attach to worker
worker.handlers["shuffle_receive"] = self.shuffle_receive
worker.handlers["shuffle_inputs_done"] = self.shuffle_inputs_done
Expand All @@ -588,10 +589,10 @@
self._executor = ThreadPoolExecutor(self.worker.state.nthreads)

def __str__(self) -> str:
return f"ShuffleWorkerExtension on {self.worker.address}"
return f"ShuffleWorkerPlugin on {self.worker.address}"

def __repr__(self) -> str:
return f"<ShuffleWorkerExtension, worker={self.worker.address_safe!r}, closed={self.closed}>"
return f"<ShuffleWorkerPlugin, worker={self.worker.address_safe!r}, closed={self.closed}>"

Check warning on line 595 in distributed/shuffle/_worker_plugin.py

View check run for this annotation

Codecov / codecov/patch

distributed/shuffle/_worker_plugin.py#L595

Added line #L595 was not covered by tests

# Handlers
##########
Expand Down Expand Up @@ -638,7 +639,7 @@
exception = RuntimeError(message)
shuffle.fail(exception)

async def _(extension: ShuffleWorkerExtension, shuffle: ShuffleRun) -> None:
async def _(extension: ShuffleWorkerPlugin, shuffle: ShuffleRun) -> None:
await shuffle.close()
extension._runs.remove(shuffle)

Expand Down Expand Up @@ -805,7 +806,7 @@
existing.fail(RuntimeError("Stale Shuffle"))

async def _(
extension: ShuffleWorkerExtension, shuffle: ShuffleRun
extension: ShuffleWorkerPlugin, shuffle: ShuffleRun
) -> None:
await shuffle.close()
extension._runs.remove(shuffle)
Expand Down Expand Up @@ -855,7 +856,7 @@
self._runs.add(shuffle)
return shuffle

async def close(self) -> None:
async def teardown(self, worker: Worker) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Along with this so that the worker plugin now conforms to the WorkerPlugin interface.

assert not self.closed

self.closed = True
Expand Down
4 changes: 2 additions & 2 deletions distributed/shuffle/tests/test_rechunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@

from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._rechunk import Split, split_axes
from distributed.shuffle._scheduler_extension import get_worker_for_hash_sharding
from distributed.shuffle._scheduler_plugin import get_worker_for_hash_sharding
from distributed.shuffle._shuffle import ShuffleId
from distributed.shuffle._worker_extension import ArrayRechunkRun
from distributed.shuffle._worker_plugin import ArrayRechunkRun
from distributed.shuffle.tests.utils import AbstractShuffleTestPool
from distributed.utils_test import gen_cluster, gen_test, raises_with_cause

Expand Down