-
Notifications
You must be signed in to change notification settings - Fork 32
/
loss_func.py
113 lines (94 loc) · 3.87 KB
/
loss_func.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import torch
import torch.nn as nn
import numpy as np
import torch
import torch.nn.functional as F
class NLLSurvLoss(nn.Module):
"""
The negative log-likelihood loss function for the discrete time to event model (Zadeh and Schmid, 2020).
Code borrowed from https://github.com/mahmoodlab/Patch-GCN/blob/master/utils/utils.py
Parameters
----------
alpha: float
TODO: document
eps: float
Numerical constant; lower bound to avoid taking logs of tiny numbers.
reduction: str
Do we sum or average the loss function over the batches. Must be one of ['mean', 'sum']
"""
def __init__(self, alpha=0.0, eps=1e-7, reduction='mean'):
super().__init__()
self.alpha = alpha
self.eps = eps
self.reduction = reduction
def __call__(self, h, y, t, c):
"""
Parameters
----------
h: (n_batches, n_classes)
The neural network output discrete survival predictions such that hazards = sigmoid(h).
y_c: (n_batches, 2) or (n_batches, 3)
The true time bin label (first column) and censorship indicator (second column).
"""
return nll_loss(h=h, y=y.unsqueeze(dim=1), c=c.unsqueeze(dim=1),
alpha=self.alpha, eps=self.eps,
reduction=self.reduction)
# TODO: document better and clean up
def nll_loss(h, y, c, alpha=0.0, eps=1e-7, reduction='mean'):
"""
The negative log-likelihood loss function for the discrete time to event model (Zadeh and Schmid, 2020).
Code borrowed from https://github.com/mahmoodlab/Patch-GCN/blob/master/utils/utils.py
Parameters
----------
h: (n_batches, n_classes)
The neural network output discrete survival predictions such that hazards = sigmoid(h).
y: (n_batches, 1)
The true time bin index label.
c: (n_batches, 1)
The censoring status indicator.
alpha: float
TODO: document
eps: float
Numerical constant; lower bound to avoid taking logs of tiny numbers.
reduction: str
Do we sum or average the loss function over the batches. Must be one of ['mean', 'sum']
References
----------
Zadeh, S.G. and Schmid, M., 2020. Bias in cross-entropy-based training of deep survival networks. IEEE transactions on pattern analysis and machine intelligence.
"""
# print("h shape", h.shape)
# make sure these are ints
y = y.type(torch.int64)
c = c.type(torch.int64)
hazards = torch.sigmoid(h)
# print("hazards shape", hazards.shape)
S = torch.cumprod(1 - hazards, dim=1)
# print("S.shape", S.shape, S)
S_padded = torch.cat([torch.ones_like(c), S], 1)
# S(-1) = 0, all patients are alive from (-inf, 0) by definition
# after padding, S(0) = S[1], S(1) = S[2], etc, h(0) = h[0]
# hazards[y] = hazards(1)
# S[1] = S(1)
# TODO: document and check
# print("S_padded.shape", S_padded.shape, S_padded)
# TODO: document/better naming
s_prev = torch.gather(S_padded, dim=1, index=y).clamp(min=eps)
h_this = torch.gather(hazards, dim=1, index=y).clamp(min=eps)
s_this = torch.gather(S_padded, dim=1, index=y+1).clamp(min=eps)
# print('s_prev.s_prev', s_prev.shape, s_prev)
# print('h_this.shape', h_this.shape, h_this)
# print('s_this.shape', s_this.shape, s_this)
uncensored_loss = -(1 - c) * (torch.log(s_prev) + torch.log(h_this))
censored_loss = - c * torch.log(s_this)
# print('uncensored_loss.shape', uncensored_loss.shape)
# print('censored_loss.shape', censored_loss.shape)
neg_l = censored_loss + uncensored_loss
if alpha is not None:
loss = (1 - alpha) * neg_l + alpha * uncensored_loss
if reduction == 'mean':
loss = loss.mean()
elif reduction == 'sum':
loss = loss.sum()
else:
raise ValueError("Bad input for reduction: {}".format(reduction))
return loss