Skip to content

Commit

Permalink
Ensure restart clears taskgroups et al
Browse files Browse the repository at this point in the history
  • Loading branch information
fjetter committed Aug 24, 2022
1 parent b7e184a commit 840c74c
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 13 deletions.
31 changes: 18 additions & 13 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1431,6 +1431,21 @@ def new_task(

return ts

def _clear_task_state(self):

logger.debug("Clear task state")
for collection in [
self.unrunnable,
self.erred_tasks,
self.computations,
self.task_prefixes,
self.task_groups,
self.task_metadata,
self.unknown_durations,
self.replicated_tasks,
]:
collection.clear()

#####################
# State Transitions #
#####################
Expand Down Expand Up @@ -3063,8 +3078,6 @@ def __init__(
resources = {}
aliases = {}

self._task_state_collections = [unrunnable]

self._worker_collections = [
workers,
host_info,
Expand Down Expand Up @@ -3365,7 +3378,7 @@ async def start_unsafe(self):

enable_gc_diagnosis()

self.clear_task_state()
self._clear_task_state()

for addr in self._start_address:
await self.listen(
Expand Down Expand Up @@ -5143,13 +5156,6 @@ async def gather(self, keys, serializers=None):
self.log_event("all", {"action": "gather", "count": len(keys)})
return result

def clear_task_state(self):
# XXX what about nested state such as ClientState.wants_what
# (see also fire-and-forget...)
logger.info("Clear task state")
for collection in self._task_state_collections:
collection.clear()

@log_errors
async def restart(self, client=None, timeout=30, wait_for_workers=True):
"""
Expand Down Expand Up @@ -5189,9 +5195,8 @@ async def restart(self, client=None, timeout=30, wait_for_workers=True):
stimulus_id=stimulus_id,
)

self.clear_task_state()
self.erred_tasks.clear()
self.computations.clear()
self._clear_task_state()
assert not self.tasks
self.report({"op": "restart"})

for plugin in list(self.plugins.values()):
Expand Down
8 changes: 8 additions & 0 deletions distributed/tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,11 +615,19 @@ async def test_ready_remove_worker(s, a, b):

@gen_cluster(client=True, Worker=Nanny, timeout=60)
async def test_restart(c, s, a, b):
from distributed.scheduler import TaskState

before = TaskState._instances
futures = c.map(inc, range(20))
await wait(futures)

await s.restart()

assert TaskState._instances == before
assert not s.computations
assert not s.task_prefixes
assert not s.task_groups

assert len(s.workers) == 2

for ws in s.workers.values():
Expand Down

0 comments on commit 840c74c

Please sign in to comment.