### Create Grad-CAM heatmaps for the VAE-MLP model for the largest short-axis diameter LNs

In [None]:
import numpy as np
import os
import sys
import math
from time import perf_counter
import time
import pandas as pd
import matplotlib.pyplot as plt
import torch
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
import torch.nn.functional as F
import torch.nn as nn
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, roc_auc_score, balanced_accuracy_score, confusion_matrix
from scipy.spatial.distance import cdist
from hyperopt import hp, fmin, tpe, Trials
import nibabel as nib
import wandb
import random
import pickle
from torchvision import transforms
import cv2


sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))
from utils.datasets import Load_Latent_Vectors, LoadImages, prepare_VAE_MLP_joint_data
from utils.utility_code import get_single_scan_file_list, get_class_distribution, weights_init, plot_MLP_results, error_analysis
from models.MLP_model import MLP_MIL_model_simple, MLP_MIL_model2
from models.VAE_2D_model import VAE_2D
from utils.train_and_test_functions import mixup_patient_data, mixup_batch, process_batch_with_noise, calibration_curve_and_distribution

In [None]:
# custom dataloader to give original MRI patches, VAE latent vectors, and clinical data
def patient_test_data(patient_id):
    add_images = True
    indices = patient_slices_dict[patient_id]
    if len(indices) > max_node_slices:
        mask_sizes = mask_sizes[patient_id]
        sizes = sorted(enumerate(mask_sizes), key=lambda x: x[1], reverse=True)
        biggest_n_mask_idx = [i for i, size in sizes[:self.max_nodes]]
        indices = sorted([indices[i] for i in biggest_n_mask_idx])
    
    # add clinical data to the latent vectors
    patient_indicator = False
    patient_options = []
    if "patient" in clinical_data_options:
        patient_indicator = True
        patient_options.append('age_scaled')
        patient_options.append('sex_numeric')
    if "T_stage" in clinical_data_options:
        patient_indicator = True
        patient_options.append('TumourLabel_numeric')
    
    if patient_options == []:
        patient_clinical_data = None
    else:
        patient_clinical_data = cohort1[cohort1['shortpatpseudoid'] == patient_id.split('_')[0]][patient_options].values.tolist()
        
    with open(r"C:\Users\mm17b2k.DS\Documents\ARCANE_Data\border_metrics.pkl", 'rb') as f:
        compactness, convexity = pickle.load(f)
        
    LN_features = []
    images = np.zeros((len(indices), 1, 32, 32))
    for i, index in enumerate(indices):
        file_name = all_files_list[index]
        pat_id = patient_id.split('_')[0]
        node_number = float(file_name.split('//')[1].split('_')[6])
        # print(file_name, node_number)
        long, short, ratio = short_long_axes_dict[pat_id][node_number]
        mask_file = file_name.replace('mri', 'mask')
        # compactness = compactness[mask_file]
        # convexity = convexity[mask_file]
        # print(long, short, ratio, compactness, convexity)
    
        node_indicator = False
        node_options = []
        if "size" in clinical_data_options:
            node_indicator = True
            node_options.append(long)
            node_options.append(short)
            node_options.append(ratio)
        # if "border" in clinical_data_options:
        #     node_indicator = True
        #     node_options.append(compactness)
        #     node_options.append(convexity)
    
        LN_features.append(node_options)

        transform = transforms.Compose([transforms.ToTensor()])
        img_dir = r"C:\Users\mm17b2k.DS\Documents\ARCANE_Data\Cohort1_2D_slices/"
        img = nib.load(img_dir + file_name).get_fdata()
        img = transform(img)
        images[i] = img
    
    
    LN_features = np.array(LN_features)
    if patient_indicator == True and ("size" in clinical_data_options):
        patient_clinical_data = patient_clinical_data[0] + LN_features[np.argmax(LN_features[:, 0])].tolist() # add the data for node with max long axis ratio (and corresponding short/ratio/compactness/convexity) to patient level clinical data
    if patient_indicator == True and ("border" in clinical_data_options) and ("size" not in clinical_data_options):
        patient_clinical_data = patient_clinical_data[0] + LN_features[np.argmin(LN_features[:, 0])].tolist() # add min compactness node data to patient level clinical data
    if patient_indicator == True and node_indicator == False:
        patient_clinical_data = patient_clinical_data[0]
    if patient_indicator == False and node_indicator == True:
        if "size" in clinical_data_options:
            patient_clinical_data = LN_features[np.argmax(LN_features[:, 0])].tolist()
        if ("border" in clinical_data_options) and ("size" not in clinical_data_options):
            patient_clinical_data = LN_features[np.argmin(LN_features[:, 0])].tolist()
    if patient_indicator == False and node_indicator == False:
        patient_clinical_data = []
    
    label = patient_labels_dict[patient_id]
    
    if add_images == True:
        number_of_nodes = len(images)
        if len(images) < max_node_slices:
            #print(LN_features.shape, label, patient_clinical_data, number_of_nodes)
            images = np.concatenate((images, np.zeros((max_node_slices - len(images), 1, 32, 32))), axis=0)
            LN_features = np.concatenate((LN_features, np.ones((max_node_slices - len(LN_features), LN_features.shape[1]))*0.5), axis=0)
            #print(images.shape, LN_features.shape)
    
        return torch.tensor(LN_features, dtype=torch.float32), torch.tensor(label, dtype=torch.long), torch.tensor(patient_clinical_data, dtype=torch.float32), torch.tensor(number_of_nodes, dtype=torch.float32), torch.tensor(images, dtype=torch.float32)

