Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add PPO agent #126

Merged
merged 93 commits into from
Nov 13, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
93 commits
Select commit Hold shift + click to select a range
c3737aa
.
toslunar Jul 24, 2017
3d823c1
debug
toslunar Jul 24, 2017
94a4cc7
..
toslunar Jul 24, 2017
ba87a4f
debug
toslunar Jul 24, 2017
c11fe88
use pessimistic loss for value function
toslunar Jul 27, 2017
c37ff02
flake8
toslunar Jul 27, 2017
5438831
fix default lr
toslunar Jul 27, 2017
164fcf5
add option not to clip vf
toslunar Jul 27, 2017
eddf28e
example ppo
toslunar Jul 28, 2017
a1cceaa
test train_ppo
toslunar Jul 28, 2017
6981ebd
add docstring
toslunar Jul 28, 2017
c31a76c
add test_ppo
toslunar Jul 31, 2017
c02b419
weaken test
toslunar Jul 31, 2017
6fb6547
Add gpu support
toslunar Jul 31, 2017
10ce5e2
debug
toslunar Aug 1, 2017
ca80279
debug gpu codes
toslunar Aug 1, 2017
aa30724
flake8
toslunar Aug 1, 2017
999a8f5
add statistics
toslunar Aug 2, 2017
7bf31c1
add A3CFFGaussian
toslunar Aug 2, 2017
406291c
gaussian policy with state-independent variance
toslunar Aug 3, 2017
7197d63
.
toslunar Aug 3, 2017
7da765b
debug
toslunar Aug 8, 2017
7320202
cp
toslunar Aug 13, 2017
0217b2f
ppo
toslunar Aug 13, 2017
b138589
Use parameters for Atari
toslunar Aug 13, 2017
5ba74c2
Add the argument phi: feature extractor func
toslunar Aug 13, 2017
02a2625
flake8
toslunar Aug 13, 2017
983e7d7
Fix citation
toslunar Aug 14, 2017
2cd2911
Debug
toslunar Aug 14, 2017
f1ffefd
Add ale/ppo
toslunar Aug 14, 2017
f1cd1e2
.
toslunar Aug 14, 2017
bd58b70
make average_v scalar
toslunar Aug 15, 2017
e9ac663
fix hyperparameters
toslunar Aug 15, 2017
9f8b112
Stop initializing empty matrices
toslunar Aug 15, 2017
fcd579d
Add comments
toslunar Aug 16, 2017
da33ed8
Remove unused codes
toslunar Aug 16, 2017
a259206
Use chainer.Parameter
toslunar Aug 16, 2017
5f4dcd1
Remove unused import
toslunar Aug 16, 2017
8049efd
Use bound_mean for Gaussian policy
toslunar Aug 16, 2017
8bc4193
Refactor
toslunar Aug 16, 2017
445ca1e
Debug
toslunar Aug 16, 2017
7503e5e
Add tests: PPO with continuous action
toslunar Aug 16, 2017
58ba8bd
Support var_type='spherical'
toslunar Aug 16, 2017
7b13f53
Add param annealing
toslunar Aug 16, 2017
38081fb
Add comments on the difference and remove TODO
toslunar Aug 16, 2017
b7895e9
Avoid 0-dim array issues
toslunar Aug 16, 2017
9bbd22d
Use env's member variable at cond
toslunar Aug 16, 2017
c9a85a8
Bugfix
toslunar Aug 21, 2017
384734f
to_cpu stats
toslunar Aug 21, 2017
092f461
to_cpu V-value
toslunar Aug 21, 2017
031cf77
Bugfix
toslunar Aug 21, 2017
365d782
Rename _F_clip to _elementwise_clip
toslunar Aug 21, 2017
a5dc30f
Disable train on act (regardless of eval)
toslunar Aug 21, 2017
0b29fb8
Use logging.basicConfig
toslunar Aug 21, 2017
af36c35
Revert "Use logging.basicConfig"
toslunar Aug 21, 2017
e8cad95
Remove unused arg
toslunar Aug 21, 2017
4d0e81e
Python 3 features
toslunar Aug 21, 2017
605c50b
Fix import order
toslunar Aug 21, 2017
8f02c43
Change default clip_eps_vf
toslunar Aug 28, 2017
6449be2
Add test parameter
toslunar Aug 28, 2017
5a642bd
hyperparameter
toslunar Aug 29, 2017
f0031a6
Change hyperparams
toslunar Sep 5, 2017
b109871
Decay clip_eps
toslunar Sep 26, 2017
ca1e1b9
standardize advantages on update
toslunar Oct 13, 2017
23b6f93
add test param
toslunar Oct 13, 2017
910b0e7
add example arg
toslunar Oct 13, 2017
e150b9d
coeff -> coef
toslunar Oct 13, 2017
85f70b4
specify non-mujoco env
toslunar Oct 13, 2017
70582ca
fix
toslunar Oct 19, 2017
e81d219
tmp
toslunar Oct 23, 2017
c15f463
add EmpiricalNormalization
toslunar Oct 23, 2017
42dbd17
add obs_filter
toslunar Oct 23, 2017
e3f0ba8
flake8
toslunar Oct 23, 2017
634628b
Merge branch 'master' into ppo-agent
toslunar Oct 23, 2017
22e157e
fix 'add obs_filter'
toslunar Oct 23, 2017
78a8c30
add bound_mean option
toslunar Oct 24, 2017
8da29e9
fix EmpiricalNormalization
toslunar Oct 25, 2017
d26d0c1
move tests
toslunar Oct 26, 2017
c6986ba
split tests
toslunar Oct 26, 2017
a5b911a
fix
toslunar Oct 26, 2017
622cab1
Add tests
toslunar Oct 26, 2017
c18e660
Hide internals
toslunar Oct 26, 2017
d64c9a0
misc
toslunar Oct 26, 2017
d0b3714
Add tests with Chainer
toslunar Oct 26, 2017
67ac092
bugfix
toslunar Oct 26, 2017
fdc0c92
misc
toslunar Oct 26, 2017
f31626c
fix
toslunar Oct 27, 2017
3cd14a5
keep Python 2 support
toslunar Oct 27, 2017
3e35647
Add docstring
toslunar Oct 30, 2017
2b6d64c
Remove unused variable
toslunar Oct 30, 2017
f3a147d
bugfix
toslunar Nov 9, 2017
5fe19fb
add docstrings
toslunar Nov 9, 2017
877fa8f
misc
toslunar Nov 9, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions chainerrl/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from chainerrl.agents.pal import PAL # NOQA
from chainerrl.agents.pcl import PCL # NOQA
from chainerrl.agents.pgt import PGT # NOQA
from chainerrl.agents.ppo import PPO # NOQA
from chainerrl.agents.reinforce import REINFORCE # NOQA
from chainerrl.agents.residual_dqn import ResidualDQN # NOQA
from chainerrl.agents.sarsa import SARSA # NOQA
292 changes: 292 additions & 0 deletions chainerrl/agents/ppo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,292 @@
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from __future__ import absolute_import
from builtins import * # NOQA
from future import standard_library
standard_library.install_aliases()

