### Hyperparameter sweep for VAE model


In [None]:
# torch for data loading and for model building
import torch
import torch.nn as nn
from torch import optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torchvision
from torchvision import transforms
import torchvision.models as models
from torch.utils.data import Dataset, DataLoader, random_split
from torchsummary import summary
import torch.nn.functional as F
from torchvision.utils import make_grid, save_image
from pytorch_msssim import ssim, ms_ssim, SSIM, MS_SSIM
from sklearn.model_selection import KFold


# Data preprocessing
import pandas as pd
import json
import os
import sys
import re
from sklearn.model_selection import train_test_split
#import nibabel as nib


# Visualisation
import matplotlib.pyplot as plt
import seaborn as sns
import scipy.stats as stats
from PIL import Image
from pylab import rcParams

# Maths
import math
import numpy as np
from sklearn import metrics

# Other
from time import perf_counter
from collections import Counter,OrderedDict
import random
import warnings
import time
import wandb

# import functions from other files
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))
from models.VAE_2D_model import VAE_2D
from utils.datasets import LoadImages, prepare_VAE_MLP_joint_data
from utils.utility_code import plot_results, parameter_count, get_single_scan_file_list, weights_init
from utils.train_and_test_functions import train, test, train_VAE_model, evaluate_VAE
from utils.loss_functions import loss_function, kl_annealing

warnings.filterwarnings("ignore")

# Define sweep configuration
sweep_configuration = {
    "method": "bayes",
    "name": "sweep2",
    "metric": {"goal": "maximize", "name": "Test SSIM"},
    "parameters": {
        "base": {"values": [16, 20, 24, 28]}, # number of feature maps in convolutional layers
        "latent_size": {"values": [16, 20, 24, 28]}, # size of latent space
        "annealing": {"values": [1]}, # annealing on KL Divergence indicator (0 for NO or 1 for YES)
        "ssim_indicator": {"values": [1]}, # SSIM indicator for loss function: 0 to not inlcude, 1 to include
        "alpha": {"values": [0.4, 0.5, 0.6]}, # If using SSIM/other metric in loss function this is the balance between reconstruction loss (L1/MAE) and the other metric in alpha*L1_loss + (1-alpha)*ssim
        "beta": {"values": [1]}, # multiplier for KL divergence, helps to disentangle latent space. Idea from beta-VAE: https://openreview.net/pdf?id=Sy2fzU9gl
        "lr": {"max": 1e-3, "min": 1e-5}, # learning rate (smaller number for slower training)
        "batch_size": {"values": [512, 1024, 1280, 1536]}, # batch size (bigger for more stable training)
        "ssim_scalar": {"values": [0, 0.5, 1, 2, 3]}, # upweight SSIM by this x batch_size
        "recon_scale_factor": {"values": [2500, 3000, 4000, 4500]}, # scale factor for reconstruction loss
        "weight_decay": {"max": 0.06, "min": 1e-4}, # weight decay
        "accumulation_steps": {"values": [2, 3, 4]}, # accumulation steps
    },
}

# Initialize sweep by passing in config.
# Provide a name of the project.
sweep_id = wandb.sweep(sweep=sweep_configuration, project="VAE_bayesian_sweep2")

wandb.login()

