diff --git a/async_substrate_interface/async_substrate.py b/async_substrate_interface/async_substrate.py index d7df725..5fce7b8 100644 --- a/async_substrate_interface/async_substrate.py +++ b/async_substrate_interface/async_substrate.py @@ -8,6 +8,7 @@ import inspect import logging import os +import socket import ssl import warnings from contextlib import suppress @@ -34,6 +35,7 @@ ss58_encode, MultiAccountId, ) +from websockets import CloseCode from websockets.asyncio.client import connect, ClientConnection from websockets.exceptions import ( ConnectionClosed, @@ -130,6 +132,16 @@ def __init__( self.__weight = None self.__total_fee_amount = None + def __str__(self): + return ( + f"AsyncExtrinsicReceipt({self.extrinsic_hash}), " + f"block_hash={self.block_hash}, block_number={self.block_number}), " + f"finalized={self.finalized})" + ) + + def __repr__(self): + return self.__str__() + async def get_extrinsic_identifier(self) -> str: """ Returns the on-chain identifier for this extrinsic in format "[block_number]-[extrinsic_idx]" e.g. 134324-2 @@ -571,6 +583,8 @@ def __init__( self._log_raw_websockets = _log_raw_websockets self._in_use_ids = set() self._max_retries = max_retries + self._last_activity = asyncio.Event() + self._last_activity.set() @property def state(self): @@ -588,44 +602,124 @@ async def __aenter__(self): async def loop_time() -> float: return asyncio.get_running_loop().time() + async def _reset_activity_timer(self): + """Reset the shared activity timeout""" + # Create a NEW event instead of reusing the same one + old_event = self._last_activity + self._last_activity = asyncio.Event() + self._last_activity.clear() # Start fresh + old_event.set() # Wake up anyone waiting on the old event + + async def _wait_with_activity_timeout(self, coro, timeout: float): + """ + Wait for a coroutine with a shared activity timeout. + Returns the result or raises TimeoutError if no activity for timeout seconds. + """ + activity_task = asyncio.create_task(self._last_activity.wait()) + + if isinstance(coro, asyncio.Task): + main_task = coro + else: + main_task = asyncio.create_task(coro) + + try: + done, pending = await asyncio.wait( + [main_task, activity_task], + timeout=timeout, + return_when=asyncio.FIRST_COMPLETED, + ) + + if not done: + logger.debug(f"Activity timeout after {timeout}s, no activity detected") + for task in pending: + task.cancel() + raise TimeoutError() + + if main_task in done: + activity_task.cancel() + + exc = main_task.exception() + if exc is not None: + raise exc + else: + return main_task.result() + else: + logger.debug("Activity detected, resetting timeout") + return await self._wait_with_activity_timeout(main_task, timeout) + + except asyncio.CancelledError: + main_task.cancel() + activity_task.cancel() + raise + async def _cancel(self): try: - self._send_recv_task.cancel() - await self.ws.close() - except ( - AttributeError, - asyncio.CancelledError, - WebSocketException, - ): + logger.debug("Cancelling send/recv tasks") + if self._send_recv_task is not None: + self._send_recv_task.cancel() + except asyncio.CancelledError: pass except Exception as e: logger.warning( f"{e} encountered while trying to close websocket connection." ) + try: + logger.debug("Closing websocket connection") + if self.ws is not None: + await self.ws.close() + except Exception as e: + logger.warning( + f"{e} encountered while trying to close websocket connection." + ) async def connect(self, force=False): - async with self._lock: - logger.debug(f"Websocket connecting to {self.ws_url}") - if self._sending is None or self._sending.empty(): - self._sending = asyncio.Queue() - if self._exit_task: - self._exit_task.cancel() - if self.state not in (State.OPEN, State.CONNECTING) or force: + if not force: + await self._lock.acquire() + else: + logger.debug("Proceeding without acquiring lock.") + logger.debug(f"Websocket connecting to {self.ws_url}") + if self._sending is None or self._sending.empty(): + self._sending = asyncio.Queue() + if self._exit_task: + self._exit_task.cancel() + logger.debug(f"self.state={self.state}") + if force and self.state == State.OPEN: + logger.debug(f"Attempting to reconnect while already connected.") + if self.ws is not None: + self.ws.protocol.fail(CloseCode.SERVICE_RESTART) + logger.debug(f"Open connection cancelled.") + await asyncio.sleep(1) + if self.state not in (State.OPEN, State.CONNECTING) or force: + if not force: try: + logger.debug("Attempting cancellation") await asyncio.wait_for(self._cancel(), timeout=10.0) except asyncio.TimeoutError: logger.debug(f"Timed out waiting for cancellation") pass - self.ws = await asyncio.wait_for( + logger.debug("Attempting connection") + try: + connection = await asyncio.wait_for( connect(self.ws_url, **self._options), timeout=10.0 ) - if self._send_recv_task is None or self._send_recv_task.done(): - self._send_recv_task = asyncio.get_running_loop().create_task( - self._handler(self.ws) - ) - logger.debug("Websocket handler attached.") + except socket.gaierror: + logger.debug(f"Hostname not known (this is just for testing") + await asyncio.sleep(10) + if self._lock.locked(): + self._lock.release() + return await self.connect(force=force) + logger.debug("Connection established") + self.ws = connection + if self._send_recv_task is None or self._send_recv_task.done(): + self._send_recv_task = asyncio.get_running_loop().create_task( + self._handler(self.ws) + ) + if self._lock.locked(): + self._lock.release() + return None async def _handler(self, ws: ClientConnection) -> Union[None, Exception]: + logger.debug("WS handler attached") recv_task = asyncio.create_task(self._start_receiving(ws)) send_task = asyncio.create_task(self._start_sending(ws)) done, pending = await asyncio.wait( @@ -635,34 +729,56 @@ async def _handler(self, ws: ClientConnection) -> Union[None, Exception]: loop = asyncio.get_running_loop() should_reconnect = False is_retry = False + for task in pending: task.cancel() + for task in done: task_res = task.result() + + # If ConnectionClosedOK, graceful shutdown - don't reconnect + if isinstance(task_res, websockets.exceptions.ConnectionClosedOK): + logger.debug("Graceful shutdown detected, not reconnecting") + return None # Clean exit + + # Check for timeout/connection errors that should trigger reconnect if isinstance( - task_res, (asyncio.TimeoutError, ConnectionClosed, TimeoutError) + task_res, (asyncio.TimeoutError, TimeoutError, ConnectionClosed) ): should_reconnect = True + logger.debug(f"Reconnection triggered by: {type(task_res).__name__}") + if isinstance(task_res, (asyncio.TimeoutError, TimeoutError)): self._attempts += 1 is_retry = True + if should_reconnect is True: if len(self._received_subscriptions) > 0: return SubstrateRequestException( f"Unable to reconnect because there are currently open subscriptions." ) - for original_id, payload in list(self._inflight.items()): - self._received[original_id] = loop.create_future() - to_send = json.loads(payload) - await self._sending.put(to_send) + if is_retry: - # Otherwise the connection was just closed due to no activity, which should not count against retries + if self._attempts >= self._max_retries: + logger.error("Max retries exceeded.") + return TimeoutError("Max retries exceeded.") logger.info( f"Timeout occurred. Reconnecting. Attempt {self._attempts} of {self._max_retries}" ) + + async with self._lock: + for original_id in list(self._inflight.keys()): + payload = self._inflight.pop(original_id) + self._received[original_id] = loop.create_future() + to_send = json.loads(payload) + logger.debug(f"Resubmitting {to_send['id']}") + await self._sending.put(to_send) + + logger.debug("Attempting reconnection...") await self.connect(True) - await self._handler(ws=self.ws) - return None + logger.debug(f"Reconnected. Send queue size: {self._sending.qsize()}") + # Recursively call handler + return await self._handler(self.ws) elif isinstance(e := recv_task.result(), Exception): return e elif isinstance(e := send_task.result(), Exception): @@ -699,6 +815,7 @@ async def _exit_with_timer(self): pass async def shutdown(self): + logger.debug("Shutdown requested") try: await asyncio.wait_for(self._cancel(), timeout=10.0) except asyncio.TimeoutError: @@ -712,11 +829,16 @@ async def _recv(self, recd: bytes) -> None: response = json.loads(recd) if "id" in response: async with self._lock: - self._inflight.pop(response["id"]) - with suppress(KeyError): - # These would be subscriptions that were unsubscribed + inflight_item = self._inflight.pop(response["id"], None) + if inflight_item is not None: + logger.debug(f"Popped {response['id']} from inflight") + else: + logger.debug( + f"Received response for {response['id']} which is no longer inflight (likely reconnection)" + ) + if self._received.get(response["id"]) is not None: self._received[response["id"]].set_result(response) - self._in_use_ids.remove(response["id"]) + self._in_use_ids.discard(response["id"]) elif "params" in response: sub_id = response["params"]["subscription"] if sub_id not in self._received_subscriptions: @@ -726,14 +848,18 @@ async def _recv(self, recd: bytes) -> None: raise KeyError(response) async def _start_receiving(self, ws: ClientConnection) -> Exception: + logger.debug("Starting receiving task") try: while True: - recd = await asyncio.wait_for( - ws.recv(decode=False), timeout=self.retry_timeout + recd = await self._wait_with_activity_timeout( + ws.recv(decode=False), self.retry_timeout ) - # reset the counter once we successfully receive something back + await self._reset_activity_timer() self._attempts = 0 await self._recv(recd) + except websockets.exceptions.ConnectionClosedOK as e: + logger.debug("ConnectionClosedOK") + return e except Exception as e: if isinstance(e, ssl.SSLError): e = ConnectionClosed @@ -745,17 +871,18 @@ async def _start_receiving(self, ws: ClientConnection) -> Exception: if not fut.done(): fut.set_exception(e) fut.cancel() - elif isinstance(e, websockets.exceptions.ConnectionClosedOK): - logger.debug("Websocket connection closed.") else: - logger.debug(f"Timeout occurred.") + logger.debug(f"Timeout/ConnectionClosed occurred.") return e async def _start_sending(self, ws) -> Exception: + logger.debug("Starting sending task") to_send = None try: while True: + logger.debug(f"_sending, {self._sending.qsize()}") to_send_ = await self._sending.get() + logger.debug("Retrieved item from sending queue") self._sending.task_done() send_id = to_send_["id"] to_send = json.dumps(to_send_) @@ -764,6 +891,8 @@ async def _start_sending(self, ws) -> Exception: if self._log_raw_websockets: raw_websocket_logger.debug(f"WEBSOCKET_SEND> {to_send}") await ws.send(to_send) + logger.debug("Sent to websocket") + await self._reset_activity_timer() except Exception as e: if isinstance(e, ssl.SSLError): e = ConnectionClosed @@ -772,8 +901,9 @@ async def _start_sending(self, ws) -> Exception: ): logger.exception("Websocket sending exception", exc_info=e) if to_send is not None: - self._received[to_send["id"]].set_exception(e) - self._received[to_send["id"]].cancel() + to_send_ = json.loads(to_send) + self._received[to_send_["id"]].set_exception(e) + self._received[to_send_["id"]].cancel() else: for i in self._received.keys(): self._received[i].set_exception(e) @@ -855,6 +985,12 @@ async def retrieve(self, item_id: str) -> Optional[dict]: return subscription except asyncio.QueueEmpty: pass + except KeyError: + logger.debug( + f"Received item {item_id} not in received subscriptions. " + f"This indicates the response of the subscription was inflight when sending " + f"the unsubscribe request." + ) if self._send_recv_task is not None and self._send_recv_task.done(): if not self._send_recv_task.cancelled(): if isinstance((e := self._send_recv_task.exception()), Exception): @@ -863,7 +999,7 @@ async def retrieve(self, item_id: str) -> Optional[dict]: elif isinstance((e := self._send_recv_task.result()), Exception): logger.exception(f"Websocket sending exception: {e}") raise e - await asyncio.sleep(0.1) + await asyncio.sleep(0.01) return None @@ -941,7 +1077,7 @@ def __init__( "strict_scale_decode": True, } self.initialized = False - self._forgettable_task = None + self._forgettable_tasks = set() self.type_registry = type_registry self.type_registry_preset = type_registry_preset self.runtime_cache = RuntimeCache() @@ -1401,11 +1537,13 @@ async def result_handler( if subscription_result is not None: # Handler returned end result: unsubscribe from further updates - self._forgettable_task = asyncio.create_task( + unsub_task = asyncio.create_task( self.rpc_request( "state_unsubscribeStorage", [subscription_id] ) ) + self._forgettable_tasks.add(unsub_task) + unsub_task.add_done_callback(self._forgettable_tasks.discard) return result_found, subscription_result @@ -1430,6 +1568,9 @@ async def retrieve_pending_extrinsics(self) -> list: result_data = await self.rpc_request("author_pendingExtrinsics", []) if "error" in result_data: + logger.error( + f"Error in retrieving pending extrinsics: {result_data['error']}" + ) raise SubstrateRequestException(result_data["error"]["message"]) extrinsics = [] @@ -2470,6 +2611,8 @@ async def _make_rpc_request( if request_manager.is_complete: break + else: + await asyncio.sleep(0.01) return request_manager.get_results() @@ -2553,10 +2696,12 @@ async def rpc_request( bh = err_msg.split("State already discarded for ")[1].strip() raise StateDiscardedError(bh) else: + logger.error(f"Substrate Request Exception: {result[payload_id]}") raise SubstrateRequestException(err_msg) if "result" in result[payload_id][0]: return result[payload_id][0] else: + logger.error(f"Substrate Request Exception: {result[payload_id]}") raise SubstrateRequestException(result[payload_id][0]) @cached_fetcher(max_size=SUBSTRATE_CACHE_METHOD_SIZE) diff --git a/async_substrate_interface/substrate_addons.py b/async_substrate_interface/substrate_addons.py index 3b0f0ba..b8cfd75 100644 --- a/async_substrate_interface/substrate_addons.py +++ b/async_substrate_interface/substrate_addons.py @@ -321,11 +321,11 @@ async def _reinstantiate_substrate( await self.ws.shutdown() except AttributeError: pass - if self._forgettable_task is not None: - self._forgettable_task: asyncio.Task - self._forgettable_task.cancel() + _forgettable_task: asyncio.Task + for _forgettable_task in self._forgettable_tasks: + _forgettable_task.cancel() try: - await self._forgettable_task + await _forgettable_task except asyncio.CancelledError: pass self.chain_endpoint = next_network diff --git a/tests/helpers/proxy_server.py b/tests/helpers/proxy_server.py new file mode 100644 index 0000000..c561289 --- /dev/null +++ b/tests/helpers/proxy_server.py @@ -0,0 +1,57 @@ +import logging +import time + +from websockets.sync.server import serve, ServerConnection +from websockets.sync.client import connect + +logger = logging.getLogger("websockets.proxy") + + +class ProxyServer: + def __init__(self, upstream: str, time_til_pause: float, time_til_resume: float): + self.upstream_server = upstream + self.time_til_pause = time_til_pause + self.time_til_resume = time_til_resume + self.upstream_connection = None + self.connection_time = 0 + self.shutdown_time = 0 + self.resume_time = 0 + + def connect(self): + self.upstream_connection = connect(self.upstream_server) + self.connection_time = time.time() + self.shutdown_time = self.connection_time + self.time_til_pause + self.resume_time = self.shutdown_time + self.time_til_resume + + def close(self): + if self.upstream_connection: + self.upstream_connection.close() + self.server.shutdown() + + def proxy_request(self, websocket: ServerConnection): + for message in websocket: + self.upstream_connection.send(message) + recd = self.upstream_connection.recv() + current_time = time.time() + if self.shutdown_time < current_time < self.resume_time: + logger.info("Pausing") + time.sleep(self.time_til_resume) + websocket.send(recd) + + def serve(self): + with serve(self.proxy_request, "localhost", 8080) as self.server: + self.server.serve_forever() + + def connect_and_serve(self): + self.connect() + self.serve() + + +def run_proxy_server(time_til_pause: float = 20.0, time_til_resume: float = 30.0): + proxy = ProxyServer("wss://archive.sub.latent.to", time_til_pause, time_til_resume) + proxy.connect() + proxy.serve() + + +if __name__ == "__main__": + run_proxy_server() diff --git a/tests/integration_tests/test_async_substrate_interface.py b/tests/integration_tests/test_async_substrate_interface.py index 8dab260..8957314 100644 --- a/tests/integration_tests/test_async_substrate_interface.py +++ b/tests/integration_tests/test_async_substrate_interface.py @@ -1,12 +1,16 @@ import asyncio +import logging +import os.path import time +import threading import pytest from scalecodec import ss58_encode -from async_substrate_interface.async_substrate import AsyncSubstrateInterface +from async_substrate_interface.async_substrate import AsyncSubstrateInterface, logger from async_substrate_interface.types import ScaleObj from tests.helpers.settings import ARCHIVE_ENTRYPOINT, LATENT_LITE_ENTRYPOINT +from tests.helpers.proxy_server import ProxyServer @pytest.mark.asyncio @@ -174,3 +178,60 @@ async def test_query_map_with_odd_number_of_params(): first_record = qm.records[0] assert len(first_record) == 2 assert len(first_record[0]) == 4 + + +@pytest.mark.asyncio +async def test_improved_reconnection(): + ws_logger_path = "/tmp/websockets-proxy-test" + ws_logger = logging.getLogger("websockets.proxy") + if os.path.exists(ws_logger_path): + os.remove(ws_logger_path) + ws_logger.setLevel(logging.INFO) + ws_logger.addHandler(logging.FileHandler(ws_logger_path)) + + asi_logger_path = "/tmp/async-substrate-interface-test" + if os.path.exists(asi_logger_path): + os.remove(asi_logger_path) + logger.setLevel(logging.DEBUG) + logger.addHandler(logging.FileHandler(asi_logger_path)) + + proxy = ProxyServer("wss://archive.sub.latent.to", 10, 20) + + server_thread = threading.Thread(target=proxy.connect_and_serve) + server_thread.start() + await asyncio.sleep(3) # give the server start up time + async with AsyncSubstrateInterface( + "ws://localhost:8080", + ss58_format=42, + chain_name="Bittensor", + retry_timeout=10.0, + ws_shutdown_timer=None, + ) as substrate: + blocks_to_check = [ + 5215000, + 5215001, + 5215002, + 5215003, + 5215004, + 5215005, + 5215006, + ] + tasks = [] + for block in blocks_to_check: + block_hash = await substrate.get_block_hash(block_id=block) + tasks.append( + substrate.query_map( + "SubtensorModule", "TotalHotkeyShares", block_hash=block_hash + ) + ) + records = await asyncio.gather(*tasks) + assert len(records) == len(blocks_to_check) + await substrate.close() + with open(ws_logger_path, "r") as f: + assert "Pausing" in f.read() + with open(asi_logger_path, "r") as f: + assert "Timeout/ConnectionClosed occurred." in f.read() + shutdown_thread = threading.Thread(target=proxy.close) + shutdown_thread.start() + shutdown_thread.join(timeout=5) + server_thread.join(timeout=5)