In [1]:
# --- config ---
DATA = {
    "train_h5": "/home/karlo/train_chunked.h5",
    "test_h5":  "../DATA/test.h5",
    "train_csv": "../DATA/train.csv",   # or None if you don’t have it
    "test_csv":  "../DATA/test.csv",
}
TILE         = 128
BATCH        = 128
NUM_WORKERS  = 2          # HDF5 is happier with 0–2
SEED         = 1337
EPOCHS       = 20
MAX_LR       = 1.5e-4
SAVE_BEST    = "./best_unet_resse_aspp.pt"
SAVE_LAST    = "./last_unet_resse.pt"

In [2]:
import os, gc, time, math, random, h5py
import numpy as np
import torch, torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import roc_auc_score

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.backends.cudnn.benchmark = True

def set_seed(s=1337):
    random.seed(s); np.random.seed(s)
    torch.manual_seed(s); torch.cuda.manual_seed_all(s)
set_seed(SEED)

def resize_masks_to(logits, masks):
    """
    Resize binary masks to the spatial size of logits.
    - nearest neighbor to avoid soft edges
    - keep masks in {0,1}
    """
    H, W = logits.shape[-2:]
    if masks.dtype != torch.float32:
        masks = masks.float()
    if masks.dim() == 3:  # (B,H,W) -> (B,1,H,W)
        masks = masks.unsqueeze(1)
    out = torch.nn.functional.interpolate(masks, size=(H, W), mode='nearest')
    # keep it strictly 0/1 after nearest
    return (out > 0.5).float()


def _free_cuda():
    if device.type == 'cuda':
        torch.cuda.synchronize()
        torch.cuda.empty_cache()


In [3]:
def robust_stats_mad(arr):
    med = np.median(arr); mad = np.median(np.abs(arr - med))
    sigma = 1.4826 * (mad + 1e-12)
    return np.float32(med), np.float32(1.0 if not np.isfinite(sigma) or sigma<=0 else sigma)

class H5TiledDataset(Dataset):
    """Stream tiles from big (H,W) images, robust-normalize per-image, k-sigma clip, pad edges."""
    def __init__(self, h5_path, tile=128, k_sigma=5.0, crop_for_stats=512):
        self.h5_path, self.tile, self.k_sigma, self.crop_for_stats = h5_path, int(tile), float(k_sigma), int(crop_for_stats)
        self._h5 = self._x = self._y = None
        self._stats_cache = {}
        with h5py.File(self.h5_path, "r") as f:
            self.N, self.H, self.W = f["images"].shape
            assert f["masks"].shape == (self.N, self.H, self.W)
        Hb = math.ceil(self.H/self.tile); Wb = math.ceil(self.W/self.tile)
        self.indices = [(i, r, c) for i in range(self.N) for r in range(Hb) for c in range(Wb)]
    def _ensure(self):
        if self._h5 is None:
            self._h5 = h5py.File(self.h5_path, "r")
            self._x, self._y = self._h5["images"], self._h5["masks"]
    def _image_stats(self, i):
        if i in self._stats_cache: return self._stats_cache[i]
        s = min(self.crop_for_stats, self.H, self.W)
        h0, w0 = (self.H-s)//2, (self.W-s)//2
        crop = self._x[i, h0:h0+s, w0:w0+s].astype("float32")
        med, sig = robust_stats_mad(crop); self._stats_cache[i] = (med, sig); return med, sig
    def __len__(self): return len(self.indices)
    def __getitem__(self, idx):
        self._ensure()
        i, r, c = self.indices[idx]; t = self.tile
        r0, c0 = r*t, c*t; r1, c1 = min(r0+t, self.H), min(c0+t, self.W)
        x = self._x[i, r0:r1, c0:c1].astype("float32"); y = self._y[i, r0:r1, c0:c1].astype("float32")
        if x.shape != (t,t):
            xp = np.zeros((t,t), np.float32); yp = np.zeros((t,t), np.float32)
            xp[:x.shape[0], :x.shape[1]] = x; yp[:y.shape[0], :y.shape[1]] = y; x, y = xp, yp
        med, sig = self._image_stats(i); x = np.clip((x-med)/sig, -5, 5)
        return torch.from_numpy(x[None,...]), torch.from_numpy(y[None,...])

class SubsetDS(Dataset):
    """Select full panels by id while reusing tiling of base dataset."""
    def __init__(self, base, panel_ids):
        self.base, self.panel_ids = base, np.asarray(panel_ids)
        t = base.tile; Hb, Wb = math.ceil(base.H/t), math.ceil(base.W/t)
        base_map = {(i,r,c):k for k,(i,r,c) in enumerate(base.indices)}
        self.map = [base_map[(i,r,c)] for i in self.panel_ids for r in range(Hb) for c in range(Wb)]
    def __len__(self): return len(self.map)
    def __getitem__(self, k): return self.base[self.map[k]]

def tile_pos_weights(h5_path, tile=128):
    with h5py.File(h5_path, "r") as f:
        Y = f["masks"]; N,H,W = Y.shape
    Hb, Wb = math.ceil(H/tile), math.ceil(W/tile)
    w = []
    with h5py.File(h5_path,"r") as f:
        Y = f["masks"]
        for i in range(N):
            for r in range(Hb):
                for c in range(Wb):
                    r0,c0=r*tile,c*tile; r1,c1=min(r0+tile,H),min(c0+tile,W)
                    w.append(1.0 + 9.0*(Y[i,r0:r1,c0:c1].any()))
    return np.asarray(w, np.float64)


