Skip to content

Commit

Permalink
Merge 695734a into 4151087
Browse files Browse the repository at this point in the history
  • Loading branch information
muupan committed Nov 7, 2019
2 parents 4151087 + 695734a commit ecf2a65
Show file tree
Hide file tree
Showing 12 changed files with 160 additions and 22 deletions.
32 changes: 32 additions & 0 deletions .pfnci/config.pbtxt
Expand Up @@ -18,6 +18,38 @@ configs {
}
}
configs {
key: "chainerrl.py3.gpu.slow"
value {
requirement {
cpu: 10
memory: 30
gpu: 1
}
time_limit {
seconds: 2400
}
environment_variables { key: "GPU" value: "1" }
environment_variables { key: "SLOW" value: "1" }
command: "bash .pfnci/script.sh py3.gpu"
}
}
configs {
key: "chainerrl.py3.cpu.slow"
value {
requirement {
cpu: 10
memory: 30
}
time_limit {
seconds: 2400
}
environment_variables { key: "SLOW" value: "1" }
command: "bash .pfnci/script.sh py3.cpu"
}
}
# ChainerRL CPU-only unit tests.
# URL: https://ci.preferred.jp/chainerrl.py3.cpu/
configs {
Expand Down
23 changes: 23 additions & 0 deletions .pfnci/hint.pbtxt
Expand Up @@ -4,9 +4,32 @@
# https://github.com/chainer/xpytest/blob/master/proto/test_case.proto

# Slow tests take 60+ seconds.
rules { name: "agents_tests/test_ddpg.py" xdist: 4 deadline: 480 }
rules { name: "agents_tests/test_reinforce.py" xdist: 4 deadline: 480 }
rules { name: "agents_tests/test_pcl.py" xdist: 4 deadline: 480 }
rules { name: "agents_tests/test_a2c.py" xdist: 4 deadline: 240 }
rules { name: "agents_tests/test_al.py" xdist: 4 deadline: 240 }
rules { name: "agents_tests/test_categorical_dqn.py" xdist: 4 deadline: 240 }
rules { name: "agents_tests/test_double_categorical_dqn.py" xdist: 4 deadline: 240 }
rules { name: "agents_tests/test_double_dqn.py" xdist: 4 deadline: 240 }
rules { name: "agents_tests/test_double_iqn.py" xdist: 4 deadline: 240 }
rules { name: "agents_tests/test_double_pal.py" xdist: 4 deadline: 240 }
rules { name: "agents_tests/test_dpp.py" xdist: 4 deadline: 240 }
rules { name: "agents_tests/test_dqn.py" xdist: 4 deadline: 240 }
rules { name: "agents_tests/test_iqn.py" xdist: 4 deadline: 240 }
rules { name: "agents_tests/test_sarsa.py" xdist: 4 deadline: 240 }
rules { name: "agents_tests/test_ppo.py" xdist: 4 deadline: 240 }
rules { name: "agents_tests/test_pal.py" xdist: 4 deadline: 240 }
rules { name: "agents_tests/test_trpo.py" xdist: 4 deadline: 240 }
rules { name: "agents_tests/test_a3c.py" xdist: 4 deadline: 240 }
rules { name: "agents_tests/test_agents.py" xdist: 4 deadline: 240 }
rules { name: "agents_tests/test_residual_dqn.py" xdist: 4 deadline: 240 }
rules { name: "agents_tests/test_soft_actor_critic.py" xdist: 4 deadline: 240 }
rules { name: "agents_tests/test_td3.py" xdist: 4 deadline: 240 }
rules { name: "misc_tests/test_random.py" xdist: 4 deadline: 240 }
# Slow tests take 10+ seconds.
rules { name: "agents_tests/test_acer.py" }
rules { name: "agents_tests/test_ale.py" }
Expand Down
7 changes: 6 additions & 1 deletion .pfnci/run.sh
Expand Up @@ -25,6 +25,7 @@ TARGET="$1"
: "${XPYTEST_NUM_THREADS:=$(nproc)}"
: "${PYTHON=python3}"
: "${CHAINER=}"
: "${SLOW:=0}"

# Use multi-process service to prevent GPU flakiness caused by running many
# processes on a GPU. Specifically, it seems running more than 16 processes
Expand All @@ -39,7 +40,11 @@ fi
################################################################################

main() {
marker='not slow'
if (( !SLOW )); then
marker='not slow'
else
marker='slow'
fi
if (( !GPU )); then
marker+=' and not gpu'
bucket=1
Expand Down
1 change: 1 addition & 0 deletions .pfnci/script.sh
Expand Up @@ -61,6 +61,7 @@ main() {
;;
esac
docker_args+=(--env="CHAINER=${chainer_version}")
docker_args+=(--env="SLOW=${SLOW:-0}")

case "${TARGET}" in
py2.* ) docker_args+=(--env="PYTHON=python");;
Expand Down
6 changes: 3 additions & 3 deletions chainerrl/agents/a3c.py
Expand Up @@ -218,16 +218,16 @@ def update(self, statevar):
total_loss = F.squeeze(pi_loss) + F.squeeze(v_loss)

# Compute gradients using thread-specific model
self.model.zerograds()
self.model.cleargrads()
total_loss.backward()
# Copy the gradients to the globally shared model
self.shared_model.zerograds()
copy_param.copy_grad(
target_link=self.shared_model, source_link=self.model)
# Update the globally shared model
if self.process_idx == 0:
norm = sum(np.sum(np.square(param.grad))
for param in self.optimizer.target.params())
for param in self.optimizer.target.params()
if param.grad is not None)
logger.debug('grad norm:%s', norm)
self.optimizer.update()
if self.process_idx == 0:
Expand Down
6 changes: 3 additions & 3 deletions chainerrl/agents/acer.py
Expand Up @@ -501,16 +501,16 @@ def update(self, t_start, t_stop, R, states, actions, rewards, values,
avg_action_distribs=avg_action_distribs)

# Compute gradients using thread-specific model
self.model.zerograds()
self.model.cleargrads()
F.squeeze(total_loss).backward()
# Copy the gradients to the globally shared model
self.shared_model.zerograds()
copy_param.copy_grad(
target_link=self.shared_model, source_link=self.model)
# Update the globally shared model
if self.process_idx == 0:
norm = sum(np.sum(np.square(param.grad))
for param in self.optimizer.target.params())
for param in self.optimizer.target.params()
if param.grad is not None)
self.logger.debug('grad norm:%s', norm)
self.optimizer.update()

