Skip to content

Commit

Permalink
Merge pull request #100 from toslunar/chainerv2
Browse files Browse the repository at this point in the history
Use Chainer v2
  • Loading branch information
muupan committed Jun 7, 2017
2 parents dda3c17 + 3b5366a commit 6ccc61d
Show file tree
Hide file tree
Showing 39 changed files with 356 additions and 222 deletions.
19 changes: 18 additions & 1 deletion chainerrl/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,22 @@

from chainer import serializers
from future.utils import with_metaclass
import numpy
import warnings

from chainerrl.misc.makedirs import makedirs


def load_npz_no_strict(filename, obj):
try:
serializers.load_npz(filename, obj)
except KeyError as e:
warnings.warn(repr(e))
with numpy.load(filename) as f:
d = serializers.NpzDeserializer(f, strict=False)
d.load(obj)


class Agent(with_metaclass(ABCMeta, object)):
"""Abstract agent class."""

Expand Down Expand Up @@ -107,7 +119,12 @@ def save(self, dirname):
def load(self, dirname):
"""Load internal states."""
for attr in self.saved_attributes:
serializers.load_npz(
"""Fix Chainer Issue #2772
In Chainer v2, a (stateful) optimizer cannot be loaded from
an npz saved before the first update.
"""
load_npz_no_strict(
os.path.join(dirname, '{}.npz'.format(attr)),
getattr(self, attr))

Expand Down
8 changes: 4 additions & 4 deletions chainerrl/agents/a3c.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,8 @@ def update(self, statevar):
target_link=self.shared_model, source_link=self.model)
# Update the globally shared model
if self.process_idx == 0:
norm = self.optimizer.compute_grads_norm()
norm = sum(np.sum(np.square(param.grad))
for param in self.optimizer.target.params())
logger.debug('grad norm:%s', norm)
self.optimizer.update()
if self.process_idx == 0:
Expand Down Expand Up @@ -255,13 +256,12 @@ def act_and_train(self, state, reward):

self.past_states[self.t] = statevar
pout, vout = self.model.pi_and_v(statevar)
action = pout.sample()
action.creator = None # Do not backprop through sampled actions
action = pout.sample().data # Do not backprop through sampled actions
self.past_action_log_prob[self.t] = pout.log_prob(action)
self.past_action_entropy[self.t] = pout.entropy
self.past_values[self.t] = vout
self.t += 1
action = action.data[0]
action = action[0]
if self.process_idx == 0:
logger.debug('t:%s r:%s a:%s pout:%s',
self.t, reward, action, pout)
Expand Down
7 changes: 4 additions & 3 deletions chainerrl/agents/acer.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,10 +194,10 @@ def __init__(self, shared, pi, q):
def backprop_truncated(*variables):
backup = [v.creator for v in variables]
for v in variables:
v.creator = None
v.unchain()
yield
for v, backup_creator in zip(variables, backup):
v.creator = backup_creator
v.set_creator(backup_creator)


def compute_loss_with_kl_constraint(distrib, another_distrib, original_loss,
Expand Down Expand Up @@ -523,7 +523,8 @@ def update(self, t_start, t_stop, R, states, actions, rewards, values,
target_link=self.shared_model, source_link=self.model)
# Update the globally shared model
if self.process_idx == 0:
norm = self.optimizer.compute_grads_norm()
norm = sum(np.sum(np.square(param.grad))
for param in self.optimizer.target.params())
self.logger.debug('grad norm:%s', norm)
self.optimizer.update()

Expand Down
6 changes: 3 additions & 3 deletions chainerrl/agents/al.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def _compute_y_and_t(self, exp_batch, gamma):
batch_state = exp_batch['state']
batch_size = len(exp_batch['reward'])

qout = self.q_function(batch_state, test=False)
qout = self.q_function(batch_state)

batch_actions = exp_batch['action']

Expand All @@ -43,13 +43,13 @@ def _compute_y_and_t(self, exp_batch, gamma):
# Compute target values

with chainer.no_backprop_mode():
target_qout = self.target_q_function(batch_state, test=True)
target_qout = self.target_q_function(batch_state)

batch_next_state = exp_batch['next_state']

with state_kept(self.target_q_function):
target_next_qout = self.target_q_function(
batch_next_state, test=True)
batch_next_state)
next_q_max = F.reshape(target_next_qout.max, (batch_size,))

batch_rewards = exp_batch['reward']
Expand Down
43 changes: 29 additions & 14 deletions chainerrl/agents/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,16 @@
from chainerrl.replay_buffer import ReplayUpdater


def disable_train(chain):
call_orig = chain.__call__

def call_test(self, x):
with chainer.using_config('train', False):
return call_orig(self, x)

chain.__call__ = call_test


class DDPGModel(chainer.Chain, RecurrentChainMixin):

def __init__(self, policy, q_func):
Expand Down Expand Up @@ -127,6 +137,8 @@ def __init__(self, model, actor_optimizer, critic_optimizer, replay_buffer,
self.last_state = None
self.last_action = None
self.target_model = copy.deepcopy(self.model)
disable_train(self.target_model['q_function'])
disable_train(self.target_model['policy'])
self.average_q = 0
self.average_actor_loss = 0.0
self.average_critic_loss = 0.0
Expand Down Expand Up @@ -172,25 +184,24 @@ def compute_critic_loss(self, batch):
with chainer.no_backprop_mode():
# Target policy observes s_{t+1}
next_actions = self.target_policy(
batch_next_state, test=True).sample()
batch_next_state).sample()

# Q(s_{t+1}, mu(a_{t+1})) is evaluated.
# This should not affect the internal state of Q.
with state_kept(self.target_q_function):
next_q = self.target_q_function(batch_next_state, next_actions,
test=True)
next_q = self.target_q_function(batch_next_state, next_actions)