import copy

import chainer
from chainer import cuda
import chainer.functions as F

from chainerrl import agent
from chainerrl.misc.batch_states import batch_states


def _elementwise_clip(x, x_min, x_max):
"""Elementwise clipping

Note: chainer.functions.clip supports clipping to constant intervals
"""
return F.minimum(F.maximum(x, x_min), x_max)


class PPO(agent.AttributeSavingMixin, agent.Agent):
"""Proximal Policy Optimization

See https://arxiv.org/abs/1707.06347

Args:
model (A3CModel): Model to train. Recurrent models are not supported.
state s |-> (pi(s, _), v(s))
optimizer (chainer.Optimizer): Optimizer used to train the model
gpu (int): GPU device id if not None nor negative
gamma (float): Discount factor [0, 1]
lambd (float): Lambda-return factor [0, 1]
phi (callable): Feature extractor function
value_func_coef (float): Weight coefficient for loss of
value function (0, inf)
entropy_coef (float): Weight coefficient for entropoy bonus [0, inf)
update_interval (int): Model update interval in step
minibatch_size (int): Minibatch size
epochs (int): Training epochs in an update
clip_eps (float): Epsilon for pessimistic clipping of likelihood ratio
to update policy
clip_eps_vf (float): Epsilon for pessimistic clipping of value
to update value function. If it is ``None``, value function is not
clipped on updates.
standardize_advantages (bool): Use standardized advantages on updates
average_v_decay (float): Decay rate of average V, only used for
recording statistics
average_loss_decay (float): Decay rate of average loss, only used for
recording statistics
"""

