Skip to content

Commit

Permalink
Add namespaces argument to Server and AsyncServer (Fixes #822)
Browse files Browse the repository at this point in the history
  • Loading branch information
miguelgrinberg committed Jul 9, 2022
1 parent d4e69fb commit efe87d8
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 7 deletions.
13 changes: 10 additions & 3 deletions src/socketio/asyncio_server.py
Expand Up @@ -40,6 +40,11 @@ class AsyncServer(server.Server):
connect handler and your client is confused when it
receives events before the connection acceptance.
In any other case use the default of ``False``.
:param namespaces: a list of namespaces that are accepted, in addition to
any namespaces for which handlers have been defined. The
default is `['/']`, which always accepts connections to
the default namespace. Set to `'*'` to accept all
namespaces.
:param kwargs: Connection parameters for the underlying Engine.IO server.
The Engine.IO configuration supports the following settings:
Expand Down Expand Up @@ -97,11 +102,12 @@ class AsyncServer(server.Server):
``engineio_logger`` is ``False``.
"""
def __init__(self, client_manager=None, logger=False, json=None,
async_handlers=True, **kwargs):
async_handlers=True, namespaces=None, **kwargs):
if client_manager is None:
client_manager = asyncio_manager.AsyncManager()
super().__init__(client_manager=client_manager, logger=logger,
json=json, async_handlers=async_handlers, **kwargs)
json=json, async_handlers=async_handlers,
namespaces=namespaces, **kwargs)

def is_asyncio_based(self):
return True
Expand Down Expand Up @@ -443,7 +449,8 @@ async def _handle_connect(self, eio_sid, namespace, data):
"""Handle a client connection request."""
namespace = namespace or '/'
sid = None
if namespace in self.handlers or namespace in self.namespace_handlers:
if namespace in self.handlers or namespace in self.namespace_handlers \
or self.namespaces == '*' or namespace in self.namespaces:
sid = self.manager.connect(eio_sid, namespace)
if sid is None:
await self._send_packet(eio_sid, self.packet_class(
Expand Down
11 changes: 9 additions & 2 deletions src/socketio/server.py
Expand Up @@ -49,6 +49,11 @@ class Server(object):
connect handler and your client is confused when it
receives events before the connection acceptance.
In any other case use the default of ``False``.
:param namespaces: a list of namespaces that are accepted, in addition to
any namespaces for which handlers have been defined. The
default is `['/']`, which always accepts connections to
the default namespace. Set to `'*'` to accept all
namespaces.
:param kwargs: Connection parameters for the underlying Engine.IO server.
The Engine.IO configuration supports the following settings:
Expand Down Expand Up @@ -110,7 +115,7 @@ class Server(object):

def __init__(self, client_manager=None, logger=False, serializer='default',
json=None, async_handlers=True, always_connect=False,
**kwargs):
namespaces=None, **kwargs):
engineio_options = kwargs
engineio_logger = engineio_options.pop('engineio_logger', None)
if engineio_logger is not None:
Expand Down Expand Up @@ -157,6 +162,7 @@ def __init__(self, client_manager=None, logger=False, serializer='default',

self.async_handlers = async_handlers
self.always_connect = always_connect
self.namespaces = namespaces or ['/']

self.async_mode = self.eio.async_mode

Expand Down Expand Up @@ -650,7 +656,8 @@ def _handle_connect(self, eio_sid, namespace, data):
"""Handle a client connection request."""
namespace = namespace or '/'
sid = None
if namespace in self.handlers or namespace in self.namespace_handlers:
if namespace in self.handlers or namespace in self.namespace_handlers \
or self.namespaces == '*' or namespace in self.namespaces:
sid = self.manager.connect(eio_sid, namespace)
if sid is None:
self._send_packet(eio_sid, self.packet_class(
Expand Down
22 changes: 21 additions & 1 deletion tests/asyncio/test_asyncio_server.py
Expand Up @@ -425,12 +425,32 @@ def test_handle_connect_async(self, eio):
_run(s._handle_eio_message('456', '0'))
assert s.manager.initialize.call_count == 1

def test_handle_connect_with_bad_namespace(self, eio):
def test_handle_connect_with_default_implied_namespaces(self, eio):
eio.return_value.send = AsyncMock()
s = asyncio_server.AsyncServer()
_run(s._handle_eio_connect('123', 'environ'))
_run(s._handle_eio_message('123', '0'))
_run(s._handle_eio_message('123', '0/foo,'))
assert s.manager.is_connected('1', '/')
assert not s.manager.is_connected('2', '/foo')

def test_handle_connect_with_implied_namespaces(self, eio):
eio.return_value.send = AsyncMock()
s = asyncio_server.AsyncServer(namespaces=['/foo'])
_run(s._handle_eio_connect('123', 'environ'))
_run(s._handle_eio_message('123', '0'))
_run(s._handle_eio_message('123', '0/foo,'))
assert not s.manager.is_connected('1', '/')
assert s.manager.is_connected('1', '/foo')

def test_handle_connect_with_all_implied_namespaces(self, eio):
eio.return_value.send = AsyncMock()
s = asyncio_server.AsyncServer(namespaces='*')
_run(s._handle_eio_connect('123', 'environ'))
_run(s._handle_eio_message('123', '0'))
_run(s._handle_eio_message('123', '0/foo,'))
assert s.manager.is_connected('1', '/')
assert s.manager.is_connected('2', '/foo')

def test_handle_connect_namespace(self, eio):
eio.return_value.send = AsyncMock()
Expand Down
20 changes: 19 additions & 1 deletion tests/common/test_server.py
Expand Up @@ -356,11 +356,29 @@ def test_handle_connect_with_auth_none(self, eio):
s._handle_eio_connect('456', 'environ')
assert s.manager.initialize.call_count == 1

def test_handle_connect_with_bad_namespace(self, eio):
def test_handle_connect_with_default_implied_namespaces(self, eio):
s = server.Server()
s._handle_eio_connect('123', 'environ')
s._handle_eio_message('123', '0')
s._handle_eio_message('123', '0/foo,')
assert s.manager.is_connected('1', '/')
assert not s.manager.is_connected('2', '/foo')

def test_handle_connect_with_implied_namespaces(self, eio):
s = server.Server(namespaces=['/foo'])
s._handle_eio_connect('123', 'environ')
s._handle_eio_message('123', '0')
s._handle_eio_message('123', '0/foo,')
assert not s.manager.is_connected('1', '/')
assert s.manager.is_connected('1', '/foo')

def test_handle_connect_with_all_implied_namespaces(self, eio):
s = server.Server(namespaces='*')
s._handle_eio_connect('123', 'environ')
s._handle_eio_message('123', '0')
s._handle_eio_message('123', '0/foo,')
assert s.manager.is_connected('1', '/')
assert s.manager.is_connected('2', '/foo')

def test_handle_connect_namespace(self, eio):
s = server.Server()
Expand Down

0 comments on commit efe87d8

Please sign in to comment.