In [1]:
from src.data import _load_readmission_dataset, _load_simData
from src.helper_functions import _prepare_rt_tensors
from src.models import *
from src.losses import *
from src.training import *
from src.metrics import *

In [2]:
# set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
x, t, e, d = _load_readmission_dataset(sequential=True)

  data['max_time'] = data.groupby('id')['t.stop'].transform(max)


In [4]:
# x, t, e, d = _load_simData(sequential=True)

In [5]:
readmission_tensor = _prepare_rt_tensors(x, t, e, d)
locals().update(readmission_tensor) # create variables from dictionary

  x = torch.tensor(x, dtype=torch.float32)


In [6]:
# model and training parameters

input_size = len(x[0]) 
output_size = int(max(t))
hidden_size = 8   # Number of units in the RNN layer

model = SimpleGRU(input_size, hidden_size, output_size, 2)

# Instantiate the model

num_epochs = 10000
patience = 3  # Number of epochs to wait for improvement before stopping
best_val_loss = float('inf')
wait = 0
loss_function = recurrent_terminal_loss

model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(),  lr=1e-4, weight_decay=1e-7)

In [7]:
#prepare data loaders

train_dataset = torch.utils.data.TensorDataset(x_train, t_train, e_train, d_train)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32)

val_dataset = torch.utils.data.TensorDataset(x_val, t_val, e_val, d_val)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=32)

In [8]:
train_validate_rt_model(model, train_dataloader, val_dataloader, loss_function, optimizer, num_epochs, patience, print_every=10)

Epoch: 0   Training loss: 15.1134   Validation loss: 13.9719


Epoch: 10   Training loss: 14.4475   Validation loss: 13.4327
Epoch: 20   Training loss: 13.9421   Validation loss: 12.9642
Epoch: 30   Training loss: 13.3933   Validation loss: 12.5372
Epoch: 40   Training loss: 12.9289   Validation loss: 12.1228
Epoch: 50   Training loss: 12.4089   Validation loss: 11.7020
Epoch: 60   Training loss: 11.9593   Validation loss: 11.2677
Epoch: 70   Training loss: 11.4280   Validation loss: 10.8268
Epoch: 80   Training loss: 10.9830   Validation loss: 10.3950
Epoch: 90   Training loss: 10.4318   Validation loss: 9.9897
Epoch: 100   Training loss: 10.0720   Validation loss: 9.6172
Epoch: 110   Training loss: 9.7550   Validation loss: 9.2856
Epoch: 120   Training loss: 9.4719   Validation loss: 8.9915
Epoch: 130   Training loss: 9.2031   Validation loss: 8.7314
Epoch: 140   Training loss: 8.9787   Validation loss: 8.4996
Epoch: 150   Training loss: 8.6064   Validation loss: 8.2930
Epoch: 160   Training loss: 8.6190   Validation loss: 8.1096
Epoch: 170   Tr

In [9]:
test_predictions = model(x_test).squeeze(-1)

survival_predictions = test_predictions[:, :, 0:1].squeeze(-1)
recurrent_predictions = test_predictions[:, :, 1:2].squeeze(-1)

In [10]:
calculate_survival_metrics(survival_predictions, d_train, t_train, d_test, t_test, d_val, t_val)

{'Brier Score': array([0.01919344, 0.02401563, 0.04047371, 0.03178264, 0.03209171,
        0.02752203, 0.02841207, 0.02636668, 0.02722039, 0.02905341,
        0.02462574, 0.0253025 , 0.01848155, 0.01853445, 0.01476412,
        0.00860042]),
 '25th Quantile CI': 0.9598488074374354,
 '50th Quantile CI': 0.957478908064154,
 '75th Quantile CI': 0.9514659842100678}

In [17]:
import torch
recurrent_cindex(recurrent_predictions, e_test, t_test, 12)

[0.48411552346570397,
 0.5084164588528678,
 0.5002613695765813,
 0.4859154929577465,
 0.5352668213457077,
 0.5566622399291722,
 0.5678713089466726,
 0.5915102389078498,
 0.6328157784305497,
 0.6627955205309001,
 0.6602277348515657,
 0.6526859504132232,
 0.607158446093409,
 0.5935469900389779,
 0.5512007249660172,
 0.510643330179754,
 0.5016908212560387,
 0.5016908212560387,
 0.5016908212560387,
 0.5016908212560387,
 0.5016908212560387,
 0.5016908212560387,
 0.5016908212560387,
 0.5016908212560387]

# Experiements

