In [16]:
import torch
import numpy as np
from deepreevent import *

In [17]:
set_seed(42)

In [18]:
# set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [19]:
x, t, e, d = _load_readmission_dataset(sequential=True)

  data['max_time'] = data.groupby('id')['t.stop'].transform(max)


In [20]:
readmission_tensor = _prepare_rt_tensors(x, t, e, d)
locals().update(readmission_tensor) # create variables from dictionary

In [21]:
train_dataloader, val_dataloader, test_dataloader = _prepare_dataloaders(readmission_tensor, batch_size=32)

In [22]:
# model and training parameters

input_size = len(x[0]) 
output_size = int(max(t))
hidden_size = 6   # Number of units in the RNN layer

model = SimpleRNN(input_size, hidden_size, output_size, 2, 0)

# Instantiate the model

num_epochs = 10000
patience = 10  # Number of epochs to wait for improvement before stopping
best_val_loss = float('inf')
loss_function = recurrent_terminal_loss

model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(),  lr=1e-4, weight_decay=1e-8)

In [23]:
train_validate_rt_model(model, train_dataloader, val_dataloader, loss_function, optimizer, num_epochs, patience, print_every=10)

Epoch: 0   Training loss: 4.7314   Validation loss: 4.6558


Epoch: 10   Training loss: 4.6460   Validation loss: 4.5538
Epoch: 20   Training loss: 4.5211   Validation loss: 4.4522
Epoch: 30   Training loss: 4.4058   Validation loss: 4.3501
Epoch: 40   Training loss: 4.3129   Validation loss: 4.2480
Epoch: 50   Training loss: 4.2165   Validation loss: 4.1460
Epoch: 60   Training loss: 4.1086   Validation loss: 4.0459
Epoch: 70   Training loss: 4.0297   Validation loss: 3.9490
Epoch: 80   Training loss: 3.9518   Validation loss: 3.8570
Epoch: 90   Training loss: 3.8715   Validation loss: 3.7716
Epoch: 100   Training loss: 3.7981   Validation loss: 3.6942
Epoch: 110   Training loss: 3.7192   Validation loss: 3.6250
Epoch: 120   Training loss: 3.6783   Validation loss: 3.5645
Epoch: 130   Training loss: 3.6387   Validation loss: 3.5119
Epoch: 140   Training loss: 3.5785   Validation loss: 3.4657
Epoch: 150   Training loss: 3.5345   Validation loss: 3.4261
Epoch: 160   Training loss: 3.5015   Validation loss: 3.3923
Epoch: 170   Training loss: 3.480

In [24]:
test_predictions = model(x_test).squeeze(-1)

survival_predictions = test_predictions[:, :, 0:1].squeeze(-1)
recurrent_predictions = test_predictions[:, :, 1:2].squeeze(-1)

In [25]:
survival_cindex(survival_predictions, d_train, t_train, d_test, t_test)

{'Time 3': 0.9697368421052632,
 'Time 4': 0.9749497470720346,
 'Time 5': 0.9749497470720346,
 'Time 6': 0.9749497470720346}

In [26]:
recurrent_cindex(recurrent_predictions, e_test, t_test, max_time = 6, horizons = "all", tolerance = 0.1)

{'Time 1': 0.812981298129813,
 'Time 2': 0.6587301587301587,
 'Time 3': 0.4974924774322969,
 'Time 4': 0.5440217391304348,
 'Time 5': 0.576226012793177,
 'Time 6': 0.576226012793177}

In [27]:
from sklearn.utils import resample

nsamples = 200  # Number of bootstrap samples
survival_cis = []
recurrent_cis = []

for sample in range(nsamples):
    
    # Resample the test data
    x_test_resampled, t_test_resampled, e_test_resampled, d_test_resampled = resample(x_test, t_test, e_test, d_test, replace=True)
    
    # Make predictions on the resampled test data
    test_predictions = model(x_test_resampled).squeeze(-1)
    
    survival_predictions = test_predictions[:, :, 0:1].squeeze(-1)
    recurrent_predictions = test_predictions[:, :, 1:2].squeeze(-1)
    try: 
        # Calculate metrics for the resampled test data
        survival_cis.append(survival_cindex(survival_predictions, d_train, t_train, d_test_resampled, t_test_resampled))
        recurrent_cis.append(recurrent_cindex(recurrent_predictions, e_test_resampled, t_test_resampled, 6))
    except: 
        continue


In [28]:
def compute_metric_ci(cis, time_key, rounding = 2):
    cis_time = [cis[i][time_key] for i in range(len(cis))]
    survival_ci_mean = np.mean(cis_time)
    survival_ci_std = np.std(np.array(cis_time))
    return round(survival_ci_mean, rounding), round(survival_ci_std, rounding)

In [29]:
survival_cis_3_mean, survival_cis_3_std = compute_metric_ci(survival_cis, "Time 3")
survival_cis_4_mean, survival_cis_4_std = compute_metric_ci(survival_cis, "Time 4")
survival_cis_5_mean, survival_cis_5_std = compute_metric_ci(survival_cis, "Time 5")
survival_cis_6_mean, survival_cis_6_std = compute_metric_ci(survival_cis, "Time 6")

print(f"Time 3: Mean = {survival_cis_3_mean}, Std = {survival_cis_3_std}")
print(f"Time 4: Mean = {survival_cis_4_mean}, Std = {survival_cis_4_std}")
print(f"Time 5: Mean = {survival_cis_5_mean}, Std = {survival_cis_5_std}")
print(f"Time 6: Mean = {survival_cis_6_mean}, Std = {survival_cis_6_std}")

Time 3: Mean = 0.97, Std = 0.02
Time 4: Mean = 0.98, Std = 0.01
Time 5: Mean = 0.98, Std = 0.01
Time 6: Mean = 0.98, Std = 0.01


In [30]:
recurrent_cis_1_mean, recurrent_cis_1_std = compute_metric_ci(recurrent_cis, "Time 1")
recurrent_cis_2_mean, recurrent_cis_2_std = compute_metric_ci(recurrent_cis, "Time 2")
recurrent_cis_3_mean, recurrent_cis_3_std = compute_metric_ci(recurrent_cis, "Time 3")
recurrent_cis_4_mean, recurrent_cis_4_std = compute_metric_ci(recurrent_cis, "Time 4")
recurrent_cis_5_mean, recurrent_cis_5_std = compute_metric_ci(recurrent_cis, "Time 5")
recurrent_cis_6_mean, recurrent_cis_6_std = compute_metric_ci(recurrent_cis, "Time 6")

print(f"Time 1: Mean = {recurrent_cis_1_mean}, Std = {recurrent_cis_1_std}")
print(f"Time 2: Mean = {recurrent_cis_2_mean}, Std = {recurrent_cis_2_std}")
print(f"Time 3: Mean = {recurrent_cis_3_mean}, Std = {recurrent_cis_3_std}")
print(f"Time 4: Mean = {recurrent_cis_4_mean}, Std = {recurrent_cis_4_std}")
print(f"Time 5: Mean = {recurrent_cis_5_mean}, Std = {recurrent_cis_5_std}")
print(f"Time 6: Mean = {recurrent_cis_6_mean}, Std = {recurrent_cis_6_std}")

Time 1: Mean = 0.82, Std = 0.04
Time 2: Mean = 0.69, Std = 0.05
Time 3: Mean = 0.5, Std = 0.06
Time 4: Mean = 0.48, Std = 0.06
Time 5: Mean = 0.51, Std = 0.06
Time 6: Mean = 0.51, Std = 0.06
