In [31]:
from src.data import _load_readmission_dataset, _load_simData
from src.helper_functions import _prepare_rt_tensors, _prepare_dataloaders, set_seed
from src.models import *
from src.losses import *    
from src.training import *
from src.metrics import *

In [32]:
set_seed(42)

In [33]:
# set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [34]:
x, t, e, d = _load_readmission_dataset(sequential=True)

  data['max_time'] = data.groupby('id')['t.stop'].transform(max)


In [35]:
readmission_tensor = _prepare_rt_tensors(x, t, e, d)
locals().update(readmission_tensor) # create variables from dictionary

In [36]:
train_dataloader, val_dataloader, test_dataloader = _prepare_dataloaders(readmission_tensor, batch_size=32)

In [37]:
# model and training parameters

input_size = len(x[0]) 
output_size = int(max(t))
hidden_size = 8   # Number of units in the RNN layer

model = SimpleLSTM(input_size, hidden_size, output_size, 2, 0.1)

# Instantiate the model

num_epochs = 10000
patience = 3  # Number of epochs to wait for improvement before stopping
best_val_loss = float('inf')
wait = 0
loss_function = recurrent_terminal_loss

model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(),  lr=1e-4, weight_decay=1e-7)

In [38]:
train_validate_rt_model(model, train_dataloader, val_dataloader, loss_function, optimizer, num_epochs, patience, print_every=10)

Epoch: 0   Training loss: 5.3275   Validation loss: 5.2912


Epoch: 10   Training loss: 5.2624   Validation loss: 5.2202
Epoch: 20   Training loss: 5.1897   Validation loss: 5.1505
Epoch: 30   Training loss: 5.1294   Validation loss: 5.0754
Epoch: 40   Training loss: 5.0242   Validation loss: 4.9901
Epoch: 50   Training loss: 4.9167   Validation loss: 4.8929
Epoch: 60   Training loss: 4.8408   Validation loss: 4.7810
Epoch: 70   Training loss: 4.6890   Validation loss: 4.6572
Epoch: 80   Training loss: 4.5527   Validation loss: 4.5297
Epoch: 90   Training loss: 4.4456   Validation loss: 4.4060
Epoch: 100   Training loss: 4.3008   Validation loss: 4.2927
Epoch: 110   Training loss: 4.2049   Validation loss: 4.1894
Epoch: 120   Training loss: 4.1191   Validation loss: 4.0974
Epoch: 130   Training loss: 4.0681   Validation loss: 4.0131
Epoch: 140   Training loss: 3.9881   Validation loss: 3.9371
Epoch: 150   Training loss: 3.9175   Validation loss: 3.8708
Epoch: 160   Training loss: 3.8372   Validation loss: 3.8130
Epoch: 170   Training loss: 3.815

In [39]:
test_predictions = model(x_test).squeeze(-1)

survival_predictions = test_predictions[:, :, 0:1].squeeze(-1)
recurrent_predictions = test_predictions[:, :, 1:2].squeeze(-1)

In [40]:
calculate_survival_metrics(survival_predictions, d_train, t_train, d_test, t_test)

{'Time 3': 0.9381578947368422,
 'Time 4': 0.9226709583528018,
 'Time 5': 0.8899967154032816,
 'Time 6': 0.874748735360172}

In [41]:
recurrent_cindex(recurrent_predictions, e_test, t_test, max_time = 6, horizons = "all", tolerance = 0.1)

{'Time 1': 0.8267326732673267,
 'Time 2': 0.5223214285714286,
 'Time 3': 0.42076228686058176,
 'Time 4': 0.4842391304347826,
 'Time 5': 0.5207889125799574,
 'Time 6': 0.5207889125799574}

In [42]:
from sklearn.utils import resample

nsamples = 100  # Number of bootstrap samples
survival_cis = []
recurrent_cis = []

