In [4]:
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import sklearn

In [7]:
def activation_parser(activation_str):
    """
    Parse a string to return the corresponding activation function.
    Supported strings: 'relu', 'sigmoid', 'tanh', 'leaky_relu'.
    """
    if activation_str.lower() == 'relu':
        return nn.ReLU(inplace=True)
    elif activation_str.lower() == 'sigmoid':
        return nn.Sigmoid()
    elif activation_str.lower() == 'tanh':
        return nn.Tanh()
    elif activation_str.lower() == 'leaky_relu':
        return nn.LeakyReLU(negative_slope=0.01, inplace=True)
    else:
        raise ValueError(f"Unsupported activation: {activation_str}")


class ChannelAttention(nn.Module):
    def __init__(self, in_channels, ratio=8):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)

        self.fc1 = nn.Conv2d(in_channels, in_channels // ratio, kernel_size=1, bias=True)
        self.relu = nn.ReLU(inplace=True)
        self.fc2 = nn.Conv2d(in_channels // ratio, in_channels, kernel_size=1, bias=True)
        self.sigmoid = nn.Sigmoid()

        nn.init.kaiming_normal_(self.fc1.weight, mode='fan_out', nonlinearity='relu')
        nn.init.constant_(self.fc1.bias, 0)
        nn.init.kaiming_normal_(self.fc2.weight, mode='fan_out', nonlinearity='linear')
        nn.init.constant_(self.fc2.bias, 0)

    def forward(self, x):
        avg_out = self.avg_pool(x)        # (B, C, 1, 1)
        avg_out = self.fc1(avg_out)       # (B, C//ratio, 1, 1)
        avg_out = self.relu(avg_out)
        avg_out = self.fc2(avg_out)       # (B, C, 1, 1)

        max_out = self.max_pool(x)        # (B, C, 1, 1)
        max_out = self.fc1(max_out)       # (B, C//ratio, 1, 1)
        max_out = self.relu(max_out)
        max_out = self.fc2(max_out)       # (B, C, 1, 1)

        out = avg_out + max_out           # (B, C, 1, 1)
        scale = self.sigmoid(out)         # (B, C, 1, 1)
        return x * scale                  # broadcast along H, W

class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super().__init__()
        assert kernel_size in (3, 7)
        padding = (kernel_size - 1) // 2
        self.conv = nn.Conv2d(2, 1, kernel_size=kernel_size, padding=padding, bias=True)
        self.sigmoid = nn.Sigmoid()
        nn.init.kaiming_normal_(self.conv.weight, mode='fan_in', nonlinearity='sigmoid')

    def forward(self, x):
        # x: (B, C, H, W)
        avg_out = torch.mean(x, dim=1, keepdim=True)     # (B, 1, H, W)
        max_out, _ = torch.max(x, dim=1, keepdim=True)   # (B, 1, H, W)
        concat = torch.cat([avg_out, max_out], dim=1)    # (B, 2, H, W)
        attn = self.conv(concat)                         # (B, 1, H, W)
        attn = self.sigmoid(attn)
        return x * attn                                  # broadcast across C

class CBAMBlock(nn.Module):
    def __init__(self, in_channels, ratio=8, kernel_size=7):
        super().__init__()
        self.channel_att = ChannelAttention(in_channels, ratio)
        self.spatial_att = SpatialAttention(kernel_size)

    def forward(self, x):
        x = self.channel_att(x)
        x = self.spatial_att(x)
        return x

class SepConv(nn.Module):
    def __init__(self, in_ch, out_ch, activation, kernel_size, padding, dilation=1):
        super().__init__()
        self.depthwise = nn.Conv2d(
            in_ch, in_ch, kernel_size=kernel_size,
            padding=padding, dilation=dilation,
            groups=in_ch, bias=True
        )
        self.pointwise = nn.Conv2d(in_ch, out_ch, kernel_size=1, bias=True)
        self.norm = nn.BatchNorm2d(out_ch)
        self.act = activation_parser(activation)

        nn.init.kaiming_normal_(self.depthwise.weight, mode='fan_out', nonlinearity=activation)
        nn.init.kaiming_normal_(self.pointwise.weight, mode='fan_out', nonlinearity=activation)

    def forward(self, x):
        x = self.depthwise(x)
        x = self.pointwise(x)
        return self.act(self.norm(x))

class ASPP(nn.Module):
    def __init__(self, in_ch, out_ch, activation):
        super().__init__()
        dilations = [1, 2, 3, 4]
        kernels   = [1, 3, 5, 7]
        self.branches = nn.ModuleList()
        for d, k in zip(dilations, kernels):
            pad = (k // 2) * d
            self.branches.append(
                SepConv(in_ch, out_ch, activation, kernel_size=k, padding=pad, dilation=d)
            )
        self.merge = nn.Sequential(
            nn.Conv2d(len(dilations) * out_ch, out_ch, kernel_size=1, bias=True),
            nn.BatchNorm2d(out_ch),
            activation_parser(activation)
        )
        nn.init.kaiming_normal_(self.merge[0].weight, mode='fan_out', nonlinearity=activation)

    def forward(self, x):
        outs = [branch(x) for branch in self.branches]
        x = torch.cat(outs, dim=1)
        return self.merge(x)

class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch, activation):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1, bias=True),
            nn.BatchNorm2d(out_ch),
            activation,
            nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1, bias=True),
            nn.BatchNorm2d(out_ch),
            activation_parser(activation)
        )
        for m in self.block.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity=activation)

    def forward(self, x):
        return self.block(x)

