-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #4565 from jinjiren/master
Add `InverseShift` extension
- Loading branch information
Showing
4 changed files
with
187 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
from __future__ import division | ||
|
||
import numpy | ||
|
||
from chainer.training import extension | ||
|
||
|
||
class InverseShift(extension.Extension): | ||
|
||
"""Trainer extension to shift an optimizer attribute. | ||
The new value is computed according to the fomula below: | ||
new_attr = init_attr * (1 + gamma * iter) ^ (- power), which is compatible | ||
to the ``inv`` learning rate policy in Caffe. | ||
The typical use is to decrease the learning rate during the training. | ||
This extension is also called before the training loop starts by default. | ||
Args: | ||
attr (str): Name of the attribute to shift. | ||
gamma (float): Parameter used to compute the new value. Refer to the | ||
fomula above. Note that gamma is assumed to be nonegative. | ||
power (float): Parameter used to compute the new value. Refer to the | ||
fomula above. | ||
init (float): Initial value of the attribute. If it is ``None``, the | ||
extension extracts the attribute at the first call and uses it as | ||
the initial value. | ||
target (float): Target value of the attribute. If the attribute reaches | ||
this value, the shift stops. | ||
optimizer (~chainer.Optimizer): Target optimizer to adjust the | ||
attribute. If it is ``None``, the main optimizer of the updater is | ||
used. | ||
""" | ||
|
||
def __init__(self, attr, gamma, power, | ||
init=None, target=None, optimizer=None): | ||
self._attr = attr | ||
if gamma < 0: | ||
raise ValueError('InverseShift does not support negative gamma') | ||
self._gamma = gamma | ||
self._power = power | ||
self._init = init | ||
self._target = target | ||
self._optimizer = optimizer | ||
self._t = 0 | ||
self._last_value = None | ||
|
||
def initialize(self, trainer): | ||
optimizer = self._get_optimizer(trainer) | ||
# ensure that _init is set | ||
if self._init is None: | ||
self._init = getattr(optimizer, self._attr) | ||
|
||
if self._last_value is not None: # resuming from a snapshot | ||
self._update_value(optimizer, self._last_value) | ||
else: | ||
self._update_value(optimizer, self._init) | ||
|
||
def __call__(self, trainer): | ||
self._t += 1 | ||
|
||
optimizer = self._get_optimizer(trainer) | ||
value = self._init * (1 + self._gamma * self._t) ** (-self._power) | ||
if self._target is not None: | ||
if self._power < 0: | ||
# almost same as value = min(value, self._target), but this | ||
# line supports negative values, too | ||
if value / self._target > 1: | ||
value = self._target | ||
else: | ||
# ditto | ||
if value / self._target < 1: | ||
value = self._target | ||
self._update_value(optimizer, value) | ||
|
||
def serialize(self, serializer): | ||
self._t = serializer('_t', self._t) | ||
self._last_value = serializer('_last_value', self._last_value) | ||
if isinstance(self._last_value, numpy.ndarray): | ||
self._last_value = numpy.asscalar(self._last_value) | ||
|
||
def _get_optimizer(self, trainer): | ||
return self._optimizer or trainer.updater.get_optimizer('main') | ||
|
||
def _update_value(self, optimizer, value): | ||
setattr(optimizer, self._attr, value) | ||
self._last_value = value |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
97 changes: 97 additions & 0 deletions
97
tests/chainer_tests/training_tests/extensions_tests/test_inverse_shift.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
import unittest | ||
|
||
import mock | ||
|
||
from chainer import testing | ||
from chainer.training import extensions | ||
from chainer.training import util | ||
|
||
|
||
@testing.parameterize( | ||
{'init': 3.0, 'gamma': 1.0, 'power': 1.0, 'target': None, | ||
'expect': [3.0, 1.5, 1.0]}, | ||
{'init': 3.0, 'gamma': 1.0, 'power': 1.0, 'target': 1.8, | ||
'expect': [3.0, 1.8, 1.8]}, | ||
{'init': -3.0, 'gamma': 1.0, 'power': 1.0, 'target': -1.8, | ||
'expect': [-3.0, -1.8, -1.8]}, | ||
{'init': 3.0, 'gamma': 1.0, 'power': -2.0, 'target': None, | ||
'expect': [3.0, 12.0, 27.0]}, | ||
{'init': 3.0, 'gamma': 1.0, 'power': -2.0, 'target': 4.0, | ||
'expect': [3.0, 4.0, 4.0]}, | ||
{'init': -3.0, 'gamma': 1.0, 'power': -2.0, 'target': -4.0, | ||
'expect': [-3.0, -4.0, -4.0]}, | ||
) | ||
class TestInverseShift(unittest.TestCase): | ||
|
||
def setUp(self): | ||
self.optimizer = mock.MagicMock() | ||
self.extension = extensions.InverseShift( | ||
'x', self.gamma, self.power, self.init, self.target, | ||
self.optimizer) | ||
|
||
self.interval = 4 | ||
self.expect = [e for e in self.expect for _ in range(self.interval)] | ||
self.trigger = util.get_trigger((self.interval, 'iteration')) | ||
|
||
self.trainer = testing.get_trainer_with_mock_updater(self.trigger) | ||
self.trainer.updater.get_optimizer.return_value = self.optimizer | ||
|
||
def _run_trainer(self, extension, expect, optimizer=None): | ||
if optimizer is None: | ||
optimizer = self.optimizer | ||
extension.initialize(self.trainer) | ||
|
||
actual = [] | ||
for _ in expect: | ||
self.trainer.updater.update() | ||
actual.append(optimizer.x) | ||
if self.trigger(self.trainer): | ||
extension(self.trainer) | ||
|
||
self.assertEqual(actual, expect) | ||
|
||
def test_basic(self): | ||
self.optimizer.x = 0 | ||
extension = extensions.InverseShift( | ||
'x', self.gamma, self.power, init=self.init, target=self.target) | ||
self._run_trainer(extension, self.expect) | ||
|
||
def test_without_init(self): | ||
self.optimizer.x = self.init | ||
extension = extensions.InverseShift( | ||
'x', self.gamma, self.power, target=self.target) | ||
self._run_trainer(extension, self.expect) | ||
|
||
def test_with_optimizer(self): | ||
optimizer = mock.Mock() | ||
optimizer.x = 0 | ||
extension = extensions.InverseShift( | ||
'x', self.gamma, self.power, init=self.init, target=self.target, | ||
optimizer=optimizer) | ||
self._run_trainer(extension, self.expect, optimizer) | ||
|
||
def test_resume(self): | ||
new_optimizer = mock.Mock() | ||
new_extension = extensions.InverseShift( | ||
'x', self.gamma, self.power, self.init, self.target, new_optimizer) | ||
|
||
self.trainer.extend(self.extension) | ||
self.trainer.run() | ||
|
||
new_trainer = testing.get_trainer_with_mock_updater((3, 'iteration')) | ||
new_trainer.extend(new_extension) | ||
testing.save_and_load_npz(self.trainer, new_trainer) | ||
|
||
new_extension.initialize(new_trainer) | ||
self.assertEqual(new_optimizer.x, self.optimizer.x) | ||
self.assertIsInstance(new_optimizer.x, float) | ||
|
||
|
||
class TestInverseShiftInvalidArgument(unittest.TestCase): | ||
|
||
def test_negative_rate(self): | ||
with self.assertRaises(ValueError): | ||
extensions.InverseShift('x', -1.0, 1.0) | ||
|
||
|
||
testing.run_module(__name__, __file__) |