In [4]:
class SEBlock(nn.Module):
    def __init__(self,c,r=8): super().__init__(); self.fc1=nn.Conv2d(c,c//r,1); self.fc2=nn.Conv2d(c//r,c,1)
    def forward(self,x): s=F.adaptive_avg_pool2d(x,1); s=F.silu(self.fc1(s),inplace=True); s=torch.sigmoid(self.fc2(s)); return x*s
def _norm(c, groups=8): g=min(groups,c) if c%groups==0 else 1; return nn.GroupNorm(g,c)

class ResBlock(nn.Module):
    def __init__(self,c_in,c_out): super().__init__(); p=1
    def __init__(self,c_in,c_out,k=3,act=nn.SiLU,se=True):
        super().__init__(); p=k//2
        self.proj = nn.Identity() if c_in==c_out else nn.Conv2d(c_in,c_out,1)
        self.bn1=_norm(c_in); self.c1=nn.Conv2d(c_in,c_out,k,padding=p,bias=False)
        self.bn2=_norm(c_out); self.c2=nn.Conv2d(c_out,c_out,k,padding=p,bias=False)
        self.act=act(); self.se=SEBlock(c_out) if se else nn.Identity()
    def forward(self,x):
        h=self.act(self.bn1(x)); h=self.c1(h)
        h=self.act(self.bn2(h)); h=self.c2(h)
        h=self.se(h); return h + self.proj(x)

class Down(nn.Module):
    def __init__(self,c_in,c_out): super().__init__(); self.pool=nn.MaxPool2d(2); self.rb=ResBlock(c_in,c_out)
    def forward(self,x): return self.rb(self.pool(x))

class Up(nn.Module):
    def __init__(self,c_in,c_skip,c_out): super().__init__(); self.up=nn.ConvTranspose2d(c_in,c_in,2,stride=2); self.rb1=ResBlock(c_in+c_skip,c_out); self.rb2=ResBlock(c_out,c_out)
    def forward(self,x,skip):
        x=self.up(x)
        dh=skip.size(-2)-x.size(-2); dw=skip.size(-1)-x.size(-1)
        if dh or dw: x=F.pad(x,(0,max(0,dw),0,max(0,dh)))
        x=torch.cat([x,skip],1); x=self.rb1(x); x=self.rb2(x); return x

class ASPP(nn.Module):
    def __init__(self,c,r=[1,6,12,18]):
        super().__init__()
        self.blocks=nn.ModuleList([nn.Sequential(nn.Conv2d(c,c//4,3,padding=d,dilation=d,bias=False), nn.BatchNorm2d(c//4), nn.SiLU(True)) for d in r])
        self.project=nn.Conv2d(c,c,1)
    def forward(self,x): return self.project(torch.cat([b(x) for b in self.blocks],1))

class UNetResSE(nn.Module):
    def __init__(self,in_ch=1,out_ch=1,widths=(32,64,128,256,512)):
        super().__init__(); w=widths
        self.stem=nn.Sequential(nn.Conv2d(in_ch,w[0],3,padding=1,bias=False), nn.BatchNorm2d(w[0]), nn.SiLU(True), ResBlock(w[0],w[0]))
        self.d1=Down(w[0],w[1]); self.d2=Down(w[1],w[2]); self.d3=Down(w[2],w[3]); self.d4=Down(w[3],w[4])
        self.u1=Up(w[4],w[3],w[3]); self.u2=Up(w[3],w[2],w[2]); self.u3=Up(w[2],w[1],w[1]); self.u4=Up(w[1],w[0],w[0])
        self.head=nn.Conv2d(w[0],out_ch,1)
    def forward(self,x):
        s0=self.stem(x); s1=self.d1(s0); s2=self.d2(s1); s3=self.d3(s2); b=self.d4(s3)
        x=self.u1(b,s3); x=self.u2(x,s2); x=self.u3(x,s1); x=self.u4(x,s0); return self.head(x) # logits

class UNetResSEASPP(UNetResSE):
    def __init__(self,in_ch=1,out_ch=1,widths=(32,64,128,256,512)):
        super().__init__(in_ch,out_ch,widths); self.aspp=ASPP(widths[-1]); self.d4=Down(widths[3],widths[4])
    def forward(self,x):
        s0=self.stem(x); s1=self.d1(s0); s2=self.d2(s1); s3=self.d3(s2); b=self.d4(s3); b=self.aspp(b)
        x=self.u1(b,s3); x=self.u2(x,s2); x=self.u3(x,s1); x=self.u4(x,s0); return self.head(x)

class DiceLoss(nn.Module):
    def __init__(self,eps=1e-6): super().__init__(); self.eps=eps
    def forward(self,logits,targets):
        p=torch.sigmoid(logits); p=p.view(p.size(0),-1); t=targets.view(targets.size(0),-1)
        inter=(p*t).sum(1); denom=p.sum(1)+t.sum(1); dice=(2*inter+self.eps)/(denom+self.eps); return 1-dice.mean()

class AdaptiveComboLoss(nn.Module):
    """Recall/precision-adaptive blend (FT + BCE [+ Dice])."""
    def __init__(self, w_ft=0.6, w_bce=0.4, w_dice=0.0, alpha=0.3, beta=0.7, gamma=0.75, ema=0.9, clamp=1e-6, use_dice=False, bce_pos_weight=None):
        super().__init__(); self.w_ft0, self.w_bce0, self.w_dice0 = w_ft,w_bce,w_dice
        self.alpha0,self.beta0,self.gamma = alpha,beta,gamma; self.ema=ema; self.clamp=clamp; self.use_dice=use_dice
        self.register_buffer('ema_prec', torch.tensor(0.5)); self.register_buffer('ema_rec', torch.tensor(0.5))
        self.bce = nn.BCEWithLogitsLoss(pos_weight=bce_pos_weight) if bce_pos_weight is not None else nn.BCEWithLogitsLoss()
    @staticmethod
    def _dice_prob(p,t,eps): p=p.view(p.size(0),-1); t=t.view(t.size(0),-1); inter=(p*t).sum(1); return (2*inter+eps)/(p.sum(1)+t.sum(1)+eps)
    def forward(self,logits,targets):
        p=torch.sigmoid(logits).clamp(self.clamp,1-self.clamp); t=targets.clamp(0,1)
        with torch.no_grad():
            ph=(p>=0.5).float(); tp=(ph*t).sum().item(); fp=(ph*(1-t)).sum().item(); fn=((1-ph)*t).sum().item()
            prec=tp/(tp+fp+1e-8); rec=tp/(tp+fn+1e-8)
            self.ema_prec = self.ema*self.ema_prec + (1-self.ema)*p.new_tensor(prec)
            self.ema_rec  = self.ema*self.ema_rec  + (1-self.ema)*p.new_tensor(rec)
        pr_gap=(self.ema_prec-self.ema_rec).clamp(-0.5,0.5)   # >0 means precision>recall
        alpha=(self.alpha0+0.3*pr_gap).clamp(0.05,0.95); beta=(self.beta0-0.3*pr_gap).clamp(0.05,0.95)
        pv=p.view(p.size(0),-1); tv=t.view(t.size(0),-1)
        TP=(pv*tv).sum(1); FP=((1-tv)*pv).sum(1); FN=(tv*(1-pv)).sum(1)
        tversky=(TP+self.clamp)/(TP+alpha*FP+beta*FN+self.clamp)
        ft_loss=torch.pow(1.0-tversky, self.gamma).mean()
        bce_loss=self.bce(logits,t)
        dice_loss=(1.0-self._dice_prob(p,t,self.clamp).mean()) if self.use_dice else logits.new_tensor(0.0)
        ft_w=(self.w_ft0 + 0.4*(-pr_gap)).clamp(0.2,0.8); bce_w=(self.w_bce0 + 0.4*(pr_gap)).clamp(0.2,0.8)
        return ft_w*ft_loss + bce_w*bce_loss + self.w_dice0*dice_loss

def estimate_pos_weight(h5_path, max_panels=64, tile=TILE):
    with h5py.File(h5_path,"r") as f:
        Y=f["masks"]; N,H,W=Y.shape
        idx=np.random.choice(N, min(N,max_panels), replace=False)
        pos=tot=0
        for i in idx:
            Hb,Wb=math.ceil(H/tile),math.ceil(W/tile)
            y=Y[i]
            for r in range(Hb):
                for c in range(Wb):
                    r0,c0=r*tile,c*tile; r1,c1=min(r0+tile,H),min(c0+tile,W)
                    patch=y[r0:r1,c0:c1]; pos+=patch.sum(); tot+=patch.size
    p=max(1e-8, pos/max(1,tot))
    return torch.tensor((1-p)/p, device=device)


In [5]:
# -------- Lovasz hinge (binary) --------
# https://arxiv.org/abs/1705.08790 (adapted, compact)
def lovasz_grad(gt_sorted):
    p = len(gt_sorted)
    gts = gt_sorted.sum()
    intersection = gts - gt_sorted.float().cumsum(0)
    union = gts + (1 - gt_sorted).float().cumsum(0)
    jaccard = 1.0 - intersection / torch.clamp(union, min=1e-12)
    if p > 1:
        jaccard[1:p] = jaccard[1:p] - jaccard[0:p-1]
    return jaccard

def lovasz_hinge_flat(logits, labels):
    # logits: (P,), labels: (P,) in {0,1}
    if logits.numel() == 0:
        return logits*0.0
    signs = 2.0 * labels.float() - 1.0
    errors = (1.0 - logits * signs)  # margin errors
    errors_sorted, perm = torch.sort(errors, dim=0, descending=True)
    gt_sorted = labels[perm]
    grad = lovasz_grad(gt_sorted)
    loss = torch.clamp(errors_sorted, min=0).dot(grad)
    return loss

def lovasz_hinge(logits, targets):
    # logits/targets: (B,1,H,W)
    losses = []
    B = logits.size(0)
    for b in range(B):
        l = logits[b].view(-1)
        t = targets[b].view(-1).float()
        losses.append(lovasz_hinge_flat(l, t))
    return torch.stack(losses).mean()

# -------- Asymmetric Focal Tversky --------
class AsymFocalTversky(nn.Module):
    def __init__(self, alpha=0.2, beta=0.8, gamma=0.75, eps=1e-6):
        super().__init__()
        self.alpha, self.beta, self.gamma, self.eps = alpha, beta, gamma, eps
    def forward(self, logits, targets):
        p = torch.sigmoid(logits).clamp(self.eps, 1 - self.eps)
        t = targets.clamp(0,1)
        # flatten per-example
        p = p.view(p.size(0), -1)
        t = t.view(t.size(0), -1)
        TP = (p * t).sum(1)
        FP = ((1 - t) * p).sum(1)
        FN = (t * (1 - p)).sum(1)
        tversky = (TP + self.eps) / (TP + self.alpha * FP + self.beta * FN + self.eps)
        loss = torch.pow(1.0 - tversky, self.gamma)
        return loss.mean()

# -------- Hard-negative-mining BCE --------
class OHEMBCE(nn.Module):
    def __init__(self, neg_percent=0.1, pos_weight=None, reduction='mean'):
        super().__init__()
        self.neg_percent = neg_percent
        self.pos_weight = pos_weight  # tensor or None
        self.reduction = reduction
    def forward(self, logits, targets):
        # per-pixel BCE (no reduction)
        if self.pos_weight is not None:
            bce = F.binary_cross_entropy_with_logits(
                logits, targets, pos_weight=self.pos_weight, reduction='none'
            )
        else:
            bce = F.binary_cross_entropy_with_logits(
                logits, targets, reduction='none'
            )
        # Separate pos/neg pixels
        pos_mask = (targets > 0.5)
        neg_mask = ~pos_mask

        pos_loss = bce[pos_mask]
        neg_loss = bce[neg_mask]

        if neg_loss.numel() > 0 and self.neg_percent > 0:
            k = max(1, int(self.neg_percent * neg_loss.numel()))
            hard_neg, _ = torch.topk(neg_loss.view(-1), k, sorted=False)
            loss = torch.cat([pos_loss.view(-1), hard_neg], 0)
        else:
            loss = bce.view(-1)

        return loss.mean() if self.reduction == 'mean' else loss.sum()

# -------- Composite loss --------
class StreakSegLoss(nn.Module):
    """
    total = w_aftl * Asymmetric Focal Tversky
          + w_lovasz * Lovasz Hinge (IoU surrogate on logits)
          + w_ohem  * OHEM-BCE (all positives + hardest q% negatives)
    """
    def __init__(self,
                 w_aftl=0.5, w_lovasz=0.3, w_ohem=0.2,
                 aftl_alpha=0.2, aftl_beta=0.8, aftl_gamma=0.75,
                 ohem_neg_percent=0.1,
                 pos_weight=None):
        super().__init__()
        self.w_aftl   = w_aftl
        self.w_lovasz = w_lovasz
        self.w_ohem   = w_ohem
        self.aftl = AsymFocalTversky(alpha=aftl_alpha, beta=aftl_beta, gamma=aftl_gamma)
        self.ohem = OHEMBCE(neg_percent=ohem_neg_percent, pos_weight=pos_weight)

    def forward(self, logits, targets):
        la = self.aftl(logits, targets)
        ll = lovasz_hinge(logits, (targets>0.5).float())
        lo = self.ohem(logits, targets)
        return self.w_aftl * la + self.w_lovasz * ll + self.w_ohem * lo

In [6]:
from contextlib import nullcontext
scaler = torch.amp.GradScaler('cuda', enabled=(device.type=='cuda'))

def _auc_from_hists(pos_hist, neg_hist):
    tp=np.cumsum(pos_hist[::-1]); fp=np.cumsum(neg_hist[::-1])
    if tp[-1]==0 or fp[-1]==0: return float('nan')
    tpr=tp/tp[-1]; fpr=fp/fp[-1]
    tpr=np.concatenate(([0],tpr)); fpr=np.concatenate(([0],fpr))
    return np.trapz(tpr,fpr)

def run_epoch(loader, model, criterion, optimizer=None, train=True, tag="Train", print_every=10, n_bins=512, grad_clip=1.0):
    model.train(train); start=time.time()
    total=len(loader.dataset); seen=0; loss_sum=0.0; tp=fp=fn=0.0
    pos_hist=np.zeros(n_bins,np.float64); neg_hist=np.zeros(n_bins,np.float64)
    amp_ctx=torch.amp.autocast('cuda', enabled=(device.type=='cuda'))
    for b,(xb,yb) in enumerate(loader,1):
        xb,yb = xb.to(device,non_blocking=True), yb.to(device,non_blocking=True)
        ctx = nullcontext() if train else torch.inference_mode()
        with ctx:
            with amp_ctx:
                logits = model(xb)
                yb_r = resize_masks_to(logits, yb)
                loss  = criterion(logits, yb_r)
            if train:
                optimizer.zero_grad(set_to_none=True)
                scaler.scale(loss).backward()
                if grad_clip is not None:
                    scaler.unscale_(optimizer); nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
                scaler.step(optimizer); scaler.update()
        bs=xb.size(0); seen+=bs; loss_sum += float(loss.item())*bs
        with torch.no_grad():
            p=torch.sigmoid(logits).detach().view(-1).cpu(); t=yb_r.detach().view(-1).cpu()
            pb=(p>=0.5).float(); tp+=float((pb*t).sum()); fp+=float((pb*(1-t)).sum()); fn+=float(((1-pb)*t).sum())
            idx=torch.clamp((p*(n_bins-1)).long(),0,n_bins-1)
            pos_hist += np.bincount(idx[t>0.5].numpy(), minlength=n_bins)
            neg_hist += np.bincount(idx[t<=0.5].numpy(), minlength=n_bins)
        if ((b%print_every==0) or (seen==total)) and train:
            P=tp/max(tp+fp,1); R=tp/max(tp+fn,1); F1=2*P*R/max(P+R,1e-8)
            print(f"\r[{tag}] batch {b}/{len(loader)} | {seen}/{total} ex | loss={loss_sum/seen:.4f} | F1 {F1:.4f} | P {P:.4f} | R {R:.4f} | {time.time()-start:.1f}s", end='', flush=True)
    if train: print()
    auc=_auc_from_hists(pos_hist,neg_hist); P=tp/max(tp+fp,1); R=tp/max(tp+fn,1); F1=2*P*R/max(P+R,1e-8)
    return (loss_sum/total), auc, P, R, F1

def rebalance_loss_blend(loss_obj, ep, total):
    t = min(1.0, max(0.0, ep/float(max(total-1,1))))
    loss_obj.w_ft0 = 0.8 - 0.2*t
    loss_obj.w_bce0 = 0.2 + 0.2*t

def fit(model, train_loader, val_loader, criterion, epochs=EPOCHS, max_lr=MAX_LR, save_path=SAVE_BEST, early_stop_patience=10):
    model=model.to(device)
    opt=torch.optim.Adam(model.parameters(), lr=max_lr, weight_decay=1e-5)
    sched=torch.optim.lr_scheduler.OneCycleLR(opt, max_lr=max_lr, steps_per_epoch=len(train_loader), epochs=epochs, pct_start=0.1, anneal_strategy='cos')
    best_f1=-1; no_improve=0
    for ep in range(1,epochs+1):
        rebalance_loss_blend(criterion, ep-1, epochs)
        t0=time.time()
        trL,_,tP,tR,tF1 = run_epoch(train_loader, model, criterion, opt, train=True, tag="Train")
        _free_cuda()
        vaL,vaA,vaP,vaR,vaF1 = run_epoch(val_loader,   model, criterion, None, train=False, tag="Val")
        _free_cuda()
        sched.step()
        print(f"Epoch {ep:03d} | Train L {trL:.4f} F1 {tF1:.4f} P {tP:.4f} R {tR:.4f} || Val L {vaL:.4f} AUC {vaA:.4f} F1 {vaF1:.4f} P {vaP:.4f} R {vaR:.4f} | {time.time()-t0:.1f}s")
        torch.save({"state_dict": model.state_dict()}, SAVE_LAST)
        if vaF1 > best_f1 + 1e-4:
            best_f1, no_improve = vaF1, 0
            torch.save({"state_dict": model.state_dict()}, save_path)
        else:
            no_improve += 1
            if no_improve >= early_stop_patience:
                print("Early stopping."); break
    return model


In [7]:
with h5py.File(DATA["train_h5"], "r") as f:
    N = f["images"].shape[0]
idx = np.arange(N); np.random.shuffle(idx)
split = int(0.9*N); idx_tr, idx_va = np.sort(idx[:split]), np.sort(idx[split:])

ds_full  = H5TiledDataset(DATA["train_h5"], tile=TILE, k_sigma=5.0)
train_ds = SubsetDS(ds_full, idx_tr)
val_ds   = SubsetDS(ds_full, idx_va)
test_ds  = H5TiledDataset(DATA["test_h5"], tile=TILE, k_sigma=5.0)

# Weighted sampler (positives 10x)
from torch.utils.data import DataLoader, WeightedRandomSampler
w_per_tile  = tile_pos_weights(DATA["train_h5"], tile=TILE)
w_for_train = torch.tensor([w_per_tile[k] for k in train_ds.map], dtype=torch.double)
sampler     = WeightedRandomSampler(w_for_train, num_samples=len(train_ds), replacement=True)

train_loader = DataLoader(train_ds, batch_size=BATCH, sampler=sampler, num_workers=NUM_WORKERS, pin_memory=(device.type=='cuda'))
val_loader   = DataLoader(val_ds,   batch_size=max(32,BATCH//2), shuffle=False, num_workers=NUM_WORKERS, pin_memory=(device.type=='cuda'))
test_loader  = DataLoader(test_ds,  batch_size=BATCH, shuffle=False, num_workers=NUM_WORKERS, pin_memory=(device.type=='cuda'))


In [8]:
from scipy.ndimage import label as cc_label, find_objects, binary_opening
from contextlib import contextmanager

@contextmanager
def cuda_oom_guard():
    try: yield
    except RuntimeError as e:
        if "CUDA out of memory" in str(e):
            torch.cuda.empty_cache(); raise MemoryError("cuda_oom")
        raise

@torch.no_grad()
def stream_panels_direct(h5_path, model, tile=TILE, start_batch_tiles=64):
    model.eval(); dev=next(model.parameters()).device
    with h5py.File(h5_path,'r') as f:
        X=f['images']; N,H,W=X.shape
    Hb,Wb=math.ceil(H/tile), math.ceil(W/tile); canvas=np.zeros((Hb*tile,Wb*tile), np.float32)
    bt=start_batch_tiles
    for pid in range(N):
        with h5py.File(h5_path,'r') as f:
            img=f['images'][pid].astype(np.float32)
        med=np.median(img); mad=np.median(np.abs(img-med))+1e-12; sigma=1.4826*mad
        img=np.clip((img-med)/max(sigma,1e-6), -5, 5)
        coords=[(r*tile,c*tile) for r in range(Hb) for c in range(Wb)]; canvas.fill(0.0)
        k=0
        while k<len(coords):
            tried=False
            while not tried:
                tried=True; this_bt=min(bt, len(coords)-k)
                patches=[]
                for j in range(this_bt):
                    r0,c0=coords[k+j]; r1,c1=min(r0+tile,H),min(c0+tile,W)
                    patch=np.zeros((tile,tile),np.float32); patch[:r1-r0,:c1-c0]=img[r0:r1,c0:c1]
                    patches.append(patch[None,None,...])
                xb=torch.from_numpy(np.concatenate(patches,0)).to(dev, non_blocking=True)
                try:
                    with cuda_oom_guard(), torch.amp.autocast('cuda', enabled=(dev.type=='cuda')):
                        probs=torch.sigmoid(model(xb)).float().cpu().numpy()[:,0]
                except MemoryError:
                    bt=max(1, bt//2); tried=False; del xb; torch.cuda.empty_cache(); continue
                for p,(r0,c0) in zip(probs, coords[k:k+this_bt]): canvas[r0:r0+tile, c0:c0+tile] = p
                k += this_bt; del xb, probs; torch.cuda.empty_cache()
                if bt<64: bt=min(64, bt*2)
        yield pid, canvas[:H,:W].copy()

def postprocess(bin_img, open_iters=1): return binary_opening(bin_img, iterations=open_iters).astype(np.uint8) if open_iters>0 else bin_img

def component_PR_one(pred_bin, gt, iou_thr=0.1, min_pix=120, min_elong=2.5):
    TP=FP=0
    L,_=cc_label(pred_bin, structure=np.ones((3,3),np.uint8))
    for k,sl in enumerate(find_objects(L) or [], start=1):
        comp=(L[sl]==k); area=int(comp.sum())
        if area < min_pix: continue
        h=sl[0].stop-sl[0].start; w=sl[1].stop-sl[1].start
        elong=max(h,w)/max(1,min(h,w))
        if elong < min_elong: continue
        gtl=gt[sl].astype(bool); inter=(comp & gtl).sum(); uni=(comp | gtl).sum()
        if uni and inter/uni >= iou_thr: TP+=1
        else: FP+=1
    _,GT=cc_label(gt, structure=np.ones((3,3),np.uint8))
    return TP,FP,int(GT)


In [8]:
model = UNetResSEASPP(in_ch=1, out_ch=1)
#criterion = AdaptiveComboLoss(w_ft=0.6, w_bce=0.2, w_dice=0.2, alpha=0.25, beta=0.75, gamma=0.75, use_dice=True).to(device)
#criterion.bce = nn.BCEWithLogitsLoss(pos_weight=estimate_pos_weight(DATA["train_h5"], max_panels=64, tile=TILE))
# build pos_weight on the correct device (optional but often helpful)
bce_pw = estimate_pos_weight(DATA["train_h5"], max_panels=64, tile=TILE).to(device)

criterion = StreakSegLoss(
    w_aftl=0.5, w_lovasz=0.3, w_ohem=0.2,
    aftl_alpha=0.2, aftl_beta=0.8, aftl_gamma=0.75,
    ohem_neg_percent=0.10,
    pos_weight=bce_pw
)
_ = fit(model, train_loader, val_loader, criterion, epochs=EPOCHS, max_lr=MAX_LR, save_path=SAVE_BEST)
torch.save({"state_dict": model.state_dict()}, "./last_unet_resse.pt")

[Train] batch 2800/5760 | 358400/737280 ex | loss=3.0523 | F1 0.0478 | P 0.0245 | R 0.9931 | 719.9s

KeyboardInterrupt: 

In [9]:
ths = np.linspace(0.15, 0.95, 33)
val_counts = {t:{'TP':0,'FP':0,'GT':0} for t in ths}
val_set = set(map(int, idx_va))

start = time.time(); processed=used=0
with h5py.File(DATA["train_h5"], 'r') as f: N_total = len(f['masks'])

def _best(counts):
    best=(-1.0, None)
    for t in ths:
        TP,FP,GT = counts[t]['TP'], counts[t]['FP'], counts[t]['GT']
        FN = max(GT-TP,0); P = TP/max(TP+FP,1); R = TP/max(TP+FN,1); F1 = 2*P*R/max(P+R,1e-8)
        if F1>best[0]: best=(F1,t)
    return best

for pid, probs in stream_panels_direct(DATA["train_h5"], model, tile=TILE, start_batch_tiles=64):
    processed += 1
    if pid in val_set:
        used += 1
        with h5py.File(DATA["train_h5"],'r') as f:
            gt = f['masks'][pid][:].astype(np.uint8)
        for t in ths:
            pred_bin = (probs>=t).astype(np.uint8)
            TP,FP,GT = component_PR_one(pred_bin, gt, iou_thr=0.1, min_pix=30, min_elong=1.8)
            d=val_counts[t]; d['TP']+=TP; d['FP']+=FP; d['GT']+=GT
    if processed%5==0 or processed==N_total:
        elapsed=time.time()-start; rate=processed/max(elapsed,1e-6); eta=(N_total-processed)/max(rate,1e-6)
        curF1,curT = _best(val_counts)
        print(f"\rPanels {processed}/{N_total} | used {used} | best F1≈{curF1:.3f} @ thr≈{curT if curT else float('nan'):.3f} | {elapsed/60:.1f}m | ETA {eta/60:.1f}m", end='', flush=True)
print()

best_f1=-1; best_thr=None; best_stats=None
for t in ths:
    TP,FP,GT = val_counts[t]['TP'], val_counts[t]['FP'], val_counts[t]['GT']
    FN=GT-TP; P=TP/max(TP+FP,1); R=TP/max(TP+FN,1); F1=2*P*R/max(P+R,1e-8)
    if F1>best_f1: best_f1, best_thr, best_stats = F1, float(t), (P,R,TP,FP,GT)
print(f"[VAL] best thr={best_thr:.3f} F1={best_f1:.3f} P={best_stats[0]:.3f} R={best_stats[1]:.3f} | TP={best_stats[2]} FP={best_stats[3]} GT={best_stats[4]}")


NameError: name 'model' is not defined

In [10]:
# PR on TEST at best_thr
TP=FP=GT=0
with h5py.File(DATA["test_h5"], 'r') as f:
    N_test = f['masks'].shape[0]
for pid, probs in stream_panels_direct(DATA["test_h5"], model, tile=TILE):
    if pid >= N_test: break
    with h5py.File(DATA["test_h5"], 'r') as f:
        gt = f['masks'][pid][:].astype(np.uint8)
    pred_bin = postprocess((probs>=best_thr).astype(np.uint8), open_iters=1)
    tpi,fpi,gti = component_PR_one(pred_bin, gt, iou_thr=0.1, min_pix=120, min_elong=2.5)
    TP += tpi; FP += fpi; GT += gti
FN = GT-TP; P = TP/max(TP+FP,1); R = TP/max(TP+FN,1); F1 = 2*P*R/max(P+R,1e-8)
print(f"[TEST] thr={best_thr:.3f} P={P:.3f} R={R:.3f} F1={F1:.3f} | TP={TP} FP={FP} GT={GT}")

# Histograms (stream + mark detections against CSV)
import pandas as pd, matplotlib.pyplot as plt

def stream_mark_nn_and_stack(csv_path, h5_path, model, thr, radius=3):
    cat = pd.read_csv(csv_path).copy()
    if "stack_detection" in cat.columns: cat["stack_detected"] = cat["stack_detection"].astype(bool)
    elif "stack_mag" in cat.columns:     cat["stack_detected"] = ~cat["stack_mag"].isna()
    else:                                 cat["stack_detected"] = False
    nn = np.zeros(len(cat), dtype=bool)
    groups = {int(pid): grp.index.to_numpy() for pid, grp in cat.groupby("image_id")}
    with h5py.File(h5_path,'r') as f: _,H,W=f['images'].shape
    for pid, probs in stream_panels_direct(h5_path, model, tile=TILE):
        if pid not in groups: continue
        mask = (probs>=thr).astype(np.uint8)
        idxs = groups[pid]
        xs=np.clip(cat.loc[idxs,"x"].to_numpy(int),0,W-1); ys=np.clip(cat.loc[idxs,"y"].to_numpy(int),0,H-1)
        for j,(x,y) in zip(idxs, zip(xs,ys)):
            y0,y1=max(0,y-radius),min(H,y+radius+1); x0,x1=max(0,x-radius),min(W,x+radius+1)
            nn[j] = (mask[y0:y1, x0:x1].max()>0)
    cat["nn_detected"] = nn
    return cat

def plot_detect_hist(cat, field, bins=12, title=None):
    vals=cat[field].to_numpy(); vals=vals[np.isfinite(vals)]
    vmin,vmax=np.nanmin(vals),np.nanmax(vals)
    if vmax/max(vmin,1e-6) > 50: edges=np.geomspace(max(vmin,1e-3), vmax, bins+1)
    else: edges=np.linspace(vmin, vmax, bins+1)
    nn_det=cat[cat["nn_detected"]]; stk_det=cat[cat["stack_detected"]]; cum=cat[nn_det.index.union(stk_det.index)]
    fig,ax=plt.subplots(figsize=(6.6,4.4))
    ax.hist(cat[field], bins=edges, histtype="step", label="All injected", alpha=0.7)
    ax.hist(cum[field], bins=edges, histtype="step", label="Cumulative (NN ∪ LSST)")
    ax.hist(nn_det[field], bins=edges, histtype="step", label="NN detected")
    ax.hist(stk_det[field], bins=edges, histtype="step", label="LSST stack detected")
    ax.set_xlabel(field.replace("_"," ")); ax.set_ylabel("Count per bin")
    if title: ax.set_title(title); ax.legend(); ax.grid(True, alpha=0.3); plt.show()

# Build marked catalogs and plot both histograms
cat_test = stream_mark_nn_and_stack(DATA["test_csv"], DATA["test_h5"], model, thr=best_thr, radius=3)
mag_field = "integrated_mag" if "integrated_mag" in cat_test.columns else ("PSF_mag" if "PSF_mag" in cat_test.columns else "mag")
plot_detect_hist(cat_test, field=mag_field,     bins=12, title=f"Detections vs magnitude (thr={best_thr:.2f})")
plot_detect_hist(cat_test, field="trail_length", bins=12, title=f"Detections vs trail length (thr={best_thr:.2f})")


NameError: name 'model' is not defined

## Sanity checks

In [8]:
import torch, torch.nn as nn, torch.nn.functional as F
# IMPORTANT: turn OFF AMP and any heavy augmentations for the probe
torch.backends.cudnn.benchmark = False

# --- Lovasz ---
def _lovasz_grad(gt_sorted):
    p = len(gt_sorted)
    gts = gt_sorted.sum()
    inter = gts - gt_sorted.float().cumsum(0)
    union = gts + (1 - gt_sorted).float().cumsum(0)
    jacc = 1.0 - inter / torch.clamp(union, min=1e-12)
    if p > 1:
        jacc[1:p] = jacc[1:p] - jacc[0:p-1]
    return jacc

def _lovasz_hinge_flat(logits, labels):
    if logits.numel() == 0:
        return logits*0.0
    signs  = 2.0*labels.float() - 1.0
    errors = 1.0 - logits * signs
    errs_sorted, perm = torch.sort(errors, dim=0, descending=True)
    gt_sorted = labels[perm]
    grad = _lovasz_grad(gt_sorted)
    return torch.clamp(errs_sorted, min=0).dot(grad)

def lovasz_hinge(logits, targets):
    B = logits.size(0)
    losses = []
    for b in range(B):
        l = logits[b].reshape(-1)
        t = targets[b].reshape(-1).float()
        losses.append(_lovasz_hinge_flat(l, t))
    return torch.stack(losses).mean()

# --- Asymmetric Focal Tversky (FP-heavy) ---
class AsymFocalTversky(nn.Module):
    def __init__(self, alpha=0.85, beta=0.15, gamma=1.0, eps=1e-6):
        super().__init__()
        self.alpha, self.beta, self.gamma, self.eps = alpha, beta, gamma, eps
    def forward(self, logits, targets):
        p = torch.sigmoid(logits).clamp(self.eps, 1-self.eps)
        t = targets.clamp(0,1)
        p = p.view(p.size(0), -1)
        t = t.view(t.size(0), -1)
        TP = (p*t).sum(1)
        FP = ((1-t)*p).sum(1)
        FN = (t*(1-p)).sum(1)
        tv = (TP+self.eps) / (TP + self.alpha*FP + self.beta*FN + self.eps)
        return torch.pow(1.0 - tv, self.gamma).mean()

# --- OHEM BCE (sample more negatives) ---
class OHEMBCE(nn.Module):
    def __init__(self, neg_percent=0.4, pos_weight=None, reduction='mean'):
        super().__init__()
        self.neg_percent = float(neg_percent)
        self.pos_weight  = pos_weight  # keep None to avoid positive bias
        self.reduction   = reduction
    def forward(self, logits, targets):
        bce = F.binary_cross_entropy_with_logits(logits, targets,
                                                 pos_weight=self.pos_weight,
                                                 reduction='none')
        pos_mask = (targets > 0.5)
        neg_mask = ~pos_mask
        pos_loss = bce[pos_mask]
        neg_loss = bce[neg_mask]
        if neg_loss.numel() > 0 and self.neg_percent > 0:
            k = max(1, int(self.neg_percent * neg_loss.numel()))
            hard_neg, _ = torch.topk(neg_loss.reshape(-1), k, sorted=False)
            loss = torch.cat([pos_loss.reshape(-1), hard_neg], 0)
        else:
            loss = bce.reshape(-1)
        return loss.mean() if self.reduction=='mean' else loss.sum()

# --- Composite with background suppression ---
class StreakSegLossFP(nn.Module):
    """
    Strong FP control:
      - Asym Focal Tversky with high alpha (penalize FP)
      - Lovasz hinge for region quality
      - OHEM-BCE with many negatives
      - Small background mean-prob penalty
    """
    def __init__(self,
                 w_aftl=0.45, w_lovasz=0.35, w_ohem=0.20,
                 aftl_alpha=0.85, aftl_beta=0.15, aftl_gamma=1.0,
                 ohem_neg_percent=0.40,
                 bg_lambda=0.02, bg_gamma=2.0,      
                 pos_weight=None):
        super().__init__()
        self.w_aftl, self.w_lovasz, self.w_ohem = w_aftl, w_lovasz, w_ohem
        self.bg_lambda, self.bg_gamma = bg_lambda, bg_gamma
        self.aftl = AsymFocalTversky(aftl_alpha, aftl_beta, aftl_gamma)
        self.ohem = OHEMBCE(ohem_neg_percent, pos_weight=pos_weight)

    def forward(self, logits, targets):
        t = targets.clamp(0,1)
        la = self.aftl(logits, t)
        ll = lovasz_hinge(logits, (t>0.5).float())
        lo = self.ohem(logits, t)

        # --- background focal suppression ---
        p = torch.sigmoid(logits)
        bg_mask = (t < 0.5).float()
        denom = torch.clamp(bg_mask.sum(dim=(1,2,3)), min=1.0)
        bg_p = (p * bg_mask).sum(dim=(1,2,3)) / denom
        bg_focal = torch.pow(bg_p, self.bg_gamma).mean()   # sharper on mid probs

        return self.w_aftl*la + self.w_lovasz*ll + self.w_ohem*lo + self.bg_lambda*bg_focal



In [17]:
# per-panel median/MAD normalize, clip like stream_panels_direct
def norm_medmad_clip(x, clip=5.0, eps=1e-6):
    # x: torch.Tensor [B,1,H,W] or [1,H,W]
    if x.ndim == 4:
        med = x.median(dim=-1, keepdim=True).values.median(dim=-2, keepdim=True).values
    else:  # [1,H,W]
        med = x.median()
        med = med.view(1,1,1)
    mad = (x - med).abs().median()
    sigma = 1.4826 * mad + eps
    z = (x - med) / sigma
    return z.clamp_(-clip, clip)

class WithTransform(torch.utils.data.Dataset):
    def __init__(self, base): self.base = base
    def __len__(self): return len(self.base)
    def __getitem__(self, i):
        x, y = self.base[i]
        x = norm_medmad_clip(x)
        return x, y

In [18]:
# Collect panel-level “has positive” flags quickly from the H5
import h5py, numpy as np, math
def panels_with_positives(h5_path, tile=128, max_panels=None):
    ids=[]
    with h5py.File(h5_path,'r') as f:
        Y=f['masks']; N,H,W=Y.shape
        rng = np.random.default_rng(0)
        order = rng.permutation(N) if max_panels else np.arange(N)
        for i in order:
            yi = Y[i]
            if yi.any(): ids.append(i)
            if max_panels and len(ids)>=max_panels: break
    return np.array(sorted(ids))

def quick_mask_stats(loader, n_batches=5):
    import numpy as np, torch
    pos_px = neg_px = 0
    nonempty_masks = 0
    for b,(xb,yb) in enumerate(loader,1):
        y = yb.numpy()
        pos_px += (y>0.5).sum()
        neg_px += (y<=0.5).sum()
        nonempty_masks += int((y>0.5).any())
        if b>=n_batches: break
    total = pos_px+neg_px
    print(f"Pos px ratio ≈ {pos_px/max(total,1):.6f} | non-empty masks in {n_batches} batches: {nonempty_masks}/{n_batches}")

pos_panels = panels_with_positives("/home/karlo/train_chunked.h5", max_panels=2000)
# Sample your small train/val from the intersection with your original idx_tr/idx_va
sub_tr = np.random.default_rng(42).choice(np.intersect1d(idx_tr, pos_panels), size=min(200, len(pos_panels)), replace=False)
sub_va = np.random.default_rng(43).choice(np.intersect1d(idx_va, pos_panels), size=min(80,  len(pos_panels)), replace=False)

train_ds_small = SubsetDS(ds_full, np.sort(sub_tr))
val_ds_small   = SubsetDS(ds_full, np.sort(sub_va))
train_loader_small = DataLoader(train_ds_small, batch_size=64, shuffle=True,  num_workers=10, pin_memory=(device.type=='cuda'))
val_loader_small   = DataLoader(val_ds_small,   batch_size=64, shuffle=False, num_workers=10, pin_memory=(device.type=='cuda'))

train_ds_small = WithTransform(train_ds_small)
val_ds_small   = WithTransform(val_ds_small)

quick_mask_stats(train_loader_small, n_batches=8)  # re-check


Pos px ratio ≈ 0.003537 | non-empty masks in 8 batches: 8/8


In [10]:
import time, numpy as np, torch

@torch.no_grad()
def pick_thr_under_min(model, loader, max_batches=40, n_bins=256, beta=1.0):
    """
    Ultra-fast threshold finder from pixel histograms.
    - No CC / IoU; just pixel-level P/R/F1.
    - Processes `max_batches` from `loader` (use val loader).
    Returns: best_thr, (P, R, F1), dict(histograms)
    """
    model.eval().to(device)
    pos_hist = np.zeros(n_bins, np.float64)
    neg_hist = np.zeros(n_bins, np.float64)

    t0 = time.time()
    processed = 0
    amp = torch.amp.autocast('cuda', enabled=(device.type=='cuda'))

    for b, (xb, yb) in enumerate(loader, start=1):
        xb = xb.to(device, non_blocking=True)
        yb = yb.to(device, non_blocking=True)

        with amp:
            logits = model(xb)
            yb_r   = resize_masks_to(logits, yb)           # your helper
            p      = torch.sigmoid(logits).float()         # (B,1,t,t)

        # flatten
        p = p.view(-1).detach()
        t = yb_r.view(-1).detach()

        # build histograms on GPU then move to CPU (fast)
        # subsample negatives so they don't dominate time
        neg_mask = (t <= 0.5)
        pos_mask = (t > 0.5)

        # cap negatives to ~2e6 pixels per batch (adjust if needed)
        max_neg = 2_000_000
        n_neg   = int(neg_mask.sum().item())
        if n_neg > 0:
            if n_neg > max_neg:
                # random subsample negatives
                idx = torch.randperm(n_neg, device=device)[:max_neg]
                p_neg = p[neg_mask][idx]
            else:
                p_neg = p[neg_mask]
            neg_hist += torch.histc(p_neg, bins=n_bins, min=0.0, max=1.0).cpu().numpy()

        if int(pos_mask.sum().item()) > 0:
            p_pos = p[pos_mask]
            pos_hist += torch.histc(p_pos, bins=n_bins, min=0.0, max=1.0).cpu().numpy()

        processed += 1
        if (processed % 5 == 0) or (processed == max_batches):
            elapsed = time.time() - t0
            print(f"\r[FAST-T] batches {processed}/{max_batches} | elapsed {elapsed:.1f}s", end='', flush=True)

        if processed >= max_batches:
            break

    print()  # newline

    # choose threshold by Fβ
    # cumulative from high→low
    pos_cum = np.cumsum(pos_hist[::-1])
    neg_cum = np.cumsum(neg_hist[::-1])
    TP = pos_cum
    FP = neg_cum
    FN = (pos_cum[-1] - TP).clip(min=0)

    P  = TP / np.maximum(TP + FP, 1)
    R  = TP / np.maximum(pos_cum[-1], 1)
    beta2 = beta * beta
    Fbeta = (1+beta2) * P * R / np.maximum(beta2*P + R, 1e-8)

    best_bin = int(np.nanargmax(Fbeta))
    best_thr = (best_bin + 0.5) / n_bins
    best_P, best_R, best_F = float(P[best_bin]), float(R[best_bin]), float(Fbeta[best_bin])

    print(f"[FAST-T] best thr≈{best_thr:.3f} | P≈{best_P:.3f} R≈{best_R:.3f} F{beta:.1f}≈{best_F:.3f} "
          f"| time {time.time()-t0:.1f}s")

    return best_thr, (best_P, best_R, best_F), {"pos_hist":pos_hist, "neg_hist":neg_hist}

In [11]:
import math
from contextlib import nullcontext



def init_head_bias_to_prior(model, p0=0.20):
    # p0 in (0,1); logit(p0) biases initial sigmoid outputs toward p0
    b = math.log(p0/(1-p0))
    with torch.no_grad():
        model.head.bias.data.fill_(b)


@torch.no_grad()
def sanitize_batch(x, y):
    """Replace non-finite with 0 and clamp labels to [0,1]."""
    x = torch.nan_to_num(x, nan=0.0, posinf=0.0, neginf=0.0)
    y = torch.nan_to_num(y, nan=0.0, posinf=0.0, neginf=0.0).clamp_(0.0, 1.0)
    return x, y

class WarmupBCE(torch.nn.Module):
    """Simple, stable warm-up loss (no Dice)."""
    def __init__(self, pos_weight=10.0):
        super().__init__()
        self.bce = torch.nn.BCEWithLogitsLoss(
            pos_weight=torch.tensor(float(pos_weight))
        )
    def forward(self, logits, targets):
        return self.bce(logits, targets)

def init_head_bias_to_prior(model, p0=0.20):
    b = math.log(p0/(1-p0))
    with torch.no_grad():
        if hasattr(model, "head") and hasattr(model.head, "bias"):
            model.head.bias.data.fill_(b)

def fit_quick_warmup(model, loader, epochs=2, max_batches=250,
                     lr=1e-4, metric_thr=0.20, pos_weight=10.0):
    device = next(model.parameters()).device
    amp_ctx = nullcontext()
    crit = WarmupBCE(pos_weight=pos_weight).to(device)
    opt  = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=0.0)
    for ep in range(1, epochs+1):
        model.train(); seen=tp=fp=fn=0; loss_sum=0.0; t0=time.time()
        for b,(xb,yb) in enumerate(loader,1):
            xb,yb = xb.to(device), yb.to(device)
            with amp_ctx:
                # forward
                logits = model(xb)
                yb_r   = resize_masks_to(logits, yb)

                # sanitize AFTER resize (resize can create NaNs if input has them)
                xb, yb_r = sanitize_batch(xb, yb_r)
                logits   = torch.nan_to_num(logits, nan=0.0, posinf=0.0, neginf=0.0)

                # compute loss
                loss = crit(logits, yb_r)

            # detect NaN early, print once and break
            if not torch.isfinite(loss):
                with torch.no_grad():
                    bad_x = ~torch.isfinite(xb).any().item()
                    bad_y = ~torch.isfinite(yb_r).any().item()
                    print(f"\n[NaN] batch {b}: loss={loss.item()} | "
                          f"x finite? {torch.isfinite(xb).all().item()} | "
                          f"y finite? {torch.isfinite(yb_r).all().item()} | "
                          f"logits finite? {torch.isfinite(logits).all().item()}")
                break

            opt.zero_grad(set_to_none=True)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            opt.step()

            # quick metrics
            with torch.no_grad():
                p = torch.sigmoid(logits)
                pv = p.view(-1); tv = yb_r.view(-1)
                pred = (pv>=metric_thr).float()
                tp += float((pred*tv).sum()); fp += float((pred*(1-tv)).sum()); fn += float(((1-pred)*tv).sum())
                loss_sum += float(loss.item())*xb.size(0); seen += xb.size(0)

                # Only print “pos/neg” means if there ARE positives to avoid NaN in the print
                if (yb_r>0.5).any():
                    pos_prob = float(p[yb_r>0.5].mean())
                else:
                    pos_prob = float('nan')
                neg_prob = float(p[yb_r<=0.5].mean())
                frac_over = float((p>=0.5).float().mean())

            if b % 5 == 0 or b == max_batches:
                print(f"\r[WARMUP] ep{ep} batch {b}/{max_batches} "
                      f"| loss={loss_sum/seen:.4f} | over>=0.5 {frac_over:.4f} "
                      f"| pos {pos_prob:.4f} | neg {neg_prob:.4f} "
                      f"| {time.time()-t0:.1f}s", end='', flush=True)
            if b>=max_batches: break

        P = tp/max(tp+fp,1); R = tp/max(tp+fn,1); F1 = 2*P*R/max(P+R,1e-8)
        print(f"\n[WARMUP] ep{ep} loss {loss_sum/seen:.4f} | F1 {F1:.3f} P {P:.3f} R {R:.3f} | {(time.time()-t0):.1f}s")

def fit_quick(model, criterion, train_loader, epochs=2, max_batches=250, lr=3e-4, metric_thr=0.20, weight_decay=0.0):
    opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    model.to(device)
    amp_ctx = nullcontext()
    for ep in range(1, epochs+1):
        model.train(); seen=tp=fp=fn=0; loss_sum=0.0; t0=time.time()
        for b,(xb,yb) in enumerate(train_loader,1):
            xb,yb = xb.to(device), yb.to(device)
            #with torch.amp.autocast('cuda', enabled=(device.type=='cuda')):
            with amp_ctx:
                logits = model(xb)
                yb_r   = resize_masks_to(logits, yb)
                # quick sanity: ensure there ARE positives after resize
                if (yb_r>0.5).sum() == 0:
                    pass  # comment out, but keep for debugging if needed
                loss   = criterion(logits, yb_r)
            opt.zero_grad(set_to_none=True); loss.backward(); opt.step()

            with torch.no_grad():
                p = torch.sigmoid(logits)
                frac_over = float((p >= metric_thr).float().mean())

                pos_prob  = float(p[yb_r>0.5].mean()) if (yb_r>0.5).any() else float('nan')
                neg_prob  = float(p[yb_r<=0.5].mean())

                posv = p[yb_r > 0.5]; negv = p[yb_r <= 0.5]
                sep = (float(posv.mean()) if posv.numel() else float('nan')) - float(negv.mean())

                pv = p.view(-1); tv = yb_r.view(-1)
                pred = (pv>=metric_thr).float()
                tp += float((pred*tv).sum()); fp += float((pred*(1-tv)).sum()); fn += float(((1-pred)*tv).sum())
                loss_sum += float(loss.item())*xb.size(0); seen += xb.size(0)

            if b>=max_batches: break
            if b % 5 == 0:
                print(f"\r[QP] batch {b}/{max_batches} | loss={loss_sum/seen:.4f} | over>={metric_thr:.4f} {frac_over:.4f} | "
                      f"pos {pos_prob:.3f} neg {neg_prob:.3f} | sep {sep:.4f} | {time.time()-t0:.1f}s", end='', flush=True)
        P = tp/max(tp+fp,1); R = tp/max(tp+fn,1); F1 = 2*P*R/max(P+R,1e-8)
        print(f"\n[QP] ep{ep} loss {loss_sum/seen:.4f} | F1 {F1:.3f} P {P:.3f} R {R:.3f} | {(time.time()-t0):.1f}s")



In [14]:
# drop-in replacement: recall-friendly threshold with a positive-rate floor
def pick_thr_with_floor(model, loader, max_batches=20, n_bins=256, beta=2.0,
                        min_pos_rate=0.01, max_pos_rate=0.10):
    model.eval()
    dev = next(model.parameters()).device

    # histograms over probs for pos/neg pixels + overall count for pos-rate
    pos_hist = np.zeros(n_bins, np.int64)
    neg_hist = np.zeros(n_bins, np.int64)
    all_hist = np.zeros(n_bins, np.int64)

    with torch.inference_mode():
        for b, (xb, yb) in enumerate(loader, 1):
            xb = xb.to(dev, non_blocking=True)
            yb = yb.to(dev, non_blocking=True)
            logits = model(xb)
            yb_r   = resize_masks_to(logits, yb)
            p      = torch.sigmoid(logits).clamp(1e-6, 1-1e-6).float().cpu()

            t  = (yb_r>0.5).cpu()
            idx = torch.clamp((p*(n_bins-1)).long(), 0, n_bins-1)

            # update hist
            for k in range(n_bins):
                pass
            flat_idx = idx.view(-1)
            flat_t   = t.view(-1)
            pos_hist += np.bincount(flat_idx[flat_t].numpy(), minlength=n_bins)
            neg_hist += np.bincount(flat_idx[~flat_t].numpy(), minlength=n_bins)
            all_hist += np.bincount(flat_idx.numpy(), minlength=n_bins)

            if b >= max_batches: break

    # sweep thresholds
    pos_cum = pos_hist[::-1].cumsum()     # TP as threshold lowers
    neg_cum = neg_hist[::-1].cumsum()     # FP as threshold lowers
    all_cum = all_hist[::-1].cumsum()     # predicted-positive count

    TP = pos_cum
    FP = neg_cum
    FN = pos_hist.sum() - TP
    P = TP / np.maximum(TP+FP, 1)
    R = TP / np.maximum(TP+FN, 1)

    # F-beta (beta=2 -> recall-friendly)
    beta2 = beta*beta
    F = (1+beta2) * P * R / np.maximum(beta2*P + R, 1e-12)

    # enforce predicted-positive rate floor/ceiling
    total_px = all_hist.sum()
    pos_rate = all_cum / max(total_px, 1)
    mask = (pos_rate >= min_pos_rate) & (pos_rate <= max_pos_rate)

    if not mask.any():
        # fallback: choose the closest rate above min_pos_rate
        k = np.argmin(np.abs(pos_rate - min_pos_rate))
    else:
        k = np.argmax(F * mask)

    thr = (n_bins-1 - k) / (n_bins-1)
    return float(thr), (float(P[k]), float(R[k]), float(F[k])), dict(pos_rate=float(pos_rate[k]))

In [20]:
# ======= One-Click Probe Tuning (no re-sweeps) =======
import copy, time, torch, torch.nn as nn

#device = next(iter(train_loader_small))[0].device if hasattr(train_loader_small, '__iter__') else torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# -- tiny helpers (safe to redefine) --
def _set_requires_grad(module, flag: bool):
    for p in module.parameters(): p.requires_grad = flag

def freeze_all(model): _set_requires_grad(model, False)

def unfreeze_head_only(model):
    freeze_all(model)
    if hasattr(model, "head"):
        _set_requires_grad(model.head, True)
    else:
        raise AttributeError("Model has no attribute 'head'.")

def unfreeze_head_and_last_conv(model):
    freeze_all(model)
    # head
    if hasattr(model, "head"):
        _set_requires_grad(model.head, True)
    # last conv of last up block (u4.rb2.c2) if present
    try:
        _set_requires_grad(model.u4.rb2.c2, True)
    except AttributeError:
        # fallback: just leave head if model layout differs
        pass

# define once if missing
if "fit_quick_head_only" not in globals():
    def fit_quick_head_only(model, loader, epochs=2, max_batches=200, lr=5e-5, metric_thr=0.20, pos_weight=2.0):
        amp_ctx = nullcontext()
        dev = next(model.parameters()).device
        unfreeze_head_only(model)
        head_params = [p for p in model.head.parameters() if p.requires_grad]
        opt = torch.optim.Adam(head_params, lr=lr, weight_decay=0.0)
        bce = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(pos_weight, device=dev))
        for ep in range(1, epochs+1):
            model.train()
            seen=tp=fp=fn=0.0; loss_sum=0.0; t0=time.time()
            for b,(xb,yb) in enumerate(loader, 1):
                xb,yb = xb.to(dev), yb.to(dev)
                with amp_ctx:
                    logits = model(xb)
                    yb_r = resize_masks_to(logits, yb)
                    loss = bce(logits, yb_r)
                opt.zero_grad(set_to_none=True); loss.backward(); opt.step()
                with torch.no_grad():
                    p = torch.sigmoid(logits)
                    pv, tv = p.view(-1), yb_r.view(-1)
                    pred = (pv>=metric_thr).float()
                    tp += float((pred*tv).sum()); fp += float((pred*(1-tv)).sum()); fn += float(((1-pred)*tv).sum())
                    loss_sum += float(loss.item())*xb.size(0); seen += xb.size(0)
                if b>=max_batches: break
            P = tp/max(tp+fp,1); R = tp/max(tp+fn,1); F1 = 2*P*R/max(P+R,1e-8)
            print(f"[HEAD] ep{ep} loss {loss_sum/seen:.4f} | F1 {F1:.3f} P {P:.3f} R {R:.3f}")

@torch.no_grad()
def quick_prob_stats(model, loader, n_batches=3, thr=0.5):
    model.eval()
    dev = next(model.parameters()).device
    pos_means, neg_means = [], []
    tp = fp = fn = 0.0
    t0=time.time()
    for b, (xb, yb) in enumerate(loader, 1):
        xb, yb = xb.to(dev, non_blocking=True), yb.to(dev, non_blocking=True)
        logits = model(xb)
        yb_r = resize_masks_to(logits, yb)
        p = torch.sigmoid(logits)
        if (yb_r > 0.5).any():
            pos_means.append(float(p[yb_r > 0.5].mean().item()))
        neg_means.append(float(p[yb_r <= 0.5].mean().item()))
        pv, tv = p.view(-1), yb_r.view(-1)
        pred = (pv >= thr).float()
        tp += float((pred*tv).sum()); fp += float((pred*(1-tv)).sum()); fn += float(((1-pred)*tv).sum())
        if b>=n_batches: break
    P = tp/max(tp+fp,1); R = tp/max(tp+fn,1); F1 = 2*P*R/max(P+R,1e-8)
    import numpy as np
    print(f"[quick_prob_stats] batches={min(n_batches, b)} | pos_mean≈{(np.mean(pos_means) if pos_means else float('nan')):.4f} | neg_mean≈{np.mean(neg_means):.4f} | P {P:.3f} R {R:.3f} F1 {F1:.3f} @ thr={thr:.3f} | {time.time()-t0:.1f}s")

def init_head_bias_to_prior(model, p0=0.70):
    import math
    b = math.log(p0/(1-p0))
    with torch.no_grad():
        if hasattr(model, "head") and hasattr(model.head, "bias"):
            model.head.bias.data.fill_(b)

# ===== main recipe =====
def run_probe_recipe(model_cls, p0=0.80):
    # 0) build probe + warmup once
    probe = model_cls(in_ch=1, out_ch=1).to(device)
    init_head_bias_to_prior(probe, p0=p0)
    torch.backends.cudnn.benchmark = False
    print("Warm-up (BCE-only)…")
    fit_quick_warmup(probe, train_loader_small, epochs=2, max_batches=800, lr=2e-4, metric_thr=0.20, pos_weight=40.0)

    # lock warm state
    warm_state = copy.deepcopy(probe.state_dict())

    # 1) fast, recall-friendly threshold (pixel hist)
    thr0, *_ = pick_thr_under_min(probe, val_loader_small, max_batches=200, n_bins=256, beta=2.0)
    thr0 = float(__import__("numpy").clip(thr0, 0.05, 0.15))
    print("thr0 =", thr0)

    # 2) head-only BCE calibration
    probe.load_state_dict(warm_state, strict=True)
    fit_quick_head_only(probe, train_loader_small, epochs=2, max_batches=2000, lr=3e-5, metric_thr=thr0, pos_weight=5.0)
    quick_prob_stats(probe, train_loader_small, n_batches=30, thr=thr0)

    # 3) unfreeze head+tiny tail; gentle BCE+Tversky (no OHEM/BG)
    class _BCEPlusTversky(nn.Module):
        def __init__(self, pos_weight=2.0, alpha=0.30, beta=0.70, gamma=1.1, w_bce=0.85, w_tv=0.15):
            super().__init__()
            self.bce = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(pos_weight))
            # reuse your AsymFocalTversky from StreakSegLossFP module if available
            self.aftl = AsymFocalTversky(alpha=alpha, beta=beta, gamma=gamma)
            self.w_bce, self.w_tv = w_bce, w_tv
        def forward(self, logits, targets):
            t = targets.clamp(0,1)
            return self.w_bce*self.bce(logits, t) + self.w_tv*self.aftl(logits, t)

    probe.load_state_dict(warm_state, strict=True)
    unfreeze_head_and_last_conv(probe)
    crit_gentle = _BCEPlusTversky(pos_weight=2.0, alpha=0.30, beta=0.70, gamma=1.1, w_bce=0.9, w_tv=0.1).to(device)
    fit_quick(probe, crit_gentle, train_loader_small, epochs=2, max_batches=2500, lr=1.5e-4, metric_thr=thr0, weight_decay=1e-4)
    quick_prob_stats(probe, train_loader_small, n_batches=30, thr=thr0)

    # 4) pick threshold with a sensible positive-rate band (recall bias)
    thr1, (P1, R1, F1), aux = pick_thr_with_floor(probe, val_loader_small, max_batches=20, n_bins=256, beta=2.0, min_pos_rate=0.03, max_pos_rate=0.10)
    print(f"thr1 = {thr1:.3f} | P≈{P1:.3f} R≈{R1:.3f} F≈{F1:.3f} pos_rate≈{aux['pos_rate']:.3f}")
    quick_prob_stats(probe, train_loader_small, n_batches=30, thr=thr1)

    # 5) tiny OHEM nudge to trim glow (optional, very short)
    crit_ohem = StreakSegLossFP(w_aftl=0.0, w_lovasz=0.0, w_ohem=1.0,
                                ohem_neg_percent=0.02, bg_lambda=0.0,
                                pos_weight=torch.tensor(2.0, device=device)).to(device)
    fit_quick(probe, crit_ohem, train_loader_small, epochs=1, max_batches=1500, lr=5e-5, metric_thr=thr1, weight_decay=0.0)

    # re-pick within a narrower 3–8% band
    thr2, (P2, R2, F2), aux2 = pick_thr_with_floor(probe, val_loader_small, max_batches=200, n_bins=256, beta=2.0, min_pos_rate=0.03, max_pos_rate=0.08)
    print(f"thr2 = {thr2:.3f} | P≈{P2:.3f} R≈{R2:.3f} F≈{F2:.3f} pos_rate≈{aux2['pos_rate']:.3f}")
    quick_prob_stats(probe, train_loader_small, n_batches=3, thr=thr2)

    # 6) light mixed pass to lift precision a touch (keep recall)
    crit_mix = StreakSegLossFP(w_aftl=0.30, w_lovasz=0.10, w_ohem=0.02,
                               aftl_alpha=0.35, aftl_beta=0.65, aftl_gamma=1.2,
                               ohem_neg_percent=0.02, bg_lambda=0.00,
                               pos_weight=torch.tensor(1.0, device=device)).to(device)
    fit_quick(probe, crit_mix, train_loader_small, epochs=1, max_batches=2000, lr=1.5e-4, metric_thr=thr2, weight_decay=0.0)

    # final threshold in a precision-friendlier band (8–12%)
    thr_final, (Pf, Rf, Ff), auxf = pick_thr_with_floor(probe, val_loader_small, max_batches=200, n_bins=256, beta=1.0, min_pos_rate=0.08, max_pos_rate=0.12)
    print(f"[FINAL] thr = {thr_final:.3f} | P≈{Pf:.3f} R≈{Rf:.3f} F≈{Ff:.3f} pos_rate≈{auxf['pos_rate']:.3f}")
    quick_prob_stats(probe, train_loader_small, n_batches=3, thr=thr_final)

    return probe, thr_final

# ===== Run it =====
probe, metric_thr = run_probe_recipe(UNetResSEASPP, p0=0.80)
print ("FINISHED")

Warm-up (BCE-only)…
[WARMUP] ep1 batch 800/800 | loss=0.3742 | over>=0.5 0.0100 | pos nan | neg 0.0854 | 120.9s.1s
[WARMUP] ep1 loss 0.3742 | F1 0.021 P 0.011 R 0.486 | 120.9s
[WARMUP] ep2 batch 800/800 | loss=0.3516 | over>=0.5 0.0352 | pos 0.4961 | neg 0.1326 | 120.9s
[WARMUP] ep2 loss 0.3516 | F1 0.028 P 0.014 R 0.517 | 121.0s
[FAST-T] batches 200/200 | elapsed 7.6s
[FAST-T] best thr≈0.447 | P≈0.030 R≈0.241 F2.0≈0.101 | time 7.6s
thr0 = 0.15
[HEAD] ep1 loss 0.1301 | F1 0.027 P 0.014 R 0.583
[HEAD] ep2 loss 0.0908 | F1 0.040 P 0.021 R 0.467
[quick_prob_stats] batches=30 | pos_mean≈0.1681 | neg_mean≈0.0326 | P 0.028 R 0.439 F1 0.053 @ thr=0.150 | 2.2s
[QP] batch 2495/2500 | loss=0.1590 | over>=0.1500 0.0043 | pos 0.074 neg 0.047 | sep 0.0266 | 99.9ss
[QP] ep1 loss 0.1590 | F1 0.026 P 0.014 R 0.136 | 100.2s
[QP] batch 2495/2500 | loss=0.1524 | over>=0.1500 0.0006 | pos 0.066 neg 0.026 | sep 0.0401 | 99.9ss
[QP] ep2 loss 0.1524 | F1 0.041 P 0.048 R 0.036 | 100.1s
[quick_prob_stats] batc