In [2]:
import cv2
import numpy as np
import torch
from detectron2.utils.visualizer import Visualizer
from detectron2.data import DatasetCatalog, MetadataCatalog
from detectron2.structures import Boxes, BoxMode

from pathlib import Path

class GradCAM:
    def __init__(self, model, target_layer):
        self.model = model
        self.target_layer = target_layer
        self.gradients = None
        self.activations = None
        self.hook_handles = []

        # register hooks
        self.hook_handles.append(target_layer.register_forward_hook(self._save_activation))
        self.hook_handles.append(target_layer.register_backward_hook(self._save_gradient))

    def _save_activation(self, module, input, output):
        if output.ndim == 3:
            output = output.unsqueeze(0)
        self.activations = output.detach()

    def _save_gradient(self, module, grad_input, grad_output):
        grad = grad_output[0]
        if grad.ndim == 3:
            grad = grad.unsqueeze(0)
        self.gradients = grad.detach()

    def remove_hooks(self):
        for h in self.hook_handles:
            h.remove()

    def __call__(self, input_tensor, class_idx=None):
        # Forward
        output = self.model(input_tensor)
        
        # If it's a detection model, you'll need to pick a score or logit
        if class_idx is None:
            score = output[0]["instances"].scores[0]  # first detection
        else:
            # print("class_idx: ", class_idx)
            # print(output[0])
            score = output[0]["instances"].scores[int(class_idx)]

        # Backward
        self.model.zero_grad()
        score.backward(retain_graph=True)

        # Global-average-pool gradients
        weights = self.gradients.mean(dim=[2, 3], keepdim=True)  # [C,1,1]

        # Weighted sum of activations
        cam = (weights * self.activations).sum(dim=1, keepdim=True)
        cam = torch.relu(cam)
        
        # Normalize to [0,1]
        cam = cam.squeeze().cpu().numpy()
        cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8)
        return cam


# ---- Visualizer helper ----
def draw_predictions(img, outputs, metadata, gt_boxes, gt_classes):
    v = Visualizer(img[:, :, ::-1], metadata=metadata, scale=1.0)
    inst = outputs["instances"].to("cpu")
    # v = v.draw_instance_predictions(inst)

    # if gt_boxes is not None:
        # Draw GT boxes in red
    for box, cls in zip(gt_boxes, gt_classes):
        x1, y1, x2, y2 = map(int, box)
        cls_name = metadata.thing_classes[cls] if len(metadata.thing_classes) > 0 else str(cls)
        text = f"GT: {cls_name}"

        v.draw_box([x1, y1, x2, y2], edge_color=(0.0,1.0,0.0))
        v.draw_text(text, (x1, y1), color=(0.0,1.0,0.0))

    boxes = inst.pred_boxes.tensor.numpy()
    scores = inst.scores.tolist()
    classes = inst.pred_classes.tolist()

    for box, score, cls_id in zip(boxes, scores, classes):
        x0, y0, x1, y1 = box
        cls_name = metadata.thing_classes[cls_id] if len(metadata.thing_classes) > 0 else str(cls_id)
        text = f"Pred: {cls_name} {score:.2f}"

        v.draw_box([x0, y0, x1, y1], edge_color=(1.0, 0.0, 0.0))
        v.draw_text(text, (x0, y0), color=(1.0, 0.0, 0.0))

    vis_out = v.output
    drawn = vis_out.get_image()[:, :, ::-1]

    return drawn


def get_gt_from_dict(entry):
    gt_boxes, gt_classes = [], []
    for ann in entry["annotations"]:
        bbox = ann["bbox"]
        if ann["bbox_mode"] != BoxMode.XYXY_ABS:
            bbox = BoxMode.convert(bbox, ann["bbox_mode"], BoxMode.XYXY_ABS)
        gt_boxes.append(bbox)
        gt_classes.append(ann["category_id"])
    return gt_boxes, gt_classes  


