-
Notifications
You must be signed in to change notification settings - Fork 30
/
losses.py
executable file
·210 lines (166 loc) · 8.21 KB
/
losses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
class PaCoLoss(nn.Module):
def __init__(self, alpha, beta=1.0, gamma=1.0, supt=1.0, temperature=1.0, base_temperature=None, K=8192, num_classes=1000, smooth=0.0):
super(PaCoLoss, self).__init__()
self.temperature = temperature
self.beta = beta
self.gamma = gamma
self.supt = supt
self.base_temperature = temperature if base_temperature is None else base_temperature
self.K = K
self.alpha = alpha
self.num_classes = num_classes
self.smooth = smooth
self.weight = None
def cal_weight_for_classes(self, cls_num_list):
cls_num_list = torch.Tensor(cls_num_list).view(1, self.num_classes)
self.weight = cls_num_list / cls_num_list.sum()
self.weight = self.weight.to(torch.device('cuda'))
def forward(self, features, labels=None, sup_logits=None):
device = (torch.device('cuda')
if features.is_cuda
else torch.device('cpu'))
batch_size = features.shape[0] - self.K
labels = labels.contiguous().view(-1, 1)
mask = torch.eq(labels[:batch_size], labels.T).float().to(device)
# compute logits
anchor_dot_contrast = torch.div(
torch.matmul(features[:batch_size], features.T),
self.temperature)
# add supervised logits
if self.weight is not None:
anchor_dot_contrast = torch.cat(( (sup_logits + torch.log(self.weight + 1e-9) ) / self.supt, anchor_dot_contrast), dim=1)
else:
anchor_dot_contrast = torch.cat(( (sup_logits) / self.supt, anchor_dot_contrast), dim=1)
# for numerical stability
logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
logits = anchor_dot_contrast - logits_max.detach()
# mask-out self-contrast cases
logits_mask = torch.scatter(
torch.ones_like(mask),
1,
torch.arange(batch_size).view(-1, 1).to(device),
0
)
mask = mask * logits_mask
# add ground truth
one_hot_label = torch.nn.functional.one_hot(labels[:batch_size,].view(-1,), num_classes=self.num_classes).to(torch.float32)
one_hot_label = self.smooth / (self.num_classes - 1 ) * (1 - one_hot_label) + (1 - self.smooth) * one_hot_label
mask = torch.cat((one_hot_label * self.beta, mask * self.alpha), dim=1)
# compute log_prob
logits_mask = torch.cat((torch.ones(batch_size, self.num_classes).to(device), self.gamma * logits_mask), dim=1)
exp_logits = torch.exp(logits) * logits_mask
log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True) + 1e-12)
# compute mean of log-likelihood over positive
mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)
# loss
loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos
loss = loss.mean()
return loss
class MultiTaskLoss(nn.Module):
def __init__(self, alpha, beta=1.0, gamma=1.0, supt=1.0, temperature=1.0, base_temperature=None, K=8192, num_classes=1000):
super(MultiTaskLoss, self).__init__()
self.temperature = temperature
self.base_temperature = temperature if base_temperature is None else base_temperature
self.K = K
self.alpha = alpha
self.beta = beta
self.gamma = gamma
self.supt = supt
self.num_classes = num_classes
self.effective_num_beta = 0.999
def cal_weight_for_classes(self, cls_num_list):
cls_num_list = torch.Tensor(cls_num_list).view(1, self.num_classes)
self.weight = cls_num_list / cls_num_list.sum()
self.weight = self.weight.to(torch.device('cuda'))
if self.effective_num_beta != 0:
effective_num = np.array(1.0 - np.power(self.effective_num_beta, cls_num_list)) / (1.0 - self.effective_num_beta)
per_cls_weights = sum(effective_num) / len(effective_num) / effective_num
self.class_weight = torch.FloatTensor(per_cls_weights).to(torch.device('cuda'))
def forward(self, features, labels=None, sup_logits=None, mask=None, epoch=None):
device = (torch.device('cuda')
if features.is_cuda
else torch.device('cpu'))
batch_size = features.shape[0] - self.K
labels = labels.contiguous().view(-1, 1)
mask = torch.eq(labels[:batch_size], labels.T).float().to(device)
# compute logits
anchor_dot_contrast = torch.div(
torch.matmul(features[:batch_size], features.T),
self.temperature)
# for numerical stability
logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
logits = anchor_dot_contrast - logits_max.detach()
# mask-out self-contrast cases
logits_mask = torch.scatter(
torch.ones_like(mask),
1,
torch.arange(batch_size).view(-1, 1).to(device),
0
)
mask = mask * logits_mask
# compute log_prob
exp_logits = torch.exp(logits) * logits_mask
log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True) + 1e-12)
# compute mean of log-likelihood over positive
mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)
# loss
rew = self.class_weight.squeeze()[labels[:batch_size].squeeze()]
loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos * rew
loss = loss.mean()
loss_balancesoftmax = F.cross_entropy(sup_logits + torch.log(self.weight + 1e-9), labels[:batch_size].squeeze())
return loss_balancesoftmax + self.alpha * loss
class MultiTaskBLoss(nn.Module):
def __init__(self, alpha, beta=1.0, gamma=1.0, supt=1.0, temperature=1.0, base_temperature=None, K=8192, num_classes=1000):
super(MultiTaskBLoss, self).__init__()
self.temperature = temperature
self.base_temperature = temperature if base_temperature is None else base_temperature
self.K = K
self.alpha = alpha
self.beta = beta
self.gamma = gamma
self.supt = supt
self.num_classes = num_classes
self.effective_num_beta = 0.999
def cal_weight_for_classes(self, cls_num_list):
cls_num_list = torch.Tensor(cls_num_list).view(1, self.num_classes)
self.weight = cls_num_list / cls_num_list.sum()
self.weight = self.weight.to(torch.device('cuda'))
if self.effective_num_beta != 0:
effective_num = np.array(1.0 - np.power(self.effective_num_beta, cls_num_list)) / (1.0 - self.effective_num_beta)
per_cls_weights = sum(effective_num) / len(effective_num) / effective_num
self.class_weight = torch.FloatTensor(per_cls_weights).to(torch.device('cuda'))
def forward(self, features, labels=None, sup_logits=None, mask=None, epoch=None):
device = (torch.device('cuda')
if features.is_cuda
else torch.device('cpu'))
batch_size = features.shape[0] - self.K
labels = labels.contiguous().view(-1, 1)
mask = torch.eq(labels[:batch_size], labels.T).float().to(device)
# compute logits
anchor_dot_contrast = torch.div(
torch.matmul(features[:batch_size], features.T),
self.temperature)
# for numerical stability
logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
logits = anchor_dot_contrast - logits_max.detach()
# mask-out self-contrast cases
logits_mask = torch.scatter(
torch.ones_like(mask),
1,
torch.arange(batch_size).view(-1, 1).to(device),
0
)
mask = mask * logits_mask
# compute log_prob
exp_logits = torch.exp(logits) * logits_mask
log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True) + 1e-12)
# compute mean of log-likelihood over positive
mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)
loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos
loss = loss.mean()
loss_ce = F.cross_entropy(sup_logits, labels[:batch_size].squeeze())
return loss_ce + self.alpha * loss