In [1]:
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import sklearn
import h5py

In [2]:
def activation_parser(activation_str):
    """
    Parse a string to return the corresponding activation function.
    Supported strings: 'relu', 'sigmoid', 'tanh', 'leaky_relu'.
    """
    if activation_str.lower() == "elu":
        return nn.ELU(inplace=True)
    elif activation_str.lower() == "hardshrink":
        return nn.Hardshrink(lambd=0.5)
    elif activation_str.lower() == "hardsigmoid":
        return nn.Hardsigmoid(inplace=True)
    elif activation_str.lower() == "hardtanh":
        return nn.Hardtanh(min_val=-1, max_val=1, inplace=True)
    elif activation_str.lower() == "leakyrelu":
        return nn.LeakyReLU(negative_slope=0.01, inplace=True)
    elif activation_str.lower() == "logsigmoid":
        return nn.LogSigmoid()
    elif activation_str.lower() == "prelu":
        return nn.PReLU(num_parameters=1, init=0.25)
    elif activation_str.lower() == "relu":
        return nn.ReLU(inplace=True)
    elif activation_str.lower() == "relu6":
        return nn.ReLU6(inplace=True)
    elif activation_str.lower() == "selu":
        return nn.SELU(inplace=True)
    elif activation_str.lower() == "celu":
        return nn.CELU(inplace=True)
    elif activation_str.lower() == "gelu":
        return nn.GELU(approximate='none')  # 'tanh' or 'none'
    elif activation_str.lower() == "sigmoid":
        return nn.Sigmoid()
    elif activation_str.lower() == "silu":
        return nn.SiLU(inplace=True)  # also known as Swish
    elif activation_str.lower() == "mish":
        return nn.Mish(inplace=True)
    elif activation_str.lower() == "softplus":
        return nn.Softplus(beta=1, threshold=20, inplace=True)
    elif activation_str.lower() == "softshrink":
        return nn.Softshrink(lambd=0.5, inplace=True)
    elif activation_str.lower() == "softsign":
        return nn.Softsign()
    elif activation_str.lower() == "tanh":
        return nn.Tanh()
    elif activation_str.lower() == "tanhshrink":
        return nn.Tanhshrink()
    elif activation_str.lower() == "threshold":
        return nn.Threshold(threshold=0.25, value=0.0, inplace=True)
    elif activation_str.lower() == "glu":
        return nn.GLU(dim=1)  # assumes input has shape (B, C, H, W)
    elif activation_str.lower() == "softmax":
        return nn.Softmax(dim=1)  # applies softmax across channels
    elif activation_str.lower() == "logsoftmax":
        return nn.LogSoftmax(dim=1)  # applies log softmax across channels
    elif activation_str.lower() == "none":
        return nn.Identity()
    else:
        raise ValueError(f"Unsupported activation: {activation_str}")

class ChannelAttention(nn.Module):
    def __init__(self, in_channels, ratio=8):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)

        self.fc1 = nn.Linear(in_channels, in_channels // ratio, bias=False)
        self.relu = nn.ReLU(inplace=True)
        self.fc2 = nn.Linear(in_channels // ratio, in_channels, bias=False)
        self.sigmoid = nn.Sigmoid()
        self.norm = nn.BatchNorm1d(in_channels)

        nn.init.kaiming_normal_(self.fc1.weight, mode='fan_out', nonlinearity='relu')
        nn.init.xavier_uniform_(self.fc2.weight)

    def forward(self, x):
        avg_out = self.avg_pool(x)        # (B, C, 1, 1)
        avg_out = avg_out.view(avg_out.size(0), avg_out.size(1))  # (B, C)
        avg_out = self.fc1(avg_out)       # (B, C//ratio)
        avg_out = self.relu(avg_out)
        avg_out = self.fc2(avg_out)       # (B, C)

        max_out = self.max_pool(x)        # (B, C, 1, 1)
        max_out = max_out.view(max_out.size(0), max_out.size(1))  # (B, C)
        max_out = self.fc1(max_out)       # (B, C//ratio)
        max_out = self.relu(max_out)
        max_out = self.fc2(max_out)       # (B, C)

        out = avg_out + max_out           # (B, C)
        out = self.norm(out)              # (B, C)
        scale = self.sigmoid(out)         # (B, C)
        scale = scale.view(scale.size(0), scale.size(1), 1, 1)  # (B, C, 1, 1)
        return x * scale                  # broadcast along H, W

class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=5):
        super().__init__()
        assert kernel_size in (3, 5, 7)
        padding = (kernel_size - 1) // 2
        self.conv = nn.Conv2d(2, 1, kernel_size=kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()
        self.norm = nn.BatchNorm2d(1)
        nn.init.xavier_uniform_(self.conv.weight)

    def forward(self, x):
        # x: (B, C, H, W)
        avg_out = torch.mean(x, dim=1, keepdim=True)     # (B, 1, H, W)
        max_out, _ = torch.max(x, dim=1, keepdim=True)   # (B, 1, H, W)
        concat = torch.cat([avg_out, max_out], dim=1)    # (B, 2, H, W)
        attn = self.conv(concat)                         # (B, 1, H, W)
        attn = self.norm(attn)                           # (B, 1, H, W)
        attn = self.sigmoid(attn)
        return x * attn                                  # broadcast across C

class CBAMBlock(nn.Module):
    def __init__(self, in_channels, ratio=8, kernel_size=7):
        super().__init__()
        self.channel_att = ChannelAttention(in_channels, ratio)
        self.spatial_att = SpatialAttention(kernel_size)

    def forward(self, x):
        x = self.channel_att(x)
        x = self.spatial_att(x)
        return x

class SepConv(nn.Module):
    def __init__(self, in_ch, out_ch, activation, kernel_size, padding, dilation=1):
        super().__init__()
        self.depthwise = nn.Conv2d(
            in_ch, in_ch, kernel_size=kernel_size,
            padding=padding, dilation=dilation,
            groups=in_ch, bias=True
        )
        self.pointwise = nn.Conv2d(in_ch, out_ch, kernel_size=1, bias=True)
        self.norm = nn.BatchNorm2d(out_ch)
        self.act = activation_parser(activation)

        nn.init.kaiming_normal_(self.depthwise.weight, mode='fan_out', nonlinearity="relu")
        nn.init.constant_(self.depthwise.bias, 0)
        nn.init.kaiming_normal_(self.pointwise.weight, mode='fan_out', nonlinearity="relu")
        nn.init.constant_(self.pointwise.bias, 0)

    def forward(self, x):
        x = self.depthwise(x)
        x = self.pointwise(x)
        return self.act(self.norm(x))

class ASPP(nn.Module):
    def __init__(self, in_ch, out_ch, activation):
        super().__init__()
        dilations = [1, 2, 3, 4]
        kernels   = [1, 3, 5, 7]
        self.branches = nn.ModuleList()
        for d, k in zip(dilations, kernels):
            pad = (k // 2) * d
            self.branches.append(
                SepConv(in_ch, out_ch, activation, kernel_size=k, padding=pad, dilation=d)
            )
        self.merge = nn.Sequential(
            nn.Conv2d(len(dilations) * out_ch, out_ch, kernel_size=1, bias=True),
            nn.BatchNorm2d(out_ch),
            activation_parser(activation)
        )
        nn.init.kaiming_normal_(self.merge[0].weight, mode='fan_out', nonlinearity="relu")
        nn.init.constant_(self.merge[0].bias, 0)

    def forward(self, x):
        outs = [branch(x) for branch in self.branches]
        x = torch.cat(outs, dim=1)
        return self.merge(x)

class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch, activation):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1, bias=True),
            nn.BatchNorm2d(out_ch),
            activation_parser(activation),
            nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1, bias=True),
            nn.BatchNorm2d(out_ch),
            activation_parser(activation)
        )
        for m in self.block.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity="relu")
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        return self.block(x)

