/
gradient_lars.py
106 lines (87 loc) · 4.13 KB
/
gradient_lars.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
from chainer import backend
from chainer import cuda
class GradientLARS(object):
"""Optimizer/UpdateRule hook function for layer wise adaptive rate scaling.
See: `Large Batch Training of Convolutional Networks \
<https://arxiv.org/abs/1708.03888>`_.
See: `Convergence Analysis of Gradient Descent Algorithms \
with Proportional Updates \
<https://arxiv.org/abs/1801.03137>`_.
This hook function scales all gradient arrays to fit to the weight norm.
In <https://arxiv.org/abs/1708.03888>,
.. math::
v_{t+1} &= m * v_t + \\gamma * \\lambda *
(\\nabla L(w_t) + \\beta w_t), \\\\
w_{t+1} &= w_{t} - v_{t+1},
where
- :math:`\\gamma` : learning_rate
- :math:`m` : momentum
- :math:`\\beta` : weight_decay
- :math:`\\eta` : lars_coeeficient
- :math:`\\lambda`: local_lr \
:math:`=\\eta * \
\\frac{\\|w_t\\|}{\\|\\nabla L(w_t)\\| + \\beta * \\|w_t\\|}`.
As :math:`lr` in chainer.optimizers.SGD or chainer.optimizers.MomentumSGD
corresponds to :math:`\\gamma * \\eta`, we define :math:`clip\\_rate` as
:math:`\\frac{\\|w_t\\|}{\\|\\nabla L(w_t)\\| + \\beta * \\|w_t\\|}`
and reformulate the aforementioned formula as:
:math:`v_{t+1} \
= m * v_t + lr * clip\\_rate * (\\nabla L(w_t) + \\beta w_t)`
and implement in this way. So you do not set lars_coeeficient.
Args:
threashold (float): If weight norm is more than threshold,
this function scales all gradient arrays to fit weight norm.
(See <https://arxiv.org/abs/1801.03137>)
weight_decay (float): Coefficient for the weight decay.
eps (float): Small value for the numerical stability.
(See <https://arxiv.org/abs/1801.03137>)
Attributes:
~optimizer_hooks.GradientLARS.threashold (float): If weight norm is
more than threshold, this function scales all
gradient arrays to fit weight norm.
(See <https://arxiv.org/abs/1801.03137>)
~optimizer_hooks.GradientLARS.weight_decay (float): Coefficient
for the weight decay.
~optimizer_hooks.GradientLARS.eps (float): Small value for the
numerical stability.
(See <https://arxiv.org/abs/1801.03137>)
~optimizer_hooks.GradientLARS.timing (string): Specifies
when this hook should be called by the
Optimizer/UpdateRule. Valid values are 'pre'
(before any updates) and 'post' (after any updates).
~optimizer_hooks.GradientLARS.call_for_each_param (bool): Specifies
if this hook is called for each parameter (``True``)
or only once (``False``) by an optimizer to
which this hook is registered. This function does
not expect users to switch the value from default one,
which is `True`.
"""
name = 'GradientLARS'
call_for_each_param = True
timing = 'pre'
def __init__(self, threshold=1e-2, weight_decay=0.0, eps=1e-9):
self.threshold = threshold
self.weight_decay = weight_decay
self.eps = eps
def __call__(self, rule, param):
p, g = param.data, param.grad
if p is None or g is None:
return
xp = backend.get_array_module(p)
# weight norm
p_norm = xp.linalg.norm(p)
# grad norm
g_norm = xp.linalg.norm(g)
local_rate = p_norm / (self.eps + g_norm + self.weight_decay * p_norm)
rate = xp.where(p_norm > self.threshold, local_rate, 1.0)
with cuda.get_device_from_array(p) as dev:
if int(dev) == -1:
g += self.weight_decay * p
g *= rate
else:
kernel = cuda.elementwise(
'T p, T rate, T weight_decay',
'T g',
'g += weight_decay * p; g *= rate;',
'lars')
kernel(p, rate, self.weight_decay, g)