In [15]:
# ==== Cell 1: Setup & Safeguards ====
import os, math, json, time, copy, random, h5py
import numpy as np
import torch, torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

# --- Repro ---
SEED = 123
random.seed(SEED); np.random.seed(SEED)
torch.manual_seed(SEED); torch.cuda.manual_seed_all(SEED)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("device:", device)

DATA = {
    "train_h5": "/home/karlo/train_chunked.h5",
    "test_h5":  "../DATA/test.h5",
    "train_csv": "../DATA/train.csv",  
    "test_csv":  "../DATA/test.csv",
}
TILE         = 128
BATCH        = 64
NUM_WORKERS  = 2          # HDF5 is happier with 0–2

# --- IMPORTANT: don't shadow torch.nn.functional as "F" anywhere below ---
# Avoid variables named "F" (use f1, fbeta, etc.)

# Expect these to exist from your data code:
# - train_loader_small: small loader with guaranteed positives (for warmup/head/tail probe)
# - train_loader: full training loader
# - val_loader_small or val_loader: a validation loader (either is fine)
# If their names differ, just alias them here:
# val_loader = val_loader_small

# --- Helper: ensure mask matches model output size ---
def resize_masks_to(logits, masks):
    if logits.shape[-2:] == masks.shape[-2:]:
        return masks
    return F.interpolate(masks, size=logits.shape[-2:], mode="nearest")

# --- Tiny metric (pixel-level) ---
@torch.no_grad()
def pix_metrics(model, loader, thr=0.5, n_batches=6):
    model.eval()
    dev = next(model.parameters()).device
    tp = fp = fn = 0.0
    pos_means, neg_means = [], []
    t0 = time.time()
    for i,(xb,yb) in enumerate(loader,1):
        xb, yb = xb.to(dev, non_blocking=True), yb.to(dev, non_blocking=True)
        logits = model(xb)
        yb_r   = resize_masks_to(logits, yb)
        p      = torch.sigmoid(logits)
        if (yb_r>0.5).any(): pos_means.append(float(p[yb_r>0.5].mean()))
        neg_means.append(float(p[yb_r<=0.5].mean()))
        pv, tv = p.reshape(-1), yb_r.reshape(-1)
        pred   = (pv>=thr).float()
        tp += float((pred*tv).sum()); fp += float((pred*(1-tv)).sum()); fn += float(((1-pred)*tv).sum())
        if i>=n_batches: break
    P = tp/max(tp+fp,1); R = tp/max(tp+fn,1); f1 = 2*P*R/max(P+R,1e-8)
    print(f"[quick_prob_stats] batches={min(n_batches,i)} | pos≈{np.mean(pos_means) if pos_means else float('nan'):.4f} | "
          f"neg≈{np.mean(neg_means):.4f} | P {P:.3f} R {R:.3f} F1 {f1:.3f} @ thr={thr:.3f} | {time.time()-t0:.1f}s")
    return dict(P=P,R=R,F1=f1,pos_mean=np.mean(pos_means) if pos_means else float('nan'),neg_mean=np.mean(neg_means))


device: cuda


In [16]:
# ==== Cell 2: Losses (SoftIoU + BCE) ====

class SoftIoU(nn.Module):
    def __init__(self, eps=1e-6): 
        super().__init__(); self.eps=eps
    def forward(self, logits, targets):
        p = torch.sigmoid(logits)
        t = targets.clamp(0,1)
        p = p.view(p.size(0), -1)
        t = t.view(t.size(0), -1)
        inter = (p*t).sum(1)
        union = p.sum(1) + t.sum(1) - inter
        iou   = (inter + self.eps) / (union + self.eps)
        return (1 - iou).mean()

class SoftIoUWithBCE(nn.Module):
    """
    total = lambda_bce * BCE(pos_weight) + (1 - lambda_bce) * SoftIoU
    """
    def __init__(self, pos_weight=8.0, lambda_bce=0.7):
        super().__init__()
        self.lambda_bce = float(lambda_bce)
        self.pos_weight = float(pos_weight)
        self.bce  = nn.BCEWithLogitsLoss(reduction='mean')
        self.siou = SoftIoU()
    def forward(self, logits, targets):
        t = targets.clamp(0,1)
        loss_bce  = self.bce(logits, t) if self.pos_weight<=0 else \
                    F.binary_cross_entropy_with_logits(logits, t, pos_weight=torch.tensor(self.pos_weight, device=logits.device))
        loss_siou = self.siou(logits, t)
        return self.lambda_bce*loss_bce + (1.0-self.lambda_bce)*loss_siou


