In [8]:
import os
import glob
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
import matplotlib.pyplot as plt

# ==========================================
# 1. CONFIGURATION & TUNING
# ==========================================
PATH_UNET   = "project/best_unet_godmode_v2.pth"  # The "Detailed" model
PATH_RESNET = "best_resattn_unet_ds.pth"          # The "High Score" model

TEST_DATA_DIR = r"C:\Users\vonkl\Documents\453_Project\453_Project\project\test_data_tiff"
OUTPUT_DIR    = os.path.join(TEST_DATA_DIR, "outputs_fixed")

# --- TUNING KNOBS (The Fix) ---
# 1. Background Suppression: Lower value = More Details (Range: 0.1 to 1.0)
#    - 1.0 = Default (Conservative)
#    - 0.5 = Aggressive (Forces model to predict classes even if uncertain)
RESNET_BG_SCALE = 0.4  

# 2. Ensemble Weights: How much to trust each model?
#    - Trust ResNet for structure (0.6), UNet for detail (0.4)
WEIGHT_RESNET = 0.6
WEIGHT_UNET   = 0.4

# 3. TTA (Test Time Augmentation): Averages original + flipped images
#    - Increases runtime but significantly improves detail/smoothness
USE_TTA = True

# Standard Config
NUM_CLASSES = 5
TILE_SIZE   = 512
STRIDE      = 400
DEVICE      = "cuda" if torch.cuda.is_available() else "cpu"

# ==========================================
# 2. PRE-PROCESSING
# ==========================================
def robust_normalize(img):
    img = img.astype(np.float32)
    p2, p98 = np.percentile(img, (2, 98))
    if p98 - p2 < 1e-6: return np.zeros_like(img, dtype=np.float32)
    img = (img - p2) / (p98 - p2)
    return np.clip(img, 0.0, 1.0).astype(np.float32)

def calculate_entropy(probs):
    epsilon = 1e-6
    entropy = -torch.sum(probs * torch.log(probs + epsilon), dim=0)
    return entropy

# ==========================================
# 3. MODEL ARCHITECTURES
# ==========================================
# --- UNet ---
class DoubleConvGN_Start2(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.GroupNorm(8, out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.GroupNorm(8, out_ch),
            nn.ReLU(inplace=True)
        )
    def forward(self, x): return self.conv(x)

class SimpleUNet(nn.Module):
    def __init__(self, n_classes):
        super().__init__()
        self.inc = DoubleConvGN_Start2(1, 64)
        self.down1 = nn.Sequential(nn.MaxPool2d(2), DoubleConvGN_Start2(64, 128))
        self.down2 = nn.Sequential(nn.MaxPool2d(2), DoubleConvGN_Start2(128, 256))
        self.down3 = nn.Sequential(nn.MaxPool2d(2), DoubleConvGN_Start2(256, 512))
        self.up1 = nn.ConvTranspose2d(512, 256, 2, 2); self.conv1 = DoubleConvGN_Start2(512, 256)
        self.up2 = nn.ConvTranspose2d(256, 128, 2, 2); self.conv2 = DoubleConvGN_Start2(256, 128)
        self.up3 = nn.ConvTranspose2d(128, 64, 2, 2); self.conv3 = DoubleConvGN_Start2(128, 64)
        self.outc = nn.Conv2d(64, n_classes, 1)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1); x3 = self.down2(x2); x4 = self.down3(x3)
        x = self.up1(x4); x = torch.cat([x3, x], dim=1); x = self.conv1(x)
        x = self.up2(x); x = torch.cat([x2, x], dim=1); x = self.conv2(x)
        x = self.up3(x); x = torch.cat([x1, x], dim=1); x = self.conv3(x)
        return self.outc(x)

# --- ResAttnUNet ---
def gn_resu(ch, groups=8):
    groups = min(groups, ch)
    while groups > 1 and (ch % groups != 0): groups -= 1
    return nn.GroupNorm(groups, ch)

class ConvGNAct_Resu(nn.Module):
    def __init__(self, in_ch, out_ch, k=3, p=1):
        super().__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, k, padding=p, bias=False)
        self.gn = gn_resu(out_ch)
        self.act = nn.SiLU(inplace=True)
    def forward(self, x): return self.act(self.gn(self.conv(x)))

class ResBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.c1 = ConvGNAct_Resu(in_ch, out_ch)
        self.c2 = nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False)
        self.g2 = gn_resu(out_ch)
        self.act = nn.SiLU(inplace=True)
        self.skip = nn.Identity() if in_ch == out_ch else nn.Conv2d(in_ch, out_ch, 1, bias=False)
    def forward(self, x):
        h = self.c1(x); h = self.g2(self.c2(h))
        return self.act(h + self.skip(x))

