In [1]:
from src.data import _load_readmission_dataset
from src.helper_functions import _prepare_rt_tensors
from src.models import *
from src.losses import *
from src.training import *
from src.metrics import *

In [2]:
# set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
x, t, e, d = _load_readmission_dataset(sequential=True)

  data['max_time'] = data.groupby('id')['t.stop'].transform(max)


In [4]:
(torch.tensor(t) == 0).sum() ## another problem. 

tensor(0)

In [5]:
readmission_tensor = _prepare_rt_tensors(x, t, e, d)
locals().update(readmission_tensor) # create variables from dictionary

  x = torch.tensor(x, dtype=torch.float32)


In [6]:
# model and training parameters

input_size = len(x[0]) 
output_size = int(max(t))
hidden_size = 8   # Number of units in the RNN layer

model = SimpleRNN(input_size, hidden_size, output_size, 2)

# Instantiate the model

num_epochs = 3000
patience = 3  # Number of epochs to wait for improvement before stopping
best_val_loss = float('inf')
wait = 0
loss_function = recurrent_terminal_loss

model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(),  lr=1e-4)

In [7]:
#prepare data loaders

train_dataset = torch.utils.data.TensorDataset(x_train, t_train, e_train, d_train)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32)

val_dataset = torch.utils.data.TensorDataset(x_val, t_val, e_val, d_val)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=32)

In [8]:
train_validate_rt_model(model, train_dataloader, val_dataloader, loss_function, optimizer, num_epochs, patience, print_every=10)

Epoch: 0   Training loss: 4.9917   Validation loss: 4.7852


Epoch: 10   Training loss: 4.8317   Validation loss: 4.6499
Epoch: 20   Training loss: 4.6845   Validation loss: 4.5214
Epoch: 30   Training loss: 4.5485   Validation loss: 4.3997
Epoch: 40   Training loss: 4.4232   Validation loss: 4.2852
Epoch: 50   Training loss: 4.3091   Validation loss: 4.1791
Epoch: 60   Training loss: 4.2069   Validation loss: 4.0828
Epoch: 70   Training loss: 4.1170   Validation loss: 3.9968
Epoch: 80   Training loss: 4.0389   Validation loss: 3.9211
Epoch: 90   Training loss: 3.9713   Validation loss: 3.8546
Epoch: 100   Training loss: 3.9127   Validation loss: 3.7963
Epoch: 110   Training loss: 3.8618   Validation loss: 3.7451
Epoch: 120   Training loss: 3.8173   Validation loss: 3.6998
Epoch: 130   Training loss: 3.7781   Validation loss: 3.6594
Epoch: 140   Training loss: 3.7434   Validation loss: 3.6232
Epoch: 150   Training loss: 3.7125   Validation loss: 3.5906
Epoch: 160   Training loss: 3.6847   Validation loss: 3.5610
Epoch: 170   Training loss: 3.659

In [9]:
test_predictions = model(x_test).squeeze(-1)

survival_predictions = test_predictions[:, :, 0:1].squeeze(-1)
recurrent_predictions = test_predictions[:, :, 1:2].squeeze(-1)

In [10]:
# calculate_survival_metrics(survival_predictions, d_train, t_train, d_test, t_test, d_val, t_val)

# Experiements

In [11]:
(d_train == 0).sum()

tensor(253)

In [12]:
model(x_train).squeeze(-1)[:, :, 0:1].squeeze(-1)

tensor([[0.5164, 0.1291, 0.0471, 0.0295, 0.0217, 0.0184],
        [0.0741, 0.0287, 0.0213, 0.0194, 0.0167, 0.0146],
        [0.1041, 0.0386, 0.0267, 0.0219, 0.0178, 0.0155],
        ...,
        [0.0311, 0.0124, 0.0101, 0.0094, 0.0084, 0.0079],
        [0.1123, 0.0360, 0.0260, 0.0203, 0.0133, 0.0095],
        [0.0487, 0.0153, 0.0106, 0.0099, 0.0094, 0.0093]],
       grad_fn=<SqueezeBackward1>)

In [20]:
indices = torch.nonzero(d_train == 1, as_tuple=True)[0]

# Get the first occurrence
dead_index = indices[0].item()
dead_index

7

In [21]:
train_predictions = model(x_train).squeeze(-1)[:, :, 0:1].squeeze(-1)
h = train_predictions[dead_index]
t = t_train[dead_index]
print(h, t)

h[t-1]

tensor([0.1123, 0.0360, 0.0260, 0.0203, 0.0133, 0.0095],
       grad_fn=<SelectBackward0>) tensor(5)


tensor(0.0133, grad_fn=<SelectBackward0>)

In [15]:
out_risk = recurrent_predictions

current_time = 20

expected_number_of_events = torch.zeros(out_risk.size(0))

for i in range(out_risk.size(0)):
    clamped_time = min(current_time, t_test[i].item())
    expected_number_of_events[i] = out_risk[i, 0:clamped_time].sum()

print(expected_number_of_events)


tensor([2.2433, 0.7087, 1.3932, 2.5085, 1.4916, 2.6302, 1.4803, 2.0606, 0.6579,
        0.9303, 0.8624, 2.4814, 2.2433, 1.4750, 2.8014, 0.5470, 0.4347, 2.0606,
        1.8625, 2.6302, 1.3541, 1.4002, 1.8625, 1.3940, 2.6306, 2.0962, 2.9550,
        1.6421, 1.2247, 2.0962, 1.4750, 0.8424, 0.8012, 1.4637, 0.6287, 2.0962,
        1.4803, 0.4626, 1.6421, 1.8625, 0.9303, 1.2525, 0.7703, 0.8021, 1.2525,
        1.8700, 0.4592, 1.4750, 0.6294, 1.5826, 0.8624, 2.0967, 0.8624, 0.9757,
        1.3598, 1.4637, 0.8624, 1.5826, 1.2525, 1.3541, 1.2525, 0.7877, 0.7945,
        0.9303, 1.8700, 1.4637, 0.7765, 2.2433, 1.8625, 3.0073, 1.3541, 0.4511,
        1.4257, 0.8659, 1.4750, 0.4511, 1.9632, 1.2525, 1.5826, 0.9256, 0.7765],
       grad_fn=<CopySlices>)


In [16]:
mask = e_test < current_time
observed_number_of_events = mask.sum(dim=1)

print(observed_number_of_events[4])
print(expected_number_of_events[4])

tensor(5)
tensor(1.4916, grad_fn=<SelectBackward0>)


In [17]:
h = out_risk[2]
e = e_test[2]
t = t_test[2]

mask = torch.ones_like(h[0:t], dtype=torch.bool)
# Set mask to False for time points where events occurred
mask[e[e < t]] = False
print(mask)
# Calculate the negative log-likelihood for both the event occurrences and non-occurrences
-1 * (torch.sum(torch.log(h[0:t][e[e < t]])) + torch.sum(torch.log(1 - h[0:t][mask])))


tensor([ True, False])


tensor(1.9443, grad_fn=<MulBackward0>)

In [18]:
print("h", h)
print("e, t", e, t)
print(h[0:t][e[e < t]])

h tensor([0.7703, 0.6229, 0.5700, 0.6658, 0.7384, 0.7658],
       grad_fn=<SelectBackward0>)
e, t tensor([  1,   2,   2,   2, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100,
        100, 100, 100, 100, 100, 100, 100, 100, 100]) tensor(2)
tensor([0.6229], grad_fn=<IndexBackward0>)


In [19]:
import pandas as pd
data = pd.read_csv('data/readmission.csv', delimiter=';')
torch.tensor(0.0)

tensor(0.)