### Script for running custom 5 fold cross-validation for VAE-MLP model
### running the best 5 VAE hyperparameters and the best 10 MLP hyperparameters for each fold
### selecting the best SSIM VAE and best test AUC MLP model for each fold


In [None]:
# torch for data loading and for model building
import torch
import torch.nn as nn
from torch import optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torchvision
from torchvision import transforms
import torchvision.models as models
from torch.utils.data import Dataset, DataLoader, random_split
from torchsummary import summary
import torch.nn.functional as F
from torchvision.utils import make_grid, save_image
from pytorch_msssim import ssim, ms_ssim, SSIM, MS_SSIM
from sklearn.model_selection import StratifiedKFold


# Data preprocessing
import pandas as pd
import json
import os
import sys
import re
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, roc_auc_score, balanced_accuracy_score, confusion_matrix
from scipy.spatial.distance import cdist
import nibabel as nib
import pickle


# Visualisation
import matplotlib.pyplot as plt
import seaborn as sns
import scipy.stats as stats
from PIL import Image
from pylab import rcParams

# Maths
import math
import numpy as np
from sklearn import metrics

# Other
from time import perf_counter
from collections import Counter,OrderedDict
import random
import warnings
import time
import wandb

# import functions from other files
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))
from models.VAE_2D_model import VAE_2D
from models.MLP_model import MLP_MIL_model_simple, MLP_MIL_model2
from utils.datasets import Load_Latent_Vectors, LoadImages, prepare_VAE_MLP_joint_data
from utils.utility_code import plot_results, parameter_count, get_single_scan_file_list, weights_init, get_class_distribution, plot_MLP_results, error_analysis
from utils.train_and_test_functions import train, test, train_VAE_model, evaluate_VAE, mixup_patient_data, mixup_batch, process_batch_with_noise, calibration_curve_and_distribution
from utils.loss_functions import loss_function, kl_annealing



warnings.filterwarnings("ignore")

In [1]:
device = 0
vae_params = [
    {'num_epochs': 200, 'threshold': 0.5, 'num_synthetic': 25, 'oversample': 1.5, 'batch_size': 128, 'lr': 0.006, 'weight_decay': 0.14, 'accumulation_steps': 4, 'patch_hidden_dim': 2048, 'patient_hidden_dim': 128, 'patch_dropout': 0.35, 'patient_dropout': 0.3, 'alpha': 0.8, 'mixup': False, 'attention_indicator': True, 'max_node_slices': 15, 'model_type': 'MLP_MIL_model2', 'clinical_data_options': ['T_stage', 'size', 'border', 'patient'], 'device': device},
    {'num_epochs': 200, 'threshold': 0.509, 'num_synthetic': 25, 'oversample': 2, 'batch_size': 100, 'lr': 0.00621, 'weight_decay': 0.116, 'accumulation_steps': 4, 'patch_hidden_dim': 1536, 'patient_hidden_dim': 128, 'patch_dropout': 0.3, 'patient_dropout': 0.2, 'alpha': 0.8, 'mixup': True, 'attention_indicator': True, 'max_node_slices': 15, 'model_type': 'MLP_MIL_model2', 'clinical_data_options': ['T_stage', 'size', 'border'], 'device': device},
    {'num_epochs': 200, 'threshold': 0.448, 'num_synthetic': 25, 'oversample': 2, 'batch_size': 100, 'lr': 0.00743, 'weight_decay': 0.157, 'accumulation_steps': 5, 'patch_hidden_dim': 2048, 'patient_hidden_dim': 96, 'patch_dropout': 0.4, 'patient_dropout': 0.2, 'alpha': 0.8, 'mixup': False, 'attention_indicator': True, 'max_node_slices': 15, 'model_type': 'MLP_MIL_model2', 'clinical_data_options': ['T_stage', 'size', 'border'], 'device': device},
    {'num_epochs': 200, 'threshold': 0.425, 'num_synthetic': 25, 'oversample': 1.25, 'batch_size': 150, 'lr': 0.00696, 'weight_decay': 0.0871, 'accumulation_steps': 4, 'patch_hidden_dim': 2048, 'patient_hidden_dim': 128, 'patch_dropout': 0.4, 'patient_dropout': 0.3, 'alpha': 0.8, 'mixup': False, 'attention_indicator': False, 'max_node_slices': 25, 'model_type': 'MLP_MIL_model2', 'clinical_data_options': ['T_stage', 'size', 'patient'], 'device': device},
    {'num_epochs': 200, 'threshold': 0.429, 'num_synthetic': 20, 'oversample': 1.5, 'batch_size': 64, 'lr': 0.00662, 'weight_decay': 0.179, 'accumulation_steps': 5, 'patch_hidden_dim': 1536, 'patient_hidden_dim': 96, 'patch_dropout': 0.4, 'patient_dropout': 0.3, 'alpha': 0.7, 'mixup': False, 'attention_indicator': False, 'max_node_slices': 15, 'model_type': 'MLP_MIL_model2', 'clinical_data_options': ['T_stage', 'size', 'border'], 'device': device},
    {'num_epochs': 200, 'threshold': 0.417, 'num_synthetic': 20, 'oversample': 1.25, 'batch_size': 64, 'lr': 0.00632, 'weight_decay': 0.179, 'accumulation_steps': 3, 'patch_hidden_dim': 512, 'patient_hidden_dim': 128, 'patch_dropout': 0.3, 'patient_dropout': 0.2, 'alpha': 0.8, 'mixup': True, 'attention_indicator': False, 'max_node_slices': 20, 'model_type': 'MLP_MIL_model2', 'clinical_data_options': ['T_stage', 'size', 'border', 'patient'], 'device': device},
    {'num_epochs': 200, 'threshold': 0.401, 'num_synthetic': 15, 'oversample': 1.25, 'batch_size': 256, 'lr': 0.00749, 'weight_decay': 0.16, 'accumulation_steps': 5, 'patch_hidden_dim': 2560, 'patient_hidden_dim': 96, 'patch_dropout': 0.4, 'patient_dropout': 0.4, 'alpha': 0.8, 'mixup': True, 'attention_indicator': False, 'max_node_slices': 25, 'model_type': 'MLP_MIL_model2', 'clinical_data_options': ['T_stage', 'size', 'border'], 'device': device},
    {'num_epochs': 200, 'threshold': 0.431, 'num_synthetic': 15, 'oversample': 2, 'batch_size': 150, 'lr': 0.00492, 'weight_decay': 0.199, 'accumulation_steps': 5, 'patch_hidden_dim': 2048, 'patient_hidden_dim': 46, 'patch_dropout': 0.4, 'patient_dropout': 0.3, 'alpha': 0.7, 'mixup': False, 'attention_indicator': True, 'max_node_slices': 25, 'model_type': 'MLP_MIL_model2', 'clinical_data_options': ['T_stage', 'size', 'border', 'patient'], 'device': device},
    {'num_epochs': 200, 'threshold': 0.419, 'num_synthetic': 25, 'oversample': 1, 'batch_size': 64, 'lr': 0.00628, 'weight_decay': 0.187, 'accumulation_steps': 5, 'patch_hidden_dim': 1536, 'patient_hidden_dim': 64, 'patch_dropout': 0.4, 'patient_dropout': 0.2, 'alpha': 0.7, 'mixup': True, 'attention_indicator': False, 'max_node_slices': 20, 'model_type': 'MLP_MIL_model2', 'clinical_data_options': ['T_stage', 'size'], 'device': device},
    {'num_epochs': 200, 'threshold': 0.434, 'num_synthetic': 20, 'oversample': 1.5, 'batch_size': 64, 'lr': 0.00771, 'weight_decay': 0.194, 'accumulation_steps': 5, 'patch_hidden_dim': 2560, 'patient_hidden_dim': 46, 'patch_dropout': 0.4, 'patient_dropout': 0.2, 'alpha': 0.8, 'mixup': False, 'attention_indicator': False, 'max_node_slices': 20, 'model_type': 'MLP_MIL_model2', 'clinical_data_options': ['T_stage', 'size', 'border'], 'device': device},
    {'num_epochs': 200, 'threshold': 0.409, 'num_synthetic': 25, 'oversample': 1.5, 'batch_size': 150, 'lr': 0.00857, 'weight_decay': 0.172, 'accumulation_steps': 3, 'patch_hidden_dim': 2048, 'patient_hidden_dim': 96, 'patch_dropout': 0.4, 'patient_dropout': 0.3, 'alpha': 0.8, 'mixup': False, 'attention_indicator': False, 'max_node_slices': 25, 'model_type': 'MLP_MIL_model2', 'clinical_data_options': ['T_stage', 'size', 'patient'], 'device': device},
    {'num_epochs': 200, 'threshold': 0.407, 'num_synthetic': 30, 'oversample': 1.25, 'batch_size': 128, 'lr': 0.00459, 'weight_decay': 0.176, 'accumulation_steps': 3, 'patch_hidden_dim': 2560, 'patient_hidden_dim': 128, 'patch_dropout': 0.3, 'patient_dropout': 0.2, 'alpha': 0.8, 'mixup': False, 'attention_indicator': True, 'max_node_slices': 25, 'model_type': 'MLP_MIL_model2', 'clinical_data_options': ['T_stage', 'size', 'patient'], 'device': device},

]
#find average for each parameter
vae_params_avg = {}
for key in vae_params[0].keys():
    if key == 'clinical_data_options' or key == 'model_type':
        continue
    vae_params_avg[key] = sum(d[key] for d in vae_params) / len(vae_params)