class AttnGate(nn.Module):
    def __init__(self, skip_ch, gate_ch, inter_ch):
        super().__init__()
        self.theta = nn.Conv2d(skip_ch, inter_ch, 1, bias=False); self.phi = nn.Conv2d(gate_ch, inter_ch, 1, bias=False)
        self.psi = nn.Conv2d(inter_ch, 1, 1, bias=True); self.act = nn.SiLU(inplace=True); self.sig = nn.Sigmoid()
    def forward(self, skip, gate):
        g = F.interpolate(gate, size=skip.shape[-2:], mode="bilinear", align_corners=False)
        a = self.act(self.theta(skip) + self.phi(g))
        return skip * self.sig(self.psi(a))

class Down(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__(); self.pool = nn.MaxPool2d(2); self.block = ResBlock(in_ch, out_ch)
    def forward(self, x): return self.block(self.pool(x))

class Up(nn.Module):
    def __init__(self, in_ch, skip_ch, out_ch):
        super().__init__()
        self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False)
        self.reduce = nn.Conv2d(in_ch, out_ch, 1, bias=False)
        self.attn = AttnGate(skip_ch, out_ch, inter_ch=max(16, out_ch // 2))
        self.block = ResBlock(out_ch + skip_ch, out_ch)
    def forward(self, x, skip):
        x = self.up(x); x = self.reduce(x); skip = self.attn(skip, x)
        return self.block(torch.cat([skip, x], dim=1))

class ResAttnUNetDS(nn.Module):
    def __init__(self, n_classes=5, base=48):
        super().__init__()
        c1, c2, c3, c4, c5 = base, base*2, base*4, base*8, base*12
        self.stem = ResBlock(1, c1)
        self.d1 = Down(c1, c2); self.d2 = Down(c2, c3); self.d3 = Down(c3, c4); self.d4 = Down(c4, c5)
        self.bottleneck = ResBlock(c5, c5)
        self.u3 = Up(c5, c4, c4); self.u2 = Up(c4, c3, c3); self.u1 = Up(c3, c2, c2); self.u0 = Up(c2, c1, c1)
        self.head0 = nn.Conv2d(c1, n_classes, 1)
    def forward(self, x):
        s0 = self.stem(x); s1 = self.d1(s0); s2 = self.d2(s1); s3 = self.d3(s2); s4 = self.d4(s3)
        b = self.bottleneck(s4)
        x3 = self.u3(b, s3); x2 = self.u2(x3, s2); x1 = self.u1(x2, s1); x0 = self.u0(x1, s0)
        return self.head0(x0)

# ==========================================
# 4. ADVANCED INFERENCE (TTA + Sliding Window)
# ==========================================
def get_model_probs(model, tile_tensor, use_tta=False):
    """Returns softmax probabilities, optionally using TTA (HFlip/VFlip)"""
    # 1. Standard
    logits = model(tile_tensor)
    if isinstance(logits, tuple): logits = logits[0]
    probs = torch.softmax(logits, dim=1)
    
    if use_tta:
        # 2. Horizontal Flip
        logits_hf = model(torch.flip(tile_tensor, [3]))
        if isinstance(logits_hf, tuple): logits_hf = logits_hf[0]
        probs_hf = torch.softmax(logits_hf, dim=1)
        probs += torch.flip(probs_hf, [3])
        
        # 3. Vertical Flip
        logits_vf = model(torch.flip(tile_tensor, [2]))
        if isinstance(logits_vf, tuple): logits_vf = logits_vf[0]
        probs_vf = torch.softmax(logits_vf, dim=1)
        probs += torch.flip(probs_vf, [2])
        
        probs /= 3.0 # Average
        
    return probs

def run_inference(models, img_array):
    """
    Runs sliding window inference with Ensemble + TTA + Sensitivity Adjustment
    """
    h, w = img_array.shape
    # Accumulators for the ENSEMBLE
    ens_prob_sum = torch.zeros((NUM_CLASSES, h, w), dtype=torch.float32, device=DEVICE)
    count_map = torch.zeros((1, h, w), dtype=torch.float32, device=DEVICE)
    
    y_starts = sorted(list(set(list(range(0, h, STRIDE)) + [max(0, h - TILE_SIZE)])))
    x_starts = sorted(list(set(list(range(0, w, STRIDE)) + [max(0, w - TILE_SIZE)])))
    
    print(f"  > Processing {len(y_starts)*len(x_starts)} tiles (TTA={USE_TTA})...")
    
    with torch.no_grad():
        for y in y_starts:
            for x in x_starts:
                y_end, x_end = min(y + TILE_SIZE, h), min(x + TILE_SIZE, w)
                tile = img_array[y:y_end, x:x_end]
                th, tw = tile.shape
                
                # Reflect Pad
                pad_h, pad_w = TILE_SIZE - th, TILE_SIZE - tw
                if pad_h > 0 or pad_w > 0:
                    tile = np.pad(tile, ((0, pad_h), (0, pad_w)), mode='reflect')
                
                inp = torch.from_numpy(tile).unsqueeze(0).unsqueeze(0).to(DEVICE)
                
                # --- A. UNet Prediction ---
                probs_unet = get_model_probs(models['UNet'], inp, use_tta=USE_TTA)
                
                # --- B. ResNet Prediction (With Sensitivity Fix) ---
                probs_res = get_model_probs(models['ResNet'], inp, use_tta=USE_TTA)
                
                # FIX: Suppress ResNet Background
                # We multiply Class 0 prob by RESNET_BG_SCALE (e.g., 0.5)
                # Then re-normalize so they sum to 1.0
                probs_res[:, 0, :, :] *= RESNET_BG_SCALE
                probs_res = probs_res / probs_res.sum(dim=1, keepdim=True)
                
                # --- C. Weighted Ensemble ---
                # Combine: (0.4 * UNet) + (0.6 * Fixed_ResNet)
                probs_tile = (probs_unet * WEIGHT_UNET) + (probs_res * WEIGHT_RESNET)
                
                # Crop and Accumulate
                probs_tile = probs_tile[0, :, :th, :tw]
                ens_prob_sum[:, y:y_end, x:x_end] += probs_tile
                count_map[:, y:y_end, x:x_end] += 1.0

    avg_probs = ens_prob_sum / count_map
    pred_map = torch.argmax(avg_probs, dim=0).cpu().numpy().astype(np.uint8)
    ent_map = calculate_entropy(avg_probs).cpu().numpy()
    
    return pred_map, ent_map

# ==========================================
# 5. MAIN
# ==========================================
def main():
    os.makedirs(OUTPUT_DIR, exist_ok=True)
    models = {}

    # Load Models
    print(f"Loading UNet: {PATH_UNET}")
    m1 = SimpleUNet(NUM_CLASSES).to(DEVICE)
    m1.load_state_dict(torch.load(PATH_UNET, map_location=DEVICE), strict=False)
    m1.eval()
    models['UNet'] = m1
    
    print(f"Loading ResNet: {PATH_RESNET}")
    m2 = ResAttnUNetDS(n_classes=NUM_CLASSES, base=48).to(DEVICE)
    m2.load_state_dict(torch.load(PATH_RESNET, map_location=DEVICE), strict=False)
    m2.eval()
    models['ResNet'] = m2

    # Find Images
    image_paths = sorted(glob.glob(os.path.join(TEST_DATA_DIR, "raw_*.tiff")))
    if not image_paths: image_paths = sorted(glob.glob(os.path.join(TEST_DATA_DIR, "raw_*.tif")))
    
    print(f"\nProcessing {len(image_paths)} images...")
    
    for img_path in image_paths:
        fname = os.path.splitext(os.path.basename(img_path))[0]
        print(f"\nTarget: {fname}")
        
        try:
            raw_img = np.array(Image.open(img_path))
            norm_img = robust_normalize(raw_img)
            
            # Run the "Fixed" Inference
            pred, ent = run_inference(models, norm_img)
            
            # Save
            np.save(os.path.join(OUTPUT_DIR, f"{fname}_fixed_pred.npy"), pred)
            np.save(os.path.join(OUTPUT_DIR, f"{fname}_fixed_entropy.npy"), ent)
            
            # Visualize
            fig, ax = plt.subplots(1, 3, figsize=(18, 6))
            ax[0].imshow(raw_img, cmap="gray"); ax[0].set_title("Input")
            
            # Plot Fixed Prediction
            cmap = plt.get_cmap("jet", NUM_CLASSES)
            ax[1].imshow(pred, cmap=cmap, vmin=0, vmax=NUM_CLASSES-1, interpolation='nearest')
            ax[1].set_title(f"Ensemble (ResScale={RESNET_BG_SCALE})")
            
            # Plot Entropy
            im = ax[2].imshow(ent, cmap="inferno"); ax[2].set_title("Uncertainty")
            plt.colorbar(im, ax=ax[2])
            
            for a in ax: a.axis("off")
            plt.savefig(os.path.join(OUTPUT_DIR, f"{fname}_fixed.png"), dpi=150)
            plt.close()
            print("  ✅ Saved fixed output.")
            
        except Exception as e:
            print(f"  ❌ Error: {e}")

if __name__ == "__main__":
    main()

Loading UNet: project/best_unet_godmode_v2.pth
Loading ResNet: best_resattn_unet_ds.pth

Processing 2 images...

Target: raw_13
  > Processing 210 tiles (TTA=True)...
  ✅ Saved fixed output.

Target: raw_14
  > Processing 20 tiles (TTA=True)...
  ✅ Saved fixed output.


In [1]:
import os
import glob
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms.functional as TF
from tqdm import tqdm

# ==========================================
# 1. CONFIG
# ==========================================
# Path to the "Lazy" ResNet weights
LOAD_PATH = "best_resattn_unet_ds.pth"
SAVE_PATH = "best_resattn_finetuned.pth"
DATA_DIR  = "project/processed_tiles"

# Training Config
BATCH_SIZE = 4
LR = 1e-4  # Low LR to preserve features
EPOCHS = 5 # Fast fine-tune
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
NUM_CLASSES = 5

# --- THE MAGIC FIX ---
# We down-weight Background (Class 0) to 0.1.
# This tells the model: "I barely care about background accuracy, 
# but missing a colored object is a HUGE error."
CLASS_WEIGHTS = torch.tensor([0.1, 1.0, 1.0, 1.0, 1.0]).to(DEVICE)

# ==========================================
# 2. MODEL (ResAttnUNetDS)
# ==========================================
def gn(ch, groups=8):
    groups = min(groups, ch)
    while groups > 1 and (ch % groups != 0): groups -= 1
    return nn.GroupNorm(groups, ch)

class ConvGNAct(nn.Module):
    def __init__(self, in_ch, out_ch, k=3, p=1):
        super().__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, k, padding=p, bias=False)
        self.gn = gn(out_ch); self.act = nn.SiLU(inplace=True)
    def forward(self, x): return self.act(self.gn(self.conv(x)))

class ResBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.c1 = ConvGNAct(in_ch, out_ch)
        self.c2 = nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False)
        self.g2 = gn(out_ch); self.act = nn.SiLU(inplace=True)
        self.skip = nn.Identity() if in_ch == out_ch else nn.Conv2d(in_ch, out_ch, 1, bias=False)
    def forward(self, x): return self.act(self.g2(self.c2(self.c1(x))) + self.skip(x))

