diff --git a/.gitignore b/.gitignore index a68384611..70cbedd8b 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ +*.log requirements.txt **/*.ipynb debug_rewards.jsonl diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7852ed5cc..08c1e7862 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -20,4 +20,22 @@ repos: hooks: - id: ruff args: [--fix] + stages: [pre-commit] - id: ruff-format + stages: [pre-commit] + +- repo: local + hooks: + - id: mypy + name: mypy + entry: mypy . + language: system + pass_filenames: false + stages: [pre-commit] + + - id: pytest + name: pytest + entry: pytest tests/ --verbose --failed-first --exitfirst --disable-warnings + language: system + pass_filenames: false + stages: [pre-commit] diff --git a/README.md b/README.md index 0192cf38d..83b109275 100644 --- a/README.md +++ b/README.md @@ -24,7 +24,7 @@ Subnet 1 is the most intelligent inference model on Bittensor. As the first agen --- -## Installation +## Run Validator 1. **Clone the repository:** ```bash @@ -32,33 +32,26 @@ Subnet 1 is the most intelligent inference model on Bittensor. As the first agen cd apex ``` -2. **Install `uv`:** - Follow the instructions at [https://github.com/astral-sh/uv](https://github.com/astral-sh/uv) to install `uv`. For example: + +2. **Prepare config file:** ```bash - curl -LsSf https://astral.sh/uv/install.sh | sh + cp config/mainnet.yaml.example config/mainnet.yaml + # Fill in the required values in config/mainnet.yaml ``` -3. **Install the project and its development dependencies:** +3. **[Recommended] Run validator with auto-updater:** ```bash - uv venv && uv python install 3.11 && uv python pin 3.11 && uv venv --python=3.11 && uv pip install -e '.[dev]' + python scripts/autoupdater.py -c config/mainnet.yaml ``` -4. **Activate python environment:** - ```bash - . .venv/bin/activate - ``` - -## Run Mainnet Validator - -1. Prepare config file: +4. **[Alternative #1] Run validator with pm2 and auto-updater:** ```bash - cp config/mainnet.yaml.example config/mainnet.yaml - # Fill in the required values in config/mainnet.yaml + bash scripts/autoupdater_pm2.sh ``` -2. **Run the validator:** +5. **[Alternative #2] Install dependencies and run validator without auto-updater:** ```bash - python validator.py -c config/mainnet.yaml + uv venv --python 3.11 && uv pip install '.[dev]' && python validator.py -c config/mainnet.yaml ``` ## Run Testnet Validator @@ -69,9 +62,9 @@ Subnet 1 is the most intelligent inference model on Bittensor. As the first agen # Fill in the required values in config/testnet.yaml ``` -2. **Run the validator:** +2. Install dependencies and run validator: ```bash - python validator.py -c config/testnet.yaml + uv venv --python 3.11 && uv pip install '.[dev]' && python validator.py -c config/testnet.yaml ``` ## Base Miner (for showcase purposes only) diff --git a/apex/__init__.py b/apex/__init__.py index 93de2e624..367567422 100644 --- a/apex/__init__.py +++ b/apex/__init__.py @@ -7,6 +7,17 @@ __version__ = version("apex") +def _version_to_int(version_str: str) -> int: + version_split = version_str.split(".") + ["0", "0"] # in case a version doesn't have third element, e.g. 3.0 + major = int(version_split[0]) + minor = int(version_split[1]) + patch = int(version_split[2]) + return (10000 * major) + (100 * minor) + patch + + +__spec_version__ = _version_to_int(__version__) + + def setup_logger(log_file_path: str | Path | None = None, level: str = "INFO") -> Any: """Set up the loguru logger with optional file logging and specified log level. @@ -28,9 +39,9 @@ def setup_logger(log_file_path: str | Path | None = None, level: str = "INFO") - # Add file handler if a path is provided. if log_file_path: file_log_format = "{time:YYYY-MM-DD HH:mm:ss} [{file}:{line}] {message}" - logger.add(str(log_file_path), level=level, format=file_log_format, rotation="10 MB", retention="7 days") + logger.add(str(log_file_path), level=level, format=file_log_format, rotation="5 MB", retention="3 days") return logger -setup_logger(level="DEBUG") +setup_logger(log_file_path="logs.log", level="DEBUG") diff --git a/apex/common/async_chain.py b/apex/common/async_chain.py index 28453f86a..d4e8a2a61 100644 --- a/apex/common/async_chain.py +++ b/apex/common/async_chain.py @@ -5,6 +5,7 @@ from bittensor.core.metagraph import AsyncMetagraph from loguru import logger +from apex import __spec_version__ from apex.common.utils import async_cache _METAGRAPH_TTL: int = 10 * 60 @@ -15,7 +16,6 @@ def __init__(self, coldkey: str, hotkey: str, netuid: int, network: list[str] | if isinstance(network, str): network = [network] self._network: list[str] = network - self._coldkey = coldkey self._hotkey = hotkey self._netuid = netuid @@ -111,36 +111,35 @@ def network(self) -> list[str]: return self._network async def set_weights(self, rewards: dict[str, float]) -> bool: - metagraph = await self.metagraph() - subtensor = await self.subtensor() - weights: dict[int, float] = {} + try: + metagraph = await self.metagraph() + subtensor = await self.subtensor() + weights: dict[int, float] = {} - for hotkey, reward in rewards.items(): - try: - idx = metagraph.hotkeys.index(hotkey) - except ValueError: - # Hotkey not found in the metagraph (e.g., deregistered). Skip it. - continue + for hotkey, reward in rewards.items(): + try: + idx = metagraph.hotkeys.index(hotkey) + except ValueError: + # Hotkey not found in the metagraph (e.g., deregistered). Skip it. + continue - uid = metagraph.uids[idx] - weights[uid] = reward + uid = metagraph.uids[idx] + weights[uid] = reward - # Set the weights. - try: - result = await subtensor.set_weights( + success, err = await subtensor.set_weights( wallet=self._wallet, netuid=self._netuid, uids=list(weights.keys()), weights=list(weights.values()), + version_key=__spec_version__, wait_for_inclusion=True, wait_for_finalization=True, ) - if not result: - logger.error(f"Error setting weights: {result}") - return False - return True + if not success: + logger.error(f"Error during weight set: {err}") + return bool(success) except BaseException as exc: - logger.exception(f"Error setting weights: {exc}") + logger.exception(f"Error during weight set: {exc}") return False async def mask_network(self) -> list[str]: diff --git a/apex/common/config.py b/apex/common/config.py index fc3fff308..9cfcad02b 100644 --- a/apex/common/config.py +++ b/apex/common/config.py @@ -15,6 +15,7 @@ class Config(BaseModel): chain: ConfigClass = Field(default_factory=ConfigClass) websearch: ConfigClass = Field(default_factory=ConfigClass) logger_db: ConfigClass = Field(default_factory=ConfigClass) + weight_syncer: ConfigClass = Field(default_factory=ConfigClass) miner_sampler: ConfigClass = Field(default_factory=ConfigClass) miner_scorer: ConfigClass = Field(default_factory=ConfigClass) llm: ConfigClass = Field(default_factory=ConfigClass) diff --git a/apex/common/constants.py b/apex/common/constants.py index 6315e4485..7524b98e8 100644 --- a/apex/common/constants.py +++ b/apex/common/constants.py @@ -1,7 +1,16 @@ +TIMEOUT: float = 20 MAX_TOKENS: int = 2048 TEMPERATURE: float = 0.1 WEBPAGE_MAXSIZE: int = 500 VALIDATOR_REFERENCE_LABEL = "Validator" +VALIDATOR_VERIFIED_HOTKEYS = { + "5CGLCBndTR1BvQZzn429ckT8GyxduzyjMgt4K1UVTYa8gKfb": "167.99.236.79:8001", # Macrocosmos. + "5CUbyC2Ez7tWYYmnFSSwjqkw26dFNo9cXH8YmcxBSfxi2XSG": None, # Yuma. + "5C8Em1kDZi5rxgDN4zZtfoT7dUqJ7FFbTzS3yTP5GPgVUsn1": None, # RoundTable21. + "5HmkM6X1D3W3CuCSPuHhrbYyZNBy2aGAiZy9NczoJmtY25H7": None, # Crucible. + "5GeR3cDuuFKJ7p66wKGjY65MWjWnYqffq571ZMV4gKMnJqK5": None, # OTF. + "5D1saVvssckE1XoPwPzdHrqYZtvBJ3vESsrPNxZ4zAxbKGs1": None, # Rizzo. +} _ENGLISH_WORDS: tuple[str, ...] | None = None diff --git a/apex/common/epistula.py b/apex/common/epistula.py index 947c0a6d4..8d5aface7 100644 --- a/apex/common/epistula.py +++ b/apex/common/epistula.py @@ -1,11 +1,17 @@ +import json import time from hashlib import sha256 from math import ceil -from typing import Any +from typing import Annotated, Any from uuid import uuid4 +from fastapi import HTTPException, Request +from loguru import logger from substrateinterface import Keypair +from apex.common.async_chain import AsyncChain +from apex.common.constants import VALIDATOR_VERIFIED_HOTKEYS + async def generate_header( hotkey: Keypair, @@ -33,3 +39,82 @@ async def generate_header( "0x" + hotkey.sign(str(timestamp_interval + 1) + "." + signed_for).hex() ) return headers + + +def verify_signature( + signature: str | None, + body: bytes, + timestamp: str | None, + uuid: str | None, + signed_for: str, + signed_by: str, + now: float, +) -> Annotated[str, "Error Message"] | None: + if not isinstance(signature, str): + return "Invalid Signature" + if not isinstance(timestamp, str) or not timestamp.isdigit(): + return "Invalid Timestamp" + timestamp_as_int = int(timestamp) + if not isinstance(signed_by, str): + return "Invalid Sender key" + if not isinstance(signed_for, str): + return "Invalid receiver key" + if not isinstance(uuid, str): + return "Invalid uuid" + if not isinstance(body, bytes): + return "Body is not of type bytes" + allowed_delta_ms = 8000 + keypair = Keypair(ss58_address=signed_by) + if timestamp_as_int + allowed_delta_ms < now: + return "Request is too stale" + message = f"{sha256(body).hexdigest()}.{uuid}.{timestamp}.{signed_for}" + verified = keypair.verify(message, signature) + if not verified: + return "Signature Mismatch" + return None + + +async def verify_validator_signature(request: Request, chain: AsyncChain, min_stake: float = 1024) -> None: + signed_by = request.headers.get("Epistula-Signed-By") + signed_for = request.headers.get("Epistula-Signed-For") + if not signed_by or not signed_for: + logger.error("Missing Epistula-Signed-* headers") + raise HTTPException(400, "Missing Epistula-Signed-* headers") + + wallet = chain.wallet + if signed_for != wallet.hotkey.ss58_address: + logger.error("Bad Request, message is not intended for self") + raise HTTPException(status_code=400, detail="Bad Request, message is not intended for self") + + is_validator = True + if min_stake > 0: + metagraph = await chain.metagraph() + try: + caller_uid = metagraph.hotkeys.index(signed_by) + except ValueError as exc: + raise HTTPException(status_code=401, detail="Signer is not in metagraph") from exc + is_validator = metagraph.stake[caller_uid] > min_stake + + if signed_by not in VALIDATOR_VERIFIED_HOTKEYS and not is_validator: + logger.error(f"Signer not the expected ss58 address: {signed_by}") + raise HTTPException(status_code=401, detail="Signer not the expected ss58 address") + + now = time.time() + body: bytes = await request.body() + try: + json.loads(body) + except json.JSONDecodeError as exc: + raise HTTPException(400, "Invalid JSON body") from exc + + err = verify_signature( + request.headers.get("Epistula-Request-Signature"), + body, + request.headers.get("Epistula-Timestamp"), + request.headers.get("Epistula-Uuid"), + signed_for, + signed_by, + now, + ) + if err: + logger.error(err) + raise HTTPException(status_code=400, detail=err) diff --git a/apex/validator/miner_sampler.py b/apex/validator/miner_sampler.py index d8c71d16a..1a34ea099 100644 --- a/apex/validator/miner_sampler.py +++ b/apex/validator/miner_sampler.py @@ -2,6 +2,7 @@ import json import random import time +from collections import deque from collections.abc import Coroutine, Sequence from typing import Any, Literal @@ -10,7 +11,7 @@ from pydantic import BaseModel from apex.common.async_chain import AsyncChain -from apex.common.constants import VALIDATOR_REFERENCE_LABEL +from apex.common.constants import TIMEOUT, VALIDATOR_REFERENCE_LABEL from apex.common.epistula import generate_header from apex.common.models import MinerDiscriminatorResults, MinerGeneratorResults from apex.common.utils import async_cache @@ -69,7 +70,8 @@ def __init__( if self._available_uids and self._available_addresses: equal_length = len(self._available_uids) == len(self._available_addresses) assert equal_length, "Test UIDs and addresses must be the same length." - self._remaining_epoch_miners: set[MinerInfo] = set() + self._epoch_deque: deque[MinerInfo] = deque() + self._sample_lock = asyncio.Lock() @async_cache(_TTL_UIDS_RESYNC) async def _get_all_miners(self) -> list[MinerInfo]: @@ -110,34 +112,43 @@ async def _get_all_miners(self) -> list[MinerInfo]: async def _sample_miners(self) -> list[MinerInfo]: miners = await self._get_all_miners() + miners_sample: list[MinerInfo] = [] if self._sample_mode == "random": miners_sample = random.sample(miners, self._sample_size) elif self._sample_mode == "sequential": - if len(self._remaining_epoch_miners) < self._sample_size: - self._remaining_epoch_miners = set(miners) - logger.debug(f"Starting new miner sampling epoch, miners amount: {len(self._remaining_epoch_miners)}") - indices_sample = sorted(random.sample(range(len(self._remaining_epoch_miners)), self._sample_size)) - miners_sample = [miners[i] for i in indices_sample] - self._remaining_epoch_miners -= set(miners_sample) + async with self._sample_lock: + while len(miners_sample) < self._sample_size: + if not self._epoch_deque: + # Get shuffled deque of miners. + self._epoch_deque = deque(random.sample(miners, len(miners))) + miners_sample.append(self._epoch_deque.popleft()) else: raise ValueError(f"Unknown sampling mode: {self._sample_mode}") + logger.debug( + f"Sampled uids (sample size = {self._sample_size}): {sorted([miner.uid for miner in miners_sample])}" + ) return miners_sample - async def query_miners(self, body: dict[str, Any], endpoint: str, hotkey: str | None = None) -> str: + async def query_miners( + self, body: dict[str, Any], endpoint: str, hotkey: str | None = None, timeout: float = TIMEOUT + ) -> str: """Query the miners for the query.""" try: + client_timeout = aiohttp.ClientTimeout(total=timeout) async with aiohttp.ClientSession() as session: headers = await generate_header( self._chain.wallet.hotkey, body=json.dumps(body).encode("utf-8"), signed_for=hotkey ) async with session.post( - endpoint + "/v1/chat/completions", + f"{endpoint}/v1/chat/completions", headers=headers, json=body, + timeout=client_timeout, ) as resp: + resp.raise_for_status() result = await resp.text() except BaseException: # Error during miner query, return empty string. @@ -151,9 +162,10 @@ async def query_generators(self, query: str) -> MinerGeneratorResults: hotkeys: list[str] = [] tasks: list[Coroutine[str, str, Any]] = [] + + logger.debug(f"Querying {len(miner_information)} miner generators") for miner_info in miner_information: hotkeys.append(miner_info.hotkey) - logger.debug(f"Querying miner generator at {miner_info.address} with uid: {miner_info.uid}") tasks.append(self.query_miners(body=body, endpoint=miner_info.address, hotkey=miner_info.hotkey)) generator_results = await asyncio.gather(*tasks) return MinerGeneratorResults(query=query, generator_hotkeys=hotkeys, generator_results=generator_results) @@ -217,7 +229,7 @@ async def query_discriminators( choice_content = "None" parsed_discriminator_results.append(choice_content) - # Apply scoring logic based on selected generator type + # Apply scoring logic based on selected generator type. if choice_content == str(ground_truth): discriminator_score = score_per_miner else: @@ -225,7 +237,7 @@ async def query_discriminators( discriminator_results_float.append(discriminator_score) - # Generator result is 1 minus sum of discriminator results + # Generator result is 1 minus sum of discriminator results. generator_result_float = 1.0 - sum(discriminator_results_float) miner_discriminator_results = MinerDiscriminatorResults( query=query, diff --git a/apex/validator/miner_scorer.py b/apex/validator/miner_scorer.py index 5ee18c646..b24490cbb 100644 --- a/apex/validator/miner_scorer.py +++ b/apex/validator/miner_scorer.py @@ -7,10 +7,12 @@ from pathlib import Path import aiosqlite +import numpy as np from loguru import logger from apex.common.async_chain import AsyncChain from apex.common.constants import VALIDATOR_REFERENCE_LABEL +from apex.validator.weight_syncer import WeightSyncer # Scoring moving average in hours. Set to be: immunity_period - post_reg_threshold. SCORE_MA_WINDOW_HOURS = 23.75 @@ -18,19 +20,29 @@ class MinerScorer: - def __init__(self, chain: AsyncChain, interval: float = SCORE_INTERVAL_DEFAULT, debug: bool = False): + def __init__( + self, + chain: AsyncChain, + weight_syncer: WeightSyncer | None = None, + interval: float = SCORE_INTERVAL_DEFAULT, + debug: bool = False, + ): self.chain = chain self.interval = interval - self._running = True self._debug = debug + self._weight_syncer = weight_syncer self._debug_rewards_path = Path("debug_rewards.jsonl") + self._running = True async def start_loop(self) -> None: self._running = True while self._running: await asyncio.sleep(self.interval) + logger.debug("Attempting to set weights") success = await self.set_scores() - if not success: + if success: + logger.info("Successfully set weights") + else: logger.error("Failed to set weights") async def shutdown(self) -> None: @@ -50,6 +62,7 @@ async def set_scores(self) -> bool: expose each one as plain python objects so that downstream code can work with them, and remove rows that are older than the time window. """ + logger.debug("Retrieving miner's performance history") async with self._db() as conn: # type: aiosqlite.Connection # Calculate the cutoff timestamp (current time - window hours). cutoff_timestamp = int(time.time() - SCORE_MA_WINDOW_HOURS * 3600) @@ -70,6 +83,7 @@ async def set_scores(self) -> bool: return False # 2. Iterate over the in-memory list so that the caller can process freely. + logger.debug("Pre-processing miner's rewards") hkey_agg_rewards: dict[str, float] = {} for generator_hotkey, generator_score, disc_hotkeys_json, disc_scores_json in rows: # Deserialize JSON columns. @@ -88,6 +102,7 @@ async def set_scores(self) -> bool: hkey_agg_rewards[hotkey] = float(hkey_agg_rewards.get(hotkey, 0.0)) + float(reward) # 3. Delete rows that are older than the time window. + logger.debug("Cleaning up miner's outdated history") await conn.execute( "DELETE FROM discriminator_results WHERE timestamp < ?", (cutoff_timestamp,), @@ -101,9 +116,23 @@ async def set_scores(self) -> bool: with self._debug_rewards_path.open("a+") as fh: record_str: str = json.dumps(record) fh.write(f"{record_str}\n") + + if self._weight_syncer is not None: + try: + hkey_agg_rewards = await self._weight_syncer.compute_weighted_rewards(hkey_agg_rewards) + except BaseException as exc: + logger.error(f"Failed to compute weighted average rewards over the network, skipping: {exc}") + + if hkey_agg_rewards: + rewards_array = np.array(list(hkey_agg_rewards.values())) + logger.debug(f"Setting weights, reward mean={rewards_array.mean():.4f} min={rewards_array.min():.4f}") + else: + logger.warning(f"Setting empty rewards: {hkey_agg_rewards}") + # TODO: Flush the db only on set_weights_result is True. set_weights_result = await self.chain.set_weights(hkey_agg_rewards) # 4. Flush all deletions in a single commit. + logger.debug("Updating rewards DB") await conn.commit() return set_weights_result diff --git a/apex/validator/pipeline.py b/apex/validator/pipeline.py index ea8dfaf9a..a87b9cf28 100644 --- a/apex/validator/pipeline.py +++ b/apex/validator/pipeline.py @@ -6,7 +6,6 @@ from loguru import logger -from apex.common.config import Config from apex.common.models import QueryTask from apex.services.deep_research.deep_research_base import DeepResearchBase from apex.services.llm.llm_base import LLMBase @@ -19,20 +18,18 @@ class Pipeline: def __init__( self, - config: Config, websearch: WebSearchBase, miner_sampler: MinerSampler, llm: LLMBase, deep_research: DeepResearchBase, logger_apex: LoggerApex | None = None, - num_consumers: int = 10, - timeout_consumer: float = 60, - timeout_producer: float = 6, + num_consumers: int = 5, + timeout_consumer: float = 1200, + timeout_producer: float = 240, queue_size: int = 10_000, - redundancy_rate: float = 0.1, # The rate that references are generated in addition to generator steps + redundancy_rate: float = 0.05, # The rate that references are generated in addition to generator steps reference_rate: float = 0.5, # The rate that references are generated as opposed to generator steps ): - self.config = config self.websearch = websearch self.miner_registry = miner_sampler self.llm = llm @@ -81,21 +78,27 @@ async def run_single(self, task: QueryTask) -> str: logger.debug("Generating task query") query = await generate_query(llm=self.llm, websearch=self.websearch) + reference = None + tool_history: list[dict[str, str]] = [] if random.random() < self.reference_rate: + try: + generator_results = None + ground_truth = 0 + logger.debug(f"Generating task reference for query: {query[:20]}..") + reference, tool_history = await generate_reference(llm=self.deep_research, query=query) + except BaseException as exc: + logger.exception(f"Failed to generate reference: {exc}") + + if reference is None: ground_truth = 1 logger.debug(f"Querying generators with query: {query[:20]}..") generator_results = await self.miner_registry.query_generators(query=query) if random.random() < self.redundancy_rate: - logger.debug(f"Generating redundant task reference for query: {query[:20]}..") - reference, tool_history = await generate_reference(llm=self.deep_research, query=query) - else: - reference = None - tool_history = [] - else: - generator_results = None - ground_truth = 0 - logger.debug(f"Generating task reference for query: {query[:20]}..") - reference, tool_history = await generate_reference(llm=self.deep_research, query=query) + try: + logger.debug(f"Generating redundant task reference for query: {query[:20]}..") + reference, tool_history = await generate_reference(llm=self.deep_research, query=query) + except BaseException as exc: + logger.warning(f"Failed to generate redundant reference: {exc}") discriminator_results = await self.miner_registry.query_discriminators( query=query, generator_results=generator_results, reference=reference, ground_truth=ground_truth diff --git a/apex/validator/weight_syncer.py b/apex/validator/weight_syncer.py new file mode 100644 index 000000000..5327e335f --- /dev/null +++ b/apex/validator/weight_syncer.py @@ -0,0 +1,212 @@ +import asyncio +import time +from typing import cast + +import httpx +import netaddr +import requests +import uvicorn +from bittensor.core.async_subtensor import AsyncMetagraph +from bittensor.core.extrinsics.asyncex.serving import serve_extrinsic +from fastapi import APIRouter, FastAPI, HTTPException, Request +from loguru import logger +from pydantic import BaseModel + +from apex.common.async_chain import AsyncChain +from apex.common.constants import VALIDATOR_VERIFIED_HOTKEYS +from apex.common.epistula import generate_header, verify_validator_signature + + +class ValidatorInfo(BaseModel): + uid: int + hotkey: str + address: str + stake: float + + +class WeightSyncer: + REWARD_EXPIRATION_SEC: float = 60 * 60 + + def __init__( + self, + chain: AsyncChain, + min_alpha_stake: float = 100_000, + verified_hotkeys: dict[str, str | None] | None = None, + enable_receive: bool = True, + enable_send: bool = True, + port: int = 8001, + ) -> None: + """Validator weight synchronizer.""" + self.chain = chain + self.min_alpha_stake = min_alpha_stake + self.verified_hotkeys = verified_hotkeys or VALIDATOR_VERIFIED_HOTKEYS + self.wallet = self.chain.wallet + self.current_hotkey = self.wallet.hotkey.ss58_address + self.receive_enabled = enable_receive + self.send_enabled = enable_send + self.port = int(port) + self.server: uvicorn.Server | None = None + self.server_task: asyncio.Task[None] | None = None + self.hotkey_rewards: dict[str, float] | None = None + self.last_update_time: float = 0 + + async def start(self) -> None: + if not self.send_enabled: + logger.warning("Weight synchronization API is disabled for incoming reward requests") + return + + try: + app = FastAPI() + app.include_router(self.get_router()) + + config = uvicorn.Config( + app=app, + host="0.0.0.0", + port=self.port, + log_level="info", + workers=1, + reload=False, + loop="asyncio", + ) + self.server = uvicorn.Server(config) + self.server_task = asyncio.create_task(self.server.serve()) + logger.info(f"Started weight synchronization API on port {self.port}") + + # Announce the axon on the network. + external_ip = requests.get("https://checkip.amazonaws.com").text.strip() + netaddr.IPAddress(external_ip) + sub = await self.chain.subtensor() + serve_success = await serve_extrinsic( + subtensor=sub, + wallet=self.chain.wallet, + ip=external_ip, + port=self.port, + protocol=4, + netuid=self.chain.netuid, + ) + if serve_success: + logger.success(f"Serving weight syncer axon on subtensor at {external_ip}:{self.port}") + else: + logger.error("Failed to serve weight syncer axon on subtensor") + except BaseException as e: + logger.warning(f"Failed to announce weight syncer axon on subtensor: {e}") + + async def shutdown(self) -> None: + if self.server is not None: + self.server.should_exit = True + if self.server_task is not None: + await self.server_task + + def get_router(self) -> APIRouter: + """Creates and returns a FastAPI router with the endpoints for this class.""" + router = APIRouter() + + @router.post("/v1/get_rewards") + async def get_rewards_endpoint(request: Request) -> dict[str, float]: + await verify_validator_signature(request=request, chain=self.chain, min_stake=self.min_alpha_stake) + + outdated = time.time() - self.last_update_time + if (outdated := time.time() - self.last_update_time) > self.REWARD_EXPIRATION_SEC: + logger.warning(f"Rewards expired: {outdated:.2f}s - {self.last_update_time}") + raise HTTPException(status_code=503, detail="Rewards expired") + if self.hotkey_rewards is None: + logger.warning("Rewards not available") + raise HTTPException(status_code=503, detail="Rewards not available") + if not self.send_enabled: + logger.warning("API is disabled") + raise HTTPException(status_code=405, detail="API is disabled") + return self.hotkey_rewards + + return router + + async def compute_weighted_rewards(self, hotkey_rewards: dict[str, float]) -> dict[str, float]: + """Computes weighted rewards by fetching rewards from other validators and averaging them by stake.""" + self.hotkey_rewards = hotkey_rewards + self.last_update_time = time.time() + if not self.receive_enabled: + logger.warning("Rewards weight averaging is disable, using raw rewards") + return hotkey_rewards + + metagraph = await self.chain.metagraph() + + try: + own_uid = metagraph.hotkeys.index(self.current_hotkey) + except ValueError: + logger.error(f"Could not find own hotkey {self.current_hotkey} in metagraph, returning raw rewards") + return hotkey_rewards + + validator_rewards_tasks: dict[int, asyncio.Task[dict[str, float]]] = {} + for uid in metagraph.uids: + if uid == own_uid: + continue + + stake = metagraph.stake[uid] + hotkey = metagraph.hotkeys[uid] + is_verified = hotkey in self.verified_hotkeys + is_validator = metagraph.validator_permit[uid] + + if (stake >= self.min_alpha_stake and is_validator) or is_verified: + validator_rewards_tasks[uid] = asyncio.create_task(self.receive_rewards(metagraph, uid)) + + results = await asyncio.gather(*validator_rewards_tasks.values(), return_exceptions=True) + + validator_rewards: dict[int, dict[str, float]] = {} + for uid, result in zip(validator_rewards_tasks, results, strict=True): + if isinstance(result, BaseException) or not result: + logger.warning(f"Cannot receive rewards from uid {uid}: {result}") + continue + validator_rewards[uid] = result + logger.debug(f"Received rewards from validator {uid} with stake {metagraph.stake[uid]}") + + all_validator_uids = [own_uid] + list(validator_rewards.keys()) + total_stake = sum(metagraph.stake[uid] for uid in all_validator_uids) + + if total_stake == 0: + logger.warning("Total stake of responding validators is zero, returning original rewards") + return hotkey_rewards + + own_stake = metagraph.stake[own_uid] + + weighted_rewards: dict[str, float] = {} + for miner_hkey in hotkey_rewards: + own_reward = hotkey_rewards.get(miner_hkey, 0.0) + total_weighted_reward = own_reward * own_stake + + for uid, rewards in validator_rewards.items(): + validator_reward = rewards.get(miner_hkey, 0.0) + total_weighted_reward += validator_reward * metagraph.stake[uid] + + weighted_rewards[miner_hkey] = total_weighted_reward / total_stake + + logger.debug( + f"Averaged rewards over {len(all_validator_uids)} validators. " + f"Self stake: {100 * own_stake / total_stake:.2f}%" + ) + return weighted_rewards + + async def receive_rewards(self, metagraph: AsyncMetagraph, uid: int) -> dict[str, float]: + """Receive rewards from the given validator uid.""" + try: + target_hotkey = metagraph.hotkeys[uid] + if (address := VALIDATOR_VERIFIED_HOTKEYS.get(target_hotkey)) is None: + axon = metagraph.axons[uid] + address = f"{axon.ip}:{axon.port}" + + async with httpx.AsyncClient() as client: + body = b"{}" + headers = await generate_header( + hotkey=self.chain.wallet.hotkey, + body=body, + signed_for=target_hotkey, + ) + resp = await client.post( + f"http://{address}/v1/get_rewards", + headers=headers, + content=body, + ) + resp.raise_for_status() + return cast(dict[str, float], resp.json()) + + except BaseException as exc: + logger.warning(f"Cannot receive rewards from uid {uid}: {exc}") + return {} diff --git a/config/mainnet.yaml.example b/config/mainnet.yaml.example index 3d8bffaf6..bff750080 100644 --- a/config/mainnet.yaml.example +++ b/config/mainnet.yaml.example @@ -1,8 +1,8 @@ chain: kwargs: netuid: 1 - coldkey: "validator" - hotkey: "default" + coldkey: "YOUR_COLDKEY" + hotkey: "YOUR_HOTKEY" network: - finney # - ws://LOCAL_SUBTENSOR_FALLBACK_1 @@ -27,3 +27,12 @@ deep_research: research_model: "Qwen/Qwen3-235B-A22B-Instruct-2507" compression_model: "deepseek-ai/DeepSeek-V3-0324" final_model: "deepseek-ai/DeepSeek-V3-0324" + +weight_syncer: + kwargs: + # Change the port if necessary. + port: 8001 + # When enabled, performs weight synchronization across validators, drastically improves vTrust. + enable_receive: True + # When enabled, allows other validators to request your rewards, slightly improves vTrust. + enable_send: True diff --git a/config/testnet.yaml.example b/config/testnet.yaml.example index 9b24697c8..adea06149 100644 --- a/config/testnet.yaml.example +++ b/config/testnet.yaml.example @@ -34,3 +34,12 @@ miner_sampler: # For testing purposes one can specify available pool of uids. # available_uids: [1, 2] # available_addresses: ["http://0.0.0.0:8081", "http://0.0.0.0:8082"] + +weight_syncer: + kwargs: + # Change the port if necessary. + port: 8001 + # When enabled, performs weight synchronization across validators, drastically improves vTrust. + enable_receive: True + # When enabled, allows other validators to request your rewards, slightly improves vTrust. + enable_send: True diff --git a/pyproject.toml b/pyproject.toml index ec1d2dc28..f8602d46a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "apex" -version = "3.0.0" +version = "3.0.1" description = "Bittensor Subnet 1: Apex" readme = "README.md" requires-python = "~=3.11" @@ -16,10 +16,11 @@ dependencies = [ "aiohttp>=3.8.3", "beautifulsoup4>=4.13.3", "langchain>=0.3.26", + "langchain-core>=0.3.68", "langchain-community>=0.0.59", - "faiss-cpu>=1.8.0", "langchain-openai>=0.3.28", "langchain-sandbox>=0.0.6", + "faiss-cpu>=1.8.0", "dotenv>=0.9.9", "rich>=14.0.0", "loguru>=0.7.3", @@ -28,6 +29,11 @@ dependencies = [ "bittensor>=9.7.0", "rouge>=1.0.1", "substrate-interface>=1.7.11", + "types-netaddr>=1.3.0.20240530", + "types-pyyaml>=6.0.12.20250516", + "types-cachetools>=6.0.0.20250525", + "dotenv>=0.9.9", + "pytest-mock>=3.14.1", ] @@ -35,13 +41,6 @@ dependencies = [ dev = [ "mypy==1.17.0", "ruff==0.12.5", - "types-pyyaml>=6.0.12.20250516", - "types-cachetools>=6.0.0.20250525", - "langchain>=0.3.26", - "dotenv>=0.9.9", - "langchain-openai>=0.3.28", - "langchain-core>=0.3.68", - "langchain-sandbox>=0.0.6", "pytest>=8.4.1", "pytest-asyncio>=1.0.0", "pytest-cov>=5.0.0", @@ -82,10 +81,8 @@ exclude = [ "^tests/", "^venv/", '^\.venv/', - # TODO: Enable once fixed. "scripts/", - "apex/services/", - "apex/validator/", + "drafts/", ] [[tool.mypy.overrides]] @@ -204,5 +201,6 @@ dev = [ "pydantic>=2.11.7", "pytest>=8.4.1", "pytest-asyncio>=1.0.0", + "pytest-mock>=3.14.1", "types-pyyaml>=6.0.12.20250516", ] diff --git a/scripts/autoupdater.py b/scripts/autoupdater.py new file mode 100644 index 000000000..b700db9d5 --- /dev/null +++ b/scripts/autoupdater.py @@ -0,0 +1,111 @@ +#!/usr/bin/env python3 +import argparse +import os +import signal +import subprocess +import sys +import time +from pathlib import Path + +CHECK_INTERVAL = 15 * 60 + + +def venv_python() -> str: + return os.path.join(".venv", "bin", "python") + + +def read_python_version() -> str | None: + try: + with open(".python-version", encoding="utf-8") as f: + # Take first non-empty token (pyenv format e.g. "3.11.9"). + return f.read().strip().split()[0] + except FileNotFoundError: + return None + + +def start_proc(config: Path) -> subprocess.Popen: + py_ver = read_python_version() + if py_ver: + subprocess.run(["uv", "venv", "--python", py_ver], check=True) + else: + subprocess.run(["uv", "venv"], check=True) + + # Install project in dev mode into the venv. + subprocess.run(["uv", "pip", "install", ".[dev]"], check=True) + + # Run validator. + return subprocess.Popen([venv_python(), "validator.py", "-c", str(config)]) + + +def stop_proc(process: subprocess.Popen) -> None: + if process and process.poll() is None: + process.terminate() + try: + process.wait(timeout=10) + except subprocess.TimeoutExpired: + process.kill() + + +def remote_has_updates() -> bool: + try: + subprocess.run(["git", "fetch", "--quiet"], check=True) + out = subprocess.check_output( + ["git", "rev-list", "--left-right", "--count", "@{u}...HEAD"], stderr=subprocess.STDOUT, text=True + ).strip() + left, right = map(int, out.split()) + # Remote is ahead. + return left > 0 + except subprocess.CalledProcessError: + # No upstream or git issue; treat as no updates. + return False + + +def git_pull_ff_only() -> None: + try: + subprocess.run(["git", "pull", "--ff-only"], check=True) + except subprocess.CalledProcessError as e: + print(f"Error: Git pull failed due to conflicts or other issues: {e}", file=sys.stderr) + print("Staying on the current version.", file=sys.stderr) + + +def read_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Apex validator") + parser.add_argument( + "-c", + "--config", + # default="config/testnet.yaml", + default="config/mainnet.yaml", + help="Config file path (e.g. config/mainnet.yaml).", + type=Path, + ) + args = parser.parse_args() + return args + + +def main() -> None: + args = read_args() + proc = start_proc(config=args.config) + + def handle_sigint(sig, frame): + stop_proc(proc) + sys.exit(0) + + signal.signal(signal.SIGINT, handle_sigint) + + while True: + time.sleep(CHECK_INTERVAL) + print("Checking for updates...") + + # If child exited, propagate its code. + if proc.poll() is not None: + sys.exit(proc.returncode) + + if remote_has_updates(): + print("Updates detected, restarting validator") + stop_proc(proc) + git_pull_ff_only() + proc = start_proc(config=args.config) + + +if __name__ == "__main__": + main() diff --git a/scripts/autoupdater_pm2.sh b/scripts/autoupdater_pm2.sh new file mode 100644 index 000000000..7adca2022 --- /dev/null +++ b/scripts/autoupdater_pm2.sh @@ -0,0 +1,46 @@ +#!/usr/bin/env bash +APP_NAME="sn1" +CONFIG="config/mainnet.yaml" +# CONFIG="config/testnet.yaml" + +UV_INSTALL_URL="https://astral.sh/uv/install.sh" + +# Ensure common user bin dirs are in PATH. +export PATH="$HOME/.local/bin:$HOME/.cargo/bin:$HOME/bin:$HOME/.npm-global/bin:$PATH" + +# 1) Ensure uv exists. +if ! command -v uv >/dev/null 2>&1; then + echo "[info] uv not found; installing..." + if ! command -v curl >/dev/null 2>&1; then + echo "[error] curl is required to install uv." >&2 + exit 1 + fi + curl -LsSf "$UV_INSTALL_URL" | sh + hash -r + if ! command -v uv >/dev/null 2>&1; then + echo "[error] uv installation completed but 'uv' not found in PATH." >&2 + exit 1 + fi +else + echo "[info] uv found: $(command -v uv)" +fi + +# 2) Ensure pm2 exists. +if ! command -v pm2 >/dev/null 2>&1; then + echo "[info] pm2 not found; installing globally with npm..." + if ! command -v npm >/dev/null 2>&1; then + echo "[error] npm is required to install pm2. Please install Node.js first." >&2 + exit 1 + fi + npm install -g pm2 + hash -r + if ! command -v pm2 >/dev/null 2>&1; then + echo "[error] pm2 installation completed but 'pm2' not found in PATH." >&2 + exit 1 + fi +else + echo "[info] pm2 found: $(command -v pm2)" +fi + +pm2 start scripts/autoupdater.py --interpreter .venv/bin/python --name sn1 -- -c $CONFIG +pm2 logs sn1 diff --git a/tests/common/mock_async_chain.py b/tests/common/mock_async_chain.py index 2161c344b..1c0312fdd 100644 --- a/tests/common/mock_async_chain.py +++ b/tests/common/mock_async_chain.py @@ -54,18 +54,20 @@ async def set_weights( netuid: int, uids: Iterable[int], weights: Iterable[float], + version_key: int, wait_for_inclusion: bool, wait_for_finalization: bool, - ) -> bool: + ) -> tuple[bool, str | None]: self.last_set_weights = { "wallet": wallet, "netuid": netuid, "uids": list(uids), "weights": list(weights), + "version_key": version_key, "wait_for_inclusion": wait_for_inclusion, "wait_for_finalization": wait_for_finalization, } - return self.weights_result + return self.weights_result, "" def patch_wallet(monkeypatch: pytest.MonkeyPatch) -> None: diff --git a/tests/common/test_async_chain.py b/tests/common/test_async_chain.py index 6d83fa262..85092880c 100644 --- a/tests/common/test_async_chain.py +++ b/tests/common/test_async_chain.py @@ -2,6 +2,7 @@ import pytest +from apex import __spec_version__ from apex.common.async_chain import AsyncChain # noqa: E402 from tests.common.mock_async_chain import DummyMetagraph, DummySubtensor, patch_subtensor, patch_wallet @@ -121,6 +122,7 @@ async def test_set_weights_happy_path(monkeypatch): assert stub.last_set_weights is not None assert stub.last_set_weights["uids"] == [2] assert stub.last_set_weights["weights"] == [0.7] + assert stub.last_set_weights["version_key"] == __spec_version__ @pytest.mark.asyncio diff --git a/tests/scripts/test_autoupdater.py b/tests/scripts/test_autoupdater.py new file mode 100644 index 000000000..84412aa7d --- /dev/null +++ b/tests/scripts/test_autoupdater.py @@ -0,0 +1,203 @@ +import os +import subprocess +import sys +from unittest import mock + +import pytest + +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../scripts"))) +import autoupdater # isort: skip + + +def test_venv_python(mocker): + mocker.patch("autoupdater.os.path.join", return_value=".venv/bin/python") + assert autoupdater.venv_python() == ".venv/bin/python" + autoupdater.os.path.join.assert_called_once_with(".venv", "bin", "python") + + +def test_read_python_version_exists(mocker): + mocker.patch("autoupdater.open", mocker.mock_open(read_data="3.11.9")) + assert autoupdater.read_python_version() == "3.11.9" + autoupdater.open.assert_called_once_with(".python-version", encoding="utf-8") + + +def test_read_python_version_not_found(mocker): + mocker.patch("autoupdater.open", side_effect=FileNotFoundError) + assert autoupdater.read_python_version() is None + autoupdater.open.assert_called_once_with(".python-version", encoding="utf-8") + + +def test_start_proc_with_version(mocker): + mocker.patch("autoupdater.read_python_version", return_value="3.11.9") + mock_run = mocker.patch("subprocess.run") + mock_popen = mocker.patch("subprocess.Popen", return_value=mocker.Mock()) + mocker.patch("autoupdater.venv_python", return_value="mock_python") + + proc = autoupdater.start_proc(config="mock.yaml") + autoupdater.read_python_version.assert_called_once() + mock_run.assert_has_calls( + [ + mock.call(["uv", "venv", "--python", "3.11.9"], check=True), + mock.call(["uv", "pip", "install", ".[dev]"], check=True), + ] + ) + mock_popen.assert_called_once_with(["mock_python", "validator.py", "-c", "mock.yaml"]) + assert proc is not None + + +def test_start_proc_without_version(mocker): + mocker.patch("autoupdater.read_python_version", return_value=None) + mock_run = mocker.patch("subprocess.run") + mock_popen = mocker.patch("subprocess.Popen", return_value=mocker.Mock()) + mocker.patch("autoupdater.venv_python", return_value="mock_python") + + proc = autoupdater.start_proc(config="mock.yaml") + autoupdater.read_python_version.assert_called_once() + mock_run.assert_has_calls( + [ + mock.call(["uv", "venv"], check=True), + mock.call(["uv", "pip", "install", ".[dev]"], check=True), + ] + ) + mock_popen.assert_called_once_with(["mock_python", "validator.py", "-c", "mock.yaml"]) + assert proc is not None + + +def test_stop_proc_running(): + mock_proc = mock.Mock() + mock_proc.poll.return_value = None # Process is running + autoupdater.stop_proc(mock_proc) + mock_proc.terminate.assert_called_once() + mock_proc.wait.assert_called_once_with(timeout=10) + mock_proc.kill.assert_not_called() + + +def test_stop_proc_timeout(): + mock_proc = mock.Mock() + mock_proc.poll.return_value = None # Process is running + mock_proc.wait.side_effect = subprocess.TimeoutExpired(cmd="test", timeout=10) + autoupdater.stop_proc(mock_proc) + mock_proc.terminate.assert_called_once() + mock_proc.wait.assert_called_once_with(timeout=10) + mock_proc.kill.assert_called_once() + + +def test_stop_proc_already_stopped(): + mock_proc = mock.Mock() + mock_proc.poll.return_value = 0 # Process already stopped + autoupdater.stop_proc(mock_proc) + mock_proc.terminate.assert_not_called() + mock_proc.wait.assert_not_called() + mock_proc.kill.assert_not_called() + + +def test_remote_has_updates_true(mocker): + mock_run = mocker.patch("subprocess.run") + mock_check_output = mocker.patch("subprocess.check_output", return_value="1\t0") + assert autoupdater.remote_has_updates() is True + mock_run.assert_called_once_with(["git", "fetch", "--quiet"], check=True) + mock_check_output.assert_called_once_with( + ["git", "rev-list", "--left-right", "--count", "@{u}...HEAD"], stderr=subprocess.STDOUT, text=True + ) + + +def test_remote_has_updates_false_no_diff(mocker): + mocker.patch("subprocess.run") + mocker.patch("subprocess.check_output", return_value="0\t0") + assert autoupdater.remote_has_updates() is False + + +def test_remote_has_updates_error(mocker): + mock_run = mocker.patch("subprocess.run") + mocker.patch("subprocess.check_output", side_effect=subprocess.CalledProcessError(1, "cmd")) + assert autoupdater.remote_has_updates() is False + mock_run.assert_called_once_with(["git", "fetch", "--quiet"], check=True) + autoupdater.subprocess.check_output.assert_called_once() + + +def test_git_pull_ff_only_success(mocker): + mock_run = mocker.patch("subprocess.run") + autoupdater.git_pull_ff_only() + mock_run.assert_called_once_with(["git", "pull", "--ff-only"], check=True) + + +def test_git_pull_ff_only_conflict(mocker): + mocker.patch("subprocess.run", side_effect=subprocess.CalledProcessError(1, "cmd", stderr="conflict")) + mock_stderr = mocker.patch("sys.stderr", new_callable=mock.MagicMock) + autoupdater.git_pull_ff_only() + autoupdater.subprocess.run.assert_called_once_with(["git", "pull", "--ff-only"], check=True) + # Check for individual calls to write, as print adds newlines separately + mock_stderr.write.assert_any_call( + "Error: Git pull failed due to conflicts or other issues: Command 'cmd' returned non-zero exit status 1." + ) + mock_stderr.write.assert_any_call("\n") + mock_stderr.write.assert_any_call("Staying on the current version.") + mock_stderr.write.assert_any_call("\n") + + +def test_main_loop_with_update(mocker): + # start_proc returns the same Mock proc; we drive its poll() via side_effect + mock_start_proc = mocker.patch("autoupdater.start_proc", return_value=mocker.Mock()) + mock_start_proc.return_value.returncode = 0 + mock_stop_proc = mocker.patch("autoupdater.stop_proc") + mock_remote_updates = mocker.patch("autoupdater.remote_has_updates", side_effect=[False, True]) + mock_git_pull = mocker.patch("autoupdater.git_pull_ff_only") + mock_sleep = mocker.patch("time.sleep", return_value=None) + + def _raise(code=0): + raise SystemExit(code) + + mock_sys_exit = mocker.patch("sys.exit", side_effect=_raise) + + mock_args = mocker.Mock() + mock_args.config = "mock.yaml" + mocker.patch("autoupdater.read_args", return_value=mock_args) + + # First loop: running (None), no update + # Second loop: still running (None), update happens -> restart + # Third loop: process seen as exited (0) -> sys.exit(0) + mock_start_proc.return_value.poll.side_effect = [None, None, 0] + + with pytest.raises(SystemExit) as pytest_wrapped_e: + autoupdater.main() + assert pytest_wrapped_e.type is SystemExit + assert pytest_wrapped_e.value.code == 0 + + assert mock_sleep.call_count == 3 # before each iteration, including after restart + assert mock_remote_updates.call_count == 2 + mock_stop_proc.assert_called_once_with(mock_start_proc.return_value) + mock_git_pull.assert_called_once() + mock_start_proc.assert_called_with(config=mock.ANY) + mock_sys_exit.assert_called_once_with(0) + + +def test_main_loop_no_update(mocker): + mock_start_proc = mocker.patch("autoupdater.start_proc", return_value=mocker.Mock()) + mock_start_proc.return_value.returncode = 0 + mock_stop_proc = mocker.patch("autoupdater.stop_proc") + mock_remote_updates = mocker.patch("autoupdater.remote_has_updates", return_value=False) + mock_git_pull = mocker.patch("autoupdater.git_pull_ff_only") + mock_sleep = mocker.patch("time.sleep", return_value=None) + + def _raise(code=0): + raise SystemExit(code) + + mock_sys_exit = mocker.patch("sys.exit", side_effect=_raise) + + mock_args = mocker.Mock() + mock_args.config = "mock.yaml" + mocker.patch("autoupdater.read_args", return_value=mock_args) + + # First loop running; second loop exits -> sys.exit(0) before checking for updates + mock_start_proc.return_value.poll.side_effect = [None, 0] + + with pytest.raises(SystemExit) as pytest_wrapped_e: + autoupdater.main() + assert pytest_wrapped_e.type is SystemExit + assert pytest_wrapped_e.value.code == 0 + + assert mock_sleep.call_count == 2 + assert mock_remote_updates.call_count == 1 # second loop exits before calling this + mock_stop_proc.assert_not_called() + mock_git_pull.assert_not_called() + mock_sys_exit.assert_called_once_with(0) diff --git a/tests/validator/test_miner_sampler.py b/tests/validator/test_miner_sampler.py index d2b8c04ba..8a4b2fcb1 100644 --- a/tests/validator/test_miner_sampler.py +++ b/tests/validator/test_miner_sampler.py @@ -3,6 +3,7 @@ from typing import Any from unittest.mock import AsyncMock, MagicMock, patch +import aiohttp import pytest from pydantic import BaseModel from pytest import MonkeyPatch @@ -166,20 +167,24 @@ async def test_sample_miners_sequential(monkeypatch: MagicMock, miner_sampler: M monkeypatch.setattr(miner_sampler, "_get_all_miners", AsyncMock(return_value=all_miners)) # 1st call in epoch. - with patch("random.sample", return_value=[0, 2]): + with patch( + "random.sample", + return_value=[MinerInfo(uid=1, address="", hotkey="1"), MinerInfo(uid=5, address="", hotkey="5")], + ): miners1 = await miner_sampler._sample_miners() assert len(miners1) == 2 assert {m.uid for m in miners1} == {all_miners[0].uid, all_miners[2].uid} - assert len(miner_sampler._remaining_epoch_miners) == 1 # 2nd call, new epoch starts as remaining (1) < sample_size (2). - with patch("random.sample", return_value=[1, 2]): + with patch( + "random.sample", + return_value=[MinerInfo(uid=3, address="", hotkey="3"), MinerInfo(uid=5, address="", hotkey="5")], + ): miners2 = await miner_sampler._sample_miners() assert len(miners2) == 2 assert {m.uid for m in miners2} == {all_miners[1].uid, all_miners[2].uid} - assert len(miner_sampler._remaining_epoch_miners) == 1 @pytest.mark.asyncio @@ -209,12 +214,15 @@ async def test_query_miners() -> None: patch("time.time", return_value=12345), ): mock_generate_header.return_value = {"some": "header"} - result = await sampler.query_miners(body, endpoint) + timeout = 20 + result = await sampler.query_miners(body, endpoint, timeout=timeout) mock_client_session.assert_called_once() expected_body = {"test": "data"} + + client_timeout = aiohttp.ClientTimeout(total=timeout) mock_session.post.assert_called_with( - endpoint + "/v1/chat/completions", headers={"some": "header"}, json=expected_body + endpoint + "/v1/chat/completions", headers={"some": "header"}, json=expected_body, timeout=client_timeout ) mock_generate_header.assert_called_with( mock_chain.wallet.hotkey, body=json.dumps(body).encode("utf-8"), signed_for=None diff --git a/tests/validator/test_weight_syncer.py b/tests/validator/test_weight_syncer.py new file mode 100644 index 000000000..26355565b --- /dev/null +++ b/tests/validator/test_weight_syncer.py @@ -0,0 +1,140 @@ +import time +from unittest.mock import AsyncMock, MagicMock, patch + +import numpy as np +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient + +from apex.validator.weight_syncer import WeightSyncer +from tests.common.mock_async_chain import DummyMetagraph + + +class UidList(list): + """A list that has a .tolist() method for compatibility with torch tensors.""" + + def tolist(self): + return self + + +@pytest.fixture +def mock_axon(): + """Returns a function to create a mock axon.""" + + def _mock_axon(ip, port, is_serving=True): + axon = MagicMock() + axon.ip = ip + axon.port = port + axon.is_serving = is_serving + return axon + + return _mock_axon + + +@pytest.fixture +def mock_metagraph(mock_axon): + """Returns a mock metagraph based on DummyMetagraph.""" + metagraph = DummyMetagraph( + hotkeys=["hotkey0_self", "hotkey1_validator", "hotkey2_validator"], + ) + # Overwrite uids with our special UidList to provide .tolist() + metagraph.uids = UidList([0, 1, 2]) + metagraph.stake = [np.float32(1000.0), np.float32(2000.0), np.float32(500.0)] + metagraph.validator_permit = [True, True, False] + metagraph.axons = [ + mock_axon("1.1.1.1", 8000), + mock_axon("2.2.2.2", 8001), + mock_axon("3.3.3.3", 8002), + ] + return metagraph + + +@pytest.fixture +def mock_chain(mock_metagraph): + """Returns a mock chain with a mock metagraph.""" + chain = MagicMock() + chain.wallet.hotkey.ss58_address = "hotkey0_self" + chain.metagraph = AsyncMock(return_value=mock_metagraph) + return chain + + +@pytest.fixture +def weight_syncer(mock_chain): + """Returns a WeightSyncer instance with a mock chain.""" + return WeightSyncer(chain=mock_chain, min_alpha_stake=1000) + + +@pytest.mark.asyncio +async def test_compute_weighted_rewards_happy_path(weight_syncer, mock_metagraph): + """Test that weighted rewards are computed correctly in the ideal case.""" + local_rewards = {"miner1": 0.9, "miner2": 0.1} + validator1_rewards = {"miner1": 0.85, "miner2": 0.82, "miner3": 0.7} + + with patch.object(weight_syncer, "receive_rewards", new_callable=AsyncMock) as mock_receive: + mock_receive.side_effect = [validator1_rewards, {}] # UID 2 has low stake + + weighted_rewards = await weight_syncer.compute_weighted_rewards(local_rewards) + + # self (1000) + validator1 (2000) = 3000 total stake + # miner1: (0.9 * 1000 + 0.85 * 2000) / 3000 = 0.8666 + # miner2: (0.1 * 1000 + 0.82 * 2000) / 3000 = 0.58 + assert mock_receive.call_count == 1 + assert mock_receive.call_args.args[1] == 1 # Called for UID 1 + assert pytest.approx(weighted_rewards["miner1"], 0.001) == 0.8666 + assert pytest.approx(weighted_rewards["miner2"], 0.001) == 0.58 + assert "miner3" not in weighted_rewards + + +@pytest.mark.asyncio +async def test_compute_weighted_rewards_self_not_in_metagraph(weight_syncer, mock_metagraph): + """Test that local rewards are returned if the validator's hotkey is not in the metagraph.""" + mock_metagraph.hotkeys = ["other1", "other2", "other3"] + local_rewards = {"miner1": 0.9} + weighted_rewards = await weight_syncer.compute_weighted_rewards(local_rewards) + assert weighted_rewards == local_rewards + + +@pytest.mark.asyncio +@patch("httpx.AsyncClient") +async def test_receive_rewards_success(mock_async_client, weight_syncer, mock_metagraph): + """Test successfully receiving rewards from another validator.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = {"miner1": 0.9} + mock_async_client.return_value.__aenter__.return_value.post = AsyncMock(return_value=mock_response) + + rewards = await weight_syncer.receive_rewards(mock_metagraph, 1) + assert rewards == {"miner1": 0.9} + + +@pytest.mark.asyncio +@patch("httpx.AsyncClient") +async def test_receive_rewards_http_error(mock_async_client, weight_syncer, mock_metagraph): + """Test that an empty dict is returned on HTTP error.""" + mock_async_client.return_value.__aenter__.return_value.post.side_effect = Exception("HTTP Error") + rewards = await weight_syncer.receive_rewards(mock_metagraph, 1) + assert rewards == {} + + +@patch("apex.validator.weight_syncer.verify_validator_signature", new_callable=AsyncMock) +def test_get_rewards_endpoint(mock_verify_signature, weight_syncer): + """Test the FastAPI endpoint for serving rewards.""" + app = FastAPI() + app.include_router(weight_syncer.get_router()) + client = TestClient(app) + + # Case 1: No rewards set yet + response = client.post("/v1/get_rewards") + assert response.status_code == 503 + + # Case 2: Rewards are set and not expired + weight_syncer.hotkey_rewards = {"miner1": 0.95} + weight_syncer.last_update_time = time.time() + response = client.post("/v1/get_rewards") + assert response.status_code == 200 + assert response.json() == {"miner1": 0.95} + + # Case 3: Rewards are expired + weight_syncer.last_update_time = time.time() - WeightSyncer.REWARD_EXPIRATION_SEC - 1 + response = client.post("/v1/get_rewards") + assert response.status_code == 503 diff --git a/uv.lock b/uv.lock index d2db4f48f..70e3f580f 100644 --- a/uv.lock +++ b/uv.lock @@ -141,7 +141,7 @@ wheels = [ [[package]] name = "apex" -version = "3.0.0" +version = "3.0.1" source = { virtual = "." } dependencies = [ { name = "aiohttp" }, @@ -153,6 +153,7 @@ dependencies = [ { name = "faiss-cpu" }, { name = "langchain" }, { name = "langchain-community" }, + { name = "langchain-core" }, { name = "langchain-openai" }, { name = "langchain-sandbox" }, { name = "loguru" }, @@ -162,29 +163,26 @@ dependencies = [ { name = "numpy", version = "2.3.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.13'" }, { name = "pip" }, { name = "pydantic" }, + { name = "pytest-mock" }, { name = "pyyaml" }, { name = "requests" }, { name = "rich" }, { name = "rouge" }, { name = "substrate-interface" }, { name = "tavily-python" }, + { name = "types-cachetools" }, + { name = "types-netaddr" }, + { name = "types-pyyaml" }, ] [package.optional-dependencies] dev = [ - { name = "dotenv" }, - { name = "langchain" }, - { name = "langchain-core" }, - { name = "langchain-openai" }, - { name = "langchain-sandbox" }, { name = "mypy" }, { name = "pre-commit" }, { name = "pytest" }, { name = "pytest-asyncio" }, { name = "pytest-cov" }, { name = "ruff" }, - { name = "types-cachetools" }, - { name = "types-pyyaml" }, ] [package.dev-dependencies] @@ -195,6 +193,7 @@ dev = [ { name = "pydantic" }, { name = "pytest" }, { name = "pytest-asyncio" }, + { name = "pytest-mock" }, { name = "types-pyyaml" }, ] @@ -206,16 +205,12 @@ requires-dist = [ { name = "bittensor", specifier = ">=9.7.0" }, { name = "cachetools", specifier = ">=5.0.0" }, { name = "dotenv", specifier = ">=0.9.9" }, - { name = "dotenv", marker = "extra == 'dev'", specifier = ">=0.9.9" }, { name = "faiss-cpu", specifier = ">=1.8.0" }, { name = "langchain", specifier = ">=0.3.26" }, - { name = "langchain", marker = "extra == 'dev'", specifier = ">=0.3.26" }, { name = "langchain-community", specifier = ">=0.0.59" }, - { name = "langchain-core", marker = "extra == 'dev'", specifier = ">=0.3.68" }, + { name = "langchain-core", specifier = ">=0.3.68" }, { name = "langchain-openai", specifier = ">=0.3.28" }, - { name = "langchain-openai", marker = "extra == 'dev'", specifier = ">=0.3.28" }, { name = "langchain-sandbox", specifier = ">=0.0.6" }, - { name = "langchain-sandbox", marker = "extra == 'dev'", specifier = ">=0.0.6" }, { name = "loguru", specifier = ">=0.7.3" }, { name = "macrocosmos", specifier = ">=0.1.0" }, { name = "mypy", marker = "extra == 'dev'", specifier = "==1.17.0" }, @@ -227,6 +222,7 @@ requires-dist = [ { name = "pytest", marker = "extra == 'dev'", specifier = ">=8.4.1" }, { name = "pytest-asyncio", marker = "extra == 'dev'", specifier = ">=1.0.0" }, { name = "pytest-cov", marker = "extra == 'dev'", specifier = ">=5.0.0" }, + { name = "pytest-mock", specifier = ">=3.14.1" }, { name = "pyyaml", specifier = ">=6.0.0" }, { name = "requests", specifier = ">=2.31.0" }, { name = "rich", specifier = ">=14.0.0" }, @@ -234,8 +230,9 @@ requires-dist = [ { name = "ruff", marker = "extra == 'dev'", specifier = "==0.12.5" }, { name = "substrate-interface", specifier = ">=1.7.11" }, { name = "tavily-python", specifier = ">=0.7.10" }, - { name = "types-cachetools", marker = "extra == 'dev'", specifier = ">=6.0.0.20250525" }, - { name = "types-pyyaml", marker = "extra == 'dev'", specifier = ">=6.0.12.20250516" }, + { name = "types-cachetools", specifier = ">=6.0.0.20250525" }, + { name = "types-netaddr", specifier = ">=1.3.0.20240530" }, + { name = "types-pyyaml", specifier = ">=6.0.12.20250516" }, ] provides-extras = ["dev"] @@ -247,6 +244,7 @@ dev = [ { name = "pydantic", specifier = ">=2.11.7" }, { name = "pytest", specifier = ">=8.4.1" }, { name = "pytest-asyncio", specifier = ">=1.0.0" }, + { name = "pytest-mock", specifier = ">=3.14.1" }, { name = "types-pyyaml", specifier = ">=6.0.12.20250516" }, ] @@ -2412,6 +2410,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/bc/16/4ea354101abb1287856baa4af2732be351c7bee728065aed451b678153fd/pytest_cov-6.2.1-py3-none-any.whl", hash = "sha256:f5bc4c23f42f1cdd23c70b1dab1bbaef4fc505ba950d53e0081d0730dd7e86d5", size = 24644 }, ] +[[package]] +name = "pytest-mock" +version = "3.14.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pytest" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/71/28/67172c96ba684058a4d24ffe144d64783d2a270d0af0d9e792737bddc75c/pytest_mock-3.14.1.tar.gz", hash = "sha256:159e9edac4c451ce77a5cdb9fc5d1100708d2dd4ba3c3df572f14097351af80e", size = 33241 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b2/05/77b60e520511c53d1c1ca75f1930c7dd8e971d0c4379b7f4b3f9644685ba/pytest_mock-3.14.1-py3-none-any.whl", hash = "sha256:178aefcd11307d874b4cd3100344e7e2d888d9791a6a1d9bfe90fbc1b74fd1d0", size = 9923 }, +] + [[package]] name = "python-dotenv" version = "1.1.1" @@ -2865,6 +2875,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/47/8c/4ab0a17ece30fe608270b89cf066387051862899fff9f54ab12511fc7fdd/types_cachetools-6.0.0.20250525-py3-none-any.whl", hash = "sha256:1de8f0fe4bdcb187a48d2026c1e3672830f67943ad2bf3486abe031b632f1252", size = 8938 }, ] +[[package]] +name = "types-netaddr" +version = "1.3.0.20240530" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/44/10/b4bad18b6918eec74db1087d1376cc094fdef5fa1261aa905b6cd94b9408/types-netaddr-1.3.0.20240530.tar.gz", hash = "sha256:742c2ec1f202b666f544223e2616b34f1f13df80c91e5aeaaa93a72e4d0774ea", size = 10459 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d0/de/1eaa9f59ea51eb8b64c04f2be5af8c91940072f1c5b0cfa5870a7665a028/types_netaddr-1.3.0.20240530-py3-none-any.whl", hash = "sha256:354998d018e326da4f1d9b005fc91137b7c2c473aaf03c4ef64bf83c6861b440", size = 14335 }, +] + [[package]] name = "types-pyyaml" version = "6.0.12.20250516" diff --git a/validator.py b/validator.py index 754afa9a0..8076fec97 100644 --- a/validator.py +++ b/validator.py @@ -4,6 +4,7 @@ from loguru import logger +from apex import __version__ from apex.common.async_chain import AsyncChain from apex.common.config import Config from apex.services.deep_research.deep_research_langchain import DeepResearchLangchain @@ -13,6 +14,7 @@ from apex.validator.miner_sampler import MinerSampler from apex.validator.miner_scorer import MinerScorer from apex.validator.pipeline import Pipeline +from apex.validator.weight_syncer import WeightSyncer async def read_args() -> argparse.Namespace: @@ -32,28 +34,45 @@ async def read_args() -> argparse.Namespace: async def main() -> None: args = await read_args() config = Config.from_file(path=args.config) + logger.debug(f"Starting validator v{__version__} with config: {args.config}") chain = AsyncChain(**config.chain.kwargs) await chain.start() + logger.debug( + f"Connected to the chain netuid={chain.netuid} with coldkey '{chain.coldkey[:2]}***', " + f"hotkey '{chain.hotkey[:2]}***'" + ) logger_db = LoggerDB(**config.logger_db.kwargs) asyncio.create_task(logger_db.start_loop()) + logger.debug(f"Started DB at: '{logger_db.db_path}'") # logger_apex = LoggerApex(async_chain=chain) websearch = WebSearchTavily(**config.websearch.kwargs) + logger.debug("Started web search tool") miner_sampler = MinerSampler(chain=chain, logger_db=logger_db, **config.miner_sampler.kwargs) + logger.debug("Started miner sampler") + + weight_syncer = WeightSyncer(chain=chain, **config.weight_syncer.kwargs) + await weight_syncer.start() + logger.debug( + f"Started weight synchronizer, receive enabled: {weight_syncer.receive_enabled}, " + f"send enabled: {weight_syncer.send_enabled}, port: {weight_syncer.port}" + ) - miner_scorer = MinerScorer(chain=chain, **config.miner_scorer.kwargs) + miner_scorer = MinerScorer(chain=chain, weight_syncer=weight_syncer, **config.miner_scorer.kwargs) asyncio.create_task(miner_scorer.start_loop()) + logger.debug(f"Started miner scorer with interval={miner_scorer.interval}") llm = LLM(**config.llm.kwargs) + logger.debug("Started LLM provider") deep_research = DeepResearchLangchain(websearch=websearch, **config.deep_research.kwargs) + logger.debug("Started Deep Researcher") pipeline = Pipeline( - config=config, websearch=websearch, miner_sampler=miner_sampler, llm=llm, @@ -62,6 +81,7 @@ async def main() -> None: **config.pipeline.kwargs, ) try: + logger.debug("Starting pipeline loop...") await pipeline.start_loop() except KeyboardInterrupt: logger.warning("Keyboard interrupt caught, exiting validator") @@ -71,6 +91,7 @@ async def main() -> None: await chain.shutdown() await logger_db.shutdown() await miner_scorer.shutdown() + await weight_syncer.shutdown() if __name__ == "__main__":