-
Notifications
You must be signed in to change notification settings - Fork 24
/
confusion.go
237 lines (207 loc) · 6.94 KB
/
confusion.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
// Copyright (c) 2019, The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package confusion
import (
"fmt"
"math"
"github.com/emer/etable/etensor"
"github.com/emer/etable/simat"
"github.com/goki/gi/gi"
"github.com/goki/ki/ki"
"github.com/goki/ki/kit"
)
// Matrix computes the confusion matrix, with rows representing
// the ground truth correct class, and columns representing the
// actual answer produced. Correct answers are along the diagonal.
type Matrix struct {
// [view: no-inline] normalized probability of confusion: Row = ground truth class, Col = actual response for that class.
Prob etensor.Float64 `view:"no-inline" desc:"normalized probability of confusion: Row = ground truth class, Col = actual response for that class."`
// [view: no-inline] incremental sums
Sum etensor.Float64 `view:"no-inline" desc:"incremental sums"`
// [view: no-inline] counts per ground truth (rows)
N etensor.Float64 `view:"no-inline" desc:"counts per ground truth (rows)"`
// [view: no-inline] visualization using SimMat
Vis simat.SimMat `view:"no-inline" desc:"visualization using SimMat"`
// [view: no-inline] true pos/neg, false pos/neg for each class, generated from the confusion matrix
TFPN etensor.Float64 `view:"no-inline" desc:"true pos/neg, false pos/neg for each class, generated from the confusion matrix"`
// [view: no-inline] precision, recall and F1 score by class
ClassScores etensor.Float64 `view:"no-inline" desc:"precision, recall and F1 score by class"`
// [view: no-inline] micro F1, macro F1 and weighted F1 scores for entire matrix ignoring class
MatrixScores etensor.Float64 `view:"no-inline" desc:"micro F1, macro F1 and weighted F1 scores for entire matrix ignoring class"`
}
var KiT_Matrix = kit.Types.AddType(&Matrix{}, MatrixProps)
// Init initializes the Matrix for given number of classes,
// and resets the data to zero.
func (cm *Matrix) Init(n int) {
cm.Prob.SetShape([]int{n, n}, nil, []string{"N", "N"})
cm.Sum.SetShape([]int{n, n}, nil, []string{"N", "N"})
cm.N.SetShape([]int{n}, nil, []string{"N"})
cm.TFPN.SetShape([]int{n, 4}, nil, []string{"TP", "FP", "FN", "TN"})
cm.ClassScores.SetShape([]int{n, 3}, nil, []string{"Precision", "Recall", "F1"})
cm.MatrixScores.SetShape([]int{3}, nil, []string{"Precision", "Recall", "F1"})
cm.Vis.Mat = &cm.Prob
cm.Reset()
}
// Reset resets the data to zero
func (cm *Matrix) Reset() {
cm.Prob.SetZeros()
cm.Sum.SetZeros()
cm.N.SetZeros()
cm.TFPN.SetZeros()
cm.ClassScores.SetZeros()
cm.MatrixScores.SetZeros()
}
// SetLabels sets the class labels, for visualization in Vis
func (cm *Matrix) SetLabels(lbls []string) {
cm.Vis.Rows = lbls
cm.Vis.Cols = lbls
}
// InitFromLabels does initialization based on given labels.
// Calls Init on len(lbls) and SetLabels.
// Default fontSize = 12 if 0 or -1 passed
func (cm *Matrix) InitFromLabels(lbls []string, fontSize int) {
cm.Init(len(lbls))
cm.SetLabels(lbls)
if fontSize <= 0 {
fontSize = 12
}
cm.Prob.SetMetaData("font-size", fmt.Sprintf("%d", fontSize))
}
// Incr increments the data for given class ground truth and response.
func (cm *Matrix) Incr(class, resp int) {
if class < 0 || resp < 0 {
return
}
ncat := cm.Sum.Dim(0)
if class >= ncat || resp >= ncat {
return
}
ix := []int{class, resp}
sum := cm.Sum.Value(ix)
sum++
cm.Sum.Set(ix, sum)
n := cm.N.Value1D(class)
n++
cm.N.Set1D(class, n)
}
// Probs computes the probabilities based on accumulated data
func (cm *Matrix) Probs() {
n := cm.N.Len()
for cl := 0; cl < n; cl++ {
cn := cm.N.Value1D(cl)
if cn == 0 {
continue
}
for ri := 0; ri < n; ri++ {
ix := []int{cl, ri}
sum := cm.Sum.Value(ix)
cm.Prob.Set(ix, sum/cn)
}
}
}
func (cm *Matrix) SumTFPN(class int) {
fn := 0.0 // false negative
fp := 0.0 // false positive
tn := 0.0 // true negative
n := cm.N.Len()
for c := 0; c < n; c++ {
for r := 0; r < n; r++ {
if r == class && c == class { // True Positive
v := cm.Sum.FloatValRowCell(r, c)
cm.TFPN.SetFloatRowCell(class, 0, v)
} else if r == class && c != class { // False Positive
fn += cm.Sum.FloatValRowCell(r, c)
cm.TFPN.SetFloatRowCell(class, 1, fp)
} else if r != class && c == class { // False Negative
fp += cm.Sum.FloatValRowCell(r, c)
cm.TFPN.SetFloatRowCell(class, 2, fn)
} else { // True Negative
tn += cm.Sum.FloatValRowCell(r, c)
cm.TFPN.SetFloatRowCell(class, 3, tn)
}
}
}
cm.TFPN.SetFloatRowCell(class, 1, fp)
cm.TFPN.SetFloatRowCell(class, 2, fn)
cm.TFPN.SetFloatRowCell(class, 3, tn)
}
func (cm *Matrix) ScoreClass(class int) {
tp := cm.TFPN.FloatValRowCell(class, 0)
fp := cm.TFPN.FloatValRowCell(class, 1)
fn := cm.TFPN.FloatValRowCell(class, 2)
precision := tp / (tp + fp)
cm.ClassScores.SetFloatRowCell(class, 0, precision)
recall := tp / (tp + fn) // also called true positive rate and has other names
cm.ClassScores.SetFloatRowCell(class, 1, recall)
f1 := 2 * tp / ((2 * tp) + fp + fn) // 2 x (Precision x Recall) / (Precision + Recall)
cm.ClassScores.SetFloatRowCell(class, 2, f1)
}
func (cm *Matrix) ScoreMatrix() {
tp := 0.0
fp := 0.0
fn := 0.0
n := cm.N.Len()
for i := 0; i < n; i++ {
tp += cm.TFPN.FloatValRowCell(i, 0)
fp += cm.TFPN.FloatValRowCell(i, 1)
fn += cm.TFPN.FloatValRowCell(i, 2)
}
// micro F1 - ignores class
f1 := 2 * tp / ((2 * tp) + fp + fn) // 2 x (Precision x Recall) / (Precision + Recall)
cm.MatrixScores.SetFloat1D(0, f1)
// macro F1 - unweighted average of class F1 scores
// some classes might not have any instances so check NaN
f1 = 0.0
for i := 0; i < n; i++ {
classf1 := cm.ClassScores.FloatValRowCell(i, 2)
if math.IsNaN(classf1) == false {
f1 += classf1
}
}
cm.MatrixScores.SetFloat1D(1, f1/float64(n))
// weighted F1 - weighted average of class F1 scores
// some classes might not have any instances so check NaN
f1 = 0.0
totalN := 0.0
for i := 0; i < n; i++ {
classf1 := cm.ClassScores.FloatValRowCell(i, 2) * cm.N.FloatVal1D(i)
if math.IsNaN(classf1) == false {
f1 += classf1
}
totalN += cm.N.FloatVal1D(i)
}
cm.MatrixScores.SetFloat1D(2, f1/totalN)
}
// SaveCSV saves Prob result to a CSV file, comma separated
func (cm *Matrix) SaveCSV(filename gi.FileName) {
etensor.SaveCSV(&cm.Prob, filename, ',')
}
// OpenCSV opens Prob result from a CSV file, comma separated
func (cm *Matrix) OpenCSV(filename gi.FileName) {
etensor.OpenCSV(&cm.Prob, filename, ',')
}
var MatrixProps = ki.Props{
"ToolBar": ki.PropSlice{
{"SaveCSV", ki.Props{
"label": "Save CSV...",
"icon": "file-save",
"desc": "Save CSV-formatted confusion probabilities (Probs)",
"Args": ki.PropSlice{
{"CSV File Name", ki.Props{
"ext": ".csv",
}},
},
}},
{"OpenCSV", ki.Props{
"label": "Open CSV...",
"icon": "file-open",
"desc": "Open CSV-formatted confusion probabilities (Probs)",
"Args": ki.PropSlice{
{"Weights File Name", ki.Props{
"ext": ".csv",
}},
},
}},
},
}