-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
728 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
225 changes: 225 additions & 0 deletions
225
chainer/functions/normalization/batch_renormalization.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,225 @@ | ||
import numpy | ||
|
||
from chainer import configuration | ||
from chainer import cuda | ||
|
||
from chainer import function | ||
from chainer.utils import type_check | ||
|
||
|
||
def _as4darray(arr): | ||
if arr.ndim == 0: | ||
return arr.reshape(1, 1, 1, 1) | ||
elif arr.ndim == 4: | ||
return arr | ||
else: | ||
return arr.reshape(arr.shape[0], -1, 1, 1) | ||
|
||
|
||
def _xhat(x, mean, std, expander): | ||
x_mu = x - mean[expander] | ||
x_mu /= std[expander] | ||
return x_mu | ||
|
||
|
||
class BatchRenormalizationFunction(function.Function): | ||
|
||
def __init__(self, eps=2e-5, mean=None, var=None, decay=0.9, | ||
rmax=1, dmax=0, freeze_running_statistics=False): | ||
self.running_mean = mean | ||
self.running_var = var | ||
self.rmax = rmax | ||
self.dmax = dmax | ||
self.r = None | ||
self.d = None | ||
self.freeze_running_statistics = freeze_running_statistics | ||
|
||
self.eps = eps | ||
self.mean_cache = None | ||
self.decay = decay | ||
|
||
def check_type_forward(self, in_types): | ||
n_in = type_check.eval(in_types.size()) | ||
if n_in != 3 and n_in != 5: | ||
raise type_check.InvalidType( | ||
'%s or %s' % (in_types.size() == 3, in_types.size() == 5), | ||
'%s == %s' % (in_types.size(), n_in)) | ||
x_type, gamma_type, beta_type = in_types[:3] | ||
M = type_check.eval(gamma_type.ndim) | ||
type_check.expect( | ||
x_type.dtype.kind == 'f', | ||
x_type.ndim >= gamma_type.ndim + 1, | ||
x_type.shape[1:1 + M] == gamma_type.shape, | ||
# TODO(tkerola): Check shape | ||
gamma_type.dtype == x_type.dtype, | ||
beta_type.dtype == x_type.dtype, | ||
gamma_type.shape == beta_type.shape, | ||
) | ||
if len(in_types) == 5: | ||
mean_type, var_type = in_types[3:] | ||
type_check.expect( | ||
mean_type.dtype == x_type.dtype, | ||
mean_type.shape == gamma_type.shape, | ||
var_type.dtype == x_type.dtype, | ||
var_type.shape == gamma_type.shape, | ||
) | ||
|
||
def forward(self, inputs): | ||
xp = cuda.get_array_module(*inputs) | ||
x, gamma, beta = inputs[:3] | ||
|
||
# Note: If length of inputs is not 5, we must be in train mode. | ||
if len(inputs) != 5: | ||
assert configuration.config.train | ||
|
||
if configuration.config.train: | ||
if self.running_mean is None: | ||
self.running_mean = xp.zeros_like(gamma) | ||
self.running_var = xp.zeros_like(gamma) | ||
else: | ||
self.running_mean = xp.array(self.running_mean) | ||
self.running_var = xp.array(self.running_var) | ||
elif len(inputs) == 5: | ||
fixed_mean = inputs[3] | ||
fixed_var = inputs[4] | ||
|
||
head_ndim = gamma.ndim + 1 | ||
expander = (None, Ellipsis) + (None,) * (x.ndim - head_ndim) | ||
|
||
# NOTE(tommi): cuDNN is not used since it does not support | ||
# batch renormalization | ||
if configuration.config.train: | ||
axis = (0,) + tuple(range(head_ndim, x.ndim)) | ||
mean = x.mean(axis=axis) | ||
var = x.var(axis=axis) + self.eps | ||
else: | ||
mean = fixed_mean | ||
var = fixed_var + self.eps | ||
self.std = xp.sqrt(var, dtype=var.dtype) | ||
|
||
if not self.freeze_running_statistics or self.r is None: | ||
if configuration.config.train: | ||
running_sigma = xp.sqrt(self.running_var + self.eps, | ||
dtype=self.running_mean.dtype) | ||
self.r = xp.clip(self.std / running_sigma, | ||
1.0 / self.rmax, self.rmax) | ||
self.d = xp.clip((mean - self.running_mean) / running_sigma, | ||
-self.dmax, self.dmax) | ||
|
||
# Update running statistics: | ||
m = x.size // gamma[expander].size | ||
self.running_mean *= self.decay | ||
adjust = m / max(m - 1., 1.) # unbiased estimation | ||
temp_ar = xp.array(mean) | ||
temp_ar *= (1 - self.decay) | ||
self.running_mean += temp_ar | ||
del temp_ar | ||
self.running_var *= self.decay | ||
temp_ar = xp.array(var) | ||
temp_ar *= (1 - self.decay) * adjust | ||
self.running_var += temp_ar | ||
del temp_ar | ||
else: | ||
self.r = xp.ones_like(gamma) | ||
self.d = xp.zeros_like(gamma) | ||
|
||
if self.freeze_running_statistics: | ||
# Need to explicitly cast during gradient check, as r and d are | ||
# not updated during finite differences | ||
self.r = self.r.astype(gamma.dtype) | ||
self.d = self.d.astype(gamma.dtype) | ||
|
||
gamma = gamma[expander] | ||
beta = beta[expander] | ||
|
||
if xp is numpy: | ||
self.x_hat = _xhat(x, mean, self.std, expander) | ||
self.x_hat_renorm = self.x_hat * self.r[expander] + \ | ||
self.d[expander] | ||
y = gamma * self.x_hat_renorm | ||
y += beta | ||
else: | ||
self.x_hat, self.x_hat_renorm, y = cuda.elementwise( | ||
'T x, T mean, T std, T gamma, T beta, T r, T d', | ||
'T x_hat, T x_hat_renorm, T y', | ||
''' | ||
x_hat = (x - mean) / std; | ||
x_hat_renorm = x_hat * r + d; | ||
y = gamma * x_hat_renorm + beta; | ||
''', | ||
'bn_fwd')(x, mean[expander], self.std[expander], gamma, | ||
beta, self.r[expander], self.d[expander]) | ||
|
||
return y, | ||
|
||
def backward(self, inputs, grad_outputs): | ||
x, gamma = inputs[:2] | ||
gy = grad_outputs[0] | ||
head_ndim = gamma.ndim + 1 | ||
expander = (None, Ellipsis) + (None,) * (x.ndim - head_ndim) | ||
m = gamma.dtype.type(x.size // gamma.size) | ||
axis = (0,) + tuple(range(head_ndim, x.ndim)) | ||
xp = cuda.get_array_module(x) | ||
if len(inputs) == 5: | ||
# This case is unlikely to be used in practice and so does not | ||
# need to be optimized for performance. | ||
mean = inputs[3] | ||
var = inputs[4] | ||
std = xp.sqrt(var, dtype=var.dtype) | ||
gs = gamma / std | ||
gbeta = gy.sum(axis=axis) | ||
x_hat = _xhat(x, mean, std, expander) | ||
ggamma = (gy * x_hat).sum(axis=axis) | ||
gmean = -gs * gbeta | ||
gvar = -0.5 * gamma / var * ggamma | ||
gx = gs[expander] * gy | ||
return gx, ggamma, gbeta, gmean, gvar | ||
|
||
# Note: If length of inputs is not 5, we must be in train mode. | ||
assert configuration.config.train | ||
# NOTE(tommi): cuDNN is not used since it does not support | ||
# batch renormalization | ||
gbeta = gy.sum(axis=axis) | ||
ggamma = (gy * self.x_hat_renorm).sum(axis=axis) | ||
gsigma_batch = (gy * self.x_hat).sum(axis=axis) | ||
if xp is numpy: | ||
scale = (self.r * gamma / self.std)[expander] | ||
gx = scale * (gy - (self.x_hat * gsigma_batch[expander] + | ||
gbeta[expander]) / m) | ||
else: | ||
inv_m = numpy.float32(1) / m | ||
gx = cuda.elementwise( | ||
'T gy, T x_hat, T gamma, T std, T gsigma_batch, T gbeta, \ | ||
T inv_m, T r', | ||
'T gx', | ||
'gx = (r * gamma / std) * (gy - (x_hat * gsigma_batch + gbeta) * \ | ||
inv_m)', | ||
'bn_bwd')(gy, self.x_hat, gamma[expander], | ||
self.std[expander], gsigma_batch[expander], | ||
gbeta[expander], inv_m, self.r[expander]) | ||
return gx, ggamma, gbeta | ||
|
||
|
||
def batch_renormalization(x, gamma, beta, rmax, dmax, eps=2e-5, | ||
running_mean=None, running_var=None, decay=0.9): | ||
"""Batch renormalization function. | ||
This is an extension of batch normalization, which ensures that the | ||
training and inference models generate the same outputs that depend on | ||
individual examples rather than the entire minibatch. | ||
See: `Batch Renormalization: Towards Reducing Minibatch Dependence in \ | ||
Batch-Normalized Models <https://arxiv.org/abs/1702.03275>`_ | ||
.. seealso:: :class:`links.BatchRenormalization` | ||
.. seealso:: :func:`functions.BatchNormalization` | ||
""" | ||
return BatchRenormalizationFunction(eps, running_mean, running_var, | ||
decay, rmax, dmax)(x, gamma, beta) | ||
|
||
|
||
def fixed_batch_renormalization(x, gamma, beta, mean, var, eps=2e-5): | ||
with configuration.using_config('train', False): | ||
return BatchRenormalizationFunction(eps, None, None, 0.0)( | ||
x, gamma, beta, mean, var) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
import numpy | ||
|
||
from chainer import configuration | ||
from chainer import cuda | ||
from chainer.functions.normalization import batch_renormalization | ||
from chainer.links.normalization.batch_normalization import BatchNormalization | ||
from chainer import variable | ||
|
||
|
||
class BatchRenormalization(BatchNormalization): | ||
|
||
"""Batch renormalization layer on outputs of linear or convolution functions. | ||
This link wraps the :func:`~chainer.functions.batch_renormalization` and | ||
:func:`~chainer.functions.fixed_batch_renormalization` functions. | ||
This is an extension of batch normalization, which ensures that the | ||
training and inference models generate the same outputs that depend on | ||
individual examples rather than the entire minibatch. | ||
See: `Batch Renormalization: Towards Reducing Minibatch Dependence in \ | ||
Batch-Normalized Models <https://arxiv.org/abs/1702.03275>`_ | ||
.. seealso:: | ||
:func:`~chainer.functions.batch_renormalization`, | ||
:func:`~chainer.functions.fixed_batch_renormalization` | ||
:func:`~chainer.functions.batch_normalization`, | ||
""" | ||
|
||
def __init__(self, size, rmax=1, dmax=0, decay=0.9, eps=2e-5, | ||
dtype=numpy.float32, use_gamma=True, use_beta=True, | ||
initial_gamma=None, initial_beta=None, | ||
freeze_running_statistics=False): | ||
super(BatchRenormalization, self).__init__(size, decay, eps, dtype, | ||
use_gamma, use_beta, | ||
initial_gamma, initial_beta) | ||
self.rmax = rmax # maximum allowed correction of variance | ||
self.dmax = dmax # maximum allowed correction of mean | ||
self.r = None | ||
self.d = None | ||
self.freeze_running_statistics = freeze_running_statistics | ||
|
||
def __call__(self, x, finetune=False): | ||
if hasattr(self, 'gamma'): | ||
gamma = self.gamma | ||
else: | ||
with cuda.get_device(self._device_id): | ||
gamma = variable.Variable(self.xp.ones( | ||
self.avg_mean.shape, dtype=x.dtype)) | ||
if hasattr(self, 'beta'): | ||
beta = self.beta | ||
else: | ||
with cuda.get_device(self._device_id): | ||
beta = variable.Variable(self.xp.zeros( | ||
self.avg_mean.shape, dtype=x.dtype)) | ||
|
||
if configuration.config.train: | ||
if finetune: | ||
self.N += 1 | ||
decay = 1. - 1. / self.N | ||
else: | ||
decay = self.decay | ||
|
||
func = batch_renormalization.BatchRenormalizationFunction( | ||
self.eps, self.avg_mean, self.avg_var, decay, | ||
self.rmax, self.dmax, self.freeze_running_statistics) | ||
if self.freeze_running_statistics: | ||
func.r = self.r | ||
func.d = self.d | ||
ret = func(x, gamma, beta) | ||
if self.freeze_running_statistics and self.r is None: | ||
self.r = func.r | ||
self.d = func.d | ||
|
||
self.avg_mean[:] = func.running_mean | ||
self.avg_var[:] = func.running_var | ||
else: | ||
# Use running average statistics or fine-tuned statistics. | ||
mean = variable.Variable(self.avg_mean) | ||
var = variable.Variable(self.avg_var) | ||
ret = batch_renormalization.fixed_batch_renormalization( | ||
x, gamma, beta, mean, var, self.eps) | ||
return ret |
Oops, something went wrong.