-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #4237 from tohmae/add-lars
Support Layer-wise Adaptive Rate Scaling(LARS)
- Loading branch information
Showing
4 changed files
with
168 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
from chainer.optimizer_hooks.gradient_clipping import GradientClipping # NOQA | ||
from chainer.optimizer_hooks.gradient_hard_clipping import GradientHardClipping # NOQA | ||
from chainer.optimizer_hooks.gradient_lars import GradientLARS # NOQA | ||
from chainer.optimizer_hooks.gradient_noise import GradientNoise # NOQA | ||
from chainer.optimizer_hooks.lasso import Lasso # NOQA | ||
from chainer.optimizer_hooks.weight_decay import WeightDecay # NOQA |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
from chainer import cuda | ||
|
||
|
||
class GradientLARS(object): | ||
|
||
"""Optimizer/UpdateRule hook function for layer wise adaptive rate scaling. | ||
See: `Large Batch Training of Convolutional Networks \ | ||
<https://arxiv.org/abs/1708.03888>`_. | ||
See: `Convergence Analysis of Gradient Descent Algorithms \ | ||
with Proportional Updates \ | ||
<https://arxiv.org/abs/1801.03137>`_. | ||
This hook function scales all gradient arrays to fit to the weight norm. | ||
In <https://arxiv.org/abs/1708.03888>, | ||
.. math:: | ||
v_{t+1} &= m * v_t + \\gamma * \\lambda * | ||
(\\nabla L(w_t) + \\beta w_t), \\\\ | ||
w_{t+1} &= w_{t} - v_{t+1}, | ||
where | ||
- :math:`\\gamma` : learning_rate | ||
- :math:`m` : momentum | ||
- :math:`\\beta` : weight_decay | ||
- :math:`\\eta` : lars_coeeficient | ||
- :math:`\\lambda`: local_lr \ | ||
:math:`=\\eta * \\frac{\|w_t\|}{\|\\nabla L(w_t)\| + \\beta * \|w_t\|}`. | ||
As :math:`lr` in chainer.optimizers.SGD or chainer.optimizers.MomentumSGD | ||
corresponds to :math:`\\gamma * \\eta`, we define :math:`clip\_rate` as | ||
:math:`\\frac{\|w_t\|}{\|\\nabla L(w_t)\| + \\beta * \|w_t\|}` | ||
and reformulate the aforementioned formula as: | ||
:math:`v_{t+1} = m * v_t + lr * clip\_rate * (\\nabla L(w_t) + \\beta w_t)` | ||
and implement in this way. So you do not set lars_coeeficient. | ||
Args: | ||
threashold (float): If weight norm is more than threshold, | ||
this function scales all gradient arrays to fit weight norm. | ||
(See <https://arxiv.org/abs/1801.03137>) | ||
weight_decay (float): Coefficient for the weight decay. | ||
eps (float): Small value for the numerical stability. | ||
(See <https://arxiv.org/abs/1801.03137>) | ||
Attributes: | ||
~optimizer_hooks.GradientLARS.threashold (float): If weight norm is | ||
more than threshold, this function scales all | ||
gradient arrays to fit weight norm. | ||
(See <https://arxiv.org/abs/1801.03137>) | ||
~optimizer_hooks.GradientLARS.weight_decay (float): Coefficient | ||
for the weight decay. | ||
~optimizer_hooks.GradientLARS.eps (float): Small value for the | ||
numerical stability. | ||
(See <https://arxiv.org/abs/1801.03137>) | ||
~optimizer_hooks.GradientLARS.timing (string): Specifies | ||
when this hook should be called by the | ||
Optimizer/UpdateRule. Valid values are 'pre' | ||
(before any updates) and 'post' (after any updates). | ||
""" | ||
name = 'GradientLARS' | ||
call_for_each_param = True | ||
timing = 'pre' | ||
|
||
def __init__(self, threshold=1e-2, weight_decay=0.0, eps=1e-9): | ||
self.threshold = threshold | ||
self.weight_decay = weight_decay | ||
self.eps = eps | ||
|
||
def __call__(self, rule, param): | ||
p, g = param.data, param.grad | ||
if p is None or g is None: | ||
return | ||
|
||
xp = cuda.get_array_module(p) | ||
|
||
# weight norm | ||
p_norm = xp.linalg.norm(p) | ||
# grad norm | ||
g_norm = xp.linalg.norm(g) | ||
local_rate = p_norm / (self.eps + g_norm + self.weight_decay * p_norm) | ||
rate = xp.where(p_norm > self.threshold, local_rate, 1.0) | ||
with cuda.get_device_from_array(p) as dev: | ||
if int(dev) == -1: | ||
g += self.weight_decay * p | ||
g *= rate | ||
else: | ||
kernel = cuda.elementwise( | ||
'T p, T rate, T weight_decay', | ||
'T g', | ||
'g += weight_decay * p; g *= rate;', | ||
'lars') | ||
kernel(p, rate, self.weight_decay, g) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
69 changes: 69 additions & 0 deletions
69
tests/chainer_tests/optimizer_hooks_tests/test_gradient_lars.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
import unittest | ||
|
||
import numpy as np | ||
|
||
import chainer | ||
from chainer import cuda | ||
import chainer.initializers as I | ||
from chainer import optimizer_hooks | ||
from chainer import optimizers | ||
from chainer import testing | ||
from chainer.testing import attr | ||
|
||
|
||
class SimpleLink(chainer.Link): | ||
|
||
def __init__(self, w, g): | ||
super(SimpleLink, self).__init__() | ||
with self.init_scope(): | ||
self.param = chainer.Parameter(I.Zero(), w.shape) | ||
self.param.data = w | ||
self.param.grad = g | ||
|
||
|
||
class TestGradientLARS(unittest.TestCase): | ||
|
||
def setUp(self): | ||
self.target = chainer.ChainList( | ||
SimpleLink(np.arange(6).astype(np.float32).reshape(2, 3), | ||
np.arange(3, -3, -1).astype(np.float32).reshape(2, 3)), | ||
SimpleLink(np.arange(6).astype(np.float32).reshape(2, 3) * 0.0001, | ||
np.arange(3, -3, -1).astype(np.float32).reshape(2, 3)) | ||
) | ||
|
||
def check_LARS(self): | ||
w0 = self.target[0].param.data | ||
g0 = self.target[0].param.grad | ||
w1 = self.target[1].param.data | ||
g1 = self.target[1].param.grad | ||
xp = cuda.get_array_module(w0) | ||
threshold = 1e-2 | ||
weight_decay = 0.2 | ||
eps = 1e-9 | ||
|
||
p0_norm = xp.linalg.norm(w0) | ||
g0_norm = xp.linalg.norm(g0) | ||
clip_rate = p0_norm / (eps + g0_norm + weight_decay * p0_norm) | ||
expect0 = w0 - clip_rate * (g0 + weight_decay * w0) | ||
expect1 = w1 - 1.0 * (g1 + weight_decay * w1) | ||
|
||
opt = optimizers.SGD(lr=1) | ||
opt.setup(self.target) | ||
opt.add_hook(optimizer_hooks.GradientLARS(threshold=threshold, | ||
weight_decay=weight_decay, | ||
eps=eps)) | ||
opt.update() | ||
|
||
testing.assert_allclose(expect0, w0) | ||
testing.assert_allclose(expect1, w1) | ||
|
||
def test_LARS_cpu(self): | ||
self.check_LARS() | ||
|
||
@attr.gpu | ||
def test_LARS_gpu(self): | ||
self.target.to_gpu() | ||
self.check_LARS() | ||
|
||
|
||
testing.run_module(__name__, __file__) |