In [19]:
from src.data import _load_readmission_dataset, _load_simData
from src.helper_functions import _prepare_rt_tensors, _prepare_dataloaders
from src.models import *
from src.losses import *
from src.training import *
from src.metrics import *

In [20]:
# set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [21]:
x, t, e, d = _load_readmission_dataset(sequential=True)

  data['max_time'] = data.groupby('id')['t.stop'].transform(max)


In [22]:
readmission_tensor = _prepare_rt_tensors(x, t, e, d)
locals().update(readmission_tensor) # create variables from dictionary

In [23]:
# model and training parameters

input_size = len(x[0]) 
output_size = int(max(t))
hidden_size = 8   # Number of units in the RNN layer

model = SimpleGRU(input_size, hidden_size, output_size, 1, 0.1)

# Instantiate the model

num_epochs = 10000
patience = 3  # Number of epochs to wait for improvement before stopping
best_val_loss = float('inf')
wait = 0
loss_function = recurrent_loss

model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(),  lr=1e-4, weight_decay=1e-7)

In [24]:
#prepare data loaders

train_dataloader, val_dataloader, test_dataloader = _prepare_dataloaders(readmission_tensor, batch_size=32)

In [25]:
train_validate_recurrent_model(model, train_dataloader, val_dataloader, loss_function, optimizer, num_epochs, patience, print_every=10)

Epoch: 0   Training loss: 2.8469   Validation loss: 2.9233


Epoch: 10   Training loss: 2.8227   Validation loss: 2.9104
Epoch: 20   Training loss: 2.8366   Validation loss: 2.9007
Epoch: 30   Training loss: 2.8105   Validation loss: 2.8935
Epoch: 40   Training loss: 2.8154   Validation loss: 2.8880
Epoch: 50   Training loss: 2.8016   Validation loss: 2.8821
Epoch: 60   Training loss: 2.8151   Validation loss: 2.8765
Epoch: 70   Training loss: 2.8103   Validation loss: 2.8725
Epoch: 80   Training loss: 2.8007   Validation loss: 2.8685
Epoch: 90   Training loss: 2.8108   Validation loss: 2.8653
Epoch: 100   Training loss: 2.8007   Validation loss: 2.8610
Epoch: 110   Training loss: 2.7812   Validation loss: 2.8565
Epoch: 120   Training loss: 2.7923   Validation loss: 2.8535
Epoch: 130   Training loss: 2.7732   Validation loss: 2.8491
Epoch: 140   Training loss: 2.7819   Validation loss: 2.8442
Epoch: 150   Training loss: 2.7707   Validation loss: 2.8402
Epoch: 160   Training loss: 2.7957   Validation loss: 2.8340
Epoch: 170   Training loss: 2.779

In [26]:
test_predictions = model(x_test).squeeze(-1)

recurrent_predictions = test_predictions.squeeze(-1)

In [27]:
import torch
recurrent_cindex(recurrent_predictions, e_test, t_test, 12)

{'25th Quantile CI': 0.7082486793986185,
 '50th Quantile CI': 0.5458937198067633,
 '75th Quantile CI': 0.5458937198067633}

In [28]:
nsamples = 10
survival_cis = []
recurrent_cis = []
for sample in range(0, nsamples): 
    print("sample: ",sample)
    readmission_tensor = _prepare_rt_tensors(x, t, e, d)
    locals().update(readmission_tensor) # create variables from dictionary

    train_dataloader, val_dataloader, test_dataloader = _prepare_dataloaders(readmission_tensor, batch_size=32)

    input_size = len(x[0]) 
    output_size = int(max(t))
    hidden_size = 8   # Number of units in the RNN layer

    model = SimpleLSTM(input_size, hidden_size, output_size, 2, 0.1)

    # Instantiate the model

    num_epochs = 10000
    patience = 3  # Number of epochs to wait for improvement before stopping
    best_val_loss = float('inf')
    wait = 0
    loss_function = recurrent_loss

    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(),  lr=1e-4, weight_decay=1e-7)
    train_validate_recurrent_model(model, train_dataloader, val_dataloader, loss_function, optimizer, num_epochs, patience, print_every=1000)
    test_predictions = model(x_test).squeeze(-1)

    survival_predictions = test_predictions[:, :, 0:1].squeeze(-1)
    recurrent_predictions = test_predictions[:, :, 1:2].squeeze(-1)
    recurrent_cis.append(recurrent_cindex(recurrent_predictions, e_test, t_test, 12))

sample:  0
Epoch: 0   Training loss: 5.7534   Validation loss: 5.7367
Early stopping
sample:  1
Epoch: 0   Training loss: 5.6968   Validation loss: 5.8596
Early stopping
sample:  2
Epoch: 0   Training loss: 5.7915   Validation loss: 6.0727
Early stopping
sample:  3
Epoch: 0   Training loss: 5.6991   Validation loss: 6.0624
Early stopping
sample:  4
Epoch: 0   Training loss: 5.6931   Validation loss: 5.7041
Early stopping
sample:  5
Epoch: 0   Training loss: 5.7578   Validation loss: 6.1278
Early stopping
sample:  6
Epoch: 0   Training loss: 5.7912   Validation loss: 6.1603
Early stopping
sample:  7
Epoch: 0   Training loss: 5.6752   Validation loss: 5.8329
Early stopping
sample:  8
Epoch: 0   Training loss: 5.6981   Validation loss: 5.8570
Early stopping
sample:  9
Epoch: 0   Training loss: 5.7793   Validation loss: 6.2008
Early stopping


In [None]:
recurrent_cis_25th = [recurrent_ci['25th Quantile CI'] for recurrent_ci in recurrent_cis]
print("Standard deviation for the c-index of 25th quantile:", np.std(np.array(recurrent_cis_25th)))
recurrent_cis_50th = [recurrent_ci['50th Quantile CI'] for recurrent_ci in recurrent_cis]
print("Standard deviation for the c-index of 50th quantile:", np.std(np.array(recurrent_cis_50th)))
recurrent_cis_75th = [recurrent_ci['75th Quantile CI'] for recurrent_ci in recurrent_cis]
print("Standard deviation for the c-index of 75th quantile:", np.std(np.array(recurrent_cis_75th)))
