In [1]:
import argparse
from pathlib import Path

import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from PIL import Image

import os
import sys
from tqdm import tqdm

import matplotlib.pyplot as plt
import matplotlib.patches as patches

#==========================================================================
# JUPYTER PATH STUFF. Not present in the main script coz does not affect it
# =========================================================================

os.chdir("..")


root = Path().resolve()
while root != root.parent:
    if (root / "scripts").is_dir():
        sys.path.insert(0, str(root))
        print("Added to sys.path:", root)
        break
    root = root.parent
else:
    raise RuntimeError("Could not find 'scripts' directory above this notebook")

from scripts.src import (
    build_model_and_tokenizer,
    get_image_transform,
    get_label_text_inputs,
)

Added to sys.path: /home/woody/iwi5/iwi5362h/ALBEF


In [2]:
from scripts.albef_crossattn_gradcam import (
    register_albef_crossattn_gradcam_hooks,
    remove_albef_crossattn_gradcam_hooks,
    generate_albef_crossattn_gradcam,
)
from scripts.albef_gradcam import upsample_cam

In [3]:
def infer_png_path(images_root: Path, image_id: str) -> Path:
    png_path = images_root / f"{image_id}.png"
    if not png_path.exists():
        raise FileNotFoundError(f"PNG not found for image_id={image_id}: {png_path}")
    return png_path

In [4]:
images_root = Path("/home/woody/iwi5/iwi5362h/data/vindr_cxr/test")
output_dir = Path("/home/woody/iwi5/iwi5362h/ALBEF/results/zero_shot_vindr_results/heatmaps_grad_crossattn_token_mask_layers_8_to_11")
output_dir.mkdir(parents=True, exist_ok=True)
config_path = "configs/Pretrain.yaml"
ckpt_path = "output_mimic_a40_transformations/checkpoint_29.pth"
device = "cuda"
labels_csv = "/home/woody/iwi5/iwi5362h/data/vindr_cxr/annotations/image_labels_test.csv"

In [5]:
# Load model/tokenizer/config once and reuse
model, tokenizer, config, device = build_model_and_tokenizer(
    config_path=config_path,
    ckpt_path=ckpt_path,
    device=device,
)

image_res = config["image_res"]

transform = get_image_transform(image_res)
model.eval()

# ---- Register Grad-CAM hooks ----
handles = register_albef_crossattn_gradcam_hooks(model)
print("[CrossAttn-GradCAM] Hooks registered.")

[Model] Building ALBEF...
[Model] State dict loaded: <All keys matched successfully>
[CrossAttn-GradCAM] Registered hooks on cross-attention dropout for layers: [6, 7, 8, 9, 10, 11] (total 12 hooks).
[CrossAttn-GradCAM] Hooks registered.


In [6]:
max_images = 100
only_labels = ["Pleural effusion", "Cardiomegaly"]

In [7]:
# ---- Load CSV & determine labels ----
df = pd.read_csv(labels_csv)
id_col = df.columns[0]
all_label_cols = list(df.columns[1:])
print(f"[Data] {len(df)} rows, {len(all_label_cols)} labels in CSV.")

if max_images is not None:
    df = df.iloc[: max_images].reset_index(drop=True)

def has_png(row):
    return (images_root / f"{row[id_col]}.png").exists()

df["__has_png__"] = df.apply(has_png, axis=1)
df = df[df["__has_png__"]].reset_index(drop=True)
print(f"[Data] After PNG filter: {len(df)} images remain.")

image_ids = df[id_col].tolist()
label_cols = all_label_cols

if only_labels is not None:
    missing = [lb for lb in only_labels if lb not in label_cols]
    if missing:
        raise ValueError(f"Requested labels not in CSV: {missing}")
    label_cols = only_labels
    print(f"[Data] Restricting to labels: {label_cols}")

[Data] 3000 rows, 28 labels in CSV.
[Data] After PNG filter: 100 images remain.
[Data] Restricting to labels: ['Pleural effusion', 'Cardiomegaly']


In [8]:
# ---- Precompute text inputs per label ----
input_ids_dict, attn_mask_dict, token_mask_dict = get_label_text_inputs(
    tokenizer=tokenizer,
    labels=label_cols,
    max_length=32,
)

In [9]:
# ---- Process images ----
index_records = []
for idx_img, image_id in enumerate(image_ids, start=1):
    try:
        img_path = infer_png_path(images_root, image_id)
    except FileNotFoundError as e:
        print("[WARN]", e)
        continue

    img_pil = Image.open(img_path).convert("RGB")
    img_tensor = transform(img_pil).unsqueeze(0)  # (1,3,H,W)

    heatmaps = {}
    for label in label_cols:
        input_ids = input_ids_dict[label]      # (1,T)
        attn_mask = attn_mask_dict[label]      # (1,T)
        text_token_mask = token_mask_dict[label]

        cam_patch = generate_albef_crossattn_gradcam(
            model=model,
            img_tensor=img_tensor,
            input_ids=input_ids,
            attention_mask=attn_mask,
            device=device,
            text_token_mask=text_token_mask,
            layers_to_use=[8, 9, 10, 11]  # trying last 4 instead of all 6
        )

        cam_up = upsample_cam(cam_patch, target_size=image_res)
        heatmaps[label] = cam_up

    out_path = output_dir / f"{image_id}.pt"
    heatmaps_cpu = {k: v.float().cpu() for k, v in heatmaps.items()}
    torch.save(heatmaps_cpu, out_path)
    index_records.append({"image_id": image_id, "heatmap_path": str(out_path)})

    if idx_img % 20 == 0 or idx_img == len(image_ids):
        print(f"[CrossAttn-GradCAM] Processed {idx_img}/{len(image_ids)} images")

index_df = pd.DataFrame(index_records)
index_path = output_dir / "crossattn_gradcam_index.csv"
index_df.to_csv(index_path, index=False)
print(f"[Output] Saved cross-attn Grad-CAM index to: {index_path}")

remove_albef_crossattn_gradcam_hooks(handles)
print("[CrossAttn-GradCAM] Hooks removed.")

[CrossAttn-GradCAM] Processed 20/100 images
[CrossAttn-GradCAM] Processed 40/100 images
[CrossAttn-GradCAM] Processed 60/100 images
[CrossAttn-GradCAM] Processed 80/100 images
[CrossAttn-GradCAM] Processed 100/100 images
[Output] Saved cross-attn Grad-CAM index to: /home/woody/iwi5/iwi5362h/ALBEF/results/zero_shot_vindr_results/heatmaps_grad_crossattn_token_mask_layers_8_to_11/crossattn_gradcam_index.csv
[CrossAttn-GradCAM] Hooks removed and buffers cleared.
[CrossAttn-GradCAM] Hooks removed.
