# Minimal Torch Reimplementation of the DeepSurv Paper
DeepSurv: Personalized Treatment Recommender System Using A Cox Proportional Hazards Deep Neural Network
## Results
| Experiement                 | Reproduced C-idx | Paper's C-idx |
|-----------------------------|------------------|---------------|
| Linear data, Linear CPH     | 77.66% | 77.37% |
| Linear data, DeepSurv       | 77.11% | 77.40% |
| Non-linear data, Linear CPH | 48.70% | 50.70% |
| Non-linear data, Deepsurv   | 59.58% | 64.90% |

In [1]:
"""
Implementation of the DeepSurv paper (Katzman et al., 2018)
"""
from typing import Callable, Tuple, List

import numpy as np

import torch
import torch.nn as nn
import matplotlib.pyplot as plt


def generate_dataset(
    h_func: Callable[[np.array], np.array],
    n_samples: int = 6000,
    censor_perc: float = 90,
):
    """
    Generate a simulated dataset as described in section 4.2

    Parameters
    ----------
    h_func: Callable
        Takes as input np.array of covariates, returns h(x)
    n_samples: int
        Number of samples to generate
    censor_perc: float
        Censoring percentile (anything above this percentile is censored), between 0 and 100

    Returns
    -------
    data_train: np.array
        Training data, first column is T, second column is observed, rest are covariates
    data_val: np.array
        Validation data, first column is T, second column is observed, rest are covariates
    data_test: np.array
        Test data, first column is T, second column is observed, rest are covariates
    """

    exp_times = np.random.exponential(5, n_samples)
    covariates = np.random.uniform(-1, 1, (n_samples, 10))
    h_x = h_func(covariates)
    survival_times = exp_times / np.exp(h_x)
    perc_censored = np.percentile(survival_times, censor_perc)
    is_observed = np.ones(survival_times.shape)
    is_observed[survival_times > perc_censored] = 0
    survival_times[survival_times > perc_censored] = perc_censored

    data = np.hstack(
        (survival_times.reshape(-1, 1), is_observed.reshape(-1, 1), covariates)
    )
    data_train, data_val, data_test = data[:4000], data[4000:5000], data[5000:]
    return data_train, data_val, data_test


class PartialLikelihood(nn.Module):
    """
    Partial likelihood loss function for Cox's proportional hazards model
    """

    def forward(
        self,
        survival_times: torch.tensor,
        model_h: torch.tensor,
        is_observed: torch.tensor,
    ) -> torch.tensor:
        """
        Compute the partial likelihood loss for a given set of covariates and
        observed data

        Parameters
        ----------
        survival_times : torch.tensor
            Survival times for each sample, shape = (batch_size, N_samples)
        model_h : torch.tensor
            Model output for each sample, shape = (batch_size, N_samples)
        is_observed : torch.tensor
            Dichotomous variable for each sample (1 if observed, 0 if censored)
            shape = (batch_size, N_samples)

        Returns
        -------
        torch.tensor
            Partial likelihood loss
        """
        # Sort by survival times
        argsort_indices = torch.argsort(survival_times, dim=1)
        sorted_obs = torch.gather(is_observed, 1, argsort_indices)
        sorted_h = torch.gather(model_h, 1, argsort_indices)

        # Compute partial likelihood
        rev_exp_h = torch.exp(torch.flip(sorted_h, dims=(1,)))
        terms = torch.flip(sorted_h, dims=(1,)) - torch.log(
            torch.cumsum(rev_exp_h, dim=1)
        )
        loss = torch.sum(terms * torch.flip(sorted_obs, dims=(1,)), dim=1)
        return -1 * torch.mean(loss) / survival_times.shape[1]


def fit_linear_cph(
    covariates: torch.tensor,
    survival_times: torch.tensor,
    is_observed: torch.tensor,
    epochs: int = 1000,
    plot: bool = False,
) -> Tuple[torch.tensor, List[float]]:
    """
    Fit a linear Cox's proportional hazards model

    Parameters
    ----------
    covariates : torch.tensor
        Covariates, shape = (N_samples, N_covariates)
    survival_times : torch.tensor
        Survival times for each sample, shape = (1, N_samples)
    is_observed : torch.tensor
        Dichotomous variable for each sample (1 if observed, 0 if censored)
        shape = (1, N_samples)
    epochs : int, optional
        Number of epochs to train for, by default 1000
    plot : bool, optional
        Whether to plot training loss, by default False

    Returns
    -------
    beta: torch.tensor
        Model parameters, shape = (1, N_covariates)
    losses: list
        Training loss at each epoch
    """
    beta = nn.Parameter(torch.normal(0, 1, (1, 10)))
    criterion = PartialLikelihood()
    optimizer = torch.optim.Adam([beta], lr=1e-2)

    losses = []
    for _ in range(epochs):
        model_h = beta @ covariates.T
        loss = criterion.forward(survival_times, model_h, is_observed)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses.append(loss.item())

    if plot:
        plt.plot(losses)
        plt.show()
    return beta, losses


