Compare AdamW reg parameter for a fixed architecture that overfits with ADAM. 

## Init

In [None]:
import torch
import torch.nn as nn
import math
import ukko 
import importlib
import pandas as pd
import numpy as np
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from torchsurv.loss.cox import neg_partial_log_likelihood
from torchsurv.loss.weibull import neg_log_likelihood, log_hazard, survival_function
from torchsurv.metrics.brier_score import BrierScore
from torchsurv.metrics.cindex import ConcordanceIndex
from torchsurv.metrics.auc import Auc
#from torchsurv.stats.kaplan_meier import KaplanMeierEstimator
# for cuda cleanups:
import gc

print("Libraries loaded")

### Custom function definitions

In [None]:
def plot_losses(train_losses, val_losses, title: str = "Cox") -> None:

    #train_losses = torch.stack(train_losses) / train_losses[0]
    #val_losses = torch.stack(val_losses) / val_losses[0]
    train_losses = np.array(train_losses)
    val_losses = np.array(val_losses)
    train_losses = train_losses / train_losses[0]
    val_losses = val_losses / val_losses[0]

    plt.plot(train_losses, label="training")
    plt.plot(val_losses, label="validation")
    plt.legend()
    plt.xlabel("Epochs")
    plt.ylabel("Normalized loss")
    plt.title(title)
    plt.yscale("log")
    plt.show()


class ukkosurv_dataset(Dataset):
    """ "Custom dataset for ukko-torcsurv use in df format"""

    # defining values in the constructor
    def __init__(self, df: pd.DataFrame):
        #self.df = df
        df_x, data_3d = ukko.utils.convert_to_3d_df(df.iloc[:,3:].fillna(-1))
        df_y = df_train.iloc[:,:3]
        
        self.df_y = df_y        # Dataframe with survival data, e.g. OSS_status, OSS_days
        self.data_3d = data_3d  # numpy array with 3D feature data: patients, features, time 


    # Getting data size/length
    def __len__(self):
        return len(self.data_3d)

    # Getting the data samples
    def __getitem__(self, idx):
        y = self.df_y.iloc[idx,:]
        # Targets
        event = torch.tensor(y["OSS_status"]).bool()
        time = torch.tensor(y["OSS_days"]).float()
        # Predictors
        x = torch.tensor(self.data_3d[idx,:,:]).float()
        return x, (event, time)

In [None]:
# Define training function
import copy
def train_model_simple(
            model,
            dataloader_train,
            dataloader_val,
            optimizer = None,
            n_epochs = 100,
            learning_rate = 0.01,
            device='cuda'
        ):

    dtype=torch.float32
    
    # Initialize optimizer if not provided
    if optimizer is None:
        optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    
    # Initiate empty list to store the loss on the train and validation sets
    train_losses = []
    val_losses = []
    best_epoch = 0
    best_loss = float('inf')
    best_model_state = None
    
    # Get device and move model
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device=device, dtype=dtype)
    
    # training loop
    for epoch in range(n_epochs):
        epoch_loss = 0.0 #torch.tensor(0.0)
        model = model.to(dtype=dtype, device=device) #
        for i, batch in enumerate(dataloader_train):
            x, (event, time) = batch
            x = x.to(dtype=dtype, device=device, non_blocking=True)
            event = event.to(device, non_blocking=True)
            time = time.to(device, non_blocking=True)
            
            optimizer.zero_grad()
            log_hz, feature_weights, time_weights = model(x)  # shape = (batchsize, 1)
            loss = neg_partial_log_likelihood(log_hz, event, time, reduction="mean").to(dtype=dtype, device=device, non_blocking=True)
            #print(f"loss dtype: {loss.dtype}")
            #print(f"loss cuda:  {loss.is_cuda}")
            #print(f"weights dtype: {feature_weights.dtype, time_weights.dtype}")
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item() #.detach().to("cpu").item()

            # Free up memory for unused tensors
            del x, event, time, log_hz, feature_weights, time_weights, loss
            
            gc.collect()
            torch.cuda.empty_cache()
        
        # Reccord loss on train and test sets
        epoch_loss /= i + 1
        train_losses.append(epoch_loss)
        model.eval()
        #print("Eval mode")
        with torch.no_grad():
            x, (event, time) = next(iter(dataloader_val))
            x = x.to(device, non_blocking=True)
            event = event.to(device, non_blocking=True)
            time = time.to(device, non_blocking=True)
            log_hz, feature_weights, time_weights = model(x)
            val_loss = neg_partial_log_likelihood(log_hz, event, time, reduction="mean").item() #.detach().to("cpu")
            val_losses.append(val_loss)
            # Save best model based on validation loss
            if val_loss < best_loss:
                best_epoch = epoch
                best_loss = val_loss
                best_model_state = copy.deepcopy(model.to('cpu').state_dict())
        
        del x, event, time, log_hz, feature_weights, time_weights    
        torch.cuda.empty_cache()
        
        # Display progress
        if epoch % 1 == 0: #(n_epochs // 5) == 0:
            print(f"    Epoch: {epoch:03}, Training loss: {train_losses[-1]:0.2f}, Validation loss: {val_losses[-1]:0.2f}")
    
    # Load best model if validation was used
    if dataloader_val and best_model_state:
        model.load_state_dict(best_model_state)

    gc.collect()
    
    return model, train_losses, val_losses, best_loss, best_epoch #if dataloader_val else avg_train_loss.item()

