From 499ae2c074769a48bc4129ee9957f56a8b23075e Mon Sep 17 00:00:00 2001 From: Antoine Pitrou Date: Tue, 7 Nov 2017 22:46:57 +0100 Subject: [PATCH 1/2] Fix Client.close() with asyncio --- distributed/client.py | 75 +++++++++++++++++++++++++------------------ 1 file changed, 43 insertions(+), 32 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index 306705847bd..f78540164d6 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -806,41 +806,44 @@ def _release_key(self, key): def _handle_report(self): """ Listen to scheduler """ with log_errors(): - while True: - try: - msgs = yield self.scheduler_comm.comm.read() - except CommClosedError: - if self.status == 'running': - logger.warning("Client report stream closed to scheduler") - logger.info("Reconnecting...") - self.status = 'connecting' - yield self._reconnect() - continue - else: - break - if not isinstance(msgs, list): - msgs = [msgs] + try: + while True: + try: + msgs = yield self.scheduler_comm.comm.read() + except CommClosedError: + if self.status == 'running': + logger.warning("Client report stream closed to scheduler") + logger.info("Reconnecting...") + self.status = 'connecting' + yield self._reconnect() + continue + else: + break + if not isinstance(msgs, list): + msgs = [msgs] - breakout = False - for msg in msgs: - logger.debug("Client receives message %s", msg) + breakout = False + for msg in msgs: + logger.debug("Client receives message %s", msg) - if 'status' in msg and 'error' in msg['status']: - six.reraise(*clean_exception(**msg)) + if 'status' in msg and 'error' in msg['status']: + six.reraise(*clean_exception(**msg)) - op = msg.pop('op') + op = msg.pop('op') - if op == 'close' or op == 'stream-closed': - breakout = True - break + if op == 'close' or op == 'stream-closed': + breakout = True + break - try: - handler = self._handlers[op] - handler(**msg) - except Exception as e: - logger.exception(e) - if breakout: - break + try: + handler = self._handlers[op] + handler(**msg) + except Exception as e: + logger.exception(e) + if breakout: + break + except CancelledError: + pass def _handle_key_in_memory(self, key=None, type=None, workers=None): state = self.futures.get(key) @@ -922,10 +925,18 @@ def _close(self, fast=False): self.status = 'closed' if _get_global_client() is self: _set_global_client(None) + coroutines = set(self.coroutines) + for f in self.coroutines: + # cancel() works on asyncio futures + # but is a no-op on Tornado futures + f.cancel() + if f.cancelled: + coroutines.remove(f) + del self.coroutines[:] if not fast: with ignoring(TimeoutError): - yield [gen.with_timeout(timedelta(seconds=2), f) - for f in self.coroutines] + yield gen.with_timeout(timedelta(seconds=2), + list(coroutines)) with ignoring(AttributeError): self.scheduler.close_rpc() self.scheduler = None From 0250319c7f67d96f8475828adadf30113a62f6ce Mon Sep 17 00:00:00 2001 From: Antoine Pitrou Date: Tue, 7 Nov 2017 23:33:35 +0100 Subject: [PATCH 2/2] Fix bug in inproc comms --- distributed/client.py | 17 ++++++++---- distributed/comm/inproc.py | 1 + distributed/comm/tests/test_comms.py | 41 ++++++++++++++-------------- 3 files changed, 32 insertions(+), 27 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index f78540164d6..18dcbad9048 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -900,6 +900,8 @@ def _handle_error(self, exception=None): @gen.coroutine def _close(self, fast=False): """ Send close signal and wait until scheduler completes """ + self.status = 'closing' + with log_errors(): for pc in self._periodic_callbacks: pc.stop() @@ -927,10 +929,10 @@ def _close(self, fast=False): _set_global_client(None) coroutines = set(self.coroutines) for f in self.coroutines: - # cancel() works on asyncio futures + # cancel() works on asyncio futures (Tornado 5) # but is a no-op on Tornado futures f.cancel() - if f.cancelled: + if f.cancelled(): coroutines.remove(f) del self.coroutines[:] if not fast: @@ -941,6 +943,8 @@ def _close(self, fast=False): self.scheduler.close_rpc() self.scheduler = None + self.status = 'closed' + _shutdown = _close def close(self, timeout=10): @@ -955,16 +959,17 @@ def close(self, timeout=10): -------- Client.restart """ + # XXX handling of self.status here is not thread-safe + if self.status == 'closed': + return + self.status = 'closing' + if self.asynchronous: future = self._close() if timeout: future = gen.with_timeout(timedelta(seconds=timeout), future) return future - # XXX handling of self.status here is not thread-safe - if self.status == 'closed': - return - self.status = 'closing' if self._start_arg is None: with ignoring(AttributeError): self.cluster.close() diff --git a/distributed/comm/inproc.py b/distributed/comm/inproc.py index 5f00f6341e3..11ed6b25132 100644 --- a/distributed/comm/inproc.py +++ b/distributed/comm/inproc.py @@ -209,6 +209,7 @@ def abort(self): if not self.closed(): # Putting EOF is cheap enough that we do it on abort() too self._write_loop.add_callback(self._write_q.put_nowait, _EOF) + self._read_q.put_nowait(_EOF) self._write_q = self._read_q = None self._closed = True self._finalizer.detach() diff --git a/distributed/comm/tests/test_comms.py b/distributed/comm/tests/test_comms.py index 8accad7d370..9fff08eb604 100644 --- a/distributed/comm/tests/test_comms.py +++ b/distributed/comm/tests/test_comms.py @@ -324,6 +324,8 @@ def client_communicate(key, delay=0): msg = yield comm.read() assert msg == {'op': 'pong', 'data': key} l.append(key) + with pytest.raises(CommClosedError): + yield comm.read() yield comm.close() client_communicate = partial(run_client, client_communicate) @@ -649,29 +651,26 @@ def test_inproc_comm_closed_implicit(): @gen.coroutine def check_comm_closed_explicit(addr, listen_args=None, connect_args=None): - @gen.coroutine - def handle_comm(comm): - # Wait - try: - yield comm.read() - except CommClosedError: - pass - - listener = listen(addr, handle_comm, connection_args=listen_args) - listener.start() - contact_addr = listener.contact_address - - comm = yield connect(contact_addr, connection_args=connect_args) - comm.close() + a, b = yield get_comm_pair(addr, listen_args=listen_args, connect_args=connect_args) + a_read = a.read() + b_read = b.read() + yield a.close() + # In-flight reads should abort with CommClosedError with pytest.raises(CommClosedError): - yield comm.write({}) - - comm = yield connect(contact_addr, connection_args=connect_args) - comm.close() + yield a_read with pytest.raises(CommClosedError): - yield comm.read() - - yield gen.moment + yield b_read + # New reads as well + with pytest.raises(CommClosedError): + yield a.read() + with pytest.raises(CommClosedError): + yield b.read() + # And writes + with pytest.raises(CommClosedError): + yield a.write({}) + with pytest.raises(CommClosedError): + yield b.write({}) + yield b.close() @gen_test()