In [1]:
#==========================================================================
# JUPYTER PATH STUFF. Not present in the main script coz does not affect it
# =========================================================================

import os
from pathlib import Path
import sys

os.chdir("..")


root = Path().resolve()
while root != root.parent:
    if (root / "scripts").is_dir():
        sys.path.insert(0, str(root))
        print("Added to sys.path:", root)
        break
    root = root.parent
else:
    raise RuntimeError("Could not find 'scripts' directory above this notebook")

Added to sys.path: /home/woody/iwi5/iwi5362h/ALBEF


In [2]:
import numpy as np
import pandas as pd
import torch

In [3]:
def load_meta(meta_csv):
    """
    Load VinDr meta CSV with columns: image_id, dim0, dim1
    Returns dict: image_id -> (orig_height, orig_width)
    """
    df = pd.read_csv(meta_csv)
    return {r["image_id"]: (int(r["dim0"]), int(r["dim1"])) for _, r in df.iterrows()}


def scale_box_to_256(r, orig_h, orig_w, target_size=256):
    """
    Scale a GT box from original resolution to target_size x target_size space.

    r: DataFrame row with x_min, y_min, x_max, y_max in original px.
    Returns (x_min_256, y_min_256, x_max_256, y_max_256) as ints, clamped to [0, target_size-1].
    """
    scale_x = float(target_size) / float(orig_w)
    scale_y = float(target_size) / float(orig_h)

    x_min = r["x_min"] * scale_x
    y_min = r["y_min"] * scale_y
    x_max = r["x_max"] * scale_x
    y_max = r["y_max"] * scale_y

    # convert to int indices for slicing heatmaps
    x_min_i = int(np.floor(x_min))
    y_min_i = int(np.floor(y_min))
    x_max_i = int(np.ceil(x_max))
    y_max_i = int(np.ceil(y_max))

    # clamp
    x_min_i = max(0, min(target_size - 1, x_min_i))
    y_min_i = max(0, min(target_size - 1, y_min_i))
    x_max_i = max(0, min(target_size, x_max_i))
    y_max_i = max(0, min(target_size, y_max_i))

    # ensure at least 1 pixel in each dimension
    if x_max_i <= x_min_i:
        x_max_i = min(target_size, x_min_i + 1)
    if y_max_i <= y_min_i:
        y_max_i = min(target_size, y_min_i + 1)

    return x_min_i, y_min_i, x_max_i, y_max_i


def compute_box_stats_for_heatmap(
    heatmap: np.ndarray,
    box_coords,
    percentiles,
):
    """
    heatmap: (H, W) in [0,1]
    box_coords: (x_min, y_min, x_max, y_max) in 256x256 indices
    percentiles: list of percentiles to evaluate (e.g. [90,95,99])

    Returns a dict with:
      - mean_in_box
      - max_in_box
      - mean_global
      - For each p in percentiles:
          top{p}_coverage
          top{p}_recall
    """
    H, W = heatmap.shape
    x_min, y_min, x_max, y_max = box_coords

    # region inside box
    box_region = heatmap[y_min:y_max, x_min:x_max]
    mean_in_box = float(box_region.mean())
    max_in_box = float(box_region.max())

    mean_global = float(heatmap.mean())

    results = {
        "mean_in_box": mean_in_box,
        "max_in_box": max_in_box,
        "mean_global": mean_global,
    }

    box_area = box_region.size  # number of pixels in box

    # Flatten for percentile computation
    flat = heatmap.reshape(-1)

    for p in percentiles:
        thr = np.percentile(flat, p)
        mask = (heatmap >= thr).astype(np.float32)  # 1 on top-p% pixels

        mask_area = mask.sum()

        # intersection: pixels in mask AND in box
        mask_box = mask[y_min:y_max, x_min:x_max]
        inter = mask_box.sum()

        if mask_area > 0:
            coverage = float(inter / mask_area)       # how much of top-p% area lies inside box
        else:
            coverage = 0.0

        recall = float(inter / box_area) if box_area > 0 else 0.0  # how much of box is "hot"

        results[f"top{p}_coverage"] = coverage
        results[f"top{p}_recall"] = recall

    return results

In [4]:
annotations_csv = Path("/home/woody/iwi5/iwi5362h/data/vindr_cxr/annotations/annotations_test.csv")
meta_csv = Path("/home/woody/iwi5/iwi5362h/data/vindr_cxr/test_meta.csv")
heatmaps_root = Path("/home/woody/iwi5/iwi5362h/ALBEF/results/zero_shot_vindr_results/heatmaps_grad_crossattn")

In [5]:
df_ann = pd.read_csv(annotations_csv)
meta = load_meta(meta_csv)

In [6]:
percentiles = [90, 95, 99]

In [7]:
# ------------------ Loop over GT boxes ------------------
records = []
missing_heatmap_images = set()

