### Hyperparameter sweep for CNN-MLP model (replacing VAE with CNN backbone)

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import nibabel as nib
import os
import sys
import numpy as np
import pandas as pd
from time import perf_counter
from monai.networks.nets import DenseNet121
import time
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score, roc_auc_score, balanced_accuracy_score, confusion_matrix
from sklearn.calibration import calibration_curve
import torchvision.models as models
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))
from utils.datasets import Load_CNN_Images, prepare_VAE_MLP_joint_data
from models.MLP_model import MLP_MIL_model_simple, MLP_MIL_model2
from utils.utility_code import get_single_scan_file_list, get_class_distribution, weights_init, plot_MLP_results, error_analysis
from utils.train_and_test_functions import mixup_patient_data, mixup_batch, process_batch_with_noise, calibration_curve_and_distribution

import wandb
from monai.networks.nets import DenseNet #, HighResNet, EfficientNet, ResNet

In [None]:
# Best so far: {'num_epochs': 200, 'threshold': 0.4049798489191535, 'num_synthetic': 30, 'oversample': 1.5, 'batch_size': 64, 'lr': 0.004089429701418951, 'weight_decay': 0.08306021271710541, 'accumulation_steps': 3, 'patch_hidden_dim': 2048, 'max_node_slices': 15, 'model_type': 'MLP_MIL_model2'}

# Define sweep configuration

sweep_configuration = {
    "method": "bayes",
    "name": "sweep3",
    "metric": {"goal": "maximize", "name": "Max Test AUC"},
    "parameters": {
        #"dataset_version": {"values": dataset_version},
        "num_synthetic": {"values": [10, 20, 25, 30]},
        "oversample": {"values": [1, 1.25, 1.5]},
        "max_node_slices": {"values": [15, 20, 25, 30]},
        "threshold": {"max": 0.55, "min": 0.4},
        "batch_size": {"values": [64, 100, 128, 150]},
        "lr": {"max": 0.01, "min": 0.0001},
        "weight_decay": {"max": 0.2, "min": 0.04},
        "accumulation_steps": {"values": [2, 3, 4, 5]},
        "patch_hidden_dim": {"values": [128, 256, 512, 1024, 1536, 2048, 2560]},
        "patient_hidden_dim": {"values": [16, 24, 32, 36, 46, 64, 96, 128]},
        "patch_dropout": {"values": [0.2, 0.3, 0.4]},
        "patient_dropout": {"values": [0.2, 0.3, 0.4]},
        "alpha": {"values": [0.2, 0.6, 0.8]}, # "alpha*max_vals + (0.9-alpha)*classifications + 0.1*attentions
        "attention_indicator": {"values": [True, False]},
        "model_type": {"values": ["MLP_MIL_model2"]}, #"MLP_MIL_model_simple",
        "clinical_data_options": {"values": [["T_stage", "size", "border", "patient"], ["T_stage", "size", "border"], ["T_stage", "size", "patient"], ["T_stage", "border", "patient"]]},
    },
}

# Initialize sweep by passing in config.
# Provide a name of the project.
sweep_id = wandb.sweep(sweep=sweep_configuration, project="CNN-MLP-bayesian-sweep2")

In [3]:
class InstanceCNN(nn.Module):
    def __init__(self):
        super(InstanceCNN, self).__init__()

        # model = models.resnet18(pretrained=True)
        # # Modify first convolutional layer to accept single channel input
        # # Original in_channels for ResNet-18 is 3
        # model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        #self.features = nn.Sequential(*list(model.children())[:-1])

        model = DenseNet121(spatial_dims=2, in_channels=1, out_channels=1, pretrained=True)
        self.features = nn.Sequential(*list(model.children())[:-1],
                                      nn.ReLU(inplace=True),
                                      nn.AdaptiveAvgPool2d(output_size=1))

        for param in self.features.parameters():
            param.requires_grad = False

    def forward(self, x):
        x = self.features(x)
        return x

In [None]:
wandb.login()