Expand Down
3 changes: 1 addition & 2 deletions chainerrl/agents/nsq.py
Expand Up @@ -114,10 +114,9 @@ def update(self, statevar):
# loss /= self.t - self.t_start

# Compute gradients using thread-specific model
self.q_function.zerograds()
self.q_function.cleargrads()
loss.backward()
# Copy the gradients to the globally shared model
self.shared_q_function.zerograds()
copy_param.copy_grad(self.shared_q_function, self.q_function)
# Update the globally shared model
self.optimizer.update()
Expand Down
6 changes: 3 additions & 3 deletions chainerrl/agents/pcl.py
Expand Up @@ -247,17 +247,17 @@ def update(self, loss):
(asfloat(loss) - self.average_loss))

# Compute gradients using thread-specific model
self.model.zerograds()
self.model.cleargrads()
F.squeeze(loss).backward()
if self.train_async:
# Copy the gradients to the globally shared model
self.shared_model.zerograds()
copy_param.copy_grad(
target_link=self.shared_model, source_link=self.model)
if self.process_idx == 0:
xp = self.xp
norm = sum(xp.sum(xp.square(param.grad))
for param in self.optimizer.target.params())
for param in self.optimizer.target.params()
if param.grad is not None)
self.logger.debug('grad norm:%s', norm)
self.optimizer.update()

Expand Down
11 changes: 10 additions & 1 deletion chainerrl/misc/copy_param.py
Expand Up @@ -58,7 +58,16 @@ def copy_grad(target_link, source_link):
"""Copy gradients of a link to another link."""
target_params = dict(target_link.namedparams())
for param_name, param in source_link.namedparams():
target_params[param_name].grad[...] = param.grad
if target_params[param_name].grad is None:
if param.grad is None:
pass
else:
target_params[param_name].grad = param.grad.copy()
else:
if param.grad is None:
target_params[param_name].grad = None
else:
target_params[param_name].grad[...] = param.grad


