In [1]:
import torch
import numpy as np
from deepreevent import *

In [2]:
set_seed(42)

In [3]:
# set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
x, t, e, d = _load_readmission_dataset(sequential=True)

  data['max_time'] = data.groupby('id')['t.stop'].transform(max)


In [5]:
readmission_tensor = _prepare_rt_tensors(x, t, e, d)
locals().update(readmission_tensor) # create variables from dictionary

  x = torch.tensor(x, dtype=torch.float32)


In [6]:
# model and training parameters

input_size = len(x[0]) 
output_size = int(max(t))
hidden_size = 8   # Number of units in the RNN layer

model = SimpleGRU(input_size, hidden_size, output_size, 1, 0.1)

# Instantiate the model

num_epochs = 10000
patience = 3  # Number of epochs to wait for improvement before stopping
best_val_loss = float('inf')
wait = 0
loss_function = recurrent_loss

model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(),  lr=1e-4, weight_decay=1e-7)

In [7]:
#prepare data loaders

train_dataloader, val_dataloader, test_dataloader = _prepare_dataloaders(readmission_tensor, batch_size=32)

In [8]:
train_validate_recurrent_model(model, train_dataloader, val_dataloader, loss_function, optimizer, num_epochs, patience, print_every=10)

Epoch: 0   Training loss: 3.0064   Validation loss: 2.8554


Epoch: 10   Training loss: 2.9716   Validation loss: 2.8452
Epoch: 20   Training loss: 2.9328   Validation loss: 2.8378
Epoch: 30   Training loss: 2.8771   Validation loss: 2.8335
Epoch: 40   Training loss: 2.8732   Validation loss: 2.8312
Epoch: 50   Training loss: 2.8580   Validation loss: 2.8306
Epoch: 60   Training loss: 2.8378   Validation loss: 2.8302
Early stopping


In [9]:
test_predictions = model(x_test).squeeze(-1)

recurrent_predictions = test_predictions.squeeze(-1)

In [10]:
recurrent_cindex(recurrent_predictions, e_test, t_test, max_time = 6, horizons = "all", tolerance = 0.1)

{'Time 1': 0.7904290429042904,
 'Time 2': 0.5367063492063492,
 'Time 3': 0.4127382146439318,
 'Time 4': 0.4967391304347826,
 'Time 5': 0.5266524520255863,
 'Time 6': 0.5266524520255863}

In [11]:
from sklearn.utils import resample

nsamples = 100  # Number of bootstrap samples
survival_cis = []
recurrent_cis = []

for sample in range(nsamples):
    
    # Resample the test data
    x_test_resampled, t_test_resampled, e_test_resampled, d_test_resampled = resample(x_test, t_test, e_test, d_test, replace=True)
    
    # Make predictions on the resampled test data
    test_predictions = model(x_test_resampled).squeeze(-1)
    
    recurrent_predictions = test_predictions.squeeze(-1)
    try: 
        # Calculate metrics for the resampled test data
        recurrent_cis.append(recurrent_cindex(recurrent_predictions, e_test_resampled, t_test_resampled, 6))
    except: 
        continue

In [12]:
def compute_metric_ci(cis, time_key, rounding = 2):
    cis_time = [cis[i][time_key] for i in range(len(cis))]
    survival_ci_mean = np.mean(cis_time)
    survival_ci_std = np.std(np.array(cis_time))
    return round(survival_ci_mean, rounding), round(survival_ci_std, rounding)

In [13]:
recurrent_cis_1_mean, recurrent_cis_1_std = compute_metric_ci(recurrent_cis, "Time 1")
recurrent_cis_2_mean, recurrent_cis_2_std = compute_metric_ci(recurrent_cis, "Time 2")
recurrent_cis_3_mean, recurrent_cis_3_std = compute_metric_ci(recurrent_cis, "Time 3")
recurrent_cis_4_mean, recurrent_cis_4_std = compute_metric_ci(recurrent_cis, "Time 4")
recurrent_cis_5_mean, recurrent_cis_5_std = compute_metric_ci(recurrent_cis, "Time 5")
recurrent_cis_6_mean, recurrent_cis_6_std = compute_metric_ci(recurrent_cis, "Time 6")

print(f"Time 1: Mean = {recurrent_cis_1_mean}, Std = {recurrent_cis_1_std}")
print(f"Time 2: Mean = {recurrent_cis_2_mean}, Std = {recurrent_cis_2_std}")
print(f"Time 3: Mean = {recurrent_cis_3_mean}, Std = {recurrent_cis_3_std}")
print(f"Time 4: Mean = {recurrent_cis_4_mean}, Std = {recurrent_cis_4_std}")
print(f"Time 5: Mean = {recurrent_cis_5_mean}, Std = {recurrent_cis_5_std}")
print(f"Time 6: Mean = {recurrent_cis_6_mean}, Std = {recurrent_cis_6_std}")

Time 1: Mean = 0.82, Std = 0.03
Time 2: Mean = 0.51, Std = 0.07
Time 3: Mean = 0.44, Std = 0.06
Time 4: Mean = 0.46, Std = 0.06
Time 5: Mean = 0.48, Std = 0.07
Time 6: Mean = 0.48, Std = 0.07
