Skip to content

Commit

Permalink
Merge pull request #1121 from dwf/num_stable_var_batch_norm
Browse files Browse the repository at this point in the history
Improve stability of batch norm
  • Loading branch information
dwf committed Jun 20, 2016
2 parents e3a1efd + 09bfd0d commit 6dd819f
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions blocks/bricks/bn.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,8 +195,13 @@ def _compute_training_statistics(self, input_, axes=None):
if self.mean_only:
stdev = tensor.ones_like(mean)
else:
var = (tensor.sqr(input_).mean(axis=axes, keepdims=True) -
tensor.sqr(mean))
# We already have the mean; this saves Theano the trouble of
# optimizing the graph produced by tensor.var().
# This two pass version is going to be slightly more stable than
# E[X^2] - E[X]^2, at least when n is small. It's also never
# going to be negative due to numerical error.
var = tensor.mean(tensor.sqr(input_ - mean),
axis=axes, keepdims=True)
eps = numpy.cast[theano.config.floatX](self.epsilon)
stdev = tensor.sqrt(var + eps)
add_role(stdev, BATCH_NORM_MINIBATCH_ESTIMATE)
Expand Down

0 comments on commit 6dd819f

Please sign in to comment.