In [4]:
# Cell 1 — Imports & helpers

import os
from pathlib import Path
import random
import re
import numpy as np
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import nibabel as nib  # for NIfTI lesion masks

# Models / dataset
from segment_anything import sam_model_registry
from ISUPMedSAM import IMG_SIZE, MedSAMSliceSpatialAttn, resize_to_img_size
from dataset_picai_slices import PicaiSliceDataset

# --- plotting defaults
plt.rcParams["figure.dpi"] = 120

# ===================== Utils =====================

def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

def to_uint8_gray(channel_tensor: torch.Tensor) -> np.ndarray:
    """
    Convert a single-channel tensor in [0,1] with shape (1,H,W) or (H,W)
    into an HxWx3 uint8 grayscale image.
    """
    if channel_tensor.ndim == 3 and channel_tensor.shape[0] == 1:
        ch = channel_tensor[0]
    else:
        ch = channel_tensor
    x = ch.detach().cpu().clamp(0, 1).numpy()
    x = (x * 255.0).round().astype(np.uint8)
    return np.stack([x, x, x], axis=-1)

def normalize01(arr: torch.Tensor | np.ndarray, eps: float = 1e-12) -> np.ndarray:
    """
    Normalize a tensor/ndarray to [0,1] (returns numpy).
    If nearly constant, returns zeros.
    """
    a = arr.detach().cpu().numpy() if torch.is_tensor(arr) else arr
    mn, mx = float(a.min()), float(a.max())
    if mx <= mn + eps:
        return np.zeros_like(a, dtype=np.float32)
    return ((a - mn) / (mx - mn)).astype(np.float32)

def overlay_red(img_uint8_hwc: np.ndarray, heat01_hw: np.ndarray, alpha: float = 0.25) -> np.ndarray:
    """
    Faint red overlay (alpha default 0.25) on an RGB base image (uint8).
    heat01_hw should be in [0,1] with shape HxW.
    """
    base = img_uint8_hwc.astype(np.float32) / 255.0
    heat = np.clip(heat01_hw, 0.0, 1.0)[..., None]
    red = np.concatenate([heat, np.zeros_like(heat), np.zeros_like(heat)], axis=-1)
    out = (1 - alpha) * base + alpha * red
    return (out * 255.0).clip(0, 255).astype(np.uint8)

def rotate_ccw(arr: np.ndarray) -> np.ndarray:
    """
    Rotate HxW or HxWx3 array 90° counterclockwise for display.
    """
    return np.rot90(arr, k=1, axes=(0, 1))

def dbg_stats(name: str, hm: torch.Tensor):
    """
    Print quick diagnostics for heatmaps:
    per-image min/max/std and mass in top-5% pixels (uniform ≈ 0.05).
    Accepts [B,H,W] or [B,1,H,W].
    """
    if hm.ndim == 4:
        hm = hm.squeeze(1)
    flat = hm.reshape(hm.shape[0], -1)
    mins = flat.min(dim=1).values[:8]
    maxs = flat.max(dim=1).values[:8]
    stds = flat.std(dim=1)[:8]
    k = max(1, int(flat.shape[1] * 0.05))
    topk_mass = torch.topk(flat, k, dim=1).values.sum(dim=1) / flat.sum(dim=1).clamp(min=1e-12)
    print(f"[{name}] min[:8]={mins.tolist()}")
    print(f"[{name}] max[:8]={maxs.tolist()}")
    print(f"[{name}] std[:8]={stds.tolist()}")
    print(f"[{name}] top5% mass (uniform≈0.05)[:8]={topk_mass.tolist()}")

# ===================== Model helpers =====================

