Skip to content

Commit

Permalink
Merge pull request #4237 from tohmae/add-lars
Browse files Browse the repository at this point in the history
Support Layer-wise Adaptive Rate Scaling(LARS)
  • Loading branch information
delta2323 committed Apr 21, 2018
2 parents c7f4717 + f1f3195 commit c091509
Show file tree
Hide file tree
Showing 4 changed files with 168 additions and 0 deletions.
1 change: 1 addition & 0 deletions chainer/optimizer_hooks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from chainer.optimizer_hooks.gradient_clipping import GradientClipping # NOQA
from chainer.optimizer_hooks.gradient_hard_clipping import GradientHardClipping # NOQA
from chainer.optimizer_hooks.gradient_lars import GradientLARS # NOQA
from chainer.optimizer_hooks.gradient_noise import GradientNoise # NOQA
from chainer.optimizer_hooks.lasso import Lasso # NOQA
from chainer.optimizer_hooks.weight_decay import WeightDecay # NOQA
97 changes: 97 additions & 0 deletions chainer/optimizer_hooks/gradient_lars.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
from chainer import cuda


class GradientLARS(object):

"""Optimizer/UpdateRule hook function for layer wise adaptive rate scaling.
See: `Large Batch Training of Convolutional Networks \
<https://arxiv.org/abs/1708.03888>`_.
See: `Convergence Analysis of Gradient Descent Algorithms \
with Proportional Updates \
<https://arxiv.org/abs/1801.03137>`_.
This hook function scales all gradient arrays to fit to the weight norm.
In <https://arxiv.org/abs/1708.03888>,
.. math::
v_{t+1} &= m * v_t + \\gamma * \\lambda *
(\\nabla L(w_t) + \\beta w_t), \\\\
w_{t+1} &= w_{t} - v_{t+1},
where
- :math:`\\gamma` : learning_rate
- :math:`m` : momentum
- :math:`\\beta` : weight_decay
- :math:`\\eta` : lars_coeeficient
- :math:`\\lambda`: local_lr \
:math:`=\\eta * \\frac{\|w_t\|}{\|\\nabla L(w_t)\| + \\beta * \|w_t\|}`.
As :math:`lr` in chainer.optimizers.SGD or chainer.optimizers.MomentumSGD
corresponds to :math:`\\gamma * \\eta`, we define :math:`clip\_rate` as
:math:`\\frac{\|w_t\|}{\|\\nabla L(w_t)\| + \\beta * \|w_t\|}`
and reformulate the aforementioned formula as:
:math:`v_{t+1} = m * v_t + lr * clip\_rate * (\\nabla L(w_t) + \\beta w_t)`
and implement in this way. So you do not set lars_coeeficient.
Args:
threashold (float): If weight norm is more than threshold,
this function scales all gradient arrays to fit weight norm.
(See <https://arxiv.org/abs/1801.03137>)
weight_decay (float): Coefficient for the weight decay.
eps (float): Small value for the numerical stability.
(See <https://arxiv.org/abs/1801.03137>)
Attributes:
~optimizer_hooks.GradientLARS.threashold (float): If weight norm is
more than threshold, this function scales all
gradient arrays to fit weight norm.
(See <https://arxiv.org/abs/1801.03137>)
~optimizer_hooks.GradientLARS.weight_decay (float): Coefficient
for the weight decay.
~optimizer_hooks.GradientLARS.eps (float): Small value for the
numerical stability.
(See <https://arxiv.org/abs/1801.03137>)
~optimizer_hooks.GradientLARS.timing (string): Specifies
when this hook should be called by the
Optimizer/UpdateRule. Valid values are 'pre'
(before any updates) and 'post' (after any updates).
"""
name = 'GradientLARS'
call_for_each_param = True
timing = 'pre'

def __init__(self, threshold=1e-2, weight_decay=0.0, eps=1e-9):
self.threshold = threshold
self.weight_decay = weight_decay
self.eps = eps

def __call__(self, rule, param):
p, g = param.data, param.grad
if p is None or g is None:
return

xp = cuda.get_array_module(p)

# weight norm
p_norm = xp.linalg.norm(p)
# grad norm
g_norm = xp.linalg.norm(g)
local_rate = p_norm / (self.eps + g_norm + self.weight_decay * p_norm)
rate = xp.where(p_norm > self.threshold, local_rate, 1.0)
with cuda.get_device_from_array(p) as dev:
if int(dev) == -1:
g += self.weight_decay * p
g *= rate
else:
kernel = cuda.elementwise(
'T p, T rate, T weight_decay',
'T g',
'g += weight_decay * p; g *= rate;',
'lars')
kernel(p, rate, self.weight_decay, g)
1 change: 1 addition & 0 deletions docs/source/reference/optimizers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,4 @@ Hook functions
chainer.optimizer_hooks.GradientClipping
chainer.optimizer_hooks.GradientHardClipping
chainer.optimizer_hooks.GradientNoise
chainer.optimizer_hooks.GradientLARS
69 changes: 69 additions & 0 deletions tests/chainer_tests/optimizer_hooks_tests/test_gradient_lars.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import unittest

import numpy as np

import chainer
from chainer import cuda
import chainer.initializers as I
from chainer import optimizer_hooks
from chainer import optimizers
from chainer import testing
from chainer.testing import attr


class SimpleLink(chainer.Link):

def __init__(self, w, g):
super(SimpleLink, self).__init__()
with self.init_scope():
self.param = chainer.Parameter(I.Zero(), w.shape)
self.param.data = w
self.param.grad = g


class TestGradientLARS(unittest.TestCase):

def setUp(self):
self.target = chainer.ChainList(
SimpleLink(np.arange(6).astype(np.float32).reshape(2, 3),
np.arange(3, -3, -1).astype(np.float32).reshape(2, 3)),
SimpleLink(np.arange(6).astype(np.float32).reshape(2, 3) * 0.0001,
np.arange(3, -3, -1).astype(np.float32).reshape(2, 3))
)

def check_LARS(self):
w0 = self.target[0].param.data
g0 = self.target[0].param.grad
w1 = self.target[1].param.data
g1 = self.target[1].param.grad
xp = cuda.get_array_module(w0)
threshold = 1e-2
weight_decay = 0.2
eps = 1e-9

p0_norm = xp.linalg.norm(w0)
g0_norm = xp.linalg.norm(g0)
clip_rate = p0_norm / (eps + g0_norm + weight_decay * p0_norm)
expect0 = w0 - clip_rate * (g0 + weight_decay * w0)
expect1 = w1 - 1.0 * (g1 + weight_decay * w1)

opt = optimizers.SGD(lr=1)
opt.setup(self.target)
opt.add_hook(optimizer_hooks.GradientLARS(threshold=threshold,
weight_decay=weight_decay,
eps=eps))
opt.update()

testing.assert_allclose(expect0, w0)
testing.assert_allclose(expect1, w1)

def test_LARS_cpu(self):
self.check_LARS()

@attr.gpu
def test_LARS_gpu(self):
self.target.to_gpu()
self.check_LARS()


testing.run_module(__name__, __file__)

0 comments on commit c091509

Please sign in to comment.