# Target Q-function observes s_{t+1} and a_{t+1}
if isinstance(self.target_q_function, Recurrent):
self.target_q_function.update_state(
batch_next_state, batch_next_actions, test=True)
batch_next_state, batch_next_actions)

target_q = batch_rewards + self.gamma * \
(1.0 - batch_terminal) * F.reshape(next_q, (batchsize,))

# Estimated Q-function observes s_t and a_t
predict_q = F.reshape(
self.q_function(batch_state, batch_actions, test=False),
self.q_function(batch_state, batch_actions),
(batchsize,))

loss = F.mean_squared_error(target_q, predict_q)
Expand Down Expand Up @@ -218,16 +229,19 @@ def compute_actor_loss(self, batch):
batch_size = len(batch_action)

# Estimated policy observes s_t
onpolicy_actions = self.policy(batch_state, test=False).sample()
onpolicy_actions = self.policy(batch_state).sample()

# Q(s_t, mu(s_t)) is evaluated.
# This should not affect the internal state of Q.
with state_kept(self.q_function):
q = self.q_function(batch_state, onpolicy_actions, test=False)
q = self.q_function(batch_state, onpolicy_actions)

# Estimated Q-function observes s_t and a_t
if isinstance(self.q_function, Recurrent):
self.q_function.update_state(batch_state, batch_action, test=False)
self.q_function.update_state(batch_state, batch_action)

# Avoid the numpy #9165 bug (see also: chainer #2744)
q = q[:, :]

# Since we want to maximize Q, loss is negation of Q
loss = - F.sum(q) / batch_size
Expand Down Expand Up @@ -268,8 +282,8 @@ def update_from_episodes(self, episodes, errors_out=None):
# Since the target model is evaluated one-step ahead,
# its internal states need to be updated
self.target_q_function.update_state(
batches[0]['state'], batches[0]['action'], test=True)
self.target_policy(batches[0]['state'], test=True)
batches[0]['state'], batches[0]['action'])
self.target_policy(batches[0]['state'])

# Update critic through time
critic_loss = 0
Expand Down Expand Up @@ -317,10 +331,11 @@ def act_and_train(self, state, reward):

def act(self, state):

s = self.batch_states([state], self.xp, self.phi)
action = self.policy(s, test=True).sample()
# Q is not needed here, but log it just for information
q = self.q_function(s, action, test=True)
with chainer.using_config('train', False):
s = self.batch_states([state], self.xp, self.phi)
action = self.policy(s).sample()
# Q is not needed here, but log it just for information
q = self.q_function(s, action)

# Update stats
self.average_q *= self.average_q_decay
Expand Down
9 changes: 6 additions & 3 deletions chainerrl/agents/double_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from future import standard_library
standard_library.install_aliases()

import chainer

from chainerrl.agents import dqn
from chainerrl.recurrent import state_kept

Expand All @@ -19,10 +21,11 @@ def _compute_target_values(self, exp_batch, gamma):

batch_next_state = exp_batch['next_state']

with state_kept(self.q_function):
next_qout = self.q_function(batch_next_state, test=True)
with chainer.using_config('train', False):
with state_kept(self.q_function):
next_qout = self.q_function(batch_next_state)

target_next_qout = self.target_q_function(batch_next_state, test=True)
target_next_qout = self.target_q_function(batch_next_state)

next_q_max = target_next_qout.evaluate_actions(
next_qout.greedy_actions)
Expand Down
8 changes: 4 additions & 4 deletions chainerrl/agents/double_pal.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,24 +20,24 @@ def _compute_y_and_t(self, exp_batch, gamma):
batch_state = exp_batch['state']
batch_size = len(exp_batch['reward'])

qout = self.q_function(batch_state, test=False)
qout = self.q_function(batch_state)

batch_actions = exp_batch['action']
batch_q = qout.evaluate_actions(batch_actions)

# Compute target values

with chainer.no_backprop_mode():
target_qout = self.target_q_function(batch_state, test=True)
target_qout = self.target_q_function(batch_state)

batch_next_state = exp_batch['next_state']

with state_kept(self.q_function):
next_qout = self.q_function(batch_next_state, test=False)
next_qout = self.q_function(batch_next_state)

with state_kept(self.target_q_function):
target_next_qout = self.target_q_function(
batch_next_state, test=True)
batch_next_state)
next_q_max = F.reshape(target_next_qout.evaluate_actions(
next_qout.greedy_actions), (batch_size,))

Expand Down
6 changes: 3 additions & 3 deletions chainerrl/agents/dpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def _compute_target_values(self, exp_batch, gamma):

batch_next_state = exp_batch['next_state']

target_next_qout = self.target_q_function(batch_next_state, test=True)
target_next_qout = self.target_q_function(batch_next_state)
next_q_expect = self._l_operator(target_next_qout)

batch_rewards = exp_batch['reward']
Expand All @@ -44,7 +44,7 @@ def _compute_y_and_t(self, exp_batch, gamma):
batch_state = exp_batch['state']
batch_size = len(exp_batch['reward'])

qout = self.q_function(batch_state, test=False)
qout = self.q_function(batch_state)

batch_actions = exp_batch['action']
# Q(s_t,a_t)
Expand All @@ -53,7 +53,7 @@ def _compute_y_and_t(self, exp_batch, gamma):

with chainer.no_backprop_mode():
# Compute target values
target_qout = self.target_q_function(batch_state, test=True)
target_qout = self.target_q_function(batch_state)

# Q'(s_t,a_t)
target_q = F.reshape(target_qout.evaluate_actions(
Expand Down

0 comments on commit 6ccc61d

Please sign in to comment.