In [None]:
# unet25d_k5_train.py
import os, glob, csv, math, random
import numpy as np
import tifffile
from dataclasses import dataclass
from typing import List, Tuple, Dict, Optional
from collections import OrderedDict

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tqdm.auto import tqdm

# -------------------------
# CONFIG
# -------------------------
@dataclass
class CFG:
    # data
    PATCH_DIR: str = "Nov23_Crops/patches_256_s128_excl_crop2buf64"
    LABEL_KEY: str = "or_"  # in npz: rb, ten, and, or

    # crop2 eval
    CROP2_RAW: str = "Nov23_Crops/90deg_myelin_curve_CCP_crop2.tif"
    CROP2_GT:  str = "Nov23_Crops/label_crop2_new.tif"

    # global normalization (from full 2960x2960 stack)
    GLOBAL_MIN: float = 1238.0
    GLOBAL_MAX: float = 65535.0

    # 2.5D
    K: int = 5
    DROP_BOUNDARIES: bool = True  # drop PAD slices at each end
    THR_EVAL: float = 0.7         # threshold for reporting crop2 Dice during training

    # training
    EPOCHS: int = 10
    BATCH: int = 8
    LR: float = 3e-4
    WEIGHT_DECAY: float = 1e-4
    CLIP_NORM: float = 1.0

    # split
    VAL_FRAC: float = 0.20
    SEED: int = 42

    # augmentation
    AUGMENT: bool = True  # flips + 90deg rotations

    # perf
    NUM_WORKERS: int = 0
    PIN_MEMORY: bool = False

    # model
    BASE_CH: int = 64
    GN_GROUPS: int = 8
    DROPOUT_BOTTLENECK: float = 0.10

    # output
    RUN_DIR: str = "Nov23_Crops/unet25d_runs/unet25d_k5_or_gn_drop_dice_bce_lr3e-4_clip1"

cfg = CFG()
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
EPS = 1e-8


