In [None]:
import os, math, random
from dataclasses import dataclass
from typing import List, Tuple

import numpy as np
import tifffile
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tqdm.auto import tqdm

# -------------------------
# CONFIG
# -------------------------
@dataclass
class FTConfig:
    # stage-1 checkpoint (K=5, OR pseudo-labels)
    CKPT_PATH: str = "Nov23_Crops/unet25d_runs/unet25d_k5_or_gn_drop_dice_bce_lr3e-4_clip1/best.pt"

    # crop2 raw and manual GT
    CROP2_RAW: str = "Nov23_Crops/90deg_myelin_curve_CCP_crop2.tif"
    CROP2_GT:  str = "Nov23_Crops/label_crop2_new.tif"

    # global normalization (same as before)
    GLOBAL_MIN: float = 1238.0
    GLOBAL_MAX: float = 65535.0

    # 2.5D
    K: int = 5
    PAD: int = 2   # K=5 -> pad=2

    # training hyperparams
    EPOCHS: int = 20     # you can stop early if dev AUPRC plateaus
    BATCH: int = 4
    LR: float = 1e-4     # smaller LR for fine-tuning
    WEIGHT_DECAY: float = 1e-5
    CLIP_NORM: float = 1.0

    # device
    DEVICE: str = "cuda" if torch.cuda.is_available() else "cpu"

    # output
    RUN_DIR: str = "Nov23_Crops/unet25d_runs/unet25d_k5_or_finetune_crop2"

cfg = FTConfig()
os.makedirs(cfg.RUN_DIR, exist_ok=True)
EPS = 1e-8

# -------------------------
# REPRO
# -------------------------
def seed_everything(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

seed_everything(42)

# -------------------------
# MODEL (same as stage 1)
# -------------------------
def gn(num_channels: int, groups: int) -> nn.GroupNorm:
    g = min(groups, num_channels)
    while num_channels % g != 0 and g > 1:
        g -= 1
    return nn.GroupNorm(g, num_channels)

class ConvBlock(nn.Module):
    def __init__(self, in_ch: int, out_ch: int, groups: int):
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False)
        self.gn1 = gn(out_ch, groups)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False)
        self.gn2 = gn(out_ch, groups)

    def forward(self, x):
        x = F.relu(self.gn1(self.conv1(x)), inplace=True)
        x = F.relu(self.gn2(self.conv2(x)), inplace=True)
        return x

class UNet2D(nn.Module):
    def __init__(self, in_ch: int, base_ch: int = 64, groups: int = 8, dropout_bottleneck: float = 0.10):
        super().__init__()
        c1, c2, c3, c4, c5 = base_ch, base_ch * 2, base_ch * 4, base_ch * 8, base_ch * 16

        self.enc1 = ConvBlock(in_ch, c1, groups)
        self.pool1 = nn.MaxPool2d(2)

        self.enc2 = ConvBlock(c1, c2, groups)
        self.pool2 = nn.MaxPool2d(2)

        self.enc3 = ConvBlock(c2, c3, groups)
        self.pool3 = nn.MaxPool2d(2)

        self.enc4 = ConvBlock(c3, c4, groups)
        self.pool4 = nn.MaxPool2d(2)

        self.bottleneck = ConvBlock(c4, c5, groups)
        self.drop = nn.Dropout2d(p=dropout_bottleneck)

        self.up4 = nn.ConvTranspose2d(c5, c4, 2, stride=2)
        self.dec4 = ConvBlock(c4 + c4, c4, groups)

        self.up3 = nn.ConvTranspose2d(c4, c3, 2, stride=2)
        self.dec3 = ConvBlock(c3 + c3, c3, groups)

        self.up2 = nn.ConvTranspose2d(c3, c2, 2, stride=2)
        self.dec2 = ConvBlock(c2 + c2, c2, groups)

        self.up1 = nn.ConvTranspose2d(c2, c1, 2, stride=2)
        self.dec1 = ConvBlock(c1 + c1, c1, groups)

        self.out = nn.Conv2d(c1, 1, 1)

    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool1(e1))
        e3 = self.enc3(self.pool2(e2))
        e4 = self.enc4(self.pool3(e3))

        b = self.bottleneck(self.pool4(e4))
        b = self.drop(b)

        d4 = self.up4(b)
        d4 = self.dec4(torch.cat([d4, e4], dim=1))

        d3 = self.up3(d4)
        d3 = self.dec3(torch.cat([d3, e3], dim=1))

        d2 = self.up2(d3)
        d2 = self.dec2(torch.cat([d2, e2], dim=1))

        d1 = self.up1(d2)
        d1 = self.dec1(torch.cat([d1, e1], dim=1))

        return self.out(d1)  # logits