In [12]:
max_time = 10
out_risk = recurrent_predictions
event_times = e_test
cindices = []
for current_time in range(2, max_time, 1):
    expected_number_of_events = torch.zeros(out_risk.size(0))
    for i in range(out_risk.size(0)):
        clamped_time = min(current_time, t_test[i].item())
        expected_number_of_events[i] = out_risk[i, 0:clamped_time].sum()
    mask = event_times < current_time
    observed_number_of_events = mask.sum(dim=1)
    concordant_pairs = 0
    total_pairs = 0

    n = len(expected_number_of_events)
    concordant_pairs = 0
    permissible_pairs = 0
    tied_risk_pairs = 0

    for i in range(n):
        for j in range(n):
            if i != j:
                if observed_number_of_events[i] < observed_number_of_events[j]:
                    permissible_pairs += 1
                    if expected_number_of_events[i] > expected_number_of_events[j]:
                        concordant_pairs += 1
                    elif expected_number_of_events[i] == expected_number_of_events[j]:
                        tied_risk_pairs += 1
                        # print(tied_risk_pairs)

    cindices.append((concordant_pairs + 0.5 * tied_risk_pairs) / permissible_pairs)

    # for i in range(0, len(expected_number_of_events)):
    #     for j in range(0, len(expected_number_of_events)):
    #         if i == j: 
    #             continue
    #         if (expected_number_of_events[i] > expected_number_of_events[j]) and (observed_number_of_events[i] > observed_number_of_events[j]):
    #             concordant_pairs += 1
    #         total_pairs += 1  
    # cindices.append(concordant_pairs / total_pairs)

print(cindices)


[0.48411552346570397, 0.5084164588528678, 0.5002613695765813, 0.4859154929577465, 0.5352668213457077, 0.5566622399291722, 0.5678713089466726, 0.5915102389078498]


In [13]:
expected_number_of_events


tensor([1.8188, 1.1345, 2.7816, 1.7172, 2.0340, 1.9411, 1.6031, 1.6037, 0.6998,
        1.6031, 0.4780, 1.7606, 1.8188, 1.6423, 2.0340, 1.0943, 0.6595, 1.6037,
        1.7172, 1.9411, 1.7416, 1.6783, 1.7172, 1.9950, 1.8552, 1.6423, 1.8188,
        2.4933, 1.0767, 1.6423, 1.6423, 1.6972, 0.4361, 1.6037, 0.4908, 1.6423,
        1.6031, 0.2635, 2.4933, 1.7172, 1.6031, 1.7172, 1.1955, 0.9972, 1.7172,
        1.7606, 0.2626, 1.6423, 0.8502, 1.8188, 1.8792, 1.7819, 1.4214, 1.6423,
        1.8287, 1.6037, 0.9547, 1.8188, 1.5352, 1.7416, 1.5352, 0.4661, 1.0714,
        1.6031, 1.7606, 1.6037, 0.8076, 1.8188, 1.7172, 2.5002, 1.7416, 0.6256,
        1.6496, 1.6474, 1.6423, 0.6256, 3.0829, 1.7172, 1.8188, 1.6037, 1.1714],
       grad_fn=<CopySlices>)

In [14]:
(d_train == 0).sum()

tensor(253)

In [15]:
model(x_train).squeeze(-1)[:, :, 0:1].squeeze(-1)

tensor([[0.1813, 0.0900, 0.0484,  ..., 0.0168, 0.0168, 0.0168],
        [0.0292, 0.0119, 0.0106,  ..., 0.0105, 0.0105, 0.0105],
        [0.0414, 0.0153, 0.0124,  ..., 0.0118, 0.0118, 0.0118],
        ...,
        [0.0223, 0.0109, 0.0102,  ..., 0.0101, 0.0101, 0.0101],
        [0.0415, 0.0126, 0.0103,  ..., 0.0100, 0.0100, 0.0100],
        [0.0357, 0.0127, 0.0108,  ..., 0.0106, 0.0106, 0.0106]],
       grad_fn=<SqueezeBackward1>)

In [16]:
indices = torch.nonzero(d_train == 1, as_tuple=True)[0]

# Get the first occurrence
dead_index = indices[50].item()
dead_index

IndexError: index 50 is out of bounds for dimension 0 with size 29

In [None]:
train_predictions = model(x_train).squeeze(-1)[:, :, 0:1].squeeze(-1)
h = train_predictions[dead_index]
t = t_train[dead_index]
print(h, t)

h[t-1]

tensor([0.3354, 0.0651, 0.0542, 0.0346, 0.0192, 0.0094, 0.0049, 0.0042, 0.0071,
        0.0150, 0.0225, 0.0205], grad_fn=<SelectBackward0>) tensor(1)


tensor(0.3354, grad_fn=<SelectBackward0>)

In [None]:
recurrent_predictions[2]

tensor([0.9104, 0.8103, 0.7153, 0.5897, 0.5036, 0.4383, 0.3696, 0.3012, 0.2769,
        0.3584, 0.6184, 0.8763], grad_fn=<SelectBackward0>)

In [None]:
out_risk = recurrent_predictions

current_time = 20

expected_number_of_events = torch.zeros(out_risk.size(0))

for i in range(out_risk.size(0)):
    clamped_time = min(current_time, t_test[i].item())
    expected_number_of_events[i] = out_risk[i, 0:clamped_time].sum()

In [None]:
mask = e_test < current_time
observed_number_of_events = mask.sum(dim=1)

print(observed_number_of_events[6])
print(expected_number_of_events[6])

tensor(2)
tensor(5.4015, grad_fn=<SelectBackward0>)


In [None]:
times = np.arange(1, max(t_test).long(), 1)
times

array([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11])