diff --git a/CHANGELOG.md b/CHANGELOG.md index 8e73e42..4797dc4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,10 @@ # Changelog +## 1.0.0rc11 /2025-02-06 +* Reuses the websocket for sync Substrate by @thewhaleking in https://github.com/opentensor/async-substrate-interface/pull/29 +* Feat/metadata v15 cache by @camfairchild in https://github.com/opentensor/async-substrate-interface/pull/30 +* Backmerge main to staging rc10 by @ibraheem-opentensor in https://github.com/opentensor/async-substrate-interface/pull/31 + ## 1.0.0rc10 /2025-02-04 * Fixes decoding account ids for sync substrate diff --git a/async_substrate_interface/async_substrate.py b/async_substrate_interface/async_substrate.py index efeb24b..3cb2189 100644 --- a/async_substrate_interface/async_substrate.py +++ b/async_substrate_interface/async_substrate.py @@ -49,7 +49,12 @@ Preprocessed, ) from async_substrate_interface.utils import hex_to_bytes, json +from async_substrate_interface.utils.decoding import ( + _determine_if_old_runtime_call, + _bt_decode_to_dict_or_list, +) from async_substrate_interface.utils.storage import StorageKey +from async_substrate_interface.type_registry import _TYPE_REGISTRY if TYPE_CHECKING: from websockets.asyncio.client import ClientConnection @@ -685,7 +690,7 @@ def __init__( self.ws = Websocket( url, options={ - "max_size": 2**32, + "max_size": self.ws_max_size, "write_limit": 2**16, }, ) @@ -706,6 +711,8 @@ def __init__( ss58_format=self.ss58_format, implements_scale_info=True ) self._metadata_cache = {} + self._metadata_v15_cache = {} + self._old_metadata_v15 = None self._nonces = {} self.metadata_version_hex = "0x0f000000" # v15 self.reload_type_registry() @@ -800,6 +807,20 @@ async def load_registry(self): ) self.registry = PortableRegistry.from_metadata_v15(self.metadata_v15) + async def _load_registry_at_block(self, block_hash: str) -> MetadataV15: + # Should be called for any block that fails decoding. + # Possibly the metadata was different. + metadata_rpc_result = await self.rpc_request( + "state_call", + ["Metadata_metadata_at_version", self.metadata_version_hex], + block_hash=block_hash, + ) + metadata_option_hex_str = metadata_rpc_result["result"] + metadata_option_bytes = bytes.fromhex(metadata_option_hex_str[2:]) + old_metadata = MetadataV15.decode_from_metadata_option(metadata_option_bytes) + + return old_metadata + async def _wait_for_registry(self, _attempt: int = 1, _retries: int = 3) -> None: async def _waiter(): while self.registry is None: @@ -930,7 +951,10 @@ async def get_runtime(block_hash, block_id) -> Runtime: if ( (block_hash and block_hash == self.last_block_hash) or (block_id and block_id == self.block_id) - ) and self._metadata is not None: + ) and all( + x is not None + for x in [self._metadata, self._old_metadata_v15, self.metadata_v15] + ): return Runtime( self.chain, self.runtime_config, @@ -976,9 +1000,9 @@ async def get_runtime(block_hash, block_id) -> Runtime: f"No runtime information for block '{block_hash}'" ) # Check if runtime state already set to current block - if ( - runtime_info.get("specVersion") == self.runtime_version - and self._metadata is not None + if runtime_info.get("specVersion") == self.runtime_version and all( + x is not None + for x in [self._metadata, self._old_metadata_v15, self.metadata_v15] ): return Runtime( self.chain, @@ -1002,6 +1026,8 @@ async def get_runtime(block_hash, block_id) -> Runtime: self.runtime_version ] else: + # TODO when I get time, I'd like to add this and the metadata v15 as tasks with callbacks + # TODO to update the caches, but I don't have time now. metadata = self._metadata = await self.get_block_metadata( block_hash=runtime_block_hash, decode=True ) @@ -1015,6 +1041,30 @@ async def get_runtime(block_hash, block_id) -> Runtime: self._metadata_cache[self.runtime_version] = self._metadata else: metadata = self._metadata + + if self.runtime_version in self._metadata_v15_cache: + # Get metadata v15 from cache + logging.debug( + "Retrieved metadata v15 for {} from memory".format( + self.runtime_version + ) + ) + metadata_v15 = self._old_metadata_v15 = self._metadata_v15_cache[ + self.runtime_version + ] + else: + metadata_v15 = ( + self._old_metadata_v15 + ) = await self._load_registry_at_block(block_hash=runtime_block_hash) + logging.debug( + "Retrieved metadata v15 for {} from Substrate node".format( + self.runtime_version + ) + ) + + # Update metadata v15 cache + self._metadata_v15_cache[self.runtime_version] = metadata_v15 + # Update type registry self.reload_type_registry(use_remote_preset=False, auto_discover=True) @@ -2487,6 +2537,56 @@ async def get_chain_finalised_head(self): return response.get("result") + async def _do_runtime_call_old( + self, + api: str, + method: str, + params: Optional[Union[list, dict]] = None, + block_hash: Optional[str] = None, + ) -> ScaleType: + logging.debug( + f"Decoding old runtime call: {api}.{method} with params: {params} at block hash: {block_hash}" + ) + runtime_call_def = _TYPE_REGISTRY["runtime_api"][api]["methods"][method] + + # Encode params + param_data = b"" + + if "encoder" in runtime_call_def: + param_data = runtime_call_def["encoder"](params) + else: + for idx, param in enumerate(runtime_call_def["params"]): + param_type_string = f"{param['type']}" + if isinstance(params, list): + param_data += await self.encode_scale( + param_type_string, params[idx] + ) + else: + if param["name"] not in params: + raise ValueError( + f"Runtime Call param '{param['name']}' is missing" + ) + + param_data += await self.encode_scale( + param_type_string, params[param["name"]] + ) + + # RPC request + result_data = await self.rpc_request( + "state_call", [f"{api}_{method}", param_data.hex(), block_hash] + ) + result_vec_u8_bytes = hex_to_bytes(result_data["result"]) + result_bytes = await self.decode_scale("Vec", result_vec_u8_bytes) + + # Decode result + # Get correct type + result_decoded = runtime_call_def["decoder"](bytes(result_bytes)) + as_dict = _bt_decode_to_dict_or_list(result_decoded) + logging.debug("Decoded old runtime call result: ", as_dict) + result_obj = ScaleObj(as_dict) + + return result_obj + async def runtime_call( self, api: str, @@ -2513,14 +2613,27 @@ async def runtime_call( params = {} try: - metadata_v15 = self.metadata_v15.value() - apis = {entry["name"]: entry for entry in metadata_v15["apis"]} + if block_hash: + # Use old metadata v15 from init_runtime call + metadata_v15 = self._old_metadata_v15 + else: + metadata_v15 = self.metadata_v15 + + self.registry = PortableRegistry.from_metadata_v15(metadata_v15) + metadata_v15_value = metadata_v15.value() + + apis = {entry["name"]: entry for entry in metadata_v15_value["apis"]} api_entry = apis[api] methods = {entry["name"]: entry for entry in api_entry["methods"]} runtime_call_def = methods[method] except KeyError: raise ValueError(f"Runtime API Call '{api}.{method}' not found in registry") + if _determine_if_old_runtime_call(runtime_call_def, metadata_v15_value): + result = await self._do_runtime_call_old(api, method, params, block_hash) + + return result + if isinstance(params, list) and len(params) != len(runtime_call_def["inputs"]): raise ValueError( f"Number of parameter provided ({len(params)}) does not " diff --git a/async_substrate_interface/sync_substrate.py b/async_substrate_interface/sync_substrate.py index 0cb87d1..f0a9ef2 100644 --- a/async_substrate_interface/sync_substrate.py +++ b/async_substrate_interface/sync_substrate.py @@ -15,6 +15,7 @@ ) from scalecodec.base import RuntimeConfigurationObject, ScaleBytes, ScaleType from websockets.sync.client import connect +from websockets.exceptions import ConnectionClosed from async_substrate_interface.errors import ( ExtrinsicNotFound, @@ -30,7 +31,12 @@ ScaleObj, ) from async_substrate_interface.utils import hex_to_bytes, json +from async_substrate_interface.utils.decoding import ( + _determine_if_old_runtime_call, + _bt_decode_to_dict_or_list, +) from async_substrate_interface.utils.storage import StorageKey +from async_substrate_interface.type_registry import _TYPE_REGISTRY ResultHandler = Callable[[dict, Any], tuple[dict, bool]] @@ -505,8 +511,11 @@ def __init__( ss58_format=self.ss58_format, implements_scale_info=True ) self._metadata_cache = {} + self._metadata_v15_cache = {} + self._old_metadata_v15 = None self.metadata_version_hex = "0x0f000000" # v15 self.reload_type_registry() + self.ws = self.connect(init=True) if not _mock: self.initialize() @@ -527,7 +536,7 @@ def initialize(self): self.initialized = True def __exit__(self, exc_type, exc_val, exc_tb): - pass + self.ws.close() @property def properties(self): @@ -562,6 +571,15 @@ def name(self): self._name = self.rpc_request("system_name", []).get("result") return self._name + def connect(self, init=False): + if init is True: + return connect(self.chain_endpoint, max_size=self.ws_max_size) + else: + if not self.ws.close_code: + return self.ws + else: + return connect(self.chain_endpoint, max_size=self.ws_max_size) + def get_storage_item(self, module: str, storage_function: str): if not self._metadata: self.init_runtime() @@ -593,6 +611,20 @@ def load_registry(self): ) self.registry = PortableRegistry.from_metadata_v15(self.metadata_v15) + def _load_registry_at_block(self, block_hash: str) -> MetadataV15: + # Should be called for any block that fails decoding. + # Possibly the metadata was different. + metadata_rpc_result = self.rpc_request( + "state_call", + ["Metadata_metadata_at_version", self.metadata_version_hex], + block_hash=block_hash, + ) + metadata_option_hex_str = metadata_rpc_result["result"] + metadata_option_bytes = bytes.fromhex(metadata_option_hex_str[2:]) + old_metadata = MetadataV15.decode_from_metadata_option(metadata_option_bytes) + + return old_metadata + def decode_scale( self, type_string: str, @@ -674,7 +706,10 @@ def get_runtime(block_hash, block_id) -> Runtime: if ( (block_hash and block_hash == self.last_block_hash) or (block_id and block_id == self.block_id) - ) and self._metadata is not None: + ) and all( + x is not None + for x in [self._metadata, self._old_metadata_v15, self.metadata_v15] + ): return Runtime( self.chain, self.runtime_config, @@ -716,9 +751,9 @@ def get_runtime(block_hash, block_id) -> Runtime: f"No runtime information for block '{block_hash}'" ) # Check if runtime state already set to current block - if ( - runtime_info.get("specVersion") == self.runtime_version - and self._metadata is not None + if runtime_info.get("specVersion") == self.runtime_version and all( + x is not None + for x in [self._metadata, self._old_metadata_v15, self.metadata_v15] ): return Runtime( self.chain, @@ -755,6 +790,29 @@ def get_runtime(block_hash, block_id) -> Runtime: self._metadata_cache[self.runtime_version] = self._metadata else: metadata = self._metadata + + if self.runtime_version in self._metadata_v15_cache: + # Get metadata v15 from cache + logging.debug( + "Retrieved metadata v15 for {} from memory".format( + self.runtime_version + ) + ) + metadata_v15 = self._old_metadata_v15 = self._metadata_v15_cache[ + self.runtime_version + ] + else: + metadata_v15 = self._old_metadata_v15 = self._load_registry_at_block( + block_hash=runtime_block_hash + ) + logging.debug( + "Retrieved metadata v15 for {} from Substrate node".format( + self.runtime_version + ) + ) + # Update metadata v15 cache + self._metadata_v15_cache[self.runtime_version] = metadata_v15 + # Update type registry self.reload_type_registry(use_remote_preset=False, auto_discover=True) @@ -1620,69 +1678,67 @@ def _make_rpc_request( _received = {} subscription_added = False - with connect(self.chain_endpoint, max_size=2**32) as ws: - item_id = 0 - for payload in payloads: - item_id += 1 - ws.send(json.dumps({**payload["payload"], **{"id": item_id}})) - request_manager.add_request(item_id, payload["id"]) - - while True: - try: - response = json.loads( - ws.recv(timeout=self.retry_timeout, decode=False) + ws = self.connect(init=False if attempt == 1 else True) + item_id = 0 + for payload in payloads: + item_id += 1 + ws.send(json.dumps({**payload["payload"], **{"id": item_id}})) + request_manager.add_request(item_id, payload["id"]) + + while True: + try: + response = json.loads(ws.recv(timeout=self.retry_timeout, decode=False)) + except (TimeoutError, ConnectionClosed): + if attempt >= self.max_retries: + logging.warning( + f"Timed out waiting for RPC requests {attempt} times. Exiting." ) - except TimeoutError: - if attempt >= self.max_retries: - logging.warning( - f"Timed out waiting for RPC requests {attempt} times. Exiting." - ) - raise SubstrateRequestException("Max retries reached.") - else: - return self._make_rpc_request( - payloads, + raise SubstrateRequestException("Max retries reached.") + else: + return self._make_rpc_request( + payloads, + value_scale_type, + storage_item, + result_handler, + attempt + 1, + ) + if "id" in response: + _received[response["id"]] = response + elif "params" in response: + _received[response["params"]["subscription"]] = response + else: + raise SubstrateRequestException(response) + for item_id in list(request_manager.response_map.keys()): + if item_id not in request_manager.responses or isinstance( + result_handler, Callable + ): + if response := _received.pop(item_id): + if ( + isinstance(result_handler, Callable) + and not subscription_added + ): + # handles subscriptions, overwrites the previous mapping of {item_id : payload_id} + # with {subscription_id : payload_id} + try: + item_id = request_manager.overwrite_request( + item_id, response["result"] + ) + subscription_added = True + except KeyError: + raise SubstrateRequestException(str(response)) + decoded_response, complete = self._process_response( + response, + item_id, value_scale_type, storage_item, result_handler, - attempt + 1, ) - if "id" in response: - _received[response["id"]] = response - elif "params" in response: - _received[response["params"]["subscription"]] = response - else: - raise SubstrateRequestException(response) - for item_id in list(request_manager.response_map.keys()): - if item_id not in request_manager.responses or isinstance( - result_handler, Callable - ): - if response := _received.pop(item_id): - if ( - isinstance(result_handler, Callable) - and not subscription_added - ): - # handles subscriptions, overwrites the previous mapping of {item_id : payload_id} - # with {subscription_id : payload_id} - try: - item_id = request_manager.overwrite_request( - item_id, response["result"] - ) - subscription_added = True - except KeyError: - raise SubstrateRequestException(str(response)) - decoded_response, complete = self._process_response( - response, - item_id, - value_scale_type, - storage_item, - result_handler, - ) - request_manager.add_response( - item_id, decoded_response, complete - ) + request_manager.add_response( + item_id, decoded_response, complete + ) - if request_manager.is_complete: - break + if request_manager.is_complete: + break return request_manager.get_results() @@ -2202,6 +2258,54 @@ def get_chain_finalised_head(self): return response.get("result") + def _do_runtime_call_old( + self, + api: str, + method: str, + params: Optional[Union[list, dict]] = None, + block_hash: Optional[str] = None, + ) -> ScaleType: + logging.debug( + f"Decoding old runtime call: {api}.{method} with params: {params} at block hash: {block_hash}" + ) + runtime_call_def = _TYPE_REGISTRY["runtime_api"][api]["methods"][method] + + # Encode params + param_data = b"" + + if "encoder" in runtime_call_def: + param_data = runtime_call_def["encoder"](params) + else: + for idx, param in enumerate(runtime_call_def["params"]): + param_type_string = f"{param['type']}" + if isinstance(params, list): + param_data += self.encode_scale(param_type_string, params[idx]) + else: + if param["name"] not in params: + raise ValueError( + f"Runtime Call param '{param['name']}' is missing" + ) + + param_data += self.encode_scale( + param_type_string, params[param["name"]] + ) + + # RPC request + result_data = self.rpc_request( + "state_call", [f"{api}_{method}", param_data.hex(), block_hash] + ) + result_vec_u8_bytes = hex_to_bytes(result_data["result"]) + result_bytes = self.decode_scale("Vec", result_vec_u8_bytes) + + # Decode result + # Get correct type + result_decoded = runtime_call_def["decoder"](bytes(result_bytes)) + as_dict = _bt_decode_to_dict_or_list(result_decoded) + logging.debug("Decoded old runtime call result: ", as_dict) + result_obj = ScaleObj(as_dict) + + return result_obj + def runtime_call( self, api: str, @@ -2228,14 +2332,27 @@ def runtime_call( params = {} try: - metadata_v15 = self.metadata_v15.value() - apis = {entry["name"]: entry for entry in metadata_v15["apis"]} + if block_hash: + # Use old metadata v15 from init_runtime call + metadata_v15 = self._old_metadata_v15 + else: + metadata_v15 = self.metadata_v15 + + self.registry = PortableRegistry.from_metadata_v15(metadata_v15) + metadata_v15_value = metadata_v15.value() + + apis = {entry["name"]: entry for entry in metadata_v15_value["apis"]} api_entry = apis[api] methods = {entry["name"]: entry for entry in api_entry["methods"]} runtime_call_def = methods[method] except KeyError: raise ValueError(f"Runtime API Call '{api}.{method}' not found in registry") + if _determine_if_old_runtime_call(runtime_call_def, metadata_v15_value): + result = self._do_runtime_call_old(api, method, params, block_hash) + + return result + if isinstance(params, list) and len(params) != len(runtime_call_def["inputs"]): raise ValueError( f"Number of parameter provided ({len(params)}) does not " @@ -2874,9 +2991,8 @@ def close(self): """ Closes the substrate connection, and the websocket connection. """ - # TODO change this logic try: - self.ws.shutdown() + self.ws.close() except AttributeError: pass diff --git a/async_substrate_interface/type_registry.py b/async_substrate_interface/type_registry.py new file mode 100644 index 0000000..0f224e8 --- /dev/null +++ b/async_substrate_interface/type_registry.py @@ -0,0 +1,163 @@ +from bt_decode import ( + NeuronInfo, + NeuronInfoLite, + DelegateInfo, + StakeInfo, + SubnetHyperparameters, + SubnetInfo, + SubnetInfoV2, + encode, +) +from scalecodec import ss58_encode + +_TYPE_REGISTRY: dict[str, dict] = { + "types": { + "Balance": "u64", # Need to override default u128 + }, + "runtime_api": { + "DelegateInfoRuntimeApi": { + "methods": { + "get_delegated": { + "params": [ + { + "name": "coldkey", + "type": "Vec", + }, + ], + "encoder": lambda addr: encode(ss58_encode(addr), "Vec"), + "type": "Vec", + "decoder": DelegateInfo.decode_delegated, + }, + "get_delegates": { + "params": [], + "type": "Vec", + "decoder": DelegateInfo.decode_vec, + }, + } + }, + "NeuronInfoRuntimeApi": { + "methods": { + "get_neuron_lite": { + "params": [ + { + "name": "netuid", + "type": "u16", + }, + { + "name": "uid", + "type": "u16", + }, + ], + "type": "Vec", + "decoder": NeuronInfoLite.decode, + }, + "get_neurons_lite": { + "params": [ + { + "name": "netuid", + "type": "u16", + }, + ], + "type": "Vec", + "decoder": NeuronInfoLite.decode_vec, + }, + "get_neuron": { + "params": [ + { + "name": "netuid", + "type": "u16", + }, + { + "name": "uid", + "type": "u16", + }, + ], + "type": "Vec", + "decoder": NeuronInfo.decode, + }, + "get_neurons": { + "params": [ + { + "name": "netuid", + "type": "u16", + }, + ], + "type": "Vec", + "decoder": NeuronInfo.decode_vec, + }, + } + }, + "StakeInfoRuntimeApi": { + "methods": { + "get_stake_info_for_coldkey": { + "params": [ + { + "name": "coldkey_account_vec", + "type": "Vec", + }, + ], + "type": "Vec", + "encoder": lambda addr: encode(ss58_encode(addr), "Vec"), + "decoder": StakeInfo.decode_vec, + }, + "get_stake_info_for_coldkeys": { + "params": [ + { + "name": "coldkey_account_vecs", + "type": "Vec>", + }, + ], + "type": "Vec", + "encoder": lambda addrs: encode( + [ss58_encode(addr) for addr in addrs], "Vec>" + ), + "decoder": StakeInfo.decode_vec_tuple_vec, + }, + }, + }, + "SubnetInfoRuntimeApi": { + "methods": { + "get_subnet_hyperparams": { + "params": [ + { + "name": "netuid", + "type": "u16", + }, + ], + "type": "Vec", + "decoder": SubnetHyperparameters.decode_option, + }, + "get_subnet_info": { + "params": [ + { + "name": "netuid", + "type": "u16", + }, + ], + "type": "Vec", + "decoder": SubnetInfo.decode_option, + }, + "get_subnet_info_v2": { + "params": [ + { + "name": "netuid", + "type": "u16", + }, + ], + "type": "Vec", + "decoder": SubnetInfoV2.decode_option, + }, + "get_subnets_info": { + "params": [], + "type": "Vec", + "decoder": SubnetInfo.decode_vec_option, + }, + "get_subnets_info_v2": { + "params": [], + "type": "Vec", + "decoder": SubnetInfo.decode_vec_option, + }, + } + }, + }, +} diff --git a/async_substrate_interface/types.py b/async_substrate_interface/types.py index b4c38ed..daaaafc 100644 --- a/async_substrate_interface/types.py +++ b/async_substrate_interface/types.py @@ -344,6 +344,7 @@ class SubstrateMixin(ABC): runtime_config: RuntimeConfigurationObject type_registry: Optional[dict] ss58_format: Optional[int] + ws_max_size = 2**32 @property def chain(self): diff --git a/async_substrate_interface/utils/decoding.py b/async_substrate_interface/utils/decoding.py new file mode 100644 index 0000000..3162fe4 --- /dev/null +++ b/async_substrate_interface/utils/decoding.py @@ -0,0 +1,44 @@ +from bt_decode import AxonInfo, PrometheusInfo + + +def _determine_if_old_runtime_call(runtime_call_def, metadata_v15_value) -> bool: + # Check if the output type is a Vec + # If so, call the API using the old method + output_type_def = [ + x + for x in metadata_v15_value["types"]["types"] + if x["id"] == runtime_call_def["output"] + ] + if output_type_def: + output_type_def = output_type_def[0] + + if "sequence" in output_type_def["type"]["def"]: + output_type_seq_def_id = output_type_def["type"]["def"]["sequence"]["type"] + output_type_seq_def = [ + x + for x in metadata_v15_value["types"]["types"] + if x["id"] == output_type_seq_def_id + ] + if output_type_seq_def: + output_type_seq_def = output_type_seq_def[0] + if ( + "primitive" in output_type_seq_def["type"]["def"] + and output_type_seq_def["type"]["def"]["primitive"] == "u8" + ): + return True + return False + + +def _bt_decode_to_dict_or_list(obj) -> dict | list[dict]: + if isinstance(obj, list): + return [_bt_decode_to_dict_or_list(item) for item in obj] + + as_dict = {} + for key in dir(obj): + if not key.startswith("_"): + val = getattr(obj, key) + if isinstance(val, (AxonInfo, PrometheusInfo)): + as_dict[key] = _bt_decode_to_dict_or_list(val) + else: + as_dict[key] = val + return as_dict diff --git a/pyproject.toml b/pyproject.toml index 24cdd13..c6281f3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "async-substrate-interface" -version = "1.0.0rc10" +version = "1.0.0rc11" description = "Asyncio library for interacting with substrate. Mostly API-compatible with py-substrate-interface" readme = "README.md" license = { file = "LICENSE" } @@ -10,7 +10,7 @@ dependencies = [ "wheel", "asyncstdlib~=3.13.0", "bittensor-wallet>=2.1.3", - "bt-decode==v0.5.0-a0", + "bt-decode==v0.5.0-a2", "scalecodec~=1.2.11", "websockets>=14.1", "xxhash"