Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
717c178
Add wandb dependency and update uv.lock revision to 2
bkb2135 Aug 19, 2025
2bf92df
Add logger_wandb configuration to mainnet and testnet YAML examples
bkb2135 Aug 19, 2025
0f826f1
Add LoggerWandb class for asynchronous logging to Weights & Biases
bkb2135 Aug 19, 2025
557294f
Add generator_times to MinerGeneratorResults and generator_time to Mi…
bkb2135 Aug 19, 2025
5f9c6cd
Update MinerResults to include execution time; enhance query_miners f…
bkb2135 Aug 19, 2025
1f6161e
Refactor logging to use LoggerWandb in validator and pipeline modules
bkb2135 Aug 19, 2025
f0796d7
Refactor LoggerWandb initialization and logging method; update MinerS…
bkb2135 Aug 19, 2025
70b5773
Enhance LoggerWandb initialization with project-specific configuratio…
bkb2135 Aug 19, 2025
e744450
Merge branch 'release/v3.0.3' into features/wandb-logging
bkb2135 Aug 19, 2025
299b9e1
Regenerate Lock
bkb2135 Aug 19, 2025
2fc194a
Add ruff to packages
bkb2135 Aug 19, 2025
17d7ab3
Refactor logger_wandb module to improve imports and maintain consiste…
bkb2135 Aug 19, 2025
622f73c
Add timing to tests
bkb2135 Aug 19, 2025
feba89f
Wrap query miners instead of altering signature
bkb2135 Aug 19, 2025
080b89c
Enhance error handling in LoggerWandb initialization by adding a try-…
bkb2135 Aug 19, 2025
465c253
Update LoggerWandb initialization to accept additional configuration …
bkb2135 Aug 19, 2025
e7c6e8e
Add wandb/ to .gitignore to exclude W&B files from version control
bkb2135 Aug 19, 2025
db80808
Enhance LoggerWandb with detailed logging for W&B run initialization …
bkb2135 Aug 19, 2025
4674d5c
Bump Project Version
bkb2135 Aug 26, 2025
006c8da
Merge branch 'release/v3.0.4' into features/wandb-logging
bkb2135 Aug 27, 2025
a0c8650
Update uv.lock to revision 2; bump zstandard version to 0.24.0 and ad…
bkb2135 Aug 27, 2025
766d2ce
Change wandb log level
bkb2135 Aug 28, 2025
d24a1e8
Set default log level to info
bkb2135 Aug 28, 2025
fb6935b
Remove unused logging import from validator.py
bkb2135 Aug 28, 2025
823a6c2
Merge pull request #806 from macrocosm-os/features/wandb-logging
bkb2135 Aug 28, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -117,3 +117,5 @@ cython_debug/

# VS Code
.vscode

wandb/
2 changes: 1 addition & 1 deletion apex/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,4 @@ def setup_logger(log_file_path: str | Path | None = None, level: str = "INFO") -
return logger


setup_logger(log_file_path="logs.log", level="DEBUG")
setup_logger(log_file_path="logs.log", level="INFO")
1 change: 1 addition & 0 deletions apex/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class Config(BaseModel):
chain: ConfigClass = Field(default_factory=ConfigClass)
websearch: ConfigClass = Field(default_factory=ConfigClass)
logger_db: ConfigClass = Field(default_factory=ConfigClass)
logger_wandb: ConfigClass = Field(default_factory=ConfigClass)
weight_syncer: ConfigClass = Field(default_factory=ConfigClass)
miner_sampler: ConfigClass = Field(default_factory=ConfigClass)
miner_scorer: ConfigClass = Field(default_factory=ConfigClass)
Expand Down
2 changes: 2 additions & 0 deletions apex/common/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@ class MinerGeneratorResults(BaseModel):
query: str
generator_hotkeys: list[str]
generator_results: list[str]
generator_times: list[float]


class MinerDiscriminatorResults(BaseModel):
query: str
generator_hotkey: str
generator_result: str
generator_time: float
generator_score: float
discriminator_hotkeys: list[str]
discriminator_results: list[str]
Expand Down
69 changes: 69 additions & 0 deletions apex/validator/logger_wandb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from collections.abc import Mapping
from typing import Any

