Compare AdamW reg parameter for a fixed architecture that overfits with ADAM. 

## Init

In [15]:
import torch
import torch.nn as nn
import math
import ukko 
import importlib
import pandas as pd
import numpy as np
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from torchsurv.loss.cox import neg_partial_log_likelihood
from torchsurv.loss.weibull import neg_log_likelihood, log_hazard, survival_function
from torchsurv.metrics.brier_score import BrierScore
from torchsurv.metrics.cindex import ConcordanceIndex
from torchsurv.metrics.auc import Auc
#from torchsurv.stats.kaplan_meier import KaplanMeierEstimator
# for cuda cleanups:
import gc

print("Libraries loaded")

Libraries loaded


### Custom function definitions

In [24]:
def plot_losses(train_losses, val_losses, title: str = "Cox") -> None:

    #train_losses = torch.stack(train_losses) / train_losses[0]
    #val_losses = torch.stack(val_losses) / val_losses[0]
    train_losses = np.array(train_losses)
    val_losses = np.array(val_losses)
    train_losses = train_losses / train_losses[0]
    val_losses = val_losses / val_losses[0]

    plt.plot(train_losses, label="training")
    plt.plot(val_losses, label="validation")
    plt.legend()
    plt.xlabel("Epochs")
    plt.ylabel("Normalized loss")
    plt.title(title)
    plt.yscale("log")
    plt.show()


class ukkosurv_dataset(Dataset):
    """ "Custom dataset for ukko-torcsurv use in df format"""

    # defining values in the constructor
    def __init__(self, df: pd.DataFrame):
        #self.df = df
        df_x, data_3d = ukko.utils.convert_to_3d_df(df.iloc[:,3:].fillna(-1))
        df_y = df_train.iloc[:,:3]
        
        self.df_y = df_y        # Dataframe with survival data, e.g. OSS_status, OSS_days
        self.data_3d = data_3d  # numpy array with 3D feature data: patients, features, time 


    # Getting data size/length
    def __len__(self):
        return len(self.data_3d)

    # Getting the data samples
    def __getitem__(self, idx):
        y = self.df_y.iloc[idx,:]
        # Targets
        event = torch.tensor(y["OSS_status"]).bool()
        time = torch.tensor(y["OSS_days"]).float()
        # Predictors
        x = torch.tensor(self.data_3d[idx,:,:]).float()
        return x, (event, time)

