Skip to content

Commit

Permalink
Merge pull request #1120 from dwf/batch_norm_subclassable
Browse files Browse the repository at this point in the history
Refactor BatchNormalization to make it more extensible
  • Loading branch information
dmitriy-serdyuk committed Jun 18, 2016
2 parents 76676f9 + c043b0a commit e3a1efd
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 26 deletions.
58 changes: 32 additions & 26 deletions blocks/bricks/bn.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@
BATCH_NORM_DIVISOR, BATCH_NORM_MINIBATCH_ESTIMATE,
BATCH_NORM_SHIFT_PARAMETER, BATCH_NORM_SCALE_PARAMETER,
add_role)
from ..utils import (shared_floatx_zeros, shared_floatx,
shared_floatx_nans, is_shared_variable)
from ..utils import shared_floatx, shared_floatx_nans, is_shared_variable
from .base import lazy, application
from .sequences import Sequence, Feedforward, MLP
from .interfaces import RNGMixin
Expand Down Expand Up @@ -188,12 +187,10 @@ def __enter__(self):
def __exit__(self, *exc_info):
self._training_mode.pop()

def _compute_training_statistics(self, input_):
axes = (0,) + tuple((i + 1) for i, b in
enumerate(self.population_mean.broadcastable)
if b)
def _compute_training_statistics(self, input_, axes=None):
if axes is None:
axes = self.normalization_axes
mean = input_.mean(axis=axes, keepdims=True)
assert mean.broadcastable[1:] == self.population_mean.broadcastable
add_role(mean, BATCH_NORM_MINIBATCH_ESTIMATE)
if self.mean_only:
stdev = tensor.ones_like(mean)
Expand All @@ -202,11 +199,15 @@ def _compute_training_statistics(self, input_):
tensor.sqr(mean))
eps = numpy.cast[theano.config.floatX](self.epsilon)
stdev = tensor.sqrt(var + eps)
assert (stdev.broadcastable[1:] ==
self.population_stdev.broadcastable)
add_role(stdev, BATCH_NORM_MINIBATCH_ESTIMATE)
return mean, stdev

@property
def normalization_axes(self):
return (0,) + tuple((i + 1) for i, b in
enumerate(self.population_mean.broadcastable)
if b)

def _prepare_population_statistics(self):
mean = _add_batch_axis(self.population_mean)
if self.mean_only:
Expand Down Expand Up @@ -237,17 +238,6 @@ def _allocate(self):
else:
self.shift = tensor.constant(0, dtype=theano.config.floatX)

# These aren't technically parameters, in that they should not be
# learned using the same cost function as other model parameters.
self.population_mean = shared_floatx_zeros(var_dim,
name='population_mean',
broadcastable=broadcastable)
add_role(self.population_mean, BATCH_NORM_POPULATION_MEAN)

# Normally these would get annotated by an AnnotatingList, but they
# aren't in self.parameters.
add_annotation(self.population_mean, self)

if self.learn_scale and not self.mean_only:
# "gamma", from the Ioffe & Szegedy manuscript.
self.scale = shared_floatx_nans(var_dim, name='batch_norm_scale',
Expand All @@ -258,12 +248,28 @@ def _allocate(self):
else:
self.scale = tensor.constant(1., dtype=theano.config.floatX)

if not self.mean_only:
self.population_stdev = shared_floatx(numpy.ones(var_dim),
name='population_stdev',
broadcastable=broadcastable)
add_role(self.population_stdev, BATCH_NORM_POPULATION_STDEV)
add_annotation(self.population_stdev, self)
self._allocate_population_statistics(var_dim, broadcastable)

def _allocate_population_statistics(self, var_dim, broadcastable):
def _allocate_buffer(name, role, value):
# These aren't technically parameters, in that they should not be
# learned using the same cost function as other model parameters.
population_buffer = shared_floatx(value * numpy.ones(var_dim),
broadcastable=broadcastable,
name=name)
add_role(population_buffer, role)
# Normally these would get annotated by an AnnotatingList, but they
# aren't in self.parameters.
add_annotation(population_buffer, self)
return population_buffer

self.population_mean = _allocate_buffer('population_mean',
BATCH_NORM_POPULATION_MEAN,
numpy.zeros(var_dim))

self.population_stdev = _allocate_buffer('population_stdev',
BATCH_NORM_POPULATION_STDEV,
numpy.ones(var_dim))

def _initialize(self):
# We gate with is_shared_variable rather than relying on
Expand Down
11 changes: 11 additions & 0 deletions tests/bricks/test_bn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
from blocks.bricks.conv import (Convolutional, ConvolutionalSequence,
MaxPooling, AveragePooling)
from blocks.initialization import Constant
from blocks.filter import VariableFilter
from blocks.graph import (ComputationGraph, batch_normalization,
get_batch_normalization_updates)
from blocks.roles import BATCH_NORM_MINIBATCH_ESTIMATE


def random_unif(rng, dim, low=1, high=10):
Expand Down Expand Up @@ -408,3 +410,12 @@ def test_batch_normalized_mlp_learn_scale_propagated_at_alloc():
assert all(act.children[0].learn_scale for act in mlp.activations)
mlp.allocate()
assert not any(act.children[0].learn_scale for act in mlp.activations)


def test_batch_normalization_broadcastable_sanity():
bn = BatchNormalization((5, 3, 2), broadcastable=(False, True, False))
with bn:
cg = ComputationGraph([bn.apply(tensor.tensor4('abc'))])
vars = VariableFilter(roles=[BATCH_NORM_MINIBATCH_ESTIMATE])(cg)
assert all(v.broadcastable[1:] == bn.population_mean.broadcastable
for v in vars)

0 comments on commit e3a1efd

Please sign in to comment.