diff --git a/socketio/asyncio_client.py b/socketio/asyncio_client.py index 998846d3..d71ccecf 100644 --- a/socketio/asyncio_client.py +++ b/socketio/asyncio_client.py @@ -355,6 +355,8 @@ async def _trigger_event(self, event, namespace, *args): event, *args) async def _handle_reconnect(self): + self._reconnect_abort.clear() + client.reconnecting_clients.append(self) attempt_count = 0 current_delay = self.reconnection_delay while True: @@ -366,7 +368,12 @@ async def _handle_reconnect(self): self.logger.info( 'Connection failed, new attempt in {:.02f} seconds'.format( delay)) - await self.sleep(delay) + try: + await asyncio.wait_for(self._reconnect_abort.wait(), delay) + self.logger.info('Reconnect task aborted') + break + except (asyncio.TimeoutError, asyncio.CancelledError): + pass attempt_count += 1 try: await self.connect(self.connection_url, @@ -385,6 +392,7 @@ async def _handle_reconnect(self): self.logger.info( 'Maximum reconnection attempts reached, giving up') break + client.reconnecting_clients.remove(self) def _handle_eio_connect(self): """Handle the Engine.IO connection event.""" @@ -422,6 +430,7 @@ async def _handle_eio_message(self, data): async def _handle_eio_disconnect(self): """Handle the Engine.IO disconnection event.""" self.logger.info('Engine.IO connection dropped') + self._reconnect_abort.set() for n in self.namespaces: await self._trigger_event('disconnect', namespace=n) await self._trigger_event('disconnect', namespace='/') diff --git a/socketio/client.py b/socketio/client.py index a3766406..c0bdc1ab 100644 --- a/socketio/client.py +++ b/socketio/client.py @@ -1,6 +1,7 @@ import itertools import logging import random +import signal import engineio import six @@ -10,6 +11,21 @@ from . import packet default_logger = logging.getLogger('socketio.client') +reconnecting_clients = [] + + +def signal_handler(sig, frame): # pragma: no cover + """SIGINT handler. + + Notify any clients that are in a reconnect loop to abort. Other + disconnection tasks are handled at the engine.io level. + """ + for client in reconnecting_clients[:]: + client._reconnect_abort.set() + return original_signal_handler(sig, frame) + + +original_signal_handler = signal.signal(signal.SIGINT, signal_handler) class Client(object): @@ -102,6 +118,7 @@ def __init__(self, reconnection=True, reconnection_attempts=0, self.callbacks = {} self._binary_packet = None self._reconnect_task = None + self._reconnect_abort = self.eio.create_event() def is_asyncio_based(self): return False @@ -486,6 +503,8 @@ def _trigger_event(self, event, namespace, *args): event, *args) def _handle_reconnect(self): + self._reconnect_abort.clear() + reconnecting_clients.append(self) attempt_count = 0 current_delay = self.reconnection_delay while True: @@ -497,7 +516,10 @@ def _handle_reconnect(self): self.logger.info( 'Connection failed, new attempt in {:.02f} seconds'.format( delay)) - self.sleep(delay) + print('***', self._reconnect_abort.wait) + if self._reconnect_abort.wait(delay): + self.logger.info('Reconnect task aborted') + break attempt_count += 1 try: self.connect(self.connection_url, @@ -516,6 +538,7 @@ def _handle_reconnect(self): self.logger.info( 'Maximum reconnection attempts reached, giving up') break + reconnecting_clients.remove(self) def _handle_eio_connect(self): """Handle the Engine.IO connection event.""" diff --git a/tests/asyncio/test_asyncio_client.py b/tests/asyncio/test_asyncio_client.py index 4ba4b40b..27a3298a 100644 --- a/tests/asyncio/test_asyncio_client.py +++ b/tests/asyncio/test_asyncio_client.py @@ -1,4 +1,5 @@ import asyncio +from contextlib import contextmanager import sys import unittest @@ -26,6 +27,19 @@ async def mock_coro(*args, **kwargs): return mock_coro +@contextmanager +def mock_wait_for(): + async def fake_wait_for(coro, timeout): + await coro + await fake_wait_for._mock(timeout) + + original_wait_for = asyncio.wait_for + asyncio.wait_for = fake_wait_for + fake_wait_for._mock = AsyncMock() + yield + asyncio.wait_for = original_wait_for + + def _run(coro): """Run the given coroutine.""" return asyncio.get_event_loop().run_until_complete(coro) @@ -542,51 +556,64 @@ def on_foo(self, a, b): _run(c._trigger_event('foo', '/', 1, '2')) self.assertEqual(result, [1, '2']) + @mock.patch('asyncio.wait_for', new_callable=AsyncMock, + side_effect=asyncio.TimeoutError) @mock.patch('socketio.client.random.random', side_effect=[1, 0, 0.5]) - def test_handle_reconnect(self, random): + def test_handle_reconnect(self, random, wait_for): c = asyncio_client.AsyncClient() c._reconnect_task = 'foo' - c.sleep = AsyncMock() c.connect = AsyncMock( side_effect=[ValueError, exceptions.ConnectionError, None]) _run(c._handle_reconnect()) - self.assertEqual(c.sleep.mock.call_count, 3) - self.assertEqual(c.sleep.mock.call_args_list, [ - mock.call(1.5), - mock.call(1.5), - mock.call(4.0) - ]) + self.assertEqual(wait_for.mock.call_count, 3) + self.assertEqual( + [x[0][1] for x in asyncio.wait_for.mock.call_args_list], + [1.5, 1.5, 4.0]) self.assertEqual(c._reconnect_task, None) + @mock.patch('asyncio.wait_for', new_callable=AsyncMock, + side_effect=asyncio.TimeoutError) @mock.patch('socketio.client.random.random', side_effect=[1, 0, 0.5]) - def test_handle_reconnect_max_delay(self, random): + def test_handle_reconnect_max_delay(self, random, wait_for): c = asyncio_client.AsyncClient(reconnection_delay_max=3) c._reconnect_task = 'foo' - c.sleep = AsyncMock() c.connect = AsyncMock( side_effect=[ValueError, exceptions.ConnectionError, None]) _run(c._handle_reconnect()) - self.assertEqual(c.sleep.mock.call_count, 3) - self.assertEqual(c.sleep.mock.call_args_list, [ - mock.call(1.5), - mock.call(1.5), - mock.call(3.0) - ]) + self.assertEqual(wait_for.mock.call_count, 3) + self.assertEqual( + [x[0][1] for x in asyncio.wait_for.mock.call_args_list], + [1.5, 1.5, 3.0]) self.assertEqual(c._reconnect_task, None) + @mock.patch('asyncio.wait_for', new_callable=AsyncMock, + side_effect=asyncio.TimeoutError) @mock.patch('socketio.client.random.random', side_effect=[1, 0, 0.5]) - def test_handle_reconnect_max_attempts(self, random): + def test_handle_reconnect_max_attempts(self, random, wait_for): c = asyncio_client.AsyncClient(reconnection_attempts=2) c._reconnect_task = 'foo' - c.sleep = AsyncMock() c.connect = AsyncMock( side_effect=[ValueError, exceptions.ConnectionError, None]) _run(c._handle_reconnect()) - self.assertEqual(c.sleep.mock.call_count, 2) - self.assertEqual(c.sleep.mock.call_args_list, [ - mock.call(1.5), - mock.call(1.5) - ]) + self.assertEqual(wait_for.mock.call_count, 2) + self.assertEqual( + [x[0][1] for x in asyncio.wait_for.mock.call_args_list], + [1.5, 1.5]) + self.assertEqual(c._reconnect_task, 'foo') + + @mock.patch('asyncio.wait_for', new_callable=AsyncMock, + side_effect=[asyncio.TimeoutError, None]) + @mock.patch('socketio.client.random.random', side_effect=[1, 0, 0.5]) + def test_handle_reconnect_aborted(self, random, wait_for): + c = asyncio_client.AsyncClient() + c._reconnect_task = 'foo' + c.connect = AsyncMock( + side_effect=[ValueError, exceptions.ConnectionError, None]) + _run(c._handle_reconnect()) + self.assertEqual(wait_for.mock.call_count, 2) + self.assertEqual( + [x[0][1] for x in asyncio.wait_for.mock.call_args_list], + [1.5, 1.5]) self.assertEqual(c._reconnect_task, 'foo') def test_eio_connect(self): diff --git a/tests/common/test_client.py b/tests/common/test_client.py index 6fef6c31..da21f442 100644 --- a/tests/common/test_client.py +++ b/tests/common/test_client.py @@ -671,12 +671,12 @@ def on_foo(self, a, b): def test_handle_reconnect(self, random): c = client.Client() c._reconnect_task = 'foo' - c.sleep = mock.MagicMock() + c._reconnect_abort.wait = mock.MagicMock(return_value=False) c.connect = mock.MagicMock( side_effect=[ValueError, exceptions.ConnectionError, None]) c._handle_reconnect() - self.assertEqual(c.sleep.call_count, 3) - self.assertEqual(c.sleep.call_args_list, [ + self.assertEqual(c._reconnect_abort.wait.call_count, 3) + self.assertEqual(c._reconnect_abort.wait.call_args_list, [ mock.call(1.5), mock.call(1.5), mock.call(4.0) @@ -687,12 +687,12 @@ def test_handle_reconnect(self, random): def test_handle_reconnect_max_delay(self, random): c = client.Client(reconnection_delay_max=3) c._reconnect_task = 'foo' - c.sleep = mock.MagicMock() + c._reconnect_abort.wait = mock.MagicMock(return_value=False) c.connect = mock.MagicMock( side_effect=[ValueError, exceptions.ConnectionError, None]) c._handle_reconnect() - self.assertEqual(c.sleep.call_count, 3) - self.assertEqual(c.sleep.call_args_list, [ + self.assertEqual(c._reconnect_abort.wait.call_count, 3) + self.assertEqual(c._reconnect_abort.wait.call_args_list, [ mock.call(1.5), mock.call(1.5), mock.call(3.0) @@ -703,12 +703,26 @@ def test_handle_reconnect_max_delay(self, random): def test_handle_reconnect_max_attempts(self, random): c = client.Client(reconnection_attempts=2) c._reconnect_task = 'foo' - c.sleep = mock.MagicMock() + c._reconnect_abort.wait = mock.MagicMock(return_value=False) c.connect = mock.MagicMock( side_effect=[ValueError, exceptions.ConnectionError, None]) c._handle_reconnect() - self.assertEqual(c.sleep.call_count, 2) - self.assertEqual(c.sleep.call_args_list, [ + self.assertEqual(c._reconnect_abort.wait.call_count, 2) + self.assertEqual(c._reconnect_abort.wait.call_args_list, [ + mock.call(1.5), + mock.call(1.5) + ]) + self.assertEqual(c._reconnect_task, 'foo') + + @mock.patch('socketio.client.random.random', side_effect=[1, 0, 0.5]) + def test_handle_reconnect_aborted(self, random): + c = client.Client() + c._reconnect_task = 'foo' + c._reconnect_abort.wait = mock.MagicMock(side_effect=[False, True]) + c.connect = mock.MagicMock(side_effect=exceptions.ConnectionError) + c._handle_reconnect() + self.assertEqual(c._reconnect_abort.wait.call_count, 2) + self.assertEqual(c._reconnect_abort.wait.call_args_list, [ mock.call(1.5), mock.call(1.5) ])