Grad-CAM applied to Resnet-18 model for binary classification

In [1]:
import os
import cv2
import torch
import numpy as np
from torchvision import transforms
import torchvision.utils as vutils
import torchvision.transforms.functional as TF
from PIL import Image 
from MedViT_model import MedViT_base
from pytorch_grad_cam import GradCAM, GradCAMPlusPlus, HiResCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget

>>> [MedViT DEBUG] Loaded this exact file!


In [2]:
# device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
# === Function to recreate and load trained MedViT model ===
def load_medvit_model(weight_path, num_classes=2, device=torch.device("cpu")):
    # Cria modelo com a mesma estrutura usada no treino
    model = MedViT_base(pretrained=False, num_classes=1000) #create inference model in the original format

    # Carrega pesos do modelo treinado
    state_dict = torch.load(weight_path, map_location=device)


    in_features = model.proj_head[0].in_features
    model.proj_head = torch.nn.Sequential(torch.nn.Linear(in_features, num_classes)) # Ajusta a camada de saída para o número correto de classes
    model.load_state_dict(state_dict, strict=False) # Load the state dict with strict=False to ignore the last layer
    model.to(device)
    model.eval()
    print(model.features)
    return model


In [4]:
# Overlay heatmap
def overlay_heatmap(original_image, heatmap):
    # Debugging: Print heatmap shape and type
    #print(f"[Debug] Initial heatmap shape: {heatmap.shape}, dtype: {heatmap.dtype}")

    # Handle cases where the heatmap has an extra dimension (1, H, W)
    if len(heatmap.shape) == 3 and heatmap.shape[0] == 1:
        heatmap = np.squeeze(heatmap, axis=0)
        #print(f"[Debug] Squeezed heatmap shape: {heatmap.shape}")

    # Normalize the heatmap to range [0, 1] and convert to uint8
    heatmap = np.clip(heatmap, 0, 1)
    heatmap_uint8 = np.uint8(255 * heatmap)

    # Resize the heatmap to match the original image size
    heatmap_resized = cv2.resize(heatmap_uint8, (original_image.shape[1], original_image.shape[0]), interpolation=cv2.INTER_CUBIC)

    # Ensure the heatmap is in a compatible format for OpenCV
    if len(heatmap_resized.shape) == 2:  # Grayscale heatmap
        heatmap_color = cv2.applyColorMap(heatmap_resized, cv2.COLORMAP_JET)
    elif len(heatmap_resized.shape) == 3 and heatmap_resized.shape[2] == 1:  # Single channel but 3D
        heatmap_color = cv2.applyColorMap(heatmap_resized[:, :, 0], cv2.COLORMAP_JET)
    else:
        print(f"[Error] Unexpected heatmap shape after resizing: {heatmap_resized.shape}")
        raise ValueError("Heatmap shape is not compatible with cv2.applyColorMap.")

    # Convert grayscale original image to BGR if necessary
    if len(original_image.shape) == 2:
        original_image = cv2.cvtColor(original_image, cv2.COLOR_GRAY2BGR)

    # Blend the heatmap with the original image
    overlay = cv2.addWeighted(heatmap_color, 0.4, original_image, 0.6, 0)
    return overlay


In [5]:
# === Prediction Frame ===
def add_colored_frame(image, color, thickness=20):
    height, width = image.shape[:2]
    cv2.rectangle(image, (0, 0), (width - 1, height - 1), color, thickness)
    return image

In [6]:
# === Extract Ground Truth from folder path ===
def get_ground_truth(img_name, folder):
    # Combine the folder path and image name for a comprehensive check
    full_path = os.path.join(folder, img_name)
    
    # Check for the ground truth in a case-insensitive way
    #if "cesarean" in full_path.lower():  #DESCOMENTAR QUANDO FOR PARA CORRER O COGIDO A PARTIR DO DATASET ORIGINAL
    if "ces" in full_path.lower():
        return "Cesarean Birth"
    elif "vag" in full_path.lower():
        return "Vaginal Birth"
    else:
        print(f"Warning: Could not determine ground truth for {full_path}. Assuming 'Cesarean Birth'.")
        ground_truth = "Cesarean Birth"

    # Print the ground truth for each image
    #print(f"[Ground Truth] Image: {img_name} | Path: {full_path} | Picked: {ground_truth}")
    return ground_truth

