Skip to content

Commit

Permalink
Merge pull request #8445 from toslunar/bn-uninit-persistent
Browse files Browse the repository at this point in the history
Register uninitialized persistents
  • Loading branch information
takagi committed Nov 14, 2019
2 parents 7bae43e + 108f1df commit d348175
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 5 deletions.
11 changes: 6 additions & 5 deletions chainer/links/normalization/batch_normalization.py
Expand Up @@ -190,8 +190,6 @@ class BatchNormalization(link.Link):

gamma = None
beta = None
avg_mean = None
avg_var = None

def __init__(self, size=None, decay=0.9, eps=2e-5, dtype=None,
use_gamma=True, use_beta=True,
Expand Down Expand Up @@ -229,16 +227,19 @@ def __init__(self, size=None, decay=0.9, eps=2e-5, dtype=None,
beta_initializer.dtype = self._highprec_dtype
self.beta = variable.Parameter(beta_initializer)

if size is not None:
if size is None:
self.avg_mean = None
self.avg_var = None
else:
self._initialize_params(size)
self.register_persistent('avg_mean')
self.register_persistent('avg_var')

def _initialize_params(self, shape):
self.avg_mean = self._init_array(self._initial_avg_mean, 0, shape)
self._initial_avg_mean = None
self.register_persistent('avg_mean')
self.avg_var = self._init_array(self._initial_avg_var, 1, shape)
self._initial_avg_var = None
self.register_persistent('avg_var')
if self.gamma is not None:
self.gamma.initialize(shape)
if self.beta is not None:
Expand Down
Expand Up @@ -656,4 +656,48 @@ def test_lazy_initialization_with_non_zero_current_cuda_device(self):
assert backend.GpuDevice.from_array(bn.avg_var) == device


@testing.parameterize(*testing.product({
'x_shape,bn_kwargs': [
((4, 3), {'axis': (0,)}),
((4, 3), {'size': (3,)}),
],
}))
class TestSerialize(unittest.TestCase):

def create_link(self):
return links.BatchNormalization(**self.bn_kwargs)

def train_link(self, bn):
x = numpy.random.rand(*self.x_shape).astype(numpy.float32)
bn(x)
x = numpy.random.rand(*self.x_shape).astype(numpy.float32)
bn(x, finetune=True)
# has non-trivial values to be stored
assert bn.avg_mean is not None
assert bn.N == 1

def create_serializer_pair(self):
target = {}
return (
chainer.serializers.DictionarySerializer(target),
chainer.serializers.NpzDeserializer(target),
)

def test_serialize(self):
ser, de = self.create_serializer_pair()

link1 = self.create_link()
self.train_link(link1)
link1.serialize(ser)

link2 = self.create_link()
link2.serialize(de)

testing.assert_allclose(link2.avg_mean, link1.avg_mean)
testing.assert_allclose(link2.avg_var, link1.avg_var)
testing.assert_allclose(link2.beta.array, link1.beta.array)
testing.assert_allclose(link2.gamma.array, link1.gamma.array)
assert link2.N == link1.N


testing.run_module(__name__, __file__)

0 comments on commit d348175

Please sign in to comment.