# ---- Main routine for one image ----
def visualize_cam_and_bboxes(entry, model, gradcam, metadata, out_dir):
    img_path = entry["file_name"]
    h, w = entry["height"], entry["width"]
    img = cv2.imread(img_path)

    # Detectron2 input
    inputs = [{"image": torch.as_tensor(img.astype("float32").transpose(2, 0, 1)).cuda(),
               "height": h, "width": w}]
    with torch.no_grad():
        outputs = model(inputs)

    # Grad-CAM on top detection (if any)
    if len(outputs[0]["instances"]) > 0:
        score = outputs[0]["instances"].scores[0]
        cam_map = gradcam(inputs, score)
        cam_resized = cv2.resize(cam_map, (w, h))
        heatmap = cv2.applyColorMap(np.uint8(255 * cam_resized), cv2.COLORMAP_JET)
        overlay = (0.5 * heatmap + 0.5 * img).astype(np.uint8)
    else:
        overlay = img.copy()

    # GT boxes
    gt_boxes, gt_classes = get_gt_from_dict(entry)

    # print("GT boxes:", gt_boxes[:3])
    # print("GT classes:", gt_classes[:3])

    # Pred + GT
    bbox_vis = draw_predictions(img.copy(), outputs[0], metadata, gt_boxes, gt_classes)

    # Stack horizontally
    stacked = np.hstack([overlay, bbox_vis])

    # Save
    Path(out_dir).mkdir(parents=True, exist_ok=True)
    out_path = Path(out_dir) / Path(img_path).name
    cv2.imwrite(str(out_path), stacked)
    print("Saved:", out_path)


In [3]:
import numpy as np
from sklearn import preprocessing

import _init_paths
from config import cfg, update_config

def update_cfg_with_args(cfg, arg_key, arg_value):
    cfg.defrost()

    arg_key = arg_key.upper()

    cfg.arg_key = arg_value

    cfg.freeze()


def minmax_norm(img):
    minmax_scaler = preprocessing.MinMaxScaler()

    img = np.array([
        minmax_scaler.fit_transform(img[:, :, 0]),
        minmax_scaler.fit_transform(img[:, :, 1]),
        minmax_scaler.fit_transform(img[:, :, 2]),
    ])

    return np.transpose(img, (1, 2, 0)).astype(np.float32)


In [None]:
from detectron2 import model_zoo
from detectron2.config import get_cfg
from detectron2.engine import DefaultPredictor
from detectron2.data import DatasetCatalog

from dataset.utils import register_patch_bin_dataset


cfg_path = "/workspace/project/configs/frcnn/frcnn.yaml"

cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file(
        "COCO-Detection/faster_rcnn_R_50_FPN_1x.yaml"
))

cfg.set_new_allowed(True)
cfg.defrost()
cfg.merge_from_file(cfg_path)

cfg.MODEL.WEIGHTS = "/workspace/project/record/debug2/result_single/frcnn_vis/model_final.pth"
cfg.MODEL.DEVICE = "cuda"

cfg.MODEL.ROI_HEADS.NMS_THRESH_TEST = 0.8   # e.g., 0.5, Removes duplicate boxes that overlap too much
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.9 # detection score threshold, Removes weak predictions

cfg.freeze()

predictor = DefaultPredictor(cfg)
model = predictor.model
model.eval()

target_layer = model.backbone.bottom_up.res5[-1]
gradcam = GradCAM(model, target_layer)

register_patch_bin_dataset(
        cfg.DATASETS.TEST[0],
        json_file=cfg.DATASETS.TEST_ANNO_DIR,
        img_root=cfg.DATASETS.IMG_DIR,
        extra_key=["patient_id"]
)

ds_dicts = DatasetCatalog.get(cfg.DATASETS.TEST[0])
metadata = MetadataCatalog.get(cfg.DATASETS.TEST[0])

out_dir = "./gradcam_results/gradcam_results_80nms_90score"

for i, entry in enumerate(ds_dicts):
        visualize_cam_and_bboxes(entry, model, gradcam, metadata, out_dir)

      

Config '/workspace/project/configs/frcnn/frcnn.yaml' has no VERSION. Assuming it to be compatible with latest v2.
  return torch.load(f, map_location=torch.device("cpu"))

