Skip to content

Commit

Permalink
Refactor restart() and restart_workers()
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed Mar 1, 2024
1 parent 1602d74 commit c8c2f03
Show file tree
Hide file tree
Showing 3 changed files with 171 additions and 85 deletions.
39 changes: 16 additions & 23 deletions distributed/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
)
from dask.widgets import get_template

from distributed.core import ErrorMessage, OKMessage
from distributed.core import OKMessage
from distributed.protocol.serialize import _is_dumpable
from distributed.utils import Deadline, wait_for

Expand Down Expand Up @@ -3641,44 +3641,36 @@ async def _restart_workers(
workers: list[str],
timeout: int | float | None = None,
raise_for_error: bool = True,
) -> dict[str, str | ErrorMessage]:
) -> dict[str, Literal["OK"] | Exception]:
info = self.scheduler_info()
name_to_addr = {meta["name"]: addr for addr, meta in info["workers"].items()}
worker_addrs = [name_to_addr.get(w, w) for w in workers]

restart_out: dict[str, str | ErrorMessage] = await self.scheduler.broadcast(
msg={"op": "restart", "timeout": timeout},
out: dict[
str, Literal["OK"] | Exception
] = await self.scheduler.restart_workers(
workers=worker_addrs,
nanny=True,
timeout=timeout,
on_error="raise" if raise_for_error else "return",
)

# Map keys back to original `workers` input names/addresses
results = {w: restart_out[w_addr] for w, w_addr in zip(workers, worker_addrs)}

timeout_workers = [w for w, status in results.items() if status == "timed out"]
if timeout_workers and raise_for_error:
raise TimeoutError(
f"The following workers failed to restart with {timeout} seconds: {timeout_workers}"
)

errored: list[ErrorMessage] = [m for m in results.values() if "exception" in m] # type: ignore
if errored and raise_for_error:
raise pickle.loads(errored[0]["exception"]) # type: ignore
return results
out = {w: out[w_addr] for w, w_addr in zip(workers, worker_addrs)}
if raise_for_error:
assert all(v == "OK" for v in out.values())
return out

Check warning on line 3660 in distributed/client.py

View check run for this annotation

Codecov / codecov/patch

distributed/client.py#L3658-L3660

Added lines #L3658 - L3660 were not covered by tests

def restart_workers(
self,
workers: list[str],
timeout: int | float | None = None,
raise_for_error: bool = True,
) -> dict[str, str]:
) -> dict[str, Literal["OK"] | Exception]:
"""Restart a specified set of workers
.. note::
Only workers being monitored by a :class:`distributed.Nanny` can be restarted.
See ``Nanny.restart`` for more details.
See ``Nanny.restart`` for more details.
Parameters
----------
Expand All @@ -3693,7 +3685,7 @@ def restart_workers(
Returns
-------
dict[str, str]
dict[str, str | Exception]
Mapping of worker and restart status, the keys will match the original
values passed in via ``workers``.
Expand Down Expand Up @@ -3727,7 +3719,8 @@ def restart_workers(
for worker, meta in info["workers"].items():
if (worker in workers or meta["name"] in workers) and meta["nanny"] is None:
raise ValueError(
f"Restarting workers requires a nanny to be used. Worker {worker} has type {info['workers'][worker]['type']}."
f"Restarting workers requires a nanny to be used. Worker "
f"{worker} has type {info['workers'][worker]['type']}."
)
return self.sync(
self._restart_workers,
Expand Down
1 change: 1 addition & 0 deletions distributed/nanny.py
Original file line number Diff line number Diff line change
Expand Up @@ -851,6 +851,7 @@ async def kill(
assert self.status in (
Status.running,
Status.failed, # process failed to start, but hasn't been joined yet
Status.closing_gracefully,
), self.status
self.status = Status.stopping
logger.info("Nanny asking worker to close. Reason: %s", reason)
Expand Down
216 changes: 154 additions & 62 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6176,39 +6176,30 @@ async def gather(
return {"status": "error", "keys": list(failed_keys)}

@log_errors
async def restart(self, client=None, timeout=30, wait_for_workers=True):
"""
Restart all workers. Reset local state. Optionally wait for workers to return.
Workers without nannies are shut down, hoping an external deployment system
will restart them. Therefore, if not using nannies and your deployment system
does not automatically restart workers, ``restart`` will just shut down all
workers, then time out!
After ``restart``, all connected workers are new, regardless of whether ``TimeoutError``
was raised. Any workers that failed to shut down in time are removed, and
may or may not shut down on their own in the future.
async def restart(
self,
*,
client: str | None = None,
timeout: float = 30,
wait_for_workers: bool = True,
stimulus_id: str,
) -> None:
"""Forget all tasks and call restart_workers on all workers.
Parameters
----------
timeout:
How long to wait for workers to shut down and come back, if ``wait_for_workers``
is True, otherwise just how long to wait for workers to shut down.
Raises ``asyncio.TimeoutError`` if this is exceeded.
See restart_workers
wait_for_workers:
Whether to wait for all workers to reconnect, or just for them to shut down
(default True). Use ``restart(wait_for_workers=False)`` combined with
:meth:`Client.wait_for_workers` for granular control over how many workers to
wait for.
See restart_workers
See also
--------
Client.restart
Client.restart_workers
Scheduler.restart_workers
"""
stimulus_id = f"restart-{time()}"

logger.info("Restarting workers and releasing all keys.")
logger.info(f"Restarting workers and releasing all keys ({stimulus_id=})")

Check warning on line 6202 in distributed/scheduler.py

View check run for this annotation

Codecov / codecov/patch

distributed/scheduler.py#L6202

Added line #L6202 was not covered by tests
for cs in self.clients.values():
self.client_releases_keys(
keys=[ts.key for ts in cs.wants_what],
Expand All @@ -6226,19 +6217,91 @@ async def restart(self, client=None, timeout=30, wait_for_workers=True):
except Exception as e:
logger.exception(e)

n_workers = len(self.workers)
await self.restart_workers(

Check warning on line 6220 in distributed/scheduler.py

View check run for this annotation

Codecov / codecov/patch

distributed/scheduler.py#L6220

Added line #L6220 was not covered by tests
client=client,
timeout=timeout,
wait_for_workers=wait_for_workers,
stimulus_id=stimulus_id,
)

@log_errors
async def restart_workers(
self,
workers: list[str] | None = None,
*,
client: str | None = None,
timeout: float = 30,
wait_for_workers: bool = True,
on_error: Literal["raise", "return"] = "raise",
stimulus_id: str,
) -> dict[str, Literal["OK", "no_nanny"] | Exception]:
"""Restart selected workers. Optionally wait for workers to return.
Workers without nannies are shut down, hoping an external deployment system
will restart them. Therefore, if not using nannies and your deployment system
does not automatically restart workers, ``restart`` will just shut down all
workers, then time out!
After ``restart``, all connected workers are new, regardless of whether
``TimeoutError`` was raised. Any workers that failed to shut down in time are
removed, and may or may not shut down on their own in the future.
Parameters
----------
workers:
List of worker addresses to restart. If omitted, restart all workers.
timeout:
How long to wait for workers to shut down and come back, if ``wait_for_workers``
is True, otherwise just how long to wait for workers to shut down.
Raises ``asyncio.TimeoutError`` if this is exceeded.
wait_for_workers:
Whether to wait for all workers to reconnect, or just for them to shut down
(default True). Use ``restart(wait_for_workers=False)`` combined with
:meth:`Client.wait_for_workers` for granular control over how many workers to
wait for.
on_error:
If 'raise' (the default), raise if any nanny times out while restarting the
worker. If 'return', return error messages.
Returns
-------
{worker address: "OK", "no nanny", or "timed out" or error message}
See also
--------
Client.restart
Client.restart_workers
Scheduler.restart
"""
n_workers_before_restart = len(self.workers)
if workers is None:
workers = list(self.workers)
logger.info(f"Restarting all workers ({stimulus_id=}")

Check warning on line 6279 in distributed/scheduler.py

View check run for this annotation

Codecov / codecov/patch

distributed/scheduler.py#L6276-L6279

Added lines #L6276 - L6279 were not covered by tests
else:
workers = list(set(workers).intersection(self.workers))
logger.info(f"Restarting {len(workers)} workers: {workers} ({stimulus_id=}")

Check warning on line 6282 in distributed/scheduler.py

View check run for this annotation

Codecov / codecov/patch

distributed/scheduler.py#L6281-L6282

Added lines #L6281 - L6282 were not covered by tests

nanny_workers = {
addr: ws.nanny for addr, ws in self.workers.items() if ws.nanny
addr: self.workers[addr].nanny
for addr in workers
if self.workers[addr].nanny
}
# Close non-Nanny workers. We have no way to restart them, so we just let them go,
# and assume a deployment system is going to restart them for us.
await asyncio.gather(
*(
self.remove_worker(address=addr, stimulus_id=stimulus_id)
for addr in self.workers
if addr not in nanny_workers
# Close non-Nanny workers. We have no way to restart them, so we just let them
# go, and assume a deployment system is going to restart them for us.
no_nanny_workers = [addr for addr in workers if addr not in nanny_workers]
if no_nanny_workers:
logger.warning(

Check warning on line 6293 in distributed/scheduler.py

View check run for this annotation

Codecov / codecov/patch

distributed/scheduler.py#L6291-L6293

Added lines #L6291 - L6293 were not covered by tests
f"Workers {no_nanny_workers} do not use a nanny and will be terminated "
"without restarting them"
)
)
await asyncio.gather(

Check warning on line 6297 in distributed/scheduler.py

View check run for this annotation

Codecov / codecov/patch

distributed/scheduler.py#L6297

Added line #L6297 was not covered by tests
*(
self.remove_worker(address=addr, stimulus_id=stimulus_id)
for addr in no_nanny_workers
)
)
out: dict[str, Literal["OK", "no_nanny"] | Exception]
out = {addr: "no_nanny" for addr in no_nanny_workers}

Check warning on line 6304 in distributed/scheduler.py

View check run for this annotation

Codecov / codecov/patch

distributed/scheduler.py#L6304

Added line #L6304 was not covered by tests

logger.debug("Send kill signal to nannies: %s", nanny_workers)
async with contextlib.AsyncExitStack() as stack:
Expand All @@ -6258,10 +6321,7 @@ async def restart(self, client=None, timeout=30, wait_for_workers=True):
# FIXME does not raise if the process fails to shut down,
# see https://github.com/dask/distributed/pull/6427/files#r894917424
# NOTE: Nanny will automatically restart worker process when it's killed
nanny.kill(
reason="scheduler-restart",
timeout=timeout,
),
nanny.kill(reason=stimulus_id, timeout=timeout),
timeout,
)
for nanny in nannies
Expand All @@ -6273,46 +6333,78 @@ async def restart(self, client=None, timeout=30, wait_for_workers=True):

# Remove any workers that failed to shut down, so we can guarantee
# that after `restart`, there are no old workers around.
bad_nannies = [
addr for addr, resp in zip(nanny_workers, resps) if resp is not None
]
bad_nannies = set()
for addr, resp in zip(nanny_workers, resps):
if resp is None:
out[addr] = "OK"

Check warning on line 6339 in distributed/scheduler.py

View check run for this annotation

Codecov / codecov/patch

distributed/scheduler.py#L6336-L6339

Added lines #L6336 - L6339 were not covered by tests
else:
assert isinstance(resp, Exception)
bad_nannies.add(addr)
out[addr] = resp

Check warning on line 6343 in distributed/scheduler.py

View check run for this annotation

Codecov / codecov/patch

distributed/scheduler.py#L6341-L6343

Added lines #L6341 - L6343 were not covered by tests

if bad_nannies:
logger.error(

Check warning on line 6346 in distributed/scheduler.py

View check run for this annotation

Codecov / codecov/patch

distributed/scheduler.py#L6346

Added line #L6346 was not covered by tests
f"Workers {list(bad_nannies)} did not shut down within {timeout}s; "
"force closing"
)
await asyncio.gather(
*(
self.remove_worker(addr, stimulus_id=stimulus_id)
for addr in bad_nannies
)
)
if on_error == "raise":
raise TimeoutError(

Check warning on line 6357 in distributed/scheduler.py

View check run for this annotation

Codecov / codecov/patch

distributed/scheduler.py#L6356-L6357

Added lines #L6356 - L6357 were not covered by tests
f"{len(bad_nannies)}/{len(nannies)} nanny worker(s) did not "
f"shut down within {timeout}s"
)

raise TimeoutError(
f"{len(bad_nannies)}/{len(nannies)} nanny worker(s) did not shut down within {timeout}s"
)

self.log_event([client, "all"], {"action": "restart", "client": client})
if client:
self.log_event(client, {"action": "restart-workers", "workers": workers})
self.log_event(

Check warning on line 6364 in distributed/scheduler.py

View check run for this annotation

Codecov / codecov/patch

distributed/scheduler.py#L6362-L6364

Added lines #L6362 - L6364 were not covered by tests
"all", {"action": "restart-workers", "workers": workers, "client": client}
)

if wait_for_workers:
while len(self.workers) < n_workers:
# NOTE: if new (unrelated) workers join while we're waiting, we may return before
# our shut-down workers have come back up. That's fine; workers are interchangeable.
expect_workers = (

Check warning on line 6369 in distributed/scheduler.py

View check run for this annotation

Codecov / codecov/patch

distributed/scheduler.py#L6369

Added line #L6369 was not covered by tests
n_workers_before_restart - len(no_nanny_workers) - len(bad_nannies)
)
while len(self.workers) < expect_workers:

Check warning on line 6372 in distributed/scheduler.py

View check run for this annotation

Codecov / codecov/patch

distributed/scheduler.py#L6372

Added line #L6372 was not covered by tests
# NOTE: if new (unrelated) workers join while we're waiting, we may
# return before our shut-down workers have come back up. That's fine;
# workers are interchangeable.
if monotonic() < start + timeout:
await asyncio.sleep(0.2)
else:
msg = (
f"Waited for {n_workers} worker(s) to reconnect after restarting, "
f"but after {timeout}s, only {len(self.workers)} have returned. "
"Consider a longer timeout, or `wait_for_workers=False`."
continue

Check warning on line 6378 in distributed/scheduler.py

View check run for this annotation

Codecov / codecov/patch

distributed/scheduler.py#L6378

Added line #L6378 was not covered by tests

msg = (

Check warning on line 6380 in distributed/scheduler.py

View check run for this annotation

Codecov / codecov/patch

distributed/scheduler.py#L6380

Added line #L6380 was not covered by tests
f"Waited for {len(nanny_workers) - len(bad_nannies)} worker(s) to reconnect after "
f"restarting but, after {timeout}s, "
f"{expect_workers - len(self.workers)} have not returned. "
"Consider a longer timeout, or `wait_for_workers=False`."
)

if no_nanny_workers:
msg += (

Check warning on line 6388 in distributed/scheduler.py

View check run for this annotation

Codecov / codecov/patch

distributed/scheduler.py#L6387-L6388

Added lines #L6387 - L6388 were not covered by tests
f" The {len(no_nanny_workers)} worker(s) not using Nannies were just shut "
"down instead of restarted (restart is only possible with Nannies). If "
"your deployment system does not automatically re-launch terminated "
"processes, then those workers will never come back, and `Client.restart` "
"will always time out. Do not use `Client.restart` in that case."
)

if (n_nanny := len(nanny_workers)) < n_workers:
msg += (
f" The {n_workers - n_nanny} worker(s) not using Nannies were just shut "
"down instead of restarted (restart is only possible with Nannies). If "
"your deployment system does not automatically re-launch terminated "
"processes, then those workers will never come back, and `Client.restart` "
"will always time out. Do not use `Client.restart` in that case."
)
raise TimeoutError(msg) from None
logger.info("Restarting finished.")
exc = TimeoutError(msg)

Check warning on line 6396 in distributed/scheduler.py

View check run for this annotation

Codecov / codecov/patch

distributed/scheduler.py#L6396

Added line #L6396 was not covered by tests

new_nannies = {ws.nanny for ws in self.workers.values() if ws.nanny}
for worker_addr, nanny_addr in nanny_workers.items():
if nanny_addr not in new_nannies:
out[worker_addr] = exc

Check warning on line 6401 in distributed/scheduler.py

View check run for this annotation

Codecov / codecov/patch

distributed/scheduler.py#L6398-L6401

Added lines #L6398 - L6401 were not covered by tests

if on_error == "raise":
raise exc from None

Check warning on line 6404 in distributed/scheduler.py

View check run for this annotation

Codecov / codecov/patch

distributed/scheduler.py#L6403-L6404

Added lines #L6403 - L6404 were not covered by tests

logger.info(f"Workers restart finished ({stimulus_id=}")
return out

Check warning on line 6407 in distributed/scheduler.py

View check run for this annotation

Codecov / codecov/patch

distributed/scheduler.py#L6406-L6407

Added lines #L6406 - L6407 were not covered by tests

async def broadcast(
self,
Expand Down

0 comments on commit c8c2f03

Please sign in to comment.