Skip to content

Commit

Permalink
Support for callbacks across servers
Browse files Browse the repository at this point in the history
  • Loading branch information
miguelgrinberg committed Dec 7, 2015
1 parent 47620bb commit 63f5ed3
Show file tree
Hide file tree
Showing 6 changed files with 82 additions and 26 deletions.
10 changes: 7 additions & 3 deletions socketio/base_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,16 @@ def emit(self, event, data, namespace, room=None, skip_sid=None,

def trigger_callback(self, sid, namespace, id, data):
"""Invoke an application callback."""
callback = None
try:
callback = self.callbacks[sid][namespace][id]
except KeyError:
raise ValueError('Unknown callback')
del self.callbacks[sid][namespace][id]
callback(*data)
# if we get an unknown callback we just ignore it
self.server.logger.warning('Unknown callback received, ignoring.')
else:
del self.callbacks[sid][namespace][id]
if callback is not None:
callback(*data)

def _generate_ack_id(self, sid, namespace, callback):
"""Generate a unique identifier for an ACK packet."""
Expand Down
10 changes: 7 additions & 3 deletions socketio/kombu_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,12 @@ def _listen(self):
if isinstance(message.payload, six.binary_type):
try:
data = pickle.loads(message.payload)
except pickle.PickleError:
except:
pass
if data is None:
data = json.loads(message.payload)
yield data
try:
data = json.loads(message.payload)
except:
pass
if data:
yield data
63 changes: 53 additions & 10 deletions socketio/pubsub_manager.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from functools import partial
import uuid

from .base_manager import BaseManager


Expand All @@ -18,6 +21,7 @@ class PubSubManager(BaseManager):
def __init__(self, channel='socketio'):
super(PubSubManager, self).__init__()
self.channel = channel
self.host_id = uuid.uuid4().hex

def initialize(self, server):
super(PubSubManager, self).initialize(server)
Expand All @@ -34,8 +38,14 @@ def emit(self, event, data, namespace=None, room=None, skip_sid=None,
The parameters are the same as in :meth:`.Server.emit`.
"""
namespace = namespace or '/'
if callback is not None:
id = self._generate_ack_id(room, namespace, callback)
callback = (room, namespace, id)
else:
callback = None
self._publish({'method': 'emit', 'event': event, 'data': data,
'namespace': namespace or '/', 'room': room,
'namespace': namespace, 'room': room,
'skip_sid': skip_sid, 'callback': callback})

def close_room(self, room, namespace=None):
Expand All @@ -61,17 +71,50 @@ def _listen(self):
raise NotImplementedError('This method must be implemented in a '
'subclass.')

def _handle_emit(self, message):
# Events with callbacks are very tricky to handle across hosts
# Here in the receiving end we set up a local callback that preserves
# the callback host and id from the sender
remote_callback = message.get('callback')
if remote_callback is not None and len(remote_callback) == 3:
callback = partial(self._return_callback, self.host_id,
*remote_callback)
else:
callback = None
super(PubSubManager, self).emit(message['event'], message['data'],
namespace=message.get('namespace'),
room=message.get('room'),
skip_sid=message.get('skip_sid'),
callback=callback)

def _handle_callback(self, message):
if self.host_id == message.get('host_id'):
try:
sid = message['sid']
namespace = message['namespace']
id = message['id']
args = message['args']
except KeyError:
return
self.trigger_callback(sid, namespace, id, args)

def _return_callback(self, host_id, sid, namespace, callback_id, *args):
# When an event callback is received, the callback is returned back
# the sender, which is identified by the host_id
self._publish({'method': 'callback', 'host_id': host_id,
'sid': sid, 'namespace': namespace, 'id': callback_id,
'args': args})

def _handle_close_room(self, message):
super(PubSubManager, self).close_room(
room=message.get('room'), namespace=message.get('namespace'))

def _thread(self):
for message in self._listen():
if 'method' in message:
if message['method'] == 'emit':
super(PubSubManager, self).emit(
message['event'], message['data'],
namespace=message.get('namespace'),
room=message.get('room'),
skip_sid=message.get('skip_sid'),
callback=message.get('callback'))
self._handle_emit(message)
elif message['method'] == 'callback':
self._handle_callback(message)
elif message['method'] == 'close_room':
super(PubSubManager, self).close_room(
room=message.get('room'),
namespace=message.get('namespace'))
self._handle_close_room(message)
10 changes: 7 additions & 3 deletions socketio/redis_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,13 @@ def _listen(self):
if isinstance(message['data'], six.binary_type):
try:
data = pickle.loads(message['data'])
except pickle.PickleError:
except:
pass
if data is None:
data = json.loads(message['data'])
yield data
try:
data = json.loads(message['data'])
except:
pass
if data:
yield data
self.pubsub.unsubscribe(self.channel)
11 changes: 5 additions & 6 deletions tests/test_base_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,11 @@ def test_invalid_callback(self):
self.bm.connect('123', '/')
cb = mock.MagicMock()
id = self.bm._generate_ack_id('123', '/', cb)
self.assertRaises(ValueError, self.bm.trigger_callback,
'124', '/', id, ['foo'])
self.assertRaises(ValueError, self.bm.trigger_callback,
'123', '/foo', id, ['foo'])
self.assertRaises(ValueError, self.bm.trigger_callback,
'123', '/', id + 1, ['foo'])

# these should not raise an exception
self.bm.trigger_callback('124', '/', id, ['foo'])
self.bm.trigger_callback('123', '/foo', id, ['foo'])
self.bm.trigger_callback('123', '/', id + 1, ['foo'])
self.assertEqual(cb.call_count, 0)

def test_get_namespaces(self):
Expand Down
4 changes: 3 additions & 1 deletion tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,9 @@ def test_handle_event_binary_ack(self, eio):
s._handle_eio_message('123', '61-1["my message","a",'
'{"_placeholder":true,"num":0}]')
self.assertEqual(s._attachment_count, 1)
self.assertRaises(ValueError, s._handle_eio_message, '123', b'foo')
# the following call should not raise an exception in spite of the
# callback id being invalid
s._handle_eio_message('123', b'foo')

def test_handle_event_with_ack(self, eio):
s = server.Server()
Expand Down

0 comments on commit 63f5ed3

Please sign in to comment.