Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 85 additions & 34 deletions apex/validator/logger_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,77 +5,128 @@
from typing import Literal, TypedDict

import aiosqlite
from loguru import logger

from apex.common.models import MinerDiscriminatorResults


class _DiscriminatorItem(TypedDict):
class DiscriminatorItem(TypedDict):
"""Item placed on the queue that represents one discriminator result row."""

kind: Literal["discriminator"]
data: MinerDiscriminatorResults


class LoggerDB:
_COMMIT_FREQ = 60
_COMMIT_CHANGES = 1000

def __init__(self, db_path: Path | str = "results.db"):
self.db_path = Path(db_path)
self._queue: asyncio.Queue[_DiscriminatorItem | object] = asyncio.Queue(maxsize=10_000)
self.db_path.parent.mkdir(parents=True, exist_ok=True)
self._queue: asyncio.Queue[DiscriminatorItem | object] = asyncio.Queue(maxsize=10_000)
self._SHUTDOWN = object()
self._closing = asyncio.Event()

async def start_loop(self) -> None:
async with aiosqlite.connect(self.db_path) as db:
await db.executescript(
"""
PRAGMA journal_mode = WAL;
PRAGMA journal_mode=WAL;
PRAGMA wal_autocheckpoint=1000;
PRAGMA synchronous=NORMAL;
PRAGMA busy_timeout=5000;

-- Results coming back from a discriminator step.
-- Store *all* fields from DiscriminatorQueryResults; list fields are serialized as JSON.
CREATE TABLE IF NOT EXISTS discriminator_results (
query TEXT,
generator_hotkey TEXT,
generator_result TEXT,
generator_score REAL,
query TEXT,
generator_hotkey TEXT,
generator_result TEXT,
generator_score REAL,
discriminator_hotkeys TEXT, -- JSON array of strings
discriminator_results TEXT, -- JSON array of strings
discriminator_scores TEXT, -- JSON array of floats
timestamp INTEGER, -- Unix timestamp when row was added
processed INTEGER DEFAULT 0,
discriminator_scores TEXT, -- JSON array of floats
timestamp INTEGER, -- Unix timestamp when row was added
processed INTEGER DEFAULT 0,
PRIMARY KEY (query, generator_hotkey)
);

CREATE INDEX IF NOT EXISTS idx_discriminator_processed
ON discriminator_results(processed);
"""
)
await db.commit()

last_commit = time.monotonic()
last_changes = db.total_changes
while True:
item = await self._queue.get()
try:
item = await asyncio.wait_for(self._queue.get(), timeout=20.0)
except TimeoutError:
if db.total_changes != last_changes:
await db.commit()
last_commit = time.monotonic()
last_changes = db.total_changes
continue

if item is self._SHUTDOWN:
self._queue.task_done()
await db.commit()
await db.execute("PRAGMA wal_checkpoint(TRUNCATE);")
break

if isinstance(item, dict) and item.get("kind") == "discriminator":
row: MinerDiscriminatorResults = item["data"]

await db.execute(
"INSERT OR REPLACE INTO discriminator_results VALUES (?,?,?,?,?,?,?,?,0)",
(
row.query,
row.generator_hotkey,
row.generator_result,
row.generator_score,
json.dumps(row.discriminator_hotkeys),
json.dumps(row.discriminator_results),
json.dumps(row.discriminator_scores),
int(time.time()), # Current Unix timestamp
),
)

# flush every 1 000 rows or on demand
if self._queue.empty() or db.total_changes % 1000 == 0:
try:
await self.add_entry(db, item=item)
except Exception:
await db.rollback()
finally:
self._queue.task_done()

commit_changes = (db.total_changes - last_changes) >= self._COMMIT_CHANGES
commit_timer = time.monotonic() - last_commit >= self._COMMIT_FREQ
if commit_changes or commit_timer:
logger.debug(f"Commiting scores to the {self.db_path}")
await db.commit()
await db.execute("PRAGMA wal_checkpoint(FULL);")
self._queue.task_done()
last_commit = time.monotonic()
last_changes = db.total_changes

async def add_entry(self, db: aiosqlite.Connection, item: DiscriminatorItem | object) -> None:
if isinstance(item, dict) and item.get("kind") == "discriminator":
row: MinerDiscriminatorResults = item["data"]

await db.execute(
"""
INSERT INTO discriminator_results (
query, generator_hotkey, generator_result, generator_score,
discriminator_hotkeys, discriminator_results, discriminator_scores, timestamp
) VALUES (?,?,?,?,?,?,?,?)
ON CONFLICT(query, generator_hotkey) DO UPDATE SET
generator_result = excluded.generator_result,
generator_score = excluded.generator_score,
discriminator_hotkeys = excluded.discriminator_hotkeys,
discriminator_results = excluded.discriminator_results,
discriminator_scores = excluded.discriminator_scores,
timestamp = excluded.timestamp
""",
(
row.query,
row.generator_hotkey,
row.generator_result,
row.generator_score,
json.dumps(row.discriminator_hotkeys),
json.dumps(row.discriminator_results),
json.dumps(row.discriminator_scores),
int(time.time()),
),
)

async def log(self, row: MinerDiscriminatorResults) -> None:
if self._closing.is_set():
logger.error("Database is shutting down")
return

await self._queue.put({"kind": "discriminator", "data": row})

