In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
import numpy as np

# ------------------------------------------------------------
# 1) CBAM (Channel + Spatial Attention)
# ------------------------------------------------------------
class ChannelAttention(nn.Module):
    def __init__(self, in_channels, ratio=8):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)

        self.fc1 = nn.Conv2d(in_channels, in_channels // ratio, kernel_size=1, bias=True)
        self.relu = nn.ReLU(inplace=True)
        self.fc2 = nn.Conv2d(in_channels // ratio, in_channels, kernel_size=1, bias=True)
        self.sigmoid = nn.Sigmoid()

        nn.init.kaiming_normal_(self.fc1.weight, mode='fan_out', nonlinearity='relu')
        nn.init.constant_(self.fc1.bias, 0)
        nn.init.kaiming_normal_(self.fc2.weight, mode='fan_out', nonlinearity='linear')
        nn.init.constant_(self.fc2.bias, 0)

    def forward(self, x):
        avg_out = self.avg_pool(x)        # (B, C, 1, 1)
        avg_out = self.fc1(avg_out)       # (B, C//ratio, 1, 1)
        avg_out = self.relu(avg_out)
        avg_out = self.fc2(avg_out)       # (B, C, 1, 1)

        max_out = self.max_pool(x)        # (B, C, 1, 1)
        max_out = self.fc1(max_out)       # (B, C//ratio, 1, 1)
        max_out = self.relu(max_out)
        max_out = self.fc2(max_out)       # (B, C, 1, 1)

        out = avg_out + max_out           # (B, C, 1, 1)
        scale = self.sigmoid(out)         # (B, C, 1, 1)
        return x * scale                  # broadcast along H, W

class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super().__init__()
        assert kernel_size in (3, 7)
        padding = (kernel_size - 1) // 2
        self.conv = nn.Conv2d(2, 1, kernel_size=kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()
        nn.init.kaiming_normal_(self.conv.weight, mode='fan_in', nonlinearity='sigmoid')

    def forward(self, x):
        # x: (B, C, H, W)
        avg_out = torch.mean(x, dim=1, keepdim=True)     # (B, 1, H, W)
        max_out, _ = torch.max(x, dim=1, keepdim=True)   # (B, 1, H, W)
        concat = torch.cat([avg_out, max_out], dim=1)    # (B, 2, H, W)
        attn = self.conv(concat)                         # (B, 1, H, W)
        attn = self.sigmoid(attn)
        return x * attn                                  # broadcast across C

class CBAMBlock(nn.Module):
    def __init__(self, in_channels, ratio=8, kernel_size=7):
        super().__init__()
        self.channel_att = ChannelAttention(in_channels, ratio)
        self.spatial_att = SpatialAttention(kernel_size)

    def forward(self, x):
        x = self.channel_att(x)
        x = self.spatial_att(x)
        return x

# ------------------------------------------------------------
# 2) DoubleConv: two 3×3 convs → BatchNorm → Activation
# ------------------------------------------------------------
class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch, activation):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            activation,
            nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            activation
        )
        for m in self.block.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')

    def forward(self, x):
        return self.block(x)

# ------------------------------------------------------------
# 3) AttentionGate for skip‐connection fusion
# ------------------------------------------------------------
class AttentionGate(nn.Module):
    def __init__(self, F_g, F_l, F_int):
        super().__init__()
        # W_g projects gating signal
        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1, bias=False),
            nn.BatchNorm2d(F_int)
        )
        # W_x projects skip connection
        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1, bias=False),
            nn.BatchNorm2d(F_int)
        )
        # psi computes 1‐channel attention map
        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1, bias=False),
            nn.Sigmoid()
        )
        self.relu = nn.ReLU(inplace=True)

        nn.init.kaiming_normal_(self.W_g[0].weight, mode='fan_out', nonlinearity='relu')
        nn.init.kaiming_normal_(self.W_x[0].weight, mode='fan_out', nonlinearity='relu')
        nn.init.kaiming_normal_(self.psi[0].weight, mode='fan_out', nonlinearity='sigmoid')

    def forward(self, g, x):
        """
        g: gating signal from decoder, shape (B, F_g, H, W)
        x: skip connection from encoder, shape (B, F_l, H, W)
        """
        g1 = self.W_g(g)   # (B, F_int, H, W)
        x1 = self.W_x(x)   # (B, F_int, H, W)
        psi = self.relu(g1 + x1)
        psi = self.psi(psi)   # (B, 1, H, W)
        return x * psi        # broadcast along channel

