Skip to content

Commit

Permalink
Merge pull request #4266 from anaruse/add_axis_option_to_batch_normal…
Browse files Browse the repository at this point in the history
…ization

Add axis option to batch normalization
  • Loading branch information
toslunar committed Apr 2, 2018
2 parents 90f50b0 + 3b8c207 commit 8511bda
Show file tree
Hide file tree
Showing 3 changed files with 287 additions and 46 deletions.
167 changes: 124 additions & 43 deletions chainer/functions/normalization/batch_normalization.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import collections
import warnings

import numpy
Expand All @@ -15,12 +16,24 @@
libcudnn = cuda.cuda.cudnn


def _compute_axis(x_ndim, param_ndim=1, axis=None):
if axis is None:
axis = (0,) + tuple(range(param_ndim + 1, x_ndim))
return axis


def _compute_key_axis(x_ndim, param_ndim=1, axis=None):
axis = _compute_axis(x_ndim, param_ndim, axis)
key_axis = tuple([i for i in range(x_ndim) if i not in axis])
return key_axis


class BatchNormalization(function_node.FunctionNode):

mean = None
inv_std = None

def __init__(self, eps=2e-5, mean=None, var=None, decay=0.9):
def __init__(self, eps=2e-5, mean=None, var=None, decay=0.9, axis=None):
self.running_mean = mean
self.running_var = var

Expand All @@ -35,20 +48,40 @@ def __init__(self, eps=2e-5, mean=None, var=None, decay=0.9):
'cuDNN does not allow an eps value '
'less than {}.'.format(libcudnn.CUDNN_BN_MIN_EPSILON))
self.decay = decay
if isinstance(axis, collections.Sequence):
for i in range(1, len(axis)):
if axis[i - 1] >= axis[i]:
msg = 'numbers in axis must be sorted in ascending order'
raise RuntimeError(msg)
elif isinstance(axis, int):
axis = axis,
elif axis is not None:
raise RuntimeError('axis must be int, tuple of int or None')
self.axis = axis

def check_type_forward(self, in_types):
type_check.expect(in_types.size() == 3)
x_type, gamma_type, beta_type = in_types
M = type_check.eval(gamma_type.ndim)
type_check.expect(
x_type.dtype.kind == 'f',
x_type.ndim >= gamma_type.ndim + 1,
x_type.shape[1:1 + M] == gamma_type.shape,
# TODO(beam2d): Check shape
gamma_type.dtype == x_type.dtype,
beta_type.dtype == x_type.dtype,
gamma_type.shape == beta_type.shape,
)
_x_ndim = type_check.eval(x_type.ndim)
_gamma_ndim = type_check.eval(gamma_type.ndim)
_axis = _compute_axis(_x_ndim, _gamma_ndim, self.axis)
type_check.expect(
x_type.ndim >= len(_axis),
)
_key_axis = _compute_key_axis(_x_ndim, _gamma_ndim, _axis)
type_check.expect(
gamma_type.ndim == len(_key_axis),
)
for i in range(len(_key_axis)):
type_check.expect(
x_type.shape[_key_axis[i]] == gamma_type.shape[i],
)

def forward(self, inputs):
self.retain_inputs((0, 1))
Expand All @@ -70,17 +103,21 @@ def forward(self, inputs):
if self.running_mean is None:
self.running_mean = xp.zeros_like(gamma)
self.running_var = xp.zeros_like(gamma)
self.mode = _BNMode(x, gamma)

self.axis = _compute_axis(x.ndim, gamma.ndim, self.axis)
self.key_axis = _compute_key_axis(x.ndim, gamma.ndim, self.axis)

# TODO(niboshi): Refactor calculation of expander and axis into a
# function and call it just before they are used.

# expander inserts singleton dimensions to gamma and beta so that they
# can be broadcasted with x.
head_ndim = gamma.ndim + 1
expander = (None, Ellipsis) + (None,) * (x.ndim - head_ndim)
expander = [None for _ in range(x.ndim)]
for i in self.key_axis:
expander[i] = slice(None)
self.expander = expander
self.axis = (0,) + tuple(range(head_ndim, x.ndim))

self.mode = _BNMode(x, gamma, self.key_axis)
self.use_cudnn = self.mode.can_use_cudnn(xp)
self.use_ideep = self.mode.can_use_ideep()

