-
Notifications
You must be signed in to change notification settings - Fork 7.5k
/
losses.py
125 lines (108 loc) · 5.67 KB
/
losses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/01a_losses.ipynb (unless otherwise specified).
__all__ = ['BaseLoss', 'CrossEntropyLossFlat', 'FocalLossFlat', 'BCEWithLogitsLossFlat', 'BCELossFlat', 'MSELossFlat',
'L1LossFlat', 'LabelSmoothingCrossEntropy', 'LabelSmoothingCrossEntropyFlat']
# Cell
from .imports import *
from .torch_imports import *
from .torch_core import *
from .layers import *
# Cell
class BaseLoss():
"Same as `loss_cls`, but flattens input and target."
activation=decodes=noops
def __init__(self, loss_cls, *args, axis=-1, flatten=True, floatify=False, is_2d=True, **kwargs):
store_attr("axis,flatten,floatify,is_2d")
self.func = loss_cls(*args,**kwargs)
functools.update_wrapper(self, self.func)
def __repr__(self): return f"FlattenedLoss of {self.func}"
@property
def reduction(self): return self.func.reduction
@reduction.setter
def reduction(self, v): self.func.reduction = v
def _contiguous(self,x):
return TensorBase(x.transpose(self.axis,-1).contiguous()) if isinstance(x,torch.Tensor) else x
def __call__(self, inp, targ, **kwargs):
inp,targ = map(self._contiguous, (inp,targ))
if self.floatify and targ.dtype!=torch.float16: targ = targ.float()
if targ.dtype in [torch.int8, torch.int16, torch.int32]: targ = targ.long()
if self.flatten: inp = inp.view(-1,inp.shape[-1]) if self.is_2d else inp.view(-1)
return self.func.__call__(inp, targ.view(-1) if self.flatten else targ, **kwargs)
# Cell
@delegates()
class CrossEntropyLossFlat(BaseLoss):
"Same as `nn.CrossEntropyLoss`, but flattens input and target."
y_int = True
@use_kwargs_dict(keep=True, weight=None, ignore_index=-100, reduction='mean')
def __init__(self, *args, axis=-1, **kwargs): super().__init__(nn.CrossEntropyLoss, *args, axis=axis, **kwargs)
def decodes(self, x): return x.argmax(dim=self.axis)
def activation(self, x): return F.softmax(x, dim=self.axis)
# Cell
class FocalLossFlat(CrossEntropyLossFlat):
"""
Same as CrossEntropyLossFlat but with focal paramter, `gamma`. Focal loss is introduced by Lin et al.
https://arxiv.org/pdf/1708.02002.pdf. Note the class weighting factor in the paper, alpha, can be
implemented through pytorch `weight` argument in nn.CrossEntropyLoss.
"""
y_int = True
@use_kwargs_dict(keep=True, weight=None, ignore_index=-100, reduction='mean')
def __init__(self, *args, gamma=2, axis=-1, **kwargs):
self.gamma = gamma
self.reduce = kwargs.pop('reduction') if 'reduction' in kwargs else 'mean'
super().__init__(*args, reduction='none', axis=axis, **kwargs)
def __call__(self, inp, targ, **kwargs):
ce_loss = super().__call__(inp, targ, **kwargs)
pt = torch.exp(-ce_loss)
fl_loss = (1-pt)**self.gamma * ce_loss
return fl_loss.mean() if self.reduce == 'mean' else fl_loss.sum() if self.reduce == 'sum' else fl_loss
# Cell
@delegates()
class BCEWithLogitsLossFlat(BaseLoss):
"Same as `nn.BCEWithLogitsLoss`, but flattens input and target."
@use_kwargs_dict(keep=True, weight=None, reduction='mean', pos_weight=None)
def __init__(self, *args, axis=-1, floatify=True, thresh=0.5, **kwargs):
if kwargs.get('pos_weight', None) is not None and kwargs.get('flatten', None) is True:
raise ValueError("`flatten` must be False when using `pos_weight` to avoid a RuntimeError due to shape mismatch")
if kwargs.get('pos_weight', None) is not None: kwargs['flatten'] = False
super().__init__(nn.BCEWithLogitsLoss, *args, axis=axis, floatify=floatify, is_2d=False, **kwargs)
self.thresh = thresh
def decodes(self, x): return x>self.thresh
def activation(self, x): return torch.sigmoid(x)
# Cell
@use_kwargs_dict(weight=None, reduction='mean')
def BCELossFlat(*args, axis=-1, floatify=True, **kwargs):
"Same as `nn.BCELoss`, but flattens input and target."
return BaseLoss(nn.BCELoss, *args, axis=axis, floatify=floatify, is_2d=False, **kwargs)
# Cell
@use_kwargs_dict(reduction='mean')
def MSELossFlat(*args, axis=-1, floatify=True, **kwargs):
"Same as `nn.MSELoss`, but flattens input and target."
return BaseLoss(nn.MSELoss, *args, axis=axis, floatify=floatify, is_2d=False, **kwargs)
# Cell
@use_kwargs_dict(reduction='mean')
def L1LossFlat(*args, axis=-1, floatify=True, **kwargs):
"Same as `nn.L1Loss`, but flattens input and target."
return BaseLoss(nn.L1Loss, *args, axis=axis, floatify=floatify, is_2d=False, **kwargs)
# Cell
class LabelSmoothingCrossEntropy(Module):
y_int = True
def __init__(self, eps:float=0.1, weight=None, reduction='mean'):
store_attr()
def forward(self, output, target):
c = output.size()[1]
log_preds = F.log_softmax(output, dim=1)
if self.reduction=='sum': loss = -log_preds.sum()
else:
loss = -log_preds.sum(dim=1) #We divide by that size at the return line so sum and not mean
if self.reduction=='mean': loss = loss.mean()
return loss*self.eps/c + (1-self.eps) * F.nll_loss(log_preds, target.long(), weight=self.weight, reduction=self.reduction)
def activation(self, out): return F.softmax(out, dim=-1)
def decodes(self, out): return out.argmax(dim=-1)
# Cell
@delegates()
class LabelSmoothingCrossEntropyFlat(BaseLoss):
"Same as `LabelSmoothingCrossEntropy`, but flattens input and target."
y_int = True
@use_kwargs_dict(keep=True, eps=0.1, reduction='mean')
def __init__(self, *args, axis=-1, **kwargs): super().__init__(LabelSmoothingCrossEntropy, *args, axis=axis, **kwargs)
def activation(self, out): return F.softmax(out, dim=-1)
def decodes(self, out): return out.argmax(dim=-1)