In [None]:
Run = 0
def main():
    global Run
    IMAGE_DIR = r"C:\Users\mm17b2k.DS\Documents\ARCANE_Data\Cohort1_2D_slices"
    results_path = r"C:\Users\mm17b2k.DS\Documents\Python\ARCANE_Results"

    run = wandb.init()

    Run += 1
    print("Run:", Run)

    save_results_path = rf"C:\Users\mm17b2k.DS\Documents\Python\ARCANE_Results\VAE_{Run}.pt"

    # Check if GPU available
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print('Using device:', device)


    # settings for reproducibility
    torch.manual_seed(int(time.time()))
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = False

    time_start = perf_counter()

    all_files_list = ['\mri' + '//' + f for f in os.listdir(IMAGE_DIR + '\mri')] + ['\mri_aug' + '//' + f  for f in os.listdir(IMAGE_DIR + '\mri_aug')]
    all_files_list.sort()


    epochs = 200


    ssim_list, loss_list = [], []

    base = wandb.config.base
    latent_size = wandb.config.latent_size
    annealing = wandb.config.annealing
    ssim_indicator = wandb.config.ssim_indicator
    alpha = wandb.config.alpha
    beta = wandb.config.beta
    lr = wandb.config.lr
    batch_size = wandb.config.batch_size
    ssim_scalar = wandb.config.ssim_scalar
    recon_scale_factor = wandb.config.recon_scale_factor
    weight_decay = wandb.config.weight_decay
    accumulation_steps = wandb.config.accumulation_steps

    hyperparams = {'base': base, 'latent_size': latent_size, 'annealing': annealing, 'ssim_indicator': ssim_indicator, 'alpha': alpha, 'beta': beta, 'lr': lr, 'batch_size': batch_size, 'ssim_scalar': ssim_scalar, 'recon_scale_factor': recon_scale_factor, 'weight_decay': weight_decay, 'accumulation_steps': accumulation_steps}
    print("Using Hyperparams:", hyperparams)
    print("latent size:", latent_size*base)

    patient_files_list = [f for f in os.listdir(IMAGE_DIR + '\mri')] + [f  for f in os.listdir(IMAGE_DIR + '\mri_aug')]
    patient_ids = []
    for file in patient_files_list:
        if file[0:21] not in patient_ids:
            patient_ids.append(file[0:21])

    patient_slices_dict, patient_labels_dict, patient_file_names_dict, short_long_axes_dict, mlp_train_ids, test_ids, mlp_train_labels, test_labels, train_images, test_images, train_test_split_dict, mask_sizes_dict = prepare_VAE_MLP_joint_data(first_time_train_test_split=True, test_proportion=0.35)


    train_dataset = LoadImages(main_dir=IMAGE_DIR + '/', files_list=train_images)
    test_dataset = LoadImages(main_dir=IMAGE_DIR + '/', files_list=test_images)
    train_loader = DataLoader(train_dataset, batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size, shuffle=False)

    vae_model = VAE_2D(hyperparams)
    vae_model = vae_model.to(device)
    print('parameter count:', parameter_count(VAE_2D(hyperparams)))
    vae_model.apply(weights_init)


    vae_model, test_loss, test_ssim, train_loss, train_ssim, VAE_metrics = train_VAE_model(vae_model, epochs, train_loader, test_loader, hyperparams, device, results_path, save_results_path, sample_shape = (12, latent_size*base, 1, 1), train_test_split_dict = train_test_split_dict, wandb_sweep=True, Run=Run)

    # if test_ssim > 0.72:
    #     plot_results(results_path, save_results_path, 'loss_graph_{}.jpg'.format(Run))
    ssim_list.append(test_ssim)
    loss_list.append(test_loss)



    #idx = loss_list.index(min(loss_list))
    print('Hyperparameters:', hyperparams)
    cohort1 = pd.read_excel(r"C:\Users\mm17b2k.DS\Documents\ARCANE_Data\Cohort1.xlsx")

    # only one scan per patient to avoid repeated scans
    all_files_list2 = get_single_scan_file_list(all_files_list, IMAGE_DIR, cohort1)
    all_files_list2.sort()
    VAE_metrics, mus = evaluate_VAE(vae_model, test_loss, results_path, batch_size, device, IMAGE_DIR,
                                    all_files_list2, feature_length=latent_size*base, Run=Run, wandb_sweep=True)

    print("Time to train:", perf_counter() - time_start)
    print(f"Cooling down for 5 mins...")
    time.sleep(300)


In [None]:
# Start sweep job.
wandb.agent(sweep_id, function=main, count=75)