diff --git a/chainerrl/replay_buffers/prioritized.py b/chainerrl/replay_buffers/prioritized.py index 3519fe0c8..c77fd6f7f 100644 --- a/chainerrl/replay_buffers/prioritized.py +++ b/chainerrl/replay_buffers/prioritized.py @@ -65,7 +65,7 @@ def _clip_error(error): def weights_from_probabilities(self, probabilities, min_probability): if self.normalize_by_max == 'batch': # discard global min and compute batch min - min_probability = np.min(min_probability) + min_probability = np.min(probabilities) if self.normalize_by_max: weights = [(p / min_probability) ** -self.beta for p in probabilities] diff --git a/tests/test_replay_buffer.py b/tests/test_replay_buffer.py index 555d1165e..ea4cf9004 100644 --- a/tests/test_replay_buffer.py +++ b/tests/test_replay_buffer.py @@ -337,6 +337,52 @@ def test_append_and_sample(self): s4 = rbuf.sample(2) self.assertAlmostEqual(s4[0][0]['weight'], s4[1][0]['weight']) + def test_normalize_by_max(self): + + rbuf = replay_buffer.PrioritizedReplayBuffer( + self.capacity, + normalize_by_max=self.normalize_by_max, + error_max=1000, + num_steps=self.num_steps, + ) + + # Add 100 transitions + for i in range(100): + trans = dict(state=i, action=1, reward=2, next_state=i + 1, + next_action=1, is_state_terminal=False) + rbuf.append(**trans) + assert len(rbuf) == 100 + + def set_errors_based_on_state(rbuf, samples): + # Use the value of 'state' as an error, so that state 0 will have + # the smallest error, thus the largest weight + errors = [s[0]['state'] for s in samples] + rbuf.update_errors(errors) + + # Assign different errors to all the transitions first + samples = rbuf.sample(100) + set_errors_based_on_state(rbuf, samples) + + # Repeatedly check how weights are normalized + for i in range(100): + samples = rbuf.sample(i + 1) + # All the weights must be unique + self.assertEqual( + len(set(s[0]['weight'] for s in samples)), len(samples)) + # Now check the maximum weight in a minibatch + max_w = max([s[0]['weight'] for s in samples]) + if self.normalize_by_max == 'batch': + # Maximum weight in a minibatch must be 1 + self.assertAlmostEqual(max_w, 1) + elif self.normalize_by_max == 'memory': + # Maximum weight in a minibatch must be less than 1 unless + # the minibatch contains the transition of least error. + if any(s[0]['state'] == 0 for s in samples): + self.assertAlmostEqual(max_w, 1) + else: + self.assertLess(max_w, 1) + set_errors_based_on_state(rbuf, samples) + def test_capacity(self): capacity = self.capacity if capacity is None: