Skip to content

Commit

Permalink
Merge pull request #410 from muupan/copy-scalar-param
Browse files Browse the repository at this point in the history
Make copy_param support scalar parameters
  • Loading branch information
muupan committed Mar 11, 2019
2 parents c8d5cbb + f1f0df0 commit 3522f58
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 10 deletions.
20 changes: 10 additions & 10 deletions chainerrl/misc/copy_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,15 @@ def copy_param(target_link, source_link):
'not initialized.\nPlease try to forward dummy input '
'beforehand to determine parameter shape of the model.'.format(
param_name))
target_params[param_name].array[:] = param.array
target_params[param_name].array[...] = param.array

# Copy Batch Normalization's statistics
target_links = dict(target_link.namedlinks())
for link_name, link in source_link.namedlinks():
if isinstance(link, L.BatchNormalization):
target_bn = target_links[link_name]
target_bn.avg_mean[:] = link.avg_mean
target_bn.avg_var[:] = link.avg_var
target_bn.avg_mean[...] = link.avg_mean
target_bn.avg_var[...] = link.avg_var


def soft_copy_param(target_link, source_link, tau):
Expand All @@ -40,25 +40,25 @@ def soft_copy_param(target_link, source_link, tau):
'not initialized.\nPlease try to forward dummy input '
'beforehand to determine parameter shape of the model.'.format(
param_name))
target_params[param_name].array[:] *= (1 - tau)
target_params[param_name].array[:] += tau * param.array
target_params[param_name].array[...] *= (1 - tau)
target_params[param_name].array[...] += tau * param.array

# Soft-copy Batch Normalization's statistics
target_links = dict(target_link.namedlinks())
for link_name, link in source_link.namedlinks():
if isinstance(link, L.BatchNormalization):
target_bn = target_links[link_name]
target_bn.avg_mean[:] *= (1 - tau)
target_bn.avg_mean[:] += tau * link.avg_mean
target_bn.avg_var[:] *= (1 - tau)
target_bn.avg_var[:] += tau * link.avg_var
target_bn.avg_mean[...] *= (1 - tau)
target_bn.avg_mean[...] += tau * link.avg_mean
target_bn.avg_var[...] *= (1 - tau)
target_bn.avg_var[...] += tau * link.avg_var


def copy_grad(target_link, source_link):
"""Copy gradients of a link to another link."""
target_params = dict(target_link.namedparams())
for param_name, param in source_link.namedparams():
target_params[param_name].grad[:] = param.grad
target_params[param_name].grad[...] = param.grad


def synchronize_parameters(src, dst, method, tau=None):
Expand Down
34 changes: 34 additions & 0 deletions tests/misc_tests/test_copy_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,21 @@ def test_copy_param(self):
self.assertEqual(a_out_new, b_out)
self.assertEqual(b_out_new, b_out)

def test_copy_param_scalar(self):
a = chainer.Chain()
with a.init_scope():
a.p = chainer.Parameter(np.array(1))
b = chainer.Chain()
with b.init_scope():
b.p = chainer.Parameter(np.array(2))

self.assertNotEqual(a.p.array, b.p.array)

# Copy b's parameters to a
copy_param.copy_param(a, b)

self.assertEqual(a.p.array, b.p.array)

def test_copy_param_type_check(self):
a = L.Linear(None, 5)
b = L.Linear(1, 5)
Expand Down Expand Up @@ -59,6 +74,25 @@ def test_soft_copy_param(self):
np.testing.assert_almost_equal(a.W.array, np.full(a.W.shape, 0.595))
np.testing.assert_almost_equal(b.W.array, np.full(b.W.shape, 1.0))

def test_soft_copy_param_scalar(self):
a = chainer.Chain()
with a.init_scope():
a.p = chainer.Parameter(np.array(0.5))
b = chainer.Chain()
with b.init_scope():
b.p = chainer.Parameter(np.array(1))

# a = (1 - tau) * a + tau * b
copy_param.soft_copy_param(target_link=a, source_link=b, tau=0.1)

np.testing.assert_almost_equal(a.p.array, 0.55)
np.testing.assert_almost_equal(b.p.array, 1.0)

copy_param.soft_copy_param(target_link=a, source_link=b, tau=0.1)

np.testing.assert_almost_equal(a.p.array, 0.595)
np.testing.assert_almost_equal(b.p.array, 1.0)

def test_soft_copy_param_type_check(self):
a = L.Linear(None, 5)
b = L.Linear(1, 5)
Expand Down

0 comments on commit 3522f58

Please sign in to comment.