diff --git a/apex/common/async_chain.py b/apex/common/async_chain.py index a87d96a88..d4e8a2a61 100644 --- a/apex/common/async_chain.py +++ b/apex/common/async_chain.py @@ -16,7 +16,6 @@ def __init__(self, coldkey: str, hotkey: str, netuid: int, network: list[str] | if isinstance(network, str): network = [network] self._network: list[str] = network - self._coldkey = coldkey self._hotkey = hotkey self._netuid = netuid diff --git a/apex/common/config.py b/apex/common/config.py index fc3fff308..9cfcad02b 100644 --- a/apex/common/config.py +++ b/apex/common/config.py @@ -15,6 +15,7 @@ class Config(BaseModel): chain: ConfigClass = Field(default_factory=ConfigClass) websearch: ConfigClass = Field(default_factory=ConfigClass) logger_db: ConfigClass = Field(default_factory=ConfigClass) + weight_syncer: ConfigClass = Field(default_factory=ConfigClass) miner_sampler: ConfigClass = Field(default_factory=ConfigClass) miner_scorer: ConfigClass = Field(default_factory=ConfigClass) llm: ConfigClass = Field(default_factory=ConfigClass) diff --git a/apex/common/constants.py b/apex/common/constants.py index 6315e4485..9f7698ab4 100644 --- a/apex/common/constants.py +++ b/apex/common/constants.py @@ -2,6 +2,14 @@ TEMPERATURE: float = 0.1 WEBPAGE_MAXSIZE: int = 500 VALIDATOR_REFERENCE_LABEL = "Validator" +VALIDATOR_VERIFIED_HOTKEYS = { + "5CGLCBndTR1BvQZzn429ckT8GyxduzyjMgt4K1UVTYa8gKfb": "167.99.236.79:8001", # Macrocosmos. + "5CUbyC2Ez7tWYYmnFSSwjqkw26dFNo9cXH8YmcxBSfxi2XSG": None, # Yuma. + "5C8Em1kDZi5rxgDN4zZtfoT7dUqJ7FFbTzS3yTP5GPgVUsn1": None, # RoundTable21. + "5HmkM6X1D3W3CuCSPuHhrbYyZNBy2aGAiZy9NczoJmtY25H7": None, # Crucible. + "5GeR3cDuuFKJ7p66wKGjY65MWjWnYqffq571ZMV4gKMnJqK5": None, # OTF. + "5D1saVvssckE1XoPwPzdHrqYZtvBJ3vESsrPNxZ4zAxbKGs1": None, # Rizzo. +} _ENGLISH_WORDS: tuple[str, ...] | None = None diff --git a/apex/common/epistula.py b/apex/common/epistula.py index 67b16d32b..8d5aface7 100644 --- a/apex/common/epistula.py +++ b/apex/common/epistula.py @@ -1,14 +1,17 @@ -# import json +import json import time from hashlib import sha256 from math import ceil -from typing import Any +from typing import Annotated, Any from uuid import uuid4 -# from fastapi import HTTPException, Request -# from loguru import logger +from fastapi import HTTPException, Request +from loguru import logger from substrateinterface import Keypair +from apex.common.async_chain import AsyncChain +from apex.common.constants import VALIDATOR_VERIFIED_HOTKEYS + async def generate_header( hotkey: Keypair, @@ -38,66 +41,80 @@ async def generate_header( return headers -# def verify_signature( -# signature: str, body: bytes, timestamp: int, uuid: str, signed_for: str, signed_by: str, now: float -# ) -> Annotated[str, "Error Message"] | None: -# if not isinstance(signature, str): -# return "Invalid Signature" -# timestamp = int(timestamp) -# if not isinstance(timestamp, int): -# return "Invalid Timestamp" -# if not isinstance(signed_by, str): -# return "Invalid Sender key" -# if not isinstance(signed_for, str): -# return "Invalid receiver key" -# if not isinstance(uuid, str): -# return "Invalid uuid" -# if not isinstance(body, bytes): -# return "Body is not of type bytes" -# allowed_delta_ms = 8000 -# keypair = Keypair(ss58_address=signed_by) -# if timestamp + allowed_delta_ms < now: -# return "Request is too stale" -# message = f"{sha256(body).hexdigest()}.{uuid}.{timestamp}.{signed_for}" -# verified = keypair.verify(message, signature) -# if not verified: -# return "Signature Mismatch" -# return None -# -# -# async def verify_weight_signature(request: Request): -# signed_by = request.headers.get("Epistula-Signed-By") -# signed_for = request.headers.get("Epistula-Signed-For") -# if not signed_by or not signed_for: -# raise HTTPException(400, "Missing Epistula-Signed-* headers") -# -# if signed_for != shared_settings.WALLET.hotkey.ss58_address: -# logger.error("Bad Request, message is not intended for self") -# raise HTTPException(status_code=400, detail="Bad Request, message is not intended for self") -# validator_hotkeys = [shared_settings.METAGRAPH.hotkeys[uid] for uid in WHITELISTED_VALIDATORS_UIDS] -# if signed_by not in validator_hotkeys: -# logger.error(f"Signer not the expected ss58 address: {signed_by}") -# raise HTTPException(status_code=401, detail="Signer not the expected ss58 address") -# -# now = time.time() -# body: bytes = await request.body() -# try: -# payload = json.loads(body) -# except json.JSONDecodeError: -# raise HTTPException(400, "Invalid JSON body") -# -# if payload.get("uid") != get_uid_from_hotkey(signed_by): -# raise HTTPException(400, "Invalid uid in body") -# -# err = verify_signature( -# request.headers.get("Epistula-Request-Signature"), -# body, -# request.headers.get("Epistula-Timestamp"), -# request.headers.get("Epistula-Uuid"), -# signed_for, -# signed_by, -# now, -# ) -# if err: -# logger.error(err) -# raise HTTPException(status_code=400, detail=err) +def verify_signature( + signature: str | None, + body: bytes, + timestamp: str | None, + uuid: str | None, + signed_for: str, + signed_by: str, + now: float, +) -> Annotated[str, "Error Message"] | None: + if not isinstance(signature, str): + return "Invalid Signature" + if not isinstance(timestamp, str) or not timestamp.isdigit(): + return "Invalid Timestamp" + timestamp_as_int = int(timestamp) + if not isinstance(signed_by, str): + return "Invalid Sender key" + if not isinstance(signed_for, str): + return "Invalid receiver key" + if not isinstance(uuid, str): + return "Invalid uuid" + if not isinstance(body, bytes): + return "Body is not of type bytes" + allowed_delta_ms = 8000 + keypair = Keypair(ss58_address=signed_by) + if timestamp_as_int + allowed_delta_ms < now: + return "Request is too stale" + message = f"{sha256(body).hexdigest()}.{uuid}.{timestamp}.{signed_for}" + verified = keypair.verify(message, signature) + if not verified: + return "Signature Mismatch" + return None + + +async def verify_validator_signature(request: Request, chain: AsyncChain, min_stake: float = 1024) -> None: + signed_by = request.headers.get("Epistula-Signed-By") + signed_for = request.headers.get("Epistula-Signed-For") + if not signed_by or not signed_for: + logger.error("Missing Epistula-Signed-* headers") + raise HTTPException(400, "Missing Epistula-Signed-* headers") + + wallet = chain.wallet + if signed_for != wallet.hotkey.ss58_address: + logger.error("Bad Request, message is not intended for self") + raise HTTPException(status_code=400, detail="Bad Request, message is not intended for self") + + is_validator = True + if min_stake > 0: + metagraph = await chain.metagraph() + try: + caller_uid = metagraph.hotkeys.index(signed_by) + except ValueError as exc: + raise HTTPException(status_code=401, detail="Signer is not in metagraph") from exc + is_validator = metagraph.stake[caller_uid] > min_stake + + if signed_by not in VALIDATOR_VERIFIED_HOTKEYS and not is_validator: + logger.error(f"Signer not the expected ss58 address: {signed_by}") + raise HTTPException(status_code=401, detail="Signer not the expected ss58 address") + + now = time.time() + body: bytes = await request.body() + try: + json.loads(body) + except json.JSONDecodeError as exc: + raise HTTPException(400, "Invalid JSON body") from exc + + err = verify_signature( + request.headers.get("Epistula-Request-Signature"), + body, + request.headers.get("Epistula-Timestamp"), + request.headers.get("Epistula-Uuid"), + signed_for, + signed_by, + now, + ) + if err: + logger.error(err) + raise HTTPException(status_code=400, detail=err) diff --git a/apex/validator/miner_scorer.py b/apex/validator/miner_scorer.py index 6539e6689..b24490cbb 100644 --- a/apex/validator/miner_scorer.py +++ b/apex/validator/miner_scorer.py @@ -12,6 +12,7 @@ from apex.common.async_chain import AsyncChain from apex.common.constants import VALIDATOR_REFERENCE_LABEL +from apex.validator.weight_syncer import WeightSyncer # Scoring moving average in hours. Set to be: immunity_period - post_reg_threshold. SCORE_MA_WINDOW_HOURS = 23.75 @@ -19,12 +20,19 @@ class MinerScorer: - def __init__(self, chain: AsyncChain, interval: float = SCORE_INTERVAL_DEFAULT, debug: bool = False): + def __init__( + self, + chain: AsyncChain, + weight_syncer: WeightSyncer | None = None, + interval: float = SCORE_INTERVAL_DEFAULT, + debug: bool = False, + ): self.chain = chain self.interval = interval - self._running = True self._debug = debug + self._weight_syncer = weight_syncer self._debug_rewards_path = Path("debug_rewards.jsonl") + self._running = True async def start_loop(self) -> None: self._running = True @@ -108,12 +116,20 @@ async def set_scores(self) -> bool: with self._debug_rewards_path.open("a+") as fh: record_str: str = json.dumps(record) fh.write(f"{record_str}\n") - # TODO: Flush the db only on set_weights_result is True. + + if self._weight_syncer is not None: + try: + hkey_agg_rewards = await self._weight_syncer.compute_weighted_rewards(hkey_agg_rewards) + except BaseException as exc: + logger.error(f"Failed to compute weighted average rewards over the network, skipping: {exc}") + if hkey_agg_rewards: rewards_array = np.array(list(hkey_agg_rewards.values())) logger.debug(f"Setting weights, reward mean={rewards_array.mean():.4f} min={rewards_array.min():.4f}") else: logger.warning(f"Setting empty rewards: {hkey_agg_rewards}") + + # TODO: Flush the db only on set_weights_result is True. set_weights_result = await self.chain.set_weights(hkey_agg_rewards) # 4. Flush all deletions in a single commit. diff --git a/apex/validator/weight_syncer.py b/apex/validator/weight_syncer.py index a9a5576c7..05b85599c 100644 --- a/apex/validator/weight_syncer.py +++ b/apex/validator/weight_syncer.py @@ -1,130 +1,217 @@ -# import asyncio -# import json -# -# import aiohttp -# import httpx -# import numpy as np -# from loguru import logger -# -# from apex.common.async_chain import AsyncChain -# from apex.common.epistula import generate_header, verify_weight_signature - - -# class WeightSyncer: -# def __init__(self, chain: AsyncChain, weight_dict: dict[int, list[float]]): -# self.chain = chain -# self.wallet = self.chain.wallet -# self.current_hotkey = self.wallet.hotkey.ss58_address -# self.latest_weights: str = {} -# metagraph = await self.chain.metagraph -# self.uid = metagraph.hotkeys.index(self.current_hotkey) -# self.validator_uids = np.where(np.array(metagraph.validator_permit))[0].tolist() -# -# self.weight_matrix = np.zeros((len(self.validator_uids), metagraph.n.item())) -# self.stake_matrix = np.array([metagraph.S[uid] for uid in self.validator_uids]) -# -# self.validator_hotkeys = np.array([metagraph.hotkeys[uid] for uid in self.validator_uids]) -# self.validator_addresses = np.array( -# [ -# f"{metagraph.axons[uid].ip}:{metagraph.axons[uid].port}" -# for uid in self.validator_uids -# if uid < metagraph.n.item() -# ] -# ) -# -# self.weight_dict = weight_dict -# -# self.request_tracker = np.zeros(len(self.validator_uids)) -# -# @router.post("/receive_weight_matrix") -# async def receive_weight_matrix( -# request: Request, -# verification_data: dict = Depends(verify_weight_signature), -# weight_dict=Depends(get_weight_dict), -# ): -# """Endpoint to receive weight matrix updates from validators.""" -# await verify_weight_signature(request) -# -# body = await request.json() -# if not isinstance(body, dict) or "weights" not in body: -# raise HTTPException(status_code=400, detail="Invalid request body format") -# -# try: -# uid = body["uid"] -# weights = list(body["weights"]) -# weight_dict[uid] = weights -# return {"status": "success", "message": "Weight matrix updated successfully"} -# except Exception as e: -# logger.error(f"Error processing weight matrix: {e}") -# raise HTTPException(status_code=500, detail="Error processing weight matrix") -# -# async def send_rewards(self, rewards: dict[str, float], validator_address: str, validator_hotkey: str): -# try: -# async with aiohttp.ClientSession() as session: -# headers = await generate_header( -# self.chain.wallet.hotkey, body=json.dumps(body).encode("utf-8"), signed_for=validator_hotkey -# ) -# async with session.post( -# endpoint + "/v1/chat/completions", -# headers=headers, -# json=body, -# ) as resp: -# result = await resp.text() -# except BaseException: -# # Error during miner query, return empty string. -# return "" -# return str(result) -# -# try: -# vali_url = f"http://{validator_address}/receive_weight_matrix" -# timeout = httpx.Timeout(timeout=40.0) -# async with httpx.AsyncClient( -# timeout=timeout, -# event_hooks={"request": [create_header_hook(self.wallet.hotkey, validator_hotkey)]}, -# ) as client: -# response = await client.post( -# url=vali_url, -# json={"weights": weight_matrix.tolist(), "uid": self.uid}, -# headers={"Content-Type": "application/json"}, -# ) -# if response.status_code != 200: -# raise Exception( -# f"Status code {response.status_code} response for validator {validator_hotkey} - {vali_url}: " -# f"{response.status_code} for uids {len(weight_matrix)}" -# ) -# logger.debug(f"Successfully forwarded response to uid {validator_hotkey} - {vali_url}") -# except httpx.ConnectError as e: -# logger.warning( -# f"Couldn't connect to validator {validator_hotkey} {vali_url} for weight setting. Exception: {e}" -# ) -# except Exception as e: -# logger.warning( -# f"Error while forwarding weight matrix to validator {validator_hotkey} {vali_url}. Exception: {e}" -# ) -# -# async def get_augmented_weights(self, weights: np.ndarray, uid: int) -> np.ndarray: -# """Get the augmented weights for the given uid, sends the weights to the validators.""" -# await self.send_weight_matrixes(weights) -# -# await self.process_weight_dict() -# -# return np.average(self.weight_matrix, axis=0, weights=self.stake_matrix * self.request_tracker) -# -# async def send_weight_matrixes(self, weight_matrix: np.ndarray): -# tasks = [ -# self.send_weights(weight_matrix, validator_address, validator_hotkey) -# for validator_address, validator_hotkey in zip( -# self.validator_addresses, self.validator_hotkeys, strict=False -# ) -# ] -# -# await asyncio.gather(*tasks) -# -# async def process_weight_dict(self): -# for uid, weights in self.weight_dict.items(): -# if uid in self.validator_uids: -# validator_index = self.validator_uids.index(uid) -# self.weight_matrix[validator_index] = weights -# self.request_tracker[validator_index] = 1 -# else: -# logger.warning(f"UID {uid} is not a validator, skipping") +import asyncio +import os +import time +from typing import cast + +import httpx +import netaddr +import requests +import uvicorn +from bittensor.core.async_subtensor import AsyncMetagraph +from bittensor.core.extrinsics.asyncex.serving import serve_extrinsic +from fastapi import APIRouter, FastAPI, HTTPException, Request +from loguru import logger +from pydantic import BaseModel + +from apex.common.async_chain import AsyncChain +from apex.common.constants import VALIDATOR_VERIFIED_HOTKEYS +from apex.common.epistula import generate_header, verify_validator_signature + + +class ValidatorInfo(BaseModel): + uid: int + hotkey: str + address: str + stake: float + + +class WeightSyncer: + REWARD_EXPIRATION_SEC: float = 60 * 60 + + def __init__( + self, + chain: AsyncChain, + min_alpha_stake: float = 100_000, + verified_hotkeys: dict[str, str | None] | None = None, + enable_receive: bool = True, + enable_send: bool = True, + port: int = 8001, + ) -> None: + """Validator weight synchronizer.""" + self.chain = chain + self.min_alpha_stake = min_alpha_stake + self.verified_hotkeys = verified_hotkeys or VALIDATOR_VERIFIED_HOTKEYS + self.wallet = self.chain.wallet + self.current_hotkey = self.wallet.hotkey.ss58_address + self.receive_enabled = enable_receive + self.send_enabled = enable_send + self.port = port + self.server: uvicorn.Server | None = None + self.server_task: asyncio.Task[None] | None = None + self.hotkey_rewards: dict[str, float] | None = None + self.last_update_time: float = 0 + + async def start(self) -> None: + if not self.send_enabled: + logger.warning("Weight synchronization API is disabled for incoming reward requests") + return + + try: + app = FastAPI() + app.include_router(self.get_router()) + + config = uvicorn.Config( + app=app, + host="0.0.0.0", + port=self.port, + log_level="info", + workers=1, + reload=False, + loop="asyncio", + ) + self.server = uvicorn.Server(config) + self.server_task = asyncio.create_task(self.server.serve()) + logger.info(f"Started weight synchronization API on port {self.port}. pid={os.getpid()} self_id={id(self)}") + + # Announce the axon on the network. + external_ip = requests.get("https://checkip.amazonaws.com").text.strip() + netaddr.IPAddress(external_ip) + sub = await self.chain.subtensor() + serve_success = await serve_extrinsic( + subtensor=sub, + wallet=self.chain.wallet, + ip=external_ip, + port=self.port, + protocol=4, + netuid=self.chain.netuid, + ) + if serve_success: + logger.success(f"Serving weight syncer axon on subtensor at {external_ip}:{self.port}") + else: + logger.error("Failed to serve weight syncer axon on subtensor") + except BaseException as e: + logger.warning(f"Failed to announce weight syncer axon on subtensor: {e}") + + async def shutdown(self) -> None: + if self.server is not None: + self.server.should_exit = True + if self.server_task is not None: + await self.server_task + + def get_router(self) -> APIRouter: + """Creates and returns a FastAPI router with the endpoints for this class.""" + router = APIRouter() + + @router.post("/v1/get_rewards") + async def get_rewards_endpoint(request: Request) -> dict[str, float]: + await verify_validator_signature(request=request, chain=self.chain, min_stake=self.min_alpha_stake) + + outdated = time.time() - self.last_update_time + if (outdated := time.time() - self.last_update_time) > self.REWARD_EXPIRATION_SEC: + logger.warning(f"Rewards expired: {outdated:.2f}s - {self.last_update_time}") + raise HTTPException(status_code=503, detail="Rewards expired") + if self.hotkey_rewards is None: + logger.warning("Rewards not available") + raise HTTPException(status_code=503, detail="Rewards not available") + if not self.send_enabled: + logger.warning("API is disabled") + raise HTTPException(status_code=405, detail="API is disabled") + return self.hotkey_rewards + + return router + + async def compute_weighted_rewards(self, hotkey_rewards: dict[str, float]) -> dict[str, float]: + """Computes weighted rewards by fetching rewards from other validators and averaging them by stake.""" + self.hotkey_rewards = hotkey_rewards + self.last_update_time = time.time() + # logger.debug(f"Updating rewards at: {self.last_update_time}") + import os + + logger.debug(f"Updating rewards at: {self.last_update_time} pid={os.getpid()} self_id={id(self)}") + if not self.receive_enabled: + logger.warning("Rewards weight averaging is disable, using raw rewards") + return hotkey_rewards + + metagraph = await self.chain.metagraph() + + try: + own_uid = metagraph.hotkeys.index(self.current_hotkey) + except ValueError: + logger.error(f"Could not find own hotkey {self.current_hotkey} in metagraph, returning raw rewards") + return hotkey_rewards + + validator_rewards_tasks: dict[int, asyncio.Task[dict[str, float]]] = {} + for uid in metagraph.uids: + if uid == own_uid: + continue + + stake = metagraph.stake[uid] + hotkey = metagraph.hotkeys[uid] + is_verified = hotkey in self.verified_hotkeys + is_validator = metagraph.validator_permit[uid] + + if (stake >= self.min_alpha_stake and is_validator) or is_verified: + validator_rewards_tasks[uid] = asyncio.create_task(self.receive_rewards(metagraph, uid)) + + results = await asyncio.gather(*validator_rewards_tasks.values(), return_exceptions=True) + + validator_rewards: dict[int, dict[str, float]] = {} + for uid, result in zip(validator_rewards_tasks, results, strict=True): + if isinstance(result, BaseException) or not result: + logger.warning(f"Cannot receive rewards from uid {uid}: {result}") + continue + validator_rewards[uid] = result + logger.debug(f"Received rewards from validator {uid} with stake {metagraph.stake[uid]}") + + all_validator_uids = [own_uid] + list(validator_rewards.keys()) + total_stake = sum(metagraph.stake[uid] for uid in all_validator_uids) + + if total_stake == 0: + logger.warning("Total stake of responding validators is zero, returning original rewards") + return hotkey_rewards + + own_stake = metagraph.stake[own_uid] + + weighted_rewards: dict[str, float] = {} + for miner_hkey in hotkey_rewards: + own_reward = hotkey_rewards.get(miner_hkey, 0.0) + total_weighted_reward = own_reward * own_stake + + for uid, rewards in validator_rewards.items(): + validator_reward = rewards.get(miner_hkey, 0.0) + total_weighted_reward += validator_reward * metagraph.stake[uid] + + weighted_rewards[miner_hkey] = total_weighted_reward / total_stake + + logger.debug( + f"Averaged rewards over {len(all_validator_uids)} validators. " + f"Self stake: {100 * own_stake / total_stake:.2f}%" + ) + return weighted_rewards + + async def receive_rewards(self, metagraph: AsyncMetagraph, uid: int) -> dict[str, float]: + """Receive rewards from the given validator uid.""" + try: + target_hotkey = metagraph.hotkeys[uid] + if (address := VALIDATOR_VERIFIED_HOTKEYS.get(target_hotkey)) is None: + axon = metagraph.axons[uid] + address = f"{axon.ip}:{axon.port}" + + async with httpx.AsyncClient() as client: + body = b"{}" + headers = await generate_header( + hotkey=self.chain.wallet.hotkey, + body=body, + signed_for=target_hotkey, + ) + resp = await client.post( + f"http://{address}/v1/get_rewards", + headers=headers, + content=body, + ) + resp.raise_for_status() + return cast(dict[str, float], resp.json()) + + except BaseException as exc: + logger.warning(f"Cannot receive rewards from uid {uid}: {exc}") + return {} diff --git a/config/mainnet.yaml.example b/config/mainnet.yaml.example index 3d8bffaf6..bff750080 100644 --- a/config/mainnet.yaml.example +++ b/config/mainnet.yaml.example @@ -1,8 +1,8 @@ chain: kwargs: netuid: 1 - coldkey: "validator" - hotkey: "default" + coldkey: "YOUR_COLDKEY" + hotkey: "YOUR_HOTKEY" network: - finney # - ws://LOCAL_SUBTENSOR_FALLBACK_1 @@ -27,3 +27,12 @@ deep_research: research_model: "Qwen/Qwen3-235B-A22B-Instruct-2507" compression_model: "deepseek-ai/DeepSeek-V3-0324" final_model: "deepseek-ai/DeepSeek-V3-0324" + +weight_syncer: + kwargs: + # Change the port if necessary. + port: 8001 + # When enabled, performs weight synchronization across validators, drastically improves vTrust. + enable_receive: True + # When enabled, allows other validators to request your rewards, slightly improves vTrust. + enable_send: True diff --git a/config/testnet.yaml.example b/config/testnet.yaml.example index 9b24697c8..adea06149 100644 --- a/config/testnet.yaml.example +++ b/config/testnet.yaml.example @@ -34,3 +34,12 @@ miner_sampler: # For testing purposes one can specify available pool of uids. # available_uids: [1, 2] # available_addresses: ["http://0.0.0.0:8081", "http://0.0.0.0:8082"] + +weight_syncer: + kwargs: + # Change the port if necessary. + port: 8001 + # When enabled, performs weight synchronization across validators, drastically improves vTrust. + enable_receive: True + # When enabled, allows other validators to request your rewards, slightly improves vTrust. + enable_send: True diff --git a/pyproject.toml b/pyproject.toml index 31ed9eca4..832ad83e3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,7 @@ dependencies = [ "bittensor>=9.7.0", "rouge>=1.0.1", "substrate-interface>=1.7.11", + "types-netaddr>=1.3.0.20240530", ] @@ -82,10 +83,8 @@ exclude = [ "^tests/", "^venv/", '^\.venv/', - # TODO: Enable once fixed. "scripts/", - "apex/services/", - "apex/validator/", + "drafts/", ] [[tool.mypy.overrides]] diff --git a/tests/validator/test_weight_syncer.py b/tests/validator/test_weight_syncer.py new file mode 100644 index 000000000..26355565b --- /dev/null +++ b/tests/validator/test_weight_syncer.py @@ -0,0 +1,140 @@ +import time +from unittest.mock import AsyncMock, MagicMock, patch + +import numpy as np +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient + +from apex.validator.weight_syncer import WeightSyncer +from tests.common.mock_async_chain import DummyMetagraph + + +class UidList(list): + """A list that has a .tolist() method for compatibility with torch tensors.""" + + def tolist(self): + return self + + +@pytest.fixture +def mock_axon(): + """Returns a function to create a mock axon.""" + + def _mock_axon(ip, port, is_serving=True): + axon = MagicMock() + axon.ip = ip + axon.port = port + axon.is_serving = is_serving + return axon + + return _mock_axon + + +@pytest.fixture +def mock_metagraph(mock_axon): + """Returns a mock metagraph based on DummyMetagraph.""" + metagraph = DummyMetagraph( + hotkeys=["hotkey0_self", "hotkey1_validator", "hotkey2_validator"], + ) + # Overwrite uids with our special UidList to provide .tolist() + metagraph.uids = UidList([0, 1, 2]) + metagraph.stake = [np.float32(1000.0), np.float32(2000.0), np.float32(500.0)] + metagraph.validator_permit = [True, True, False] + metagraph.axons = [ + mock_axon("1.1.1.1", 8000), + mock_axon("2.2.2.2", 8001), + mock_axon("3.3.3.3", 8002), + ] + return metagraph + + +@pytest.fixture +def mock_chain(mock_metagraph): + """Returns a mock chain with a mock metagraph.""" + chain = MagicMock() + chain.wallet.hotkey.ss58_address = "hotkey0_self" + chain.metagraph = AsyncMock(return_value=mock_metagraph) + return chain + + +@pytest.fixture +def weight_syncer(mock_chain): + """Returns a WeightSyncer instance with a mock chain.""" + return WeightSyncer(chain=mock_chain, min_alpha_stake=1000) + + +@pytest.mark.asyncio +async def test_compute_weighted_rewards_happy_path(weight_syncer, mock_metagraph): + """Test that weighted rewards are computed correctly in the ideal case.""" + local_rewards = {"miner1": 0.9, "miner2": 0.1} + validator1_rewards = {"miner1": 0.85, "miner2": 0.82, "miner3": 0.7} + + with patch.object(weight_syncer, "receive_rewards", new_callable=AsyncMock) as mock_receive: + mock_receive.side_effect = [validator1_rewards, {}] # UID 2 has low stake + + weighted_rewards = await weight_syncer.compute_weighted_rewards(local_rewards) + + # self (1000) + validator1 (2000) = 3000 total stake + # miner1: (0.9 * 1000 + 0.85 * 2000) / 3000 = 0.8666 + # miner2: (0.1 * 1000 + 0.82 * 2000) / 3000 = 0.58 + assert mock_receive.call_count == 1 + assert mock_receive.call_args.args[1] == 1 # Called for UID 1 + assert pytest.approx(weighted_rewards["miner1"], 0.001) == 0.8666 + assert pytest.approx(weighted_rewards["miner2"], 0.001) == 0.58 + assert "miner3" not in weighted_rewards + + +@pytest.mark.asyncio +async def test_compute_weighted_rewards_self_not_in_metagraph(weight_syncer, mock_metagraph): + """Test that local rewards are returned if the validator's hotkey is not in the metagraph.""" + mock_metagraph.hotkeys = ["other1", "other2", "other3"] + local_rewards = {"miner1": 0.9} + weighted_rewards = await weight_syncer.compute_weighted_rewards(local_rewards) + assert weighted_rewards == local_rewards + + +@pytest.mark.asyncio +@patch("httpx.AsyncClient") +async def test_receive_rewards_success(mock_async_client, weight_syncer, mock_metagraph): + """Test successfully receiving rewards from another validator.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = {"miner1": 0.9} + mock_async_client.return_value.__aenter__.return_value.post = AsyncMock(return_value=mock_response) + + rewards = await weight_syncer.receive_rewards(mock_metagraph, 1) + assert rewards == {"miner1": 0.9} + + +@pytest.mark.asyncio +@patch("httpx.AsyncClient") +async def test_receive_rewards_http_error(mock_async_client, weight_syncer, mock_metagraph): + """Test that an empty dict is returned on HTTP error.""" + mock_async_client.return_value.__aenter__.return_value.post.side_effect = Exception("HTTP Error") + rewards = await weight_syncer.receive_rewards(mock_metagraph, 1) + assert rewards == {} + + +@patch("apex.validator.weight_syncer.verify_validator_signature", new_callable=AsyncMock) +def test_get_rewards_endpoint(mock_verify_signature, weight_syncer): + """Test the FastAPI endpoint for serving rewards.""" + app = FastAPI() + app.include_router(weight_syncer.get_router()) + client = TestClient(app) + + # Case 1: No rewards set yet + response = client.post("/v1/get_rewards") + assert response.status_code == 503 + + # Case 2: Rewards are set and not expired + weight_syncer.hotkey_rewards = {"miner1": 0.95} + weight_syncer.last_update_time = time.time() + response = client.post("/v1/get_rewards") + assert response.status_code == 200 + assert response.json() == {"miner1": 0.95} + + # Case 3: Rewards are expired + weight_syncer.last_update_time = time.time() - WeightSyncer.REWARD_EXPIRATION_SEC - 1 + response = client.post("/v1/get_rewards") + assert response.status_code == 503 diff --git a/uv.lock b/uv.lock index d2db4f48f..f81505ed0 100644 --- a/uv.lock +++ b/uv.lock @@ -141,7 +141,7 @@ wheels = [ [[package]] name = "apex" -version = "3.0.0" +version = "3.0.1" source = { virtual = "." } dependencies = [ { name = "aiohttp" }, @@ -168,6 +168,7 @@ dependencies = [ { name = "rouge" }, { name = "substrate-interface" }, { name = "tavily-python" }, + { name = "types-netaddr" }, ] [package.optional-dependencies] @@ -235,6 +236,7 @@ requires-dist = [ { name = "substrate-interface", specifier = ">=1.7.11" }, { name = "tavily-python", specifier = ">=0.7.10" }, { name = "types-cachetools", marker = "extra == 'dev'", specifier = ">=6.0.0.20250525" }, + { name = "types-netaddr", specifier = ">=1.3.0.20240530" }, { name = "types-pyyaml", marker = "extra == 'dev'", specifier = ">=6.0.12.20250516" }, ] provides-extras = ["dev"] @@ -2865,6 +2867,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/47/8c/4ab0a17ece30fe608270b89cf066387051862899fff9f54ab12511fc7fdd/types_cachetools-6.0.0.20250525-py3-none-any.whl", hash = "sha256:1de8f0fe4bdcb187a48d2026c1e3672830f67943ad2bf3486abe031b632f1252", size = 8938 }, ] +[[package]] +name = "types-netaddr" +version = "1.3.0.20240530" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/44/10/b4bad18b6918eec74db1087d1376cc094fdef5fa1261aa905b6cd94b9408/types-netaddr-1.3.0.20240530.tar.gz", hash = "sha256:742c2ec1f202b666f544223e2616b34f1f13df80c91e5aeaaa93a72e4d0774ea", size = 10459 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d0/de/1eaa9f59ea51eb8b64c04f2be5af8c91940072f1c5b0cfa5870a7665a028/types_netaddr-1.3.0.20240530-py3-none-any.whl", hash = "sha256:354998d018e326da4f1d9b005fc91137b7c2c473aaf03c4ef64bf83c6861b440", size = 14335 }, +] + [[package]] name = "types-pyyaml" version = "6.0.12.20250516" diff --git a/validator.py b/validator.py index 8cace967c..177c58ff7 100644 --- a/validator.py +++ b/validator.py @@ -14,6 +14,7 @@ from apex.validator.miner_sampler import MinerSampler from apex.validator.miner_scorer import MinerScorer from apex.validator.pipeline import Pipeline +from apex.validator.weight_syncer import WeightSyncer async def read_args() -> argparse.Namespace: @@ -54,7 +55,14 @@ async def main() -> None: miner_sampler = MinerSampler(chain=chain, logger_db=logger_db, **config.miner_sampler.kwargs) logger.debug("Started miner sampler") - miner_scorer = MinerScorer(chain=chain, **config.miner_scorer.kwargs) + weight_syncer = WeightSyncer(chain=chain, **config.weight_syncer.kwargs) + await weight_syncer.start() + logger.debug( + f"Started weight synchronizer, receive enabled: {weight_syncer.receive_enabled}, " + f"send enabled: {weight_syncer.send_enabled}, port: {weight_syncer.port}" + ) + + miner_scorer = MinerScorer(chain=chain, weight_syncer=weight_syncer, **config.miner_scorer.kwargs) asyncio.create_task(miner_scorer.start_loop()) logger.debug(f"Started miner scorer with interval={miner_scorer.interval}")