Skip to content

Commit

Permalink
Async versions of enter_room and leave_room should be coroutines (bre…
Browse files Browse the repository at this point in the history
…aking change)
  • Loading branch information
miguelgrinberg committed Sep 19, 2023
1 parent 8da3c61 commit ab33cb7
Show file tree
Hide file tree
Showing 11 changed files with 141 additions and 53 deletions.
4 changes: 2 additions & 2 deletions examples/server/aiohttp/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,14 @@ async def my_broadcast_event(sid, message):

@sio.event
async def join(sid, message):
sio.enter_room(sid, message['room'])
await sio.enter_room(sid, message['room'])
await sio.emit('my_response', {'data': 'Entered room: ' + message['room']},
room=sid)


@sio.event
async def leave(sid, message):
sio.leave_room(sid, message['room'])
await sio.leave_room(sid, message['room'])
await sio.emit('my_response', {'data': 'Left room: ' + message['room']},
room=sid)

Expand Down
4 changes: 2 additions & 2 deletions examples/server/asgi/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,14 @@ async def test_broadcast_message(sid, message):

@sio.on('join')
async def join(sid, message):
sio.enter_room(sid, message['room'])
await sio.enter_room(sid, message['room'])
await sio.emit('my_response', {'data': 'Entered room: ' + message['room']},
room=sid)


@sio.on('leave')
async def leave(sid, message):
sio.leave_room(sid, message['room'])
await sio.leave_room(sid, message['room'])
await sio.emit('my_response', {'data': 'Left room: ' + message['room']},
room=sid)

Expand Down
4 changes: 2 additions & 2 deletions examples/server/sanic/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,14 @@ async def my_broadcast_event(sid, message):

@sio.event
async def join(sid, message):
sio.enter_room(sid, message['room'])
await sio.enter_room(sid, message['room'])
await sio.emit('my_response', {'data': 'Entered room: ' + message['room']},
room=sid)


@sio.event
async def leave(sid, message):
sio.leave_room(sid, message['room'])
await sio.leave_room(sid, message['room'])
await sio.emit('my_response', {'data': 'Left room: ' + message['room']},
room=sid)

Expand Down
4 changes: 2 additions & 2 deletions examples/server/tornado/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,14 @@ async def my_broadcast_event(sid, message):

@sio.event
async def join(sid, message):
sio.enter_room(sid, message['room'])
await sio.enter_room(sid, message['room'])
await sio.emit('my_response', {'data': 'Entered room: ' + message['room']},
room=sid)


@sio.event
async def leave(sid, message):
sio.leave_room(sid, message['room'])
await sio.leave_room(sid, message['room'])
await sio.emit('my_response', {'data': 'Left room: ' + message['room']},
room=sid)

Expand Down
14 changes: 14 additions & 0 deletions src/socketio/asyncio_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,20 @@ async def disconnect(self, sid, namespace, **kwargs):
"""
return super().disconnect(sid, namespace, **kwargs)

async def enter_room(self, sid, namespace, room, eio_sid=None):
"""Add a client to a room.
Note: this method is a coroutine.
"""
return super().enter_room(sid, namespace, room, eio_sid=eio_sid)

async def leave_room(self, sid, namespace, room):
"""Remove a client from a room.
Note: this method is a coroutine.
"""
return super().leave_room(sid, namespace, room)

async def close_room(self, room, namespace):
"""Remove all participants from a room.
Expand Down
24 changes: 24 additions & 0 deletions src/socketio/asyncio_namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,30 @@ async def call(self, event, data=None, to=None, sid=None, namespace=None,
timeout=timeout,
ignore_queue=ignore_queue)

async def enter_room(self, sid, room, namespace=None):
"""Enter a room.
The only difference with the :func:`socketio.Server.enter_room` method
is that when the ``namespace`` argument is not given the namespace
associated with the class is used.
Note: this method is a coroutine.
"""
return await self.server.enter_room(
sid, room, namespace=namespace or self.namespace)