def c_index_eval(
    survival_times: torch.tensor, model_pred: torch.tensor, is_observed: torch.tensor
) -> torch.tensor:
    """
    Compute the concordance index for a given set of survival times, model
    predictions, and observed data

    Parameters
    ----------
    survival_times : torch.tensor
        Survival times for each sample, shape = (batch_size, N_samples)
    model_pred : torch.tensor
        Model output for each sample, shape = (batch_size, N_samples)
    is_observed : torch.tensor
        Dichotomous variable for each sample (1 if observed, 0 if censored)
        shape = (batch_size, N_samples)

    Returns
    -------
    torch.tensor
        Concordance index
    """
    argsort_indices = torch.argsort(survival_times, dim=1)
    sorted_obs = torch.gather(is_observed, 1, argsort_indices).reshape(-1)
    sorted_model_pred = torch.exp(torch.gather(model_pred, 1, argsort_indices)).reshape(
        -1, 1
    )
    comparisons = sorted_model_pred.T < sorted_model_pred
    corrects = torch.sum(torch.triu(comparisons.int()), dim=1).reshape(-1) * sorted_obs
    total = torch.flip(torch.arange(0, corrects.shape[0]), (0,)) * sorted_obs
    return torch.sum(corrects) / torch.sum(total)


def plot_heatmap_2d_function(
    func: Callable[[np.array, np.array], np.array],
    x_range: Tuple[float, float] = (-1, 1),
    y_range: Tuple[float, float] = (-1, 1),
    num_points: int = 100,
    title: str = "",
) -> None:
    """
    Plot a 2D heatmap of a given function
    Created with the help of ChatGPT

    Parameters
    ----------
    func : Callable[np.array, np.array, np.array]
        Function to plot, takes as input np.array of x values, np.array of y values,
        returns np.array of function values
    x_range : Tuple[float, float], by default (-1, 1)
        Range of x values to plot
    y_range : Tuple[float, float], by default (-1, 1)
        Range of y values to plot
    num_points : int, by default 100
        Number of points to plot per axis (resolution of each axis)
    title : str, by default ""
        Title to add to the plot
    """
    # Generate x and y values
    x_linspace = np.linspace(x_range[0], x_range[1], num_points)
    y_linspace = np.linspace(y_range[0], y_range[1], num_points)

    # Create a grid of (x, y) pairs
    x_mesh, y_mesh = np.meshgrid(x_linspace, y_linspace)

    # Calculate the function values for each (x, y) pair
    z_mesh = func(x_mesh, y_mesh)

    # Create the heatmap
    plt.figure(figsize=(8, 6))
    plt.imshow(
        z_mesh,
        extent=[x_range[0], x_range[1], y_range[0], y_range[1]],
        origin="lower",
        cmap="viridis",
    )
    plt.colorbar(label="Function Value")
    plt.xlabel("X-axis")
    plt.ylabel("Y-axis")
    plt.title(f"{title} 2D heatmap")
    plt.show()


def deep_surv(
    covariates: torch.tensor,
    survival_times: torch.tensor,
    observed: torch.tensor,
    x_val: torch.tensor,
    survival_times_val: torch.tensor,
    observed_val: torch.tensor,
    epochs: int = 100,
    plot: bool = False,
) -> Tuple[nn.Module, List[float]]:
    """
    Deep survival model

    Parameters
    ----------
    covariates : torch.tensor
        Covariates, shape = (N_samples, N_covariates)
    survival_times : torch.tensor
        Survival times for each sample, shape = (1, N_samples)
    observed : torch.tensor
        Dichotomous variable for each sample (1 if observed, 0 if censored)
        shape = (1, N_samples)
    x_val : torch.tensor
        Validation covariates, shape = (N_samples, N_covariates)
    survival_times_val : torch.tensor
        Validation survival times for each sample, shape = (1, N_samples)
    observed_val : torch.tensor
        Validation dichotomous variable for each sample (1 if observed, 0 if censored)
        shape = (1, N_samples)
    epochs : int, by default 100
        Number of epochs to train for
    plot : bool, by default False
        Whether to plot training and validation c-index

    Returns
    -------
    model: nn.Module
        Trained model
    losses: list
        Training c-index at each epoch
    """
    model = nn.Sequential(
        nn.Linear(10, 30),
        nn.ReLU(),
        nn.Linear(30, 50),
        nn.ReLU(),
        nn.Linear(50, 30),
        nn.ReLU(),
        nn.Linear(30, 1),
    )

    criterion = PartialLikelihood()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    tr_c_indices = []
    val_c_indices = []
    losses = []
    val_losses = []
    for epoch in range(epochs):
        model_h = model.forward(covariates).T
        loss = criterion.forward(survival_times, model_h, observed)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if epoch % 10 == 0:
            with torch.no_grad():
                tr_c_index = c_index_eval(survival_times, model_h, observed)
                model_h_val = model.forward(x_val).T
                val_c_index = c_index_eval(
                    survival_times_val, model_h_val, observed_val
                )
                loss_val = criterion.forward(
                    survival_times_val, model_h_val, observed_val
                )

                losses.append(loss.item())
                val_losses.append(loss_val.item())
                tr_c_indices.append(tr_c_index)
                val_c_indices.append(val_c_index)

    if plot:
        plt.plot(tr_c_indices, label="Training")
        plt.plot(val_c_indices, label="Validation")
        plt.xlabel("Epochs (x10)")
        plt.ylabel("C-index")
        plt.legend()
        plt.title("DeepSurv Training C-index")
        plt.show()

        plt.plot(losses, label="Training")
        plt.plot(val_losses, label="Validation")
        plt.xlabel("Epochs (x10)")
        plt.ylabel("Partial Likelihood")
        plt.legend()
        plt.title("DeepSurv Training Partial Likelihoods")
        plt.show()
    return model, losses

