Grad-CAM applied to MedViT model for binary classification

In [None]:
import os
import cv2
import torch
import numpy as np
import torchvision.transforms as T
from PIL import Image
import torchvision.utils as vutils
import torchvision.transforms.functional as TF
from MedViT import MedViT_base

In [None]:
# Function to create and load trained MedViT model
def load_medvit_model(weight_path, num_classes=2, device=torch.device("cpu")):
    # Cria modelo com a mesma estrutura usada no treino
    model = MedViT_base(pretrained=False, num_classes=1000)

    # Carrega pesos do modelo treinado
    state_dict = torch.load(weight_path, map_location=device)

    # Atualiza manualmente a última camada
    in_features = model.proj_head[0].in_features
    model.proj_head = torch.nn.Sequential(
        torch.nn.Linear(in_features, num_classes)
    )

    model.to(device)
    model.eval()
    return model


In [None]:
class VITAttentionGradRolloutWithClassAttribution:
    def __init__(self, model, attention_module_path="features.19", discard_ratio=0.9):
        self.model = model
        self.discard_ratio = discard_ratio
        self.attentions = []
        self.attention_gradients = []
        self.handles = []

        for name, module in self.model.named_modules():
            if name.startswith("features."):
                print(f"{name:<20} {type(module)}")


        # Localizar e registar hooks no módulo correto
        for name, module in self.model.named_modules():
            if name == attention_module_path:
                print(f">> HOOKING: {name} → {type(module)}")
                print(f">> Has 'last_attn'? {'last_attn' in dir(module)}")

                self._register_custom_hook(module)

    def _register_custom_hook(self, module):
        def forward_hook(mod, input, output):
            attn = mod.last_attn.detach()
            self.attentions.append(attn)
            attn.requires_grad_(True)
            attn.retain_grad()
            attn.register_hook(lambda grad: self.attention_gradients.append(grad.detach()))

        handle = module.register_forward_hook(forward_hook)
        self.handles.append(handle)


    def rollout(self, input_tensor, target_index=None):
        self.attentions = []
        self.attention_gradients = []

        output = self.model(input_tensor)
        if isinstance(output, tuple):
            output = output[0]
        pred_class = output.argmax(dim=1).item() if target_index is None else target_index

        self.model.zero_grad()
        one_hot = torch.zeros_like(output)
        one_hot[0, pred_class] = 1
        output.backward(gradient=one_hot)

        result = None

        with torch.no_grad():
            print(">> ATTENTIONS CAPTURED:", len(self.attentions))

            for attention, grad in zip(self.attentions, self.attention_gradients):
                if attention.shape != grad.shape:
                    print(f"Skipping incompatible pair: attention {attention.shape}, grad {grad.shape}")
                    continue

                weights = grad
                attn_heads_fused = (attention * weights).mean(dim=1)
                attn_heads_fused = torch.clamp(attn_heads_fused, min=0)

                flat = attn_heads_fused.view(attn_heads_fused.size(0), -1)
                _, indices = flat.topk(int(flat.size(-1) * self.discard_ratio), dim=-1, largest=False)
                flat.scatter_(1, indices, 0)
                attn_heads_fused = flat.view_as(attn_heads_fused)

                token_dim = attn_heads_fused.size(-1)
                I = torch.eye(token_dim, device=attn_heads_fused.device).unsqueeze(0)
                a = (attn_heads_fused + I) / 2
                a = a / a.sum(dim=-1, keepdim=True)

                if result is None:
                    result = a
                else:
                    result = torch.matmul(a, result)

        # Remove class token
        mask = result[0, 0]
        if mask.shape[0] in [65, 197, 577]:  # caso típico ViT
            mask = mask[1:]

        num_tokens = mask.shape[0]
        side = int(num_tokens ** 0.5)
        mask = mask[:side * side].reshape(1, 1, side, side)
        mask = torch.nn.functional.interpolate(mask, size=(224, 224), mode="bilinear", align_corners=False)
        mask = mask.squeeze().cpu().numpy()
        mask = (mask - mask.min()) / (mask.max() + 1e-8)
        return mask

    def clear_hooks(self):
        for handle in self.handles:
            handle.remove()


In [None]:
# Utility to overlay heatmap
def overlay_heatmap(gray_img, heatmap):
    # Redimensionar heatmap para o tamanho da imagem original
    heatmap_resized = cv2.resize(heatmap, (gray_img.shape[1], gray_img.shape[0]), interpolation=cv2.INTER_CUBIC)

    heatmap_uint8 = np.uint8(255 * heatmap_resized)
    heatmap_color = cv2.applyColorMap(heatmap_uint8, cv2.COLORMAP_JET)

    gray_bgr = cv2.cvtColor(gray_img, cv2.COLOR_GRAY2BGR)
    overlay = cv2.addWeighted(heatmap_color, 0.5, gray_bgr, 0.5, 0)
    return overlay