# -------------------------
# REPRODUCIBILITY
# -------------------------
def seed_everything(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

seed_everything(cfg.SEED)


# -------------------------
# SMALL HELPERS
# -------------------------
def global_norm_uint16_to_01(x: np.ndarray, mn: float, mx: float) -> np.ndarray:
    x = x.astype(np.float32)
    x = (x - mn) / (mx - mn + 1e-12)
    return np.clip(x, 0.0, 1.0)

@torch.no_grad()
def dice_prec_rec_from_logits(logits: torch.Tensor, gt01: torch.Tensor, thr: float) -> Tuple[float, float, float]:
    # logits: [N,1,H,W], gt01: [N,1,H,W]
    probs = torch.sigmoid(logits)
    pred = probs > thr
    gt = gt01 > 0.5
    tp = (pred & gt).sum().item()
    fp = (pred & ~gt).sum().item()
    fn = (~pred & gt).sum().item()
    dice = (2 * tp) / (2 * tp + fp + fn + EPS)
    prec = tp / (tp + fp + EPS)
    rec  = tp / (tp + fn + EPS)
    return float(dice), float(prec), float(rec)

def soft_dice_loss_with_logits(logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
    # logits: [N,1,H,W], targets: [N,1,H,W] float 0/1
    probs = torch.sigmoid(logits)
    num = 2.0 * (probs * targets).sum(dim=(1,2,3)) + EPS
    den = (probs * probs).sum(dim=(1,2,3)) + (targets * targets).sum(dim=(1,2,3)) + EPS
    return 1.0 - (num / den).mean()

def ensure_dir(p: str):
    os.makedirs(p, exist_ok=True)


# -------------------------
# NPZ PATCH DATASET (2.5D)
# -------------------------
class TinyLRUCache:
    """Caches a few loaded npz dicts (keeps memory bounded)."""
    def __init__(self, max_items: int = 8):
        self.max_items = max_items
        self.cache = OrderedDict()

    def get(self, key: str):
        if key in self.cache:
            self.cache.move_to_end(key)
            return self.cache[key]
        return None

    def put(self, key: str, val):
        self.cache[key] = val
        self.cache.move_to_end(key)
        while len(self.cache) > self.max_items:
            self.cache.popitem(last=False)

class Patch25DDataset(Dataset):
    def __init__(self, files: List[str], label_key: str, k: int, drop_boundaries: bool,
                 global_min: float, global_max: float, augment: bool):
        self.files = files
        self.label_key = label_key
        self.k = k
        self.pad = k // 2
        self.drop_boundaries = drop_boundaries
        self.global_min = global_min
        self.global_max = global_max
        self.augment = augment

        # inspect Z once to build index (assume constant Z across patches)
        d0 = np.load(self.files[0])
        self.Z = int(d0["raw"].shape[0])
        self.H = int(d0["raw"].shape[1])
        self.W = int(d0["raw"].shape[2])

        # allowed centers
        if self.drop_boundaries:
            self.centers = list(range(self.pad, self.Z - self.pad))
        else:
            self.centers = list(range(0, self.Z))
        self.centers_per_patch = len(self.centers)

        # flat index: (patch_idx, center_idx)
        self.index = [(pi, ci) for pi in range(len(self.files)) for ci in range(self.centers_per_patch)]

        self.cache = TinyLRUCache(max_items=8)

    def __len__(self):
        return len(self.index)

    def _load_patch(self, path: str) -> Dict[str, np.ndarray]:
        cached = self.cache.get(path)
        if cached is not None:
            return cached
        d = np.load(path)
        out = {
            "raw": d["raw"],                 # [Z,H,W]
            "lbl": (d[self.label_key] > 0),  # [Z,H,W] bool
        }
        self.cache.put(path, out)
        return out

    def _augment_xy(self, x: torch.Tensor, y: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        # x: [K,H,W], y: [H,W]
        # random 0/90/180/270 rotation
        krot = random.randint(0, 3)
        if krot:
            x = torch.rot90(x, k=krot, dims=(-2, -1))
            y = torch.rot90(y, k=krot, dims=(-2, -1))
        # random flips
        if random.random() < 0.5:
            x = torch.flip(x, dims=(-1,))
            y = torch.flip(y, dims=(-1,))
        if random.random() < 0.5:
            x = torch.flip(x, dims=(-2,))
            y = torch.flip(y, dims=(-2,))
        return x, y

    def __getitem__(self, idx: int):
        pi, ci = self.index[idx]
        path = self.files[pi]
        d = self._load_patch(path)

        raw = d["raw"]  # [Z,H,W]
        lbl = d["lbl"]  # [Z,H,W] bool

        zc = self.centers[ci]
        if self.k == 1:
            xk = raw[zc:zc+1]
        else:
            xk = raw[zc - self.pad: zc + self.pad + 1]  # [K,H,W]
        y = lbl[zc].astype(np.float32)  # [H,W]

        xk = global_norm_uint16_to_01(xk, self.global_min, self.global_max)  # float32 [K,H,W]
        x = torch.from_numpy(xk)  # [K,H,W]
        y = torch.from_numpy(y)   # [H,W]

        if self.augment:
            x, y = self._augment_xy(x, y)

        # output format
        # x: [K,H,W] float32, y: [1,H,W] float32
        return x, y.unsqueeze(0)


# -------------------------
# MODEL
# -------------------------
def gn(num_channels: int, groups: int) -> nn.GroupNorm:
    g = min(groups, num_channels)
    while num_channels % g != 0 and g > 1:
        g -= 1
    return nn.GroupNorm(g, num_channels)

class ConvBlock(nn.Module):
    def __init__(self, in_ch: int, out_ch: int, groups: int):
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False)
        self.gn1 = gn(out_ch, groups)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False)
        self.gn2 = gn(out_ch, groups)

    def forward(self, x):
        x = F.relu(self.gn1(self.conv1(x)), inplace=True)
        x = F.relu(self.gn2(self.conv2(x)), inplace=True)
        return x

class UNet2D(nn.Module):
    def __init__(self, in_ch: int, base_ch: int = 64, groups: int = 8, dropout_bottleneck: float = 0.10):
        super().__init__()
        c1, c2, c3, c4, c5 = base_ch, base_ch * 2, base_ch * 4, base_ch * 8, base_ch * 16

        self.enc1 = ConvBlock(in_ch, c1, groups)
        self.pool1 = nn.MaxPool2d(2)

        self.enc2 = ConvBlock(c1, c2, groups)
        self.pool2 = nn.MaxPool2d(2)

        self.enc3 = ConvBlock(c2, c3, groups)
        self.pool3 = nn.MaxPool2d(2)

        self.enc4 = ConvBlock(c3, c4, groups)
        self.pool4 = nn.MaxPool2d(2)

        self.bottleneck = ConvBlock(c4, c5, groups)
        self.drop = nn.Dropout2d(p=dropout_bottleneck)

        self.up4 = nn.ConvTranspose2d(c5, c4, 2, stride=2)
        self.dec4 = ConvBlock(c4 + c4, c4, groups)

        self.up3 = nn.ConvTranspose2d(c4, c3, 2, stride=2)
        self.dec3 = ConvBlock(c3 + c3, c3, groups)

        self.up2 = nn.ConvTranspose2d(c3, c2, 2, stride=2)
        self.dec2 = ConvBlock(c2 + c2, c2, groups)

        self.up1 = nn.ConvTranspose2d(c2, c1, 2, stride=2)
        self.dec1 = ConvBlock(c1 + c1, c1, groups)

        self.out = nn.Conv2d(c1, 1, 1)

    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool1(e1))
        e3 = self.enc3(self.pool2(e2))
        e4 = self.enc4(self.pool3(e3))

        b = self.bottleneck(self.pool4(e4))
        b = self.drop(b)

        d4 = self.up4(b)
        d4 = self.dec4(torch.cat([d4, e4], dim=1))

        d3 = self.up3(d4)
        d3 = self.dec3(torch.cat([d3, e3], dim=1))

        d2 = self.up2(d3)
        d2 = self.dec2(torch.cat([d2, e2], dim=1))

        d1 = self.up1(d2)
        d1 = self.dec1(torch.cat([d1, e1], dim=1))

        return self.out(d1)  # logits


