Skip to content

Commit

Permalink
Ensure Client connection pool semaphore attaches to the Client's even…
Browse files Browse the repository at this point in the history
…t loop (#3546)

* Add Node and ConnectionPool start methods
* Make ConnectionPools awaitable
  • Loading branch information
jrbourbeau committed Mar 24, 2020
1 parent 0f834e7 commit dd28d08
Show file tree
Hide file tree
Showing 10 changed files with 49 additions and 19 deletions.
3 changes: 3 additions & 0 deletions distributed/client.py
Expand Up @@ -923,6 +923,9 @@ def _send_to_scheduler(self, msg):
)

async def _start(self, timeout=no_default, **kwargs):

await super().start()

if timeout == no_default:
timeout = self._timeout
if timeout is not None:
Expand Down
13 changes: 11 additions & 2 deletions distributed/core.py
Expand Up @@ -835,8 +835,6 @@ def __init__(
self.connection_args = connection_args
self.timeout = timeout
self._n_connecting = 0
# Invariant: semaphore._value == limit - open - _n_connecting
self.semaphore = asyncio.Semaphore(self.limit)
self.server = weakref.ref(server) if server else None
self._created = weakref.WeakSet()
self._instances.add(self)
Expand Down Expand Up @@ -871,6 +869,17 @@ def __call__(self, addr=None, ip=None, port=None):
addr, self, serializers=self.serializers, deserializers=self.deserializers
)

def __await__(self):
async def _():
await self.start()
return self

return _().__await__()

async def start(self):
# Invariant: semaphore._value == limit - open - _n_connecting
self.semaphore = asyncio.Semaphore(self.limit)

async def connect(self, addr, timeout=None):
"""
Get a Comm to the given address. For internal use.
Expand Down
3 changes: 3 additions & 0 deletions distributed/nanny.py
Expand Up @@ -241,6 +241,9 @@ def local_dir(self):

async def start(self):
""" Start nanny, start local process, start watching """

await super().start()

await self.listen(self._start_address, listen_args=self.listen_args)
self.ip = get_address_host(self.address)

Expand Down
7 changes: 6 additions & 1 deletion distributed/node.py
Expand Up @@ -38,6 +38,9 @@ def __init__(
server=self,
)

async def start(self):
await self.rpc.start()


class ServerNode(Node, Server):
"""
Expand Down Expand Up @@ -182,5 +185,7 @@ async def wait_for(future, timeout=None):
future = wait_for(future, timeout=timeout)
return future.__await__()

async def start(self): # subclasses should implement this
async def start(self):
# subclasses should implement their own start method whichs calls super().start()
await Node.start(self)
return self
3 changes: 3 additions & 0 deletions distributed/scheduler.py
Expand Up @@ -1408,6 +1408,9 @@ def get_worker_service_addr(self, worker, service_name, protocol=False):

async def start(self):
""" Clear out old state and restart all running coroutines """

await super().start()

enable_gc_diagnosis()

self.clear_task_state()
Expand Down
5 changes: 5 additions & 0 deletions distributed/tests/test_client.py
Expand Up @@ -5943,3 +5943,8 @@ def test_as_completed_condition_loop(c, s, a, b):
seq = c.map(inc, range(5))
ac = as_completed(seq)
assert ac.condition._loop == c.loop.asyncio_loop


def test_client_connectionpool_semaphore_loop(s, a, b):
with Client(s["address"]) as c:
assert c.rpc.semaphore._loop is c.loop.asyncio_loop
10 changes: 5 additions & 5 deletions distributed/tests/test_core.py
Expand Up @@ -526,7 +526,7 @@ async def ping(comm, delay=0.1):
for server in servers:
await server.listen(0)

rpc = ConnectionPool(limit=5)
rpc = await ConnectionPool(limit=5)

