Skip to content

Commit

Permalink
Optimize 'receive_json' for the happy path
Browse files Browse the repository at this point in the history
  • Loading branch information
davfsa committed Mar 6, 2022
1 parent b660381 commit e39deac
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 57 deletions.
81 changes: 43 additions & 38 deletions hikari/impl/shard.py
Expand Up @@ -177,53 +177,58 @@ async def send_json(
await self.send_str(pl, compress)

async def _receive_and_check(self, timeout: typing.Optional[float], /) -> str:
buff = bytearray()
message = await self.receive(timeout)

while True:
message = await self.receive(timeout)
if message.type == aiohttp.WSMsgType.TEXT:
return message.data # type: ignore

if message.type == aiohttp.WSMsgType.CLOSE:
close_code = int(message.data)
reason = message.extra
self.logger.error("connection closed with code %s (%s)", close_code, reason)

can_reconnect = close_code < 4000 or close_code in (
errors.ShardCloseCode.UNKNOWN_ERROR,
errors.ShardCloseCode.DECODE_ERROR,
errors.ShardCloseCode.INVALID_SEQ,
errors.ShardCloseCode.SESSION_TIMEOUT,
errors.ShardCloseCode.RATE_LIMITED,
)
elif message.type == aiohttp.WSMsgType.BINARY:
return await self._receive_complete_package(message.data, timeout)

raise errors.GatewayServerClosedConnectionError(reason, close_code, can_reconnect)
elif message.type == aiohttp.WSMsgType.CLOSE:
close_code = int(message.data)
reason = message.extra
self.logger.error("connection closed with code %s (%s)", close_code, reason)

elif message.type == aiohttp.WSMsgType.CLOSING or message.type == aiohttp.WSMsgType.CLOSED:
# May be caused by the server shutting us down.
# May be caused by Windows injecting an EOF if something disconnects, as some
# network drivers appear to do this.
raise errors.GatewayConnectionError("Socket has closed")
can_reconnect = close_code < 4000 or close_code in (
errors.ShardCloseCode.UNKNOWN_ERROR,
errors.ShardCloseCode.DECODE_ERROR,
errors.ShardCloseCode.INVALID_SEQ,
errors.ShardCloseCode.SESSION_TIMEOUT,
errors.ShardCloseCode.RATE_LIMITED,
)

elif len(buff) != 0 and message.type != aiohttp.WSMsgType.BINARY:
raise errors.GatewayError(f"Unexpected message type received {message.type.name}, expected BINARY")
raise errors.GatewayServerClosedConnectionError(reason, close_code, can_reconnect)

elif message.type == aiohttp.WSMsgType.BINARY:
buff.extend(message.data)
elif message.type == aiohttp.WSMsgType.CLOSING or message.type == aiohttp.WSMsgType.CLOSED:
# May be caused by the server shutting us down.
# May be caused by Windows injecting an EOF if something disconnects, as some
# network drivers appear to do this.
raise errors.GatewayConnectionError("Socket has closed")

if buff.endswith(b"\x00\x00\xff\xff"):
return self.zlib.decompress(buff).decode("utf-8")
else:
# Assume exception for now.
ex = self.exception()
self.logger.warning(
"encountered unexpected error: %s",
ex,
exc_info=ex if self.logger.isEnabledFor(logging.DEBUG) else None,
)
raise errors.GatewayError("Unexpected websocket exception from gateway") from ex

elif message.type == aiohttp.WSMsgType.TEXT:
return message.data # type: ignore
async def _receive_complete_package(self, initial_data: bytes, timeout: typing.Optional[float], /) -> str:
buff = bytearray(initial_data)

else:
# Assume exception for now.
ex = self.exception()
self.logger.warning(
"encountered unexpected error: %s",
ex,
exc_info=ex if self.logger.isEnabledFor(logging.DEBUG) else None,
)
raise errors.GatewayError("Unexpected websocket exception from gateway") from ex
while True:
message = await self.receive(timeout)

if message.type != aiohttp.WSMsgType.BINARY:
raise errors.GatewayError(f"Unexpected message type received {message.type.name}, expected BINARY")

buff.extend(message.data)

if buff.endswith(b"\x00\x00\xff\xff"):
return self.zlib.decompress(buff).decode("utf-8")

@classmethod
@contextlib.asynccontextmanager
Expand Down
49 changes: 30 additions & 19 deletions tests/hikari/impl/test_shard.py
Expand Up @@ -206,27 +206,14 @@ async def test__receive_and_check_when_message_type_is_CLOSED(self, transport_im

@pytest.mark.asyncio()
async def test__receive_and_check_when_message_type_is_BINARY(self, transport_impl):
response1 = StubResponse(type=aiohttp.WSMsgType.BINARY, data=b"some")
response2 = StubResponse(type=aiohttp.WSMsgType.BINARY, data=b"data")
response3 = StubResponse(type=aiohttp.WSMsgType.BINARY, data=b"\x00\x00\xff\xff")
transport_impl.receive = mock.AsyncMock(side_effect=[response1, response2, response3])
transport_impl.zlib = mock.Mock(decompress=mock.Mock(return_value=b"utf-8 encoded bytes"))

assert await transport_impl._receive_and_check(10) == "utf-8 encoded bytes"
response = StubResponse(type=aiohttp.WSMsgType.BINARY, data=b"some initial data")
transport_impl.receive = mock.AsyncMock(return_value=response)
transport_impl._receive_complete_package = mock.AsyncMock()

transport_impl.receive.assert_awaited_with(10)
transport_impl.zlib.decompress.assert_called_once_with(bytearray(b"somedata\x00\x00\xff\xff"))

@pytest.mark.asyncio()
async def test__receive_and_check_when_buff_but_next_is_not_BINARY(self, transport_impl):
response1 = StubResponse(type=aiohttp.WSMsgType.BINARY, data=b"some")
response2 = StubResponse(type=aiohttp.WSMsgType.TEXT)
transport_impl.receive = mock.AsyncMock(side_effect=[response1, response2])
assert await transport_impl._receive_and_check(10) == transport_impl._receive_complete_package.return_value

with pytest.raises(errors.GatewayError, match="Unexpected message type received TEXT, expected BINARY"):
await transport_impl._receive_and_check(10)

transport_impl.receive.assert_awaited_with(10)
transport_impl.receive.assert_awaited_once_with(10)
transport_impl._receive_complete_package.assert_awaited_once_with(b"some initial data", 10)

@pytest.mark.asyncio()
async def test__receive_and_check_when_message_type_is_TEXT(self, transport_impl):
Expand All @@ -248,6 +235,30 @@ async def test__receive_and_check_when_message_type_is_unknown(self, transport_i

transport_impl.receive.assert_awaited_once_with(10)

@pytest.mark.asyncio()
async def test__receive_complete_package_when_not_BINARY(self, transport_impl):
response = StubResponse(type=aiohttp.WSMsgType.TEXT, data="not binary")
transport_impl.receive = mock.AsyncMock(return_value=response)

with pytest.raises(errors.GatewayError, match="Unexpected message type received TEXT, expected BINARY"):
await transport_impl._receive_complete_package(b"some", 10)

transport_impl.receive.assert_awaited_with(10)

@pytest.mark.asyncio()
async def test__receive_complete_package(self, transport_impl):
response1 = StubResponse(type=aiohttp.WSMsgType.BINARY, data=b"more")
response2 = StubResponse(type=aiohttp.WSMsgType.BINARY, data=b"data")
response3 = StubResponse(type=aiohttp.WSMsgType.BINARY, data=b"\x00\x00\xff\xff")
transport_impl.receive = mock.AsyncMock(side_effect=[response1, response2, response3])
transport_impl.zlib = mock.Mock(decompress=mock.Mock(return_value=b"decoded utf-8 encoded bytes"))

assert await transport_impl._receive_complete_package(b"some", 10) == "decoded utf-8 encoded bytes"

assert transport_impl.receive.call_count == 3
transport_impl.receive.assert_has_awaits([mock.call(10), mock.call(10), mock.call(10)])
transport_impl.zlib.decompress.assert_called_once_with(bytearray(b"somemoredata\x00\x00\xff\xff"))

@pytest.mark.asyncio()
async def test_connect_yields_websocket(self, http_settings, proxy_settings):
class MockWS(hikari_test_helpers.AsyncContextManagerMock, shard._GatewayTransport):
Expand Down

0 comments on commit e39deac

Please sign in to comment.