In [None]:
import os, glob, math
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

NUM_CLASSES = 5
CLASS_NAMES = ["Background", "Yellow", "Green", "Red", "Blue"]

REPO_ROOT = os.getcwd()

TEST_DIR = os.path.join(REPO_ROOT, "project", "test_data_tiff")

CKPT_UNET = os.path.join(REPO_ROOT, "best_unet.pth")
CKPT_RES  = os.path.join(REPO_ROOT, "best_resattn_unet_ds.pth")

OUT_DIR = os.path.join(REPO_ROOT, "project", "inference_outputs_fulltest")
os.makedirs(OUT_DIR, exist_ok=True)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE_AUTOMIX = True   
TILE = 512
STRIDE = 400  
def robust_normalize(img: np.ndarray) -> np.ndarray:
    img = img.astype(np.float32)
    p2, p98 = np.percentile(img, (2, 98))
    if (p98 - p2) < 1e-6:
        return np.zeros_like(img, dtype=np.float32)
    img = (img - p2) / (p98 - p2)
    return np.clip(img, 0.0, 1.0).astype(np.float32)

def ensure_2d(img: np.ndarray) -> np.ndarray:

    if img.ndim == 3:
        img = img[..., 0]
    return img

def pad_to_tile(img: np.ndarray, tile=TILE) -> tuple[np.ndarray, tuple[int,int]]:
    h, w = img.shape
    pad_h = (tile - (h % tile)) % tile
    pad_w = (tile - (w % tile)) % tile
    if pad_h or pad_w:
        img = np.pad(img, ((0, pad_h), (0, pad_w)), mode="reflect")
    return img, (pad_h, pad_w)

def unpad(img: np.ndarray, pad_hw: tuple[int,int]) -> np.ndarray:
    pad_h, pad_w = pad_hw
    if pad_h:
        img = img[:-pad_h, :]
    if pad_w:
        img = img[:, :-pad_w]
    return img

def discrete_cmap():
 
    return plt.colormaps["tab10"].resampled(NUM_CLASSES)

def gn(ch): return nn.GroupNorm(8, ch)

class ConvGNAct(nn.Module):
    """Single conv + GN + ReLU. Used to build a DoubleConv style block by stacking."""
    def __init__(self, in_ch, out_ch, k=3, p=1):
        super().__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, k, padding=p, bias=False)
        self.gn = gn(out_ch)
        self.act = nn.ReLU(inplace=True)
    def forward(self, x):
        return self.act(self.gn(self.conv(x)))

class DoubleConvGN(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            ConvGNAct(in_ch, out_ch),
            ConvGNAct(out_ch, out_ch),
        )
    def forward(self, x): return self.net(x)

class SimpleUNet(nn.Module):
    """
    Matches your described UNet:
    inc -> down1 -> down2 -> down3 -> up1+skip(x3) -> up2+skip(x2) -> up3+skip(x1) -> out
    """
    def __init__(self, n_classes=5, base=64):
        super().__init__()
        self.inc = DoubleConvGN(1, base)
        self.down1 = nn.Sequential(nn.MaxPool2d(2), DoubleConvGN(base, base*2))
        self.down2 = nn.Sequential(nn.MaxPool2d(2), DoubleConvGN(base*2, base*4))
        self.down3 = nn.Sequential(nn.MaxPool2d(2), DoubleConvGN(base*4, base*8))

        self.up1 = nn.ConvTranspose2d(base*8, base*4, 2, stride=2)
        self.conv1 = DoubleConvGN(base*8, base*4)

        self.up2 = nn.ConvTranspose2d(base*4, base*2, 2, stride=2)
        self.conv2 = DoubleConvGN(base*4, base*2)

        self.up3 = nn.ConvTranspose2d(base*2, base, 2, stride=2)
        self.conv3 = DoubleConvGN(base*2, base)

        self.outc = nn.Conv2d(base, n_classes, 1)

    def forward(self, x):
        x1 = self.inc(x)    
        x2 = self.down1(x1) 
        x3 = self.down2(x2) 
        x4 = self.down3(x3) 

        x = self.up1(x4)
        x = self.conv1(torch.cat([x3, x], dim=1))

        x = self.up2(x)
        x = self.conv2(torch.cat([x2, x], dim=1))

        x = self.up3(x)
        x = self.conv3(torch.cat([x1, x], dim=1))

        return self.outc(x)

class ResBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False)
        self.gn1 = gn(out_ch)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False)
        self.gn2 = gn(out_ch)
        self.act = nn.SiLU(inplace=True)
        self.skip = nn.Identity() if in_ch == out_ch else nn.Conv2d(in_ch, out_ch, 1, bias=False)
    def forward(self, x):
        h = self.act(self.gn1(self.conv1(x)))
        h = self.gn2(self.conv2(h))
        return self.act(h + self.skip(x))

import os, re
import torch
import torch.nn as nn
import torch.nn.functional as F

def _clean_state_dict(state):

    if isinstance(state, dict) and ("state_dict" in state):
        state = state["state_dict"]
    if isinstance(state, dict) and ("model" in state):
        state = state["model"]

    if isinstance(state, dict):
        new = {}
        for k, v in state.items():
            nk = k
            if nk.startswith("module."):
                nk = nk[len("module."):]
            new[nk] = v
        return new
    return state

def load_weights_strict(model, ckpt_path, device="cpu"):
    if not os.path.exists(ckpt_path):
        raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")
    state = torch.load(ckpt_path, map_location=device)
    state = _clean_state_dict(state)
    missing, unexpected = model.load_state_dict(state, strict=True)
    print(f"[OK] Loaded: {ckpt_path}")
    return missing, unexpected

def gn(ch): return nn.GroupNorm(8, ch)

class ResBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False)
        self.gn1 = gn(out_ch)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False)
        self.gn2 = gn(out_ch)
        self.act = nn.SiLU(inplace=True)
        self.skip = nn.Identity() if in_ch == out_ch else nn.Conv2d(in_ch, out_ch, 1, bias=False)

    def forward(self, x):
        h = self.act(self.gn1(self.conv1(x)))
        h = self.gn2(self.conv2(h))
        return self.act(h + self.skip(x))