async def leave_room(self, sid, room, namespace=None):
"""Leave a room.
The only difference with the :func:`socketio.Server.leave_room` method
is that when the ``namespace`` argument is not given the namespace
associated with the class is used.
Note: this method is a coroutine.
"""
return await self.server.leave_room(
sid, room, namespace=namespace or self.namespace)

async def close_room(self, room, namespace=None):
"""Close a room.
Expand Down
34 changes: 34 additions & 0 deletions src/socketio/asyncio_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,40 @@ def event_callback(*args):
else callback_args[0][0] if len(callback_args[0]) == 1 \
else None

async def enter_room(self, sid, room, namespace=None):
"""Enter a room.
This function adds the client to a room. The :func:`emit` and
:func:`send` functions can optionally broadcast events to all the
clients in a room.
:param sid: Session ID of the client.
:param room: Room name. If the room does not exist it is created.
:param namespace: The Socket.IO namespace for the event. If this
argument is omitted the default namespace is used.
Note: this method is a coroutine.
"""
namespace = namespace or '/'
self.logger.info('%s is entering room %s [%s]', sid, room, namespace)
await self.manager.enter_room(sid, namespace, room)

async def leave_room(self, sid, room, namespace=None):
"""Leave a room.
This function removes the client from a room.
:param sid: Session ID of the client.
:param room: Room name.
:param namespace: The Socket.IO namespace for the event. If this
argument is omitted the default namespace is used.
Note: this method is a coroutine.
"""
namespace = namespace or '/'
self.logger.info('%s is leaving room %s [%s]', sid, room, namespace)
await self.manager.leave_room(sid, namespace, room)

async def close_room(self, room, namespace=None):
"""Close a room.
Expand Down
20 changes: 14 additions & 6 deletions src/socketio/base_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,11 @@ def connect(self, eio_sid, namespace):
"""Register a client connection to a namespace."""
sid = self.server.eio.generate_id()
try:
self.enter_room(sid, namespace, None, eio_sid=eio_sid)
self.basic_enter_room(sid, namespace, None, eio_sid=eio_sid)
except ValueDuplicationError:
# already connected
return None
self.enter_room(sid, namespace, sid, eio_sid=eio_sid)
self.basic_enter_room(sid, namespace, sid, eio_sid=eio_sid)
return sid

def is_connected(self, sid, namespace):
Expand Down Expand Up @@ -106,7 +106,7 @@ def disconnect(self, sid, namespace, **kwargs):
if sid in room:
rooms.append(room_name)
for room in rooms:
self.leave_room(sid, namespace, room)
self.basic_leave_room(sid, namespace, room)
if sid in self.callbacks:
del self.callbacks[sid]
if namespace in self.pending_disconnect and \
Expand All @@ -115,7 +115,7 @@ def disconnect(self, sid, namespace, **kwargs):
if len(self.pending_disconnect[namespace]) == 0:
del self.pending_disconnect[namespace]

def enter_room(self, sid, namespace, room, eio_sid=None):
def basic_enter_room(self, sid, namespace, room, eio_sid=None):
"""Add a client to a room."""
if eio_sid is None and namespace not in self.rooms:
raise ValueError('sid is not connected to requested namespace')
Expand All @@ -127,7 +127,7 @@ def enter_room(self, sid, namespace, room, eio_sid=None):
eio_sid = self.rooms[namespace][None][sid]
self.rooms[namespace][room][sid] = eio_sid

def leave_room(self, sid, namespace, room):
def basic_leave_room(self, sid, namespace, room):
"""Remove a client from a room."""
try:
del self.rooms[namespace][room][sid]
Expand All @@ -138,11 +138,19 @@ def leave_room(self, sid, namespace, room):
except KeyError:
pass

def enter_room(self, sid, namespace, room, eio_sid=None):
"""Add a client to a room."""
self.basic_enter_room(sid, namespace, room, eio_sid=eio_sid)

def leave_room(self, sid, namespace, room):
"""Remove a client from a room."""
self.basic_leave_room(sid, namespace, room)

def close_room(self, room, namespace):
"""Remove all participants from a room."""
try:
for sid, _ in self.get_participants(namespace, room):
self.leave_room(sid, namespace, room)
self.basic_leave_room(sid, namespace, room)
except KeyError:
pass