Expand Down Expand Up @@ -134,9 +171,10 @@ def forward(self, inputs):
beta = cuda.cupy.ascontiguousarray(beta)
dtype = x.dtype
handle = cudnn.get_handle()
x_desc = cudnn.create_tensor_descriptor(_as4darray(x))
derivedBnDesc = cudnn.create_uninitialized_tensor_descriptor()
x_desc = cudnn.create_tensor_descriptor(
_as4darray(x, self.key_axis))
cudnn_mode = self.mode.get_cudnn_mode()
derivedBnDesc = cudnn.create_uninitialized_tensor_descriptor()
libcudnn.deriveBNTensorDescriptor(derivedBnDesc.value,
x_desc.value, cudnn_mode)
dtype_param = _get_dtype_of_tensor_descriptor(derivedBnDesc)
Expand Down Expand Up @@ -220,14 +258,14 @@ def backward(self, indexes, grad_outputs):

f = BatchNormalizationGrad(
self.eps, self.use_cudnn, self.mode, self.expander, self.axis,
self.mean, var, self.inv_std)
self.mean, var, self.inv_std, self.key_axis)
return f(x, gamma, gy)


class BatchNormalizationGrad(function.Function):

def __init__(
self, eps, use_cudnn, mode, expander, axis, mean, var, inv_std):
def __init__(self, eps, use_cudnn, mode, expander, axis, mean, var,
inv_std, key_axis):
self.eps = eps
self.use_cudnn = use_cudnn
self.use_ideep = mode.can_use_ideep()
Expand All @@ -237,6 +275,7 @@ def __init__(
self.mean = mean
self.var = var # Only used in iDeep implementation
self.inv_std = inv_std
self.key_axis = key_axis

def forward(self, inputs):
self.retain_inputs((0, 1, 2))
Expand Down Expand Up @@ -272,13 +311,14 @@ def forward(self, inputs):

elif self.use_cudnn:
# TODO(niboshi): Refactor cuDNN part into a separate method
cudnn_mode = self.mode.get_cudnn_mode()
x = cuda.cupy.ascontiguousarray(x)
gamma = cuda.cupy.ascontiguousarray(gamma)
gy = cuda.cupy.ascontiguousarray(gy)
dtype = x.dtype
handle = cudnn.get_handle()
x_desc = cudnn.create_tensor_descriptor(_as4darray(x))
x_desc = cudnn.create_tensor_descriptor(
_as4darray(x, self.key_axis))
cudnn_mode = self.mode.get_cudnn_mode()
derivedBnDesc = cudnn.create_uninitialized_tensor_descriptor()
libcudnn.deriveBNTensorDescriptor(derivedBnDesc.value,
x_desc.value, cudnn_mode)
Expand Down Expand Up @@ -368,17 +408,24 @@ class FixedBatchNormalization(function_node.FunctionNode):
inv_std = None
inv_var = None

def __init__(self, eps=2e-5):
def __init__(self, eps=2e-5, axis=None):
self.eps = eps
if isinstance(axis, collections.Sequence):
for i in range(1, len(axis)):
if axis[i - 1] >= axis[i]:
msg = 'numbers in axis must be sorted in ascending order'
raise RuntimeError(msg)
elif isinstance(axis, int):
axis = axis,
elif axis is not None:
raise RuntimeError('axis must be int, tuple of int or None')
self.axis = axis

def check_type_forward(self, in_types):
type_check.expect(in_types.size() == 5)
x_type, gamma_type, beta_type, mean_type, var_type = in_types
M = type_check.eval(gamma_type.ndim)
type_check.expect(
x_type.dtype.kind == 'f',
x_type.ndim >= gamma_type.ndim + 1,
x_type.shape[1:1 + M] == gamma_type.shape,
# TODO(beam2d): Check shape
gamma_type.dtype == x_type.dtype,
beta_type.dtype == x_type.dtype,
Expand All @@ -388,20 +435,37 @@ def check_type_forward(self, in_types):
var_type.dtype == x_type.dtype,
var_type.shape == gamma_type.shape,
)
_x_ndim = type_check.eval(x_type.ndim)
_gamma_ndim = type_check.eval(gamma_type.ndim)
_axis = _compute_axis(_x_ndim, _gamma_ndim, self.axis)
type_check.expect(
x_type.ndim >= len(_axis),
)
_key_axis = _compute_key_axis(_x_ndim, _gamma_ndim, _axis)
type_check.expect(
gamma_type.ndim == len(_key_axis),
)
for i in range(len(_key_axis)):
type_check.expect(
x_type.shape[_key_axis[i]] == gamma_type.shape[i],
)

def forward(self, inputs):
self.retain_inputs((0, 1, 3, 4))
x, gamma, beta, mean, var = inputs
xp = cuda.get_array_module(x)

self.axis = _compute_axis(x.ndim, gamma.ndim, self.axis)
self.key_axis = _compute_key_axis(x.ndim, gamma.ndim, self.axis)

# expander inserts singleton dimensions to gamma and beta so that they
# can be broadcasted with x.
head_ndim = gamma.ndim + 1
expander = (None, Ellipsis) + (None,) * (x.ndim - head_ndim)
expander = [None for _ in range(x.ndim)]
for i, j in enumerate(self.key_axis):
expander[j] = slice(gamma.shape[i])
self.expander = expander
self.axis = (0,) + tuple(range(head_ndim, x.ndim))

mode = _BNMode(x, gamma)
mode = _BNMode(x, gamma, self.key_axis)
if mode.can_use_ideep():
# TODO(niboshi): Refactor iDeep part into a separate method
expand_dim = False
Expand Down Expand Up @@ -436,9 +500,10 @@ def forward(self, inputs):
beta = cuda.cupy.ascontiguousarray(beta)
dtype = x.dtype
handle = cudnn.get_handle()
x_desc = cudnn.create_tensor_descriptor(_as4darray(x))
derivedBnDesc = cudnn.create_uninitialized_tensor_descriptor()
x_desc = cudnn.create_tensor_descriptor(
_as4darray(x, self.key_axis))
cudnn_mode = mode.get_cudnn_mode()
derivedBnDesc = cudnn.create_uninitialized_tensor_descriptor()
libcudnn.deriveBNTensorDescriptor(derivedBnDesc.value,
x_desc.value, cudnn_mode)
dtype_param = _get_dtype_of_tensor_descriptor(derivedBnDesc)
Expand Down Expand Up @@ -550,26 +615,24 @@ def backward(self, inputs, grad_outputs):

class _BNMode(object):

def __init__(self, x, gamma):
def __init__(self, x, gamma, key_axis):
is_gamma_1d = gamma.ndim == 1
# cuDNN only supports these tensor dimensions because they are
# the most commonly used. If there is a need to support other
# dimensions with cuDNN, we could consider reshaping the input
# into a 2-dim array with channels as second dim and m=<product
# of all dimensions except the 2nd dimension> as the first
# dimension.
self.is_for_conv2d = x.ndim == 4 and is_gamma_1d
self.is_for_linear = x.ndim == 2 and is_gamma_1d
self.is_for_conv2d = is_gamma_1d and x.ndim == 4 and key_axis[0] == 1
self.is_for_linear = is_gamma_1d and key_axis[0] == x.ndim - 1
self.cudnn_dim_ok = self.is_for_conv2d or self.is_for_linear
# self.cudnn_dtype_ok = x.dtype != numpy.float16
self.cudnn_dtype_ok = self.is_for_conv2d or (x.dtype != numpy.float16)
self.ideep_ok = is_gamma_1d and intel64.inputs_all_ready((x,))

def get_cudnn_mode(self):
assert self.cudnn_dim_ok
if self.is_for_conv2d:
return libcudnn.CUDNN_BATCHNORM_SPATIAL
return libcudnn.CUDNN_BATCHNORM_PER_ACTIVATION
return libcudnn.CUDNN_BATCHNORM_SPATIAL

def can_use_ideep(self):
return self.ideep_ok and intel64.should_use_ideep('>=auto')
Expand All @@ -583,13 +646,14 @@ def can_use_cudnn(self, xp):
self.cudnn_dtype_ok)