def build_model(
    sam_type: str,
    sam_ckpt: str,
    num_classes: int = 3,
    proj_dim: int = 512,
    attn_dim: int = 256,
    head_hidden: int = 256,
    head_dropout: float = 0.0,   # keep 0.0 for eval
    use_pre_neck: bool = True,
    device: str = "cuda",
) -> MedSAMSliceSpatialAttn:
    """
    Build MedSAM backbone safely on CPU if CUDA isn't available.
    Some checkpoints were saved with CUDA tensors; we force map_location='cpu'.
    """
    # Always build the SAM model with no weights first
    sam = sam_model_registry[sam_type](checkpoint=None)

    # Load checkpoint on CPU no matter what environment we're in
    sd = torch.load(sam_ckpt, map_location="cpu")
    # Handle common wrappers
    if isinstance(sd, dict) and "state_dict" in sd:
        sd = sd["state_dict"]
    elif isinstance(sd, dict) and "model" in sd:
        sd = sd["model"]

    # Strict where possible; relax if keys differ slightly across forks
    missing, unexpected = sam.load_state_dict(sd, strict=False)
    if missing or unexpected:
        print(f"[sam] missing keys: {missing[:8] if missing else []}")
        print(f"[sam] unexpected keys: {unexpected[:8] if unexpected else []}")

    # Now wrap in your spatial-attn head
    model = MedSAMSliceSpatialAttn(
        sam_model=sam,
        num_classes=num_classes,
        proj_dim=proj_dim,
        attn_dim=attn_dim,
        head_hidden=head_hidden,
        head_dropout=head_dropout,
        use_pre_neck=use_pre_neck,
    )

    # Final device placement (CPU if no CUDA)
    final_device = device if torch.cuda.is_available() and device.startswith("cuda") else "cpu"
    return model.to(final_device).eval()

def load_ckpt_strict_filtered(model: torch.nn.Module, ckpt_path: str):
    """
    Load a checkpoint but drop keys whose shapes don't match current model
    (useful when proj_dim/head differ slightly).
    """
    sd = torch.load(ckpt_path, map_location="cpu")
    sd = sd.get("model", sd)
    own = model.state_dict()
    filtered, dropped = {}, []
    for k, v in sd.items():
        if k in own and own[k].shape == v.shape:
            filtered[k] = v
        else:
            dropped.append(k)
    missing, unexpected = model.load_state_dict(filtered, strict=False)
    print(f"[load] loaded {ckpt_path}")
    if dropped:
        print(f"[load] dropped {len(dropped)} incompatible keys (e.g., {dropped[:6]})")
    if missing:
        print(f"[load] missing: {missing[:6]}")
    if unexpected:
        print(f"[load] unexpected: {unexpected[:6]}")

# ===================== Mask helpers =====================

SUFFIX_RE = re.compile(r"(_000\d)?(\.nii(\.gz)?)$")

def path_root_from_channel_path(p: str) -> str:
    """
    Strip modality suffix like '_0000.nii.gz' from a channel path
    to get a root ID (e.g., '10000_1000000').
    """
    name = Path(p).name
    name = SUFFIX_RE.sub("", name)
    return name

def extract_slice_index(meta_sample: dict) -> int | None:
    """
    Best-effort slice index extraction from the sample dict.
    """
    for k in ("slice_idx", "slice", "z", "z_idx", "index_z"):
        if k in meta_sample and meta_sample[k] is not None:
            try:
                return int(meta_sample[k])
            except Exception:
                pass
    if "meta" in meta_sample and isinstance(meta_sample["meta"], dict):
        for k in ("slice_idx", "slice", "z", "z_idx", "index_z"):
            if k in meta_sample["meta"]:
                try:
                    return int(meta_sample["meta"][k])
                except Exception:
                    pass
    return None

def load_mask_slice(mask_dir: Path, root_name: str, slice_idx: int | None) -> np.ndarray | None:
    """
    Load a 2D lesion mask slice (float32 in {0,1}) from <mask_dir>/<root>.nii(.gz).
    Chooses the provided slice index; if None or out of bounds, uses the middle slice.
    """
    f = mask_dir / f"{root_name}.nii.gz"
    if not f.exists():
        f = mask_dir / f"{root_name}.nii"
        if not f.exists():
            print(f"[mask] missing for root={root_name}")
            return None
    try:
        vol = nib.load(str(f))
        m = vol.get_fdata().astype(np.float32)
        z = slice_idx if (slice_idx is not None and 0 <= slice_idx < m.shape[-1]) else m.shape[-1] // 2
        sl = m[..., z]
        sl = (sl > 0).astype(np.float32)
        return sl
    except Exception as e:
        print(f"[mask] failed to load {f}: {e}")
        return None