In [7]:
def apply_gradcam_and_save(model, input_folder, output_folder, grid_image_size=(224, 224)):
    os.makedirs(output_folder, exist_ok=True)
    
    target_layers = [model.features[28]]


    transform = transforms.Compose([
        transforms.Resize((224, 224)), # Correct input size
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # Normalize for 1 channel 
    ])

    original_images = []
    correct_heatmaps = []
    incorrect_heatmaps = []
    misclassified_tensors = []
    correctly_classified_tensors = []
    #prediction_results = []

    class_names = ["Cesarean Birth", "Vaginal Birth"]

    for img_name in os.listdir(input_folder):
        if not img_name.lower().endswith((".png")):
            continue

        img_path = os.path.join(input_folder, img_name)
        image = Image.open(img_path).convert("RGB") 

        
        input_tensor = transform(image).unsqueeze(0).to(device)
        # Save the resized image as original (224x224) to ensure overlay consistency
        resized_image = transforms.Resize((224, 224))(image)
        original_np = np.array(resized_image)
        original_images.append(original_np)


        # Previsão para obter classe
        with torch.no_grad():
            output = model(input_tensor)
            prediction = output.argmax(dim=1).item()
            probabilities = torch.softmax(output, dim=1)
            confidence = probabilities[0, prediction].item()

        # Grad-CAM
        with GradCAM(model=model, target_layers=target_layers) as cam:   #use_cuda=torch.cuda.is_available()
            targets = [ClassifierOutputTarget(prediction)]
            grayscale_cam = cam(input_tensor=input_tensor, targets=targets)[0]  # [H, W]


        # Access raw activations (before interpolation)
        raw_activations = cam.activations_and_grads.activations[-1]  # shape: [B, C, H, W]
        print("RAW FEATURE MAP SHAPE:", raw_activations.shape)  # For example: torch.Size([1, 1664, 5, 5])

        # Just spatial resolution (pre-Grad-CAM aggregation)
        h_raw, w_raw = raw_activations.shape[-2:]
        print(f"Raw heatmap resolution: {h_raw}x{w_raw}")



        # === Ground truth from folder name ===
        ground_truth_label = get_ground_truth(img_name, input_folder)
        ground_truth_class = class_names.index(ground_truth_label)

        # Add frame
        frame_color = (0, 255, 0)  # green
        if prediction != ground_truth_class:
            frame_color = (0, 0, 255)  # red

        # === Overlay and save ===
        overlay = overlay_heatmap(original_np, grayscale_cam[np.newaxis, ...])  # adiciona dimensão extra
        overlay_with_frame = add_colored_frame(overlay, frame_color, thickness=7)

        # === Save heatmap image ===
        save_path = os.path.join(output_folder, f"heatmap_{class_names[prediction]}_{img_name}")
        cv2.imwrite(save_path, overlay_with_frame)

        # === Save heatmap for mean and tensor for grid ===
        overlay_rgb = cv2.cvtColor(overlay_with_frame, cv2.COLOR_BGR2RGB) #converts to rgb
        overlay_pil = Image.fromarray(overlay_rgb).resize(grid_image_size) # downsizes the image to fit the grid
        tensor_for_grid = TF.to_tensor(overlay_pil) #converts into a pytorch tensor [3,H,W] and normalizes pixel values

        if prediction == ground_truth_class:
            correctly_classified_tensors.append(tensor_for_grid)
            correct_heatmaps.append(grayscale_cam)  # 2D spatial heatmap , cam shape = [7,7]
        else:
            misclassified_tensors.append(tensor_for_grid)
            incorrect_heatmaps.append(grayscale_cam)


    # === CREATE GRID ===
    final_sorted_grid = misclassified_tensors + correctly_classified_tensors
    if final_sorted_grid:
        grid = vutils.make_grid(final_sorted_grid, nrow=20, padding=5, normalize=True)
        grid_path = os.path.join(output_folder, "grid_overlay_ordered.jpg")
        vutils.save_image(grid, grid_path)
        #print(f"Grid saved at: {grid_path}")

    # === CALCULATE CORRECT MEAN HEATMAP ===
    if correct_heatmaps: #list of NumPy arrays (raw grad-cam maps before resizing or coloring), each w\ shape [3, 3]
        mean_heatmap = np.mean(np.stack(correct_heatmaps), axis=0) # turns list of N heatmaps into 3D array w\ shape [N, 7, 7]
        mean_heatmap -= mean_heatmap.min() # subtracts the minimum value from every pixel so the lowest value is 0
        mean_heatmap /= (mean_heatmap.max() + 1e-8) # divides every pixel by the new maximum value (normalize) ; adds 1e-8 to ensure it wont be divided by 0
        heatmap_uint8 = np.uint8(255 * mean_heatmap) # Converts [0.0 – 1.0] → [0 – 255] format needed for OpenCV
        heatmap_colored = cv2.applyColorMap(heatmap_uint8, cv2.COLORMAP_JET) # assigns color to the values

        # Resize to match original image size
        h, w = original_images[0].shape[:2]
        heatmap_resized_blocky = cv2.resize(heatmap_colored, (w, h), interpolation=cv2.INTER_NEAREST) #INTER_CUBIC for smoother transitions ; INTER_NEAREST for blocky look
        heatmap_resized_smooth = cv2.resize(heatmap_colored, (w, h), interpolation=cv2.INTER_CUBIC)


        mean_path_blocky = os.path.join(output_folder, "correct_mean_heatmap_blocky.png")
        cv2.imwrite(mean_path_blocky, heatmap_resized_blocky)
        #print(f"Mean heatmap (correct only) saved at: {mean_path_blocky}")

        mean_path_smooth = os.path.join(output_folder, "correct_mean_smooth.png")
        cv2.imwrite(mean_path_smooth, heatmap_resized_smooth)
        #print(f"Mean heatmap (correct only) saved at: {mean_path_smooth}")

    # === CALCULATE INCORRECT MEAN HEATMAP ===
    if incorrect_heatmaps:
        mean_heatmap_incorrect = np.mean(np.stack(incorrect_heatmaps), axis=0)
        mean_heatmap_incorrect -= mean_heatmap_incorrect.min()
        mean_heatmap_incorrect /= (mean_heatmap_incorrect.max() + 1e-8)
        heatmap_uint8_incorrect = np.uint8(255 * mean_heatmap_incorrect)
        heatmap_colored_incorrect = cv2.applyColorMap(heatmap_uint8_incorrect, cv2.COLORMAP_JET)

        h, w = original_images[0].shape[:2]
        heatmap_resized_blocky_incorrect = cv2.resize(heatmap_colored_incorrect, (w, h), interpolation=cv2.INTER_NEAREST)
        heatmap_resized_smooth_incorrect = cv2.resize(heatmap_colored_incorrect, (w, h), interpolation=cv2.INTER_CUBIC)

        mean_path_blocky_incorrect = os.path.join(output_folder, "incorrect_mean_heatmap_blocky.png")
        cv2.imwrite(mean_path_blocky_incorrect, heatmap_resized_blocky_incorrect)
        print(f"Mean heatmap (incorrect only) saved at: {mean_path_blocky_incorrect}")

        mean_path_smooth_incorrect = os.path.join(output_folder, "incorrect_mean_smooth.png")
        cv2.imwrite(mean_path_smooth_incorrect, heatmap_resized_smooth_incorrect)
        print(f"Mean heatmap (incorrect only) saved at: {mean_path_smooth_incorrect}")


    return correct_heatmaps, original_images

