From dd28d08ca22f7ae874ba10e524ed322a15b4cacd Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Tue, 24 Mar 2020 17:55:45 -0500 Subject: [PATCH] Ensure Client connection pool semaphore attaches to the Client's event loop (#3546) * Add Node and ConnectionPool start methods * Make ConnectionPools awaitable --- distributed/client.py | 3 +++ distributed/core.py | 13 +++++++++++-- distributed/nanny.py | 3 +++ distributed/node.py | 7 ++++++- distributed/scheduler.py | 3 +++ distributed/tests/test_client.py | 5 +++++ distributed/tests/test_core.py | 10 +++++----- distributed/tests/test_scheduler.py | 6 +++--- distributed/tests/test_utils_comm.py | 16 ++++++++-------- distributed/worker.py | 2 ++ 10 files changed, 49 insertions(+), 19 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index 24f25a79a0..4065aad17e 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -923,6 +923,9 @@ def _send_to_scheduler(self, msg): ) async def _start(self, timeout=no_default, **kwargs): + + await super().start() + if timeout == no_default: timeout = self._timeout if timeout is not None: diff --git a/distributed/core.py b/distributed/core.py index 5bff3276e9..1bf3b172b6 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -835,8 +835,6 @@ def __init__( self.connection_args = connection_args self.timeout = timeout self._n_connecting = 0 - # Invariant: semaphore._value == limit - open - _n_connecting - self.semaphore = asyncio.Semaphore(self.limit) self.server = weakref.ref(server) if server else None self._created = weakref.WeakSet() self._instances.add(self) @@ -871,6 +869,17 @@ def __call__(self, addr=None, ip=None, port=None): addr, self, serializers=self.serializers, deserializers=self.deserializers ) + def __await__(self): + async def _(): + await self.start() + return self + + return _().__await__() + + async def start(self): + # Invariant: semaphore._value == limit - open - _n_connecting + self.semaphore = asyncio.Semaphore(self.limit) + async def connect(self, addr, timeout=None): """ Get a Comm to the given address. For internal use. diff --git a/distributed/nanny.py b/distributed/nanny.py index ec5397efb9..baa77e3ce1 100644 --- a/distributed/nanny.py +++ b/distributed/nanny.py @@ -241,6 +241,9 @@ def local_dir(self): async def start(self): """ Start nanny, start local process, start watching """ + + await super().start() + await self.listen(self._start_address, listen_args=self.listen_args) self.ip = get_address_host(self.address) diff --git a/distributed/node.py b/distributed/node.py index 4e26defeb0..af15b5a409 100644 --- a/distributed/node.py +++ b/distributed/node.py @@ -38,6 +38,9 @@ def __init__( server=self, ) + async def start(self): + await self.rpc.start() + class ServerNode(Node, Server): """ @@ -182,5 +185,7 @@ async def wait_for(future, timeout=None): future = wait_for(future, timeout=timeout) return future.__await__() - async def start(self): # subclasses should implement this + async def start(self): + # subclasses should implement their own start method whichs calls super().start() + await Node.start(self) return self diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 8a61ba31fc..cea1f9fd13 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -1408,6 +1408,9 @@ def get_worker_service_addr(self, worker, service_name, protocol=False): async def start(self): """ Clear out old state and restart all running coroutines """ + + await super().start() + enable_gc_diagnosis() self.clear_task_state() diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 5c85391655..d9633876cb 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -5943,3 +5943,8 @@ def test_as_completed_condition_loop(c, s, a, b): seq = c.map(inc, range(5)) ac = as_completed(seq) assert ac.condition._loop == c.loop.asyncio_loop + + +def test_client_connectionpool_semaphore_loop(s, a, b): + with Client(s["address"]) as c: + assert c.rpc.semaphore._loop is c.loop.asyncio_loop diff --git a/distributed/tests/test_core.py b/distributed/tests/test_core.py index 0a9c48bc87..76f2b28550 100644 --- a/distributed/tests/test_core.py +++ b/distributed/tests/test_core.py @@ -526,7 +526,7 @@ async def ping(comm, delay=0.1): for server in servers: await server.listen(0) - rpc = ConnectionPool(limit=5) + rpc = await ConnectionPool(limit=5) # Reuse connections await asyncio.gather( @@ -583,7 +583,7 @@ async def do_ping(pool, port): for server in servers: await server.listen(0) - pool = ConnectionPool(limit=limit) + pool = await ConnectionPool(limit=limit) await asyncio.gather(*[do_ping(pool, s.port) for s in servers]) @@ -605,7 +605,7 @@ async def ping(comm, delay=0.01): for server in servers: await server.listen("tls://", listen_args=listen_args) - rpc = ConnectionPool(limit=5, connection_args=connection_args) + rpc = await ConnectionPool(limit=5, connection_args=connection_args) await asyncio.gather(*[rpc(s.address).ping() for s in servers[:5]]) await asyncio.gather(*[rpc(s.address).ping() for s in servers[::2]]) @@ -625,7 +625,7 @@ async def ping(comm, delay=0.01): for server in servers: await server.listen(0) - rpc = ConnectionPool(limit=10) + rpc = await ConnectionPool(limit=10) serv = servers.pop() await asyncio.gather(*[rpc(s.address).ping() for s in servers]) await asyncio.gather(*[rpc(serv.address).ping() for i in range(3)]) @@ -758,7 +758,7 @@ async def test_connection_pool_detects_remote_close(): await server.listen("tcp://") # open a connection, use it and give it back to the pool - p = ConnectionPool(limit=10) + p = await ConnectionPool(limit=10) conn = await p.connect(server.address) await send_recv(conn, op="ping") p.reuse(server.address, conn) diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 5459716ca8..24a40dccda 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -1887,7 +1887,7 @@ async def test_gather_failing_cnn_recover(c, s, a, b): orig_rpc = s.rpc x = await c.scatter({"x": 1}, workers=a.address) - s.rpc = FlakyConnectionPool(failing_connections=1) + s.rpc = await FlakyConnectionPool(failing_connections=1) with mock.patch("distributed.utils_comm.retry_count", 1): res = await s.gather(keys=["x"]) assert res["status"] == "OK" @@ -1898,7 +1898,7 @@ async def test_gather_failing_cnn_error(c, s, a, b): orig_rpc = s.rpc x = await c.scatter({"x": 1}, workers=a.address) - s.rpc = FlakyConnectionPool(failing_connections=10) + s.rpc = await FlakyConnectionPool(failing_connections=10) res = await s.gather(keys=["x"]) assert res["status"] == "error" assert list(res["keys"]) == ["x"] @@ -1949,7 +1949,7 @@ def reducer(x, y): z = c.submit(reducer, x, y) - s.rpc = FlakyConnectionPool(failing_connections=4) + s.rpc = await FlakyConnectionPool(failing_connections=4) with captured_logger( logging.getLogger("distributed.scheduler") diff --git a/distributed/tests/test_utils_comm.py b/distributed/tests/test_utils_comm.py index 2d0159a2d3..7ab793e18e 100644 --- a/distributed/tests/test_utils_comm.py +++ b/distributed/tests/test_utils_comm.py @@ -30,11 +30,11 @@ def test_subs_multiple(): @gen_cluster(client=True) -def test_gather_from_workers_permissive(c, s, a, b): - rpc = ConnectionPool() - x = yield c.scatter({"x": 1}, workers=a.address) +async def test_gather_from_workers_permissive(c, s, a, b): + rpc = await ConnectionPool() + x = await c.scatter({"x": 1}, workers=a.address) - data, missing, bad_workers = yield gather_from_workers( + data, missing, bad_workers = await gather_from_workers( {"x": [a.address], "y": [b.address]}, rpc=rpc ) @@ -68,11 +68,11 @@ async def connect(self, *args, **kwargs): @gen_cluster(client=True) -def test_gather_from_workers_permissive_flaky(c, s, a, b): - x = yield c.scatter({"x": 1}, workers=a.address) +async def test_gather_from_workers_permissive_flaky(c, s, a, b): + x = await c.scatter({"x": 1}, workers=a.address) - rpc = BrokenConnectionPool() - data, missing, bad_workers = yield gather_from_workers({"x": [a.address]}, rpc=rpc) + rpc = await BrokenConnectionPool() + data, missing, bad_workers = await gather_from_workers({"x": [a.address]}, rpc=rpc) assert missing == {"x": [a.address]} assert bad_workers == [a.address] diff --git a/distributed/worker.py b/distributed/worker.py index 191e4df085..ba25c91d97 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -1012,6 +1012,8 @@ async def start(self): return assert self.status is None, self.status + await super().start() + enable_gc_diagnosis() thread_state.on_event_loop_thread = True