# ------------------------------------------------------------
# 4) Encoder Block: DoubleConv → CBAM → Dropout → Optional MaxPool
# ------------------------------------------------------------
class EncoderBlock(nn.Module):
    def __init__(self, in_ch, out_ch, activation, dropout_prob=0.0, attention=True, pool=True):
        super().__init__()
        self.double_conv = DoubleConv(in_ch, out_ch, activation)
        self.cbam        = CBAMBlock(out_ch, ratio=8, kernel_size=7) if attention else nn.Identity()
        self.dropout     = nn.Dropout2d(dropout_prob) if dropout_prob > 0 else nn.Identity()
        self.pool        = pool

    def forward(self, x):
        x = self.double_conv(x)
        x = self.cbam(x)
        x = self.dropout(x)
        skip = x.clone()
        if self.pool:
            x = F.max_pool2d(x, kernel_size=2, stride=2)
        return x, skip

# ------------------------------------------------------------
# 5) Decoder Block:
#       - Upsample from in_ch → skip_ch
#       - AttentionGate on skip
#       - DoubleConv(2*skip_ch → out_ch) → CBAM → Dropout
# ------------------------------------------------------------
class DecoderBlock(nn.Module):
    def __init__(self, in_ch, skip_ch, out_ch, activation, dropout_prob=0.0, attention=True, upsample=True):
        """
        in_ch:   channels from previous layer (bottleneck or previous decoder)
        skip_ch: channels in the corresponding encoder skip
        out_ch:  desired output channels for this decoder block
        """
        super().__init__()
        self.upsample = upsample
        self.skip_ch = skip_ch

        if self.upsample:
            # ConvTranspose2d(in_ch → skip_ch) to match spatial & channel dims
            self.up = nn.ConvTranspose2d(in_ch, skip_ch, kernel_size=3,
                                         stride=2, padding=1, output_padding=1, bias=False)
            nn.init.kaiming_normal_(self.up.weight, mode='fan_out', nonlinearity='relu')
            self.bn_up = nn.BatchNorm2d(skip_ch)
            self.act_up = activation
            self.attention = AttentionGate(F_g=skip_ch, F_l=skip_ch, F_int=skip_ch // 2) if attention else nn.Identity()
            in_double = skip_ch * 2
        else:
            self.up = None
            self.bn_up = None
            self.act_up = None
            self.attention = AttentionGate(F_g=in_ch, F_l=in_ch, F_int=in_ch // 2) if attention else nn.Identity()
            in_double = in_ch * 2 if attention else in_ch

        self.double_conv = DoubleConv(in_double, out_ch, activation)
        self.cbam        = CBAMBlock(out_ch, ratio=8, kernel_size=7) if attention else nn.Identity()
        self.dropout     = nn.Dropout2d(dropout_prob) if dropout_prob > 0 else nn.Identity()

    def forward(self, x, skip=None):
        if self.upsample:
            x = self.up(x)       # (B, skip_ch, H*2, W*2)
            x = self.bn_up(x)
            x = self.act_up(x)
        if skip is not None and not isinstance(self.attention, nn.Identity):
            skip = self.attention(g=x, x=skip)
            x = torch.cat([x, skip], dim=1)  # (B, 2*skip_ch, H*2, W*2)
        x = self.double_conv(x)
        x = self.cbam(x)
        x = self.dropout(x)
        return x

# ------------------------------------------------------------
# 6) UNetTFEquivalent: exactly follows your JSON architecture
# ------------------------------------------------------------
class UNetTFEquivalent(nn.Module):
    def __init__(self,
                 in_channels=1,
                 out_channels=1,
                 down_filters=None,
                 down_activations=None,
                 down_dropouts=None,
                 down_pool=None,
                 up_filters=None,
                 up_activations=None,
                 up_dropouts=None):
        super().__init__()
        assert len(down_filters) == len(down_activations) == len(down_dropouts) == len(down_pool)
        assert len(up_filters)   == len(up_activations)  == len(up_dropouts)

        # Build Encoder path
        self.encoders = nn.ModuleList()
        prev_ch = in_channels
        for i, out_ch in enumerate(down_filters):
            act_str = down_activations[i].lower()
            if act_str == 'relu':
                act_fn = nn.ReLU(inplace=True)
            elif act_str == 'sigmoid':
                act_fn = nn.Sigmoid()
            else:
                raise ValueError(f"Unsupported encoder activation: {act_str}")

            self.encoders.append(
                EncoderBlock(in_ch=prev_ch,
                             out_ch=out_ch,
                             activation=act_fn,
                             dropout_prob=down_dropouts[i],
                             attention=(i != 0),      # no CBAM in very first block
                             pool=down_pool[i])
            )
            prev_ch = out_ch

        # Bottleneck: DoubleConv(down_filters[-1] → down_filters[-1]*2)
        self.bottleneck = DoubleConv(down_filters[-1], down_filters[-1] * 2, nn.ReLU(inplace=True))

        # Build Decoder path
        self.decoders = nn.ModuleList()
        N = len(down_filters)
        for i, out_ch in enumerate(up_filters):
            act_str = up_activations[i].lower()
            if act_str == 'relu':
                act_fn = nn.ReLU(inplace=True)
            elif act_str == 'sigmoid':
                act_fn = nn.Sigmoid()
            else:
                raise ValueError(f"Unsupported decoder activation: {act_str}")

            # Mirror the pooling flags to decide upsampling
            do_upsample = down_pool[N - 1 - i]
            # Corresponding skip channels from encoder
            skip_ch = down_filters[N - 1 - i]
            # Input channels for this decoder block
            in_ch_dec = (down_filters[-1] * 2) if (i == 0) else up_filters[i - 1]

            self.decoders.append(
                DecoderBlock(in_ch=in_ch_dec,
                             skip_ch=skip_ch,
                             out_ch=out_ch,
                             activation=act_fn,
                             dropout_prob=up_dropouts[i],
                             attention=True,
                             upsample=do_upsample)
            )

        # Final 3×3 conv + Sigmoid → 1 channel
        self.final_conv = nn.Conv2d(up_filters[-1], out_channels, kernel_size=3, padding=1, bias=False)
        nn.init.kaiming_normal_(self.final_conv.weight, mode='fan_out', nonlinearity='sigmoid')
        self.final_sigmoid = nn.Sigmoid()

    def forward(self, x):
        # x: (B, 1, 128, 128)
        skips = []
        for enc in self.encoders:
            x, skip = enc(x)
            skips.append(skip)

        x = self.bottleneck(x)          # (B, 1024, 1, 1)
        skips = skips[::-1]              # reverse order for decoding

        for i, dec in enumerate(self.decoders):
            skip_feat = skips[i]
            x = dec(x, skip_feat)

        x = self.final_conv(x)
        x = self.final_sigmoid(x)
        return x  # (B, 1, 32, 32) in your JSON case

# ------------------------------------------------------------
# 7) Dataset utilities: split_stack, build_datasets, reshape_masks, split_train_val
# ------------------------------------------------------------
def split_stack(arr, nrows, ncols):
    """
    Split a stack of 2D panels into (nrows × ncols) tiles.
    arr: ndarray, shape (P, H, W)
    Returns: ndarray, shape (P * (H//nrows)*(W//ncols), nrows, ncols)
    """
    P, H, W = arr.shape
    pad_h = (-H) % nrows
    pad_w = (-W) % ncols
    if pad_h or pad_w:
        arr = np.pad(arr,
                     ((0, 0),
                      (0, pad_h),
                      (0, pad_w)),
                     mode='constant',
                     constant_values=0)
    H2, W2 = arr.shape[1], arr.shape[2]
    blocks = (arr
              .reshape(P,
                       H2 // nrows, nrows,
                       W2 // ncols, ncols)
              .swapaxes(2, 3))
    P2, Hb, Wb, nr, nc = blocks.shape
    out = blocks.reshape(P2 * Hb * Wb, nr, nc)
    return out

def build_datasets(npz_file, tile_size=128):
    """
    Load data from .npz, clip exactly as TF did, split into tiles, return PyTorch tensors.
      - Clips x to [-166.43, 169.96]
      - Splits each large image into (tile_size × tile_size) patches
      - Adds a channel dimension (→ shape (N, 1, tile_size, tile_size))
    """
    data = np.load(npz_file)
    x = data['x']  # shape (P, H, W)
    y = data['y']

    # 1) Clip as TF: [-166.43, 169.96]
    x = np.clip(x, -166.43, 169.96)

    # 2) Split into tiles (tile_size × tile_size)
    x_tiles = split_stack(x, tile_size, tile_size)  # (N_tiles, tile_size, tile_size)
    y_tiles = split_stack(y, tile_size, tile_size)

    # 3) Convert to FloatTensor and add channel dimension
    x_tiles = torch.from_numpy(x_tiles).float().unsqueeze(1)  # (N, 1, tile_size, tile_size)
    y_tiles = torch.from_numpy(y_tiles).float().unsqueeze(1)  # (N, 1, tile_size, tile_size)

    return x_tiles, y_tiles

def reshape_masks(masks, new_size):
    """
    Resize binary masks (0/1) to `new_size`:
      - Uses bilinear interpolation (same as TF’s tf.image.resize with bilinear)
      - Applies torch.ceil(...) to recover {0,1} values exactly.
    Input:
      - masks: either a Tensor of shape (N, 1, H_orig, W_orig)
               or a numpy array of shape (N, H_orig, W_orig)
      - new_size: tuple (new_H, new_W)
    Returns:
      - Tensor of shape (N, 1, new_H, new_W), values in {0,1}
    """
    if isinstance(masks, np.ndarray):
        m = torch.from_numpy(masks).float().unsqueeze(1)  # → (N,1,H,W)
    else:
        m = masks  # assume already FloatTensor (N,1,H,W)
    m_resized = F.interpolate(m, size=new_size, mode='bilinear', align_corners=False)
    m_resized = torch.ceil(m_resized)
    return m_resized.clamp(0, 1)

def split_train_val(x_tiles, y_tiles, train_frac=0.8, seed=42):
    """
    Shuffle and split x_tiles, y_tiles into two TensorDatasets: train (80%) and val (20%).
    """
    n = x_tiles.shape[0]
    idx = torch.randperm(n, generator=torch.Generator().manual_seed(seed))
    split = int(train_frac * n)
    train_idx = idx[:split]
    val_idx   = idx[split:]
    x_tr, y_tr = x_tiles[train_idx], y_tiles[train_idx]
    x_val, y_val = x_tiles[val_idx], y_tiles[val_idx]
    return TensorDataset(x_tr, y_tr), TensorDataset(x_val, y_val)

# ------------------------------------------------------------
# 8) Loss: Weighted BCE + Dice (TF used BCE + ceil(targets) for binary masks)
# ------------------------------------------------------------
class DiceLoss(nn.Module):
    def __init__(self, smooth=1e-6):
        super().__init__()
        self.smooth = smooth

    def forward(self, preds, targets):
        """
        preds:   Tensor (B,1,H,W) after Sigmoid
        targets: Tensor (B,1,H,W) binary {0,1}
        """
        p_flat = preds.view(-1)
        t_flat = targets.view(-1)
        intersection = (p_flat * t_flat).sum()
        dice_coeff = (2. * intersection + self.smooth) / (p_flat.sum() + t_flat.sum() + self.smooth)
        return 1 - dice_coeff

class ComboLossTF(nn.Module):
    def __init__(self, bce_weight=0.5, dice_weight=0.5):
        super().__init__()
        self.bce = nn.BCELoss()
        self.dice = DiceLoss(smooth=1e-6)
        self.bw, self.dw = bce_weight, dice_weight

    def forward(self, preds, targets):
        # preds, targets both (B,1,H,W)
        l_bce = self.bce(preds, targets)
        l_dice = self.dice(preds, targets)
        return self.bw * l_bce + self.dw * l_dice

# ------------------------------------------------------------
# 9) Training loop (resizes masks to match model’s output)
# ------------------------------------------------------------
def train_model(model, train_ds, val_ds, epochs=100, batch_size=32, lr=1e-3, device=None):
    """
    Train the model on train_ds, validate on val_ds, and print losses + F1 each epoch.
    Resizes all masks to `output_size` so that preds and targets match in spatial dims.
    """
    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)

    # 1) Figure out the model’s output spatial size by pushing a dummy 128×128 patch.
    with torch.no_grad():
        dummy = torch.randn(1, 1, 128, 128).to(device)
        out_dummy = model(dummy)
        output_size = (out_dummy.shape[-2], out_dummy.shape[-1])  # e.g. (32,32) for your JSON

    criterion = ComboLossTF(bce_weight=0.5, dice_weight=0.5)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,
                              num_workers=4, pin_memory=True)
    val_loader   = DataLoader(val_ds,   batch_size=batch_size, shuffle=False,
                              num_workers=4, pin_memory=True)

    for epoch in range(1, epochs + 1):
        # ——— Training ———
        model.train()
        running_loss = 0.0
        tp = fp = fn = 0
        for imgs, masks in train_loader:
            imgs = imgs.to(device)  # (B,1,128,128)

            # Resize the ground‐truth masks to output_size (e.g. (32,32))
            m_resized = reshape_masks(masks, new_size=output_size).to(device)

            optimizer.zero_grad()
            preds = model(imgs)              # (B,1, output_H, output_W)
            loss = criterion(preds, m_resized)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * imgs.size(0)

            with torch.no_grad():
                pred_bin = (preds > 0.5).float()
                t = m_resized
                tp += (pred_bin * t).sum().item()
                fp += (pred_bin * (1 - t)).sum().item()
                fn += ((1 - pred_bin) * t).sum().item()

        train_loss = running_loss / len(train_ds)
        prec = tp / (tp + fp + 1e-8)
        rec  = tp / (tp + fn + 1e-8)
        f1   = 2 * prec * rec / (prec + rec + 1e-8)

        # ——— Validation ———
        model.eval()
        val_loss = 0.0
        tp = fp = fn = 0
        with torch.no_grad():
            for imgs, masks in val_loader:
                imgs = imgs.to(device)
                m_resized = reshape_masks(masks, new_size=output_size).to(device)
                preds = model(imgs)
                loss = criterion(preds, m_resized)
                val_loss += loss.item() * imgs.size(0)

                pred_bin = (preds > 0.5).float()
                tp += (pred_bin * m_resized).sum().item()
                fp += (pred_bin * (1 - m_resized)).sum().item()
                fn += ((1 - pred_bin) * m_resized).sum().item()

        val_loss = val_loss / len(val_ds)
        prec = tp / (tp + fp + 1e-8)
        rec  = tp / (tp + fn + 1e-8)
        f1_val = 2 * prec * rec / (prec + rec + 1e-8)

        print(f"Epoch {epoch:03d}  "
              f"Train Loss: {train_loss:.4f}  "
              f"| Val Loss: {val_loss:.4f}  "
              f"| Val F1: {f1_val:.4f}  "
              f"| Prec: {prec:.4f}  "
              f"| Rec: {rec:.4f}")

    return model

# ------------------------------------------------------------
# 10) Main() – build data, instantiate model, train
# ------------------------------------------------------------
if __name__ == "__main__":
    # 10.1) Build datasets from your .npz (train.npz assumed in ../DATA/)
    npz_file = "../DATA/train.npz"
    x_tiles, y_tiles = build_datasets(npz_file, tile_size=128)
    train_ds, val_ds = split_train_val(x_tiles, y_tiles, train_frac=0.8, seed=42)

    # 10.2) Use exactly your JSON architecture:
    down_filters     = [32, 32, 32, 64, 128, 256]
    down_activations = ['relu', 'relu', 'relu', 'relu', 'sigmoid', 'relu']
    down_dropouts    = [0, 0, 0, 0, 0, 0]
    down_pool        = [True, True, True, True, True, True]

    up_filters       = [512, 256, 128, 64]
    up_activations   = ['relu', 'sigmoid', 'relu', 'sigmoid']
    up_dropouts      = [0, 0, 0, 0]

    model = UNetTFEquivalent(
        in_channels=1,
        out_channels=1,
        down_filters=down_filters,
        down_activations=down_activations,
        down_dropouts=down_dropouts,
        down_pool=down_pool,
        up_filters=up_filters,
        up_activations=up_activations,
        up_dropouts=up_dropouts,
    )

    # 10.3) Train for 100 epochs with batch_size=32, lr=1e-3
    trained_model = train_model(model,
                                train_ds, val_ds,
                                epochs=20,
                                batch_size=32,
                                lr=1e-3)


Epoch 001  Train Loss: 0.5123  | Val Loss: 0.5068  | Val F1: 0.0000  | Prec: 0.0000  | Rec: 0.0000
Epoch 002  Train Loss: 0.5061  | Val Loss: 0.5084  | Val F1: 0.0000  | Prec: 0.0000  | Rec: 0.0000
Epoch 003  Train Loss: 0.5046  | Val Loss: 0.4981  | Val F1: 0.0733  | Prec: 0.0782  | Rec: 0.0689
Epoch 004  Train Loss: 0.5034  | Val Loss: 0.5004  | Val F1: 0.0468  | Prec: 0.0585  | Rec: 0.0390
Epoch 005  Train Loss: 0.5038  | Val Loss: 0.4974  | Val F1: 0.0741  | Prec: 0.0729  | Rec: 0.0754
Epoch 006  Train Loss: 0.5003  | Val Loss: 0.4949  | Val F1: 0.0672  | Prec: 0.1178  | Rec: 0.0470
Epoch 007  Train Loss: 0.4990  | Val Loss: 0.5052  | Val F1: 0.0808  | Prec: 0.0518  | Rec: 0.1835
Epoch 008  Train Loss: 0.4986  | Val Loss: 0.4895  | Val F1: 0.1273  | Prec: 0.0989  | Rec: 0.1784
Epoch 009  Train Loss: 0.4960  | Val Loss: 0.4896  | Val F1: 0.1118  | Prec: 0.2962  | Rec: 0.0689
Epoch 010  Train Loss: 0.4914  | Val Loss: 0.4812  | Val F1: 0.1339  | Prec: 0.1288  | Rec: 0.1394
Epoch 011 