class AttentionGate(nn.Module):
    def __init__(self, F_g, F_l, F_int):
        super().__init__()
        # W_g projects gating signal
        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1, bias=True),
            nn.BatchNorm2d(F_int)
        )
        # W_x projects skip connection
        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1, bias=True),
            nn.BatchNorm2d(F_int)
        )
        # psi computes 1‐channel attention map
        self.psi = nn.Sequential(
            nn.Conv2d(F_int, F_g, kernel_size=1, bias=True),
            nn.BatchNorm2d(F_g),
            nn.Sigmoid()
        )
        self.relu = nn.ReLU(inplace=True)

        nn.init.kaiming_normal_(self.W_g[0].weight, mode='fan_out', nonlinearity='relu')
        nn.init.constant_(self.W_g[0].bias, 0)
        nn.init.kaiming_normal_(self.W_x[0].weight, mode='fan_out', nonlinearity='relu')
        nn.init.constant_(self.W_x[0].bias, 0)
        nn.init.xavier_uniform_(self.psi[0].weight)
        nn.init.constant_(self.psi[0].bias, 0)

    def forward(self, g, x):
        """
        g: gating signal from decoder, shape (B, F_g, H, W)
        x: skip connection from encoder, shape (B, F_l, H, W)
        """
        g1 = self.W_g(g)   # (B, F_int, H, W)
        x1 = self.W_x(x)   # (B, F_int, H, W)
        psi = self.relu(g1 + x1)
        psi = self.psi(psi)   # (B, 1, H, W)
        return x * psi        # broadcast along channel


class EncoderBlock(nn.Module):
    def __init__(self, in_ch, out_ch, activation, dropout_prob=0.0, attention=True, pool=True, ASPP_blocks=True):
        super().__init__()
        if ASPP_blocks:
            # Use ASPP instead of DoubleConv
            self.conv = ASPP(in_ch, out_ch, activation)
        else:
            # Use DoubleConv if ASPP_blocks is False
            self.conv = DoubleConv(in_ch, out_ch, activation)
        self.cbam        = CBAMBlock(out_ch, ratio=8, kernel_size=7) if attention else nn.Identity()
        self.dropout     = nn.Dropout2d(dropout_prob) if dropout_prob > 0 else nn.Identity()
        self.pool        = pool

    def forward(self, x):
        x = self.conv(x)
        x = self.cbam(x)
        x = self.dropout(x)
        skip = x.clone()
        if self.pool:
            x = F.max_pool2d(x, kernel_size=2)
        return x, skip


