-
Notifications
You must be signed in to change notification settings - Fork 1
/
metrics.py
90 lines (62 loc) · 2.7 KB
/
metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import pickle as pkl
import torch
import torch.nn as nn
def filterOutput(outputs, labels, label_embeds, PAD_token_ind):
mask = (labels != PAD_token_ind).squeeze()
labels, label_embeds, outputs = labels[mask], label_embeds[mask], outputs[
mask]
return labels, label_embeds, outputs
def maskedLoss(label_embeds, outputs, criterion):
num_tokens = label_embeds.shape[0]
loss = criterion(label_embeds, outputs)
loss = torch.sum(loss) / num_tokens
return loss
# def getIndicesFromEmbedding(output_embeds, txt_embeds):
# similarity = output_embeds.mm(txt_embeds.t())
# output_indices = similarity.argmax(1)
# return output_indices
def getIndicesFromEmbedding(output_embeds, txt_embeds):
eps = torch.Tensor([1e-6]).cuda()
txt_embeds_norm = torch.max(
torch.norm(txt_embeds, p=2, dim=1).detach(), eps)
similarity = output_embeds.mm(txt_embeds.t())
similarity = similarity.div(txt_embeds_norm) #.expand_as(txt_embeds))
output_indices = similarity.argmax(1)
return output_indices
def word_accuracy(output_embeds, txt_embeds, labels):
"""
Input:
output_embeds - Flattened Output of the decoder with padding token removed
txt_embeds - Text embeds for the entire vocab
labels - Flattened labels
Output:
Word Level Accuracy
TODO(Jay): Address this issue
There is a problem, since the embedding of SOS == EOS
So for EOS it will predict index 2 which is index for SOS coz it comes first
"""
num_words = output_embeds.shape[0]
output_indices = getIndicesFromEmbedding(output_embeds, txt_embeds)
labels = labels.squeeze()
output_indices = output_indices.squeeze()
assert output_indices.shape == labels.shape
correct = torch.sum(output_indices == labels).float()
acc = 100 * (correct / num_words)
return acc
# def accuracy(output_embeds, txt_embeds, labels, PAD_token_ind):
# batch_size = output_embeds.shape[0]
# for i in range(batch_size):
# output_embeds
# similarity = output_embeds.mm(txt_embeds.t())
# output_indices = similarity.argmax(1)
if __name__ == "__main__":
labels = torch.tensor([2, 45, 65, 71, 32, 3])
correct_outputs = [2, 45, 65, 71, 32, 3]
wrong_outputs = [3, 32, 71, 65, 45, 2]
with open('./data/txt_embed.pkl', 'rb') as f:
txt_embeds = torch.from_numpy(pkl.load(f))
correct_acc = word_accuracy(txt_embeds[correct_outputs], txt_embeds,
labels)
wrong_acc = word_accuracy(txt_embeds[wrong_outputs], txt_embeds, labels)
print('Accuracy for correct outputs: {}'.format(correct_acc))
print('Accuracy for wrong outputs: {}'.format(wrong_acc))