In [23]:
from src.data import _load_readmission_dataset, _load_simData
from src.helper_functions import _prepare_rt_tensors, _prepare_dataloaders
from src.models import *
from src.losses import *    
from src.training import *
from src.metrics import *

In [24]:
# set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [25]:
x, t, e, d = _load_readmission_dataset(sequential=True)

  data['max_time'] = data.groupby('id')['t.stop'].transform(max)


In [26]:
readmission_tensor = _prepare_rt_tensors(x, t, e, d)
locals().update(readmission_tensor) # create variables from dictionary

In [27]:
train_dataloader, val_dataloader, test_dataloader = _prepare_dataloaders(readmission_tensor, batch_size=32)

In [28]:
# model and training parameters

input_size = len(x[0]) 
output_size = int(max(t))
hidden_size = 8   # Number of units in the RNN layer

model = SimpleLSTM(input_size, hidden_size, output_size, 2, 0.1)

# Instantiate the model

num_epochs = 10000
patience = 3  # Number of epochs to wait for improvement before stopping
best_val_loss = float('inf')
wait = 0
loss_function = recurrent_terminal_loss

model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(),  lr=1e-4, weight_decay=1e-7)

In [29]:
train_validate_rt_model(model, train_dataloader, val_dataloader, loss_function, optimizer, num_epochs, patience, print_every=10)

Epoch: 0   Training loss: 5.8839   Validation loss: 5.6622


Epoch: 10   Training loss: 5.7795   Validation loss: 5.5669
Epoch: 20   Training loss: 5.6569   Validation loss: 5.4793
Epoch: 30   Training loss: 5.5834   Validation loss: 5.3979
Epoch: 40   Training loss: 5.5035   Validation loss: 5.3194
Epoch: 50   Training loss: 5.4012   Validation loss: 5.2430
Epoch: 60   Training loss: 5.3176   Validation loss: 5.1664
Epoch: 70   Training loss: 5.2344   Validation loss: 5.0883
Epoch: 80   Training loss: 5.1366   Validation loss: 5.0086
Epoch: 90   Training loss: 5.0462   Validation loss: 4.9257
Epoch: 100   Training loss: 4.9762   Validation loss: 4.8381
Epoch: 110   Training loss: 4.8446   Validation loss: 4.7436
Epoch: 120   Training loss: 4.7577   Validation loss: 4.6406
Epoch: 130   Training loss: 4.6301   Validation loss: 4.5323
Epoch: 140   Training loss: 4.5123   Validation loss: 4.4172
Epoch: 150   Training loss: 4.3960   Validation loss: 4.3038
Epoch: 160   Training loss: 4.2634   Validation loss: 4.1967
Epoch: 170   Training loss: 4.201

In [30]:
test_predictions = model(x_test).squeeze(-1)

survival_predictions = test_predictions[:, :, 0:1].squeeze(-1)
recurrent_predictions = test_predictions[:, :, 1:2].squeeze(-1)

In [31]:
calculate_survival_metrics(survival_predictions, d_train, t_train, d_test, t_test, d_val, t_val)

{'25th Quantile CI': 0.9565789473684212,
 '50th Quantile CI': 0.9662366156188291,
 '75th Quantile CI': 0.9662366156188291}

In [32]:
recurrent_cindex(recurrent_predictions, e_test, t_test, 12)

{'25th Quantile CI': 0.7277529459569281,
 '50th Quantile CI': 0.5647342995169082,
 '75th Quantile CI': 0.5647342995169082}

In [36]:
nsamples = 10
survival_cis = []
recurrent_cis = []
for sample in range(0, nsamples): 
    print("sample: ",sample)
    readmission_tensor = _prepare_rt_tensors(x, t, e, d)
    locals().update(readmission_tensor) # create variables from dictionary

    train_dataloader, val_dataloader, test_dataloader = _prepare_dataloaders(readmission_tensor, batch_size=32)

    input_size = len(x[0]) 
    output_size = int(max(t))
    hidden_size = 8   # Number of units in the RNN layer

    model = SimpleLSTM(input_size, hidden_size, output_size, 2, 0.1)

    # Instantiate the model

    num_epochs = 10000
    patience = 3  # Number of epochs to wait for improvement before stopping
    best_val_loss = float('inf')
    wait = 0
    loss_function = recurrent_terminal_loss

    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(),  lr=1e-4, weight_decay=1e-7)
    train_validate_rt_model(model, train_dataloader, val_dataloader, loss_function, optimizer, num_epochs, patience, print_every=1000)
    test_predictions = model(x_test).squeeze(-1)

    survival_predictions = test_predictions[:, :, 0:1].squeeze(-1)
    recurrent_predictions = test_predictions[:, :, 1:2].squeeze(-1)
    survival_cis.append(calculate_survival_metrics(survival_predictions, d_train, t_train, d_test, t_test, d_val, t_val))
    recurrent_cis.append(recurrent_cindex(recurrent_predictions, e_test, t_test, 12))

Epoch: 0   Training loss: 5.2783   Validation loss: 4.9569


Early stopping
Epoch: 0   Training loss: 5.1055   Validation loss: 5.1536
Early stopping
Epoch: 0   Training loss: 5.6743   Validation loss: 5.6215
Epoch: 1000   Training loss: 3.2161   Validation loss: 2.8989
Early stopping
Epoch: 0   Training loss: 4.9109   Validation loss: 4.7191
Early stopping
Epoch: 0   Training loss: 5.8840   Validation loss: 5.8265
Early stopping
Epoch: 0   Training loss: 5.1185   Validation loss: 4.9886
Early stopping
Epoch: 0   Training loss: 4.9664   Validation loss: 4.7394
Early stopping
Epoch: 0   Training loss: 4.7094   Validation loss: 4.5858
Epoch: 1000   Training loss: 3.1888   Validation loss: 2.8866
Early stopping
Epoch: 0   Training loss: 4.8085   Validation loss: 4.5682
Early stopping
Epoch: 0   Training loss: 6.0655   Validation loss: 5.8072
Early stopping


In [46]:
survival_cis_25th = [survival_ci['25th Quantile CI'] for  survival_ci in survival_cis]
print("Standard deviation for the c-index of 25th quantile:", np.std(np.array(survival_cis_25th)))

survival_cis_50th = [survival_ci['50th Quantile CI'] for survival_ci in survival_cis]
print("Standard deviation for the c-index of 50th quantile:", np.std(np.array(survival_cis_50th)))

survival_cis_75th = [survival_ci['75th Quantile CI'] for survival_ci in survival_cis]
print("Standard deviation for the c-index of 75th quantile:", np.std(np.array(survival_cis_75th)))

Standard deviation for the c-index of 25th quantile: 0.01306281436420088
Standard deviation for the c-index of 50th quantile: 0.01483553475574478
Standard deviation for the c-index of 75th quantile: 0.022048538422807778