In [5]:
Run = 0
best_test_preds, best_test_probs= [], []
def main():
    global Run, best_test_preds, best_test_probs
    best_score = {'TP': 10, 'FP':10, 'Train_Sensitivity': 0.6}
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    run = wandb.init()
    Run += 1
    print('Run:', Run)
    num_logged = 0
    # Hyperparams    
    num_epochs = 200
    #dataset_version = wandb.config.dataset_version
    num_synthetic = wandb.config.num_synthetic
    oversample = wandb.config.oversample
    max_node_slices = wandb.config.max_node_slices
    threshold = wandb.config.threshold
    batch_size = wandb.config.batch_size
    lr = wandb.config.lr
    weight_decay = wandb.config.weight_decay
    accumulation_steps = wandb.config.accumulation_steps
    patch_hidden_dim = wandb.config.patch_hidden_dim
    patient_hidden_dim = wandb.config.patient_hidden_dim
    patch_dropout = wandb.config.patch_dropout
    patient_dropout = wandb.config.patient_dropout
    model_type = wandb.config.model_type
    attention_indicator = wandb.config.attention_indicator
    alpha = wandb.config.alpha
    clinical_data_options = wandb.config.clinical_data_options
    clinical_length = 0
    if "size" in clinical_data_options:
        clinical_length += 3
    if "border" in clinical_data_options:
        clinical_length += 2

    hyperparams = {'num_epochs': num_epochs, 'threshold': threshold, 'num_synthetic': num_synthetic, 'oversample': oversample,
                   'batch_size': batch_size, 'lr': lr, 'weight_decay': weight_decay, 'accumulation_steps': accumulation_steps,
                   'patch_hidden_dim': patch_hidden_dim, 'patient_hidden_dim': patient_hidden_dim,
                   'patch_dropout': patch_dropout, 'patient_dropout': patient_dropout, 'alpha': alpha,
                   'attention_indicator': attention_indicator, 'max_node_slices': max_node_slices, 'model_type': model_type,
                   'clinical_data_options': clinical_data_options, 'device': device}

    print(hyperparams)
    print('Device:', device)

    time_start = perf_counter()

    results_path = r"C:\Users\mm17b2k.DS\Documents\Python\ARCANE_Results\CNN_MLP_Results"
    save_results_path = rf"C:\Users\mm17b2k.DS\Documents\Python\ARCANE_Results\CNN_MLP_Results\MLP_{Run}.pt"
    # Load the dataset
    IMAGE_DIR = r"C:\Users\mm17b2k.DS\Documents\ARCANE_Data\Cohort1_2D_slices"
    cohort1 = pd.read_excel(r"C:\Users\mm17b2k.DS\Documents\ARCANE_Data\Cohort1.xlsx")
    latent_vectors = np.load(r"C:\Users\mm17b2k.DS\Documents\Python\ARCANE_Results\VAE2_results\latent_vectors_36.npy")



    VAE_params_path = r"C:\Users\mm17b2k.DS\Documents\Python\ARCANE_Results\VAE2_results\VAE_36.pt"
    checkpoint = torch.load(VAE_params_path)
    train_test_split_dict = checkpoint['train_test_split']
    train_ids = train_test_split_dict['train']
    test_ids = train_test_split_dict['test']
    patient_slices_dict, patient_labels_dict, patient_file_names_dict, short_long_axes_dict, mlp_train_ids, test_ids, mlp_train_labels, test_labels, train_images, test_images, train_test_split_dict, mask_sizes = prepare_VAE_MLP_joint_data(first_time_train_test_split=False, train_ids=train_ids, test_ids=test_ids, num_synthetic=num_synthetic, oversample_ratio=oversample)

    all_files_list = ['\mri' + '//' + f for f in os.listdir(IMAGE_DIR + '\mri')] + ['\mri_aug' + '//' + f  for f in os.listdir(IMAGE_DIR + '\mri_aug')]
    all_files_list.sort()
    all_files_list = get_single_scan_file_list(all_files_list, IMAGE_DIR, cohort1)

    patient_file_names_dict = {}
    for patient in patient_slices_dict.keys():
        for idx in patient_slices_dict[patient]:
            if patient in patient_file_names_dict.keys():
                patient_file_names_dict[patient].append(all_files_list[idx])
            else:
                patient_file_names_dict[patient] = [all_files_list[idx]]

    train_dataset = Load_CNN_Images(patient_file_names_dict, patient_labels_dict, mlp_train_ids, cohort1, all_files_list, short_long_axes_dict, mask_sizes, clinical_data_options, max_nodes=max_node_slices)
    
    test_dataset = Load_CNN_Images(patient_file_names_dict, patient_labels_dict, test_ids, cohort1, all_files_list, short_long_axes_dict, mask_sizes, clinical_data_options, max_nodes=max_node_slices)
    
    train_dataloader = DataLoader(train_dataset, batch_size, shuffle=True)
    test_dataloader = DataLoader(test_dataset, batch_size=10, shuffle=False)

    # Initialise CNN model, loss function, and optimizer
    backbone = InstanceCNN()
    patch_input_dim = 1024

    # Instantiate the model
    if model_type == 'MLP_MIL_model_simple':
        model = MLP_MIL_model_simple(patch_input_dim=patch_input_dim+clinical_length, hyperparams=hyperparams,
                                     backbone_indicator=True, backbone=backbone)
    if model_type == 'MLP_MIL_model2':
        model = MLP_MIL_model2(patch_input_dim=patch_input_dim+clinical_length, hyperparams=hyperparams, 
                               backbone_indicator=True, backbone=backbone)

    #model = GatedAttention(patch_input_dim)
    model.apply(weights_init)
    model.to(device)


    criterion = nn.BCELoss()
    optimiser = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimiser, mode='min', factor=0.5, patience=40,
                                                              verbose=True, threshold=0.001, threshold_mode='abs')

    train_losses, test_losses = [], []
    train_AUCs, test_AUCs = [], []
    train_sensitivitys, test_sensitivitys = [], []
    batches_mixed = 0
    early_stopping = 0
    for epoch in range(num_epochs):
        # Training phase
        model.train()
        train_loss = 0
        test_loss = 0
        all_train_labels = []
        all_train_preds = []
        all_train_probs = []
        steps = 0
        optimiser.zero_grad()


        for features, label, clinical_data, number_of_nodes, LN_features in train_dataloader:
            features, label = features.to(device), label.to(device)
            clinical_data, number_of_nodes = clinical_data.to(device), number_of_nodes.to(device)
            LN_features = LN_features.to(device)
            steps += 1
            # Forward pass
            #output = model(features.squeeze(0))  # Remove batch dimension
            output, max_vals, attentions, classifications = model(features, clinical_data, number_of_nodes, label, LN_features)
            output = output.squeeze(1)

            # binary threshold classifications
            # classifications = torch.where(classifications > 0.5, torch.tensor([1.]).to(device), torch.tensor([0.]).to(device))
            # print('Classifications:', classifications)


            #print(torch.mean(label.float()))
            # if label == 1:
            #     weight = torch.tensor([2.0]).to(device)
            # if label == 0:
            #     weight = torch.tensor([1.0]).to(device)

            loss = criterion(output, label.float()) #*weight
            train_loss += loss.item()

            # # Backward pass and optimization
            # optimiser.zero_grad()
            # loss.backward()
            # optimiser.step()
            # Backward pass and optimization
            loss.backward()
            if steps % accumulation_steps == 1:
                optimiser.step()
                optimiser.zero_grad()

            # Apply threshold to determine predicted class
            #predicted_probs = F.softmax(output, dim=1)[:, 1]  # Probability of class 1 (positive)
            #predicted_probs = torch.sigmoid(output)

            predicted_probs = output
            classifications_class = (classifications >= threshold).long()
            #predicted_probs = 0.6*max_vals + 0.35*classifications_class + 0.05*attentions
            predicted_class = (predicted_probs >= threshold).long()


            # Store predictions and labels
            all_train_labels.extend(label.cpu().numpy())
            all_train_preds.extend(predicted_class.cpu().numpy())
            all_train_probs.extend(predicted_probs.tolist())
            # random_int = np.random.randint(1, 20)
            # if random_int == 1:
            #     rdn_idx  = np.random.randint(0, len(features))
            #     print(rdn_idx, 'label (train)', label[rdn_idx].item(), 'output', output[rdn_idx].item(), 'predicted class', predicted_class[rdn_idx].item(), 'max', max_vals[rdn_idx].item(), 'attention', attentions[rdn_idx].item(), 'classification', classifications[rdn_idx].item(), 'class binary', classifications_class[rdn_idx].item(), 'number of nodes', number_of_nodes[rdn_idx].item()) # 'reweighted prediction', predicted_probs[rdn_idx].item())

        optimiser.step()
        optimiser.zero_grad()
        lr_scheduler.step(train_loss/len(train_dataloader))
        if epoch % 5 == 0 or epoch + 20 > num_epochs-1:
            print('Learning rate:', optimiser.param_groups[0]['lr'])
        train_losses.append(train_loss/len(train_dataloader))
        train_accuracy = accuracy_score(all_train_labels, all_train_preds)
        train_auc = roc_auc_score(all_train_labels, all_train_preds)
        train_AUCs.append(train_auc)
        train_bal_accuracy = balanced_accuracy_score(all_train_labels, all_train_preds)
        train_confusion_matrix = confusion_matrix(all_train_labels, all_train_preds)
        tn, fp, fn, tp = confusion_matrix(all_train_labels, all_train_preds).ravel()
        # Compute sensitivity (recall) and specificity
        train_sensitivity = tp / (tp + fn)
        train_specificity = tn / (tn + fp)
        train_sensitivitys.append(train_sensitivity)


        print(f'Epoch [{epoch+1}/{num_epochs}], Train: Loss: {train_loss/len(train_dataloader):.4f}, Accuracy: {train_accuracy:.4f}, Balanced Accuracy: {train_bal_accuracy:.4f}, AUC: {train_auc:.4f}, Sensitivity: {train_sensitivity:.4f}, Specificity: {train_specificity:.4f}')
        print(f'Train Confusion Matrix:')
        print(train_confusion_matrix)

        # Evaluation phase
        model.eval()
        test_loss = 0
        all_test_labels = []
        all_test_preds = []
        all_test_probs = []
        with torch.no_grad():
            for features, label, clinical_data, number_of_nodes, LN_features in test_dataloader:
                features, label = features.to(device), label.to(device)
                clinical_data, number_of_nodes = clinical_data.to(device), number_of_nodes.to(device)
                LN_features = LN_features.to(device)
                #output = model(features.squeeze(0))  # Remove batch dimension
                output, max_vals, attentions, classifications = model(features, clinical_data, number_of_nodes, label, LN_features)
                output = output.squeeze(1)

                #output = output.squeeze(0)
                loss = criterion(output, label.float())
                test_loss += loss.item()


                # Store predictions and labels
                #predicted_probs = F.softmax(output, dim=1)[:, 1]  # Probability of class 1 (positive)
                #predicted_probs = torch.sigmoid(output)
                predicted_probs = output
                classifications_class = (classifications >= threshold).long()
                #predicted_probs = 0.6*max_vals + 0.35*classifications_class + 0.05*attentions
                predicted_class = (predicted_probs >= threshold).type(torch.long)
                all_test_labels.extend(label.cpu().numpy())
                all_test_preds.extend(predicted_class.cpu().numpy())
                all_test_probs.extend(predicted_probs.cpu().numpy())

                # random_int = np.random.randint(1, 8)
                # if random_int == 1:
                #     rdn_idx  = np.random.randint(0, len(features))
                #     print(rdn_idx, 'label (test)', label[rdn_idx].item(), 'output', output[rdn_idx].item(), 'predicted class', predicted_class[rdn_idx].item(), 'max', max_vals[rdn_idx].item(), 'attention', attentions[rdn_idx].item(), 'classification', classifications[rdn_idx].item(), 'class binary', classifications_class[rdn_idx].item(), 'number of nodes', number_of_nodes[rdn_idx].item()) #'reweighted prediction', predicted_probs[rdn_idx].item())
                # 

        test_losses.append(test_loss/len(test_dataloader))
        test_accuracy = accuracy_score(all_test_labels, all_test_preds)
        test_auc = roc_auc_score(all_test_labels, all_test_preds)
        test_AUCs.append(test_auc)
        test_bal_accuracy = balanced_accuracy_score(all_test_labels, all_test_preds)
        test_confusion_matrix = confusion_matrix(all_test_labels, all_test_preds)
        tn, fp, fn, tp = confusion_matrix(all_test_labels, all_test_preds).ravel()
        # Compute sensitivity (recall) and specificity
        test_sensitivity = tp / (tp + fn)
        test_specificity = tn / (tn + fp)
        test_sensitivitys.append(test_sensitivity)
        #if epoch % 5 == 0 or epoch + 20 > num_epochs-1:
        print(f'Test: Loss: {test_loss/len(test_dataloader):.4f}, Accuracy: {test_accuracy:.4f}, Balanced Accuracy: {test_bal_accuracy:.4f}, AUC: {test_auc:.4f}, Sensitivity: {test_sensitivity:.4f}, Specificity: {test_specificity:.4f}')
        print('Test Confusion Matrix:')
        print(test_confusion_matrix)
        # Wait for GPU to cool down for 10 seconds
        time.sleep(10)

        if epoch == 0:
            test_labels = np.array(all_test_labels)

        if tp >= 10 and fp <= 10:
            # error analysis
            best_test_preds.append(all_test_preds)
            best_test_probs.append(all_test_probs)
            print('number of preds logged:', len(best_test_probs))
            error_analysis(np.array(best_test_probs), test_labels, results_path, threshold)
            num_logged += 1
            if tp > best_score['TP'] or (tp >= best_score['TP'] and fp < best_score['FP']) or (tp >= best_score['TP'] and fp <= best_score['FP'] and train_sensitivity > best_score['Train_Sensitivity']):
                best_score['TP'] = tp
                best_score['FP'] = fp
                best_score['Train_Sensitivity'] = train_sensitivity
                print('Saving model with TP:', tp, 'and FP:', fp, 'at epoch:', epoch)
                training_results = {'train_losses': train_losses, 'test_losses': test_losses, 'train_AUCs': train_AUCs, 'test_AUCs': test_AUCs, 'train_sensitivitys': train_sensitivitys, 'test_sensitivitys': test_sensitivitys,
                                    'all test labels': all_test_labels, 'all test probs': all_test_probs}
                torch.save({"state_dict": model.state_dict(), "training_results": training_results,
                            "hyperparams": hyperparams, "train_test_split": train_test_split_dict}, save_results_path)
                calibration_curve_and_distribution(all_train_labels, all_train_probs, 'Train', results_path, 'saved_result_' + str(Run), save=True)
                calibration_curve_and_distribution(all_test_labels, all_test_probs, 'Test', results_path, 'saved_result_' + str(Run), save=True)

                # plot results at this stage (updating until the best run)
                plot_MLP_results(training_results, hyperparams, results_path=results_path, filename='MLP_training_results_run_{}.png'.format(Run))



        if epoch == num_epochs-1:
            calibration_curve_and_distribution(all_train_labels, all_train_probs, 'Train', results_path, Run)
            calibration_curve_and_distribution(all_test_labels, all_test_probs, 'Test', results_path, Run)

        # log to wandb
        wandb.log(
            {
                "Test Loss": test_loss/len(test_dataloader),
                "Test Accuracy": test_accuracy,
                "Test AUC": test_auc,
                "Test Sensitivity": test_sensitivity,
                "Test Specificity": test_specificity,
                "Test TP": tp,
                "Test FP": fp,
                "Train Loss": train_loss/len(train_dataloader),
                "Train Accuracy": train_accuracy,
                "Train AUC": train_auc,
                "Train Sensitivity": train_sensitivity,
                "Train Specificity": train_specificity,
                "Max Test AUC": np.max(test_AUCs),
            }
        )

        # Early stopping
        if epoch > 25 and test_auc < 0.7:
            early_stopping+=1
            if early_stopping > 15 and test_auc < 0.6 and num_logged == 0:
                print('Early stopping')
                break
            if early_stopping > 25 and test_auc < 0.65 and num_logged == 0:
                print('Early stopping')
                break

            if early_stopping > 50 and test_auc < 0.7 and num_logged <= 1:
                print('Early stopping')
                break


    # save test preds and probs
    np.save(results_path + '//best_test_preds.npy', np.array(best_test_preds))
    np.save(results_path + '//best_test_probs.npy', np.array(best_test_probs))

    print('Batches mixed:', batches_mixed, 'out of', len(train_dataloader)*num_epochs, 'percentage:', batches_mixed/(len(train_dataloader)*num_epochs))
    print(hyperparams)
    print('Time taken:', perf_counter() - time_start)
    # Wait for GPU to cool down after each model run
    print(f"Cooling down for 5 mins...")
    time.sleep(60*5)

In [None]:
# Start sweep job.
wandb.agent(sweep_id, function=main, count=100)