In [5]:

# === Cell 2: Args (edit me) ===
# Data
MANIFEST   = "/project/aip-medilab/shared/picai/manifests/slices_manifest.csv"
FOLDS_TEST = "4"  # e.g., "4" or "3,4"
TARGET     = "isup3"  # ["isup3", "binary_all", "binary_low_high", "raw"]
CHANNELS   = "path_T2,path_ADC,path_HBV"
NUM        = 8           # used only if IDS is None
IDS        = "10012_1000012:6"        # e.g., "10034_1000345:18,10077_1000999:22" or None
SEED       = 42
MASK_DIR   = "/home/ewillis/projects/aip-medilab/shared/picai/picai_prepped_registered/labelsTr_lesion"

# Models
SAM_TYPE   = "vit_b"   # ["vit_b", "vit_l", "vit_h"]
SAM_CKPT   = "/project/aip-medilab/ewillis/pca_contrastive/mri_model_medsam_finetune/work_dir/MedSAM/medsam_vit_b.pth"
BASELINE_CKPT = "/home/ewillis/projects/aip-medilab/ewillis/pca_contrastive/mri_model_medsam_finetune_2D/results_no_loading/baseline/ckpt_best.pt"
ALIGNED_CKPT  = "/home/ewillis/projects/aip-medilab/ewillis/pca_contrastive/mri_model_medsam_finetune_2D/results_no_loading/triplet/ckpt_head_best.pt"

NUM_CLASSES = 3
PROJ_DIM    = 512
ATTN_DIM    = 256
HEAD_HIDDEN = 256
HEAD_DROPOUT= 0.1
USE_PRE_NECK= True

# Viz
OVERLAY_ALPHA = 0.25   # faint overlay so white anatomy pops
ROTATE_CCW    = True   # rotate for readability


In [6]:
# --- Sampling helpers missing in this notebook cell ---

import torch
import torch.nn.functional as F
import numpy as np

# Resize each sample to IMG_SIZE so the batch stacks cleanly
def collate_resize_to_imgsize(batch):
    imgs, labels, metas = [], [], []
    for s in batch:
        x = s["image"].unsqueeze(0)  # [1,C,H,W]
        x = F.interpolate(x, size=(IMG_SIZE, IMG_SIZE), mode="bilinear", align_corners=False).squeeze(0)
        imgs.append(x)
        labels.append(torch.as_tensor(s["label"], dtype=torch.long))
        metas.append(s)  # keep original dict for paths/ids
    return {"image": torch.stack(imgs, 0), "label": torch.stack(labels, 0), "meta": metas}

# Random sample N items from the dataset
def sample_batch(ds, n, seed=42):
    rng = np.random.default_rng(seed)
    idxs = rng.choice(len(ds), size=min(n, len(ds)), replace=False)
    batch = [ds[i] for i in idxs]
    return collate_resize_to_imgsize(batch)

# Optional: sample by explicit ids list if you’re using --ids
# Expects each ds[i] to include 'case_id' in the returned dict.
def sample_by_ids(ds, ids_list):
    wanted = set(str(i) for i in ids_list)
    picked = []
    for i in range(len(ds)):
        s = ds[i]
        cid = str(s.get("case_id", ""))
        if cid in wanted:
            picked.append(s)
        if len(picked) == len(wanted):
            break
    if not picked:
        raise ValueError("No matching case_ids found in dataset for provided --ids.")
    return collate_resize_to_imgsize(picked)

# === Cell 3: Logic (display only; no files saved) ===
set_seed(SEED)
device = "cuda" if torch.cuda.is_available() else "cpu"

chan_keys  = tuple(CHANNELS.split(","))
chan_names = ("T2", "ADC", "HBV")  # display names (must match order)