class DecoderBlock(nn.Module):
    def __init__(self, in_ch, skip_ch, out_ch, activation, dropout_prob=0.0, attention=True, upsample=True, ASPP_blocks=True):
        """
        in_ch:   channels from previous layer (bottleneck or previous decoder)
        skip_ch: channels in the corresponding encoder skip
        out_ch:  desired output channels for this decoder block
        """
        super().__init__()
        self.upsample = upsample
        self.skip_ch = skip_ch

        if self.upsample:
            # ConvTranspose2d(in_ch → skip_ch) to match spatial & channel dims
            self.up = nn.ConvTranspose2d(in_ch, skip_ch, kernel_size=3,
                                         stride=2, padding=1, output_padding=1, bias=True)
            nn.init.kaiming_normal_(self.up.weight, mode='fan_out', nonlinearity='relu')
            self.bn_up = nn.BatchNorm2d(skip_ch)
            self.act_up = activation_parser(activation)
            self.attention = AttentionGate(F_g=skip_ch, F_l=skip_ch, F_int=skip_ch // 2) if attention else nn.Identity()
        else:
            self.up = None
            self.bn_up = None
            self.act_up = None
            self.attention = AttentionGate(F_g=in_ch, F_l=in_ch, F_int=in_ch // 2) if attention else nn.Identity()

        #self.double_conv = DoubleConv(in_double, out_ch, activation)
        if ASPP_blocks:
            # Use ASPP instead of DoubleConv
            self.conv = ASPP(in_ch, out_ch, activation)
        else:
            # Use DoubleConv if ASPP_blocks is False
            self.conv = DoubleConv(in_ch, out_ch, activation)
        self.cbam        = CBAMBlock(out_ch, ratio=8, kernel_size=7) if attention else nn.Identity()
        self.dropout     = nn.Dropout2d(dropout_prob) if dropout_prob > 0 else nn.Identity()

    def forward(self, x, skip=None):
        if self.upsample:
            x = self.up(x)       # (B, skip_ch, H*2, W*2)
            x = self.bn_up(x)
            x = self.act_up(x)
        if skip is not None:
            skip = self.attention(g=x, x=skip)
            x = torch.cat([x, skip], dim=1)  # (B, 2*skip_ch, H*2, W*2)
        x = self.conv(x)
        x = self.cbam(x)
        x = self.dropout(x)
        return x


class BottleneckTransformer(nn.Module):
    """
    Takes a tensor of shape (B, C, H, W), flattens the H×W patches into tokens,
    runs a small TransformerEncoder over them, then reshapes back to (B, C, H, W).
    """
    def __init__(self, dim, heads=8, depth=3, mlp_dim=None):
        super().__init__()
        mlp_dim = mlp_dim or dim * 4
        # one TransformerEncoderLayer (or more, if depth>1)
        layer_e = nn.TransformerEncoderLayer(
            d_model=dim,
            nhead=heads,
            dim_feedforward=mlp_dim,
            activation='relu',
            batch_first=True
        )
        layer_d = nn.TransformerDecoderLayer(
            d_model=dim,
            nhead=heads,
            dim_feedforward=mlp_dim,
            activation='relu',
            norm_first=True,  # important for TransformerDecoder
            #batch_first=True
        )

        self.encoder = nn.TransformerEncoder(layer_e, num_layers=depth//2 if depth > 1 else depth)
        self.norm    = nn.LayerNorm(dim)
        if depth > 1:
            self.decoder = nn.TransformerDecoder(layer_d, num_layers=depth - depth//2)

    def forward(self, x):
        # x: (B, C, H, W)
        B, C, H, W = x.shape
        # flatten spatial dims:
        # → (B, C, H*W) then permute to (H*W, B, C) for PyTorch’s MHSA
        tokens = x.flatten(2).permute(2, 0, 1)   # (H*W, B, C)
        # run through TransformerEncoder
        out   = self.encoder(tokens)             # (H*W, B, C)
        # run through TransformerDecoder (optional, if depth > 1)
        if hasattr(self, 'decoder'):
            out = self.decoder(out, out)          # (H*W, B, C)
        # put back into (B, C, H, W) after a LayerNorm on each token
        out   = out.permute(1, 2, 0).view(B, C, H, W)
        return self.norm(out.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
        # explanation of the two permutes:
        #  - out.permute(1,2,0)→(B, C, H*W) then .view(B, C, H, W)
        #  - we want LN over the C‐dimension, so we permute to (B, H, W, C), apply LayerNorm,
        #    then back to (B, C, H, W).


class UNet(nn.Module):
    def __init__(self,
                 in_channels=1,
                 out_channels=1,
                 down_filters=None,
                 down_activations=None,
                 up_filters=None,
                 up_activations=None,
                 bottleneck_transformer=True,
                 ASPP_blocks=True,
                 output_sigmoid=True):
        super().__init__()
        assert len(down_filters) == len(down_activations)
        assert len(up_filters)   == len(up_activations)

        # Build Encoder path
        self.output_sigmoid = output_sigmoid
        self.input_norm = nn.BatchNorm2d(in_channels)
        self.encoders = nn.ModuleList()
        self.bottleneck_transformer = bottleneck_transformer
        prev_ch = in_channels
        for i, out_ch in enumerate(down_filters):
            act_str = down_activations[i].lower()
            self.encoders.append(
                EncoderBlock(in_ch=prev_ch,
                             out_ch=out_ch,
                             activation=act_str,
                             dropout_prob=0.1,
                             attention=(i != 0),
                             pool=True,
                             ASPP_blocks=ASPP_blocks)
            )
            prev_ch = out_ch

        # Bottleneck:
        if bottleneck_transformer:
            self.bottleneck  = BottleneckTransformer(dim=down_filters[-1],
                                                           heads=4,
                                                           depth=4)
        else:
            self.bottleneck = nn.Identity()

        # Build Decoder path
        self.decoders = nn.ModuleList()
        N = len(down_filters)
        for i in range(len(up_filters)):
            act_str = up_activations[i].lower()
            # Corresponding skip channels from encoder
            skip_ch = down_filters[N - 1 - i]
            # Input channels for this decoder block
            out_ch = up_filters[i]
            in_ch_dec = (down_filters[-1] * 1) if (i == 0) else up_filters[i - 1]

            self.decoders.append(
                DecoderBlock(in_ch=in_ch_dec,
                             skip_ch=skip_ch,
                             out_ch=out_ch,
                             activation=act_str,
                             dropout_prob=0.1,
                             attention= True,
                             upsample=True,
                             ASPP_blocks=ASPP_blocks)
            )

        if output_sigmoid:
            self.final_conv = nn.Sequential(
                nn.Conv2d(up_filters[-1], out_channels, kernel_size=5, padding=2, bias=True),
                nn.Sigmoid())
            nn.init.kaiming_normal_(self.final_conv[0].weight, mode='fan_out', nonlinearity='sigmoid')
            nn.init.constant_(self.final_conv[0].bias, 0)
        else:
            self.final_conv = nn.Conv2d(up_filters[-1], out_channels, kernel_size=5, padding=2, bias=True)
            nn.init.kaiming_normal_(self.final_conv.weight, mode='fan_out', nonlinearity='relu')
            nn.init.constant_(self.final_conv.bias, 0)


    def forward(self, x):
        x = self.input_norm(x)  # Normalize input
        # x: (B, 1, 128, 128)
        skips = []
        for enc in self.encoders[:-1]:  # skip last encoder (bottleneck)
            x, skip = enc(x)
            skips.append(skip)

        # Bottleneck:
        x, _ = self.encoders[-1](x) # last encoder does not return a skip
        skips.append(None)
        x = self.bottleneck(x)

        x = self.decoders[0](x, skips[-1])  # first decoder uses the last encoder skip

        skips = skips[::-1]              # reverse order for decoding

        for i in range(1, len(self.decoders)):
            skip_feat = skips[i]
            x = self.decoders[i](x, skip_feat)

        x = self.final_conv(x)
        return x

In [3]:
def train_model(model, train_ds, val_ds, epochs=100, batch_size=32, lr=1e-3, loss=None, alpha=0.99, gamma=3.1, device=None):
    """
    Train the model on train_ds, validate on val_ds, and print losses + F1 each epoch.
    Resizes all masks to `output_size` so that preds and targets match in spatial dims.
    """
    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)

    # 1) Figure out the model’s output spatial size by pushing a dummy 128×128 patch.
    model.eval()  # ensure BatchNorm uses running‐stats, not “batch” stats
    with torch.no_grad():
        dummy = torch.randn(1, 1, 128, 128).to(device)
        out_dummy = model(dummy)
        output_size = (out_dummy.shape[-2], out_dummy.shape[-1])  # e.g. (32,32) for your JSON

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,
                              num_workers=4, pin_memory=True)
    val_loader   = DataLoader(val_ds,   batch_size=batch_size, shuffle=False,
                              num_workers=4, pin_memory=True)

    if loss is None:
        criterion = ComboLossTF(bce_weight=0.0, dice_weight=0.0, focal_twersky_weight=1, alpha=alpha, gamma=gamma)
    else:
        criterion = loss
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    sched = torch.optim.lr_scheduler.OneCycleLR(optimizer,
                                                max_lr=lr,
                                                steps_per_epoch=len(train_loader),
                                                epochs=epochs,
                                                pct_start=0.1,
                                                anneal_strategy='cos')

    for epoch in range(1, epochs + 1):
        # ——— Training ———
        model.train()
        running_loss = 0.0
        tp = fp = fn = 0
        for batch_num, (imgs, masks) in enumerate(train_loader):
            imgs = imgs.to(device)  # (B,1,128,128)

            # Resize the ground‐truth masks to output_size (e.g. (32,32))
            m_resized = reshape_masks(masks, new_size=output_size).to(device)

            optimizer.zero_grad()
            preds = model(imgs)              # (B,1, output_H, output_W)
            loss = criterion(preds, m_resized)
            loss.backward()
            optimizer.step()
            sched.step()
            optimizer.zero_grad()

            running_loss += loss.item() * imgs.size(0)
            if not model.output_sigmoid:
                # If model does not output Sigmoid, apply it here
                preds = torch.sigmoid(preds)
            with torch.no_grad():
                pred_bin = (preds > 0.5).float()
                t = m_resized
                tp += (pred_bin * t).sum().item()
                fp += (pred_bin * (1 - t)).sum().item()
                fn += ((1 - pred_bin) * t).sum().item()

            prec = tp / (tp + fp + 1e-8)
            rec  = tp / (tp + fn + 1e-8)
            f1   = 2 * prec * rec / (prec + rec + 1e-8)
            print (f"\rEpoch {epoch:03d}  "
                   f"Batch {batch_num+1:03d}/{len(train_loader)}  "
                   f"Batch Loss: {loss.item():.4f}  "
                   f"| train F1: {f1:.4f}  | train precision: {prec:.4f}  | train recall: {rec:.4f}", end='\r')

        train_loss = running_loss / len(train_ds)
        prec = tp / (tp + fp + 1e-8)
        rec  = tp / (tp + fn + 1e-8)
        f1   = 2 * prec * rec / (prec + rec + 1e-8)

        # ——— Validation ———
        model.eval()
        val_loss = 0.0
        tp = fp = fn = 0
        val_y = []
        pred_val = []
        with torch.no_grad():
            for imgs, masks in val_loader:
                imgs = imgs.to(device)
                m_resized = reshape_masks(masks, new_size=output_size).to(device)
                preds = model(imgs)
                loss = criterion(preds, m_resized)
                val_loss += loss.item() * imgs.size(0)

                if not model.output_sigmoid:
                    preds = torch.sigmoid(preds)  # apply Sigmoid if model does not output it
                pred_bin = (preds > 0.5).float()
                tp += (pred_bin * m_resized).sum().item()
                fp += (pred_bin * (1 - m_resized)).sum().item()
                fn += ((1 - pred_bin) * m_resized).sum().item()
                val_y.append(m_resized.cpu().numpy())
                pred_val.append(preds.cpu().numpy())
        # Collect all validation masks for AUC calculation
        val_y = np.concatenate(val_y, axis=0)
        preds_val = np.concatenate(pred_val, axis=0)  # (N, 1, Hout, Wout)
        val_loss = val_loss / len(val_ds)
        prec = tp / (tp + fp + 1e-8)
        rec  = tp / (tp + fn + 1e-8)
        f1_val = 2 * prec * rec / (prec + rec + 1e-8)
        auc_val = sklearn.metrics.roc_auc_score(val_y.flatten(), preds_val.flatten() )

        print(f"Epoch {epoch:03d}  "
              f"Train Loss: {train_loss:.4f}  "
              f"| Val Loss: {val_loss:.4f}  "
              f"| Train F1: {f1:.4f}  "
              f"| Val F1: {f1_val:.4f}  "
              f"| Val Prec: {prec:.4f}  "
              f"| Val Rec: {rec:.4f}"
              f"| Val AUC: {auc_val:.4f}")

    return model


In [4]:
class DiceLoss(nn.Module):
    def __init__(self, smooth=1e-6):
        super().__init__()
        self.smooth = smooth

    def forward(self, preds, targets):
        """
        preds:   Tensor (B,1,H,W) after Sigmoid
        targets: Tensor (B,1,H,W) binary {0,1}
        """
        p_flat = preds.view(-1)
        t_flat = targets.view(-1)
        intersection = (p_flat * t_flat).sum()
        dice_coeff = (2. * intersection + self.smooth) / (p_flat.sum() + t_flat.sum() + self.smooth)
        return 1 - dice_coeff

class FocalTverskyLoss(nn.Module):
    def __init__(self, alpha=0.9, gamma=2.0, eps=1e-6):
        super().__init__()
        self.alpha, self.gamma, self.eps = alpha, gamma, eps
        self.beta = 1 - alpha  # Ensure alpha + beta = 1

    def forward(self, preds, targets):
        preds = preds.view(-1)
        targets = targets.view(-1)
        TP = (preds * targets).sum()
        FP = (preds * (1 - targets)).sum()
        FN = ((1 - preds) * targets).sum()
        tversky = (TP + self.eps) / (TP + self.alpha*FN + self.beta*FP + self.eps)
        return torch.pow((1 - tversky), self.gamma)

class ComboLossTF(nn.Module):
    def __init__(self, bce_weight=0.33, dice_weight=0.33, focal_twersky_weight=0.33, alpha=0.95, gamma=3.1):

        super().__init__()
        self.bce = nn.BCELoss()
        self.dice = DiceLoss(smooth=1e-6)
        self.FW = FocalTverskyLoss (alpha = alpha, gamma=gamma)
        self.bw, self.dw, self.fw = bce_weight, dice_weight, focal_twersky_weight

    def forward(self, preds, targets):
        # preds, targets both (B,1,H,W)
        total_loss = 0
        if self.bw > 0:
            l_bce = self.bce(preds, targets)
            total_loss += self.bw * l_bce
        if self.dw > 0:
            l_dice = self.dice(preds, targets)
            total_loss += self.dw * l_dice
        if self.fw > 0:
            l_focal_tversky = self.FW(preds, targets)
            total_loss += self.fw * l_focal_tversky
        return total_loss

def sigzi(x, axis=None):
    """
Compute the interquartile range (IQR) of x along the specified axis.
    Args:
        x: array-like, shape (P, H, W) or (H, W) or (N, C, H, W)
        axis: axis along which to compute the IQR.
              If None, computes over the flattened array.

    Returns: float, the IQR of x.

    """
    return 0.741 * (np.percentile(x, 75, axis=axis) - np.percentile(x, 25, axis=axis))

def split_stack(arr, nrows, ncols):
    """
    Split a stack of 2D panels into (nrows × ncols) tiles.
    arr: ndarray, shape (P, H, W)
    Returns: ndarray, shape (P * (H//nrows)*(W//ncols), nrows, ncols)
    """
    P, H, W = arr.shape
    pad_h = (-H) % nrows
    pad_w = (-W) % ncols
    if pad_h or pad_w:
        arr = np.pad(arr,
                     ((0, 0),
                      (0, pad_h),
                      (0, pad_w)),
                     mode='constant',
                     constant_values=0)
    H2, W2 = arr.shape[1], arr.shape[2]
    blocks = (arr
              .reshape(P,
                       H2 // nrows, nrows,
                       W2 // ncols, ncols)
              .swapaxes(2, 3))
    P2, Hb, Wb, nr, nc = blocks.shape
    out = blocks.reshape(P2 * Hb * Wb, nr, nc)
    return out

def build_datasets(h5_path, tile_size=128, clip_min=-7.0, clip_max=7.0):
    """
    Load HDF5 training data, normalize, clip, tile, and return PyTorch tensors.

    Args:
        h5_path (str): Path to the .h5 file with "images" and "masks" datasets.
        tile_size (int): Size of square tiles to extract from each image.
        clip_min (float): Minimum value for clipping.
        clip_max (float): Maximum value for clipping.

    Returns:
        x_tiles (Tensor): (N, 1, tile_size, tile_size) image tiles
        y_tiles (Tensor): (N, 1, tile_size, tile_size) mask tiles
    """
    with h5py.File(h5_path, "r") as f:
        x = f["images"][:]  # shape (N, H, W), float32
        y = f["masks"][:]   # shape (N, H, W), bool or int
        print(x.shape)

    # Normalize and clip
    #x = x / sigzi(x)
    #x = np.clip(x, clip_min, clip_max)
    #print (x.shape)
    x = np.clip(x, -166.43, 169.96)

    # Tile images and masks
    x_tiles = split_stack(x, tile_size, tile_size)  # (N, tile_size, tile_size)
    y_tiles = split_stack(y.astype("float32"), tile_size, tile_size)

    # Convert to PyTorch tensors and add channel dimension
    x_tiles = torch.from_numpy(x_tiles).float().unsqueeze(1)
    y_tiles = torch.from_numpy(y_tiles).float().unsqueeze(1)
    print (x_tiles.shape)

    return x_tiles, y_tiles


def reshape_masks(masks, new_size):
    """
    Resize binary masks (0/1) to `new_size`:
      - Uses bilinear interpolation (same as TF’s tf.image.resize with bilinear)
      - Applies torch.ceil(...) to recover {0,1} values exactly.
    Input:
      - masks: either a Tensor of shape (N, 1, H_orig, W_orig)
               or a numpy array of shape (N, H_orig, W_orig)
      - new_size: tuple (new_H, new_W)
    Returns:
      - Tensor of shape (N, 1, new_H, new_W), values in {0,1}
    """
    if isinstance(masks, np.ndarray):
        m = torch.from_numpy(masks).float().unsqueeze(1)  # → (N,1,H,W)
    else:
        m = masks  # assume already FloatTensor (N,1,H,W)
    m_resized = F.interpolate(m, size=new_size, mode='bilinear', align_corners=False)
    m_resized = torch.ceil(m_resized)
    return m_resized.clamp(0, 1)

def split_train_val(x_tiles, y_tiles, train_frac=0.8, seed=42):
    """
    Shuffle and split x_tiles, y_tiles into two TensorDatasets: train (80%) and val (20%).
    """
    n = x_tiles.shape[0]
    idx = torch.randperm(n, generator=torch.Generator().manual_seed(seed))
    split = int(train_frac * n)
    train_idx = idx[:split]
    val_idx   = idx[split:]
    train_idx, val_idx = train_idx.sort().values, val_idx.sort().values
    x_tr, y_tr = x_tiles[train_idx], y_tiles[train_idx]
    x_val, y_val = x_tiles[val_idx], y_tiles[val_idx]
    return TensorDataset(x_tr, y_tr), TensorDataset(x_val, y_val)

In [41]:
import h5py, numpy as np, torch
from torch.utils.data import Dataset, DataLoader, random_split

CLIP_MIN, CLIP_MAX = -166.43, 169.96  # match TF

def _tiles_for_shape(H, W, tile):
    Hb = (H + tile - 1) // tile
    Wb = (W + tile - 1) // tile
    return Hb, Wb

class H5TiledDataset(Dataset):
    """
    Streams tiles directly from an HDF5 file, no full-array load.
    Each worker opens its own h5 handle (h5py is not fork/thread safe).
    """
    def __init__(self, h5_path, tile_size=128, pos_fraction=1.0, seed=42):
        self.h5_path   = h5_path
        self.tile      = tile_size
        self.rng       = np.random.default_rng(seed)
        self.h5        = None  # opened lazily per worker

        # Probe shapes without reading the datasets
        with h5py.File(self.h5_path, "r") as f:
            ds_x = f["images"]; ds_y = f["masks"]
            self.N, self.H, self.W = ds_x.shape
            assert ds_y.shape == (self.N, self.H, self.W)

            # Precompute tile index list: (img_idx, r, c)
            Hb, Wb = _tiles_for_shape(self.H, self.W, self.tile)
            self.all_indices = [(i, r, c)
                                for i in range(self.N)
                                for r in range(Hb)
                                for c in range(Wb)]

            # (Optional) positive-tile mining to fight imbalance
            # Build a quick bitmap of which tiles have any positives
            # (sample a thin mask to avoid full read)
            self.pos_tiles = []
            step = tile_size  # coarse scan at tile stride
            for i in range(self.N):
                # read a subsampled view to quickly mark tiles
                # (safe on memory: only (H/step)*(W/step) booleans)
                m = ds_y[i, ::step, ::step]
                Hp, Wp = m.shape
                for r in range(Hp):
                    for c in range(Wp):
                        if m[r, c] > 0:
                            self.pos_tiles.append((i, r, c))
            # Optionally upsample positive tiles by pos_fraction
            if 0 < pos_fraction < 1.0 and len(self.pos_tiles) > 0:
                keep = int(len(self.pos_tiles) * pos_fraction)
                self.pos_tiles = self.rng.choice(self.pos_tiles, size=keep, replace=False).tolist()

        # Mix: all tiles; you can also bias sampling later via a Sampler
        self.indices = self.all_indices

    def _ensure_open(self):
        if self.h5 is None:
            self.h5 = h5py.File(self.h5_path, "r")
            self.ds_x = self.h5["images"]
            self.ds_y = self.h5["masks"]

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        self._ensure_open()
        i, r, c = self.indices[idx]
        t = self.tile
        r0, c0 = r * t, c * t
        r1, c1 = min(r0 + t, self.H), min(c0 + t, self.W)

        # Slice the exact window (handles borders)
        x = self.ds_x[i, r0:r1, c0:c1].astype("float32")
        y = self.ds_y[i, r0:r1, c0:c1].astype("float32")

        # Pad to (t,t) if we hit the edges
        if x.shape[0] != t or x.shape[1] != t:
            xp = np.zeros((t, t), dtype=np.float32)
            yp = np.zeros((t, t), dtype=np.float32)
            xp[:x.shape[0], :x.shape[1]] = x
            yp[:y.shape[0], :y.shape[1]] = y
            x, y = xp, yp

        # Match TF preprocessing
        x = np.clip(x, CLIP_MIN, CLIP_MAX)
        x = x[None, ...]  # add channel dim
        y = y[None, ...]
        return torch.from_numpy(x), torch.from_numpy(y)

def _worker_init_fn(worker_id):
    # each worker will lazily open its own h5 in __getitem__
    pass

def train_model(model, train_ds, val_ds, epochs=100, batch_size=32, lr=1e-3,
                loss=None, alpha=0.99, gamma=3.1, device=None,
                num_workers=0, pin_memory=False, prefetch_factor=2, persistent_workers=False):
    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)

    # infer output size once
    model.eval()
    with torch.no_grad():
        dummy = torch.randn(1, 1, 128, 128, device=device)
        out_dummy = model(dummy)
        output_size = (out_dummy.shape[-2], out_dummy.shape[-1])

    # ✅ use caller-provided worker settings
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,
                              num_workers=num_workers, pin_memory=pin_memory,
                              prefetch_factor=prefetch_factor if num_workers>0 else None,
                              persistent_workers=persistent_workers if num_workers>0 else False)
    val_loader   = DataLoader(val_ds, batch_size=batch_size, shuffle=False,
                              num_workers=num_workers, pin_memory=pin_memory,
                              prefetch_factor=prefetch_factor if num_workers>0 else None,
                              persistent_workers=persistent_workers if num_workers>0 else False)

    if loss is None:
        criterion = ComboLossTF(bce_weight=0.0, dice_weight=0.0, focal_twersky_weight=1, alpha=alpha, gamma=gamma)
    else:
        criterion = loss
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    sched = torch.optim.lr_scheduler.OneCycleLR(optimizer,
                                                max_lr=lr,
                                                steps_per_epoch=len(train_loader),
                                                epochs=epochs,
                                                pct_start=0.1,
                                                anneal_strategy='cos')

    for epoch in range(1, epochs + 1):
        # ——— Training ———
        model.train()
        running_loss = 0.0
        tp = fp = fn = 0
        for batch_num, (imgs, masks) in enumerate(train_loader):
            imgs = imgs.to(device)  # (B,1,128,128)
            #print(imgs.shape)

            # Resize the ground‐truth masks to output_size (e.g. (32,32))
            m_resized = reshape_masks(masks, new_size=output_size).to(device)

            optimizer.zero_grad()
            preds = model(imgs)              # (B,1, output_H, output_W)
            loss = criterion(preds, m_resized)
            loss.backward()
            optimizer.step()
            sched.step()
            optimizer.zero_grad()

            running_loss += loss.item() * imgs.size(0)
            if not model.output_sigmoid:
                # If model does not output Sigmoid, apply it here
                preds = torch.sigmoid(preds)
            with torch.no_grad():
                pred_bin = (preds > 0.5).float()
                t = m_resized
                tp += (pred_bin * t).sum().item()
                fp += (pred_bin * (1 - t)).sum().item()
                fn += ((1 - pred_bin) * t).sum().item()

            prec = tp / (tp + fp + 1e-8)
            rec  = tp / (tp + fn + 1e-8)
            f1   = 2 * prec * rec / (prec + rec + 1e-8)
            print (f"Epoch {epoch:03d}  "
                   f"Batch {batch_num+1:03d}/{len(train_loader)}  "
                   f"Batch Loss: {loss.item():.4f}  "
                   f"| train F1: {f1:.4f}  | train precision: {prec:.4f}  | train recall: {rec:.4f}", end='\n', flush=True)

        train_loss = running_loss / len(train_ds)
        prec = tp / (tp + fp + 1e-8)
        rec  = tp / (tp + fn + 1e-8)
        f1   = 2 * prec * rec / (prec + rec + 1e-8)

        # ——— Validation ———
        model.eval()
        val_loss = 0.0
        tp = fp = fn = 0
        val_y = []
        pred_val = []
        with torch.no_grad():
            for imgs, masks in val_loader:
                imgs = imgs.to(device)
                m_resized = reshape_masks(masks, new_size=output_size).to(device)
                preds = model(imgs)
                loss = criterion(preds, m_resized)
                val_loss += loss.item() * imgs.size(0)

                if not model.output_sigmoid:
                    preds = torch.sigmoid(preds)  # apply Sigmoid if model does not output it
                pred_bin = (preds > 0.5).float()
                tp += (pred_bin * m_resized).sum().item()
                fp += (pred_bin * (1 - m_resized)).sum().item()
                fn += ((1 - pred_bin) * m_resized).sum().item()
                val_y.append(m_resized.cpu().numpy())
                pred_val.append(preds.cpu().numpy())
        # Collect all validation masks for AUC calculation
        val_y = np.concatenate(val_y, axis=0)
        preds_val = np.concatenate(pred_val, axis=0)  # (N, 1, Hout, Wout)
        val_loss = val_loss / len(val_ds)
        prec = tp / (tp + fp + 1e-8)
        rec  = tp / (tp + fn + 1e-8)
        f1_val = 2 * prec * rec / (prec + rec + 1e-8)
        auc_val = sklearn.metrics.roc_auc_score(val_y.flatten(), preds_val.flatten() )

        print(f"Epoch {epoch:03d}  "
              f"Train Loss: {train_loss:.4f}  "
              f"| Val Loss: {val_loss:.4f}  "
              f"| Train F1: {f1:.4f}  "
              f"| Val F1: {f1_val:.4f}  "
              f"| Val Prec: {prec:.4f}  "
              f"| Val Rec: {rec:.4f}"
              f"| Val AUC: {auc_val:.4f}")

    return model


In [44]:
down_filters =  [32, 32, 64, 128, 256, 512, 1024]
down_activations = ['relu', 'selu', 'selu', 'selu', 'selu', 'selu', 'selu']

up_filters       = [1024, 512, 256, 128, 64]
up_activations   = ['selu', 'selu', 'selu', 'selu', 'relu']
# BCE loss with logits with
#x_tiles, y_tiles = build_datasets("../DATA/test.h5", tile_size=128)

#print("Positive/negative ratio:", (y_tiles==1).sum().item() / (y_tiles==0).sum().item())

#neg = (y_tiles==0).sum()
#pos = (y_tiles==1).sum()
#pos_weight = (neg / pos).float()
#loss = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(pos_weight))
loss = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(1 / 0.003329051612807811))
model = UNet(
        down_filters=down_filters,
        down_activations=down_activations,
        up_filters=up_filters,
        up_activations=up_activations,
        bottleneck_transformer=False,
        ASPP_blocks=False,
        output_sigmoid=False)

full_ds = H5TiledDataset("../DATA/train_chunked.h5", tile_size=128)

# Split without materializing:
n = len(full_ds)
n_tr = int(0.8 * n)
n_va = n - n_tr

train_ds, val_ds = random_split(full_ds, [n_tr, n_va],
                                generator=torch.Generator().manual_seed(42))

trained_model = train_model(
    model,
    train_ds, val_ds,
    epochs=150,
    batch_size=64,          # start smaller to reduce I/O pressure
    lr=1.5e-4,
    loss=loss,              # or use Focal-Tversky combo (recommended)
    num_workers=32,          # ✅ critical: no multiprocessing
    pin_memory=False
)


Epoch 001  Batch 001/10240  Batch Loss: 2.1704  | train F1: 0.0020  | train precision: 0.0010  | train recall: 0.8710
Epoch 001  Batch 002/10240  Batch Loss: 2.8776  | train F1: 0.0032  | train precision: 0.0016  | train recall: 0.6098
Epoch 001  Batch 003/10240  Batch Loss: 2.3687  | train F1: 0.0056  | train precision: 0.0028  | train recall: 0.7712
Epoch 001  Batch 004/10240  Batch Loss: 2.8409  | train F1: 0.0067  | train precision: 0.0033  | train recall: 0.7694
Epoch 001  Batch 005/10240  Batch Loss: 2.0287  | train F1: 0.0053  | train precision: 0.0027  | train recall: 0.7691
Epoch 001  Batch 006/10240  Batch Loss: 2.3819  | train F1: 0.0054  | train precision: 0.0027  | train recall: 0.7802
Epoch 001  Batch 007/10240  Batch Loss: 2.9573  | train F1: 0.0063  | train precision: 0.0032  | train recall: 0.7807
Epoch 001  Batch 008/10240  Batch Loss: 2.2109  | train F1: 0.0058  | train precision: 0.0029  | train recall: 0.7787
Epoch 001  Batch 009/10240  Batch Loss: 2.4495  | train 

KeyboardInterrupt: 

In [21]:
#down_filters     = [32, 64, 128, 256, 512]
down_filters =  [32, 32, 64, 128, 256, 512, 1024]
down_activations = ['relu', 'selu', 'selu', 'selu', 'selu', 'selu', 'selu']

up_filters       = [1024, 512, 256, 128, 64]
up_activations   = ['selu', 'selu', 'selu', 'selu', 'relu']
# BCE loss with logits with
#x_tiles, y_tiles = build_datasets("../DATA/test.h5", tile_size=128)

#print("Positive/negative ratio:", (y_tiles==1).sum().item() / (y_tiles==0).sum().item())

#neg = (y_tiles==0).sum()
#pos = (y_tiles==1).sum()
#pos_weight = (neg / pos).float()
#loss = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(pos_weight))
loss = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(1 / 0.003329051612807811))
model = UNet(
        down_filters=down_filters,
        down_activations=down_activations,
        up_filters=up_filters,
        up_activations=up_activations,
        bottleneck_transformer=False,
        ASPP_blocks=False,
        output_sigmoid=False)


