forked from shogun-toolbox/shogun
-
Notifications
You must be signed in to change notification settings - Fork 1
/
ContingencyTableEvaluation.cpp
112 lines (102 loc) · 2.36 KB
/
ContingencyTableEvaluation.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
/*
* This software is distributed under BSD 3-clause license (see LICENSE file).
*
* Authors: Soeren Sonnenburg, Sergey Lisitsyn, Heiko Strathmann,
* Roman Votyakov, Viktor Gal
*/
#include <shogun/evaluation/ContingencyTableEvaluation.h>
#include <shogun/labels/BinaryLabels.h>
using namespace shogun;
float64_t CContingencyTableEvaluation::evaluate(CLabels* predicted, CLabels* ground_truth)
{
REQUIRE(
predicted->get_num_labels() == ground_truth->get_num_labels(),
"Number of predicted labels (%d) must be "
"equal to number of ground truth labels (%d)!\n",
get_name(), predicted->get_num_labels(),
ground_truth->get_num_labels());
auto predicted_binary = predicted->as<CBinaryLabels>();
auto ground_truth_binary = ground_truth->as<CBinaryLabels>();
ground_truth->ensure_valid();
compute_scores(predicted_binary, ground_truth_binary);
switch (m_type)
{
case ACCURACY:
return get_accuracy();
case ERROR_RATE:
return get_error_rate();
case BAL:
return get_BAL();
case WRACC:
return get_WRACC();
case F1:
return get_F1();
case CROSS_CORRELATION:
return get_cross_correlation();
case RECALL:
return get_recall();
case PRECISION:
return get_precision();
case SPECIFICITY:
return get_specificity();
case CUSTOM:
return get_custom_score();
}
SG_NOTIMPLEMENTED
return 42;
}
EEvaluationDirection CContingencyTableEvaluation::get_evaluation_direction() const
{
switch (m_type)
{
case ACCURACY:
return ED_MAXIMIZE;
case ERROR_RATE:
return ED_MINIMIZE;
case BAL:
return ED_MINIMIZE;
case WRACC:
return ED_MAXIMIZE;
case F1:
return ED_MAXIMIZE;
case CROSS_CORRELATION:
return ED_MAXIMIZE;
case RECALL:
return ED_MAXIMIZE;
case PRECISION:
return ED_MAXIMIZE;
case SPECIFICITY:
return ED_MAXIMIZE;
case CUSTOM:
return get_custom_direction();
default:
SG_NOTIMPLEMENTED
}
return ED_MINIMIZE;
}
void CContingencyTableEvaluation::compute_scores(CBinaryLabels* predicted, CBinaryLabels* ground_truth)
{
m_TP = 0.0;
m_FP = 0.0;
m_TN = 0.0;
m_FN = 0.0;
m_N = predicted->get_num_labels();
for (int i=0; i<predicted->get_num_labels(); i++)
{
if (ground_truth->get_label(i)==1)
{
if (predicted->get_label(i)==1)
m_TP += 1.0;
else
m_FN += 1.0;
}
else
{
if (predicted->get_label(i)==1)
m_FP += 1.0;
else
m_TN += 1.0;
}
}
m_computed = true;
}