# -------------------------
# EVAL ON CROP2 MANUAL GT
# -------------------------
@torch.no_grad()
def eval_on_crop2(model: nn.Module, thr: float, k: int, drop_boundaries: bool,
                  crop2_raw_path: str, crop2_gt_path: str,
                  global_min: float, global_max: float,
                  batch: int = 8) -> Tuple[float, float, float]:
    raw = tifffile.imread(crop2_raw_path).astype(np.float32)  # [Z,H,W]
    gt  = tifffile.imread(crop2_gt_path).astype(np.float32)   # [Z,H,W]
    assert raw.shape == gt.shape, f"crop2 raw {raw.shape} != gt {gt.shape}"
    Z, H, W = raw.shape

    pad = k // 2
    x = global_norm_uint16_to_01(raw, global_min, global_max)  # [Z,H,W]
    y = (gt > 0).astype(np.float32)                            # [Z,H,W]

    if drop_boundaries:
        z_centers = list(range(pad, Z - pad))
    else:
        z_centers = list(range(0, Z))
    N = len(z_centers)

    # build inputs in CPU numpy (small: 30 slices only)
    if k == 1:
        x_in = np.stack([x[z:z+1] for z in z_centers], axis=0)           # [N,1,H,W]
    else:
        x_in = np.stack([x[z-pad:z+pad+1] for z in z_centers], axis=0)   # [N,K,H,W]
    y_in = np.stack([y[z] for z in z_centers], axis=0)                  # [N,H,W]

    xt = torch.from_numpy(x_in).to(DEVICE)               # [N,K,H,W]
    yt = torch.from_numpy(y_in).unsqueeze(1).to(DEVICE)  # [N,1,H,W]

    logits_list = []
    for i in range(0, N, batch):
        logits_list.append(model(xt[i:i+batch]))
    logits = torch.cat(logits_list, dim=0)

    return dice_prec_rec_from_logits(logits, yt, thr=thr)


