Skip to content

Commit

Permalink
Fixed remote async disconnects via message queue (Fixes #1003)
Browse files Browse the repository at this point in the history
  • Loading branch information
miguelgrinberg committed Aug 15, 2022
1 parent f56ef6f commit 104d656
Show file tree
Hide file tree
Showing 7 changed files with 50 additions and 20 deletions.
7 changes: 7 additions & 0 deletions src/socketio/asyncio_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,13 @@ async def emit(self, event, data, namespace, room=None, skip_sid=None,
return
await asyncio.wait(tasks)

async def disconnect(self, sid, namespace, **kwargs):
"""Disconnect a client.
Note: this method is a coroutine.
"""
return super().disconnect(sid, namespace, **kwargs)

async def close_room(self, room, namespace):
"""Remove all participants from a room.
Expand Down
9 changes: 8 additions & 1 deletion src/socketio/asyncio_pubsub_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,14 @@ async def can_disconnect(self, sid, namespace):
else:
# client is in another server, so we post request to the queue
await self._publish({'method': 'disconnect', 'sid': sid,
'namespace': namespace or '/'})
'namespace': namespace or '/'})

async def disconnect(self, sid, namespace, **kwargs):
if kwargs.get('ignore_queue'):
return await super(AsyncPubSubManager, self).disconnect(
sid, namespace=namespace)
await self._publish({'method': 'disconnect', 'sid': sid,
'namespace': namespace or '/'})

async def close_room(self, room, namespace=None):
await self._publish({'method': 'close_room', 'room': room,
Expand Down
7 changes: 4 additions & 3 deletions src/socketio/asyncio_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,8 @@ async def disconnect(self, sid, namespace=None, ignore_queue=False):
await self._send_packet(eio_sid, self.packet_class(
packet.DISCONNECT, namespace=namespace))
await self._trigger_event('disconnect', namespace, sid)
self.manager.disconnect(sid, namespace=namespace)
await self.manager.disconnect(sid, namespace=namespace,
ignore_queue=True)

async def handle_request(self, *args, **kwargs):
"""Handle an HTTP request from the client.
Expand Down Expand Up @@ -486,7 +487,7 @@ async def _handle_connect(self, eio_sid, namespace, data):
await self._send_packet(eio_sid, self.packet_class(
packet.CONNECT_ERROR, data=fail_reason,
namespace=namespace))
self.manager.disconnect(sid, namespace)
await self.manager.disconnect(sid, namespace, ignore_queue=True)
elif not self.always_connect:
await self._send_packet(eio_sid, self.packet_class(
packet.CONNECT, {'sid': sid}, namespace=namespace))
Expand All @@ -499,7 +500,7 @@ async def _handle_disconnect(self, eio_sid, namespace):
return
self.manager.pre_disconnect(sid, namespace=namespace)
await self._trigger_event('disconnect', namespace, sid)
self.manager.disconnect(sid, namespace)
await self.manager.disconnect(sid, namespace, ignore_queue=True)

async def _handle_event(self, eio_sid, namespace, id, data):
"""Handle an incoming client event."""
Expand Down
1 change: 1 addition & 0 deletions src/socketio/base_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def is_connected(self, sid, namespace):
return self.rooms[namespace][None][sid] is not None
except KeyError:
pass
return False

def sid_from_eio_sid(self, eio_sid, namespace):
try:
Expand Down
28 changes: 14 additions & 14 deletions tests/asyncio/test_asyncio_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,17 +59,17 @@ def test_pre_disconnect(self):
assert self.bm.pre_disconnect(sid2, '/foo') == '456'
assert self.bm.pending_disconnect == {'/foo': [sid1, sid2]}
assert not self.bm.is_connected(sid2, '/foo')
self.bm.disconnect(sid1, '/foo')
_run(self.bm.disconnect(sid1, '/foo'))
assert self.bm.pending_disconnect == {'/foo': [sid2]}
self.bm.disconnect(sid2, '/foo')
_run(self.bm.disconnect(sid2, '/foo'))
assert self.bm.pending_disconnect == {}

def test_disconnect(self):
sid1 = self.bm.connect('123', '/foo')
sid2 = self.bm.connect('456', '/foo')
self.bm.enter_room(sid1, '/foo', 'bar')
self.bm.enter_room(sid2, '/foo', 'baz')
self.bm.disconnect(sid1, '/foo')
_run(self.bm.disconnect(sid1, '/foo'))
assert dict(self.bm.rooms['/foo'][None]) == {sid2: '456'}
assert dict(self.bm.rooms['/foo'][sid2]) == {sid2: '456'}
assert dict(self.bm.rooms['/foo']['baz']) == {sid2: '456'}
Expand All @@ -83,10 +83,10 @@ def test_disconnect_default_namespace(self):
assert self.bm.is_connected(sid2, '/foo')
assert not self.bm.is_connected(sid2, '/')
assert not self.bm.is_connected(sid1, '/foo')
self.bm.disconnect(sid1, '/')
_run(self.bm.disconnect(sid1, '/'))
assert not self.bm.is_connected(sid1, '/')
assert self.bm.is_connected(sid2, '/foo')
self.bm.disconnect(sid2, '/foo')
_run(self.bm.disconnect(sid2, '/foo'))
assert not self.bm.is_connected(sid2, '/foo')
assert dict(self.bm.rooms['/'][None]) == {sid3: '456'}
assert dict(self.bm.rooms['/'][sid3]) == {sid3: '456'}
Expand All @@ -98,10 +98,10 @@ def test_disconnect_twice(self):
sid2 = self.bm.connect('123', '/foo')
sid3 = self.bm.connect('456', '/')
sid4 = self.bm.connect('456', '/foo')
self.bm.disconnect(sid1, '/')
self.bm.disconnect(sid2, '/foo')
self.bm.disconnect(sid1, '/')
self.bm.disconnect(sid2, '/foo')
_run(self.bm.disconnect(sid1, '/'))
_run(self.bm.disconnect(sid2, '/foo'))
_run(self.bm.disconnect(sid1, '/'))
_run(self.bm.disconnect(sid2, '/foo'))
assert dict(self.bm.rooms['/'][None]) == {sid3: '456'}
assert dict(self.bm.rooms['/'][sid3]) == {sid3: '456'}
assert dict(self.bm.rooms['/foo'][None]) == {sid4: '456'}
Expand All @@ -112,8 +112,8 @@ def test_disconnect_all(self):
sid2 = self.bm.connect('456', '/foo')
self.bm.enter_room(sid1, '/foo', 'bar')
self.bm.enter_room(sid2, '/foo', 'baz')
self.bm.disconnect(sid1, '/foo')
self.bm.disconnect(sid2, '/foo')
_run(self.bm.disconnect(sid1, '/foo'))
_run(self.bm.disconnect(sid2, '/foo'))
assert self.bm.rooms == {}

def test_disconnect_with_callbacks(self):
Expand All @@ -123,9 +123,9 @@ def test_disconnect_with_callbacks(self):
self.bm._generate_ack_id(sid1, 'f')
self.bm._generate_ack_id(sid2, 'g')
self.bm._generate_ack_id(sid3, 'h')
self.bm.disconnect(sid2, '/foo')
_run(self.bm.disconnect(sid2, '/foo'))
assert sid2 not in self.bm.callbacks
self.bm.disconnect(sid1, '/')
_run(self.bm.disconnect(sid1, '/'))
assert sid1 not in self.bm.callbacks
assert sid3 in self.bm.callbacks

Expand Down Expand Up @@ -176,7 +176,7 @@ def test_get_participants(self):
sid1 = self.bm.connect('123', '/')
sid2 = self.bm.connect('456', '/')
sid3 = self.bm.connect('789', '/')
self.bm.disconnect(sid3, '/')
_run(self.bm.disconnect(sid3, '/'))
assert sid3 not in self.bm.rooms['/'][None]
participants = list(self.bm.get_participants('/', None))
assert len(participants) == 2
Expand Down
13 changes: 13 additions & 0 deletions tests/asyncio/test_asyncio_pubsub_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,19 @@ def test_can_disconnect(self):
{'method': 'disconnect', 'sid': sid, 'namespace': '/foo'}
)

def test_disconnect(self):
_run(self.pm.disconnect('foo', '/'))
self.pm._publish.mock.assert_called_once_with(
{'method': 'disconnect', 'sid': 'foo', 'namespace': '/'}
)

def test_disconnect_ignore_queue(self):
sid = self.pm.connect('123', '/')
self.pm.pre_disconnect(sid, '/')
_run(self.pm.disconnect(sid, '/', ignore_queue=True))
self.pm._publish.mock.assert_not_called()
assert self.pm.is_connected(sid, '/') is False

def test_close_room(self):
_run(self.pm.close_room('foo'))
self.pm._publish.mock.assert_called_once_with(
Expand Down
5 changes: 3 additions & 2 deletions tests/asyncio/test_asyncio_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,14 +597,15 @@ def test_handle_connect_namespace_rejected_with_empty_exception(self, eio):
def test_handle_disconnect(self, eio):
eio.return_value.send = AsyncMock()
s = asyncio_server.AsyncServer()
s.manager.disconnect = mock.MagicMock()
s.manager.disconnect = AsyncMock()
handler = mock.MagicMock()
s.on('disconnect', handler)
_run(s._handle_eio_connect('123', 'environ'))
_run(s._handle_eio_message('123', '0'))
_run(s._handle_eio_disconnect('123'))
handler.assert_called_once_with('1')
s.manager.disconnect.assert_called_once_with('1', '/')
s.manager.disconnect.mock.assert_called_once_with(
'1', '/', ignore_queue=True)
assert s.environ == {}

def test_handle_disconnect_namespace(self, eio):
Expand Down

0 comments on commit 104d656

Please sign in to comment.