# -------------------------
# AUPRC / Dice metrics
# -------------------------
@torch.no_grad()
def binary_auprc_from_logits(logits: torch.Tensor, y: torch.Tensor) -> float:
    """Flatten, compute AUPRC over pixels."""
    probs = torch.sigmoid(logits).flatten().cpu().numpy().astype(np.float64)
    y_true = y.flatten().cpu().numpy().astype(np.int32)

    # sort by decreasing probability
    order = np.argsort(-probs)
    probs = probs[order]
    y_true = y_true[order]

    tp = 0.0
    fp = 0.0
    tps = []
    fps = []

    for i in range(len(y_true)):
        if y_true[i] == 1:
            tp += 1
        else:
            fp += 1
        tps.append(tp)
        fps.append(fp)

    tps = np.asarray(tps)
    fps = np.asarray(fps)

    P = tps[-1]  # total positives
    if P == 0:
        return 0.0

    precision = tps / (tps + fps + 1e-8)
    recall    = tps / P

    # integrate precision(recall) with trapezoidal rule
    # need recall sorted ascending
    order_r = np.argsort(recall)
    recall = recall[order_r]
    precision = precision[order_r]
    auprc = np.trapz(precision, recall)
    return float(auprc)

@torch.no_grad()
def dice_from_logits(logits: torch.Tensor, y: torch.Tensor, thr: float = 0.5) -> float:
    probs = torch.sigmoid(logits)
    pred = probs > thr
    gt = y > 0.5
    tp = (pred & gt).sum().item()
    fp = (pred & ~gt).sum().item()
    fn = (~pred & gt).sum().item()
    if tp + fp + fn == 0:
        return 0.0
    return float((2 * tp) / (2 * tp + fp + fn + 1e-8))

# -------------------------
# Dataset for Crop2 fine-tuning
# -------------------------
def aug_flip_rot(x: torch.Tensor, y: torch.Tensor):
    # x: [B,C,H,W], y: [B,1,H,W]
    k = random.randint(0, 3)
    if k:
        x = torch.rot90(x, k=k, dims=(-2, -1))
        y = torch.rot90(y, k=k, dims=(-2, -1))
    if random.random() < 0.5:
        x = torch.flip(x, dims=(-1,))
        y = torch.flip(y, dims=(-1,))
    if random.random() < 0.5:
        x = torch.flip(x, dims=(-2,))
        y = torch.flip(y, dims=(-2,))
    return x, y

