In [None]:
# default_exp loss_functions

# `loss_functions`

> This module contains implementations of the loss functions associated with uncensored data, as defined in the [original DRSA paper](https://arxiv.org/pdf/1809.02403.pdf).

In [None]:
#hide
from nbdev.showdoc import *
import torch

As defined in the [DRSA paper](https://arxiv.org/pdf/1809.02403.pdf)

* Discrete event rate (empirical CDF):

$$ W(t_l) = Pr(z \leq t_l) = \sum_{j\leq l}Pr(z\in V_j) $$

* Discrete Survival function:

$$ S(t_l) = Pr(z > t_l) = \sum_{j > l}Pr(z\in V_j) $$

* Discrete event time probability function (empirical PDF):

$$p_l = Pr(z\in V_t) = W(t_l) - W(t_{l-1}) = S(t_{l-1}) - S(t_{l})$$

* Discrete conditional hazard rate:

$$h_l = Pr(z\in V_l | z > t_{l-1}) = \frac{Pr(z\in V_l)}{Pr(z>t_{l-1})} = \frac{p_l}{S(t_{l-1})}$$

In [None]:
#export

def survival_rate(ts):
    return torch.prod(1-ts, dim=1)

def event_rate(ts):
    return 1-survival_rate(ts)

def event_time(ts):
    return ts[:, -1, :] * survival_rate(ts[:, :-1, :])

In [None]:
ts1 = torch.tensor([[0.001],
                    [0.5],
                    [0.55],
                    [0.15],
                    [0.15],
                    [0.15],
                    [0.15],
                    [0.9]], requires_grad=True)

ts2 = torch.tensor([[0.001],
                    [0.005],
                    [0.1],
                    [0.11],
                    [0.12],
                    [0.15],
                    [0.15],
                    [0.9]], requires_grad=True)

ts = torch.stack([ts1, ts2], dim=0)
print(ts.shape)
print(f"Survival rate:\n {survival_rate(ts)}")
print(f"Event rate:\n {event_rate(ts)}")
print(f"Event time:\n {event_time(ts)}")

torch.Size([2, 8, 1])
Survival rate:
 tensor([[0.0117],
        [0.0506]], grad_fn=<ProdBackward1>)
Event rate:
 tensor([[0.9883],
        [0.9494]], grad_fn=<RsubBackward1>)
Event time:
 tensor([[0.1056],
        [0.4556]], grad_fn=<MulBackward0>)