print(vae_params_avg)

{'num_epochs': 200.0, 'threshold': 0.4357499999999999, 'num_synthetic': 22.5, 'oversample': 1.5, 'batch_size': 118.16666666666667, 'lr': 0.006591666666666666, 'weight_decay': 0.16217499999999999, 'accumulation_steps': 4.25, 'patch_hidden_dim': 1920.0, 'patient_hidden_dim': 98.33333333333333, 'patch_dropout': 0.3708333333333333, 'patient_dropout': 0.25833333333333336, 'alpha': 0.775, 'mixup': 0.3333333333333333, 'attention_indicator': 0.4166666666666667, 'max_node_slices': 20.416666666666668, 'device': 0.0}


In [None]:
best_vae_hyperparams = [{'base': 20, 'latent_size': 20, 'annealing': 1, 'ssim_indicator': 1, 'alpha': 0.4, 'beta': 1, 
                         'lr':0.000544, 'batch_size': 1536, 'ssim_scalar': 3, 'recon_scale_factor': 4500, 
                         'weight_decay': 0.0535, 'accumulation_steps': 2}, 
                        {'base': 24, 'latent_size': 20, 'annealing': 1, 'ssim_indicator': 1, 'alpha': 0.5, 'beta': 1, 'lr': 0.000792, 'batch_size': 1024, 'ssim_scalar': 3, 'recon_scale_factor': 4500, 'weight_decay': 0.0302, 'accumulation_steps': 2},
                        {'base': 24, 'latent_size': 16, 'annealing': 1, 'ssim_indicator': 1, 'alpha': 0.4, 'beta': 1, 'lr': 0.000746, 'batch_size': 512, 'ssim_scalar': 3, 'recon_scale_factor': 4500, 'weight_decay': 0.0292, 'accumulation_steps': 2},
                        {'base': 24, 'latent_size': 16, 'annealing': 1, 'ssim_indicator': 1, 'alpha': 0.5, 'beta': 1, 'lr': 0.000646, 'batch_size': 1280, 'ssim_scalar': 2, 'recon_scale_factor': 3000, 'weight_decay': 0.0364, 'accumulation_steps': 2},
                        {'base': 16, 'latent_size': 24, 'annealing': 1, 'ssim_indicator': 1, 'alpha': 0.5, 'beta': 1, 'lr': 0.000427, 'batch_size': 512, 'ssim_scalar': 3, 'recon_scale_factor': 4000, 'weight_decay': 0.0469, 'accumulation_steps': 3},
                         {'base': 20, 'latent_size': 20, 'annealing': 1, 'ssim_indicator': 1, 'alpha': 0.6, 'beta': 1, 'lr': 0.000886, 'batch_size': 512, 'ssim_scalar': 3, 'recon_scale_factor': 3000, 'weight_decay': 0.0135, 'accumulation_steps': 3}]
     