Expand Down
44 changes: 23 additions & 21 deletions tests/asyncio/test_asyncio_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ def test_pre_disconnect(self):
def test_disconnect(self):
sid1 = self.bm.connect('123', '/foo')
sid2 = self.bm.connect('456', '/foo')
self.bm.enter_room(sid1, '/foo', 'bar')
self.bm.enter_room(sid2, '/foo', 'baz')
_run(self.bm.enter_room(sid1, '/foo', 'bar'))
_run(self.bm.enter_room(sid2, '/foo', 'baz'))
_run(self.bm.disconnect(sid1, '/foo'))
assert dict(self.bm.rooms['/foo'][None]) == {sid2: '456'}
assert dict(self.bm.rooms['/foo'][sid2]) == {sid2: '456'}
Expand Down Expand Up @@ -97,8 +97,8 @@ def test_disconnect_twice(self):
def test_disconnect_all(self):
sid1 = self.bm.connect('123', '/foo')
sid2 = self.bm.connect('456', '/foo')
self.bm.enter_room(sid1, '/foo', 'bar')
self.bm.enter_room(sid2, '/foo', 'baz')
_run(self.bm.enter_room(sid1, '/foo', 'bar'))
_run(self.bm.enter_room(sid2, '/foo', 'baz'))
_run(self.bm.disconnect(sid1, '/foo'))
_run(self.bm.disconnect(sid2, '/foo'))
assert self.bm.rooms == {}
Expand Down Expand Up @@ -173,8 +173,8 @@ def test_get_participants(self):

def test_leave_invalid_room(self):
sid = self.bm.connect('123', '/foo')
self.bm.leave_room(sid, '/foo', 'baz')
self.bm.leave_room(sid, '/bar', 'baz')
_run(self.bm.leave_room(sid, '/foo', 'baz'))
_run(self.bm.leave_room(sid, '/bar', 'baz'))

def test_no_room(self):
rooms = self.bm.get_rooms('123', '/foo')
Expand All @@ -184,17 +184,19 @@ def test_close_room(self):
sid = self.bm.connect('123', '/foo')
self.bm.connect('456', '/foo')
self.bm.connect('789', '/foo')
self.bm.enter_room(sid, '/foo', 'bar')
self.bm.enter_room(sid, '/foo', 'bar')
_run(self.bm.enter_room(sid, '/foo', 'bar'))
_run(self.bm.enter_room(sid, '/foo', 'bar'))
_run(self.bm.close_room('bar', '/foo'))
from pprint import pprint
pprint(self.bm.rooms)
assert 'bar' not in self.bm.rooms['/foo']

def test_close_invalid_room(self):
self.bm.close_room('bar', '/foo')

def test_rooms(self):
sid = self.bm.connect('123', '/foo')
self.bm.enter_room(sid, '/foo', 'bar')
_run(self.bm.enter_room(sid, '/foo', 'bar'))
r = self.bm.get_rooms(sid, '/foo')
assert len(r) == 2
assert sid in r
Expand All @@ -216,9 +218,9 @@ def test_emit_to_sid(self):

