-
Notifications
You must be signed in to change notification settings - Fork 10
/
losses.py
135 lines (113 loc) · 4.51 KB
/
losses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
from __future__ import absolute_import
from __future__ import division
import sys
import torch
from torch import nn
from torch.nn import functional as F
def DeepSupervision(criterion, xs, y):
"""
Args:
- criterion: loss function
- xs: tuple of inputs
- y: ground truth
"""
loss = 0.
for x in xs:
loss += criterion(x, y)
loss /= len(xs)
return loss
class MaskLoss(nn.Module):
"""L2 or L1 loss or cross entropy loss with average with all elements.
"""
def __init__(self, mode='l2'):
super(MaskLoss, self).__init__()
if mode == 'l2':
self.loss = nn.MSELoss()
elif mode == 'l1':
self.loss = nn.L1Loss()
elif mode == 'ce':
self.loss = nn.BCELoss()
def forward(self, inputs, targets):
"""
Args:
- inputs: prediction spatial map with shape (batch_size, 1, h, w)
- targets: ground truth labels with shape (batch_size, 1, h1, w1)
"""
b, c, h, w = inputs.size()
targets = F.interpolate(targets, (h, w), mode='bilinear', align_corners=True)
inputs = inputs.view(b, -1)
targets = targets.view(b, -1)
return self.loss(inputs, targets)
class CrossEntropyLabelSmooth(nn.Module):
"""Cross entropy loss with label smoothing regularizer.
Reference:
Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016.
Equation: y = (1 - epsilon) * y + epsilon / K.
Args:
- num_classes (int): number of classes.
- epsilon (float): weight.
"""
def __init__(self, num_classes, epsilon=0.1, use_gpu=True):
super(CrossEntropyLabelSmooth, self).__init__()
self.num_classes = num_classes
self.epsilon = epsilon
self.use_gpu = use_gpu
self.logsoftmax = nn.LogSoftmax(dim=1)
def forward(self, inputs, targets):
"""
Args:
- inputs: prediction matrix (before softmax) with shape (batch_size, num_classes)
- targets: ground truth labels with shape (num_classes)
"""
log_probs = self.logsoftmax(inputs)
targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).data.cpu(), 1)
if self.use_gpu: targets = targets.cuda()
targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
loss = (- targets * log_probs).mean(0).sum()
return loss
class TripletLoss(nn.Module):
"""Triplet loss with hard positive/negative mining.
Reference:
Hermans et al. In Defense of the Triplet Loss for Person Re-Identification. arXiv:1703.07737.
Code imported from https://github.com/Cysu/open-reid/blob/master/reid/loss/triplet.py.
Args:
- margin (float): margin for triplet.
"""
def __init__(self, margin=0.3, distance='euclidean', use_gpu=True):
super(TripletLoss, self).__init__()
if distance not in ['euclidean', 'consine']:
raise KeyError('Unsupported distance: {}'.format(distance))
self.distance = distance
self.margin = margin
self.use_gpu = use_gpu
self.ranking_loss = nn.MarginRankingLoss(margin=margin)
def forward(self, inputs, targets):
"""
Args:
- inputs: feature matrix with shape (batch_size, feat_dim)
- targets: ground truth labels with shape (num_classes)
"""
n = inputs.size(0)
# Compute pairwise distance, replace by the official when merged
if self.distance == 'euclidean':
dist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n)
dist = dist + dist.t()
dist.addmm_(1, -2, inputs, inputs.t())
dist = dist.clamp(min=1e-12).sqrt() # for numerical stability
elif self.distance == 'consine':
fnorm = torch.norm(inputs, p=2, dim=1, keepdim=True)
l2norm = inputs.div(fnorm.expand_as(inputs))
dist = - torch.mm(l2norm, l2norm.t())
if self.use_gpu: targets = targets.cuda()
# For each anchor, find the hardest positive and negative
mask = targets.expand(n, n).eq(targets.expand(n, n).t())
dist_ap, dist_an = [], []
for i in range(n):
dist_ap.append(dist[i][mask[i]].max().unsqueeze(0))
dist_an.append(dist[i][mask[i] == 0].min().unsqueeze(0))
dist_ap = torch.cat(dist_ap)
dist_an = torch.cat(dist_an)
# Compute ranking hinge loss
y = torch.ones_like(dist_an)
loss = self.ranking_loss(dist_an, dist_ap, y)
return loss