Skip to content

Commit

Permalink
handle keyboard interrupt during reconnect (Fixes #301)
Browse files Browse the repository at this point in the history
  • Loading branch information
miguelgrinberg committed Jun 29, 2019
1 parent 8a4e5ff commit fa53e38
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 34 deletions.
11 changes: 10 additions & 1 deletion socketio/asyncio_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,8 @@ async def _trigger_event(self, event, namespace, *args):
event, *args)

async def _handle_reconnect(self):
self._reconnect_abort.clear()
client.reconnecting_clients.append(self)
attempt_count = 0
current_delay = self.reconnection_delay
while True:
Expand All @@ -366,7 +368,12 @@ async def _handle_reconnect(self):
self.logger.info(
'Connection failed, new attempt in {:.02f} seconds'.format(
delay))
await self.sleep(delay)
try:
await asyncio.wait_for(self._reconnect_abort.wait(), delay)
self.logger.info('Reconnect task aborted')
break
except (asyncio.TimeoutError, asyncio.CancelledError):
pass
attempt_count += 1
try:
await self.connect(self.connection_url,
Expand All @@ -385,6 +392,7 @@ async def _handle_reconnect(self):
self.logger.info(
'Maximum reconnection attempts reached, giving up')
break
client.reconnecting_clients.remove(self)

def _handle_eio_connect(self):
"""Handle the Engine.IO connection event."""
Expand Down Expand Up @@ -422,6 +430,7 @@ async def _handle_eio_message(self, data):
async def _handle_eio_disconnect(self):
"""Handle the Engine.IO disconnection event."""
self.logger.info('Engine.IO connection dropped')
self._reconnect_abort.set()
for n in self.namespaces:
await self._trigger_event('disconnect', namespace=n)
await self._trigger_event('disconnect', namespace='/')
Expand Down
25 changes: 24 additions & 1 deletion socketio/client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import itertools
import logging
import random
import signal

import engineio
import six
Expand All @@ -10,6 +11,21 @@
from . import packet

default_logger = logging.getLogger('socketio.client')
reconnecting_clients = []


def signal_handler(sig, frame): # pragma: no cover
"""SIGINT handler.
Notify any clients that are in a reconnect loop to abort. Other
disconnection tasks are handled at the engine.io level.
"""
for client in reconnecting_clients[:]:
client._reconnect_abort.set()
return original_signal_handler(sig, frame)


original_signal_handler = signal.signal(signal.SIGINT, signal_handler)


class Client(object):
Expand Down Expand Up @@ -102,6 +118,7 @@ def __init__(self, reconnection=True, reconnection_attempts=0,
self.callbacks = {}
self._binary_packet = None
self._reconnect_task = None
self._reconnect_abort = self.eio.create_event()

def is_asyncio_based(self):
return False
Expand Down Expand Up @@ -486,6 +503,8 @@ def _trigger_event(self, event, namespace, *args):
event, *args)

def _handle_reconnect(self):
self._reconnect_abort.clear()
reconnecting_clients.append(self)
attempt_count = 0
current_delay = self.reconnection_delay
while True:
Expand All @@ -497,7 +516,10 @@ def _handle_reconnect(self):
self.logger.info(
'Connection failed, new attempt in {:.02f} seconds'.format(
delay))
self.sleep(delay)
print('***', self._reconnect_abort.wait)
if self._reconnect_abort.wait(delay):
self.logger.info('Reconnect task aborted')
break
attempt_count += 1
try:
self.connect(self.connection_url,
Expand All @@ -516,6 +538,7 @@ def _handle_reconnect(self):
self.logger.info(
'Maximum reconnection attempts reached, giving up')
break
reconnecting_clients.remove(self)

def _handle_eio_connect(self):
"""Handle the Engine.IO connection event."""
Expand Down
73 changes: 50 additions & 23 deletions tests/asyncio/test_asyncio_client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
from contextlib import contextmanager
import sys
import unittest

Expand Down Expand Up @@ -26,6 +27,19 @@ async def mock_coro(*args, **kwargs):
return mock_coro


@contextmanager
def mock_wait_for():
async def fake_wait_for(coro, timeout):
await coro
await fake_wait_for._mock(timeout)

original_wait_for = asyncio.wait_for
asyncio.wait_for = fake_wait_for
fake_wait_for._mock = AsyncMock()
yield
asyncio.wait_for = original_wait_for


def _run(coro):
"""Run the given coroutine."""
return asyncio.get_event_loop().run_until_complete(coro)
Expand Down Expand Up @@ -542,51 +556,64 @@ def on_foo(self, a, b):
_run(c._trigger_event('foo', '/', 1, '2'))
self.assertEqual(result, [1, '2'])

@mock.patch('asyncio.wait_for', new_callable=AsyncMock,
side_effect=asyncio.TimeoutError)
@mock.patch('socketio.client.random.random', side_effect=[1, 0, 0.5])
def test_handle_reconnect(self, random):
def test_handle_reconnect(self, random, wait_for):
c = asyncio_client.AsyncClient()
c._reconnect_task = 'foo'
c.sleep = AsyncMock()
c.connect = AsyncMock(
side_effect=[ValueError, exceptions.ConnectionError, None])
_run(c._handle_reconnect())
self.assertEqual(c.sleep.mock.call_count, 3)
self.assertEqual(c.sleep.mock.call_args_list, [
mock.call(1.5),
mock.call(1.5),
mock.call(4.0)
])
self.assertEqual(wait_for.mock.call_count, 3)
self.assertEqual(
[x[0][1] for x in asyncio.wait_for.mock.call_args_list],
[1.5, 1.5, 4.0])
self.assertEqual(c._reconnect_task, None)

@mock.patch('asyncio.wait_for', new_callable=AsyncMock,
side_effect=asyncio.TimeoutError)
@mock.patch('socketio.client.random.random', side_effect=[1, 0, 0.5])
def test_handle_reconnect_max_delay(self, random):
def test_handle_reconnect_max_delay(self, random, wait_for):
c = asyncio_client.AsyncClient(reconnection_delay_max=3)
c._reconnect_task = 'foo'
c.sleep = AsyncMock()
c.connect = AsyncMock(
side_effect=[ValueError, exceptions.ConnectionError, None])
_run(c._handle_reconnect())
self.assertEqual(c.sleep.mock.call_count, 3)
self.assertEqual(c.sleep.mock.call_args_list, [
mock.call(1.5),
mock.call(1.5),
mock.call(3.0)
])
self.assertEqual(wait_for.mock.call_count, 3)
self.assertEqual(
[x[0][1] for x in asyncio.wait_for.mock.call_args_list],
[1.5, 1.5, 3.0])
self.assertEqual(c._reconnect_task, None)

@mock.patch('asyncio.wait_for', new_callable=AsyncMock,
side_effect=asyncio.TimeoutError)
@mock.patch('socketio.client.random.random', side_effect=[1, 0, 0.5])
def test_handle_reconnect_max_attempts(self, random):
def test_handle_reconnect_max_attempts(self, random, wait_for):
c = asyncio_client.AsyncClient(reconnection_attempts=2)
c._reconnect_task = 'foo'
c.sleep = AsyncMock()
c.connect = AsyncMock(
side_effect=[ValueError, exceptions.ConnectionError, None])
_run(c._handle_reconnect())
self.assertEqual(c.sleep.mock.call_count, 2)
self.assertEqual(c.sleep.mock.call_args_list, [
mock.call(1.5),
mock.call(1.5)
])
self.assertEqual(wait_for.mock.call_count, 2)
self.assertEqual(
[x[0][1] for x in asyncio.wait_for.mock.call_args_list],
[1.5, 1.5])
self.assertEqual(c._reconnect_task, 'foo')

@mock.patch('asyncio.wait_for', new_callable=AsyncMock,
side_effect=[asyncio.TimeoutError, None])
@mock.patch('socketio.client.random.random', side_effect=[1, 0, 0.5])
def test_handle_reconnect_aborted(self, random, wait_for):
c = asyncio_client.AsyncClient()
c._reconnect_task = 'foo'
c.connect = AsyncMock(
side_effect=[ValueError, exceptions.ConnectionError, None])
_run(c._handle_reconnect())
self.assertEqual(wait_for.mock.call_count, 2)
self.assertEqual(
[x[0][1] for x in asyncio.wait_for.mock.call_args_list],
[1.5, 1.5])
self.assertEqual(c._reconnect_task, 'foo')

def test_eio_connect(self):
Expand Down
32 changes: 23 additions & 9 deletions tests/common/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,12 +671,12 @@ def on_foo(self, a, b):
def test_handle_reconnect(self, random):
c = client.Client()
c._reconnect_task = 'foo'
c.sleep = mock.MagicMock()
c._reconnect_abort.wait = mock.MagicMock(return_value=False)
c.connect = mock.MagicMock(
side_effect=[ValueError, exceptions.ConnectionError, None])
c._handle_reconnect()
self.assertEqual(c.sleep.call_count, 3)
self.assertEqual(c.sleep.call_args_list, [
self.assertEqual(c._reconnect_abort.wait.call_count, 3)
self.assertEqual(c._reconnect_abort.wait.call_args_list, [
mock.call(1.5),
mock.call(1.5),
mock.call(4.0)
Expand All @@ -687,12 +687,12 @@ def test_handle_reconnect(self, random):
def test_handle_reconnect_max_delay(self, random):
c = client.Client(reconnection_delay_max=3)
c._reconnect_task = 'foo'
c.sleep = mock.MagicMock()
c._reconnect_abort.wait = mock.MagicMock(return_value=False)
c.connect = mock.MagicMock(
side_effect=[ValueError, exceptions.ConnectionError, None])
c._handle_reconnect()
self.assertEqual(c.sleep.call_count, 3)
self.assertEqual(c.sleep.call_args_list, [
self.assertEqual(c._reconnect_abort.wait.call_count, 3)
self.assertEqual(c._reconnect_abort.wait.call_args_list, [
mock.call(1.5),
mock.call(1.5),
mock.call(3.0)
Expand All @@ -703,12 +703,26 @@ def test_handle_reconnect_max_delay(self, random):
def test_handle_reconnect_max_attempts(self, random):
c = client.Client(reconnection_attempts=2)
c._reconnect_task = 'foo'
c.sleep = mock.MagicMock()
c._reconnect_abort.wait = mock.MagicMock(return_value=False)
c.connect = mock.MagicMock(
side_effect=[ValueError, exceptions.ConnectionError, None])
c._handle_reconnect()
self.assertEqual(c.sleep.call_count, 2)
self.assertEqual(c.sleep.call_args_list, [
self.assertEqual(c._reconnect_abort.wait.call_count, 2)
self.assertEqual(c._reconnect_abort.wait.call_args_list, [
mock.call(1.5),
mock.call(1.5)
])
self.assertEqual(c._reconnect_task, 'foo')

@mock.patch('socketio.client.random.random', side_effect=[1, 0, 0.5])
def test_handle_reconnect_aborted(self, random):
c = client.Client()
c._reconnect_task = 'foo'
c._reconnect_abort.wait = mock.MagicMock(side_effect=[False, True])
c.connect = mock.MagicMock(side_effect=exceptions.ConnectionError)
c._handle_reconnect()
self.assertEqual(c._reconnect_abort.wait.call_count, 2)
self.assertEqual(c._reconnect_abort.wait.call_args_list, [
mock.call(1.5),
mock.call(1.5)
])
Expand Down

0 comments on commit fa53e38

Please sign in to comment.