diff --git a/baal/active/active_loop.py b/baal/active/active_loop.py index f07db47d..cc59dcf2 100644 --- a/baal/active/active_loop.py +++ b/baal/active/active_loop.py @@ -79,7 +79,10 @@ def step(self, pool=None) -> bool: indices = None if len(pool) > 0: - probs = self.get_probabilities(pool, **self.kwargs) + if isinstance(self.heuristic, heuristics.Random): + probs = np.random.uniform(low=0, high=1, size=(len(pool), 1)) + else: + probs = self.get_probabilities(pool, **self.kwargs) if probs is not None and (isinstance(probs, types.GeneratorType) or len(probs) > 0): to_label, uncertainty = self.heuristic.get_ranks(probs) if indices is not None: diff --git a/tests/active/active_loop_test.py b/tests/active/active_loop_test.py index 527e030c..3a652339 100644 --- a/tests/active/active_loop_test.py +++ b/tests/active/active_loop_test.py @@ -1,6 +1,7 @@ import os import pickle import warnings +from unittest.mock import patch import numpy as np import pytest @@ -140,5 +141,24 @@ def test_deprecation(): assert issubclass(w[-1].category, DeprecationWarning) assert "ndata_to_label" in str(w[-1].message) + +@pytest.mark.parametrize('heur,num_get_probs', [(heuristics.Random(), 0), + (heuristics.BALD(), 1), + (heuristics.Entropy(), 1), + (heuristics.Variance(reduction='sum'), 1) + ]) +def test_get_probs(heur, num_get_probs): + dataset = ActiveLearningDataset(MyDataset(), make_unlabelled=lambda x: -1) + active_loop = ActiveLearningLoop(dataset, + get_probs_iter, + heur, + query_size=5, + dummy_param=1) + dataset.label_randomly(10) + with patch.object(active_loop, "get_probabilities") as mock_probs: + active_loop.step() + assert mock_probs.call_count == num_get_probs + + if __name__ == '__main__': pytest.main()