## Equivalent to Ren Etal 2019

In [25]:
import torch
import numpy as np
from deepreevent import *

In [26]:
set_seed(42)

In [27]:
# set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [28]:
x, t, e, d = _load_readmission_dataset(sequential=True)

  data['max_time'] = data.groupby('id')['t.stop'].transform(max)


In [29]:
readmission_tensor = _prepare_rt_tensors(x, t, e, d)
locals().update(readmission_tensor) # create variables from dictionary

In [30]:
# model and training parameters

input_size = len(x[0]) 
output_size = int(max(t))
hidden_size = 8   # Number of units in the RNN layer

model = SimpleGRU(input_size, hidden_size, output_size, 1, 0.1)

# Instantiate the model

num_epochs = 10000
patience = 3  # Number of epochs to wait for improvement before stopping
best_val_loss = float('inf')
wait = 0
loss_function = survival_loss

model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(),  lr=1e-4, weight_decay=1e-7)

In [31]:
#prepare data loaders

train_dataset = torch.utils.data.TensorDataset(x_train, t_train, e_train, d_train)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32)

val_dataset = torch.utils.data.TensorDataset(x_val, t_val, e_val, d_val)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=32)

In [32]:
train_validate_survival_model(model, train_dataloader, val_dataloader, loss_function, optimizer, num_epochs, patience, print_every=10)

Epoch: 0   Training loss: 3.4248   Validation loss: 3.0438


Epoch: 10   Training loss: 3.2613   Validation loss: 2.9075
Epoch: 20   Training loss: 3.1152   Validation loss: 2.7739
Epoch: 30   Training loss: 2.9501   Validation loss: 2.6415
Epoch: 40   Training loss: 2.8180   Validation loss: 2.5106
Epoch: 50   Training loss: 2.6530   Validation loss: 2.3793
Epoch: 60   Training loss: 2.5068   Validation loss: 2.2487
Epoch: 70   Training loss: 2.3635   Validation loss: 2.1191
Epoch: 80   Training loss: 2.2181   Validation loss: 1.9917
Epoch: 90   Training loss: 2.0835   Validation loss: 1.8687
Epoch: 100   Training loss: 1.9493   Validation loss: 1.7500
Epoch: 110   Training loss: 1.8335   Validation loss: 1.6365
Epoch: 120   Training loss: 1.7110   Validation loss: 1.5290
Epoch: 130   Training loss: 1.6239   Validation loss: 1.4283
Epoch: 140   Training loss: 1.5221   Validation loss: 1.3344
Epoch: 150   Training loss: 1.4488   Validation loss: 1.2480
Epoch: 160   Training loss: 1.3450   Validation loss: 1.1689
Epoch: 170   Training loss: 1.289

In [33]:
test_predictions = model(x_test).squeeze(-1)

recurrent_predictions = test_predictions.squeeze(-1)

In [34]:
from sklearn.utils import resample

nsamples = 100  # Number of bootstrap samples
survival_cis = []
recurrent_cis = []

for sample in range(nsamples):
    
    # Resample the test data
    x_test_resampled, t_test_resampled, e_test_resampled, d_test_resampled = resample(x_test, t_test, e_test, d_test, replace=True)
    
    # Make predictions on the resampled test data
    test_predictions = model(x_test_resampled).squeeze(-1)
    
    survival_predictions = test_predictions.squeeze(-1)
    try: 
        # Calculate metrics for the resampled test data
        survival_cis.append(calculate_survival_metrics(survival_predictions, d_train, t_train, d_test_resampled, t_test_resampled))
    except: 
        # When all samples are censored
        continue

In [35]:
def compute_metric_ci(cis, time_key, rounding = 2):
    cis_time = [cis[i][time_key] for i in range(len(cis))]
    survival_ci_mean = np.mean(cis_time)
    survival_ci_std = np.std(np.array(cis_time))
    return round(survival_ci_mean, rounding), round(survival_ci_std, rounding)

In [36]:
survival_cis_3_mean, survival_cis_3_std = compute_metric_ci(survival_cis, "Time 3")
survival_cis_4_mean, survival_cis_4_std = compute_metric_ci(survival_cis, "Time 4")
survival_cis_5_mean, survival_cis_5_std = compute_metric_ci(survival_cis, "Time 5")
survival_cis_6_mean, survival_cis_6_std = compute_metric_ci(survival_cis, "Time 6")

print(f"Time 3: Mean = {survival_cis_3_mean}, Std = {survival_cis_3_std}")
print(f"Time 4: Mean = {survival_cis_4_mean}, Std = {survival_cis_4_std}")
print(f"Time 5: Mean = {survival_cis_5_mean}, Std = {survival_cis_5_std}")
print(f"Time 6: Mean = {survival_cis_6_mean}, Std = {survival_cis_6_std}")

Time 3: Mean = 0.96, Std = 0.02
Time 4: Mean = 0.97, Std = 0.01
Time 5: Mean = 0.97, Std = 0.01
Time 6: Mean = 0.97, Std = 0.01