# Dataset
folds = [int(x.strip()) for x in FOLDS_TEST.split(",") if x.strip()]
ds_test = PicaiSliceDataset(
    manifest_csv=MANIFEST,
    folds=folds,
    use_skip=True,
    label6_column="label6",
    target=TARGET,
    channels=chan_keys,
    missing_channel_mode="zeros",
    pct_lower=0.5, pct_upper=99.5,
    cache_size=64,
)

# Choose samples
if IDS and len(str(IDS).strip()) > 0:
    ids_list = [t.strip() for t in str(IDS).split(",") if t.strip()]
    batch = sample_by_ids(ds_test, ids_list)
else:
    batch = sample_batch(ds_test, n=NUM, seed=SEED)

imgs, labels, metas = batch["image"].to(device), batch["label"].to(device), batch["meta"]
H, W = imgs.shape[-2], imgs.shape[-1]

# Build + load models
m_base = build_model(SAM_TYPE, SAM_CKPT, num_classes=NUM_CLASSES,
                     proj_dim=PROJ_DIM, attn_dim=ATTN_DIM,
                     head_hidden=HEAD_HIDDEN, head_dropout=HEAD_DROPOUT,
                     use_pre_neck=USE_PRE_NECK, device=device)
load_ckpt_strict_filtered(m_base, BASELINE_CKPT)

m_align = build_model(SAM_TYPE, SAM_CKPT, num_classes=NUM_CLASSES,
                      proj_dim=PROJ_DIM, attn_dim=ATTN_DIM,
                      head_hidden=HEAD_HIDDEN, head_dropout=HEAD_DROPOUT,
                      use_pre_neck=USE_PRE_NECK, device=device)
load_ckpt_strict_filtered(m_align, ALIGNED_CKPT)

# Forward once to get attention + logits (no grads)
with torch.no_grad():
    logits_b, emb_b, attn_b, _ = m_base(imgs, return_attn=True, return_feats=True, attn_upsample_to=(H, W))
    logits_a, emb_a, attn_a, _ = m_align(imgs, return_attn=True, return_feats=True, attn_upsample_to=(H, W))

pred_b = logits_b.argmax(1).cpu().numpy()
pred_a = logits_a.argmax(1).cpu().numpy()
y_true = labels.cpu().numpy()

# Prepare masks
mask_dir = Path(MASK_DIR)
mask_list = []
for meta in metas:
    sample_path = None
    for k in chan_keys:
        if k in meta and meta[k] is not None:
            sample_path = meta[k]; break
        if "paths" in meta and isinstance(meta["paths"], dict) and k in meta["paths"]:
            sample_path = meta["paths"][k]; break
    root = path_root_from_channel_path(str(sample_path)) if sample_path is not None else None
    z = extract_slice_index(meta)
    m2d = load_mask_slice(mask_dir, root, z) if root is not None else None
    mask_list.append(m2d)

# ---- Display per sample ----
B = imgs.size(0)
for i in range(B):
    # Per-sample fig: rows = channels (T2/ADC/HBV), cols = Input | Mask | Base-Attn | Align-Attn
    cols = 4
    rows = 3
    fig, axes = plt.subplots(rows, cols, figsize=(cols*3.2, rows*3.2))
    if rows == 1:
        axes = np.expand_dims(axes, 0)
    fig.suptitle("Sample {}  |  true={}  base_pred={}  aligned_pred={}".format(
        i, int(y_true[i]), int(pred_b[i]), int(pred_a[i])
    ), fontsize=12)

    # Common attention maps (0..1)
    a_b = normalize01(attn_b[i].squeeze(0))
    a_a = normalize01(attn_a[i].squeeze(0))

    # Mask prep
    mask = mask_list[i]
    if mask is not None:
        mask_rgb = (np.stack([mask, mask, mask], axis=-1) * 255).astype(np.uint8)
    else:
        mask_rgb = np.zeros((H, W, 3), dtype=np.uint8)

    # Channel loop
    for r, (cidx, cname) in enumerate(zip([0,1,2], ["T2", "ADC", "HBV"])):
        base_img = to_uint8_gray(imgs[i, cidx:cidx+1])

        # Rotate for readability
        if ROTATE_CCW:
            base_img_r = rotate_ccw(base_img)
            mask_rgb_r = rotate_ccw(mask_rgb) if mask is not None else mask_rgb
            a_b_r = rotate_ccw(a_b); a_a_r = rotate_ccw(a_a)
        else:
            base_img_r = base_img
            mask_rgb_r = mask_rgb
            a_b_r = a_b; a_a_r = a_a

        # Panels
        axes[r,0].imshow(base_img_r); axes[r,0].set_title(f"{cname} (input)", fontsize=10)
        axes[r,1].imshow(mask_rgb_r); axes[r,1].set_title("Lesion mask", fontsize=10)
        axes[r,2].imshow(overlay_red(base_img_r, a_b_r, alpha=OVERLAY_ALPHA)); axes[r,2].set_title("Baseline—Attn", fontsize=10)
        axes[r,3].imshow(overlay_red(base_img_r, a_a_r, alpha=OVERLAY_ALPHA)); axes[r,3].set_title("Aligned—Attn", fontsize=10)

        for cc in range(cols):
            axes[r,cc].set_axis_off()

    plt.tight_layout()
    plt.show()


  return torch._C._cuda_getDeviceCount() > 0


