Skip to content

Commit

Permalink
handle closed comm
Browse files Browse the repository at this point in the history
  • Loading branch information
fjetter committed Dec 22, 2023
1 parent e7a0dfe commit 2140f88
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 5 deletions.
20 changes: 15 additions & 5 deletions distributed/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1176,15 +1176,25 @@ async def _():

server._ongoing_background_tasks.call_soon(_)

async def watch_comm():
while True:
if self._bcomm.comm.closed():
fut.set_exception(CommClosedError)
break
await asyncio.sleep(0.1)

t = asyncio.create_task(watch_comm())

def is_next():
return server._waiting_for[0] == sig

async with server._ensure_order:
await server._ensure_order.wait_for(is_next)
try:
try:
async with server._ensure_order:
await server._ensure_order.wait_for(is_next)
return await fut
finally:
server._waiting_for.popleft()
finally:
t.cancel()
server._waiting_for.popleft()

return send_recv_from_rpc

Expand Down
37 changes: 37 additions & 0 deletions distributed/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1544,3 +1544,40 @@ async def wait_to_unblock(error=False):
async with rpc(s2.address) as r:
assert not await r.do_work(other_addr=s1.address)
assert await r.do_work(other_addr=s1.address, ordered=True)


@pytest.mark.parametrize(
"use_side_channel",
[False, True],
)
@gen_test()
async def test_ordered_rpc_comm_closed(use_side_channel):
async def sleep(duration):
await asyncio.sleep(duration)

class MyServer(Server):
def __init__(self, *args, **kwargs):
handlers = {
"sleep": sleep,
"do_work": self.do_work,
"kill": self.kill,
}
super().__init__(handlers, *args, **kwargs)

async def kill(self):
await self.close()

async def do_work(self, other_addr):
r = await self.ordered_rpc(other_addr, use_side_channel=use_side_channel)
t1 = asyncio.create_task(r.sleep(duration=100000))
with contextlib.suppress(OSError):
await self.rpc(other_addr).kill()
with pytest.raises(CommClosedError):
await t1
return True

async with MyServer() as s1, MyServer() as s2:
await s1.listen()
await s2.listen()
async with rpc(s2.address) as r:
assert await r.do_work(other_addr=s1.address)

0 comments on commit 2140f88

Please sign in to comment.