class AttnGate(nn.Module):
    def __init__(self, skip_ch, gate_ch, inter_ch):
        super().__init__()
        self.theta = nn.Conv2d(skip_ch, inter_ch, 1, bias=False)
        self.phi = nn.Conv2d(gate_ch, inter_ch, 1, bias=False)
        self.psi = nn.Conv2d(inter_ch, 1, 1, bias=True)
        self.act = nn.SiLU(inplace=True); self.sig = nn.Sigmoid()
    def forward(self, skip, gate):
        g = torch.nn.functional.interpolate(gate, size=skip.shape[-2:], mode="bilinear", align_corners=False)
        return skip * self.sig(self.psi(self.act(self.theta(skip) + self.phi(g))))

class Down(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__(); self.pool = nn.MaxPool2d(2); self.block = ResBlock(in_ch, out_ch)
    def forward(self, x): return self.block(self.pool(x))

class Up(nn.Module):
    def __init__(self, in_ch, skip_ch, out_ch):
        super().__init__()
        self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False)
        self.reduce = nn.Conv2d(in_ch, out_ch, 1, bias=False)
        self.attn = AttnGate(skip_ch, out_ch, inter_ch=max(16, out_ch // 2))
        self.block = ResBlock(out_ch + skip_ch, out_ch)
    def forward(self, x, skip):
        x = self.up(x); x = self.reduce(x); skip = self.attn(skip, x)
        return self.block(torch.cat([skip, x], dim=1))

class ResAttnUNetDS(nn.Module):
    def __init__(self, n_classes=5, base=48):
        super().__init__()
        c1, c2, c3, c4, c5 = base, base*2, base*4, base*8, base*12
        self.stem = ResBlock(1, c1)
        self.d1 = Down(c1, c2); self.d2 = Down(c2, c3); self.d3 = Down(c3, c4); self.d4 = Down(c4, c5)
        self.bottleneck = ResBlock(c5, c5)
        self.u3 = Up(c5, c4, c4); self.u2 = Up(c4, c3, c3); self.u1 = Up(c3, c2, c2); self.u0 = Up(c2, c1, c1)
        self.head0 = nn.Conv2d(c1, n_classes, 1)
    def forward(self, x):
        s0 = self.stem(x); s1 = self.d1(s0); s2 = self.d2(s1); s3 = self.d3(s2); s4 = self.d4(s3)
        b = self.bottleneck(s4)
        x3 = self.u3(b, s3); x2 = self.u2(x3, s2); x1 = self.u1(x2, s1); x0 = self.u0(x1, s0)
        return self.head0(x0)

# ==========================================
# 3. DATASET
# ==========================================
class TiledNPYDataset(Dataset):
    def __init__(self, root_dir):
        self.img_paths = sorted(glob.glob(os.path.join(root_dir, "train", "images", "*.npy")))
        self.msk_paths = sorted(glob.glob(os.path.join(root_dir, "train", "masks", "*.npy")))
    def __len__(self): return len(self.img_paths)
    def __getitem__(self, idx):
        # Using mmap to keep it fast and low RAM
        img = np.load(self.img_paths[idx], mmap_mode="r").astype(np.float32)
        msk = np.load(self.msk_paths[idx], mmap_mode="r").astype(np.int64)
        img = torch.from_numpy(img).float().unsqueeze(0)
        msk = torch.from_numpy(msk).long()
        # Simple flip augmentation
        if np.random.rand() > 0.5: img = TF.hflip(img); msk = TF.hflip(msk)
        return img, msk

# ==========================================
# 4. FINE-TUNE LOOP
# ==========================================
def main():
    print(f"Loading Lazy ResNet from {LOAD_PATH}...")
    model = ResAttnUNetDS(n_classes=NUM_CLASSES, base=48).to(DEVICE)
    model.load_state_dict(torch.load(LOAD_PATH, map_location=DEVICE), strict=False)
    
    print("Preparing Data...")
    ds = TiledNPYDataset(DATA_DIR)
    loader = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
    
    optimizer = optim.AdamW(model.parameters(), lr=LR)
    criterion = nn.CrossEntropyLoss(weight=CLASS_WEIGHTS)
    
    print(f"Starting Fine-Tune for {EPOCHS} epochs...")
    model.train()
    
    for epoch in range(EPOCHS):
        total_loss = 0
        pbar = tqdm(loader, desc=f"Epoch {epoch+1}/{EPOCHS}")
        for img, lbl in pbar:
            img, lbl = img.to(DEVICE), lbl.to(DEVICE)
            
            optimizer.zero_grad()
            logits = model(img)
            
            # If model returns deep supervision tuple, take the first one
            if isinstance(logits, tuple): logits = logits[0]
                
            loss = criterion(logits, lbl)
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            pbar.set_postfix({"Loss": f"{loss.item():.4f}"})
            
    print(f"Saving Fine-Tuned Model to {SAVE_PATH}...")
    torch.save(model.state_dict(), SAVE_PATH)
    print("✅ Done! You can now use this .pth file in the inference script.")

if __name__ == "__main__":
    main()

Loading Lazy ResNet from best_resattn_unet_ds.pth...
Preparing Data...
Starting Fine-Tune for 5 epochs...


Epoch 1/5: 100%|██████████| 557/557 [03:55<00:00,  2.36it/s, Loss=0.0819]
Epoch 2/5: 100%|██████████| 557/557 [03:44<00:00,  2.48it/s, Loss=0.0658]
Epoch 3/5: 100%|██████████| 557/557 [03:44<00:00,  2.48it/s, Loss=0.1672]
Epoch 4/5: 100%|██████████| 557/557 [03:49<00:00,  2.43it/s, Loss=0.4009]
Epoch 5/5: 100%|██████████| 557/557 [04:44<00:00,  1.96it/s, Loss=0.0616]

Saving Fine-Tuned Model to best_resattn_finetuned.pth...
✅ Done! You can now use this .pth file in the inference script.





In [None]:
import os
import glob
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
import matplotlib.pyplot as plt
import cv2 

# ==========================================
# 1. OPTIMIZED CONFIGURATION
# ==========================================
# Use the fine-tuned model if available
MODEL_PATH = "best_resattn_finetuned.pth" 
if not os.path.exists(MODEL_PATH):
    MODEL_PATH = "best_resattn_unet_ds.pth"

TEST_DATA_DIR = r"C:\Users\vonkl\Documents\453_Project\453_Project\project\test_data_tiff"
OUTPUT_DIR = os.path.join(TEST_DATA_DIR, "outputs_fast_boost")


STRIDE = 256 
TILE_SIZE = 512


ZOOM_LEVEL = 1.25 

BG_SUPPRESSION = 0.4

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
NUM_CLASSES = 5


def gn(ch, groups=8):
    groups = min(groups, ch)
    while groups > 1 and (ch % groups != 0): groups -= 1
    return nn.GroupNorm(groups, ch)

class ConvGNAct(nn.Module):
    def __init__(self, in_ch, out_ch, k=3, p=1):
        super().__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, k, padding=p, bias=False)
        self.gn = gn(out_ch); self.act = nn.SiLU(inplace=True)
    def forward(self, x): return self.act(self.gn(self.conv(x)))

class ResBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.c1 = ConvGNAct(in_ch, out_ch)
        self.c2 = nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False)
        self.g2 = gn(out_ch); self.act = nn.SiLU(inplace=True)
        self.skip = nn.Identity() if in_ch == out_ch else nn.Conv2d(in_ch, out_ch, 1, bias=False)
    def forward(self, x): return self.act(self.g2(self.c2(self.c1(x))) + self.skip(x))

