-
Notifications
You must be signed in to change notification settings - Fork 11
/
losses.py
96 lines (79 loc) · 2.8 KB
/
losses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import torch
import torch.nn as nn
import torch.nn.functional as F
class KLLoss(nn.Module):
def __init__(self):
super().__init__()
self.KLLoss = nn.KLDivLoss()
def forward(self, output, target):
'''
Output: (N,*) \n
Target: (N,*) \n
'''
output = torch.log(output) # Invert softmax
# target = torch.log(target) # Invert softmax
# How output distribution differs from target distribution
return self.KLLoss(output, target)
class CELoss(nn.Module):
def __init__(self, ignore_index=-1):
super().__init__()
self.CELoss = nn.CrossEntropyLoss(ignore_index=ignore_index)
def forward(self, output, target):
'''
Output: (N,*,C) \n
Target: (N,*) \n
'''
output = torch.log(output) # Invert softmax
output = output.reshape(-1, output.shape[-1]) # (*,C)
target = target.reshape(-1).long() # (*)
return self.CELoss(output, target)
class CELossSame(nn.Module):
def __init__(self, ignore_index=-1):
super().__init__()
self.CELoss = nn.CrossEntropyLoss(ignore_index=ignore_index)
def forward(self, outputs, target):
'''
Output: (N,*,C) \n
Target: (N,*) \n
'''
output_img = torch.log(outputs[0]) # Invert softmax
output_txt = torch.log(outputs[1])
output_sen = torch.log(outputs[2])
output_img = output_img.reshape(-1, output_img.shape[-1]) # (*,C)
output_txt = output_txt.reshape(-1, output_txt.shape[-1]) # (*,C)
output_sen = output_sen.reshape(-1, output_sen.shape[-1]) # (*,C)
target = target.reshape(-1).long() # (*)
return self.CELoss(output_img, target) + self.CELoss(output_txt, target) + self.CELoss(output_sen, target)
class CELossShift(nn.Module):
def __init__(self, ignore_index=-1):
super().__init__()
self.CELoss = CELoss(ignore_index=ignore_index)
def forward(self, output, target):
'''
Output: (N,*,C) \n
Target: (N,*) \n
'''
output = output[:,:-1,:] # (* - 1,C)
target = target[:,1:] # (* - 1)
return self.CELoss(output, target)
class CELossTotal(nn.Module):
def __init__(self, ignore_index=-1):
super().__init__()
self.CELoss = CELoss()
self.CELossShift = CELossShift(ignore_index=ignore_index)
def forward(self, output, target):
return self.CELossShift(output[0], target[0]) + self.CELoss(output[1], target[1])
class CELossTotalEval(nn.Module):
def __init__(self, ignore_index=-1):
super().__init__()
self.CELoss = CELoss()
self.CELossShift = CELossShift(ignore_index=ignore_index)
def forward(self, output, target):
return self.CELossShift(output[0], target[0]) + self.CELoss(output[1], target[1]) + self.CELoss(output[2], target[1])
class CELossTransfer(nn.Module):
def __init__(self, ignore_index=-1):
super().__init__()
self.CELoss = CELoss()
self.CELossShift = CELossShift(ignore_index=ignore_index)
def forward(self, output, target):
return self.CELossShift(output[0], target[0]) # + self.CELoss(output[1], target[1])