In [None]:
# Utility to add colored frame
def add_colored_frame(image, color, thickness=20):
    return cv2.copyMakeBorder(image, thickness, thickness, thickness, thickness, cv2.BORDER_CONSTANT, value=color)

In [None]:
# Apply Grad-Rollout on folder of images
def apply_gradrollout_and_save(model, rollout, input_folder, output_folder, class_names, grid_image_size=(80, 80)):
    os.makedirs(output_folder, exist_ok=True)

    transform = T.Compose([
        T.Resize((224, 224)),
        T.ToTensor(),
        T.Normalize(mean=[.5], std=[.5])
    ])

    all_heatmaps = []
    correct_heatmaps = []
    incorrect_heatmaps = []
    correct_tensors = []
    incorrect_tensors = []

    for img_name in os.listdir(input_folder):
        if not img_name.lower().endswith((".png", ".jpg", ".jpeg")):
            continue

        img_path = os.path.join(input_folder, img_name)

        #Load image
        image = Image.open(img_path).convert("L")

        #Keep a copy of original grayscale image
        original = np.array(image)
        #original_images.append(original_np)

        input_tensor = transform(image.convert("RGB")).unsqueeze(0)

        #vit-explain
        with torch.no_grad():
            output = model(input_tensor)
            if isinstance(output, tuple):
                output = output[0]
            pred_class = output.argmax(dim=1).item()

        mask = rollout.rollout(input_tensor, target_index=pred_class)
        all_heatmaps.append(mask)

        gt_label = 0 if "Cesarean" in img_name else 1  # example: ground truth from filename
        frame_color = (0, 255, 0) if pred_class == gt_label else (0, 0, 255)
        if pred_class == gt_label:
            correct_heatmaps.append(mask)
        else:
            incorrect_heatmaps.append(mask)

        overlay = overlay_heatmap(original, mask)
        framed = add_colored_frame(overlay, frame_color)
        cv2.imwrite(os.path.join(output_folder, f"heatmap_{img_name}"), framed)

        pil_resized = Image.fromarray(cv2.cvtColor(framed, cv2.COLOR_BGR2RGB)).resize(grid_image_size)
        tensor_img = TF.to_tensor(pil_resized)
        if pred_class == gt_label:
            correct_tensors.append(tensor_img)
        else:
            incorrect_tensors.append(tensor_img)

    final_grid = incorrect_tensors + correct_tensors
    if final_grid:
        grid = vutils.make_grid(final_grid, nrow=20, padding=5, normalize=True)
        vutils.save_image(grid, os.path.join(output_folder, "grid_overlay_ordered.jpg"))

    for tag, maps in zip(["correct", "incorrect"], [correct_heatmaps, incorrect_heatmaps]):
        if maps:
            mean_map = np.mean(np.stack(maps), axis=0)
            mean_map -= mean_map.min()
            mean_map /= (mean_map.max() + 1e-8)
            mean_uint8 = np.uint8(255 * mean_map)
            mean_colored = cv2.applyColorMap(mean_uint8, cv2.COLORMAP_JET)
            h, w = original.shape[:2]
            cv2.imwrite(os.path.join(output_folder, f"mean_heatmap_{tag}_blocky.jpg"), cv2.resize(mean_colored, (w, h), interpolation=cv2.INTER_NEAREST))
            cv2.imwrite(os.path.join(output_folder, f"mean_heatmap_{tag}_smooth.jpg"), cv2.resize(mean_colored, (w, h), interpolation=cv2.INTER_CUBIC))

    rollout.clear_hooks()

In [None]:
if __name__ == "__main__":
    input_folder = "C:/Users/anale/OneDrive/Documentos/Universidade/TESE/image-dataset/dataset_images_cv_1/Abdomen_/test/Cesarean Birth"  # Folder with images to explain
    output_folder = "C:/Users/anale/OneDrive/Documentos/Universidade/TESE/RESULTS/X_ABDOMEN_cesarean_test_cv1_medvit_pretrained"   # Where to save heatmaps
    weight_path = "C:/Users/anale/OneDrive/Documentos/Universidade/TESE/MSc-Thesis/model_paths/abdomen_cv1_medvit_pretrained_best-model.pth"  
    #abdomen_cv1_medvit_pretrained_best-model
    class_names = ["Cesarean Birth", "Vaginal Birth"]

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Load trained MedViT model
    model = load_medvit_model(weight_path, num_classes=len(class_names), device=device)

    # Create Grad-Rollout explainer
    rollout = VITAttentionGradRolloutWithClassAttribution(model)

    # Apply and save explanations
    apply_gradrollout_and_save(model, rollout, input_folder, output_folder, class_names)