From 76ff6cd413406f18d4ae77d5ce6b23f15cfea70b Mon Sep 17 00:00:00 2001 From: muupan Date: Mon, 18 Jun 2018 17:12:35 +0900 Subject: [PATCH] Mimic the paper epsilon is added to (absoulte) errors, not priorities default epsilon is 0.01 TD errors are clipped by [-1, 1] --- chainerrl/replay_buffer.py | 30 +++++++++++++++++++++++------- tests/test_replay_buffer.py | 9 +++++++-- 2 files changed, 30 insertions(+), 9 deletions(-) diff --git a/chainerrl/replay_buffer.py b/chainerrl/replay_buffer.py index 2f66265ff..1eb8d0454 100644 --- a/chainerrl/replay_buffer.py +++ b/chainerrl/replay_buffer.py @@ -176,7 +176,8 @@ class PriorityWeightError(object): of a batch. """ - def __init__(self, alpha, beta0, betasteps, eps, normalize_by_max): + def __init__(self, alpha, beta0, betasteps, eps, normalize_by_max, + error_min, error_max): assert 0.0 <= alpha assert 0.0 <= beta0 <= 1.0 self.alpha = alpha @@ -187,9 +188,19 @@ def __init__(self, alpha, beta0, betasteps, eps, normalize_by_max): self.beta_add = (1.0 - beta0) / betasteps self.eps = eps self.normalize_by_max = normalize_by_max + self.error_min = error_min + self.error_max = error_max def priority_from_errors(self, errors): - return [d ** self.alpha + self.eps for d in errors] + + def _clip_error(error): + if self.error_min is not None: + error = max(self.error_min, error) + if self.error_max is not None: + error = min(self.error_max, error) + return error + + return [(_clip_error(d) + self.eps) ** self.alpha for d in errors] def weights_from_probabilities(self, probabilities): tmp = [p for p in probabilities if p is not None] @@ -217,11 +228,12 @@ class PrioritizedReplayBuffer(ReplayBuffer, PriorityWeightError): """ def __init__(self, capacity=None, - alpha=0.6, beta0=0.4, betasteps=2e5, eps=1e-8, - normalize_by_max=True): + alpha=0.6, beta0=0.4, betasteps=2e5, eps=0.01, + normalize_by_max=True, error_min=0, error_max=1): self.memory = PrioritizedBuffer(capacity=capacity) PriorityWeightError.__init__( - self, alpha, beta0, betasteps, eps, normalize_by_max) + self, alpha, beta0, betasteps, eps, normalize_by_max, + error_min=error_min, error_max=error_max) def sample(self, n): assert len(self.memory) >= n @@ -326,7 +338,10 @@ def __init__(self, capacity=None, default_priority_func=None, uniform_ratio=0, wait_priority_after_sampling=True, - return_sample_weights=True): + return_sample_weights=True, + error_min=None, + error_max=None, + ): self.current_episode = [] self.episodic_memory = PrioritizedBuffer( capacity=None, @@ -337,7 +352,8 @@ def __init__(self, capacity=None, self.uniform_ratio = uniform_ratio self.return_sample_weights = return_sample_weights PriorityWeightError.__init__( - self, alpha, beta0, betasteps, eps, normalize_by_max) + self, alpha, beta0, betasteps, eps, normalize_by_max, + error_min=error_min, error_max=error_max) def sample_episodes(self, n_episodes, max_len=None): """Sample n unique samples from this replay buffer""" diff --git a/tests/test_replay_buffer.py b/tests/test_replay_buffer.py index 414380fd6..1e84c951c 100644 --- a/tests/test_replay_buffer.py +++ b/tests/test_replay_buffer.py @@ -202,7 +202,7 @@ class TestPrioritizedReplayBuffer(unittest.TestCase): def test_append_and_sample(self): capacity = self.capacity - rbuf = replay_buffer.PrioritizedReplayBuffer(capacity) + rbuf = replay_buffer.PrioritizedReplayBuffer(capacity, error_max=5) self.assertEqual(len(rbuf), 0) @@ -238,9 +238,14 @@ def test_append_and_sample(self): # Weights should be different for different TD-errors s3 = rbuf.sample(2) self.assertNotAlmostEqual(s3[0]['weight'], s3[1]['weight']) - rbuf.update_errors([3.14, 3.14]) + + # Weights should be equal for different but clipped TD-errors + rbuf.update_errors([5, 10]) + s3 = rbuf.sample(2) + self.assertAlmostEqual(s3[0]['weight'], s3[1]['weight']) # Weights should be equal for the same TD-errors + rbuf.update_errors([3.14, 3.14]) s4 = rbuf.sample(2) self.assertAlmostEqual(s4[0]['weight'], s4[1]['weight'])