-
Notifications
You must be signed in to change notification settings - Fork 226
/
prioritized.py
120 lines (100 loc) · 4.51 KB
/
prioritized.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
from __future__ import print_function
from __future__ import unicode_literals
from __future__ import division
from __future__ import absolute_import
from builtins import * # NOQA
from future import standard_library
standard_library.install_aliases() # NOQA
import collections
import numpy as np
from chainerrl.misc.prioritized import PrioritizedBuffer
from chainerrl.replay_buffers.replay_buffer import ReplayBuffer # NOQA
class PriorityWeightError(object):
"""For proportional prioritization
alpha determines how much prioritization is used.
beta determines how much importance sampling weights are used. beta is
scheduled by ``beta0`` and ``betasteps``.
Args:
alpha (float): Exponent of errors to compute probabilities to sample
beta0 (float): Initial value of beta
betasteps (float): Steps to anneal beta to 1
eps (float): To revisit a step after its error becomes near zero
normalize_by_max (str): Method to normalize weights. ``'batch'`` or
``True`` (default): divide by the maximum weight in the sampled
batch. ``'memory'``: divide by the maximum weight in the memory.
``False``: do not normalize.
"""
def __init__(self, alpha, beta0, betasteps, eps, normalize_by_max,
error_min, error_max):
assert 0.0 <= alpha
assert 0.0 <= beta0 <= 1.0
self.alpha = alpha
self.beta = beta0
if betasteps is None:
self.beta_add = 0
else:
self.beta_add = (1.0 - beta0) / betasteps
self.eps = eps
if normalize_by_max is True:
normalize_by_max = 'batch'
assert normalize_by_max in [False, 'batch', 'memory']
self.normalize_by_max = normalize_by_max
self.error_min = error_min
self.error_max = error_max
def priority_from_errors(self, errors):
def _clip_error(error):
if self.error_min is not None:
error = max(self.error_min, error)
if self.error_max is not None:
error = min(self.error_max, error)
return error
return [(_clip_error(d) + self.eps) ** self.alpha for d in errors]
def weights_from_probabilities(self, probabilities, min_probability):
if self.normalize_by_max == 'batch':
# discard global min and compute batch min
min_probability = np.min(probabilities)
if self.normalize_by_max:
weights = [(p / min_probability) ** -self.beta
for p in probabilities]
else:
weights = [(len(self.memory) * p) ** -self.beta
for p in probabilities]
self.beta = min(1.0, self.beta + self.beta_add)
return weights
class PrioritizedReplayBuffer(ReplayBuffer, PriorityWeightError):
"""Stochastic Prioritization
https://arxiv.org/pdf/1511.05952.pdf Section 3.3
proportional prioritization
Args:
capacity (int): capacity in terms of number of transitions
alpha (float): Exponent of errors to compute probabilities to sample
beta0 (float): Initial value of beta
betasteps (int): Steps to anneal beta to 1
eps (float): To revisit a step after its error becomes near zero
normalize_by_max (bool): Method to normalize weights. ``'batch'`` or
``True`` (default): divide by the maximum weight in the sampled
batch. ``'memory'``: divide by the maximum weight in the memory.
``False``: do not normalize
"""
def __init__(self, capacity=None,
alpha=0.6, beta0=0.4, betasteps=2e5, eps=0.01,
normalize_by_max=True, error_min=0,
error_max=1, num_steps=1):
self.capacity = capacity
assert num_steps > 0
self.num_steps = num_steps
self.memory = PrioritizedBuffer(capacity=capacity)
self.last_n_transitions = collections.defaultdict(
lambda: collections.deque([], maxlen=num_steps))
PriorityWeightError.__init__(
self, alpha, beta0, betasteps, eps, normalize_by_max,
error_min=error_min, error_max=error_max)
def sample(self, n):
assert len(self.memory) >= n
sampled, probabilities, min_prob = self.memory.sample(n)
weights = self.weights_from_probabilities(probabilities, min_prob)
for e, w in zip(sampled, weights):
e[0]['weight'] = w
return sampled
def update_errors(self, errors):
self.memory.set_last_priority(self.priority_from_errors(errors))