Skip to content

Commit

Permalink
Merge pull request #506 from prabhatnagarajan/replay_module
Browse files Browse the repository at this point in the history
Splits Replay Buffers into separate files in a replay_buffers module
  • Loading branch information
prabhatnagarajan committed Jul 16, 2019
2 parents b73156d + 57d1203 commit 2c335e2
Show file tree
Hide file tree
Showing 9 changed files with 413 additions and 308 deletions.
9 changes: 9 additions & 0 deletions chainerrl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from chainerrl import q_functions # NOQA
from chainerrl import recurrent # NOQA
from chainerrl import replay_buffer # NOQA
from chainerrl import replay_buffers # NOQA
from chainerrl import v_function # NOQA
from chainerrl import v_functions # NOQA
from chainerrl import wrappers # NOQA
Expand Down Expand Up @@ -48,5 +49,13 @@
q_function.FCBNQuadraticStateQFunction = \
q_functions.FCBNQuadraticStateQFunction

replay_buffer.ReplayBuffer = replay_buffers.ReplayBuffer
replay_buffer.PriorityWeightError = replay_buffers.PriorityWeightError
replay_buffer.PrioritizedReplayBuffer = \
replay_buffers.PrioritizedReplayBuffer
replay_buffer.EpisodicReplayBuffer = replay_buffers.EpisodicReplayBuffer
replay_buffer.PrioritizedEpisodicReplayBuffer = \
replay_buffers.PrioritizedEpisodicReplayBuffer

v_function.SingleModelVFunction = v_functions.SingleModelVFunction
v_function.FCVFunction = v_functions.FCVFunction
308 changes: 0 additions & 308 deletions chainerrl/replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,10 @@
from abc import ABCMeta
from abc import abstractmethod
from abc import abstractproperty
import collections

import numpy as np
import six.moves.cPickle as pickle

from chainerrl.misc.batch_states import batch_states
from chainerrl.misc.collections import RandomAccessQueue
from chainerrl.misc.prioritized import PrioritizedBuffer


class AbstractReplayBuffer(with_metaclass(ABCMeta, object)):
Expand Down Expand Up @@ -136,172 +132,6 @@ def stop_current_episode(self, env_id=0):
raise NotImplementedError


class ReplayBuffer(AbstractReplayBuffer):

def __init__(self, capacity=None, num_steps=1):
self.capacity = capacity
assert num_steps > 0
self.num_steps = num_steps
self.memory = RandomAccessQueue(maxlen=capacity)
self.last_n_transitions = collections.defaultdict(
lambda: collections.deque([], maxlen=num_steps))

def append(self, state, action, reward, next_state=None, next_action=None,
is_state_terminal=False, env_id=0, **kwargs):
last_n_transitions = self.last_n_transitions[env_id]
experience = dict(
state=state,
action=action,
reward=reward,
next_state=next_state,
next_action=next_action,
is_state_terminal=is_state_terminal,
**kwargs
)
last_n_transitions.append(experience)
if is_state_terminal:
while last_n_transitions:
self.memory.append(list(last_n_transitions))
del last_n_transitions[0]
assert len(last_n_transitions) == 0
else:
if len(last_n_transitions) == self.num_steps:
self.memory.append(list(last_n_transitions))

def stop_current_episode(self, env_id=0):
last_n_transitions = self.last_n_transitions[env_id]
# if n-step transition hist is not full, add transition;
# if n-step hist is indeed full, transition has already been added;
if 0 < len(last_n_transitions) < self.num_steps:
self.memory.append(list(last_n_transitions))
# avoid duplicate entry
if 0 < len(last_n_transitions) <= self.num_steps:
del last_n_transitions[0]
while last_n_transitions:
self.memory.append(list(last_n_transitions))
del last_n_transitions[0]
assert len(last_n_transitions) == 0

def sample(self, num_experiences):
assert len(self.memory) >= num_experiences
return self.memory.sample(num_experiences)

def __len__(self):
return len(self.memory)

def save(self, filename):
with open(filename, 'wb') as f:
pickle.dump(self.memory, f)

def load(self, filename):
with open(filename, 'rb') as f:
self.memory = pickle.load(f)
if isinstance(self.memory, collections.deque):
# Load v0.2
self.memory = RandomAccessQueue(
self.memory, maxlen=self.memory.maxlen)


