/
focal_loss.py
105 lines (89 loc) · 3.63 KB
/
focal_loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
# pyre-strict
import torch
from torch.nn import functional as F
def sigmoid_focal_loss(
inputs: torch.Tensor,
targets: torch.Tensor,
alpha: float = -1,
gamma: float = 2,
reduction: str = "none",
) -> torch.Tensor:
"""
Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
Args:
inputs: A float tensor of arbitrary shape.
The predictions for each example.
targets: A float tensor with the same shape as inputs. Stores the binary
classification label for each element in inputs
(0 for the negative class and 1 for the positive class).
alpha: (optional) Weighting factor in range (0,1) to balance
positive vs negative examples. Default = -1 (no weighting).
gamma: Exponent of the modulating factor (1 - p_t) to
balance easy vs hard examples.
reduction: 'none' | 'mean' | 'sum'
'none': No reduction will be applied to the output.
'mean': The output will be averaged.
'sum': The output will be summed.
Returns:
Loss tensor with the reduction option applied.
"""
inputs = inputs.float()
targets = targets.float()
p = torch.sigmoid(inputs)
ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
p_t = p * targets + (1 - p) * (1 - targets)
loss = ce_loss * ((1 - p_t) ** gamma)
if alpha >= 0:
alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
loss = alpha_t * loss
if reduction == "mean":
loss = loss.mean()
elif reduction == "sum":
loss = loss.sum()
return loss
# pyre-fixme[9]: sigmoid_focal_loss_jit has type `ScriptModule`; used as
# `ScriptFunction[..., typing.Any]`.
sigmoid_focal_loss_jit: "torch.jit.ScriptModule" = torch.jit.script(sigmoid_focal_loss)
def sigmoid_focal_loss_star(
inputs: torch.Tensor,
targets: torch.Tensor,
alpha: float = -1,
gamma: float = 1,
reduction: str = "none",
) -> torch.Tensor:
"""
FL* described in RetinaNet paper Appendix: https://arxiv.org/abs/1708.02002.
Args:
inputs: A float tensor of arbitrary shape.
The predictions for each example.
targets: A float tensor with the same shape as inputs. Stores the binary
classification label for each element in inputs
(0 for the negative class and 1 for the positive class).
alpha: (optional) Weighting factor in range (0,1) to balance
positive vs negative examples. Default = -1 (no weighting).
gamma: Gamma parameter described in FL*. Default = 1 (no weighting).
reduction: 'none' | 'mean' | 'sum'
'none': No reduction will be applied to the output.
'mean': The output will be averaged.
'sum': The output will be summed.
Returns:
Loss tensor with the reduction option applied.
"""
inputs = inputs.float()
targets = targets.float()
shifted_inputs = gamma * (inputs * (2 * targets - 1))
loss = -(F.logsigmoid(shifted_inputs)) / gamma
if alpha >= 0:
alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
loss *= alpha_t
if reduction == "mean":
loss = loss.mean()
elif reduction == "sum":
loss = loss.sum()
return loss
# pyre-fixme[9]: sigmoid_focal_loss_star_jit has type `ScriptModule`; used as
# `ScriptFunction[..., typing.Any]`.
sigmoid_focal_loss_star_jit: "torch.jit.ScriptModule" = torch.jit.script(
sigmoid_focal_loss_star
)