# LN_features, label, patient_clinical_data, num_nodes, images = patient_test_data(test_ids[0])
# print(LN_features.shape, label, patient_clinical_data, num_nodes, images.shape)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

IMAGE_DIR = r"C:\Users\mm17b2k.DS\Documents\ARCANE_Data\Cohort1_2D_slices"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
cohort1 = pd.read_excel(r"C:\Users\mm17b2k.DS\Documents\ARCANE_Data\Cohort1.xlsx")
latent_vectors = np.load(r"C:\Users\mm17b2k.DS\Documents\Python\ARCANE_Results\VAE2_results\latent_vectors_36.npy")

all_files_list = ['\mri' + '//' + f for f in os.listdir(IMAGE_DIR + '\mri')] + ['\mri_aug' + '//' + f  for f in os.listdir(IMAGE_DIR + '\mri_aug')]
all_files_list.sort()
all_files_list = get_single_scan_file_list(all_files_list, IMAGE_DIR, cohort1)


VAE_params_path = r"C:\Users\mm17b2k.DS\Documents\Python\ARCANE_Results\VAE2_results\VAE_36.pt"
checkpoint = torch.load(VAE_params_path)
train_test_split_dict = checkpoint['train_test_split']
train_ids = train_test_split_dict['train']
test_ids = train_test_split_dict['test']
hyperparams = checkpoint['hyperparams']
vae_model = VAE_2D(hyperparams)
vae_model = vae_model.to(device)
# Load the saved checkpoint
vae_model.load_state_dict(checkpoint["state_dict"])
# Put the model into evaluation mode if you're not training anymore
vae_model.eval()
encoder = vae_model.encoder


mlp_checkpoint = torch.load(r"C:\Users\mm17b2k.DS\Documents\Python\ARCANE_Results\MLP_Results\best_model\MLP_62.pt")
mlp_hyperparams = mlp_checkpoint['hyperparams']
print(mlp_hyperparams)
clinical_data_options = mlp_hyperparams['clinical_data_options']
clinical_length = 0
if "size" in clinical_data_options:
    clinical_length += 3
if "border" in clinical_data_options:
    clinical_length += 2
print("clinical_length", clinical_length)
mlp_model = MLP_MIL_model2(patch_input_dim=400+clinical_length, hyperparams=mlp_hyperparams, grad_cam=True)
mlp_model = mlp_model.to(device)
mlp_model.load_state_dict(mlp_checkpoint["state_dict"])
mlp_model.eval()