import wandb
from loguru import logger

from apex import __version__
from apex.common.async_chain import AsyncChain
from apex.common.models import MinerDiscriminatorResults


def approximate_tokens(text: str) -> int:
"""Count the number of tokens in a text."""
return len(text) // 4


class LoggerWandb:
def __init__(
self,
async_chain: AsyncChain,
project: str = "apex-gan-arena",
api_key: str | None = None,
):
self.run: Any | None = None
if project and api_key:
try:
# Authenticate with W&B, then initialize the run
wandb.login(key=api_key)
self.run = wandb.init(
entity="macrocosmos",
project=project,
config={
"hotkey": async_chain.wallet.hotkey.ss58_address,
"netuid": async_chain.netuid,
"version": __version__,
},
)
logger.info(f"Initialized W&B run: {self.run.id}")
except Exception as e:
logger.error(f"Failed to initialize W&B run: {e}")
else:
logger.warning("W&B API key not provided, skipping logging to W&B")

async def log(
self,
reference: str | None = None,
discriminator_results: MinerDiscriminatorResults | None = None,
tool_history: list[dict[str, str]] | None = None,
) -> None:
"""Log an event to wandb."""
if self.run:
if discriminator_results:
processed_event = self.process_event(discriminator_results.model_dump())
processed_event["reference"] = reference
processed_event["tool_history"] = tool_history
self.run.log(processed_event)

def process_event(self, event: Mapping[str, Any]) -> dict[str, Any]:
"""Preprocess an event before logging it."""
reference = event.get("reference", "")
generation = event.get("generation", "")
generator_tokens = approximate_tokens(generation)
reference_tokens = approximate_tokens(reference)

processed_event: dict[str, Any] = dict(event)
processed_event["generator_tokens"] = generator_tokens
processed_event["reference_tokens"] = reference_tokens

