In [8]:
# Import M3d-CAM
from medcam import medcam
import torch
from get_model import get_trained_model
import numpy as np
import matplotlib.pyplot as plt
import os
import cv2
%matplotlib inline

BACKENDS = ["gcampp", "gcam", "gbp", "ggcam"]
MODALITIES = ["DWI", "T2W", "ADC"]

pairings = [
    [139,197],
    [17, 175],
    [93, 55],
    [95, 67],
    [118, 85]
]

def setup_image(image: np.array, slice: int, modality: int):
    data = np.transpose(image, (2, 3, 0, 1))
    data = data[modality, :, :, slice]
    normalized_image = (data-np.min(data))/(np.max(data)-np.min(data)) * 255
    normalized_image = normalized_image.astype(np.uint8)
    normalized_image = cv2.resize(normalized_image, (84, 84))
    return normalized_image

for i in pairings:
    patient_id = i[1]
    fold = i[0]
    MODEL_PATH = f'pca-results/3_modalities_cropped_T2W_clinical_sig/best_weights/best_metric_model_fold_{fold}.pth'
    SAVE_DIR = f"pca-results/attention_maps/3_modalities_v2/{patient_id}/individual_modalities"
    IMAGE_PATH = f'data/pca_processed_data/3_modalities_combined/ProstateX-{str(patient_id).zfill(4)}_multi_modal_img.npy'
    for backend in BACKENDS:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        model = get_trained_model(MODEL_PATH)
        model = model.to(device)
        # replace so that doesn't save the attention map to the directory 
        model = medcam.inject(model, output_dir="attention_maps", save_maps=True, backend=backend, layer="auto", replace=True)
        model.eval()

        image_data = np.load(IMAGE_PATH)
        print(image_data.shape) # (3, 12, 84, 84)

        new_input = torch.from_numpy(np.array([image_data])).float().to(device)
        print(new_input.shape) # torch.Size([1, 3, 12, 84, 84])

        output = model(new_input)
        print(output.shape) # torch.Size([1, 1, 12, 84, 84])

        output_np = output.cpu().numpy()

        output_np = (output_np - np.min(output_np)) / (np.max(output_np) - np.min(output_np)) # normalizing
        print(output_np.shape)
        num_slices = output_np.shape[2]  # Get the number of slices

        # Set up the figure size and layout
        for modality in range(3):
            fig, axes = plt.subplots(3, 4, figsize=(15, 10))
            axes = axes.flatten()

            for slice in range(num_slices):
                ax = axes[slice]
                ax.imshow(setup_image(image_data, slice, modality))  # Overlay with mask
                ax.imshow(output_np[0, 0, slice, :, :], cmap='jet', alpha=0.5)
                ax.axis('off') 
                ax.set_title(f'Slice {slice+1}')

            plt.tight_layout()
            os.makedirs(SAVE_DIR, exist_ok=True)
            plt.savefig(f"{SAVE_DIR}/{backend}-{MODALITIES[modality]}.png")
            plt.close(fig)

(3, 12, 84, 84)
torch.Size([1, 3, 12, 84, 84])
torch.Size([1, 1, 12, 84, 84])
(1, 1, 12, 84, 84)
(3, 12, 84, 84)
torch.Size([1, 3, 12, 84, 84])
torch.Size([1, 1, 12, 84, 84])
(1, 1, 12, 84, 84)
(3, 12, 84, 84)
torch.Size([1, 3, 12, 84, 84])
torch.Size([1, 1, 12, 84, 84])
(1, 1, 12, 84, 84)
(3, 12, 84, 84)
torch.Size([1, 3, 12, 84, 84])
torch.Size([1, 1, 12, 84, 84])
(1, 1, 12, 84, 84)