class AttnGate(nn.Module):
    def __init__(self, skip_ch, gate_ch, inter_ch):
        super().__init__()
        self.theta = nn.Conv2d(skip_ch, inter_ch, 1, bias=False)
        self.phi = nn.Conv2d(gate_ch, inter_ch, 1, bias=False)
        self.psi = nn.Conv2d(inter_ch, 1, 1, bias=True)
        self.act = nn.SiLU(inplace=True); self.sig = nn.Sigmoid()
    def forward(self, skip, gate):
        g = F.interpolate(gate, size=skip.shape[-2:], mode="bilinear", align_corners=False)
        return skip * self.sig(self.psi(self.act(self.theta(skip) + self.phi(g))))

class Down(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__(); self.pool = nn.MaxPool2d(2); self.block = ResBlock(in_ch, out_ch)
    def forward(self, x): return self.block(self.pool(x))

class Up(nn.Module):
    def __init__(self, in_ch, skip_ch, out_ch):
        super().__init__()
        self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False)
        self.reduce = nn.Conv2d(in_ch, out_ch, 1, bias=False)
        self.attn = AttnGate(skip_ch, out_ch, inter_ch=max(16, out_ch // 2))
        self.block = ResBlock(out_ch + skip_ch, out_ch)
    def forward(self, x, skip):
        x = self.up(x); x = self.reduce(x); skip = self.attn(skip, x)
        return self.block(torch.cat([skip, x], dim=1))

class ResAttnUNetDS(nn.Module):
    def __init__(self, n_classes=5, base=48):
        super().__init__()
        c1, c2, c3, c4, c5 = base, base*2, base*4, base*8, base*12
        self.stem = ResBlock(1, c1)
        self.d1 = Down(c1, c2); self.d2 = Down(c2, c3); self.d3 = Down(c3, c4); self.d4 = Down(c4, c5)
        self.bottleneck = ResBlock(c5, c5)
        self.u3 = Up(c5, c4, c4); self.u2 = Up(c4, c3, c3); self.u1 = Up(c3, c2, c2); self.u0 = Up(c2, c1, c1)
        self.head0 = nn.Conv2d(c1, n_classes, 1)
    def forward(self, x):
        s0 = self.stem(x); s1 = self.d1(s0); s2 = self.d2(s1); s3 = self.d3(s2); s4 = self.d4(s3)
        b = self.bottleneck(s4)
        x3 = self.u3(b, s3); x2 = self.u2(x3, s2); x1 = self.u1(x2, s1); x0 = self.u0(x1, s0)
        return self.head0(x0)


def robust_normalize(img):
    img = img.astype(np.float32)
    p2, p98 = np.percentile(img, (2, 98))
    if p98 - p2 < 1e-6: return np.zeros_like(img, dtype=np.float32)
    img = (img - p2) / (p98 - p2)
    return np.clip(img, 0.0, 1.0).astype(np.float32)

def calculate_entropy(probs):
    epsilon = 1e-6
    entropy = -torch.sum(probs * torch.log(probs + epsilon), dim=0)
    return entropy

def predict_on_array(model, img_arr, stride):
    """
    Standard sliding window inference on a specific numpy array
    """
    h, w = img_arr.shape
    prob_sum = torch.zeros((NUM_CLASSES, h, w), dtype=torch.float32, device=DEVICE)
    count_map = torch.zeros((1, h, w), dtype=torch.float32, device=DEVICE)
    
    y_starts = sorted(list(set(list(range(0, h, stride)) + [max(0, h - TILE_SIZE)])))
    x_starts = sorted(list(set(list(range(0, w, stride)) + [max(0, w - TILE_SIZE)])))

    with torch.no_grad():
        for y in y_starts:
            for x in x_starts:
                y_end, x_end = min(y + TILE_SIZE, h), min(x + TILE_SIZE, w)
                tile = img_arr[y:y_end, x:x_end]
                th, tw = tile.shape
                
                pad_h, pad_w = TILE_SIZE - th, TILE_SIZE - tw
                if pad_h > 0 or pad_w > 0:
                    tile = np.pad(tile, ((0, pad_h), (0, pad_w)), mode='reflect')

                inp = torch.from_numpy(tile).unsqueeze(0).unsqueeze(0).to(DEVICE)
                logits = model(inp)
                if isinstance(logits, tuple): logits = logits[0]
                probs = torch.softmax(logits, dim=1)
                
                probs = probs[0, :, :th, :tw]
                prob_sum[:, y:y_end, x:x_end] += probs
                count_map[:, y:y_end, x:x_end] += 1.0
                
    return prob_sum / count_map

def fast_boost_inference(model, raw_img):
    """
    Runs Single-Pass Zoomed Inference
    """
    orig_h, orig_w = raw_img.shape

    target_h, target_w = int(orig_h * ZOOM_LEVEL), int(orig_w * ZOOM_LEVEL)
    print(f"  > Processing Single Pass at {ZOOM_LEVEL}x Zoom ({target_w}x{target_h})...")
    
    scaled_img = cv2.resize(raw_img, (target_w, target_h), interpolation=cv2.INTER_LINEAR)
    

    scaled_probs = predict_on_array(model, scaled_img, stride=STRIDE)
    

    scaled_probs_t = scaled_probs.unsqueeze(0) 
    
    final_probs = F.interpolate(
        scaled_probs_t, 
        size=(orig_h, orig_w), 
        mode='bilinear', 
        align_corners=False
    )[0] 

    if BG_SUPPRESSION < 1.0:
        final_probs[0, :, :] *= BG_SUPPRESSION
        final_probs = final_probs / final_probs.sum(dim=0, keepdim=True)
        
    return final_probs


def main():
    os.makedirs(OUTPUT_DIR, exist_ok=True)

    print(f"Loading ResNet from {MODEL_PATH}...")
    model = ResAttnUNetDS(n_classes=NUM_CLASSES, base=48).to(DEVICE)
    model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE), strict=False)
    model.eval()

    image_paths = sorted(glob.glob(os.path.join(TEST_DATA_DIR, "raw_*.tiff")))
    if not image_paths: image_paths = sorted(glob.glob(os.path.join(TEST_DATA_DIR, "raw_*.tif")))
    
    for img_path in image_paths:
        fname = os.path.splitext(os.path.basename(img_path))[0]
        print(f"\nProcessing {fname}...")
        
        try:
            raw_img = np.array(Image.open(img_path))
            norm_img = robust_normalize(raw_img)

            probs = fast_boost_inference(model, norm_img)

            pred = torch.argmax(probs, dim=0).cpu().numpy().astype(np.uint8)
            ent = calculate_entropy(probs).cpu().numpy()
 
            np.save(os.path.join(OUTPUT_DIR, f"{fname}_boost_pred.npy"), pred)
            np.save(os.path.join(OUTPUT_DIR, f"{fname}_boost_entropy.npy"), ent)
            
 
            fig, ax = plt.subplots(1, 3, figsize=(18, 6))
            ax[0].imshow(raw_img, cmap="gray"); ax[0].set_title("Input")
            
            cmap = plt.get_cmap("jet", NUM_CLASSES)
            ax[1].imshow(pred, cmap=cmap, vmin=0, vmax=NUM_CLASSES-1, interpolation='nearest')
            ax[1].set_title(f"ResNet Fast Boost (Zoom={ZOOM_LEVEL}x) Prediction")
            
            im = ax[2].imshow(ent, cmap="inferno"); ax[2].set_title("ResNet Entropy")
            plt.colorbar(im, ax=ax[2])
            
            for a in ax: a.axis("off")
            plt.tight_layout()
            save_path = os.path.join(OUTPUT_DIR, f"{fname}ResNet_FastBoost.png")
            plt.savefig(save_path, dpi=150)
            plt.close()
            print(f"  ✅ Saved results to {save_path}")
            
        except Exception as e:
            print(f"  ❌ Error: {e}")