async def shutdown(self) -> None:
self._closing.set()
await self._queue.join()
await self._queue.put(self._SHUTDOWN)
58 changes: 48 additions & 10 deletions apex/validator/miner_scorer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
import json
import time
from collections.abc import AsyncGenerator
from collections.abc import AsyncGenerator, Iterable
from contextlib import asynccontextmanager
from datetime import datetime
from pathlib import Path
Expand Down Expand Up @@ -37,13 +37,13 @@ def __init__(
async def start_loop(self) -> None:
self._running = True
while self._running:
await asyncio.sleep(self.interval)
logger.debug("Attempting to set weights")
success = await self.set_scores()
if success:
logger.info("Successfully set weights")
else:
logger.error("Failed to set weights")
await asyncio.sleep(self.interval)

async def shutdown(self) -> None:
self._running = False
Expand All @@ -52,9 +52,37 @@ async def shutdown(self) -> None:
@asynccontextmanager
async def _db() -> AsyncGenerator[aiosqlite.Connection, None]:
async with aiosqlite.connect("results.db") as conn:
await conn.execute("PRAGMA foreign_keys = ON")
await conn.execute("PRAGMA journal_mode=WAL")
await conn.execute("PRAGMA synchronous=NORMAL")
await conn.execute("PRAGMA busy_timeout=15000")
await conn.execute("PRAGMA foreign_keys=ON")
yield conn

async def _delete_expired(self, conn: aiosqlite.Connection, cutoff_ts: int, batch_size: int = 5000) -> int:
total = 0
while True:
cur = await conn.execute(
"""
DELETE FROM discriminator_results
WHERE rowid IN (
SELECT rowid FROM discriminator_results
WHERE timestamp < ?
LIMIT ?
)
""",
(cutoff_ts, batch_size),
)
n = cur.rowcount if cur.rowcount is not None else 0
total += max(n, 0)

# Commit each batch to release the write lock early.
await conn.commit()
logger.debug(f"Deleted batch: {n} (total deleeted: {total})")

if n < batch_size:
break
return total

async def set_scores(self) -> bool:
"""Set weights based on the current miner scores.

Expand All @@ -77,15 +105,16 @@ async def set_scores(self) -> bool:
""",
(cutoff_timestamp,),
) as cursor:
rows = await cursor.fetchall()
rows: Iterable[aiosqlite.Row] = await cursor.fetchall()
except BaseException as exc:
logger.exception(f"Exception during DB fetch: {exc}")
return False

# 2. Iterate over the in-memory list so that the caller can process freely.
logger.debug("Pre-processing miner's rewards")
hkey_agg_rewards: dict[str, float] = {}
rows_count = 0
for generator_hotkey, generator_score, disc_hotkeys_json, disc_scores_json in rows:
rows_count += 1
# Deserialize JSON columns.
disc_hotkeys = json.loads(disc_hotkeys_json)
disc_scores = json.loads(disc_scores_json)
Expand All @@ -101,12 +130,16 @@ async def set_scores(self) -> bool:
for hotkey, reward in reward_dict.items():
hkey_agg_rewards[hotkey] = float(hkey_agg_rewards.get(hotkey, 0.0)) + float(reward)

logger.debug(f"Fetched {rows_count} rows for scoring")
logger.debug(f"Total hotkeys to score: {len(hkey_agg_rewards)}")

# 3. Delete rows that are older than the time window.
logger.debug("Cleaning up miner's outdated history")
await conn.execute(
"DELETE FROM discriminator_results WHERE timestamp < ?",
(cutoff_timestamp,),
)
logger.debug("Cleaning up expired miner's history")
try:
deleted_total = await asyncio.wait_for(self._delete_expired(conn, cutoff_timestamp), timeout=15)
logger.debug(f"Expired rows cleanup done: {deleted_total} rows")
except TimeoutError:
logger.warning("Timed out deleting expired rows; will retry next loop")

if self._debug:
record: dict[str, str | dict[str, float]] = {
Expand All @@ -118,13 +151,18 @@ async def set_scores(self) -> bool:
fh.write(f"{record_str}\n")

if self._weight_syncer is not None:
logger.debug("Attempting to perform weight synchronization")
try:
hkey_agg_rewards = await self._weight_syncer.compute_weighted_rewards(hkey_agg_rewards)
logger.debug(f"Total hotkeys to score after weight sync: {len(hkey_agg_rewards)}")
except BaseException as exc:
logger.error(f"Failed to compute weighted average rewards over the network, skipping: {exc}")

if hkey_agg_rewards:
rewards_array = np.array(list(hkey_agg_rewards.values()))
if rewards_array.min() < 0:
logger.warning(f"Negative reward detected: {rewards_array.min():.4f}, assigning zero value instead")
hkey_agg_rewards = {hkey: max(reward, 0) for hkey, reward in hkey_agg_rewards.items()}
logger.debug(
f"Setting weights to {len(hkey_agg_rewards)} hotkeys; "
f"reward mean={rewards_array.mean():.4f} min={rewards_array.min():.4f}"
Expand Down
13 changes: 11 additions & 2 deletions tests/validator/test_miner_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from collections.abc import AsyncGenerator
from pathlib import Path
from typing import Any
from unittest.mock import AsyncMock, patch
from unittest.mock import AsyncMock, call, patch

import aiosqlite
import pytest
Expand Down Expand Up @@ -408,7 +408,16 @@ async def test_db_context_manager(self) -> None:
assert conn is mock_conn

mock_connect.assert_called_once_with("results.db")
mock_conn.execute.assert_called_once_with("PRAGMA foreign_keys = ON")

mock_conn.execute.assert_has_calls(
[
call("PRAGMA journal_mode=WAL"),
call("PRAGMA synchronous=NORMAL"),
call("PRAGMA busy_timeout=15000"),
call("PRAGMA foreign_keys=ON"),
]
)
assert mock_conn.execute.call_count == 4
finally:
Path(db_path).unlink()

Expand Down
Loading