Skip to content

Commit

Permalink
Merge 4bce43e into 400c07d
Browse files Browse the repository at this point in the history
  • Loading branch information
toslunar committed Jul 19, 2019
2 parents 400c07d + 4bce43e commit bd74f1a
Show file tree
Hide file tree
Showing 4 changed files with 335 additions and 112 deletions.
172 changes: 109 additions & 63 deletions chainer/functions/normalization/decorrelated_batch_normalization.py
@@ -1,11 +1,52 @@
import numpy

from chainer import backend
from chainer import function_node
from chainer.utils import argument
from chainer.utils import type_check


# {numpy: True, cupy: False}
_xp_supports_batch_eigh = {}


# routines for batched matrices
def _eigh(a, xp):
if xp not in _xp_supports_batch_eigh:
try:
xp.linalg.eigh(xp.ones((2, 2, 2), xp.float32))
except ValueError:
_xp_supports_batch_eigh[xp] = False
else:
_xp_supports_batch_eigh[xp] = True
if _xp_supports_batch_eigh[xp]:
return xp.linalg.eigh(a)
ws = []
vs = []
for ai in a:
w, v = xp.linalg.eigh(ai)
ws.append(w)
vs.append(v)
return xp.stack(ws), xp.stack(vs)


def _matmul(a, b, xp):
if hasattr(xp, 'matmul'): # numpy.matmul is supported from version 1.10.0
return xp.matmul(a, b)
else:
return xp.einsum('bij,bjk->bik', a, b)


def _diag(a, xp):
s0, s1 = a.shape
ret = xp.zeros((s0, s1, s1), a.dtype)
arange_s1 = numpy.arange(s1)
ret[:, arange_s1, arange_s1] = a
return ret


def _calc_axis_and_m(x_shape, batch_size, groups):
m = batch_size * groups
m = batch_size
spatial_ndim = len(x_shape) - 2
spatial_axis = tuple(range(2, 2 + spatial_ndim))
for i in spatial_axis:
Expand Down Expand Up @@ -47,34 +88,38 @@ def forward(self, inputs):
C = c // g
spatial_axis, m = _calc_axis_and_m(x_shape, b, g)

if g > 1:
x = x.reshape((b * g, C) + x.shape[2:])
x_hat = x.transpose((1, 0) + spatial_axis).reshape(C, -1)

mean = x_hat.mean(axis=1)
x_hat = x_hat - mean[:, None]
# (g, C, m)
x_hat = x.transpose((1, 0) + spatial_axis).reshape(g, C, m)
mean = x_hat.mean(axis=2, keepdims=True)
x_hat = x_hat - mean
self.eps = x.dtype.type(self.eps)

eps_matrix = self.eps * xp.eye(C, dtype=x.dtype)
cov = x_hat.dot(x_hat.T) / x.dtype.type(m) + eps_matrix
self.eigvals, self.eigvectors = xp.linalg.eigh(cov)
U = xp.diag(self.eigvals ** -0.5).dot(self.eigvectors.T)
self.y_hat_pca = U.dot(x_hat) # PCA whitening
y_hat = self.eigvectors.dot(self.y_hat_pca) # ZCA whitening

