Skip to content

Commit

Permalink
Merge pull request #1062 from dwf/possibly_speedup_batch_norm
Browse files Browse the repository at this point in the history
Speed up BatchNormalization training graphs
  • Loading branch information
dwf committed Apr 20, 2016
2 parents 17373d4 + 1ef9ddb commit 0c9a029
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions blocks/bricks/bn.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,10 @@ def _compute_training_statistics(self, input_):
if self.mean_only:
stdev = tensor.ones_like(mean)
else:
stdev = tensor.sqrt(tensor.var(input_, axis=axes, keepdims=True) +
numpy.cast[theano.config.floatX](self.epsilon))
var = (tensor.sqr(input_).mean(axis=axes, keepdims=True) -
tensor.sqr(mean))
eps = numpy.cast[theano.config.floatX](self.epsilon)
stdev = tensor.sqrt(var + eps)
assert (stdev.broadcastable[1:] ==
self.population_stdev.broadcastable)
add_role(stdev, BATCH_NORM_MINIBATCH_ESTIMATE)
Expand Down

0 comments on commit 0c9a029

Please sign in to comment.