Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 17 additions & 16 deletions apex/validator/miner_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ def __init__(
self,
chain: AsyncChain,
sample_mode: Literal["random", "sequential"] = "sequential",
sample_size: int = 50,
discriminator_sample_size: int = 50,
generator_sample_size: int = 1,
logger_db: LoggerDB | None = None,
available_uids: Sequence[int] | None = None,
available_addresses: Sequence[str] | None = None,
Expand All @@ -52,7 +53,8 @@ def __init__(
sample_mode: Sampling mode, available modes:
- random: Samples random uids.
- sequential: Samples all uids sequentially.
sample_size: Amount of miners to be samples in one call.
discriminator_sample_size: Amount of miners to be sampled for discriminator queries.
generator_sample_size: Amount of miners to be sampled for generator queries.
logger_db: Optional logger DB object.
available_uids: List of available UIDs. If None, use all UIDs.
available_addresses: List of available addresses for given UIDs. If None, use metagraph addresses.
Expand All @@ -61,7 +63,8 @@ def __init__(
"""
self._chain = chain
self._sample_mode = sample_mode
self._sample_size = sample_size
self._discriminator_sample_size = discriminator_sample_size
self._generator_sample_size = generator_sample_size
self._logger_db = logger_db
self._available_uids = available_uids
self._available_addresses = available_addresses
Expand All @@ -74,7 +77,7 @@ def __init__(
self._sample_lock = asyncio.Lock()

@async_cache(_TTL_UIDS_RESYNC)
async def _get_all_miners(self) -> list[MinerInfo]:
async def _get_all_miners(self, sample_size: int) -> list[MinerInfo]:
meta = await self._chain.metagraph()
miners: list[MinerInfo] = []
for idx in range(meta.n.item()):
Expand All @@ -101,24 +104,24 @@ async def _get_all_miners(self) -> list[MinerInfo]:
miners_test.append(miner_info)
miners = miners_test

if self._sample_size > len(miners):
if sample_size > len(miners):
logger.warning(
f"Sample size is larger than amount of miners: {self._sample_size} > {len(miners)}. "
f"Sample size is larger than amount of miners: {sample_size} > {len(miners)}. "
f"Setting sample size to {len(miners)}"
)
self._sample_size = len(miners)
sample_size = len(miners)
return miners

async def _sample_miners(self) -> list[MinerInfo]:
miners = await self._get_all_miners()
async def _sample_miners(self, sample_size: int) -> list[MinerInfo]:
miners = await self._get_all_miners(sample_size=sample_size)

miners_sample: list[MinerInfo] = []
if self._sample_mode == "random":
miners_sample = random.sample(miners, self._sample_size)
miners_sample = random.sample(miners, sample_size)

elif self._sample_mode == "sequential":
async with self._sample_lock:
while len(miners_sample) < self._sample_size:
while len(miners_sample) < (sample_size):
if not self._epoch_deque:
# Get shuffled deque of miners.
self._epoch_deque = deque(random.sample(miners, len(miners)))
Expand All @@ -127,9 +130,7 @@ async def _sample_miners(self) -> list[MinerInfo]:
else:
raise ValueError(f"Unknown sampling mode: {self._sample_mode}")

logger.debug(
f"Sampled uids (sample size = {self._sample_size}): {sorted([miner.uid for miner in miners_sample])}"
)
logger.debug(f"Sampled uids (sample size = {sample_size}): {sorted([miner.uid for miner in miners_sample])}")
return miners_sample

async def query_miners(
Expand Down Expand Up @@ -157,7 +158,7 @@ async def query_miners(

async def query_generators(self, query: str) -> MinerGeneratorResults:
"""Query the miners for the query."""
miner_information = await self._sample_miners()
miner_information = await self._sample_miners(sample_size=self._generator_sample_size)
body = {"step": "generator", "query": query}

hotkeys: list[str] = []
Expand All @@ -177,7 +178,7 @@ async def query_discriminators(
ground_truth: int,
) -> MinerDiscriminatorResults:
"""Query the miners for the query."""
miner_information = await self._sample_miners()
miner_information = await self._sample_miners(sample_size=self._discriminator_sample_size)
# Flip the coin for the generator.
if ground_truth and generator_results:
selected_generator: tuple[str, str] = random.choice(
Expand Down
20 changes: 10 additions & 10 deletions tests/validator/test_miner_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def test_miner_info_hash() -> None:
@pytest.mark.asyncio
async def test_get_all_miners(miner_sampler: MinerSampler, mock_metagraph: MockMetagraph) -> None:
"""Tests that all miners are returned."""
miners = await miner_sampler._get_all_miners()
miners = await miner_sampler._get_all_miners(sample_size=3)
assert len(miners) == 3
uids = {m.uid for m in miners}
assert uids == {1, 3, 5}
Expand All @@ -110,7 +110,7 @@ async def test_get_all_miners_with_available_uids(mock_chain: MockAsyncChain) ->
available_uids=[1, 5, 10],
validator_min_stake=16000,
)
miners = await sampler._get_all_miners()
miners = await sampler._get_all_miners(sample_size=3)
assert len(miners) == 2
uids = {m.uid for m in miners}
assert uids == {1, 5}
Expand All @@ -125,7 +125,7 @@ async def test_get_all_miners_with_available_uids_and_addresses(mock_chain: Mock
available_addresses=["http://localhost:1234", "http://localhost:5678"],
validator_min_stake=16000,
)
miners = await sampler._get_all_miners()
miners = await sampler._get_all_miners(sample_size=3)
assert len(miners) == 2
miner1 = next(m for m in miners if m.uid == 1)
miner3 = next(m for m in miners if m.uid == 3)
Expand All @@ -137,7 +137,7 @@ async def test_get_all_miners_with_available_uids_and_addresses(mock_chain: Mock
async def test_sample_miners_random(miner_sampler: MinerSampler) -> None:
"""Tests that a random sample of miners is returned."""
miner_sampler._sample_mode = "random"
miner_sampler._sample_size = 2
miner_sampler._discriminator_sample_size = 2

with patch(
"random.sample",
Expand All @@ -146,10 +146,10 @@ async def test_sample_miners_random(miner_sampler: MinerSampler) -> None:
MinerInfo(hotkey="key3", uid=3, address="http://3.3.3.3:8002"),
],
) as mock_random_sample:
miners = await miner_sampler._sample_miners()
miners = await miner_sampler._sample_miners(sample_size=2)
assert len(miners) == 2
mock_random_sample.assert_called_once()
all_miners = await miner_sampler._get_all_miners()
all_miners = await miner_sampler._get_all_miners(sample_size=2)
arg_uids = {m.uid for m in mock_random_sample.call_args[0][0]}
all_uids = {m.uid for m in all_miners}
assert arg_uids == all_uids
Expand All @@ -160,9 +160,9 @@ async def test_sample_miners_random(miner_sampler: MinerSampler) -> None:
async def test_sample_miners_sequential(monkeypatch: MagicMock, miner_sampler: MinerSampler) -> None:
"""Tests that a sequential sample of miners is returned."""
miner_sampler._sample_mode = "sequential"
miner_sampler._sample_size = 2
miner_sampler._discriminator_sample_size = 2

all_miners = await miner_sampler._get_all_miners()
all_miners = await miner_sampler._get_all_miners(sample_size=2)
all_miners.sort(key=lambda m: m.uid)
monkeypatch.setattr(miner_sampler, "_get_all_miners", AsyncMock(return_value=all_miners))

Expand All @@ -171,7 +171,7 @@ async def test_sample_miners_sequential(monkeypatch: MagicMock, miner_sampler: M
"random.sample",
return_value=[MinerInfo(uid=1, address="", hotkey="1"), MinerInfo(uid=5, address="", hotkey="5")],
):
miners1 = await miner_sampler._sample_miners()
miners1 = await miner_sampler._sample_miners(sample_size=2)

assert len(miners1) == 2
assert {m.uid for m in miners1} == {all_miners[0].uid, all_miners[2].uid}
Expand All @@ -181,7 +181,7 @@ async def test_sample_miners_sequential(monkeypatch: MagicMock, miner_sampler: M
"random.sample",
return_value=[MinerInfo(uid=3, address="", hotkey="3"), MinerInfo(uid=5, address="", hotkey="5")],
):
miners2 = await miner_sampler._sample_miners()
miners2 = await miner_sampler._sample_miners(sample_size=2)

assert len(miners2) == 2
assert {m.uid for m in miners2} == {all_miners[1].uid, all_miners[2].uid}
Expand Down