Skip to content

Commit

Permalink
Move ack functionality into BaseManager class
Browse files Browse the repository at this point in the history
  • Loading branch information
miguelgrinberg committed Sep 14, 2015
1 parent ebea5aa commit ad12b83
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 61 deletions.
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
],
tests_require=[
'mock',
'pbr<1.7.0', # temporary, to workaround bug in 1.7.0
],
test_suite='tests',
classifiers=[
Expand Down
34 changes: 32 additions & 2 deletions socketio/base_manager.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import itertools

import six


Expand All @@ -14,6 +16,7 @@ def __init__(self, server):
self.server = server
self.rooms = {}
self.pending_removals = []
self.callbacks = {}

def get_namespaces(self):
"""Return an iterable with the active namespace names."""
Expand Down Expand Up @@ -43,6 +46,10 @@ def disconnect(self, sid, namespace):
rooms.append(room_name)
for room in rooms:
self.leave_room(sid, namespace, room)
if sid in self.callbacks and namespace in self.callbacks[sid]:
del self.callbacks[sid][namespace]
if len(self.callbacks[sid]) == 0:
del self.callbacks[sid]

def enter_room(self, sid, namespace, room):
"""Add a client to a room."""
Expand Down Expand Up @@ -86,8 +93,31 @@ def emit(self, event, data, namespace, room=None, skip_sid=None,
return
for sid in self.get_participants(namespace, room):
if sid != skip_sid:
self.server._emit_internal(sid, event, data, namespace,
callback)
if callback is not None:
id = self.server._generate_ack_id(sid, namespace, callback)
else:
id = None
self.server._emit_internal(sid, event, data, namespace, id)

def trigger_callback(self, sid, namespace, id, data):
"""Invoke an application callback."""
try:
callback = self.callbacks[sid][namespace][id]
except KeyError:
raise ValueError('Unknown callback')
del self.callbacks[sid][namespace][id]
callback(*data)

def _generate_ack_id(self, sid, namespace, callback):
"""Generate a unique identifier for an ACK packet."""
namespace = namespace or '/'
if sid not in self.callbacks:
self.callbacks[sid] = {}
if namespace not in self.callbacks[sid]:
self.callbacks[sid][namespace] = {0: itertools.count(1)}
id = six.next(self.callbacks[sid][namespace][0])
self.callbacks[sid][namespace][id] = callback
return id

def _clean_rooms(self):
"""Remove all the inactive room participants."""
Expand Down
35 changes: 2 additions & 33 deletions socketio/server.py
100755 → 100644
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import itertools
import logging

import engineio
Expand Down Expand Up @@ -83,7 +82,6 @@ def __init__(self, client_manager_class=None, logger=False, binary=False,

self.environ = {}
self.handlers = {}
self.callbacks = {}

self._binary_packet = None
self._attachment_count = 0
Expand Down Expand Up @@ -304,12 +302,8 @@ def handle_request(self, environ, start_response):
"""
return self.eio.handle_request(environ, start_response)

def _emit_internal(self, sid, event, data, namespace=None, callback=None):
def _emit_internal(self, sid, event, data, namespace=None, id=None):
"""Send a message to a client."""
if callback is not None:
id = self._generate_ack_id(sid, namespace, callback)
else:
id = None
if six.PY2 and not self.binary:
binary = False # pragma: nocover
else:
Expand Down Expand Up @@ -353,13 +347,9 @@ def _handle_disconnect(self, sid, namespace):
if n != '/' and self.manager.is_connected(sid, n):
self._trigger_event('disconnect', n, sid)
self.manager.disconnect(sid, n)
if sid in self.callbacks and n in self.callbacks[sid]:
del self.callbacks[sid][n]
if namespace == '/' and self.manager.is_connected(sid, namespace):
self._trigger_event('disconnect', '/', sid)
self.manager.disconnect(sid, '/')
if sid in self.callbacks:
del self.callbacks[sid]
if sid in self.environ:
del self.environ[sid]

Expand Down Expand Up @@ -390,34 +380,13 @@ def _handle_ack(self, sid, namespace, id, data):
"""Handle ACK packets from the client."""
namespace = namespace or '/'
self.logger.info('received ack from %s [%s]', sid, namespace)
self._trigger_callback(sid, namespace, id, data)
self.manager.trigger_callback(sid, namespace, id, data)

def _trigger_event(self, event, namespace, *args):
"""Invoke an application event handler."""
if namespace in self.handlers and event in self.handlers[namespace]:
return self.handlers[namespace][event](*args)

def _generate_ack_id(self, sid, namespace, callback):
"""Generate a unique identifier for an ACK packet."""
namespace = namespace or '/'
if sid not in self.callbacks:
self.callbacks[sid] = {}
if namespace not in self.callbacks[sid]:
self.callbacks[sid][namespace] = {0: itertools.count(1)}
id = six.next(self.callbacks[sid][namespace][0])
self.callbacks[sid][namespace][id] = callback
return id

def _trigger_callback(self, sid, namespace, id, data):
"""Invoke an application callback."""
namespace = namespace or '/'
try:
callback = self.callbacks[sid][namespace][id]
except KeyError:
raise ValueError('Unknown callback')
del self.callbacks[sid][namespace][id]
callback(*data)

def _handle_eio_connect(self, sid, environ):
"""Handle the Engine.IO connection event."""
self.environ[sid] = environ
Expand Down
45 changes: 44 additions & 1 deletion tests/test_base_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
class TestBaseManager(unittest.TestCase):
def setUp(self):
mock_server = mock.MagicMock()
mock_server.rooms = {}
self.bm = base_manager.BaseManager(mock_server)

def test_connect(self):
Expand Down Expand Up @@ -78,6 +77,40 @@ def test_disconnect_all(self):
self.bm._clean_rooms()
self.assertEqual(self.bm.rooms, {})

def test_disconnect_with_callbacks(self):
self.bm.connect('123', '/')
self.bm.connect('123', '/foo')
self.bm._generate_ack_id('123', '/', 'f')
self.bm._generate_ack_id('123', '/foo', 'g')
self.bm.disconnect('123', '/foo')
self.assertNotIn('/foo', self.bm.callbacks['123'])
self.bm.disconnect('123', '/')
self.assertNotIn('123', self.bm.callbacks)

def test_trigger_callback(self):
self.bm.connect('123', '/')
self.bm.connect('123', '/foo')
cb = mock.MagicMock()
id1 = self.bm._generate_ack_id('123', '/', cb)
id2 = self.bm._generate_ack_id('123', '/foo', cb)
self.bm.trigger_callback('123', '/', id1, ['foo'])
self.bm.trigger_callback('123', '/foo', id2, ['bar', 'baz'])
self.assertEqual(cb.call_count, 2)
cb.assert_any_call('foo')
cb.assert_any_call('bar', 'baz')

def test_invalid_callback(self):
self.bm.connect('123', '/')
cb = mock.MagicMock()
id = self.bm._generate_ack_id('123', '/', cb)
self.assertRaises(ValueError, self.bm.trigger_callback,
'124', '/', id, ['foo'])
self.assertRaises(ValueError, self.bm.trigger_callback,
'123', '/foo', id, ['foo'])
self.assertRaises(ValueError, self.bm.trigger_callback,
'123', '/', id + 1, ['foo'])
self.assertEqual(cb.call_count, 0)

def test_get_namespaces(self):
self.assertEqual(list(self.bm.get_namespaces()), [])
self.bm.connect('123', '/')
Expand Down Expand Up @@ -185,6 +218,16 @@ def test_emit_to_all_skip_one(self):
{'foo': 'bar'}, '/foo',
None)

def test_emit_with_callback(self):
self.bm.connect('123', '/foo')
self.bm.server._generate_ack_id.return_value = 11
self.bm.emit('my event', {'foo': 'bar'}, namespace='/foo',
callback='cb')
self.bm.server._emit_internal.assert_called_once_with('123',
'my event',
{'foo': 'bar'},
'/foo', 11)

def test_emit_to_invalid_room(self):
self.bm.emit('my event', {'foo': 'bar'}, namespace='/', room='123')

Expand Down
35 changes: 10 additions & 25 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,8 @@ def test_emit_internal(self, eio):

def test_emit_internal_with_callback(self, eio):
s = server.Server()
s._emit_internal('123', 'my event', 'my data', namespace='/foo',
callback='cb')
id = s.manager._generate_ack_id('123', '/foo', 'cb')
s._emit_internal('123', 'my event', 'my data', namespace='/foo', id=id)
s.eio.send.assert_called_once_with('123',
'2/foo,1["my event","my data"]',
binary=False)
Expand Down Expand Up @@ -323,40 +323,25 @@ def test_handle_invalid_packet(self, eio):

def test_send_with_ack(self, eio):
s = server.Server()
cb = mock.MagicMock()
s._handle_eio_connect('123', 'environ')
s._emit_internal('123', 'my event', ['foo'], callback=cb)
s._emit_internal('123', 'my event', ['bar'], callback=cb)
cb = mock.MagicMock()
id1 = s.manager._generate_ack_id('123', '/', cb)
id2 = s.manager._generate_ack_id('123', '/', cb)
s._emit_internal('123', 'my event', ['foo'], id=id1)
s._emit_internal('123', 'my event', ['bar'], id=id2)
s._handle_eio_message('123', '31["foo",2]')
cb.assert_called_once_with('foo', 2)
self.assertIn('123', s.callbacks)
s._handle_disconnect('123', '/')
self.assertNotIn('123', s.callbacks)

def test_send_with_ack_namespace(self, eio):
s = server.Server()
cb = mock.MagicMock()
s._handle_eio_connect('123', 'environ')
s._handle_eio_message('123', '0/foo')
cb = mock.MagicMock()
id = s.manager._generate_ack_id('123', '/foo', cb)
s._emit_internal('123', 'my event', ['foo'], namespace='/foo',
callback=cb)
id=id)
s._handle_eio_message('123', '3/foo,1["foo",2]')
cb.assert_called_once_with('foo', 2)
self.assertIn('/foo', s.callbacks['123'])
s._handle_eio_disconnect('123')
self.assertNotIn('123', s.callbacks)

def test_invalid_callback(self, eio):
s = server.Server()
cb = mock.MagicMock()
s._handle_eio_connect('123', 'environ')
s._emit_internal('123', 'my event', ['foo'], callback=cb)
self.assertRaises(ValueError, s._handle_eio_message, '124',
'31["foo",2]')
self.assertRaises(ValueError, s._handle_eio_message, '123',
'3/foo,1["foo",2]')
self.assertRaises(ValueError, s._handle_eio_message, '123',
'32["foo",2]')

def test_disconnect(self, eio):
s = server.Server()
Expand Down

0 comments on commit ad12b83

Please sign in to comment.