In [None]:
# import sys
# from google.colab import drive

# drive.mount('/content/gdrive', force_remount=True)
# SRC_PATH = '/content/gdrive/MyDrive/MP FEB/Colab'
# sys.path.append(SRC_PATH)

# !pip install wandb -qqq
# import wandb
# wandb.login()

# wandb.init(
#     # Set the project where this run will be logged
#     project="baseline_poc", 
#     # We pass a run name (otherwise it’ll be randomly assigned, like sunshine-lollypop-10)
#     #name="experiment 1"
#     # Track hyperparameters and run metadata
#     #config={
#       #"learning_rate": 0.02,
#       #"architecture": "CNN",
#       #"dataset": "CIFAR-100",
#       #"epochs": 10,}
#     )

In [None]:
import numpy as np
import pandas as pd
import torch

import matplotlib.pyplot as plt
%matplotlib inline

DEVICE = 'cuda'

# 1. Import Dataset

In [None]:
from torch.utils.data import DataLoader
from poc_data import PocDataset, display_sample


NUM_CASES = 2**9
BATCH_SIZE = 2**8

poc_raw_dataset = PocDataset(num_cases=NUM_CASES, augment=True)
data_loader = DataLoader(dataset=poc_raw_dataset,
                        batch_size=BATCH_SIZE,
                        shuffle=True,
                        pin_memory=True)

batch_data, batch_data_length, batch_event, batch_tte, _ = next(iter(data_loader))

print(batch_data.shape)
print(batch_data_length.shape)
print(batch_event.shape)

### Data Exploration

In [None]:
#Take a look at the longitudional data
display_sample(batch_data[0], batch_data_length[0], batch_event[0])
display_sample(batch_data[1], batch_data_length[1], batch_event[1])
display_sample(batch_data[2], batch_data_length[2], batch_event[2])
display_sample(batch_data[3], batch_data_length[3], batch_event[3])

### Test Sample

In [None]:
test_poc_raw_dataset = PocDataset(num_cases=1)
test_data_loader = DataLoader(dataset=test_poc_raw_dataset,batch_size=1,pin_memory=True)
test_batch_data, test_batch_data_length, test_batch_event, test_tte, test_meta = next(iter(test_data_loader))
display_sample(test_batch_data[0], test_batch_data_length[0], test_batch_event[0])

# 2. Hyperparameters

In [None]:
from torch.nn import MSELoss
from tqdm import trange, tqdm

from torch.optim import Adam

from deepHit import Encoder, CauseSpecificSubnetwork, DeepHit
from baseline_losses import loss_1_batch, loss_2_batch

NUM_EPOCHS = 1

LEARNING_RATE_ENCODER = 0.001
LEARNING_RATE_CAUSESS = 0.0005

LOSS_1_AMPLIFIER = 1
LOSS_2_AMPLIFIER = 1

RUN_VALIDATION_ROUND = False
VAL_NUM_CASES_RUNTIME = BATCH_SIZE

input_size = 5
output_size = input_size
MAX_LENGTH = 2*12
NUM_CAUSES = 3
hidden_size_encoder = 512
context_size = 256
hidden_cause_size = 512
SIGMA = 0.1

# 3. Defining The Model

In [None]:
# initialize model
encoder = Encoder(input_size, hidden_size_encoder, context_size).to(DEVICE)
causess = CauseSpecificSubnetwork(context_size, hidden_cause_size, input_size, MAX_LENGTH, NUM_CAUSES).to(DEVICE)
DHT = DeepHit(encoder, causess, DEVICE)

# intialize optimizer
optimizer_encoder = Adam(encoder.parameters(), lr=LEARNING_RATE_ENCODER)
optimizer_causess = Adam(causess.parameters(), lr=LEARNING_RATE_CAUSESS)

# 4. Training the model

In [None]:
# start training
for epoch in trange(NUM_EPOCHS):

  epoch_loss = 0

  for batch_number, data in enumerate(data_loader):

    batch_loss = 0

    optimizer_encoder.zero_grad()
    optimizer_causess.zero_grad()

    batch_data, batch_data_length, batch_event, batch_tte, _ = data
    batch_data = batch_data.to(DEVICE)
    batch_data_length = batch_data_length.to(DEVICE)
    batch_event = batch_event.to(DEVICE)
    batch_tte = batch_tte.to(DEVICE)
    
    first_hitting_time_batch = DHT(batch_data, batch_data_length)

    loss1 = LOSS_1_AMPLIFIER*loss_1_batch(first_hitting_time_batch, batch_event, batch_tte, MAX_LENGTH, DEVICE)
    loss2 = LOSS_2_AMPLIFIER*loss_2_batch(first_hitting_time_batch, batch_event, batch_tte, NUM_CAUSES, MAX_LENGTH, SIGMA, DEVICE)

    batch_loss = loss1 + loss2
    batch_loss.backward()

    epoch_loss += batch_loss.detach()

    #wandb.log({"train_loss1": loss1.item(), "train_loss2": loss2.item()})

    optimizer_encoder.step()
    optimizer_causess.step()

  if RUN_VALIDATION_ROUND:
    # validating round
    DHT.eval()

    with torch.no_grad():
      val_poc_raw_dataset = PocDataset(num_cases=VAL_NUM_CASES_RUNTIME)
      val_data_loader = torch.utils.data.DataLoader(val_poc_raw_dataset,batch_size=VAL_NUM_CASES_RUNTIME)
      val_batch_data, val_data_length, val_batch_event, val_batch_tte, _ = next(iter(val_data_loader))
      val_batch_data = val_batch_data.to(DEVICE)
      val_data_length = val_data_length.to(DEVICE)
      val_batch_event = val_batch_event.to(DEVICE)
      val_batch_tte = val_batch_tte.to(DEVICE)

      val_first_hitting_time_batch = DHT(val_batch_data, val_data_length)

      val_loss1 = LOSS_1_AMPLIFIER*loss_1_batch(val_first_hitting_time_batch, val_batch_event, val_batch_tte, MAX_LENGTH, DEVICE)/VAL_NUM_CASES_RUNTIME
      val_loss2 = LOSS_2_AMPLIFIER*loss_2_batch(val_first_hitting_time_batch, val_batch_event, val_batch_tte, NUM_CAUSES, MAX_LENGTH, SIGMA, DEVICE)/VAL_NUM_CASES_RUNTIME

      #wandb.log({"val_loss1": val_loss1.item(), "val_loss2": val_loss2.item()})
      ##wandb.log({"train_epoch_loss" : epoch_loss.item(), "val_epoch_loss" : val_loss1.item() + val_loss2.item(),"epoch": epoch})

    DHT.train()
    # end validating round

  #PATH = "/content/gdrive/MyDrive/MP FEB/Colab/models/baseline_model_v2.pth"
  #torch.save(DHT.state_dict(), PATH)