In [8]:
"""def generate_prediction_table_only(model, input_folder, split_name, structure, birth_type, dataset):
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])

    class_names = ["Cesarean Birth", "Vaginal Birth"]
    prediction_results = []

    for img_name in os.listdir(input_folder):
        if not img_name.lower().endswith(".png"):
            continue

        img_path = os.path.join(input_folder, img_name)
        image = Image.open(img_path).convert("RGB")
        input_tensor = transform(image).unsqueeze(0).to(device)

        with torch.no_grad():
            output = model(input_tensor)
            prediction = output.argmax(dim=1).item()
            probabilities = torch.softmax(output, dim=1)
            confidence = probabilities[0, prediction].item()

        ground_truth_label = get_ground_truth(img_name, input_folder)
        ground_truth_class = class_names.index(ground_truth_label)

        prediction_results.append({
            "Image": img_name,
            "Predicted Class": class_names[prediction],
            "Ground Truth": ground_truth_label,
            "Correct": prediction == ground_truth_class,
            "Confidence": round(confidence, 4),
            "Split": split_name,
            "Structure": structure,
            "Birth Type": birth_type,
            "Dataset": dataset
        })

    return pd.DataFrame(prediction_results)"""

'def generate_prediction_table_only(model, input_folder, split_name, structure, birth_type, dataset):\n    transform = transforms.Compose([\n        transforms.Resize((224, 224)),\n        transforms.ToTensor(),\n        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])\n    ])\n\n    class_names = ["Cesarean Birth", "Vaginal Birth"]\n    prediction_results = []\n\n    for img_name in os.listdir(input_folder):\n        if not img_name.lower().endswith(".png"):\n            continue\n\n        img_path = os.path.join(input_folder, img_name)\n        image = Image.open(img_path).convert("RGB")\n        input_tensor = transform(image).unsqueeze(0).to(device)\n\n        with torch.no_grad():\n            output = model(input_tensor)\n            prediction = output.argmax(dim=1).item()\n            probabilities = torch.softmax(output, dim=1)\n            confidence = probabilities[0, prediction].item()\n\n        ground_truth_label = get_ground_truth(img_name, input_folder)\

