Skip to content

Commit

Permalink
Pass auth information sent by client to the connect handler
Browse files Browse the repository at this point in the history
  • Loading branch information
miguelgrinberg committed Dec 27, 2020
1 parent 3349b02 commit 11b6f1a
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 11 deletions.
8 changes: 5 additions & 3 deletions docs/server.rst
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ The ``connect`` and ``disconnect`` events are special; they are invoked
automatically when a client connects or disconnects from the server::

@sio.event
def connect(sid, environ):
def connect(sid, environ, auth):
print('connect ', sid)

@sio.event
Expand All @@ -193,8 +193,10 @@ The ``connect`` event is an ideal place to perform user authentication, and
any necessary mapping between user entities in the application and the ``sid``
that was assigned to the client. The ``environ`` argument is a dictionary in
standard WSGI format containing the request information, including HTTP
headers. After inspecting the request, the connect event handler can return
``False`` to reject the connection with the client.
headers. The ``auth`` argument contains any authentication details passed by
the client, or ``None`` if the client did not pass anything. After inspecting
the request, the connect event handler can return ``False`` to reject the
connection with the client.

Sometimes it is useful to pass data back to the client being rejected. In that
case instead of returning ``False``
Expand Down
16 changes: 12 additions & 4 deletions socketio/asyncio_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,7 @@ async def _send_packet(self, eio_sid, pkt):
else:
await self.eio.send(eio_sid, encoded_packet)

async def _handle_connect(self, eio_sid, namespace):
async def _handle_connect(self, eio_sid, namespace, data):
"""Handle a client connection request."""
namespace = namespace or '/'
sid = self.manager.connect(eio_sid, namespace)
Expand All @@ -442,8 +442,16 @@ async def _handle_connect(self, eio_sid, namespace):
packet.CONNECT, {'sid': sid}, namespace=namespace))
fail_reason = exceptions.ConnectionRefusedError().error_args
try:
success = await self._trigger_event('connect', namespace, sid,
self.environ[eio_sid])
if data:
success = await self._trigger_event(
'connect', namespace, sid, self.environ[eio_sid], data)
else:
try:
success = await self._trigger_event(
'connect', namespace, sid, self.environ[eio_sid])
except TypeError:
success = await self._trigger_event(
'connect', namespace, sid, self.environ[eio_sid], None)
except exceptions.ConnectionRefusedError as exc:
fail_reason = exc.error_args
success = False
Expand Down Expand Up @@ -552,7 +560,7 @@ async def _handle_eio_message(self, eio_sid, data):
else:
pkt = packet.Packet(encoded_packet=data)
if pkt.packet_type == packet.CONNECT:
await self._handle_connect(eio_sid, pkt.namespace)
await self._handle_connect(eio_sid, pkt.namespace, pkt.data)
elif pkt.packet_type == packet.DISCONNECT:
await self._handle_disconnect(eio_sid, pkt.namespace)
elif pkt.packet_type == packet.EVENT:
Expand Down
16 changes: 12 additions & 4 deletions socketio/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,7 +619,7 @@ def _send_packet(self, eio_sid, pkt):
else:
self.eio.send(eio_sid, encoded_packet)

def _handle_connect(self, eio_sid, namespace):
def _handle_connect(self, eio_sid, namespace, data):
"""Handle a client connection request."""
namespace = namespace or '/'
sid = self.manager.connect(eio_sid, namespace)
Expand All @@ -628,8 +628,16 @@ def _handle_connect(self, eio_sid, namespace):
packet.CONNECT, {'sid': sid}, namespace=namespace))
fail_reason = exceptions.ConnectionRefusedError().error_args
try:
success = self._trigger_event('connect', namespace, sid,
self.environ[eio_sid])
if data:
success = self._trigger_event(
'connect', namespace, sid, self.environ[eio_sid], data)
else:
try:
success = self._trigger_event(
'connect', namespace, sid, self.environ[eio_sid])
except TypeError:
success = self._trigger_event(
'connect', namespace, sid, self.environ[eio_sid], None)
except exceptions.ConnectionRefusedError as exc:
fail_reason = exc.error_args
success = False
Expand Down Expand Up @@ -729,7 +737,7 @@ def _handle_eio_message(self, eio_sid, data):
else:
pkt = packet.Packet(encoded_packet=data)
if pkt.packet_type == packet.CONNECT:
self._handle_connect(eio_sid, pkt.namespace)
self._handle_connect(eio_sid, pkt.namespace, pkt.data)
elif pkt.packet_type == packet.DISCONNECT:
self._handle_disconnect(eio_sid, pkt.namespace)
elif pkt.packet_type == packet.EVENT:
Expand Down
32 changes: 32 additions & 0 deletions tests/asyncio/test_asyncio_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,38 @@ def test_handle_connect(self, eio):
_run(s._handle_eio_message('456', '0'))
assert s.manager.initialize.call_count == 1

