diff --git a/packages/stompman/stompman/connection_manager.py b/packages/stompman/stompman/connection_manager.py index 0ee1258..4b4a464 100644 --- a/packages/stompman/stompman/connection_manager.py +++ b/packages/stompman/stompman/connection_manager.py @@ -55,10 +55,12 @@ class ConnectionManager: _reconnect_lock: asyncio.Lock = field(init=False, default_factory=asyncio.Lock) _task_group: asyncio.TaskGroup = field(init=False, default_factory=asyncio.TaskGroup) _send_heartbeat_task: asyncio.Task[None] = field(init=False, repr=False) + _check_server_heartbeat_task: asyncio.Task[None] = field(init=False, repr=False) async def __aenter__(self) -> Self: await self._task_group.__aenter__() self._send_heartbeat_task = self._task_group.create_task(asyncio.sleep(0)) + self._check_server_heartbeat_task = self._task_group.create_task(asyncio.sleep(0)) self._active_connection_state = await self._get_active_connection_state(is_initial_call=True) return self @@ -66,7 +68,8 @@ async def __aexit__( self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None ) -> None: self._send_heartbeat_task.cancel() - await asyncio.wait([self._send_heartbeat_task]) + self._check_server_heartbeat_task.cancel() + await asyncio.wait([self._send_heartbeat_task, self._check_server_heartbeat_task]) await self._task_group.__aexit__(exc_type, exc_value, traceback) if not self._active_connection_state: @@ -77,11 +80,15 @@ async def __aexit__( return await self._active_connection_state.connection.close() - def _restart_heartbeat_task(self, server_heartbeat: Heartbeat) -> None: + def _restart_heartbeat_tasks(self, server_heartbeat: Heartbeat) -> None: self._send_heartbeat_task.cancel() + self._check_server_heartbeat_task.cancel() self._send_heartbeat_task = self._task_group.create_task( self._send_heartbeats_forever(server_heartbeat.want_to_receive_interval_ms) ) + self._check_server_heartbeat_task = self._task_group.create_task( + self._check_server_heartbeat_forever(server_heartbeat.will_send_interval_ms) + ) async def _send_heartbeats_forever(self, send_heartbeat_interval_ms: int) -> None: send_heartbeat_interval_seconds = send_heartbeat_interval_ms / 1000 @@ -89,6 +96,15 @@ async def _send_heartbeats_forever(self, send_heartbeat_interval_ms: int) -> Non await self.write_heartbeat_reconnecting() await asyncio.sleep(send_heartbeat_interval_seconds) + async def _check_server_heartbeat_forever(self, receive_heartbeat_interval_ms: int) -> None: + receive_heartbeat_interval_seconds = receive_heartbeat_interval_ms / 1000 + while True: + await asyncio.sleep(receive_heartbeat_interval_seconds * self.check_server_alive_interval_factor) + if not self._active_connection_state: + continue + if not self._active_connection_state.is_alive(self.check_server_alive_interval_factor): + self._clear_active_connection_state() + async def _create_connection_to_one_server( self, server: ConnectionParameters ) -> tuple[AbstractConnection, ConnectionParameters] | None: @@ -119,7 +135,7 @@ async def _connect_to_any_server(self) -> ActiveConnectionState | AnyConnectionI lifespan = self.lifespan_factory( connection=connection, connection_parameters=connection_parameters, - set_heartbeat_interval=self._restart_heartbeat_task, + set_heartbeat_interval=self._restart_heartbeat_tasks, ) try: diff --git a/packages/stompman/test_stompman/test_connection_lifespan.py b/packages/stompman/test_stompman/test_connection_lifespan.py index 3b48095..09d73f7 100644 --- a/packages/stompman/test_stompman/test_connection_lifespan.py +++ b/packages/stompman/test_stompman/test_connection_lifespan.py @@ -148,7 +148,7 @@ async def mock_sleep(delay: float) -> None: async with EnrichedClient(connection_class=connection_class): await real_sleep(0) - assert sleep_calls == [0, 1, 1, 1] + assert sleep_calls == [0, 0, 1, 3, 1, 3, 1, 3] assert write_heartbeat_mock.mock_calls == [mock.call(), mock.call(), mock.call(), mock.call()] diff --git a/packages/stompman/test_stompman/test_connection_manager.py b/packages/stompman/test_stompman/test_connection_manager.py index ff51aff..533ad42 100644 --- a/packages/stompman/test_stompman/test_connection_manager.py +++ b/packages/stompman/test_stompman/test_connection_manager.py @@ -143,12 +143,12 @@ async def test_get_active_connection_state_lifespan_flaky_ok() -> None: mock.call( connection=BaseMockConnection(), connection_parameters=manager.servers[0], - set_heartbeat_interval=manager._restart_heartbeat_task, + set_heartbeat_interval=manager._restart_heartbeat_tasks, ), mock.call( connection=BaseMockConnection(), connection_parameters=manager.servers[0], - set_heartbeat_interval=manager._restart_heartbeat_task, + set_heartbeat_interval=manager._restart_heartbeat_tasks, ), ]