Skip to content

Commit

Permalink
Merge 7c72643 into a6c3c50
Browse files Browse the repository at this point in the history
  • Loading branch information
muupan committed Oct 25, 2019
2 parents a6c3c50 + 7c72643 commit b8d1346
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 1 deletion.
2 changes: 1 addition & 1 deletion chainerrl/replay_buffers/prioritized.py
Expand Up @@ -65,7 +65,7 @@ def _clip_error(error):
def weights_from_probabilities(self, probabilities, min_probability):
if self.normalize_by_max == 'batch':
# discard global min and compute batch min
min_probability = np.min(min_probability)
min_probability = np.min(probabilities)
if self.normalize_by_max:
weights = [(p / min_probability) ** -self.beta
for p in probabilities]
Expand Down
46 changes: 46 additions & 0 deletions tests/test_replay_buffer.py
Expand Up @@ -337,6 +337,52 @@ def test_append_and_sample(self):
s4 = rbuf.sample(2)
self.assertAlmostEqual(s4[0][0]['weight'], s4[1][0]['weight'])

def test_normalize_by_max(self):

rbuf = replay_buffer.PrioritizedReplayBuffer(
self.capacity,
normalize_by_max=self.normalize_by_max,
error_max=1000,
num_steps=self.num_steps,
)

# Add 100 transitions
for i in range(100):
trans = dict(state=i, action=1, reward=2, next_state=i + 1,
next_action=1, is_state_terminal=False)
rbuf.append(**trans)
assert len(rbuf) == 100

def set_errors_based_on_state(rbuf, samples):
# Use the value of 'state' as an error, so that state 0 will have
# the smallest error, thus the largest weight
errors = [s[0]['state'] for s in samples]
rbuf.update_errors(errors)

# Assign different errors to all the transitions first
samples = rbuf.sample(100)
set_errors_based_on_state(rbuf, samples)

# Repeatedly check how weights are normalized
for i in range(100):
samples = rbuf.sample(i + 1)
# All the weights must be unique
self.assertEqual(
len(set(s[0]['weight'] for s in samples)), len(samples))
# Now check the maximum weight in a minibatch
max_w = max([s[0]['weight'] for s in samples])
if self.normalize_by_max == 'batch':
# Maximum weight in a minibatch must be 1
self.assertAlmostEqual(max_w, 1)
elif self.normalize_by_max == 'memory':
# Maximum weight in a minibatch must be less than 1 unless
# the minibatch contains the transition of least error.
if any(s[0]['state'] == 0 for s in samples):
self.assertAlmostEqual(max_w, 1)
else:
self.assertLess(max_w, 1)
set_errors_based_on_state(rbuf, samples)

def test_capacity(self):
capacity = self.capacity
if capacity is None:
Expand Down

0 comments on commit b8d1346

Please sign in to comment.