From 840c74caad709103b0f5fdf8e7917126e4767b56 Mon Sep 17 00:00:00 2001 From: fjetter Date: Wed, 24 Aug 2022 11:38:25 +0200 Subject: [PATCH] Ensure restart clears taskgroups et al --- distributed/scheduler.py | 31 +++++++++++++++++------------ distributed/tests/test_scheduler.py | 8 ++++++++ 2 files changed, 26 insertions(+), 13 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index e38447c70b..1823f2094e 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -1431,6 +1431,21 @@ def new_task( return ts + def _clear_task_state(self): + + logger.debug("Clear task state") + for collection in [ + self.unrunnable, + self.erred_tasks, + self.computations, + self.task_prefixes, + self.task_groups, + self.task_metadata, + self.unknown_durations, + self.replicated_tasks, + ]: + collection.clear() + ##################### # State Transitions # ##################### @@ -3063,8 +3078,6 @@ def __init__( resources = {} aliases = {} - self._task_state_collections = [unrunnable] - self._worker_collections = [ workers, host_info, @@ -3365,7 +3378,7 @@ async def start_unsafe(self): enable_gc_diagnosis() - self.clear_task_state() + self._clear_task_state() for addr in self._start_address: await self.listen( @@ -5143,13 +5156,6 @@ async def gather(self, keys, serializers=None): self.log_event("all", {"action": "gather", "count": len(keys)}) return result - def clear_task_state(self): - # XXX what about nested state such as ClientState.wants_what - # (see also fire-and-forget...) - logger.info("Clear task state") - for collection in self._task_state_collections: - collection.clear() - @log_errors async def restart(self, client=None, timeout=30, wait_for_workers=True): """ @@ -5189,9 +5195,8 @@ async def restart(self, client=None, timeout=30, wait_for_workers=True): stimulus_id=stimulus_id, ) - self.clear_task_state() - self.erred_tasks.clear() - self.computations.clear() + self._clear_task_state() + assert not self.tasks self.report({"op": "restart"}) for plugin in list(self.plugins.values()): diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 44237027b1..58b2fc7bd2 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -615,11 +615,19 @@ async def test_ready_remove_worker(s, a, b): @gen_cluster(client=True, Worker=Nanny, timeout=60) async def test_restart(c, s, a, b): + from distributed.scheduler import TaskState + + before = TaskState._instances futures = c.map(inc, range(20)) await wait(futures) await s.restart() + assert TaskState._instances == before + assert not s.computations + assert not s.task_prefixes + assert not s.task_groups + assert len(s.workers) == 2 for ws in s.workers.values():