In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
class ConvBlock(nn.Module):
    """
    Standard Conv -> BN -> ReLU block used throughout the network.
    """
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1):
        super(ConvBlock, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, bias=False)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        return self.relu(self.bn(self.conv(x)))

class DoubleConv(nn.Module):
    """
    The standard U-Net encoder block: Two Conv3x3 blocks.
    """
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.double_conv = nn.Sequential(
            ConvBlock(in_channels, out_channels),
            ConvBlock(out_channels, out_channels)
        )

    def forward(self, x):
        return self.double_conv(x)

class RFB_Skip(nn.Module):
    """
    Multi-level RFB-based skip pathways (Section 2.2).
    It consists of two parallel paths:
    1. A stack of standard 3x3 convolutions (approximating a larger kernel).
    2. A single dilated convolution.
    The outputs are concatenated.
    """
    def __init__(self, in_channels, num_stack, dilation_rate):
        super(RFB_Skip, self).__init__()
        
        # Path 1: Stack of standard convolutions (Green blocks in diagram)
        # To emulate large kernels like 7x7, we stack 3x3 convs.
        stack_layers = []
        for _ in range(num_stack):
            # 1x1 convs are used in the bottom level as per text, 3x3 elsewhere
            k = 3 if num_stack > 1 else 1 
            p = 1 if num_stack > 1 else 0
            stack_layers.append(ConvBlock(in_channels, in_channels, kernel_size=k, padding=p))
        self.stack_path = nn.Sequential(*stack_layers)

        # Path 2: Dilated Convolution (Yellow blocks in diagram)
        # Note: Section 2.2 mentions matching receptive fields.
        # For the bottom level (dilation 1), it acts as a standard conv.
        self.dilated_path = ConvBlock(in_channels, in_channels, 
                                      kernel_size=3, 
                                      padding=dilation_rate, # Padding must equal dilation to keep size
                                      dilation=dilation_rate)

    def forward(self, x):
        out_stack = self.stack_path(x)
        out_dilated = self.dilated_path(x)
        # Concatenate features from both paths
        return torch.cat([out_stack, out_dilated], dim=1)