class Crop2FineTuneDataset(Dataset):
    """
    2.5D K=5 patches from specified quadrants of Crop2.
    Each sample = one center slice z with its K-slice neighborhood and center-slice GT.
    """
    def __init__(self, raw: np.ndarray, gt: np.ndarray,
                 quads: List[Tuple[int,int,int,int]],
                 k: int = 5, augment: bool = False):
        super().__init__()
        self.raw = raw.astype(np.float32)  # [Z,H,W]
        self.gt = gt.astype(np.float32)    # [Z,H,W]
        self.quads = quads
        self.k = k
        self.pad = k // 2
        self.augment = augment

        Z, H, W = raw.shape
        # valid centers: z=2..27 for K=5 with Z=30 (common26)
        self.z_centers = list(range(self.pad, Z - self.pad))
        self.samples = []  # (quad_idx, z_center)

        for qi, (y0,y1,x0,x1) in enumerate(self.quads):
            for z in self.z_centers:
                self.samples.append((qi, z))

        # pre-normalize whole volume once (global min/max)
        self.raw_n = (self.raw - cfg.GLOBAL_MIN) / (cfg.GLOBAL_MAX - cfg.GLOBAL_MIN + 1e-12)
        self.raw_n = np.clip(self.raw_n, 0.0, 1.0).astype(np.float32)

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        qi, z = self.samples[idx]
        y0,y1,x0,x1 = self.quads[qi]

        xk = self.raw_n[z - self.pad : z + self.pad + 1, y0:y1, x0:x1]   # [K,256,256]
        yk = (self.gt[z, y0:y1, x0:x1] > 0).astype(np.float32)          # [256,256]

        x = torch.from_numpy(xk)                       # [K,H,W]
        y = torch.from_numpy(yk).unsqueeze(0)          # [1,H,W]

        if self.augment:
            xb = x.unsqueeze(0)
            yb = y.unsqueeze(0)
            xb, yb = aug_flip_rot(xb, yb)
            x = xb.squeeze(0)
            y = yb.squeeze(0)

        return x, y

# -------------------------
# Dice loss
# -------------------------
class DiceLoss(nn.Module):
    def __init__(self, eps: float = 1e-8):
        super().__init__()
        self.eps = eps

    def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
        probs = torch.sigmoid(logits)
        probs = probs.view(probs.size(0), -1)
        targets = targets.view(targets.size(0), -1)
        inter = (probs * targets).sum(dim=1)
        denom = probs.sum(dim=1) + targets.sum(dim=1)
        dice = (2 * inter + self.eps) / (denom + self.eps)
        return 1.0 - dice.mean()

# -------------------------
# Build data + model
# -------------------------
# load crop2 volumes
raw = tifffile.imread(cfg.CROP2_RAW).astype(np.float32)  # [Z,H,W]
gt  = tifffile.imread(cfg.CROP2_GT).astype(np.float32)   # [Z,H,W]
Z,H,W = raw.shape
assert H == 512 and W == 512 and Z == 30, (raw.shape, "unexpected crop2 shape")

# quadrants: (y0,y1,x0,x1)
TL = (0,256,   0,256)
TR = (0,256, 256,512)   # dev quadrant we used for k selection
BL = (256,512, 0,256)
BR = (256,512,256,512)

train_quads = [TL, BL, BR]
val_quads   = [TR]

ds_tr = Crop2FineTuneDataset(raw, gt, train_quads, k=cfg.K, augment=True)
ds_va = Crop2FineTuneDataset(raw, gt, val_quads,   k=cfg.K, augment=False)

print(f"Train samples: {len(ds_tr)} (quads={len(train_quads)}, z-centers={len(ds_tr.z_centers)})")
print(f"Val samples:   {len(ds_va)} (quad=TR, z-centers={len(ds_va.z_centers)})")

dl_tr = DataLoader(ds_tr, batch_size=cfg.BATCH, shuffle=True,
                   num_workers= 0, pin_memory=True, drop_last=True)
dl_va = DataLoader(ds_va, batch_size=cfg.BATCH, shuffle=False,
                   num_workers= 0, pin_memory=True, drop_last=False)

# load checkpoint & build model with same hyperparams
ckpt = torch.load(cfg.CKPT_PATH, map_location="cpu")
cfg_ckpt = ckpt.get("cfg", {})
base_ch = int(cfg_ckpt.get("BASE_CH", cfg_ckpt.get("base_ch", 64)))
gn_groups = int(cfg_ckpt.get("GN_GROUPS", cfg_ckpt.get("gn_groups", 8)))
drop_p = float(cfg_ckpt.get("DROPOUT_BOTTLENECK", cfg_ckpt.get("dropout_bottleneck", 0.10)))

