In [12]:
# --- core ---
import os, gc, time, math, random
import numpy as np
import torch, torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import roc_auc_score
from scipy.ndimage import label as cc_label, find_objects, binary_opening

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.backends.cudnn.benchmark = True

def set_seed(s=1337):
    random.seed(s); np.random.seed(s)
    torch.manual_seed(s); torch.cuda.manual_seed_all(s)
set_seed(1337)

# helper: resize masks to model output size (NN/metrics use nearest to preserve {0,1})
def resize_masks_to(pred_like, masks, mode='nearest'):
    H, W = pred_like.shape[-2:]
    return F.interpolate(masks.float(), size=(H, W), mode=mode)


In [13]:
import h5py

def robust_stats_mad(arr):
    med = np.median(arr)
    mad = np.median(np.abs(arr - med))
    sigma = 1.4826 * (mad + 1e-12)
    return np.float32(med), np.float32(sigma if np.isfinite(mad) and mad>0 else 1.0)

class H5TiledDataset(Dataset):
    """
    Streams 128x128 (default) tiles from large (H,W) images in an HDF5.
    Per-image robust standardization + k-sigma clipping.
    Pads edge tiles to full tile size.
    """
    def __init__(self, h5_path, tile=128, k_sigma=5.0, crop_for_stats=512):
        self.h5_path = h5_path
        self.tile    = int(tile)
        self.k_sigma = float(k_sigma)
        self.crop_for_stats = int(crop_for_stats)
        self._h5 = None
        self._x = None; self._y = None
        self._stats_cache = {}   # image_id -> (med, sigma)

        with h5py.File(self.h5_path, "r") as f:
            self.N, self.H, self.W = f["images"].shape
            assert f["masks"].shape == (self.N, self.H, self.W)

        Hb = math.ceil(self.H / self.tile)
        Wb = math.ceil(self.W / self.tile)
        self.indices = [(i, r, c) for i in range(self.N) for r in range(Hb) for c in range(Wb)]

    def _ensure_open(self):
        if self._h5 is None:
            self._h5 = h5py.File(self.h5_path, "r")
            self._x = self._h5["images"]; self._y = self._h5["masks"]

    def _image_stats(self, i):
        if i in self._stats_cache: return self._stats_cache[i]
        H, W = self.H, self.W
        s = min(self.crop_for_stats, H, W)
        h0 = (H - s)//2; w0 = (W - s)//2
        crop = self._x[i, h0:h0+s, w0:w0+s].astype("float32")
        med, sig = robust_stats_mad(crop)
        self._stats_cache[i] = (med, sig)
        return med, sig

    def __len__(self): return len(self.indices)

    def __getitem__(self, idx):
        self._ensure_open()
        i, r, c = self.indices[idx]
        t = self.tile
        r0, c0 = r*t, c*t
        r1, c1 = min(r0+t, self.H), min(c0+t, self.W)

        x = self._x[i, r0:r1, c0:c1].astype("float32")
        y = self._y[i, r0:r1, c0:c1].astype("float32")

        if x.shape != (t, t):
            xp = np.zeros((t, t), np.float32); yp = np.zeros((t, t), np.float32)
            xp[:x.shape[0], :x.shape[1]] = x;  yp[:y.shape[0], :y.shape[1]] = y
            x, y = xp, yp

        med, sig = self._image_stats(i)
        x = (x - med) / sig
        x = np.clip(x, -5.0, 5.0)  # k-sigma clip

        # add channel dim
        return torch.from_numpy(x[None, ...]), torch.from_numpy(y[None, ...])


In [15]:
# <<< set your paths here >>>
train_h5 = "/home/karlo/train_chunked.h5"
test_h5  = "../DATA/test.h5"
tile = 128
batch_size = 128
num_workers = 0  # safer on shared systems

# Split train into train/val by index
with h5py.File(train_h5, "r") as f:
    N = f["images"].shape[0]
idx = np.arange(N)
np.random.shuffle(idx)
split = int(0.9 * N)  # 90/10 split
idx_tr, idx_va = np.sort(idx[:split]), np.sort(idx[split:])

