Skip to content

Commit

Permalink
Merge pull request #4565 from jinjiren/master
Browse files Browse the repository at this point in the history
Add `InverseShift` extension
  • Loading branch information
bkvogel committed Jun 11, 2018
2 parents 249c998 + cbfa2aa commit 149d747
Show file tree
Hide file tree
Showing 4 changed files with 187 additions and 0 deletions.
1 change: 1 addition & 0 deletions chainer/training/extensions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from chainer.training.extensions.evaluator import Evaluator # NOQA
from chainer.training.extensions.exponential_shift import ExponentialShift # NOQA
from chainer.training.extensions.fail_on_nonnumber import FailOnNonNumber # NOQA
from chainer.training.extensions.inverse_shift import InverseShift # NOQA
from chainer.training.extensions.linear_shift import LinearShift # NOQA
from chainer.training.extensions.log_report import LogReport # NOQA
from chainer.training.extensions.micro_average import MicroAverage # NOQA
Expand Down
88 changes: 88 additions & 0 deletions chainer/training/extensions/inverse_shift.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from __future__ import division

import numpy

from chainer.training import extension


class InverseShift(extension.Extension):

"""Trainer extension to shift an optimizer attribute.
The new value is computed according to the fomula below:
new_attr = init_attr * (1 + gamma * iter) ^ (- power), which is compatible
to the ``inv`` learning rate policy in Caffe.
The typical use is to decrease the learning rate during the training.
This extension is also called before the training loop starts by default.
Args:
attr (str): Name of the attribute to shift.
gamma (float): Parameter used to compute the new value. Refer to the
fomula above. Note that gamma is assumed to be nonegative.
power (float): Parameter used to compute the new value. Refer to the
fomula above.
init (float): Initial value of the attribute. If it is ``None``, the
extension extracts the attribute at the first call and uses it as
the initial value.
target (float): Target value of the attribute. If the attribute reaches
this value, the shift stops.
optimizer (~chainer.Optimizer): Target optimizer to adjust the
attribute. If it is ``None``, the main optimizer of the updater is
used.
"""

def __init__(self, attr, gamma, power,
init=None, target=None, optimizer=None):
self._attr = attr
if gamma < 0:
raise ValueError('InverseShift does not support negative gamma')
self._gamma = gamma
self._power = power
self._init = init
self._target = target
self._optimizer = optimizer
self._t = 0
self._last_value = None

def initialize(self, trainer):
optimizer = self._get_optimizer(trainer)
# ensure that _init is set
if self._init is None:
self._init = getattr(optimizer, self._attr)

if self._last_value is not None: # resuming from a snapshot
self._update_value(optimizer, self._last_value)
else:
self._update_value(optimizer, self._init)

def __call__(self, trainer):
self._t += 1

optimizer = self._get_optimizer(trainer)
value = self._init * (1 + self._gamma * self._t) ** (-self._power)
if self._target is not None:
if self._power < 0:
# almost same as value = min(value, self._target), but this
# line supports negative values, too
if value / self._target > 1:
value = self._target
else:
# ditto
if value / self._target < 1:
value = self._target
self._update_value(optimizer, value)

def serialize(self, serializer):
self._t = serializer('_t', self._t)
self._last_value = serializer('_last_value', self._last_value)
if isinstance(self._last_value, numpy.ndarray):
self._last_value = numpy.asscalar(self._last_value)

def _get_optimizer(self, trainer):
return self._optimizer or trainer.updater.get_optimizer('main')

def _update_value(self, optimizer, value):
setattr(optimizer, self._attr, value)
self._last_value = value
1 change: 1 addition & 0 deletions docs/source/reference/training.rst
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ The typical use case is to change the learning rate of the optimizer over time.
:nosignatures:

chainer.training.extensions.ExponentialShift
chainer.training.extensions.InverseShift
chainer.training.extensions.LinearShift
chainer.training.extensions.PolynomialShift

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import unittest

import mock

from chainer import testing
from chainer.training import extensions
from chainer.training import util


@testing.parameterize(
{'init': 3.0, 'gamma': 1.0, 'power': 1.0, 'target': None,
'expect': [3.0, 1.5, 1.0]},
{'init': 3.0, 'gamma': 1.0, 'power': 1.0, 'target': 1.8,
'expect': [3.0, 1.8, 1.8]},
{'init': -3.0, 'gamma': 1.0, 'power': 1.0, 'target': -1.8,
'expect': [-3.0, -1.8, -1.8]},
{'init': 3.0, 'gamma': 1.0, 'power': -2.0, 'target': None,
'expect': [3.0, 12.0, 27.0]},
{'init': 3.0, 'gamma': 1.0, 'power': -2.0, 'target': 4.0,
'expect': [3.0, 4.0, 4.0]},
{'init': -3.0, 'gamma': 1.0, 'power': -2.0, 'target': -4.0,
'expect': [-3.0, -4.0, -4.0]},
)
class TestInverseShift(unittest.TestCase):

def setUp(self):
self.optimizer = mock.MagicMock()
self.extension = extensions.InverseShift(
'x', self.gamma, self.power, self.init, self.target,
self.optimizer)

self.interval = 4
self.expect = [e for e in self.expect for _ in range(self.interval)]
self.trigger = util.get_trigger((self.interval, 'iteration'))

self.trainer = testing.get_trainer_with_mock_updater(self.trigger)
self.trainer.updater.get_optimizer.return_value = self.optimizer

def _run_trainer(self, extension, expect, optimizer=None):
if optimizer is None:
optimizer = self.optimizer
extension.initialize(self.trainer)

actual = []
for _ in expect:
self.trainer.updater.update()
actual.append(optimizer.x)
if self.trigger(self.trainer):
extension(self.trainer)

self.assertEqual(actual, expect)

def test_basic(self):
self.optimizer.x = 0
extension = extensions.InverseShift(
'x', self.gamma, self.power, init=self.init, target=self.target)
self._run_trainer(extension, self.expect)

def test_without_init(self):
self.optimizer.x = self.init
extension = extensions.InverseShift(
'x', self.gamma, self.power, target=self.target)
self._run_trainer(extension, self.expect)

def test_with_optimizer(self):
optimizer = mock.Mock()
optimizer.x = 0
extension = extensions.InverseShift(
'x', self.gamma, self.power, init=self.init, target=self.target,
optimizer=optimizer)
self._run_trainer(extension, self.expect, optimizer)

def test_resume(self):
new_optimizer = mock.Mock()
new_extension = extensions.InverseShift(
'x', self.gamma, self.power, self.init, self.target, new_optimizer)

self.trainer.extend(self.extension)
self.trainer.run()

new_trainer = testing.get_trainer_with_mock_updater((3, 'iteration'))
new_trainer.extend(new_extension)
testing.save_and_load_npz(self.trainer, new_trainer)

new_extension.initialize(new_trainer)
self.assertEqual(new_optimizer.x, self.optimizer.x)
self.assertIsInstance(new_optimizer.x, float)


class TestInverseShiftInvalidArgument(unittest.TestCase):

def test_negative_rate(self):
with self.assertRaises(ValueError):
extensions.InverseShift('x', -1.0, 1.0)


testing.run_module(__name__, __file__)

0 comments on commit 149d747

Please sign in to comment.