diff --git a/apex/validator/miner_sampler.py b/apex/validator/miner_sampler.py index c4fc10433..b6f192aa2 100644 --- a/apex/validator/miner_sampler.py +++ b/apex/validator/miner_sampler.py @@ -39,7 +39,8 @@ def __init__( self, chain: AsyncChain, sample_mode: Literal["random", "sequential"] = "sequential", - sample_size: int = 50, + discriminator_sample_size: int = 50, + generator_sample_size: int = 1, logger_db: LoggerDB | None = None, available_uids: Sequence[int] | None = None, available_addresses: Sequence[str] | None = None, @@ -52,7 +53,8 @@ def __init__( sample_mode: Sampling mode, available modes: - random: Samples random uids. - sequential: Samples all uids sequentially. - sample_size: Amount of miners to be samples in one call. + discriminator_sample_size: Amount of miners to be sampled for discriminator queries. + generator_sample_size: Amount of miners to be sampled for generator queries. logger_db: Optional logger DB object. available_uids: List of available UIDs. If None, use all UIDs. available_addresses: List of available addresses for given UIDs. If None, use metagraph addresses. @@ -61,7 +63,8 @@ def __init__( """ self._chain = chain self._sample_mode = sample_mode - self._sample_size = sample_size + self._discriminator_sample_size = discriminator_sample_size + self._generator_sample_size = generator_sample_size self._logger_db = logger_db self._available_uids = available_uids self._available_addresses = available_addresses @@ -74,7 +77,7 @@ def __init__( self._sample_lock = asyncio.Lock() @async_cache(_TTL_UIDS_RESYNC) - async def _get_all_miners(self) -> list[MinerInfo]: + async def _get_all_miners(self, sample_size: int) -> list[MinerInfo]: meta = await self._chain.metagraph() miners: list[MinerInfo] = [] for idx in range(meta.n.item()): @@ -101,24 +104,24 @@ async def _get_all_miners(self) -> list[MinerInfo]: miners_test.append(miner_info) miners = miners_test - if self._sample_size > len(miners): + if sample_size > len(miners): logger.warning( - f"Sample size is larger than amount of miners: {self._sample_size} > {len(miners)}. " + f"Sample size is larger than amount of miners: {sample_size} > {len(miners)}. " f"Setting sample size to {len(miners)}" ) - self._sample_size = len(miners) + sample_size = len(miners) return miners - async def _sample_miners(self) -> list[MinerInfo]: - miners = await self._get_all_miners() + async def _sample_miners(self, sample_size: int) -> list[MinerInfo]: + miners = await self._get_all_miners(sample_size=sample_size) miners_sample: list[MinerInfo] = [] if self._sample_mode == "random": - miners_sample = random.sample(miners, self._sample_size) + miners_sample = random.sample(miners, sample_size) elif self._sample_mode == "sequential": async with self._sample_lock: - while len(miners_sample) < self._sample_size: + while len(miners_sample) < (sample_size): if not self._epoch_deque: # Get shuffled deque of miners. self._epoch_deque = deque(random.sample(miners, len(miners))) @@ -127,9 +130,7 @@ async def _sample_miners(self) -> list[MinerInfo]: else: raise ValueError(f"Unknown sampling mode: {self._sample_mode}") - logger.debug( - f"Sampled uids (sample size = {self._sample_size}): {sorted([miner.uid for miner in miners_sample])}" - ) + logger.debug(f"Sampled uids (sample size = {sample_size}): {sorted([miner.uid for miner in miners_sample])}") return miners_sample async def query_miners( @@ -157,7 +158,7 @@ async def query_miners( async def query_generators(self, query: str) -> MinerGeneratorResults: """Query the miners for the query.""" - miner_information = await self._sample_miners() + miner_information = await self._sample_miners(sample_size=self._generator_sample_size) body = {"step": "generator", "query": query} hotkeys: list[str] = [] @@ -177,7 +178,7 @@ async def query_discriminators( ground_truth: int, ) -> MinerDiscriminatorResults: """Query the miners for the query.""" - miner_information = await self._sample_miners() + miner_information = await self._sample_miners(sample_size=self._discriminator_sample_size) # Flip the coin for the generator. if ground_truth and generator_results: selected_generator: tuple[str, str] = random.choice( diff --git a/tests/validator/test_miner_sampler.py b/tests/validator/test_miner_sampler.py index 8a4b2fcb1..2195abbcc 100644 --- a/tests/validator/test_miner_sampler.py +++ b/tests/validator/test_miner_sampler.py @@ -91,7 +91,7 @@ def test_miner_info_hash() -> None: @pytest.mark.asyncio async def test_get_all_miners(miner_sampler: MinerSampler, mock_metagraph: MockMetagraph) -> None: """Tests that all miners are returned.""" - miners = await miner_sampler._get_all_miners() + miners = await miner_sampler._get_all_miners(sample_size=3) assert len(miners) == 3 uids = {m.uid for m in miners} assert uids == {1, 3, 5} @@ -110,7 +110,7 @@ async def test_get_all_miners_with_available_uids(mock_chain: MockAsyncChain) -> available_uids=[1, 5, 10], validator_min_stake=16000, ) - miners = await sampler._get_all_miners() + miners = await sampler._get_all_miners(sample_size=3) assert len(miners) == 2 uids = {m.uid for m in miners} assert uids == {1, 5} @@ -125,7 +125,7 @@ async def test_get_all_miners_with_available_uids_and_addresses(mock_chain: Mock available_addresses=["http://localhost:1234", "http://localhost:5678"], validator_min_stake=16000, ) - miners = await sampler._get_all_miners() + miners = await sampler._get_all_miners(sample_size=3) assert len(miners) == 2 miner1 = next(m for m in miners if m.uid == 1) miner3 = next(m for m in miners if m.uid == 3) @@ -137,7 +137,7 @@ async def test_get_all_miners_with_available_uids_and_addresses(mock_chain: Mock async def test_sample_miners_random(miner_sampler: MinerSampler) -> None: """Tests that a random sample of miners is returned.""" miner_sampler._sample_mode = "random" - miner_sampler._sample_size = 2 + miner_sampler._discriminator_sample_size = 2 with patch( "random.sample", @@ -146,10 +146,10 @@ async def test_sample_miners_random(miner_sampler: MinerSampler) -> None: MinerInfo(hotkey="key3", uid=3, address="http://3.3.3.3:8002"), ], ) as mock_random_sample: - miners = await miner_sampler._sample_miners() + miners = await miner_sampler._sample_miners(sample_size=2) assert len(miners) == 2 mock_random_sample.assert_called_once() - all_miners = await miner_sampler._get_all_miners() + all_miners = await miner_sampler._get_all_miners(sample_size=2) arg_uids = {m.uid for m in mock_random_sample.call_args[0][0]} all_uids = {m.uid for m in all_miners} assert arg_uids == all_uids @@ -160,9 +160,9 @@ async def test_sample_miners_random(miner_sampler: MinerSampler) -> None: async def test_sample_miners_sequential(monkeypatch: MagicMock, miner_sampler: MinerSampler) -> None: """Tests that a sequential sample of miners is returned.""" miner_sampler._sample_mode = "sequential" - miner_sampler._sample_size = 2 + miner_sampler._discriminator_sample_size = 2 - all_miners = await miner_sampler._get_all_miners() + all_miners = await miner_sampler._get_all_miners(sample_size=2) all_miners.sort(key=lambda m: m.uid) monkeypatch.setattr(miner_sampler, "_get_all_miners", AsyncMock(return_value=all_miners)) @@ -171,7 +171,7 @@ async def test_sample_miners_sequential(monkeypatch: MagicMock, miner_sampler: M "random.sample", return_value=[MinerInfo(uid=1, address="", hotkey="1"), MinerInfo(uid=5, address="", hotkey="5")], ): - miners1 = await miner_sampler._sample_miners() + miners1 = await miner_sampler._sample_miners(sample_size=2) assert len(miners1) == 2 assert {m.uid for m in miners1} == {all_miners[0].uid, all_miners[2].uid} @@ -181,7 +181,7 @@ async def test_sample_miners_sequential(monkeypatch: MagicMock, miner_sampler: M "random.sample", return_value=[MinerInfo(uid=3, address="", hotkey="3"), MinerInfo(uid=5, address="", hotkey="5")], ): - miners2 = await miner_sampler._sample_miners() + miners2 = await miner_sampler._sample_miners(sample_size=2) assert len(miners2) == 2 assert {m.uid for m in miners2} == {all_miners[1].uid, all_miners[2].uid}