class CSE_UNet(nn.Module):
    def __init__(self, in_channels=3, num_classes=1):
        super(CSE_UNet, self).__init__()
        
        filters = [64, 128, 256, 512, 1024]

        # --- ENCODER (Dual-Path) ---
        
        # Level 1
        self.enc1_main = DoubleConv(in_channels, filters[0])
        self.pool1 = nn.MaxPool2d(2)
        # Aux Path 1: Conv7x7 stride 2 (Section 2.3)
        self.enc1_aux = nn.Sequential(
            nn.Conv2d(in_channels, filters[0], kernel_size=7, stride=2, padding=3, bias=False),
            nn.BatchNorm2d(filters[0])
        )

        # Level 2
        self.enc2_main = DoubleConv(filters[0], filters[1])
        self.pool2 = nn.MaxPool2d(2)
        # Aux Path 2: Conv5x5 stride 2
        self.enc2_aux = nn.Sequential(
            nn.Conv2d(filters[0], filters[1], kernel_size=5, stride=2, padding=2, bias=False),
            nn.BatchNorm2d(filters[1])
        )

        # Level 3
        self.enc3_main = DoubleConv(filters[1], filters[2])
        self.pool3 = nn.MaxPool2d(2)
        # Aux Path 3: Conv3x3 stride 2
        self.enc3_aux = nn.Sequential(
            nn.Conv2d(filters[1], filters[2], kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(filters[2])
        )

        # Level 4
        self.enc4_main = DoubleConv(filters[2], filters[3])
        self.pool4 = nn.MaxPool2d(2)
        # Aux Path 4: Conv2x2 stride 2
        # Note: 2x2 kernel with stride 2 and padding 0 halves the dimension perfectly
        self.enc4_aux = nn.Sequential(
            nn.Conv2d(filters[2], filters[3], kernel_size=2, stride=2, padding=0, bias=False),
            nn.BatchNorm2d(filters[3])
        )

        # Bridge
        self.bridge = DoubleConv(filters[3], filters[4])

        # --- RFB SKIP PATHWAYS ---
        
        # According to Section 2.2 and Diagram
        # Top (Level 1): Stack of 3, Dilation 7
        self.rfb1 = RFB_Skip(filters[0], num_stack=3, dilation_rate=7)
        # Level 2: Stack of 2, Dilation 5
        self.rfb2 = RFB_Skip(filters[1], num_stack=2, dilation_rate=5)
        # Level 3: Stack of 1 (3x3), Dilation 3
        self.rfb3 = RFB_Skip(filters[2], num_stack=1, dilation_rate=3)
        # Level 4: Stack of 1 (1x1), Dilation 1
        # Text says: "one convolution layer with 1x1... and one dilated... dilation rate of 1"
        # We pass num_stack=1, but internal logic handles the 1x1 kernel switch for the bottom layer
        self.rfb4 = RFB_Skip(filters[3], num_stack=1, dilation_rate=1) 
        
        # Note on RFB4: The stack logic in `RFB_Skip` uses 3x3 by default. 
        # We need to manually override or create a specific block if strict adherence to "1x1" 
        # for the stack path is required. I added logic in RFB_Skip to handle this.

        # --- DECODER ---
        
        # Since RFB concatenates output, skip channels are doubled
        
        # Decoder 4
        self.up4 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.up_conv4 = nn.Conv2d(filters[4], filters[3], kernel_size=1) # Conv1x1 after upsample
        self.dec4 = DoubleConv(filters[3] + (filters[3] * 2), filters[3]) # Input = Prev + Skip(RFB x2)

        # Decoder 3
        self.up3 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.up_conv3 = nn.Conv2d(filters[3], filters[2], kernel_size=1)
        self.dec3 = DoubleConv(filters[2] + (filters[2] * 2), filters[2])

        # Decoder 2
        self.up2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.up_conv2 = nn.Conv2d(filters[2], filters[1], kernel_size=1)
        self.dec2 = DoubleConv(filters[1] + (filters[1] * 2), filters[1])

        # Decoder 1
        self.up1 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.up_conv1 = nn.Conv2d(filters[1], filters[0], kernel_size=1)
        self.dec1 = DoubleConv(filters[0] + (filters[0] * 2), filters[0])

        # Final Output
        self.final_conv = nn.Conv2d(filters[0], num_classes, kernel_size=1)

    def forward(self, x):
        # --- ENCODER + DUAL PATH FUSION ---
        
        # Stage 1
        x1_main = self.enc1_main(x)
        x1_pool = self.pool1(x1_main)
        x1_aux = self.enc1_aux(x)
        # Fusion: Main (Pooled) + Aux
        x1_fused = x1_pool + x1_aux 
        # Apply ReLU after addition (Standard ResNet practice, though not explicitly drawn, implied by block logic)
        x1_fused = F.relu(x1_fused)

        # Stage 2
        x2_main = self.enc2_main(x1_fused)
        x2_pool = self.pool2(x2_main)
        x2_aux = self.enc2_aux(x1_fused)
        x2_fused = x2_pool + x2_aux
        x2_fused = F.relu(x2_fused)

        # Stage 3
        x3_main = self.enc3_main(x2_fused)
        x3_pool = self.pool3(x3_main)
        x3_aux = self.enc3_aux(x2_fused)
        x3_fused = x3_pool + x3_aux
        x3_fused = F.relu(x3_fused)

        # Stage 4
        x4_main = self.enc4_main(x3_fused)
        x4_pool = self.pool4(x4_main)
        x4_aux = self.enc4_aux(x3_fused)
        x4_fused = x4_pool + x4_aux
        x4_fused = F.relu(x4_fused)

        # Bridge
        x_bridge = self.bridge(x4_fused)

        # --- RFB SKIP GENERATION ---
        # The skips come from the fused encoder features before they went to the next level
        # Note: The diagram arrows for skips originate from the output of the "Main" double conv 
        # BEFORE pooling/fusion? Or from the fused result?
        # Looking at Fig 2: The arrows go Inputs -> Conv -> [Conv] -> right arrow to RFB.
        # This implies the skip comes from the STANDARD encoder path before pooling.
        
        s1 = self.rfb1(x1_main)
        s2 = self.rfb2(x2_main)
        s3 = self.rfb3(x3_main)
        s4 = self.rfb4(x4_main)

        # --- DECODER ---
        
        # Block 4
        d4 = self.up4(x_bridge)
        d4 = self.up_conv4(d4)
        # Concatenate with RFB Skip 4
        d4 = torch.cat([d4, s4], dim=1)
        d4 = self.dec4(d4)

        # Block 3
        d3 = self.up3(d4)
        d3 = self.up_conv3(d3)
        d3 = torch.cat([d3, s3], dim=1)
        d3 = self.dec3(d3)

        # Block 2
        d2 = self.up2(d3)
        d2 = self.up_conv2(d2)
        d2 = torch.cat([d2, s2], dim=1)
        d2 = self.dec2(d2)

        # Block 1
        d1 = self.up1(d2)
        d1 = self.up_conv1(d1)
        d1 = torch.cat([d1, s1], dim=1)
        d1 = self.dec1(d1)

        return self.final_conv(d1)

In [3]:
model = CSE_UNet(in_channels=3, num_classes=1)
dummy_input = torch.randn(1, 3, 256, 256)
output = model(dummy_input)
print(f"Input Shape: {dummy_input.shape}")
print(f"Output Shape: {output.shape}")

# Calculate params
total_params = sum(p.numel() for p in model.parameters())
print(f"Total Parameters: {total_params:,}")

Input Shape: torch.Size([1, 3, 256, 256])
Output Shape: torch.Size([1, 1, 256, 256])
Total Parameters: 36,988,417