class PriorityWeightError(object):
"""For proportional prioritization
alpha determines how much prioritization is used.
beta determines how much importance sampling weights are used. beta is
scheduled by ``beta0`` and ``betasteps``.
Args:
alpha (float): Exponent of errors to compute probabilities to sample
beta0 (float): Initial value of beta
betasteps (float): Steps to anneal beta to 1
eps (float): To revisit a step after its error becomes near zero
normalize_by_max (str): Method to normalize weights. ``'batch'`` or
``True`` (default): divide by the maximum weight in the sampled
batch. ``'memory'``: divide by the maximum weight in the memory.
``False``: do not normalize.
"""

def __init__(self, alpha, beta0, betasteps, eps, normalize_by_max,
error_min, error_max):
assert 0.0 <= alpha
assert 0.0 <= beta0 <= 1.0
self.alpha = alpha
self.beta = beta0
if betasteps is None:
self.beta_add = 0
else:
self.beta_add = (1.0 - beta0) / betasteps
self.eps = eps
if normalize_by_max is True:
normalize_by_max = 'batch'
assert normalize_by_max in [False, 'batch', 'memory']
self.normalize_by_max = normalize_by_max
self.error_min = error_min
self.error_max = error_max

def priority_from_errors(self, errors):

def _clip_error(error):
if self.error_min is not None:
error = max(self.error_min, error)
if self.error_max is not None:
error = min(self.error_max, error)
return error

return [(_clip_error(d) + self.eps) ** self.alpha for d in errors]

def weights_from_probabilities(self, probabilities, min_probability):
if self.normalize_by_max == 'batch':
# discard global min and compute batch min
min_probability = np.min(min_probability)
if self.normalize_by_max:
weights = [(p / min_probability) ** -self.beta
for p in probabilities]
else:
weights = [(len(self.memory) * p) ** -self.beta
for p in probabilities]
self.beta = min(1.0, self.beta + self.beta_add)
return weights


class PrioritizedReplayBuffer(ReplayBuffer, PriorityWeightError):
"""Stochastic Prioritization
https://arxiv.org/pdf/1511.05952.pdf Section 3.3
proportional prioritization
Args:
capacity (int)
alpha, beta0, betasteps, eps (float)
normalize_by_max (bool)
"""

def __init__(self, capacity=None,
alpha=0.6, beta0=0.4, betasteps=2e5, eps=0.01,
normalize_by_max=True, error_min=0,
error_max=1, num_steps=1):
self.capacity = capacity
assert num_steps > 0
self.num_steps = num_steps
self.memory = PrioritizedBuffer(capacity=capacity)
self.last_n_transitions = collections.defaultdict(
lambda: collections.deque([], maxlen=num_steps))
PriorityWeightError.__init__(
self, alpha, beta0, betasteps, eps, normalize_by_max,
error_min=error_min, error_max=error_max)

def sample(self, n):
assert len(self.memory) >= n
sampled, probabilities, min_prob = self.memory.sample(n)
weights = self.weights_from_probabilities(probabilities, min_prob)
for e, w in zip(sampled, weights):
e[0]['weight'] = w
return sampled

def update_errors(self, errors):
self.memory.set_last_priority(self.priority_from_errors(errors))


def random_subseq(seq, subseq_len):
if len(seq) <= subseq_len:
return seq
Expand All @@ -310,144 +140,6 @@ def random_subseq(seq, subseq_len):
return seq[i:i + subseq_len]


class EpisodicReplayBuffer(AbstractEpisodicReplayBuffer):

def __init__(self, capacity=None):
self.current_episode = collections.defaultdict(list)
self.episodic_memory = RandomAccessQueue()
self.memory = RandomAccessQueue()
self.capacity = capacity

def append(self, state, action, reward, next_state=None, next_action=None,
is_state_terminal=False, env_id=0, **kwargs):
current_episode = self.current_episode[env_id]
experience = dict(state=state, action=action, reward=reward,
next_state=next_state, next_action=next_action,
is_state_terminal=is_state_terminal,
**kwargs)
current_episode.append(experience)
if is_state_terminal:
self.stop_current_episode(env_id=env_id)

def sample(self, n):
assert len(self.memory) >= n
return self.memory.sample(n)

def sample_episodes(self, n_episodes, max_len=None):
assert len(self.episodic_memory) >= n_episodes
episodes = self.episodic_memory.sample(n_episodes)
if max_len is not None:
return [random_subseq(ep, max_len) for ep in episodes]
else:
return episodes

