Skip to content

Commit

Permalink
New shutdown() method added to the client (Fixes #1333)
Browse files Browse the repository at this point in the history
  • Loading branch information
miguelgrinberg committed May 19, 2024
1 parent 82ceaf7 commit 811e044
Show file tree
Hide file tree
Showing 4 changed files with 153 additions and 2 deletions.
24 changes: 22 additions & 2 deletions src/socketio/async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,21 @@ async def disconnect(self):
namespace=n))
await self.eio.disconnect(abort=True)

async def shutdown(self):
"""Stop the client.
If the client is connected to a server, it is disconnected. If the
client is attempting to reconnect to server, the reconnection attempts
are stopped. If the client is not connected to a server and is not
attempting to reconnect, then this function does nothing.
"""
if self.connected:
await self.disconnect()
elif self._reconnect_task: # pragma: no branch
self._reconnect_abort.set()
print(self._reconnect_task)
await self._reconnect_task

def start_background_task(self, target, *args, **kwargs):
"""Start a background task using the appropriate async model.
Expand Down Expand Up @@ -467,15 +482,20 @@ async def _handle_reconnect(self):
self.logger.info(
'Connection failed, new attempt in {:.02f} seconds'.format(
delay))
abort = False
try:
await asyncio.wait_for(self._reconnect_abort.wait(), delay)
abort = True
except asyncio.TimeoutError:
pass
except asyncio.CancelledError: # pragma: no cover
abort = True
if abort:
self.logger.info('Reconnect task aborted')
for n in self.connection_namespaces:
await self._trigger_event('__disconnect_final',
namespace=n)
break
except (asyncio.TimeoutError, asyncio.CancelledError):
pass
attempt_count += 1
try:
await self.connect(self.connection_url,
Expand Down
14 changes: 14 additions & 0 deletions src/socketio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,20 @@ def disconnect(self):
packet.DISCONNECT, namespace=n))
self.eio.disconnect(abort=True)

def shutdown(self):
"""Stop the client.
If the client is connected to a server, it is disconnected. If the
client is attempting to reconnect to server, the reconnection attempts
are stopped. If the client is not connected to a server and is not
attempting to reconnect, then this function does nothing.
"""
if self.connected:
self.disconnect()
elif self._reconnect_task: # pragma: no branch
self._reconnect_abort.set()
self._reconnect_task.join()

def start_background_task(self, target, *args, **kwargs):
"""Start a background task using the appropriate async model.
Expand Down
60 changes: 60 additions & 0 deletions tests/async/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -990,6 +990,66 @@ def test_handle_reconnect_aborted(self, random, wait_for):
c._trigger_event.mock.assert_called_once_with('__disconnect_final',
namespace='/')

def test_shutdown_disconnect(self):
c = async_client.AsyncClient()
c.connected = True
c.namespaces = {'/': '1'}
c._trigger_event = AsyncMock()
c._send_packet = AsyncMock()
c.eio = mock.MagicMock()
c.eio.disconnect = AsyncMock()
c.eio.state = 'connected'
_run(c.shutdown())
assert c._trigger_event.mock.call_count == 0
assert c._send_packet.mock.call_count == 1
expected_packet = packet.Packet(packet.DISCONNECT, namespace='/')
assert (
c._send_packet.mock.call_args_list[0][0][0].encode()
== expected_packet.encode()
)
c.eio.disconnect.mock.assert_called_once_with(abort=True)

def test_shutdown_disconnect_namespaces(self):
c = async_client.AsyncClient()
c.connected = True
c.namespaces = {'/foo': '1', '/bar': '2'}
c._trigger_event = AsyncMock()
c._send_packet = AsyncMock()
c.eio = mock.MagicMock()
c.eio.disconnect = AsyncMock()
c.eio.state = 'connected'
_run(c.shutdown())
assert c._trigger_event.mock.call_count == 0
assert c._send_packet.mock.call_count == 2
expected_packet = packet.Packet(packet.DISCONNECT, namespace='/foo')
assert (
c._send_packet.mock.call_args_list[0][0][0].encode()
== expected_packet.encode()
)
expected_packet = packet.Packet(packet.DISCONNECT, namespace='/bar')
assert (
c._send_packet.mock.call_args_list[1][0][0].encode()
== expected_packet.encode()
)