def synchronize_parameters(src, dst, method, tau=None):
Expand Down
2 changes: 1 addition & 1 deletion tests/agents_tests/basetest_agents.py
Expand Up @@ -48,6 +48,6 @@ def test_train(self):
train_agent(
agent=agent,
env=self.env,
steps=2000,
steps=200,
outdir=tempfile.mkdtemp(),
max_episode_len=10)
16 changes: 8 additions & 8 deletions tests/agents_tests/test_agents.py
Expand Up @@ -39,8 +39,8 @@ def create_deterministic_policy_for_env(env):
return policies.FCDeterministicPolicy(
n_input_channels=ndim_obs,
action_size=env.action_space.low.size,
n_hidden_channels=200,
n_hidden_layers=2,
n_hidden_channels=10,
n_hidden_layers=1,
bound_action=False)


Expand All @@ -51,14 +51,14 @@ def create_state_q_function_for_env(env):
return q_functions.FCStateQFunctionWithDiscreteAction(
ndim_obs=ndim_obs,
n_actions=env.action_space.n,
n_hidden_channels=200,
n_hidden_layers=2)
n_hidden_channels=10,
n_hidden_layers=1)
elif isinstance(env.action_space, gym.spaces.Box):
return q_functions.FCQuadraticStateQFunction(
n_input_channels=ndim_obs,
n_dim_action=env.action_space.low.size,
n_hidden_channels=200,
n_hidden_layers=2,
n_hidden_channels=10,
n_hidden_layers=1,
action_space=env.action_space)
else:
raise NotImplementedError()
Expand All @@ -71,8 +71,8 @@ def create_state_action_q_function_for_env(env):
return q_functions.FCSAQFunction(
n_dim_obs=ndim_obs,
n_dim_action=env.action_space.low.size,
n_hidden_channels=200,
n_hidden_layers=2)
n_hidden_channels=10,
n_hidden_layers=1)


def create_v_function_for_env(env):
Expand Down
69 changes: 69 additions & 0 deletions tests/misc_tests/test_copy_param.py
Expand Up @@ -7,6 +7,7 @@
import unittest

import chainer
from chainer import functions as F
from chainer import links as L
import numpy as np

Expand Down Expand Up @@ -99,3 +100,71 @@ def test_soft_copy_param_type_check(self):

with self.assertRaises(TypeError):
copy_param.soft_copy_param(target_link=a, source_link=b, tau=0.1)

def test_copy_grad(self):

def set_random_grad(link):
link.cleargrads()
x = np.random.normal(size=(1, 1)).astype(np.float32)
y = link(x) * np.random.normal()
F.sum(y).backward()

# When source is not None and target is None
a = L.Linear(1, 5)
b = L.Linear(1, 5)
set_random_grad(a)
b.cleargrads()
assert a.W.grad is not None
assert a.b.grad is not None
assert b.W.grad is None
assert b.b.grad is None
copy_param.copy_grad(target_link=b, source_link=a)
np.testing.assert_almost_equal(a.W.grad, b.W.grad)
np.testing.assert_almost_equal(a.b.grad, b.b.grad)
assert a.W.grad is not b.W.grad
assert a.b.grad is not b.b.grad

# When both are not None
a = L.Linear(1, 5)
b = L.Linear(1, 5)
set_random_grad(a)
set_random_grad(b)
assert a.W.grad is not None
assert a.b.grad is not None
assert b.W.grad is not None
assert b.b.grad is not None
copy_param.copy_grad(target_link=b, source_link=a)
np.testing.assert_almost_equal(a.W.grad, b.W.grad)
np.testing.assert_almost_equal(a.b.grad, b.b.grad)
assert a.W.grad is not b.W.grad
assert a.b.grad is not b.b.grad

# When source is None and target is not None
a = L.Linear(1, 5)
b = L.Linear(1, 5)
a.cleargrads()
set_random_grad(b)
assert a.W.grad is None
assert a.b.grad is None
assert b.W.grad is not None
assert b.b.grad is not None
copy_param.copy_grad(target_link=b, source_link=a)
assert a.W.grad is None
assert a.b.grad is None
assert b.W.grad is None
assert b.b.grad is None

# When both are None
a = L.Linear(1, 5)
b = L.Linear(1, 5)
a.cleargrads()
b.cleargrads()
assert a.W.grad is None
assert a.b.grad is None
assert b.W.grad is None
assert b.b.grad is None
copy_param.copy_grad(target_link=b, source_link=a)
assert a.W.grad is None
assert a.b.grad is None
assert b.W.grad is None
assert b.b.grad is None

0 comments on commit ecf2a65

Please sign in to comment.