In [17]:
# ==== Cell 3: Fast threshold pickers ====

@torch.no_grad()
def pick_thr_under_min(model, loader, max_batches=40, n_bins=256, beta=2.0):
    """Histogram-based pixel threshold selection (recall-lean if beta>1)."""
    model.eval(); dev = next(model.parameters()).device
    hist_pos = torch.zeros(n_bins, device=dev); hist_neg = torch.zeros(n_bins, device=dev)
    edges = torch.linspace(0,1,n_bins+1, device=dev)
    for i,(xb,yb) in enumerate(loader,1):
        xb,yb = xb.to(dev), yb.to(dev)
        p = torch.sigmoid(model(xb))
        yb_r = resize_masks_to(p, yb)
        pv = p.reshape(-1); tv = (yb_r>0.5).reshape(-1)
        hist_pos += torch.histc(pv[tv], bins=n_bins, min=0, max=1)
        hist_neg += torch.histc(pv[~tv], bins=n_bins, min=0, max=1)
        if i>=max_batches: break
    cpos = torch.flip(torch.cumsum(torch.flip(hist_pos, dims=[0]), 0), dims=[0])  # >=t
    cneg = torch.flip(torch.cumsum(torch.flip(hist_neg, dims=[0]), 0), dims=[0])
    TP = cpos; FP = cneg; FN = (hist_pos.sum() - TP).clamp(min=0)
    P = TP / (TP + FP + 1e-8); R = TP / (TP + FN + 1e-8)
    fbeta = (1+beta*beta)*P*R / (beta*beta*P + R + 1e-8)
    idx = int(torch.argmax(fbeta).item())
    thr = float((edges[idx] + edges[idx+1])/2)
    return thr, (float(P[idx]), float(R[idx]), float(fbeta[idx])), dict(pos_rate=float((TP[idx]+FP[idx])/(hist_pos.sum()+hist_neg.sum()+1e-8)))

@torch.no_grad()
def pick_thr_with_floor(model, loader, max_batches=40, n_bins=256, beta=1.0, min_pos_rate=0.05, max_pos_rate=0.10):
    thr, (P,R,F), aux = pick_thr_under_min(model, loader, max_batches=max_batches, n_bins=n_bins, beta=beta)
    # simple clamp pass using percentile of preds to hit pos_rate band
    # (if your earlier “floor” function is available, feel free to swap it in)
    return thr, (P,R,F), aux


In [18]:
# ==== Cell 4: Freeze/Unfreeze + Warmup/Head fits ====

def set_requires_grad(mod, flag: bool):
    for p in mod.parameters(): p.requires_grad = flag

def freeze_all(model): set_requires_grad(model, False)

def unfreeze_head_only(model):
    freeze_all(model)
    if hasattr(model, "head"):
        set_requires_grad(model.head, True)
    else:
        raise AttributeError("Model has no attribute 'head'.")

def unfreeze_head_and_tail(model):
    """
    Unfreeze head + late upsample blocks + ASPP (assumes UNetResSEASPP)
    Adjust attribute names if your class differs.
    """
    freeze_all(model)
    if hasattr(model, "head"): set_requires_grad(model.head, True)
    for path in ["u3", "u4", "aspp"]:
        if hasattr(model, path): set_requires_grad(getattr(model, path), True)

def init_head_bias_to_prior(model, p0=0.70):
    b = math.log(p0/(1-p0))
    with torch.no_grad():
        if hasattr(model, "head") and hasattr(model.head, "bias"):
            model.head.bias.data.fill_(b)

