/
decorrelated_batch_normalization.py
211 lines (169 loc) · 7.49 KB
/
decorrelated_batch_normalization.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
import functools
import warnings
import numpy
import chainer
from chainer import configuration
from chainer import functions
from chainer import link
import chainer.serializer as serializer_mod
from chainer.utils import argument
class DecorrelatedBatchNormalization(link.Link):
"""Decorrelated batch normalization layer.
This link wraps the
:func:`~chainer.functions.decorrelated_batch_normalization` and
:func:`~chainer.functions.fixed_decorrelated_batch_normalization`
functions. It works on outputs of linear or convolution functions.
It runs in three modes: training mode, fine-tuning mode, and testing mode.
In training mode, it normalizes the input by *batch statistics*. It also
maintains approximated population statistics by moving averages, which can
be used for instant evaluation in testing mode.
In fine-tuning mode, it accumulates the input to compute *population
statistics*. In order to correctly compute the population statistics, a
user must use this mode to feed mini-batches running through whole training
dataset.
In testing mode, it uses pre-computed population statistics to normalize
the input variable. The population statistics is approximated if it is
computed by training mode, or accurate if it is correctly computed by
fine-tuning mode.
Args:
size (int or tuple of ints): Size (or shape) of channel
dimensions.
groups (int): Number of groups to use for group whitening.
decay (float): Decay rate of moving average
which is used during training.
eps (float): Epsilon value for numerical stability.
dtype (numpy.dtype): Type to use in computing.
See: `Decorrelated Batch Normalization <https://arxiv.org/abs/1804.08450>`_
.. seealso::
:func:`~chainer.functions.decorrelated_batch_normalization`,
:func:`~chainer.functions.fixed_decorrelated_batch_normalization`
Attributes:
avg_mean (:ref:`ndarray`): Population mean.
avg_projection (:ref:`ndarray`): Population
projection.
groups (int): Number of groups to use for group whitening.
N (int): Count of batches given for fine-tuning.
decay (float): Decay rate of moving average
which is used during training.
~DecorrelatedBatchNormalization.eps (float): Epsilon value for
numerical stability. This value is added to the batch variances.
"""
def __init__(self, size, groups=16, decay=0.9, eps=2e-5,
dtype=numpy.float32):
super(DecorrelatedBatchNormalization, self).__init__()
C = size // groups
self.avg_mean = numpy.zeros((groups, C), dtype=dtype)
self.register_persistent('avg_mean')
avg_projection = numpy.zeros((groups, C, C), dtype=dtype)
arange_C = numpy.arange(C)
avg_projection[:, arange_C, arange_C] = 1
self.avg_projection = avg_projection
self.register_persistent('avg_projection')
self.N = 0
self.register_persistent('N')
self.decay = decay
self.eps = eps
self.groups = groups
def serialize(self, serializer):
if isinstance(serializer, serializer_mod.Deserializer):
serializer = _PatchedDeserializer(serializer, {
'avg_mean': functools.partial(
fix_avg_mean, groups=self.groups),
'avg_projection': functools.partial(
fix_avg_projection, groups=self.groups),
})
super(DecorrelatedBatchNormalization, self).serialize(serializer)
def forward(self, x, **kwargs):
"""forward(self, x, *, finetune=False)
Invokes the forward propagation of DecorrelatedBatchNormalization.
In training mode, the DecorrelatedBatchNormalization computes moving
averages of the mean and projection for evaluation during training,
and normalizes the input using batch statistics.
Args:
x (:class:`~chainer.Variable`): Input variable.
finetune (bool): If it is in the training mode and ``finetune`` is
``True``, DecorrelatedBatchNormalization runs in fine-tuning
mode; it accumulates the input array to compute population
statistics for normalization, and normalizes the input using
batch statistics.
"""
finetune, = argument.parse_kwargs(kwargs, ('finetune', False))
if configuration.config.train:
if finetune:
self.N += 1
decay = 1. - 1. / self.N
else:
decay = self.decay
avg_mean = self.avg_mean
avg_projection = self.avg_projection
if configuration.config.in_recomputing:
# Do not update statistics when extra forward computation is
# called.
if finetune:
self.N -= 1
avg_mean = None
avg_projection = None
ret = functions.decorrelated_batch_normalization(
x, groups=self.groups, eps=self.eps,
running_mean=avg_mean, running_projection=avg_projection,
decay=decay)
else:
# Use running average statistics or fine-tuned statistics.
mean = self.avg_mean
projection = self.avg_projection
ret = functions.fixed_decorrelated_batch_normalization(
x, mean, projection, groups=self.groups)
return ret
def start_finetuning(self):
"""Resets the population count for collecting population statistics.
This method can be skipped if it is the first time to use the
fine-tuning mode. Otherwise, this method should be called before
starting the fine-tuning mode again.
"""
self.N = 0
class _PatchedDeserializer(serializer_mod.Deserializer):
def __init__(self, base, patches):
self.base = base
self.patches = patches
def __repr__(self):
return '_PatchedDeserializer({}, {})'.format(
repr(self.base), repr(self.patches))
def __call__(self, key, value):
if key not in self.patches:
return self.base(key, value)
arr = self.base(key, None)
arr = self.patches[key](arr)
if value is None:
return arr
chainer.backend.copyto(value, arr)
return value
def _warn_old_model():
msg = (
'Found moving statistics of old DecorrelatedBatchNormalization, whose '
'algorithm was different from the paper.')
warnings.warn(msg)
def fix_avg_mean(avg_mean, groups):
if avg_mean.ndim == 2: # OK
return avg_mean
elif avg_mean.ndim == 1: # Issue #7706
if groups != 1:
_warn_old_model()
return _broadcast_to(avg_mean, (groups,) + avg_mean.shape)
raise ValueError('unexpected shape of avg_mean')
def fix_avg_projection(avg_projection, groups):
if avg_projection.ndim == 3: # OK
return avg_projection
elif avg_projection.ndim == 2: # Issue #7706
if groups != 1:
_warn_old_model()
return _broadcast_to(
avg_projection, (groups,) + avg_projection.shape)
raise ValueError('unexpected shape of avg_projection')
def _broadcast_to(array, shape):
if hasattr(numpy, 'broadcast_to'):
return numpy.broadcast_to(array, shape)
else:
# numpy 1.9 doesn't support broadcast_to method
dummy = numpy.empty(shape)
bx, _ = numpy.broadcast_arrays(array, dummy)
return bx