trained_model = train_model(model,
                            train_ds, val_ds,
                            epochs=150,
                            batch_size=128,
                            lr=0.00015,
                            alpha=0.95, gamma=3.1, loss=loss)

Epoch 001  Batch 160/5120  Batch Loss: 0.8639  | train F1: 0.0077  | train precision: 0.0039  | train recall: 0.5039

KeyboardInterrupt: 

In [16]:
from time import perf_counter
ds = H5TiledDataset("../DATA/train.h5", tile_size=128)
t0 = perf_counter()
x,y = ds[0]
print("one __getitem__ took:", perf_counter()-t0, "sec", x.shape, y.shape)


⚠️ HDF5 datasets are contiguous (no chunks). Random tile reads will be slow. Consider h5repack with CHUNK=1x128x128.
one __getitem__ took: 0.0007629280000855942 sec torch.Size([1, 128, 128]) torch.Size([1, 128, 128])


In [11]:
train_loader = DataLoader(train_ds, batch_size=64, shuffle=True,
                          num_workers=0, pin_memory=False)
for i,(xb,yb) in enumerate(train_loader):
    print("first batch OK:", xb.shape); break


first batch OK: torch.Size([64, 1, 128, 128])


In [19]:
import h5py, numpy as np, torch
from torch.utils.data import Dataset

