/
classifier.py
120 lines (99 loc) · 4.47 KB
/
classifier.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
from chainer.functions.evaluation import accuracy
from chainer.functions.loss import softmax_cross_entropy
from chainer import link
from chainer import reporter
class Classifier(link.Chain):
"""A simple classifier model.
This is an example of chain that wraps another chain. It computes the
loss and accuracy based on a given input/label pair.
Args:
predictor (~chainer.Link): Predictor network.
lossfun (function): Loss function.
accfun (function): Function that computes accuracy.
label_key (int or str): Key to specify label variable from arguments.
When it is ``int``, a variable in positional arguments is used.
And when it is ``str``, a variable in keyword arguments is used.
Attributes:
predictor (~chainer.Link): Predictor network.
lossfun (function): Loss function.
accfun (function): Function that computes accuracy.
y (~chainer.Variable): Prediction for the last minibatch.
loss (~chainer.Variable): Loss value for the last minibatch.
accuracy (~chainer.Variable): Accuracy for the last minibatch.
compute_accuracy (bool): If ``True``, compute accuracy on the forward
computation. The default value is ``True``.
.. note::
This link uses :func:`chainer.softmax_cross_entropy` with
default arguments as a loss function (specified by ``lossfun``),
if users do not explicitly change it. In particular, the loss function
does not support double backpropagation.
If you need second or higher order differentiation, you need to turn
it on with ``enable_double_backprop=True``:
>>> import chainer.functions as F
>>> import chainer.links as L
>>>
>>> def lossfun(x, t):
... return F.softmax_cross_entropy(
... x, t, enable_double_backprop=True)
>>>
>>> predictor = L.Linear(10)
>>> model = L.Classifier(predictor, lossfun=lossfun)
"""
compute_accuracy = True
def __init__(self, predictor,
lossfun=softmax_cross_entropy.softmax_cross_entropy,
accfun=accuracy.accuracy,
label_key=-1):
if not (isinstance(label_key, (int, str))):
raise TypeError('label_key must be int or str, but is %s' %
type(label_key))
super(Classifier, self).__init__()
self.lossfun = lossfun
self.accfun = accfun
self.y = None
self.loss = None
self.accuracy = None
self.label_key = label_key
with self.init_scope():
self.predictor = predictor
def __call__(self, *args, **kwargs):
"""Computes the loss value for an input and label pair.
It also computes accuracy and stores it to the attribute.
Args:
args (list of ~chainer.Variable): Input minibatch.
kwargs (dict of ~chainer.Variable): Input minibatch.
When ``label_key`` is ``int``, the correpoding element in ``args``
is treated as ground truth labels. And when it is ``str``, the
element in ``kwargs`` is used.
The all elements of ``args`` and ``kwargs`` except the ground trush
labels are features.
It feeds features to the predictor and compare the result
with ground truth labels.
Returns:
~chainer.Variable: Loss value.
"""
if isinstance(self.label_key, int):
if not (-len(args) <= self.label_key < len(args)):
msg = 'Label key %d is out of bounds' % self.label_key
raise ValueError(msg)
t = args[self.label_key]
if self.label_key == -1:
args = args[:-1]
else:
args = args[:self.label_key] + args[self.label_key + 1:]
elif isinstance(self.label_key, str):
if self.label_key not in kwargs:
msg = 'Label key "%s" is not found' % self.label_key
raise ValueError(msg)
t = kwargs[self.label_key]
del kwargs[self.label_key]
self.y = None
self.loss = None
self.accuracy = None
self.y = self.predictor(*args, **kwargs)
self.loss = self.lossfun(self.y, t)
reporter.report({'loss': self.loss}, self)
if self.compute_accuracy:
self.accuracy = self.accfun(self.y, t)
reporter.report({'accuracy': self.accuracy}, self)
return self.loss