In [None]:
#@title Setup project root

PROJECT_ROOT = "/Users/yunfanbao/Documents/work/mv-mammo-transformer"

import os, sys

os.chdir(PROJECT_ROOT)
sys.path.append(PROJECT_ROOT)

print("Current directory:", os.getcwd())

In [None]:
#@title CKPT Path

from config import CHECKPOINT_DIR

experiment_name = "exp_mammo_v1"
ckpt_name = "mv_transformer_ep10-30_acc0.9326.pt"

checkpoint_dir = os.path.join(CHECKPOINT_DIR, experiment_name, ckpt_name)

In [None]:
#@title study_id for Grad-CAM

gradcam_study_id = "2f4d26ae21e1fb85ec2d97f9464aadff"

In [None]:
#@title Evaluate Checkpoint

from scripts.evaluate import evaluate

res = evaluate(
    checkpoint_path=checkpoint_dir,
    gradcam_study_id=gradcam_study_id,
)

In [None]:
#@title Extract metrics and GradCAM

metrics = res["metrics"]
key = res["gradcam"]["key"]
cams = res["gradcam"]["cams"]
images = res["gradcam"]["images"]

In [None]:
#@title Show evaluation metrics

from pprint import pprint

pprint(metrics)

In [None]:
#@title Quick metric summary

for task, m in metrics.items():
    print(f"[{task}]")
    for k, v in m.items():
        if isinstance(v, float):
            print(f"  {k:20s}: {v:.4f}")

In [None]:
#@title Load bounding boxes

from src.data.bbox import BBox

bbox_db = BBox()

In [None]:
#@title Visualization helper

import torch.nn.functional as F
import matplotlib.pyplot as plt
import matplotlib.patches as patches

def show_cam_with_bbox(ax, img, cam, image_key):
    # img: (C, H, W)
    # cam: (h, w)

    img_np = img.squeeze().cpu().numpy()
    H, W = img_np.shape

    # Resize CAM to image size (bilinear interpolation)
    cam_t = cam.unsqueeze(0).unsqueeze(0)   # (1, 1, h, w)
    cam_up = F.interpolate(
        cam_t,
        size=(H, W),
        mode="bilinear",
        align_corners=False,
    )
    cam_np = cam_up.squeeze().cpu().numpy()  # (H, W)

    # Re-normalize after resize
    cam_np -= cam_np.min()
    cam_np /= (cam_np.max() + 1e-6)

    # Plot
    ax.imshow(img_np, cmap="gray")
    ax.imshow(cam_np, cmap="jet", alpha=0.45)

    # Draw all bboxes
    for (x1, y1, x2, y2) in bbox_db.by_index.get(
        image_key.study_id,
        image_key.laterality,
        image_key.view,
    ):
        rect = patches.Rectangle(
            (x1, y1),
            x2 - x1,
            y2 - y1,
            linewidth=2,
            edgecolor="lime",
            facecolor="none",
        )
        ax.add_patch(rect)

    ax.set_title(f"{image_key.laterality}-{image_key.view}")
    ax.axis("off")


In [None]:
#@title Show single-view GradCAM

from src.dataio.keys import ImageKey

if isinstance(key, ImageKey):
    fig, ax = plt.subplots(1, 1, figsize=(4, 4))

    cam = cams[0]
    img = images[0]

    show_cam_with_bbox(ax, img, cam, key)

    plt.tight_layout()
    plt.show()

In [None]:
#@title Show multi-view GradCAM

from src.dataio.keys import MultiViewKey

if isinstance(key, MultiViewKey):
    view_keys = key.views()
    n_views = len(view_keys)

    fig, axes = plt.subplots(2, 2, figsize=(8, 8))
    axes = axes.flatten()

    for i, image_key in enumerate(view_keys):
        # (C, H, W)
        img = images[i]
        # (H, W)
        cam = cams[i]

        show_cam_with_bbox(axes[i], img, cam, image_key)

    plt.suptitle(f"Study {key.study_id} â€“ GradCAM (predicted class)", y=0.95)
    plt.tight_layout()
    plt.show()