In [11]:
from src.data import _load_readmission_dataset, _load_simData
from src.helper_functions import _prepare_rt_tensors
from src.models import *
from src.losses import *
from src.training import *
from src.metrics import *

In [12]:
# set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [13]:
x, t, e, d = _load_readmission_dataset(sequential=True)

  data['max_time'] = data.groupby('id')['t.stop'].transform(max)


In [14]:
readmission_tensor = _prepare_rt_tensors(x, t, e, d)
locals().update(readmission_tensor) # create variables from dictionary

{'x_train': tensor([[0., 1., 1.,  ..., 1., 0., 0.],
         [1., 0., 0.,  ..., 1., 0., 0.],
         [1., 0., 0.,  ..., 0., 0., 1.],
         ...,
         [1., 0., 0.,  ..., 1., 0., 0.],
         [1., 0., 1.,  ..., 0., 0., 1.],
         [1., 0., 0.,  ..., 0., 0., 0.]]),
 'x_val': tensor([[1., 0., 1., 0., 0., 1., 0., 0., 0., 0.],
         [1., 0., 1., 0., 0., 1., 0., 1., 0., 0.],
         [1., 0., 1., 0., 1., 0., 0., 1., 0., 0.],
         [1., 0., 0., 1., 0., 1., 0., 1., 0., 0.],
         [0., 1., 1., 0., 1., 0., 0., 1., 0., 0.],
         [0., 1., 1., 0., 1., 0., 0., 1., 0., 0.],
         [0., 1., 1., 0., 1., 0., 0., 1., 0., 0.],
         [1., 0., 1., 0., 0., 1., 0., 1., 0., 0.],
         [0., 1., 0., 1., 1., 0., 0., 1., 0., 0.],
         [1., 0., 0., 1., 1., 0., 0., 0., 1., 0.],
         [1., 0., 1., 0., 0., 1., 0., 1., 0., 0.],
         [0., 1., 1., 0., 0., 0., 1., 0., 0., 1.],
         [0., 1., 1., 0., 1., 0., 0., 1., 0., 0.],
         [0., 1., 0., 1., 1., 0., 0., 1., 0., 0.],
    

In [15]:
# model and training parameters

input_size = len(x[0]) 
output_size = int(max(t))
hidden_size = 8   # Number of units in the RNN layer

model = SimpleGRU(input_size, hidden_size, output_size, 1, 0.1)

# Instantiate the model

num_epochs = 10000
patience = 3  # Number of epochs to wait for improvement before stopping
best_val_loss = float('inf')
wait = 0
loss_function = survival_loss

model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(),  lr=1e-4, weight_decay=1e-7)

In [16]:
#prepare data loaders

train_dataset = torch.utils.data.TensorDataset(x_train, t_train, e_train, d_train)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32)

val_dataset = torch.utils.data.TensorDataset(x_val, t_val, e_val, d_val)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=32)

In [17]:
train_validate_survival_model(model, train_dataloader, val_dataloader, loss_function, optimizer, num_epochs, patience, print_every=10)

Epoch: 0   Training loss: 2.3595   Validation loss: 2.0748


Epoch: 10   Training loss: 2.2177   Validation loss: 1.9561
Epoch: 20   Training loss: 2.0831   Validation loss: 1.8401
Epoch: 30   Training loss: 1.9484   Validation loss: 1.7211
Epoch: 40   Training loss: 1.8238   Validation loss: 1.6002
Epoch: 50   Training loss: 1.6755   Validation loss: 1.4779
Epoch: 60   Training loss: 1.5386   Validation loss: 1.3572
Epoch: 70   Training loss: 1.4335   Validation loss: 1.2415
Epoch: 80   Training loss: 1.3428   Validation loss: 1.1330
Epoch: 90   Training loss: 1.2353   Validation loss: 1.0332
Epoch: 100   Training loss: 1.1285   Validation loss: 0.9445
Epoch: 110   Training loss: 1.0716   Validation loss: 0.8655
Epoch: 120   Training loss: 1.0110   Validation loss: 0.7973
Epoch: 130   Training loss: 0.9455   Validation loss: 0.7375
Epoch: 140   Training loss: 0.8882   Validation loss: 0.6859
Epoch: 150   Training loss: 0.8499   Validation loss: 0.6413
Epoch: 160   Training loss: 0.8522   Validation loss: 0.6024
Epoch: 170   Training loss: 0.791

In [18]:
test_predictions = model(x_test).squeeze(-1)

recurrent_predictions = test_predictions.squeeze(-1)

In [19]:
import torch
calculate_survival_metrics(test_predictions, d_train, t_train, d_test, t_test, d_val, t_val)

{'Brier Score': array([0.04079208, 0.03520956, 0.03196074, 0.02618879]),
 '25th Quantile CI': 0.9592105263157895,
 '50th Quantile CI': 0.9618800498922264,
 '75th Quantile CI': 0.959701767028925}