class AttentionGate(nn.Module):
    def __init__(self, F_g, F_l, F_int):
        super().__init__()
        # W_g projects gating signal
        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1, bias=True),
            nn.BatchNorm2d(F_int)
        )
        # W_x projects skip connection
        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1, bias=True),
            nn.BatchNorm2d(F_int)
        )
        # psi computes 1‐channel attention map
        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1, bias=True),
            nn.Sigmoid()
        )
        self.relu = nn.ReLU(inplace=True)

        nn.init.kaiming_normal_(self.W_g[0].weight, mode='fan_out', nonlinearity='relu')
        nn.init.kaiming_normal_(self.W_x[0].weight, mode='fan_out', nonlinearity='relu')
        nn.init.kaiming_normal_(self.psi[0].weight, mode='fan_out', nonlinearity='sigmoid')

    def forward(self, g, x):
        """
        g: gating signal from decoder, shape (B, F_g, H, W)
        x: skip connection from encoder, shape (B, F_l, H, W)
        """
        g1 = self.W_g(g)   # (B, F_int, H, W)
        x1 = self.W_x(x)   # (B, F_int, H, W)
        psi = self.relu(g1 + x1)
        psi = self.psi(psi)   # (B, 1, H, W)
        return x * psi        # broadcast along channel


class EncoderBlock(nn.Module):
    def __init__(self, in_ch, out_ch, activation, dropout_prob=0.0, attention=True, pool=True):
        super().__init__()
        #self.double_conv = DoubleConv(in_ch, out_ch, activation)
        self.aspp        = ASPP(in_ch, out_ch, activation)
        self.cbam        = CBAMBlock(out_ch, ratio=8, kernel_size=7) if attention else nn.Identity()
        self.dropout     = nn.Dropout2d(dropout_prob) if dropout_prob > 0 else nn.Identity()
        self.pool        = pool

    def forward(self, x):
        x = self.aspp(x)
        x = self.cbam(x)
        x = self.dropout(x)
        skip = x.clone()
        if self.pool:
            x = F.max_pool2d(x, kernel_size=2, stride=2)
        return x, skip


