Skip to content

Commit

Permalink
Optimize 'receive_json' for the happy path
Browse files Browse the repository at this point in the history
  • Loading branch information
davfsa committed Mar 6, 2022
1 parent b660381 commit c5492b9
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 57 deletions.
1 change: 1 addition & 0 deletions changes/1058.feature.md
@@ -0,0 +1 @@
Optimize receiving websocket JSON for the happy path.
86 changes: 47 additions & 39 deletions hikari/impl/shard.py
Expand Up @@ -177,53 +177,61 @@ async def send_json(
await self.send_str(pl, compress)

async def _receive_and_check(self, timeout: typing.Optional[float], /) -> str:
buff = bytearray()
message = await self.receive(timeout)

if message.type == aiohttp.WSMsgType.TEXT:
assert isinstance(message.data, str)
return message.data

elif message.type == aiohttp.WSMsgType.BINARY:
return await self._receive_and_check_complete_zlib_package(message.data, timeout)

elif message.type == aiohttp.WSMsgType.CLOSE:
close_code = int(message.data)
reason = message.extra
self.logger.error("connection closed with code %s (%s)", close_code, reason)

can_reconnect = close_code < 4000 or close_code in (
errors.ShardCloseCode.UNKNOWN_ERROR,
errors.ShardCloseCode.DECODE_ERROR,
errors.ShardCloseCode.INVALID_SEQ,
errors.ShardCloseCode.SESSION_TIMEOUT,
errors.ShardCloseCode.RATE_LIMITED,
)

while True:
message = await self.receive(timeout)
raise errors.GatewayServerClosedConnectionError(reason, close_code, can_reconnect)

if message.type == aiohttp.WSMsgType.CLOSE:
close_code = int(message.data)
reason = message.extra
self.logger.error("connection closed with code %s (%s)", close_code, reason)

can_reconnect = close_code < 4000 or close_code in (
errors.ShardCloseCode.UNKNOWN_ERROR,
errors.ShardCloseCode.DECODE_ERROR,
errors.ShardCloseCode.INVALID_SEQ,
errors.ShardCloseCode.SESSION_TIMEOUT,
errors.ShardCloseCode.RATE_LIMITED,
)
elif message.type == aiohttp.WSMsgType.CLOSING or message.type == aiohttp.WSMsgType.CLOSED:
# May be caused by the server shutting us down.
# May be caused by Windows injecting an EOF if something disconnects, as some
# network drivers appear to do this.
raise errors.GatewayConnectionError("Socket has closed")

raise errors.GatewayServerClosedConnectionError(reason, close_code, can_reconnect)
else:
# Assume exception for now.
ex = self.exception()
self.logger.warning(
"encountered unexpected error: %s",
ex,
exc_info=ex if self.logger.isEnabledFor(logging.DEBUG) else None,
)
raise errors.GatewayError("Unexpected websocket exception from gateway") from ex

elif message.type == aiohttp.WSMsgType.CLOSING or message.type == aiohttp.WSMsgType.CLOSED:
# May be caused by the server shutting us down.
# May be caused by Windows injecting an EOF if something disconnects, as some
# network drivers appear to do this.
raise errors.GatewayConnectionError("Socket has closed")
async def _receive_and_check_complete_zlib_package(
self, initial_data: bytes, timeout: typing.Optional[float], /
) -> str:
buff = bytearray(initial_data)

elif len(buff) != 0 and message.type != aiohttp.WSMsgType.BINARY:
raise errors.GatewayError(f"Unexpected message type received {message.type.name}, expected BINARY")

elif message.type == aiohttp.WSMsgType.BINARY:
buff.extend(message.data)
while True:
message = await self.receive(timeout)

if buff.endswith(b"\x00\x00\xff\xff"):
return self.zlib.decompress(buff).decode("utf-8")
if message.type != aiohttp.WSMsgType.BINARY:
raise errors.GatewayError(f"Unexpected message type received {message.type.name}, expected BINARY")

elif message.type == aiohttp.WSMsgType.TEXT:
return message.data # type: ignore
buff.extend(message.data)

else:
# Assume exception for now.
ex = self.exception()
self.logger.warning(
"encountered unexpected error: %s",
ex,
exc_info=ex if self.logger.isEnabledFor(logging.DEBUG) else None,
)
raise errors.GatewayError("Unexpected websocket exception from gateway") from ex
if buff.endswith(b"\x00\x00\xff\xff"):
return self.zlib.decompress(buff).decode("utf-8")

@classmethod
@contextlib.asynccontextmanager
Expand Down
52 changes: 34 additions & 18 deletions tests/hikari/impl/test_shard.py
Expand Up @@ -206,27 +206,17 @@ async def test__receive_and_check_when_message_type_is_CLOSED(self, transport_im

@pytest.mark.asyncio()
async def test__receive_and_check_when_message_type_is_BINARY(self, transport_impl):
response1 = StubResponse(type=aiohttp.WSMsgType.BINARY, data=b"some")
response2 = StubResponse(type=aiohttp.WSMsgType.BINARY, data=b"data")
response3 = StubResponse(type=aiohttp.WSMsgType.BINARY, data=b"\x00\x00\xff\xff")
transport_impl.receive = mock.AsyncMock(side_effect=[response1, response2, response3])
transport_impl.zlib = mock.Mock(decompress=mock.Mock(return_value=b"utf-8 encoded bytes"))

assert await transport_impl._receive_and_check(10) == "utf-8 encoded bytes"
response = StubResponse(type=aiohttp.WSMsgType.BINARY, data=b"some initial data")
transport_impl.receive = mock.AsyncMock(return_value=response)
transport_impl._receive_and_check_complete_zlib_package = mock.AsyncMock()

transport_impl.receive.assert_awaited_with(10)
transport_impl.zlib.decompress.assert_called_once_with(bytearray(b"somedata\x00\x00\xff\xff"))

@pytest.mark.asyncio()
async def test__receive_and_check_when_buff_but_next_is_not_BINARY(self, transport_impl):
response1 = StubResponse(type=aiohttp.WSMsgType.BINARY, data=b"some")
response2 = StubResponse(type=aiohttp.WSMsgType.TEXT)
transport_impl.receive = mock.AsyncMock(side_effect=[response1, response2])

with pytest.raises(errors.GatewayError, match="Unexpected message type received TEXT, expected BINARY"):
assert (
await transport_impl._receive_and_check(10)
== transport_impl._receive_and_check_complete_zlib_package.return_value
)

transport_impl.receive.assert_awaited_with(10)
transport_impl.receive.assert_awaited_once_with(10)
transport_impl._receive_and_check_complete_zlib_package.assert_awaited_once_with(b"some initial data", 10)

@pytest.mark.asyncio()
async def test__receive_and_check_when_message_type_is_TEXT(self, transport_impl):
Expand All @@ -248,6 +238,32 @@ async def test__receive_and_check_when_message_type_is_unknown(self, transport_i

transport_impl.receive.assert_awaited_once_with(10)

@pytest.mark.asyncio()
async def test__receive_and_check_complete_zlib_package_when_not_BINARY(self, transport_impl):
response = StubResponse(type=aiohttp.WSMsgType.TEXT, data="not binary")
transport_impl.receive = mock.AsyncMock(return_value=response)

with pytest.raises(errors.GatewayError, match="Unexpected message type received TEXT, expected BINARY"):
await transport_impl._receive_and_check_complete_zlib_package(b"some", 10)

transport_impl.receive.assert_awaited_with(10)

@pytest.mark.asyncio()
async def test__receive_and_check_complete_zlib_package(self, transport_impl):
response1 = StubResponse(type=aiohttp.WSMsgType.BINARY, data=b"more")
response2 = StubResponse(type=aiohttp.WSMsgType.BINARY, data=b"data")
response3 = StubResponse(type=aiohttp.WSMsgType.BINARY, data=b"\x00\x00\xff\xff")
transport_impl.receive = mock.AsyncMock(side_effect=[response1, response2, response3])
transport_impl.zlib = mock.Mock(decompress=mock.Mock(return_value=b"decoded utf-8 encoded bytes"))

assert (
await transport_impl._receive_and_check_complete_zlib_package(b"some", 10) == "decoded utf-8 encoded bytes"
)

assert transport_impl.receive.call_count == 3
transport_impl.receive.assert_has_awaits([mock.call(10), mock.call(10), mock.call(10)])
transport_impl.zlib.decompress.assert_called_once_with(bytearray(b"somemoredata\x00\x00\xff\xff"))

@pytest.mark.asyncio()
async def test_connect_yields_websocket(self, http_settings, proxy_settings):
class MockWS(hikari_test_helpers.AsyncContextManagerMock, shard._GatewayTransport):
Expand Down

0 comments on commit c5492b9

Please sign in to comment.