Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
9 changed files
with
1,064 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,276 @@ | ||
from __future__ import division | ||
from __future__ import print_function | ||
from __future__ import unicode_literals | ||
from __future__ import absolute_import | ||
from builtins import * # NOQA | ||
from future import standard_library | ||
standard_library.install_aliases() | ||
|
||
from logging import getLogger | ||
|
||
import chainer | ||
from chainer import functions as F | ||
|
||
from chainerrl import agent | ||
from chainerrl.misc.batch_states import batch_states | ||
from chainerrl.recurrent import RecurrentChainMixin | ||
|
||
logger = getLogger(__name__) | ||
|
||
|
||
class A2CModel(chainer.Link): | ||
"""A2C model.""" | ||
|
||
def pi_and_v(self, obs): | ||
"""Evaluate the policy and the V-function. | ||
Args: | ||
obs (Variable or ndarray): Batched observations. | ||
Returns: | ||
Distribution and Variable | ||
""" | ||
raise NotImplementedError() | ||
|
||
def __call__(self, obs): | ||
return self.pi_and_v(obs) | ||
|
||
|
||
class A2CSeparateModel(chainer.Chain, A2CModel, RecurrentChainMixin): | ||
"""A2C model that consists of a separate policy and V-function. | ||
Args: | ||
pi (Policy): Policy. | ||
v (VFunction): V-function. | ||
""" | ||
|
||
def __init__(self, pi, v): | ||
super().__init__(pi=pi, v=v) | ||
|
||
def pi_and_v(self, obs): | ||
pout = self.pi(obs) | ||
vout = self.v(obs) | ||
return pout, vout | ||
|
||
|
||
class A2C(agent.AttributeSavingMixin): | ||
"""A2C: Advantage Actor-Critic. | ||
A2C is a synchronous, deterministic variant of Asynchronous Advantage | ||
Actor Critic (A3C). | ||
See https://arxiv.org/abs/1708.05144 | ||
Args: | ||
model (A2CModel): Model to train | ||
optimizer (chainer.Optimizer): optimizer used to train the model | ||
gamma (float): Discount factor [0,1] | ||
num_processes (int): The number of processes | ||
gpu (int): GPU device id if not None nor negative. | ||
update_steps (int): The number of update steps | ||
phi (callable): Feature extractor function | ||
pi_loss_coef (float): Weight coefficient for the loss of the policy | ||
v_loss_coef (float): Weight coefficient for the loss of the value | ||
function | ||
entropy_coeff (float): Weight coefficient for the loss of the entropy | ||
use_gae (bool): use generalized advantage estimation(GAE) | ||
tau (float): gae parameter | ||
average_actor_loss_decay (float): Decay rate of average actor loss. | ||
Used only to record statistics. | ||
average_entropy_decay (float): Decay rate of average entropy. Used only | ||
to record statistics. | ||
average_value_decay (float): Decay rate of average value. Used only | ||
to record statistics. | ||
act_deterministically (bool): If set true, choose most probable actions | ||
in act method. | ||
batch_states (callable): method which makes a batch of observations. | ||
default is `chainerrl.misc.batch_states.batch_states` | ||
""" | ||
|
||
process_idx = None | ||
saved_attributes = ['model', 'optimizer'] | ||
|
||
def __init__(self, model, optimizer, gamma, num_processes, | ||
gpu=None, | ||
update_steps=5, | ||
phi=lambda x: x, | ||
pi_loss_coef=1.0, | ||
v_loss_coef=0.5, | ||
entropy_coeff=0.01, | ||
use_gae=False, | ||
tau=0.95, | ||
act_deterministically=False, | ||
average_actor_loss_decay=0.999, | ||
average_entropy_decay=0.999, | ||
average_value_decay=0.999, | ||
batch_states=batch_states): | ||
|
||
assert isinstance(model, A2CModel) | ||
|
||
self.model = model | ||
self.gpu = gpu | ||
if gpu is not None and gpu >= 0: | ||
chainer.cuda.get_device(gpu).use() | ||
self.model.to_gpu(device=gpu) | ||
|
||
self.optimizer = optimizer | ||
|
||
self.update_steps = update_steps | ||
self.num_processes = num_processes | ||
|
||
self.gamma = gamma | ||
self.use_gae = use_gae | ||
self.tau = tau | ||
self.act_deterministically = act_deterministically | ||
self.phi = phi | ||
self.pi_loss_coef = pi_loss_coef | ||
self.v_loss_coef = v_loss_coef | ||
self.entropy_coeff = entropy_coeff | ||
|
||
self.average_actor_loss_decay = average_actor_loss_decay | ||
self.average_value_decay = average_value_decay | ||
self.average_entropy_decay = average_entropy_decay | ||
self.batch_states = batch_states | ||
|
||
self.xp = self.model.xp | ||
self.t = 0 | ||
self.t_start = 0 | ||
|
||
# Stats | ||
self.average_actor_loss = 0 | ||
self.average_value = 0 | ||
self.average_entropy = 0 | ||
|
||
def _flush_storage(self, obs_shape, action): | ||
obs_shape = obs_shape[1:] | ||
action_shape = action.shape[1:] | ||
|
||
self.states = self.xp.zeros( | ||
[self.update_steps + 1, self.num_processes] + list(obs_shape), | ||
dtype='f') | ||
self.actions = self.xp.zeros( | ||
[self.update_steps, self.num_processes] + list(action_shape), | ||
dtype=action.dtype) | ||
self.rewards = self.xp.zeros( | ||
(self.update_steps, self.num_processes, 1), dtype='f') | ||
self.value_preds = self.xp.zeros( | ||
(self.update_steps + 1, self.num_processes, 1), dtype='f') | ||
self.returns = self.xp.zeros( | ||
(self.update_steps + 1, self.num_processes, 1), dtype='f') | ||
self.masks = self.xp.ones( | ||
(self.update_steps, self.num_processes, 1), dtype='f') | ||
|
||
self.obs_shape = obs_shape | ||
self.action_shape = action_shape | ||
|
||
def _compute_returns(self, next_value): | ||
if self.use_gae: | ||
self.value_preds[-1] = next_value | ||
gae = 0 | ||
for i in reversed(range(self.update_steps)): | ||
delta = self.rewards[i] + \ | ||
self.gamma * self.value_preds[i + 1] * self.masks[i] - \ | ||
self.value_preds[i] | ||
gae = delta + self.gamma * self.tau * self.masks[i] * gae | ||
self.returns[i] = gae + self.value_preds[i] | ||
else: | ||
self.returns[-1] = next_value | ||
for i in reversed(range(self.update_steps)): | ||
self.returns[i] = self.rewards[i] + \ | ||
self.gamma * self.returns[i + 1] * self.masks[i] | ||
|
||
def update(self): | ||
with chainer.no_backprop_mode(): | ||
_, next_value = self.model.pi_and_v(self.states[-1]) | ||
next_value = next_value.data | ||
|
||
self._compute_returns(next_value) | ||
pout, values = \ | ||
self.model.pi_and_v(chainer.Variable( | ||
self.states[:-1].reshape([-1] + list(self.obs_shape)))) | ||
|
||
actions = chainer.Variable( | ||
self.actions.reshape([-1] + list(self.action_shape))) | ||
dist_entropy = F.mean(pout.entropy) | ||
action_log_probs = pout.log_prob(actions) | ||
|
||
values = values.reshape(self.update_steps, self.num_processes, 1) | ||
action_log_probs = action_log_probs.reshape( | ||
self.update_steps, self.num_processes, 1) | ||
advantages = chainer.Variable(self.returns[:-1]) - values | ||
value_loss = F.mean(advantages * advantages) | ||
action_loss = \ | ||
- F.mean(chainer.Variable(advantages.data) * action_log_probs) | ||
|
||
self.model.cleargrads() | ||
|
||
(value_loss * self.v_loss_coef + | ||
action_loss * self.pi_loss_coef - | ||
dist_entropy * self.entropy_coeff).backward() | ||
|
||
self.optimizer.update() | ||
self.states[0] = self.states[-1] | ||
|
||
self.t_start = self.t | ||
|
||
# Update stats | ||
self.average_actor_loss += ( | ||
(1 - self.average_actor_loss_decay) * | ||
(float(action_loss.data) - self.average_actor_loss)) | ||
self.average_value += ( | ||
(1 - self.average_value_decay) * | ||
(float(value_loss.data) - self.average_value)) | ||
self.average_entropy += ( | ||
(1 - self.average_entropy_decay) * | ||
(float(dist_entropy.data) - self.average_entropy)) | ||
|
||
def act_and_train(self, state, reward, done): | ||
statevar = self.batch_states([state], self.xp, self.phi)[0] | ||
|
||
if self.t == 0: | ||
pout, _ = self.model.pi_and_v(statevar[0:1]) | ||
action = pout.sample().data | ||
self._flush_storage(state.shape, action) | ||
|
||
self.rewards[self.t - self.t_start - | ||
1] = self.xp.array(reward, dtype=self.xp.float32) | ||
self.states[self.t - self.t_start] = statevar | ||
self.masks[self.t - self.t_start - 1] \ | ||
= self.xp.array([[0.0] if done_ else [1.0] for done_ in done]) | ||
|
||
if self.t - self.t_start == self.update_steps: | ||
self.update() | ||
|
||
with chainer.no_backprop_mode(): | ||
pout, value = self.model.pi_and_v(statevar) | ||
|
||
action = pout.sample().data | ||
|
||
self.actions[self.t - self.t_start] \ | ||
= action.reshape([-1] + list(self.action_shape)) | ||
self.value_preds[self.t - self.t_start] = value.data | ||
|
||
self.t += 1 | ||
|
||
return chainer.cuda.to_cpu(action) | ||
|
||
def act(self, obs): | ||
with chainer.no_backprop_mode(): | ||
statevar = self.batch_states([obs], self.xp, self.phi) | ||
pout, _ = self.model.pi_and_v(statevar) | ||
if self.act_deterministically: | ||
return chainer.cuda.to_cpu(pout.most_probable.data)[0] | ||
else: | ||
return chainer.cuda.to_cpu(pout.sample().data)[0] | ||
|
||
def stop_episode_and_train(self, state, reward, done=False): | ||
pass | ||
|
||
def stop_episode(self): | ||
pass | ||
|
||
def get_statistics(self): | ||
return [ | ||
('average_actor', self.average_actor_loss), | ||
('average_value', self.average_value), | ||
('average_entropy', self.average_entropy), | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
from __future__ import unicode_literals | ||
from __future__ import print_function | ||
from __future__ import division | ||
from __future__ import absolute_import | ||
from builtins import * # NOQA | ||
from future import standard_library | ||
standard_library.install_aliases() | ||
|
||
from multiprocessing import Pipe | ||
from multiprocessing import Process | ||
|
||
from cached_property import cached_property | ||
import numpy as np | ||
|
||
from chainerrl import env | ||
|
||
|
||
def worker(remote, env_fn_wrapper): | ||
env = env_fn_wrapper | ||
while True: | ||
cmd, data = remote.recv() | ||
if cmd == 'step': | ||
ob, reward, done, info = env.step(data) | ||
if done: | ||
ob = env.reset() | ||
remote.send((ob, reward, done, info)) | ||
elif cmd == 'reset': | ||
ob = env.reset() | ||
remote.send(ob) | ||
elif cmd == 'close': | ||
remote.close() | ||
break | ||
elif cmd == 'get_spaces': | ||
remote.send((env.action_space, env.observation_space)) | ||
elif cmd == 'spec': | ||
remote.send(env.spec) | ||
elif cmd == 'seed': | ||
remote.send(env.seed(data)) | ||
else: | ||
raise NotImplementedError | ||
|
||
|
||
class VectorEnv(env.Env): | ||
|
||
def __init__(self, env_fns): | ||
"""envs: list of gym environments to run in subprocesses | ||
""" | ||
nenvs = len(env_fns) | ||
self.remotes, self.work_remotes = zip(*[Pipe() for _ in range(nenvs)]) | ||
self.ps = \ | ||
[Process(target=worker, args=(work_remote, env_fn)) | ||
for (work_remote, env_fn) in zip(self.work_remotes, env_fns)] | ||
for p in self.ps: | ||
p.start() | ||
|
||
self.remotes[0].send(('get_spaces', None)) | ||
self.action_space, self.observation_space = self.remotes[0].recv() | ||
|
||
@cached_property | ||
def spec(self): | ||
self.remotes[0].send(('spec', None)) | ||
spec = self.remotes[0].recv() | ||
return spec | ||
|
||
def step(self, actions): | ||
for remote, action in zip(self.remotes, actions): | ||
remote.send(('step', action)) | ||
results = [remote.recv() for remote in self.remotes] | ||
obs, rews, dones, infos = zip(*results) | ||
return np.stack(obs), np.stack(rews), np.stack(dones), infos | ||
|
||
def reset(self): | ||
for remote in self.remotes: | ||
remote.send(('reset', None)) | ||
return np.stack([remote.recv() for remote in self.remotes]) | ||
|
||
def close(self): | ||
for remote in self.remotes: | ||
remote.send(('close', None)) | ||
for p in self.ps: | ||
p.join() | ||
|
||
def seed(self, seeds=None): | ||
if seeds is not None: | ||
if isinstance(seeds, int): | ||
seeds = [seeds] * self.num_envs | ||
elif isinstance(seeds, list): | ||
if len(seeds) != self.num_envs: | ||
raise ValueError( | ||
"length of seeds must be same as num_envs {}" | ||
.format(self.num_envs)) | ||
else: | ||
raise TypeError( | ||
"Type of Seeds {} is not supported.".format(type(seeds))) | ||
else: | ||
seeds = [None] * self.num_envs | ||
|
||
for remote, seed in zip(self.remotes, seeds): | ||
remote.send(('seed', seed)) | ||
results = [remote.recv() for remote in self.remotes] | ||
return results | ||
|
||
@property | ||
def num_envs(self): | ||
return len(self.remotes) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.