/
minmax_value_trigger.py
105 lines (76 loc) · 3.57 KB
/
minmax_value_trigger.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from chainer import reporter
from chainer.training import util
class BestValueTrigger(object):
"""Trigger invoked when specific value becomes best.
Args:
key (str): Key of value.
compare (function): Compare function which takes current best value and
new value and returns whether new value is better than current
best.
trigger: Trigger that decides the comparison interval between current
best value and new value. This must be a tuple in the form of
``<int>, 'epoch'`` or ``<int>, 'iteration'`` which is passed to
:class:`~chainer.training.triggers.IntervalTrigger`.
"""
def __init__(self, key, compare, trigger=(1, 'epoch')):
self._key = key
self._best_value = None
self._interval_trigger = util.get_trigger(trigger)
self._init_summary()
self._compare = compare
def __call__(self, trainer):
"""Decides whether the extension should be called on this iteration.
Args:
trainer (~chainer.training.Trainer): Trainer object that this
trigger is associated with. The ``observation`` of this trainer
is used to determine if the trigger should fire.
Returns:
bool: ``True`` if the corresponding extension should be invoked in
this iteration.
"""
observation = trainer.observation
summary = self._summary
key = self._key
if key in observation:
summary.add({key: observation[key]})
if not self._interval_trigger(trainer):
return False
stats = summary.compute_mean()
value = float(stats[key]) # copy to CPU
self._init_summary()
if self._best_value is None or self._compare(self._best_value, value):
self._best_value = value
return True
return False
def _init_summary(self):
self._summary = reporter.DictSummary()
class MaxValueTrigger(BestValueTrigger):
"""Trigger invoked when specific value becomes maximum.
For example you can use this trigger to take snapshot on the epoch the
validation accuracy is maximum.
Args:
key (str): Key of value. The trigger fires when the value associated
with this key becomes maximum.
trigger: Trigger that decides the comparison interval between current
best value and new value. This must be a tuple in the form of
``<int>, 'epoch'`` or ``<int>, 'iteration'`` which is passed to
:class:`~chainer.training.triggers.IntervalTrigger`.
"""
def __init__(self, key, trigger=(1, 'epoch')):
super(MaxValueTrigger, self).__init__(
key, lambda max_value, new_value: new_value > max_value, trigger)
class MinValueTrigger(BestValueTrigger):
"""Trigger invoked when specific value becomes minimum.
For example you can use this trigger to take snapshot on the epoch the
validation loss is minimum.
Args:
key (str): Key of value. The trigger fires when the value associated
with this key becomes minimum.
trigger: Trigger that decides the comparison interval between current
best value and new value. This must be a tuple in the form of
``<int>, 'epoch'`` or ``<int>, 'iteration'`` which is passed to
:class:`~chainer.training.triggers.IntervalTrigger`.
"""
def __init__(self, key, trigger=(1, 'epoch')):
super(MinValueTrigger, self).__init__(
key, lambda min_value, new_value: new_value < min_value, trigger)