In [None]:
# === ROI2 inference + save prob + masks @ thr 0.8/0.9 for TWO checkpoints (single notebook cell) ===
import os
import numpy as np
import tifffile
import torch
import torch.nn as nn
import torch.nn.functional as F

# -------------------------
# INPUTS (you provided)
# -------------------------
REF_TIF = "Nov23_Crops/CCP_ROI2.tif"        # used ONLY to compute global min/max
RAW_TIF = "Nov23_Crops/CCP_ROI2_crop1.tif"  # stack to run inference on

CKPT_A = "Nov23_Crops/unet25d_runs/unet25d_k5_or_gn_drop_dice_bce_lr3e-4_clip1/best.pt"       # pre-finetune
CKPT_B = "Nov23_Crops/unet25d_runs/unet25d_k5_or_finetune_crop2/best_finetune.pt"             # finetune

OUT_DIR = "Nov23_Crops/roi2_crop1_inference"
THRS = (0.80, 0.90)

K = 5
PAD = K // 2  # => drop first 2 and last 2 slices automatically
BATCH = 2     # safe for memory on 512x512; can increase batch size!!!!

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
os.makedirs(OUT_DIR, exist_ok=True)

# -------------------------
# MODEL
# -------------------------
def gn(num_channels: int, groups: int) -> nn.GroupNorm:
    g = min(groups, num_channels)
    while num_channels % g != 0 and g > 1:
        g -= 1
    return nn.GroupNorm(g, num_channels)

class ConvBlock(nn.Module):
    def __init__(self, in_ch: int, out_ch: int, groups: int):
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False)
        self.gn1   = gn(out_ch, groups)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False)
        self.gn2   = gn(out_ch, groups)

    def forward(self, x):
        x = F.relu(self.gn1(self.conv1(x)), inplace=True)
        x = F.relu(self.gn2(self.conv2(x)), inplace=True)
        return x

class UNet2D(nn.Module):
    def __init__(self, in_ch: int, base_ch: int = 64, groups: int = 8, dropout_bottleneck: float = 0.10):
        super().__init__()
        c1, c2, c3, c4, c5 = base_ch, base_ch * 2, base_ch * 4, base_ch * 8, base_ch * 16

        self.enc1 = ConvBlock(in_ch, c1, groups)
        self.pool1 = nn.MaxPool2d(2)

        self.enc2 = ConvBlock(c1, c2, groups)
        self.pool2 = nn.MaxPool2d(2)

        self.enc3 = ConvBlock(c2, c3, groups)
        self.pool3 = nn.MaxPool2d(2)

        self.enc4 = ConvBlock(c3, c4, groups)
        self.pool4 = nn.MaxPool2d(2)

        self.bottleneck = ConvBlock(c4, c5, groups)
        self.drop = nn.Dropout2d(p=dropout_bottleneck)

        self.up4 = nn.ConvTranspose2d(c5, c4, 2, stride=2)
        self.dec4 = ConvBlock(c4 + c4, c4, groups)

        self.up3 = nn.ConvTranspose2d(c4, c3, 2, stride=2)
        self.dec3 = ConvBlock(c3 + c3, c3, groups)

        self.up2 = nn.ConvTranspose2d(c3, c2, 2, stride=2)
        self.dec2 = ConvBlock(c2 + c2, c2, groups)

        self.up1 = nn.ConvTranspose2d(c2, c1, 2, stride=2)
        self.dec1 = ConvBlock(c1 + c1, c1, groups)

        self.out = nn.Conv2d(c1, 1, 1)

    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool1(e1))
        e3 = self.enc3(self.pool2(e2))
        e4 = self.enc4(self.pool3(e3))

        b = self.bottleneck(self.pool4(e4))
        b = self.drop(b)

        d4 = self.up4(b)
        d4 = self.dec4(torch.cat([d4, e4], dim=1))

        d3 = self.up3(d4)
        d3 = self.dec3(torch.cat([d3, e3], dim=1))

        d2 = self.up2(d3)
        d2 = self.dec2(torch.cat([d2, e2], dim=1))

        d1 = self.up1(d2)
        d1 = self.dec1(torch.cat([d1, e1], dim=1))

        return self.out(d1)  # logits