CLIP_MIN, CLIP_MAX = -166.43, 169.96  # match TF

def _tiles_for_shape(H, W, tile):
    Hb = (H + tile - 1) // tile
    Wb = (W + tile - 1) // tile
    return Hb, Wb

class H5TiledDataset(Dataset):
    """
    Lazily streams tiles from an HDF5 file. No mask scanning in __init__.
    Each worker opens its own file handle.
    """
    def __init__(self, h5_path, tile_size=128):
        self.h5_path = h5_path
        self.tile    = tile_size
        self.h5      = None  # opened lazily per worker

        # Probe shapes without loading data
        with h5py.File(self.h5_path, "r") as f:
            ds_x = f["images"]; ds_y = f["masks"]
            self.N, self.H, self.W = ds_x.shape
            assert ds_y.shape == (self.N, self.H, self.W)

            # (Optional) warn about unfriendly chunking
            try:
                chunks_x = ds_x.chunks
                chunks_y = ds_y.chunks
            except Exception:
                chunks_x = chunks_y = None
            #if chunks_x is None or chunks_y is None:
                #print("⚠️ HDF5 datasets are contiguous (no chunks). Random tile reads will be slow. Consider h5repack with CHUNK=1x128x128.")

        # Precompute tile indices (grouped by image for locality)
        Hb, Wb = _tiles_for_shape(self.H, self.W, self.tile)
        self.indices = []
        for i in range(self.N):
            for r in range(Hb):
                for c in range(Wb):
                    self.indices.append((i, r, c))

    def _ensure_open(self):
        if self.h5 is None:
            self.h5 = h5py.File(self.h5_path, "r")
            self.ds_x = self.h5["images"]
            self.ds_y = self.h5["masks"]

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        self._ensure_open()
        i, r, c = self.indices[idx]
        t = self.tile
        r0, c0 = r * t, c * t
        r1, c1 = min(r0 + t, self.H), min(c0 + t, self.W)

        # read exact window
        x = self.ds_x[i, r0:r1, c0:c1].astype("float32")
        y = self.ds_y[i, r0:r1, c0:c1].astype("float32")

        # pad edges
        if x.shape[0] != t or x.shape[1] != t:
            xp = np.zeros((t, t), dtype=np.float32)
            yp = np.zeros((t, t), dtype=np.float32)
            xp[:x.shape[0], :x.shape[1]] = x
            yp[:y.shape[0], :y.shape[1]] = y
            x, y = xp, yp

        # preprocessing to match TF
        x = np.clip(x, CLIP_MIN, CLIP_MAX)
        return torch.from_numpy(x[None, ...]), torch.from_numpy(y[None, ...])


In [15]:
from torch.utils.data import DataLoader, random_split

full_ds = H5TiledDataset("../DATA/train.h5", tile_size=128)

# quick timing check (should be ~milliseconds, not 20+ seconds)
from time import perf_counter
t0 = perf_counter(); _ = full_ds[0]; print("getitem:", perf_counter()-t0, "s")

# split
n = len(full_ds); n_tr = int(0.8*n); n_va = n - n_tr
train_ds, val_ds = random_split(full_ds, [n_tr, n_va],
                                generator=torch.Generator().manual_seed(42))

# start with num_workers=0 to sidestep h5py/fork issues
train_loader = DataLoader(train_ds, batch_size=64, shuffle=True,
                          num_workers=0, pin_memory=False)
val_loader   = DataLoader(val_ds, batch_size=64, shuffle=False,
                          num_workers=0, pin_memory=False)


⚠️ HDF5 datasets are contiguous (no chunks). Random tile reads will be slow. Consider h5repack with CHUNK=1x128x128.
getitem: 0.0008007400001588394 s
