Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
*.log
requirements.txt
**/*.ipynb
debug_rewards.jsonl
Expand Down
18 changes: 18 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,22 @@ repos:
hooks:
- id: ruff
args: [--fix]
stages: [pre-commit]
- id: ruff-format
stages: [pre-commit]

- repo: local
hooks:
- id: mypy
name: mypy
entry: mypy .
language: system
pass_filenames: false
stages: [pre-commit]

- id: pytest
name: pytest
entry: pytest tests/ --verbose --failed-first --exitfirst --disable-warnings
language: system
pass_filenames: false
stages: [pre-commit]
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ Subnet 1 is the most intelligent inference model on Bittensor. As the first agen

3. **Install the project and its development dependencies:**
```bash
uv venv && uv python install 3.11 && uv python pin 3.11 && uv venv --python=3.11 && uv pip install -e '.[dev]'
uv venv --python=3.11 && uv pip install '.[dev]'
```

4. **Activate python environment:**
Expand Down
15 changes: 13 additions & 2 deletions apex/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,17 @@
__version__ = version("apex")


def _version_to_int(version_str: str) -> int:
version_split = version_str.split(".") + ["0", "0"] # in case a version doesn't have third element, e.g. 3.0
major = int(version_split[0])
minor = int(version_split[1])
patch = int(version_split[2])
return (10000 * major) + (100 * minor) + patch


__spec_version__ = _version_to_int(__version__)


def setup_logger(log_file_path: str | Path | None = None, level: str = "INFO") -> Any:
"""Set up the loguru logger with optional file logging and specified log level.

Expand All @@ -28,9 +39,9 @@ def setup_logger(log_file_path: str | Path | None = None, level: str = "INFO") -
# Add file handler if a path is provided.
if log_file_path:
file_log_format = "{time:YYYY-MM-DD HH:mm:ss} [{file}:{line}] {message}"
logger.add(str(log_file_path), level=level, format=file_log_format, rotation="10 MB", retention="7 days")
logger.add(str(log_file_path), level=level, format=file_log_format, rotation="5 MB", retention="3 days")

return logger


setup_logger(level="DEBUG")
setup_logger(log_file_path="logs.log", level="DEBUG")
38 changes: 19 additions & 19 deletions apex/common/async_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from bittensor.core.metagraph import AsyncMetagraph
from loguru import logger

from apex import __spec_version__
from apex.common.utils import async_cache

_METAGRAPH_TTL: int = 10 * 60
Expand Down Expand Up @@ -111,36 +112,35 @@ def network(self) -> list[str]:
return self._network

async def set_weights(self, rewards: dict[str, float]) -> bool:
metagraph = await self.metagraph()
subtensor = await self.subtensor()
weights: dict[int, float] = {}
try:
metagraph = await self.metagraph()
subtensor = await self.subtensor()
weights: dict[int, float] = {}

for hotkey, reward in rewards.items():
try:
idx = metagraph.hotkeys.index(hotkey)
except ValueError:
# Hotkey not found in the metagraph (e.g., deregistered). Skip it.
continue
for hotkey, reward in rewards.items():
try:
idx = metagraph.hotkeys.index(hotkey)
except ValueError:
# Hotkey not found in the metagraph (e.g., deregistered). Skip it.
continue

uid = metagraph.uids[idx]
weights[uid] = reward
uid = metagraph.uids[idx]
weights[uid] = reward

# Set the weights.
try:
result = await subtensor.set_weights(
success, err = await subtensor.set_weights(
wallet=self._wallet,
netuid=self._netuid,
uids=list(weights.keys()),
weights=list(weights.values()),
version_key=__spec_version__,
wait_for_inclusion=True,
wait_for_finalization=True,
)
if not result:
logger.error(f"Error setting weights: {result}")
return False
return True
if not success:
logger.error(f"Error during weight set: {err}")
return bool(success)
except BaseException as exc:
logger.exception(f"Error setting weights: {exc}")
logger.exception(f"Error during weight set: {exc}")
return False

async def mask_network(self) -> list[str]:
Expand Down
29 changes: 18 additions & 11 deletions apex/validator/miner_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import json
import random
import time
from collections import deque
from collections.abc import Coroutine, Sequence
from typing import Any, Literal

Expand Down Expand Up @@ -69,7 +70,8 @@ def __init__(
if self._available_uids and self._available_addresses:
equal_length = len(self._available_uids) == len(self._available_addresses)
assert equal_length, "Test UIDs and addresses must be the same length."
self._remaining_epoch_miners: set[MinerInfo] = set()
self._epoch_deque: deque[MinerInfo] = deque()
self._sample_lock = asyncio.Lock()

@async_cache(_TTL_UIDS_RESYNC)
async def _get_all_miners(self) -> list[MinerInfo]:
Expand Down Expand Up @@ -110,20 +112,24 @@ async def _get_all_miners(self) -> list[MinerInfo]:
async def _sample_miners(self) -> list[MinerInfo]:
miners = await self._get_all_miners()

miners_sample: list[MinerInfo] = []
if self._sample_mode == "random":
miners_sample = random.sample(miners, self._sample_size)

elif self._sample_mode == "sequential":
if len(self._remaining_epoch_miners) < self._sample_size:
self._remaining_epoch_miners = set(miners)
logger.debug(f"Starting new miner sampling epoch, miners amount: {len(self._remaining_epoch_miners)}")
indices_sample = sorted(random.sample(range(len(self._remaining_epoch_miners)), self._sample_size))
miners_sample = [miners[i] for i in indices_sample]
self._remaining_epoch_miners -= set(miners_sample)
async with self._sample_lock:
while len(miners_sample) < self._sample_size:
if not self._epoch_deque:
# Get shuffled deque of miners.
self._epoch_deque = deque(random.sample(miners, len(miners)))
miners_sample.append(self._epoch_deque.popleft())

else:
raise ValueError(f"Unknown sampling mode: {self._sample_mode}")

logger.debug(
f"Sampled uids (sample size = {self._sample_size}): {sorted([miner.uid for miner in miners_sample])}"
)
return miners_sample

async def query_miners(self, body: dict[str, Any], endpoint: str, hotkey: str | None = None) -> str:
Expand All @@ -134,7 +140,7 @@ async def query_miners(self, body: dict[str, Any], endpoint: str, hotkey: str |
self._chain.wallet.hotkey, body=json.dumps(body).encode("utf-8"), signed_for=hotkey
)
async with session.post(
endpoint + "/v1/chat/completions",
f"{endpoint}/v1/chat/completions",
headers=headers,
json=body,
) as resp:
Expand All @@ -151,9 +157,10 @@ async def query_generators(self, query: str) -> MinerGeneratorResults:

hotkeys: list[str] = []
tasks: list[Coroutine[str, str, Any]] = []

logger.debug(f"Querying {len(miner_information)} miner generators")
for miner_info in miner_information:
hotkeys.append(miner_info.hotkey)
logger.debug(f"Querying miner generator at {miner_info.address} with uid: {miner_info.uid}")
tasks.append(self.query_miners(body=body, endpoint=miner_info.address, hotkey=miner_info.hotkey))
generator_results = await asyncio.gather(*tasks)
return MinerGeneratorResults(query=query, generator_hotkeys=hotkeys, generator_results=generator_results)
Expand Down Expand Up @@ -217,15 +224,15 @@ async def query_discriminators(
choice_content = "None"
parsed_discriminator_results.append(choice_content)

# Apply scoring logic based on selected generator type
# Apply scoring logic based on selected generator type.
if choice_content == str(ground_truth):
discriminator_score = score_per_miner
else:
discriminator_score = 0.0

discriminator_results_float.append(discriminator_score)

# Generator result is 1 minus sum of discriminator results
# Generator result is 1 minus sum of discriminator results.
generator_result_float = 1.0 - sum(discriminator_results_float)
miner_discriminator_results = MinerDiscriminatorResults(
query=query,
Expand Down
15 changes: 14 additions & 1 deletion apex/validator/miner_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from pathlib import Path

import aiosqlite
import numpy as np
from loguru import logger

from apex.common.async_chain import AsyncChain
Expand All @@ -29,8 +30,11 @@ async def start_loop(self) -> None:
self._running = True
while self._running:
await asyncio.sleep(self.interval)
logger.debug("Attempting to set weights")
success = await self.set_scores()
if not success:
if success:
logger.info("Successfully set weights")
else:
logger.error("Failed to set weights")

async def shutdown(self) -> None:
Expand All @@ -50,6 +54,7 @@ async def set_scores(self) -> bool:
expose each one as plain python objects so that downstream code can work with them,
and remove rows that are older than the time window.
"""
logger.debug("Retrieving miner's performance history")
async with self._db() as conn: # type: aiosqlite.Connection
# Calculate the cutoff timestamp (current time - window hours).
cutoff_timestamp = int(time.time() - SCORE_MA_WINDOW_HOURS * 3600)
Expand All @@ -70,6 +75,7 @@ async def set_scores(self) -> bool:
return False

# 2. Iterate over the in-memory list so that the caller can process freely.
logger.debug("Pre-processing miner's rewards")
hkey_agg_rewards: dict[str, float] = {}
for generator_hotkey, generator_score, disc_hotkeys_json, disc_scores_json in rows:
# Deserialize JSON columns.
Expand All @@ -88,6 +94,7 @@ async def set_scores(self) -> bool:
hkey_agg_rewards[hotkey] = float(hkey_agg_rewards.get(hotkey, 0.0)) + float(reward)

# 3. Delete rows that are older than the time window.
logger.debug("Cleaning up miner's outdated history")
await conn.execute(
"DELETE FROM discriminator_results WHERE timestamp < ?",
(cutoff_timestamp,),
Expand All @@ -102,8 +109,14 @@ async def set_scores(self) -> bool:
record_str: str = json.dumps(record)
fh.write(f"{record_str}\n")
# TODO: Flush the db only on set_weights_result is True.
if hkey_agg_rewards:
rewards_array = np.array(list(hkey_agg_rewards.values()))
logger.debug(f"Setting weights, reward mean={rewards_array.mean():.4f} min={rewards_array.min():.4f}")
else:
logger.warning(f"Setting empty rewards: {hkey_agg_rewards}")
set_weights_result = await self.chain.set_weights(hkey_agg_rewards)

# 4. Flush all deletions in a single commit.
logger.debug("Updating rewards DB")
await conn.commit()
return set_weights_result
34 changes: 20 additions & 14 deletions apex/validator/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@ def __init__(
llm: LLMBase,
deep_research: DeepResearchBase,
logger_apex: LoggerApex | None = None,
num_consumers: int = 10,
timeout_consumer: float = 60,
timeout_producer: float = 6,
num_consumers: int = 5,
timeout_consumer: float = 180,
timeout_producer: float = 36,
queue_size: int = 10_000,
redundancy_rate: float = 0.1, # The rate that references are generated in addition to generator steps
redundancy_rate: float = 0.05, # The rate that references are generated in addition to generator steps
reference_rate: float = 0.5, # The rate that references are generated as opposed to generator steps
):
self.config = config
Expand Down Expand Up @@ -81,21 +81,27 @@ async def run_single(self, task: QueryTask) -> str:
logger.debug("Generating task query")
query = await generate_query(llm=self.llm, websearch=self.websearch)

reference = None
tool_history: list[dict[str, str]] = []
if random.random() < self.reference_rate:
try:
generator_results = None
ground_truth = 0
logger.debug(f"Generating task reference for query: {query[:20]}..")
reference, tool_history = await generate_reference(llm=self.deep_research, query=query)
except BaseException as exc:
logger.exception(f"Failed to generate reference: {exc}")

if reference is None:
ground_truth = 1
logger.debug(f"Querying generators with query: {query[:20]}..")
generator_results = await self.miner_registry.query_generators(query=query)
if random.random() < self.redundancy_rate:
logger.debug(f"Generating redundant task reference for query: {query[:20]}..")
reference, tool_history = await generate_reference(llm=self.deep_research, query=query)
else:
reference = None
tool_history = []
else:
generator_results = None
ground_truth = 0
logger.debug(f"Generating task reference for query: {query[:20]}..")
reference, tool_history = await generate_reference(llm=self.deep_research, query=query)
try:
logger.debug(f"Generating redundant task reference for query: {query[:20]}..")
reference, tool_history = await generate_reference(llm=self.deep_research, query=query)
except BaseException as exc:
logger.warning(f"Failed to generate redundant reference: {exc}")

discriminator_results = await self.miner_registry.query_discriminators(
query=query, generator_results=generator_results, reference=reference, ground_truth=ground_truth
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "apex"
version = "3.0.0"
version = "3.0.1"
description = "Bittensor Subnet 1: Apex"
readme = "README.md"
requires-python = "~=3.11"
Expand Down
6 changes: 4 additions & 2 deletions tests/common/mock_async_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,18 +54,20 @@ async def set_weights(
netuid: int,
uids: Iterable[int],
weights: Iterable[float],
version_key: int,
wait_for_inclusion: bool,
wait_for_finalization: bool,
) -> bool:
) -> tuple[bool, str | None]:
self.last_set_weights = {
"wallet": wallet,
"netuid": netuid,
"uids": list(uids),
"weights": list(weights),
"version_key": version_key,
"wait_for_inclusion": wait_for_inclusion,
"wait_for_finalization": wait_for_finalization,
}
return self.weights_result
return self.weights_result, ""


def patch_wallet(monkeypatch: pytest.MonkeyPatch) -> None:
Expand Down
2 changes: 2 additions & 0 deletions tests/common/test_async_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import pytest

from apex import __spec_version__
from apex.common.async_chain import AsyncChain # noqa: E402
from tests.common.mock_async_chain import DummyMetagraph, DummySubtensor, patch_subtensor, patch_wallet

Expand Down Expand Up @@ -121,6 +122,7 @@ async def test_set_weights_happy_path(monkeypatch):
assert stub.last_set_weights is not None
assert stub.last_set_weights["uids"] == [2]
assert stub.last_set_weights["weights"] == [0.7]
assert stub.last_set_weights["version_key"] == __spec_version__


@pytest.mark.asyncio
Expand Down
Loading