In [14]:
from src.data import _load_readmission_dataset, _load_simData
from src.helper_functions import _prepare_rt_tensors, _prepare_dataloaders
from src.models import *
from src.losses import *    
from src.training import *
from src.metrics import *

In [15]:
# set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [16]:
x, t, e, d = _load_readmission_dataset(sequential=True)

  data['max_time'] = data.groupby('id')['t.stop'].transform(max)


In [17]:
readmission_tensor = _prepare_rt_tensors(x, t, e, d)
locals().update(readmission_tensor) # create variables from dictionary

In [18]:
train_dataloader, val_dataloader, test_dataloader = _prepare_dataloaders(readmission_tensor, batch_size=32)

In [19]:
# model and training parameters

input_size = len(x[0]) 
output_size = int(max(t))
hidden_size = 8   # Number of units in the RNN layer

model = SimpleLSTM(input_size, hidden_size, output_size, 2, 0.1)

# Instantiate the model

num_epochs = 10000
patience = 3  # Number of epochs to wait for improvement before stopping
best_val_loss = float('inf')
wait = 0
loss_function = recurrent_terminal_loss

model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(),  lr=1e-4, weight_decay=1e-7)

In [20]:
train_validate_rt_model(model, train_dataloader, val_dataloader, loss_function, optimizer, num_epochs, patience, print_every=10)

Epoch: 0   Training loss: 4.9938   Validation loss: 4.8714


Epoch: 10   Training loss: 4.9355   Validation loss: 4.8159
Epoch: 20   Training loss: 4.8856   Validation loss: 4.7648
Epoch: 30   Training loss: 4.8238   Validation loss: 4.7148
Epoch: 40   Training loss: 4.7744   Validation loss: 4.6646
Epoch: 50   Training loss: 4.7145   Validation loss: 4.6111
Epoch: 60   Training loss: 4.6543   Validation loss: 4.5528
Epoch: 70   Training loss: 4.5865   Validation loss: 4.4871
Epoch: 80   Training loss: 4.4944   Validation loss: 4.4135
Epoch: 90   Training loss: 4.4205   Validation loss: 4.3323
Epoch: 100   Training loss: 4.3231   Validation loss: 4.2469
Epoch: 110   Training loss: 4.2409   Validation loss: 4.1606
Epoch: 120   Training loss: 4.1673   Validation loss: 4.0761
Epoch: 130   Training loss: 4.0778   Validation loss: 3.9956
Epoch: 140   Training loss: 3.9861   Validation loss: 3.9211
Epoch: 150   Training loss: 3.9096   Validation loss: 3.8530
Epoch: 160   Training loss: 3.8543   Validation loss: 3.7893
Epoch: 170   Training loss: 3.858

In [21]:
test_predictions = model(x_test).squeeze(-1)

survival_predictions = test_predictions[:, :, 0:1].squeeze(-1)
recurrent_predictions = test_predictions[:, :, 1:2].squeeze(-1)

In [22]:
calculate_survival_metrics(survival_predictions, d_train, t_train, d_test, t_test, d_val, t_val)

[0.9539473684210527, 0.9618800498922262, 0.9618800498922262]


{'Time 2': 0.9539473684210527,
 'Time 3': 0.9618800498922262,
 'Time 4': 0.9618800498922262}

In [23]:
recurrent_cindex(recurrent_predictions, e_test, t_test, 6)

1


In [11]:
nsamples = 10
survival_cis = []
recurrent_cis = []
for sample in range(0, nsamples): 
    print("sample: ",sample)
    readmission_tensor = _prepare_rt_tensors(x, t, e, d)
    locals().update(readmission_tensor) # create variables from dictionary

    train_dataloader, val_dataloader, test_dataloader = _prepare_dataloaders(readmission_tensor, batch_size=32)

    input_size = len(x[0]) 
    output_size = int(max(t))
    hidden_size = 8   # Number of units in the RNN layer

    model = SimpleLSTM(input_size, hidden_size, output_size, 2, 0.1)

    # Instantiate the model

    num_epochs = 10000
    patience = 3  # Number of epochs to wait for improvement before stopping
    best_val_loss = float('inf')
    wait = 0
    loss_function = recurrent_terminal_loss

    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(),  lr=1e-4, weight_decay=1e-7)
    train_validate_rt_model(model, train_dataloader, val_dataloader, loss_function, optimizer, num_epochs, patience, print_every=1000)
    test_predictions = model(x_test).squeeze(-1)

    survival_predictions = test_predictions[:, :, 0:1].squeeze(-1)
    recurrent_predictions = test_predictions[:, :, 1:2].squeeze(-1)
    survival_cis.append(calculate_survival_metrics(survival_predictions, d_train, t_train, d_test, t_test, d_val, t_val))
    recurrent_cis.append(recurrent_cindex(recurrent_predictions, e_test, t_test, 12))

sample:  0
Epoch: 0   Training loss: 5.2466   Validation loss: 5.1963


Early stopping
sample:  1
Epoch: 0   Training loss: 5.4680   Validation loss: 5.2982
Early stopping
sample:  2
Epoch: 0   Training loss: 5.5915   Validation loss: 5.4352


KeyboardInterrupt: 

In [None]:
survival_cis_25th = [survival_ci['25th Quantile CI'] for  survival_ci in survival_cis]
print("Standard deviation for the c-index of 25th quantile:", np.std(np.array(survival_cis_25th)))

survival_cis_50th = [survival_ci['50th Quantile CI'] for survival_ci in survival_cis]
print("Standard deviation for the c-index of 50th quantile:", np.std(np.array(survival_cis_50th)))

survival_cis_75th = [survival_ci['75th Quantile CI'] for survival_ci in survival_cis]
print("Standard deviation for the c-index of 75th quantile:", np.std(np.array(survival_cis_75th)))

Standard deviation for the c-index of 25th quantile: 0.01306281436420088
Standard deviation for the c-index of 50th quantile: 0.01483553475574478
Standard deviation for the c-index of 75th quantile: 0.022048538422807778


In [None]:
recurrent_cis_25th = [recurrent_ci['25th Quantile CI'] for recurrent_ci in recurrent_cis]
print("Standard deviation for the c-index of 25th quantile:", np.std(np.array(recurrent_cis_25th)))
recurrent_cis_50th = [recurrent_ci['50th Quantile CI'] for recurrent_ci in recurrent_cis]
print("Standard deviation for the c-index of 50th quantile:", np.std(np.array(recurrent_cis_50th)))
recurrent_cis_75th = [recurrent_ci['75th Quantile CI'] for recurrent_ci in recurrent_cis]
print("Standard deviation for the c-index of 75th quantile:", np.std(np.array(recurrent_cis_75th)))


Standard deviation for the c-index of 25th quantile: 0.009576180015140427
Standard deviation for the c-index of 50th quantile: 0.009046369381791142
Standard deviation for the c-index of 75th quantile: 0.009046369381791142