file_path = r"C:\Users\mm17b2k.DS\Documents\Python\ARCANE_Results\fold_data.npy"
first_time = False
if first_time:
    cohort1 = pd.read_excel(r"C:\Users\mm17b2k.DS\Documents\ARCANE_Data\Cohort1.xlsx")
    IMAGE_DIR = r"C:\Users\mm17b2k.DS\Documents\ARCANE_Data\Cohort1_2D_slices"
    
    patient_files_list = [f for f in os.listdir(IMAGE_DIR + '\mri')] + [f  for f in os.listdir(IMAGE_DIR + '\mri_aug')]
    patient_ids = []
    for f in patient_files_list:
        patient_id = f[:10]
        if patient_id not in patient_ids:
            patient_ids.append(patient_id)
    
    pure_labels = []
    for i in range(len(patient_ids)):
        N = cohort1[cohort1[('shortpatpseudoid')] == patient_ids[i]]['NodeLabel'].item()
        if N == '0':
            pure_labels.append(0)
        else:
            pure_labels.append(1)
    
    # Initialize KFold for cross-validation
    skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=1)    
    # To store the folds
    fold_data = []
    # Loop over each fold and split
    for fold_idx, (train_index, test_index) in enumerate(skf.split(patient_ids, pure_labels)):
        # Create train and test sets for current fold
        train_ids = [patient_ids[i] for i in train_index]
        test_ids = [patient_ids[i] for i in test_index]
        train_labels = [pure_labels[i] for i in train_index]
        test_labels = [pure_labels[i] for i in test_index]
    
        # Store the fold's train/test split
        fold_data.append([train_ids, test_ids, train_labels, test_labels])
    
    with open(file_path, 'wb') as f:
        pickle.dump(fold_data, f)
    print(f"Fold data saved to {file_path}")
    
    

else:
    with open(file_path, 'rb') as f:
        fold_data = pickle.load(f)
    print(f"Fold data loaded from {file_path}")

In [None]:
Run = 25
def main(best_hyperparams, i, fold_idx):
    global Run, fold_data
    IMAGE_DIR = r"C:\Users\mm17b2k.DS\Documents\ARCANE_Data\Cohort1_2D_slices"
    results_path = r"C:\Users\mm17b2k.DS\Documents\Python\ARCANE_Results"


    Run += 1
    print("Run:", Run)

    save_results_path = rf"C:\Users\mm17b2k.DS\Documents\Python\ARCANE_Results\VAE_fold_{fold_idx}_run_{Run}.pt"

    # Check if GPU available
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print('Using device:', device)


    # settings for reproducibility
    torch.manual_seed(int(time.time()))
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = False

    time_start = perf_counter()
    
    cohort1 = pd.read_excel(r"C:\Users\mm17b2k.DS\Documents\ARCANE_Data\Cohort1.xlsx")
    all_files_list = ['\mri' + '//' + f for f in os.listdir(IMAGE_DIR + '\mri')] + ['\mri_aug' + '//' + f  for f in os.listdir(IMAGE_DIR + '\mri_aug')]
    all_files_list.sort()

    # only one scan per patient to avoid repeated scans when saving latent vectors
    all_files_list2 = get_single_scan_file_list(all_files_list, IMAGE_DIR, cohort1)
    all_files_list2.sort()


    epochs = 200


    ssim_list, loss_list = [], []

    base = best_hyperparams[i]['base']
    latent_size = best_hyperparams[i]['latent_size']
    annealing = best_hyperparams[i]['annealing']
    ssim_indicator = best_hyperparams[i]['ssim_indicator']
    alpha = best_hyperparams[i]['alpha']
    beta = best_hyperparams[i]['beta']
    lr = best_hyperparams[i]['lr']
    batch_size = best_hyperparams[i]['batch_size']
    ssim_scalar = best_hyperparams[i]['ssim_scalar']
    recon_scale_factor = best_hyperparams[i]['recon_scale_factor']
    weight_decay = best_hyperparams[i]['weight_decay']
    accumulation_steps = best_hyperparams[i]['accumulation_steps']

    hyperparams = {'base': base, 'latent_size': latent_size, 'annealing': annealing, 'ssim_indicator': ssim_indicator, 'alpha': alpha, 'beta': beta, 'lr': lr, 'batch_size': batch_size, 'ssim_scalar': ssim_scalar, 'recon_scale_factor': recon_scale_factor, 'weight_decay': weight_decay, 'accumulation_steps': accumulation_steps}
    print("Using Hyperparams:", hyperparams)
    print("latent size:", latent_size*base)

    patient_files_list = [f for f in os.listdir(IMAGE_DIR + '\mri')] + [f  for f in os.listdir(IMAGE_DIR + '\mri_aug')]
    patient_ids = []
    for file in patient_files_list:
        if file[0:21] not in patient_ids:
            patient_ids.append(file[0:21])

    patient_slices_dict, patient_labels_dict, patient_file_names_dict, short_long_axes_dict, mlp_train_ids, test_ids, mlp_train_labels, test_labels, train_images, test_images, train_test_split_dict, mask_sizes_dict = prepare_VAE_MLP_joint_data(first_time_train_test_split=False, cross_val=True, fold_data=fold_data[fold_idx])


    train_dataset = LoadImages(main_dir=IMAGE_DIR + '/', files_list=train_images)
    test_dataset = LoadImages(main_dir=IMAGE_DIR + '/', files_list=test_images)
    train_loader = DataLoader(train_dataset, batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size, shuffle=False)

    vae_model = VAE_2D(hyperparams)
    vae_model = vae_model.to(device)
    print('parameter count:', parameter_count(VAE_2D(hyperparams)))
    vae_model.apply(weights_init)


    vae_model, test_loss, test_ssim, train_loss, train_ssim, VAE_metrics = train_VAE_model(vae_model, epochs, train_loader, test_loader, hyperparams, device, results_path, save_results_path, sample_shape = (12, latent_size*base, 1, 1), train_test_split_dict = train_test_split_dict, wandb_sweep=False, Run=Run, fold_idx=fold_idx)

    # if test_ssim > 0.72:
    #     plot_results(results_path, save_results_path, 'loss_graph_{}.jpg'.format(Run))
    ssim_list.append(test_ssim)
    loss_list.append(test_loss)



    #idx = loss_list.index(min(loss_list))
    print('Hyperparameters:', hyperparams)

    # only one scan per patient to avoid repeated scans
    if test_ssim > 0.6:
        all_files_list2 = get_single_scan_file_list(all_files_list, IMAGE_DIR, cohort1)
        all_files_list2.sort()
        VAE_metrics, mus = evaluate_VAE(vae_model, test_loss, results_path, batch_size, device, IMAGE_DIR,
                                        all_files_list2, feature_length=latent_size*base, Run=Run, wandb_sweep=False, fold_idx=fold_idx)

    print("Time to train:", perf_counter() - time_start)
    print(f"Cooling down for 5 mins...")
    time.sleep(300)