In [2]:
simple_data_tr, simple_data_val, simple_data_test = generate_dataset(lambda x: x[:,0] + 2*x[:,1])

x = torch.from_numpy(simple_data_tr[:,2:]).float()
T = torch.from_numpy(simple_data_tr[:,0]).reshape(1,-1).float()
observed = torch.from_numpy(simple_data_tr[:,1]).reshape(1,-1).float()

x_val = torch.from_numpy(simple_data_val[:,2:]).float()
T_val = torch.from_numpy(simple_data_val[:,0]).reshape(1,-1).float()
observed_val = torch.from_numpy(simple_data_val[:,1]).reshape(1,-1).float()

x_test = torch.from_numpy(simple_data_test[:,2:]).float()
T_test = torch.from_numpy(simple_data_test[:,0]).reshape(1,-1).float()
observed_test = torch.from_numpy(simple_data_test[:,1]).reshape(1,-1).float()

# Linear CPH
beta, _ = fit_linear_cph(x, T, observed, plot=False)
pred = torch.exp(beta@x_test.T)
linear_c_index = c_index_eval(T_test, pred, observed_test)

# DeepSurv
model,_ = deep_surv(x, T, observed, x_val, T_val, observed_val, plot=False)
pred = torch.exp(model.forward(x_test).T)
deep_surv_c_index = c_index_eval(T_test, pred, observed_test)

print("Linear Simulated Data")
print("Linear CPH C-idx:", linear_c_index)
print("DeepSurv C-idx:", deep_surv_c_index)

Linear Simulated Data
Linear CPH C-idx: tensor(0.7766)
DeepSurv C-idx: tensor(0.7711)


In [3]:
def non_linear_h(x):
    lambda_max = 5
    r = 0.5
    return np.log(lambda_max) * np.exp(-1 * (x[:,0]**2 + x[:,1]**2)/(2*r**2) )
simple_data_tr, simple_data_val, simple_data_test = generate_dataset(non_linear_h)

x = torch.from_numpy(simple_data_tr[:,2:]).float()
T = torch.from_numpy(simple_data_tr[:,0]).reshape(1,-1).float()
observed = torch.from_numpy(simple_data_tr[:,1]).reshape(1,-1).float()

x_val = torch.from_numpy(simple_data_val[:,2:]).float()
T_val = torch.from_numpy(simple_data_val[:,0]).reshape(1,-1).float()
observed_val = torch.from_numpy(simple_data_val[:,1]).reshape(1,-1).float()

x_test = torch.from_numpy(simple_data_test[:,2:]).float()
T_test = torch.from_numpy(simple_data_test[:,0]).reshape(1,-1).float()
observed_test = torch.from_numpy(simple_data_test[:,1]).reshape(1,-1).float()

# Linear CPH
beta, _ = fit_linear_cph(x, T, observed, plot=False)
pred = torch.exp(beta@x_test.T)
linear_c_index = c_index_eval(T_test, pred, observed_test)

# DeepSurv
model,_ = deep_surv(x, T, observed, x_val, T_val, observed_val, plot=False)
pred = torch.exp(model.forward(x_test).T)
deep_surv_c_index = c_index_eval(T_test, pred, observed_test)

print("Non-linear Simulated Data")
print("Linear CPH C-idx:", linear_c_index)
print("DeepSurv C-idx:", deep_surv_c_index)

Non-linear Simulated Data
Linear CPH C-idx: tensor(0.4870)
DeepSurv C-idx: tensor(0.5958)