saved_attributes = ['model', 'optimizer']

def __init__(self, model, optimizer,
gpu=None,
gamma=0.99,
lambd=0.95,
phi=lambda x: x,
value_func_coef=1.0,
entropy_coef=0.01,
update_interval=2048,
minibatch_size=64,
epochs=10,
clip_eps=0.2,
clip_eps_vf=None,
standardize_advantages=True,
average_v_decay=0.999, average_loss_decay=0.99,
):
self.model = model

if gpu is not None and gpu >= 0:
cuda.get_device_from_id(gpu).use()
self.model.to_gpu(device=gpu)

self.optimizer = optimizer
self.gamma = gamma
self.lambd = lambd
self.phi = phi
self.value_func_coef = value_func_coef
self.entropy_coef = entropy_coef
self.update_interval = update_interval
self.minibatch_size = minibatch_size
self.epochs = epochs
self.clip_eps = clip_eps
self.clip_eps_vf = clip_eps_vf
self.standardize_advantages = standardize_advantages

self.average_v = 0
self.average_v_decay = average_v_decay
self.average_loss_policy = 0
self.average_loss_value_func = 0
self.average_loss_entropy = 0
self.average_loss_decay = average_loss_decay

self.xp = self.model.xp
self.last_state = None

self.memory = []
self.last_episode = []

def _act(self, state):
xp = self.xp
with chainer.using_config('train', False):
b_state = batch_states([state], xp, self.phi)
with chainer.no_backprop_mode():
action_distrib, v = self.model(b_state)
action = action_distrib.sample()
return cuda.to_cpu(action.data)[0], cuda.to_cpu(v.data)[0]

def _train(self):
if len(self.memory) + len(self.last_episode) >= self.update_interval:
self._flush_last_episode()
self.update()
self.memory = []

def _flush_last_episode(self):
if self.last_episode:
self._compute_teacher()
self.memory.extend(self.last_episode)
self.last_episode = []

def _compute_teacher(self):
"""Estimate state values and advantages of self.last_episode

TD(lambda) estimation
"""

adv = 0.0
for transition in reversed(self.last_episode):
td_err = (
transition['reward']
+ (self.gamma * transition['nonterminal']
* transition['next_v_pred'])
- transition['v_pred']
)
adv = td_err + self.gamma * self.lambd * adv
transition['adv'] = adv
transition['v_teacher'] = adv + transition['v_pred']

def _lossfun(self,
distribs, vs_pred, log_probs,
vs_pred_old, target_log_probs,
advs, vs_teacher):
prob_ratio = F.exp(log_probs - target_log_probs)
ent = distribs.entropy

prob_ratio = F.expand_dims(prob_ratio, axis=-1)
loss_policy = - F.mean(F.minimum(
prob_ratio * advs,
F.clip(prob_ratio, 1-self.clip_eps, 1+self.clip_eps) * advs))

if self.clip_eps_vf is None:
loss_value_func = F.mean_squared_error(vs_pred, vs_teacher)
else:
loss_value_func = F.mean(F.maximum(
F.square(vs_pred - vs_teacher),
F.square(_elementwise_clip(vs_pred,
vs_pred_old - self.clip_eps_vf,
vs_pred_old + self.clip_eps_vf)
- vs_teacher)
))

loss_entropy = -F.mean(ent)

# Update stats
self.average_loss_policy += (
(1 - self.average_loss_decay) *
(cuda.to_cpu(loss_policy.data) - self.average_loss_policy))
self.average_loss_value_func += (
(1 - self.average_loss_decay) *
(cuda.to_cpu(loss_value_func.data) - self.average_loss_value_func))
self.average_loss_entropy += (
(1 - self.average_loss_decay) *
(cuda.to_cpu(loss_entropy.data) - self.average_loss_entropy))