# Datasets
ds_full = H5TiledDataset(train_h5, tile=tile, k_sigma=5.0)
class SubsetDS(Dataset):
    def __init__(self, base, panel_indices):
        self.base = base; self.panel_indices = panel_indices
        # remap base.indices to only those panels
        t = base.tile; Hb = math.ceil(base.H / t); Wb = math.ceil(base.W / t)
        self.map = []
        base_map = {(i,r,c):k for k,(i,r,c) in enumerate(base.indices)}
        for i in panel_indices:
            for r in range(Hb):
                for c in range(Wb):
                    self.map.append(base_map[(i,r,c)])
    def __len__(self): return len(self.map)
    def __getitem__(self, k): return self.base[self.map[k]]

train_ds = SubsetDS(ds_full, idx_tr)
val_ds   = SubsetDS(ds_full, idx_va)
test_ds  = H5TiledDataset(test_h5, tile=tile, k_sigma=5.0)

# Loaders
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,
                          num_workers=num_workers, pin_memory=(device.type=='cuda'))
val_loader   = DataLoader(val_ds,   batch_size=batch_size, shuffle=False,
                          num_workers=num_workers, pin_memory=(device.type=='cuda'))
test_loader  = DataLoader(test_ds,  batch_size=batch_size, shuffle=False,
                          num_workers=num_workers, pin_memory=(device.type=='cuda'))