In [None]:
# To run VAE cross-val
# for fold_idx in range(5):
#     for i in range(6):
#         main(best_vae_hyperparams, i, fold_idx)


In [None]:
import wandb
wandb.login()

wandb.init(project="Cross_validation_VAE_MLP4")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

best_ten_MLP_params = [
    {'num_epochs': 200, 'threshold': 0.5, 'num_synthetic': 25, 'oversample': 1.5, 'batch_size': 128, 'lr': 0.006, 'weight_decay': 0.14, 'accumulation_steps': 4, 'patch_hidden_dim': 2048, 'patient_hidden_dim': 128, 'patch_dropout': 0.35, 'patient_dropout': 0.3, 'alpha': 0.8, 'mixup': False, 'attention_indicator': True, 'max_node_slices': 15, 'model_type': 'MLP_MIL_model2', 'clinical_data_options': ['T_stage', 'size', 'border', 'patient'], 'device': device},
    {'num_epochs': 200, 'threshold': 0.509, 'num_synthetic': 25, 'oversample': 2, 'batch_size': 100, 'lr': 0.00621, 'weight_decay': 0.116, 'accumulation_steps': 4, 'patch_hidden_dim': 1536, 'patient_hidden_dim': 128, 'patch_dropout': 0.3, 'patient_dropout': 0.2, 'alpha': 0.8, 'mixup': True, 'attention_indicator': True, 'max_node_slices': 15, 'model_type': 'MLP_MIL_model2', 'clinical_data_options': ['T_stage', 'size', 'border'], 'device': device}, 
    {'num_epochs': 200, 'threshold': 0.448, 'num_synthetic': 25, 'oversample': 2, 'batch_size': 100, 'lr': 0.00743, 'weight_decay': 0.157, 'accumulation_steps': 5, 'patch_hidden_dim': 2048, 'patient_hidden_dim': 96, 'patch_dropout': 0.4, 'patient_dropout': 0.2, 'alpha': 0.8, 'mixup': False, 'attention_indicator': True, 'max_node_slices': 15, 'model_type': 'MLP_MIL_model2', 'clinical_data_options': ['T_stage', 'size', 'border'], 'device': device},
    {'num_epochs': 200, 'threshold': 0.425, 'num_synthetic': 25, 'oversample': 1.25, 'batch_size': 150, 'lr': 0.00696, 'weight_decay': 0.0871, 'accumulation_steps': 4, 'patch_hidden_dim': 2048, 'patient_hidden_dim': 128, 'patch_dropout': 0.4, 'patient_dropout': 0.3, 'alpha': 0.8, 'mixup': False, 'attention_indicator': False, 'max_node_slices': 25, 'model_type': 'MLP_MIL_model2', 'clinical_data_options': ['T_stage', 'size', 'patient'], 'device': device},
    {'num_epochs': 200, 'threshold': 0.429, 'num_synthetic': 20, 'oversample': 1.5, 'batch_size': 64, 'lr': 0.00662, 'weight_decay': 0.179, 'accumulation_steps': 5, 'patch_hidden_dim': 1536, 'patient_hidden_dim': 96, 'patch_dropout': 0.4, 'patient_dropout': 0.3, 'alpha': 0.7, 'mixup': False, 'attention_indicator': False, 'max_node_slices': 15, 'model_type': 'MLP_MIL_model2', 'clinical_data_options': ['T_stage', 'size', 'border'], 'device': device},
    {'num_epochs': 200, 'threshold': 0.417, 'num_synthetic': 20, 'oversample': 1.25, 'batch_size': 64, 'lr': 0.00632, 'weight_decay': 0.179, 'accumulation_steps': 3, 'patch_hidden_dim': 512, 'patient_hidden_dim': 128, 'patch_dropout': 0.3, 'patient_dropout': 0.2, 'alpha': 0.8, 'mixup': True, 'attention_indicator': False, 'max_node_slices': 20, 'model_type': 'MLP_MIL_model2', 'clinical_data_options': ['T_stage', 'size', 'border', 'patient'], 'device': device},
    {'num_epochs': 200, 'threshold': 0.401, 'num_synthetic': 15, 'oversample': 1.25, 'batch_size': 256, 'lr': 0.00749, 'weight_decay': 0.16, 'accumulation_steps': 5, 'patch_hidden_dim': 2560, 'patient_hidden_dim': 96, 'patch_dropout': 0.4, 'patient_dropout': 0.4, 'alpha': 0.8, 'mixup': True, 'attention_indicator': False, 'max_node_slices': 25, 'model_type': 'MLP_MIL_model2', 'clinical_data_options': ['T_stage', 'size', 'border'], 'device': device},
    {'num_epochs': 200, 'threshold': 0.431, 'num_synthetic': 15, 'oversample': 2, 'batch_size': 150, 'lr': 0.00492, 'weight_decay': 0.199, 'accumulation_steps': 5, 'patch_hidden_dim': 2048, 'patient_hidden_dim': 46, 'patch_dropout': 0.4, 'patient_dropout': 0.3, 'alpha': 0.7, 'mixup': False, 'attention_indicator': True, 'max_node_slices': 25, 'model_type': 'MLP_MIL_model2', 'clinical_data_options': ['T_stage', 'size', 'border', 'patient'], 'device': device},
    {'num_epochs': 200, 'threshold': 0.419, 'num_synthetic': 25, 'oversample': 1, 'batch_size': 64, 'lr': 0.00628, 'weight_decay': 0.187, 'accumulation_steps': 5, 'patch_hidden_dim': 1536, 'patient_hidden_dim': 64, 'patch_dropout': 0.4, 'patient_dropout': 0.2, 'alpha': 0.7, 'mixup': True, 'attention_indicator': False, 'max_node_slices': 20, 'model_type': 'MLP_MIL_model2', 'clinical_data_options': ['T_stage', 'size'], 'device': device},
    {'num_epochs': 200, 'threshold': 0.434, 'num_synthetic': 20, 'oversample': 1.5, 'batch_size': 64, 'lr': 0.00771, 'weight_decay': 0.194, 'accumulation_steps': 5, 'patch_hidden_dim': 2560, 'patient_hidden_dim': 46, 'patch_dropout': 0.4, 'patient_dropout': 0.2, 'alpha': 0.8, 'mixup': False, 'attention_indicator': False, 'max_node_slices': 20, 'model_type': 'MLP_MIL_model2', 'clinical_data_options': ['T_stage', 'size', 'border'], 'device': device},
    {'num_epochs': 200, 'threshold': 0.409, 'num_synthetic': 25, 'oversample': 1.5, 'batch_size': 150, 'lr': 0.00857, 'weight_decay': 0.172, 'accumulation_steps': 3, 'patch_hidden_dim': 2048, 'patient_hidden_dim': 96, 'patch_dropout': 0.4, 'patient_dropout': 0.3, 'alpha': 0.8, 'mixup': False, 'attention_indicator': False, 'max_node_slices': 25, 'model_type': 'MLP_MIL_model2', 'clinical_data_options': ['T_stage', 'size', 'patient'], 'device': device},
    {'num_epochs': 200, 'threshold': 0.407, 'num_synthetic': 30, 'oversample': 1.25, 'batch_size': 128, 'lr': 0.00459, 'weight_decay': 0.176, 'accumulation_steps': 3, 'patch_hidden_dim': 2560, 'patient_hidden_dim': 128, 'patch_dropout': 0.3, 'patient_dropout': 0.2, 'alpha': 0.8, 'mixup': False, 'attention_indicator': True, 'max_node_slices': 25, 'model_type': 'MLP_MIL_model2', 'clinical_data_options': ['T_stage', 'size', 'patient'], 'device': device},

]