In [116]:
# Define training function
import copy
def train_model_simple(
            model,
            dataloader_train,
            dataloader_val,
            optimizer = None,
            n_epochs = 100,
            learning_rate = 0.01,
            device='cuda'
        ):

    dtype=torch.float32
    
    # Initialize optimizer if not provided
    #if optimizer is None:
    #    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    
    # Initiate empty list to store the loss on the train and validation sets
    train_losses = []
    val_losses = []
    best_epoch = 0
    best_loss = float('inf')
    best_model_state = None
    
    # Get device and move model
    #device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    #model = model.to(device=device, dtype=dtype)
    
    # training loop
    for epoch in range(n_epochs):
        epoch_loss = 0.0 #torch.tensor(0.0)
        #model = model.to(dtype=dtype, device=device) #
        model.train()
        #print(f"Model start   : {list(model.parameters())[3]}")
        for i, batch in enumerate(dataloader_train):
            x, (event, time) = batch
            #x = x.to(dtype=dtype, device=device, non_blocking=True)
            #event = event.to(device, non_blocking=True)
            #time = time.to(device, non_blocking=True)
            
            optimizer.zero_grad()
            log_hz, feature_weights, time_weights = model(x)  # shape = (batchsize, 1)
            loss = neg_partial_log_likelihood(log_hz, event, time, reduction="mean")#.to(dtype=dtype, device=device, non_blocking=True)
            #print(f"loss dtype: {loss.dtype}")
            #print(f"loss cuda:  {loss.is_cuda}")
            #print(f"weights dtype: {feature_weights.dtype, time_weights.dtype}")
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item() #.detach().to("cpu").item()
            print(f"Epoch {epoch}, Batch {i}, Loss: {loss.item():.4f}")

            # check model paramters are updating
            #print(f"Model training: {model(x)[0][:3]}")
            #print(f"Model training: {list(model.parameters())[3]}")
            #for name, param in model.named_parameters():
            #  print(f"{name} mean: {param.data.mean().item()}")

            # Free up memory for unused tensors
            #del x, event, time, log_hz, feature_weights, time_weights, loss
            
            #gc.collect()
            #torch.cuda.empty_cache()
        
        # Reccord loss on train and test sets
        epoch_loss /= i + 1
        train_losses.append(epoch_loss)
        #model.eval()
        #print("Eval mode")
        with torch.no_grad():
            x, (event, time) = next(iter(dataloader_val))
            #x = x.to(device, non_blocking=True)
            #event = event.to(device, non_blocking=True)
            #time = time.to(device, non_blocking=True)
            log_hz, feature_weights, time_weights = model(x)
            #print(f"log_hz: {log_hz}")
            val_loss = neg_partial_log_likelihood(log_hz, event, time, reduction="mean").item() #.detach().to("cpu")
            print(f"Validation loss: {val_loss}")
            val_losses.append(val_loss)
            # Save best model based on validation loss
            if val_loss < best_loss:
                best_epoch = epoch
                best_loss = val_loss
                best_model_state = copy.deepcopy(model.to('cpu').state_dict())
        
            # check model paramters are updating
            #print(f"Model val     : {model(x)[0][:3]}")
            #print(f"Model val     : {list(model.parameters())[3]}")
            #for name, param in model.named_parameters():
            #  print(f"{name} mean: {param.data.mean().item()}")



        del x, event, time, log_hz, feature_weights, time_weights    
        torch.cuda.empty_cache()
        
        # Display progress
        if epoch % 1 == 0: #(n_epochs // 5) == 0:
            print(f"    Epoch: {epoch:03}, Training loss: {train_losses[-1]:0.2f}, Validation loss: {val_losses[-1]:0.2f}")
    
    # Load best model if validation was used
    if dataloader_val and best_model_state:
        model.load_state_dict(best_model_state)

    gc.collect()
    
    return model, train_losses, val_losses, best_loss, best_epoch #if dataloader_val else avg_train_loss.item()

## Load data

In [108]:
#Load tidy data
print("Loading tidy data")
df_xy = pd.read_csv("data/df_xy_synth_v1.csv")

# create train, validation and test datasets: IMPUTE nan: -1
df_train = df_xy.fillna(-1)
df_test = df_train.sample(n=1, random_state=42)
df_train = df_train.drop(df_test.index)
df_val = df_train.sample(n=400, random_state=42)
df_train = df_train.drop(df_val.index)

print(f"Train: {df_train.shape}")
print(f"Val  : {df_val.shape}")
print(f"Test : {df_test.shape}")

Loading tidy data
Train: (599, 273)
Val  : (400, 273)
Test : (1, 273)


# Ukko

### Dataloaders: train, val, test

In [None]:
importlib.reload(ukko)
importlib.reload(ukko.utils)

# Dataloader
BATCH_SIZE = 300#512
dataloader_train = DataLoader(
    ukkosurv_dataset(df_train), batch_size=BATCH_SIZE, shuffle=True
)
dataloader_val = DataLoader(
    #ukkosurv_dataset(df_val), batch_size=len(df_val), shuffle=False
    ukkosurv_dataset(df_train), batch_size=len(df_train), shuffle=False
)
dataloader_test = DataLoader(
    ukkosurv_dataset(df_test), batch_size=len(df_test), shuffle=False
)

### Define model

In [110]:
importlib.reload(ukko.core)

# Get feature and time dimensions
x, (event, time) = next(iter(dataloader_train))
num_features, num_timepoints = x.size(1), x.size(2)
print(f"Number of features: {num_features}, Number of timepoints: {num_timepoints}")