def _as4darray(arr):
if arr.ndim == 0:
return arr.reshape(1, 1, 1, 1)
elif arr.ndim == 4:
def _as4darray(arr, key_axis):
if arr.ndim == 4 and key_axis[0] == 1:
return arr
elif key_axis[0] == arr.ndim - 1:
return arr.reshape(numpy.prod(arr.shape[0:-1]), -1, 1, 1)
else:
return arr.reshape(arr.shape[0], -1, 1, 1)
msg = 'Unexpected combination of array shape and key_axis'
raise RuntimeError(msg)


def _get_mode(x, gamma):
Expand Down Expand Up @@ -701,6 +765,14 @@ def batch_normalization(x, gamma, beta, **kwargs):
be ``None``.
decay (float): Decay rate of moving average. It is used during
training.
axis (int or tuple of int): Axis over which normalization is
performed. When axis is ``None``, it is determined from input
dimensions. For example, if ``x.ndim`` is 4, axis becomes (0, 2, 3)
and normalization is performed over 0th, 2nd and 3rd axis of input.
If it is 2, axis becomes (0) and normalization is performed
over 0th axis of input. When a tuple of int is given to this
option, numbers in the tuple must be being sorted in ascending
order. For example, (0, 2) is OK, but (2, 0) is not.
See: `Batch Normalization: Accelerating Deep Network Training by Reducing\
Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`_
Expand All @@ -712,15 +784,15 @@ def batch_normalization(x, gamma, beta, **kwargs):
argument.check_unexpected_kwargs(
kwargs, train='train argument is not supported anymore. '
'Use chainer.using_config')
eps, running_mean, running_var, decay = argument.parse_kwargs(
eps, running_mean, running_var, decay, axis = argument.parse_kwargs(
kwargs, ('eps', 2e-5), ('running_mean', None),
('running_var', None), ('decay', 0.9))
('running_var', None), ('decay', 0.9), ('axis', None))

