/
batch_normalization.py
161 lines (131 loc) · 6.31 KB
/
batch_normalization.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import numpy
from chainer.backends import cuda
from chainer import configuration
from chainer import functions
from chainer import initializers
from chainer import link
from chainer.utils import argument
from chainer import variable
class BatchNormalization(link.Link):
"""Batch normalization layer on outputs of linear or convolution functions.
This link wraps the :func:`~chainer.functions.batch_normalization` and
:func:`~chainer.functions.fixed_batch_normalization` functions.
It runs in three modes: training mode, fine-tuning mode, and testing mode.
In training mode, it normalizes the input by *batch statistics*. It also
maintains approximated population statistics by moving averages, which can
be used for instant evaluation in testing mode.
In fine-tuning mode, it accumulates the input to compute *population
statistics*. In order to correctly compute the population statistics, a
user must use this mode to feed mini-batches running through whole training
dataset.
In testing mode, it uses pre-computed population statistics to normalize
the input variable. The population statistics is approximated if it is
computed by training mode, or accurate if it is correctly computed by
fine-tuning mode.
Args:
size (int or tuple of ints): Size (or shape) of channel
dimensions.
decay (float): Decay rate of moving average. It is used on training.
eps (float): Epsilon value for numerical stability.
dtype (numpy.dtype): Type to use in computing.
use_gamma (bool): If ``True``, use scaling parameter. Otherwise, use
unit(1) which makes no effect.
use_beta (bool): If ``True``, use shifting parameter. Otherwise, use
unit(0) which makes no effect.
See: `Batch Normalization: Accelerating Deep Network Training by Reducing\
Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`_
.. seealso::
:func:`~chainer.functions.batch_normalization`,
:func:`~chainer.functions.fixed_batch_normalization`
Attributes:
gamma (~chainer.Variable): Scaling parameter.
beta (~chainer.Variable): Shifting parameter.
avg_mean (numpy.ndarray or cupy.ndarray): Population mean.
avg_var (numpy.ndarray or cupy.ndarray): Population variance.
N (int): Count of batches given for fine-tuning.
decay (float): Decay rate of moving average. It is used on training.
~BatchNormalization.eps (float): Epsilon value for numerical stability.
This value is added to the batch variances.
"""
def __init__(self, size, decay=0.9, eps=2e-5, dtype=numpy.float32,
use_gamma=True, use_beta=True,
initial_gamma=None, initial_beta=None):
super(BatchNormalization, self).__init__()
self.avg_mean = numpy.zeros(size, dtype=dtype)
self.register_persistent('avg_mean')
self.avg_var = numpy.zeros(size, dtype=dtype)
self.register_persistent('avg_var')
self.N = 0
self.register_persistent('N')
self.decay = decay
self.eps = eps
with self.init_scope():
if use_gamma:
if initial_gamma is None:
initial_gamma = 1
initial_gamma = initializers._get_initializer(initial_gamma)
initial_gamma.dtype = dtype
self.gamma = variable.Parameter(initial_gamma, size)
if use_beta:
if initial_beta is None:
initial_beta = 0
initial_beta = initializers._get_initializer(initial_beta)
initial_beta.dtype = dtype
self.beta = variable.Parameter(initial_beta, size)
def __call__(self, x, **kwargs):
"""__call__(self, x, finetune=False)
Invokes the forward propagation of BatchNormalization.
In training mode, the BatchNormalization computes moving averages of
mean and variance for evaluation during training, and normalizes the
input using batch statistics.
.. warning::
``test`` argument is not supported anymore since v2.
Instead, use ``chainer.using_config('train', False)``.
See :func:`chainer.using_config`.
Args:
x (Variable): Input variable.
finetune (bool): If it is in the training mode and ``finetune`` is
``True``, BatchNormalization runs in fine-tuning mode; it
accumulates the input array to compute population statistics
for normalization, and normalizes the input using batch
statistics.
"""
argument.check_unexpected_kwargs(
kwargs, test='test argument is not supported anymore. '
'Use chainer.using_config')
finetune, = argument.parse_kwargs(kwargs, ('finetune', False))
if hasattr(self, 'gamma'):
gamma = self.gamma
else:
with cuda.get_device_from_id(self._device_id):
gamma = variable.Variable(self.xp.ones(
self.avg_mean.shape, dtype=x.dtype))
if hasattr(self, 'beta'):
beta = self.beta
else:
with cuda.get_device_from_id(self._device_id):
beta = variable.Variable(self.xp.zeros(
self.avg_mean.shape, dtype=x.dtype))
if configuration.config.train:
if finetune:
self.N += 1
decay = 1. - 1. / self.N
else:
decay = self.decay
ret = functions.batch_normalization(
x, gamma, beta, eps=self.eps, running_mean=self.avg_mean,
running_var=self.avg_var, decay=decay)
else:
# Use running average statistics or fine-tuned statistics.
mean = variable.Variable(self.avg_mean)
var = variable.Variable(self.avg_var)
ret = functions.fixed_batch_normalization(
x, gamma, beta, mean, var, self.eps)
return ret
def start_finetuning(self):
"""Resets the population count for collecting population statistics.
This method can be skipped if it is the first time to use the
fine-tuning mode. Otherwise, this method should be called before
starting the fine-tuning mode again.
"""
self.N = 0