Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
*.log
requirements.txt
**/*.ipynb
debug/
debug_rewards.jsonl
results.db*
sn13_db.db*
Expand Down
119 changes: 85 additions & 34 deletions apex/validator/logger_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,77 +5,128 @@
from typing import Literal, TypedDict

import aiosqlite
from loguru import logger

from apex.common.models import MinerDiscriminatorResults


class _DiscriminatorItem(TypedDict):
class DiscriminatorItem(TypedDict):
"""Item placed on the queue that represents one discriminator result row."""

kind: Literal["discriminator"]
data: MinerDiscriminatorResults


class LoggerDB:
_COMMIT_FREQ = 60
_COMMIT_CHANGES = 1000

def __init__(self, db_path: Path | str = "results.db"):
self.db_path = Path(db_path)
self._queue: asyncio.Queue[_DiscriminatorItem | object] = asyncio.Queue(maxsize=10_000)
self.db_path.parent.mkdir(parents=True, exist_ok=True)
self._queue: asyncio.Queue[DiscriminatorItem | object] = asyncio.Queue(maxsize=10_000)
self._SHUTDOWN = object()
self._closing = asyncio.Event()

async def start_loop(self) -> None:
async with aiosqlite.connect(self.db_path) as db:
await db.executescript(
"""
PRAGMA journal_mode = WAL;
PRAGMA journal_mode=WAL;
PRAGMA wal_autocheckpoint=1000;
PRAGMA synchronous=NORMAL;
PRAGMA busy_timeout=5000;

-- Results coming back from a discriminator step.
-- Store *all* fields from DiscriminatorQueryResults; list fields are serialized as JSON.
CREATE TABLE IF NOT EXISTS discriminator_results (
query TEXT,
generator_hotkey TEXT,
generator_result TEXT,
generator_score REAL,
query TEXT,
generator_hotkey TEXT,
generator_result TEXT,
generator_score REAL,
discriminator_hotkeys TEXT, -- JSON array of strings
discriminator_results TEXT, -- JSON array of strings
discriminator_scores TEXT, -- JSON array of floats
timestamp INTEGER, -- Unix timestamp when row was added
processed INTEGER DEFAULT 0,
discriminator_scores TEXT, -- JSON array of floats
timestamp INTEGER, -- Unix timestamp when row was added
processed INTEGER DEFAULT 0,
PRIMARY KEY (query, generator_hotkey)
);

CREATE INDEX IF NOT EXISTS idx_discriminator_processed
ON discriminator_results(processed);
"""
)
await db.commit()

last_commit = time.monotonic()
last_changes = db.total_changes
while True:
item = await self._queue.get()
try:
item = await asyncio.wait_for(self._queue.get(), timeout=20.0)
except TimeoutError:
if db.total_changes != last_changes:
await db.commit()
last_commit = time.monotonic()
last_changes = db.total_changes
continue

if item is self._SHUTDOWN:
self._queue.task_done()
await db.commit()
await db.execute("PRAGMA wal_checkpoint(TRUNCATE);")
break

if isinstance(item, dict) and item.get("kind") == "discriminator":
row: MinerDiscriminatorResults = item["data"]

await db.execute(
"INSERT OR REPLACE INTO discriminator_results VALUES (?,?,?,?,?,?,?,?,0)",
(
row.query,
row.generator_hotkey,
row.generator_result,
row.generator_score,
json.dumps(row.discriminator_hotkeys),
json.dumps(row.discriminator_results),
json.dumps(row.discriminator_scores),
int(time.time()), # Current Unix timestamp
),
)

# flush every 1 000 rows or on demand
if self._queue.empty() or db.total_changes % 1000 == 0:
try:
await self.add_entry(db, item=item)
except Exception:
await db.rollback()
finally:
self._queue.task_done()

commit_changes = (db.total_changes - last_changes) >= self._COMMIT_CHANGES
commit_timer = time.monotonic() - last_commit >= self._COMMIT_FREQ
if commit_changes or commit_timer:
logger.debug(f"Commiting scores to the {self.db_path}")
await db.commit()
await db.execute("PRAGMA wal_checkpoint(FULL);")
self._queue.task_done()
last_commit = time.monotonic()
last_changes = db.total_changes

async def add_entry(self, db: aiosqlite.Connection, item: DiscriminatorItem | object) -> None:
if isinstance(item, dict) and item.get("kind") == "discriminator":
row: MinerDiscriminatorResults = item["data"]

await db.execute(
"""
INSERT INTO discriminator_results (
query, generator_hotkey, generator_result, generator_score,
discriminator_hotkeys, discriminator_results, discriminator_scores, timestamp
) VALUES (?,?,?,?,?,?,?,?)
ON CONFLICT(query, generator_hotkey) DO UPDATE SET
generator_result = excluded.generator_result,
generator_score = excluded.generator_score,
discriminator_hotkeys = excluded.discriminator_hotkeys,
discriminator_results = excluded.discriminator_results,
discriminator_scores = excluded.discriminator_scores,
timestamp = excluded.timestamp
""",
(
row.query,
row.generator_hotkey,
row.generator_result,
row.generator_score,
json.dumps(row.discriminator_hotkeys),
json.dumps(row.discriminator_results),
json.dumps(row.discriminator_scores),
int(time.time()),
),
)

async def log(self, row: MinerDiscriminatorResults) -> None:
if self._closing.is_set():
logger.error("Database is shutting down")
return

await self._queue.put({"kind": "discriminator", "data": row})

async def shutdown(self) -> None:
self._closing.set()
await self._queue.join()
await self._queue.put(self._SHUTDOWN)
37 changes: 37 additions & 0 deletions apex/validator/logger_local.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import json
from datetime import datetime
from pathlib import Path

from apex.common.models import MinerDiscriminatorResults, MinerGeneratorResults


class LoggerLocal:
def __init__(self, filepath: str = "debug/logs.jsonl"):
self._debug_file_path = Path(filepath)
self._debug_file_path.parent.mkdir(exist_ok=True)

async def log(
self,
query: str,
ground_truth: int,
reference: str | None,
generator_results: MinerGeneratorResults | None,
discriminator_results: MinerDiscriminatorResults | None,
) -> None:
day = datetime.now().strftime("%Y-%m-%d")
filepath = Path(f"{self._debug_file_path.with_suffix('')}-{day}.jsonl")
record: dict[str, str | int | list[str] | list[float] | None] = {
"date": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
"query": query,
"ground_truth": ground_truth,
"reference": reference,
"generators": generator_results.generator_results if generator_results else [],
"generator_hotkeys": generator_results.generator_hotkeys if generator_results else [],
"discriminator_results": discriminator_results.discriminator_results if discriminator_results else [],
"discriminator_scores": discriminator_results.discriminator_scores if discriminator_results else [],
"discriminator_hotkeys": discriminator_results.discriminator_hotkeys if discriminator_results else [],
"generator_hotkey": discriminator_results.generator_hotkey if discriminator_results else "",
}

with filepath.open("a+") as fh:
fh.write(f"{json.dumps(record)}\n")
33 changes: 17 additions & 16 deletions apex/validator/miner_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ def __init__(
self,
chain: AsyncChain,
sample_mode: Literal["random", "sequential"] = "sequential",
sample_size: int = 50,
discriminator_sample_size: int = 50,
generator_sample_size: int = 1,
logger_db: LoggerDB | None = None,
available_uids: Sequence[int] | None = None,
available_addresses: Sequence[str] | None = None,
Expand All @@ -52,7 +53,8 @@ def __init__(
sample_mode: Sampling mode, available modes:
- random: Samples random uids.
- sequential: Samples all uids sequentially.
sample_size: Amount of miners to be samples in one call.
discriminator_sample_size: Amount of miners to be sampled for discriminator queries.
generator_sample_size: Amount of miners to be sampled for generator queries.
logger_db: Optional logger DB object.
available_uids: List of available UIDs. If None, use all UIDs.
available_addresses: List of available addresses for given UIDs. If None, use metagraph addresses.
Expand All @@ -61,7 +63,8 @@ def __init__(
"""
self._chain = chain
self._sample_mode = sample_mode
self._sample_size = sample_size
self._discriminator_sample_size = discriminator_sample_size
self._generator_sample_size = generator_sample_size
self._logger_db = logger_db
self._available_uids = available_uids
self._available_addresses = available_addresses
Expand All @@ -74,7 +77,7 @@ def __init__(
self._sample_lock = asyncio.Lock()

@async_cache(_TTL_UIDS_RESYNC)
async def _get_all_miners(self) -> list[MinerInfo]:
async def _get_all_miners(self, sample_size: int) -> list[MinerInfo]:
meta = await self._chain.metagraph()
miners: list[MinerInfo] = []
for idx in range(meta.n.item()):
Expand All @@ -101,24 +104,24 @@ async def _get_all_miners(self) -> list[MinerInfo]:
miners_test.append(miner_info)
miners = miners_test

if self._sample_size > len(miners):
if sample_size > len(miners):
logger.warning(
f"Sample size is larger than amount of miners: {self._sample_size} > {len(miners)}. "
f"Sample size is larger than amount of miners: {sample_size} > {len(miners)}. "
f"Setting sample size to {len(miners)}"
)
self._sample_size = len(miners)
sample_size = len(miners)
return miners

async def _sample_miners(self) -> list[MinerInfo]:
miners = await self._get_all_miners()
async def _sample_miners(self, sample_size: int) -> list[MinerInfo]:
miners = await self._get_all_miners(sample_size=sample_size)

miners_sample: list[MinerInfo] = []
if self._sample_mode == "random":
miners_sample = random.sample(miners, self._sample_size)
miners_sample = random.sample(miners, sample_size)

elif self._sample_mode == "sequential":
async with self._sample_lock:
while len(miners_sample) < self._sample_size:
while len(miners_sample) < (sample_size):
if not self._epoch_deque:
# Get shuffled deque of miners.
self._epoch_deque = deque(random.sample(miners, len(miners)))
Expand All @@ -127,9 +130,7 @@ async def _sample_miners(self) -> list[MinerInfo]:
else:
raise ValueError(f"Unknown sampling mode: {self._sample_mode}")

logger.debug(
f"Sampled uids (sample size = {self._sample_size}): {sorted([miner.uid for miner in miners_sample])}"
)
logger.debug(f"Sampled uids (sample size = {sample_size}): {sorted([miner.uid for miner in miners_sample])}")
return miners_sample

async def query_miners(
Expand Down Expand Up @@ -157,7 +158,7 @@ async def query_miners(

async def query_generators(self, query: str) -> MinerGeneratorResults:
"""Query the miners for the query."""
miner_information = await self._sample_miners()
miner_information = await self._sample_miners(sample_size=self._generator_sample_size)
body = {"step": "generator", "query": query}

hotkeys: list[str] = []
Expand All @@ -177,7 +178,7 @@ async def query_discriminators(
ground_truth: int,
) -> MinerDiscriminatorResults:
"""Query the miners for the query."""
miner_information = await self._sample_miners()
miner_information = await self._sample_miners(sample_size=self._discriminator_sample_size)
# Flip the coin for the generator.
if ground_truth and generator_results:
selected_generator: tuple[str, str] = random.choice(
Expand Down
Loading