def test_handle_connect_with_auth(self, eio):
eio.return_value.send = AsyncMock()
s = asyncio_server.AsyncServer()
s.manager.initialize = mock.MagicMock()
handler = mock.MagicMock()
s.on('connect', handler)
_run(s._handle_eio_connect('123', 'environ'))
_run(s._handle_eio_message('123', '0{"token":"abc"}'))
assert s.manager.is_connected('1', '/')
handler.assert_called_once_with('1', 'environ', {'token': 'abc'})
s.eio.send.mock.assert_called_once_with('123', '0{"sid":"1"}')
assert s.manager.initialize.call_count == 1
_run(s._handle_eio_connect('456', 'environ'))
_run(s._handle_eio_message('456', '0'))
assert s.manager.initialize.call_count == 1

def test_handle_connect_with_auth_none(self, eio):
eio.return_value.send = AsyncMock()
s = asyncio_server.AsyncServer()
s.manager.initialize = mock.MagicMock()
handler = mock.MagicMock(side_effect=[TypeError, None, None])
s.on('connect', handler)
_run(s._handle_eio_connect('123', 'environ'))
_run(s._handle_eio_message('123', '0'))
assert s.manager.is_connected('1', '/')
handler.assert_called_with('1', 'environ', None)
s.eio.send.mock.assert_called_once_with('123', '0{"sid":"1"}')
assert s.manager.initialize.call_count == 1
_run(s._handle_eio_connect('456', 'environ'))
_run(s._handle_eio_message('456', '0'))
assert s.manager.initialize.call_count == 1

def test_handle_connect_async(self, eio):
eio.return_value.send = AsyncMock()
s = asyncio_server.AsyncServer()
Expand Down
28 changes: 28 additions & 0 deletions tests/common/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,34 @@ def test_handle_connect(self, eio):
s._handle_eio_connect('456', 'environ')
assert s.manager.initialize.call_count == 1

def test_handle_connect_with_auth(self, eio):
s = server.Server()
s.manager.initialize = mock.MagicMock()
handler = mock.MagicMock()
s.on('connect', handler)
s._handle_eio_connect('123', 'environ')
s._handle_eio_message('123', '0{"token":"abc"}')
assert s.manager.is_connected('1', '/')
handler.assert_called_with('1', 'environ', {'token': 'abc'})
s.eio.send.assert_called_once_with('123', '0{"sid":"1"}')
assert s.manager.initialize.call_count == 1
s._handle_eio_connect('456', 'environ')
assert s.manager.initialize.call_count == 1

def test_handle_connect_with_auth_none(self, eio):
s = server.Server()
s.manager.initialize = mock.MagicMock()
handler = mock.MagicMock(side_effect=[TypeError, None])
s.on('connect', handler)
s._handle_eio_connect('123', 'environ')
s._handle_eio_message('123', '0')
assert s.manager.is_connected('1', '/')
handler.assert_called_with('1', 'environ', None)
s.eio.send.assert_called_once_with('123', '0{"sid":"1"}')
assert s.manager.initialize.call_count == 1
s._handle_eio_connect('456', 'environ')
assert s.manager.initialize.call_count == 1

def test_handle_connect_namespace(self, eio):
s = server.Server()
handler = mock.MagicMock()
Expand Down

0 comments on commit 11b6f1a

Please sign in to comment.