-
Notifications
You must be signed in to change notification settings - Fork 1.4k
/
group_normalization.py
80 lines (61 loc) · 2.92 KB
/
group_normalization.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import warnings
from chainer import backend
from chainer.backends import cuda
from chainer.functions.array import broadcast
from chainer.functions.array import reshape
from chainer.functions.normalization import batch_normalization
def group_normalization(x, groups, gamma, beta, eps=1e-5):
"""Group normalization function.
This function implements a "group normalization"
which divides the channels into groups and computes within each group
the mean and variance, then normalize by these statistics,
scales and shifts them.
Args:
x (:class:`~chainer.Variable` or :ref:`ndarray`): Batch tensors.
First dimension of this value must be the size of minibatch and
second dimension must be the number of channels.
Moreover, this value must have one or more following dimensions,
such as height and width.
groups (int):
The number of channel groups.
This value must be a divisor of the number of channels.
gamma (:class:`~chainer.Variable` or :ref:`ndarray`):
Scaling parameter.
beta (:class:`~chainer.Variable` or :ref:`ndarray`):
Shifting parameter.
eps (float): Epsilon value for numerical stability of normalization.
Returns:
~chainer.Variable: The output variable which has the same shape
as :math:`x`.
See: `Group Normalization <https://arxiv.org/abs/1803.08494>`_
"""
if x.ndim <= 2:
raise ValueError('Input dimension must be grater than 2, '
'including batch size dimension '
'(first dimension).')
if not isinstance(groups, int):
raise TypeError('Argument: \'groups\' type must be (int).')
xp = backend.get_array_module(x)
batch_size, channels = x.shape[:2]
original_shape = x.shape
if channels % groups != 0:
raise ValueError('Argument: \'groups\' must be a divisor '
'of the number of channel.')
# By doing this reshaping, calling batch_normalization function becomes
# equivalent to Group Normalization.
# And redundant dimension is added in order to utilize ideep64/cuDNN.
x = reshape.reshape(x, (1, batch_size * groups, -1, 1))
with cuda.get_device_from_array(x.array):
dummy_gamma = xp.ones(batch_size * groups).astype(xp.float32)
dummy_beta = xp.zeros(batch_size * groups).astype(xp.float32)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
x = batch_normalization.batch_normalization(
x, dummy_gamma, dummy_beta, eps=eps)
x = reshape.reshape(x, original_shape)
target_shape = [1, channels] + [1] * (x.ndim - 2)
gamma_broadcast = broadcast.broadcast_to(
reshape.reshape(gamma, target_shape), x.shape)
beta_broadcast = broadcast.broadcast_to(
reshape.reshape(beta, target_shape), x.shape)
return x * gamma_broadcast + beta_broadcast