-
Notifications
You must be signed in to change notification settings - Fork 1.4k
/
batch_renormalization.py
89 lines (74 loc) · 3.43 KB
/
batch_renormalization.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import chainer
from chainer import configuration
from chainer.functions.normalization import batch_normalization
from chainer.functions.normalization import batch_renormalization
from chainer.links.normalization.batch_normalization import BatchNormalization
class BatchRenormalization(BatchNormalization):
"""Batch renormalization layer on outputs of linear or convolution functions.
This link wraps the :func:`~chainer.functions.batch_renormalization` and
:func:`~chainer.functions.fixed_batch_renormalization` functions.
This is an extension of batch normalization, which ensures that the
training and inference models generate the same outputs that depend on
individual examples rather than the entire minibatch.
See: `Batch Renormalization: Towards Reducing Minibatch Dependence in
Batch-Normalized Models <https://arxiv.org/abs/1702.03275>`_
.. seealso::
:func:`~chainer.functions.batch_renormalization`,
:func:`~chainer.functions.fixed_batch_renormalization`
:func:`~chainer.functions.batch_normalization`,
"""
def __init__(self, size, rmax=1, dmax=0, decay=0.9, eps=2e-5,
dtype=None, use_gamma=True, use_beta=True,
initial_gamma=None, initial_beta=None,
initial_avg_mean=None, initial_avg_var=None):
super(BatchRenormalization, self).__init__(
size, decay, eps, dtype, use_gamma, use_beta,
initial_gamma, initial_beta, initial_avg_mean, initial_avg_var)
self.rmax = rmax # maximum allowed correction of variance
self.dmax = dmax # maximum allowed correction of mean
self.r = None
self.d = None
def forward(self, x, finetune=False):
if self.gamma is not None:
gamma = self.gamma
else:
with chainer.using_device(self.device):
gamma = self.xp.ones(
self.avg_mean.shape, dtype=x.dtype)
if self.beta is not None:
beta = self.beta
else:
with chainer.using_device(self.device):
beta = self.xp.zeros(
self.avg_mean.shape, dtype=x.dtype)
if configuration.config.train:
if finetune:
self.N += 1
decay = 1. - 1. / self.N
else:
decay = self.decay
avg_mean = self.avg_mean
avg_var = self.avg_var
update_statistics = True
if chainer.config.in_recomputing:
# Do not update statistics when extra forward computation is
# called.
if finetune:
self.N -= 1 # Revert the count
avg_mean = self._prev_avg_mean
avg_var = self._prev_avg_var
update_statistics = False
elif chainer.config._will_recompute:
self._prev_avg_mean = avg_mean.copy()
self._prev_avg_var = avg_var.copy()
ret = batch_renormalization.batch_renormalization(
x, gamma, beta, self.rmax, self.dmax,
self.eps, avg_mean, avg_var, decay,
update_statistics=update_statistics)
else:
# Use running average statistics or fine-tuned statistics.
mean = self.avg_mean
var = self.avg_var
ret = batch_normalization.fixed_batch_normalization(
x, gamma, beta, mean, var, self.eps)
return ret