# DRSA

This notebook implements the DRSA survival model for Heart Failure patients (with functional covariates).

In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import scale
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset,DataLoader
from lifelines.utils import concordance_index
from tqdm import tqdm

from utils.drsa_utils import *
from utils.common import *

In [2]:
# load data
df = pd.read_csv('../../data/main_process_preprocessed_data.csv')
test = pd.read_csv('../../data/main_process_preprocessed_data_test.csv')

In [3]:
# use GPU if available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

## Prepare data

In [4]:
# set time windows parameters
T_max = 365*10
n_times = int(T_max/60)
# discretise time
df['discretised_time_event'] = [int(v) for v in df.time_event/T_max*n_times]
times = np.arange(n_times).reshape(-1,1,1)

print('(Arbitrary) maximum survival time:',np.round(T_max/365,1),'years')
print('Number of time steps:',n_times)
print('Size time step:',np.round(T_max/n_times/30,1),'months')

(Arbitrary) maximum survival time: 10.0 years
Number of time steps: 60
Size time step: 2.0 months


In [None]:
# create dummy sex
df['sexM'] = [1 if v == 'M' else 0 for v in df.sex]
test['sexM'] = [1 if v == 'M' else 0 for v in test.sex]

In [None]:
# set features
features = ['sexM', 'age_in','ACE_PC1', 'ACE_PC2','beta_PC1', 'beta_PC2','aldosteronics_PC1','aldosteronics_PC2','hospitalisation_PC1', 'hospitalisation_PC2']

X,X_test = df[features],test[features]
y = df.discretised_time_event
true_times = df.time_event
status = df.status

In [None]:
# scale
mean = X.mean()
std = X.std()
X -= mean
X /= std
X_test -= mean
X_test /= std

In [None]:
# split train in train/validation (for early stopping)
X_train, X_valid, y_train, y_valid, status_train, status_valid = \
    train_test_split(X, y, status, test_size=0.15, random_state=47)

In [None]:
X_train.shape,X_valid.shape,X_test.shape

((2702, 10), (477, 10), (1362, 10))

## Reformat datasets for lstm layers

In [None]:
x_training = []
for i in X_train.index:
    tmp = np.repeat(X_train.loc[i,:].values.reshape(1,1,-1),len(times),axis = 0)
    observation_i = np.concatenate([tmp,times], axis = 2)
    x_training.append(observation_i)  
x_training = np.concatenate(x_training, axis = 1)


x_validation = []
for i in X_valid.index:
    tmp = np.repeat(X_valid.loc[i,:].values.reshape(1,1,-1),len(times),axis = 0)
    observation_i = np.concatenate([tmp,times], axis = 2)
    x_validation.append(observation_i)
x_validation = np.concatenate(x_validation, axis = 1)


x_test = []
for i in X_test.index:
    tmp = np.repeat(X_test.loc[i,:].values.reshape(1,1,-1),len(times),axis = 0)
    observation_i = np.concatenate([tmp,times], axis = 2)
    x_test.append(observation_i)  
x_test = np.concatenate(x_test, axis = 1)

In [None]:
# dimensions: time,id,features (added time)
x_training.shape, x_validation.shape, x_test.shape

((60, 2702, 11), (60, 477, 11), (60, 1362, 11))

In [None]:
# transform validation sets in torch tensors
X_valid = torch.from_numpy(x_validation.astype('float32'))
y_valid = torch.from_numpy(y_valid.values)
status_valid = torch.from_numpy(status_valid.values.astype('float32'))
X_valid,y_valid,status_valid = X_valid.to(device),y_valid.to(device),status_valid.to(device)

# transform test set in tensors
X_test = torch.from_numpy(x_test.astype('float32'))
X_test = X_test.to(device)

## Prepare for training

In [None]:
# set early stopping parameter
max_epochs_no_improvement = 150