def test_emit_to_room(self):
sid1 = self.bm.connect('123', '/foo')
self.bm.enter_room(sid1, '/foo', 'bar')
_run(self.bm.enter_room(sid1, '/foo', 'bar'))
sid2 = self.bm.connect('456', '/foo')
self.bm.enter_room(sid2, '/foo', 'bar')
_run(self.bm.enter_room(sid2, '/foo', 'bar'))
self.bm.connect('789', '/foo')
_run(
self.bm.emit(
Expand All @@ -237,12 +239,12 @@ def test_emit_to_room(self):

def test_emit_to_rooms(self):
sid1 = self.bm.connect('123', '/foo')
self.bm.enter_room(sid1, '/foo', 'bar')
_run(self.bm.enter_room(sid1, '/foo', 'bar'))
sid2 = self.bm.connect('456', '/foo')
self.bm.enter_room(sid2, '/foo', 'bar')
self.bm.enter_room(sid2, '/foo', 'baz')
_run(self.bm.enter_room(sid2, '/foo', 'bar'))
_run(self.bm.enter_room(sid2, '/foo', 'baz'))
sid3 = self.bm.connect('789', '/foo')
self.bm.enter_room(sid3, '/foo', 'baz')
_run(self.bm.enter_room(sid3, '/foo', 'baz'))
_run(
self.bm.emit('my event', {'foo': 'bar'}, namespace='/foo',
room=['bar', 'baz'])
Expand All @@ -263,9 +265,9 @@ def test_emit_to_rooms(self):

def test_emit_to_all(self):
sid1 = self.bm.connect('123', '/foo')
self.bm.enter_room(sid1, '/foo', 'bar')
_run(self.bm.enter_room(sid1, '/foo', 'bar'))
sid2 = self.bm.connect('456', '/foo')
self.bm.enter_room(sid2, '/foo', 'bar')
_run(self.bm.enter_room(sid2, '/foo', 'bar'))
self.bm.connect('789', '/foo')
self.bm.connect('abc', '/bar')
_run(self.bm.emit('my event', {'foo': 'bar'}, namespace='/foo'))
Expand All @@ -285,9 +287,9 @@ def test_emit_to_all(self):

def test_emit_to_all_skip_one(self):
sid1 = self.bm.connect('123', '/foo')
self.bm.enter_room(sid1, '/foo', 'bar')
_run(self.bm.enter_room(sid1, '/foo', 'bar'))
sid2 = self.bm.connect('456', '/foo')
self.bm.enter_room(sid2, '/foo', 'bar')
_run(self.bm.enter_room(sid2, '/foo', 'bar'))
self.bm.connect('789', '/foo')
self.bm.connect('abc', '/bar')
_run(
Expand All @@ -307,9 +309,9 @@ def test_emit_to_all_skip_one(self):

def test_emit_to_all_skip_two(self):
sid1 = self.bm.connect('123', '/foo')
self.bm.enter_room(sid1, '/foo', 'bar')
_run(self.bm.enter_room(sid1, '/foo', 'bar'))
sid2 = self.bm.connect('456', '/foo')
self.bm.enter_room(sid2, '/foo', 'bar')
_run(self.bm.enter_room(sid2, '/foo', 'bar'))
sid3 = self.bm.connect('789', '/foo')
self.bm.connect('abc', '/bar')
_run(
Expand Down
24 changes: 14 additions & 10 deletions tests/asyncio/test_asyncio_namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,25 +176,29 @@ def test_call(self):

def test_enter_room(self):
ns = asyncio_namespace.AsyncNamespace('/foo')
ns._set_server(mock.MagicMock())
ns.enter_room('sid', 'room')
ns.server.enter_room.assert_called_with(
mock_server = mock.MagicMock()
mock_server.enter_room = AsyncMock()
ns._set_server(mock_server)
_run(ns.enter_room('sid', 'room'))
ns.server.enter_room.mock.assert_called_with(
'sid', 'room', namespace='/foo'
)
ns.enter_room('sid', 'room', namespace='/bar')
ns.server.enter_room.assert_called_with(
_run(ns.enter_room('sid', 'room', namespace='/bar'))
ns.server.enter_room.mock.assert_called_with(
'sid', 'room', namespace='/bar'
)

def test_leave_room(self):
ns = asyncio_namespace.AsyncNamespace('/foo')
ns._set_server(mock.MagicMock())
ns.leave_room('sid', 'room')
ns.server.leave_room.assert_called_with(
mock_server = mock.MagicMock()
mock_server.leave_room = AsyncMock()
ns._set_server(mock_server)
_run(ns.leave_room('sid', 'room'))
ns.server.leave_room.mock.assert_called_with(
'sid', 'room', namespace='/foo'
)
ns.leave_room('sid', 'room', namespace='/bar')
ns.server.leave_room.assert_called_with(
_run(ns.leave_room('sid', 'room', namespace='/bar'))
ns.server.leave_room.mock.assert_called_with(
'sid', 'room', namespace='/bar'
)

Expand Down
Loading

0 comments on commit ab33cb7

Please sign in to comment.