From 2140f88be7e0536f25dde7a4363c19ed03bde3db Mon Sep 17 00:00:00 2001 From: fjetter Date: Fri, 22 Dec 2023 12:13:52 +0100 Subject: [PATCH] handle closed comm --- distributed/core.py | 20 +++++++++++++----- distributed/tests/test_core.py | 37 ++++++++++++++++++++++++++++++++++ 2 files changed, 52 insertions(+), 5 deletions(-) diff --git a/distributed/core.py b/distributed/core.py index 8740320b38..6d80706570 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -1176,15 +1176,25 @@ async def _(): server._ongoing_background_tasks.call_soon(_) + async def watch_comm(): + while True: + if self._bcomm.comm.closed(): + fut.set_exception(CommClosedError) + break + await asyncio.sleep(0.1) + + t = asyncio.create_task(watch_comm()) + def is_next(): return server._waiting_for[0] == sig - async with server._ensure_order: - await server._ensure_order.wait_for(is_next) - try: + try: + async with server._ensure_order: + await server._ensure_order.wait_for(is_next) return await fut - finally: - server._waiting_for.popleft() + finally: + t.cancel() + server._waiting_for.popleft() return send_recv_from_rpc diff --git a/distributed/tests/test_core.py b/distributed/tests/test_core.py index 90e71eb26a..e5db1402dc 100644 --- a/distributed/tests/test_core.py +++ b/distributed/tests/test_core.py @@ -1544,3 +1544,40 @@ async def wait_to_unblock(error=False): async with rpc(s2.address) as r: assert not await r.do_work(other_addr=s1.address) assert await r.do_work(other_addr=s1.address, ordered=True) + + +@pytest.mark.parametrize( + "use_side_channel", + [False, True], +) +@gen_test() +async def test_ordered_rpc_comm_closed(use_side_channel): + async def sleep(duration): + await asyncio.sleep(duration) + + class MyServer(Server): + def __init__(self, *args, **kwargs): + handlers = { + "sleep": sleep, + "do_work": self.do_work, + "kill": self.kill, + } + super().__init__(handlers, *args, **kwargs) + + async def kill(self): + await self.close() + + async def do_work(self, other_addr): + r = await self.ordered_rpc(other_addr, use_side_channel=use_side_channel) + t1 = asyncio.create_task(r.sleep(duration=100000)) + with contextlib.suppress(OSError): + await self.rpc(other_addr).kill() + with pytest.raises(CommClosedError): + await t1 + return True + + async with MyServer() as s1, MyServer() as s2: + await s1.listen() + await s2.listen() + async with rpc(s2.address) as r: + assert await r.do_work(other_addr=s1.address)