In [16]:
class SEBlock(nn.Module):
    def __init__(self, c, r=8):
        super().__init__()
        self.fc1 = nn.Conv2d(c, c//r, 1)
        self.fc2 = nn.Conv2d(c//r, c, 1)
    def forward(self, x):
        s = F.adaptive_avg_pool2d(x,1)
        s = F.silu(self.fc1(s))
        s = torch.sigmoid(self.fc2(s))
        return x * s

class ResBlock(nn.Module):
    def __init__(self, c_in, c_out, k=3, se=True):
        super().__init__()
        p = k//2
        self.proj = nn.Identity() if c_in==c_out else nn.Conv2d(c_in, c_out, 1)
        self.bn1 = nn.BatchNorm2d(c_in)
        self.c1  = nn.Conv2d(c_in, c_out, k, padding=p, bias=False)
        self.bn2 = nn.BatchNorm2d(c_out)
        self.c2  = nn.Conv2d(c_out, c_out, k, padding=p, bias=False)
        self.se  = SEBlock(c_out) if se else nn.Identity()
    def forward(self, x):
        h = F.silu(self.bn1(x))
        h = self.c1(h)
        h = F.silu(self.bn2(h))
        h = self.c2(h)
        h = self.se(h)
        return h + self.proj(x)

class Down(nn.Module):
    def __init__(self, c_in, c_out):
        super().__init__()
        self.pool = nn.MaxPool2d(2)
        self.rb   = ResBlock(c_in, c_out)
    def forward(self, x): return self.rb(self.pool(x))

class Up(nn.Module):
    def __init__(self, c_in, c_skip, c_out):
        super().__init__()
        self.up  = nn.ConvTranspose2d(c_in, c_in, 2, stride=2)
        self.rb1 = ResBlock(c_in + c_skip, c_out)
        self.rb2 = ResBlock(c_out, c_out)
    def forward(self, x, skip):
        x = self.up(x)
        # pad if odd sizes
        dh = skip.size(-2) - x.size(-2); dw = skip.size(-1) - x.size(-1)
        if dh or dw: x = F.pad(x, (0,max(0,dw),0,max(0,dh)))
        x = torch.cat([x, skip], 1)
        x = self.rb1(x); x = self.rb2(x)
        return x

class UNetResSE(nn.Module):
    def __init__(self, in_ch=1, out_ch=1, widths=(32,64,128,256,512)):
        super().__init__()
        w = widths
        self.stem = nn.Sequential(
            nn.Conv2d(in_ch, w[0], 3, padding=1, bias=False),
            nn.BatchNorm2d(w[0]), nn.SiLU(inplace=True),
            ResBlock(w[0], w[0])
        )
        self.d1 = Down(w[0], w[1])
        self.d2 = Down(w[1], w[2])
        self.d3 = Down(w[2], w[3])
        self.d4 = Down(w[3], w[4])
        self.u1 = Up(w[4], w[3], w[3])
        self.u2 = Up(w[3], w[2], w[2])
        self.u3 = Up(w[2], w[1], w[1])
        self.u4 = Up(w[1], w[0], w[0])
        self.head = nn.Conv2d(w[0], out_ch, 1)  # logits
    def forward(self, x):
        s0 = self.stem(x)
        s1 = self.d1(s0)
        s2 = self.d2(s1)
        s3 = self.d3(s2)
        b  = self.d4(s3)
        x  = self.u1(b, s3)
        x  = self.u2(x, s2)
        x  = self.u3(x, s1)
        x  = self.u4(x, s0)
        return self.head(x)


In [17]:
class DiceLoss(nn.Module):
    def __init__(self, eps=1e-6): super().__init__(); self.eps=eps
    def forward(self, logits, targets):
        p = torch.sigmoid(logits)
        p = p.view(p.size(0), -1); t = targets.view(targets.size(0), -1)
        inter = (p*t).sum(1)
        denom = p.sum(1) + t.sum(1)
        dice  = (2*inter + self.eps)/(denom + self.eps)
        return 1 - dice.mean()

class BCEDice(nn.Module):
    def __init__(self, pos_weight=None, dice_weight=1.0, bce_weight=1.0):
        super().__init__()
        self.bce = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
        self.dice= DiceLoss()
        self.dw, self.bw = dice_weight, bce_weight
    def forward(self, logits, targets):
        return self.bw*self.bce(logits, targets) + self.dw*self.dice(logits, targets)

def make_optimizer(model, lr=1.5e-4, wd=1e-4):
    return torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)

def make_scheduler(opt, steps_per_epoch, epochs, max_lr):
    return torch.optim.lr_scheduler.OneCycleLR(opt, max_lr=max_lr,
                                               steps_per_epoch=steps_per_epoch,
                                               epochs=epochs, pct_start=0.1,
                                               anneal_strategy='cos')


In [18]:
scaler = torch.cuda.amp.GradScaler(enabled=(device.type=='cuda'))

def _auc_from_hists(pos_hist, neg_hist):
    tp = np.cumsum(pos_hist[::-1]); fp = np.cumsum(neg_hist[::-1])
    if tp[-1]==0 or fp[-1]==0: return float('nan')
    tpr = tp/tp[-1]; fpr = fp/fp[-1]
    tpr = np.concatenate(([0.0], tpr)); fpr = np.concatenate(([0.0], fpr))
    return np.trapz(tpr, fpr)

def run_epoch(loader, model, criterion, optimizer=None, train=True,
              tag="Train", print_every=10, n_bins=512, grad_clip=1.0):
    model.train(train)
    start = time.time()
    total = len(loader.dataset)
    seen = 0; running_loss=0.0; tp=fp=fn=0.0
    pos_hist = np.zeros(n_bins, np.float64)
    neg_hist = np.zeros(n_bins, np.float64)

    for b, (xb, yb) in enumerate(loader, start=1):
        xb, yb = xb.to(device, non_blocking=True), yb.to(device, non_blocking=True)
        with torch.cuda.amp.autocast(enabled=(device.type=='cuda')):
            logits = model(xb)                     # logits
            yb_r   = resize_masks_to(logits, yb)   # nearest for metrics/targets
            loss   = criterion(logits, yb_r)

        if train:
            optimizer.zero_grad(set_to_none=True)
            scaler.scale(loss).backward()
            if grad_clip is not None:
                scaler.unscale_(optimizer)
                nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
            scaler.step(optimizer); scaler.update()

        bs = xb.size(0)
        seen += bs
        running_loss += float(loss.item()) * bs

        with torch.no_grad():
            p = torch.sigmoid(logits).detach().view(-1).cpu()
            t = yb_r.detach().view(-1).cpu()
            p_bin = (p >= 0.5).float()
            tp += float((p_bin*t).sum())
            fp += float((p_bin*(1-t)).sum())
            fn += float(((1-p_bin)*t).sum())
            # streaming AUC
            idx = torch.clamp((p*(n_bins-1)).long(), 0, n_bins-1)
            pos_hist += np.bincount(idx[t>0.5].numpy(), minlength=n_bins)
            neg_hist += np.bincount(idx[t<=0.5].numpy(), minlength=n_bins)

        if ((b % print_every == 0) or (seen == total)) and train:
            prec = tp / (tp + fp + 1e-8)
            rec  = tp / (tp + fn + 1e-8)
            f1   = 2 * prec * rec / (prec + rec + 1e-8)
            avg_loss = running_loss / seen
            elapsed = time.time() - start
            print(
                f"\r[{tag}] batch {b}/{len(loader)} | {seen}/{total} ex | "
                f"loss={avg_loss:.4f} | F1 {f1:.4f} | P {prec:.4f} | R {rec:.4f} | {elapsed:.1f}s",
                end='', flush=True
            )
    if train: print()
    auc = _auc_from_hists(pos_hist, neg_hist)
    prec = tp / (tp + fp + 1e-8)
    rec  = tp / (tp + fn + 1e-8)
    f1   = 2 * prec * rec / (prec + rec + 1e-8)
    epoch_loss = running_loss / total
    return epoch_loss, auc, prec, rec, f1

def fit(model, train_loader, val_loader, criterion,
        epochs=20, max_lr=1.5e-4, early_stop_patience=10,
        save_path="./best_unet_resse.pt"):
    model = model.to(device)
    opt   = make_optimizer(model, lr=max_lr, wd=1e-4)
    sched = make_scheduler(opt, steps_per_epoch=len(train_loader), epochs=epochs, max_lr=max_lr)
    best_f1, no_improve = -1, 0

    for ep in range(1, epochs+1):
        t0=time.time()
        tr_loss, _, tr_p, tr_r, tr_f1 = run_epoch(train_loader, model, criterion, opt, train=True,  tag="Train", print_every=10)
        va_loss, va_auc, va_p, va_r, va_f1 = run_epoch(val_loader,   model, criterion, None,      train=False, tag="Val")
        sched.step()

        print(f"Epoch {ep:03d} | "
              f"Train L {tr_loss:.4f} F1 {tr_f1:.4f} P {tr_p:.4f} R {tr_r:.4f} || "
              f"Val L {va_loss:.4f} AUC {va_auc:.4f} F1 {va_f1:.4f} P {va_p:.4f} R {va_r:.4f} | "
              f"{time.time()-t0:.1f}s")

        if va_f1 > best_f1 + 1e-4:
            best_f1, no_improve = va_f1, 0
            torch.save({"state_dict": model.state_dict()}, save_path)
        else:
            no_improve += 1
            if no_improve >= early_stop_patience:
                print("Early stopping."); break

    return model


  scaler = torch.cuda.amp.GradScaler(enabled=(device.type=='cuda'))


In [None]:
model = UNetResSE(in_ch=1, out_ch=1, widths=(32,64,128,256,512))
criterion = BCEDice(pos_weight=None, dice_weight=1.0, bce_weight=1.0)
_ = fit(model, train_loader, val_loader, criterion,
        epochs=20, max_lr=1.5e-4, early_stop_patience=10,
        save_path="./best_unet_resse.pt")


  with torch.cuda.amp.autocast(enabled=(device.type=='cuda')):


[Train] batch 4520/5760 | 578560/737280 ex | loss=1.0560 | F1 0.0675 | P 0.0594 | R 0.0783 | 912.1s

In [None]:
@torch.no_grad()
def predict_tiles_to_full(h5_path, loader, model, tile=128):
    model.eval(); model.to(device)
    with h5py.File(h5_path,'r') as f:
        N,H,W = f['images'].shape
    Hb, Wb = math.ceil(H/tile), math.ceil(W/tile)
    tiles_per_panel = Hb*Wb
    full_preds = np.zeros((N,H,W), np.float32)
    ptr=0; buf=[]

    for xb,_ in loader:
        xb = xb.to(device, non_blocking=True)
        logits = model(xb)
        probs  = torch.sigmoid(logits).cpu().numpy()[:,0]
        buf.append(probs)

        # flush complete panels
        while len(buf)>0:
            cur = buf[0]
            if cur.shape[0] < tiles_per_panel: break
            tile_buf = cur[:tiles_per_panel]
            buf[0] = cur[tiles_per_panel:]
            if buf[0].shape[0]==0: buf.pop(0)

            p = ptr; ptr += 1
            panel = np.zeros((Hb*tile, Wb*tile), np.float32)
            t=0
            for r in range(Hb):
                for c in range(Wb):
                    r0, c0 = r*tile, c*tile
                    panel[r0:r0+tile, c0:c0+tile] = tile_buf[t]; t+=1
            full_preds[p] = panel[:H,:W]

    return full_preds

def postprocess(bin_full, open_iters=1):
    if open_iters>0:
        bin_full = np.stack([binary_opening(b, iterations=open_iters) for b in bin_full], 0)
    return bin_full.astype(np.uint8)

def component_PR(bin_full, gt_full, iou_thr=0.1, min_pix=120, min_elong=2.5):
    TP=FP=0; total_gt=0
    for i in range(bin_full.shape[0]):
        pred = bin_full[i]; gt = gt_full[i].astype(np.uint8)
        L, n = cc_label(pred, structure=np.ones((3,3), np.uint8))
        for k, sl in enumerate(find_objects(L) or [], start=1):
            comp = (L[sl]==k)
            area = int(comp.sum())
            if area < min_pix: continue
            h = sl[0].stop - sl[0].start; w = sl[1].stop - sl[1].start
            elong = max(h,w)/max(1,min(h,w))
            if elong < min_elong: continue
            gtl = gt[sl].astype(bool)
            inter = (comp & gtl).sum(); union = (comp | gtl).sum()
            iou = inter/union if union else 0.0
            if iou >= iou_thr: TP += 1
            else: FP += 1
        _, ng = cc_label(gt, structure=np.ones((3,3), np.uint8))
        total_gt += ng
    FN = total_gt - TP
    P = TP/max(TP+FP,1); R = TP/max(TP+FN,1); F1=2*P*R/max(P+R,1e-8)
    return P,R,F1,TP,FP,total_gt

# Load best weights
ckpt = torch.load("./best_unet_resse.pt", map_location=device)
model.load_state_dict(ckpt["state_dict"])
model.eval();

# Predict on validation full frames and pick threshold by F1
val_panels_h5 = train_h5  # we stitched the same file; extract only val indices
p_val_full = predict_tiles_to_full(val_panels_h5, val_loader, model, tile=tile)
with h5py.File(val_panels_h5,'r') as f:
    gt_all = f['masks'][:].astype(np.uint8)
gt_val_full = gt_all[idx_va]

ths = np.linspace(0.5, 0.99, 20)
best_thr, best_metrics = None, None
for t in ths:
    bin_v = postprocess((p_val_full>=t).astype(np.uint8), open_iters=1)
    P,R,F1,TP,FP,TG = component_PR(bin_v, gt_val_full, iou_thr=0.1, min_pix=120, min_elong=2.5)
    if (best_metrics is None) or (F1 > best_metrics[2]):
        best_thr, best_metrics = t, (P,R,F1,TP,FP,TG)

print(f"[VAL] best thr={best_thr:.3f}  P={best_metrics[0]:.3f} R={best_metrics[1]:.3f} F1={best_metrics[2]:.3f} "
      f"| TP={best_metrics[3]} FP={best_metrics[4]} GT={best_metrics[5]}")

# Final TEST evaluation
p_test_full = predict_tiles_to_full(test_h5, test_loader, model, tile=tile)
with h5py.File(test_h5,'r') as f:
    gt_test_full = f['masks'][:].astype(np.uint8)

bin_t = postprocess((p_test_full>=best_thr).astype(np.uint8), open_iters=1)
P,R,F1,TP,FP,TG = component_PR(bin_t, gt_test_full, iou_thr=0.1, min_pix=120, min_elong=2.5)
print(f"[TEST] thr={best_thr:.3f}  P={P:.3f} R={R:.3f} F1={F1:.3f} | TP={TP} FP={FP} GT={TG}")
