Skip to content

Commit

Permalink
Allow functions to be used for URL, headers and auth data in client c…
Browse files Browse the repository at this point in the history
…onnection (Fixes #588)
  • Loading branch information
miguelgrinberg committed May 4, 2021
1 parent 2538df8 commit 7d2e7f7
Show file tree
Hide file tree
Showing 4 changed files with 136 additions and 10 deletions.
29 changes: 24 additions & 5 deletions socketio/asyncio_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,19 @@ async def connect(self, url, headers={}, auth=None, transports=None,
"""Connect to a Socket.IO server.
:param url: The URL of the Socket.IO server. It can include custom
query string parameters if required by the server.
query string parameters if required by the server. If a
function is provided, the client will invoke it to obtain
the URL each time a connection or reconnection is
attempted.
:param headers: A dictionary with custom headers to send with the
connection request.
connection request. If a function is provided, the
client will invoke it to obtain the headers dictionary
each time a connection or reconnection is attempted.
:param auth: Authentication data passed to the server with the
connection request, normally a dictionary with one or
more string key/value pairs.
more string key/value pairs. If a function is provided,
the client will invoke it to obtain the authentication
data each time a connection or reconnection is attempted.
:param transports: The list of allowed transports. Valid transports
are ``'polling'`` and ``'websocket'``. If not
given, the polling transport is connected first,
Expand Down Expand Up @@ -124,8 +131,10 @@ async def connect(self, url, headers={}, auth=None, transports=None,
self._connect_event = self.eio.create_event()
else:
self._connect_event.clear()
real_url = await self._get_real_value(self.connection_url)
real_headers = await self._get_real_value(self.connection_headers)
try:
await self.eio.connect(url, headers=headers,
await self.eio.connect(real_url, headers=real_headers,
transports=transports,
engineio_path=socketio_path)
except engineio.exceptions.ConnectionError as exc:
Expand Down Expand Up @@ -320,6 +329,15 @@ async def sleep(self, seconds=0):
"""
return await self.eio.sleep(seconds)

async def _get_real_value(self, value):
"""Return the actual value, for parameters that can also be given as
callables."""
if not callable(value):
return value
if asyncio.iscoroutinefunction(value):
return await value()
return value()

async def _send_packet(self, pkt):
"""Send a Socket.IO packet to the server."""
encoded_packet = pkt.encode()
Expand Down Expand Up @@ -462,9 +480,10 @@ async def _handle_eio_connect(self):
"""Handle the Engine.IO connection event."""
self.logger.info('Engine.IO connection established')
self.sid = self.eio.sid
real_auth = await self._get_real_value(self.connection_auth)
for n in self.connection_namespaces:
await self._send_packet(packet.Packet(
packet.CONNECT, data=self.connection_auth, namespace=n))
packet.CONNECT, data=real_auth, namespace=n))

async def _handle_eio_message(self, data):
"""Dispatch Engine.IO messages."""
Expand Down
28 changes: 23 additions & 5 deletions socketio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,12 +240,19 @@ def connect(self, url, headers={}, auth=None, transports=None,
"""Connect to a Socket.IO server.
:param url: The URL of the Socket.IO server. It can include custom
query string parameters if required by the server.
query string parameters if required by the server. If a
function is provided, the client will invoke it to obtain
the URL each time a connection or reconnection is
attempted.
:param headers: A dictionary with custom headers to send with the
connection request.
connection request. If a function is provided, the
client will invoke it to obtain the headers dictionary
each time a connection or reconnection is attempted.
:param auth: Authentication data passed to the server with the
connection request, normally a dictionary with one or
more string key/value pairs.
more string key/value pairs. If a function is provided,
the client will invoke it to obtain the authentication
data each time a connection or reconnection is attempted.
:param transports: The list of allowed transports. Valid transports
are ``'polling'`` and ``'websocket'``. If not
given, the polling transport is connected first,
Expand Down Expand Up @@ -294,8 +301,11 @@ def connect(self, url, headers={}, auth=None, transports=None,
self._connect_event = self.eio.create_event()
else:
self._connect_event.clear()
real_url = self._get_real_value(self.connection_url)
real_headers = self._get_real_value(self.connection_headers)
try:
self.eio.connect(url, headers=headers, transports=transports,
self.eio.connect(real_url, headers=real_headers,
transports=transports,
engineio_path=socketio_path)
except engineio.exceptions.ConnectionError as exc:
self._trigger_event(
Expand Down Expand Up @@ -490,6 +500,13 @@ def sleep(self, seconds=0):
"""
return self.eio.sleep(seconds)

def _get_real_value(self, value):
"""Return the actual value, for parameters that can also be given as
callables."""
if not callable(value):
return value
return value()

def _send_packet(self, pkt):
"""Send a Socket.IO packet to the server."""
encoded_packet = pkt.encode()
Expand Down Expand Up @@ -628,9 +645,10 @@ def _handle_eio_connect(self):
"""Handle the Engine.IO connection event."""
self.logger.info('Engine.IO connection established')
self.sid = self.eio.sid
real_auth = self._get_real_value(self.connection_auth)
for n in self.connection_namespaces:
self._send_packet(packet.Packet(
packet.CONNECT, data=self.connection_auth, namespace=n))
packet.CONNECT, data=real_auth, namespace=n))

