Skip to content

Commit

Permalink
Merge pull request chainer#4505 from toslunar/fix-bn-eps-twice
Browse files Browse the repository at this point in the history
Remove eps from batch normalization statistics
  • Loading branch information
okuta committed Mar 23, 2018
1 parent 3230942 commit e6477b6
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 8 deletions.
3 changes: 1 addition & 2 deletions chainer/functions/normalization/batch_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,8 +195,7 @@ def forward(self, inputs):
beta = beta[expander]
self.mean = x.mean(axis=self.axis)
var = x.var(axis=self.axis)
var += self.eps
self.inv_std = var ** (-0.5)
self.inv_std = (var + self.eps) ** (-0.5)
y = _apply_bn_fwd(xp, x, self.mean[expander],
self.inv_std[expander], gamma, beta)
# Update running statistics
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,8 @@ def test_backward_gpu_without_cudnn(self):


@testing.parameterize(
{'nx': 10, 'ny': 10},
{'nx': 10, 'ny': 10, 'eps': 2e-5},
{'nx': 10, 'ny': 10, 'eps': 1e-1},
# TODO(Kenta Oono)
# Pass the case below (this test does not pass when nx != ny).
# {'nx': 10, 'ny': 15}
Expand All @@ -131,7 +132,7 @@ class TestPopulationStatistics(unittest.TestCase):
def setUp(self):
self.decay = 0.9
self.size = 3
self.link = links.BatchNormalization(self.size, self.decay)
self.link = links.BatchNormalization(self.size, self.decay, self.eps)
self.x = numpy.random.uniform(
-1, 1, (self.nx, self.size)).astype(numpy.float32)
self.y = numpy.random.uniform(
Expand All @@ -151,12 +152,10 @@ def check_statistics(self, x, y):
testing.assert_allclose(mean, self.link.avg_mean)
testing.assert_allclose(unbiased_var, self.link.avg_var)

@condition.retry(3)
def test_statistics_cpu(self):
self.check_statistics(self.x, self.y)

@attr.gpu
@condition.retry(3)
def test_statistics_gpu(self):
self.link.to_gpu()
self.check_statistics(cuda.to_gpu(self.x), cuda.to_gpu(self.y))
Expand Down Expand Up @@ -185,12 +184,10 @@ def check_statistics2(self, x, y):
testing.assert_allclose(mean, self.link.avg_mean)
testing.assert_allclose(unbiased_var, self.link.avg_var)

@condition.retry(3)
def test_statistics2_cpu(self):
self.check_statistics2(self.x, self.y)

@attr.gpu
@condition.retry(3)
def test_statistics2_gpu(self):
self.link.to_gpu()
self.check_statistics2(
Expand Down

0 comments on commit e6477b6

Please sign in to comment.