#wandb.finish() 

# 4. Testing the Model

### Predicting the test sample

In [None]:
#PATH = "C:/Users/marij/Desktop/THESIS ECO/DDHT/DDHT_pytorch/models/model_v11.pth"
#torch.save(DDHT.state_dict(), PATH)
#DDHT.load_state_dict(torch.load(PATH))

In [None]:
test_poc_raw_dataset = PocDataset(num_cases=1)
test_data_loader = DataLoader(dataset=test_poc_raw_dataset,batch_size=1,pin_memory=True)
test_batch_data, test_batch_data_length, test_batch_event, test_batch_tte, test_batch_meta = next(iter(test_data_loader))

display_sample(test_batch_data[0], test_batch_data_length[0], test_batch_event[0])
print(test_batch_data[0,:5,3:])

In [None]:
from utils import plot_fht_and_cif_baseline

DHT.eval();
test_first_hitting_time = DHT(test_batch_data.to(DEVICE), test_batch_data_length.to(DEVICE))

test_first_hitting_time_argmax = test_first_hitting_time.argmax().item()
model_event_prediction = test_first_hitting_time_argmax // MAX_LENGTH
model_tte_prediction = test_first_hitting_time_argmax % MAX_LENGTH
print("the model predicts at time %d the event %d" % (model_tte_prediction, model_event_prediction))

plot_fht_and_cif_baseline(test_first_hitting_time[0], MAX_LENGTH)

### Scoring our model - random test set

In [None]:
VAL_NUM_CASES = 2**6

val_poc_raw_dataset = PocDataset(num_cases=VAL_NUM_CASES)
val_data_loader = torch.utils.data.DataLoader(val_poc_raw_dataset,batch_size=2**6)
val_batch_data, val_batch_data_length, val_batch_event, val_batch_tte, val_batch_meta = next(iter(val_data_loader))
val_batch_data = val_batch_data.to(DEVICE)
val_batch_data_length = val_batch_data_length.to(DEVICE)
val_batch_event = val_batch_event.to(DEVICE)
val_batch_tte = val_batch_tte.to(DEVICE)


In [None]:
DHT.eval()

val_first_hitting_time_batch = DHT(val_batch_data, val_batch_data_length)
val_loss1 = LOSS_1_AMPLIFIER*loss_1_batch(val_first_hitting_time_batch, val_batch_event, val_batch_tte, MAX_LENGTH, DEVICE)/VAL_NUM_CASES
val_loss2 = LOSS_2_AMPLIFIER*loss_2_batch(val_first_hitting_time_batch, val_batch_event, val_batch_tte, NUM_CAUSES, MAX_LENGTH, SIGMA, DEVICE)/VAL_NUM_CASES


print("val_loss1=", val_loss1.item())
print("val_loss2=", val_loss2.item())


### Scoring our model - fixed test set

In [None]:
from utils import plot_fht_and_cif

test_poc_raw_dataset = PocDataset(num_cases=1)
test_data_loader = poc_raw_dataset = PocDataset(num_cases=1, test_set=True, repays=False)
test_batch_data, test_batch_data_length, test_batch_event, test_batch_tte, test_meta = next(iter(test_data_loader))
test_batch_data = test_batch_data.unsqueeze(0).to(DEVICE)
test_batch_data_length = test_batch_data_length.unsqueeze(0).to(DEVICE)
test_batch_event = test_batch_event.unsqueeze(0).to(DEVICE)

DHT.eval()

test_first_hitting_time = DHT(test_batch_data, test_batch_data_length)
print("sample has length %d" % test_batch_data_length[0])
print("sample will experience event %d at time %d, but shows event %d" % (test_meta['ground_truth_event'], test_meta['age'], test_batch_event[0]))


test_first_hitting_time_argmax = test_first_hitting_time.argmax().item()
model_event_prediction = test_first_hitting_time_argmax // MAX_LENGTH
model_tte_prediction = test_first_hitting_time_argmax % MAX_LENGTH
print("the model predicts the event %d at time %d" % (model_event_prediction, model_tte_prediction + 1))

plot_fht_and_cif_baseline(test_first_hitting_time[0], MAX_LENGTH)

In [None]:
# PATH = "/content/gdrive/MyDrive/MP FEB/Colab/models/model_v10.pth"
# torch.save(DDHT.state_dict(), PATH)