# Reuse connections
await asyncio.gather(
Expand Down Expand Up @@ -583,7 +583,7 @@ async def do_ping(pool, port):
for server in servers:
await server.listen(0)

pool = ConnectionPool(limit=limit)
pool = await ConnectionPool(limit=limit)

await asyncio.gather(*[do_ping(pool, s.port) for s in servers])

Expand All @@ -605,7 +605,7 @@ async def ping(comm, delay=0.01):
for server in servers:
await server.listen("tls://", listen_args=listen_args)

rpc = ConnectionPool(limit=5, connection_args=connection_args)
rpc = await ConnectionPool(limit=5, connection_args=connection_args)

await asyncio.gather(*[rpc(s.address).ping() for s in servers[:5]])
await asyncio.gather(*[rpc(s.address).ping() for s in servers[::2]])
Expand All @@ -625,7 +625,7 @@ async def ping(comm, delay=0.01):
for server in servers:
await server.listen(0)

rpc = ConnectionPool(limit=10)
rpc = await ConnectionPool(limit=10)
serv = servers.pop()
await asyncio.gather(*[rpc(s.address).ping() for s in servers])
await asyncio.gather(*[rpc(serv.address).ping() for i in range(3)])
Expand Down Expand Up @@ -758,7 +758,7 @@ async def test_connection_pool_detects_remote_close():
await server.listen("tcp://")

# open a connection, use it and give it back to the pool
p = ConnectionPool(limit=10)
p = await ConnectionPool(limit=10)
conn = await p.connect(server.address)
await send_recv(conn, op="ping")
p.reuse(server.address, conn)
Expand Down
6 changes: 3 additions & 3 deletions distributed/tests/test_scheduler.py
Expand Up @@ -1887,7 +1887,7 @@ async def test_gather_failing_cnn_recover(c, s, a, b):
orig_rpc = s.rpc
x = await c.scatter({"x": 1}, workers=a.address)

s.rpc = FlakyConnectionPool(failing_connections=1)
s.rpc = await FlakyConnectionPool(failing_connections=1)
with mock.patch("distributed.utils_comm.retry_count", 1):
res = await s.gather(keys=["x"])
assert res["status"] == "OK"
Expand All @@ -1898,7 +1898,7 @@ async def test_gather_failing_cnn_error(c, s, a, b):
orig_rpc = s.rpc
x = await c.scatter({"x": 1}, workers=a.address)

s.rpc = FlakyConnectionPool(failing_connections=10)
s.rpc = await FlakyConnectionPool(failing_connections=10)
res = await s.gather(keys=["x"])
assert res["status"] == "error"
assert list(res["keys"]) == ["x"]
Expand Down Expand Up @@ -1949,7 +1949,7 @@ def reducer(x, y):

z = c.submit(reducer, x, y)

s.rpc = FlakyConnectionPool(failing_connections=4)
s.rpc = await FlakyConnectionPool(failing_connections=4)

with captured_logger(
logging.getLogger("distributed.scheduler")
Expand Down
16 changes: 8 additions & 8 deletions distributed/tests/test_utils_comm.py
Expand Up @@ -30,11 +30,11 @@ def test_subs_multiple():


@gen_cluster(client=True)
def test_gather_from_workers_permissive(c, s, a, b):
rpc = ConnectionPool()
x = yield c.scatter({"x": 1}, workers=a.address)
async def test_gather_from_workers_permissive(c, s, a, b):
rpc = await ConnectionPool()
x = await c.scatter({"x": 1}, workers=a.address)

data, missing, bad_workers = yield gather_from_workers(
data, missing, bad_workers = await gather_from_workers(
{"x": [a.address], "y": [b.address]}, rpc=rpc
)

Expand Down Expand Up @@ -68,11 +68,11 @@ async def connect(self, *args, **kwargs):


@gen_cluster(client=True)
def test_gather_from_workers_permissive_flaky(c, s, a, b):
x = yield c.scatter({"x": 1}, workers=a.address)
async def test_gather_from_workers_permissive_flaky(c, s, a, b):
x = await c.scatter({"x": 1}, workers=a.address)

rpc = BrokenConnectionPool()
data, missing, bad_workers = yield gather_from_workers({"x": [a.address]}, rpc=rpc)
rpc = await BrokenConnectionPool()
data, missing, bad_workers = await gather_from_workers({"x": [a.address]}, rpc=rpc)

assert missing == {"x": [a.address]}
assert bad_workers == [a.address]
Expand Down
2 changes: 2 additions & 0 deletions distributed/worker.py
Expand Up @@ -1012,6 +1012,8 @@ async def start(self):
return
assert self.status is None, self.status

await super().start()

enable_gc_diagnosis()
thread_state.on_event_loop_thread = True

Expand Down

0 comments on commit dd28d08

Please sign in to comment.