# -------------------------
# TRAIN LOOP
# -------------------------
def main():
    ensure_dir(cfg.RUN_DIR)

    patch_files = sorted(glob.glob(os.path.join(cfg.PATCH_DIR, "patch_*.npz")))
    if len(patch_files) == 0:
        raise FileNotFoundError(f"No patch_*.npz found in {cfg.PATCH_DIR}")
    print(f"Found {len(patch_files)} patch volumes")

    # split on patch volumes (not per-slice), to avoid leakage
    rng = np.random.RandomState(cfg.SEED)
    perm = rng.permutation(len(patch_files))
    n_val = int(round(cfg.VAL_FRAC * len(patch_files)))
    val_ids = set(perm[:n_val].tolist())
    tr_files = [f for i, f in enumerate(patch_files) if i not in val_ids]
    va_files = [f for i, f in enumerate(patch_files) if i in val_ids]
    print(f"Train patches: {len(tr_files)} | Val patches: {len(va_files)}")
    print(f"Label key: {cfg.LABEL_KEY}")
    print(f"Device: {DEVICE}")

    pad = cfg.K // 2
    print(f"2.5D K={cfg.K} (PAD={pad}), drop boundaries: {cfg.DROP_BOUNDARIES}")
    # dataset info
    ds_probe = np.load(patch_files[0])
    Z = ds_probe["raw"].shape[0]
    centers = (Z - 2*pad) if cfg.DROP_BOUNDARIES else Z
    print(f"Detected Z slices per patch: {Z}")
    print(f"Training centers per patch: {centers} (dropping {2*pad} boundary slices)")

    ds_tr = Patch25DDataset(
        tr_files, label_key=cfg.LABEL_KEY, k=cfg.K, drop_boundaries=cfg.DROP_BOUNDARIES,
        global_min=cfg.GLOBAL_MIN, global_max=cfg.GLOBAL_MAX, augment=cfg.AUGMENT
    )
    ds_va = Patch25DDataset(
        va_files, label_key=cfg.LABEL_KEY, k=cfg.K, drop_boundaries=cfg.DROP_BOUNDARIES,
        global_min=cfg.GLOBAL_MIN, global_max=cfg.GLOBAL_MAX, augment=False
    )

    # sanity check
    x0, y0 = ds_tr[0]
    print(f"Sanity ds_tr[0] x: {float(x0.min()):.6f} {float(x0.max()):.6f} mean: {float(x0.mean()):.6f} shape: {tuple(x0.shape)}")
    print(f"Sanity ds_tr[0] y unique: {torch.unique(y0)} y mean: {float(y0.mean()):.6f}")

    dl_tr = DataLoader(ds_tr, batch_size=cfg.BATCH, shuffle=True,
                       num_workers=cfg.NUM_WORKERS, pin_memory=cfg.PIN_MEMORY, drop_last=True)
    dl_va = DataLoader(ds_va, batch_size=cfg.BATCH, shuffle=False,
                       num_workers=cfg.NUM_WORKERS, pin_memory=cfg.PIN_MEMORY, drop_last=False)

    model = UNet2D(in_ch=cfg.K, base_ch=cfg.BASE_CH, groups=cfg.GN_GROUPS, dropout_bottleneck=cfg.DROPOUT_BOTTLENECK).to(DEVICE)
    opt = torch.optim.AdamW(model.parameters(), lr=cfg.LR, weight_decay=cfg.WEIGHT_DECAY)
    bce = nn.BCEWithLogitsLoss()

    # logging
    log_path = os.path.join(cfg.RUN_DIR, "log.csv")
    with open(log_path, "w", newline="") as f:
        w = csv.writer(f)
        w.writerow(["epoch", "train_loss", "val_loss", "crop2_dice_thr", "crop2_precision_thr", "crop2_recall_thr"])

    best_dice = -1.0

    for epoch in range(1, cfg.EPOCHS + 1):
        # train
        model.train()
        tr_losses = []
        pbar = tqdm(dl_tr, desc=f"Epoch {epoch}/{cfg.EPOCHS} [train]", leave=False)
        for xb, yb in pbar:
            xb = xb.to(DEVICE, non_blocking=True)  # [B,K,H,W]
            yb = yb.to(DEVICE, non_blocking=True)  # [B,1,H,W]

            opt.zero_grad(set_to_none=True)
            logits = model(xb)

            loss_bce = bce(logits, yb)
            loss_dice = soft_dice_loss_with_logits(logits, yb)
            loss = loss_bce + loss_dice

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.CLIP_NORM)
            opt.step()

            tr_losses.append(float(loss.item()))
            pbar.set_postfix(loss=float(np.mean(tr_losses)))

        train_loss = float(np.mean(tr_losses)) if tr_losses else float("nan")

        # val loss on held-out patches
        model.eval()
        va_losses = []
        with torch.no_grad():
            for xb, yb in dl_va:
                xb = xb.to(DEVICE, non_blocking=True)
                yb = yb.to(DEVICE, non_blocking=True)
                logits = model(xb)
                loss = bce(logits, yb) + soft_dice_loss_with_logits(logits, yb)
                va_losses.append(float(loss.item()))
        val_loss = float(np.mean(va_losses)) if va_losses else float("nan")

        # eval on crop2 manual GT
        crop2_d, crop2_p, crop2_r = eval_on_crop2(
            model, thr=cfg.THR_EVAL, k=cfg.K, drop_boundaries=True,
            crop2_raw_path=cfg.CROP2_RAW, crop2_gt_path=cfg.CROP2_GT,
            global_min=cfg.GLOBAL_MIN, global_max=cfg.GLOBAL_MAX,
            batch=cfg.BATCH
        )

        # print epoch summary
        print(f"Epoch {epoch:03d}/{cfg.EPOCHS} | train_loss={train_loss:.4f} val_loss={val_loss:.4f} "
              f"| crop2 Dice@{cfg.THR_EVAL:.2f}={crop2_d:.4f} P={crop2_p:.4f} R={crop2_r:.4f}")

        # log
        with open(log_path, "a", newline="") as f:
            w = csv.writer(f)
            w.writerow([epoch, train_loss, val_loss, crop2_d, crop2_p, crop2_r])

        # save last
        ckpt = {
            "model_state": model.state_dict(),
            "opt_state": opt.state_dict(),
            "epoch": epoch,
            "cfg": cfg.__dict__,
        }
        torch.save(ckpt, os.path.join(cfg.RUN_DIR, "last.pt"))

        # save best by crop2 Dice@thr
        if crop2_d > best_dice:
            best_dice = crop2_d
            torch.save(ckpt, os.path.join(cfg.RUN_DIR, "best.pt"))

    print(f"\nDone. Best crop2 Dice@{cfg.THR_EVAL:.2f}: {best_dice:.4f}")
    print(f"Saved in: {cfg.RUN_DIR}")
    print("Artifacts: best.pt, last.pt, log.csv")


if __name__ == "__main__":
    main()
