forked from sugarme/gotch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
loss.go
121 lines (105 loc) · 3.28 KB
/
loss.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
package nn
import (
"github.com/nullbull/gotch/ts"
)
type lossFnOptions struct {
ClassWeights []float64
Reduction int64 // 0: "None", 1: "mean", 2: "sum"
IgnoreIndex int64
PosWeight int64 // index of the weight attributed to positive class. Used in BCELoss
}
type LossFnOption func(*lossFnOptions)
func WithLossFnWeights(vals []float64) LossFnOption {
return func(o *lossFnOptions) {
o.ClassWeights = vals
}
}
func WithLossFnReduction(val int64) LossFnOption {
return func(o *lossFnOptions) {
o.Reduction = val
}
}
func WithLossFnIgnoreIndex(val int64) LossFnOption {
return func(o *lossFnOptions) {
o.IgnoreIndex = val
}
}
func WithLossFnPosWeight(val int64) LossFnOption {
return func(o *lossFnOptions) {
o.PosWeight = val
}
}
func defaultLossFnOptions() *lossFnOptions {
return &lossFnOptions{
ClassWeights: nil,
Reduction: 1, // "mean"
IgnoreIndex: -100,
PosWeight: -1,
}
}
// CrossEntropyLoss calculates cross entropy loss.
// Ref. https://github.com/pytorch/pytorch/blob/15be189f0de4addf4f68d18022500f67617ab05d/torch/nn/functional.py#L2012
// - logits: tensor of shape [B, C, H, W] corresponding the raw output of the model.
// - target: ground truth tensor of shape [B, 1, H, W]
// - posWeight: scalar representing the weight attributed to positive class.
// This is especially useful for an imbalanced dataset
func CrossEntropyLoss(logits, target *ts.Tensor, opts ...LossFnOption) *ts.Tensor {
options := defaultLossFnOptions()
for _, o := range opts {
o(options)
}
var ws *ts.Tensor
device := logits.MustDevice()
dtype := logits.DType()
if len(options.ClassWeights) > 0 {
ws = ts.MustOfSlice(options.ClassWeights).MustTotype(dtype, true).MustTo(device, true)
} else {
ws = ts.NewTensor()
}
reduction := options.Reduction
ignoreIndex := options.IgnoreIndex
logSm := logits.MustLogSoftmax(-1, dtype, false)
loss := logSm.MustNllLoss(target, ws, reduction, ignoreIndex, true)
ws.MustDrop()
return loss
}
// BCELoss calculates a binary cross entropy loss.
//
// - logits: tensor of shape [B, C, H, W] corresponding the raw output of the model.
// - target: ground truth tensor of shape [B, 1, H, W]
// - posWeight: scalar representing the weight attributed to positive class.
// This is especially useful for an imbalanced dataset
func BCELoss(logits, target *ts.Tensor, opts ...LossFnOption) *ts.Tensor {
options := defaultLossFnOptions()
for _, o := range opts {
o(options)
}
var ws *ts.Tensor
device := logits.MustDevice()
dtype := logits.DType()
if len(options.ClassWeights) > 0 {
ws = ts.MustOfSlice(options.ClassWeights).MustTotype(dtype, true).MustTo(device, true)
} else {
ws = ts.NewTensor()
}
reduction := options.Reduction
var posWeight *ts.Tensor
if options.PosWeight >= 0 {
posWeight = ts.MustOfSlice([]int64{options.PosWeight})
} else {
posWeight = ts.NewTensor()
}
loss := logits.MustSqueeze(false).MustBinaryCrossEntropyWithLogits(target, ws, posWeight, reduction, true)
return loss
}
// MSELoss calculates Mean-Square Loss.
//
// - reductionOpt: either 0 ("none"); 1 ("mean"); 2 ("sum"). Default=mean
func MSELoss(logits, labels *ts.Tensor, reductionOpt ...int64) *ts.Tensor {
reduction := int64(1)
if len(reductionOpt) > 0 {
reduction = reductionOpt[0]
}
out := logits.MustMseLoss(labels, reduction, false)
return out
}