diff --git a/src/engineio/async_drivers/aiohttp.py b/src/engineio/async_drivers/aiohttp.py index 41131d8..72dab64 100644 --- a/src/engineio/async_drivers/aiohttp.py +++ b/src/engineio/async_drivers/aiohttp.py @@ -79,11 +79,13 @@ class WebSocket: # pragma: no cover """ def __init__(self, handler, server): self.handler = handler + self.server = server self._sock = None async def __call__(self, environ): request = environ['aiohttp.request'] - self._sock = WebSocketResponse(max_msg_size=0) + self._sock = WebSocketResponse( + max_msg_size=self.server.max_http_buffer_size) await self._sock.prepare(request) self.environ = environ diff --git a/src/engineio/async_drivers/asgi.py b/src/engineio/async_drivers/asgi.py index c3581c6..9889047 100644 --- a/src/engineio/async_drivers/asgi.py +++ b/src/engineio/async_drivers/asgi.py @@ -136,10 +136,18 @@ def _ensure_trailing_slash(self, path): async def translate_request(scope, receive, send): class AwaitablePayload: # pragma: no cover - def __init__(self, payload): - self.payload = payload or b'' + def __init__(self, event): + self.event = event + self.payload = None async def read(self, length=None): + if self.payload is None and event['type'] == 'http.request': + # read payload from http request + self.payload = self.event.get('body') or b'' + while self.event.get('more_body'): + self.event = await receive() + if self.event['type'] == 'http.request': + self.payload += self.event.get('body') or b'' if length is None: r = self.payload self.payload = b'' @@ -149,16 +157,7 @@ async def read(self, length=None): return r event = await receive() - payload = b'' - if event['type'] == 'http.request': - payload += event.get('body') or b'' - while event.get('more_body'): - event = await receive() - if event['type'] == 'http.request': - payload += event.get('body') or b'' - elif event['type'] == 'websocket.connect': - pass - else: + if event['type'] not in ['http.request', 'websocket.connect']: return {} raw_uri = scope['path'] @@ -171,7 +170,7 @@ async def read(self, length=None): else: raw_uri += '?' + query_string environ = { - 'wsgi.input': AwaitablePayload(payload), + 'wsgi.input': AwaitablePayload(event), 'wsgi.errors': sys.stderr, 'wsgi.version': (1, 0), 'wsgi.async': True, diff --git a/src/engineio/async_server.py b/src/engineio/async_server.py index e0e252e..91339c3 100644 --- a/src/engineio/async_server.py +++ b/src/engineio/async_server.py @@ -462,7 +462,6 @@ async def _handle_connect(self, environ, transport, jsonp_index=None): 'maxPayload': self.max_http_buffer_size, }) await s.send(pkt) - s.schedule_ping() ret = await self._trigger_event('connect', sid, environ, run_async=False) @@ -471,6 +470,8 @@ async def _handle_connect(self, environ, transport, jsonp_index=None): self.logger.warning('Application rejected connection') return self._unauthorized(ret or None) + s.schedule_ping() + if transport == 'websocket': ret = await s.handle_get_request(environ) if s.closed and sid in self.sockets: diff --git a/src/engineio/async_socket.py b/src/engineio/async_socket.py index cfdbe1a..c1948ee 100644 --- a/src/engineio/async_socket.py +++ b/src/engineio/async_socket.py @@ -128,10 +128,12 @@ async def close(self, wait=True, abort=False, reason=None): await self.queue.join() def schedule_ping(self): - self.server.start_background_task(self._send_ping) + # only schedule a new ping if the previous ping wait cycle completed + if self.last_ping: + self.last_ping = None + self.server.start_background_task(self._send_ping) async def _send_ping(self): - self.last_ping = None await asyncio.sleep(self.server.ping_interval) if not self.closing and not self.closed: self.last_ping = time.time() diff --git a/src/engineio/base_socket.py b/src/engineio/base_socket.py index 6b5d7dc..9db39c6 100644 --- a/src/engineio/base_socket.py +++ b/src/engineio/base_socket.py @@ -1,3 +1,6 @@ +import time + + class BaseSocket: upgrade_protocols = ['websocket'] @@ -5,7 +8,7 @@ def __init__(self, server, sid): self.server = server self.sid = sid self.queue = self.server.create_queue() - self.last_ping = None + self.last_ping = time.time() self.connected = False self.upgrading = False self.upgraded = False diff --git a/src/engineio/server.py b/src/engineio/server.py index 917c53f..8117d60 100644 --- a/src/engineio/server.py +++ b/src/engineio/server.py @@ -401,7 +401,6 @@ def _handle_connect(self, environ, start_response, transport, 'maxPayload': self.max_http_buffer_size, }) s.send(pkt) - s.schedule_ping() # NOTE: some sections below are marked as "no cover" to workaround # what seems to be a bug in the coverage package. All the lines below @@ -413,6 +412,8 @@ def _handle_connect(self, environ, start_response, transport, self.logger.warning('Application rejected connection') return self._unauthorized(ret or None) + s.schedule_ping() + if transport == 'websocket': # pragma: no cover ret = s.handle_get_request(environ, start_response) if s.closed and sid in self.sockets: diff --git a/src/engineio/socket.py b/src/engineio/socket.py index 26bb94b..1ddd31d 100644 --- a/src/engineio/socket.py +++ b/src/engineio/socket.py @@ -130,10 +130,12 @@ def close(self, wait=True, abort=False, reason=None): self.queue.join() def schedule_ping(self): - self.server.start_background_task(self._send_ping) + # only schedule a new ping if the previous ping wait cycle completed + if self.last_ping: + self.last_ping = None + self.server.start_background_task(self._send_ping) def _send_ping(self): - self.last_ping = None self.server.sleep(self.server.ping_interval) if not self.closing and not self.closed: self.last_ping = time.time() diff --git a/tests/async/test_socket.py b/tests/async/test_socket.py index b6c4e60..ca7f862 100644 --- a/tests/async/test_socket.py +++ b/tests/async/test_socket.py @@ -95,14 +95,30 @@ async def test_schedule_ping(self): s = async_socket.AsyncSocket(mock_server, 'sid') s.send = mock.AsyncMock() - async def schedule_ping(): + async def schedule_ping_and_sleep(): s.schedule_ping() await asyncio.sleep(0.05) - await schedule_ping() + await schedule_ping_and_sleep() assert s.last_ping is not None assert s.send.await_args_list[0][0][0].encode() == '2' + async def test_schedule_ping_twice(self): + mock_server = self._get_mock_server() + mock_server.ping_interval = 0.01 + s = async_socket.AsyncSocket(mock_server, 'sid') + s.send = mock.AsyncMock() + + async def schedule_ping_and_sleep(): + s.schedule_ping() + await asyncio.sleep(0.05) + + s.schedule_ping() + assert s.last_ping is None + await schedule_ping_and_sleep() + assert s.last_ping is not None + assert s.send.await_count == 1 + async def test_schedule_ping_closed_socket(self): mock_server = self._get_mock_server() mock_server.ping_interval = 0.01 diff --git a/tests/common/test_socket.py b/tests/common/test_socket.py index a5e91cd..2197a14 100644 --- a/tests/common/test_socket.py +++ b/tests/common/test_socket.py @@ -21,6 +21,7 @@ def _get_mock_server(self): mock_server.ping_interval_grace_period = 0.001 mock_server.async_handlers = True mock_server.max_http_buffer_size = 128 + mock_server.sleep = time.sleep try: import queue @@ -104,6 +105,19 @@ def test_schedule_ping(self): assert s.last_ping is not None assert s.send.call_args_list[0][0][0].encode() == '2' + def test_schedule_ping_twice(self): + mock_server = self._get_mock_server() + mock_server.ping_interval = 0.1 + s = socket.Socket(mock_server, 'sid') + s.send = mock.MagicMock() + s.schedule_ping() + assert s.last_ping is None + time.sleep(0.01) + s.schedule_ping() + time.sleep(0.1) + assert s.last_ping is not None + assert s.send.call_count == 1 + def test_schedule_ping_closed_socket(self): mock_server = self._get_mock_server() mock_server.ping_interval = 0.01