Category ids in annotations are not in [1, #categories]! We'll apply a mapping for you.

  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)


loading annotations into memory...
Done (t=0.00s)
creating index...
index created!
Saved: gradcam_results_80nms_90score/raw1_09-D2(30m)-4.JPG
Saved: gradcam_results_80nms_90score/raw1_09-D2(30m)-1.JPG


  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)


Saved: gradcam_results_80nms_90score/raw1_09-D2(30m)-2.JPG
Saved: gradcam_results_80nms_90score/raw1_09-D2(30m)-3.JPG
Saved: gradcam_results_80nms_90score/raw1_09-D3(24h)-3.JPG


  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)


Saved: gradcam_results_80nms_90score/raw1_09-D3(24h)-4.JPG
Saved: gradcam_results_80nms_90score/raw1_28-D2(30m)-4.JPG
Saved: gradcam_results_80nms_90score/raw1_28-D2(30m)-3.JPG


  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)


Saved: gradcam_results_80nms_90score/raw1_33-D2(30m)-1.JPG
Saved: gradcam_results_80nms_90score/raw1_33-D2(30m)-2.JPG
Saved: gradcam_results_80nms_90score/raw1_33-D2(30m)-3.JPG


  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)


Saved: gradcam_results_80nms_90score/raw1_33-D3(24h)-2.JPG
Saved: gradcam_results_80nms_90score/raw1_33-D2(30m)-4.JPG
Saved: gradcam_results_80nms_90score/raw1_33-D3(24h)-4.JPG


  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)


Saved: gradcam_results_80nms_90score/raw2_02-D2(30m)-4.JPG
Saved: gradcam_results_80nms_90score/raw2_02-D2(30m)-1.JPG
Saved: gradcam_results_80nms_90score/raw2_02-D2(30m)-2.JPG


  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)


Saved: gradcam_results_80nms_90score/raw2_02-D2(30m)-3.JPG
Saved: gradcam_results_80nms_90score/raw2_06-D2(30m)-1.JPG
Saved: gradcam_results_80nms_90score/raw2_06-D2(30m)-4.JPG


  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)


Saved: gradcam_results_80nms_90score/raw2_06-D2(30m)-2.JPG
Saved: gradcam_results_80nms_90score/raw2_06-D2(30m)-3.JPG
Saved: gradcam_results_80nms_90score/raw2_09-D2(30m)-4.JPG


  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)


Saved: gradcam_results_80nms_90score/raw2_09-D2(30m)-1.JPG
Saved: gradcam_results_80nms_90score/raw2_09-D2(30m)-2.JPG
Saved: gradcam_results_80nms_90score/raw2_09-D2(30m)-3.JPG


  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)


Saved: gradcam_results_80nms_90score/raw2_15-D2(30m)-3.JPG
Saved: gradcam_results_80nms_90score/raw2_15-(D3)24h-4.JPG
Saved: gradcam_results_80nms_90score/raw2_15-D2(30m)-1.JPG


  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)


Saved: gradcam_results_80nms_90score/raw2_15-D2(30m)-2.JPG
Saved: gradcam_results_80nms_90score/raw2_15-D2(30m)-4.JPG
Saved: gradcam_results_80nms_90score/raw2_17-(D3)24h-3.JPG


  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)


Saved: gradcam_results_80nms_90score/raw2_17-D2(30m)-4.JPG
Saved: gradcam_results_80nms_90score/raw2_17-(D3)24h-1.JPG
Saved: gradcam_results_80nms_90score/raw2_17-(D3)24h-4.JPG


  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)


Saved: gradcam_results_80nms_90score/raw2_17-D2(30m)-1.JPG
Saved: gradcam_results_80nms_90score/raw2_17-D2(30m)-2.JPG
Saved: gradcam_results_80nms_90score/raw2_17-D2(30m)-3.JPG


  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)


Saved: gradcam_results_80nms_90score/raw2_20-D2(30m)-4.JPG
Saved: gradcam_results_80nms_90score/raw2_20-D2(30m)-1.JPG
Saved: gradcam_results_80nms_90score/raw2_20-D2(30m)-2.JPG


  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)


