Skip to content
2 changes: 1 addition & 1 deletion distributed/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1298,7 +1298,7 @@ async def _close(self, fast=False):
with ignoring(TimeoutError):
await gen.with_timeout(timedelta(seconds=2), list(coroutines))
with ignoring(AttributeError):
self.scheduler.close_rpc()
await self.scheduler.close_rpc()
self.scheduler = None

self.status = "closed"
Expand Down
7 changes: 3 additions & 4 deletions distributed/comm/tcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,14 +213,13 @@ async def read(self, deserializers=None):
raise CommClosedError("aborted stream on truncated data")
return msg

@gen.coroutine
def write(self, msg, serializers=None, on_error="message"):
async def write(self, msg, serializers=None, on_error="message"):
stream = self.stream
bytes_since_last_yield = 0
if stream is None:
raise CommClosedError

frames = yield to_frames(
frames = await to_frames(
msg,
serializers=serializers,
on_error=on_error,
Expand All @@ -247,7 +246,7 @@ def write(self, msg, serializers=None, on_error="message"):
future = stream.write(frame)
bytes_since_last_yield += nbytes(frame)
if bytes_since_last_yield > 32e6:
yield future
await future
bytes_since_last_yield = 0
except StreamClosedError as e:
stream = None
Expand Down
3 changes: 1 addition & 2 deletions distributed/comm/ucx.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,6 @@ async def connect(self, address: str, deserialize=True, **connection_args) -> UC


class UCXListener(Listener):
# MAX_LISTENERS 256 in ucx-py
prefix = UCXConnector.prefix
comm_class = UCXConnector.comm_class
encrypted = UCXConnector.encrypted
Expand Down Expand Up @@ -250,7 +249,7 @@ async def serve_forever(client_ep):
ucx = UCX(
client_ep,
local_addr=self.address,
peer_addr=self.address, # TODO: https://github.com/Akshay-Venkatesh/ucx-py/issues/111
peer_addr=self.address,
deserialize=self.deserialize,
)
if self.comm_handler:
Expand Down
16 changes: 6 additions & 10 deletions distributed/comm/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import logging
import socket

from tornado import gen

from .. import protocol
from ..utils import get_ip, get_ipv6, nbytes, offload

Expand All @@ -16,8 +14,7 @@
FRAME_OFFLOAD_THRESHOLD = 10 * 1024 ** 2 # 10 MB


@gen.coroutine
def to_frames(msg, serializers=None, on_error="message", context=None):
async def to_frames(msg, serializers=None, on_error="message", context=None):
"""
Serialize a message into a list of Distributed protocol frames.
"""
Expand All @@ -34,13 +31,12 @@ def _to_frames():
logger.exception(e)
raise

res = yield offload(_to_frames)
res = await offload(_to_frames)

raise gen.Return(res)
return res


@gen.coroutine
def from_frames(frames, deserialize=True, deserializers=None):
async def from_frames(frames, deserialize=True, deserializers=None):
"""
Unserialize a list of Distributed protocol frames.
"""
Expand All @@ -61,11 +57,11 @@ def _from_frames():
raise

if deserialize and size > FRAME_OFFLOAD_THRESHOLD:
res = yield offload(_from_frames)
res = await offload(_from_frames)
else:
res = _from_frames()

raise gen.Return(res)
return res


def get_tcp_server_address(tcp_server):
Expand Down
18 changes: 11 additions & 7 deletions distributed/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,25 +641,29 @@ async def live_comm(self):
return comm

def close_comms(self):
@gen.coroutine
def _close_comm(comm):
async def _close_comm(comm):
# Make sure we tell the peer to close
try:
if not comm.closed():
yield comm.write({"op": "close", "reply": False})
yield comm.close()
await comm.write({"op": "close", "reply": False})
await comm.close()
except EnvironmentError:
comm.abort()

tasks = []
for comm in list(self.comms):
if comm and not comm.closed():
# IOLoop.current().add_callback(_close_comm, comm)
task = asyncio.ensure_future(_close_comm(comm))
tasks.append(task)
for comm in list(self._created):
if comm and not comm.closed():
# IOLoop.current().add_callback(_close_comm, comm)
task = asyncio.ensure_future(_close_comm(comm))
tasks.append(task)

self.comms.clear()
return tasks

def __getattr__(self, key):
async def send_recv_from_rpc(**kwargs):
Expand All @@ -685,13 +689,13 @@ def close_rpc(self):
if self.status != "closed":
rpc.active.discard(self)
self.status = "closed"
self.close_comms()
return asyncio.gather(*self.close_comms())

def __enter__(self):
return self

def __exit__(self, *args):
self.close_rpc()
asyncio.ensure_future(self.close_rpc())

def __del__(self):
if self.status != "closed":
Expand Down Expand Up @@ -744,7 +748,7 @@ async def send_recv_from_rpc(**kwargs):

return send_recv_from_rpc

def close_rpc(self):
async def close_rpc(self):
pass

# For compatibility with rpc()
Expand Down
2 changes: 1 addition & 1 deletion distributed/deploy/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ async def _close(self):

for pc in self.periodic_callbacks.values():
pc.stop()
self.scheduler_comm.close_rpc()
await self.scheduler_comm.close_rpc()

self.status = "closed"

Expand Down
16 changes: 15 additions & 1 deletion distributed/deploy/tests/test_local.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
from functools import partial
import gc
import subprocess
Expand Down Expand Up @@ -455,7 +456,7 @@ def test_silent_startup():

if __name__ == "__main__":
with LocalCluster(1, dashboard_address=None, scheduler_port=0):
sleep(1.5)
sleep(.1)
"""

out = subprocess.check_output(
Expand Down Expand Up @@ -1004,3 +1005,16 @@ async def test_capture_security(cleanup, temporary):
) as cluster:
async with Client(cluster, asynchronous=True) as client:
assert client.security == cluster.security


@pytest.mark.asyncio
@pytest.mark.skipif(
sys.version_info < (3, 7), reason="asyncio.all_tasks not implemented"
)
async def test_no_danglng_asyncio_tasks(cleanup):
start = asyncio.all_tasks()
async with LocalCluster(asynchronous=True, processes=False):
await asyncio.sleep(0.01)

tasks = asyncio.all_tasks()
assert tasks == start
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not very familiar with asyncio, does this mean that if the assert was inside the async with block it would be fail (i.e., more tasks than during start)?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. This is the collection of all ongoing coroutines/tasks in asyncio. During Dask execution there are several concurrent asyncio tasks. We want to verify that they all get cleaned up.

These tests don't fix anything, but they're nice to have around regardless. I'm still searching for what is going on here.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, thanks for the explanation!

3 changes: 1 addition & 2 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2642,8 +2642,7 @@ async def restart(self, client=None, timeout=3):
"timeout. Continuuing with restart process"
)
finally:
for nanny in nannies:
nanny.close_rpc()
await asyncio.gather(*[nanny.close_rpc() for nanny in nannies])

await self.start()

Expand Down
5 changes: 3 additions & 2 deletions distributed/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,13 +299,14 @@ def test_rpc_inproc():
yield check_rpc("inproc://", None)


def test_rpc_inputs():
@pytest.mark.asyncio
async def test_rpc_inputs():
L = [rpc("127.0.0.1:8884"), rpc(("127.0.0.1", 8884)), rpc("tcp://127.0.0.1:8884")]

assert all(r.address == "tcp://127.0.0.1:8884" for r in L), L

for r in L:
r.close_rpc()
await r.close_rpc()


async def check_rpc_message_lifetime(*listen_args):
Expand Down
2 changes: 1 addition & 1 deletion distributed/tests/test_nanny.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def test_nanny_process_failure(c, s):
assert not os.path.exists(second_dir)
assert not os.path.exists(first_dir)
assert first_dir != n.worker_dir
ww.close_rpc()
yield ww.close_rpc()
s.stop()


Expand Down
16 changes: 16 additions & 0 deletions distributed/tests/test_scheduler.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import cloudpickle
import pickle
from collections import defaultdict
Expand Down Expand Up @@ -1688,3 +1689,18 @@ def test_get_task_duration():
assert s.get_task_duration(ts_pref2_2) == 0.5 # default
assert len(s.unknown_durations) == 1
assert len(s.unknown_durations["prefix_2"]) == 2


@pytest.mark.asyncio
@pytest.mark.skipif(
sys.version_info < (3, 7), reason="asyncio.all_tasks not implemented"
)
async def test_no_danglng_asyncio_tasks(cleanup):
start = asyncio.all_tasks()
async with Scheduler(port=0) as s:
async with Worker(s.address, name="0") as a:
async with Client(s.address, asynchronous=True) as c:
await asyncio.sleep(0.01)

tasks = asyncio.all_tasks()
assert tasks == start
33 changes: 14 additions & 19 deletions distributed/tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def dont_test_delete_data_with_missing_worker(c, a, b):
assert not c.has_what[bad]
assert not c.has_what[a.address]

cc.close_rpc()
yield cc.close_rpc()


@gen_cluster(client=True)
Expand Down Expand Up @@ -998,32 +998,27 @@ def test_worker_fds(s):


@gen_cluster(nthreads=[])
def test_service_hosts_match_worker(s):
async def test_service_hosts_match_worker(s):
pytest.importorskip("bokeh")
from distributed.dashboard import BokehWorker

services = {("dashboard", ":0"): BokehWorker}

w = yield Worker(
async with Worker(
s.address, services={("dashboard", ":0"): BokehWorker}, host="tcp://0.0.0.0"
)
sock = first(w.services["dashboard"].server._http._sockets.values())
assert sock.getsockname()[0] in ("::", "0.0.0.0")
yield w.close()
) as w:
sock = first(w.services["dashboard"].server._http._sockets.values())
assert sock.getsockname()[0] in ("::", "0.0.0.0")

w = yield Worker(
async with Worker(
s.address, services={("dashboard", ":0"): BokehWorker}, host="tcp://127.0.0.1"
)
sock = first(w.services["dashboard"].server._http._sockets.values())
assert sock.getsockname()[0] in ("::", "0.0.0.0")
yield w.close()
) as w:
sock = first(w.services["dashboard"].server._http._sockets.values())
assert sock.getsockname()[0] in ("::", "0.0.0.0")

w = yield Worker(
async with Worker(
s.address, services={("dashboard", 0): BokehWorker}, host="tcp://127.0.0.1"
)
sock = first(w.services["dashboard"].server._http._sockets.values())
assert sock.getsockname()[0] == "127.0.0.1"
yield w.close()
) as w:
sock = first(w.services["dashboard"].server._http._sockets.values())
assert sock.getsockname()[0] == "127.0.0.1"


@gen_cluster(nthreads=[])
Expand Down
6 changes: 3 additions & 3 deletions distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1365,6 +1365,6 @@ def is_valid_xml(text):
weakref.finalize(_offload_executor, _offload_executor.shutdown)


@gen.coroutine
def offload(fn, *args, **kwargs):
return (yield _offload_executor.submit(fn, *args, **kwargs))
async def offload(fn, *args, **kwargs):
loop = asyncio.get_event_loop()
return await loop.run_in_executor(_offload_executor, fn, *args, **kwargs)
4 changes: 2 additions & 2 deletions distributed/utils_comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ async def gather_from_workers(who_has, rpc, close=True, serializers=None, who=No
response.update(r["data"])
finally:
for r in rpcs.values():
r.close_rpc()
await r.close_rpc()

bad_addresses |= {v for k, v in rev.items() if k not in response}
results.update(response)
Expand Down Expand Up @@ -148,7 +148,7 @@ async def scatter_to_workers(nthreads, data, rpc=rpc, report=True, serializers=N
)
finally:
for r in rpcs.values():
r.close_rpc()
await r.close_rpc()

nbytes = merge(o["nbytes"] for o in out)

Expand Down
2 changes: 1 addition & 1 deletion distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1063,7 +1063,7 @@ async def close(
address=self.contact_address, safe=safe
),
)
self.scheduler.close_rpc()
await self.scheduler.close_rpc()
self._workdir.release()

for k, v in self.services.items():
Expand Down