Skip to content

Commit

Permalink
Improve cleanup when closing during handshake
Browse files Browse the repository at this point in the history
  • Loading branch information
horazont committed Jul 20, 2018
1 parent 762252b commit c280328
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 7 deletions.
17 changes: 11 additions & 6 deletions aioopenssl/__init__.py
Expand Up @@ -207,11 +207,11 @@ def __init__(self, loop, rawsock, protocol, ssl_context_factory,

def _waiter_done(self, fut):
self._trace_logger.debug("_waiter future done (%r)", fut)
if fut.cancelled():
for chained in self._chained_pending:
self._trace_logger.debug("cancelling chained %r", chained)
chained.cancel()
self._chained_pending.clear()

for chained in self._chained_pending:
self._trace_logger.debug("cancelling chained %r", chained)
chained.cancel()
self._chained_pending.clear()

def _invalid_transition(self, via=None, to=None):
via_text = (" via {}".format(via)) if via is not None else ""
Expand Down Expand Up @@ -257,6 +257,8 @@ def _force_close(self, exc):
if self._buffer:
self._buffer.clear()

if self._waiter is not None and not self._waiter.done():
self._waiter.set_exception(ConnectionError("_force_close() called"))
self._loop.remove_reader(self._raw_fd)
self._loop.remove_writer(self._raw_fd)
self._loop.call_soon(self._call_connection_lost_and_clean_up, exc)
Expand Down Expand Up @@ -371,6 +373,9 @@ def _tls_post_handshake_done(self, task):
self._chained_pending.discard(task)
try:
task.result()
except asyncio.CancelledError:
# canceled due to closure or something similar
pass
except BaseException as err:
self._tls_post_handshake(err)
else:
Expand All @@ -379,9 +384,9 @@ def _tls_post_handshake_done(self, task):
def _tls_post_handshake(self, exc):
self._trace_logger.debug("_tls_post_handshake called")
if exc is not None:
self._fatal_error(exc, "Fatal error on post-handshake callback")
if self._waiter is not None and not self._waiter.done():
self._waiter.set_exception(exc)
self._fatal_error(exc, "Fatal error on post-handshake callback")
return

self._tls_read_wants_write = False
Expand Down
36 changes: 35 additions & 1 deletion tests/test_e2e.py
Expand Up @@ -447,14 +447,48 @@ def post_handshake_callback(transport):

s_reader, s_writer = yield from self.inbound_queue.get()


s_recv = yield from asyncio.wait_for(
s_reader.readexactly(6),
timeout=0.1,
)

self.assertEqual(s_recv, b"foobar")

@blocking
@asyncio.coroutine
def test_close_during_handshake(self):
cancelled = None

@asyncio.coroutine
def post_handshake_callback(transport):
nonlocal cancelled
try:
yield from asyncio.sleep(0.5)
cancelled = False
except asyncio.CancelledError:
cancelled = True

c_transport, c_reader, c_writer = yield from self._connect(
host="127.0.0.1",
port=PORT,
ssl_context_factory=lambda transport: OpenSSL.SSL.Context(
OpenSSL.SSL.SSLv23_METHOD
),
server_hostname="localhost",
use_starttls=True,
post_handshake_callback=post_handshake_callback,
)

starttls_task = asyncio.ensure_future(c_transport.starttls())
# ensure that handshake is in progress...
yield from asyncio.sleep(0.2)
c_transport.close()

with self.assertRaises(ConnectionError):
yield from starttls_task

self.assertTrue(cancelled)


class ServerThread(threading.Thread):
def __init__(self, ctx, port, loop, queue):
Expand Down

0 comments on commit c280328

Please sign in to comment.