Saved: gradcam_results_80nms_90score/raw2_20-D2(30m)-3.JPG
Saved: gradcam_results_80nms_90score/raw2_26-D2(30m)-3.JPG
Saved: gradcam_results_80nms_90score/raw2_26-D2(30m)-4.JPG


  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)


Saved: gradcam_results_80nms_90score/raw2_26-D2(30m)-1.JPG
Saved: gradcam_results_80nms_90score/raw2_26-D2(30m)-2.JPG
Saved: gradcam_results_80nms_90score/raw2_27-D2(30m)-3.JPG


  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)


Saved: gradcam_results_80nms_90score/raw2_27-D2(30m)-1.JPG
Saved: gradcam_results_80nms_90score/raw2_27-D2(30m)-2.JPG
Saved: gradcam_results_80nms_90score/raw2_27-D2(30m)-4.JPG


  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)


Saved: gradcam_results_80nms_90score/raw2_28-D2(30m)-4.JPG
Saved: gradcam_results_80nms_90score/raw2_28-D2(30m)-1.JPG
Saved: gradcam_results_80nms_90score/raw2_28-D2(30m)-2.JPG


  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)


Saved: gradcam_results_80nms_90score/raw2_28-D2(30m)-3.JPG
Saved: gradcam_results_80nms_90score/raw2_29-(D3)24h-4.JPG
Saved: gradcam_results_80nms_90score/raw2_29-D2(30m)-1.JPG


  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)


Saved: gradcam_results_80nms_90score/raw2_29-D2(30m)-2.JPG
Saved: gradcam_results_80nms_90score/raw2_29-D2(30m)-3.JPG
Saved: gradcam_results_80nms_90score/raw2_29-D2(30m)-4.JPG


  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)


Saved: gradcam_results_80nms_90score/raw1_02-D2(30m)-2.JPG
Saved: gradcam_results_80nms_90score/raw1_02-D2(30m)-3.JPG
Saved: gradcam_results_80nms_90score/raw1_02-D2(30m)-4.JPG


  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)


Saved: gradcam_results_80nms_90score/raw1_02-D3(24h)-1.JPG
Saved: gradcam_results_80nms_90score/raw1_02-D3(24h)-4.JPG
Saved: gradcam_results_80nms_90score/raw2_33-(D3)24h-4.JPG


  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)


Saved: gradcam_results_80nms_90score/raw2_33-D2(30m)-1.JPG
Saved: gradcam_results_80nms_90score/raw2_33-D2(30m)-2.JPG
Saved: gradcam_results_80nms_90score/raw2_33-D2(30m)-3.JPG


  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)


Saved: gradcam_results_80nms_90score/raw2_33-D2(30m)-4.JPG


In [None]:



img_path = "/workspace/datasets/seg_by_patient/preprocessed/pos_cropped_patch_all_r1_r2/raw1_01-D2(30m)-1.JPG"
# Prepare image
img = cv2.imread(img_path)[:, :, ::-1]  # BGR->RGB

img = minmax_norm(img)

# inputs = predictor.transform_gen.get_transform(img).apply_image(img)
tensor = torch.as_tensor(img.transpose(2,0,1)).cuda().float()

input = [{
    "image":tensor,
    "height": 1200,
    "width": 600
}]

# print(tensor.shape)

# Generate CAM
cam_map = gradcam(input)

# Overlay heatmap
heatmap = cv2.applyColorMap(np.uint8(255 * cam_map), cv2.COLORMAP_JET)

H, W, _ = img.shape

r_hm = cv2.resize(heatmap, (W, H))

r_hm = (r_hm-r_hm.min()) / (r_hm.max()-r_hm.min()+1e-8)

heatmap = cv2.applyColorMap(np.uint8(255 * r_hm), cv2.COLORMAP_JET)


overlay = 0.5 * heatmap[:, :, ::-1] + 0.5 * img
cv2.imwrite("./gradcam_custom.jpg", overlay[:, :, ::-1])

  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)


True

In [47]:
heatmap.shape

(38, 19, 3)