-
Notifications
You must be signed in to change notification settings - Fork 8
/
evaluate.py
104 lines (86 loc) · 4.15 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import numpy as np
import statistics as s
EMOS = ['happy', 'angry', 'sad', 'others']
EMOS_DIC = {'happy': 0,
'angry': 1,
'sad': 2,
'others': 3}
NUM_EMO = len(EMOS)
def load_dev_labels(data_path='data/dev.txt'):
# data_path = 'data/train.txt'
CONV_PAD_LEN = 3
target_list = []
f_data = open(data_path, 'r', encoding='utf8')
data_lines = f_data.readlines()
f_data.close()
for i, text in enumerate(data_lines):
# Ignore the first line as it is the name of the columns
if i == 0:
continue
tokens = text.split('\t')
emo = tokens[CONV_PAD_LEN + 1].strip()
target_list.append(EMOS_DIC[emo])
return np.asarray(target_list)
def to_categorical(vec):
to_ret = np.zeros((vec.shape[0], NUM_EMO))
for idx, val in enumerate(vec):
to_ret[idx, val] = 1
return to_ret
def get_metrics(ground, predictions):
"""Given predicted labels and the respective ground truth labels, display some metrics
Input: shape [# of samples, NUM_CLASSES]
predictions : Model output. Every row has 4 decimal values, with the highest belonging to the predicted class
ground : Ground truth labels, converted to one-hot encodings. A sample belonging to Happy class will be [0, 1, 0, 0]
Output:
accuracy : Average accuracy
microPrecision : Precision calculated on a micro level. Ref -
https://datascience.stackexchange.com/questions/15989/micro-average-vs-macro-average-performance-in-a-multiclass-classification-settin/16001
microRecall : Recall calculated on a micro level
microF1 : Harmonic mean of microPrecision and microRecall. Higher value implies better classification
"""
# [0.1, 0.3 , 0.2, 0.1] -> [0, 1, 0, 0]
discretePredictions = to_categorical(predictions)
ground = to_categorical(ground)
truePositives = np.sum(discretePredictions * ground, axis=0)
falsePositives = np.sum(np.clip(discretePredictions - ground, 0, 1), axis=0)
falseNegatives = np.sum(np.clip(ground - discretePredictions, 0, 1), axis=0)
print("True Positives per class : ", truePositives)
print("False Positives per class : ", falsePositives)
print("False Negatives per class : ", falseNegatives)
# Macro level calculation
macroPrecision = 0
macroRecall = 0
f1_list = []
# We ignore the "Others" class during the calculation of Precision, Recall and F1
for c in range(NUM_EMO-1):
precision = truePositives[c] / (truePositives[c] + falsePositives[c])
macroPrecision += precision
recall = truePositives[c] / (truePositives[c] + falseNegatives[c])
macroRecall += recall
f1 = (2 * recall * precision) / (precision + recall) if (precision + recall) > 0 else 0
f1_list.append(f1)
print("Class %s : Precision : %.3f, Recall : %.3f, F1 : %.3f" % (EMOS[c], precision, recall, f1))
print('Harmonic Mean: ',
s.harmonic_mean(f1_list))
macroPrecision /= 3
macroRecall /= 3
macroF1 = (2 * macroRecall * macroPrecision) / (macroPrecision + macroRecall) \
if (macroPrecision + macroRecall) > 0 else 0
print("Ignoring the Others class, Macro Precision : %.4f, Macro Recall : %.4f, Macro F1 : %.4f" % (
macroPrecision, macroRecall, macroF1))
# Micro level calculation
truePositives = truePositives[1:].sum()
falsePositives = falsePositives[1:].sum()
falseNegatives = falseNegatives[1:].sum()
print("Ignoring the Others class, Micro TP : %d, FP : %d, FN : %d"
% (truePositives, falsePositives, falseNegatives))
microPrecision = truePositives / (truePositives + falsePositives)
microRecall = truePositives / (truePositives + falseNegatives)
microF1 = (2 * microRecall * microPrecision) / (microPrecision + microRecall)\
if (microPrecision + microRecall) > 0 else 0
# predictions = predictions.argmax(axis=1)
ground = ground.argmax(axis=1)
accuracy = np.mean(predictions == ground)
print("Accuracy : %.4f, Micro Precision : %.4f, Micro Recall : %.4f, Micro F1 : %.4f" % (
accuracy, microPrecision, microRecall, microF1))
return accuracy, microPrecision, microRecall, microF1