def __len__(self):
return len(self.memory)

@property
def n_episodes(self):
return len(self.episodic_memory)

def save(self, filename):
with open(filename, 'wb') as f:
pickle.dump((self.memory, self.episodic_memory), f)

def load(self, filename):
with open(filename, 'rb') as f:
memory = pickle.load(f)
if isinstance(memory, tuple):
self.memory, self.episodic_memory = memory
else:
# Load v0.2
# FIXME: The code works with EpisodicReplayBuffer
# but not with PrioritizedEpisodicReplayBuffer
self.memory = RandomAccessQueue(memory)
self.episodic_memory = RandomAccessQueue()

# Recover episodic_memory with best effort.
episode = []
for item in self.memory:
episode.append(item)
if item['is_state_terminal']:
self.episodic_memory.append(episode)
episode = []

def stop_current_episode(self, env_id=0):
current_episode = self.current_episode[env_id]
if current_episode:
self.episodic_memory.append(current_episode)
for transition in current_episode:
self.memory.append([transition])
self.current_episode[env_id] = []
while self.capacity is not None and \
len(self.memory) > self.capacity:
discarded_episode = self.episodic_memory.popleft()
for _ in range(len(discarded_episode)):
self.memory.popleft()
assert not self.current_episode[env_id]


class PrioritizedEpisodicReplayBuffer (
EpisodicReplayBuffer, PriorityWeightError):

def __init__(self, capacity=None,
alpha=0.6, beta0=0.4, betasteps=2e5, eps=1e-8,
normalize_by_max=True,
default_priority_func=None,
uniform_ratio=0,
wait_priority_after_sampling=True,
return_sample_weights=True,
error_min=None,
error_max=None,
):
self.current_episode = collections.defaultdict(list)
self.episodic_memory = PrioritizedBuffer(
capacity=None,
wait_priority_after_sampling=wait_priority_after_sampling)
self.memory = RandomAccessQueue(maxlen=capacity)
self.capacity_left = capacity
self.default_priority_func = default_priority_func
self.uniform_ratio = uniform_ratio
self.return_sample_weights = return_sample_weights
PriorityWeightError.__init__(
self, alpha, beta0, betasteps, eps, normalize_by_max,
error_min=error_min, error_max=error_max)

def sample_episodes(self, n_episodes, max_len=None):
"""Sample n unique samples from this replay buffer"""
assert len(self.episodic_memory) >= n_episodes
episodes, probabilities, min_prob = self.episodic_memory.sample(
n_episodes, uniform_ratio=self.uniform_ratio)
if max_len is not None:
episodes = [random_subseq(ep, max_len) for ep in episodes]
if self.return_sample_weights:
weights = self.weights_from_probabilities(probabilities, min_prob)
return episodes, weights
else:
return episodes

def update_errors(self, errors):
self.episodic_memory.set_last_priority(
self.priority_from_errors(errors))

def stop_current_episode(self, env_id=0):
current_episode = self.current_episode[env_id]
if current_episode:
if self.default_priority_func is not None:
priority = self.default_priority_func(current_episode)
else:
priority = None
self.memory.extend(current_episode)
self.episodic_memory.append(current_episode, priority=priority)
if self.capacity_left is not None:
self.capacity_left -= len(current_episode)
self.current_episode[env_id] = []
while self.capacity_left is not None and self.capacity_left < 0:
discarded_episode = self.episodic_memory.popleft()
self.capacity_left += len(discarded_episode)
assert not self.current_episode[env_id]


def batch_experiences(experiences, xp, phi, gamma, batch_states=batch_states):
"""Takes a batch of k experiences each of which contains j
Expand Down
5 changes: 5 additions & 0 deletions chainerrl/replay_buffers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from chainerrl.replay_buffers.episodic import EpisodicReplayBuffer # NOQA
from chainerrl.replay_buffers.prioritized import PrioritizedReplayBuffer # NOQA
from chainerrl.replay_buffers.prioritized import PriorityWeightError # NOQA
from chainerrl.replay_buffers.prioritized_episodic import PrioritizedEpisodicReplayBuffer # NOQA
from chainerrl.replay_buffers.replay_buffer import ReplayBuffer # NOQA
Loading

0 comments on commit 2c335e2

Please sign in to comment.