In [None]:
#export
import torch

In [None]:
# default_exp loss_functions

# loss_functions

> This module contains implementations of both [traditional survival analysis functions](https://square.github.io/pysurvival/math.html), as well as the loss functions associated with uncensored data, as defined in the [original DRSA paper](https://arxiv.org/pdf/1809.02403.pdf).

In [None]:
#hide
from nbdev.showdoc import *
import pytest

## Survival Analysis Functions

Following the notation used in the the [DRSA paper](https://arxiv.org/pdf/1809.02403.pdf), we define the following:

* Let $z$ be the true occurrence time for the event of interest.

* Let $t$ be the time that a given data point was observed.

* For each observation, there exist $L$ time slices, ie $0 < t_1 < t_2 < \dots < t_L$, at which we either observe the event (uncensored) or do not (censored).

* Let $V_l = (t_{l-1}, t_l]$ be the set of all disjoint intervals with $l = 1, 2, \dots, L$.

In [None]:
#hide

def assert_correct_input_shape(h):
    if len(h.shape) != 3:
        raise ValueError(f"h is of shape {h.shape}. It is expected that h is of shape (batch size, 1, 1), as this is most amenable to use in training neural nets with pytorch.")

def assert_correct_output_shape(q, batch_size):
    if q.shape != torch.Size([batch_size, 1]):
        raise ValueError(f"q is of shape {q.shape}. It is expected that q is of shape (batch_size, 1)")

### Discrete Survival function

Though it's given its own name is survival analysis, the survival function is simply calculated as $1 - \text{CDF}(z)$. In the discrete, empirical case, the survival function is estimated as follows (this is equation (5) in the paper).

$$ S(t_l) = Pr(z > t_l) = \sum_{j > l}Pr(z\in V_j) $$

In [None]:
#export

def survival_rate(h):
    """
    Given the predicted conditional hazard rate, this function estimates
    the survival rate.
    
    *input*:
    * `h`: 
        - type: `torch.tensor`, 
        - predicted conditional hazard rate, at each observed time step.
        - note: `h.shape == (batch size, 1, 1)`, as this is most amenable to use in training neural nets with pytorch.
            
    _output_:
    * `s`: 
        - type: `torch.tensor`
        - estimated survival rate at time t. 
        - note: `s.shape == (batch_size, 1)`
    """
    assert_correct_input_shape(h)
    s = (1-h).prod(dim=1)
    return s

In [None]:
# example
h1 = torch.tensor([[0.001],
                   [0.5],
                   [0.55],
                   [0.15],
                   [0.15],
                   [0.15],
                   [0.15],
                   [0.9]], requires_grad=True)
h2 = torch.tensor([[0.001],
                    [0.005],
                    [0.1],
                    [0.11],
                    [0.12],
                    [0.15],
                    [0.15],
                    [0.9]], requires_grad=True)
h = torch.stack([h1, h2], dim=0)
survival_rate(h)

tensor([[0.0117],
        [0.0506]], grad_fn=<ProdBackward1>)

In [None]:
#hide

# survival rate tests

def test_survival_rate(h):
    # shape should be 3-d
    with pytest.raises(ValueError):
        s = survival_rate(h[0, :, :])
    
    batch_size, length, _ = h.shape
    s = survival_rate(h)
    
    # output should have shape (batch_size, 1)
    assert_correct_output_shape(s, batch_size)
    
    # testing correct output
    torch.testing.assert_allclose(s, torch.tensor([[0.0117], [0.0506]]), rtol=1e-3, atol=1e-3)
    

test_survival_rate(h)

### Discrete Event Rate function

The event rate function is calculated as $\text{CDF}(z)$. In the discrete, empirical case, it is estimated as follows (this is equation (5) in the paper).

$$ W(t_l) = Pr(z \leq t_l) = \sum_{j\leq l}Pr(z\in V_j) $$

In [None]:
#export

def event_rate(h):
    """
    Given the predicted conditional hazard rate, this function estimates
    the event rate.
    
    *input*:
    * `h`: 
        - type: `torch.tensor`, 
        - predicted conditional hazard rate, at each observed time step.
        - note: `h.shape == (batch size, 1, 1)`, as this is most amenable to use in training neural nets with pytorch.
            
    _output_:
    * `w`: 
        - type: `torch.tensor`
        - estimated survival rate at time t. 
        - note: `w.shape == (batch_size, 1)`
    """
    assert_correct_input_shape(h)
    w = 1-survival_rate(h)
    return w

In [None]:
# example
event_rate(h)

tensor([[0.9883],
        [0.9494]], grad_fn=<RsubBackward1>)

In [None]:
#hide

# event rate tests

def test_event_rate(h):
    # shape should be 3-d
    with pytest.raises(ValueError):
        w = event_rate(h[0, :, :])
    
    batch_size, length, _ = h.shape
    w = event_rate(h)
    
    # output should have shape (batch_size, 1)
    assert_correct_output_shape(w, batch_size)
    
    # testing correct output
    torch.testing.assert_allclose(w, torch.tensor([[0.9883], [0.9494]]), rtol=1e-3, atol=1e-3)
    

test_event_rate(h)

### Discrete Event Time Probability function

The event time probability function is calculated as $\text{PDF}(z)$. In the discrete, empirical case, it is estimated as follows (this is equation (6) in the paper).

$$p_l = Pr(z\in V_t) = W(t_l) - W(t_{l-1}) = S(t_{l-1}) - S(t_{l})$$

In [None]:
#export

def event_time(h):
    """
    Given the predicted conditional hazard rate, this function estimates
    the probability that the event occurs at time t.
    
    *input*:
    * `h`: 
        - type: `torch.tensor`, 
        - predicted conditional hazard rate, at each observed time step.
        - note: `h.shape == (batch size, 1, 1)`, as this is most amenable to use in training neural nets with pytorch.
            
    _output_:
    * `p`: 
        - type: `torch.tensor`
        - estimated probability of event at time t. 
        - note: `p.shape == (batch_size, 1)`
    """
    assert_correct_input_shape(h)
    p = h[:, -1, :] * survival_rate(h[:, :-1, :])
    return p

In [None]:
# example
event_time(h)

tensor([[0.1056],
        [0.4556]], grad_fn=<MulBackward0>)

In [None]:
#hide

# event time tests

def test_event_time(h):
    # shape should be 3-d
    with pytest.raises(ValueError):
        p = event_time(h[0, :, :])
    
    batch_size, length, _ = h.shape
    p = event_time(h)
    
    # output should have shape (batch_size, 1)
    assert_correct_output_shape(p, batch_size)

    # testing correct output
    torch.testing.assert_allclose(p, torch.tensor([[0.1056], [0.4556]]), rtol=1e-3, atol=1e-3)
    

test_event_time(h)

### Discrete Conditional Hazard Rate

The conditional hazard rate is the quantity which will be predicted at each time step by a recurrent survival analysis model. In the discrete, empirical case, it is estimated as follows (this is equation (7) in the paper).

$$h_l = Pr(z\in V_l | z > t_{l-1}) = \frac{Pr(z\in V_l)}{Pr(z>t_{l-1})} = \frac{p_l}{S(t_{l-1})}$$

## Log Survival Analysis Functions

We additionally define the log of each of the traditional survival analysis functions, which prove useful for computational stability, being that we need to multiply many float point decimal values together.

### Log Survival Function

In [None]:
#export

def log_survival_rate(h):
    """
    Given the predicted conditional hazard rate, this function estimates
    the log survival rate.
    
    *input*:
    * `h`: 
        - type: `torch.tensor`, 
        - predicted conditional hazard rate, at each observed time step.
        - note: `h.shape == (batch size, 1, 1)`, as this is most amenable to use in training neural nets with pytorch.
            
    _output_:
    * `s`: 
        - type: `torch.tensor`
        - estimated log survival rate at time t. 
        - note: `s.shape == (batch_size, 1)`
    """
    assert_correct_input_shape(h)
    s = (1-h).log().sum(dim=1)
    return s

In [None]:
#example
log_survival_rate(h)

tensor([[-4.4453],
        [-2.9834]], grad_fn=<SumBackward1>)

In [None]:
#hide

# log survival rate tests
def test_log_survival_rate(h):
    # shape should be 3-d
    with pytest.raises(ValueError):
        s = log_survival_rate(h[0, :, :])
    
    batch_size, length, _ = h.shape
    s = log_survival_rate(h)
    
    # output should have shape (batch_size, 1)
    assert_correct_output_shape(s, batch_size)
    
    # testing correct output
    torch.testing.assert_allclose(s, survival_rate(h).log(), rtol=1e-3, atol=1e-3)
    

test_log_survival_rate(h)

### Log Event Rate Function

In [None]:
#export

def log_event_rate(h):
    """
    Given the predicted conditional hazard rate, this function estimates
    the log event rate.
    
    *input*:
    * `h`: 
        - type: `torch.tensor`, 
        - predicted conditional hazard rate, at each observed time step.
        - note: `h.shape == (batch size, 1, 1)`, as this is most amenable to use in training neural nets with pytorch.
            
    _output_:
    * `w`: 
        - type: `torch.tensor`
        - estimated log survival rate at time t. 
        - note: `w.shape == (batch_size, 1)`
    """
    assert_correct_input_shape(h)
    w = event_rate(h).log()
    return w

In [None]:
# example
log_event_rate(h)

tensor([[-0.0118],
        [-0.0519]], grad_fn=<LogBackward>)

In [None]:
#hide

# log event rate tests

def test_log_event_rate(h):
    # shape should be 3-d
    with pytest.raises(ValueError):
        w = log_event_rate(h[0, :, :])
    
    batch_size, length, _ = h.shape
    w = log_event_rate(h)
    
    # output should have shape (batch_size, 1)
    assert_correct_output_shape(w, batch_size)
    
    # testing correct output
    torch.testing.assert_allclose(w, event_rate(h).log(), rtol=1e-3, atol=1e-3)
    

test_log_event_rate(h)

### Log Event Time Function

In [None]:
#export

def log_event_time(h):
    """
    Given the predicted conditional hazard rate, this function estimates
    the log probability that the event occurs at time t.
    
    *input*:
    * `h`: 
        - type: `torch.tensor`, 
        - predicted conditional hazard rate, at each observed time step.
        - note: `h.shape == (batch size, 1, 1)`, as this is most amenable to use in training neural nets with pytorch.
            
    _output_:
    * `p`: 
        - type: `torch.tensor`
        - estimated log probability of event at time t. 
        - note: `p.shape == (batch_size, 1)`
    """
    assert_correct_input_shape(h)
    p = torch.log(h[:, -1, :]) + log_survival_rate(h[:, :-1, :])
    return p

In [None]:
# example
log_event_time(h)

tensor([[-2.2481],
        [-0.7861]], grad_fn=<AddBackward0>)

In [None]:
#hide

# log event time tests

def test_log_event_time(h):
    # shape should be 3-d
    with pytest.raises(ValueError):
        p = log_event_time(h[0, :, :])
    
    batch_size, length, _ = h.shape
    p = log_event_time(h)
    
    # output should have shape (batch_size, 1)
    assert_correct_output_shape(p, batch_size)
    
    # testing correct output
    torch.testing.assert_allclose(p, event_time(h).log(), rtol=1e-3, atol=1e-3)
    

test_log_event_time(h)

## Loss Functions

Now, we define the transform these generic survival analysis functions into loss functions that can be automatically differentiated by PyTorch, in order to train a Deep Recurrent Survival Analysis model.


We make a few notes below:

1. The functions below adhere to the common pattern used across all of [`PyTorch`'s loss functions](https://pytorch.org/docs/stable/nn.functional.html#loss-functions), which is to take two arguments named `input` and `target`. We note, however, that due to the nature of this survival data, the target is inherent to the data structure and thus unnecessary.

2. The original DRSA paper defines 3 loss functions, 2 of which are directed towards uncensored data, and 1 of which applies to censored data. This library's focus is on DRSA models using only uncensored data, so those are the only lossed we'll be defining.

### Event Time Loss

In [None]:
#export

def event_time_loss(input, target=None):
    """
    Loss function applied to uncensored data in order
    to optimize the PDF of the true event time, z
    
    input:
    * `input`: 
        - type: `torch.tensor`, 
        - predicted conditional hazard rate, at each observed time step.
        - note: `h.shape == (batch size, 1, 1)`
    * `target`:
        - unused, only present to mimic pytorch loss functions
            
    output:
    * `evt_loss`: 
        - type: `torch.tensor`
        - Loss associated with how wrong each predicted probability was at each time step
    """
    assert_correct_input_shape(input)
    evt_loss = -log_event_time(input).sum(dim=0).squeeze()
    return evt_loss

In [None]:
# example
event_time_loss(h)

tensor(3.0342, grad_fn=<NegBackward>)

In [None]:
#hide

# event time loss tests
def test_event_time_loss(input, target=None):
    evt_loss = event_time_loss(input)
    
    # testing correct output
    torch.testing.assert_allclose(evt_loss, torch.tensor(3.0342), rtol=1e-3, atol=1e-3)

test_event_time_loss(h)

### Event Rate Loss

In [None]:
#export

def event_rate_loss(input, target=None):
    """
    Loss function applied to uncensored data in order
    to optimize the CDF of the true event time, z
    
    input:
    * `input`: 
        - type: `torch.tensor`, 
        - predicted conditional hazard rate, at each observed time step.
        - note: `h.shape == (batch size, 1, 1)`
    * `target`:
        - unused, only present to mimic pytorch loss functions
            
    output:
    * `evr_loss`: 
        - type: `torch.tensor`
        - Loss associated with how cumulative predicted probabilities differ from the ground truth labels.
    """
    assert_correct_input_shape(input)
    evr_loss = -log_event_rate(input).sum(dim=0).squeeze()
    return evr_loss

In [None]:
# example
event_rate_loss(h)

tensor(0.0638, grad_fn=<NegBackward>)

In [None]:
#hide

# event rate loss tests
def test_event_rate_loss(input, target=None):
    evr_loss = event_rate_loss(input)
    
    # testing correct output
    torch.testing.assert_allclose(evr_loss, torch.tensor(0.0638), rtol=1e-3, atol=1e-3)

test_event_rate_loss(h)