-
Notifications
You must be signed in to change notification settings - Fork 25
/
ranking_losses.py
246 lines (194 loc) · 10.1 KB
/
ranking_losses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
import torch
class RankSort(torch.autograd.Function):
@staticmethod
def forward(ctx, logits, targets, delta_RS=0.50, eps=1e-10):
classification_grads=torch.zeros(logits.shape).cuda()
#Filter fg logits
fg_labels = (targets > 0.)
fg_logits = logits[fg_labels]
fg_targets = targets[fg_labels]
fg_num = len(fg_logits)
#Do not use bg with scores less than minimum fg logit
#since changing its score does not have an effect on precision
threshold_logit = torch.min(fg_logits)-delta_RS
relevant_bg_labels=((targets==0) & (logits>=threshold_logit))
relevant_bg_logits = logits[relevant_bg_labels]
relevant_bg_grad=torch.zeros(len(relevant_bg_logits)).cuda()
sorting_error=torch.zeros(fg_num).cuda()
ranking_error=torch.zeros(fg_num).cuda()
fg_grad=torch.zeros(fg_num).cuda()
#sort the fg logits
order=torch.argsort(fg_logits)
#Loops over each positive following the order
for ii in order:
# Difference Transforms (x_ij)
fg_relations=fg_logits-fg_logits[ii]
bg_relations=relevant_bg_logits-fg_logits[ii]
if delta_RS > 0:
fg_relations=torch.clamp(fg_relations/(2*delta_RS)+0.5,min=0,max=1)
bg_relations=torch.clamp(bg_relations/(2*delta_RS)+0.5,min=0,max=1)
else:
fg_relations = (fg_relations >= 0).float()
bg_relations = (bg_relations >= 0).float()
# Rank of ii among pos and false positive number (bg with larger scores)
rank_pos=torch.sum(fg_relations)
FP_num=torch.sum(bg_relations)
# Rank of ii among all examples
rank=rank_pos+FP_num
# Ranking error of example ii. target_ranking_error is always 0. (Eq. 7)
ranking_error[ii]=FP_num/rank
# Current sorting error of example ii. (Eq. 7)
current_sorting_error = torch.sum(fg_relations*(1-fg_targets))/rank_pos
#Find examples in the target sorted order for example ii
iou_relations = (fg_targets >= fg_targets[ii])
target_sorted_order = iou_relations * fg_relations
#The rank of ii among positives in sorted order
rank_pos_target = torch.sum(target_sorted_order)
#Compute target sorting error. (Eq. 8)
#Since target ranking error is 0, this is also total target error
target_sorting_error= torch.sum(target_sorted_order*(1-fg_targets))/rank_pos_target
#Compute sorting error on example ii
sorting_error[ii] = current_sorting_error - target_sorting_error
#Identity Update for Ranking Error
if FP_num > eps:
#For ii the update is the ranking error
fg_grad[ii] -= ranking_error[ii]
#For negatives, distribute error via ranking pmf (i.e. bg_relations/FP_num)
relevant_bg_grad += (bg_relations*(ranking_error[ii]/FP_num))
#Find the positives that are misranked (the cause of the error)
#These are the ones with smaller IoU but larger logits
missorted_examples = (~ iou_relations) * fg_relations
#Denominotor of sorting pmf
sorting_pmf_denom = torch.sum(missorted_examples)
#Identity Update for Sorting Error
if sorting_pmf_denom > eps:
#For ii the update is the sorting error
fg_grad[ii] -= sorting_error[ii]
#For positives, distribute error via sorting pmf (i.e. missorted_examples/sorting_pmf_denom)
fg_grad += (missorted_examples*(sorting_error[ii]/sorting_pmf_denom))
#Normalize gradients by number of positives
classification_grads[fg_labels]= (fg_grad/fg_num)
classification_grads[relevant_bg_labels]= (relevant_bg_grad/fg_num)
ctx.save_for_backward(classification_grads)
return ranking_error.mean(), sorting_error.mean()
@staticmethod
def backward(ctx, out_grad1, out_grad2):
g1, =ctx.saved_tensors
return g1*out_grad1, None, None, None
class aLRPLoss(torch.autograd.Function):
@staticmethod
def forward(ctx, logits, targets, regression_losses, delta=1., eps=1e-5):
classification_grads=torch.zeros(logits.shape).cuda()
#Filter fg logits
fg_labels = (targets == 1)
fg_logits = logits[fg_labels]
fg_num = len(fg_logits)
#Do not use bg with scores less than minimum fg logit
#since changing its score does not have an effect on precision
threshold_logit = torch.min(fg_logits)-delta
#Get valid bg logits
relevant_bg_labels=((targets==0)&(logits>=threshold_logit))
relevant_bg_logits=logits[relevant_bg_labels]
relevant_bg_grad=torch.zeros(len(relevant_bg_logits)).cuda()
rank=torch.zeros(fg_num).cuda()
prec=torch.zeros(fg_num).cuda()
fg_grad=torch.zeros(fg_num).cuda()
max_prec=0
#sort the fg logits
order=torch.argsort(fg_logits)
#Loops over each positive following the order
for ii in order:
#x_ij s as score differences with fgs
fg_relations=fg_logits-fg_logits[ii]
#Apply piecewise linear function and determine relations with fgs
fg_relations=torch.clamp(fg_relations/(2*delta)+0.5,min=0,max=1)
#Discard i=j in the summation in rank_pos
fg_relations[ii]=0
#x_ij s as score differences with bgs
bg_relations=relevant_bg_logits-fg_logits[ii]
#Apply piecewise linear function and determine relations with bgs
bg_relations=torch.clamp(bg_relations/(2*delta)+0.5,min=0,max=1)
#Compute the rank of the example within fgs and number of bgs with larger scores
rank_pos=1+torch.sum(fg_relations)
FP_num=torch.sum(bg_relations)
#Store the total since it is normalizer also for aLRP Regression error
rank[ii]=rank_pos+FP_num
#Compute precision for this example to compute classification loss
prec[ii]=rank_pos/rank[ii]
#For stability, set eps to a infinitesmall value (e.g. 1e-6), then compute grads
if FP_num > eps:
fg_grad[ii] = -(torch.sum(fg_relations*regression_losses)+FP_num)/rank[ii]
relevant_bg_grad += (bg_relations*(-fg_grad[ii]/FP_num))
#aLRP with grad formulation fg gradient
classification_grads[fg_labels]= fg_grad
#aLRP with grad formulation bg gradient
classification_grads[relevant_bg_labels]= relevant_bg_grad
classification_grads /= (fg_num)
cls_loss=1-prec.mean()
ctx.save_for_backward(classification_grads)
return cls_loss, rank, order
@staticmethod
def backward(ctx, out_grad1, out_grad2, out_grad3):
g1, =ctx.saved_tensors
return g1*out_grad1, None, None, None, None
class APLoss(torch.autograd.Function):
@staticmethod
def forward(ctx, logits, targets, delta=1.):
classification_grads=torch.zeros(logits.shape).cuda()
#Filter fg logits
fg_labels = (targets == 1)
fg_logits = logits[fg_labels]
fg_num = len(fg_logits)
#Do not use bg with scores less than minimum fg logit
#since changing its score does not have an effect on precision
threshold_logit = torch.min(fg_logits)-delta
#Get valid bg logits
relevant_bg_labels=((targets==0)&(logits>=threshold_logit))
relevant_bg_logits=logits[relevant_bg_labels]
relevant_bg_grad=torch.zeros(len(relevant_bg_logits)).cuda()
rank=torch.zeros(fg_num).cuda()
prec=torch.zeros(fg_num).cuda()
fg_grad=torch.zeros(fg_num).cuda()
max_prec=0
#sort the fg logits
order=torch.argsort(fg_logits)
#Loops over each positive following the order
for ii in order:
#x_ij s as score differences with fgs
fg_relations=fg_logits-fg_logits[ii]
#Apply piecewise linear function and determine relations with fgs
fg_relations=torch.clamp(fg_relations/(2*delta)+0.5,min=0,max=1)
#Discard i=j in the summation in rank_pos
fg_relations[ii]=0
#x_ij s as score differences with bgs
bg_relations=relevant_bg_logits-fg_logits[ii]
#Apply piecewise linear function and determine relations with bgs
bg_relations=torch.clamp(bg_relations/(2*delta)+0.5,min=0,max=1)
#Compute the rank of the example within fgs and number of bgs with larger scores
rank_pos=1+torch.sum(fg_relations)
FP_num=torch.sum(bg_relations)
#Store the total since it is normalizer also for aLRP Regression error
rank[ii]=rank_pos+FP_num
#Compute precision for this example
current_prec=rank_pos/rank[ii]
#Compute interpolated AP and store gradients for relevant bg examples
if (max_prec<=current_prec):
max_prec=current_prec
relevant_bg_grad += (bg_relations/rank[ii])
else:
relevant_bg_grad += (bg_relations/rank[ii])*(((1-max_prec)/(1-current_prec)))
#Store fg gradients
fg_grad[ii]=-(1-max_prec)
prec[ii]=max_prec
#aLRP with grad formulation fg gradient
classification_grads[fg_labels]= fg_grad
#aLRP with grad formulation bg gradient
classification_grads[relevant_bg_labels]= relevant_bg_grad
classification_grads /= fg_num
cls_loss=1-prec.mean()
ctx.save_for_backward(classification_grads)
return cls_loss
@staticmethod
def backward(ctx, out_grad1):
g1, =ctx.saved_tensors
return g1*out_grad1, None, None