## Load data

In [None]:
#Load tidy data
print("Loading tidy data")
df_xy = pd.read_csv("data/df_xy_synth_v1.csv")

# create train, validation and test datasets: IMPUTE nan: -1
df_train = df_xy.fillna(-1)
df_test = df_train.sample(n=200, random_state=42)
df_train = df_train.drop(df_test.index)
df_val = df_train.sample(n=200, random_state=42)
df_train = df_train.drop(df_val.index)

print(f"Train: {df_train.shape}")
print(f"Val  : {df_val.shape}")
print(f"Test : {df_test.shape}")

# Ukko

### Dataloaders: train, val, test

In [None]:
importlib.reload(ukko)
importlib.reload(ukko.utils)

# Dataloader
BATCH_SIZE = 600
dataloader_train = DataLoader(
    ukkosurv_dataset(df_train), batch_size=BATCH_SIZE, shuffle=True
)
dataloader_val = DataLoader(
    ukkosurv_dataset(df_val), batch_size=len(df_val), shuffle=False
)
dataloader_test = DataLoader(
    ukkosurv_dataset(df_test), batch_size=len(df_test), shuffle=False
)

### Define model

In [None]:
importlib.reload(ukko.core)

# Get feature and time dimensions
x, (event, time) = next(iter(dataloader_train))
num_features, num_timepoints = x.size(1), x.size(2)
print(f"Number of features: {num_features}, Number of timepoints: {num_timepoints}")

# Initialize model
# DualAttentionRegressor1(self, n_features, time_steps, d_model=128, n_heads=8, dropout=0.1, n_modules=1)
def initmodel():
  model = ukko.core.DualAttentionRegressor1(
    n_features=num_features,
    time_steps=num_timepoints,
    d_model=8,
    n_heads=4,
    n_kv_heads=4,
    dropout=0.2,
    n_modules=2
  )
  return model

model = initmodel()
model

## Train model

In [None]:
EPOCHS = 10
LEARNING_RATE = 1e-3

# Init optimizer for Cox
# AdamW is generally preferred over Adam for its weight decay regularization
optimizer_noreg = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
optimizer_reg = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-1)


# Initiate empty list to store the loss on the train and validation sets
train_losses = []
val_losses = []

# Get device and move model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = initmodel()
model = model.to(device)
# Train models
trained_model_reg, train_losses_reg, val_losses_reg, val_loss_reg, best_epoch_i_reg = train_model_simple(
    model=model,
    dataloader_train = dataloader_train,
    dataloader_val = dataloader_val, 
    optimizer=optimizer_reg,
    n_epochs=EPOCHS,
    device='cuda' if torch.cuda.is_available() else 'cpu'
)

model = initmodel().to(device)
trained_model_noreg, train_losses_noreg, val_losses_noreg, val_loss_noreg, best_epoch_i_noreg = train_model_simple(
    model=model,
    dataloader_train = dataloader_train,
    dataloader_val = dataloader_val, 
    optimizer=optimizer_noreg,
    n_epochs=EPOCHS,
    device='cuda' if torch.cuda.is_available() else 'cpu'
)


plot_losses(train_losses_reg, val_losses_reg, "Cox")
print(f"  Best train and val loss: {min(train_losses_reg):0.3f}, {min(val_losses_reg):0.3f}")

plot_losses(train_losses_noreg, val_losses_noreg, "Cox")
print(f"  Best train and val loss: {min(train_losses_noreg):0.3f}, {min(val_losses_noreg):0.3f}")