@mock.patch('socketio.client.random.random', side_effect=[1, 0, 0.5])
def test_shutdown_reconnect(self, random):
c = async_client.AsyncClient()
c.connection_namespaces = ['/']
c._reconnect_task = AsyncMock()()
c._trigger_event = AsyncMock()
c.connect = AsyncMock(side_effect=exceptions.ConnectionError)

async def r():
task = c.start_background_task(c._handle_reconnect)
await asyncio.sleep(0.1)
await c.shutdown()
await task

_run(r())
c._trigger_event.mock.assert_called_once_with('__disconnect_final',
namespace='/')

def test_handle_eio_connect(self):
c = async_client.AsyncClient()
c.connection_namespaces = ['/', '/foo']
Expand Down
57 changes: 57 additions & 0 deletions tests/common/test_client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import time
import unittest
from unittest import mock

Expand Down Expand Up @@ -636,6 +637,7 @@ def test_disconnect(self):

def test_disconnect_namespaces(self):
c = client.Client()
c.connected = True
c.namespaces = {'/foo': '1', '/bar': '2'}
c._trigger_event = mock.MagicMock()
c._send_packet = mock.MagicMock()
Expand Down Expand Up @@ -1128,6 +1130,61 @@ def test_handle_reconnect_aborted(self, random):
c._trigger_event.assert_called_once_with('__disconnect_final',
namespace='/')

def test_shutdown_disconnect(self):
c = client.Client()
c.connected = True
c.namespaces = {'/': '1'}
c._trigger_event = mock.MagicMock()
c._send_packet = mock.MagicMock()
c.eio = mock.MagicMock()
c.eio.state = 'connected'
c.shutdown()
assert c._trigger_event.call_count == 0
assert c._send_packet.call_count == 1
expected_packet = packet.Packet(packet.DISCONNECT, namespace='/')
assert (
c._send_packet.call_args_list[0][0][0].encode()
== expected_packet.encode()
)
c.eio.disconnect.assert_called_once_with(abort=True)

def test_shutdown_disconnect_namespaces(self):
c = client.Client()
c.connected = True
c.namespaces = {'/foo': '1', '/bar': '2'}
c._trigger_event = mock.MagicMock()
c._send_packet = mock.MagicMock()
c.eio = mock.MagicMock()
c.eio.state = 'connected'
c.shutdown()
assert c._trigger_event.call_count == 0
assert c._send_packet.call_count == 2
expected_packet = packet.Packet(packet.DISCONNECT, namespace='/foo')
assert (
c._send_packet.call_args_list[0][0][0].encode()
== expected_packet.encode()
)
expected_packet = packet.Packet(packet.DISCONNECT, namespace='/bar')
assert (
c._send_packet.call_args_list[1][0][0].encode()
== expected_packet.encode()
)

@mock.patch('socketio.client.random.random', side_effect=[1, 0, 0.5])
def test_shutdown_reconnect(self, random):
c = client.Client()
c.connection_namespaces = ['/']
c._reconnect_task = mock.MagicMock()
c._trigger_event = mock.MagicMock()
c.connect = mock.MagicMock(side_effect=exceptions.ConnectionError)
task = c.start_background_task(c._handle_reconnect)
time.sleep(0.1)
c.shutdown()
task.join()
c._trigger_event.assert_called_once_with('__disconnect_final',
namespace='/')
assert c._reconnect_task.join.called_once_with()

def test_handle_eio_connect(self):
c = client.Client()
c.connection_namespaces = ['/', '/foo']
Expand Down

0 comments on commit 811e044

Please sign in to comment.