def _handle_eio_message(self, data):
"""Dispatch Engine.IO messages."""
Expand Down
47 changes: 47 additions & 0 deletions tests/asyncio/test_asyncio_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,30 @@ def test_connect(self):
engineio_path='path',
)

def test_connect_functions(self):
async def headers():
return 'headers'

c = asyncio_client.AsyncClient()
c.eio.connect = AsyncMock()
_run(
c.connect(
lambda: 'url',
headers=headers,
auth='auth',
transports='transports',
namespaces=['/foo', '/', '/bar'],
socketio_path='path',
wait=False,
)
)
c.eio.connect.mock.assert_called_once_with(
'url',
headers='headers',
transports='transports',
engineio_path='path',
)

def test_connect_one_namespace(self):
c = asyncio_client.AsyncClient()
c.eio.connect = AsyncMock()
Expand Down Expand Up @@ -960,6 +984,29 @@ def test_handle_eio_connect(self):
== expected_packet.encode()
)

def test_handle_eio_connect_function(self):
c = asyncio_client.AsyncClient()
c.connection_namespaces = ['/', '/foo']
c.connection_auth = lambda: 'auth'
c._send_packet = AsyncMock()
c.eio.sid = 'foo'
assert c.sid is None
_run(c._handle_eio_connect())
assert c.sid == 'foo'
assert c._send_packet.mock.call_count == 2
expected_packet = packet.Packet(
packet.CONNECT, data='auth', namespace='/')
assert (
c._send_packet.mock.call_args_list[0][0][0].encode()
== expected_packet.encode()
)
expected_packet = packet.Packet(
packet.CONNECT, data='auth', namespace='/foo')
assert (
c._send_packet.mock.call_args_list[1][0][0].encode()
== expected_packet.encode()
)

def test_handle_eio_message(self):
c = asyncio_client.AsyncClient()
c._handle_connect = AsyncMock()
Expand Down
42 changes: 42 additions & 0 deletions tests/common/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,25 @@ def test_connect(self):
engineio_path='path',
)

def test_connect_functions(self):
c = client.Client()
c.eio.connect = mock.MagicMock()
c.connect(
lambda: 'url',
headers=lambda: 'headers',
auth='auth',
transports='transports',
namespaces=['/foo', '/', '/bar'],
socketio_path='path',
wait=False,
)
c.eio.connect.assert_called_once_with(
'url',
headers='headers',
transports='transports',
engineio_path='path',
)

def test_connect_one_namespace(self):
c = client.Client()
c.eio.connect = mock.MagicMock()
Expand Down Expand Up @@ -1030,6 +1049,29 @@ def test_handle_eio_connect(self):
== expected_packet.encode()
)

def test_handle_eio_connect_function(self):
c = client.Client()
c.connection_namespaces = ['/', '/foo']
c.connection_auth = lambda: 'auth'
c._send_packet = mock.MagicMock()
c.eio.sid = 'foo'
assert c.sid is None
c._handle_eio_connect()
assert c.sid == 'foo'
assert c._send_packet.call_count == 2
expected_packet = packet.Packet(
packet.CONNECT, data='auth', namespace='/')
assert (
c._send_packet.call_args_list[0][0][0].encode()
== expected_packet.encode()
)
expected_packet = packet.Packet(
packet.CONNECT, data='auth', namespace='/foo')
assert (
c._send_packet.call_args_list[1][0][0].encode()
== expected_packet.encode()
)

def test_handle_eio_message(self):
c = client.Client()
c._handle_connect = mock.MagicMock()
Expand Down

0 comments on commit 7d2e7f7

Please sign in to comment.