ValueError: No matching case_ids found in dataset for provided --ids.

In [1]:
import torch, re, sys, json

ckpt_path = "/home/ewillis/projects/aip-medilab/ewillis/pca_contrastive/mri_model_medsam_finetune_2D/results_no_loading/baseline/ckpt_best.pt"
sd = torch.load(ckpt_path, map_location="cpu")
sd = sd.get("model", sd)

# strip potential 'module.' prefix from DDP
if any(k.startswith("module.") for k in sd):
    sd = {k.replace("module.","",1): v for k,v in sd.items()}

def shapes(prefix):
    rgx = re.compile(rf"^{re.escape(prefix)}")
    return {k: tuple(v.shape) for k,v in sd.items() if rgx.search(k)}

out = {
    "has_pool": any(k.startswith("pool.") for k in sd),
    "has_attn_pool": any(k.startswith("attn_pool.") for k in sd),
    "pool.theta.weight": tuple(sd.get("pool.theta.weight", torch.empty(0)).shape) if "pool.theta.weight" in sd else None,
    "pool.gate.weight":  tuple(sd.get("pool.gate.weight",  torch.empty(0)).shape) if "pool.gate.weight" in sd else None,
    "proj.0.weight":     tuple(sd.get("proj.0.weight",     torch.empty(0)).shape) if "proj.0.weight" in sd else None,  # LayerNorm weight -> (C,)
    "proj.1.weight":     tuple(sd.get("proj.1.weight",     torch.empty(0)).shape) if "proj.1.weight" in sd else None,  # Linear -> (proj_dim, C)
    "head.0.weight":     tuple(sd.get("head.0.weight",     torch.empty(0)).shape) if "head.0.weight" in sd else None,
    "head.1.weight":     tuple(sd.get("head.1.weight",     torch.empty(0)).shape) if "head.1.weight" in sd else None,
    "all_pool_keys": sorted([k for k in sd if k.startswith("pool.")])[:12],
    "all_proj_keys": sorted([k for k in sd if k.startswith("proj.")])[:12],
}
print(json.dumps(out, indent=2))


{
  "has_pool": true,
  "has_attn_pool": false,
  "pool.theta.weight": [
    256,
    16,
    1,
    1
  ],
  "pool.gate.weight": [
    256,
    16,
    1,
    1
  ],
  "proj.0.weight": [
    16
  ],
  "proj.1.weight": [
    512,
    16
  ],
  "head.0.weight": [
    512
  ],
  "head.1.weight": [
    256,
    512
  ],
  "all_pool_keys": [
    "pool.gate.bias",
    "pool.gate.weight",
    "pool.norm.bias",
    "pool.norm.weight",
    "pool.score.bias",
    "pool.score.weight",
    "pool.theta.bias",
    "pool.theta.weight"
  ],
  "all_proj_keys": [
    "proj.0.bias",
    "proj.0.weight",
    "proj.1.bias",
    "proj.1.weight"
  ]
}