# build a training data loader
trainds= SequenceDataset(x_training, y_train,status_train) 
params = {'batch_size': 50,
          'shuffle': True,
          'num_workers': 4}
 
train_dl = DataLoader(trainds, **params)

## Train

In [None]:
# set criterion
criterion = DRSA_Loss()

In [None]:
# build the network
net = DRSA(len(features)+1).to(device)

# set optimizer
optimizer = optim.Adam(net.parameters(), lr=1e-4)

In [None]:
train_losses,valid_losses = [],[]
best_net = DRSA(len(features)+1).to(device)
best_loss = 1e10
epochs_no_improvement = 0

#train
for epoch in range(5000):  # loop over the dataset multiple times
    running_loss = 0.0
    
    net.train()
    for i, data in enumerate(train_dl, 0):
        inputs, labels, status = data
        # dimensions: time, id, features
        inputs = inputs.transpose(0,1)
        
        # make it use GPU, if you have it
        inputs, labels, status = inputs.to(device), labels.to(device), status.to(device) 
        # zero the parameter gradients
        optimizer.zero_grad()
        # forward + backward + optimize (rows: id; columns: time)
        outputs = net(inputs).reshape(n_times,-1).transpose(0,1)
        loss = criterion(outputs, labels,status)
        loss.backward()
        optimizer.step()
        # training loss
        running_loss += loss.item()
    
    # Evaluate on validation set
    net.eval()
    valid_pred = net(X_valid).reshape(n_times,-1).transpose(0,1)
    valid_loss = criterion(valid_pred, y_valid,status_valid).item()
    
    
    # store loss
    train_losses.append(running_loss)
    valid_losses.append(valid_loss)
    
    # early stopping
    if valid_loss < best_loss:
        best_loss = valid_loss
        # save weights
        best_net.load_state_dict(net.state_dict())
        epochs_no_improvement = 0
    else:
        epochs_no_improvement += 1
       
    if epochs_no_improvement > max_epochs_no_improvement:
        break
    
    if epoch % 25 == 0:
        print('Epoch: ',epoch,'   Training Loss: ',np.round(running_loss,2),'   Current Validation Loss: ',np.round(valid_loss,2),'   Best Validation Loss: ',np.round(best_loss,2))

        
print('Finished Training')

Epoch:  0    Training Loss:  20905.15    Current Validation Loss:  3036.52    Best Validation Loss:  3036.52


In [None]:
plt.figure(figsize = (10,5))

plt.subplot(1,2,1)
plt.plot(train_losses, color = 'orange', label = 'training')
plt.legend()
plt.title('Survival loss')
plt.xlabel('epochs')

plt.subplot(1,2,2)
plt.plot(valid_losses, color = 'blue', label = 'validation')
plt.legend()
plt.title('Survival loss')
plt.xlabel('epochs')

plt.show()

# Evaluate Concordance Index on test set

Note: we convert predictions to original time scale and we compare with true time of events; this is done in order to make it comparable with other models

In [None]:
best_net.eval()
prediction_test_set = best_net(X_test).reshape(n_times,-1).transpose(0,1)

In [None]:
# compute survival time probabilities from hazards
prediction_test = [prediction_test_set[:,0].reshape(-1,1)]
tmp = (1 - prediction_test_set).cumprod(1)
for j in np.arange(1,prediction_test_set.shape[1]):
    if j > 0:
        prediction_test.append((prediction_test_set[:,j]*tmp[:,j-1]).reshape(-1,1))
prediction_test = torch.cat(prediction_test,dim = 1).detach().numpy()

In [None]:
expected_survival_times = compute_expected_survival_time(prediction_test_set,n_times,T_max)

In [None]:
C = concordance_index(test.time_event, 
                  expected_survival_times, 
                  test.status)

print('Concordance Index on test set:',np.round(C*100,2),'%')

In [None]:
# save model
torch.save(best_net.state_dict(), '../../data/DRSA_weights.pt')