Permalink
Cannot retrieve contributors at this time
Name already in use
A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
cherry/cherry/experience_replay.py
Go to fileThis commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
423 lines (331 sloc)
12.9 KB
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
import random | |
import torch as th | |
import cherry as ch | |
from cherry._utils import _istensorable, _min_size | |
""" | |
TODO: replay.astype(dtype) + init dtype | |
""" | |
class Transition(object): | |
""" | |
**Description** | |
Represents a (s, a, r, s', d) tuple. | |
All attributes (including the ones in infos) are accessible via | |
`transition.name_of_attr`. | |
(e.g. `transition.log_prob` if `log_prob` is in `infos`.) | |
**Arguments** | |
* **state** (tensor) - Originating state. | |
* **action** (tensor) - Executed action. | |
* **reward** (tensor) - Observed reward. | |
* **next_state** (tensor) - Resulting state. | |
* **done** (tensor) - Is `next_state` a terminal (absorbing) state ? | |
* **infos** (dict, *optional*, default=None) - Additional information on | |
the transition. | |
**Example** | |
~~~python | |
for transition in replay: | |
print(transition.state) | |
~~~ | |
""" | |
def __init__(self, | |
state, | |
action, | |
reward, | |
next_state, | |
done, | |
device=None, | |
**infos): | |
self.__fields = ['state', 'action', 'reward', 'next_state', 'done'] | |
values = [state, action, reward, next_state, done] | |
for key, val in zip(self.__fields, values): | |
setattr(self, key, val) | |
info_keys = infos.keys() | |
self.__fields += info_keys | |
for key in info_keys: | |
setattr(self, key, infos[key]) | |
self.device = device | |
def __str__(self): | |
string = 'Transition(' + ', '.join(self.__fields) | |
if self.device is not None: | |
string += ', device=\'' + str(self.device) + '\'' | |
string += ')' | |
return string | |
def __repr__(self): | |
return str(self) | |
def cpu(self): | |
return self.to('cpu') | |
def cuda(self, device=0, *args, **kwargs): | |
return self.to('cuda:' + str(device), *args, **kwargs) | |
def _apply(self, fn, device=None): | |
if device is None: | |
device = self.device | |
new_transition = {'device': device} | |
for field in self.__fields: | |
value = getattr(self, field) | |
if isinstance(value, th.Tensor): | |
new_transition[field] = fn(value) | |
else: | |
new_transition[field] = value | |
return Transition(**new_transition) | |
def to(self, *args, **kwargs): | |
""" | |
**Description** | |
Moves the constituents of the transition to the desired device, | |
and casts them to the desired format. | |
Note: This is done in-place and doesn't create a new transition. | |
**Arguments** | |
* **device** (device, *optional*, default=None) - The device to move the data to. | |
* **dtype** (dtype, *optional*, default=None) - The torch.dtype format to cast to. | |
* **non_blocking** (bool, *optional*, default=False) - Whether to perform the move asynchronously. | |
**Example** | |
~~~python | |
sars = Transition(state, action, reward, next_state) | |
sars.to('cuda') | |
~~~ | |
""" | |
device, dtype, non_blocking, *_ = th._C._nn._parse_to(*args, **kwargs) | |
return self._apply(lambda t: t.to(device, dtype if t.is_floating_point() else None, non_blocking), device) | |
def half(self): | |
return self._apply(lambda t: t.half() if t.is_floating_point() else t) | |
def double(self): | |
return self._apply(lambda t: t.double() if t.is_floating_point() else t) | |
class ExperienceReplay(list): | |
""" | |
[[Source]](https://github.com/seba-1511/cherry/blob/master/cherry/experience_replay.py) | |
**Description** | |
Experience replay buffer to store, retrieve, and sample past transitions. | |
`ExperienceReplay` behaves like a list of transitions, . | |
It also support accessing specific properties, such as states, actions, | |
rewards, next_states, and informations. | |
The first four are returned as tensors, while `infos` is returned as | |
a list of dicts. | |
The properties of infos can be accessed directly by appending an `s` to | |
their dictionary key -- see Examples below. | |
In this case, if the values of the infos are tensors, they will be returned | |
as a concatenated Tensor. | |
Otherwise, they default to a list of values. | |
**Arguments** | |
* **states** (Tensor, *optional*, default=None) - Tensor of states. | |
* **actions** (Tensor, *optional*, default=None) - Tensor of actions. | |
* **rewards** (Tensor, *optional*, default=None) - Tensor of rewards. | |
* **next_states** (Tensor, *optional*, default=None) - Tensor of | |
next_states. | |
* **dones** (Tensor, *optional*, default=None) - Tensor of dones. | |
* **infos** (list, *optional*, default=None) - List of infos. | |
**References** | |
1. Lin, Long-Ji. 1992. “Self-Improving Reactive Agents Based on Reinforcement Learning, Planning and Teaching.” Machine Learning 8 (3): 293–321. | |
**Example** | |
~~~python | |
replay = ch.ExperienceReplay() # Instanciate a new replay | |
replay.append(state, # Add experience to the replay | |
action, | |
reward, | |
next_state, | |
done, | |
density: action_density, | |
log_prob: action_density.log_prob(action), | |
) | |
replay.state() # Tensor of states | |
replay.action() # Tensor of actions | |
replay.density() # list of action_density | |
replay.log_prob() # Tensor of log_probabilities | |
new_replay = replay[-10:] # Last 10 transitions in new_replay | |
#Sample some previous experience | |
batch = replay.sample(32, contiguous=True) | |
~~~ | |
""" | |
def __init__(self, storage=None, device=None): | |
list.__init__(self) | |
if storage is None: | |
storage = [] | |
self._storage = storage | |
self.device = device | |
def _access_property(self, name): | |
try: | |
values = [getattr(sars, name) for sars in self._storage] | |
except AttributeError: | |
msg = 'Attribute ' + name + ' not in replay.' | |
raise AttributeError(msg) | |
true_size = _min_size(values[0]) | |
return th.cat(values, dim=0).view(len(values), *true_size) | |
def __getslice__(self, i, j): | |
return self.__getitem__(slice(i, j)) | |
def __len__(self): | |
return len(self._storage) | |
def __str__(self): | |
string = 'ExperienceReplay(' + str(len(self)) | |
if self.device is not None: | |
string += ', device=\'' + str(self.device) + '\'' | |
string += ')' | |
return string | |
def __repr__(self): | |
return str(self) | |
def __add__(self, other): | |
storage = self._storage + other._storage | |
return ExperienceReplay(storage) | |
def __iadd__(self, other): | |
self._storage += other._storage | |
return self | |
def __iter__(self): | |
for i in range(len(self)): | |
yield self[i] | |
def __getattr__(self, attr): | |
return lambda: self._access_property(attr) | |
def __getitem__(self, key): | |
value = self._storage[key] | |
if isinstance(key, slice): | |
return ExperienceReplay(value) | |
return value | |
def save(self, path): | |
""" | |
**Description** | |
Serializes and saves the ExperienceReplay into the given path. | |
**Arguments** | |
* **path** (str) - File path. | |
**Example** | |
~~~python | |
replay.save('my_replay_file.pt') | |
~~~ | |
""" | |
th.save(self._storage, path) | |
def load(self, path): | |
""" | |
**Description** | |
Loads data from a serialized ExperienceReplay. | |
**Arguments** | |
* **path** (str) - File path of serialized ExperienceReplay. | |
**Example** | |
~~~python | |
replay.load('my_replay_file.pt') | |
~~~ | |
""" | |
self._storage = th.load(path) | |
def append(self, | |
state=None, | |
action=None, | |
reward=None, | |
next_state=None, | |
done=None, | |
**infos): | |
""" | |
**Description** | |
Appends new data to the list ExperienceReplay. | |
**Arguments** | |
* **state** (tensor/ndarray/list) - Originating state. | |
* **action** (tensor/ndarray/list) - Executed action. | |
* **reward** (tensor/ndarray/list) - Observed reward. | |
* **next_state** (tensor/ndarray/list) - Resulting state. | |
* **done** (tensor/bool) - Is `next_state` a terminal (absorbing) | |
state ? | |
* **infos** (dict, *optional*, default=None) - Additional information | |
on the transition. | |
**Example** | |
~~~python | |
replay.append(state, action, reward, next_state, done, info={ | |
'density': density, | |
'log_prob': density.log_prob(action), | |
}) | |
~~~ | |
""" | |
for key in infos: | |
if _istensorable(infos[key]): | |
infos[key] = ch.totensor(infos[key]) | |
sars = Transition(ch.totensor(state), | |
ch.totensor(action), | |
ch.totensor(reward), | |
ch.totensor(next_state), | |
ch.totensor(done), | |
**infos) | |
self._storage.append(sars.to(self.device)) | |
def sample(self, size=1, contiguous=False, episodes=False): | |
""" | |
Samples from the Experience replay. | |
**Arguments** | |
* **size** (int, *optional*, default=1) - The number of samples. | |
* **contiguous** (bool, *optional*, default=False) - Whether to sample | |
contiguous transitions. | |
* **episodes** (bool, *optional*, default=False) - Sample full | |
episodes, instead of transitions. | |
**Return** | |
* ExperienceReplay - New ExperienceReplay containing the sampled | |
transitions. | |
""" | |
if len(self) < 1 or size < 1: | |
return ExperienceReplay() | |
indices = [] | |
if episodes: | |
if size > 1 and not contiguous: | |
replay = ExperienceReplay() | |
return sum([self.sample(1, episodes=True) for _ in range(size)], replay) | |
else: # Sample 'size' contiguous episodes | |
num_episodes = self.done().sum().int().item() | |
end = random.randint(size-1, num_episodes-size) | |
# Find idx of the end-th done | |
count = 0 | |
dones = self.done() | |
for idx, d in reversed(list(enumerate(dones))): | |
if bool(d): | |
if count >= end: | |
end_idx = idx | |
break | |
count += 1 | |
# Fill indices | |
indices.insert(0, end_idx) | |
count = 0 | |
for idx in reversed(range(0, end_idx)): | |
if bool(dones[idx]): | |
count += 1 | |
if count >= size: | |
break | |
indices.insert(0, idx) | |
else: | |
length = len(self) - 1 | |
if contiguous: | |
start = random.randint(0, length - size) | |
indices = list(range(start, start + size)) | |
else: | |
indices = [random.randint(0, length) for _ in range(size)] | |
# Fill the sample | |
storage = [self[idx] for idx in indices] | |
return ExperienceReplay(storage) | |
def empty(self): | |
""" | |
**Description** | |
Removes all data from an ExperienceReplay. | |
**Example** | |
~~~python | |
replay.empty() | |
~~~ | |
""" | |
self._storage = [] | |
def cpu(self): | |
return self.to('cpu') | |
def cuda(self, device=0, *args, **kwargs): | |
return self.to('cuda:' + str(device), *args, **kwargs) | |
def to(self, *args, **kwargs): | |
""" | |
**Description** | |
Calls `.to()` on all transitions of the experience replay, moving them to the | |
desired device and casting the to the desired format. | |
Note: This return a new experience replay, but the transitions are modified in-place. | |
**Arguments** | |
* **device** (device, *optional*, default=None) - The device to move the data to. | |
* **dtype** (dtype, *optional*, default=None) - The torch.dtype format to cast to. | |
* **non_blocking** (bool, *optional*, default=False) - Whether to perform the move asynchronously. | |
**Example** | |
~~~python | |
replay.to('cuda:1') | |
policy.to('cuda:1') | |
for sars in replay: | |
cuda_action = policy(sars.state).sample() | |
~~~ | |
""" | |
device, dtype, non_blocking, *_ = th._C._nn._parse_to(*args, **kwargs) | |
storage = [sars.to(*args, **kwargs) for sars in self._storage] | |
return ExperienceReplay(storage, device=device) | |
def half(self): | |
storage = [sars.half() for sars in self._storage] | |
return ExperienceReplay(storage, device=self.device) | |
def double(self): | |
storage = [sars.double() for sars in self._storage] | |
return ExperienceReplay(storage, device=self.device) |