return processed_event
27 changes: 21 additions & 6 deletions apex/validator/miner_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,19 +156,32 @@ async def query_miners(
return ""
return str(result)

async def query_miners_with_times(
self, body: dict[str, Any], endpoint: str, hotkey: str | None = None, timeout: float = TIMEOUT
) -> tuple[str, float]:
"""Query the miners for the query."""
start_time = time.time()
result = await self.query_miners(body, endpoint, hotkey, timeout)
return result, time.time() - start_time

async def query_generators(self, query: str) -> MinerGeneratorResults:
"""Query the miners for the query."""
miner_information = await self._sample_miners(sample_size=self._generator_sample_size)
body = {"step": "generator", "query": query}

hotkeys: list[str] = []
tasks: list[Coroutine[str, str, Any]] = []
tasks: list[Coroutine[tuple[str, float], str, Any]] = []

for miner_info in miner_information:
hotkeys.append(miner_info.hotkey)
tasks.append(self.query_miners(body=body, endpoint=miner_info.address, hotkey=miner_info.hotkey))
tasks.append(self.query_miners_with_times(body=body, endpoint=miner_info.address, hotkey=miner_info.hotkey))
generator_results = await asyncio.gather(*tasks)
return MinerGeneratorResults(query=query, generator_hotkeys=hotkeys, generator_results=generator_results)
return MinerGeneratorResults(
query=query,
generator_hotkeys=hotkeys,
generator_results=[result[0] for result in generator_results],
generator_times=[result[1] for result in generator_results],
)

async def query_discriminators(
self,
Expand All @@ -181,19 +194,20 @@ async def query_discriminators(
miner_information = await self._sample_miners(sample_size=self._discriminator_sample_size)
# Flip the coin for the generator.
if ground_truth and generator_results:
selected_generator: tuple[str, str] = random.choice(
selected_generator: tuple[str, str, float] = random.choice(
list(
zip(
generator_results.generator_hotkeys,
generator_results.generator_results,
generator_results.generator_times,
strict=False,
)
)
)
else:
if reference is None:
raise ValueError("Reference cannot be None when not using miner generator results")
selected_generator = (VALIDATOR_REFERENCE_LABEL, reference)
selected_generator = (VALIDATOR_REFERENCE_LABEL, reference, 0.0)

body = {
"step": "discriminator",
Expand All @@ -202,7 +216,7 @@ async def query_discriminators(
}

hotkeys: list[str] = []
tasks: list[Coroutine[str, str, Any]] = []
tasks: list[Coroutine[tuple[str, float], str, Any]] = []
for miner_info in miner_information:
hotkeys.append(miner_info.hotkey)
tasks.append(self.query_miners(body=body, endpoint=miner_info.address, hotkey=miner_info.hotkey))
Expand Down Expand Up @@ -244,6 +258,7 @@ async def query_discriminators(
generator_hotkey=selected_generator[0],
generator_result=selected_generator[1],
generator_score=generator_result_float,
generator_time=selected_generator[2],
discriminator_hotkeys=hotkeys,
discriminator_results=parsed_discriminator_results,
discriminator_scores=discriminator_results_float,
Expand Down
10 changes: 5 additions & 5 deletions apex/validator/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
from apex.services.llm.llm_base import LLMBase
from apex.services.websearch.websearch_base import WebSearchBase
from apex.validator import generate_query, generate_reference
from apex.validator.logger_apex import LoggerApex
from apex.validator.logger_local import LoggerLocal
from apex.validator.logger_wandb import LoggerWandb
from apex.validator.miner_sampler import MinerSampler


Expand All @@ -23,7 +23,7 @@ def __init__(
miner_sampler: MinerSampler,
llm: LLMBase,
deep_research: DeepResearchBase,
logger_apex: LoggerApex | None = None,
logger_wandb: LoggerWandb | None = None,
num_consumers: int = 5,
timeout_consumer: float = 1200,
timeout_producer: float = 240,
Expand All @@ -36,7 +36,7 @@ def __init__(
self.miner_registry = miner_sampler
self.llm = llm
self.deep_research = deep_research
self.logger_apex = logger_apex
self.logger_wandb = logger_wandb
self.num_consumers = num_consumers
self.timeout_consumer = timeout_consumer
self.timeout_producer = timeout_producer
Expand Down Expand Up @@ -109,8 +109,8 @@ async def run_single(self, task: QueryTask) -> str:
query=query, generator_results=generator_results, reference=reference, ground_truth=ground_truth
)

if self.logger_apex:
await self.logger_apex.log(
if self.logger_wandb:
await self.logger_wandb.log(
reference=reference, discriminator_results=discriminator_results, tool_history=tool_history
)

Expand Down
5 changes: 5 additions & 0 deletions config/mainnet.yaml.example
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ websearch:
kwargs:
key: "TAVILY_API_KEY"

logger_wandb:
kwargs:
project: "apex-gan-arena"
api_key: "YOUR_WANDB_API_KEY"

llm:
kwargs:
key: "CHUTES_API_KEY"
Expand Down
5 changes: 5 additions & 0 deletions config/testnet.yaml.example
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@ websearch:
kwargs:
key: "TAVILY_API_KEY"

logger_wandb:
kwargs:
project: "apex-gan-arena"
api_key: "YOUR_WANDB_API_KEY"

llm:
kwargs:
key: "CHUTES_API_KEY"
Expand Down
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "apex"
version = "3.0.3"
version = "3.0.4"
description = "Bittensor Subnet 1: Apex"
readme = "README.md"
requires-python = "~=3.11"
Expand Down Expand Up @@ -34,6 +34,8 @@ dependencies = [
"types-cachetools>=6.0.0.20250525",
"dotenv>=0.9.9",
"pytest-mock>=3.14.1",
"wandb>=0.21.1",
"ruff>=0.12.5",
]


Expand Down
39 changes: 29 additions & 10 deletions tests/validator/test_miner_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,8 +243,12 @@ async def test_query_generators(monkeypatch: MonkeyPatch, miner_sampler: MinerSa
],
),
)
query_miners_mock: AsyncMock = AsyncMock(side_effect=["result1", "result2"])
monkeypatch.setattr(miner_sampler, "query_miners", AsyncMock(side_effect=query_miners_mock))
query_miners_with_times_mock: AsyncMock = AsyncMock(side_effect=[("result1", 0.1), ("result2", 0.2)])
monkeypatch.setattr(
miner_sampler,
"query_miners_with_times",
AsyncMock(side_effect=query_miners_with_times_mock),
)

query = "test query"
results = await miner_sampler.query_generators(query)
Expand All @@ -253,12 +257,13 @@ async def test_query_generators(monkeypatch: MonkeyPatch, miner_sampler: MinerSa
assert results.query == query
assert results.generator_hotkeys == ["key1", "key3"]
assert results.generator_results == ["result1", "result2"]
assert results.generator_times == [0.1, 0.2]

assert query_miners_mock.call_count == 2 # type: ignore
query_miners_mock.assert_any_call(
assert query_miners_with_times_mock.call_count == 2 # type: ignore
query_miners_with_times_mock.assert_any_call(
body={"step": "generator", "query": query}, endpoint="http://1.1.1.1:8000", hotkey="key1"
)
query_miners_mock.assert_any_call( # type: ignore
query_miners_with_times_mock.assert_any_call( # type: ignore
body={"step": "generator", "query": query}, endpoint="http://3.3.3.3:8002", hotkey="key3"
)

Expand All @@ -270,7 +275,7 @@ async def test_query_discriminators_selects_generator(
mock_random_choice: MagicMock, mock_random_random: MagicMock, monkeypatch: MonkeyPatch, miner_sampler: MinerSampler
) -> None:
"""Tests that a query to a discriminator is successful when a generator is selected."""
mock_random_choice.return_value = ("gen_key1", "gen_result1")
mock_random_choice.return_value = ("gen_key1", "gen_result1", 0.1)

monkeypatch.setattr(
miner_sampler,
Expand All @@ -295,7 +300,10 @@ async def test_query_discriminators_selects_generator(
)

generator_results = MinerGeneratorResults(
query="test query", generator_hotkeys=["gen_key1", "gen_key2"], generator_results=["gen_result1", "gen_result2"]
query="test query",
generator_hotkeys=["gen_key1", "gen_key2"],
generator_results=["gen_result1", "gen_result2"],
generator_times=[0.1, 0.2],
)
reference = "reference text"

Expand All @@ -308,6 +316,7 @@ async def test_query_discriminators_selects_generator(
assert results.discriminator_results == ["1", "0"]
assert results.discriminator_scores == [0.5, 0.0]
assert results.generator_score == 0.5
assert results.generator_time == 0.1


@pytest.mark.asyncio
Expand Down Expand Up @@ -339,7 +348,10 @@ async def test_query_discriminators_selects_reference(
)

generator_results = MinerGeneratorResults(
query="test query", generator_hotkeys=["gen_key1", "gen_key2"], generator_results=["gen_result1", "gen_result2"]
query="test query",
generator_hotkeys=["gen_key1", "gen_key2"],
generator_results=["gen_result1", "gen_result2"],
generator_times=[0.1, 0.2],
)
reference = "reference text"

Expand All @@ -349,6 +361,7 @@ async def test_query_discriminators_selects_reference(

assert results.generator_hotkey == "Validator"
assert results.generator_result == reference
assert results.generator_time == 0.0
assert results.discriminator_hotkeys == ["disc_key1", "disc_key2"]
assert results.discriminator_results == ["0", "1"]
assert results.discriminator_scores == [0.5, 0.0]
Expand Down Expand Up @@ -390,7 +403,10 @@ async def test_query_discriminators_response_parsing(
monkeypatch.setattr(miner_sampler, "query_miners", AsyncMock(return_value=miner_response))

generator_results = MinerGeneratorResults(
query="test query", generator_hotkeys=["gen_key1"], generator_results=["gen_result1"]
query="test query",
generator_hotkeys=["gen_key1"],
generator_results=["gen_result1"],
generator_times=[0.1],
)
reference = "reference text"

Expand Down Expand Up @@ -423,7 +439,10 @@ async def test_query_discriminators_with_db_log(monkeypatch: MonkeyPatch, miner_

with patch("random.random", return_value=0.6):
generator_results = MinerGeneratorResults(
query="test query", generator_hotkeys=["gen_key1"], generator_results=["gen_result1"]
query="test query",
generator_hotkeys=["gen_key1"],
generator_results=["gen_result1"],
generator_times=[0.1],
)
reference = "reference text"

Expand Down
Loading