model = UNet2D(in_ch=cfg.K, base_ch=base_ch, groups=gn_groups,
               dropout_bottleneck=drop_p).to(cfg.DEVICE)
model.load_state_dict(ckpt["model_state"])
print("Loaded base model from:", cfg.CKPT_PATH)

optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.LR, weight_decay=cfg.WEIGHT_DECAY)
dice_loss_fn = DiceLoss()
bce_loss_fn = nn.BCEWithLogitsLoss()

log_path = os.path.join(cfg.RUN_DIR, "finetune_log.csv")
with open(log_path, "w") as f:
    f.write("epoch,train_loss,val_loss,val_auprc,val_dice0.5\n")

best_auprc = -1.0
best_path = os.path.join(cfg.RUN_DIR, "best_finetune.pt")
last_path = os.path.join(cfg.RUN_DIR, "last_finetune.pt")

# -------------------------
# Fine-tuning loop
# -------------------------
for epoch in range(1, cfg.EPOCHS + 1):
    # --- train ---
    model.train()
    tr_losses = []
    pbar = tqdm(dl_tr, desc=f"Epoch {epoch}/{cfg.EPOCHS} [train]", leave=False)
    for xb, yb in pbar:
        xb = xb.to(cfg.DEVICE, non_blocking=True)
        yb = yb.to(cfg.DEVICE, non_blocking=True)

        optimizer.zero_grad(set_to_none=True)
        logits = model(xb)

        loss_d = dice_loss_fn(logits, yb)
        loss_b = bce_loss_fn(logits, yb)
        loss = loss_d + loss_b

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.CLIP_NORM)
        optimizer.step()

        tr_losses.append(float(loss.item()))
        pbar.set_postfix(loss=f"{np.mean(tr_losses):.4f}")

    train_loss = float(np.mean(tr_losses)) if tr_losses else float("nan")

    # --- validation (AUPRC + Dice@0.5) ---
    model.eval()
    va_losses = []
    logits_all = []
    y_all = []
    with torch.no_grad():
        for xb, yb in dl_va:
            xb = xb.to(cfg.DEVICE, non_blocking=True)
            yb = yb.to(cfg.DEVICE, non_blocking=True)
            logits = model(xb)
            loss = dice_loss_fn(logits, yb) + bce_loss_fn(logits, yb)
            va_losses.append(float(loss.item()))
            logits_all.append(logits.cpu())
            y_all.append(yb.cpu())

    val_loss = float(np.mean(va_losses)) if va_losses else float("nan")
    logits_all = torch.cat(logits_all, dim=0)
    y_all = torch.cat(y_all, dim=0)

    val_auprc = binary_auprc_from_logits(logits_all, y_all)
    val_dice = dice_from_logits(logits_all, y_all, thr=0.5)

    print(f"Epoch {epoch:03d}/{cfg.EPOCHS} | "
          f"train_loss={train_loss:.4f} val_loss={val_loss:.4f} | "
          f"val AUPRC={val_auprc:.4f} Dice@0.5={val_dice:.4f}")

    with open(log_path, "a") as f:
        f.write(f"{epoch},{train_loss},{val_loss},{val_auprc},{val_dice}\n")

    # save last
    torch.save({"epoch": epoch,
                "model_state": model.state_dict(),
                "optimizer_state": optimizer.state_dict(),
                "cfg": cfg_ckpt}, last_path)

    # save best by AUPRC
    if val_auprc > best_auprc:
        best_auprc = val_auprc
        torch.save({"epoch": epoch,
                    "model_state": model.state_dict(),
                    "optimizer_state": optimizer.state_dict(),
                    "cfg": cfg_ckpt,
                    "best_val_auprc": best_auprc}, best_path)

print("\nFine-tuning done.")
print("Best dev AUPRC:", best_auprc)
print("Saved:", log_path)
print("      ", best_path)
print("      ", last_path)
