Skip to content

Commit

Permalink
Merge pull request #301 from muupan/replicate-prioritized-replay
Browse files Browse the repository at this point in the history
Mimic the details of prioritized experience replay
  • Loading branch information
toslunar committed Aug 31, 2018
2 parents db159d9 + 4d32a21 commit eb18687
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 9 deletions.
30 changes: 23 additions & 7 deletions chainerrl/replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,8 @@ class PriorityWeightError(object):
``False``: do not normalize.
"""

def __init__(self, alpha, beta0, betasteps, eps, normalize_by_max):
def __init__(self, alpha, beta0, betasteps, eps, normalize_by_max,
error_min, error_max):
assert 0.0 <= alpha
assert 0.0 <= beta0 <= 1.0
self.alpha = alpha
Expand All @@ -196,9 +197,19 @@ def __init__(self, alpha, beta0, betasteps, eps, normalize_by_max):
normalize_by_max = 'batch'
assert normalize_by_max in [False, 'batch', 'memory']
self.normalize_by_max = normalize_by_max
self.error_min = error_min
self.error_max = error_max

def priority_from_errors(self, errors):
return [d ** self.alpha + self.eps for d in errors]

def _clip_error(error):
if self.error_min is not None:
error = max(self.error_min, error)
if self.error_max is not None:
error = min(self.error_max, error)
return error

return [(_clip_error(d) + self.eps) ** self.alpha for d in errors]

def weights_from_probabilities(self, probabilities, min_probability):
if self.normalize_by_max == 'batch':
Expand Down Expand Up @@ -227,11 +238,12 @@ class PrioritizedReplayBuffer(ReplayBuffer, PriorityWeightError):
"""

def __init__(self, capacity=None,
alpha=0.6, beta0=0.4, betasteps=2e5, eps=1e-8,
normalize_by_max=True):
alpha=0.6, beta0=0.4, betasteps=2e5, eps=0.01,
normalize_by_max=True, error_min=0, error_max=1):
self.memory = PrioritizedBuffer(capacity=capacity)
PriorityWeightError.__init__(
self, alpha, beta0, betasteps, eps, normalize_by_max)
self, alpha, beta0, betasteps, eps, normalize_by_max,
error_min=error_min, error_max=error_max)

def sample(self, n):
assert len(self.memory) >= n
Expand Down Expand Up @@ -336,7 +348,10 @@ def __init__(self, capacity=None,
default_priority_func=None,
uniform_ratio=0,
wait_priority_after_sampling=True,
return_sample_weights=True):
return_sample_weights=True,
error_min=None,
error_max=None,
):
self.current_episode = []
self.episodic_memory = PrioritizedBuffer(
capacity=None,
Expand All @@ -347,7 +362,8 @@ def __init__(self, capacity=None,
self.uniform_ratio = uniform_ratio
self.return_sample_weights = return_sample_weights
PriorityWeightError.__init__(
self, alpha, beta0, betasteps, eps, normalize_by_max)
self, alpha, beta0, betasteps, eps, normalize_by_max,
error_min=error_min, error_max=error_max)

def sample_episodes(self, n_episodes, max_len=None):
"""Sample n unique samples from this replay buffer"""
Expand Down
10 changes: 8 additions & 2 deletions tests/test_replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,8 @@ def test_append_and_sample(self):
capacity = self.capacity
rbuf = replay_buffer.PrioritizedReplayBuffer(
capacity,
normalize_by_max=self.normalize_by_max)
normalize_by_max=self.normalize_by_max,
error_max=5)

self.assertEqual(len(rbuf), 0)

Expand Down Expand Up @@ -241,9 +242,14 @@ def test_append_and_sample(self):
# Weights should be different for different TD-errors
s3 = rbuf.sample(2)
self.assertNotAlmostEqual(s3[0]['weight'], s3[1]['weight'])
rbuf.update_errors([3.14, 3.14])

# Weights should be equal for different but clipped TD-errors
rbuf.update_errors([5, 10])
s3 = rbuf.sample(2)
self.assertAlmostEqual(s3[0]['weight'], s3[1]['weight'])

# Weights should be equal for the same TD-errors
rbuf.update_errors([3.14, 3.14])
s4 = rbuf.sample(2)
self.assertAlmostEqual(s4[0]['weight'], s4[1]['weight'])

Expand Down

0 comments on commit eb18687

Please sign in to comment.