In [1]:
import cv2
import matplotlib.pyplot as plt
import torch

from importlib import import_module
from pytorch_grad_cam import AblationCAM, GradCAM, XGradCAM, EigenCAM, EigenGradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget, RawScoresOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
from skp.toolbox.functions import load_model_from_config, overlay_images
from torch import nn

In [2]:
class ModelForGradCAM(nn.Module):

    def __init__(self, model):
        super().__init__()
        self.model = model
    
    def forward(self, x):
        return self.model({"x": x})["logits1"]

In [None]:
cfg = import_module("skp.configs.boneage.cfg_female_channel_reg_cls_match_hist").cfg
cfg.backbone, cfg.backbone_img_size = "convnextv2_tiny", False
model = load_model_from_config(
    cfg, 
    weights_path=cfg.save_dir + "boneage.cfg_female_channel_reg_cls_match_hist/fa77ff59/fold0/checkpoints/last.ckpt",
    device="cpu",
    eval_mode=True
)
model_gradcam = ModelForGradCAM(model)
dataset = import_module(f"skp.datasets.{cfg.dataset}").Dataset(cfg, mode="val")

In [None]:
model.backbone.stages[-1].blocks[-1]

In [27]:
batch = dataset[6]
x = batch["x"].unsqueeze(0)

with torch.inference_mode():
    out = model({"x": x}, return_loss=False)
    # predicted_bone_age = out["logits0"][0].item()
    predicted_bone_age = out["logits1"][0].softmax(dim=0)
    predicted_bone_age = (predicted_bone_age * torch.arange(240)).sum().item()
    # round to nearest integer
    rounded_bone_age = round(predicted_bone_age)

predicted_bone_age, rounded_bone_age

target_layers = [model_gradcam.model.backbone.stages[-1]]
targets = [ClassifierOutputTarget(rounded_bone_age)]

with GradCAM(model=model_gradcam, target_layers=target_layers) as cam:
    grayscale_cam = cam(input_tensor=x, targets=targets, eigen_smooth=True)

In [None]:
heatmap = cv2.applyColorMap((grayscale_cam[0] * 255).astype("uint8"), cv2.COLORMAP_JET)
image = cv2.cvtColor(batch["x"].numpy()[0].astype("uint8"), cv2.COLOR_GRAY2RGB)
image_weight = 0.6
cam = (1 - image_weight) * heatmap[..., ::-1] + image_weight * image
cam = cam.astype("uint8")

plt.imshow(cam)