-
Notifications
You must be signed in to change notification settings - Fork 1.4k
/
exponential_shift.py
83 lines (66 loc) · 2.95 KB
/
exponential_shift.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
from __future__ import division
import numpy
from chainer.training import extension
class ExponentialShift(extension.Extension):
"""Trainer extension to exponentially shift an optimizer attribute.
This extension exponentially increases or decreases the specified attribute
of the optimizer. The typical use case is an exponential decay of the
learning rate.
This extension is also called before the training loop starts by default.
Args:
attr (str): Name of the attribute to shift.
rate (float): Rate of the exponential shift. This value is multiplied
to the attribute at each call.
init (float): Initial value of the attribute. If it is ``None``, the
extension extracts the attribute at the first call and uses it as
the initial value.
target (float): Target value of the attribute. If the attribute reaches
this value, the shift stops.
optimizer (~chainer.Optimizer): Target optimizer to adjust the
attribute. If it is ``None``, the main optimizer of the updater is
used.
"""
def __init__(self, attr, rate, init=None, target=None, optimizer=None):
self._attr = attr
if rate < 0:
raise ValueError('ExponentialShift does not support negative rate')
self._rate = rate
self._init = init
self._target = target
self._optimizer = optimizer
self._t = 0
self._last_value = None
def initialize(self, trainer):
optimizer = self._get_optimizer(trainer)
# ensure that _init is set
if self._init is None:
self._init = getattr(optimizer, self._attr)
if self._last_value is not None: # resuming from a snapshot
self._update_value(optimizer, self._last_value)
else:
self._update_value(optimizer, self._init)
def __call__(self, trainer):
self._t += 1
optimizer = self._get_optimizer(trainer)
value = self._init * (self._rate ** self._t)
if self._target is not None:
if self._rate > 1:
# almost same as value = min(value, self._target), but this
# line supports negative values, too
if value / self._target > 1:
value = self._target
else:
# ditto
if value / self._target < 1:
value = self._target
self._update_value(optimizer, value)
def serialize(self, serializer):
self._t = serializer('_t', self._t)
self._last_value = serializer('_last_value', self._last_value)
if isinstance(self._last_value, numpy.ndarray):
self._last_value = self._last_value.item()
def _get_optimizer(self, trainer):
return self._optimizer or trainer.updater.get_optimizer('main')
def _update_value(self, optimizer, value):
setattr(optimizer, self._attr, value)
self._last_value = value