/
classification_summary.py
134 lines (101 loc) · 4.82 KB
/
classification_summary.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
from __future__ import division
import six
import chainer
from chainer import cuda
from chainer import function
from chainer.utils import type_check
def _fbeta_score(precision, recall, beta):
beta_square = beta * beta
return ((1 + beta_square) * precision * recall /
(beta_square * precision + recall)).astype(precision.dtype)
class ClassificationSummary(function.Function):
def __init__(self, label_num, beta, ignore_label):
self.label_num = label_num
self.beta = beta
self.ignore_label = ignore_label
def check_type_forward(self, in_types):
type_check.expect(in_types.size() == 2)
x_type, t_type = in_types
type_check.expect(
x_type.dtype.kind == 'f',
t_type.dtype.kind == 'i'
)
t_ndim = type_check.eval(t_type.ndim)
type_check.expect(
x_type.ndim >= t_type.ndim,
x_type.shape[0] == t_type.shape[0],
x_type.shape[2: t_ndim + 1] == t_type.shape[1:]
)
for i in six.moves.range(t_ndim + 1, type_check.eval(x_type.ndim)):
type_check.expect(x_type.shape[i] == 1)
def forward(self, inputs):
xp = cuda.get_array_module(*inputs)
y, t = inputs
# numpy.bincount requires int32 on Windows
t = t.astype('i', copy=False)
if self.label_num is None:
label_num = xp.amax(t) + 1
else:
label_num = self.label_num
if chainer.is_debug():
assert (t < label_num).all()
mask = (t == self.ignore_label).ravel()
pred = xp.where(mask, label_num, y.argmax(axis=1).ravel())
true = xp.where(mask, label_num, t.ravel())
support = xp.bincount(true, minlength=label_num + 1)[:label_num]
relevant = xp.bincount(pred, minlength=label_num + 1)[:label_num]
tp_mask = xp.where(pred == true, true, label_num)
tp = xp.bincount(tp_mask, minlength=label_num + 1)[:label_num]
precision = tp / relevant
recall = tp / support
fbeta = _fbeta_score(precision, recall, self.beta)
return precision, recall, fbeta, support
def classification_summary(y, t, label_num=None, beta=1.0, ignore_label=-1):
"""Calculates Precision, Recall, F beta Score, and support.
This function calculates the following quantities for each class.
- Precision: :math:`\\frac{\\mathrm{tp}}{\\mathrm{tp} + \\mathrm{fp}}`
- Recall: :math:`\\frac{\\mathrm{tp}}{\\mathrm{tp} + \\mathrm{tn}}`
- F beta Score: The weighted harmonic average of Precision and Recall.
- Support: The number of instances of each ground truth label.
Here, ``tp``, ``fp``, and ``tn`` stand for the number of true positives,
false positives, and true negative, respectively.
``label_num`` specifies the number of classes, that is,
each value in ``t`` must be an integer in the range of
``[0, label_num)``.
If ``label_num`` is ``None``, this function regards
``label_num`` as a maximum of in ``t`` plus one.
``ignore_label`` determines which instances should be ignored.
Specifically, instances with the given label are not taken
into account for calculating the above quantities.
By default, it is set to -1 so that all instances are taken
into consideration, as labels are supposed to be non-negative integers.
Setting ``ignore_label`` to a non-negative integer less than ``label_num``
is illegal and yields undefined behavior. In the current implementation,
it arises ``RuntimeWarning`` and ``ignore_label``-th entries in output
arrays do not contain correct quantities.
Args:
y (~chainer.Variable): Variable holding a vector of scores.
t (~chainer.Variable): Variable holding a vector of
ground truth labels.
label_num (int): The number of classes.
beta (float): The parameter which determines the weight of
precision in the F-beta score.
ignore_label (int): Instances with this label are ignored.
Returns:
4-tuple of ~chainer.Variable of size ``(label_num,)``.
Each element represents precision, recall, F beta score,
and support of this minibatch.
"""
return ClassificationSummary(label_num, beta, ignore_label)(y, t)
def precision(y, t, label_num=None, ignore_label=-1):
ret = ClassificationSummary(label_num, 1.0, ignore_label)(y, t)
return ret[0], ret[-1]
def recall(y, t, label_num=None, ignore_label=-1):
ret = ClassificationSummary(label_num, 1.0, ignore_label)(y, t)
return ret[1], ret[-1]
def fbeta_score(y, t, label_num=None, beta=1.0, ignore_label=-1):
ret = ClassificationSummary(label_num, beta, ignore_label)(y, t)
return ret[2], ret[-1]
def f1_score(y, t, label_num=None, ignore_label=-1):
ret = ClassificationSummary(label_num, 1.0, ignore_label)(y, t)
return ret[2], ret[-1]