In [1]:
# Repro & core
import os, math, json, time, copy, random, gc, h5py
import numpy as np
import torch, torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler

# Reproducibility
SEED = 1337
random.seed(SEED); np.random.seed(SEED)
torch.manual_seed(SEED); torch.cuda.manual_seed_all(SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ---- Paths (edit to your env) ----
DATA = {
    "train_h5": "/home/karlo/train_chunked.h5",
    "test_h5":  "../DATA/test.h5",     # optional for end evaluation
    "train_csv": "../DATA/train.csv",  # optional for plots
    "test_csv":  "../DATA/test.csv",   # optional for plots
}

# ---- Training knobs ----
TILE        = 128
BATCH       = 64         # raise if memory allows
NUM_WORKERS = 2          # HDF5 prefers small (0–2)
EPOCHS      = 4          # 3–6 is a good range for this probe
MAX_LR      = 3e-4       # we use a flat LR in this probe
SAVE_DIR    = "./checkpoints"
os.makedirs(SAVE_DIR, exist_ok=True)

# Toggle AMP for speed; disable if it destabilizes
USE_AMP = (device.type == "cuda")


device: cuda


In [2]:
def _free_cuda():
    if device.type == "cuda":
        torch.cuda.synchronize()
        torch.cuda.empty_cache()

def resize_masks_to(logits, masks):
    """Make masks match logits spatially via nearest (keeps 0/1)."""
    H, W = logits.shape[-2:]
    if masks.dtype != torch.float32:
        masks = masks.float()
    if masks.dim() == 3:  # (B,H,W) -> (B,1,H,W)
        masks = masks.unsqueeze(1)
    if masks.shape[-2:] == (H, W):
        return (masks > 0.5).float()
    out = F.interpolate(masks, size=(H, W), mode="nearest")
    return (out > 0.5).float()


In [3]:
def robust_stats_mad(arr):
    med = np.median(arr); mad = np.median(np.abs(arr - med))
    sigma = 1.4826 * (mad + 1e-12)
    return np.float32(med), np.float32(1.0 if not np.isfinite(sigma) or sigma<=0 else sigma)

class H5TiledDataset(Dataset):
    """Stream tiles from big (H,W) images, robust-normalize per-image, k-sigma clip, pad edges."""
    def __init__(self, h5_path, tile=128, k_sigma=5.0, crop_for_stats=512):
        self.h5_path, self.tile, self.k_sigma, self.crop_for_stats = h5_path, int(tile), float(k_sigma), int(crop_for_stats)
        self._h5 = self._x = self._y = None
        self._stats_cache = {}
        with h5py.File(self.h5_path, "r") as f:
            self.N, self.H, self.W = f["images"].shape
            assert f["masks"].shape == (self.N, self.H, self.W)
        Hb = math.ceil(self.H/self.tile); Wb = math.ceil(self.W/self.tile)
        self.indices = [(i, r, c) for i in range(self.N) for r in range(Hb) for c in range(Wb)]
    def _ensure(self):
        if self._h5 is None:
            self._h5 = h5py.File(self.h5_path, "r")
            self._x, self._y = self._h5["images"], self._h5["masks"]
    def _image_stats(self, i):
        if i in self._stats_cache: return self._stats_cache[i]
        s = min(self.crop_for_stats, self.H, self.W)
        h0, w0 = (self.H-s)//2, (self.W-s)//2
        crop = self._x[i, h0:h0+s, w0:w0+s].astype("float32")
        med, sig = robust_stats_mad(crop); self._stats_cache[i] = (med, sig); return med, sig
    def __len__(self): return len(self.indices)
    def __getitem__(self, idx):
        self._ensure()
        i, r, c = self.indices[idx]; t = self.tile
        r0, c0 = r*t, c*t; r1, c1 = min(r0+t, self.H), min(c0+t, self.W)
        x = self._x[i, r0:r1, c0:c1].astype("float32"); y = self._y[i, r0:r1, c0:c1].astype("float32")
        if x.shape != (t,t):
            xp = np.zeros((t,t), np.float32); yp = np.zeros((t,t), np.float32)
            xp[:x.shape[0], :x.shape[1]] = x; yp[:y.shape[0], :y.shape[1]] = y; x, y = xp, yp
        med, sig = self._image_stats(i); x = np.clip((x-med)/sig, -5, 5)
        return torch.from_numpy(x[None,...]), torch.from_numpy(y[None,...])

class SubsetDS(Dataset):
    """Select full panels by id while reusing tiling of base dataset."""
    def __init__(self, base, panel_ids):
        self.base, self.panel_ids = base, np.asarray(panel_ids)
        t = base.tile; Hb, Wb = math.ceil(base.H/t), math.ceil(base.W/t)
        base_map = {(i,r,c):k for k,(i,r,c) in enumerate(base.indices)}
        self.map = [base_map[(i,r,c)] for i in self.panel_ids for r in range(Hb) for c in range(Wb)]
    def __len__(self): return len(self.map)
    def __getitem__(self, k): return self.base[self.map[k]]

def tile_pos_weights(h5_path, tile=128):
    with h5py.File(h5_path, "r") as f:
        Y = f["masks"]; N,H,W = Y.shape
    Hb, Wb = math.ceil(H/tile), math.ceil(W/tile)
    w = []
    with h5py.File(h5_path,"r") as f:
        Y = f["masks"]
        for i in range(N):
            for r in range(Hb):
                for c in range(Wb):
                    r0,c0=r*tile,c*tile; r1,c1=min(r0+tile,H),min(c0+tile,W)
                    w.append(1.0 + 9.0*(Y[i,r0:r1,c0:c1].any()))
    return np.asarray(w, np.float64)


def panels_with_positives(h5_path, tile=128, max_panels=None):
    ids=[]
    with h5py.File(h5_path,'r') as f:
        Y=f['masks']; N,H,W=Y.shape
        rng = np.random.default_rng(0)
        order = rng.permutation(N) if max_panels else np.arange(N)
        for i in order:
            yi = Y[i]
            if yi.any(): ids.append(i)
            if max_panels and len(ids)>=max_panels: break
    return np.array(sorted(ids))

# per-panel median/MAD normalize, clip like stream_panels_direct
def norm_medmad_clip(x, clip=5.0, eps=1e-6):
    # x: torch.Tensor [B,1,H,W] or [1,H,W]
    if x.ndim == 4:
        med = x.median(dim=-1, keepdim=True).values.median(dim=-2, keepdim=True).values
    else:  # [1,H,W]
        med = x.median()
        med = med.view(1,1,1)
    mad = (x - med).abs().median()
    sigma = 1.4826 * mad + eps
    z = (x - med) / sigma
    return z.clamp_(-clip, clip)

class WithTransform(torch.utils.data.Dataset):
    def __init__(self, base): self.base = base
    def __len__(self): return len(self.base)
    def __getitem__(self, i):
        x, y = self.base[i]
        x = norm_medmad_clip(x)
        return x, y

In [4]:
# --- MODEL ---
class SEBlock(nn.Module):
    def __init__(self,c,r=8): super().__init__(); self.fc1=nn.Conv2d(c,c//r,1); self.fc2=nn.Conv2d(c//r,c,1)
    def forward(self,x): s=F.adaptive_avg_pool2d(x,1); s=F.silu(self.fc1(s),inplace=True); s=torch.sigmoid(self.fc2(s)); return x*s
def _norm(c, groups=8): g=min(groups,c) if c%groups==0 else 1; return nn.GroupNorm(g,c)

class ResBlock(nn.Module):
    def __init__(self,c_in,c_out): super().__init__(); p=1
    def __init__(self,c_in,c_out,k=3,act=nn.SiLU,se=True):
        super().__init__(); p=k//2
        self.proj = nn.Identity() if c_in==c_out else nn.Conv2d(c_in,c_out,1)
        self.bn1=_norm(c_in); self.c1=nn.Conv2d(c_in,c_out,k,padding=p,bias=False)
        self.bn2=_norm(c_out); self.c2=nn.Conv2d(c_out,c_out,k,padding=p,bias=False)
        self.act=act(); self.se=SEBlock(c_out) if se else nn.Identity()
    def forward(self,x):
        h=self.act(self.bn1(x)); h=self.c1(h)
        h=self.act(self.bn2(h)); h=self.c2(h)
        h=self.se(h); return h + self.proj(x)

class Down(nn.Module):
    def __init__(self,c_in,c_out): super().__init__(); self.pool=nn.MaxPool2d(2); self.rb=ResBlock(c_in,c_out)
    def forward(self,x): return self.rb(self.pool(x))

class Up(nn.Module):
    def __init__(self,c_in,c_skip,c_out): super().__init__(); self.up=nn.ConvTranspose2d(c_in,c_in,2,stride=2); self.rb1=ResBlock(c_in+c_skip,c_out); self.rb2=ResBlock(c_out,c_out)
    def forward(self,x,skip):
        x=self.up(x)
        dh=skip.size(-2)-x.size(-2); dw=skip.size(-1)-x.size(-1)
        if dh or dw: x=F.pad(x,(0,max(0,dw),0,max(0,dh)))
        x=torch.cat([x,skip],1); x=self.rb1(x); x=self.rb2(x); return x

class ASPP(nn.Module):
    def __init__(self,c,r=[1,6,12,18]):
        super().__init__()
        self.blocks=nn.ModuleList([nn.Sequential(nn.Conv2d(c,c//4,3,padding=d,dilation=d,bias=False), nn.BatchNorm2d(c//4), nn.SiLU(True)) for d in r])
        self.project=nn.Conv2d(c,c,1)
    def forward(self,x): return self.project(torch.cat([b(x) for b in self.blocks],1))

class UNetResSE(nn.Module):
    def __init__(self,in_ch=1,out_ch=1,widths=(32,64,128,256,512)):
        super().__init__(); w=widths
        self.stem=nn.Sequential(nn.Conv2d(in_ch,w[0],3,padding=1,bias=False), nn.BatchNorm2d(w[0]), nn.SiLU(True), ResBlock(w[0],w[0]))
        self.d1=Down(w[0],w[1]); self.d2=Down(w[1],w[2]); self.d3=Down(w[2],w[3]); self.d4=Down(w[3],w[4])
        self.u1=Up(w[4],w[3],w[3]); self.u2=Up(w[3],w[2],w[2]); self.u3=Up(w[2],w[1],w[1]); self.u4=Up(w[1],w[0],w[0])
        self.head=nn.Conv2d(w[0],out_ch,1)
    def forward(self,x):
        s0=self.stem(x); s1=self.d1(s0); s2=self.d2(s1); s3=self.d3(s2); b=self.d4(s3)
        x=self.u1(b,s3); x=self.u2(x,s2); x=self.u3(x,s1); x=self.u4(x,s0); return self.head(x) # logits

class UNetResSEASPP(UNetResSE):
    def __init__(self,in_ch=1,out_ch=1,widths=(32,64,128,256,512)):
        super().__init__(in_ch,out_ch,widths); self.aspp=ASPP(widths[-1]); self.d4=Down(widths[3],widths[4])
    def forward(self,x):
        s0=self.stem(x); s1=self.d1(s0); s2=self.d2(s1); s3=self.d3(s2); b=self.d4(s3); b=self.aspp(b)
        x=self.u1(b,s3); x=self.u2(x,s2); x=self.u3(x,s1); x=self.u4(x,s0); return self.head(x)

In [5]:
class SoftIoU(nn.Module):
    def __init__(self, eps=1e-6): 
        super().__init__(); self.eps=eps
    def forward(self, logits, targets):
        p = torch.sigmoid(logits)
        t = targets.clamp(0,1)
        p = p.view(p.size(0), -1)
        t = t.view(t.size(0), -1)
        inter = (p*t).sum(1)
        union = p.sum(1) + t.sum(1) - inter
        iou   = (inter + self.eps) / (union + self.eps)
        return (1 - iou).mean()

class SoftIoUWithBCE(nn.Module):
    """
    total = lambda_bce * BCE(pos_weight) + (1 - lambda_bce) * SoftIoU
    """
    def __init__(self, pos_weight=8.0, lambda_bce=0.7):
        super().__init__()
        self.lambda_bce = float(lambda_bce)
        self.pos_weight = float(pos_weight)
        self.bce  = nn.BCEWithLogitsLoss(reduction='mean')
        self.siou = SoftIoU()
    def forward(self, logits, targets):
        t = targets.clamp(0,1)
        loss_bce  = self.bce(logits, t) if self.pos_weight<=0 else \
                    F.binary_cross_entropy_with_logits(logits, t, pos_weight=torch.tensor(self.pos_weight, device=logits.device))
        loss_siou = self.siou(logits, t)
        return self.lambda_bce*loss_bce + (1.0-self.lambda_bce)*loss_siou


# (Optional) More aggressive FP control for later HPO:
class AsymFocalTversky(nn.Module):
    def __init__(self, alpha=0.35, beta=0.65, gamma=1.2, eps=1e-6):
        super().__init__(); self.alpha, self.beta, self.gamma, self.eps = alpha,beta,gamma,eps
    def forward(self, logits, targets):
        p = torch.sigmoid(logits).clamp(self.eps, 1-self.eps); t = targets.clamp(0,1)
        p = p.view(p.size(0), -1); t = t.view(t.size(0), -1)
        TP = (p*t).sum(1); FP = ((1-t)*p).sum(1); FN = (t*(1-p)).sum(1)
        tv = (TP+self.eps)/(TP+self.alpha*FP+self.beta*FN+self.eps)
        return torch.pow(1.0 - tv, self.gamma).mean()


In [6]:
@torch.no_grad()
def pix_metrics(model, loader, thr=0.5, n_batches=6):
    model.eval()
    dev = next(model.parameters()).device
    tp = fp = fn = 0.0
    pos_means, neg_means = [], []
    t0 = time.time()
    for i,(xb,yb) in enumerate(loader,1):
        xb, yb = xb.to(dev, non_blocking=True), yb.to(dev, non_blocking=True)
        logits = model(xb)
        yb_r   = resize_masks_to(logits, yb)
        p      = torch.sigmoid(logits)
        if (yb_r>0.5).any(): pos_means.append(float(p[yb_r>0.5].mean()))
        neg_means.append(float(p[yb_r<=0.5].mean()))
        pv, tv = p.reshape(-1), yb_r.reshape(-1)
        pred   = (pv>=thr).float()
        tp += float((pred*tv).sum()); fp += float((pred*(1-tv)).sum()); fn += float(((1-pred)*tv).sum())
        if i>=n_batches: break
    P = tp/max(tp+fp,1); R = tp/max(tp+fn,1); f1 = 2*P*R/max(P+R,1e-8)
    print(f"[quick_prob_stats] batches={min(n_batches,i)} | pos≈{np.mean(pos_means) if pos_means else float('nan'):.4f} | "
          f"neg≈{np.mean(neg_means):.4f} | P {P:.3f} R {R:.3f} F1 {f1:.3f} @ thr={thr:.3f} | {time.time()-t0:.1f}s")
    return dict(P=P,R=R,F1=f1,pos_mean=np.mean(pos_means) if pos_means else float('nan'),neg_mean=np.mean(neg_means))

@torch.no_grad()
def pick_thr_under_min(model, loader, max_batches=40, n_bins=256, beta=2.0):
    """Histogram-based pixel threshold selection (recall-lean if beta>1)."""
    model.eval(); dev = next(model.parameters()).device
    hist_pos = torch.zeros(n_bins, device=dev); hist_neg = torch.zeros(n_bins, device=dev)
    edges = torch.linspace(0,1,n_bins+1, device=dev)
    for i,(xb,yb) in enumerate(loader,1):
        xb,yb = xb.to(dev), yb.to(dev)
        p = torch.sigmoid(model(xb))
        yb_r = resize_masks_to(p, yb)
        pv = p.reshape(-1); tv = (yb_r>0.5).reshape(-1)
        hist_pos += torch.histc(pv[tv], bins=n_bins, min=0, max=1)
        hist_neg += torch.histc(pv[~tv], bins=n_bins, min=0, max=1)
        if i>=max_batches: break
    cpos = torch.flip(torch.cumsum(torch.flip(hist_pos, dims=[0]), 0), dims=[0])  # >=t
    cneg = torch.flip(torch.cumsum(torch.flip(hist_neg, dims=[0]), 0), dims=[0])
    TP = cpos; FP = cneg; FN = (hist_pos.sum() - TP).clamp(min=0)
    P = TP / (TP + FP + 1e-8); R = TP / (TP + FN + 1e-8)
    fbeta = (1+beta*beta)*P*R / (beta*beta*P + R + 1e-8)
    idx = int(torch.argmax(fbeta).item())
    thr = float((edges[idx] + edges[idx+1])/2)
    return thr, (float(P[idx]), float(R[idx]), float(fbeta[idx])), dict(pos_rate=float((TP[idx]+FP[idx])/(hist_pos.sum()+hist_neg.sum()+1e-8)))

@torch.no_grad()
def pick_thr_with_floor(model, loader, max_batches=40, n_bins=256, beta=1.0, min_pos_rate=0.05, max_pos_rate=0.10):
    thr, (P,R,F), aux = pick_thr_under_min(model, loader, max_batches=max_batches, n_bins=n_bins, beta=beta)
    # simple clamp pass using percentile of preds to hit pos_rate band
    # (if your earlier “floor” function is available, feel free to swap it in)
    return thr, (P,R,F), aux


In [7]:
def init_head_bias_to_prior(model, p0=0.70):
    b = math.log(p0/(1-p0))
    with torch.no_grad():
        if hasattr(model, "head") and hasattr(model.head, "bias"):
            model.head.bias.data.fill_(b)

def set_requires_grad(mod, flag: bool):
    for p in mod.parameters(): p.requires_grad = flag

def freeze_all(model): set_requires_grad(model, False)

def unfreeze_head_only(model):
    freeze_all(model)
    if hasattr(model, "head"):
        set_requires_grad(model.head, True)
    else:
        raise AttributeError("Model has no attribute 'head'.")

def unfreeze_head_and_tail(model):
    """
    Unfreeze head + late upsample blocks + ASPP (assumes UNetResSEASPP)
    Adjust attribute names if your class differs.
    """
    freeze_all(model)
    if hasattr(model, "head"): set_requires_grad(model.head, True)
    for path in ["u3", "u4", "aspp"]:
        if hasattr(model, path): set_requires_grad(getattr(model, path), True)

def fit_quick_warmup(model, loader, epochs=2, max_batches=800, lr=2e-4, metric_thr=0.20, pos_weight=30.0):
    dev = next(model.parameters()).device
    model.train()
    opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=0.0)
    bce = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(pos_weight, device=dev))
    for ep in range(1, epochs+1):
        seen=tp=fp=fn=0.0; loss_sum=0.0
        t0=time.time()
        for b,(xb,yb) in enumerate(loader,1):
            xb,yb = xb.to(dev), yb.to(dev)
            logits = model(xb)
            yb_r   = resize_masks_to(logits, yb)
            loss = bce(logits, yb_r)
            opt.zero_grad(set_to_none=True); loss.backward(); opt.step()
            with torch.no_grad():
                p = torch.sigmoid(logits); pv, tv = p.view(-1), yb_r.view(-1)
                pred = (pv>=metric_thr).float()
                tp += float((pred*tv).sum()); fp += float((pred*(1-tv)).sum()); fn += float(((1-pred)*tv).sum())
                loss_sum += float(loss.item())*xb.size(0); seen += xb.size(0)
            if b>=max_batches: break
        P = tp/max(tp+fp,1); R = tp/max(tp+fn,1); f1 = 2*P*R/max(P+R,1e-8)
        print(f"[WARMUP] ep{ep} loss {loss_sum/seen:.4f} | F1 {f1:.3f} P {P:.3f} R {R:.3f}")

def fit_head_only(model, loader, epochs=2, max_batches=600, lr=3e-5, metric_thr=0.15, pos_weight=5.0):
    dev = next(model.parameters()).device
    unfreeze_head_only(model)
    head_params = [p for p in model.head.parameters() if p.requires_grad]
    opt = torch.optim.Adam(head_params, lr=lr, weight_decay=0.0)
    bce = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(pos_weight, device=dev))
    for ep in range(1, epochs+1):
        seen=tp=fp=fn=0.0; loss_sum=0.0
        for b,(xb,yb) in enumerate(loader,1):
            xb,yb = xb.to(dev), yb.to(dev)
            logits = model(xb); yb_r = resize_masks_to(logits, yb)
            loss = bce(logits, yb_r)
            opt.zero_grad(set_to_none=True); loss.backward(); opt.step()
            with torch.no_grad():
                p = torch.sigmoid(logits); pv, tv = p.view(-1), yb_r.view(-1)
                pred = (pv>=metric_thr).float()
                tp += float((pred*tv).sum()); fp += float((pred*(1-tv)).sum()); fn += float(((1-pred)*tv).sum())
                loss_sum += float(loss.item())*xb.size(0); seen += xb.size(0)
            if b>=max_batches: break
        P = tp/max(tp+fp,1); R = tp/max(tp+fn,1); f1 = 2*P*R/max(P+R,1e-8)
        print(f"[HEAD] ep{ep} loss {loss_sum/seen:.4f} | F1 {f1:.3f} P {P:.3f} R {R:.3f}")

def brief_tail_probe(model, loader, criterion, max_batches=400, lr=3e-4, wd=1e-4):
    dev = next(model.parameters()).device
    unfreeze_head_and_tail(model)
    opt = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, weight_decay=wd)
    loss_sum=0.0; seen=0; t0=time.time()
    for b,(xb,yb) in enumerate(loader,1):
        xb,yb = xb.to(dev), yb.to(dev)
        logits = model(xb); yb_r = resize_masks_to(logits, yb)
        loss = criterion(logits, yb_r)
        opt.zero_grad(set_to_none=True); loss.backward(); opt.step()
        loss_sum += float(loss.item())*xb.size(0); seen += xb.size(0)
        if b>=max_batches: break
    print(f"[tail-probe] loss≈{loss_sum/max(seen,1):.4f}")

def one_epoch(model, loader, criterion, opt):
    t0=time.time()
    model.train(); loss_sum=0.0; seen=0
    for b, (xb,yb) in enumerate(loader,1):
        xb,yb = xb.to(device), yb.to(device)
        logits = model(xb); yb_r = resize_masks_to(logits, yb)
        loss = criterion(logits, yb_r)
        opt.zero_grad(set_to_none=True); loss.backward(); opt.step()
        loss_sum += float(loss.item())*xb.size(0); seen += xb.size(0)
        print(f"\r[TRAIN] batch {b}/{len(loader)} | loss={loss_sum/seen:.4f} | {time.time()-t0:.1f}s", end='')
    print (f"\r")
    return loss_sum/max(seen,1)

@torch.no_grad()
def val_pick_thr(model, vloader, min_pr=0.05, max_pr=0.10):
    thr, (P,R,F), aux = pick_thr_with_floor(model, vloader, max_batches=60, n_bins=256, beta=1.0,
                                            min_pos_rate=min_pr, max_pos_rate=max_pr)
    return float(thr), (P,R,F), aux

In [8]:
with h5py.File(DATA["train_h5"], "r") as f:
    N = f["images"].shape[0]
idx = np.arange(N); np.random.shuffle(idx)
split = int(0.9*N)
idx_tr, idx_va = np.sort(idx[:split]), np.sort(idx[split:])

ds_full = H5TiledDataset(DATA["train_h5"], tile=TILE, k_sigma=5.0)

pos_panels = panels_with_positives(DATA["train_h5"], max_panels=2000)
sub_tr = np.random.default_rng(SEED).choice(np.intersect1d(idx_tr, pos_panels),
                                            size=min(200, len(pos_panels)), replace=False)
sub_va = np.random.default_rng(SEED+1).choice(np.intersect1d(idx_va, pos_panels),
                                              size=min(80, len(pos_panels)), replace=False)

train_loader_small = DataLoader(SubsetDS(ds_full, np.sort(sub_tr)), batch_size=BATCH, shuffle=True,
                                num_workers=NUM_WORKERS, pin_memory=(device.type=='cuda'))
val_loader_small   = DataLoader(SubsetDS(ds_full, np.sort(sub_va)), batch_size=BATCH, shuffle=False,
                                num_workers=NUM_WORKERS, pin_memory=(device.type=='cuda'))

# Full loaders (for end-to-end epochs)
train_loader = DataLoader(SubsetDS(ds_full, idx_tr), batch_size=BATCH, shuffle=True,
                          num_workers=NUM_WORKERS, pin_memory=(device.type=='cuda'))
val_loader   = DataLoader(SubsetDS(ds_full, idx_va), batch_size=BATCH, shuffle=False,
                          num_workers=NUM_WORKERS, pin_memory=(device.type=='cuda'))


In [9]:
probe = UNetResSEASPP(in_ch=1, out_ch=1).to(device)
init_head_bias_to_prior(probe, p0=0.80)
torch.backends.cudnn.benchmark = False

print("Warmup…")
fit_quick_warmup(probe, train_loader_small, epochs=1, max_batches=400, lr=2e-4, metric_thr=0.20, pos_weight=40.0)
thr0, *_ = pick_thr_under_min(probe, val_loader_small, max_batches=40, n_bins=256, beta=2.0)
thr0 = float(np.clip(thr0, 0.10, 0.20))
print(f"[thr0] ≈ {thr0:.3f}")

fit_head_only(probe, train_loader_small, epochs=2, max_batches=600, lr=3e-5, metric_thr=thr0, pos_weight=5.0)
_ = pix_metrics(probe, train_loader_small, thr=thr0, n_batches=6)

# Tail probe with SoftIoU+BCE (default training loss)
criterion = SoftIoUWithBCE(pos_weight=8.0, lambda_bce=0.7).to(device)
brief_tail_probe(probe, train_loader_small, criterion, max_batches=400, lr=3e-4, wd=1e-4)
_ = pix_metrics(probe, train_loader_small, thr=thr0, n_batches=6)

# Snapshot probe state for full training
best_state = copy.deepcopy(probe.state_dict())


Warmup…
[WARMUP] ep1 loss 0.4237 | F1 0.018 P 0.009 R 0.448
[thr0] ≈ 0.200
[HEAD] ep1 loss 0.1285 | F1 0.026 P 0.014 R 0.371
[HEAD] ep2 loss 0.1043 | F1 0.034 P 0.018 R 0.239
[quick_prob_stats] batches=6 | pos≈0.1231 | neg≈0.0533 | P 0.020 R 0.192 F1 0.036 @ thr=0.200 | 0.7s
[tail-probe] loss≈0.3797
[quick_prob_stats] batches=6 | pos≈0.0550 | neg≈0.0270 | P 0.056 R 0.065 F1 0.060 @ thr=0.200 | 0.7s


In [10]:
# Reset model to snapshot before full train
probe.load_state_dict(best_state, strict=True)
unfreeze_head_and_tail(probe)

criterion = SoftIoUWithBCE(pos_weight=8.0, lambda_bce=0.7).to(device)
opt = torch.optim.Adam(filter(lambda p: p.requires_grad, probe.parameters()), lr=MAX_LR, weight_decay=1e-4)

best = {"F": -1.0, "thr": None, "ep": 0, "state": None}
metric_thr = thr0

for ep in range(1, EPOCHS+1):
    # small LR bump on stagnation
    if ep == 3: 
        for g in opt.param_groups: g["lr"] = 4e-4

    ep_loss = one_epoch(probe, train_loader, criterion, opt)
    tr = pix_metrics(probe, train_loader_small, thr=metric_thr, n_batches=6)
    print(f"[EP{ep:02d}] loss {ep_loss:.4f} | train P {tr['P']:.3f} R {tr['R']:.3f} F {tr['F1']:.3f}")

    metric_thr, (VP,VR,VF), _ = val_pick_thr(probe, val_loader_small, min_pr=0.05, max_pr=0.10)
    print(f"[thr@ep{ep}] thr={metric_thr:.3f} | val P {VP:.3f} R {VR:.3f} F {VF:.3f}")

    if VF > best["F"]:
        best.update(F=VF, thr=metric_thr, ep=ep, state=copy.deepcopy(probe.state_dict()))
        print(f"[VAL ep{ep}] improved: F {VF:.3f} (thr={metric_thr:.3f})")

print("Best summary:", {k: (round(v,3) if isinstance(v,float) else v) for k,v in best.items() if k!='state'})


[TRAIN] batch 11520/11520 | loss=0.3742 | 1367.2s
[quick_prob_stats] batches=6 | pos≈0.2307 | neg≈0.0246 | P 0.272 R 0.336 F1 0.301 @ thr=0.200 | 0.7s
[EP01] loss 0.3742 | train P 0.272 R 0.336 F 0.301
[thr@ep1] thr=0.436 | val P 0.457 R 0.183 F 0.262
[VAL ep1] improved: F 0.262 (thr=0.436)
[TRAIN] batch 11520/11520 | loss=0.3685 | 1350.1s
[quick_prob_stats] batches=6 | pos≈0.0589 | neg≈0.0200 | P 0.000 R 0.000 F1 0.000 @ thr=0.436 | 0.7s
[EP02] loss 0.3685 | train P 0.000 R 0.000 F 0.000
[thr@ep2] thr=0.373 | val P 0.440 R 0.187 F 0.262
[VAL ep2] improved: F 0.262 (thr=0.373)
[TRAIN] batch 11520/11520 | loss=0.3668 | 1342.1s
[quick_prob_stats] batches=6 | pos≈0.1341 | neg≈0.0192 | P 0.428 R 0.204 F1 0.276 @ thr=0.373 | 0.7s
[EP03] loss 0.3668 | train P 0.428 R 0.204 F 0.276
[thr@ep3] thr=0.350 | val P 0.303 R 0.187 F 0.231
[TRAIN] batch 11520/11520 | loss=0.3640 | 1355.8s
[quick_prob_stats] batches=6 | pos≈0.1519 | neg≈0.0184 | P 0.111 R 0.134 F1 0.122 @ thr=0.350 | 0.7s
[EP04] loss 0

In [11]:
torch.save({"state": probe.state_dict(), "thr": 0.717}, "baseline_ep4.pt")


# Test

In [13]:
DATA = {
    "train_h5": "/home/karlo/train_chunked.h5",
    "test_h5":  "../DATA/test.h5",     # optional for end evaluation
    "train_csv": "../DATA/train.csv",  # optional for plots
    "test_csv":  "../DATA/test.csv",   # optional for plots
}
TILE        = 128
BATCH       = 64         # raise if memory allows
NUM_WORKERS = 10          # HDF5 prefers small (0–2)
SEED = 1337
random.seed(SEED); np.random.seed(SEED)

In [14]:
import os, math, json, time, copy, random, gc, h5py
import numpy as np
import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler

In [15]:
def robust_stats_mad(arr):
    med = np.median(arr); mad = np.median(np.abs(arr - med))
    sigma = 1.4826 * (mad + 1e-12)
    return np.float32(med), np.float32(1.0 if not np.isfinite(sigma) or sigma<=0 else sigma)

class H5TiledDataset(Dataset):
    """Stream tiles from big (H,W) images, robust-normalize per-image, k-sigma clip, pad edges."""
    def __init__(self, h5_path, tile=128, k_sigma=5.0, crop_for_stats=512):
        self.h5_path, self.tile, self.k_sigma, self.crop_for_stats = h5_path, int(tile), float(k_sigma), int(crop_for_stats)
        self._h5 = self._x = self._y = None
        self._stats_cache = {}
        with h5py.File(self.h5_path, "r") as f:
            self.N, self.H, self.W = f["images"].shape
            assert f["masks"].shape == (self.N, self.H, self.W)
        Hb = math.ceil(self.H/self.tile); Wb = math.ceil(self.W/self.tile)
        self.indices = [(i, r, c) for i in range(self.N) for r in range(Hb) for c in range(Wb)]
    def _ensure(self):
        if self._h5 is None:
            self._h5 = h5py.File(self.h5_path, "r")
            self._x, self._y = self._h5["images"], self._h5["masks"]
    def _image_stats(self, i):
        if i in self._stats_cache: return self._stats_cache[i]
        s = min(self.crop_for_stats, self.H, self.W)
        h0, w0 = (self.H-s)//2, (self.W-s)//2
        crop = self._x[i, h0:h0+s, w0:w0+s].astype("float32")
        med, sig = robust_stats_mad(crop); self._stats_cache[i] = (med, sig); return med, sig
    def __len__(self): return len(self.indices)
    def __getitem__(self, idx):
        self._ensure()
        i, r, c = self.indices[idx]; t = self.tile
        r0, c0 = r*t, c*t; r1, c1 = min(r0+t, self.H), min(c0+t, self.W)
        x = self._x[i, r0:r1, c0:c1].astype("float32"); y = self._y[i, r0:r1, c0:c1].astype("float32")
        if x.shape != (t,t):
            xp = np.zeros((t,t), np.float32); yp = np.zeros((t,t), np.float32)
            xp[:x.shape[0], :x.shape[1]] = x; yp[:y.shape[0], :y.shape[1]] = y; x, y = xp, yp
        med, sig = self._image_stats(i); x = np.clip((x-med)/sig, -5, 5)
        return torch.from_numpy(x[None,...]), torch.from_numpy(y[None,...])

class SubsetDS(Dataset):
    """Select full panels by id while reusing tiling of base dataset."""
    def __init__(self, base, panel_ids):
        self.base, self.panel_ids = base, np.asarray(panel_ids)
        t = base.tile; Hb, Wb = math.ceil(base.H/t), math.ceil(base.W/t)
        base_map = {(i,r,c):k for k,(i,r,c) in enumerate(base.indices)}
        self.map = [base_map[(i,r,c)] for i in self.panel_ids for r in range(Hb) for c in range(Wb)]
    def __len__(self): return len(self.map)
    def __getitem__(self, k): return self.base[self.map[k]]

def tile_pos_weights(h5_path, tile=128):
    with h5py.File(h5_path, "r") as f:
        Y = f["masks"]; N,H,W = Y.shape
    Hb, Wb = math.ceil(H/tile), math.ceil(W/tile)
    w = []
    with h5py.File(h5_path,"r") as f:
        Y = f["masks"]
        for i in range(N):
            for r in range(Hb):
                for c in range(Wb):
                    r0,c0=r*tile,c*tile; r1,c1=min(r0+tile,H),min(c0+tile,W)
                    w.append(1.0 + 9.0*(Y[i,r0:r1,c0:c1].any()))
    return np.asarray(w, np.float64)


def panels_with_positives(h5_path, tile=128, max_panels=None):
    ids=[]
    with h5py.File(h5_path,'r') as f:
        Y=f['masks']; N,H,W=Y.shape
        rng = np.random.default_rng(0)
        order = rng.permutation(N) if max_panels else np.arange(N)
        for i in order:
            yi = Y[i]
            if yi.any(): ids.append(i)
            if max_panels and len(ids)>=max_panels: break
    return np.array(sorted(ids))

# per-panel median/MAD normalize, clip like stream_panels_direct
def norm_medmad_clip(x, clip=5.0, eps=1e-6):
    # x: torch.Tensor [B,1,H,W] or [1,H,W]
    if x.ndim == 4:
        med = x.median(dim=-1, keepdim=True).values.median(dim=-2, keepdim=True).values
    else:  # [1,H,W]
        med = x.median()
        med = med.view(1,1,1)
    mad = (x - med).abs().median()
    sigma = 1.4826 * mad + eps
    z = (x - med) / sigma
    return z.clamp_(-clip, clip)

class WithTransform(torch.utils.data.Dataset):
    def __init__(self, base): self.base = base
    def __len__(self): return len(self.base)
    def __getitem__(self, i):
        x, y = self.base[i]
        x = norm_medmad_clip(x)
        return x, y

In [16]:
# --- MODEL ---
class SEBlock(nn.Module):
    def __init__(self,c,r=8): super().__init__(); self.fc1=nn.Conv2d(c,c//r,1); self.fc2=nn.Conv2d(c//r,c,1)
    def forward(self,x): s=F.adaptive_avg_pool2d(x,1); s=F.silu(self.fc1(s),inplace=True); s=torch.sigmoid(self.fc2(s)); return x*s
def _norm(c, groups=8): g=min(groups,c) if c%groups==0 else 1; return nn.GroupNorm(g,c)

class ResBlock(nn.Module):
    def __init__(self,c_in,c_out): super().__init__(); p=1
    def __init__(self,c_in,c_out,k=3,act=nn.SiLU,se=True):
        super().__init__(); p=k//2
        self.proj = nn.Identity() if c_in==c_out else nn.Conv2d(c_in,c_out,1)
        self.bn1=_norm(c_in); self.c1=nn.Conv2d(c_in,c_out,k,padding=p,bias=False)
        self.bn2=_norm(c_out); self.c2=nn.Conv2d(c_out,c_out,k,padding=p,bias=False)
        self.act=act(); self.se=SEBlock(c_out) if se else nn.Identity()
    def forward(self,x):
        h=self.act(self.bn1(x)); h=self.c1(h)
        h=self.act(self.bn2(h)); h=self.c2(h)
        h=self.se(h); return h + self.proj(x)

class Down(nn.Module):
    def __init__(self,c_in,c_out): super().__init__(); self.pool=nn.MaxPool2d(2); self.rb=ResBlock(c_in,c_out)
    def forward(self,x): return self.rb(self.pool(x))

class Up(nn.Module):
    def __init__(self,c_in,c_skip,c_out): super().__init__(); self.up=nn.ConvTranspose2d(c_in,c_in,2,stride=2); self.rb1=ResBlock(c_in+c_skip,c_out); self.rb2=ResBlock(c_out,c_out)
    def forward(self,x,skip):
        x=self.up(x)
        dh=skip.size(-2)-x.size(-2); dw=skip.size(-1)-x.size(-1)
        if dh or dw: x=F.pad(x,(0,max(0,dw),0,max(0,dh)))
        x=torch.cat([x,skip],1); x=self.rb1(x); x=self.rb2(x); return x

class ASPP(nn.Module):
    def __init__(self,c,r=[1,6,12,18]):
        super().__init__()
        self.blocks=nn.ModuleList([nn.Sequential(nn.Conv2d(c,c//4,3,padding=d,dilation=d,bias=False), nn.BatchNorm2d(c//4), nn.SiLU(True)) for d in r])
        self.project=nn.Conv2d(c,c,1)
    def forward(self,x): return self.project(torch.cat([b(x) for b in self.blocks],1))

class UNetResSE(nn.Module):
    def __init__(self,in_ch=1,out_ch=1,widths=(32,64,128,256,512)):
        super().__init__(); w=widths
        self.stem=nn.Sequential(nn.Conv2d(in_ch,w[0],3,padding=1,bias=False), nn.BatchNorm2d(w[0]), nn.SiLU(True), ResBlock(w[0],w[0]))
        self.d1=Down(w[0],w[1]); self.d2=Down(w[1],w[2]); self.d3=Down(w[2],w[3]); self.d4=Down(w[3],w[4])
        self.u1=Up(w[4],w[3],w[3]); self.u2=Up(w[3],w[2],w[2]); self.u3=Up(w[2],w[1],w[1]); self.u4=Up(w[1],w[0],w[0])
        self.head=nn.Conv2d(w[0],out_ch,1)
    def forward(self,x):
        s0=self.stem(x); s1=self.d1(s0); s2=self.d2(s1); s3=self.d3(s2); b=self.d4(s3)
        x=self.u1(b,s3); x=self.u2(x,s2); x=self.u3(x,s1); x=self.u4(x,s0); return self.head(x) # logits

class UNetResSEASPP(UNetResSE):
    def __init__(self,in_ch=1,out_ch=1,widths=(32,64,128,256,512)):
        super().__init__(in_ch,out_ch,widths); self.aspp=ASPP(widths[-1]); self.d4=Down(widths[3],widths[4])
    def forward(self,x):
        s0=self.stem(x); s1=self.d1(s0); s2=self.d2(s1); s3=self.d3(s2); b=self.d4(s3); b=self.aspp(b)
        x=self.u1(b,s3); x=self.u2(x,s2); x=self.u3(x,s1); x=self.u4(x,s0); return self.head(x)

In [37]:
def resize_masks_to(logits, masks):
    """Make masks match logits spatially via nearest (keeps 0/1)."""
    H, W = logits.shape[-2:]
    if masks.dtype != torch.float32:
        masks = masks.float()
    if masks.dim() == 3:  # (B,H,W) -> (B,1,H,W)
        masks = masks.unsqueeze(1)
    if masks.shape[-2:] == (H, W):
        return (masks > 0.5).float()
    out = F.interpolate(masks, size=(H, W), mode="nearest")
    return (out > 0.5).float()

@torch.no_grad()
def _pix_eval(m, loader, thr=0.2, max_batches=12):
    m.eval(); dev = next(m.parameters()).device
    tp=fp=fn=0.0; posm=[]; negm=[]
    t0=time.time()
    for i,(xb,yb) in enumerate(loader,1):
        xb,yb = xb.to(dev), yb.to(dev)
        logits = m(xb)
        yb_r   = resize_masks_to(logits, yb)
        p      = torch.sigmoid(logits)
        if (yb_r>0.5).any(): posm.append(float(p[yb_r>0.5].mean()))
        negm.append(float(p[yb_r<=0.5].mean()))
        pv,tv = p.view(-1), yb_r.view(-1)
        pred = (pv>=thr).float()
        tp += float((pred*tv).sum()); fp += float((pred*(1-tv)).sum()); fn += float(((1-pred)*tv).sum())
        if i>=max_batches: break
    P = tp/max(tp+fp,1); R = tp/max(tp+fn,1); F = 2*P*R/max(P+R,1e-8)
    return {"P":P,"R":R,"F":F,"pos_mean":float(sum(posm)/max(len(posm),1)), "neg_mean":float(sum(negm)/len(negm))}

# ============ Losses ============
class SoftIoULoss(nn.Module):
    def __init__(self, eps=1e-6): super().__init__(); self.eps=eps
    def forward(self, logits, targets):
        p = torch.sigmoid(logits); t = targets.clamp(0,1)
        inter = (p*t).sum(dim=(1,2,3))
        union = (p + t - p*t).sum(dim=(1,2,3)) + self.eps
        iou = inter/union
        return (1 - iou).mean()

class AFTL(nn.Module):
    def __init__(self, alpha=0.45, beta=0.55, gamma=1.3, eps=1e-6):
        super().__init__(); self.alpha, self.beta, self.gamma, self.eps = alpha,beta,gamma,eps
    def forward(self, logits, targets):
        p = torch.sigmoid(logits).clamp(self.eps, 1-self.eps)
        t = targets.clamp(0,1)
        p = p.view(p.size(0), -1); t = t.view(t.size(0), -1)
        TP = (p*t).sum(1); FP = ((1-t)*p).sum(1); FN = (t*(1-p)).sum(1)
        tv = (TP+self.eps)/(TP + self.alpha*FP + self.beta*FN + self.eps)
        return torch.pow(1.0 - tv, self.gamma).mean()

class BCEIoUEdge(nn.Module):
    """
    λ_bce * BCE(pos_weight) + (1-λ_bce) * SoftIoU [+ λ_edge * Sobel L1]
    """
    def __init__(self, lambda_bce=0.6, pos_weight=8.0, lambda_edge=0.0):
        super().__init__()
        self.lambda_bce = float(lambda_bce)
        self.lambda_edge = float(lambda_edge)
        self.iou = SoftIoULoss()
        self.posw = float(pos_weight)
        # Sobel
        kx = torch.tensor([[[-1,0,1],[-2,0,2],[-1,0,1]]], dtype=torch.float32).unsqueeze(0)
        ky = torch.tensor([[[-1,-2,-1],[0,0,0],[1,2,1]]], dtype=torch.float32).unsqueeze(0)
        self.register_buffer("kx", kx); self.register_buffer("ky", ky)
    def _edge(self, x):
        gx = F.conv2d(x, self.kx, padding=1)
        gy = F.conv2d(x, self.ky, padding=1)
        return torch.sqrt(gx*gx + gy*gy + 1e-12)
    def forward(self, logits, targets):
        t = targets.clamp(0,1)
        posw = torch.tensor(self.posw, device=logits.device)
        bce  = F.binary_cross_entropy_with_logits(logits, t, pos_weight=posw)
        siou = self.iou(logits, t)
        loss = self.lambda_bce*bce + (1.0-self.lambda_bce)*siou
        if self.lambda_edge>0:
            p = torch.sigmoid(logits)
            loss = loss + self.lambda_edge * F.l1_loss(self._edge(p), self._edge(t))
        return loss

def blended_loss(core, aftl, w, logits, targets):
    loss = w["w_core"] * core(logits, targets)
    if aftl is not None and w.get("w_aftl", 0) > 0:
        loss = loss + w["w_aftl"] * aftl(logits, targets)
    return loss

def _make_loss_for_epoch(ep: int):
    # Early recall → mid mixed → late precision (edge)
    if ep <= 10:
        core = BCEIoUEdge(lambda_bce=0.6, pos_weight=8.0, lambda_edge=0.00).to(device); aftl=None
        return core, aftl, {"w_core":1.0, "w_aftl":0.0}
    elif ep <= 25:
        core = BCEIoUEdge(lambda_bce=0.6, pos_weight=8.0, lambda_edge=0.00).to(device)
        aftl = AFTL(alpha=0.45, beta=0.55, gamma=1.3).to(device)
        return core, aftl, {"w_core":0.85, "w_aftl":0.15}
    else:
        core = BCEIoUEdge(lambda_bce=0.8, pos_weight=8.0, lambda_edge=0.03).to(device)
        aftl = AFTL(alpha=0.45, beta=0.55, gamma=1.3).to(device)
        return core, aftl, {"w_core":0.85, "w_aftl":0.15}

def _make_opt_sched(ep: int, base_lrs, weight_decay):
    if ep <= 12:  base_lr = base_lrs[0]
    elif ep <= 25: base_lr = base_lrs[1]
    else:          base_lr = base_lrs[2]
    opt = torch.optim.Adam([p for p in model.parameters() if p.requires_grad],
                           lr=base_lr, weight_decay=weight_decay)
    sched = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(opt, T_0=6, T_mult=2, eta_min=base_lr/10)
    return opt, sched

def set_requires_grad(mod, flag: bool):
    for p in mod.parameters(): p.requires_grad = flag
        
def freeze_all(model): set_requires_grad(model, False)
    
def _unfreeze_if_exists(path: str):
    mod = model
    for name in path.split('.'):
        if not hasattr(mod, name): return False
        mod = getattr(mod, name)
    for p in mod.parameters(): p.requires_grad = True
    return True

def _apply_phase(ep: int):
    """
    Phases:
      1–3:   head only
      4–12:  head + tail (u4, u3, aspp)
      13–25: head + u2,u3,u4,aspp
      26+:   full
    """
    freeze_all(model)
    groups = []
    if ep <= 3:
        if hasattr(model, "head"):
            for p in model.head.parameters(): p.requires_grad = True
        groups = ["head"]
    elif ep <= 12:
        if hasattr(model, "head"):
            for p in model.head.parameters(): p.requires_grad = True
        for g in ["u4","u3","aspp"]:
            if _unfreeze_if_exists(g): groups.append(g)
    elif ep <= 25:
        if hasattr(model, "head"):
            for p in model.head.parameters(): p.requires_grad = True
        for g in ["u4","u3","u2","aspp"]:
            if _unfreeze_if_exists(g): groups.append(g)
    else:
        for p in model.parameters(): p.requires_grad = True
        groups = ["<FULL>"]
    ntrain = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"[phase] ep={ep} | trainable params={ntrain:,} | groups={groups}")
    return groups

def _maybe_init_head_bias_to_prior(model, p0=0.70):
    """If model.head.bias exists, set it to logit(p0) to start slightly positive-prior aware."""
    if p0 is None: return
    if hasattr(model, "head") and hasattr(model.head, "bias"):
        with torch.no_grad():
            b = math.log(p0/(1-p0))
            model.head.bias.data.fill_(b)

In [38]:
@torch.no_grad()
def pix_metrics(model, loader, thr=0.5, n_batches=6):
    model.eval()
    dev = next(model.parameters()).device
    tp = fp = fn = 0.0
    pos_means, neg_means = [], []
    t0 = time.time()
    for i,(xb,yb) in enumerate(loader,1):
        xb, yb = xb.to(dev, non_blocking=True), yb.to(dev, non_blocking=True)
        logits = model(xb)
        yb_r   = resize_masks_to(logits, yb)
        p      = torch.sigmoid(logits)
        if (yb_r>0.5).any(): pos_means.append(float(p[yb_r>0.5].mean()))
        neg_means.append(float(p[yb_r<=0.5].mean()))
        pv, tv = p.reshape(-1), yb_r.reshape(-1)
        pred   = (pv>=thr).float()
        tp += float((pred*tv).sum()); fp += float((pred*(1-tv)).sum()); fn += float(((1-pred)*tv).sum())
        if i>=n_batches: break
    P = tp/max(tp+fp,1); R = tp/max(tp+fn,1); f1 = 2*P*R/max(P+R,1e-8)
    print(f"[quick_prob_stats] batches={min(n_batches,i)} | pos≈{np.mean(pos_means) if pos_means else float('nan'):.4f} | "
          f"neg≈{np.mean(neg_means):.4f} | P {P:.3f} R {R:.3f} F1 {f1:.3f} @ thr={thr:.3f} | {time.time()-t0:.1f}s")
    return dict(P=P,R=R,F1=f1,pos_mean=np.mean(pos_means) if pos_means else float('nan'),neg_mean=np.mean(neg_means))

@torch.no_grad()
def pick_thr_under_min(model, loader, max_batches=40, n_bins=256, beta=2.0):
    """Histogram-based pixel threshold selection (recall-lean if beta>1)."""
    model.eval(); dev = next(model.parameters()).device
    hist_pos = torch.zeros(n_bins, device=dev); hist_neg = torch.zeros(n_bins, device=dev)
    edges = torch.linspace(0,1,n_bins+1, device=dev)
    for i,(xb,yb) in enumerate(loader,1):
        xb,yb = xb.to(dev), yb.to(dev)
        p = torch.sigmoid(model(xb))
        yb_r = resize_masks_to(p, yb)
        pv = p.reshape(-1); tv = (yb_r>0.5).reshape(-1)
        hist_pos += torch.histc(pv[tv], bins=n_bins, min=0, max=1)
        hist_neg += torch.histc(pv[~tv], bins=n_bins, min=0, max=1)
        if i>=max_batches: break
    cpos = torch.flip(torch.cumsum(torch.flip(hist_pos, dims=[0]), 0), dims=[0])  # >=t
    cneg = torch.flip(torch.cumsum(torch.flip(hist_neg, dims=[0]), 0), dims=[0])
    TP = cpos; FP = cneg; FN = (hist_pos.sum() - TP).clamp(min=0)
    P = TP / (TP + FP + 1e-8); R = TP / (TP + FN + 1e-8)
    fbeta = (1+beta*beta)*P*R / (beta*beta*P + R + 1e-8)
    idx = int(torch.argmax(fbeta).item())
    thr = float((edges[idx] + edges[idx+1])/2)
    return thr, (float(P[idx]), float(R[idx]), float(fbeta[idx])), dict(pos_rate=float((TP[idx]+FP[idx])/(hist_pos.sum()+hist_neg.sum()+1e-8)))

@torch.no_grad()
def pick_thr_with_floor(model, loader, max_batches=40, n_bins=256, beta=1.0, min_pos_rate=0.05, max_pos_rate=0.10):
    thr, (P,R,F), aux = pick_thr_under_min(model, loader, max_batches=max_batches, n_bins=n_bins, beta=beta)
    # simple clamp pass using percentile of preds to hit pos_rate band
    # (if your earlier “floor” function is available, feel free to swap it in)
    return thr, (P,R,F), aux


In [39]:
# ===========================================================
# End-to-end training pipeline (warmup → probe → train → pick thr)
# ===========================================================
def train_full_probe(
    model: nn.Module,
    train_loader,
    val_loader,
    resize_masks_to,                 # callable(logits, y) -> y_resized
    pick_thr_with_floor,             # callable(model, loader, max_batches, n_bins, beta, min_pos_rate, max_pos_rate)
    *,
    device=None,
    seed: int = 1337,
    init_head_prior: float = 0.70,   # initialize head bias to prior P(Y=1)
    # Warmup (BCE only, whole net) — quick grads + stable head
    warmup_epochs: int = 1,
    warmup_batches: int = 800,
    warmup_lr: float = 2e-4,
    warmup_pos_weight: float = 40.0,
    # Head-only calibration (BCE)
    head_epochs: int = 2,
    head_batches: int = 2000,
    head_lr: float = 3e-5,
    head_pos_weight: float = 5.0,
    # Tail probe (optional gentle shape)
    tail_epochs: int = 2,
    tail_batches: int = 2500,
    tail_lr: float = 1.5e-4,
    tail_pos_weight: float = 2.0,
    # Long training (cosine restarts + curriculum)
    max_epochs: int = 60,
    val_every: int = 3,
    base_lrs=(3e-4, 2e-4, 1e-4),     # (early, mid, late)
    weight_decay: float = 1e-4,
    # Threshold selection
    thr_beta: float = 1.0,           # Fβ for picking thresholds during long training
    thr_pos_rate_early=(0.03, 0.10),
    thr_pos_rate_late=(0.08, 0.12),
    # Checkpointing
    save_best_to: str | None = "ckpt_best.pt",
    # Print & eval settings
    quick_eval_train_batches: int = 6,
    quick_eval_val_batches: int = 12,
):
    """
    Returns:
        model (nn.Module): trained model (best weights loaded if save_best_to is not None)
        metric_thr (float): final recommended threshold
        summary (dict): basic metrics {'best_F','best_P','best_R','best_ep','final_thr'}
    Notes:
        - Assumes model has attributes like 'head', optionally 'u2','u3','u4','aspp'. Uses best-effort unfreezing.
        - Requires your existing resize_masks_to and pick_thr_with_floor utilities.
    """
    # --------------- Setup ---------------
    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)

    model.to(device)
    _maybe_init_head_bias_to_prior(model, init_head_prior)
    # --------------- Warmup (BCE-only) ---------------
    print("Warmup…")
    freeze_all(model)
    for p in model.parameters(): p.requires_grad = True  # warm everything briefly
    posw = torch.tensor(warmup_pos_weight, device=device)
    opt = torch.optim.Adam(model.parameters(), lr=warmup_lr, weight_decay=0.0)
    for ep in range(1, warmup_epochs+1):
        model.train(); seen=0; loss_sum=0.0; tp=fp=fn=0.0
        for b,(xb,yb) in enumerate(train_loader, 1):
            xb,yb = xb.to(device), yb.to(device)
            logits = model(xb); yb_r = resize_masks_to(logits, yb)
            loss = F.binary_cross_entropy_with_logits(logits, yb_r, pos_weight=posw)
            opt.zero_grad(set_to_none=True); loss.backward(); opt.step()
            loss_sum += float(loss.item())*xb.size(0); seen += xb.size(0)
            if b>=warmup_batches: break
        stats = _pix_eval(model, train_loader, thr=0.2, max_batches=quick_eval_train_batches)
        print(f"[WARMUP] ep{ep} loss {loss_sum/seen:.4f} | F1 {stats['F']:.3f} P {stats['P']:.3f} R {stats['R']:.3f}")

    # initial recall-friendly threshold
    thr0, *_ = pick_thr_with_floor(model, val_loader, max_batches=200, n_bins=256,
                                   beta=2.0, min_pos_rate=thr_pos_rate_early[0], max_pos_rate=thr_pos_rate_early[1])
    thr0 = float(max(0.05, min(0.20, thr0)))
    print(f"[thr0] ≈ {thr0:.3f}")

    # --------------- Head-only calibration (BCE) ---------------
    freeze_all(model)
    if hasattr(model, "head"):
        for p in model.head.parameters(): p.requires_grad = True
    head_posw = torch.tensor(head_pos_weight, device=device)
    opt = torch.optim.Adam([p for p in model.parameters() if p.requires_grad], lr=head_lr, weight_decay=0.0)
    for ep in range(1, head_epochs+1):
        model.train(); seen=0; loss_sum=0.0
        tp=fp=fn=0.0
        for b,(xb,yb) in enumerate(train_loader, 1):
            xb,yb = xb.to(device), yb.to(device)
            logits = model(xb); yb_r = resize_masks_to(logits, yb)
            loss = F.binary_cross_entropy_with_logits(logits, yb_r, pos_weight=head_posw)
            opt.zero_grad(set_to_none=True); loss.backward(); opt.step()
            loss_sum += float(loss.item())*xb.size(0); seen += xb.size(0)
            if b>=head_batches: break
        stats = _pix_eval(model, train_loader, thr=thr0, max_batches=quick_eval_train_batches)
        print(f"[HEAD] ep{ep} loss {loss_sum/seen:.4f} | F1 {stats['F']:.3f} P {stats['P']:.3f} R {stats['R']:.3f}")

    # --------------- Tail probe (gentle BCE+IoU) ---------------
    # Unfreeze tail (best-effort): head + (u4,u3,aspp)
    freeze_all(model)
    if hasattr(model, "head"):
        for p in model.head.parameters(): p.requires_grad = True
    tails = []
    for g in ["u4","u3","aspp"]:
        if _unfreeze_if_exists(g): tails.append(g)
    core_probe = BCEIoUEdge(lambda_bce=0.9, pos_weight=tail_pos_weight, lambda_edge=0.0).to(device)
    opt = torch.optim.Adam([p for p in model.parameters() if p.requires_grad], lr=tail_lr, weight_decay=1e-4)
    for ep in range(1, tail_epochs+1):
        model.train(); seen=0; loss_sum=0.0
        for b,(xb,yb) in enumerate(train_loader, 1):
            xb,yb = xb.to(device), yb.to(device)
            logits = model(xb); yb_r = resize_masks_to(logits, yb)
            loss = core_probe(logits, yb_r)
            opt.zero_grad(set_to_none=True); loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            opt.step()
            loss_sum += float(loss.item())*xb.size(0); seen += xb.size(0)
            if b>=tail_batches: break
    stats = _pix_eval(model, train_loader, thr=thr0, max_batches=quick_eval_train_batches)
    print(f"[tail-probe] loss≈{loss_sum/seen:.4f}")
    print("[quick_prob_stats] train @ thr0:", {k:round(v,3) for k,v in stats.items()})

    # --------------- Long training (cosine restarts + curriculum) ---------------
    best = {"F": -1.0, "state": None, "thr": thr0, "ep": 0}
    metric_thr = thr0

    for ep in range(1, max_epochs+1):
        # freeze/unfreeze per phase
        groups = _apply_phase(ep)
        core, aftl, w = _make_loss_for_epoch(ep)
        opt, sched = _make_opt_sched(ep, base_lrs, weight_decay)

        model.train(); seen=0; loss_sum=0.0; t0=time.time()
        for i,(xb,yb) in enumerate(train_loader, 1):
            xb,yb = xb.to(device), yb.to(device)
            logits = model(xb); yb_r = resize_masks_to(logits, yb)
            loss = blended_loss(core, aftl, w, logits, yb_r)
            opt.zero_grad(set_to_none=True); loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            opt.step(); sched.step(i)
            loss_sum += float(loss.item())*xb.size(0); seen += xb.size(0)
        train_loss = loss_sum/seen
        tr_stats = _pix_eval(model, train_loader, thr=metric_thr, max_batches=quick_eval_train_batches)
        print(f"[EP{ep:02d}] loss {train_loss:.4f} | train P {tr_stats['P']:.3f} R {tr_stats['R']:.3f} F {tr_stats['F']:.3f} "
              f"| pos≈{tr_stats['pos_mean']:.3f} neg≈{tr_stats['neg_mean']:.3f} | {time.time()-t0:.1f}s")

        if ep % val_every == 0 or ep <= 3:
            # pick threshold; early epochs looser pos-rate, late epochs tighter
            pr_min, pr_max = (thr_pos_rate_early if ep < 26 else thr_pos_rate_late)
            thr, (VP,VR,VF), aux = pick_thr_with_floor(
                model, val_loader,
                max_batches=120, n_bins=256, beta=thr_beta,
                min_pos_rate=pr_min, max_pos_rate=pr_max
            )
            metric_thr = float(thr)
            print(f"[thr@ep{ep}] thr={metric_thr:.3f} | val P {VP:.3f} R {VR:.3f} F {VF:.3f} | pos_rate≈{aux['pos_rate']:.3f}")

            val_stats = _pix_eval(model, val_loader, thr=metric_thr, max_batches=quick_eval_val_batches)
            print(f"[VAL ep{ep}] P {val_stats['P']:.3f} R {val_stats['R']:.3f} F {val_stats['F']:.3f}")
            if val_stats['F'] > best["F"]:
                best = {"F": val_stats['F'], "state": copy.deepcopy(model.state_dict()),
                        "thr": metric_thr, "ep": ep, "P": val_stats["P"], "R": val_stats["R"]}
                if save_best_to:
                    torch.save({"state": best["state"], "thr": best["thr"], "ep": best["ep"],
                                "P":best["P"], "R":best["R"], "F":best["F"]}, save_best_to)
                    print(f"  ↳ saved best → {save_best_to} (F={best['F']:.3f}, thr={best['thr']:.3f}, ep={best['ep']})")

    # Load best and return
    if best["state"] is not None:
        model.load_state_dict(best["state"], strict=True)
    summary = {"best_F": float(best["F"]), "best_P": float(best.get("P", 0.0)), "best_R": float(best.get("R", 0.0)),
               "best_ep": int(best["ep"]), "final_thr": float(best["thr"])}
    print("=== DONE ===")
    print("Best summary:", summary)
    return model, best["thr"], summary

In [20]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
with h5py.File(DATA["train_h5"], "r") as f:
    N = f["images"].shape[0]
idx = np.arange(N); np.random.shuffle(idx)
split = int(0.9*N)
idx_tr, idx_va = np.sort(idx[:split]), np.sort(idx[split:])

ds_full = H5TiledDataset(DATA["train_h5"], tile=TILE, k_sigma=5.0)

pos_panels = panels_with_positives(DATA["train_h5"], max_panels=2000)
sub_tr = np.random.default_rng(SEED).choice(np.intersect1d(idx_tr, pos_panels),
                                            size=min(200, len(pos_panels)), replace=False)
sub_va = np.random.default_rng(SEED+1).choice(np.intersect1d(idx_va, pos_panels),
                                              size=min(80, len(pos_panels)), replace=False)

train_loader_small = DataLoader(SubsetDS(ds_full, np.sort(sub_tr)), batch_size=BATCH, shuffle=True,
                                num_workers=NUM_WORKERS, pin_memory=(device.type=='cuda'))
val_loader_small   = DataLoader(SubsetDS(ds_full, np.sort(sub_va)), batch_size=BATCH, shuffle=False,
                                num_workers=NUM_WORKERS, pin_memory=(device.type=='cuda'))

# Full loaders (for end-to-end epochs)
train_loader = DataLoader(SubsetDS(ds_full, idx_tr), batch_size=BATCH, shuffle=True,
                          num_workers=NUM_WORKERS, pin_memory=(device.type=='cuda'))
val_loader   = DataLoader(SubsetDS(ds_full, idx_va), batch_size=BATCH, shuffle=False,
                          num_workers=NUM_WORKERS, pin_memory=(device.type=='cuda'))


In [40]:
model = UNetResSEASPP(in_ch=1, out_ch=1).to('cuda')
model, thr, summary = train_full_probe(
    model,
    train_loader=train_loader,
    val_loader=val_loader,
    resize_masks_to=resize_masks_to,
    pick_thr_with_floor=pick_thr_with_floor,
    save_best_to="ckpt_best.pt",
)
print("Final threshold:", thr)
print("Summary:", summary)


Warmup…
[WARMUP] ep1 loss 0.3786 | F1 0.028 P 0.014 R 0.478
[thr0] ≈ 0.200
[HEAD] ep1 loss 0.0908 | F1 0.011 P 0.006 R 0.055
[HEAD] ep2 loss 0.0827 | F1 0.037 P 0.025 R 0.068
[tail-probe] loss≈0.1325
[quick_prob_stats] train @ thr0: {'P': 0.0, 'R': 0.0, 'F': 0.0, 'pos_mean': 0.012, 'neg_mean': 0.005}
[phase] ep=1 | trainable params=33 | groups=['head']
[EP01] loss 0.4639 | train P 0.059 R 0.063 F 0.061 | pos≈0.081 neg≈0.023 | 609.8s
[thr@ep1] thr=0.178 | val P 0.055 R 0.112 F 0.074 | pos_rate≈0.007
[VAL ep1] P 0.357 R 0.163 F 0.224
  ↳ saved best → ckpt_best.pt (F=0.224, thr=0.178, ep=1)
[phase] ep=2 | trainable params=33 | groups=['head']
[EP02] loss 0.4637 | train P 0.030 R 0.063 F 0.041 | pos≈0.047 neg≈0.024 | 610.6s
[thr@ep2] thr=0.189 | val P 0.054 R 0.114 F 0.074 | pos_rate≈0.007
[VAL ep2] P 0.353 R 0.167 F 0.226
  ↳ saved best → ckpt_best.pt (F=0.226, thr=0.189, ep=2)
[phase] ep=3 | trainable params=33 | groups=['head']
[EP03] loss 0.4636 | train P 0.043 R 0.148 F 0.067 | pos≈0.

KeyboardInterrupt: 