diff --git a/src/engineio/asyncio_socket.py b/src/engineio/asyncio_socket.py index 508ee3c..17265c2 100644 --- a/src/engineio/asyncio_socket.py +++ b/src/engineio/asyncio_socket.py @@ -143,12 +143,18 @@ async def _upgrade_websocket(self, environ): async def _websocket_handler(self, ws): """Engine.IO handler for websocket transport.""" + async def websocket_wait(): + data = await ws.wait() + if data and len(data) > self.server.max_http_buffer_size: + raise ValueError('packet is too large') + return data + if self.connected: # the socket was already connected, so this is an upgrade self.upgrading = True # hold packet sends during the upgrade try: - pkt = await ws.wait() + pkt = await websocket_wait() except IOError: # pragma: no cover return decoded_pkt = packet.Packet(encoded_packet=pkt) @@ -162,7 +168,7 @@ async def _websocket_handler(self, ws): await self.queue.put(packet.Packet(packet.NOOP)) # end poll try: - pkt = await ws.wait() + pkt = await websocket_wait() except IOError: # pragma: no cover self.upgrading = False return @@ -204,7 +210,7 @@ async def writer(): while True: p = None - wait_task = asyncio.ensure_future(ws.wait()) + wait_task = asyncio.ensure_future(websocket_wait()) try: p = await asyncio.wait_for( wait_task, diff --git a/src/engineio/socket.py b/src/engineio/socket.py index 1434b19..be0c83f 100644 --- a/src/engineio/socket.py +++ b/src/engineio/socket.py @@ -159,6 +159,12 @@ def _upgrade_websocket(self, environ, start_response): def _websocket_handler(self, ws): """Engine.IO handler for websocket transport.""" + def websocket_wait(): + data = ws.wait() + if data and len(data) > self.server.max_http_buffer_size: + raise ValueError('packet is too large') + return data + # try to set a socket timeout matching the configured ping interval # and timeout for attr in ['_sock', 'socket']: # pragma: no cover @@ -170,7 +176,7 @@ def _websocket_handler(self, ws): # the socket was already connected, so this is an upgrade self.upgrading = True # hold packet sends during the upgrade - pkt = ws.wait() + pkt = websocket_wait() decoded_pkt = packet.Packet(encoded_packet=pkt) if decoded_pkt.packet_type != packet.PING or \ decoded_pkt.data != 'probe': @@ -181,7 +187,7 @@ def _websocket_handler(self, ws): ws.send(packet.Packet(packet.PONG, data='probe').encode()) self.queue.put(packet.Packet(packet.NOOP)) # end poll - pkt = ws.wait() + pkt = websocket_wait() decoded_pkt = packet.Packet(encoded_packet=pkt) if decoded_pkt.packet_type != packet.UPGRADE: self.upgraded = False @@ -221,7 +227,7 @@ def writer(): while True: p = None try: - p = ws.wait() + p = websocket_wait() except Exception as e: # if the socket is already closed, we can assume this is a # downstream error of that diff --git a/tests/asyncio/test_asyncio_socket.py b/tests/asyncio/test_asyncio_socket.py index 147c1b7..6af50e9 100644 --- a/tests/asyncio/test_asyncio_socket.py +++ b/tests/asyncio/test_asyncio_socket.py @@ -42,6 +42,7 @@ def _get_mock_server(self): mock_server.ping_interval = 0.2 mock_server.ping_interval_grace_period = 0.001 mock_server.async_handlers = False + mock_server.max_http_buffer_size = 128 mock_server._async = { 'asyncio': True, 'create_route': mock.MagicMock(), @@ -456,6 +457,23 @@ def test_websocket_read_write_wait_fail(self): _run(s._websocket_handler(ws)) assert s.closed + def test_websocket_upgrade_with_large_packet(self): + mock_server = self._get_mock_server() + s = asyncio_socket.AsyncSocket(mock_server, 'sid') + s.connected = True + s.queue.join = AsyncMock(return_value=None) + probe = 'probe' + ws = mock.MagicMock() + ws.send = AsyncMock() + ws.wait = AsyncMock() + ws.wait.mock.side_effect = [ + packet.Packet(packet.PING, data=probe).encode(), + packet.Packet(packet.UPGRADE, data='2' * 128).encode(), + ] + with pytest.raises(ValueError): + _run(s._websocket_handler(ws)) + assert not s.upgraded + def test_websocket_ignore_invalid_packet(self): mock_server = self._get_mock_server() s = asyncio_socket.AsyncSocket(mock_server, 'sid') diff --git a/tests/common/test_socket.py b/tests/common/test_socket.py index 33785e9..84e4317 100644 --- a/tests/common/test_socket.py +++ b/tests/common/test_socket.py @@ -21,6 +21,7 @@ def _get_mock_server(self): mock_server.ping_interval = 0.2 mock_server.ping_interval_grace_period = 0.001 mock_server.async_handlers = True + mock_server.max_http_buffer_size = 128 try: import queue @@ -444,6 +445,22 @@ def test_websocket_read_write_wait_fail(self): self._join_bg_tasks() assert s.closed + def test_websocket_upgrade_with_large_packet(self): + mock_server = self._get_mock_server() + s = socket.Socket(mock_server, 'sid') + s.connected = True + s.queue.join = mock.MagicMock(return_value=None) + probe = 'probe' + ws = mock.MagicMock() + ws.wait.side_effect = [ + packet.Packet(packet.PING, data=probe).encode(), + packet.Packet(packet.UPGRADE, data='2' * 128).encode(), + ] + with pytest.raises(ValueError): + s._websocket_handler(ws) + self._join_bg_tasks() + assert not s.upgraded + def test_websocket_ignore_invalid_packet(self): mock_server = self._get_mock_server() s = socket.Socket(mock_server, 'sid')