y = y_hat.reshape((C, b * g,) + x.shape[2:]).transpose(
cov = _matmul(
x_hat, x_hat.transpose(0, 2, 1),
xp) / x.dtype.type(m) + eps_matrix
# (g, C), (g, C, C)
self.eigvals, self.eigvectors = _eigh(cov, xp)
U = _matmul(
_diag(self.eigvals ** -0.5, xp),
self.eigvectors.transpose(0, 2, 1),
xp)
self.y_hat_pca = _matmul(U, x_hat, xp) # PCA whitening
# ZCA whitening
y_hat = _matmul(self.eigvectors, self.y_hat_pca, xp)

y = y_hat.reshape((c, b) + x_shape[2:]).transpose(
(1, 0) + spatial_axis)
if self.groups > 1:
y = y.reshape((-1, c) + x.shape[2:])

# Update running statistics
if self.running_mean is not None:
mean = mean.squeeze(axis=2)
self.running_mean *= self.decay
self.running_mean += (1 - self.decay) * mean
if self.running_projection is not None:
adjust = m / max(m - 1., 1.) # unbiased estimation
self.running_projection *= self.decay
projection = self.eigvectors.dot(U)
projection = _matmul(self.eigvectors, U, xp)
self.running_projection += (1 - self.decay) * adjust * projection

return y,
Expand Down Expand Up @@ -104,38 +149,45 @@ def forward(self, inputs):
g = self.groups
C = c // g
spatial_axis, m = _calc_axis_and_m(gy_shape, b, g)
arange_C = numpy.arange(C)
diag_indices = slice(None), arange_C, arange_C

if g > 1:
gy = gy.reshape((b * g, C) + gy.shape[2:])
gy_hat = gy.transpose((1, 0) + spatial_axis).reshape(C, -1)
gy_hat = gy.transpose((1, 0) + spatial_axis).reshape(g, C, m)

eigvectors = self.eigvectors
eigvals = self.eigvals
y_hat_pca = self.y_hat_pca
gy_hat_pca = eigvectors.T.dot(gy_hat)
f = gy_hat_pca.mean(axis=1)

K = eigvals[:, None] - eigvals[None, :]
valid = K != 0
K[valid] = 1 / K[valid]
xp.fill_diagonal(K, 0)

V = xp.diag(eigvals)
V_sqrt = xp.diag(eigvals ** 0.5)
V_invsqrt = xp.diag(eigvals ** -0.5)

F_c = gy_hat_pca.dot(y_hat_pca.T) / gy.dtype.type(m)
M = xp.diag(xp.diag(F_c))

mat = K.T * (V.dot(F_c.T) + V_sqrt.dot(F_c).dot(V_sqrt))
S = mat + mat.T
R = gy_hat_pca - f[:, None] + (S - M).T.dot(y_hat_pca)
gx_hat = R.T.dot(V_invsqrt).dot(eigvectors.T).T

gx = gx_hat.reshape((C, b * g,) + gy.shape[2:]).transpose(
gy_hat_pca = _matmul(eigvectors.transpose(0, 2, 1), gy_hat, xp)
f = gy_hat_pca.mean(axis=2, keepdims=True)

K = eigvals[:, :, None] - eigvals[:, None, :]
valid = K != 0 # to avoid nan, use eig_i != eig_j instead of i != j
K[valid] = xp.reciprocal(K[valid])

V = _diag(eigvals, xp)
V_sqrt = _diag(eigvals ** 0.5, xp)
V_invsqrt = _diag(eigvals ** -0.5, xp)

F_c = _matmul(
gy_hat_pca, y_hat_pca.transpose(0, 2, 1),
xp) / gy.dtype.type(m)
M = xp.zeros_like(F_c)
M[diag_indices] = F_c[diag_indices]

mat = K.transpose(0, 2, 1) * (
_matmul(V, F_c.transpose(0, 2, 1), xp)
+ _matmul(_matmul(V_sqrt, F_c, xp), V_sqrt, xp)
)
S = mat + mat.transpose(0, 2, 1)
R = gy_hat_pca - f + _matmul(
(S - M).transpose(0, 2, 1), y_hat_pca, xp)
gx_hat = _matmul(
_matmul(R.transpose(0, 2, 1), V_invsqrt, xp),
eigvectors.transpose(0, 2, 1), xp
).transpose(0, 2, 1)

gx = gx_hat.reshape((c, b) + gy_shape[2:]).transpose(
(1, 0) + spatial_axis)
if g > 1:
gx = gx.reshape((-1, c, ) + gy.shape[2:])

self.retain_outputs(())
return gx,
Expand Down Expand Up @@ -166,22 +218,20 @@ def check_type_forward(self, in_types):
def forward(self, inputs):
self.retain_inputs((0, 1, 2))
x, mean, projection = inputs
xp = backend.get_array_module(x)
x_shape = x.shape
b, c = x_shape[:2]
g = self.groups
C = c // g
spatial_axis, m = _calc_axis_and_m(x_shape, b, g)

if g > 1:
x = x.reshape((b * g, C) + x.shape[2:])
x_hat = x.transpose((1, 0) + spatial_axis).reshape(C, -1)
x_hat = x.transpose((1, 0) + spatial_axis).reshape(g, C, m)
x_hat = x_hat - xp.expand_dims(mean, axis=2)

y_hat = projection.dot(x_hat - mean[:, None])
y_hat = _matmul(projection, x_hat, xp)

y = y_hat.reshape((C, b * g) + x.shape[2:]).transpose(
y = y_hat.reshape((c, b) + x_shape[2:]).transpose(
(1, 0) + spatial_axis)
if g > 1:
y = y.reshape((-1, c) + x.shape[2:])

return y,

Expand All @@ -200,25 +250,21 @@ def __init__(self, groups):
def forward(self, inputs):
self.retain_inputs(())
x, mean, projection, gy = inputs
xp = backend.get_array_module(x)
gy_shape = gy.shape
b, c = gy_shape[:2]
g = self.groups
C = c // g
spatial_axis, m = _calc_axis_and_m(gy_shape, b, g)

if g > 1:
gy = gy.reshape((b * g, C) + gy.shape[2:])
x = x.reshape((b * g, C) + x.shape[2:])
x_hat = x.transpose((1, 0) + spatial_axis).reshape(C, -1)
gy_hat = gy.transpose((1, 0) + spatial_axis).reshape(C, -1)
gy_hat_pca = projection.T.dot(gy_hat)
gx = gy_hat_pca.reshape(
(C, b * g) + gy.shape[2:]).transpose((1, 0) + spatial_axis)
if g > 1:
gx = gx.reshape((-1, c) + gy.shape[2:])
rhs = x_hat - mean[Ellipsis, None]
gprojection = (x_hat - rhs).T.dot(gy_hat)
gmean = -gx[:, 0]
gy_hat = gy.transpose((1, 0) + spatial_axis).reshape(g, C, m)
x_hat = x.transpose((1, 0) + spatial_axis).reshape(g, C, m)
gy_hat_pca = _matmul(projection.transpose(0, 2, 1), gy_hat, xp)
gx = gy_hat_pca.reshape((c, b) + gy_shape[2:]).transpose(
(1, 0) + spatial_axis)
rhs = x_hat - xp.expand_dims(mean, axis=2)
gprojection = _matmul((x_hat - rhs).transpose(0, 2, 1), gy_hat, xp)
gmean = -gy_hat_pca[..., 0]
self.retain_outputs(())
return gx, gmean, gprojection

Expand Down
82 changes: 80 additions & 2 deletions chainer/links/normalization/decorrelated_batch_normalization.py
@@ -1,8 +1,13 @@
import functools
import warnings

import numpy

import chainer
from chainer import configuration
from chainer import functions
from chainer import link
import chainer.serializer as serializer_mod
from chainer.utils import argument


Expand Down Expand Up @@ -62,16 +67,30 @@ class DecorrelatedBatchNormalization(link.Link):
def __init__(self, size, groups=16, decay=0.9, eps=2e-5,
dtype=numpy.float32):
super(DecorrelatedBatchNormalization, self).__init__()
self.avg_mean = numpy.zeros(size // groups, dtype=dtype)
C = size // groups
self.avg_mean = numpy.zeros((groups, C), dtype=dtype)
self.register_persistent('avg_mean')
self.avg_projection = numpy.eye(size // groups, dtype=dtype)
avg_projection = numpy.zeros((groups, C, C), dtype=dtype)
arange_C = numpy.arange(C)
avg_projection[:, arange_C, arange_C] = 1
self.avg_projection = avg_projection
self.register_persistent('avg_projection')
self.N = 0
self.register_persistent('N')
self.decay = decay
self.eps = eps
self.groups = groups

def serialize(self, serializer):
if isinstance(serializer, serializer_mod.Deserializer):
serializer = _PatchedDeserializer(serializer, {
'avg_mean': functools.partial(
fix_avg_mean, groups=self.groups),
'avg_projection': functools.partial(
fix_avg_projection, groups=self.groups),
})
super(DecorrelatedBatchNormalization, self).serialize(serializer)

def forward(self, x, **kwargs):
"""forward(self, x, *, finetune=False)
Expand Down Expand Up @@ -131,3 +150,62 @@ def start_finetuning(self):
"""
self.N = 0


class _PatchedDeserializer(serializer_mod.Deserializer):

def __init__(self, base, patches):
self.base = base
self.patches = patches

def __repr__(self):
return '_PatchedDeserializer({}, {})'.format(
repr(self.base), repr(self.patches))

def __call__(self, key, value):
if key not in self.patches:
return self.base(key, value)
arr = self.base(key, None)
arr = self.patches[key](arr)
if value is None:
return arr
chainer.backend.copyto(value, arr)
return value


def _warn_old_model():
msg = (
'Found moving statistics of old DecorrelatedBatchNormalization, whose '
'algorithm was different from the paper.')
warnings.warn(msg)


def fix_avg_mean(avg_mean, groups):
if avg_mean.ndim == 2: # OK
return avg_mean
elif avg_mean.ndim == 1: # Issue #7706
if groups != 1:
_warn_old_model()
return _broadcast_to(avg_mean, (groups,) + avg_mean.shape)
raise ValueError('unexpected shape of avg_mean')


def fix_avg_projection(avg_projection, groups):
if avg_projection.ndim == 3: # OK
return avg_projection
elif avg_projection.ndim == 2: # Issue #7706
if groups != 1:
_warn_old_model()
return _broadcast_to(
avg_projection, (groups,) + avg_projection.shape)
raise ValueError('unexpected shape of avg_projection')


def _broadcast_to(array, shape):
if hasattr(numpy, 'broadcast_to'):
return numpy.broadcast_to(array, shape)
else:
# numpy 1.9 doesn't support broadcast_to method
dummy = numpy.empty(shape)
bx, _ = numpy.broadcast_arrays(array, dummy)
return bx

0 comments on commit bd74f1a

Please sign in to comment.