Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/engineio/async_drivers/aiohttp.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,13 @@ class WebSocket: # pragma: no cover
"""
def __init__(self, handler, server):
self.handler = handler
self.server = server
self._sock = None

async def __call__(self, environ):
request = environ['aiohttp.request']
self._sock = WebSocketResponse(max_msg_size=0)
self._sock = WebSocketResponse(
max_msg_size=self.server.max_http_buffer_size)
await self._sock.prepare(request)

self.environ = environ
Expand Down
25 changes: 12 additions & 13 deletions src/engineio/async_drivers/asgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,10 +136,18 @@ def _ensure_trailing_slash(self, path):

async def translate_request(scope, receive, send):
class AwaitablePayload: # pragma: no cover
def __init__(self, payload):
self.payload = payload or b''
def __init__(self, event):
self.event = event
self.payload = None

async def read(self, length=None):
if self.payload is None and event['type'] == 'http.request':
# read payload from http request
self.payload = self.event.get('body') or b''
while self.event.get('more_body'):
self.event = await receive()
if self.event['type'] == 'http.request':
self.payload += self.event.get('body') or b''
if length is None:
r = self.payload
self.payload = b''
Expand All @@ -149,16 +157,7 @@ async def read(self, length=None):
return r

event = await receive()
payload = b''
if event['type'] == 'http.request':
payload += event.get('body') or b''
while event.get('more_body'):
event = await receive()
if event['type'] == 'http.request':
payload += event.get('body') or b''
elif event['type'] == 'websocket.connect':
pass
else:
if event['type'] not in ['http.request', 'websocket.connect']:
return {}

raw_uri = scope['path']
Expand All @@ -171,7 +170,7 @@ async def read(self, length=None):
else:
raw_uri += '?' + query_string
environ = {
'wsgi.input': AwaitablePayload(payload),
'wsgi.input': AwaitablePayload(event),
'wsgi.errors': sys.stderr,
'wsgi.version': (1, 0),
'wsgi.async': True,
Expand Down
3 changes: 2 additions & 1 deletion src/engineio/async_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,6 @@ async def _handle_connect(self, environ, transport, jsonp_index=None):
'maxPayload': self.max_http_buffer_size,
})
await s.send(pkt)
s.schedule_ping()

ret = await self._trigger_event('connect', sid, environ,
run_async=False)
Expand All @@ -471,6 +470,8 @@ async def _handle_connect(self, environ, transport, jsonp_index=None):
self.logger.warning('Application rejected connection')
return self._unauthorized(ret or None)

s.schedule_ping()

if transport == 'websocket':
ret = await s.handle_get_request(environ)
if s.closed and sid in self.sockets:
Expand Down
6 changes: 4 additions & 2 deletions src/engineio/async_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,10 +128,12 @@ async def close(self, wait=True, abort=False, reason=None):
await self.queue.join()

def schedule_ping(self):
self.server.start_background_task(self._send_ping)
# only schedule a new ping if the previous ping wait cycle completed
if self.last_ping:
self.last_ping = None
self.server.start_background_task(self._send_ping)

async def _send_ping(self):
self.last_ping = None
await asyncio.sleep(self.server.ping_interval)
if not self.closing and not self.closed:
self.last_ping = time.time()
Expand Down
5 changes: 4 additions & 1 deletion src/engineio/base_socket.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import time


class BaseSocket:
upgrade_protocols = ['websocket']

def __init__(self, server, sid):
self.server = server
self.sid = sid
self.queue = self.server.create_queue()
self.last_ping = None
self.last_ping = time.time()
self.connected = False
self.upgrading = False
self.upgraded = False
Expand Down
3 changes: 2 additions & 1 deletion src/engineio/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,6 @@ def _handle_connect(self, environ, start_response, transport,
'maxPayload': self.max_http_buffer_size,
})
s.send(pkt)
s.schedule_ping()

# NOTE: some sections below are marked as "no cover" to workaround
# what seems to be a bug in the coverage package. All the lines below
Expand All @@ -413,6 +412,8 @@ def _handle_connect(self, environ, start_response, transport,
self.logger.warning('Application rejected connection')
return self._unauthorized(ret or None)

s.schedule_ping()

if transport == 'websocket': # pragma: no cover
ret = s.handle_get_request(environ, start_response)
if s.closed and sid in self.sockets:
Expand Down
6 changes: 4 additions & 2 deletions src/engineio/socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,12 @@ def close(self, wait=True, abort=False, reason=None):
self.queue.join()

def schedule_ping(self):
self.server.start_background_task(self._send_ping)
# only schedule a new ping if the previous ping wait cycle completed
if self.last_ping:
self.last_ping = None
self.server.start_background_task(self._send_ping)

def _send_ping(self):
self.last_ping = None
self.server.sleep(self.server.ping_interval)
if not self.closing and not self.closed:
self.last_ping = time.time()
Expand Down
20 changes: 18 additions & 2 deletions tests/async/test_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,14 +95,30 @@ async def test_schedule_ping(self):
s = async_socket.AsyncSocket(mock_server, 'sid')
s.send = mock.AsyncMock()

async def schedule_ping():
async def schedule_ping_and_sleep():
s.schedule_ping()
await asyncio.sleep(0.05)

await schedule_ping()
await schedule_ping_and_sleep()
assert s.last_ping is not None
assert s.send.await_args_list[0][0][0].encode() == '2'

async def test_schedule_ping_twice(self):
mock_server = self._get_mock_server()
mock_server.ping_interval = 0.01
s = async_socket.AsyncSocket(mock_server, 'sid')
s.send = mock.AsyncMock()

async def schedule_ping_and_sleep():
s.schedule_ping()
await asyncio.sleep(0.05)

s.schedule_ping()
assert s.last_ping is None
await schedule_ping_and_sleep()
assert s.last_ping is not None
assert s.send.await_count == 1

async def test_schedule_ping_closed_socket(self):
mock_server = self._get_mock_server()
mock_server.ping_interval = 0.01
Expand Down
14 changes: 14 additions & 0 deletions tests/common/test_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def _get_mock_server(self):
mock_server.ping_interval_grace_period = 0.001
mock_server.async_handlers = True
mock_server.max_http_buffer_size = 128
mock_server.sleep = time.sleep

try:
import queue
Expand Down Expand Up @@ -104,6 +105,19 @@ def test_schedule_ping(self):
assert s.last_ping is not None
assert s.send.call_args_list[0][0][0].encode() == '2'

def test_schedule_ping_twice(self):
mock_server = self._get_mock_server()
mock_server.ping_interval = 0.1
s = socket.Socket(mock_server, 'sid')
s.send = mock.MagicMock()
s.schedule_ping()
assert s.last_ping is None
time.sleep(0.01)
s.schedule_ping()
time.sleep(0.1)
assert s.last_ping is not None
assert s.send.call_count == 1

def test_schedule_ping_closed_socket(self):
mock_server = self._get_mock_server()
mock_server.ping_interval = 0.01
Expand Down
Loading