From 39c98522de205c63ec2caef769273b62ea9d31aa Mon Sep 17 00:00:00 2001 From: bkb2135 Date: Sat, 9 Aug 2025 10:27:27 +0100 Subject: [PATCH 01/47] Add Spec Version for Weight Set --- apex/__init__.py | 11 +++++++++++ apex/common/async_chain.py | 2 ++ 2 files changed, 13 insertions(+) diff --git a/apex/__init__.py b/apex/__init__.py index 93de2e624..ceb6e43cb 100644 --- a/apex/__init__.py +++ b/apex/__init__.py @@ -7,6 +7,17 @@ __version__ = version("apex") +def _version_to_int(version_str: str) -> int: + version_split = version_str.split(".") + ["0", "0"] # in case a version doesn't have third element, e.g. 3.0 + major = int(version_split[0]) + minor = int(version_split[1]) + patch = int(version_split[2]) + return (10000 * major) + (100 * minor) + patch + + +__spec_version__ = _version_to_int(__version__) + + def setup_logger(log_file_path: str | Path | None = None, level: str = "INFO") -> Any: """Set up the loguru logger with optional file logging and specified log level. diff --git a/apex/common/async_chain.py b/apex/common/async_chain.py index 28453f86a..cd948ad11 100644 --- a/apex/common/async_chain.py +++ b/apex/common/async_chain.py @@ -5,6 +5,7 @@ from bittensor.core.metagraph import AsyncMetagraph from loguru import logger +from apex import __spec_version__ from apex.common.utils import async_cache _METAGRAPH_TTL: int = 10 * 60 @@ -132,6 +133,7 @@ async def set_weights(self, rewards: dict[str, float]) -> bool: netuid=self._netuid, uids=list(weights.keys()), weights=list(weights.values()), + version_key=__spec_version__, wait_for_inclusion=True, wait_for_finalization=True, ) From d3e87363820a06c6082170be8bf158beb40bc3a9 Mon Sep 17 00:00:00 2001 From: Dmytro Bobrenko <17252809+dbobrenko@users.noreply.github.com> Date: Sat, 9 Aug 2025 11:57:02 +0200 Subject: [PATCH 02/47] WIP --- apex/common/epistula.py | 68 +++++++++++++++++ apex/validator/pipeline.py | 7 +- apex/validator/weight_syncer.py | 130 ++++++++++++++++++++++++++++++++ scripts/autoupdater.py | 0 validator.py | 1 - 5 files changed, 200 insertions(+), 6 deletions(-) create mode 100644 apex/validator/weight_syncer.py create mode 100644 scripts/autoupdater.py diff --git a/apex/common/epistula.py b/apex/common/epistula.py index 947c0a6d4..67b16d32b 100644 --- a/apex/common/epistula.py +++ b/apex/common/epistula.py @@ -1,9 +1,12 @@ +# import json import time from hashlib import sha256 from math import ceil from typing import Any from uuid import uuid4 +# from fastapi import HTTPException, Request +# from loguru import logger from substrateinterface import Keypair @@ -33,3 +36,68 @@ async def generate_header( "0x" + hotkey.sign(str(timestamp_interval + 1) + "." + signed_for).hex() ) return headers + + +# def verify_signature( +# signature: str, body: bytes, timestamp: int, uuid: str, signed_for: str, signed_by: str, now: float +# ) -> Annotated[str, "Error Message"] | None: +# if not isinstance(signature, str): +# return "Invalid Signature" +# timestamp = int(timestamp) +# if not isinstance(timestamp, int): +# return "Invalid Timestamp" +# if not isinstance(signed_by, str): +# return "Invalid Sender key" +# if not isinstance(signed_for, str): +# return "Invalid receiver key" +# if not isinstance(uuid, str): +# return "Invalid uuid" +# if not isinstance(body, bytes): +# return "Body is not of type bytes" +# allowed_delta_ms = 8000 +# keypair = Keypair(ss58_address=signed_by) +# if timestamp + allowed_delta_ms < now: +# return "Request is too stale" +# message = f"{sha256(body).hexdigest()}.{uuid}.{timestamp}.{signed_for}" +# verified = keypair.verify(message, signature) +# if not verified: +# return "Signature Mismatch" +# return None +# +# +# async def verify_weight_signature(request: Request): +# signed_by = request.headers.get("Epistula-Signed-By") +# signed_for = request.headers.get("Epistula-Signed-For") +# if not signed_by or not signed_for: +# raise HTTPException(400, "Missing Epistula-Signed-* headers") +# +# if signed_for != shared_settings.WALLET.hotkey.ss58_address: +# logger.error("Bad Request, message is not intended for self") +# raise HTTPException(status_code=400, detail="Bad Request, message is not intended for self") +# validator_hotkeys = [shared_settings.METAGRAPH.hotkeys[uid] for uid in WHITELISTED_VALIDATORS_UIDS] +# if signed_by not in validator_hotkeys: +# logger.error(f"Signer not the expected ss58 address: {signed_by}") +# raise HTTPException(status_code=401, detail="Signer not the expected ss58 address") +# +# now = time.time() +# body: bytes = await request.body() +# try: +# payload = json.loads(body) +# except json.JSONDecodeError: +# raise HTTPException(400, "Invalid JSON body") +# +# if payload.get("uid") != get_uid_from_hotkey(signed_by): +# raise HTTPException(400, "Invalid uid in body") +# +# err = verify_signature( +# request.headers.get("Epistula-Request-Signature"), +# body, +# request.headers.get("Epistula-Timestamp"), +# request.headers.get("Epistula-Uuid"), +# signed_for, +# signed_by, +# now, +# ) +# if err: +# logger.error(err) +# raise HTTPException(status_code=400, detail=err) diff --git a/apex/validator/pipeline.py b/apex/validator/pipeline.py index ea8dfaf9a..632262ee5 100644 --- a/apex/validator/pipeline.py +++ b/apex/validator/pipeline.py @@ -6,7 +6,6 @@ from loguru import logger -from apex.common.config import Config from apex.common.models import QueryTask from apex.services.deep_research.deep_research_base import DeepResearchBase from apex.services.llm.llm_base import LLMBase @@ -19,20 +18,18 @@ class Pipeline: def __init__( self, - config: Config, websearch: WebSearchBase, miner_sampler: MinerSampler, llm: LLMBase, deep_research: DeepResearchBase, logger_apex: LoggerApex | None = None, num_consumers: int = 10, - timeout_consumer: float = 60, - timeout_producer: float = 6, + timeout_consumer: float = 300, + timeout_producer: float = 30, queue_size: int = 10_000, redundancy_rate: float = 0.1, # The rate that references are generated in addition to generator steps reference_rate: float = 0.5, # The rate that references are generated as opposed to generator steps ): - self.config = config self.websearch = websearch self.miner_registry = miner_sampler self.llm = llm diff --git a/apex/validator/weight_syncer.py b/apex/validator/weight_syncer.py new file mode 100644 index 000000000..a9a5576c7 --- /dev/null +++ b/apex/validator/weight_syncer.py @@ -0,0 +1,130 @@ +# import asyncio +# import json +# +# import aiohttp +# import httpx +# import numpy as np +# from loguru import logger +# +# from apex.common.async_chain import AsyncChain +# from apex.common.epistula import generate_header, verify_weight_signature + + +# class WeightSyncer: +# def __init__(self, chain: AsyncChain, weight_dict: dict[int, list[float]]): +# self.chain = chain +# self.wallet = self.chain.wallet +# self.current_hotkey = self.wallet.hotkey.ss58_address +# self.latest_weights: str = {} +# metagraph = await self.chain.metagraph +# self.uid = metagraph.hotkeys.index(self.current_hotkey) +# self.validator_uids = np.where(np.array(metagraph.validator_permit))[0].tolist() +# +# self.weight_matrix = np.zeros((len(self.validator_uids), metagraph.n.item())) +# self.stake_matrix = np.array([metagraph.S[uid] for uid in self.validator_uids]) +# +# self.validator_hotkeys = np.array([metagraph.hotkeys[uid] for uid in self.validator_uids]) +# self.validator_addresses = np.array( +# [ +# f"{metagraph.axons[uid].ip}:{metagraph.axons[uid].port}" +# for uid in self.validator_uids +# if uid < metagraph.n.item() +# ] +# ) +# +# self.weight_dict = weight_dict +# +# self.request_tracker = np.zeros(len(self.validator_uids)) +# +# @router.post("/receive_weight_matrix") +# async def receive_weight_matrix( +# request: Request, +# verification_data: dict = Depends(verify_weight_signature), +# weight_dict=Depends(get_weight_dict), +# ): +# """Endpoint to receive weight matrix updates from validators.""" +# await verify_weight_signature(request) +# +# body = await request.json() +# if not isinstance(body, dict) or "weights" not in body: +# raise HTTPException(status_code=400, detail="Invalid request body format") +# +# try: +# uid = body["uid"] +# weights = list(body["weights"]) +# weight_dict[uid] = weights +# return {"status": "success", "message": "Weight matrix updated successfully"} +# except Exception as e: +# logger.error(f"Error processing weight matrix: {e}") +# raise HTTPException(status_code=500, detail="Error processing weight matrix") +# +# async def send_rewards(self, rewards: dict[str, float], validator_address: str, validator_hotkey: str): +# try: +# async with aiohttp.ClientSession() as session: +# headers = await generate_header( +# self.chain.wallet.hotkey, body=json.dumps(body).encode("utf-8"), signed_for=validator_hotkey +# ) +# async with session.post( +# endpoint + "/v1/chat/completions", +# headers=headers, +# json=body, +# ) as resp: +# result = await resp.text() +# except BaseException: +# # Error during miner query, return empty string. +# return "" +# return str(result) +# +# try: +# vali_url = f"http://{validator_address}/receive_weight_matrix" +# timeout = httpx.Timeout(timeout=40.0) +# async with httpx.AsyncClient( +# timeout=timeout, +# event_hooks={"request": [create_header_hook(self.wallet.hotkey, validator_hotkey)]}, +# ) as client: +# response = await client.post( +# url=vali_url, +# json={"weights": weight_matrix.tolist(), "uid": self.uid}, +# headers={"Content-Type": "application/json"}, +# ) +# if response.status_code != 200: +# raise Exception( +# f"Status code {response.status_code} response for validator {validator_hotkey} - {vali_url}: " +# f"{response.status_code} for uids {len(weight_matrix)}" +# ) +# logger.debug(f"Successfully forwarded response to uid {validator_hotkey} - {vali_url}") +# except httpx.ConnectError as e: +# logger.warning( +# f"Couldn't connect to validator {validator_hotkey} {vali_url} for weight setting. Exception: {e}" +# ) +# except Exception as e: +# logger.warning( +# f"Error while forwarding weight matrix to validator {validator_hotkey} {vali_url}. Exception: {e}" +# ) +# +# async def get_augmented_weights(self, weights: np.ndarray, uid: int) -> np.ndarray: +# """Get the augmented weights for the given uid, sends the weights to the validators.""" +# await self.send_weight_matrixes(weights) +# +# await self.process_weight_dict() +# +# return np.average(self.weight_matrix, axis=0, weights=self.stake_matrix * self.request_tracker) +# +# async def send_weight_matrixes(self, weight_matrix: np.ndarray): +# tasks = [ +# self.send_weights(weight_matrix, validator_address, validator_hotkey) +# for validator_address, validator_hotkey in zip( +# self.validator_addresses, self.validator_hotkeys, strict=False +# ) +# ] +# +# await asyncio.gather(*tasks) +# +# async def process_weight_dict(self): +# for uid, weights in self.weight_dict.items(): +# if uid in self.validator_uids: +# validator_index = self.validator_uids.index(uid) +# self.weight_matrix[validator_index] = weights +# self.request_tracker[validator_index] = 1 +# else: +# logger.warning(f"UID {uid} is not a validator, skipping") diff --git a/scripts/autoupdater.py b/scripts/autoupdater.py new file mode 100644 index 000000000..e69de29bb diff --git a/validator.py b/validator.py index 754afa9a0..53d50fd64 100644 --- a/validator.py +++ b/validator.py @@ -53,7 +53,6 @@ async def main() -> None: deep_research = DeepResearchLangchain(websearch=websearch, **config.deep_research.kwargs) pipeline = Pipeline( - config=config, websearch=websearch, miner_sampler=miner_sampler, llm=llm, From cc0484012a4ca51c6297daa58f841bbbafbe3368 Mon Sep 17 00:00:00 2001 From: bkb2135 Date: Sat, 9 Aug 2025 10:59:55 +0100 Subject: [PATCH 03/47] Update unit tests --- tests/common/mock_async_chain.py | 2 ++ tests/common/test_async_chain.py | 4 ++++ 2 files changed, 6 insertions(+) diff --git a/tests/common/mock_async_chain.py b/tests/common/mock_async_chain.py index 2161c344b..7bbbb0980 100644 --- a/tests/common/mock_async_chain.py +++ b/tests/common/mock_async_chain.py @@ -54,6 +54,7 @@ async def set_weights( netuid: int, uids: Iterable[int], weights: Iterable[float], + version_key: int, wait_for_inclusion: bool, wait_for_finalization: bool, ) -> bool: @@ -62,6 +63,7 @@ async def set_weights( "netuid": netuid, "uids": list(uids), "weights": list(weights), + "version_key": version_key, "wait_for_inclusion": wait_for_inclusion, "wait_for_finalization": wait_for_finalization, } diff --git a/tests/common/test_async_chain.py b/tests/common/test_async_chain.py index 6d83fa262..a16d88ad2 100644 --- a/tests/common/test_async_chain.py +++ b/tests/common/test_async_chain.py @@ -121,6 +121,10 @@ async def test_set_weights_happy_path(monkeypatch): assert stub.last_set_weights is not None assert stub.last_set_weights["uids"] == [2] assert stub.last_set_weights["weights"] == [0.7] + # ensure we pass spec version as version_key + from apex import __spec_version__ + + assert stub.last_set_weights["version_key"] == __spec_version__ @pytest.mark.asyncio From 7c153f104b87309e7e9ac2da2f9acd27fe4791f1 Mon Sep 17 00:00:00 2001 From: Dmytro Bobrenko <17252809+dbobrenko@users.noreply.github.com> Date: Sat, 9 Aug 2025 12:10:29 +0200 Subject: [PATCH 04/47] Add logs, reduce frequency --- apex/validator/miner_sampler.py | 5 +++-- apex/validator/miner_scorer.py | 4 ++-- apex/validator/pipeline.py | 6 +++--- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/apex/validator/miner_sampler.py b/apex/validator/miner_sampler.py index d8c71d16a..95f1543fc 100644 --- a/apex/validator/miner_sampler.py +++ b/apex/validator/miner_sampler.py @@ -38,7 +38,7 @@ def __init__( self, chain: AsyncChain, sample_mode: Literal["random", "sequential"] = "sequential", - sample_size: int = 50, + sample_size: int = 100, logger_db: LoggerDB | None = None, available_uids: Sequence[int] | None = None, available_addresses: Sequence[str] | None = None, @@ -151,9 +151,10 @@ async def query_generators(self, query: str) -> MinerGeneratorResults: hotkeys: list[str] = [] tasks: list[Coroutine[str, str, Any]] = [] + + logger.debug(f"Querying {len(miner_information)} miner generators") for miner_info in miner_information: hotkeys.append(miner_info.hotkey) - logger.debug(f"Querying miner generator at {miner_info.address} with uid: {miner_info.uid}") tasks.append(self.query_miners(body=body, endpoint=miner_info.address, hotkey=miner_info.hotkey)) generator_results = await asyncio.gather(*tasks) return MinerGeneratorResults(query=query, generator_hotkeys=hotkeys, generator_results=generator_results) diff --git a/apex/validator/miner_scorer.py b/apex/validator/miner_scorer.py index 5ee18c646..c2108a110 100644 --- a/apex/validator/miner_scorer.py +++ b/apex/validator/miner_scorer.py @@ -29,9 +29,9 @@ async def start_loop(self) -> None: self._running = True while self._running: await asyncio.sleep(self.interval) + logger.debug("Attempting to set weights") success = await self.set_scores() - if not success: - logger.error("Failed to set weights") + logger.log("INFO" if success else "ERROR", f"Set weights: {'success' if success else 'fail'}") async def shutdown(self) -> None: self._running = False diff --git a/apex/validator/pipeline.py b/apex/validator/pipeline.py index ea8dfaf9a..24d5b8118 100644 --- a/apex/validator/pipeline.py +++ b/apex/validator/pipeline.py @@ -26,10 +26,10 @@ def __init__( deep_research: DeepResearchBase, logger_apex: LoggerApex | None = None, num_consumers: int = 10, - timeout_consumer: float = 60, - timeout_producer: float = 6, + timeout_consumer: float = 180, + timeout_producer: float = 18, queue_size: int = 10_000, - redundancy_rate: float = 0.1, # The rate that references are generated in addition to generator steps + redundancy_rate: float = 0.05, # The rate that references are generated in addition to generator steps reference_rate: float = 0.5, # The rate that references are generated as opposed to generator steps ): self.config = config From 4b3562498cb9860348eb86fe05479d77a690b127 Mon Sep 17 00:00:00 2001 From: Dmytro Bobrenko <17252809+dbobrenko@users.noreply.github.com> Date: Sat, 9 Aug 2025 12:28:20 +0200 Subject: [PATCH 05/47] Add verbose logs --- README.md | 2 +- apex/validator/miner_scorer.py | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 0192cf38d..64ad1e079 100644 --- a/README.md +++ b/README.md @@ -40,7 +40,7 @@ Subnet 1 is the most intelligent inference model on Bittensor. As the first agen 3. **Install the project and its development dependencies:** ```bash - uv venv && uv python install 3.11 && uv python pin 3.11 && uv venv --python=3.11 && uv pip install -e '.[dev]' + uv venv --python=3.11 && uv pip install '.[dev]' ``` 4. **Activate python environment:** diff --git a/apex/validator/miner_scorer.py b/apex/validator/miner_scorer.py index c2108a110..df35a62aa 100644 --- a/apex/validator/miner_scorer.py +++ b/apex/validator/miner_scorer.py @@ -50,6 +50,7 @@ async def set_scores(self) -> bool: expose each one as plain python objects so that downstream code can work with them, and remove rows that are older than the time window. """ + logger.debug("Retrieving miner's performance history") async with self._db() as conn: # type: aiosqlite.Connection # Calculate the cutoff timestamp (current time - window hours). cutoff_timestamp = int(time.time() - SCORE_MA_WINDOW_HOURS * 3600) @@ -70,6 +71,7 @@ async def set_scores(self) -> bool: return False # 2. Iterate over the in-memory list so that the caller can process freely. + logger.debug("Pre-processing miner's rewards") hkey_agg_rewards: dict[str, float] = {} for generator_hotkey, generator_score, disc_hotkeys_json, disc_scores_json in rows: # Deserialize JSON columns. @@ -88,6 +90,7 @@ async def set_scores(self) -> bool: hkey_agg_rewards[hotkey] = float(hkey_agg_rewards.get(hotkey, 0.0)) + float(reward) # 3. Delete rows that are older than the time window. + logger.debug("Cleaning up miner's outdated history") await conn.execute( "DELETE FROM discriminator_results WHERE timestamp < ?", (cutoff_timestamp,), @@ -102,8 +105,10 @@ async def set_scores(self) -> bool: record_str: str = json.dumps(record) fh.write(f"{record_str}\n") # TODO: Flush the db only on set_weights_result is True. + logger.debug("Setting weights") set_weights_result = await self.chain.set_weights(hkey_agg_rewards) # 4. Flush all deletions in a single commit. + logger.debug("Updating rewards DB") await conn.commit() return set_weights_result From 915671f458a2d7d9e39567bd46de1ad5a3858aec Mon Sep 17 00:00:00 2001 From: Dmytro Bobrenko <17252809+dbobrenko@users.noreply.github.com> Date: Sat, 9 Aug 2025 13:05:38 +0200 Subject: [PATCH 06/47] Add logs file and fix tests --- .gitignore | 1 + apex/__init__.py | 2 +- apex/validator/miner_scorer.py | 5 ++++- 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/.gitignore b/.gitignore index a68384611..70cbedd8b 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ +*.log requirements.txt **/*.ipynb debug_rewards.jsonl diff --git a/apex/__init__.py b/apex/__init__.py index ceb6e43cb..680ddbc64 100644 --- a/apex/__init__.py +++ b/apex/__init__.py @@ -44,4 +44,4 @@ def setup_logger(log_file_path: str | Path | None = None, level: str = "INFO") - return logger -setup_logger(level="DEBUG") +setup_logger(log_file_path="logs.log", level="DEBUG") diff --git a/apex/validator/miner_scorer.py b/apex/validator/miner_scorer.py index df35a62aa..46a225b78 100644 --- a/apex/validator/miner_scorer.py +++ b/apex/validator/miner_scorer.py @@ -31,7 +31,10 @@ async def start_loop(self) -> None: await asyncio.sleep(self.interval) logger.debug("Attempting to set weights") success = await self.set_scores() - logger.log("INFO" if success else "ERROR", f"Set weights: {'success' if success else 'fail'}") + if success: + logger.info(f"Set weights: {success}") + else: + logger.error("Failed to set weights") async def shutdown(self) -> None: self._running = False From 7c0219a87211178af0a93c39792f9675ae24bdc4 Mon Sep 17 00:00:00 2001 From: Dmytro Bobrenko <17252809+dbobrenko@users.noreply.github.com> Date: Sat, 9 Aug 2025 13:06:07 +0200 Subject: [PATCH 07/47] Bump patch version --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index ec1d2dc28..31ed9eca4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "apex" -version = "3.0.0" +version = "3.0.1" description = "Bittensor Subnet 1: Apex" readme = "README.md" requires-python = "~=3.11" From dad2cb86124fc7df78f1a77a33f48aa2faf5bca2 Mon Sep 17 00:00:00 2001 From: Dmytro Bobrenko <17252809+dbobrenko@users.noreply.github.com> Date: Sat, 9 Aug 2025 14:07:39 +0200 Subject: [PATCH 08/47] Add more verbose logs --- apex/validator/miner_sampler.py | 5 ++++- validator.py | 10 ++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/apex/validator/miner_sampler.py b/apex/validator/miner_sampler.py index 95f1543fc..1014ce8a0 100644 --- a/apex/validator/miner_sampler.py +++ b/apex/validator/miner_sampler.py @@ -124,6 +124,9 @@ async def _sample_miners(self) -> list[MinerInfo]: else: raise ValueError(f"Unknown sampling mode: {self._sample_mode}") + logger.debug( + f"Sampled uids (sample size = {self._sample_size}): {sorted([miner.uid for miner in miners_sample])}" + ) return miners_sample async def query_miners(self, body: dict[str, Any], endpoint: str, hotkey: str | None = None) -> str: @@ -134,7 +137,7 @@ async def query_miners(self, body: dict[str, Any], endpoint: str, hotkey: str | self._chain.wallet.hotkey, body=json.dumps(body).encode("utf-8"), signed_for=hotkey ) async with session.post( - endpoint + "/v1/chat/completions", + f"{endpoint}/v1/chat/completions", headers=headers, json=body, ) as resp: diff --git a/validator.py b/validator.py index 754afa9a0..87f72a903 100644 --- a/validator.py +++ b/validator.py @@ -4,6 +4,7 @@ from loguru import logger +from apex import __version__ from apex.common.async_chain import AsyncChain from apex.common.config import Config from apex.services.deep_research.deep_research_langchain import DeepResearchLangchain @@ -32,25 +33,33 @@ async def read_args() -> argparse.Namespace: async def main() -> None: args = await read_args() config = Config.from_file(path=args.config) + logger.debug(f"Starting validator v{__version__} with config: {args.config}") chain = AsyncChain(**config.chain.kwargs) await chain.start() + logger.debug(f"Connected to the chain with coldkey '{chain.coldkey[:3]}**', hotkey '{chain.hotkey[:3]}**'") logger_db = LoggerDB(**config.logger_db.kwargs) asyncio.create_task(logger_db.start_loop()) + logger.debug(f"Started DB at: '{logger_db.db_path}") # logger_apex = LoggerApex(async_chain=chain) websearch = WebSearchTavily(**config.websearch.kwargs) + logger.debug("Started web search tool") miner_sampler = MinerSampler(chain=chain, logger_db=logger_db, **config.miner_sampler.kwargs) + logger.debug("Started miner sampler") miner_scorer = MinerScorer(chain=chain, **config.miner_scorer.kwargs) asyncio.create_task(miner_scorer.start_loop()) + logger.debug(f"Started miner scorer with interval={miner_scorer.interval}") llm = LLM(**config.llm.kwargs) + logger.debug("Started LLM provider") deep_research = DeepResearchLangchain(websearch=websearch, **config.deep_research.kwargs) + logger.debug("Started Deep Researcher") pipeline = Pipeline( config=config, @@ -62,6 +71,7 @@ async def main() -> None: **config.pipeline.kwargs, ) try: + logger.debug("Starting pipeline loop...") await pipeline.start_loop() except KeyboardInterrupt: logger.warning("Keyboard interrupt caught, exiting validator") From 7355c967526519f7ec2ced7b9be1407819b7e670 Mon Sep 17 00:00:00 2001 From: Dmytro Bobrenko <17252809+dbobrenko@users.noreply.github.com> Date: Sat, 9 Aug 2025 14:10:25 +0200 Subject: [PATCH 09/47] Reduce loops to 5 --- apex/validator/pipeline.py | 2 +- validator.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/apex/validator/pipeline.py b/apex/validator/pipeline.py index 24d5b8118..bc4adbc25 100644 --- a/apex/validator/pipeline.py +++ b/apex/validator/pipeline.py @@ -25,7 +25,7 @@ def __init__( llm: LLMBase, deep_research: DeepResearchBase, logger_apex: LoggerApex | None = None, - num_consumers: int = 10, + num_consumers: int = 5, timeout_consumer: float = 180, timeout_producer: float = 18, queue_size: int = 10_000, diff --git a/validator.py b/validator.py index 87f72a903..bc1161e19 100644 --- a/validator.py +++ b/validator.py @@ -37,11 +37,11 @@ async def main() -> None: chain = AsyncChain(**config.chain.kwargs) await chain.start() - logger.debug(f"Connected to the chain with coldkey '{chain.coldkey[:3]}**', hotkey '{chain.hotkey[:3]}**'") + logger.debug(f"Connected to the chain with coldkey '{chain.coldkey[:3]}***', hotkey '{chain.hotkey[:3]}***'") logger_db = LoggerDB(**config.logger_db.kwargs) asyncio.create_task(logger_db.start_loop()) - logger.debug(f"Started DB at: '{logger_db.db_path}") + logger.debug(f"Started DB at: '{logger_db.db_path}'") # logger_apex = LoggerApex(async_chain=chain) From 7c22d89b26e466fd820475fe2cef5876e170e380 Mon Sep 17 00:00:00 2001 From: Dmytro Bobrenko <17252809+dbobrenko@users.noreply.github.com> Date: Sat, 9 Aug 2025 14:31:47 +0200 Subject: [PATCH 10/47] Adjust logging --- validator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/validator.py b/validator.py index bc1161e19..5146ab1a6 100644 --- a/validator.py +++ b/validator.py @@ -37,7 +37,7 @@ async def main() -> None: chain = AsyncChain(**config.chain.kwargs) await chain.start() - logger.debug(f"Connected to the chain with coldkey '{chain.coldkey[:3]}***', hotkey '{chain.hotkey[:3]}***'") + logger.debug(f"Connected to the chain with coldkey '{chain.coldkey[:3]}***', hotkey '{chain.hotkey[:2]}***'") logger_db = LoggerDB(**config.logger_db.kwargs) asyncio.create_task(logger_db.start_loop()) From 06c43fcc83c008a2dceb906d4eee354b7d97dae7 Mon Sep 17 00:00:00 2001 From: Dmytro Bobrenko <17252809+dbobrenko@users.noreply.github.com> Date: Sat, 9 Aug 2025 14:33:54 +0200 Subject: [PATCH 11/47] Adjust timeouts --- apex/validator/pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/apex/validator/pipeline.py b/apex/validator/pipeline.py index bc4adbc25..023b9fa51 100644 --- a/apex/validator/pipeline.py +++ b/apex/validator/pipeline.py @@ -27,7 +27,7 @@ def __init__( logger_apex: LoggerApex | None = None, num_consumers: int = 5, timeout_consumer: float = 180, - timeout_producer: float = 18, + timeout_producer: float = 36, queue_size: int = 10_000, redundancy_rate: float = 0.05, # The rate that references are generated in addition to generator steps reference_rate: float = 0.5, # The rate that references are generated as opposed to generator steps From cebee2153e2a70f5da673e2d31752131d66a3238 Mon Sep 17 00:00:00 2001 From: Dmytro Bobrenko <17252809+dbobrenko@users.noreply.github.com> Date: Sat, 9 Aug 2025 16:31:29 +0200 Subject: [PATCH 12/47] Fix sampling --- apex/validator/miner_sampler.py | 17 ++++++++++------- tests/validator/test_miner_sampler.py | 12 ++++++++---- 2 files changed, 18 insertions(+), 11 deletions(-) diff --git a/apex/validator/miner_sampler.py b/apex/validator/miner_sampler.py index 1014ce8a0..665956479 100644 --- a/apex/validator/miner_sampler.py +++ b/apex/validator/miner_sampler.py @@ -2,6 +2,7 @@ import json import random import time +from collections import deque from collections.abc import Coroutine, Sequence from typing import Any, Literal @@ -69,7 +70,8 @@ def __init__( if self._available_uids and self._available_addresses: equal_length = len(self._available_uids) == len(self._available_addresses) assert equal_length, "Test UIDs and addresses must be the same length." - self._remaining_epoch_miners: set[MinerInfo] = set() + self._epoch_deque: deque[MinerInfo] = deque() + self._sample_lock = asyncio.Lock() @async_cache(_TTL_UIDS_RESYNC) async def _get_all_miners(self) -> list[MinerInfo]: @@ -114,12 +116,13 @@ async def _sample_miners(self) -> list[MinerInfo]: miners_sample = random.sample(miners, self._sample_size) elif self._sample_mode == "sequential": - if len(self._remaining_epoch_miners) < self._sample_size: - self._remaining_epoch_miners = set(miners) - logger.debug(f"Starting new miner sampling epoch, miners amount: {len(self._remaining_epoch_miners)}") - indices_sample = sorted(random.sample(range(len(self._remaining_epoch_miners)), self._sample_size)) - miners_sample = [miners[i] for i in indices_sample] - self._remaining_epoch_miners -= set(miners_sample) + async with self._sample_lock: + miners_sample: list[MinerInfo] = [] + while len(miners_sample) < self._sample_size: + if not self._epoch_deque: + # Get shuffled deque of miners. + self._epoch_deque: deque[MinerInfo] = deque(random.sample(miners, len(miners))) + miners_sample.append(self._epoch_deque.popleft()) else: raise ValueError(f"Unknown sampling mode: {self._sample_mode}") diff --git a/tests/validator/test_miner_sampler.py b/tests/validator/test_miner_sampler.py index d2b8c04ba..d184ed41f 100644 --- a/tests/validator/test_miner_sampler.py +++ b/tests/validator/test_miner_sampler.py @@ -166,20 +166,24 @@ async def test_sample_miners_sequential(monkeypatch: MagicMock, miner_sampler: M monkeypatch.setattr(miner_sampler, "_get_all_miners", AsyncMock(return_value=all_miners)) # 1st call in epoch. - with patch("random.sample", return_value=[0, 2]): + with patch( + "random.sample", + return_value=[MinerInfo(uid=1, address="", hotkey="1"), MinerInfo(uid=5, address="", hotkey="5")], + ): miners1 = await miner_sampler._sample_miners() assert len(miners1) == 2 assert {m.uid for m in miners1} == {all_miners[0].uid, all_miners[2].uid} - assert len(miner_sampler._remaining_epoch_miners) == 1 # 2nd call, new epoch starts as remaining (1) < sample_size (2). - with patch("random.sample", return_value=[1, 2]): + with patch( + "random.sample", + return_value=[MinerInfo(uid=3, address="", hotkey="3"), MinerInfo(uid=5, address="", hotkey="5")], + ): miners2 = await miner_sampler._sample_miners() assert len(miners2) == 2 assert {m.uid for m in miners2} == {all_miners[1].uid, all_miners[2].uid} - assert len(miner_sampler._remaining_epoch_miners) == 1 @pytest.mark.asyncio From 784715938445ef12bd8ca9523a8f404ea8c43015 Mon Sep 17 00:00:00 2001 From: Dmytro Bobrenko <17252809+dbobrenko@users.noreply.github.com> Date: Sat, 9 Aug 2025 17:16:05 +0200 Subject: [PATCH 13/47] Add error handling --- apex/common/async_chain.py | 26 +++++++++++++------------- apex/validator/miner_scorer.py | 5 +++-- apex/validator/pipeline.py | 26 ++++++++++++++++---------- 3 files changed, 32 insertions(+), 25 deletions(-) diff --git a/apex/common/async_chain.py b/apex/common/async_chain.py index cd948ad11..e9b91d0b2 100644 --- a/apex/common/async_chain.py +++ b/apex/common/async_chain.py @@ -112,22 +112,22 @@ def network(self) -> list[str]: return self._network async def set_weights(self, rewards: dict[str, float]) -> bool: - metagraph = await self.metagraph() - subtensor = await self.subtensor() - weights: dict[int, float] = {} + try: + metagraph = await self.metagraph() + subtensor = await self.subtensor() + weights: dict[int, float] = {} - for hotkey, reward in rewards.items(): - try: - idx = metagraph.hotkeys.index(hotkey) - except ValueError: - # Hotkey not found in the metagraph (e.g., deregistered). Skip it. - continue + for hotkey, reward in rewards.items(): + try: + idx = metagraph.hotkeys.index(hotkey) + except ValueError: + # Hotkey not found in the metagraph (e.g., deregistered). Skip it. + continue - uid = metagraph.uids[idx] - weights[uid] = reward + uid = metagraph.uids[idx] + weights[uid] = reward - # Set the weights. - try: + # Set the weights. result = await subtensor.set_weights( wallet=self._wallet, netuid=self._netuid, diff --git a/apex/validator/miner_scorer.py b/apex/validator/miner_scorer.py index 46a225b78..0828217ff 100644 --- a/apex/validator/miner_scorer.py +++ b/apex/validator/miner_scorer.py @@ -7,6 +7,7 @@ from pathlib import Path import aiosqlite +import numpy as np from loguru import logger from apex.common.async_chain import AsyncChain @@ -32,7 +33,7 @@ async def start_loop(self) -> None: logger.debug("Attempting to set weights") success = await self.set_scores() if success: - logger.info(f"Set weights: {success}") + logger.info("Successfully set weights") else: logger.error("Failed to set weights") @@ -108,7 +109,7 @@ async def set_scores(self) -> bool: record_str: str = json.dumps(record) fh.write(f"{record_str}\n") # TODO: Flush the db only on set_weights_result is True. - logger.debug("Setting weights") + logger.debug(f"Setting weights, mean reward={np.mean(list(hkey_agg_rewards.values())):.4f}") set_weights_result = await self.chain.set_weights(hkey_agg_rewards) # 4. Flush all deletions in a single commit. diff --git a/apex/validator/pipeline.py b/apex/validator/pipeline.py index 023b9fa51..48f8427a8 100644 --- a/apex/validator/pipeline.py +++ b/apex/validator/pipeline.py @@ -81,21 +81,27 @@ async def run_single(self, task: QueryTask) -> str: logger.debug("Generating task query") query = await generate_query(llm=self.llm, websearch=self.websearch) + reference = None + tool_history = [] if random.random() < self.reference_rate: + try: + generator_results = None + ground_truth = 0 + logger.debug(f"Generating task reference for query: {query[:20]}..") + reference, tool_history = await generate_reference(llm=self.deep_research, query=query) + except BaseException as exc: + logger.exception(f"Failed to generate reference: {exc}") + + if reference is None: ground_truth = 1 logger.debug(f"Querying generators with query: {query[:20]}..") generator_results = await self.miner_registry.query_generators(query=query) if random.random() < self.redundancy_rate: - logger.debug(f"Generating redundant task reference for query: {query[:20]}..") - reference, tool_history = await generate_reference(llm=self.deep_research, query=query) - else: - reference = None - tool_history = [] - else: - generator_results = None - ground_truth = 0 - logger.debug(f"Generating task reference for query: {query[:20]}..") - reference, tool_history = await generate_reference(llm=self.deep_research, query=query) + try: + logger.debug(f"Generating redundant task reference for query: {query[:20]}..") + reference, tool_history = await generate_reference(llm=self.deep_research, query=query) + except BaseException as exc: + logger.warning(f"Failed to generate redundant reference: {exc}") discriminator_results = await self.miner_registry.query_discriminators( query=query, generator_results=generator_results, reference=reference, ground_truth=ground_truth From d8e96d6cd3341dc47d4cd7f69087d43530046d69 Mon Sep 17 00:00:00 2001 From: Dmytro Bobrenko <17252809+dbobrenko@users.noreply.github.com> Date: Sat, 9 Aug 2025 17:20:47 +0200 Subject: [PATCH 14/47] Reduce logs persistency --- apex/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/apex/__init__.py b/apex/__init__.py index 680ddbc64..367567422 100644 --- a/apex/__init__.py +++ b/apex/__init__.py @@ -39,7 +39,7 @@ def setup_logger(log_file_path: str | Path | None = None, level: str = "INFO") - # Add file handler if a path is provided. if log_file_path: file_log_format = "{time:YYYY-MM-DD HH:mm:ss} [{file}:{line}] {message}" - logger.add(str(log_file_path), level=level, format=file_log_format, rotation="10 MB", retention="7 days") + logger.add(str(log_file_path), level=level, format=file_log_format, rotation="5 MB", retention="3 days") return logger From 86e61c3b1a45f227bb8725d5f53292e89115f0d5 Mon Sep 17 00:00:00 2001 From: Dmytro Bobrenko <17252809+dbobrenko@users.noreply.github.com> Date: Sat, 9 Aug 2025 19:05:59 +0200 Subject: [PATCH 15/47] Fix set weights result parsing --- apex/common/async_chain.py | 9 ++++----- apex/validator/miner_scorer.py | 3 ++- validator.py | 5 ++++- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/apex/common/async_chain.py b/apex/common/async_chain.py index e9b91d0b2..c02db9bbf 100644 --- a/apex/common/async_chain.py +++ b/apex/common/async_chain.py @@ -128,7 +128,7 @@ async def set_weights(self, rewards: dict[str, float]) -> bool: weights[uid] = reward # Set the weights. - result = await subtensor.set_weights( + success, err = await subtensor.set_weights( wallet=self._wallet, netuid=self._netuid, uids=list(weights.keys()), @@ -137,10 +137,9 @@ async def set_weights(self, rewards: dict[str, float]) -> bool: wait_for_inclusion=True, wait_for_finalization=True, ) - if not result: - logger.error(f"Error setting weights: {result}") - return False - return True + if not success: + logger.error(f"Error setting weights: {err}") + return success except BaseException as exc: logger.exception(f"Error setting weights: {exc}") return False diff --git a/apex/validator/miner_scorer.py b/apex/validator/miner_scorer.py index 0828217ff..655a06333 100644 --- a/apex/validator/miner_scorer.py +++ b/apex/validator/miner_scorer.py @@ -109,7 +109,8 @@ async def set_scores(self) -> bool: record_str: str = json.dumps(record) fh.write(f"{record_str}\n") # TODO: Flush the db only on set_weights_result is True. - logger.debug(f"Setting weights, mean reward={np.mean(list(hkey_agg_rewards.values())):.4f}") + rewards_array = np.array(list(hkey_agg_rewards.values())) + logger.debug(f"Setting weights, reward mean={rewards_array.mean():.4f} min={rewards_array.min():.4f}") set_weights_result = await self.chain.set_weights(hkey_agg_rewards) # 4. Flush all deletions in a single commit. diff --git a/validator.py b/validator.py index 5146ab1a6..b0ae22229 100644 --- a/validator.py +++ b/validator.py @@ -37,7 +37,10 @@ async def main() -> None: chain = AsyncChain(**config.chain.kwargs) await chain.start() - logger.debug(f"Connected to the chain with coldkey '{chain.coldkey[:3]}***', hotkey '{chain.hotkey[:2]}***'") + logger.debug( + f"Connected to the chain netuid={chain.netuid} with coldkey '{chain.coldkey[:2]}***', " + f"hotkey '{chain.hotkey[:2]}***'" + ) logger_db = LoggerDB(**config.logger_db.kwargs) asyncio.create_task(logger_db.start_loop()) From 53a5f29294d66ff87e5b443f426419e27fe89b2e Mon Sep 17 00:00:00 2001 From: Dmytro Bobrenko <17252809+dbobrenko@users.noreply.github.com> Date: Sat, 9 Aug 2025 20:29:07 +0200 Subject: [PATCH 16/47] Revert sample size to 50 --- apex/validator/miner_sampler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/apex/validator/miner_sampler.py b/apex/validator/miner_sampler.py index 665956479..3b900ec91 100644 --- a/apex/validator/miner_sampler.py +++ b/apex/validator/miner_sampler.py @@ -39,7 +39,7 @@ def __init__( self, chain: AsyncChain, sample_mode: Literal["random", "sequential"] = "sequential", - sample_size: int = 100, + sample_size: int = 50, logger_db: LoggerDB | None = None, available_uids: Sequence[int] | None = None, available_addresses: Sequence[str] | None = None, From 4583804d4c750d5b80ec3453a0db5a6b540e0ab5 Mon Sep 17 00:00:00 2001 From: Dmytro Bobrenko <17252809+dbobrenko@users.noreply.github.com> Date: Sat, 9 Aug 2025 22:55:06 +0200 Subject: [PATCH 17/47] Fix mypy --- apex/common/async_chain.py | 7 +++---- apex/validator/miner_sampler.py | 8 ++++---- apex/validator/pipeline.py | 2 +- 3 files changed, 8 insertions(+), 9 deletions(-) diff --git a/apex/common/async_chain.py b/apex/common/async_chain.py index c02db9bbf..dbaa73ade 100644 --- a/apex/common/async_chain.py +++ b/apex/common/async_chain.py @@ -111,7 +111,7 @@ def netuid(self) -> int: def network(self) -> list[str]: return self._network - async def set_weights(self, rewards: dict[str, float]) -> bool: + async def set_weights(self, rewards: dict[str, float]) -> bool: # type: ignore try: metagraph = await self.metagraph() subtensor = await self.subtensor() @@ -127,7 +127,6 @@ async def set_weights(self, rewards: dict[str, float]) -> bool: uid = metagraph.uids[idx] weights[uid] = reward - # Set the weights. success, err = await subtensor.set_weights( wallet=self._wallet, netuid=self._netuid, @@ -138,10 +137,10 @@ async def set_weights(self, rewards: dict[str, float]) -> bool: wait_for_finalization=True, ) if not success: - logger.error(f"Error setting weights: {err}") + logger.error(f"Error during weight set: {err}") return success except BaseException as exc: - logger.exception(f"Error setting weights: {exc}") + logger.exception(f"Error during weight set: {exc}") return False async def mask_network(self) -> list[str]: diff --git a/apex/validator/miner_sampler.py b/apex/validator/miner_sampler.py index 3b900ec91..f4948efe2 100644 --- a/apex/validator/miner_sampler.py +++ b/apex/validator/miner_sampler.py @@ -112,16 +112,16 @@ async def _get_all_miners(self) -> list[MinerInfo]: async def _sample_miners(self) -> list[MinerInfo]: miners = await self._get_all_miners() + miners_sample: list[MinerInfo] = [] if self._sample_mode == "random": miners_sample = random.sample(miners, self._sample_size) elif self._sample_mode == "sequential": async with self._sample_lock: - miners_sample: list[MinerInfo] = [] while len(miners_sample) < self._sample_size: if not self._epoch_deque: # Get shuffled deque of miners. - self._epoch_deque: deque[MinerInfo] = deque(random.sample(miners, len(miners))) + self._epoch_deque = deque(random.sample(miners, len(miners))) miners_sample.append(self._epoch_deque.popleft()) else: @@ -224,7 +224,7 @@ async def query_discriminators( choice_content = "None" parsed_discriminator_results.append(choice_content) - # Apply scoring logic based on selected generator type + # Apply scoring logic based on selected generator type. if choice_content == str(ground_truth): discriminator_score = score_per_miner else: @@ -232,7 +232,7 @@ async def query_discriminators( discriminator_results_float.append(discriminator_score) - # Generator result is 1 minus sum of discriminator results + # Generator result is 1 minus sum of discriminator results. generator_result_float = 1.0 - sum(discriminator_results_float) miner_discriminator_results = MinerDiscriminatorResults( query=query, diff --git a/apex/validator/pipeline.py b/apex/validator/pipeline.py index 48f8427a8..26b55ca1c 100644 --- a/apex/validator/pipeline.py +++ b/apex/validator/pipeline.py @@ -82,7 +82,7 @@ async def run_single(self, task: QueryTask) -> str: query = await generate_query(llm=self.llm, websearch=self.websearch) reference = None - tool_history = [] + tool_history: list[dict[str, str]] = [] if random.random() < self.reference_rate: try: generator_results = None From 8fdecf6d6d6673f0a1bca9ab10d867665ad06c3d Mon Sep 17 00:00:00 2001 From: Dmytro Bobrenko <17252809+dbobrenko@users.noreply.github.com> Date: Sat, 9 Aug 2025 23:04:45 +0200 Subject: [PATCH 18/47] Fix mypy --- .pre-commit-config.yaml | 19 +++++++++++++++++++ apex/common/async_chain.py | 4 ++-- 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7852ed5cc..e43ae52d9 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -20,4 +20,23 @@ repos: hooks: - id: ruff args: [--fix] + stages: [pre-commit] - id: ruff-format + stages: [pre-commit] + +- repo: local + hooks: + - id: mypy + name: mypy + entry: mypy . + language: system + pass_filenames: false + stages: [pre-commit] + + # Run tests on push so local "push" matches CI. + - id: pytest + name: pytest + entry: pytest tests/ --verbose --failed-first --exitfirst --disable-warnings + language: system + pass_filenames: false + stages: [pre-push] diff --git a/apex/common/async_chain.py b/apex/common/async_chain.py index dbaa73ade..a87d96a88 100644 --- a/apex/common/async_chain.py +++ b/apex/common/async_chain.py @@ -111,7 +111,7 @@ def netuid(self) -> int: def network(self) -> list[str]: return self._network - async def set_weights(self, rewards: dict[str, float]) -> bool: # type: ignore + async def set_weights(self, rewards: dict[str, float]) -> bool: try: metagraph = await self.metagraph() subtensor = await self.subtensor() @@ -138,7 +138,7 @@ async def set_weights(self, rewards: dict[str, float]) -> bool: # type: ignore ) if not success: logger.error(f"Error during weight set: {err}") - return success + return bool(success) except BaseException as exc: logger.exception(f"Error during weight set: {exc}") return False From 710e4d1a494426692e027319ff35d22663f46e98 Mon Sep 17 00:00:00 2001 From: Dmytro Bobrenko <17252809+dbobrenko@users.noreply.github.com> Date: Sat, 9 Aug 2025 23:15:20 +0200 Subject: [PATCH 19/47] Fix tests --- .pre-commit-config.yaml | 3 +-- apex/validator/miner_scorer.py | 7 +++++-- tests/common/mock_async_chain.py | 4 ++-- tests/common/test_async_chain.py | 4 +--- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e43ae52d9..08c1e7862 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -33,10 +33,9 @@ repos: pass_filenames: false stages: [pre-commit] - # Run tests on push so local "push" matches CI. - id: pytest name: pytest entry: pytest tests/ --verbose --failed-first --exitfirst --disable-warnings language: system pass_filenames: false - stages: [pre-push] + stages: [pre-commit] diff --git a/apex/validator/miner_scorer.py b/apex/validator/miner_scorer.py index 655a06333..6539e6689 100644 --- a/apex/validator/miner_scorer.py +++ b/apex/validator/miner_scorer.py @@ -109,8 +109,11 @@ async def set_scores(self) -> bool: record_str: str = json.dumps(record) fh.write(f"{record_str}\n") # TODO: Flush the db only on set_weights_result is True. - rewards_array = np.array(list(hkey_agg_rewards.values())) - logger.debug(f"Setting weights, reward mean={rewards_array.mean():.4f} min={rewards_array.min():.4f}") + if hkey_agg_rewards: + rewards_array = np.array(list(hkey_agg_rewards.values())) + logger.debug(f"Setting weights, reward mean={rewards_array.mean():.4f} min={rewards_array.min():.4f}") + else: + logger.warning(f"Setting empty rewards: {hkey_agg_rewards}") set_weights_result = await self.chain.set_weights(hkey_agg_rewards) # 4. Flush all deletions in a single commit. diff --git a/tests/common/mock_async_chain.py b/tests/common/mock_async_chain.py index 7bbbb0980..1c0312fdd 100644 --- a/tests/common/mock_async_chain.py +++ b/tests/common/mock_async_chain.py @@ -57,7 +57,7 @@ async def set_weights( version_key: int, wait_for_inclusion: bool, wait_for_finalization: bool, - ) -> bool: + ) -> tuple[bool, str | None]: self.last_set_weights = { "wallet": wallet, "netuid": netuid, @@ -67,7 +67,7 @@ async def set_weights( "wait_for_inclusion": wait_for_inclusion, "wait_for_finalization": wait_for_finalization, } - return self.weights_result + return self.weights_result, "" def patch_wallet(monkeypatch: pytest.MonkeyPatch) -> None: diff --git a/tests/common/test_async_chain.py b/tests/common/test_async_chain.py index a16d88ad2..85092880c 100644 --- a/tests/common/test_async_chain.py +++ b/tests/common/test_async_chain.py @@ -2,6 +2,7 @@ import pytest +from apex import __spec_version__ from apex.common.async_chain import AsyncChain # noqa: E402 from tests.common.mock_async_chain import DummyMetagraph, DummySubtensor, patch_subtensor, patch_wallet @@ -121,9 +122,6 @@ async def test_set_weights_happy_path(monkeypatch): assert stub.last_set_weights is not None assert stub.last_set_weights["uids"] == [2] assert stub.last_set_weights["weights"] == [0.7] - # ensure we pass spec version as version_key - from apex import __spec_version__ - assert stub.last_set_weights["version_key"] == __spec_version__ From de71240f7debcaab025d93eb3ab07612190240ab Mon Sep 17 00:00:00 2001 From: bkb2135 Date: Sun, 10 Aug 2025 15:23:36 +0100 Subject: [PATCH 20/47] Add Autoupdater --- apex/validator/auto_update.py | 87 +++++++++++++++++++++++++++++++++++ validator.py | 21 +++++++++ 2 files changed, 108 insertions(+) create mode 100644 apex/validator/auto_update.py diff --git a/apex/validator/auto_update.py b/apex/validator/auto_update.py new file mode 100644 index 000000000..3ad374831 --- /dev/null +++ b/apex/validator/auto_update.py @@ -0,0 +1,87 @@ +import asyncio +import subprocess +import sys +import time +from loguru import logger +from pathlib import Path +from shlex import split + +ROOT_DIR = Path(__file__).parent.parent + + +def get_version() -> str: + """Extract the version as current git commit hash""" + result = subprocess.run( + split("git rev-parse HEAD"), + check=True, + capture_output=True, + cwd=ROOT_DIR, + ) + commit = result.stdout.decode().strip() + assert len(commit) == 40, f"Invalid commit hash: {commit}" + return commit[:8] + + +def pull_latest_version() -> None: + """ + Pull the latest version from git. + This uses `git pull --rebase`, so if any changes were made to the local repository, + this will try to apply them on top of origin's changes. This is intentional, as we + don't want to overwrite any local changes. However, if there are any conflicts, + this will abort the rebase and return to the original state. + The conflicts are expected to happen rarely since validator is expected + to be used as-is. + """ + try: + subprocess.run(split("git pull --rebase --autostash"), check=True, cwd=ROOT_DIR) + except subprocess.CalledProcessError as exc: + logger.error("Failed to pull, reverting: %s", exc) + subprocess.run(split("git rebase --abort"), check=True, cwd=ROOT_DIR) + + +def upgrade_packages() -> None: + """ + Upgrade python packages by running `pip install --upgrade -r requirements.txt`. + Notice: this won't work if some package in `requirements.txt` is downgraded. + Ignored as this is unlikely to happen. + """ + + logger.info("Upgrading packages") + try: + subprocess.run( + split(f"{sys.executable} -m pip install -e ."), + check=True, + cwd=ROOT_DIR, + ) + except subprocess.CalledProcessError as exc: + logger.error("Failed to upgrade packages, proceeding anyway. %s", exc) + + +async def autoupdate_loop() -> None: + """ + Async version of autoupdate that runs alongside the validator. + Checks for updates every hour and applies them if available. + """ + current_version = latest_version = get_version() + logger.info("Current version: %s", current_version) + + try: + while True: + await asyncio.sleep(3600) # Wait 1 hour between checks + + pull_latest_version() + latest_version = get_version() + logger.info("Latest version: %s", latest_version) + + if latest_version != current_version: + logger.info( + "Upgraded to latest version: %s -> %s", + current_version, + latest_version, + ) + upgrade_packages() + current_version = latest_version + + except asyncio.CancelledError: + logger.info("Autoupdate task cancelled") + raise diff --git a/validator.py b/validator.py index 754afa9a0..e92ce7e13 100644 --- a/validator.py +++ b/validator.py @@ -13,6 +13,7 @@ from apex.validator.miner_sampler import MinerSampler from apex.validator.miner_scorer import MinerScorer from apex.validator.pipeline import Pipeline +from apex.validator.auto_update import autoupdate_loop async def read_args() -> argparse.Namespace: @@ -25,6 +26,12 @@ async def read_args() -> argparse.Namespace: help="Config file path (e.g. config/mainnet.yaml).", type=Path, ) + parser.add_argument( + "--no-autoupdate", + action="store_true", + default=False, + help="Disable automatic updates (checks every hour) (default: enabled)", + ) args = parser.parse_args() return args @@ -48,6 +55,12 @@ async def main() -> None: miner_scorer = MinerScorer(chain=chain, **config.miner_scorer.kwargs) asyncio.create_task(miner_scorer.start_loop()) + # Start autoupdate task if enabled + autoupdate_task = None + if not args.no_autoupdate: + logger.info("Autoupdate enabled - will check for updates every hour") + autoupdate_task = asyncio.create_task(autoupdate_loop()) + llm = LLM(**config.llm.kwargs) deep_research = DeepResearchLangchain(websearch=websearch, **config.deep_research.kwargs) @@ -68,6 +81,14 @@ async def main() -> None: except BaseException as exc: logger.exception(f"Unknown exception caught, exiting validator: {exc}") finally: + # Cancel autoupdate task if it was started + if autoupdate_task is not None: + autoupdate_task.cancel() + try: + await autoupdate_task + except asyncio.CancelledError: + pass + await chain.shutdown() await logger_db.shutdown() await miner_scorer.shutdown() From 9bbdc3f5882791b00d8476d99192eb939002a64f Mon Sep 17 00:00:00 2001 From: bkb2135 Date: Sun, 10 Aug 2025 15:25:05 +0100 Subject: [PATCH 21/47] Refactor autoupdate module: reorder imports and improve docstring formatting --- apex/validator/auto_update.py | 14 +++++--------- validator.py | 2 +- 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/apex/validator/auto_update.py b/apex/validator/auto_update.py index 3ad374831..767ab5856 100644 --- a/apex/validator/auto_update.py +++ b/apex/validator/auto_update.py @@ -1,11 +1,11 @@ import asyncio import subprocess import sys -import time -from loguru import logger from pathlib import Path from shlex import split +from loguru import logger + ROOT_DIR = Path(__file__).parent.parent @@ -23,8 +23,7 @@ def get_version() -> str: def pull_latest_version() -> None: - """ - Pull the latest version from git. + """Pull the latest version from git. This uses `git pull --rebase`, so if any changes were made to the local repository, this will try to apply them on top of origin's changes. This is intentional, as we don't want to overwrite any local changes. However, if there are any conflicts, @@ -40,12 +39,10 @@ def pull_latest_version() -> None: def upgrade_packages() -> None: - """ - Upgrade python packages by running `pip install --upgrade -r requirements.txt`. + """Upgrade python packages by running `pip install --upgrade -r requirements.txt`. Notice: this won't work if some package in `requirements.txt` is downgraded. Ignored as this is unlikely to happen. """ - logger.info("Upgrading packages") try: subprocess.run( @@ -58,8 +55,7 @@ def upgrade_packages() -> None: async def autoupdate_loop() -> None: - """ - Async version of autoupdate that runs alongside the validator. + """Async version of autoupdate that runs alongside the validator. Checks for updates every hour and applies them if available. """ current_version = latest_version = get_version() diff --git a/validator.py b/validator.py index e92ce7e13..53653ae72 100644 --- a/validator.py +++ b/validator.py @@ -9,11 +9,11 @@ from apex.services.deep_research.deep_research_langchain import DeepResearchLangchain from apex.services.llm.llm import LLM from apex.services.websearch.websearch_tavily import WebSearchTavily +from apex.validator.auto_update import autoupdate_loop from apex.validator.logger_db import LoggerDB from apex.validator.miner_sampler import MinerSampler from apex.validator.miner_scorer import MinerScorer from apex.validator.pipeline import Pipeline -from apex.validator.auto_update import autoupdate_loop async def read_args() -> argparse.Namespace: From ce3401996711b6df172c03ac1c799bcd9f1ba997 Mon Sep 17 00:00:00 2001 From: bkb2135 Date: Sun, 10 Aug 2025 15:43:01 +0100 Subject: [PATCH 22/47] Enhance docstrings and improve autoupdate termination handling in auto_update.py --- apex/validator/auto_update.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/apex/validator/auto_update.py b/apex/validator/auto_update.py index 767ab5856..ab28b551f 100644 --- a/apex/validator/auto_update.py +++ b/apex/validator/auto_update.py @@ -10,7 +10,7 @@ def get_version() -> str: - """Extract the version as current git commit hash""" + """Extract the version as current git commit hash.""" result = subprocess.run( split("git rev-parse HEAD"), check=True, @@ -24,6 +24,7 @@ def get_version() -> str: def pull_latest_version() -> None: """Pull the latest version from git. + This uses `git pull --rebase`, so if any changes were made to the local repository, this will try to apply them on top of origin's changes. This is intentional, as we don't want to overwrite any local changes. However, if there are any conflicts, @@ -40,6 +41,7 @@ def pull_latest_version() -> None: def upgrade_packages() -> None: """Upgrade python packages by running `pip install --upgrade -r requirements.txt`. + Notice: this won't work if some package in `requirements.txt` is downgraded. Ignored as this is unlikely to happen. """ @@ -56,6 +58,7 @@ def upgrade_packages() -> None: async def autoupdate_loop() -> None: """Async version of autoupdate that runs alongside the validator. + Checks for updates every hour and applies them if available. """ current_version = latest_version = get_version() @@ -71,12 +74,13 @@ async def autoupdate_loop() -> None: if latest_version != current_version: logger.info( - "Upgraded to latest version: %s -> %s", + "Upgraded to latest version: %s -> %s, terminating program for restart", current_version, latest_version, ) upgrade_packages() - current_version = latest_version + logger.info("Program terminating, please run with persistent process manager") + sys.exit(0) except asyncio.CancelledError: logger.info("Autoupdate task cancelled") From a354d703e0b8f4be81abc3850ca9a20439d4f5cf Mon Sep 17 00:00:00 2001 From: Dmytro Bobrenko <17252809+dbobrenko@users.noreply.github.com> Date: Sun, 10 Aug 2025 17:48:03 +0200 Subject: [PATCH 23/47] Add weight synchronizer --- apex/common/async_chain.py | 1 - apex/common/config.py | 1 + apex/common/constants.py | 8 + apex/common/epistula.py | 143 ++++++----- apex/validator/miner_scorer.py | 26 +- apex/validator/weight_syncer.py | 334 ++++++++++++++++---------- pyproject.toml | 1 + tests/validator/test_weight_syncer.py | 141 +++++++++++ uv.lock | 13 +- validator.py | 8 + 10 files changed, 475 insertions(+), 201 deletions(-) create mode 100644 tests/validator/test_weight_syncer.py diff --git a/apex/common/async_chain.py b/apex/common/async_chain.py index a87d96a88..d4e8a2a61 100644 --- a/apex/common/async_chain.py +++ b/apex/common/async_chain.py @@ -16,7 +16,6 @@ def __init__(self, coldkey: str, hotkey: str, netuid: int, network: list[str] | if isinstance(network, str): network = [network] self._network: list[str] = network - self._coldkey = coldkey self._hotkey = hotkey self._netuid = netuid diff --git a/apex/common/config.py b/apex/common/config.py index fc3fff308..9cfcad02b 100644 --- a/apex/common/config.py +++ b/apex/common/config.py @@ -15,6 +15,7 @@ class Config(BaseModel): chain: ConfigClass = Field(default_factory=ConfigClass) websearch: ConfigClass = Field(default_factory=ConfigClass) logger_db: ConfigClass = Field(default_factory=ConfigClass) + weight_syncer: ConfigClass = Field(default_factory=ConfigClass) miner_sampler: ConfigClass = Field(default_factory=ConfigClass) miner_scorer: ConfigClass = Field(default_factory=ConfigClass) llm: ConfigClass = Field(default_factory=ConfigClass) diff --git a/apex/common/constants.py b/apex/common/constants.py index 6315e4485..9f7698ab4 100644 --- a/apex/common/constants.py +++ b/apex/common/constants.py @@ -2,6 +2,14 @@ TEMPERATURE: float = 0.1 WEBPAGE_MAXSIZE: int = 500 VALIDATOR_REFERENCE_LABEL = "Validator" +VALIDATOR_VERIFIED_HOTKEYS = { + "5CGLCBndTR1BvQZzn429ckT8GyxduzyjMgt4K1UVTYa8gKfb": "167.99.236.79:8001", # Macrocosmos. + "5CUbyC2Ez7tWYYmnFSSwjqkw26dFNo9cXH8YmcxBSfxi2XSG": None, # Yuma. + "5C8Em1kDZi5rxgDN4zZtfoT7dUqJ7FFbTzS3yTP5GPgVUsn1": None, # RoundTable21. + "5HmkM6X1D3W3CuCSPuHhrbYyZNBy2aGAiZy9NczoJmtY25H7": None, # Crucible. + "5GeR3cDuuFKJ7p66wKGjY65MWjWnYqffq571ZMV4gKMnJqK5": None, # OTF. + "5D1saVvssckE1XoPwPzdHrqYZtvBJ3vESsrPNxZ4zAxbKGs1": None, # Rizzo. +} _ENGLISH_WORDS: tuple[str, ...] | None = None diff --git a/apex/common/epistula.py b/apex/common/epistula.py index 67b16d32b..95e7449d4 100644 --- a/apex/common/epistula.py +++ b/apex/common/epistula.py @@ -1,14 +1,19 @@ -# import json +import json import time from hashlib import sha256 from math import ceil -from typing import Any +from typing import Annotated, Any from uuid import uuid4 # from fastapi import HTTPException, Request # from loguru import logger +from fastapi import HTTPException, Request +from loguru import logger from substrateinterface import Keypair +from apex.common.async_chain import AsyncChain +from apex.common.constants import VALIDATOR_VERIFIED_HOTKEYS + async def generate_header( hotkey: Keypair, @@ -38,66 +43,74 @@ async def generate_header( return headers -# def verify_signature( -# signature: str, body: bytes, timestamp: int, uuid: str, signed_for: str, signed_by: str, now: float -# ) -> Annotated[str, "Error Message"] | None: -# if not isinstance(signature, str): -# return "Invalid Signature" -# timestamp = int(timestamp) -# if not isinstance(timestamp, int): -# return "Invalid Timestamp" -# if not isinstance(signed_by, str): -# return "Invalid Sender key" -# if not isinstance(signed_for, str): -# return "Invalid receiver key" -# if not isinstance(uuid, str): -# return "Invalid uuid" -# if not isinstance(body, bytes): -# return "Body is not of type bytes" -# allowed_delta_ms = 8000 -# keypair = Keypair(ss58_address=signed_by) -# if timestamp + allowed_delta_ms < now: -# return "Request is too stale" -# message = f"{sha256(body).hexdigest()}.{uuid}.{timestamp}.{signed_for}" -# verified = keypair.verify(message, signature) -# if not verified: -# return "Signature Mismatch" -# return None -# -# -# async def verify_weight_signature(request: Request): -# signed_by = request.headers.get("Epistula-Signed-By") -# signed_for = request.headers.get("Epistula-Signed-For") -# if not signed_by or not signed_for: -# raise HTTPException(400, "Missing Epistula-Signed-* headers") -# -# if signed_for != shared_settings.WALLET.hotkey.ss58_address: -# logger.error("Bad Request, message is not intended for self") -# raise HTTPException(status_code=400, detail="Bad Request, message is not intended for self") -# validator_hotkeys = [shared_settings.METAGRAPH.hotkeys[uid] for uid in WHITELISTED_VALIDATORS_UIDS] -# if signed_by not in validator_hotkeys: -# logger.error(f"Signer not the expected ss58 address: {signed_by}") -# raise HTTPException(status_code=401, detail="Signer not the expected ss58 address") -# -# now = time.time() -# body: bytes = await request.body() -# try: -# payload = json.loads(body) -# except json.JSONDecodeError: -# raise HTTPException(400, "Invalid JSON body") -# -# if payload.get("uid") != get_uid_from_hotkey(signed_by): -# raise HTTPException(400, "Invalid uid in body") -# -# err = verify_signature( -# request.headers.get("Epistula-Request-Signature"), -# body, -# request.headers.get("Epistula-Timestamp"), -# request.headers.get("Epistula-Uuid"), -# signed_for, -# signed_by, -# now, -# ) -# if err: -# logger.error(err) -# raise HTTPException(status_code=400, detail=err) +def verify_signature( + signature: str | None, + body: bytes, + timestamp: str | None, + uuid: str | None, + signed_for: str, + signed_by: str, + now: float, +) -> Annotated[str, "Error Message"] | None: + if not isinstance(signature, str): + return "Invalid Signature" + if not isinstance(timestamp, str) or not timestamp.isdigit(): + return "Invalid Timestamp" + timestamp_as_int = int(timestamp) + if not isinstance(signed_by, str): + return "Invalid Sender key" + if not isinstance(signed_for, str): + return "Invalid receiver key" + if not isinstance(uuid, str): + return "Invalid uuid" + if not isinstance(body, bytes): + return "Body is not of type bytes" + allowed_delta_ms = 8000 + keypair = Keypair(ss58_address=signed_by) + if timestamp_as_int + allowed_delta_ms < now: + return "Request is too stale" + message = f"{sha256(body).hexdigest()}.{uuid}.{timestamp}.{signed_for}" + verified = keypair.verify(message, signature) + if not verified: + return "Signature Mismatch" + return None + + +async def verify_weight_signature(request: Request, chain: AsyncChain) -> None: + signed_by = request.headers.get("Epistula-Signed-By") + signed_for = request.headers.get("Epistula-Signed-For") + if not signed_by or not signed_for: + logger.error("Missing Epistula-Signed-* headers") + raise HTTPException(400, "Missing Epistula-Signed-* headers") + + wallet = chain.wallet + if signed_for != wallet.hotkey.ss58_address: + logger.error("Bad Request, message is not intended for self") + raise HTTPException(status_code=400, detail="Bad Request, message is not intended for self") + + if signed_by not in VALIDATOR_VERIFIED_HOTKEYS: + logger.error(f"Signer not the expected ss58 address: {signed_by}") + raise HTTPException(status_code=401, detail="Signer not the expected ss58 address") + + now = time.time() + body: bytes = await request.body() + try: + json.loads(body) + except json.JSONDecodeError as e: + raise HTTPException(400, "Invalid JSON body") from e + + # if payload.get("uid") != get_uid_from_hotkey(signed_by): + # raise HTTPException(400, "Invalid uid in body") + + err = verify_signature( + request.headers.get("Epistula-Request-Signature"), + body, + request.headers.get("Epistula-Timestamp"), + request.headers.get("Epistula-Uuid"), + signed_for, + signed_by, + now, + ) + if err: + logger.error(err) + raise HTTPException(status_code=400, detail=err) diff --git a/apex/validator/miner_scorer.py b/apex/validator/miner_scorer.py index 6539e6689..c4f97a8b9 100644 --- a/apex/validator/miner_scorer.py +++ b/apex/validator/miner_scorer.py @@ -11,7 +11,8 @@ from loguru import logger from apex.common.async_chain import AsyncChain -from apex.common.constants import VALIDATOR_REFERENCE_LABEL +from apex.common.constants import VALIDATOR_REFERENCE_LABEL, VALIDATOR_VERIFIED_HOTKEYS +from apex.validator.weight_syncer import WeightSyncer # Scoring moving average in hours. Set to be: immunity_period - post_reg_threshold. SCORE_MA_WINDOW_HOURS = 23.75 @@ -19,12 +20,21 @@ class MinerScorer: - def __init__(self, chain: AsyncChain, interval: float = SCORE_INTERVAL_DEFAULT, debug: bool = False): + def __init__( + self, + chain: AsyncChain, + interval: float = SCORE_INTERVAL_DEFAULT, + debug: bool = False, + enable_weight_sync: bool = True, + ): self.chain = chain self.interval = interval - self._running = True self._debug = debug + self._weight_syncer: WeightSyncer | None = None + if enable_weight_sync: + self._weight_syncer = WeightSyncer(chain=chain, verified_hotkeys=VALIDATOR_VERIFIED_HOTKEYS) self._debug_rewards_path = Path("debug_rewards.jsonl") + self._running = True async def start_loop(self) -> None: self._running = True @@ -108,12 +118,20 @@ async def set_scores(self) -> bool: with self._debug_rewards_path.open("a+") as fh: record_str: str = json.dumps(record) fh.write(f"{record_str}\n") - # TODO: Flush the db only on set_weights_result is True. + + if self._weight_syncer is not None: + try: + hkey_agg_rewards = await self._weight_syncer.compute_weighted_rewards(hkey_agg_rewards) + except BaseException as exc: + logger.error(f"Failed to compute weighted average rewards over the network, skipping: {exc}") + if hkey_agg_rewards: rewards_array = np.array(list(hkey_agg_rewards.values())) logger.debug(f"Setting weights, reward mean={rewards_array.mean():.4f} min={rewards_array.min():.4f}") else: logger.warning(f"Setting empty rewards: {hkey_agg_rewards}") + + # TODO: Flush the db only on set_weights_result is True. set_weights_result = await self.chain.set_weights(hkey_agg_rewards) # 4. Flush all deletions in a single commit. diff --git a/apex/validator/weight_syncer.py b/apex/validator/weight_syncer.py index a9a5576c7..37fc858bb 100644 --- a/apex/validator/weight_syncer.py +++ b/apex/validator/weight_syncer.py @@ -1,130 +1,204 @@ -# import asyncio -# import json -# -# import aiohttp -# import httpx -# import numpy as np -# from loguru import logger -# -# from apex.common.async_chain import AsyncChain -# from apex.common.epistula import generate_header, verify_weight_signature - - -# class WeightSyncer: -# def __init__(self, chain: AsyncChain, weight_dict: dict[int, list[float]]): -# self.chain = chain -# self.wallet = self.chain.wallet -# self.current_hotkey = self.wallet.hotkey.ss58_address -# self.latest_weights: str = {} -# metagraph = await self.chain.metagraph -# self.uid = metagraph.hotkeys.index(self.current_hotkey) -# self.validator_uids = np.where(np.array(metagraph.validator_permit))[0].tolist() -# -# self.weight_matrix = np.zeros((len(self.validator_uids), metagraph.n.item())) -# self.stake_matrix = np.array([metagraph.S[uid] for uid in self.validator_uids]) -# -# self.validator_hotkeys = np.array([metagraph.hotkeys[uid] for uid in self.validator_uids]) -# self.validator_addresses = np.array( -# [ -# f"{metagraph.axons[uid].ip}:{metagraph.axons[uid].port}" -# for uid in self.validator_uids -# if uid < metagraph.n.item() -# ] -# ) -# -# self.weight_dict = weight_dict -# -# self.request_tracker = np.zeros(len(self.validator_uids)) -# -# @router.post("/receive_weight_matrix") -# async def receive_weight_matrix( -# request: Request, -# verification_data: dict = Depends(verify_weight_signature), -# weight_dict=Depends(get_weight_dict), -# ): -# """Endpoint to receive weight matrix updates from validators.""" -# await verify_weight_signature(request) -# -# body = await request.json() -# if not isinstance(body, dict) or "weights" not in body: -# raise HTTPException(status_code=400, detail="Invalid request body format") -# -# try: -# uid = body["uid"] -# weights = list(body["weights"]) -# weight_dict[uid] = weights -# return {"status": "success", "message": "Weight matrix updated successfully"} -# except Exception as e: -# logger.error(f"Error processing weight matrix: {e}") -# raise HTTPException(status_code=500, detail="Error processing weight matrix") -# -# async def send_rewards(self, rewards: dict[str, float], validator_address: str, validator_hotkey: str): -# try: -# async with aiohttp.ClientSession() as session: -# headers = await generate_header( -# self.chain.wallet.hotkey, body=json.dumps(body).encode("utf-8"), signed_for=validator_hotkey -# ) -# async with session.post( -# endpoint + "/v1/chat/completions", -# headers=headers, -# json=body, -# ) as resp: -# result = await resp.text() -# except BaseException: -# # Error during miner query, return empty string. -# return "" -# return str(result) -# -# try: -# vali_url = f"http://{validator_address}/receive_weight_matrix" -# timeout = httpx.Timeout(timeout=40.0) -# async with httpx.AsyncClient( -# timeout=timeout, -# event_hooks={"request": [create_header_hook(self.wallet.hotkey, validator_hotkey)]}, -# ) as client: -# response = await client.post( -# url=vali_url, -# json={"weights": weight_matrix.tolist(), "uid": self.uid}, -# headers={"Content-Type": "application/json"}, -# ) -# if response.status_code != 200: -# raise Exception( -# f"Status code {response.status_code} response for validator {validator_hotkey} - {vali_url}: " -# f"{response.status_code} for uids {len(weight_matrix)}" -# ) -# logger.debug(f"Successfully forwarded response to uid {validator_hotkey} - {vali_url}") -# except httpx.ConnectError as e: -# logger.warning( -# f"Couldn't connect to validator {validator_hotkey} {vali_url} for weight setting. Exception: {e}" -# ) -# except Exception as e: -# logger.warning( -# f"Error while forwarding weight matrix to validator {validator_hotkey} {vali_url}. Exception: {e}" -# ) -# -# async def get_augmented_weights(self, weights: np.ndarray, uid: int) -> np.ndarray: -# """Get the augmented weights for the given uid, sends the weights to the validators.""" -# await self.send_weight_matrixes(weights) -# -# await self.process_weight_dict() -# -# return np.average(self.weight_matrix, axis=0, weights=self.stake_matrix * self.request_tracker) -# -# async def send_weight_matrixes(self, weight_matrix: np.ndarray): -# tasks = [ -# self.send_weights(weight_matrix, validator_address, validator_hotkey) -# for validator_address, validator_hotkey in zip( -# self.validator_addresses, self.validator_hotkeys, strict=False -# ) -# ] -# -# await asyncio.gather(*tasks) -# -# async def process_weight_dict(self): -# for uid, weights in self.weight_dict.items(): -# if uid in self.validator_uids: -# validator_index = self.validator_uids.index(uid) -# self.weight_matrix[validator_index] = weights -# self.request_tracker[validator_index] = 1 -# else: -# logger.warning(f"UID {uid} is not a validator, skipping") +import asyncio +import time +from typing import cast + +import httpx +import netaddr +import requests +import uvicorn +from bittensor.core.async_subtensor import AsyncMetagraph +from bittensor.core.extrinsics.asyncex.serving import serve_extrinsic +from fastapi import APIRouter, FastAPI, HTTPException, Request +from loguru import logger +from pydantic import BaseModel + +from apex.common.async_chain import AsyncChain +from apex.common.constants import VALIDATOR_VERIFIED_HOTKEYS +from apex.common.epistula import generate_header, verify_weight_signature + + +class ValidatorInfo(BaseModel): + uid: int + hotkey: str + address: str + stake: float + + +class WeightSyncer: + REWARD_EXPIRATION_SEC: float = 60 * 60 + + def __init__( + self, + chain: AsyncChain, + min_alpha_stake: float = 100_000, + verified_hotkeys: dict[str, str | None] | None = None, + enable_receive: bool = True, + enable_send: bool = True, + port: int = 8001, + ) -> None: + """Validator weight synchronizer.""" + self.chain = chain + self._min_alpha_stake = min_alpha_stake + self.verified_hotkeys = verified_hotkeys or VALIDATOR_VERIFIED_HOTKEYS + self.wallet = self.chain.wallet + self.current_hotkey = self.wallet.hotkey.ss58_address + self.receive_enabled = enable_receive + self.send_enabled = enable_send + self.port = port + self.server: uvicorn.Server | None = None + self.server_task: asyncio.Task[None] | None = None + self.hotkey_rewards: dict[str, float] | None = None + self._last_update_time: float = 0 + + async def start(self) -> None: + if not self.send_enabled: + logger.warning("Weight synchronization API is disabled for incoming reward requests") + return + + try: + # Setup and run the API server in the background. + app = FastAPI() + app.include_router(self.get_router()) + + # Running uvicorn in a background task. + config = uvicorn.Config(app=app, host="0.0.0.0", port=self.port, log_level="info") + self.server = uvicorn.Server(config) + self.server_task = asyncio.create_task(self.server.serve()) + + logger.info(f"Started weight synchronization API on port {self.port}.") + + # Announce the axon on the network. + external_ip = requests.get("https://checkip.amazonaws.com").text.strip() + netaddr.IPAddress(external_ip) + sub = await self.chain.subtensor() + serve_success = await serve_extrinsic( + subtensor=sub, + wallet=self.chain.wallet, + ip=external_ip, + port=self.port, + protocol=4, + netuid=self.chain.netuid, + ) + if serve_success: + logger.success(f"Serving weight syncer axon on subtensor at {external_ip}:{self.port}") + else: + logger.error("Failed to serve weight syncer axon on subtensor") + except BaseException as e: + logger.warning(f"Failed to announce weight syncer axon on subtensor: {e}") + + async def shutdown(self) -> None: + if self.server is not None: + self.server.should_exit = True + if self.server_task is not None: + await self.server_task + + def get_router(self) -> APIRouter: + """Creates and returns a FastAPI router with the endpoints for this class.""" + router = APIRouter() + + @router.post("/v1/get_rewards") + async def get_rewards_endpoint(request: Request) -> dict[str, float]: + """FastAPI endpoint to get the rewards of this validator.""" + await verify_weight_signature(request=request, chain=self.chain) + + if time.time() - self._last_update_time > self.REWARD_EXPIRATION_SEC or self.hotkey_rewards is None: + raise HTTPException(status_code=404, detail="Rewards not available or expired") + if not self.send_enabled: + raise HTTPException(status_code=404, detail="API is disabled") + return self.hotkey_rewards + + return router + + async def compute_weighted_rewards(self, hotkey_rewards: dict[str, float]) -> dict[str, float]: + """Computes weighted rewards by fetching rewards from other validators and averaging them by stake.""" + self.hotkey_rewards = hotkey_rewards + self._last_update_time = time.time() + if not self.receive_enabled: + logger.warning("Rewards weight averaging is disable, using raw rewards") + return hotkey_rewards + + metagraph = await self.chain.metagraph() + + try: + own_uid = metagraph.hotkeys.index(self.current_hotkey) + except ValueError: + logger.error(f"Could not find own hotkey {self.current_hotkey} in metagraph, returning raw rewards") + return hotkey_rewards + + validator_uids: list[int] = [] + for uid in metagraph.uids: + if uid == own_uid: + continue + + stake = metagraph.stake[uid] + hotkey = metagraph.hotkeys[uid] + is_verified = hotkey in self.verified_hotkeys + is_validator = metagraph.validator_permit[uid] + + if stake >= self._min_alpha_stake or is_verified or is_validator: + validator_uids.append(uid) + + validator_rewards_tasks: dict[int, asyncio.Task[dict[str, float]]] = { + uid: asyncio.create_task(self.receive_rewards(metagraph, uid)) for uid in validator_uids + } + + results = await asyncio.gather(*validator_rewards_tasks.values(), return_exceptions=True) + + validator_rewards: dict[int, dict[str, float]] = {} + for uid, result in zip(validator_uids, results, strict=True): + if isinstance(result, BaseException): + logger.warning(f"Cannot receive rewards from uid {uid}: {result}") + continue + validator_rewards[uid] = result + + all_validator_uids = [own_uid] + list(validator_rewards.keys()) + total_stake = sum(metagraph.stake[uid] for uid in all_validator_uids) + + if total_stake == 0: + logger.warning("Total stake of responding validators is zero, returning original rewards") + return hotkey_rewards + + own_stake = metagraph.stake[own_uid] + + weighted_rewards: dict[str, float] = {} + for miner_hkey in hotkey_rewards: + own_reward = hotkey_rewards.get(miner_hkey, 0.0) + total_weighted_reward = own_reward * own_stake + + for uid, rewards in validator_rewards.items(): + validator_reward = rewards.get(miner_hkey, 0.0) + validator_stake = metagraph.stake[uid].item() + total_weighted_reward += validator_reward * validator_stake + + weighted_rewards[miner_hkey] = total_weighted_reward / total_stake + + logger.debug( + f"Averaged rewards over {len(all_validator_uids)} validators. " + f"Self stake percentage: {100 * own_stake / total_stake:.2f}" + ) + return weighted_rewards + + async def receive_rewards(self, metagraph: AsyncMetagraph, uid: int) -> dict[str, float]: + """Receive rewards from the given validator uid.""" + try: + target_hotkey = metagraph.hotkeys[uid] + if (address := VALIDATOR_VERIFIED_HOTKEYS.get(target_hotkey)) is None: + axon = metagraph.axons[uid] + address = f"{axon.ip}:{axon.port}" + async with httpx.AsyncClient() as client: + headers = await generate_header( + self.chain.wallet.hotkey, + body=b"", + signed_for=target_hotkey, + ) + resp = await client.post( + f"http://{address}/v1/get_rewards", + headers=headers, + content=b"", + ) + resp.raise_for_status() + return cast(dict[str, float], resp.json()) + + except BaseException as exc: + logger.warning(f"Cannot receive rewards from uid {uid}: {exc}") + return {} diff --git a/pyproject.toml b/pyproject.toml index 31ed9eca4..9e9296c66 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,7 @@ dependencies = [ "bittensor>=9.7.0", "rouge>=1.0.1", "substrate-interface>=1.7.11", + "types-netaddr>=1.3.0.20240530", ] diff --git a/tests/validator/test_weight_syncer.py b/tests/validator/test_weight_syncer.py new file mode 100644 index 000000000..e5269d7af --- /dev/null +++ b/tests/validator/test_weight_syncer.py @@ -0,0 +1,141 @@ +import time +from unittest.mock import AsyncMock, MagicMock, patch + +import numpy as np +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient + +from apex.validator.weight_syncer import WeightSyncer +from tests.common.mock_async_chain import DummyMetagraph + + +class UidList(list): + """A list that has a .tolist() method for compatibility with torch tensors.""" + + def tolist(self): + return self + + +@pytest.fixture +def mock_axon(): + """Returns a function to create a mock axon.""" + + def _mock_axon(ip, port, is_serving=True): + axon = MagicMock() + axon.ip = ip + axon.port = port + axon.is_serving = is_serving + return axon + + return _mock_axon + + +@pytest.fixture +def mock_metagraph(mock_axon): + """Returns a mock metagraph based on DummyMetagraph.""" + metagraph = DummyMetagraph( + hotkeys=["hotkey0_self", "hotkey1_validator", "hotkey2_validator"], + ) + # Overwrite uids with our special UidList to provide .tolist() + metagraph.uids = UidList([0, 1, 2]) + metagraph.stake = [np.float32(1000.0), np.float32(2000.0), np.float32(500.0)] + metagraph.validator_permit = [True, True, False] + metagraph.axons = [ + mock_axon("1.1.1.1", 8000), + mock_axon("2.2.2.2", 8001), + mock_axon("3.3.3.3", 8002), + ] + return metagraph + + +@pytest.fixture +def mock_chain(mock_metagraph): + """Returns a mock chain with a mock metagraph.""" + chain = MagicMock() + chain.wallet.hotkey.ss58_address = "hotkey0_self" + chain.metagraph = AsyncMock(return_value=mock_metagraph) + return chain + + +@pytest.fixture +def weight_syncer(mock_chain): + """Returns a WeightSyncer instance with a mock chain.""" + return WeightSyncer(chain=mock_chain, min_alpha_stake=1000) + + +@pytest.mark.asyncio +async def test_compute_weighted_rewards_happy_path(weight_syncer, mock_metagraph): + """Test that weighted rewards are computed correctly in the ideal case.""" + local_rewards = {"miner1": 0.9, "miner2": 0.1} + validator1_rewards = {"miner1": 0.85, "miner2": 0.82, "miner3": 0.7} + + with patch.object(weight_syncer, "receive_rewards", new_callable=AsyncMock) as mock_receive: + mock_receive.side_effect = [validator1_rewards, {}] # UID 2 has low stake + + weighted_rewards = await weight_syncer.compute_weighted_rewards(local_rewards) + + # self (1000) + validator1 (2000) = 3000 total stake + # miner1: (0.9 * 1000 + 0.85 * 2000) / 3000 = 0.8666 + # miner2: (0.1 * 1000 + 0.82 * 2000) / 3000 = 0.58 + assert mock_receive.call_count == 1 + assert mock_receive.call_args.args[1] == 1 # Called for UID 1 + assert pytest.approx(weighted_rewards["miner1"], 0.001) == 0.8666 + assert pytest.approx(weighted_rewards["miner2"], 0.001) == 0.58 + assert "miner3" not in weighted_rewards + + +@pytest.mark.asyncio +async def test_compute_weighted_rewards_self_not_in_metagraph(weight_syncer, mock_metagraph): + """Test that local rewards are returned if the validator's hotkey is not in the metagraph.""" + mock_metagraph.hotkeys = ["other1", "other2", "other3"] + local_rewards = {"miner1": 0.9} + weighted_rewards = await weight_syncer.compute_weighted_rewards(local_rewards) + assert weighted_rewards == local_rewards + + +@pytest.mark.asyncio +@patch("httpx.AsyncClient") +async def test_receive_rewards_success(mock_async_client, weight_syncer, mock_metagraph): + """Test successfully receiving rewards from another validator.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = {"miner1": 0.9} + mock_async_client.return_value.__aenter__.return_value.post = AsyncMock(return_value=mock_response) + + rewards = await weight_syncer.receive_rewards(mock_metagraph, 1) + assert rewards == {"miner1": 0.9} + + +@pytest.mark.asyncio +@patch("httpx.AsyncClient") +async def test_receive_rewards_http_error(mock_async_client, weight_syncer, mock_metagraph): + """Test that an empty dict is returned on HTTP error.""" + mock_async_client.return_value.__aenter__.return_value.post.side_effect = Exception("HTTP Error") + rewards = await weight_syncer.receive_rewards(mock_metagraph, 1) + assert rewards == {} + + +@patch("apex.validator.weight_syncer.verify_weight_signature", new_callable=AsyncMock) +def test_get_rewards_endpoint(mock_verify_signature, weight_syncer): + """Test the FastAPI endpoint for serving rewards.""" + app = FastAPI() + app.include_router(weight_syncer.get_router()) + client = TestClient(app) + + # Case 1: No rewards set yet + response = client.post("/v1/get_rewards") + assert response.status_code == 404 + assert response.json() == {"detail": "Rewards not available or expired"} + + # Case 2: Rewards are set and not expired + weight_syncer.hotkey_rewards = {"miner1": 0.95} + weight_syncer._last_update_time = time.time() + response = client.post("/v1/get_rewards") + assert response.status_code == 200 + assert response.json() == {"miner1": 0.95} + + # Case 3: Rewards are expired + weight_syncer._last_update_time = time.time() - WeightSyncer.REWARD_EXPIRATION_SEC - 1 + response = client.post("/v1/get_rewards") + assert response.status_code == 404 diff --git a/uv.lock b/uv.lock index d2db4f48f..f81505ed0 100644 --- a/uv.lock +++ b/uv.lock @@ -141,7 +141,7 @@ wheels = [ [[package]] name = "apex" -version = "3.0.0" +version = "3.0.1" source = { virtual = "." } dependencies = [ { name = "aiohttp" }, @@ -168,6 +168,7 @@ dependencies = [ { name = "rouge" }, { name = "substrate-interface" }, { name = "tavily-python" }, + { name = "types-netaddr" }, ] [package.optional-dependencies] @@ -235,6 +236,7 @@ requires-dist = [ { name = "substrate-interface", specifier = ">=1.7.11" }, { name = "tavily-python", specifier = ">=0.7.10" }, { name = "types-cachetools", marker = "extra == 'dev'", specifier = ">=6.0.0.20250525" }, + { name = "types-netaddr", specifier = ">=1.3.0.20240530" }, { name = "types-pyyaml", marker = "extra == 'dev'", specifier = ">=6.0.12.20250516" }, ] provides-extras = ["dev"] @@ -2865,6 +2867,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/47/8c/4ab0a17ece30fe608270b89cf066387051862899fff9f54ab12511fc7fdd/types_cachetools-6.0.0.20250525-py3-none-any.whl", hash = "sha256:1de8f0fe4bdcb187a48d2026c1e3672830f67943ad2bf3486abe031b632f1252", size = 8938 }, ] +[[package]] +name = "types-netaddr" +version = "1.3.0.20240530" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/44/10/b4bad18b6918eec74db1087d1376cc094fdef5fa1261aa905b6cd94b9408/types-netaddr-1.3.0.20240530.tar.gz", hash = "sha256:742c2ec1f202b666f544223e2616b34f1f13df80c91e5aeaaa93a72e4d0774ea", size = 10459 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d0/de/1eaa9f59ea51eb8b64c04f2be5af8c91940072f1c5b0cfa5870a7665a028/types_netaddr-1.3.0.20240530-py3-none-any.whl", hash = "sha256:354998d018e326da4f1d9b005fc91137b7c2c473aaf03c4ef64bf83c6861b440", size = 14335 }, +] + [[package]] name = "types-pyyaml" version = "6.0.12.20250516" diff --git a/validator.py b/validator.py index 8cace967c..b8d0eedd7 100644 --- a/validator.py +++ b/validator.py @@ -14,6 +14,7 @@ from apex.validator.miner_sampler import MinerSampler from apex.validator.miner_scorer import MinerScorer from apex.validator.pipeline import Pipeline +from apex.validator.weight_syncer import WeightSyncer async def read_args() -> argparse.Namespace: @@ -54,6 +55,13 @@ async def main() -> None: miner_sampler = MinerSampler(chain=chain, logger_db=logger_db, **config.miner_sampler.kwargs) logger.debug("Started miner sampler") + weight_syncer = WeightSyncer(**config.weight_syncer.kwargs) + await weight_syncer.start() + logger.debug( + f"Started weight synchronizer, receive enabled: {weight_syncer.receive_enabled}, " + f"send enabled: {weight_syncer.send_enabled}, port: {weight_syncer.port}" + ) + miner_scorer = MinerScorer(chain=chain, **config.miner_scorer.kwargs) asyncio.create_task(miner_scorer.start_loop()) logger.debug(f"Started miner scorer with interval={miner_scorer.interval}") From 59057b1a8fa1624ba0f2768818c703cd3db09a24 Mon Sep 17 00:00:00 2001 From: Dmytro Bobrenko <17252809+dbobrenko@users.noreply.github.com> Date: Sun, 10 Aug 2025 22:58:32 +0200 Subject: [PATCH 24/47] Fix weight syncer --- apex/common/epistula.py | 20 +++++--- apex/validator/miner_scorer.py | 8 ++- apex/validator/weight_syncer.py | 73 ++++++++++++++++----------- config/mainnet.yaml.example | 13 ++++- config/testnet.yaml.example | 9 ++++ pyproject.toml | 4 +- tests/validator/test_weight_syncer.py | 11 ++-- validator.py | 8 +-- 8 files changed, 89 insertions(+), 57 deletions(-) diff --git a/apex/common/epistula.py b/apex/common/epistula.py index 95e7449d4..e5748b3dd 100644 --- a/apex/common/epistula.py +++ b/apex/common/epistula.py @@ -76,7 +76,7 @@ def verify_signature( return None -async def verify_weight_signature(request: Request, chain: AsyncChain) -> None: +async def verify_validator_signature(request: Request, chain: AsyncChain, min_stake: float = 1024) -> None: signed_by = request.headers.get("Epistula-Signed-By") signed_for = request.headers.get("Epistula-Signed-For") if not signed_by or not signed_for: @@ -88,7 +88,16 @@ async def verify_weight_signature(request: Request, chain: AsyncChain) -> None: logger.error("Bad Request, message is not intended for self") raise HTTPException(status_code=400, detail="Bad Request, message is not intended for self") - if signed_by not in VALIDATOR_VERIFIED_HOTKEYS: + is_validator = True + if min_stake > 0: + metagraph = await chain.metagraph() + try: + caller_uid = metagraph.hotkeys.index(signed_by) + except ValueError as exc: + raise HTTPException(status_code=401, detail="Signer is not in metagraph") from exc + is_validator = metagraph.stake[caller_uid] > min_stake + + if signed_by not in VALIDATOR_VERIFIED_HOTKEYS and not is_validator: logger.error(f"Signer not the expected ss58 address: {signed_by}") raise HTTPException(status_code=401, detail="Signer not the expected ss58 address") @@ -96,11 +105,8 @@ async def verify_weight_signature(request: Request, chain: AsyncChain) -> None: body: bytes = await request.body() try: json.loads(body) - except json.JSONDecodeError as e: - raise HTTPException(400, "Invalid JSON body") from e - - # if payload.get("uid") != get_uid_from_hotkey(signed_by): - # raise HTTPException(400, "Invalid uid in body") + except json.JSONDecodeError as exc: + raise HTTPException(400, "Invalid JSON body") from exc err = verify_signature( request.headers.get("Epistula-Request-Signature"), diff --git a/apex/validator/miner_scorer.py b/apex/validator/miner_scorer.py index c4f97a8b9..b24490cbb 100644 --- a/apex/validator/miner_scorer.py +++ b/apex/validator/miner_scorer.py @@ -11,7 +11,7 @@ from loguru import logger from apex.common.async_chain import AsyncChain -from apex.common.constants import VALIDATOR_REFERENCE_LABEL, VALIDATOR_VERIFIED_HOTKEYS +from apex.common.constants import VALIDATOR_REFERENCE_LABEL from apex.validator.weight_syncer import WeightSyncer # Scoring moving average in hours. Set to be: immunity_period - post_reg_threshold. @@ -23,16 +23,14 @@ class MinerScorer: def __init__( self, chain: AsyncChain, + weight_syncer: WeightSyncer | None = None, interval: float = SCORE_INTERVAL_DEFAULT, debug: bool = False, - enable_weight_sync: bool = True, ): self.chain = chain self.interval = interval self._debug = debug - self._weight_syncer: WeightSyncer | None = None - if enable_weight_sync: - self._weight_syncer = WeightSyncer(chain=chain, verified_hotkeys=VALIDATOR_VERIFIED_HOTKEYS) + self._weight_syncer = weight_syncer self._debug_rewards_path = Path("debug_rewards.jsonl") self._running = True diff --git a/apex/validator/weight_syncer.py b/apex/validator/weight_syncer.py index 37fc858bb..05b85599c 100644 --- a/apex/validator/weight_syncer.py +++ b/apex/validator/weight_syncer.py @@ -1,4 +1,5 @@ import asyncio +import os import time from typing import cast @@ -14,7 +15,7 @@ from apex.common.async_chain import AsyncChain from apex.common.constants import VALIDATOR_VERIFIED_HOTKEYS -from apex.common.epistula import generate_header, verify_weight_signature +from apex.common.epistula import generate_header, verify_validator_signature class ValidatorInfo(BaseModel): @@ -38,7 +39,7 @@ def __init__( ) -> None: """Validator weight synchronizer.""" self.chain = chain - self._min_alpha_stake = min_alpha_stake + self.min_alpha_stake = min_alpha_stake self.verified_hotkeys = verified_hotkeys or VALIDATOR_VERIFIED_HOTKEYS self.wallet = self.chain.wallet self.current_hotkey = self.wallet.hotkey.ss58_address @@ -48,7 +49,7 @@ def __init__( self.server: uvicorn.Server | None = None self.server_task: asyncio.Task[None] | None = None self.hotkey_rewards: dict[str, float] | None = None - self._last_update_time: float = 0 + self.last_update_time: float = 0 async def start(self) -> None: if not self.send_enabled: @@ -56,16 +57,21 @@ async def start(self) -> None: return try: - # Setup and run the API server in the background. app = FastAPI() app.include_router(self.get_router()) - # Running uvicorn in a background task. - config = uvicorn.Config(app=app, host="0.0.0.0", port=self.port, log_level="info") + config = uvicorn.Config( + app=app, + host="0.0.0.0", + port=self.port, + log_level="info", + workers=1, + reload=False, + loop="asyncio", + ) self.server = uvicorn.Server(config) self.server_task = asyncio.create_task(self.server.serve()) - - logger.info(f"Started weight synchronization API on port {self.port}.") + logger.info(f"Started weight synchronization API on port {self.port}. pid={os.getpid()} self_id={id(self)}") # Announce the axon on the network. external_ip = requests.get("https://checkip.amazonaws.com").text.strip() @@ -98,13 +104,18 @@ def get_router(self) -> APIRouter: @router.post("/v1/get_rewards") async def get_rewards_endpoint(request: Request) -> dict[str, float]: - """FastAPI endpoint to get the rewards of this validator.""" - await verify_weight_signature(request=request, chain=self.chain) - - if time.time() - self._last_update_time > self.REWARD_EXPIRATION_SEC or self.hotkey_rewards is None: - raise HTTPException(status_code=404, detail="Rewards not available or expired") + await verify_validator_signature(request=request, chain=self.chain, min_stake=self.min_alpha_stake) + + outdated = time.time() - self.last_update_time + if (outdated := time.time() - self.last_update_time) > self.REWARD_EXPIRATION_SEC: + logger.warning(f"Rewards expired: {outdated:.2f}s - {self.last_update_time}") + raise HTTPException(status_code=503, detail="Rewards expired") + if self.hotkey_rewards is None: + logger.warning("Rewards not available") + raise HTTPException(status_code=503, detail="Rewards not available") if not self.send_enabled: - raise HTTPException(status_code=404, detail="API is disabled") + logger.warning("API is disabled") + raise HTTPException(status_code=405, detail="API is disabled") return self.hotkey_rewards return router @@ -112,7 +123,11 @@ async def get_rewards_endpoint(request: Request) -> dict[str, float]: async def compute_weighted_rewards(self, hotkey_rewards: dict[str, float]) -> dict[str, float]: """Computes weighted rewards by fetching rewards from other validators and averaging them by stake.""" self.hotkey_rewards = hotkey_rewards - self._last_update_time = time.time() + self.last_update_time = time.time() + # logger.debug(f"Updating rewards at: {self.last_update_time}") + import os + + logger.debug(f"Updating rewards at: {self.last_update_time} pid={os.getpid()} self_id={id(self)}") if not self.receive_enabled: logger.warning("Rewards weight averaging is disable, using raw rewards") return hotkey_rewards @@ -125,7 +140,7 @@ async def compute_weighted_rewards(self, hotkey_rewards: dict[str, float]) -> di logger.error(f"Could not find own hotkey {self.current_hotkey} in metagraph, returning raw rewards") return hotkey_rewards - validator_uids: list[int] = [] + validator_rewards_tasks: dict[int, asyncio.Task[dict[str, float]]] = {} for uid in metagraph.uids: if uid == own_uid: continue @@ -135,21 +150,18 @@ async def compute_weighted_rewards(self, hotkey_rewards: dict[str, float]) -> di is_verified = hotkey in self.verified_hotkeys is_validator = metagraph.validator_permit[uid] - if stake >= self._min_alpha_stake or is_verified or is_validator: - validator_uids.append(uid) - - validator_rewards_tasks: dict[int, asyncio.Task[dict[str, float]]] = { - uid: asyncio.create_task(self.receive_rewards(metagraph, uid)) for uid in validator_uids - } + if (stake >= self.min_alpha_stake and is_validator) or is_verified: + validator_rewards_tasks[uid] = asyncio.create_task(self.receive_rewards(metagraph, uid)) results = await asyncio.gather(*validator_rewards_tasks.values(), return_exceptions=True) validator_rewards: dict[int, dict[str, float]] = {} - for uid, result in zip(validator_uids, results, strict=True): - if isinstance(result, BaseException): + for uid, result in zip(validator_rewards_tasks, results, strict=True): + if isinstance(result, BaseException) or not result: logger.warning(f"Cannot receive rewards from uid {uid}: {result}") continue validator_rewards[uid] = result + logger.debug(f"Received rewards from validator {uid} with stake {metagraph.stake[uid]}") all_validator_uids = [own_uid] + list(validator_rewards.keys()) total_stake = sum(metagraph.stake[uid] for uid in all_validator_uids) @@ -167,14 +179,13 @@ async def compute_weighted_rewards(self, hotkey_rewards: dict[str, float]) -> di for uid, rewards in validator_rewards.items(): validator_reward = rewards.get(miner_hkey, 0.0) - validator_stake = metagraph.stake[uid].item() - total_weighted_reward += validator_reward * validator_stake + total_weighted_reward += validator_reward * metagraph.stake[uid] weighted_rewards[miner_hkey] = total_weighted_reward / total_stake logger.debug( f"Averaged rewards over {len(all_validator_uids)} validators. " - f"Self stake percentage: {100 * own_stake / total_stake:.2f}" + f"Self stake: {100 * own_stake / total_stake:.2f}%" ) return weighted_rewards @@ -185,16 +196,18 @@ async def receive_rewards(self, metagraph: AsyncMetagraph, uid: int) -> dict[str if (address := VALIDATOR_VERIFIED_HOTKEYS.get(target_hotkey)) is None: axon = metagraph.axons[uid] address = f"{axon.ip}:{axon.port}" + async with httpx.AsyncClient() as client: + body = b"{}" headers = await generate_header( - self.chain.wallet.hotkey, - body=b"", + hotkey=self.chain.wallet.hotkey, + body=body, signed_for=target_hotkey, ) resp = await client.post( f"http://{address}/v1/get_rewards", headers=headers, - content=b"", + content=body, ) resp.raise_for_status() return cast(dict[str, float], resp.json()) diff --git a/config/mainnet.yaml.example b/config/mainnet.yaml.example index 3d8bffaf6..7013b1764 100644 --- a/config/mainnet.yaml.example +++ b/config/mainnet.yaml.example @@ -1,8 +1,8 @@ chain: kwargs: netuid: 1 - coldkey: "validator" - hotkey: "default" + coldkey: "YOUR_COLDKEY" + hotkey: "YOUR_HOTKEY" network: - finney # - ws://LOCAL_SUBTENSOR_FALLBACK_1 @@ -27,3 +27,12 @@ deep_research: research_model: "Qwen/Qwen3-235B-A22B-Instruct-2507" compression_model: "deepseek-ai/DeepSeek-V3-0324" final_model: "deepseek-ai/DeepSeek-V3-0324" + +weight_syncer: + kwargs: + # Change port if needed. + port: 8001 + # When enabled performs weight synchronization across validators in order to improve vTrust. + enable_receive: True + # When enabled, other validators can request your rewards for their weight averaging. + enable_send: True diff --git a/config/testnet.yaml.example b/config/testnet.yaml.example index 9b24697c8..f3c8864a5 100644 --- a/config/testnet.yaml.example +++ b/config/testnet.yaml.example @@ -34,3 +34,12 @@ miner_sampler: # For testing purposes one can specify available pool of uids. # available_uids: [1, 2] # available_addresses: ["http://0.0.0.0:8081", "http://0.0.0.0:8082"] + +weight_syncer: + kwargs: + # Change port if needed. + port: 8001 + # When enabled performs weight synchronization across validators in order to improve vTrust. + enable_receive: True + # When enabled, other validators can request your rewards for their weight averaging. + enable_send: True diff --git a/pyproject.toml b/pyproject.toml index 9e9296c66..832ad83e3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -83,10 +83,8 @@ exclude = [ "^tests/", "^venv/", '^\.venv/', - # TODO: Enable once fixed. "scripts/", - "apex/services/", - "apex/validator/", + "drafts/", ] [[tool.mypy.overrides]] diff --git a/tests/validator/test_weight_syncer.py b/tests/validator/test_weight_syncer.py index e5269d7af..26355565b 100644 --- a/tests/validator/test_weight_syncer.py +++ b/tests/validator/test_weight_syncer.py @@ -116,7 +116,7 @@ async def test_receive_rewards_http_error(mock_async_client, weight_syncer, mock assert rewards == {} -@patch("apex.validator.weight_syncer.verify_weight_signature", new_callable=AsyncMock) +@patch("apex.validator.weight_syncer.verify_validator_signature", new_callable=AsyncMock) def test_get_rewards_endpoint(mock_verify_signature, weight_syncer): """Test the FastAPI endpoint for serving rewards.""" app = FastAPI() @@ -125,17 +125,16 @@ def test_get_rewards_endpoint(mock_verify_signature, weight_syncer): # Case 1: No rewards set yet response = client.post("/v1/get_rewards") - assert response.status_code == 404 - assert response.json() == {"detail": "Rewards not available or expired"} + assert response.status_code == 503 # Case 2: Rewards are set and not expired weight_syncer.hotkey_rewards = {"miner1": 0.95} - weight_syncer._last_update_time = time.time() + weight_syncer.last_update_time = time.time() response = client.post("/v1/get_rewards") assert response.status_code == 200 assert response.json() == {"miner1": 0.95} # Case 3: Rewards are expired - weight_syncer._last_update_time = time.time() - WeightSyncer.REWARD_EXPIRATION_SEC - 1 + weight_syncer.last_update_time = time.time() - WeightSyncer.REWARD_EXPIRATION_SEC - 1 response = client.post("/v1/get_rewards") - assert response.status_code == 404 + assert response.status_code == 503 diff --git a/validator.py b/validator.py index b8d0eedd7..42ecc00ed 100644 --- a/validator.py +++ b/validator.py @@ -22,8 +22,8 @@ async def read_args() -> argparse.Namespace: parser.add_argument( "-c", "--config", - # default="config/testnet.yaml", - default="config/mainnet.yaml", + default="config/testnet.yaml", + # default="config/mainnet.yaml", help="Config file path (e.g. config/mainnet.yaml).", type=Path, ) @@ -55,14 +55,14 @@ async def main() -> None: miner_sampler = MinerSampler(chain=chain, logger_db=logger_db, **config.miner_sampler.kwargs) logger.debug("Started miner sampler") - weight_syncer = WeightSyncer(**config.weight_syncer.kwargs) + weight_syncer = WeightSyncer(chain=chain, **config.weight_syncer.kwargs) await weight_syncer.start() logger.debug( f"Started weight synchronizer, receive enabled: {weight_syncer.receive_enabled}, " f"send enabled: {weight_syncer.send_enabled}, port: {weight_syncer.port}" ) - miner_scorer = MinerScorer(chain=chain, **config.miner_scorer.kwargs) + miner_scorer = MinerScorer(chain=chain, weight_syncer=weight_syncer, **config.miner_scorer.kwargs) asyncio.create_task(miner_scorer.start_loop()) logger.debug(f"Started miner scorer with interval={miner_scorer.interval}") From d65aeed3691cf43cd63630af0d393b0bc74019a1 Mon Sep 17 00:00:00 2001 From: Dmytro Bobrenko <17252809+dbobrenko@users.noreply.github.com> Date: Sun, 10 Aug 2025 23:04:00 +0200 Subject: [PATCH 25/47] Clean up the code --- apex/common/epistula.py | 2 -- config/mainnet.yaml.example | 6 +++--- config/testnet.yaml.example | 6 +++--- 3 files changed, 6 insertions(+), 8 deletions(-) diff --git a/apex/common/epistula.py b/apex/common/epistula.py index e5748b3dd..8d5aface7 100644 --- a/apex/common/epistula.py +++ b/apex/common/epistula.py @@ -5,8 +5,6 @@ from typing import Annotated, Any from uuid import uuid4 -# from fastapi import HTTPException, Request -# from loguru import logger from fastapi import HTTPException, Request from loguru import logger from substrateinterface import Keypair diff --git a/config/mainnet.yaml.example b/config/mainnet.yaml.example index 7013b1764..bff750080 100644 --- a/config/mainnet.yaml.example +++ b/config/mainnet.yaml.example @@ -30,9 +30,9 @@ deep_research: weight_syncer: kwargs: - # Change port if needed. + # Change the port if necessary. port: 8001 - # When enabled performs weight synchronization across validators in order to improve vTrust. + # When enabled, performs weight synchronization across validators, drastically improves vTrust. enable_receive: True - # When enabled, other validators can request your rewards for their weight averaging. + # When enabled, allows other validators to request your rewards, slightly improves vTrust. enable_send: True diff --git a/config/testnet.yaml.example b/config/testnet.yaml.example index f3c8864a5..adea06149 100644 --- a/config/testnet.yaml.example +++ b/config/testnet.yaml.example @@ -37,9 +37,9 @@ miner_sampler: weight_syncer: kwargs: - # Change port if needed. + # Change the port if necessary. port: 8001 - # When enabled performs weight synchronization across validators in order to improve vTrust. + # When enabled, performs weight synchronization across validators, drastically improves vTrust. enable_receive: True - # When enabled, other validators can request your rewards for their weight averaging. + # When enabled, allows other validators to request your rewards, slightly improves vTrust. enable_send: True From 150628c72ce3800087e2e74cbae6fc22d1e4c9ad Mon Sep 17 00:00:00 2001 From: Dmytro Bobrenko <17252809+dbobrenko@users.noreply.github.com> Date: Sun, 10 Aug 2025 23:05:35 +0200 Subject: [PATCH 26/47] Set main config by default --- validator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/validator.py b/validator.py index 42ecc00ed..177c58ff7 100644 --- a/validator.py +++ b/validator.py @@ -22,8 +22,8 @@ async def read_args() -> argparse.Namespace: parser.add_argument( "-c", "--config", - default="config/testnet.yaml", - # default="config/mainnet.yaml", + # default="config/testnet.yaml", + default="config/mainnet.yaml", help="Config file path (e.g. config/mainnet.yaml).", type=Path, ) From 32b1d63a9717cdba93143d458076ec8a60cd0103 Mon Sep 17 00:00:00 2001 From: Dmytro Bobrenko <17252809+dbobrenko@users.noreply.github.com> Date: Sun, 10 Aug 2025 23:19:29 +0200 Subject: [PATCH 27/47] Add shutdown for weight syncer --- validator.py | 22 +--------------------- 1 file changed, 1 insertion(+), 21 deletions(-) diff --git a/validator.py b/validator.py index 21d2adfd4..8076fec97 100644 --- a/validator.py +++ b/validator.py @@ -10,7 +10,6 @@ from apex.services.deep_research.deep_research_langchain import DeepResearchLangchain from apex.services.llm.llm import LLM from apex.services.websearch.websearch_tavily import WebSearchTavily -from apex.validator.auto_update import autoupdate_loop from apex.validator.logger_db import LoggerDB from apex.validator.miner_sampler import MinerSampler from apex.validator.miner_scorer import MinerScorer @@ -28,12 +27,6 @@ async def read_args() -> argparse.Namespace: help="Config file path (e.g. config/mainnet.yaml).", type=Path, ) - parser.add_argument( - "--no-autoupdate", - action="store_true", - default=False, - help="Disable automatic updates (checks every hour) (default: enabled)", - ) args = parser.parse_args() return args @@ -73,12 +66,6 @@ async def main() -> None: asyncio.create_task(miner_scorer.start_loop()) logger.debug(f"Started miner scorer with interval={miner_scorer.interval}") - # Start autoupdate task if enabled - autoupdate_task = None - if not args.no_autoupdate: - logger.info("Autoupdate enabled - will check for updates every hour") - autoupdate_task = asyncio.create_task(autoupdate_loop()) - llm = LLM(**config.llm.kwargs) logger.debug("Started LLM provider") @@ -101,17 +88,10 @@ async def main() -> None: except BaseException as exc: logger.exception(f"Unknown exception caught, exiting validator: {exc}") finally: - # Cancel autoupdate task if it was started - if autoupdate_task is not None: - autoupdate_task.cancel() - try: - await autoupdate_task - except asyncio.CancelledError: - pass - await chain.shutdown() await logger_db.shutdown() await miner_scorer.shutdown() + await weight_syncer.shutdown() if __name__ == "__main__": From 491c11dbebffa92f8bf5ae45c238afaeed5d40b1 Mon Sep 17 00:00:00 2001 From: Dmytro Bobrenko <17252809+dbobrenko@users.noreply.github.com> Date: Sun, 10 Aug 2025 23:49:05 +0200 Subject: [PATCH 28/47] Add tests to autoupdater --- apex/validator/auto_update.py | 87 -------------- pyproject.toml | 1 + scripts/autoupdater.py | 94 +++++++++++++++ tests/scripts/test_autoupdater.py | 184 ++++++++++++++++++++++++++++++ uv.lock | 14 +++ 5 files changed, 293 insertions(+), 87 deletions(-) delete mode 100644 apex/validator/auto_update.py create mode 100644 tests/scripts/test_autoupdater.py diff --git a/apex/validator/auto_update.py b/apex/validator/auto_update.py deleted file mode 100644 index ab28b551f..000000000 --- a/apex/validator/auto_update.py +++ /dev/null @@ -1,87 +0,0 @@ -import asyncio -import subprocess -import sys -from pathlib import Path -from shlex import split - -from loguru import logger - -ROOT_DIR = Path(__file__).parent.parent - - -def get_version() -> str: - """Extract the version as current git commit hash.""" - result = subprocess.run( - split("git rev-parse HEAD"), - check=True, - capture_output=True, - cwd=ROOT_DIR, - ) - commit = result.stdout.decode().strip() - assert len(commit) == 40, f"Invalid commit hash: {commit}" - return commit[:8] - - -def pull_latest_version() -> None: - """Pull the latest version from git. - - This uses `git pull --rebase`, so if any changes were made to the local repository, - this will try to apply them on top of origin's changes. This is intentional, as we - don't want to overwrite any local changes. However, if there are any conflicts, - this will abort the rebase and return to the original state. - The conflicts are expected to happen rarely since validator is expected - to be used as-is. - """ - try: - subprocess.run(split("git pull --rebase --autostash"), check=True, cwd=ROOT_DIR) - except subprocess.CalledProcessError as exc: - logger.error("Failed to pull, reverting: %s", exc) - subprocess.run(split("git rebase --abort"), check=True, cwd=ROOT_DIR) - - -def upgrade_packages() -> None: - """Upgrade python packages by running `pip install --upgrade -r requirements.txt`. - - Notice: this won't work if some package in `requirements.txt` is downgraded. - Ignored as this is unlikely to happen. - """ - logger.info("Upgrading packages") - try: - subprocess.run( - split(f"{sys.executable} -m pip install -e ."), - check=True, - cwd=ROOT_DIR, - ) - except subprocess.CalledProcessError as exc: - logger.error("Failed to upgrade packages, proceeding anyway. %s", exc) - - -async def autoupdate_loop() -> None: - """Async version of autoupdate that runs alongside the validator. - - Checks for updates every hour and applies them if available. - """ - current_version = latest_version = get_version() - logger.info("Current version: %s", current_version) - - try: - while True: - await asyncio.sleep(3600) # Wait 1 hour between checks - - pull_latest_version() - latest_version = get_version() - logger.info("Latest version: %s", latest_version) - - if latest_version != current_version: - logger.info( - "Upgraded to latest version: %s -> %s, terminating program for restart", - current_version, - latest_version, - ) - upgrade_packages() - logger.info("Program terminating, please run with persistent process manager") - sys.exit(0) - - except asyncio.CancelledError: - logger.info("Autoupdate task cancelled") - raise diff --git a/pyproject.toml b/pyproject.toml index 832ad83e3..c79d27cb0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -203,5 +203,6 @@ dev = [ "pydantic>=2.11.7", "pytest>=8.4.1", "pytest-asyncio>=1.0.0", + "pytest-mock>=3.14.1", "types-pyyaml>=6.0.12.20250516", ] diff --git a/scripts/autoupdater.py b/scripts/autoupdater.py index e69de29bb..dbd467c89 100644 --- a/scripts/autoupdater.py +++ b/scripts/autoupdater.py @@ -0,0 +1,94 @@ +#!/usr/bin/env python3 +import os +import signal +import subprocess +import sys +import time + +CHECK_INTERVAL = 15 * 60 + + +def venv_python() -> str: + return os.path.join(".venv", "bin", "python") + + +def read_python_version() -> str | None: + try: + with open(".python-version", encoding="utf-8") as f: + # Take first non-empty token (pyenv format e.g. "3.11.9"). + return f.read().strip().split()[0] + except FileNotFoundError: + return None + + +def start_proc() -> subprocess.Popen: + py_ver = read_python_version() + if py_ver: + subprocess.run(["uv", "venv", "--python", py_ver], check=True) + else: + subprocess.run(["uv", "venv"], check=True) + + # Install project in dev mode into the venv. + subprocess.run(["uv", "pip", "install", ".[dev]"], check=True) + + # Run validator. + return subprocess.Popen([venv_python(), "validator.py"]) + + +def stop_proc(process: subprocess.Popen) -> None: + if process and process.poll() is None: + process.terminate() + try: + process.wait(timeout=10) + except subprocess.TimeoutExpired: + process.kill() + + +def remote_has_updates() -> bool: + try: + subprocess.run(["git", "fetch", "--quiet"], check=True) + out = subprocess.check_output( + ["git", "rev-list", "--left-right", "--count", "@{u}...HEAD"], stderr=subprocess.STDOUT, text=True + ).strip() + left, right = map(int, out.split()) + # Remote is ahead. + return left > 0 + except subprocess.CalledProcessError: + # No upstream or git issue; treat as no updates. + return False + + +def git_pull_ff_only() -> None: + try: + subprocess.run(["git", "pull", "--ff-only"], check=True) + except subprocess.CalledProcessError as e: + print(f"Error: Git pull failed due to conflicts or other issues: {e}", file=sys.stderr) + print("Staying on the current version.", file=sys.stderr) + + +def main() -> None: + proc = start_proc() + + def handle_sigint(sig, frame): + stop_proc(proc) + sys.exit(0) + + signal.signal(signal.SIGINT, handle_sigint) + + while True: + time.sleep(CHECK_INTERVAL) + print("Checking for updates...") + + # If child exited, propagate its code. + if proc.poll() is not None: + sys.exit(proc.returncode) + + if remote_has_updates(): + print("Updates detected, restaring process") + stop_proc(proc) + git_pull_ff_only() + proc = start_proc() + + +if __name__ == "__main__": + main() diff --git a/tests/scripts/test_autoupdater.py b/tests/scripts/test_autoupdater.py new file mode 100644 index 000000000..5f619fc6c --- /dev/null +++ b/tests/scripts/test_autoupdater.py @@ -0,0 +1,184 @@ +import os +import subprocess +import sys +from unittest import mock + +import pytest + +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../scripts"))) +import autoupdater # isort: skip + + +def test_venv_python(mocker): + mocker.patch("os.path.join", return_value=".venv/bin/python") + assert autoupdater.venv_python() == ".venv/bin/python" + autoupdater.os.path.join.assert_called_once_with(".venv", "bin", "python") + + +def test_read_python_version_exists(mocker): + mocker.patch("autoupdater.open", mocker.mock_open(read_data="3.11.9")) + assert autoupdater.read_python_version() == "3.11.9" + autoupdater.open.assert_called_once_with(".python-version", encoding="utf-8") + + +def test_read_python_version_not_found(mocker): + mocker.patch("autoupdater.open", side_effect=FileNotFoundError) + assert autoupdater.read_python_version() is None + autoupdater.open.assert_called_once_with(".python-version", encoding="utf-8") + + +def test_start_proc_with_version(mocker): + mocker.patch("autoupdater.read_python_version", return_value="3.11.9") + mock_run = mocker.patch("subprocess.run") + mock_popen = mocker.patch("subprocess.Popen", return_value=mocker.Mock()) + mocker.patch("autoupdater.venv_python", return_value="mock_python") + + proc = autoupdater.start_proc() + autoupdater.read_python_version.assert_called_once() + mock_run.assert_has_calls( + [ + mock.call(["uv", "venv", "--python", "3.11.9"], check=True), + mock.call(["uv", "pip", "install", ".[dev]"], check=True), + ] + ) + mock_popen.assert_called_once_with(["mock_python", "validator.py"]) + assert proc is not None + + +def test_start_proc_without_version(mocker): + mocker.patch("autoupdater.read_python_version", return_value=None) + mock_run = mocker.patch("subprocess.run") + mock_popen = mocker.patch("subprocess.Popen", return_value=mocker.Mock()) + mocker.patch("autoupdater.venv_python", return_value="mock_python") + + proc = autoupdater.start_proc() + autoupdater.read_python_version.assert_called_once() + mock_run.assert_has_calls( + [ + mock.call(["uv", "venv"], check=True), + mock.call(["uv", "pip", "install", ".[dev]"], check=True), + ] + ) + mock_popen.assert_called_once_with(["mock_python", "validator.py"]) + assert proc is not None + + +def test_stop_proc_running(): + mock_proc = mock.Mock() + mock_proc.poll.return_value = None # Process is running + autoupdater.stop_proc(mock_proc) + mock_proc.terminate.assert_called_once() + mock_proc.wait.assert_called_once_with(timeout=10) + mock_proc.kill.assert_not_called() + + +def test_stop_proc_timeout(): + mock_proc = mock.Mock() + mock_proc.poll.return_value = None # Process is running + mock_proc.wait.side_effect = subprocess.TimeoutExpired(cmd="test", timeout=10) + autoupdater.stop_proc(mock_proc) + mock_proc.terminate.assert_called_once() + mock_proc.wait.assert_called_once_with(timeout=10) + mock_proc.kill.assert_called_once() + + +def test_stop_proc_already_stopped(): + mock_proc = mock.Mock() + mock_proc.poll.return_value = 0 # Process already stopped + autoupdater.stop_proc(mock_proc) + mock_proc.terminate.assert_not_called() + mock_proc.wait.assert_not_called() + mock_proc.kill.assert_not_called() + + +def test_remote_has_updates_true(mocker): + mock_run = mocker.patch("subprocess.run") + mock_check_output = mocker.patch("subprocess.check_output", return_value="1\t0") + assert autoupdater.remote_has_updates() is True + mock_run.assert_called_once_with(["git", "fetch", "--quiet"], check=True) + mock_check_output.assert_called_once_with( + ["git", "rev-list", "--left-right", "--count", "@{u}...HEAD"], stderr=subprocess.STDOUT, text=True + ) + + +def test_remote_has_updates_false_no_diff(mocker): + mocker.patch("subprocess.run") + mocker.patch("subprocess.check_output", return_value="0\t0") + assert autoupdater.remote_has_updates() is False + + +def test_remote_has_updates_error(mocker): + mock_run = mocker.patch("subprocess.run") + mocker.patch("subprocess.check_output", side_effect=subprocess.CalledProcessError(1, "cmd")) + assert autoupdater.remote_has_updates() is False + mock_run.assert_called_once_with(["git", "fetch", "--quiet"], check=True) + autoupdater.subprocess.check_output.assert_called_once() + + +def test_git_pull_ff_only_success(mocker): + mock_run = mocker.patch("subprocess.run") + autoupdater.git_pull_ff_only() + mock_run.assert_called_once_with(["git", "pull", "--ff-only"], check=True) + + +def test_git_pull_ff_only_conflict(mocker): + mocker.patch("subprocess.run", side_effect=subprocess.CalledProcessError(1, "cmd", stderr="conflict")) + mock_stderr = mocker.patch("sys.stderr", new_callable=mock.MagicMock) + autoupdater.git_pull_ff_only() + autoupdater.subprocess.run.assert_called_once_with(["git", "pull", "--ff-only"], check=True) + # Check for individual calls to write, as print adds newlines separately + mock_stderr.write.assert_any_call( + "Error: Git pull failed due to conflicts or other issues: Command 'cmd' returned non-zero exit status 1." + ) + mock_stderr.write.assert_any_call("\n") + mock_stderr.write.assert_any_call("Staying on the current version.") + mock_stderr.write.assert_any_call("\n") + + +def test_main_loop_with_update(mocker): + mock_start_proc = mocker.patch("autoupdater.start_proc", return_value=mocker.Mock()) + mock_start_proc.return_value.returncode = 0 # Ensure returncode is 0 for sys.exit(0) + mock_stop_proc = mocker.patch("autoupdater.stop_proc") + mock_remote_updates = mocker.patch("autoupdater.remote_has_updates", side_effect=[False, True, False]) + mock_git_pull = mocker.patch("autoupdater.git_pull_ff_only") + mock_sleep = mocker.patch( + "time.sleep", side_effect=[None, None, Exception("StopLoop")] + ) # Allow 2 calls to remote_has_updates + mock_sys_exit = mocker.patch("sys.exit") + + mock_start_proc.return_value.poll.side_effect = [None, 0] # proc.poll() returns 0 on second iteration + + with pytest.raises(Exception) as cm: + autoupdater.main() + assert str(cm.value) == "StopLoop" + + assert mock_sleep.call_count == 3 + mock_remote_updates.assert_has_calls([mock.call(), mock.call()]) + mock_stop_proc.assert_called_once_with(mock_start_proc.return_value) + mock_git_pull.assert_called_once() + mock_start_proc.assert_called_with() # Called initially and after update + mock_sys_exit.assert_called_once_with(0) + + +def test_main_loop_no_update(mocker): + mock_start_proc = mocker.patch("autoupdater.start_proc", return_value=mocker.Mock()) + mock_start_proc.return_value.returncode = 0 # Ensure returncode is 0 for sys.exit(0) + mock_stop_proc = mocker.patch("autoupdater.stop_proc") + mock_remote_updates = mocker.patch("autoupdater.remote_has_updates", return_value=False) + mock_git_pull = mocker.patch("autoupdater.git_pull_ff_only") + mock_sleep = mocker.patch( + "time.sleep", side_effect=[None, None, Exception("StopLoop")] + ) # Allow 2 calls to remote_has_updates + mock_sys_exit = mocker.patch("sys.exit") + + mock_start_proc.return_value.poll.side_effect = [None, 0] # proc.poll() returns 0 on second iteration + + with pytest.raises(Exception) as cm: + autoupdater.main() + assert str(cm.value) == "StopLoop" + + assert mock_sleep.call_count == 3 + assert mock_remote_updates.call_count == 2 + mock_stop_proc.assert_not_called() + mock_git_pull.assert_not_called() + mock_sys_exit.assert_called_once_with(0) diff --git a/uv.lock b/uv.lock index f81505ed0..5b5774d5b 100644 --- a/uv.lock +++ b/uv.lock @@ -196,6 +196,7 @@ dev = [ { name = "pydantic" }, { name = "pytest" }, { name = "pytest-asyncio" }, + { name = "pytest-mock" }, { name = "types-pyyaml" }, ] @@ -249,6 +250,7 @@ dev = [ { name = "pydantic", specifier = ">=2.11.7" }, { name = "pytest", specifier = ">=8.4.1" }, { name = "pytest-asyncio", specifier = ">=1.0.0" }, + { name = "pytest-mock", specifier = ">=3.14.1" }, { name = "types-pyyaml", specifier = ">=6.0.12.20250516" }, ] @@ -2414,6 +2416,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/bc/16/4ea354101abb1287856baa4af2732be351c7bee728065aed451b678153fd/pytest_cov-6.2.1-py3-none-any.whl", hash = "sha256:f5bc4c23f42f1cdd23c70b1dab1bbaef4fc505ba950d53e0081d0730dd7e86d5", size = 24644 }, ] +[[package]] +name = "pytest-mock" +version = "3.14.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pytest" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/71/28/67172c96ba684058a4d24ffe144d64783d2a270d0af0d9e792737bddc75c/pytest_mock-3.14.1.tar.gz", hash = "sha256:159e9edac4c451ce77a5cdb9fc5d1100708d2dd4ba3c3df572f14097351af80e", size = 33241 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b2/05/77b60e520511c53d1c1ca75f1930c7dd8e971d0c4379b7f4b3f9644685ba/pytest_mock-3.14.1-py3-none-any.whl", hash = "sha256:178aefcd11307d874b4cd3100344e7e2d888d9791a6a1d9bfe90fbc1b74fd1d0", size = 9923 }, +] + [[package]] name = "python-dotenv" version = "1.1.1" From 606b91245a06be700ae172f6eb3f4fb4aec568d1 Mon Sep 17 00:00:00 2001 From: Dmytro Bobrenko <17252809+dbobrenko@users.noreply.github.com> Date: Sun, 10 Aug 2025 23:53:44 +0200 Subject: [PATCH 29/47] Clean up packages --- pyproject.toml | 13 +++++-------- uv.lock | 20 ++++++-------------- 2 files changed, 11 insertions(+), 22 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index c79d27cb0..898b171df 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,10 +16,11 @@ dependencies = [ "aiohttp>=3.8.3", "beautifulsoup4>=4.13.3", "langchain>=0.3.26", + "langchain-core>=0.3.68", "langchain-community>=0.0.59", - "faiss-cpu>=1.8.0", "langchain-openai>=0.3.28", "langchain-sandbox>=0.0.6", + "faiss-cpu>=1.8.0", "dotenv>=0.9.9", "rich>=14.0.0", "loguru>=0.7.3", @@ -29,6 +30,9 @@ dependencies = [ "rouge>=1.0.1", "substrate-interface>=1.7.11", "types-netaddr>=1.3.0.20240530", + "types-pyyaml>=6.0.12.20250516", + "types-cachetools>=6.0.0.20250525", + "dotenv>=0.9.9", ] @@ -36,13 +40,6 @@ dependencies = [ dev = [ "mypy==1.17.0", "ruff==0.12.5", - "types-pyyaml>=6.0.12.20250516", - "types-cachetools>=6.0.0.20250525", - "langchain>=0.3.26", - "dotenv>=0.9.9", - "langchain-openai>=0.3.28", - "langchain-core>=0.3.68", - "langchain-sandbox>=0.0.6", "pytest>=8.4.1", "pytest-asyncio>=1.0.0", "pytest-cov>=5.0.0", diff --git a/uv.lock b/uv.lock index 5b5774d5b..eccdbe19c 100644 --- a/uv.lock +++ b/uv.lock @@ -153,6 +153,7 @@ dependencies = [ { name = "faiss-cpu" }, { name = "langchain" }, { name = "langchain-community" }, + { name = "langchain-core" }, { name = "langchain-openai" }, { name = "langchain-sandbox" }, { name = "loguru" }, @@ -168,24 +169,19 @@ dependencies = [ { name = "rouge" }, { name = "substrate-interface" }, { name = "tavily-python" }, + { name = "types-cachetools" }, { name = "types-netaddr" }, + { name = "types-pyyaml" }, ] [package.optional-dependencies] dev = [ - { name = "dotenv" }, - { name = "langchain" }, - { name = "langchain-core" }, - { name = "langchain-openai" }, - { name = "langchain-sandbox" }, { name = "mypy" }, { name = "pre-commit" }, { name = "pytest" }, { name = "pytest-asyncio" }, { name = "pytest-cov" }, { name = "ruff" }, - { name = "types-cachetools" }, - { name = "types-pyyaml" }, ] [package.dev-dependencies] @@ -208,16 +204,12 @@ requires-dist = [ { name = "bittensor", specifier = ">=9.7.0" }, { name = "cachetools", specifier = ">=5.0.0" }, { name = "dotenv", specifier = ">=0.9.9" }, - { name = "dotenv", marker = "extra == 'dev'", specifier = ">=0.9.9" }, { name = "faiss-cpu", specifier = ">=1.8.0" }, { name = "langchain", specifier = ">=0.3.26" }, - { name = "langchain", marker = "extra == 'dev'", specifier = ">=0.3.26" }, { name = "langchain-community", specifier = ">=0.0.59" }, - { name = "langchain-core", marker = "extra == 'dev'", specifier = ">=0.3.68" }, + { name = "langchain-core", specifier = ">=0.3.68" }, { name = "langchain-openai", specifier = ">=0.3.28" }, - { name = "langchain-openai", marker = "extra == 'dev'", specifier = ">=0.3.28" }, { name = "langchain-sandbox", specifier = ">=0.0.6" }, - { name = "langchain-sandbox", marker = "extra == 'dev'", specifier = ">=0.0.6" }, { name = "loguru", specifier = ">=0.7.3" }, { name = "macrocosmos", specifier = ">=0.1.0" }, { name = "mypy", marker = "extra == 'dev'", specifier = "==1.17.0" }, @@ -236,9 +228,9 @@ requires-dist = [ { name = "ruff", marker = "extra == 'dev'", specifier = "==0.12.5" }, { name = "substrate-interface", specifier = ">=1.7.11" }, { name = "tavily-python", specifier = ">=0.7.10" }, - { name = "types-cachetools", marker = "extra == 'dev'", specifier = ">=6.0.0.20250525" }, + { name = "types-cachetools", specifier = ">=6.0.0.20250525" }, { name = "types-netaddr", specifier = ">=1.3.0.20240530" }, - { name = "types-pyyaml", marker = "extra == 'dev'", specifier = ">=6.0.12.20250516" }, + { name = "types-pyyaml", specifier = ">=6.0.12.20250516" }, ] provides-extras = ["dev"] From dddeadaf1073a4de7c236d844296ba6074380e36 Mon Sep 17 00:00:00 2001 From: Dmytro Bobrenko <17252809+dbobrenko@users.noreply.github.com> Date: Mon, 11 Aug 2025 00:02:22 +0200 Subject: [PATCH 30/47] Testing autoupdater --- testing_updater.txt | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 testing_updater.txt diff --git a/testing_updater.txt b/testing_updater.txt new file mode 100644 index 000000000..e69de29bb From 8aff43487128d9f01eb570b85aa2c256e81871c4 Mon Sep 17 00:00:00 2001 From: Dmytro Bobrenko <17252809+dbobrenko@users.noreply.github.com> Date: Mon, 11 Aug 2025 00:25:37 +0200 Subject: [PATCH 31/47] Testing autoupdater 2 --- testing_updater.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/testing_updater.txt b/testing_updater.txt index e69de29bb..9daeafb98 100644 --- a/testing_updater.txt +++ b/testing_updater.txt @@ -0,0 +1 @@ +test From b38f591905e582f27a268a335604cf24d1771aa4 Mon Sep 17 00:00:00 2001 From: Dmytro Bobrenko <17252809+dbobrenko@users.noreply.github.com> Date: Mon, 11 Aug 2025 00:26:37 +0200 Subject: [PATCH 32/47] Cleaning up files --- testing_updater.txt | 1 - 1 file changed, 1 deletion(-) delete mode 100644 testing_updater.txt diff --git a/testing_updater.txt b/testing_updater.txt deleted file mode 100644 index 9daeafb98..000000000 --- a/testing_updater.txt +++ /dev/null @@ -1 +0,0 @@ -test From 10072e47ff32a5c484136fd7575ab7777a8e8d1f Mon Sep 17 00:00:00 2001 From: Dmytro Bobrenko <17252809+dbobrenko@users.noreply.github.com> Date: Mon, 11 Aug 2025 00:37:17 +0200 Subject: [PATCH 33/47] Add pm2 script --- scripts/autoupdater_pm2.sh | 73 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 73 insertions(+) create mode 100644 scripts/autoupdater_pm2.sh diff --git a/scripts/autoupdater_pm2.sh b/scripts/autoupdater_pm2.sh new file mode 100644 index 000000000..758cd9693 --- /dev/null +++ b/scripts/autoupdater_pm2.sh @@ -0,0 +1,73 @@ +#!/usr/bin/env bash +set -euo pipefail + +APP_NAME="sn1" +PY_VERSION_FILE=".python-version" +UV_INSTALL_URL="https://astral.sh/uv/install.sh" + +# Ensure common user bin dirs are in PATH. +export PATH="$HOME/.local/bin:$HOME/.cargo/bin:$HOME/bin:$PATH" + +# 1. Ensure uv exists. +if ! command -v uv >/dev/null 2>&1; then + echo "[info] uv not found; installing..." + if ! command -v curl >/dev/null 2>&1; then + echo "[error] curl is required to install uv." >&2 + exit 1 + fi + curl -LsSf "$UV_INSTALL_URL" | sh + export PATH="$HOME/.local/bin:$HOME/.cargo/bin:$HOME/bin:$PATH" + hash -r + if ! command -v uv >/dev/null 2>&1; then + echo "[error] uv installation completed but 'uv' not found in PATH." >&2 + exit 1 + fi +else + echo "[info] uv found: $(command -v uv)" +fi + +# 2. Ensure pm2 exists. +if ! command -v pm2 >/dev/null 2>&1; then + echo "[info] pm2 not found; installing globally with npm..." + if ! command -v npm >/dev/null 2>&1; then + echo "[error] npm is required to install pm2. Please install Node.js first." >&2 + exit 1 + fi + npm install -g pm2 + export PATH="$HOME/.npm-global/bin:$PATH" + hash -r + if ! command -v pm2 >/dev/null 2>&1; then + echo "[error] pm2 installation completed but 'pm2' not found in PATH." >&2 + exit 1 + fi +else + echo "[info] pm2 found: $(command -v pm2)" +fi + +# 3. Determine Python version. +PY_VERSION_FILE=".python-version" +PY_VER="3.11" +if [[ -s "$PY_VERSION_FILE" ]]; then + CANDIDATE="$(tr -d ' \t\r\n' < "$PY_VERSION_FILE")" + if [[ "$CANDIDATE" =~ ^[0-9]+(\.[0-9]+){0,2}$ ]]; then + PY_VER="$CANDIDATE" + +echo "[info] Using Python version: $PY_VER" +# 4. Create/update venv and install deps. +echo "[info] Creating/refreshing .venv with uv…" +uv venv --python="$PY_VER" +if [[ ! -x ".venv/bin/python" ]]; then + echo "[error] .venv/bin/python not created. Check that uv can resolve Python $PY_VER." >&2 + exit 1 +fi + +echo "[info] Installing project dependencies…" +uv pip install '.[dev]' + +# 5. Start with pm2. +pm2 delete "$APP_NAME" >/dev/null 2>&1 || true +pm2 start ".venv/bin/python" --name "$APP_NAME" -- \ + scripts/autoupdater.py -c configs/mainnet.yaml + +echo "[done] pm2 process '$APP_NAME' started." +pm2 status "$APP_NAME" From 014064b21a2cde9e68c8c65e987eff72dd3d4d49 Mon Sep 17 00:00:00 2001 From: Dmytro Bobrenko <17252809+dbobrenko@users.noreply.github.com> Date: Mon, 11 Aug 2025 00:41:26 +0200 Subject: [PATCH 34/47] Fix pm2 --- scripts/autoupdater_pm2.sh | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/scripts/autoupdater_pm2.sh b/scripts/autoupdater_pm2.sh index 758cd9693..3290b0856 100644 --- a/scripts/autoupdater_pm2.sh +++ b/scripts/autoupdater_pm2.sh @@ -4,11 +4,12 @@ set -euo pipefail APP_NAME="sn1" PY_VERSION_FILE=".python-version" UV_INSTALL_URL="https://astral.sh/uv/install.sh" +CONFIG="configs/mainnet.yaml" # Ensure common user bin dirs are in PATH. -export PATH="$HOME/.local/bin:$HOME/.cargo/bin:$HOME/bin:$PATH" +export PATH="$HOME/.local/bin:$HOME/.cargo/bin:$HOME/bin:$HOME/.npm-global/bin:$PATH" -# 1. Ensure uv exists. +# 1) Ensure uv exists if ! command -v uv >/dev/null 2>&1; then echo "[info] uv not found; installing..." if ! command -v curl >/dev/null 2>&1; then @@ -16,7 +17,6 @@ if ! command -v uv >/dev/null 2>&1; then exit 1 fi curl -LsSf "$UV_INSTALL_URL" | sh - export PATH="$HOME/.local/bin:$HOME/.cargo/bin:$HOME/bin:$PATH" hash -r if ! command -v uv >/dev/null 2>&1; then echo "[error] uv installation completed but 'uv' not found in PATH." >&2 @@ -26,7 +26,7 @@ else echo "[info] uv found: $(command -v uv)" fi -# 2. Ensure pm2 exists. +# 2) Ensure pm2 exists if ! command -v pm2 >/dev/null 2>&1; then echo "[info] pm2 not found; installing globally with npm..." if ! command -v npm >/dev/null 2>&1; then @@ -34,7 +34,6 @@ if ! command -v pm2 >/dev/null 2>&1; then exit 1 fi npm install -g pm2 - export PATH="$HOME/.npm-global/bin:$PATH" hash -r if ! command -v pm2 >/dev/null 2>&1; then echo "[error] pm2 installation completed but 'pm2' not found in PATH." >&2 @@ -44,16 +43,21 @@ else echo "[info] pm2 found: $(command -v pm2)" fi -# 3. Determine Python version. -PY_VERSION_FILE=".python-version" +# 3) Determine Python version (read from .python-version; fallback 3.11) PY_VER="3.11" if [[ -s "$PY_VERSION_FILE" ]]; then CANDIDATE="$(tr -d ' \t\r\n' < "$PY_VERSION_FILE")" if [[ "$CANDIDATE" =~ ^[0-9]+(\.[0-9]+){0,2}$ ]]; then PY_VER="$CANDIDATE" - + else + echo "[warn] Invalid version in $PY_VERSION_FILE: '$CANDIDATE' — using $PY_VER" + fi +else + echo "[warn] $PY_VERSION_FILE missing or empty — using $PY_VER" +fi echo "[info] Using Python version: $PY_VER" -# 4. Create/update venv and install deps. + +# 4) Create/update venv and install deps echo "[info] Creating/refreshing .venv with uv…" uv venv --python="$PY_VER" if [[ ! -x ".venv/bin/python" ]]; then @@ -64,10 +68,10 @@ fi echo "[info] Installing project dependencies…" uv pip install '.[dev]' -# 5. Start with pm2. +# 5) Start with pm2 pm2 delete "$APP_NAME" >/dev/null 2>&1 || true pm2 start ".venv/bin/python" --name "$APP_NAME" -- \ - scripts/autoupdater.py -c configs/mainnet.yaml + scripts/autoupdater.py -c "$CONFIG" echo "[done] pm2 process '$APP_NAME' started." pm2 status "$APP_NAME" From 893d2ffd32f89362a63c9f50c33a8f66054d3773 Mon Sep 17 00:00:00 2001 From: Dmytro Bobrenko <17252809+dbobrenko@users.noreply.github.com> Date: Mon, 11 Aug 2025 01:10:10 +0200 Subject: [PATCH 35/47] Fix tests --- scripts/autoupdater.py | 25 +++++++++-- scripts/autoupdater_pm2.sh | 7 ++- tests/scripts/test_autoupdater.py | 73 +++++++++++++++++++------------ 3 files changed, 70 insertions(+), 35 deletions(-) diff --git a/scripts/autoupdater.py b/scripts/autoupdater.py index dbd467c89..101753e20 100644 --- a/scripts/autoupdater.py +++ b/scripts/autoupdater.py @@ -1,9 +1,11 @@ #!/usr/bin/env python3 +import argparse import os import signal import subprocess import sys import time +from pathlib import Path CHECK_INTERVAL = 15 * 60 @@ -21,7 +23,7 @@ def read_python_version() -> str | None: return None -def start_proc() -> subprocess.Popen: +def start_proc(config: Path) -> subprocess.Popen: py_ver = read_python_version() if py_ver: subprocess.run(["uv", "venv", "--python", py_ver], check=True) @@ -32,7 +34,7 @@ def start_proc() -> subprocess.Popen: subprocess.run(["uv", "pip", "install", ".[dev]"], check=True) # Run validator. - return subprocess.Popen([venv_python(), "validator.py"]) + return subprocess.Popen([venv_python(), "validator.py", "-c", str(config)]) def stop_proc(process: subprocess.Popen) -> None: @@ -66,8 +68,23 @@ def git_pull_ff_only() -> None: print("Staying on the current version.", file=sys.stderr) +def read_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Apex validator") + parser.add_argument( + "-c", + "--config", + # default="config/testnet.yaml", + default="config/mainnet.yaml", + help="Config file path (e.g. config/mainnet.yaml).", + type=Path, + ) + args = parser.parse_args() + return args + + def main() -> None: - proc = start_proc() + args = read_args() + proc = start_proc(config=args.config) def handle_sigint(sig, frame): stop_proc(proc) @@ -87,7 +104,7 @@ def handle_sigint(sig, frame): print("Updates detected, restaring process") stop_proc(proc) git_pull_ff_only() - proc = start_proc() + proc = start_proc(config=args.config) if __name__ == "__main__": diff --git a/scripts/autoupdater_pm2.sh b/scripts/autoupdater_pm2.sh index 3290b0856..209186740 100644 --- a/scripts/autoupdater_pm2.sh +++ b/scripts/autoupdater_pm2.sh @@ -4,7 +4,7 @@ set -euo pipefail APP_NAME="sn1" PY_VERSION_FILE=".python-version" UV_INSTALL_URL="https://astral.sh/uv/install.sh" -CONFIG="configs/mainnet.yaml" +CONFIG="config/mainnet.yaml" # Ensure common user bin dirs are in PATH. export PATH="$HOME/.local/bin:$HOME/.cargo/bin:$HOME/bin:$HOME/.npm-global/bin:$PATH" @@ -70,8 +70,7 @@ uv pip install '.[dev]' # 5) Start with pm2 pm2 delete "$APP_NAME" >/dev/null 2>&1 || true -pm2 start ".venv/bin/python" --name "$APP_NAME" -- \ - scripts/autoupdater.py -c "$CONFIG" +pm2 start ".venv/bin/python" --name "$APP_NAME" -- scripts/autoupdater.py -c "$CONFIG" echo "[done] pm2 process '$APP_NAME' started." -pm2 status "$APP_NAME" +pm2 logs "$APP_NAME" diff --git a/tests/scripts/test_autoupdater.py b/tests/scripts/test_autoupdater.py index 5f619fc6c..84412aa7d 100644 --- a/tests/scripts/test_autoupdater.py +++ b/tests/scripts/test_autoupdater.py @@ -10,7 +10,7 @@ def test_venv_python(mocker): - mocker.patch("os.path.join", return_value=".venv/bin/python") + mocker.patch("autoupdater.os.path.join", return_value=".venv/bin/python") assert autoupdater.venv_python() == ".venv/bin/python" autoupdater.os.path.join.assert_called_once_with(".venv", "bin", "python") @@ -33,7 +33,7 @@ def test_start_proc_with_version(mocker): mock_popen = mocker.patch("subprocess.Popen", return_value=mocker.Mock()) mocker.patch("autoupdater.venv_python", return_value="mock_python") - proc = autoupdater.start_proc() + proc = autoupdater.start_proc(config="mock.yaml") autoupdater.read_python_version.assert_called_once() mock_run.assert_has_calls( [ @@ -41,7 +41,7 @@ def test_start_proc_with_version(mocker): mock.call(["uv", "pip", "install", ".[dev]"], check=True), ] ) - mock_popen.assert_called_once_with(["mock_python", "validator.py"]) + mock_popen.assert_called_once_with(["mock_python", "validator.py", "-c", "mock.yaml"]) assert proc is not None @@ -51,7 +51,7 @@ def test_start_proc_without_version(mocker): mock_popen = mocker.patch("subprocess.Popen", return_value=mocker.Mock()) mocker.patch("autoupdater.venv_python", return_value="mock_python") - proc = autoupdater.start_proc() + proc = autoupdater.start_proc(config="mock.yaml") autoupdater.read_python_version.assert_called_once() mock_run.assert_has_calls( [ @@ -59,7 +59,7 @@ def test_start_proc_without_version(mocker): mock.call(["uv", "pip", "install", ".[dev]"], check=True), ] ) - mock_popen.assert_called_once_with(["mock_python", "validator.py"]) + mock_popen.assert_called_once_with(["mock_python", "validator.py", "-c", "mock.yaml"]) assert proc is not None @@ -136,49 +136,68 @@ def test_git_pull_ff_only_conflict(mocker): def test_main_loop_with_update(mocker): + # start_proc returns the same Mock proc; we drive its poll() via side_effect mock_start_proc = mocker.patch("autoupdater.start_proc", return_value=mocker.Mock()) - mock_start_proc.return_value.returncode = 0 # Ensure returncode is 0 for sys.exit(0) + mock_start_proc.return_value.returncode = 0 mock_stop_proc = mocker.patch("autoupdater.stop_proc") - mock_remote_updates = mocker.patch("autoupdater.remote_has_updates", side_effect=[False, True, False]) + mock_remote_updates = mocker.patch("autoupdater.remote_has_updates", side_effect=[False, True]) mock_git_pull = mocker.patch("autoupdater.git_pull_ff_only") - mock_sleep = mocker.patch( - "time.sleep", side_effect=[None, None, Exception("StopLoop")] - ) # Allow 2 calls to remote_has_updates - mock_sys_exit = mocker.patch("sys.exit") + mock_sleep = mocker.patch("time.sleep", return_value=None) - mock_start_proc.return_value.poll.side_effect = [None, 0] # proc.poll() returns 0 on second iteration + def _raise(code=0): + raise SystemExit(code) - with pytest.raises(Exception) as cm: + mock_sys_exit = mocker.patch("sys.exit", side_effect=_raise) + + mock_args = mocker.Mock() + mock_args.config = "mock.yaml" + mocker.patch("autoupdater.read_args", return_value=mock_args) + + # First loop: running (None), no update + # Second loop: still running (None), update happens -> restart + # Third loop: process seen as exited (0) -> sys.exit(0) + mock_start_proc.return_value.poll.side_effect = [None, None, 0] + + with pytest.raises(SystemExit) as pytest_wrapped_e: autoupdater.main() - assert str(cm.value) == "StopLoop" + assert pytest_wrapped_e.type is SystemExit + assert pytest_wrapped_e.value.code == 0 - assert mock_sleep.call_count == 3 - mock_remote_updates.assert_has_calls([mock.call(), mock.call()]) + assert mock_sleep.call_count == 3 # before each iteration, including after restart + assert mock_remote_updates.call_count == 2 mock_stop_proc.assert_called_once_with(mock_start_proc.return_value) mock_git_pull.assert_called_once() - mock_start_proc.assert_called_with() # Called initially and after update + mock_start_proc.assert_called_with(config=mock.ANY) mock_sys_exit.assert_called_once_with(0) def test_main_loop_no_update(mocker): mock_start_proc = mocker.patch("autoupdater.start_proc", return_value=mocker.Mock()) - mock_start_proc.return_value.returncode = 0 # Ensure returncode is 0 for sys.exit(0) + mock_start_proc.return_value.returncode = 0 mock_stop_proc = mocker.patch("autoupdater.stop_proc") mock_remote_updates = mocker.patch("autoupdater.remote_has_updates", return_value=False) mock_git_pull = mocker.patch("autoupdater.git_pull_ff_only") - mock_sleep = mocker.patch( - "time.sleep", side_effect=[None, None, Exception("StopLoop")] - ) # Allow 2 calls to remote_has_updates - mock_sys_exit = mocker.patch("sys.exit") + mock_sleep = mocker.patch("time.sleep", return_value=None) - mock_start_proc.return_value.poll.side_effect = [None, 0] # proc.poll() returns 0 on second iteration + def _raise(code=0): + raise SystemExit(code) - with pytest.raises(Exception) as cm: + mock_sys_exit = mocker.patch("sys.exit", side_effect=_raise) + + mock_args = mocker.Mock() + mock_args.config = "mock.yaml" + mocker.patch("autoupdater.read_args", return_value=mock_args) + + # First loop running; second loop exits -> sys.exit(0) before checking for updates + mock_start_proc.return_value.poll.side_effect = [None, 0] + + with pytest.raises(SystemExit) as pytest_wrapped_e: autoupdater.main() - assert str(cm.value) == "StopLoop" + assert pytest_wrapped_e.type is SystemExit + assert pytest_wrapped_e.value.code == 0 - assert mock_sleep.call_count == 3 - assert mock_remote_updates.call_count == 2 + assert mock_sleep.call_count == 2 + assert mock_remote_updates.call_count == 1 # second loop exits before calling this mock_stop_proc.assert_not_called() mock_git_pull.assert_not_called() mock_sys_exit.assert_called_once_with(0) From 9b8c1107e910bca2e54dc39c7ca32c79490eaff2 Mon Sep 17 00:00:00 2001 From: Dmytro Bobrenko <17252809+dbobrenko@users.noreply.github.com> Date: Mon, 11 Aug 2025 01:21:51 +0200 Subject: [PATCH 36/47] Add deps --- pyproject.toml | 1 + scripts/autoupdater.py | 3 ++- scripts/autoupdater_pm2.sh | 3 ++- uv.lock | 2 ++ 4 files changed, 7 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 898b171df..f8602d46a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,6 +33,7 @@ dependencies = [ "types-pyyaml>=6.0.12.20250516", "types-cachetools>=6.0.0.20250525", "dotenv>=0.9.9", + "pytest-mock>=3.14.1", ] diff --git a/scripts/autoupdater.py b/scripts/autoupdater.py index 101753e20..df87d9391 100644 --- a/scripts/autoupdater.py +++ b/scripts/autoupdater.py @@ -7,7 +7,8 @@ import time from pathlib import Path -CHECK_INTERVAL = 15 * 60 +# CHECK_INTERVAL = 15 * 60 +CHECK_INTERVAL = 15 def venv_python() -> str: diff --git a/scripts/autoupdater_pm2.sh b/scripts/autoupdater_pm2.sh index 209186740..4cb62200d 100644 --- a/scripts/autoupdater_pm2.sh +++ b/scripts/autoupdater_pm2.sh @@ -4,7 +4,8 @@ set -euo pipefail APP_NAME="sn1" PY_VERSION_FILE=".python-version" UV_INSTALL_URL="https://astral.sh/uv/install.sh" -CONFIG="config/mainnet.yaml" +# CONFIG="config/mainnet.yaml" +CONFIG="config/testnet.yaml" # Ensure common user bin dirs are in PATH. export PATH="$HOME/.local/bin:$HOME/.cargo/bin:$HOME/bin:$HOME/.npm-global/bin:$PATH" diff --git a/uv.lock b/uv.lock index eccdbe19c..70e3f580f 100644 --- a/uv.lock +++ b/uv.lock @@ -163,6 +163,7 @@ dependencies = [ { name = "numpy", version = "2.3.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.13'" }, { name = "pip" }, { name = "pydantic" }, + { name = "pytest-mock" }, { name = "pyyaml" }, { name = "requests" }, { name = "rich" }, @@ -221,6 +222,7 @@ requires-dist = [ { name = "pytest", marker = "extra == 'dev'", specifier = ">=8.4.1" }, { name = "pytest-asyncio", marker = "extra == 'dev'", specifier = ">=1.0.0" }, { name = "pytest-cov", marker = "extra == 'dev'", specifier = ">=5.0.0" }, + { name = "pytest-mock", specifier = ">=3.14.1" }, { name = "pyyaml", specifier = ">=6.0.0" }, { name = "requests", specifier = ">=2.31.0" }, { name = "rich", specifier = ">=14.0.0" }, From 5be01c146085bc60421a0b97726f63bff22b13b0 Mon Sep 17 00:00:00 2001 From: Dmytro Bobrenko <17252809+dbobrenko@users.noreply.github.com> Date: Mon, 11 Aug 2025 01:25:25 +0200 Subject: [PATCH 37/47] Update README --- README.md | 33 +++++++++++++-------------------- 1 file changed, 13 insertions(+), 20 deletions(-) diff --git a/README.md b/README.md index 64ad1e079..7b63eff9b 100644 --- a/README.md +++ b/README.md @@ -24,7 +24,7 @@ Subnet 1 is the most intelligent inference model on Bittensor. As the first agen --- -## Installation +## Run Validator 1. **Clone the repository:** ```bash @@ -32,33 +32,26 @@ Subnet 1 is the most intelligent inference model on Bittensor. As the first agen cd apex ``` -2. **Install `uv`:** - Follow the instructions at [https://github.com/astral-sh/uv](https://github.com/astral-sh/uv) to install `uv`. For example: + +2. **Prepare config file:** ```bash - curl -LsSf https://astral.sh/uv/install.sh | sh + cp config/mainnet.yaml.example config/mainnet.yaml + # Fill in the required values in config/mainnet.yaml ``` -3. **Install the project and its development dependencies:** +3. **[Recommended] Run validator with auto-updater:** ```bash - uv venv --python=3.11 && uv pip install '.[dev]' + python scripts/autoupdater.py -c config/mainnet.yaml ``` -4. **Activate python environment:** - ```bash - . .venv/bin/activate - ``` - -## Run Mainnet Validator - -1. Prepare config file: +3a. **[Alternative #1] Run validator with pm2 and auto-updater:** ```bash - cp config/mainnet.yaml.example config/mainnet.yaml - # Fill in the required values in config/mainnet.yaml + bash scripts/autoupdater_pm2.sh ``` -2. **Run the validator:** +3b. **[Alternative #2] Install dependencies and run validator without auto-updater:** ```bash - python validator.py -c config/mainnet.yaml + uv venv --python 3.11 && uv pip install '.[dev]' && python validator.py -c config/mainnet.yaml ``` ## Run Testnet Validator @@ -69,9 +62,9 @@ Subnet 1 is the most intelligent inference model on Bittensor. As the first agen # Fill in the required values in config/testnet.yaml ``` -2. **Run the validator:** +2. Install dependencies and run validator: ```bash - python validator.py -c config/testnet.yaml + uv venv --python 3.11 && uv pip install '.[dev]' && python validator.py -c config/testnet.yaml ``` ## Base Miner (for showcase purposes only) From 21fcc3aaf74f5fb92e64b9bb4a52eaa05571396b Mon Sep 17 00:00:00 2001 From: Dmytro Bobrenko <17252809+dbobrenko@users.noreply.github.com> Date: Mon, 11 Aug 2025 01:32:51 +0200 Subject: [PATCH 38/47] Update pm2 --- scripts/autoupdater_pm2.sh | 43 ++++++-------------------------------- 1 file changed, 6 insertions(+), 37 deletions(-) diff --git a/scripts/autoupdater_pm2.sh b/scripts/autoupdater_pm2.sh index 4cb62200d..8c39f2a2a 100644 --- a/scripts/autoupdater_pm2.sh +++ b/scripts/autoupdater_pm2.sh @@ -1,16 +1,14 @@ #!/usr/bin/env bash -set -euo pipefail - APP_NAME="sn1" -PY_VERSION_FILE=".python-version" -UV_INSTALL_URL="https://astral.sh/uv/install.sh" # CONFIG="config/mainnet.yaml" CONFIG="config/testnet.yaml" +UV_INSTALL_URL="https://astral.sh/uv/install.sh" + # Ensure common user bin dirs are in PATH. export PATH="$HOME/.local/bin:$HOME/.cargo/bin:$HOME/bin:$HOME/.npm-global/bin:$PATH" -# 1) Ensure uv exists +# 1) Ensure uv exists. if ! command -v uv >/dev/null 2>&1; then echo "[info] uv not found; installing..." if ! command -v curl >/dev/null 2>&1; then @@ -27,7 +25,7 @@ else echo "[info] uv found: $(command -v uv)" fi -# 2) Ensure pm2 exists +# 2) Ensure pm2 exists. if ! command -v pm2 >/dev/null 2>&1; then echo "[info] pm2 not found; installing globally with npm..." if ! command -v npm >/dev/null 2>&1; then @@ -44,34 +42,5 @@ else echo "[info] pm2 found: $(command -v pm2)" fi -# 3) Determine Python version (read from .python-version; fallback 3.11) -PY_VER="3.11" -if [[ -s "$PY_VERSION_FILE" ]]; then - CANDIDATE="$(tr -d ' \t\r\n' < "$PY_VERSION_FILE")" - if [[ "$CANDIDATE" =~ ^[0-9]+(\.[0-9]+){0,2}$ ]]; then - PY_VER="$CANDIDATE" - else - echo "[warn] Invalid version in $PY_VERSION_FILE: '$CANDIDATE' — using $PY_VER" - fi -else - echo "[warn] $PY_VERSION_FILE missing or empty — using $PY_VER" -fi -echo "[info] Using Python version: $PY_VER" - -# 4) Create/update venv and install deps -echo "[info] Creating/refreshing .venv with uv…" -uv venv --python="$PY_VER" -if [[ ! -x ".venv/bin/python" ]]; then - echo "[error] .venv/bin/python not created. Check that uv can resolve Python $PY_VER." >&2 - exit 1 -fi - -echo "[info] Installing project dependencies…" -uv pip install '.[dev]' - -# 5) Start with pm2 -pm2 delete "$APP_NAME" >/dev/null 2>&1 || true -pm2 start ".venv/bin/python" --name "$APP_NAME" -- scripts/autoupdater.py -c "$CONFIG" - -echo "[done] pm2 process '$APP_NAME' started." -pm2 logs "$APP_NAME" +pm2 start scripts/autoupdater.py --interpreter .venv/bin/python --name sn1 -- -c config/testnet.yaml +pm2 logs sn1 From 7cf1218ab96a93b60eca2b96b8525758d3b5c84c Mon Sep 17 00:00:00 2001 From: Dmytro Bobrenko <17252809+dbobrenko@users.noreply.github.com> Date: Mon, 11 Aug 2025 01:33:58 +0200 Subject: [PATCH 39/47] Remove test values --- apex/validator/weight_syncer.py | 6 +----- scripts/autoupdater.py | 3 +-- 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/apex/validator/weight_syncer.py b/apex/validator/weight_syncer.py index 05b85599c..a496b4680 100644 --- a/apex/validator/weight_syncer.py +++ b/apex/validator/weight_syncer.py @@ -71,7 +71,7 @@ async def start(self) -> None: ) self.server = uvicorn.Server(config) self.server_task = asyncio.create_task(self.server.serve()) - logger.info(f"Started weight synchronization API on port {self.port}. pid={os.getpid()} self_id={id(self)}") + logger.info(f"Started weight synchronization API on port {self.port}") # Announce the axon on the network. external_ip = requests.get("https://checkip.amazonaws.com").text.strip() @@ -124,10 +124,6 @@ async def compute_weighted_rewards(self, hotkey_rewards: dict[str, float]) -> di """Computes weighted rewards by fetching rewards from other validators and averaging them by stake.""" self.hotkey_rewards = hotkey_rewards self.last_update_time = time.time() - # logger.debug(f"Updating rewards at: {self.last_update_time}") - import os - - logger.debug(f"Updating rewards at: {self.last_update_time} pid={os.getpid()} self_id={id(self)}") if not self.receive_enabled: logger.warning("Rewards weight averaging is disable, using raw rewards") return hotkey_rewards diff --git a/scripts/autoupdater.py b/scripts/autoupdater.py index df87d9391..101753e20 100644 --- a/scripts/autoupdater.py +++ b/scripts/autoupdater.py @@ -7,8 +7,7 @@ import time from pathlib import Path -# CHECK_INTERVAL = 15 * 60 -CHECK_INTERVAL = 15 +CHECK_INTERVAL = 15 * 60 def venv_python() -> str: From 640eaf08be81abcbb960bd7cfecb5365c4616f1e Mon Sep 17 00:00:00 2001 From: Dmytro Bobrenko <17252809+dbobrenko@users.noreply.github.com> Date: Mon, 11 Aug 2025 01:35:48 +0200 Subject: [PATCH 40/47] Remove test values in pm2 script --- scripts/autoupdater_pm2.sh | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/scripts/autoupdater_pm2.sh b/scripts/autoupdater_pm2.sh index 8c39f2a2a..7adca2022 100644 --- a/scripts/autoupdater_pm2.sh +++ b/scripts/autoupdater_pm2.sh @@ -1,7 +1,7 @@ #!/usr/bin/env bash APP_NAME="sn1" -# CONFIG="config/mainnet.yaml" -CONFIG="config/testnet.yaml" +CONFIG="config/mainnet.yaml" +# CONFIG="config/testnet.yaml" UV_INSTALL_URL="https://astral.sh/uv/install.sh" @@ -42,5 +42,5 @@ else echo "[info] pm2 found: $(command -v pm2)" fi -pm2 start scripts/autoupdater.py --interpreter .venv/bin/python --name sn1 -- -c config/testnet.yaml +pm2 start scripts/autoupdater.py --interpreter .venv/bin/python --name sn1 -- -c $CONFIG pm2 logs sn1 From f73ae7ae039649bc67c2dade82524230960f4c31 Mon Sep 17 00:00:00 2001 From: Dmytro Bobrenko <17252809+dbobrenko@users.noreply.github.com> Date: Mon, 11 Aug 2025 01:39:47 +0200 Subject: [PATCH 41/47] Fix logging --- scripts/autoupdater.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/autoupdater.py b/scripts/autoupdater.py index 101753e20..b700db9d5 100644 --- a/scripts/autoupdater.py +++ b/scripts/autoupdater.py @@ -101,7 +101,7 @@ def handle_sigint(sig, frame): sys.exit(proc.returncode) if remote_has_updates(): - print("Updates detected, restaring process") + print("Updates detected, restarting validator") stop_proc(proc) git_pull_ff_only() proc = start_proc(config=args.config) From 8fcbadf510eaa1f737fd2cf0402deb08bf9421c6 Mon Sep 17 00:00:00 2001 From: Dmytro Bobrenko <17252809+dbobrenko@users.noreply.github.com> Date: Mon, 11 Aug 2025 01:41:07 +0200 Subject: [PATCH 42/47] Lower task generation frequency --- apex/validator/pipeline.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/apex/validator/pipeline.py b/apex/validator/pipeline.py index 5b695d7ea..a87b9cf28 100644 --- a/apex/validator/pipeline.py +++ b/apex/validator/pipeline.py @@ -24,8 +24,8 @@ def __init__( deep_research: DeepResearchBase, logger_apex: LoggerApex | None = None, num_consumers: int = 5, - timeout_consumer: float = 180, - timeout_producer: float = 36, + timeout_consumer: float = 1200, + timeout_producer: float = 240, queue_size: int = 10_000, redundancy_rate: float = 0.05, # The rate that references are generated in addition to generator steps reference_rate: float = 0.5, # The rate that references are generated as opposed to generator steps From e3ce414f00094539bfa8a489816b1cb5a2e2273f Mon Sep 17 00:00:00 2001 From: Dmytro Bobrenko <17252809+dbobrenko@users.noreply.github.com> Date: Mon, 11 Aug 2025 01:42:41 +0200 Subject: [PATCH 43/47] Fix readme --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 7b63eff9b..83b109275 100644 --- a/README.md +++ b/README.md @@ -44,12 +44,12 @@ Subnet 1 is the most intelligent inference model on Bittensor. As the first agen python scripts/autoupdater.py -c config/mainnet.yaml ``` -3a. **[Alternative #1] Run validator with pm2 and auto-updater:** +4. **[Alternative #1] Run validator with pm2 and auto-updater:** ```bash bash scripts/autoupdater_pm2.sh ``` -3b. **[Alternative #2] Install dependencies and run validator without auto-updater:** +5. **[Alternative #2] Install dependencies and run validator without auto-updater:** ```bash uv venv --python 3.11 && uv pip install '.[dev]' && python validator.py -c config/mainnet.yaml ``` From 3123c98bae562db3713f38efd59381ce0e8c440d Mon Sep 17 00:00:00 2001 From: Dmytro Bobrenko <17252809+dbobrenko@users.noreply.github.com> Date: Mon, 11 Aug 2025 01:45:24 +0200 Subject: [PATCH 44/47] Fix formatting --- apex/validator/weight_syncer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/apex/validator/weight_syncer.py b/apex/validator/weight_syncer.py index a496b4680..72a46974b 100644 --- a/apex/validator/weight_syncer.py +++ b/apex/validator/weight_syncer.py @@ -1,5 +1,4 @@ import asyncio -import os import time from typing import cast From 6802c3bcf6c6cf4de2be23c6ae7b4d66bc06470c Mon Sep 17 00:00:00 2001 From: Dmytro Bobrenko <17252809+dbobrenko@users.noreply.github.com> Date: Mon, 11 Aug 2025 02:19:49 +0200 Subject: [PATCH 45/47] Ensure port is of type int --- apex/validator/weight_syncer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/apex/validator/weight_syncer.py b/apex/validator/weight_syncer.py index 72a46974b..5327e335f 100644 --- a/apex/validator/weight_syncer.py +++ b/apex/validator/weight_syncer.py @@ -44,7 +44,7 @@ def __init__( self.current_hotkey = self.wallet.hotkey.ss58_address self.receive_enabled = enable_receive self.send_enabled = enable_send - self.port = port + self.port = int(port) self.server: uvicorn.Server | None = None self.server_task: asyncio.Task[None] | None = None self.hotkey_rewards: dict[str, float] | None = None From 91f9801911e32351787827d084c6f9f79e85d115 Mon Sep 17 00:00:00 2001 From: Dmytro Bobrenko <17252809+dbobrenko@users.noreply.github.com> Date: Mon, 11 Aug 2025 02:36:19 +0200 Subject: [PATCH 46/47] Add constants timeout variable --- apex/common/constants.py | 1 + apex/validator/miner_sampler.py | 13 +++++++++++-- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/apex/common/constants.py b/apex/common/constants.py index 9f7698ab4..7524b98e8 100644 --- a/apex/common/constants.py +++ b/apex/common/constants.py @@ -1,3 +1,4 @@ +TIMEOUT: float = 20 MAX_TOKENS: int = 2048 TEMPERATURE: float = 0.1 WEBPAGE_MAXSIZE: int = 500 diff --git a/apex/validator/miner_sampler.py b/apex/validator/miner_sampler.py index f4948efe2..cfdbc64de 100644 --- a/apex/validator/miner_sampler.py +++ b/apex/validator/miner_sampler.py @@ -11,7 +11,7 @@ from pydantic import BaseModel from apex.common.async_chain import AsyncChain -from apex.common.constants import VALIDATOR_REFERENCE_LABEL +from apex.common.constants import TIMEOUT, VALIDATOR_REFERENCE_LABEL from apex.common.epistula import generate_header from apex.common.models import MinerDiscriminatorResults, MinerGeneratorResults from apex.common.utils import async_cache @@ -132,9 +132,16 @@ async def _sample_miners(self) -> list[MinerInfo]: ) return miners_sample - async def query_miners(self, body: dict[str, Any], endpoint: str, hotkey: str | None = None) -> str: + async def query_miners( + self, + body: dict[str, Any], + endpoint: str, + hotkey: str | None = None, + timeout: float = TIMEOUT + ) -> str: """Query the miners for the query.""" try: + client_timeout = aiohttp.ClientTimeout(total=timeout) async with aiohttp.ClientSession() as session: headers = await generate_header( self._chain.wallet.hotkey, body=json.dumps(body).encode("utf-8"), signed_for=hotkey @@ -143,7 +150,9 @@ async def query_miners(self, body: dict[str, Any], endpoint: str, hotkey: str | f"{endpoint}/v1/chat/completions", headers=headers, json=body, + timeout=client_timeout, ) as resp: + resp.raise_for_status() result = await resp.text() except BaseException: # Error during miner query, return empty string. From 676b18764777acaa04a39ebbdfb166eb725c92e2 Mon Sep 17 00:00:00 2001 From: Dmytro Bobrenko <17252809+dbobrenko@users.noreply.github.com> Date: Mon, 11 Aug 2025 02:49:37 +0200 Subject: [PATCH 47/47] Run formatter --- apex/validator/miner_sampler.py | 6 +----- tests/validator/test_miner_sampler.py | 8 ++++++-- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/apex/validator/miner_sampler.py b/apex/validator/miner_sampler.py index cfdbc64de..1a34ea099 100644 --- a/apex/validator/miner_sampler.py +++ b/apex/validator/miner_sampler.py @@ -133,11 +133,7 @@ async def _sample_miners(self) -> list[MinerInfo]: return miners_sample async def query_miners( - self, - body: dict[str, Any], - endpoint: str, - hotkey: str | None = None, - timeout: float = TIMEOUT + self, body: dict[str, Any], endpoint: str, hotkey: str | None = None, timeout: float = TIMEOUT ) -> str: """Query the miners for the query.""" try: diff --git a/tests/validator/test_miner_sampler.py b/tests/validator/test_miner_sampler.py index d184ed41f..8a4b2fcb1 100644 --- a/tests/validator/test_miner_sampler.py +++ b/tests/validator/test_miner_sampler.py @@ -3,6 +3,7 @@ from typing import Any from unittest.mock import AsyncMock, MagicMock, patch +import aiohttp import pytest from pydantic import BaseModel from pytest import MonkeyPatch @@ -213,12 +214,15 @@ async def test_query_miners() -> None: patch("time.time", return_value=12345), ): mock_generate_header.return_value = {"some": "header"} - result = await sampler.query_miners(body, endpoint) + timeout = 20 + result = await sampler.query_miners(body, endpoint, timeout=timeout) mock_client_session.assert_called_once() expected_body = {"test": "data"} + + client_timeout = aiohttp.ClientTimeout(total=timeout) mock_session.post.assert_called_with( - endpoint + "/v1/chat/completions", headers={"some": "header"}, json=expected_body + endpoint + "/v1/chat/completions", headers={"some": "header"}, json=expected_body, timeout=client_timeout ) mock_generate_header.assert_called_with( mock_chain.wallet.hotkey, body=json.dumps(body).encode("utf-8"), signed_for=None