Skip to content

Commit

Permalink
Support msgpack and custom packet serializers (Fixes #749)
Browse files Browse the repository at this point in the history
  • Loading branch information
miguelgrinberg committed Jul 20, 2021
1 parent a813bde commit 5159e84
Show file tree
Hide file tree
Showing 10 changed files with 148 additions and 39 deletions.
10 changes: 5 additions & 5 deletions src/socketio/asyncio_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ async def emit(self, event, data=None, namespace=None, callback=None):
data = [data]
else:
data = []
await self._send_packet(packet.Packet(
await self._send_packet(self.packet_class(
packet.EVENT, namespace=namespace, data=[event] + data, id=id))

async def send(self, data, namespace=None, callback=None):
Expand Down Expand Up @@ -296,7 +296,7 @@ async def disconnect(self):
# here we just request the disconnection
# later in _handle_eio_disconnect we invoke the disconnect handler
for n in self.namespaces:
await self._send_packet(packet.Packet(packet.DISCONNECT,
await self._send_packet(self.packet_class(packet.DISCONNECT,
namespace=n))
await self.eio.disconnect(abort=True)

Expand Down Expand Up @@ -379,7 +379,7 @@ async def _handle_event(self, namespace, id, data):
data = list(r)
else:
data = [r]
await self._send_packet(packet.Packet(
await self._send_packet(self.packet_class(
packet.ACK, namespace=namespace, id=id, data=data))

async def _handle_ack(self, namespace, id, data):
Expand Down Expand Up @@ -482,7 +482,7 @@ async def _handle_eio_connect(self):
self.sid = self.eio.sid
real_auth = await self._get_real_value(self.connection_auth)
for n in self.connection_namespaces:
await self._send_packet(packet.Packet(
await self._send_packet(self.packet_class(
packet.CONNECT, data=real_auth, namespace=n))

async def _handle_eio_message(self, data):
Expand All @@ -496,7 +496,7 @@ async def _handle_eio_message(self, data):
else:
await self._handle_ack(pkt.namespace, pkt.id, pkt.data)
else:
pkt = packet.Packet(encoded_packet=data)
pkt = self.packet_class(encoded_packet=data)
if pkt.packet_type == packet.CONNECT:
await self._handle_connect(pkt.namespace, pkt.data)
elif pkt.packet_type == packet.DISCONNECT:
Expand Down
16 changes: 8 additions & 8 deletions src/socketio/asyncio_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ async def disconnect(self, sid, namespace=None, ignore_queue=False):
if delete_it:
self.logger.info('Disconnecting %s [%s]', sid, namespace)
eio_sid = self.manager.pre_disconnect(sid, namespace=namespace)
await self._send_packet(eio_sid, packet.Packet(
await self._send_packet(eio_sid, self.packet_class(
packet.DISCONNECT, namespace=namespace))
await self._trigger_event('disconnect', namespace, sid)
self.manager.disconnect(sid, namespace=namespace)
Expand Down Expand Up @@ -423,7 +423,7 @@ async def _emit_internal(self, sid, event, data, namespace=None, id=None):
data = [data]
else:
data = []
await self._send_packet(sid, packet.Packet(
await self._send_packet(sid, self.packet_class(
packet.EVENT, namespace=namespace, data=[event] + data, id=id))

async def _send_packet(self, eio_sid, pkt):
Expand All @@ -440,7 +440,7 @@ async def _handle_connect(self, eio_sid, namespace, data):
namespace = namespace or '/'
sid = self.manager.connect(eio_sid, namespace)
if self.always_connect:
await self._send_packet(eio_sid, packet.Packet(
await self._send_packet(eio_sid, self.packet_class(
packet.CONNECT, {'sid': sid}, namespace=namespace))
fail_reason = exceptions.ConnectionRefusedError().error_args
try:
Expand All @@ -461,15 +461,15 @@ async def _handle_connect(self, eio_sid, namespace, data):
if success is False:
if self.always_connect:
self.manager.pre_disconnect(sid, namespace)
await self._send_packet(eio_sid, packet.Packet(
await self._send_packet(eio_sid, self.packet_class(
packet.DISCONNECT, data=fail_reason, namespace=namespace))
else:
await self._send_packet(eio_sid, packet.Packet(
await self._send_packet(eio_sid, self.packet_class(
packet.CONNECT_ERROR, data=fail_reason,
namespace=namespace))
self.manager.disconnect(sid, namespace)
elif not self.always_connect:
await self._send_packet(eio_sid, packet.Packet(
await self._send_packet(eio_sid, self.packet_class(
packet.CONNECT, {'sid': sid}, namespace=namespace))

async def _handle_disconnect(self, eio_sid, namespace):
Expand Down Expand Up @@ -511,7 +511,7 @@ async def _handle_event_internal(self, server, sid, eio_sid, data,
data = list(r)
else:
data = [r]
await server._send_packet(eio_sid, packet.Packet(
await server._send_packet(eio_sid, self.packet_class(
packet.ACK, namespace=namespace, id=id, data=data))

async def _handle_ack(self, eio_sid, namespace, id, data):
Expand Down Expand Up @@ -560,7 +560,7 @@ async def _handle_eio_message(self, eio_sid, data):
await self._handle_ack(eio_sid, pkt.namespace, pkt.id,
pkt.data)
else:
pkt = packet.Packet(encoded_packet=data)
pkt = self.packet_class(encoded_packet=data)
if pkt.packet_type == packet.CONNECT:
await self._handle_connect(eio_sid, pkt.namespace, pkt.data)
elif pkt.packet_type == packet.DISCONNECT:
Expand Down
34 changes: 25 additions & 9 deletions src/socketio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,13 @@ class Client(object):
use. To disable logging set to ``False``. The default is
``False``. Note that fatal errors are logged even when
``logger`` is ``False``.
:param serializer: The serialization method to use when transmitting
packets. Valid values are ``'default'``, ``'pickle'``,
``'msgpack'`` and ``'cbor'``. Alternatively, a subclass
of the :class:`Packet` class with custom implementations
of the ``encode()`` and ``decode()`` methods can be
provided. Client and server must use compatible
serializers.
:param json: An alternative json module to use for encoding and decoding
packets. Custom json modules must have ``dumps`` and ``loads``
functions that are compatible with the standard library
Expand All @@ -82,7 +89,8 @@ class Client(object):
"""
def __init__(self, reconnection=True, reconnection_attempts=0,
reconnection_delay=1, reconnection_delay_max=5,
randomization_factor=0.5, logger=False, json=None, **kwargs):
randomization_factor=0.5, logger=False, serializer='default',
json=None, **kwargs):
global original_signal_handler
if original_signal_handler is None and \
threading.current_thread() == threading.main_thread():
Expand All @@ -98,8 +106,15 @@ def __init__(self, reconnection=True, reconnection_attempts=0,
engineio_logger = engineio_options.pop('engineio_logger', None)
if engineio_logger is not None:
engineio_options['logger'] = engineio_logger
if serializer == 'default':
self.packet_class = packet.Packet
elif serializer == 'msgpack':
from . import msgpack_packet
self.packet_class = msgpack_packet.MsgPackPacket
else:
self.packet_class = serializer
if json is not None:
packet.Packet.json = json
self.packet_class.json = json
engineio_options['json'] = json

self.eio = self._engineio_client_class()(**engineio_options)
Expand Down Expand Up @@ -381,8 +396,8 @@ def emit(self, event, data=None, namespace=None, callback=None):
data = [data]
else:
data = []
self._send_packet(packet.Packet(packet.EVENT, namespace=namespace,
data=[event] + data, id=id))
self._send_packet(self.packet_class(packet.EVENT, namespace=namespace,
data=[event] + data, id=id))

def send(self, data, namespace=None, callback=None):
"""Send a message to one or more connected clients.
Expand Down Expand Up @@ -448,7 +463,8 @@ def disconnect(self):
# here we just request the disconnection
# later in _handle_eio_disconnect we invoke the disconnect handler
for n in self.namespaces:
self._send_packet(packet.Packet(packet.DISCONNECT, namespace=n))
self._send_packet(self.packet_class(
packet.DISCONNECT, namespace=n))
self.eio.disconnect(abort=True)

def get_sid(self, namespace=None):
Expand Down Expand Up @@ -557,8 +573,8 @@ def _handle_event(self, namespace, id, data):
data = list(r)
else:
data = [r]
self._send_packet(packet.Packet(packet.ACK, namespace=namespace,
id=id, data=data))
self._send_packet(self.packet_class(
packet.ACK, namespace=namespace, id=id, data=data))

def _handle_ack(self, namespace, id, data):
namespace = namespace or '/'
Expand Down Expand Up @@ -647,7 +663,7 @@ def _handle_eio_connect(self):
self.sid = self.eio.sid
real_auth = self._get_real_value(self.connection_auth)
for n in self.connection_namespaces:
self._send_packet(packet.Packet(
self._send_packet(self.packet_class(
packet.CONNECT, data=real_auth, namespace=n))

def _handle_eio_message(self, data):
Expand All @@ -661,7 +677,7 @@ def _handle_eio_message(self, data):
else:
self._handle_ack(pkt.namespace, pkt.id, pkt.data)
else:
pkt = packet.Packet(encoded_packet=data)
pkt = self.packet_class(encoded_packet=data)
if pkt.packet_type == packet.CONNECT:
self._handle_connect(pkt.namespace, pkt.data)
elif pkt.packet_type == packet.DISCONNECT:
Expand Down
16 changes: 16 additions & 0 deletions src/socketio/msgpack_packet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import msgpack
from . import packet


class MsgPackPacket(packet.Packet):
def encode(self):
"""Encode the packet for transmission."""
return msgpack.dumps(self._to_dict())

def decode(self, encoded_packet):
"""Decode a transmitted package."""
decoded = msgpack.loads(encoded_packet)
self.packet_type = decoded['type']
self.data = decoded['data']
self.id = decoded.get('id')
self.namespace = decoded['nsp']
12 changes: 11 additions & 1 deletion src/socketio/packet.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(self, packet_type=EVENT, data=None, namespace=None, id=None,
self.attachment_count = 0
self.attachments = []
if encoded_packet:
self.attachment_count = self.decode(encoded_packet)
self.attachment_count = self.decode(encoded_packet) or 0

def encode(self):
"""Encode the packet for transmission.
Expand Down Expand Up @@ -175,3 +175,13 @@ def _data_is_binary(self, data):
False)
else:
return False

def _to_dict(self):
d = {
'type': self.packet_type,
'data': self.data,
'nsp': self.namespace,
}
if self.id:
d['id'] = self.id
return d
46 changes: 31 additions & 15 deletions src/socketio/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,13 @@ class Server(object):
use. To disable logging set to ``False``. The default is
``False``. Note that fatal errors are logged even when
``logger`` is ``False``.
:param serializer: The serialization method to use when transmitting
packets. Valid values are ``'default'``, ``'pickle'``,
``'msgpack'`` and ``'cbor'``. Alternatively, a subclass
of the :class:`Packet` class with custom implementations
of the ``encode()`` and ``decode()`` methods can be
provided. Client and server must use compatible
serializers.
:param json: An alternative json module to use for encoding and decoding
packets. Custom json modules must have ``dumps`` and ``loads``
functions that are compatible with the standard library
Expand All @@ -48,10 +55,11 @@ class Server(object):
:param async_mode: The asynchronous model to use. See the Deployment
section in the documentation for a description of the
available options. Valid async modes are "threading",
"eventlet", "gevent" and "gevent_uwsgi". If this
argument is not given, "eventlet" is tried first, then
"gevent_uwsgi", then "gevent", and finally "threading".
available options. Valid async modes are
``'threading'``, ``'eventlet'``, ``'gevent'`` and
``'gevent_uwsgi'``. If this argument is not given,
``'eventlet'`` is tried first, then ``'gevent_uwsgi'``,
then ``'gevent'``, and finally ``'threading'``.
The first async mode that has all its dependencies
installed is then one that is chosen.
:param ping_interval: The interval in seconds at which the server pings
Expand Down Expand Up @@ -98,14 +106,22 @@ class Server(object):
fatal errors are logged even when
``engineio_logger`` is ``False``.
"""
def __init__(self, client_manager=None, logger=False, json=None,
async_handlers=True, always_connect=False, **kwargs):
def __init__(self, client_manager=None, logger=False, serializer='default',
json=None, async_handlers=True, always_connect=False,
**kwargs):
engineio_options = kwargs
engineio_logger = engineio_options.pop('engineio_logger', None)
if engineio_logger is not None:
engineio_options['logger'] = engineio_logger
if serializer == 'default':
self.packet_class = packet.Packet
elif serializer == 'msgpack':
from . import msgpack_packet
self.packet_class = msgpack_packet.MsgPackPacket
else:
self.packet_class = serializer
if json is not None:
packet.Packet.json = json
self.packet_class.json = json
engineio_options['json'] = json
engineio_options['async_handlers'] = False
self.eio = self._engineio_server_class()(**engineio_options)
Expand Down Expand Up @@ -531,7 +547,7 @@ def disconnect(self, sid, namespace=None, ignore_queue=False):
if delete_it:
self.logger.info('Disconnecting %s [%s]', sid, namespace)
eio_sid = self.manager.pre_disconnect(sid, namespace=namespace)
self._send_packet(eio_sid, packet.Packet(
self._send_packet(eio_sid, self.packet_class(
packet.DISCONNECT, namespace=namespace))
self._trigger_event('disconnect', namespace, sid)
self.manager.disconnect(sid, namespace=namespace)
Expand Down Expand Up @@ -609,7 +625,7 @@ def _emit_internal(self, eio_sid, event, data, namespace=None, id=None):
data = [data]
else:
data = []
self._send_packet(eio_sid, packet.Packet(
self._send_packet(eio_sid, self.packet_class(
packet.EVENT, namespace=namespace, data=[event] + data, id=id))

def _send_packet(self, eio_sid, pkt):
Expand All @@ -626,7 +642,7 @@ def _handle_connect(self, eio_sid, namespace, data):
namespace = namespace or '/'
sid = self.manager.connect(eio_sid, namespace)
if self.always_connect:
self._send_packet(eio_sid, packet.Packet(
self._send_packet(eio_sid, self.packet_class(
packet.CONNECT, {'sid': sid}, namespace=namespace))
fail_reason = exceptions.ConnectionRefusedError().error_args
try:
Expand All @@ -647,15 +663,15 @@ def _handle_connect(self, eio_sid, namespace, data):
if success is False:
if self.always_connect:
self.manager.pre_disconnect(sid, namespace)
self._send_packet(eio_sid, packet.Packet(
self._send_packet(eio_sid, self.packet_class(
packet.DISCONNECT, data=fail_reason, namespace=namespace))
else:
self._send_packet(eio_sid, packet.Packet(
self._send_packet(eio_sid, self.packet_class(
packet.CONNECT_ERROR, data=fail_reason,
namespace=namespace))
self.manager.disconnect(sid, namespace)
elif not self.always_connect:
self._send_packet(eio_sid, packet.Packet(
self._send_packet(eio_sid, self.packet_class(
packet.CONNECT, {'sid': sid}, namespace=namespace))

def _handle_disconnect(self, eio_sid, namespace):
Expand Down Expand Up @@ -697,7 +713,7 @@ def _handle_event_internal(self, server, sid, eio_sid, data, namespace,
data = list(r)
else:
data = [r]
server._send_packet(eio_sid, packet.Packet(
server._send_packet(eio_sid, self.packet_class(
packet.ACK, namespace=namespace, id=id, data=data))

def _handle_ack(self, eio_sid, namespace, id, data):
Expand Down Expand Up @@ -737,7 +753,7 @@ def _handle_eio_message(self, eio_sid, data):
else:
self._handle_ack(eio_sid, pkt.namespace, pkt.id, pkt.data)
else:
pkt = packet.Packet(encoded_packet=data)
pkt = self.packet_class(encoded_packet=data)
if pkt.packet_type == packet.CONNECT:
self._handle_connect(eio_sid, pkt.namespace, pkt.data)
elif pkt.packet_type == packet.DISCONNECT:
Expand Down
15 changes: 14 additions & 1 deletion tests/common/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from socketio import asyncio_namespace
from socketio import client
from socketio import exceptions
from socketio import msgpack_packet
from socketio import namespace
from socketio import packet

Expand Down Expand Up @@ -49,8 +50,20 @@ def test_create(self, engineio_client_class):
assert c.callbacks == {}
assert c._binary_packet is None
assert c._reconnect_task is None
assert c.packet_class == packet.Packet

def test_custon_json(self):
def test_msgpack(self):
c = client.Client(serializer='msgpack')
assert c.packet_class == msgpack_packet.MsgPackPacket

def test_custom_serializer(self):
class CustomPacket(packet.Packet):
pass

c = client.Client(serializer=CustomPacket)
assert c.packet_class == CustomPacket

def test_custom_json(self):
client.Client()
assert packet.Packet.json == json
assert engineio_packet.Packet.json == json
Expand Down
Loading

0 comments on commit 5159e84

Please sign in to comment.