forked from open-mmlab/mmdetection
-
Notifications
You must be signed in to change notification settings - Fork 16
/
ghm_loss.py
139 lines (123 loc) · 4.67 KB
/
ghm_loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..registry import LOSSES
def _expand_binary_labels(labels, label_weights, label_channels):
bin_labels = labels.new_full((labels.size(0), label_channels), 0)
inds = torch.nonzero(labels >= 1).squeeze()
if inds.numel() > 0:
bin_labels[inds, labels[inds] - 1] = 1
bin_label_weights = label_weights.view(-1, 1).expand(
label_weights.size(0), label_channels)
return bin_labels, bin_label_weights
@LOSSES.register_module
class GHMC(nn.Module):
def __init__(
self,
bins=10,
momentum=0,
use_sigmoid=True,
loss_weight=1.0):
super(GHMC, self).__init__()
self.bins = bins
self.momentum = momentum
self.edges = [float(x) / bins for x in range(bins+1)]
self.edges[-1] += 1e-6
if momentum > 0:
self.acc_sum = [0.0 for _ in range(bins)]
self.use_sigmoid = use_sigmoid
self.loss_weight = loss_weight
def forward(self, pred, target, label_weight, *args, **kwargs):
""" Args:
pred [batch_num, class_num]:
The direct prediction of classification fc layer.
target [batch_num, class_num]:
Binary class target for each sample.
label_weight [batch_num, class_num]:
the value is 1 if the sample is valid and 0 if ignored.
"""
if not self.use_sigmoid:
raise NotImplementedError
# the target should be binary class label
if pred.dim() != target.dim():
target, label_weight = _expand_binary_labels(target, label_weight, pred.size(-1))
target, label_weight = target.float(), label_weight.float()
edges = self.edges
mmt = self.momentum
weights = torch.zeros_like(pred)
# gradient length
g = torch.abs(pred.sigmoid().detach() - target)
valid = label_weight > 0
tot = max(valid.float().sum().item(), 1.0)
n = 0 # n valid bins
for i in range(self.bins):
inds = (g >= edges[i]) & (g < edges[i+1]) & valid
num_in_bin = inds.sum().item()
if num_in_bin > 0:
if mmt > 0:
self.acc_sum[i] = mmt * self.acc_sum[i] \
+ (1 - mmt) * num_in_bin
weights[inds] = tot / self.acc_sum[i]
else:
weights[inds] = tot / num_in_bin
n += 1
if n > 0:
weights = weights / n
loss = F.binary_cross_entropy_with_logits(
pred, target, weights, reduction='sum') / tot
return loss * self.loss_weight
@LOSSES.register_module
class GHMR(nn.Module):
def __init__(
self,
mu=0.02,
bins=10,
momentum=0,
loss_weight=1.0):
super(GHMR, self).__init__()
self.mu = mu
self.bins = bins
self.edges = [float(x) / bins for x in range(bins+1)]
self.edges[-1] = 1e3
self.momentum = momentum
if momentum > 0:
self.acc_sum = [0.0 for _ in range(bins)]
self.loss_weight = loss_weight
def forward(self, pred, target, label_weight, avg_factor=None):
""" Args:
pred [batch_num, 4 (* class_num)]:
The prediction of box regression layer. Channel number can be 4 or
(4 * class_num) depending on whether it is class-agnostic.
target [batch_num, 4 (* class_num)]:
The target regression values with the same size of pred.
label_weight [batch_num, 4 (* class_num)]:
The weight of each sample, 0 if ignored.
"""
mu = self.mu
edges = self.edges
mmt = self.momentum
# ASL1 loss
diff = pred - target
loss = torch.sqrt(diff * diff + mu * mu) - mu
# gradient length
g = torch.abs(diff / torch.sqrt(mu * mu + diff * diff)).detach()
weights = torch.zeros_like(g)
valid = label_weight > 0
tot = max(label_weight.float().sum().item(), 1.0)
n = 0 # n: valid bins
for i in range(self.bins):
inds = (g >= edges[i]) & (g < edges[i+1]) & valid
num_in_bin = inds.sum().item()
if num_in_bin > 0:
n += 1
if mmt > 0:
self.acc_sum[i] = mmt * self.acc_sum[i] \
+ (1 - mmt) * num_in_bin
weights[inds] = tot / self.acc_sum[i]
else:
weights[inds] = tot / num_in_bin
if n > 0:
weights /= n
loss = loss * weights
loss = loss.sum() / tot
return loss * self.loss_weight