max_node_slices = mlp_hyperparams['max_node_slices']
n_synthetic = 0
oversample = 1



first_time = False
if first_time:
    patient_slices_dict, patient_labels_dict, patient_file_names_dict, short_long_axes_dict, mlp_train_ids, test_ids, mlp_train_labels, test_labels, train_images, test_images, train_test_split_dict, mask_sizes = prepare_VAE_MLP_joint_data(first_time_train_test_split=False, train_ids=train_ids, test_ids=test_ids, num_synthetic=n_synthetic, oversample_ratio=oversample)
    data_dictionaries = {"slices": patient_slices_dict, "labels": patient_labels_dict, "files": patient_file_names_dict, "sizes": short_long_axes_dict, "mask_sizes": mask_sizes, "mlp_train_ids": mlp_train_ids, "test_ids": test_ids, "mlp_train_labels": mlp_train_labels, "test_labels": test_labels, "train_images": train_images, "test_images": test_images, "train_test_split_dict": train_test_split_dict}
    torch.save(data_dictionaries, r'C:\Users\mm17b2k.DS\Documents\Python\ARCANE_Results\MLP_Results\data_dictionaries.pth')
else:
    data_dictionaries = torch.load(r'C:\Users\mm17b2k.DS\Documents\Python\ARCANE_Results\MLP_Results\data_dictionaries.pth')
    patient_slices_dict = data_dictionaries["slices"]
    patient_labels_dict = data_dictionaries["labels"]
    patient_file_names_dict = data_dictionaries["files"]
    short_long_axes_dict = data_dictionaries["sizes"]
    mask_sizes = data_dictionaries["mask_sizes"]
    mlp_train_ids = data_dictionaries["mlp_train_ids"]
    test_ids = data_dictionaries["test_ids"]
    mlp_train_labels = data_dictionaries["mlp_train_labels"]
    test_labels = data_dictionaries["test_labels"]
    train_images = data_dictionaries["train_images"]
    test_images = data_dictionaries["test_images"]
    train_test_split_dict = data_dictionaries["train_test_split_dict"]


train_dataset = LoadImages(main_dir=IMAGE_DIR + '/', files_list=train_images)
test_dataset = LoadImages(main_dir=IMAGE_DIR + '/', files_list=test_images)
train_loader = DataLoader(train_dataset, 1, shuffle=True)
test_loader = DataLoader(test_dataset, 1, shuffle=False)


original_images = []
reconstructed_images = []

# random seed
random.seed(40)



#random_indices = random.sample(range(59), 16)
examples = [6, 12, 14, 33, 37, 41, 43, 21]

# # With no gradients, gather reconstructions
# with torch.no_grad():
#     for batch_idx, data in enumerate(test_loader):
criterion = nn.BCELoss()

important_latent_variables = {i:0 for i in range(400)}
#for patient_idx in examples:

