Skip to content

Commit

Permalink
address comments from PR #3378: clear caches + add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
fselmo committed May 9, 2024
1 parent ffef012 commit bc5601b
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 11 deletions.
30 changes: 30 additions & 0 deletions tests/core/providers/test_async_ipc_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,36 @@ def test_get_endpoint_uri_or_ipc_path_returns_ipc_path():
# -- async -- #


@pytest.mark.asyncio
async def test_disconnect_cleanup(
simple_ipc_server,
jsonrpc_ipc_pipe_path,
):
w3 = await AsyncWeb3(AsyncIPCProvider(pathlib.Path(jsonrpc_ipc_pipe_path)))
provider = w3.provider

assert provider._message_listener_task is not None
assert provider._reader is not None
assert provider._writer is not None

# put some items in each cache
provider._request_processor._request_response_cache.cache("0", "0x1337")
provider._request_processor._request_information_cache.cache("0", "0x1337")
provider._request_processor._subscription_response_queue.put_nowait({"id": "0"})
assert len(provider._request_processor._request_response_cache) == 1
assert len(provider._request_processor._request_information_cache) == 1
assert provider._request_processor._subscription_response_queue.qsize() == 1

await w3.provider.disconnect()

assert not provider._message_listener_task
assert not w3.provider._reader
assert not w3.provider._writer
assert len(provider._request_processor._request_response_cache) == 0
assert len(provider._request_processor._request_information_cache) == 0
assert provider._request_processor._subscription_response_queue.empty()


async def _raise_connection_closed(*_args, **_kwargs):
raise ConnectionClosed(None, None)

Expand Down
51 changes: 41 additions & 10 deletions tests/core/providers/test_websocket_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,13 @@

def _mock_ws(provider):
provider._ws = AsyncMock()
provider._ws.closed = False


async def _coro():
return None
async def _mocked_ws_conn():
_conn = AsyncMock()
_conn.closed = False
return _conn


class WSException(Exception):
Expand All @@ -58,13 +61,42 @@ def test_get_endpoint_uri_or_ipc_path_returns_endpoint_uri():
# -- async -- #


@pytest.mark.asyncio
async def test_disconnect_cleanup():
provider = WebSocketProvider("ws://mocked")

with patch(
"web3.providers.persistent.websocket.connect",
new=lambda *_1, **_2: _mocked_ws_conn(),
):
await provider.connect()

assert provider._ws is not None
assert provider._message_listener_task is not None

# put some items in each cache
provider._request_processor._request_response_cache.cache("0", "0x1337")
provider._request_processor._request_information_cache.cache("0", "0x1337")
provider._request_processor._subscription_response_queue.put_nowait({"id": "0"})
assert len(provider._request_processor._request_response_cache) == 1
assert len(provider._request_processor._request_information_cache) == 1
assert provider._request_processor._subscription_response_queue.qsize() == 1

await provider.disconnect()

assert provider._ws is None
assert len(provider._request_processor._request_response_cache) == 0
assert len(provider._request_processor._request_information_cache) == 0
assert provider._request_processor._subscription_response_queue.empty()


@pytest.mark.asyncio
async def test_async_make_request_returns_desired_response():
provider = WebSocketProvider("ws://mocked")

with patch(
"web3.providers.persistent.websocket.connect",
new=lambda *_1, **_2: _coro(),
new=lambda *_1, **_2: _mocked_ws_conn(),
):
await provider.connect()

Expand Down Expand Up @@ -125,15 +157,15 @@ async def test_async_make_request_times_out_of_while_loop_looking_for_response()


@pytest.mark.asyncio
async def test_msg_listener_task_starts_on_provider_connect_and_cancels_on_disconnect():
async def test_msg_listener_task_starts_on_provider_connect_and_clears_on_disconnect():
provider = WebSocketProvider("ws://mocked")
_mock_ws(provider)

assert provider._message_listener_task is None

with patch(
"web3.providers.persistent.websocket.connect",
new=lambda *_1, **_2: _coro(),
new=lambda *_1, **_2: _mocked_ws_conn(),
):
await provider.connect() # connect

Expand All @@ -142,8 +174,7 @@ async def test_msg_listener_task_starts_on_provider_connect_and_cancels_on_disco

await provider.disconnect() # disconnect

assert provider._message_listener_task.cancelled()
assert provider._message_listener_task.done()
assert not provider._message_listener_task


@pytest.mark.asyncio
Expand All @@ -153,7 +184,7 @@ async def test_msg_listener_task_raises_exceptions_by_default():

with patch(
"web3.providers.persistent.websocket.connect",
new=lambda *_1, **_2: _coro(),
new=lambda *_1, **_2: _mocked_ws_conn(),
):
await provider.connect()
assert provider._message_listener_task is not None
Expand All @@ -177,7 +208,7 @@ async def test_msg_listener_task_silences_exceptions_and_error_logs_when_configu

with patch(
"web3.providers.persistent.websocket.connect",
new=lambda *_1, **_2: _coro(),
new=lambda *_1, **_2: _mocked_ws_conn(),
):
await provider.connect()
assert provider._message_listener_task is not None
Expand Down Expand Up @@ -210,7 +241,7 @@ async def test_listen_event_awaits_msg_processing_when_subscription_queue_is_ful
"""
with patch(
"web3.providers.persistent.websocket.connect",
new=lambda *_1, **_2: _coro(),
new=lambda *_1, **_2: _mocked_ws_conn(),
):
async_w3 = await AsyncWeb3(WebSocketProvider("ws://mocked"))

Expand Down
2 changes: 1 addition & 1 deletion web3/_utils/module_testing/module_testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ async def __anext__(self) -> bytes:
return self.messages.popleft()

@staticmethod
async def pong() -> bool:
async def pong() -> Literal[False]:
return False

async def connect(self) -> None:
Expand Down
2 changes: 2 additions & 0 deletions web3/providers/persistent/persistent.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,11 @@ async def disconnect(self) -> None:
except (asyncio.CancelledError, StopAsyncIteration, ConnectionClosed):
pass
finally:
self._message_listener_task = None
self.logger.info("Message listener background task successfully shut down.")

await self._provider_specific_disconnect()
self._request_processor.clear_caches()
self.logger.info(
f"Successfully disconnected from: {self.get_endpoint_uri_or_ipc_path()}"
)
Expand Down

0 comments on commit bc5601b

Please sign in to comment.