# Initialize model
# DualAttentionRegressor1(self, n_features, time_steps, d_model=128, n_heads=8, dropout=0.1, n_modules=1)
def initmodel():
  model = ukko.core.DualAttentionRegressor1(
    n_features=num_features,
    time_steps=num_timepoints,
    d_model=8,
    n_heads=4,
    n_kv_heads=2,
    dropout=0.2,
    n_modules=2
  )
  return model

model = initmodel()
model

Number of features: 10, Number of timepoints: 27


DualAttentionRegressor1(
  (modules_list): ModuleList(
    (0-1): 2 x DualAttentionModule(
      (input_projection): Linear(in_features=1, out_features=8, bias=True)
      (pos_encoder): PositionalEncoding()
      (feature_attention): GroupedQueryAttention(
        (W_q): Linear(in_features=8, out_features=8, bias=True)
        (W_k): Linear(in_features=8, out_features=4, bias=True)
        (W_v): Linear(in_features=8, out_features=4, bias=True)
        (W_o): Linear(in_features=8, out_features=8, bias=True)
        (dropout): Dropout(p=0.2, inplace=False)
      )
      (feature_norm): LayerNorm((8,), eps=1e-05, elementwise_affine=True)
      (feature_ff): FeedForward(
        (linear1): Linear(in_features=8, out_features=2048, bias=True)
        (linear2): Linear(in_features=2048, out_features=8, bias=True)
        (dropout): Dropout(p=0.2, inplace=False)
        (activation): LeakyReLU(negative_slope=0.01)
      )
      (feature_ff_norm): LayerNorm((8,), eps=1e-05, elementwise_affine

## Train model

In [46]:
for name, param in model.named_parameters():
    print(f"{name}: {param.data.mean().item():.6f}")

modules_list.0.input_projection.weight: 0.354489
modules_list.0.input_projection.bias: -0.367049
modules_list.0.feature_attention.W_q.weight: -0.035006
modules_list.0.feature_attention.W_q.bias: -0.037683
modules_list.0.feature_attention.W_k.weight: -0.037529
modules_list.0.feature_attention.W_k.bias: -0.033596
modules_list.0.feature_attention.W_v.weight: -0.038815
modules_list.0.feature_attention.W_v.bias: -0.015222
modules_list.0.feature_attention.W_o.weight: -0.011895
modules_list.0.feature_attention.W_o.bias: 0.157283
modules_list.0.feature_norm.weight: 1.000000
modules_list.0.feature_norm.bias: 0.000000
modules_list.0.feature_ff.linear1.weight: -0.000022
modules_list.0.feature_ff.linear1.bias: -0.000965
modules_list.0.feature_ff.linear2.weight: 0.000093
modules_list.0.feature_ff.linear2.bias: -0.003047
modules_list.0.feature_ff_norm.weight: 1.000000
modules_list.0.feature_ff_norm.bias: 0.000000
modules_list.0.time_attention.W_q.weight: 0.031950
modules_list.0.time_attention.W_q.bi

In [79]:
list(model.parameters())[3]

Parameter containing:
tensor([ 0.2570,  0.3088,  0.2686, -0.1744, -0.0968, -0.1702, -0.2498, -0.2376],
       requires_grad=True)

In [None]:
next(iter(model.named_parameters()))[1][:3]

tensor([[ 0.3991],
        [-0.6855],
        [-0.6416]], grad_fn=<SliceBackward0>)

In [51]:
model(x)[0][:3]

tensor([-0.3527, -0.2711, -0.3267], grad_fn=<SliceBackward0>)

In [105]:
next(iter(dataloader_test))

[tensor([[[ -1.0000,  -1.0000,  -1.0000,  ...,  -1.0000,   0.0000,  -1.0000],
          [ -1.0000,  -1.0000,  -1.0000,  ...,  -1.0000,  -1.0000,  -1.0000],
          [ -1.0000,  -1.0000,  -1.0000,  ...,  -1.0000,  -1.0000,   0.3000],
          ...,
          [ -1.0000,  -1.0000,  -1.0000,  ...,  -1.0000,  -1.0000,   2.7000],
          [ -1.0000,  -1.0000,  34.0000,  ...,  -1.0000,  -1.0000,  -1.0000],
          [ -1.0000,  -1.0000,  -1.0000,  ...,  10.0384,   4.1489,  14.7076]],
 
         [[ -1.0000,  -1.0000,  -1.0000,  ...,  -1.0000,   1.0000,  -1.0000],
          [ -1.0000,  -1.0000,  -1.0000,  ...,  -1.0000,  -1.0000,  -1.0000],
          [ -1.0000,  -1.0000,  -1.0000,  ...,  -1.0000,  -1.0000,  -1.0000],
          ...,
          [ -1.0000,  -1.0000,  -1.0000,  ...,  -1.0000,   7.0000,  -1.0000],
          [ -1.0000,  -1.0000,  -1.0000,  ...,  -1.0000,  -1.0000,  -1.0000],
          [110.3000,  -1.0000,  -1.0000,  ...,  35.3383,  10.8326,  16.3405]],
 
         [[ -1.0000,  -1.000

In [117]:
EPOCHS = 100
LEARNING_RATE = 1e-3

# Init optimizer for Cox
# AdamW is generally preferred over Adam for its weight decay regularization
optimizer_noreg = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
optimizer_reg = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.01)


# Initiate empty list to store the loss on the train and validation sets
train_losses = []
val_losses = []

# Get device and move model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = initmodel()
model = model.to(device)
# Train models
trained_model_reg, train_losses_reg, val_losses_reg, val_loss_reg, best_epoch_i_reg = train_model_simple(
    model=model,
    dataloader_train = dataloader_train,
    dataloader_val = dataloader_val, 
    optimizer=optimizer_reg,
    n_epochs=EPOCHS,
    device='cuda' if torch.cuda.is_available() else 'cpu'
)

# model = initmodel().to(device)
# trained_model_noreg, train_losses_noreg, val_losses_noreg, val_loss_noreg, best_epoch_i_noreg = train_model_simple(
#     model=model,
#     dataloader_train = dataloader_train,
#     dataloader_val = dataloader_val, 
#     optimizer=optimizer_noreg,
#     n_epochs=EPOCHS,
#     device='cuda' if torch.cuda.is_available() else 'cpu'
# )


# plot_losses(train_losses_reg, val_losses_reg, "Cox")
# print(f"  Best train and val loss: {min(train_losses_reg):0.3f}, {min(val_losses_reg):0.3f}")

# plot_losses(train_losses_noreg, val_losses_noreg, "Cox")
# print(f"  Best train and val loss: {min(train_losses_noreg):0.3f}, {min(val_losses_noreg):0.3f}")

Epoch 0, Batch 0, Loss: 7.9429
Validation loss: 7.938645362854004
    Epoch: 000, Training loss: 7.94, Validation loss: 7.94
Epoch 1, Batch 0, Loss: 7.9321
Validation loss: 7.929367542266846
    Epoch: 001, Training loss: 7.93, Validation loss: 7.93
Epoch 2, Batch 0, Loss: 7.9429
Validation loss: 7.925448894500732
    Epoch: 002, Training loss: 7.94, Validation loss: 7.93
Epoch 3, Batch 0, Loss: 7.9333
Validation loss: 7.919982433319092
    Epoch: 003, Training loss: 7.93, Validation loss: 7.92
Epoch 4, Batch 0, Loss: 7.9579
Validation loss: 7.92241096496582
    Epoch: 004, Training loss: 7.96, Validation loss: 7.92
Epoch 5, Batch 0, Loss: 7.9281
Validation loss: 7.913344383239746
    Epoch: 005, Training loss: 7.93, Validation loss: 7.91
Epoch 6, Batch 0, Loss: 7.9329
Validation loss: 7.931198596954346
    Epoch: 006, Training loss: 7.93, Validation loss: 7.93
Epoch 7, Batch 0, Loss: 7.9441
Validation loss: 7.916887283325195
    Epoch: 007, Training loss: 7.94, Validation loss: 7.92
E

KeyboardInterrupt: 