In [None]:
#export
import torch

In [None]:
# default_exp loss_functions

# loss_functions

> This module contains implementations of both [traditional survival analysis functions](https://square.github.io/pysurvival/math.html), as well as the loss functions associated with uncensored data, as defined in the [original DRSA paper](https://arxiv.org/pdf/1809.02403.pdf).

In [None]:
#hide
from nbdev.showdoc import *

## Survival Analysis Functions

Following the notation used in the the [DRSA paper](https://arxiv.org/pdf/1809.02403.pdf), we define the following:

* Let $z$ be the true occurrence time for the event of interest.

* Let $t$ be the time that a given data point was observed.

* For each observation, there exist $L$ time slices, ie $0 < t_1 < t_2 < \dots < t_L$, at which we either observe the event (uncensored) or do not (censored).

* Let $V_l = (t_{l-1}, t_l]$ be the set of all disjoint intervals with $l = 1, 2, \dots, L$.

### Discrete Survival function

Though it's given its own name is survival analysis, the survival function is simply calculated as $1 - \text{CDF}(z)$. In the discrete, empirical case, the survival function is estimated as follows (this is equation (1) in the paper).

$$ S(t_l) = Pr(z > t_l) = \sum_{j > l}Pr(z\in V_j) $$

In [None]:
#export

def survival_rate(h):
    """
    Given the predictied conditional hazard rate, this function estimates
    the survival rate.
    
    input:
        h: torch.tensor, which is the predicted
            conditional hazard rate, h, at each observed time step
    output:
        s: torch.tensor, estimated survival rate at time t
    """
    s = torch.prod(1-ts, dim=1)
    return s

In [None]:
#hide

def survival_rate(ts):
    """
    Survival (rate) function.
    1-CDF at each time
    """
    return torch.prod(1-ts, dim=1)

def event_rate(ts):
    """
    CDF of probabilities at each time step 
    """
    return 1-survival_rate(ts)

def event_time(ts):
    """
    this is the pdf
    """
    return ts[:, -1, :] * survival_rate(ts[:, :-1, :])

<!-- As defined in the [DRSA paper](https://arxiv.org/pdf/1809.02403.pdf)

* Discrete event rate (empirical CDF):

$$ W(t_l) = Pr(z \leq t_l) = \sum_{j\leq l}Pr(z\in V_j) $$

* Discrete Survival function:

$$ S(t_l) = Pr(z > t_l) = \sum_{j > l}Pr(z\in V_j) $$

* Discrete event time probability function (empirical PDF):

$$p_l = Pr(z\in V_t) = W(t_l) - W(t_{l-1}) = S(t_{l-1}) - S(t_{l})$$

* Discrete conditional hazard rate:

$$h_l = Pr(z\in V_l | z > t_{l-1}) = \frac{Pr(z\in V_l)}{Pr(z>t_{l-1})} = \frac{p_l}{S(t_{l-1})}$$ -->

In [None]:
ts1 = torch.tensor([[0.001],
                    [0.5],
                    [0.55],
                    [0.15],
                    [0.15],
                    [0.15],
                    [0.15],
                    [0.9]], requires_grad=True)

ts2 = torch.tensor([[0.001],
                    [0.005],
                    [0.1],
                    [0.11],
                    [0.12],
                    [0.15],
                    [0.15],
                    [0.9]], requires_grad=True)

ts = torch.stack([ts1, ts2], dim=0)
print(ts.shape)
print(f"Survival rate:\n {survival_rate(ts)}")
# print(f"Event rate:\n {event_rate(ts)}")
# print(f"Event time:\n {event_time(ts)}")

torch.Size([2, 8, 1])
Survival rate:
 tensor([[0.0117],
        [0.0506]], grad_fn=<ProdBackward1>)