def get_model_cfg(ckpt: dict):
    cfg = ckpt.get("cfg", {}) or {}
    base_ch = int(cfg.get("BASE_CH", cfg.get("base_ch", 64)))
    gn_groups = int(cfg.get("GN_GROUPS", cfg.get("gn_groups", 8)))
    drop_p = float(cfg.get("DROPOUT_BOTTLENECK", cfg.get("dropout_bottleneck", 0.10)))
    return base_ch, gn_groups, drop_p

def load_model(ckpt_path: str):
    ckpt = torch.load(ckpt_path, map_location="cpu")
    sd = ckpt.get("model_state", None) or ckpt.get("state_dict", None)
    if sd is None:
        raise KeyError(f"Checkpoint missing model_state/state_dict: {ckpt_path}")

    base_ch, gn_groups, drop_p = get_model_cfg(ckpt)
    model = UNet2D(in_ch=K, base_ch=base_ch, groups=gn_groups, dropout_bottleneck=drop_p).to(DEVICE)
    model.load_state_dict(sd, strict=True)
    model.eval()
    return model

# -------------------------
# LOAD + NORMALIZE
# -------------------------
ref = tifffile.imread(REF_TIF)
raw = tifffile.imread(RAW_TIF)

# ensure [Z,H,W]
if ref.ndim == 2:
    ref = ref[None, ...]
if raw.ndim == 2:
    raw = raw[None, ...]

ref = ref.astype(np.float32)
raw = raw.astype(np.float32)

gmin = float(np.min(ref))
gmax = float(np.max(ref))
print(f"Ref {REF_TIF} shape={tuple(ref.shape)}, GLOBAL_MIN={gmin:.3f}, GLOBAL_MAX={gmax:.3f}")
print(f"Eval stack {RAW_TIF} shape={tuple(raw.shape)}  (expect [Z,H,W])")

Z, H, W = raw.shape
z_centers = list(range(PAD, Z - PAD))  # drops first 2 & last 2 for K=5
N = len(z_centers)
print(f"K={K}, PAD={PAD} => centers z={z_centers[0]}..{z_centers[-1]} (N={N})")

x01 = (raw - gmin) / (gmax - gmin + 1e-12)
x01 = np.clip(x01, 0.0, 1.0).astype(np.float32)

# Build input tensor [N,K,H,W]
x_in = np.stack([x01[z - PAD : z + PAD + 1] for z in z_centers], axis=0)  # [N,K,H,W]
xt = torch.from_numpy(x_in).to(DEVICE)

# -------------------------
# INFERENCE + SAVE
# -------------------------
def run_and_save(tag: str, ckpt_path: str):
    model = load_model(ckpt_path)
    probs_center = []

    with torch.no_grad():
        for i in range(0, N, BATCH):
            logits = model(xt[i:i+BATCH])
            probs = torch.sigmoid(logits).squeeze(1)  # [b,H,W]
            probs_center.append(probs.detach().cpu().numpy().astype(np.float32))

    probs_center = np.concatenate(probs_center, axis=0)  # [N,H,W]

    # Place back into full Z volume
    prob_vol = np.full((Z, H, W), np.nan, dtype=np.float32)
    for i, z in enumerate(z_centers):
        prob_vol[z] = probs_center[i]

    # Masks
    mask_vols = {}
    for thr in THRS:
        m = np.zeros((Z, H, W), dtype=np.uint8)
        for i, z in enumerate(z_centers):
            m[z] = (probs_center[i] >= thr).astype(np.uint8) * 255
        mask_vols[thr] = m

    # Save
    base = os.path.join(OUT_DIR, f"{os.path.splitext(os.path.basename(RAW_TIF))[0]}__{tag}")
    prob_path = base + "__prob_float16.tif"
    tifffile.imwrite(prob_path, prob_vol.astype(np.float16))  # for visualization

    for thr in THRS:
        mpath = base + f"__mask_thr{thr:.2f}.tif"
        tifffile.imwrite(mpath, mask_vols[thr])

    print(f"\n[{tag}] ckpt={ckpt_path}")
    print(f"  saved prob:  {prob_path}")
    for thr in THRS:
        print(f"  saved mask:  {base}__mask_thr{thr:.2f}.tif")

run_and_save("A_pre", CKPT_A)
run_and_save("B_finetune", CKPT_B)

print(f"\nDone. Outputs in: {OUT_DIR}")