In [9]:
def main():
    #only_generate_csv = True  # <-- MUDA AQUI

    structures = ["Abdomen_"] #,"Femur_", "Head_"] 
    birth_types = ["Cesarean Birth"]#, "Vaginal Birth"]
    datasets = ["test"]#, "test"]
    splits = ["cv3"] #["cv1", "cv2","cv3"]

    #dfs_by_split = {split: {} for split in splits}  # dict de dicts

    base_input = "C:/Users/anale/OneDrive/Documentos/Universidade/TESE/image-dataset/dataset_images_{split}/{structure}/{dataset}/{birth_type}"
    base_output = "C:/Users/anale/OneDrive/Documentos/Universidade/TESE/RESULTS/TARGETLAYERTESTE"#HiRes_medvit_{split}_results/X_hires_medvit_{structure_clean}_{birth_type_short}_{dataset}_{split}" 
    base_weight = "C:/Users/anale/OneDrive/Documentos/Universidade/TESE/model_paths/MedViT_paths/{structure_clean}_{split}_medvit_pretrained_best-model.pth"
    #output_dir = "C:/Users/anale/OneDrive/Documentos/Universidade/TESE/RESULTS/prediction_confidence_medvit_xlsx"
    #os.makedirs(output_dir, exist_ok=True)


    for split in splits:
        for structure in structures:
            structure_clean = structure.rstrip("_").lower()
            for birth_type in birth_types:
                birth_type_short = birth_type.split()[0].lower()  # 'cesarean' or 'vaginal'
                for dataset in datasets:
                    input_folder = base_input.format(split=split, structure=structure, birth_type=birth_type, dataset=dataset)
                    output_folder = base_output.format(split=split, structure_clean=structure_clean, birth_type_short=birth_type_short, dataset=dataset)
                    weight_path = base_weight.format(structure_clean=structure_clean, split=split)

                    if not os.path.exists(weight_path):
                        print(f"[SKIPPED] Weight file does not exist: {weight_path}")
                        continue

                    print(f"Processing: {structure}, {birth_type}, {dataset}, {split}")
                    model = load_medvit_model(weight_path, num_classes=2)
                    all_heatmaps, original_images = apply_gradcam_and_save(model, input_folder, output_folder)

                    """df = generate_prediction_table_only(
                            model=model,
                            input_folder=input_folder,
                            split_name=split,
                            structure=structure_clean,
                            birth_type=birth_type_short,
                            dataset=dataset
                    )
                    
                    sheet_name = f"{structure_clean}_{birth_type_short}_{dataset}"[:31]
                    dfs_by_split[split][sheet_name] = df

    # === Guardar cada split como um ficheiro .xlsx ===
    for split, sheet_dict in dfs_by_split.items():
        excel_path = os.path.join(output_dir, f"{split}.xlsx")
        with pd.ExcelWriter(excel_path, engine="openpyxl") as writer:
            for sheet_name, df in sheet_dict.items():
                df.to_excel(writer, sheet_name=sheet_name, index=False)
        print(f"Guardado: {excel_path}")"""

    print("All Grad-CAM heatmaps were generated successfully.")

if __name__ == "__main__":
    main() 


Processing: Abdomen_, Cesarean Birth, test, cv3
initialize_weights...


  state_dict = torch.load(weight_path, map_location=device)


Sequential(
  (0): ECB(
    (patch_embed): PatchEmbed(
      (avgpool): Identity()
      (conv): Conv2d(64, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (norm): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (mhca): MHCA(
      (group_conv3x3): Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=3, bias=False)
      (norm): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act): ReLU(inplace=True)
      (projection): Conv2d(96, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
    )
    (attention_path_dropout): DropPath(drop_prob=0.000)
    (conv): LocalityFeedForward(
      (conv): Sequential(
        (0): Conv2d(96, 288, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(288, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): h_swish(
          (sigmoid): h_sigmoid(
            (relu): ReLU6(inplace=True)
          )
       

KeyboardInterrupt: 