Skip to content

Commit

Permalink
Disconnect Engine.IO connection when server disconnects a client (mig…
Browse files Browse the repository at this point in the history
  • Loading branch information
miguelgrinberg committed Aug 3, 2019
1 parent d23581e commit 516a295
Show file tree
Hide file tree
Showing 7 changed files with 150 additions and 7 deletions.
22 changes: 19 additions & 3 deletions socketio/asyncio_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ async def connect(self, url, headers={}, transports=None,
engineio_path=socketio_path)
except engineio.exceptions.ConnectionError as exc:
six.raise_from(exceptions.ConnectionError(exc.args[0]), None)
self.connected = True

async def wait(self):
"""Wait until the connection with the server ends.
Expand Down Expand Up @@ -232,6 +233,7 @@ async def disconnect(self):
namespace=n))
await self._send_packet(packet.Packet(
packet.DISCONNECT, namespace='/'))
self.connected = False
await self.eio.disconnect(abort=True)

def start_background_task(self, target, *args, **kwargs):
Expand Down Expand Up @@ -286,10 +288,18 @@ async def _handle_connect(self, namespace):
self.namespaces.append(namespace)

async def _handle_disconnect(self, namespace):
if not self.connected:
return
namespace = namespace or '/'
if namespace == '/':
for n in self.namespaces:
await self._trigger_event('disconnect', namespace=n)
self.namespaces = []
await self._trigger_event('disconnect', namespace=namespace)
if namespace in self.namespaces:
self.namespaces.remove(namespace)
if namespace == '/':
self.connected = False

async def _handle_event(self, namespace, id, data):
namespace = namespace or '/'
Expand Down Expand Up @@ -335,6 +345,9 @@ def _handle_error(self, namespace):
namespace))
if namespace in self.namespaces:
self.namespaces.remove(namespace)
if namespace == '/':
self.namespaces = []
self.connected = False

async def _trigger_event(self, event, namespace, *args):
"""Invoke an application event handler."""
Expand Down Expand Up @@ -431,9 +444,12 @@ async def _handle_eio_disconnect(self):
"""Handle the Engine.IO disconnection event."""
self.logger.info('Engine.IO connection dropped')
self._reconnect_abort.set()
for n in self.namespaces:
await self._trigger_event('disconnect', namespace=n)
await self._trigger_event('disconnect', namespace='/')
if self.connected:
for n in self.namespaces:
await self._trigger_event('disconnect', namespace=n)
await self._trigger_event('disconnect', namespace='/')
self.namespaces = []
self.connected = False
self.callbacks = {}
self._binary_packet = None
self.sid = None
Expand Down
2 changes: 2 additions & 0 deletions socketio/asyncio_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,8 @@ async def disconnect(self, sid, namespace=None):
namespace=namespace))
await self._trigger_event('disconnect', namespace, sid)
self.manager.disconnect(sid, namespace=namespace)
if namespace == '/':
await self.eio.disconnect(sid)

async def handle_request(self, *args, **kwargs):
"""Handle an HTTP request from the client.
Expand Down
24 changes: 20 additions & 4 deletions socketio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def __init__(self, reconnection=True, reconnection_attempts=0,
self.socketio_path = None
self.sid = None

self.connected = False
self.namespaces = []
self.handlers = {}
self.namespace_handlers = {}
Expand Down Expand Up @@ -261,6 +262,7 @@ def connect(self, url, headers={}, transports=None,
engineio_path=socketio_path)
except engineio.exceptions.ConnectionError as exc:
six.raise_from(exceptions.ConnectionError(exc.args[0]), None)
self.connected = True

def wait(self):
"""Wait until the connection with the server ends.
Expand Down Expand Up @@ -377,6 +379,7 @@ def disconnect(self):
self._send_packet(packet.Packet(packet.DISCONNECT, namespace=n))
self._send_packet(packet.Packet(
packet.DISCONNECT, namespace='/'))
self.connected = False
self.eio.disconnect(abort=True)

def transport(self):
Expand Down Expand Up @@ -445,10 +448,18 @@ def _handle_connect(self, namespace):
self.namespaces.append(namespace)

def _handle_disconnect(self, namespace):
if not self.connected:
return
namespace = namespace or '/'
if namespace == '/':
for n in self.namespaces:
self._trigger_event('disconnect', namespace=n)
self.namespaces = []
self._trigger_event('disconnect', namespace=namespace)
if namespace in self.namespaces:
self.namespaces.remove(namespace)
if namespace == '/':
self.connected = False

def _handle_event(self, namespace, id, data):
namespace = namespace or '/'
Expand Down Expand Up @@ -490,6 +501,9 @@ def _handle_error(self, namespace):
namespace))
if namespace in self.namespaces:
self.namespaces.remove(namespace)
if namespace == '/':
self.namespaces = []
self.connected = False

def _trigger_event(self, event, namespace, *args):
"""Invoke an application event handler."""
Expand All @@ -516,7 +530,6 @@ def _handle_reconnect(self):
self.logger.info(
'Connection failed, new attempt in {:.02f} seconds'.format(
delay))
print('***', self._reconnect_abort.wait)
if self._reconnect_abort.wait(delay):
self.logger.info('Reconnect task aborted')
break
Expand Down Expand Up @@ -576,9 +589,12 @@ def _handle_eio_message(self, data):
def _handle_eio_disconnect(self):
"""Handle the Engine.IO disconnection event."""
self.logger.info('Engine.IO connection dropped')
for n in self.namespaces:
self._trigger_event('disconnect', namespace=n)
self._trigger_event('disconnect', namespace='/')
if self.connected:
for n in self.namespaces:
self._trigger_event('disconnect', namespace=n)
self._trigger_event('disconnect', namespace='/')
self.namespaces = []
self.connected = False
self.callbacks = {}
self._binary_packet = None
self.sid = None
Expand Down
2 changes: 2 additions & 0 deletions socketio/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,8 @@ def disconnect(self, sid, namespace=None):
namespace=namespace))
self._trigger_event('disconnect', namespace, sid)
self.manager.disconnect(sid, namespace=namespace)
if namespace == '/':
self.eio.disconnect(sid)

def transport(self, sid):
"""Return the name of the transport used by the client.
Expand Down
52 changes: 52 additions & 0 deletions tests/asyncio/test_asyncio_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,28 +411,51 @@ def test_handle_connect_namespace(self):

def test_handle_disconnect(self):
c = asyncio_client.AsyncClient()
c.connected = True
c._trigger_event = AsyncMock()
_run(c._handle_disconnect('/'))
c._trigger_event.mock.assert_called_once_with(
'disconnect', namespace='/')
self.assertFalse(c.connected)
_run(c._handle_disconnect('/'))
self.assertEqual(c._trigger_event.mock.call_count, 1)

def test_handle_disconnect_namespace(self):
c = asyncio_client.AsyncClient()
c.connected = True
c.namespaces = ['/foo', '/bar']
c._trigger_event = AsyncMock()
_run(c._handle_disconnect('/foo'))
c._trigger_event.mock.assert_called_once_with(
'disconnect', namespace='/foo')
self.assertEqual(c.namespaces, ['/bar'])
self.assertTrue(c.connected)

def test_handle_disconnect_unknown_namespace(self):
c = asyncio_client.AsyncClient()
c.connected = True
c.namespaces = ['/foo', '/bar']
c._trigger_event = AsyncMock()
_run(c._handle_disconnect('/baz'))
c._trigger_event.mock.assert_called_once_with(
'disconnect', namespace='/baz')
self.assertEqual(c.namespaces, ['/foo', '/bar'])
self.assertTrue(c.connected)

def test_handle_disconnect_all_namespaces(self):
c = asyncio_client.AsyncClient()
c.connected = True
c.namespaces = ['/foo', '/bar']
c._trigger_event = AsyncMock()
_run(c._handle_disconnect('/'))
c._trigger_event.mock.assert_any_call(
'disconnect', namespace='/')
c._trigger_event.mock.assert_any_call(
'disconnect', namespace='/foo')
c._trigger_event.mock.assert_any_call(
'disconnect', namespace='/bar')
self.assertEqual(c.namespaces, [])
self.assertFalse(c.connected)

def test_handle_event(self):
c = asyncio_client.AsyncClient()
Expand Down Expand Up @@ -519,15 +542,27 @@ def test_handle_ack_not_found(self):

def test_handle_error(self):
c = asyncio_client.AsyncClient()
c.connected = True
c.namespaces = ['/foo', '/bar']
c._handle_error('/')
self.assertEqual(c.namespaces, [])
self.assertFalse(c.connected)

def test_handle_error_namespace(self):
c = asyncio_client.AsyncClient()
c.connected = True
c.namespaces = ['/foo', '/bar']
c._handle_error('/bar')
self.assertEqual(c.namespaces, ['/foo'])
self.assertTrue(c.connected)

def test_handle_error_unknown_namespace(self):
c = asyncio_client.AsyncClient()
c.connected = True
c.namespaces = ['/foo', '/bar']
c._handle_error('/baz')
self.assertEqual(c.namespaces, ['/foo', '/bar'])
self.assertTrue(c.connected)

def test_trigger_event(self):
c = asyncio_client.AsyncClient()
Expand Down Expand Up @@ -556,6 +591,19 @@ def on_foo(self, a, b):
_run(c._trigger_event('foo', '/', 1, '2'))
self.assertEqual(result, [1, '2'])

def test_trigger_event_unknown_namespace(self):
c = asyncio_client.AsyncClient()
result = []

class MyNamespace(asyncio_namespace.AsyncClientNamespace):
def on_foo(self, a, b):
result.append(a)
result.append(b)

c.register_namespace(MyNamespace('/'))
_run(c._trigger_event('foo', '/bar', 1, '2'))
self.assertEqual(result, [])

@mock.patch('asyncio.wait_for', new_callable=AsyncMock,
side_effect=asyncio.TimeoutError)
@mock.patch('socketio.client.random.random', side_effect=[1, 0, 0.5])
Expand Down Expand Up @@ -663,16 +711,19 @@ def test_handle_eio_message(self):

def test_eio_disconnect(self):
c = asyncio_client.AsyncClient()
c.connected = True
c._trigger_event = AsyncMock()
c.sid = 'foo'
c.eio.state = 'connected'
_run(c._handle_eio_disconnect())
c._trigger_event.mock.assert_called_once_with(
'disconnect', namespace='/')
self.assertIsNone(c.sid)
self.assertFalse(c.connected)

def test_eio_disconnect_namespaces(self):
c = asyncio_client.AsyncClient()
c.connected = True
c.namespaces = ['/foo', '/bar']
c._trigger_event = AsyncMock()
c.sid = 'foo'
Expand All @@ -682,6 +733,7 @@ def test_eio_disconnect_namespaces(self):
c._trigger_event.mock.assert_any_call('disconnect', namespace='/bar')
c._trigger_event.mock.assert_any_call('disconnect', namespace='/')
self.assertIsNone(c.sid)
self.assertFalse(c.connected)

def test_eio_disconnect_reconnect(self):
c = asyncio_client.AsyncClient(reconnection=True)
Expand Down
6 changes: 6 additions & 0 deletions tests/asyncio/test_asyncio_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,27 +596,33 @@ async def _test():

def test_disconnect(self, eio):
eio.return_value.send = AsyncMock()
eio.return_value.disconnect = AsyncMock()
s = asyncio_server.AsyncServer()
_run(s._handle_eio_connect('123', 'environ'))
_run(s.disconnect('123'))
s.eio.send.mock.assert_any_call('123', '1', binary=False)
s.eio.disconnect.mock.assert_called_once_with('123')

def test_disconnect_namespace(self, eio):
eio.return_value.send = AsyncMock()
eio.return_value.disconnect = AsyncMock()
s = asyncio_server.AsyncServer()
_run(s._handle_eio_connect('123', 'environ'))
_run(s._handle_eio_message('123', '0/foo'))
_run(s.disconnect('123', namespace='/foo'))
s.eio.send.mock.assert_any_call('123', '1/foo', binary=False)
s.eio.disconnect.mock.assert_not_called()

def test_disconnect_twice(self, eio):
eio.return_value.send = AsyncMock()
eio.return_value.disconnect = AsyncMock()
s = asyncio_server.AsyncServer()
_run(s._handle_eio_connect('123', 'environ'))
_run(s.disconnect('123'))
calls = s.eio.send.mock.call_count
_run(s.disconnect('123'))
self.assertEqual(calls, s.eio.send.mock.call_count)
self.assertEqual(s.eio.disconnect.mock.call_count, 1)

def test_disconnect_twice_namespace(self, eio):
eio.return_value.send = AsyncMock()
Expand Down
Loading

0 comments on commit 516a295

Please sign in to comment.