class DecoderBlock(nn.Module):
    def __init__(self, in_ch, skip_ch, out_ch, activation, dropout_prob=0.0, attention=True, upsample=True):
        """
        in_ch:   channels from previous layer (bottleneck or previous decoder)
        skip_ch: channels in the corresponding encoder skip
        out_ch:  desired output channels for this decoder block
        """
        super().__init__()
        self.upsample = upsample
        self.skip_ch = skip_ch

        if self.upsample:
            # ConvTranspose2d(in_ch → skip_ch) to match spatial & channel dims
            self.up = nn.ConvTranspose2d(in_ch, skip_ch, kernel_size=3,
                                         stride=2, padding=1, output_padding=1, bias=True)
            nn.init.kaiming_normal_(self.up.weight, mode='fan_out', nonlinearity='relu')
            self.bn_up = nn.BatchNorm2d(skip_ch)
            self.act_up = activation
            self.attention = AttentionGate(F_g=skip_ch, F_l=skip_ch, F_int=skip_ch // 2) if attention else nn.Identity()
            in_double = skip_ch * 2
        else:
            self.up = None
            self.bn_up = None
            self.act_up = None
            self.attention = AttentionGate(F_g=in_ch, F_l=in_ch, F_int=in_ch // 2) if attention else nn.Identity()
            in_double = in_ch * 2 if attention else in_ch

        #self.double_conv = DoubleConv(in_double, out_ch, activation)
        self.aspp        = ASPP(in_ch, out_ch, activation)
        self.cbam        = CBAMBlock(out_ch, ratio=8, kernel_size=7) if attention else nn.Identity()
        self.dropout     = nn.Dropout2d(dropout_prob) if dropout_prob > 0 else nn.Identity()

    def forward(self, x, skip=None):
        if self.upsample:
            x = self.up(x)       # (B, skip_ch, H*2, W*2)
            x = self.bn_up(x)
            x = self.act_up(x)
        if skip is not None and not isinstance(self.attention, nn.Identity):
            skip = self.attention(g=x, x=skip)
            x = torch.cat([x, skip], dim=1)  # (B, 2*skip_ch, H*2, W*2)
        x = self.aspp(x)
        x = self.cbam(x)
        x = self.dropout(x)
        return x


class BottleneckTransformer(nn.Module):
    """
    Takes a tensor of shape (B, C, H, W), flattens the H×W patches into tokens,
    runs a small TransformerEncoder over them, then reshapes back to (B, C, H, W).
    """
    def __init__(self, dim, heads=8, depth=3, mlp_dim=None):
        super().__init__()
        mlp_dim = mlp_dim or dim * 4
        # one TransformerEncoderLayer (or more, if depth>1)
        layer = nn.TransformerEncoderLayer(
            d_model=dim,
            nhead=heads,
            dim_feedforward=mlp_dim,
            activation='relu',
            norm_first=True,
        )
        self.encoder = nn.TransformerEncoder(layer, num_layers=depth)
        self.norm    = nn.LayerNorm(dim)

    def forward(self, x):
        # x: (B, C, H, W)
        B, C, H, W = x.shape
        # flatten spatial dims:
        # → (B, C, H*W) then permute to (H*W, B, C) for PyTorch’s MHSA
        tokens = x.flatten(2).permute(2, 0, 1)   # (H*W, B, C)
        # run through TransformerEncoder
        out   = self.encoder(tokens)             # (H*W, B, C)
        # put back into (B, C, H, W) after a LayerNorm on each token
        out   = out.permute(1, 2, 0).view(B, C, H, W)
        return self.norm(out.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
        # explanation of the two permutes:
        #  - out.permute(1,2,0)→(B, C, H*W) then .view(B, C, H, W)
        #  - we want LN over the C‐dimension, so we permute to (B, H, W, C), apply LayerNorm,
        #    then back to (B, C, H, W).


class UNet(nn.Module):
    def __init__(self,
                 in_channels=1,
                 out_channels=1,
                 down_filters=None,
                 down_activations=None,
                 up_filters=None,
                 up_activations=None):
        super().__init__()
        assert len(down_filters) == len(down_activations)
        assert len(up_filters)   == len(up_activations)

        # Build Encoder path
        self.encoders = nn.ModuleList()
        prev_ch = in_channels
        for i, out_ch in enumerate(down_filters):
            act_str = down_activations[i].lower()
            if act_str == 'relu':
                act_fn = nn.ReLU(inplace=True)
            elif act_str == 'sigmoid':
                act_fn = nn.Sigmoid()
            else:
                raise ValueError(f"Unsupported encoder activation: {act_str}")

            self.encoders.append(
                EncoderBlock(in_ch=prev_ch,
                             out_ch=out_ch,
                             activation=act_fn,
                             dropout_prob=0.1,
                             attention=(i != 0),
                             pool=True)
            )
            prev_ch = out_ch

        # Bottleneck: DoubleConv(down_filters[-1] → down_filters[-1]*2)
        #self.bottleneck = DoubleConv(down_filters[-1], down_filters[-1]*2, nn.ReLU(inplace=True))
        self.bottleneck_aspp   = ASPP(down_filters[-1], down_filters[-1]*2, nn.ReLU(inplace=True))
        self.bottleneck_trans  = BottleneckTransformer(dim=down_filters[-1]*2,
                                                 heads=4,
                                                 depth=4,
                                                 mlp_dim=down_filters[-1] * 4
                                                 )

        # Build Decoder path
        self.decoders = nn.ModuleList()
        N = len(down_filters)
        for i, out_ch in enumerate(up_filters):
            act_str = up_activations[i].lower()
            if act_str == 'relu':
                act_fn = nn.ReLU(inplace=True)
            elif act_str == 'sigmoid':
                act_fn = nn.Sigmoid()
            else:
                raise ValueError(f"Unsupported decoder activation: {act_str}")
            # Corresponding skip channels from encoder
            skip_ch = down_filters[N - 1 - i]
            # Input channels for this decoder block
            in_ch_dec = (down_filters[-1] * 2) if (i == 0) else up_filters[i - 1]

            self.decoders.append(
                DecoderBlock(in_ch=in_ch_dec,
                             skip_ch=skip_ch,
                             out_ch=out_ch,
                             activation=act_fn,
                             dropout_prob=0.1,
                             attention=True,
                             upsample=True)
            )

        # Final 3×3 conv + Sigmoid → 1 channel
        self.final_conv  = ASPP(up_filters[-1], out_channels, nn.Sigmoid())

    def forward(self, x):
        # x: (B, 1, 128, 128)
        skips = []
        for enc in self.encoders:
            x, skip = enc(x)
            skips.append(skip)

        x = self.bottleneck_aspp(x)
        x = self.bottleneck_trans(x)
        skips = skips[::-1]              # reverse order for decoding

        for i, dec in enumerate(self.decoders):
            skip_feat = skips[i]
            x = dec(x, skip_feat)

        x = self.final_conv(x)
        return x

In [9]:
# 1) pip install torchinfo
#    (if you haven’t already)
import torch
from torchinfo import summary

down_filters     = [32, 64, 128, 256, 512]
down_activations = ['relu'] * len(down_filters)

up_filters       = [512, 256, 128]#down_filters[::-1]  # reverse the down_filters
up_activations   = ['relu'] * len(up_filters)

# 2) Re‐instantiate your UNet exactly as in your training code:
model = UNet(
        down_filters=down_filters,
        down_activations=down_activations,
        up_filters=up_filters,
        up_activations=up_activations)

# 3) Ask for a summary on a dummy (1×1×128×128) input:
_ = summary(
    model,
    input_size=(128, 1, 128, 128),
    col_names=["input_size", "output_size", "num_params", "trainable"],
    verbose=1
)


Layer (type:depth-idx)                                  Input Shape               Output Shape              Param #                   Trainable
UNet                                                    [128, 1, 128, 128]        [128, 1, 32, 32]          --                        True
├─ModuleList: 1-1                                       --                        --                        --                        True
│    └─EncoderBlock: 2-1                                [128, 1, 128, 128]        [128, 32, 64, 64]         --                        True
│    │    └─ASPP: 3-1                                   [128, 1, 128, 128]        [128, 32, 128, 128]       4,792                     True
│    │    └─Identity: 3-2                               [128, 32, 128, 128]       [128, 32, 128, 128]       --                        --
│    │    └─Dropout2d: 3-3                              [128, 32, 128, 128]       [128, 32, 128, 128]       --                        --
│    └─EncoderBlock: 2-2  