for idx, r in df_ann.iterrows():
    image_id = r["image_id"]
    label = r["class_name"]

    if image_id not in meta:
        print(f"[WARN] image_id={image_id} not found in meta CSV; skipping.")
        continue

    # Skip boxes with invalid / NaN coordinates (for No Findings)
    if any(pd.isna(r[col]) for col in ["x_min", "y_min", "x_max", "y_max"]):
        print(f"[WARN] NaN box coordinates for image_id={image_id}, label={label} → skipping GT box.")
        continue

    orig_h, orig_w = meta[image_id]

    hm_path = heatmaps_root / f"{image_id}.pt"
    if not hm_path.exists():
        if image_id not in missing_heatmap_images:
            print(f"[WARN] Missing heatmap file for {image_id}: {hm_path}")
            missing_heatmap_images.add(image_id)
        continue

    heatmaps = torch.load(hm_path, map_location="cpu")
    if label not in heatmaps:
        print(f"[WARN] No heatmap for label '{label}' in {hm_path.name}; skipping this GT box.")
        continue

    hm_tensor = heatmaps[label].float()
    # Normalization (defensive; should already be ~[0,1])
    if hm_tensor.max() > hm_tensor.min():
        hm_tensor = (hm_tensor - hm_tensor.min()) / (hm_tensor.max() - hm_tensor.min())
    else:
        hm_tensor = torch.zeros_like(hm_tensor)

    heatmap = hm_tensor.detach().cpu().numpy()
    H, W = heatmap.shape
    if H != 256 or W != 256:
        raise ValueError(f"Expected heatmap 256x256, got {H}x{W} for {image_id}, label={label}")

    # Scale GT box to 256x256
    box_256 = scale_box_to_256(r, orig_h=orig_h, orig_w=orig_w, target_size=256)

    stats = compute_box_stats_for_heatmap(
        heatmap=heatmap,
        box_coords=box_256,
        percentiles=percentiles,
    )

    row_record = {
        "image_id": image_id,
        "label": label,
        "x_min_256": box_256[0],
        "y_min_256": box_256[1],
        "x_max_256": box_256[2],
        "y_max_256": box_256[3],
    }
    row_record.update(stats)
    records.append(row_record)

    if (idx + 1) % 500 == 0 or (idx + 1) == len(df_ann):
        print(f"[Eval] Processed {idx + 1}/{len(df_ann)} GT boxes")

if not records:
    print("[Result] No records to evaluate. Check inputs / filters.")

df_res = pd.DataFrame(records)
print("\n[Result] Overall statistics:")
print(df_res.describe())

# ------------------ Per-label summaries ------------------
print("\n[Result] Per-label statistics (mean over boxes):")
group_cols = ["mean_in_box", "max_in_box", "mean_global"]
for p in percentiles:
    group_cols.append(f"top{p}_coverage")
    group_cols.append(f"top{p}_recall")

df_label = df_res.groupby("label")[group_cols].mean().sort_values("mean_in_box", ascending=False)
print(df_label)

[WARN] No heatmap for label 'Calcification' in e0dc2e79105ad93532484e956ef8a71a.pt; skipping this GT box.
[WARN] No heatmap for label 'ILD' in e0dc2e79105ad93532484e956ef8a71a.pt; skipping this GT box.
[WARN] No heatmap for label 'Pneumothorax' in e0dc2e79105ad93532484e956ef8a71a.pt; skipping this GT box.
[WARN] No heatmap for label 'Pneumothorax' in e0dc2e79105ad93532484e956ef8a71a.pt; skipping this GT box.
[WARN] No heatmap for label 'Atelectasis' in e0dc2e79105ad93532484e956ef8a71a.pt; skipping this GT box.
[WARN] No heatmap for label 'Pneumothorax' in e0dc2e79105ad93532484e956ef8a71a.pt; skipping this GT box.
[WARN] No heatmap for label 'Infiltration' in 0aed23e64ebdea798486056b4f174424.pt; skipping this GT box.
[WARN] No heatmap for label 'Consolidation' in 0aed23e64ebdea798486056b4f174424.pt; skipping this GT box.
[WARN] No heatmap for label 'Pulmonary fibrosis' in aa15cfcfca7605465ca0513902738b95.pt; skipping this GT box.
[WARN] No heatmap for label 'Pleural thickening' in aa15c

In [8]:
df_label

Unnamed: 0_level_0,mean_in_box,max_in_box,mean_global,top90_coverage,top90_recall,top95_coverage,top95_recall,top99_coverage,top99_recall
label,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
Cardiomegaly,0.228701,0.703796,0.160336,0.167301,0.181749,0.159755,0.088638,0.089717,0.010594
Pleural effusion,0.092109,0.304311,0.109048,0.026414,0.057601,0.022801,0.023425,0.003925,0.000648


In [9]:
os.makedirs('/home/woody/iwi5/iwi5362h/ALBEF/results/zero_shot_vindr_results/evaluate_heatmaps_grad_crossattn', exist_ok=True)

In [10]:
# Optionally, save CSVs for deeper analysis
out_full = Path("/home/woody/iwi5/iwi5362h/ALBEF/results/zero_shot_vindr_results/evaluate_heatmaps_grad_crossattn/heatmap_eval_per_box.csv")
out_label = Path("/home/woody/iwi5/iwi5362h/ALBEF/results/zero_shot_vindr_results/evaluate_heatmaps_grad_crossattn/heatmap_eval_per_label.csv")
df_res.to_csv(out_full, index=False)
df_label.to_csv(out_label)
print(f"\n[Output] Saved per-box results to: {out_full}")
print(f"[Output] Saved per-label summary to: {out_label}")


[Output] Saved per-box results to: /home/woody/iwi5/iwi5362h/ALBEF/results/zero_shot_vindr_results/evaluate_heatmaps_grad_crossattn/heatmap_eval_per_box.csv
[Output] Saved per-label summary to: /home/woody/iwi5/iwi5362h/ALBEF/results/zero_shot_vindr_results/evaluate_heatmaps_grad_crossattn/heatmap_eval_per_label.csv


In [None]:
df_res