latent_vector_paths = [
                       r"C:\Users\mm17b2k.DS\Documents\Python\ARCANE_Results\VAE_MLP_cross_validation\latent_vectors_fold_0_run_5.npy",
                       r"C:\Users\mm17b2k.DS\Documents\Python\ARCANE_Results\VAE_MLP_cross_validation\latent_vectors_fold_1_run_11.npy",
                       r"C:\Users\mm17b2k.DS\Documents\Python\ARCANE_Results\VAE_MLP_cross_validation\latent_vectors_fold_2_run_14.npy",
                       r"C:\Users\mm17b2k.DS\Documents\Python\ARCANE_Results\VAE_MLP_cross_validation\latent_vectors_fold_3_run_23.npy",
                       r"C:\Users\mm17b2k.DS\Documents\Python\ARCANE_Results\VAE_MLP_cross_validation\latent_vectors_fold_4_run_26.npy"
                        ]   

VAE_params_paths = [
                    r"C:\Users\mm17b2k.DS\Documents\Python\ARCANE_Results\VAE_MLP_cross_validation\VAE_fold_0_run_5_ssim_0.727695107460022.pt",
                    r"C:\Users\mm17b2k.DS\Documents\Python\ARCANE_Results\VAE_MLP_cross_validation\VAE_fold_1_run_11_ssim_0.7619633674621582.pt",
                    r"C:\Users\mm17b2k.DS\Documents\Python\ARCANE_Results\VAE_MLP_cross_validation\VAE_fold_2_run_14_ssim_0.7635628581047058.pt",
                    r"C:\Users\mm17b2k.DS\Documents\Python\ARCANE_Results\VAE_MLP_cross_validation\VAE_fold_3_run_23_ssim_0.77414470911026.pt",
                    r"C:\Users\mm17b2k.DS\Documents\Python\ARCANE_Results\VAE_MLP_cross_validation\VAE_fold_4_run_26_ssim_0.7822268009185791.pt"
                    ]

                   