return (
loss_policy
+ self.value_func_coef * loss_value_func
+ self.entropy_coef * loss_entropy
)

def update(self):
xp = self.xp

if self.standardize_advantages:
all_advs = xp.array([b['adv'] for b in self.memory])
mean_advs = xp.mean(all_advs)
std_advs = xp.std(all_advs)

target_model = copy.deepcopy(self.model)
dataset_iter = chainer.iterators.SerialIterator(
self.memory, self.minibatch_size)

dataset_iter.reset()
while dataset_iter.epoch < self.epochs:
batch = dataset_iter.__next__()
states = batch_states([b['state'] for b in batch], xp, self.phi)
actions = xp.array([b['action'] for b in batch])
distribs, vs_pred = self.model(states)
with chainer.no_backprop_mode():
target_distribs, _ = target_model(states)

advs = xp.array([b['adv'] for b in batch], dtype=xp.float32)
if self.standardize_advantages:
advs = (advs - mean_advs) / std_advs

self.optimizer.update(
self._lossfun,
distribs, vs_pred, distribs.log_prob(actions),
vs_pred_old=xp.array(
[b['v_pred'] for b in batch], dtype=xp.float32),
target_log_probs=target_distribs.log_prob(actions),
advs=advs,
vs_teacher=xp.array(
[b['v_teacher'] for b in batch], dtype=xp.float32),
)

def act_and_train(self, state, reward):
if hasattr(self.model, 'obs_filter'):
xp = self.xp
b_state = batch_states([state], xp, self.phi)
self.model.obs_filter.experience(b_state)

action, v = self._act(state)

# Update stats
self.average_v += (
(1 - self.average_v_decay) *
(v[0] - self.average_v))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

v[0] is still ndarray when v is cupy.ndarray. You can use float(v) to make it scalar when its size is 1.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed issues on 0-dim cupy.ndarrays, using cuda.to_cpu(variable_1_dim.data)[0].


if self.last_state is not None:
self.last_episode.append({
'state': self.last_state,
'action': self.last_action,
'reward': reward,
'v_pred': self.last_v,
'next_state': state,
'next_v_pred': v,
'nonterminal': 1.0})
self.last_state = state
self.last_action = action
self.last_v = v

self._train()
return action

def act(self, state):
action, v = self._act(state)

# Update stats
self.average_v += (
(1 - self.average_v_decay) *
(v[0] - self.average_v))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto.


return action

def stop_episode_and_train(self, state, reward, done=False):
_, v = self._act(state)

assert self.last_state is not None
self.last_episode.append({
'state': self.last_state,
'action': self.last_action,
'reward': reward,
'v_pred': self.last_v,
'next_state': state,
'next_v_pred': v,
'nonterminal': 0.0 if done else 1.0})

self.last_state = None
del self.last_action
del self.last_v

self._flush_last_episode()
self.stop_episode()

def stop_episode(self):
pass

def get_statistics(self):
return [
('average_v', self.average_v),
('average_loss_policy', self.average_loss_policy),
('average_loss_value_func', self.average_loss_value_func),
('average_loss_entropy', self.average_loss_entropy),
]
3 changes: 2 additions & 1 deletion chainerrl/envs/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ def reset(self):
return self.observe()

def step(self, action):
if isinstance(action, np.ndarray):
if isinstance(self.action_space, spaces.Box):
assert isinstance(action, np.ndarray)
action = np.clip(action,
self.action_space.low,
self.action_space.high)
Expand Down
1 change: 1 addition & 0 deletions chainerrl/links/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from chainerrl.links.dqn_head import NatureDQNHead # NOQA
from chainerrl.links.dqn_head import NIPSDQNHead # NOQA
from chainerrl.links.empirical_normalization import EmpiricalNormalization # NOQA
from chainerrl.links.mlp import MLP # NOQA
from chainerrl.links.mlp_bn import MLPBN # NOQA
from chainerrl.links.sequence import Sequence # NOQA