for patient_idx, id in enumerate(test_ids):
#for patient_idx in examples:
    print('Patient ID:', id, 'Patient index:', patient_idx)
    if id == 42:
        pass
    else:
        LN_features, label, clinical_data, num_nodes, images = patient_test_data(test_ids[patient_idx])
        #images = images[0].unsqueeze(0)  # Unsqueeze to make it a batch of 1
        #LN_features = LN_features[0].unsqueeze(0)

        images = images.float().to(device)
        images.requires_grad_()
        
        # Hook function to save feature maps
        def save_feature_maps(module, input, output):
            # Saving the output feature map from this layer
            feature_maps.append(output)
            output.retain_grad()  # Ensure gradients are stored for this tensor


        # Register the hook to the last convolutional layer
        feature_maps = []
        last_conv_layer = vae_model.encoder[4].conv[0] 
        hook = last_conv_layer.register_forward_hook(save_feature_maps)
        

        # Pass the data through the model to get reconstructions
        vae_model.eval()
        vae_model.zero_grad()
        
        recons, mu, log_var = vae_model(images.float().to(device))

        mu = mu.squeeze() 
        
        LN_features = torch.tensor(LN_features, dtype=torch.float32).to(device)
        print('LN_features:', LN_features.shape, 'mu:', mu.shape)
        features = torch.cat((mu, LN_features), axis=1)

        label, clinical_data, num_nodes, features = label.unsqueeze(dim=0).to(device), clinical_data.unsqueeze(dim=0).to(device), num_nodes.unsqueeze(dim=0).to(device), features.unsqueeze(dim=0)   
        print('Patient ID:', train_ids[patient_idx], 'Label:', label.squeeze().item(), 'Num nodes:', num_nodes.squeeze().item())

        #(25, 403)
        output, max_vals, refined_LNM_predictions, patch_features = mlp_model(features, clinical_data, num_nodes, label)

    
    
        print('Max prediction:', max_vals)
        print('prediction', output, 'mlp pred', refined_LNM_predictions)
        #print('patch_predictions', patch_features)
        #max_node_index = torch.argmax(patch_features, dim=1)
        node_features = LN_features.detach().cpu().numpy()
        node_features = node_features[:int(num_nodes.squeeze().item()), :] 
        max_node_index = np.argmax(node_features[:, 0])
        print('Max node index: ', max_node_index, 'short-axis diameter: ', node_features)

        biggest_size = 0.35
        if label == 1:
            for i in range(int(num_nodes.item())):
                if node_features[i][0] > biggest_size:
                    biggest_size = node_features[i][0]
                    max_node_index = i
                    print('Max node index:', max_node_index, 'short-axis diameter:', node_features[i][0])
                else:
                    continue
                # run through maps 80 x (8x8) and average gradient
                gradients = []
                image_idx = max_node_index
                print('Max node index:', max_node_index, feature_maps[0].shape)
                for i in range(80):
                    for j in range(8):
                        for k in range(8):
                            images.grad = None
                            vae_model.eval()
                            vae_model.zero_grad()
                            feature_maps = []
                            recons, mu, log_var = vae_model(images.float().to(device))
                            feature_maps = feature_maps[0][image_idx].squeeze()
                            feature_maps[i][j][k].backward(retain_graph=True)
                            gradients.append(images.grad)
                gradients = torch.mean(torch.stack(gradients), dim=0)
                

                print('clinical data:', clinical_data)

                grad_cam_map = F.relu(gradients.squeeze())  # Apply ReLU to focus on 
        
                # Normalize heatmap
                heatmap = grad_cam_map.detach().cpu().numpy()
                heatmap = heatmap[max_node_index]
                heatmap = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min())
                heatmap = (1 - heatmap)  # Invert heatmap
                heatmap = (heatmap / heatmap.max()) * 255  # Scale to [0, 255]
                heatmap = np.uint8(heatmap)
        
                # Overlay heatmap onto the original image
                original_image = images.detach().squeeze().cpu().numpy()  
                original_image = original_image[max_node_index]
                original_image = (original_image / original_image.max()) * 255  # Scale to [0, 255]
                original_image = np.uint8(original_image)  # Convert to uint8 for OpenCV compatibility
                # Scale heatmap to [0, 255] and convert to uint8        
                heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)  # Convert to color map for better visualization
                overlayed_img = cv2.addWeighted(cv2.cvtColor(original_image, cv2.COLOR_GRAY2BGR), 0.5, heatmap, 0.5, 0)
        
                # Display
                plt.figure(figsize=(10, 10))
                plt.subplot(1, 2, 1)
                plt.title("Original Image", fontsize=18)
                
                plt.imshow(original_image, cmap='gray')
                plt.axis('off')
                plt.subplot(1, 2, 2)
                plt.title("Grad-CAM Heatmap", fontsize=18)
                plt.imshow(overlayed_img)
                plt.axis('off')
                # save the plot
                plt.savefig(r"C:\Users\mm17b2k.DS\Documents\Python\ARCANE_Results\Grad-CAM\Grad-CAM_" + str(patient_idx) + "_" + str(max_node_index) + ".png")
                plt.show()

