/
interval_trigger.py
107 lines (84 loc) · 3.89 KB
/
interval_trigger.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import warnings
class IntervalTrigger(object):
"""Trigger based on a fixed interval.
This trigger accepts iterations divided by a given interval. There are two
ways to specify the interval: per iterations and epochs. `Iteration` means
the number of updates, while `epoch` means the number of sweeps over the
training dataset. Fractional values are allowed if the interval is a
number of epochs; the trigger uses the `iteration` and `epoch_detail`
attributes defined by the updater.
For the description of triggers, see :func:`~chainer.training.get_trigger`.
Args:
period (int or float): Length of the interval. Must be an integer if
unit is ``'iteration'``.
unit (str): Unit of the length specified by ``period``. It must be
either ``'iteration'`` or ``'epoch'``.
"""
def __init__(self, period, unit):
self.period = period
assert unit == 'epoch' or unit == 'iteration'
self.unit = unit
self._previous_iteration = 0
self._previous_epoch_detail = 0.
# count is kept for backward compatibility
self.count = 0
def __call__(self, trainer):
"""Decides whether the extension should be called on this iteration.
Args:
trainer (Trainer): Trainer object that this trigger is associated
with. The updater associated with this trainer is used to
determine if the trigger should fire.
Returns:
bool: True if the corresponding extension should be invoked in this
iteration.
"""
updater = trainer.updater
if self.unit == 'epoch':
epoch_detail = updater.epoch_detail
previous_epoch_detail = self._previous_epoch_detail
# if pvevious_epoch_detail is invalid value,
# use the value of updater.
if previous_epoch_detail < 0:
previous_epoch_detail = updater.previous_epoch_detail
# count is kept for backward compatibility
self.count = epoch_detail // self.period
fire = previous_epoch_detail // self.period != \
epoch_detail // self.period
else:
iteration = updater.iteration
previous_iteration = self._previous_iteration
# if pvevious_iteration is invalid value,
# guess it from current iteration.
if previous_iteration < 0:
previous_iteration = iteration - 1
fire = previous_iteration // self.period != \
iteration // self.period
# save current values
self._previous_iteration = updater.iteration
if hasattr(updater, 'epoch_detail'):
self._previous_epoch_detail = updater.epoch_detail
return fire
def serialize(self, serializer):
try:
self._previous_iteration = serializer(
'previous_iteration', self._previous_iteration)
except KeyError:
warnings.warn(
'The previous value of iteration is not saved. '
'IntervalTrigger guesses it using current iteration. '
'If this trigger is not called at every iteration, '
'it may not work correctly.')
# set a negative value for invalid
self._previous_iteration = -1
try:
self._previous_epoch_detail = serializer(
'previous_epoch_detail', self._previous_epoch_detail)
except KeyError:
warnings.warn(
'The previous value of epoch_detail is not saved. '
'IntervalTrigger uses the value of '
'trainer.updater.previous_epoch_detail. '
'If this trigger is not called at every iteration, '
'it may not work correctly.')
# set a negative value for invalid
self._previous_epoch_detail = -1.