-
Notifications
You must be signed in to change notification settings - Fork 1
/
unifiedFocalLossDemo.Rmd
198 lines (149 loc) · 9.7 KB
/
unifiedFocalLossDemo.Rmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
---
title: "Unified Focal Loss Framework"
output: rmarkdown::html_vignette
vignette: >
%\VignetteIndexEntry{unifiedFocalLossDemo}
%\VignetteEngine{knitr::rmarkdown}
%\VignetteEncoding{UTF-8}
---
```{r setup, include=FALSE}
knitr::opts_chunk$set(echo=TRUE, comment="#>", collapse=TRUE, warning=FALSE, message=FALSE, fit.cap="")
```
```r
library(geodl)
library(torch)
```
```r
device <- torch_device("cuda")
```
## Background
**geodl** provides a modified version of the **unified focal loss** proposed in the following study:
Yeung, M., Sala, E., Schönlieb, C.B. and Rundo, L., 2022. Unified focal loss: Generalising dice and cross entropy-based losses to handle class imbalanced medical image segmentation. *Computerized Medical Imaging and Graphics*, 95, p.102026.
The table below describes the loss parameterization. Our implementation is different from the originally proposed implementation because we do not implement symmetric and asymmetric forms. Instead, we allow the user to define class weights for both the distribution- and region-based loss components.
![Modified unified focal loss parameterization](figure/unifiedFocalLossTable.PNG){width=80%}
The *lambda* parameter controls the relative weighting between the distribution- and region-based losses. When *lambda* = 1, only the distribution-based loss is used. When *lambda* = 0, only the region-based loss is used. Values between 0 and 1 yield a loss metric that incorporate both the distribution- and region-based loss components. A *lambda* of 0.5 yield equal weighting, values larger than 0.5 put more weight on the distribution-based loss, and values lower than 0.5 put more weighting on the region-based loss.
The *gamma* parameter controls the weight applied to difficult-to-predict samples, defined as samples or classes with low predicted rescaled logits relative to their correct class. For the distribution-based loss, focal corrections are implemented sample-by-sample. For the region-based loss, corrections are implemented class-by-class during macro-averaging. *gamma* must be larger than 0 and less than or equal to 1. When *gamma* = 1, no focal correction is applied. Lower values result in a larger focal correction.
The *delta* parameter controls the relative weights of FN and FP errors and should be between 0 and 1. A *delta* of 0.5 places equal weight on FN and FP errors. Values larger than 0.5 place more weight on FN in comparison to FP samples while values smaller than 0.5 place more weight on FP samples in comparison to FN samples.
The *clsWghtsDist* parameter controls the relative weights of classes in the distribution-based loss and is applied sample-by-sample. The *clsWghtsReg* parameter controls the relative weights of classes in the region-based loss and are applied to each class when calculating a macro average. By default, all classes are weighted equally. If you want to implement class weights, you must provide a vector of class weights equal in length to the number of classes being differentiated.
Lastly, the *useLogCosH* parameter determines whether or not to apply a log cosh transformation to the region-based loss. If it is set to TRUE, this transformation is applied.
Using different parameterization, users can define a variety of loss metrics including cross entropy (CE) loss, weighted CE loss, focal CE loss, focal weighted CE loss, Dice loss, focal Dice loss, Tversky loss, and focal Tversky loss.
## Examples
We will now demonstrate how different loss metrics can be obtained using different parameterizations. We first load in example data using the *rast()* function from **terra** representing class reference numeric codes (refC) and predicted class logits (predL).
```r
refC <- terra::rast("C:/myFiles/data/metricCheck/multiclass_reference.tif")
predL <- terra::rast("C:/myFiles/data/metricCheck/multiclass_logits.tif")
```
The spatRaster objects are then converted to **torch** tensors with the correct shape and data type. We simulate a mini-batch of two samples by concatenating two copies of the tensors.
```r
predL <- terra::as.array(predL)
refC <- terra::as.array(refC)
target <- torch::torch_tensor(refC, dtype=torch::torch_long(), device=device)
pred <- torch::torch_tensor(predL, dtype=torch::torch_float32(), device=device)
target <- target$permute(c(3,1,2))
pred <- pred$permute(c(3,1,2))
target <- target$unsqueeze(1)
pred <- pred$unsqueeze(1)
target <- torch::torch_cat(list(target, target), dim=1)
pred <- torch::torch_cat(list(pred, pred), dim=1)
```
### Example 1: Dice Loss
The **Dice loss** is obtained by setting the *lambda* parameter to 0, the *gamma* parameter to 1, and the *delta* parameter to 0.5. This results in only the region-based loss being considered, no focal correction being applied, and equal weighting between FN and FP errors.
```r
myDiceLoss <- defineUnifiedFocalLoss(nCls=5,
lambda=0, #Only use region-based loss
gamma= 1,
delta= 0.5, #Equal weights for FP and FN
smooth = 1e-8,
zeroStart=TRUE,
clsWghtsDist=1,
clsWghtsReg=1,
useLogCosH =FALSE,
device=device)
myDiceLoss(pred=pred,
target=target)
torch_tensor
0.17861
[ CUDAFloatType{} ]
```
### #Example 2: Tversky Loss
The **Tversky Loss** can be obtained using the same settings as those used for the **Dice loss** except that the *delta* parameter must be set to a value other than 0.5 so that different weights are applied to FN and FP errors. In the example, we use a weighting of 0.6, which places more weight on FN errors relative to FP errors. Setting *gamma* to a value lower than 1 results in a **focal Tversky loss**.
Note that we regenerate the tensors so that the computational graphs are re-initialized.
```r
target <- torch::torch_tensor(refC, dtype=torch::torch_long(), device=device)
pred <- torch::torch_tensor(predL, dtype=torch::torch_float32(), device=device)
target <- target$permute(c(3,1,2))
pred <- pred$permute(c(3,1,2))
target <- target$unsqueeze(1)
pred <- pred$unsqueeze(1)
target <- torch::torch_cat(list(target, target), dim=1)
pred <- torch::torch_cat(list(pred, pred), dim=1)
myTverskyLoss <- defineUnifiedFocalLoss(nCls=5,
lambda=0, #Only use region-based loss
gamma= 1,
delta= 0.6, #FN weighted higher than FP
smooth = 1e-8,
zeroStart=TRUE,
clsWghtsDist=1,
clsWghtsReg=1,
useLogCosH =FALSE,
device=device)
myTverskyLoss(pred=pred,
target=target)
torch_tensor
0.177986
[ CUDAFloatType{} ]
```
### Example 3: Cross Entropy (CE) Loss
The **cross entropy** (**CE**) loss is obtained by setting *lambda* to 1, so that only the distribution-based loss is considered, and setting *gamma* to 1, so that no focal correction is applied. Setting *gamma* to a value lower than 1 results in a **focal CE loss**.
```r
target <- torch::torch_tensor(refC, dtype=torch::torch_long(), device=device)
pred <- torch::torch_tensor(predL, dtype=torch::torch_float32(), device=device)
target <- target$permute(c(3,1,2))
pred <- pred$permute(c(3,1,2))
target <- target$unsqueeze(1)
pred <- pred$unsqueeze(1)
target <- torch::torch_cat(list(target, target), dim=1)
pred <- torch::torch_cat(list(pred, pred), dim=1)
myCELoss <- defineUnifiedFocalLoss(nCls=5,
lambda=1, #Only use distribution-based loss
gamma= 1,
delta= 0.5,
smooth = 1e-8,
zeroStart=TRUE,
clsWghtsDist=1,
clsWghtsReg=1,
useLogCosH =FALSE,
device=device)
myCELoss(pred=pred,
target=target)
torch_tensor
1.58689
[ CUDAFloatType{} ]
```
### Example 4: Combo-Loss
A **combo-loss** can be obtained by setting *lambda* to a value between 0 and 1. In the example, we have used 0.5, which results in equal weights being applied to the distribution- and region-based losses. We also apply a focal correction using a *gamma* of 0.8 and weight FN errors higher than FP errors by using a *delta* of 0.6. The result is a combination of the **focal CE** and **focal Tversky** losses.
```r
target <- torch::torch_tensor(refC, dtype=torch::torch_long(), device=device)
pred <- torch::torch_tensor(predL, dtype=torch::torch_float32(), device=device)
target <- target$permute(c(3,1,2))
pred <- pred$permute(c(3,1,2))
target <- target$unsqueeze(1)
pred <- pred$unsqueeze(1)
target <- torch::torch_cat(list(target, target), dim=1)
pred <- torch::torch_cat(list(pred, pred), dim=1)
myComboLoss <- defineUnifiedFocalLoss(nCls=5,
lambda=.5, #Use both distribution and region-based losses
gamma= 0.8, #Apply a focal adjustment
delta= 0.6, #Weight FN higher than FP
smooth = 1e-8,
zeroStart=TRUE,
clsWghtsDist=1,
clsWghtsReg=1,
useLogCosH =FALSE,
device=device)
myComboLoss(pred=pred,
target=target)
torch_tensor
0.9143
[ CUDAFloatType{1} ]
```