Skip to content

Commit

Permalink
Simplify usage of Queues in nanny
Browse files Browse the repository at this point in the history
  • Loading branch information
fjetter committed Aug 31, 2022
1 parent f07f384 commit 7ba727b
Showing 1 changed file with 12 additions and 38 deletions.
50 changes: 12 additions & 38 deletions distributed/nanny.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,11 +590,8 @@ async def close(self, timeout=5):
await asyncio.gather(*(td for td in teardowns if isawaitable(td)))

self.stop()
try:
if self.process is not None:
await self.kill(timeout=timeout)
except Exception:
logger.exception("Error in Nanny killing Worker subprocess")
if self.process is not None:
await self.kill(timeout=timeout)
self.process = None
await self.rpc.close()
self.status = Status.closed
Expand Down Expand Up @@ -655,9 +652,9 @@ async def start(self) -> Status:
if self.status == Status.starting:
await self.running.wait()
return self.status

self.init_result_q = init_q = get_mp_context().Queue()
self.child_stop_q = get_mp_context().Queue()
mp_ctx = get_mp_context()
self.init_result_q = init_q = mp_ctx.Queue()
self.child_stop_q = mp_ctx.Queue()
uid = uuid.uuid4().hex

self.process = AsyncProcess(
Expand Down Expand Up @@ -698,6 +695,8 @@ async def start(self) -> Status:
await self.process.terminate()
self.status = Status.failed
raise
finally:
init_q.close()
if not msg:
return self.status
self.worker_address = msg["address"]
Expand All @@ -706,8 +705,6 @@ async def start(self) -> Status:
self.status = Status.running
self.running.set()

init_q.close()

return self.status

def _on_exit(self, proc):
Expand Down Expand Up @@ -771,11 +768,7 @@ async def kill(self, timeout: float = 2, executor_wait: bool = True) -> None:
if self.status == Status.stopping:
await self.stopped.wait()
return
assert self.status in (
Status.starting,
Status.running,
Status.failed, # process failed to start, but hasn't been joined yet
), self.status
assert self.status in (Status.starting, Status.running)
self.status = Status.stopping
logger.info("Nanny asking worker to close")

Expand All @@ -791,9 +784,6 @@ async def kill(self, timeout: float = 2, executor_wait: bool = True) -> None:
"executor_wait": executor_wait,
}
)
await asyncio.sleep(0) # otherwise we get broken pipe errors
queue.close()
del queue

try:
try:
Expand Down Expand Up @@ -825,7 +815,7 @@ async def _wait_until_connected(self, uid):
continue

if msg["uid"] != uid: # ensure that we didn't cross queues
continue
raise RuntimeError("Encountered message from a different queue.")

if "exception" in msg:
raise msg["exception"]
Expand Down Expand Up @@ -881,7 +871,6 @@ def watch_stop_q():
logger.error("Worker process died unexpectedly")
msg = {"op": "stop"}
finally:
child_stop_q.close()
assert msg["op"] == "stop", msg
del msg["op"]
loop.add_callback(do_stop, **msg)
Expand All @@ -901,7 +890,6 @@ async def run():
except Exception as e:
logger.exception("Failed to start worker")
init_result_q.put({"uid": uid, "exception": e})
init_result_q.close()
# If we hit an exception here we need to wait for a least
# one interval for the outside to pick up this message.
# Otherwise we arrive in a race condition where the process
Expand All @@ -923,14 +911,12 @@ async def run():
"uid": uid,
}
)
init_result_q.close()
await worker.finished()
logger.info("Worker closed")

except Exception as e:
logger.exception("Failed to initialize Worker")
init_result_q.put({"uid": uid, "exception": e})
init_result_q.close()
# If we hit an exception here we need to wait for a least one
# interval for the outside to pick up this message. Otherwise we
# arrive in a race condition where the process cleanup wipes the
Expand All @@ -948,20 +934,8 @@ async def run():
# do_stop() explicitly.
loop.run_sync(do_stop)
finally:
with suppress(ValueError):
child_stop_q.put({"op": "stop"}) # usually redundant
with suppress(ValueError):
child_stop_q.close() # usually redundant
child_stop_q.put({"op": "stop"})
thread.join()
child_stop_q.close()
child_stop_q.join_thread()
thread.join(timeout=2)


def _get_env_variables(config_key: str) -> dict[str, str]:
cfg = dask.config.get(config_key)
if not isinstance(cfg, dict):
raise TypeError( # pragma: nocover
f"{config_key} configuration must be of type dict. Instead got {type(cfg)}"
)
# Override dask config with explicitly defined env variables from the OS
cfg = {k: os.environ.get(k, str(v)) for k, v in cfg.items()}
return cfg

0 comments on commit 7ba727b

Please sign in to comment.