Skip to content

Commit

Permalink
Merge pull request #1012 from dwf/mean_only_batch_norm
Browse files Browse the repository at this point in the history
Support for mean-only batch normalization.
  • Loading branch information
vdumoulin committed Mar 1, 2016
2 parents 568e7a7 + 5f8d4a0 commit 47285eb
Show file tree
Hide file tree
Showing 4 changed files with 198 additions and 90 deletions.
103 changes: 73 additions & 30 deletions blocks/bricks/bn.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import collections
from functools import partial

import numpy
from picklable_itertools.extras import equizip
Expand Down Expand Up @@ -64,6 +65,8 @@ class BatchNormalization(RNGMixin, Feedforward):
shift_init : object, optional
Initialization object to use for the learned shift parameter
($\\beta$ in [BN]_). By default, uses constant initialization of 0.
mean_only : bool, optional
Perform "mean-only" batch normalization as described in [SK2016]_.
Notes
-----
Expand Down Expand Up @@ -106,11 +109,20 @@ class BatchNormalization(RNGMixin, Feedforward):
the `input_dim` should be omitted at construction, to be inferred from
the layer below.
.. [BN] Sergey Ioffe and Christian Szegedy. *Batch normalization:
accelerating deep network training by reducing internal covariate
shift*. ICML (2015), pp. 448-456.
.. [SK2016] Tim Salimans and Diederik P. Kingma. *Weight
normalization: a simple reparameterization to accelerate training
of deep neural networks*. arXiv 1602.07868.
"""
@lazy(allocation=['input_dim'])
def __init__(self, input_dim, broadcastable=None,
conserve_memory=True, epsilon=1e-4, scale_init=None,
shift_init=None, **kwargs):
shift_init=None, mean_only=False, **kwargs):
self.input_dim = input_dim
self.broadcastable = broadcastable
self.conserve_memory = conserve_memory
Expand All @@ -119,6 +131,7 @@ def __init__(self, input_dim, broadcastable=None,
else scale_init)
self.shift_init = (Constant(0) if shift_init is None
else shift_init)
self.mean_only = mean_only
self._training_mode = []
super(BatchNormalization, self).__init__(**kwargs)

Expand All @@ -143,10 +156,16 @@ def apply(self, input_, application_call):
# Give these quantities roles in the graph.
_add_role_and_annotate(mean, BATCH_NORM_OFFSET,
[self, application_call])
_add_role_and_annotate(stdev, BATCH_NORM_DIVISOR,
[self, application_call])
scale = _add_batch_axis(self.scale)
if self.mean_only:
scale = tensor.ones_like(self.shift)
stdev = tensor.ones_like(mean)
else:
scale = self.scale
# The annotation/role information is useless if it's a constant.
_add_role_and_annotate(stdev, BATCH_NORM_DIVISOR,
[self, application_call])
shift = _add_batch_axis(self.shift)
scale = _add_batch_axis(scale)
# Heavy lifting is done by the Theano utility function.
normalized = bn.batch_normalization(input_, scale, shift, mean, stdev,
mode=('low_mem'
Expand All @@ -166,16 +185,24 @@ def _compute_training_statistics(self, input_):
if b)
mean = input_.mean(axis=axes, keepdims=True)
assert mean.broadcastable[1:] == self.population_mean.broadcastable
stdev = tensor.sqrt(tensor.var(input_, axis=axes, keepdims=True) +
numpy.cast[theano.config.floatX](self.epsilon))
assert stdev.broadcastable[1:] == self.population_stdev.broadcastable
add_role(mean, BATCH_NORM_MINIBATCH_ESTIMATE)
add_role(stdev, BATCH_NORM_MINIBATCH_ESTIMATE)
if self.mean_only:
stdev = tensor.ones_like(mean)
else:
stdev = tensor.sqrt(tensor.var(input_, axis=axes, keepdims=True) +
numpy.cast[theano.config.floatX](self.epsilon))
assert (stdev.broadcastable[1:] ==
self.population_stdev.broadcastable)
add_role(stdev, BATCH_NORM_MINIBATCH_ESTIMATE)
return mean, stdev

def _prepare_population_statistics(self):
mean = _add_batch_axis(self.population_mean)
stdev = _add_batch_axis(self.population_stdev)
if self.mean_only:
stdev = tensor.ones_like(self.population_mean)
else:
stdev = self.population_stdev
stdev = _add_batch_axis(stdev)
return mean, stdev

def _allocate(self):
Expand All @@ -190,37 +217,41 @@ def _allocate(self):
equizip(input_dim, broadcastable))
broadcastable = broadcastable

# "gamma", from the Ioffe & Szegedy manuscript.
self.scale = shared_floatx_nans(var_dim, name='batch_norm_scale',
broadcastable=broadcastable)

# "beta", from the Ioffe & Szegedy manuscript.
self.shift = shared_floatx_nans(var_dim, name='batch_norm_shift',
broadcastable=broadcastable)
add_role(self.scale, BATCH_NORM_SCALE_PARAMETER)
add_role(self.shift, BATCH_NORM_SHIFT_PARAMETER)
self.parameters.append(self.scale)
self.parameters.append(self.shift)

# These aren't technically parameters, in that they should not be
# learned using the same cost function as other model parameters.
self.population_mean = shared_floatx_zeros(var_dim,
name='population_mean',
broadcastable=broadcastable)
self.population_stdev = shared_floatx(numpy.ones(var_dim),
name='population_stdev',
broadcastable=broadcastable)
add_role(self.population_mean, BATCH_NORM_POPULATION_MEAN)
add_role(self.population_stdev, BATCH_NORM_POPULATION_STDEV)

# Normally these would get annotated by an AnnotatingList, but they
# aren't in self.parameters.
add_annotation(self.population_mean, self)
add_annotation(self.population_stdev, self)

if not self.mean_only:
# "gamma", from the Ioffe & Szegedy manuscript.
self.scale = shared_floatx_nans(var_dim, name='batch_norm_scale',
broadcastable=broadcastable)

add_role(self.scale, BATCH_NORM_SCALE_PARAMETER)
self.parameters.append(self.scale)

self.population_stdev = shared_floatx(numpy.ones(var_dim),
name='population_stdev',
broadcastable=broadcastable)
add_role(self.population_stdev, BATCH_NORM_POPULATION_STDEV)
add_annotation(self.population_stdev, self)

def _initialize(self):
self.shift_init.initialize(self.shift, self.rng)
self.scale_init.initialize(self.scale, self.rng)
if not self.mean_only:
self.scale_init.initialize(self.scale, self.rng)

# Needed for the Feedforward interface.
@property
Expand Down Expand Up @@ -295,6 +326,8 @@ class BatchNormalizedMLP(MLP):
----------
conserve_memory : bool, optional
See :class:`BatchNormalization`.
mean_only : bool, optional
See :class:`BatchNormalization`.
Notes
-----
Expand All @@ -312,10 +345,12 @@ class BatchNormalizedMLP(MLP):
"""
@lazy(allocation=['dims'])
def __init__(self, activations, dims, *args, **kwargs):
conserve_memory = kwargs.pop('conserve_memory', True)
self._conserve_memory = kwargs.pop('conserve_memory', True)
self._mean_only = kwargs.pop('mean_only', False)
activations = [
Sequence([
BatchNormalization(conserve_memory=conserve_memory).apply,
BatchNormalization(conserve_memory=self._conserve_memory,
mean_only=self._mean_only).apply,
act.apply
], name='batch_norm_activation_{}'.format(i))
for i, act in enumerate(activations)
Expand All @@ -326,16 +361,24 @@ def __init__(self, activations, dims, *args, **kwargs):
super(BatchNormalizedMLP, self).__init__(activations, dims, *args,
**kwargs)

@property
def conserve_memory(self):
return self._conserve_memory
def _nested_brick_property_getter(self, property_name):
return getattr(self, '_' + property_name)

@conserve_memory.setter
def conserve_memory(self, value):
self._conserve_memory = value
def _nested_brick_property_setter(self, value, property_name):
setattr(self, '_' + property_name, value)
for act in self.activations:
assert isinstance(act.children[0], BatchNormalization)
act.children[0].conserve_memory = value
setattr(act.children[0], property_name, value)

conserve_memory = property(partial(_nested_brick_property_getter,
property_name='conserve_memory'),
partial(_nested_brick_property_setter,
property_name='conserve_memory'))

mean_only = property(partial(_nested_brick_property_getter,
property_name='mean_only'),
partial(_nested_brick_property_setter,
property_name='mean_only'))

def _push_allocation_config(self):
super(BatchNormalizedMLP, self)._push_allocation_config()
Expand Down
5 changes: 4 additions & 1 deletion blocks/graph/bn.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,4 +266,7 @@ def extract_pair(brick_attribute, metadata_key, app_call):

mean_pair = partial(extract_pair, 'population_mean', 'offset')
stdev_pair = partial(extract_pair, 'population_stdev', 'divisor')
return sum([[mean_pair(a), stdev_pair(a)] for a in train_app_calls], [])
return sum([[mean_pair(a), stdev_pair(a)]
if not a.application.brick.mean_only
else [mean_pair(a)]
for a in train_app_calls], [])

0 comments on commit 47285eb

Please sign in to comment.