def fit_quick_warmup(model, loader, epochs=2, max_batches=800, lr=2e-4, metric_thr=0.20, pos_weight=30.0):
    dev = next(model.parameters()).device
    model.train()
    opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=0.0)
    bce = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(pos_weight, device=dev))
    for ep in range(1, epochs+1):
        seen=tp=fp=fn=0.0; loss_sum=0.0
        t0=time.time()
        for b,(xb,yb) in enumerate(loader,1):
            xb,yb = xb.to(dev), yb.to(dev)
            logits = model(xb)
            yb_r   = resize_masks_to(logits, yb)
            loss = bce(logits, yb_r)
            opt.zero_grad(set_to_none=True); loss.backward(); opt.step()
            with torch.no_grad():
                p = torch.sigmoid(logits); pv, tv = p.view(-1), yb_r.view(-1)
                pred = (pv>=metric_thr).float()
                tp += float((pred*tv).sum()); fp += float((pred*(1-tv)).sum()); fn += float(((1-pred)*tv).sum())
                loss_sum += float(loss.item())*xb.size(0); seen += xb.size(0)
            if b>=max_batches: break
        P = tp/max(tp+fp,1); R = tp/max(tp+fn,1); f1 = 2*P*R/max(P+R,1e-8)
        print(f"[WARMUP] ep{ep} loss {loss_sum/seen:.4f} | F1 {f1:.3f} P {P:.3f} R {R:.3f}")

def fit_head_only(model, loader, epochs=2, max_batches=600, lr=3e-5, metric_thr=0.15, pos_weight=5.0):
    dev = next(model.parameters()).device
    unfreeze_head_only(model)
    head_params = [p for p in model.head.parameters() if p.requires_grad]
    opt = torch.optim.Adam(head_params, lr=lr, weight_decay=0.0)
    bce = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(pos_weight, device=dev))
    for ep in range(1, epochs+1):
        seen=tp=fp=fn=0.0; loss_sum=0.0
        for b,(xb,yb) in enumerate(loader,1):
            xb,yb = xb.to(dev), yb.to(dev)
            logits = model(xb); yb_r = resize_masks_to(logits, yb)
            loss = bce(logits, yb_r)
            opt.zero_grad(set_to_none=True); loss.backward(); opt.step()
            with torch.no_grad():
                p = torch.sigmoid(logits); pv, tv = p.view(-1), yb_r.view(-1)
                pred = (pv>=metric_thr).float()
                tp += float((pred*tv).sum()); fp += float((pred*(1-tv)).sum()); fn += float(((1-pred)*tv).sum())
                loss_sum += float(loss.item())*xb.size(0); seen += xb.size(0)
            if b>=max_batches: break
        P = tp/max(tp+fp,1); R = tp/max(tp+fn,1); f1 = 2*P*R/max(P+R,1e-8)
        print(f"[HEAD] ep{ep} loss {loss_sum/seen:.4f} | F1 {f1:.3f} P {P:.3f} R {R:.3f}")


