Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure workers do not kill on restart #8611

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion distributed/cli/dask_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,11 @@

async def wait_for_nannies_to_finish():
"""Wait for all nannies to initialize and finish"""
await asyncio.gather(*nannies)
try:
await asyncio.gather(*nannies)
except Exception:
if not signal_fired:
raise

Check warning on line 424 in distributed/cli/dask_worker.py

View check run for this annotation

Codecov / codecov/patch

distributed/cli/dask_worker.py#L420-L424

Added lines #L420 - L424 were not covered by tests
await asyncio.gather(*(n.finished() for n in nannies))

async def wait_for_signals_and_close():
Expand Down
4 changes: 3 additions & 1 deletion distributed/comm/tcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,7 +665,9 @@ async def _handle_stream(self, stream, address):
try:
await self.on_connection(comm)
except CommClosedError:
logger.info("Connection from %s closed before handshake completed", address)
logger.debug(
"Connection from %s closed before handshake completed", address
)
return

await self.comm_handler(comm)
Expand Down
49 changes: 43 additions & 6 deletions distributed/deploy/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from collections.abc import Awaitable, Generator
from contextlib import suppress
from inspect import isawaitable
from time import time
from typing import TYPE_CHECKING, Any, ClassVar, TypeVar

from tornado import gen
Expand Down Expand Up @@ -389,28 +390,64 @@
# proper teardown.
await asyncio.gather(*worker_futs)

def _update_worker_status(self, op, msg):
def _update_worker_status(self, op, worker_addr):
if op == "remove":
name = self.scheduler_info["workers"][msg]["name"]
worker_info = self.scheduler_info["workers"][worker_addr].copy()
name = worker_info["name"]

from distributed import Nanny, Worker

def f():
# FIXME: SpecCluster is tracking workers by `name`` which are
# not necessarily unique.
# Clusters with Nannies (default) are susceptible to falsely
# removing the Nannies on restart due to this logic since the
# restart emits a op==remove signal on the worker address but
# the SpecCluster only tracks the names, i.e. after
# `lost-worker-timeout` the Nanny is still around and this logic
# could trigger a false close. The below code should handle this
# but it would be cleaner if the cluster tracked by address
# instead of name just like the scheduler does
if (
name in self.workers
and msg not in self.scheduler_info["workers"]
and worker_addr not in self.scheduler_info["workers"]
and not any(
d["name"] == name
for d in self.scheduler_info["workers"].values()
)
):
self._futures.add(asyncio.ensure_future(self.workers[name].close()))
del self.workers[name]
w = self.workers[name]

Check warning on line 419 in distributed/deploy/spec.py

View check run for this annotation

Codecov / codecov/patch

distributed/deploy/spec.py#L419

Added line #L419 was not covered by tests

async def remove_worker():
await w.close(reason=f"lost-worker-timeout-{time()}")
self.workers.pop(name, None)

Check warning on line 423 in distributed/deploy/spec.py

View check run for this annotation

Codecov / codecov/patch

distributed/deploy/spec.py#L421-L423

Added lines #L421 - L423 were not covered by tests

if (

Check warning on line 425 in distributed/deploy/spec.py

View check run for this annotation

Codecov / codecov/patch

distributed/deploy/spec.py#L425

Added line #L425 was not covered by tests
worker_info["type"] == "Worker"
and (isinstance(w, Nanny) and w.worker_address == worker_addr)
or (isinstance(w, Worker) and w.address == worker_addr)
):
self._futures.add(

Check warning on line 430 in distributed/deploy/spec.py

View check run for this annotation

Codecov / codecov/patch

distributed/deploy/spec.py#L430

Added line #L430 was not covered by tests
asyncio.create_task(
remove_worker(),
name="remove-worker-lost-worker-timeout",
)
)
elif worker_info["type"] == "Nanny":

Check warning on line 436 in distributed/deploy/spec.py

View check run for this annotation

Codecov / codecov/patch

distributed/deploy/spec.py#L436

Added line #L436 was not covered by tests
# This should never happen
logger.critical(

Check warning on line 438 in distributed/deploy/spec.py

View check run for this annotation

Codecov / codecov/patch

distributed/deploy/spec.py#L438

Added line #L438 was not covered by tests
"Unespected signal encountered. WorkerStatusPlugin "
"emitted a op==remove signal for a Nanny which "
"should not happen. This might cause a lingering "
"Nanny process."
)

delay = parse_timedelta(
dask.config.get("distributed.deploy.lost-worker-timeout")
)

asyncio.get_running_loop().call_later(delay, f)
super()._update_worker_status(op, msg)
super()._update_worker_status(op, worker_addr)

def __await__(self: Self) -> Generator[Any, Any, Self]:
async def _() -> Self:
Expand Down
15 changes: 15 additions & 0 deletions distributed/deploy/tests/test_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import pytest
from tornado.httpclient import AsyncHTTPClient

import dask
from dask.system import CPU_COUNT

from distributed import Client, LocalCluster, Nanny, Worker, get_client
Expand Down Expand Up @@ -1285,3 +1286,17 @@ def test_localcluster_get_client(loop):
with Client(cluster) as client2:
assert client1 != client2
assert client2 == cluster.get_client()


@pytest.mark.slow()
def test_localcluster_restart(loop):
with (
dask.config.set({"distributed.deploy.lost-worker-timeout": "0.5s"}),
LocalCluster(asynchronous=False, dashboard_address=":0", loop=loop) as cluster,
cluster.get_client() as client,
):
nworkers = len(client.run(lambda: None))
for _ in range(10):
assert len(client.run(lambda: None)) == nworkers
client.restart()
assert len(client.run(lambda: None)) == nworkers
Loading
Loading