class AttnGate(nn.Module):
    """
    Attention gate with reduced intermediate channels:
      theta: skip_ch -> inter_ch
      phi:   gate_ch -> inter_ch
      psi:   inter_ch -> 1
    This matches your checkpoint shapes like [24,48,1,1], [1,24,1,1], etc.
    """
    def __init__(self, skip_ch, gate_ch, inter_ch=None):
        super().__init__()
        if inter_ch is None:
            inter_ch = max(1, skip_ch // 2)

        self.theta = nn.Conv2d(skip_ch, inter_ch, 1, bias=False)
        self.phi   = nn.Conv2d(gate_ch, inter_ch, 1, bias=False)
        self.psi   = nn.Conv2d(inter_ch, 1, 1, bias=True)
        self.act   = nn.SiLU(inplace=True)
        self.sig   = nn.Sigmoid()

    def forward(self, skip, gate):
        g = F.interpolate(gate, size=skip.shape[-2:], mode="bilinear", align_corners=False)
        a = self.act(self.theta(skip) + self.phi(g))
        a = self.sig(self.psi(a))
        return skip * a

class UpBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False)
        self.reduce = nn.Conv2d(in_ch, out_ch, 1, bias=False)
       
        self.attn = AttnGate(skip_ch=out_ch, gate_ch=out_ch, inter_ch=max(1, out_ch // 2))
        self.block = ResBlock(out_ch * 2, out_ch)

    def forward(self, x, skip):
        x = self.up(x)
        x = self.reduce(x)
        s = self.attn(skip, x)
        return self.block(torch.cat([s, x], dim=1))

class ResAttnUNetDS(nn.Module):
    def __init__(self, n_classes=5, base=48):
        super().__init__()
        self.stem   = ResBlock(1, base)
        self.d1     = nn.Sequential(nn.MaxPool2d(2), ResBlock(base, base*2))
        self.d2     = nn.Sequential(nn.MaxPool2d(2), ResBlock(base*2, base*4))
        self.d3     = nn.Sequential(nn.MaxPool2d(2), ResBlock(base*4, base*8))
        self.d4     = nn.Sequential(nn.MaxPool2d(2), ResBlock(base*8, base*12))
        self.bottle = ResBlock(base*12, base*12)

        self.u3 = UpBlock(base*12, base*8)
        self.u2 = UpBlock(base*8,  base*4)
        self.u1 = UpBlock(base*4,  base*2)
        self.u0 = UpBlock(base*2,  base)

        self.head0 = nn.Conv2d(base, n_classes, 1)

    def forward(self, x):
        s0 = self.stem(x)
        s1 = self.d1(s0)
        s2 = self.d2(s1)
        s3 = self.d3(s2)
        s4 = self.d4(s3)
        b  = self.bottle(s4)

        x = self.u3(b,  s3)
        x = self.u2(x,  s2)
        x = self.u1(x,  s1)
        x = self.u0(x,  s0)
        return self.head0(x)


def load_weights_strict(model, ckpt_path):
    if not os.path.exists(ckpt_path):
        raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")
    state = torch.load(ckpt_path, map_location=DEVICE)
    missing, unexpected = model.load_state_dict(state, strict=False)
    print(f"[OK] Loaded: {ckpt_path}")
    if missing:
        print("  missing keys (sample):", missing[:8], ("..." if len(missing)>8 else ""))
    if unexpected:
        print("  unexpected keys (sample):", unexpected[:8], ("..." if len(unexpected)>8 else ""))

@torch.no_grad()
def predict_full_image_logits(model, img01: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
    """
    img01: normalized float32 in [0,1], shape HxW
    returns:
      pred: HxW uint8 argmax
      entropy: HxW float32 entropy map
    """
    model.eval()

    img_pad, pad_hw = pad_to_tile(img01, tile=TILE)
    H, W = img_pad.shape

    num = np.zeros((NUM_CLASSES, H, W), dtype=np.float32)
    den = np.zeros((H, W), dtype=np.float32)

    wy = np.hanning(TILE).astype(np.float32)
    wx = np.hanning(TILE).astype(np.float32)
    w2 = (wy[:, None] * wx[None, :])
    w2 = np.maximum(w2, 1e-3)  

    autocast_on = (DEVICE.startswith("cuda") and DTYPE_AUTOMIX)

    for y in range(0, H - TILE + 1, STRIDE):
        for x in range(0, W - TILE + 1, STRIDE):
            patch = img_pad[y:y+TILE, x:x+TILE]
            t = torch.from_numpy(patch).unsqueeze(0).unsqueeze(0).to(DEVICE)  

            if autocast_on:
                with torch.cuda.amp.autocast():
                    logits = model(t) 
            else:
                logits = model(t)

            probs = torch.softmax(logits, dim=1)[0].float().cpu().numpy()  

            num[:, y:y+TILE, x:x+TILE] += probs * w2[None, :, :]
            den[y:y+TILE, x:x+TILE] += w2

    probs_full = num / (den[None, :, :] + 1e-8) 
    probs_full = probs_full.astype(np.float32)

    ent = -np.sum(probs_full * np.log(probs_full + 1e-8), axis=0).astype(np.float32)

    pred = np.argmax(probs_full, axis=0).astype(np.uint8)

    pred = unpad(pred, pad_hw)
    ent  = unpad(ent,  pad_hw)
    return pred, ent

def save_full_outputs(base_name, img01, pred_unet, ent_unet, pred_res, ent_res):
    cmap = discrete_cmap()

    for tag, pred in [("unet", pred_unet), ("resattn", pred_res)]:
        plt.figure(figsize=(8, 8))
        im = plt.imshow(pred, vmin=0, vmax=NUM_CLASSES-1, cmap=cmap)
        plt.title(f"{base_name} — {tag.upper()} prediction (full image)")
        cbar = plt.colorbar(im, fraction=0.03, pad=0.02, ticks=list(range(NUM_CLASSES)))
        cbar.ax.set_yticklabels(CLASS_NAMES)
        plt.axis("off")
        plt.tight_layout()
        plt.savefig(os.path.join(OUT_DIR, f"{base_name}_{tag}_full_pred.png"), dpi=200)
        plt.close()

    for tag, ent in [("unet", ent_unet), ("resattn", ent_res)]:
        plt.figure(figsize=(8, 8))
        plt.imshow(ent, cmap="magma")
        plt.title(f"{base_name} — {tag.upper()} entropy (uncertainty)")
        plt.colorbar(fraction=0.03, pad=0.02)
        plt.axis("off")
        plt.tight_layout()
        plt.savefig(os.path.join(OUT_DIR, f"{base_name}_{tag}_entropy.png"), dpi=200)
        plt.close()


    fig, ax = plt.subplots(1, 5, figsize=(22, 5))
    ax[0].imshow(img01, cmap="gray")
    ax[0].set_title("Input (normalized)")

    ax[1].imshow(pred_unet, vmin=0, vmax=NUM_CLASSES-1, cmap=cmap)
    ax[1].set_title("UNet: argmax class")

    ax[2].imshow(ent_unet, cmap="magma")
    ax[2].set_title("UNet: entropy")

    ax[3].imshow(pred_res, vmin=0, vmax=NUM_CLASSES-1, cmap=cmap)
    ax[3].set_title("ResAttn: argmax class")

    ax[4].imshow(ent_res, cmap="magma")
    ax[4].set_title("ResAttn: entropy")

    for a in ax:
        a.axis("off")


    caption = (
        "Full-image stitched predictions from 512×512 patches (stride 400). "
        "Argmax maps show predicted class per pixel; entropy highlights uncertainty (higher=less confident). "
        "Colors correspond to the 5 classes shown in the legend/colorbar in the individual prediction figures."
    )
    fig.text(0.5, -0.02, caption, ha="center", va="top", fontsize=10)

    plt.tight_layout()
    out_path = os.path.join(OUT_DIR, f"{base_name}_COMPARISON_PANEL.png")
    plt.savefig(out_path, dpi=200, bbox_inches="tight")
    plt.close()
    print("[SAVED]", out_path)

def main():
    print("[INFO] Device:", DEVICE)
    print("[INFO] Test dir:", TEST_DIR)
    os.makedirs(OUT_DIR, exist_ok=True)

    test_raws = sorted(glob.glob(os.path.join(TEST_DIR, "*.tif*"))) + sorted(glob.glob(os.path.join(TEST_DIR, "*.tiff")))
    test_raws = sorted(list(set(test_raws)))
    if not test_raws:
        raise RuntimeError(f"No TIFF files found in {TEST_DIR}")

    unet = SimpleUNet(n_classes=NUM_CLASSES, base=64).to(DEVICE)
    res  = ResAttnUNetDS(n_classes=NUM_CLASSES, base=48).to(DEVICE)

    load_weights_strict(unet, CKPT_UNET)
    load_weights_strict(res,  CKPT_RES)

    for raw_p in test_raws:
        base = os.path.splitext(os.path.basename(raw_p))[0]
        print("\n[RUN] Full-image inference:", base)

        img = np.array(Image.open(raw_p))
        img = ensure_2d(img)
        img01 = robust_normalize(img)

        pred_u, ent_u = predict_full_image_logits(unet, img01)
        pred_r, ent_r = predict_full_image_logits(res,  img01)

        save_full_outputs(base, img01, pred_u, ent_u, pred_r, ent_r)

    print("\n[DONE] Outputs in:", OUT_DIR)

if __name__ == "__main__":
    main()


[INFO] Device: cuda
[INFO] Test dir: c:\Users\vonkl\Documents\453_Project\453_Project\project\test_data_tiff
[OK] Loaded: c:\Users\vonkl\Documents\453_Project\453_Project\best_unet.pth
  missing keys (sample): ['inc.net.0.gn.weight', 'inc.net.0.gn.bias', 'inc.net.1.gn.weight', 'inc.net.1.gn.bias', 'down1.1.net.0.gn.weight', 'down1.1.net.0.gn.bias', 'down1.1.net.1.gn.weight', 'down1.1.net.1.gn.bias'] ...
  unexpected keys (sample): ['inc.net.0.norm.weight', 'inc.net.0.norm.bias', 'inc.net.1.norm.weight', 'inc.net.1.norm.bias', 'down1.1.net.0.norm.weight', 'down1.1.net.0.norm.bias', 'down1.1.net.1.norm.weight', 'down1.1.net.1.norm.bias'] ...
[OK] Loaded: c:\Users\vonkl\Documents\453_Project\453_Project\best_resattn_unet_ds.pth
  missing keys (sample): ['stem.conv1.weight', 'stem.gn1.weight', 'stem.gn1.bias', 'stem.conv2.weight', 'stem.gn2.weight', 'stem.gn2.bias', 'd1.1.conv1.weight', 'd1.1.gn1.weight'] ...
  unexpected keys (sample): ['bott.c1.conv.weight', 'bott.c1.norm.weight', 'bott.

  with torch.cuda.amp.autocast():


[SAVED] c:\Users\vonkl\Documents\453_Project\453_Project\project\inference_outputs_fulltest\raw_13_COMPARISON_PANEL.png

[RUN] Full-image inference: raw_14


  with torch.cuda.amp.autocast():


[SAVED] c:\Users\vonkl\Documents\453_Project\453_Project\project\inference_outputs_fulltest\raw_14_COMPARISON_PANEL.png

[DONE] Outputs in: c:\Users\vonkl\Documents\453_Project\453_Project\project\inference_outputs_fulltest