return BatchNormalization(eps, running_mean, running_var, decay).apply(
(x, gamma, beta))[0]
return BatchNormalization(eps, running_mean, running_var, decay,
axis).apply((x, gamma, beta))[0]


def fixed_batch_normalization(x, gamma, beta, mean, var, eps=2e-5):
def fixed_batch_normalization(x, gamma, beta, mean, var, eps=2e-5, axis=None):
"""Batch normalization function with fixed statistics.
This is a variant of batch normalization, where the mean and variance
Expand All @@ -735,10 +807,19 @@ def fixed_batch_normalization(x, gamma, beta, mean, var, eps=2e-5):
mean (Variable): Shifting parameter of input.
var (Variable): Square of scaling parameter of input.
eps (float): Epsilon value for numerical stability.
axis (int or tuple of int): Axis over which normalization is
performed. When axis is ``None``, it is determined from input
dimensions. For example, if ``x.ndim is 4``, axis becomes (0, 2, 3)
and normalization is performed over 0th, 2nd and 3rd axis of input.
If it is 2, axis becomes (0) and normalization is performed
over 0th axis of input. When a tuple of int is given to this
option, numbers in the tuple must be being sorted in ascending
order. For example, (0, 2) is OK, but (2, 0) is not.
.. seealso::
:func:`functions.batch_normalization`,
:class:`links.BatchNormalization`
"""
return FixedBatchNormalization(eps).apply((x, gamma, beta, mean, var))[0]
return FixedBatchNormalization(eps, axis).apply((x, gamma, beta, mean,
var))[0]
15 changes: 12 additions & 3 deletions chainer/links/normalization/batch_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,20 @@ class BatchNormalization(link.Link):
decay (float): Decay rate of moving average. It is used on training.
~BatchNormalization.eps (float): Epsilon value for numerical stability.
This value is added to the batch variances.
axis (int or tuple of int): Axis over which normalization is
performed. When axis is ``None``, it is determined from input
dimensions. For example, if ``x.ndim`` is 4, axis becomes (0, 2, 3)
and normalization is performed over 0th, 2nd and 3rd axis of input.
If it is 2, axis becomes (0) and normalization is performed
over 0th axis of input. When a tuple of int is given to this
option, numbers in the tuple must be being sorted in ascending
order. For example, (0, 2) is OK, but (2, 0) is not.
"""

def __init__(self, size, decay=0.9, eps=2e-5, dtype=numpy.float32,
use_gamma=True, use_beta=True,
initial_gamma=None, initial_beta=None):
initial_gamma=None, initial_beta=None, axis=None):
super(BatchNormalization, self).__init__()
self.avg_mean = numpy.zeros(size, dtype=dtype)
self.register_persistent('avg_mean')
Expand All @@ -74,6 +82,7 @@ def __init__(self, size, decay=0.9, eps=2e-5, dtype=numpy.float32,
self.register_persistent('N')
self.decay = decay
self.eps = eps
self.axis = axis

with self.init_scope():
if use_gamma:
Expand Down Expand Up @@ -141,13 +150,13 @@ def __call__(self, x, **kwargs):

ret = functions.batch_normalization(
x, gamma, beta, eps=self.eps, running_mean=self.avg_mean,
running_var=self.avg_var, decay=decay)
running_var=self.avg_var, decay=decay, axis=self.axis)
else:
# Use running average statistics or fine-tuned statistics.
mean = variable.Variable(self.avg_mean)
var = variable.Variable(self.avg_var)
ret = functions.fixed_batch_normalization(
x, gamma, beta, mean, var, self.eps)
x, gamma, beta, mean, var, self.eps, axis=self.axis)
return ret

def start_finetuning(self):
Expand Down
Loading

0 comments on commit 8511bda

Please sign in to comment.