Skip to content

Commit

Permalink
Merge branch 'master' into replicate-prioritized-replay
Browse files Browse the repository at this point in the history
  • Loading branch information
muupan committed Aug 31, 2018
2 parents 7214d4a + db159d9 commit 4d32a21
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 28 deletions.
43 changes: 32 additions & 11 deletions chainerrl/misc/prioritized.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ def __init__(self, capacity=None, wait_priority_after_sampling=True,
initial_max_priority=1.0):
self.capacity = capacity
self.data = collections.deque()
self.data_priority = SumTreeQueue()
self.priority_sums = SumTreeQueue()
self.priority_mins = MinTreeQueue()
self.max_priority = initial_max_priority
self.wait_priority_after_sampling = wait_priority_after_sampling
self.flag_wait_priority = False
Expand All @@ -34,38 +35,44 @@ def append(self, value, priority=None):
priority = self.max_priority

self.data.append(value)
self.data_priority.append(priority)
self.priority_sums.append(priority)
self.priority_mins.append(priority)

def popleft(self):
assert len(self) > 0
self.data_priority.popleft()
self.priority_sums.popleft()
self.priority_mins.popleft()
return self.data.popleft()

def _sample_indices_and_probabilities(self, n, uniform_ratio):
total_priority = self.data_priority.sum()
total_priority = self.priority_sums.sum()
min_prob = self.priority_mins.min() / total_priority
indices = []
priorities = []
if uniform_ratio > 0:
# Mix uniform samples and prioritized samples
n_uniform = np.random.binomial(n, uniform_ratio)
un_indices, un_priorities = \
self.data_priority.uniform_sample(
self.priority_sums.uniform_sample(
n_uniform, remove=self.wait_priority_after_sampling)
indices.extend(un_indices)
priorities.extend(un_priorities)
n -= n_uniform
min_prob = uniform_ratio / len(self) \
+ (1 - uniform_ratio) * min_prob

pr_indices, pr_priorities = \
self.data_priority.prioritized_sample(
self.priority_sums.prioritized_sample(
n, remove=self.wait_priority_after_sampling)
indices.extend(pr_indices)
priorities.extend(pr_priorities)

return indices, [
probs = [
uniform_ratio / len(self)
+ (1 - uniform_ratio) * pri / total_priority
for pri in priorities
]
return indices, probs, min_prob

def sample(self, n, uniform_ratio=0):
"""Sample data along with their corresponding probabilities.
Expand All @@ -79,20 +86,22 @@ def sample(self, n, uniform_ratio=0):
"""
assert (not self.wait_priority_after_sampling or
not self.flag_wait_priority)
indices, probabilities = self._sample_indices_and_probabilities(
n, uniform_ratio=uniform_ratio)
indices, probabilities, min_prob = \
self._sample_indices_and_probabilities(
n, uniform_ratio=uniform_ratio)
sampled = [self.data[i] for i in indices]
self.sampled_indices = indices
self.flag_wait_priority = True
return sampled, probabilities
return sampled, probabilities, min_prob

def set_last_priority(self, priority):
assert (not self.wait_priority_after_sampling or
self.flag_wait_priority)
assert all([p > 0.0 for p in priority])
assert len(self.sampled_indices) == len(priority)
for i, p in zip(self.sampled_indices, priority):
self.data_priority[i] = p
self.priority_sums[i] = p
self.priority_mins[i] = p
self.max_priority = max(self.max_priority, p)
self.flag_wait_priority = False
self.sampled_indices = []
Expand Down Expand Up @@ -277,6 +286,18 @@ def prioritized_sample(self, n, remove):
return ixs, vals


class MinTreeQueue(TreeQueue):

def __init__(self):
super().__init__(op=min)

def min(self):
if self.length == 0:
return np.inf
else:
return self.root[2]


# Deprecated
class SumTree (object):
"""Fast weighted sampling.
Expand Down
40 changes: 25 additions & 15 deletions chainerrl/replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,14 +166,20 @@ def stop_current_episode(self):
class PriorityWeightError(object):
"""For propotional prioritization
alpha determines how much prioritization is used.
beta determines how much importance sampling weights are used. beta is
scheduled by ``beta0`` and ``betasteps``.
Args:
alpha (float): A hyperparameter that determines how much
prioritization is used
beta0, betasteps (float): Schedule of beta. beta determines how much
importance sampling weights are used.
alpha (float): Exponent of errors to compute probabilities to sample
beta0 (float): Initial value of beta
betasteps (float): Steps to anneal beta to 1
eps (float): To revisit a step after its error becomes near zero
normalize_by_max (bool): normalize weights by maximum priority
of a batch.
normalize_by_max (str): Method to normalize weights. ``'batch'`` or
``True`` (default): divide by the maximum weight in the sampled
batch. ``'memory'``: divide by the maximum weight in the memory.
``False``: do not normalize.
"""

def __init__(self, alpha, beta0, betasteps, eps, normalize_by_max,
Expand All @@ -187,6 +193,9 @@ def __init__(self, alpha, beta0, betasteps, eps, normalize_by_max,
else:
self.beta_add = (1.0 - beta0) / betasteps
self.eps = eps
if normalize_by_max is True:
normalize_by_max = 'batch'
assert normalize_by_max in [False, 'batch', 'memory']
self.normalize_by_max = normalize_by_max
self.error_min = error_min
self.error_max = error_max
Expand All @@ -202,12 +211,13 @@ def _clip_error(error):

return [(_clip_error(d) + self.eps) ** self.alpha for d in errors]

def weights_from_probabilities(self, probabilities):
tmp = [p for p in probabilities if p is not None]
minp = min(tmp) if tmp else 1.0
probabilities = [minp if p is None else p for p in probabilities]
def weights_from_probabilities(self, probabilities, min_probability):
if self.normalize_by_max == 'batch':
# discard global min and compute batch min
min_probability = np.min(min_probability)
if self.normalize_by_max:
weights = [(p / minp) ** -self.beta for p in probabilities]
weights = [(p / min_probability) ** -self.beta
for p in probabilities]
else:
weights = [(len(self.memory) * p) ** -self.beta
for p in probabilities]
Expand Down Expand Up @@ -237,8 +247,8 @@ def __init__(self, capacity=None,

def sample(self, n):
assert len(self.memory) >= n
sampled, probabilities = self.memory.sample(n)
weights = self.weights_from_probabilities(probabilities)
sampled, probabilities, min_prob = self.memory.sample(n)
weights = self.weights_from_probabilities(probabilities, min_prob)
for e, w in zip(sampled, weights):
e['weight'] = w
return sampled
Expand Down Expand Up @@ -358,12 +368,12 @@ def __init__(self, capacity=None,
def sample_episodes(self, n_episodes, max_len=None):
"""Sample n unique samples from this replay buffer"""
assert len(self.episodic_memory) >= n_episodes
episodes, probabilities = self.episodic_memory.sample(
episodes, probabilities, min_prob = self.episodic_memory.sample(
n_episodes, uniform_ratio=self.uniform_ratio)
if max_len is not None:
episodes = [random_subseq(ep, max_len) for ep in episodes]
if self.return_sample_weights:
weights = self.weights_from_probabilities(probabilities)
weights = self.weights_from_probabilities(probabilities, min_prob)
return episodes, weights
else:
return episodes
Expand Down
2 changes: 1 addition & 1 deletion tests/misc_tests/test_prioritized.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def priority(x, n):
return priority_init[x] / count_sampled[x]

for t in range(200):
sampled, probabilities = \
sampled, probabilities, _ = \
buf.sample(16, uniform_ratio=self.uniform_ratio)
priority_old = [priority(x, count_sampled[x]) for x in sampled]
if self.uniform_ratio == 0:
Expand Down
9 changes: 8 additions & 1 deletion tests/test_replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,13 +196,17 @@ def test_save_and_load(self):
@testing.parameterize(*testing.product(
{
'capacity': [100, None],
'normalize_by_max': ['batch', 'memory'],
}
))
class TestPrioritizedReplayBuffer(unittest.TestCase):

def test_append_and_sample(self):
capacity = self.capacity
rbuf = replay_buffer.PrioritizedReplayBuffer(capacity, error_max=5)
rbuf = replay_buffer.PrioritizedReplayBuffer(
capacity,
normalize_by_max=self.normalize_by_max,
error_max=5)

self.assertEqual(len(rbuf), 0)

Expand Down Expand Up @@ -322,13 +326,15 @@ def exp_return_of_episode(episode):
@testing.parameterize(*(
testing.product({
'capacity': [100],
'normalize_by_max': ['batch', 'memory'],
'wait_priority_after_sampling': [False],
'default_priority_func': [exp_return_of_episode],
'uniform_ratio': [0, 0.1, 1.0],
'return_sample_weights': [True, False],
}) +
testing.product({
'capacity': [100],
'normalize_by_max': ['batch', 'memory'],
'wait_priority_after_sampling': [True],
'default_priority_func': [None, exp_return_of_episode],
'uniform_ratio': [0, 0.1, 1.0],
Expand All @@ -340,6 +346,7 @@ class TestPrioritizedEpisodicReplayBuffer(unittest.TestCase):
def test_append_and_sample(self):
rbuf = replay_buffer.PrioritizedEpisodicReplayBuffer(
capacity=self.capacity,
normalize_by_max=self.normalize_by_max,
default_priority_func=self.default_priority_func,
uniform_ratio=self.uniform_ratio,
wait_priority_after_sampling=self.wait_priority_after_sampling,
Expand Down

0 comments on commit 4d32a21

Please sign in to comment.