if __name__ == "__main__":
    main()

Loading ResNet from best_resattn_finetuned.pth...

Processing raw_13...
  > Processing Single Pass at 1.25x Zoom (6213x6901)...
  ✅ Saved results to C:\Users\vonkl\Documents\453_Project\453_Project\project\test_data_tiff\outputs_fast_boost\raw_13ResNet_FastBoost.png

Processing raw_14...
  > Processing Single Pass at 1.25x Zoom (1783x1453)...
  ✅ Saved results to C:\Users\vonkl\Documents\453_Project\453_Project\project\test_data_tiff\outputs_fast_boost\raw_14ResNet_FastBoost.png


In [None]:
import os
import glob
import numpy as np
import torch
import torch.nn as nn
from PIL import Image
import matplotlib.pyplot as plt


CHECKPOINT_PATH = "project/best_unet_godmode_v2.pth"

TEST_DATA_DIR = r"C:\Users\vonkl\Documents\453_Project\453_Project\project\test_data_tiff"
OUTPUT_DIR = os.path.join(TEST_DATA_DIR, "outputs_unet_only")


NUM_CLASSES = 5
TILE_SIZE = 512
STRIDE = 400
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

class DoubleConvGN(nn.Module):

    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.GroupNorm(8, out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.GroupNorm(8, out_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.conv(x)

class SimpleUNet(nn.Module):
    def __init__(self, n_classes):
        super().__init__()
        self.inc = DoubleConvGN(1, 64)
        self.down1 = nn.Sequential(nn.MaxPool2d(2), DoubleConvGN(64, 128))
        self.down2 = nn.Sequential(nn.MaxPool2d(2), DoubleConvGN(128, 256))
        self.down3 = nn.Sequential(nn.MaxPool2d(2), DoubleConvGN(256, 512))
        
        self.up1 = nn.ConvTranspose2d(512, 256, 2, 2)
        self.conv1 = DoubleConvGN(512, 256)
        self.up2 = nn.ConvTranspose2d(256, 128, 2, 2)
        self.conv2 = DoubleConvGN(256, 128)
        self.up3 = nn.ConvTranspose2d(128, 64, 2, 2)
        self.conv3 = DoubleConvGN(128, 64)
        self.outc = nn.Conv2d(64, n_classes, 1)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        
        x = self.up1(x4)
        x = torch.cat([x3, x], dim=1)
        x = self.conv1(x)
        
        x = self.up2(x)
        x = torch.cat([x2, x], dim=1)
        x = self.conv2(x)
        
        x = self.up3(x)
        x = torch.cat([x1, x], dim=1)
        x = self.conv3(x)
        return self.outc(x)

def robust_normalize(img):
    """Normalizes based on 2nd and 98th percentiles."""
    img = img.astype(np.float32)
    p2, p98 = np.percentile(img, (2, 98))
    if p98 - p2 < 1e-6:
        return np.zeros_like(img, dtype=np.float32)
    img = (img - p2) / (p98 - p2)
    return np.clip(img, 0.0, 1.0).astype(np.float32)

def calculate_entropy(probs):
    """Entropy = - sum(p * log(p))"""
    epsilon = 1e-6
    entropy = -torch.sum(probs * torch.log(probs + epsilon), dim=0)
    return entropy

def sliding_window_inference(model, img_array, tile_size, stride, num_classes, device):
    """
    Runs sliding window inference on a large image.
    """
    h, w = img_array.shape
    prob_sum = torch.zeros((num_classes, h, w), dtype=torch.float32, device=device)
    count_map = torch.zeros((1, h, w), dtype=torch.float32, device=device)
    
   
    y_starts = sorted(list(set(list(range(0, h, stride)) + [max(0, h - tile_size)])))
    x_starts = sorted(list(set(list(range(0, w, stride)) + [max(0, w - tile_size)])))

    print(f"  > Processing {len(y_starts) * len(x_starts)} tiles...")

    model.eval()
    with torch.no_grad():
        for y in y_starts:
            for x in x_starts:
                y_end, x_end = min(y + tile_size, h), min(x + tile_size, w)
                tile = img_array[y:y_end, x:x_end]
                th, tw = tile.shape

                pad_h, pad_w = tile_size - th, tile_size - tw
                if pad_h > 0 or pad_w > 0:
                    tile = np.pad(tile, ((0, pad_h), (0, pad_w)), mode='reflect')


                input_tensor = torch.from_numpy(tile).unsqueeze(0).unsqueeze(0).to(device)
                logits = model(input_tensor)
                probs = torch.softmax(logits, dim=1)
       
                probs = probs[0, :, :th, :tw]
                prob_sum[:, y:y_end, x:x_end] += probs
                count_map[:, y:y_end, x:x_end] += 1.0


    avg_probs = prob_sum / count_map
    

    pred_map = torch.argmax(avg_probs, dim=0).cpu().numpy().astype(np.uint8)

    ent_map = calculate_entropy(avg_probs).cpu().numpy()
    
    return pred_map, ent_map


def main():
    os.makedirs(OUTPUT_DIR, exist_ok=True)
    
    
    print(f"Loading SimpleUNet from {CHECKPOINT_PATH}...")
    if not os.path.exists(CHECKPOINT_PATH):
        print(f"❌ Error: Model file not found at {CHECKPOINT_PATH}")
        return

    model = SimpleUNet(n_classes=NUM_CLASSES).to(DEVICE)
    try:
       
        model.load_state_dict(torch.load(CHECKPOINT_PATH, map_location=DEVICE), strict=False)
        print("✅ Model loaded.")
    except Exception as e:
        print(f"❌ Failed to load weights: {e}")
        return

    
    image_paths = sorted(glob.glob(os.path.join(TEST_DATA_DIR, "raw_*.tiff")))
    if not image_paths:
        image_paths = sorted(glob.glob(os.path.join(TEST_DATA_DIR, "raw_*.tif")))
    
    if not image_paths:
        print("❌ No images found in test_data_tiff folder.")
        return

    for img_path in image_paths:
        fname = os.path.splitext(os.path.basename(img_path))[0]
        print(f"\nProcessing {fname}...")
        
        try:
            
            raw_img = np.array(Image.open(img_path))
            norm_img = robust_normalize(raw_img)
            
     
            pred_mask, entropy_map = sliding_window_inference(
                model, norm_img, TILE_SIZE, STRIDE, NUM_CLASSES, DEVICE
            )
            
          
            np.save(os.path.join(OUTPUT_DIR, f"{fname}_pred.npy"), pred_mask)
            np.save(os.path.join(OUTPUT_DIR, f"{fname}_entropy.npy"), entropy_map)
            
        
            fig, ax = plt.subplots(1, 3, figsize=(18, 6))
            
        
            ax[0].imshow(raw_img, cmap="gray")
            ax[0].set_title("Input Image")
            ax[0].axis("off")
            
   
            cmap = plt.get_cmap("jet", NUM_CLASSES)
            ax[1].imshow(pred_mask, cmap=cmap, vmin=0, vmax=NUM_CLASSES-1, interpolation='nearest')
            ax[1].set_title("UNet Prediction")
            ax[1].axis("off")
            
      
            im = ax[2].imshow(entropy_map, cmap="inferno")
            ax[2].set_title("Entropy (Uncertainty)")
            ax[2].axis("off")
            plt.colorbar(im, ax=ax[2], fraction=0.046, pad=0.04)
            
            save_vis = os.path.join(OUTPUT_DIR, f"{fname}_unet_result.png")
            plt.tight_layout()
            plt.savefig(save_vis, dpi=150)
            plt.close()
            print(f"✅ Saved results to {save_vis}")
            
        except Exception as e:
            print(f"❌ Error processing {fname}: {e}")

if __name__ == "__main__":
    main()

Loading SimpleUNet from project/best_unet_godmode_v2.pth...
✅ Model loaded.

Processing raw_13...
  > Processing 210 tiles...