for sample in range(nsamples):
    
    # Resample the test data
    x_test_resampled, t_test_resampled, e_test_resampled, d_test_resampled = resample(x_test, t_test, e_test, d_test, replace=True)
    
    # Make predictions on the resampled test data
    test_predictions = model(x_test_resampled).squeeze(-1)
    
    survival_predictions = test_predictions[:, :, 0:1].squeeze(-1)
    recurrent_predictions = test_predictions[:, :, 1:2].squeeze(-1)
    try: 
        # Calculate metrics for the resampled test data
        survival_cis.append(calculate_survival_metrics(survival_predictions, d_train, t_train, d_test_resampled, t_test_resampled))
        recurrent_cis.append(recurrent_cindex(recurrent_predictions, e_test_resampled, t_test_resampled, 6))
    except: 
        continue


In [43]:
def compute_metric_ci(cis, time_key, rounding = 2):
    cis_time = [cis[i][time_key] for i in range(len(cis))]
    survival_ci_mean = np.mean(cis_time)
    survival_ci_std = np.std(np.array(cis_time))
    return round(survival_ci_mean, rounding), round(survival_ci_std, rounding)

In [44]:
survival_cis_3_mean, survival_cis_3_std = compute_metric_ci(survival_cis, "Time 3")
survival_cis_4_mean, survival_cis_4_std = compute_metric_ci(survival_cis, "Time 4")
survival_cis_5_mean, survival_cis_5_std = compute_metric_ci(survival_cis, "Time 5")
survival_cis_6_mean, survival_cis_6_std = compute_metric_ci(survival_cis, "Time 6")

print(f"Time 3: Mean = {survival_cis_3_mean}, Std = {survival_cis_3_std}")
print(f"Time 4: Mean = {survival_cis_4_mean}, Std = {survival_cis_4_std}")
print(f"Time 5: Mean = {survival_cis_5_mean}, Std = {survival_cis_5_std}")
print(f"Time 6: Mean = {survival_cis_6_mean}, Std = {survival_cis_6_std}")

Time 3: Mean = 0.94, Std = 0.03
Time 4: Mean = 0.92, Std = 0.05
Time 5: Mean = 0.89, Std = 0.08
Time 6: Mean = 0.87, Std = 0.09


In [45]:
recurrent_cis_1_mean, recurrent_cis_1_std = compute_metric_ci(recurrent_cis, "Time 1")
recurrent_cis_2_mean, recurrent_cis_2_std = compute_metric_ci(recurrent_cis, "Time 2")
recurrent_cis_3_mean, recurrent_cis_3_std = compute_metric_ci(recurrent_cis, "Time 3")
recurrent_cis_4_mean, recurrent_cis_4_std = compute_metric_ci(recurrent_cis, "Time 4")
recurrent_cis_5_mean, recurrent_cis_5_std = compute_metric_ci(recurrent_cis, "Time 5")
recurrent_cis_6_mean, recurrent_cis_6_std = compute_metric_ci(recurrent_cis, "Time 6")

print(f"Time 1: Mean = {recurrent_cis_1_mean}, Std = {recurrent_cis_1_std}")
print(f"Time 2: Mean = {recurrent_cis_2_mean}, Std = {recurrent_cis_2_std}")
print(f"Time 3: Mean = {recurrent_cis_3_mean}, Std = {recurrent_cis_3_std}")
print(f"Time 4: Mean = {recurrent_cis_4_mean}, Std = {recurrent_cis_4_std}")
print(f"Time 5: Mean = {recurrent_cis_5_mean}, Std = {recurrent_cis_5_std}")
print(f"Time 6: Mean = {recurrent_cis_6_mean}, Std = {recurrent_cis_6_std}")

Time 1: Mean = 0.84, Std = 0.03
Time 2: Mean = 0.53, Std = 0.06
Time 3: Mean = 0.44, Std = 0.05
Time 4: Mean = 0.44, Std = 0.06
Time 5: Mean = 0.48, Std = 0.07
Time 6: Mean = 0.48, Std = 0.07