Run = 0
best_test_preds, best_test_probs= [], []
def main(best_hyperparams, i, fold_idx, latent_vector_paths, VAE_params_paths):
    global Run, best_test_preds, best_test_probs
    if fold_idx < 3:
        best_score = {'TP': 7, 'FP': 7, 'Train_Sensitivity': 0.6}
    if fold_idx >= 3:
        best_score = {'TP': 6, 'FP': 7, 'Train_Sensitivity': 0.6}
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    Run += 1
    print('Run:', Run)
    num_logged = 0
    
    run = wandb.init(project="Cross_validation_VAE_MLP4", name=f"fold_{fold_idx}_run_{Run}", reinit=True)
    wandb.config.update(best_hyperparams[i])
    
    # Hyperparams    
    num_epochs = 200
    
    num_synthetic = best_hyperparams[i]['num_synthetic']
    oversample = best_hyperparams[i]['oversample']
    max_node_slices = best_hyperparams[i]['max_node_slices']
    threshold = best_hyperparams[i]['threshold']
    batch_size = best_hyperparams[i]['batch_size']
    lr = best_hyperparams[i]['lr']
    weight_decay = best_hyperparams[i]['weight_decay']
    accumulation_steps = best_hyperparams[i]['accumulation_steps']
    patch_hidden_dim = best_hyperparams[i]['patch_hidden_dim']
    patient_hidden_dim = best_hyperparams[i]['patient_hidden_dim']
    patch_dropout = best_hyperparams[i]['patch_dropout']
    patient_dropout = best_hyperparams[i]['patient_dropout']
    alpha = best_hyperparams[i]['alpha']
    mixup = best_hyperparams[i]['mixup']
    attention_indicator = best_hyperparams[i]['attention_indicator']
    model_type = best_hyperparams[i]['model_type']
    clinical_data_options = best_hyperparams[i]['clinical_data_options']
    
    clinical_length = 0
    if "size" in clinical_data_options:
        clinical_length += 3
    if "border" in clinical_data_options:
        clinical_length += 2

    hyperparams = {'num_epochs': num_epochs, 'threshold': threshold, 'num_synthetic': num_synthetic, 'oversample': oversample,
                   'batch_size': batch_size, 'lr': lr, 'weight_decay': weight_decay, 'accumulation_steps': accumulation_steps,
                   'patch_hidden_dim': patch_hidden_dim, 'patient_hidden_dim': patient_hidden_dim,
                   'patch_dropout': patch_dropout, 'patient_dropout': patient_dropout, 'alpha': alpha, 'mixup': mixup,
                   'attention_indicator': attention_indicator, 'max_node_slices': max_node_slices, 'model_type': model_type,
                   'clinical_data_options': clinical_data_options, 'device': device}

    print(hyperparams)
    print('Device:', device)

    time_start = perf_counter()

    results_path = r"C:\Users\mm17b2k.DS\Documents\Python\ARCANE_Results\VAE_MLP_cross_validation\MLP_Results"
    save_results_path = rf"C:\Users\mm17b2k.DS\Documents\Python\ARCANE_Results\VAE_MLP_cross_validation\MLP_Results\MLP_Fold_{fold_idx}_Run_{Run}.pt"

    # Load the dataset
    IMAGE_DIR = r"C:\Users\mm17b2k.DS\Documents\ARCANE_Data\Cohort1_2D_slices"
    cohort1 = pd.read_excel(r"C:\Users\mm17b2k.DS\Documents\ARCANE_Data\Cohort1.xlsx")
    latent_vectors = np.load(latent_vector_paths[fold_idx])

    all_files_list = ['\mri' + '//' + f for f in os.listdir(IMAGE_DIR + '\mri')] + ['\mri_aug' + '//' + f  for f in os.listdir(IMAGE_DIR + '\mri_aug')]
    all_files_list.sort()
    all_files_list = get_single_scan_file_list(all_files_list, IMAGE_DIR, cohort1)

    VAE_params_path = VAE_params_paths[fold_idx]
    checkpoint = torch.load(VAE_params_path)
    train_test_split_dict = checkpoint['train_test_split']
    train_ids = train_test_split_dict['train']
    test_ids = train_test_split_dict['test']
    patient_slices_dict, patient_labels_dict, patient_file_names_dict, short_long_axes_dict, mlp_train_ids, test_ids, mlp_train_labels, test_labels, train_images, test_images, train_test_split_dict, mask_sizes = prepare_VAE_MLP_joint_data(first_time_train_test_split=False, train_ids=train_ids, test_ids=test_ids, num_synthetic=num_synthetic, oversample_ratio=oversample)
    
    #train_dataset = torch.load(r'C:\Users\mm17b2k.DS\Documents\Python\ARCANE_Results\MLP_Results\{}.pth'.format(dataset_version))
    train_dataset = Load_Latent_Vectors(patient_slices_dict, latent_vectors, patient_labels_dict, mlp_train_ids, cohort1, all_files_list, short_long_axes_dict, mask_sizes, clinical_data_options, max_nodes=max_node_slices)
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    #test_dataset = torch.load(r'C:\Users\mm17b2k.DS\Documents\Python\ARCANE_Results\MLP_Results\test_dataset.pth')
    test_dataset = Load_Latent_Vectors(patient_slices_dict, latent_vectors, patient_labels_dict, test_ids, cohort1, all_files_list, short_long_axes_dict, mask_sizes, clinical_data_options, max_nodes=max_node_slices)
    test_dataloader = DataLoader(test_dataset, batch_size=10, shuffle=False)
    
    vae_hyperparams = checkpoint['hyperparams']
    base = vae_hyperparams['base']
    latent_size = vae_hyperparams['latent_size']
    vae_features_length = base*latent_size
    # Instantiate the model
    if model_type == 'MLP_MIL_model_simple':
        model = MLP_MIL_model_simple(patch_input_dim=vae_features_length+clinical_length, hyperparams=hyperparams)
    if model_type == 'MLP_MIL_model2':
        model = MLP_MIL_model2(patch_input_dim=vae_features_length+clinical_length, hyperparams=hyperparams)

    model.apply(weights_init)
    model.to(device)


    criterion = nn.BCELoss()
    optimiser = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimiser, mode='min', factor=0.5, patience=40,
                                                              verbose=True, threshold=0.001, threshold_mode='abs')

    train_losses, test_losses = [], []
    train_AUCs, test_AUCs = [], []
    train_sensitivitys, test_sensitivitys = [], []
    batches_mixed = 0
    early_stopping = 0
    for epoch in range(num_epochs):
        # Training phase
        model.train()
        train_loss = 0
        test_loss = 0
        all_train_labels = []
        all_train_preds = []
        all_train_probs = []
        steps = 0
        optimiser.zero_grad()


        for features, label, clinical_data, number_of_nodes in train_dataloader:
            if mixup:
                random_int = np.random.randint(1, 6) # random int between 1 and 3
            else:
                random_int = 0
            if random_int == 1: # 1/6 chance
                #print('Mixup')
                batches_mixed += 1

                features_without_clinical = features[:, :, :-clinical_length]
                features_mixed = mixup_batch(features_without_clinical) # mixup 50% of batch
                if clinical_length > 0:
                    features = torch.cat((features_mixed, features[:, :, -clinical_length:]), dim=2) # add clinical data back
            if random_int == 2: # 1/6 chance
                #print('Noise')
                batches_mixed += 1
                features_without_clinical = features[:, :, :-clinical_length]
                features_noise = process_batch_with_noise(features_without_clinical) # add noise to 50% of batch
                if clinical_length > 0:
                    features = torch.cat((features_noise, features[:, :, -clinical_length:]), dim=2) # add clinical data back

            features, label = features.to(device), label.to(device)
            clinical_data, number_of_nodes = clinical_data.to(device), number_of_nodes.to(device)
            steps += 1
            # Forward pass
            #output = model(features.squeeze(0))  # Remove batch dimension
            output, max_vals, attentions, classifications = model(features, clinical_data, number_of_nodes, label)
            output = output.squeeze(1)

            # binary threshold classifications
            # classifications = torch.where(classifications > 0.5, torch.tensor([1.]).to(device), torch.tensor([0.]).to(device))
            # print('Classifications:', classifications)


            #print(torch.mean(label.float()))
            # if label == 1:
            #     weight = torch.tensor([2.0]).to(device)
            # if label == 0:
            #     weight = torch.tensor([1.0]).to(device)

            loss = criterion(output, label.float()) #*weight
            train_loss += loss.item()

            # # Backward pass and optimization
            # optimiser.zero_grad()
            # loss.backward()
            # optimiser.step()
            # Backward pass and optimization
            loss.backward()
            if steps % accumulation_steps == 1:
                optimiser.step()
                optimiser.zero_grad()

            # Apply threshold to determine predicted class
            #predicted_probs = F.softmax(output, dim=1)[:, 1]  # Probability of class 1 (positive)
            #predicted_probs = torch.sigmoid(output)

            predicted_probs = output
            classifications_class = (classifications >= threshold).long()
            #predicted_probs = 0.6*max_vals + 0.35*classifications_class + 0.05*attentions
            predicted_class = (predicted_probs >= threshold).long()


            # Store predictions and labels
            all_train_labels.extend(label.cpu().numpy())
            all_train_preds.extend(predicted_class.cpu().numpy())
            all_train_probs.extend(predicted_probs.tolist())
            # random_int = np.random.randint(1, 20)
            # if random_int == 1:
            #     rdn_idx  = np.random.randint(0, len(features))
            #     print(rdn_idx, 'label (train)', label[rdn_idx].item(), 'output', output[rdn_idx].item(), 'predicted class', predicted_class[rdn_idx].item(), 'max', max_vals[rdn_idx].item(), 'attention', attentions[rdn_idx].item(), 'classification', classifications[rdn_idx].item(), 'class binary', classifications_class[rdn_idx].item(), 'number of nodes', number_of_nodes[rdn_idx].item()) # 'reweighted prediction', predicted_probs[rdn_idx].item())

        optimiser.step()
        optimiser.zero_grad()
        lr_scheduler.step(train_loss/len(train_dataloader))
        if epoch % 5 == 0 or epoch + 20 > num_epochs-1:
            print('Learning rate:', optimiser.param_groups[0]['lr'])
        train_losses.append(train_loss/len(train_dataloader))
        train_accuracy = accuracy_score(all_train_labels, all_train_preds)
        train_auc = roc_auc_score(all_train_labels, all_train_preds)
        train_AUCs.append(train_auc)
        train_bal_accuracy = balanced_accuracy_score(all_train_labels, all_train_preds)
        train_confusion_matrix = confusion_matrix(all_train_labels, all_train_preds)
        tn, fp, fn, tp = confusion_matrix(all_train_labels, all_train_preds).ravel()
        # Compute sensitivity (recall) and specificity
        train_sensitivity = tp / (tp + fn)
        train_specificity = tn / (tn + fp)
        train_sensitivitys.append(train_sensitivity)


        print(f'Epoch [{epoch+1}/{num_epochs}], Train: Loss: {train_loss/len(train_dataloader):.4f}, Accuracy: {train_accuracy:.4f}, Balanced Accuracy: {train_bal_accuracy:.4f}, AUC: {train_auc:.4f}, Sensitivity: {train_sensitivity:.4f}, Specificity: {train_specificity:.4f}')
        print(f'Train Confusion Matrix:')
        print(train_confusion_matrix)

        # Evaluation phase
        model.eval()
        test_loss = 0
        all_test_labels = []
        all_test_preds = []
        all_test_probs = []
        with torch.no_grad():
            for features, label, clinical_data, number_of_nodes in test_dataloader:
                features, label = features.to(device), label.to(device)
                clinical_data, number_of_nodes = clinical_data.to(device), number_of_nodes.to(device)
                #output = model(features.squeeze(0))  # Remove batch dimension
                output, max_vals, attentions, classifications = model(features, clinical_data, number_of_nodes, label)
                output = output.squeeze(1)

                #output = output.squeeze(0)
                loss = criterion(output, label.float())
                test_loss += loss.item()


                # Store predictions and labels
                #predicted_probs = F.softmax(output, dim=1)[:, 1]  # Probability of class 1 (positive)
                #predicted_probs = torch.sigmoid(output)
                predicted_probs = output
                classifications_class = (classifications >= threshold).long()
                #predicted_probs = 0.6*max_vals + 0.35*classifications_class + 0.05*attentions
                predicted_class = (predicted_probs >= threshold).type(torch.long)
                all_test_labels.extend(label.cpu().numpy())
                all_test_preds.extend(predicted_class.cpu().numpy())
                all_test_probs.extend(predicted_probs.cpu().numpy())


        test_losses.append(test_loss/len(test_dataloader))
        test_accuracy = accuracy_score(all_test_labels, all_test_preds)
        test_auc = roc_auc_score(all_test_labels, all_test_preds)
        test_AUCs.append(test_auc)
        test_bal_accuracy = balanced_accuracy_score(all_test_labels, all_test_preds)
        test_confusion_matrix = confusion_matrix(all_test_labels, all_test_preds)
        tn, fp, fn, tp = confusion_matrix(all_test_labels, all_test_preds).ravel()
        # Compute sensitivity (recall) and specificity
        test_sensitivity = tp / (tp + fn)
        test_specificity = tn / (tn + fp)
        test_metric = (2*test_sensitivity + test_specificity)/3
        test_sensitivitys.append(test_sensitivity)
        #if epoch % 5 == 0 or epoch + 20 > num_epochs-1:
        print(f'Test: Loss: {test_loss/len(test_dataloader):.4f}, Accuracy: {test_accuracy:.4f}, Balanced Accuracy: {test_bal_accuracy:.4f}, AUC: {test_auc:.4f}, Sensitivity: {test_sensitivity:.4f}, Specificity: {test_specificity:.4f}, Metric:, {test_metric:.4f}')
        print('Test Confusion Matrix:')
        print(test_confusion_matrix)
        # Wait for GPU to cool down for 10 seconds
        time.sleep(10)

        if epoch == 0:
            test_labels = np.array(all_test_labels)

        if tp >= 6 and fp <= 8:
            # error analysis
            best_test_preds.append(all_test_preds)
            best_test_probs.append(all_test_probs)
            print('number of preds logged:', len(best_test_probs))
            error_analysis(np.array(best_test_probs), test_labels, results_path, threshold, fold_idx)
            num_logged += 1

            if tp > best_score['TP'] or (tp >= best_score['TP'] and fp < best_score['FP']) or (tp >= best_score['TP'] and fp <= best_score['FP'] and train_sensitivity > best_score['Train_Sensitivity']):
                best_score['TP'] = tp
                best_score['FP'] = fp
                best_score['Train_Sensitivity'] = train_sensitivity
                print('Saving model with TP:', tp, 'and FP:', fp, 'at epoch:', epoch)
                training_results = {'train_losses': train_losses, 'test_losses': test_losses, 'train_AUCs': train_AUCs, 
                                    'test_AUCs': test_AUCs, 'train_sensitivitys': train_sensitivitys, 'test_sensitivitys': test_sensitivitys, 'all test labels': all_test_labels, 'all test probs': all_test_probs}
                torch.save({"state_dict": model.state_dict(), "training_results": training_results,
                            "hyperparams": hyperparams, "train_test_split": train_test_split_dict}, save_results_path)
                calibration_curve_and_distribution(all_train_labels, all_train_probs, 'Train', results_path, f'saved_result_fold_{fold_idx}_run_{Run}', save=True)
                calibration_curve_and_distribution(all_test_labels, all_test_probs, 'Test', results_path, f'saved_result_fold_{fold_idx}_run_{Run}', save=True)

                # plot results at this stage (updating until the best run)
                plot_MLP_results(training_results, hyperparams, results_path=results_path, filename=f'MLP_training_metrics_{fold_idx}_run_{Run}.png')



        if epoch == num_epochs-1:
            calibration_curve_and_distribution(all_train_labels, all_train_probs, 'Train', results_path, Run)
            calibration_curve_and_distribution(all_test_labels, all_test_probs, 'Test', results_path, Run)

        print('Max Test AUC:', np.max(test_AUCs))
        # log to wandb
        wandb.log(
            {
                "Test Loss": test_loss/len(test_dataloader),
                "Test Metric": test_metric,
                "Test Accuracy": test_accuracy,
                "Test AUC": test_auc,
                "Test Sensitivity": test_sensitivity,
                "Test Specificity": test_specificity,
                "Test TP": tp,
                "Test FP": fp,
                "Train Loss": train_loss/len(train_dataloader),
                "Train Accuracy": train_accuracy,
                "Train AUC": train_auc,
                "Train Sensitivity": train_sensitivity,
                "Train Specificity": train_specificity,
                "Max Test AUC": np.max(test_AUCs),
            }
        )

        # # Early stopping
        # if epoch > 75 and test_auc < 0.7:
        #     early_stopping+=1
        #     if early_stopping > 25 and test_auc < 0.6 and num_logged == 0:
        #         print('Early stopping')
        #         break
        #     if early_stopping > 50 and test_auc < 0.65 and num_logged == 0:
        #         print('Early stopping')
        #         break
        # 
        #     if early_stopping > 75 and test_auc < 0.7 and num_logged <= 1:
        #         print('Early stopping')
        #         break


    # save test preds and probs
    np.save(results_path + '//best_test_preds.npy', np.array(best_test_preds))
    np.save(results_path + '//best_test_probs.npy', np.array(best_test_probs))

    print('Batches mixed:', batches_mixed, 'out of', len(train_dataloader)*num_epochs, 'percentage:', batches_mixed/(len(train_dataloader)*num_epochs))
    print(hyperparams)
    print('Time taken:', perf_counter() - time_start)
    # Wait for GPU to cool down after each model run
    print(f"Cooling down for 5 mins...")
    time.sleep(60*5)

In [None]:
time_start = perf_counter()
folds = [1]

for i in range(12):
    for fold_idx in folds:
        print('starting fold:', fold_idx, 'run:', i)
        main(best_ten_MLP_params, i, fold_idx, latent_vector_paths, VAE_params_paths)
print('Total time taken:', perf_counter() - time_start)