In [19]:
# --- MODEL ---
class SEBlock(nn.Module):
    def __init__(self,c,r=8): super().__init__(); self.fc1=nn.Conv2d(c,c//r,1); self.fc2=nn.Conv2d(c//r,c,1)
    def forward(self,x): s=F.adaptive_avg_pool2d(x,1); s=F.silu(self.fc1(s),inplace=True); s=torch.sigmoid(self.fc2(s)); return x*s
def _norm(c, groups=8): g=min(groups,c) if c%groups==0 else 1; return nn.GroupNorm(g,c)

class ResBlock(nn.Module):
    def __init__(self,c_in,c_out): super().__init__(); p=1
    def __init__(self,c_in,c_out,k=3,act=nn.SiLU,se=True):
        super().__init__(); p=k//2
        self.proj = nn.Identity() if c_in==c_out else nn.Conv2d(c_in,c_out,1)
        self.bn1=_norm(c_in); self.c1=nn.Conv2d(c_in,c_out,k,padding=p,bias=False)
        self.bn2=_norm(c_out); self.c2=nn.Conv2d(c_out,c_out,k,padding=p,bias=False)
        self.act=act(); self.se=SEBlock(c_out) if se else nn.Identity()
    def forward(self,x):
        h=self.act(self.bn1(x)); h=self.c1(h)
        h=self.act(self.bn2(h)); h=self.c2(h)
        h=self.se(h); return h + self.proj(x)

class Down(nn.Module):
    def __init__(self,c_in,c_out): super().__init__(); self.pool=nn.MaxPool2d(2); self.rb=ResBlock(c_in,c_out)
    def forward(self,x): return self.rb(self.pool(x))

class Up(nn.Module):
    def __init__(self,c_in,c_skip,c_out): super().__init__(); self.up=nn.ConvTranspose2d(c_in,c_in,2,stride=2); self.rb1=ResBlock(c_in+c_skip,c_out); self.rb2=ResBlock(c_out,c_out)
    def forward(self,x,skip):
        x=self.up(x)
        dh=skip.size(-2)-x.size(-2); dw=skip.size(-1)-x.size(-1)
        if dh or dw: x=F.pad(x,(0,max(0,dw),0,max(0,dh)))
        x=torch.cat([x,skip],1); x=self.rb1(x); x=self.rb2(x); return x

class ASPP(nn.Module):
    def __init__(self,c,r=[1,6,12,18]):
        super().__init__()
        self.blocks=nn.ModuleList([nn.Sequential(nn.Conv2d(c,c//4,3,padding=d,dilation=d,bias=False), nn.BatchNorm2d(c//4), nn.SiLU(True)) for d in r])
        self.project=nn.Conv2d(c,c,1)
    def forward(self,x): return self.project(torch.cat([b(x) for b in self.blocks],1))

class UNetResSE(nn.Module):
    def __init__(self,in_ch=1,out_ch=1,widths=(32,64,128,256,512)):
        super().__init__(); w=widths
        self.stem=nn.Sequential(nn.Conv2d(in_ch,w[0],3,padding=1,bias=False), nn.BatchNorm2d(w[0]), nn.SiLU(True), ResBlock(w[0],w[0]))
        self.d1=Down(w[0],w[1]); self.d2=Down(w[1],w[2]); self.d3=Down(w[2],w[3]); self.d4=Down(w[3],w[4])
        self.u1=Up(w[4],w[3],w[3]); self.u2=Up(w[3],w[2],w[2]); self.u3=Up(w[2],w[1],w[1]); self.u4=Up(w[1],w[0],w[0])
        self.head=nn.Conv2d(w[0],out_ch,1)
    def forward(self,x):
        s0=self.stem(x); s1=self.d1(s0); s2=self.d2(s1); s3=self.d3(s2); b=self.d4(s3)
        x=self.u1(b,s3); x=self.u2(x,s2); x=self.u3(x,s1); x=self.u4(x,s0); return self.head(x) # logits

class UNetResSEASPP(UNetResSE):
    def __init__(self,in_ch=1,out_ch=1,widths=(32,64,128,256,512)):
        super().__init__(in_ch,out_ch,widths); self.aspp=ASPP(widths[-1]); self.d4=Down(widths[3],widths[4])
    def forward(self,x):
        s0=self.stem(x); s1=self.d1(s0); s2=self.d2(s1); s3=self.d3(s2); b=self.d4(s3); b=self.aspp(b)
        x=self.u1(b,s3); x=self.u2(x,s2); x=self.u3(x,s1); x=self.u4(x,s0); return self.head(x)

In [20]:
def robust_stats_mad(arr):
    med = np.median(arr); mad = np.median(np.abs(arr - med))
    sigma = 1.4826 * (mad + 1e-12)
    return np.float32(med), np.float32(1.0 if not np.isfinite(sigma) or sigma<=0 else sigma)

class H5TiledDataset(Dataset):
    """Stream tiles from big (H,W) images, robust-normalize per-image, k-sigma clip, pad edges."""
    def __init__(self, h5_path, tile=128, k_sigma=5.0, crop_for_stats=512):
        self.h5_path, self.tile, self.k_sigma, self.crop_for_stats = h5_path, int(tile), float(k_sigma), int(crop_for_stats)
        self._h5 = self._x = self._y = None
        self._stats_cache = {}
        with h5py.File(self.h5_path, "r") as f:
            self.N, self.H, self.W = f["images"].shape
            assert f["masks"].shape == (self.N, self.H, self.W)
        Hb = math.ceil(self.H/self.tile); Wb = math.ceil(self.W/self.tile)
        self.indices = [(i, r, c) for i in range(self.N) for r in range(Hb) for c in range(Wb)]
    def _ensure(self):
        if self._h5 is None:
            self._h5 = h5py.File(self.h5_path, "r")
            self._x, self._y = self._h5["images"], self._h5["masks"]
    def _image_stats(self, i):
        if i in self._stats_cache: return self._stats_cache[i]
        s = min(self.crop_for_stats, self.H, self.W)
        h0, w0 = (self.H-s)//2, (self.W-s)//2
        crop = self._x[i, h0:h0+s, w0:w0+s].astype("float32")
        med, sig = robust_stats_mad(crop); self._stats_cache[i] = (med, sig); return med, sig
    def __len__(self): return len(self.indices)
    def __getitem__(self, idx):
        self._ensure()
        i, r, c = self.indices[idx]; t = self.tile
        r0, c0 = r*t, c*t; r1, c1 = min(r0+t, self.H), min(c0+t, self.W)
        x = self._x[i, r0:r1, c0:c1].astype("float32"); y = self._y[i, r0:r1, c0:c1].astype("float32")
        if x.shape != (t,t):
            xp = np.zeros((t,t), np.float32); yp = np.zeros((t,t), np.float32)
            xp[:x.shape[0], :x.shape[1]] = x; yp[:y.shape[0], :y.shape[1]] = y; x, y = xp, yp
        med, sig = self._image_stats(i); x = np.clip((x-med)/sig, -5, 5)
        return torch.from_numpy(x[None,...]), torch.from_numpy(y[None,...])

class SubsetDS(Dataset):
    """Select full panels by id while reusing tiling of base dataset."""
    def __init__(self, base, panel_ids):
        self.base, self.panel_ids = base, np.asarray(panel_ids)
        t = base.tile; Hb, Wb = math.ceil(base.H/t), math.ceil(base.W/t)
        base_map = {(i,r,c):k for k,(i,r,c) in enumerate(base.indices)}
        self.map = [base_map[(i,r,c)] for i in self.panel_ids for r in range(Hb) for c in range(Wb)]
    def __len__(self): return len(self.map)
    def __getitem__(self, k): return self.base[self.map[k]]

def tile_pos_weights(h5_path, tile=128):
    with h5py.File(h5_path, "r") as f:
        Y = f["masks"]; N,H,W = Y.shape
    Hb, Wb = math.ceil(H/tile), math.ceil(W/tile)
    w = []
    with h5py.File(h5_path,"r") as f:
        Y = f["masks"]
        for i in range(N):
            for r in range(Hb):
                for c in range(Wb):
                    r0,c0=r*tile,c*tile; r1,c1=min(r0+tile,H),min(c0+tile,W)
                    w.append(1.0 + 9.0*(Y[i,r0:r1,c0:c1].any()))
    return np.asarray(w, np.float64)


def panels_with_positives(h5_path, tile=128, max_panels=None):
    ids=[]
    with h5py.File(h5_path,'r') as f:
        Y=f['masks']; N,H,W=Y.shape
        rng = np.random.default_rng(0)
        order = rng.permutation(N) if max_panels else np.arange(N)
        for i in order:
            yi = Y[i]
            if yi.any(): ids.append(i)
            if max_panels and len(ids)>=max_panels: break
    return np.array(sorted(ids))

# per-panel median/MAD normalize, clip like stream_panels_direct
def norm_medmad_clip(x, clip=5.0, eps=1e-6):
    # x: torch.Tensor [B,1,H,W] or [1,H,W]
    if x.ndim == 4:
        med = x.median(dim=-1, keepdim=True).values.median(dim=-2, keepdim=True).values
    else:  # [1,H,W]
        med = x.median()
        med = med.view(1,1,1)
    mad = (x - med).abs().median()
    sigma = 1.4826 * mad + eps
    z = (x - med) / sigma
    return z.clamp_(-clip, clip)

class WithTransform(torch.utils.data.Dataset):
    def __init__(self, base): self.base = base
    def __len__(self): return len(self.base)
    def __getitem__(self, i):
        x, y = self.base[i]
        x = norm_medmad_clip(x)
        return x, y





In [28]:
with h5py.File(DATA["train_h5"], "r") as f:
    N = f["images"].shape[0]
idx = np.arange(N); np.random.shuffle(idx)
split = int(0.9*N); idx_tr, idx_va = np.sort(idx[:split]), np.sort(idx[split:])

ds_full  = H5TiledDataset(DATA["train_h5"], tile=TILE, k_sigma=5.0)

pos_panels = panels_with_positives(DATA["train_h5"], max_panels=2000)
# Sample small train/val from the intersection with idx_tr/idx_va
sub_tr = np.random.default_rng(SEED).choice(np.intersect1d(idx_tr, pos_panels), size=min(200, len(pos_panels)), replace=False)
sub_va = np.random.default_rng(SEED+1).choice(np.intersect1d(idx_va, pos_panels), size=min(80,  len(pos_panels)), replace=False)

train_ds_small = SubsetDS(ds_full, np.sort(sub_tr))
val_ds_small   = SubsetDS(ds_full, np.sort(sub_va))

train_ds_small = WithTransform(train_ds_small)
val_ds_small   = WithTransform(val_ds_small)

train_loader_small = DataLoader(train_ds_small, batch_size=BATCH, shuffle=True,  num_workers=NUM_WORKERS, pin_memory=(device.type=='cuda'))
val_loader_small   = DataLoader(val_ds_small,   batch_size=BATCH, shuffle=False, num_workers=NUM_WORKERS, pin_memory=(device.type=='cuda'))

train_ds = SubsetDS(ds_full, idx_tr)
val_ds   = SubsetDS(ds_full, idx_va)

train_ds = WithTransform(train_ds)
val_ds   = WithTransform(val_ds)

train_loader = DataLoader(train_ds, batch_size=BATCH, shuffle=True,  num_workers=NUM_WORKERS, pin_memory=(device.type=='cuda'))
val_loader   = DataLoader(val_ds,   batch_size=BATCH, shuffle=False, num_workers=NUM_WORKERS, pin_memory=(device.type=='cuda'))

In [29]:
# ==== Cell 5: Build model & warmup ====
# If UNetResSEASPP is already defined in your session, this will work directly.
# Otherwise import/define it before running this cell.

probe = UNetResSEASPP(in_ch=1, out_ch=1).to(device)
init_head_bias_to_prior(probe, p0=0.80)
torch.backends.cudnn.benchmark = False

print("Warmup…")
fit_quick_warmup(probe, train_loader_small, epochs=1, max_batches=400, lr=2e-4, metric_thr=0.20, pos_weight=40.0)

thr0, *_ = pick_thr_under_min(probe, val_loader_small, max_batches=40, n_bins=256, beta=2.0)
thr0 = float(np.clip(thr0, 0.10, 0.20))
print(f"[thr0] ≈ {thr0:.3f}")

# Head-only calibration
fit_head_only(probe, train_loader_small, epochs=2, max_batches=600, lr=3e-5, metric_thr=thr0, pos_weight=5.0)
_ = pix_metrics(probe, train_loader_small, thr=thr0, n_batches=6)
warm_state = copy.deepcopy(probe.state_dict())


Warmup…
[WARMUP] ep1 loss 0.3919 | F1 0.015 P 0.008 R 0.433
[thr0] ≈ 0.200
[HEAD] ep1 loss 0.1040 | F1 0.032 P 0.017 R 0.272
[HEAD] ep2 loss 0.0908 | F1 0.034 P 0.021 R 0.084
[quick_prob_stats] batches=6 | pos≈0.0896 | neg≈0.0350 | P 0.011 R 0.022 F1 0.015 @ thr=0.200 | 0.8s


In [30]:
# ==== Cell 6: Unfreeze tail + loss + brief probe ====
probe.load_state_dict(warm_state, strict=True)
unfreeze_head_and_tail(probe)

criterion = SoftIoUWithBCE(pos_weight=8.0, lambda_bce=0.7).to(device)
opt = torch.optim.Adam(filter(lambda p: p.requires_grad, probe.parameters()),
                       lr=3e-4, weight_decay=1e-4)

# Brief tail probe (optional)
probe.train(); loss_sum=0.0; seen=0
for b,(xb,yb) in enumerate(train_loader_small,1):
    xb,yb = xb.to(device), yb.to(device)
    logits = probe(xb); yb_r = resize_masks_to(logits, yb)
    loss = criterion(logits, yb_r)
    opt.zero_grad(set_to_none=True); loss.backward(); opt.step()
    loss_sum += float(loss.item())*xb.size(0); seen += xb.size(0)
    if b>=400: break
print(f"[tail-probe] loss≈{loss_sum/max(seen,1):.4f}")

_ = pix_metrics(probe, train_loader_small, thr=thr0, n_batches=6)


[tail-probe] loss≈0.3756
[quick_prob_stats] batches=6 | pos≈0.0436 | neg≈0.0169 | P 0.000 R 0.000 F1 0.000 @ thr=0.200 | 0.8s


In [None]:
# ==== Cell 7: Full training (3–6 epochs) with early-stop + threshold repick ====

def val_pick_thr(model, vloader, min_pr=0.05, max_pr=0.10):
    thr, (P,R,F), aux = pick_thr_with_floor(model, vloader, max_batches=60, n_bins=256, beta=1.0,
                                            min_pos_rate=min_pr, max_pos_rate=max_pr)
    return float(thr), (P,R,F), aux

def one_epoch(model, loader, criterion, opt):
    t0=time.time()
    model.train(); loss_sum=0.0; seen=0
    for b, (xb,yb) in enumerate(loader,1):
        xb,yb = xb.to(device), yb.to(device)
        logits = model(xb); yb_r = resize_masks_to(logits, yb)
        loss = criterion(logits, yb_r)
        opt.zero_grad(set_to_none=True); loss.backward(); opt.step()
        loss_sum += float(loss.item())*xb.size(0); seen += xb.size(0)
        print(f"\r[TRAIN] batch {b}/{len(loader)} | loss={loss_sum/seen:.4f} | {time.time()-t0:.1f}s", end='')
    print (f"\r")
    return loss_sum/max(seen,1)

def pix_eval(model, loader, thr):
    m = pix_metrics(model, loader, thr=thr, n_batches=6)
    return m["P"], m["R"], m["F1"]

# --- train ---
best_F = -1
best = {"state": None, "thr": None, "ep": 0}
epochs = 4  # adjust 3–6
metric_thr = thr0

for ep in range(1, epochs+1):
    # optional small LR bump if stall
    if ep == 3: 
        for g in opt.param_groups: g["lr"] = 4e-4

    ep_loss = one_epoch(probe, train_loader, criterion, opt)
    P_tr, R_tr, F_tr = pix_eval(probe, train_loader_small, thr=metric_thr)
    print(f"[EP{ep:02d}] loss {ep_loss:.4f} | train P {P_tr:.3f} R {R_tr:.3f} F {F_tr:.3f}")

    # repick thr on val (pos-rate band recall→precision balance)
    metric_thr, (VP,VR,VF), aux = val_pick_thr(probe, val_loader_small, min_pr=0.05, max_pr=0.10)
    print(f"[thr@ep{ep}] thr={metric_thr:.3f} | val P {VP:.3f} R {VR:.3f} F {VF:.3f} | pos_rate≈{aux['pos_rate']:.3f}")

    # simple early-stop on best F
    if VF > best_F:
        best_F = VF
        best["state"] = copy.deepcopy(probe.state_dict())
        best["thr"] = metric_thr
        best["ep"] = ep
        print(f"[VAL ep{ep}] P {VP:.3f} R {VR:.3f} F {VF:.3f}  ↳ saved best (F={best_F:.3f})")

print("Best so far:", best)


[TRAIN] batch 744/11520 | loss=0.3783 | 88.7s

In [None]:
# ==== Cell 8: Load best, final eval, save ====
if best["state"] is not None:
    probe.load_state_dict(best["state"], strict=True)
    metric_thr = best["thr"]

val_stats = pix_metrics(probe, val_loader_small, thr=metric_thr, n_batches=12)
print(f"[FINAL] thr={metric_thr:.3f} | val: P {val_stats['P']:.3f} R {val_stats['R']:.3f} F {val_stats['F1']:.3f}")

# quick train sanity
_ = pix_metrics(probe, train_loader_small, thr=metric_thr, n_batches=6)

# Save weights & threshold
os.makedirs("checkpoints", exist_ok=True)
ckpt_path = f"checkpoints/probe_softiou_bce_best_ep{best['ep']:02d}.pt"
torch.save({"model": probe.state_dict(), "thr": metric_thr, "epoch": best["ep"], "seed": SEED}, ckpt_path)

with open("checkpoints/probe_threshold.json","w") as f:
    json.dump({"metric_thr": float(metric_thr)}, f, indent=2)

print("Saved:", ckpt_path)
