/
classification_evaluator.py
65 lines (53 loc) · 2.99 KB
/
classification_evaluator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import numpy as np
import torch
import torch.nn.functional as F
from sklearn import metrics
from common.evaluators.evaluator import Evaluator
class ClassificationEvaluator(Evaluator):
def __init__(self, dataset_cls, model, embedding, data_loader, batch_size, device, keep_results=False):
super().__init__(dataset_cls, model, embedding, data_loader, batch_size, device, keep_results)
self.ignore_lengths = False
self.is_multilabel = False
def get_scores(self):
self.model.eval()
self.data_loader.init_epoch()
total_loss = 0
if hasattr(self.model, 'beta_ema') and self.model.beta_ema > 0:
# Temporal averaging
old_params = self.model.get_params()
self.model.load_ema_params()
predicted_labels, target_labels = list(), list()
for batch_idx, batch in enumerate(self.data_loader):
if hasattr(self.model, 'tar') and self.model.tar:
if self.ignore_lengths:
scores, rnn_outs = self.model(batch.text)
else:
scores, rnn_outs = self.model(batch.text[0], lengths=batch.text[1])
else:
if self.ignore_lengths:
scores = self.model(batch.text)
else:
scores = self.model(batch.text[0], lengths=batch.text[1])
if self.is_multilabel:
scores_rounded = F.sigmoid(scores).round().long()
predicted_labels.extend(scores_rounded.cpu().detach().numpy())
target_labels.extend(batch.label.cpu().detach().numpy())
total_loss += F.binary_cross_entropy_with_logits(scores, batch.label.float(), size_average=False).item()
else:
predicted_labels.extend(torch.argmax(scores, dim=1).cpu().detach().numpy())
target_labels.extend(torch.argmax(batch.label, dim=1).cpu().detach().numpy())
total_loss += F.cross_entropy(scores, torch.argmax(batch.label, dim=1), size_average=False).item()
if hasattr(self.model, 'tar') and self.model.tar:
# Temporal activation regularization
total_loss += (rnn_outs[1:] - rnn_outs[:-1]).pow(2).mean()
predicted_labels = np.array(predicted_labels)
target_labels = np.array(target_labels)
accuracy = metrics.accuracy_score(target_labels, predicted_labels)
precision = metrics.precision_score(target_labels, predicted_labels, average='micro')
recall = metrics.recall_score(target_labels, predicted_labels, average='micro')
f1 = metrics.f1_score(target_labels, predicted_labels, average='micro')
avg_loss = total_loss / len(self.data_loader.dataset.examples)
if hasattr(self.model, 'beta_ema') and self.model.beta_ema > 0:
# Temporal averaging
self.model.load_params(old_params)
return [accuracy, precision, recall, f1, avg_loss], ['accuracy', 'precision', 'recall', 'f1', 'cross_entropy_loss']