# Recurrent DSM on PBC Dataset

The longitudinal PBC dataset comes from the Mayo Clinic trial in primary biliary cirrhosis (PBC) of the liver conducted between 1974 and 1984 (Refer to https://stat.ethz.ch/R-manual/R-devel/library/survival/html/pbc.html)

In this notebook, we will apply Recurrent Deep Survival Machines for survival prediction on the PBC data.

### Load the PBC Dataset

The package includes helper functions to load the dataset.

X represents an np.array of features (covariates),
T is the event/censoring times and,
E is the censoring indicator.

In [40]:
import sys

sys.path.append('../')
from auton_survival import datasets

x, t, e = datasets._load_pbc_dataset(sequential=True)

In [41]:
x

[array([[ 0.99436038, -0.99436038,  0.3725034 , -0.3725034 , -2.73741663,
          3.24173962, -0.97913787,  1.0425497 , -1.43799282,  1.54166667,
         -1.6047952 ,  3.28449064, -0.49195313, -0.22660837, -0.39802969,
         -0.67758009,  1.00051427,  2.0158767 , -0.46946083, -1.5706465 ,
          0.28561281,  0.19548755, -0.45602232,  0.81313216,  0.2480577 ],
        [ 0.99436038, -0.99436038,  0.3725034 , -0.3725034 , -2.73741663,
          3.24173962, -0.97913787,  1.0425497 , -1.43799282,  1.54166667,
         -1.6047952 ,  3.28449064, -0.49195313, -0.22660837, -0.39802969,
         -0.67758009,  1.00051427,  3.28188992,  0.        , -0.89457526,
          0.19553238, -1.48526305, -0.529101  ,  0.13676811,  0.2480577 ]]),
 array([[ 9.94360375e-01, -9.94360375e-01,  3.72503399e-01,
         -3.72503399e-01,  3.65307930e-01, -3.08476349e-01,
         -9.79137874e-01,  1.04254970e+00, -1.43799282e+00,
          1.54166667e+00,  6.23132470e-01, -3.04461211e-01,
         -4.9195

### Compute horizons at which we evaluate the performance of RDSM

Survival predictions are issued at certain time horizons. Here we will evaluate the performance
of RDSM to issue predictions at the 25th, 50th and 75th event time quantile as is standard practice in Survival Analysis.

In [42]:
import numpy as np
horizons = [0.25, 0.5, 0.75]
times = np.quantile([t_[-1] for t_, e_ in zip(t, e) if e_[-1] == 1], horizons).tolist()

### Splitting the data into train, test and validation sets

We will train RDSM on 70% of the Data, use a Validation set of 10% for Model Selection and report performance on the remaining 20% held out test set.

In [76]:
n = len(x)

# tr_size = int(n*0.70)
# vl_size = int(n*0.10)
# te_size = int(n*0.20)
#
# x_train, x_test, x_val = np.array(x[:tr_size], dtype = object), np.array(x[-te_size:], dtype = object), np.array(x[tr_size:tr_size+vl_size], dtype = object)
# t_train, t_test, t_val = np.array(t[:tr_size], dtype = object), np.array(t[-te_size:], dtype = object), np.array(t[tr_size:tr_size+vl_size], dtype = object)
# e_train, e_test, e_val = np.array(e[:tr_size], dtype = object), np.array(e[-te_size:], dtype = object), np.array(e[tr_size:tr_size+vl_size], dtype = object)


In [81]:
from sklearn.model_selection import train_test_split

# First, split the data into a 60% training set and a 40% temporary set (later split into validation and test sets)
x_temp, x_train, t_temp, t_train, e_temp, e_train = train_test_split(x, t, e, test_size=0.60, random_state=42)

# Then split the temporary set into a 50% validation set and a 50% test set (which is 20% and 20% of the original data respectively)
x_val, x_test, t_val, t_test, e_val, e_test = train_test_split(x_temp, t_temp, e_temp, test_size=0.5, random_state=42)

# Convert arrays to dtype=object
x_train = np.array(x_train, dtype=object)
x_val = np.array(x_val, dtype=object)
x_test = np.array(x_test, dtype=object)

t_train = np.array(t_train, dtype=object)
t_val = np.array(t_val, dtype=object)
t_test = np.array(t_test, dtype=object)

e_train = np.array(e_train, dtype=object)
e_val = np.array(e_val, dtype=object)
e_test = np.array(e_test, dtype=object)


### Setting the parameter grid

Lets set up the parameter grid to tune hyper-parameters. We will tune the number of underlying survival distributions, 
($K$), the distribution choices (Log-Normal or Weibull), the learning rate for the Adam optimizer between $1\times10^{-3}$ and $1\times10^{-4}$, the number of hidden nodes per layer $50, 100$ and $2$, the number of layers $3, 2$ and $1$ and the type of recurrent cell (LSTM, GRU, RNN).

In [82]:
from sklearn.model_selection import ParameterGrid

In [83]:
param_grid = {'k' : [3, 4, 6],
              'distribution' : ['LogNormal', 'Weibull'],
              'learning_rate' : [1e-4, 1e-3],
              'hidden': [50, 100],
              'layers': [3, 2, 1],
              'typ': ['LSTM', 'GRU', 'RNN'],
             }
params = ParameterGrid(param_grid)

### Model Training and Selection

In [84]:
from auton_survival.models.dsm import DeepRecurrentSurvivalMachines

In [85]:
models = []
for param in params:
    model = DeepRecurrentSurvivalMachines(k = param['k'],
                                 distribution = param['distribution'],
                                 hidden = param['hidden'], 
                                 typ = param['typ'],
                                 layers = param['layers'])
    # The fit method is called to train the model
    model.fit(x_train, t_train, e_train, iters = 1, learning_rate = param['learning_rate'])
    models.append([[model.compute_nll(x_val, t_val, e_val), model]])

best_model = min(models)
model = best_model[0][1]

  8%|▊         | 753/10000 [00:01<00:16, 566.03it/s]
100%|██████████| 1/1 [00:00<00:00,  6.07it/s]
  8%|▊         | 753/10000 [00:01<00:12, 716.30it/s]
100%|██████████| 1/1 [00:00<00:00,  8.40it/s]
  8%|▊         | 753/10000 [00:01<00:12, 727.66it/s]
100%|██████████| 1/1 [00:00<00:00, 24.07it/s]
  8%|▊         | 753/10000 [00:01<00:12, 726.11it/s]
100%|██████████| 1/1 [00:00<00:00,  7.85it/s]
  8%|▊         | 753/10000 [00:01<00:13, 699.20it/s]
100%|██████████| 1/1 [00:00<00:00,  6.82it/s]
  8%|▊         | 753/10000 [00:01<00:13, 662.34it/s]
100%|██████████| 1/1 [00:00<00:00, 21.08it/s]
  8%|▊         | 753/10000 [00:01<00:13, 666.18it/s]
100%|██████████| 1/1 [00:00<00:00, 10.41it/s]
  8%|▊         | 753/10000 [00:01<00:13, 705.66it/s]
100%|██████████| 1/1 [00:00<00:00, 12.06it/s]
  8%|▊         | 753/10000 [00:01<00:12, 734.51it/s]
100%|██████████| 1/1 [00:00<00:00, 32.49it/s]
  8%|▊         | 753/10000 [00:01<00:12, 734.77it/s]
100%|██████████| 1/1 [00:00<00:00, 11.68it/s]
  8%|▊    

### Inference

In [86]:
out_risk = model.predict_risk(x_test, times)
out_survival = model.predict_survival(x_test, times)

### Evaluation

We evaluate the performance of RDSM in its discriminative ability (Time Dependent Concordance Index and Cumulative Dynamic AUC) as well as Brier Score on the concatenated temporal data.

In [87]:
from sksurv.metrics import concordance_index_ipcw, brier_score, cumulative_dynamic_auc

In [88]:
cis = []
brs = []

et_train = np.array([(e_train[i][j], t_train[i][j]) for i in range(len(e_train)) for j in range(len(e_train[i]))],
                 dtype = [('e', bool), ('t', float)])
et_test = np.array([(e_test[i][j], t_test[i][j]) for i in range(len(e_test)) for j in range(len(e_test[i]))],
                 dtype = [('e', bool), ('t', float)])
et_val = np.array([(e_val[i][j], t_val[i][j]) for i in range(len(e_val)) for j in range(len(e_val[i]))],
                 dtype = [('e', bool), ('t', float)])

for i, _ in enumerate(times):
    cis.append(concordance_index_ipcw(et_train, et_test, out_risk[:, i], times[i])[0])
brs.append(brier_score(et_train, et_test, out_survival, times)[1])
roc_auc = []
for i, _ in enumerate(times):
    roc_auc.append(cumulative_dynamic_auc(et_train, et_test, out_risk[:, i], times[i])[0])
for horizon in enumerate(horizons):
    print(f"For {horizon[1]} quantile,")
    print("TD Concordance Index:", cis[horizon[0]])
    print("Brier Score:", brs[0][horizon[0]])
    print("ROC AUC ", roc_auc[horizon[0]][0], "\n")

For 0.25 quantile,
TD Concordance Index: 0.6703919532144654
Brier Score: 0.008237089508192411
ROC AUC  0.6683866663300119 

For 0.5 quantile,
TD Concordance Index: 0.6222677952894667
Brier Score: 0.0320714225538293
ROC AUC  0.6235024593635348 

For 0.75 quantile,
TD Concordance Index: 0.5695961585521813
Brier Score: 0.07434723192686321
ROC AUC  0.5665689418746567 

