In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

In [None]:
class ChannelAttentionModule(nn.Module):
    """ Channel attention module """
    def __init__(self, in_channels, reduction_ratio=16):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.fc = nn.Sequential(
            nn.Conv2d(in_channels, in_channels // reduction_ratio, 1, bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels // reduction_ratio, in_channels, 1, bias=False)
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.fc(self.avg_pool(x))
        max_out = self.fc(self.max_pool(x))
        out = avg_out + max_out
        return self.sigmoid(out)

class SpatialAttentionModule(nn.Module):
    """ Spatial attention module """
    def __init__(self):
        super().__init__()
        self.conv3x3 = nn.Conv2d(2, 1, kernel_size=3, stride=1, padding=1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avg_out, max_out], dim=1)
        x = self.conv3x3(x)
        return self.sigmoid(x)

class DualAttentionModule(nn.Module):
    """ Dual attention module """
    def __init__(self, in_channels, reduction_ratio=16):
        super().__init__()
        self.channel_attention = ChannelAttentionModule(in_channels, reduction_ratio)
        self.spatial_attention = SpatialAttentionModule()

    def forward(self, x):
        # Apply channel attention
        x = x * self.channel_attention(x)
        # Apply spatial attention
        x = x * self.spatial_attention(x)
        return x


In [None]:
class UNetWithDualAttention(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=True):
        super(UNetWithDualAttention, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = DoubleConv(n_channels, 64)
        self.down1 = DoubleConv(64, 128)
        self.down2 = DoubleConv(128, 256)
        self.down3 = DoubleConv(256, 512)
        self.dam1 = DualAttentionModule(512)  # Dual attention module
        self.up1 = DoubleConv(256 + 512, 256)
        self.up2 = DoubleConv(128 + 256, 128)
        self.up3 = DoubleConv(64 + 128, 64)
        self.dam2 = DualAttentionModule(64)  # Dual attention module
        self.outc = nn.Conv2d(64, n_classes, kernel_size=1)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = F.max_pool2d(x1, 2)
        x2 = self.down1(x2)
        x3 = F.max_pool2d(x2, 2)
        x3 = self.down2(x3)
        x4 = F.max_pool2d(x3, 2)
        x4 = self.down3(x4)

        x4 = self.dam1(x4)  # Apply dual attention module

        x = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=True) if self.bilinear else F.conv_transpose2d(x4, self.up1.weight, stride=2)
        x = torch.cat([x, x3], dim=1)
        x = self.up1(x)
        x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True) if self.bilinear else F.conv_transpose2d(x, self.up2.weight, stride=2)
        x = torch.cat([x, x2], dim=1)
        x = self.up2(x)
        x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True) if self.bilinear else F.conv_transpose2d(x, self.up3.weight, stride=2)
        x = torch.cat([x, x1], dim=1)
        x = self.up3(x)

        x = self.dam2(x)  # Apply dual attention module

        logits = self.outc(x)
        return logits

# Create the model
model = UNetWithDualAttention(n_channels=3, n_classes=1)

In [None]:
print(model)

UNetWithDualAttention(
  (inc): DoubleConv(
    (double_conv): Sequential(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
  )
  (down1): DoubleConv(
    (double_conv): Sequential(
      (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
  )
  (down2): DoubleConv(
    (double_conv): Sequential(
 

In [None]:
class UNetWithDualAttention(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=True):
        super(UNetWithDualAttention, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = DoubleConv(n_channels, 64)
        self.down1 = DoubleConv(64, 128)
        self.down2 = DoubleConv(128, 256)
        self.down3 = DoubleConv(256, 512)
        self.dam1 = DualAttentionModule(512)  # Dual attention module
        self.up1 = DoubleConv(256 + 512, 256)
        self.up2 = DoubleConv(128 + 256, 128)
        self.up3 = DoubleConv(64 + 128, 64)
        self.dam2 = DualAttentionModule(64)  # Dual attention module
        self.outc = nn.Conv2d(64, n_classes, kernel_size=1)

    def forward(self, x):
        sizes = []
        x1 = self.inc(x)
        sizes.append(x1.size())
        x2 = F.max_pool2d(x1, 2)
        x2 = self.down1(x2)
        sizes.append(x2.size())
        x3 = F.max_pool2d(x2, 2)
        x3 = self.down2(x3)
        sizes.append(x3.size())
        x4 = F.max_pool2d(x3, 2)
        x4 = self.down3(x4)
        sizes.append(x4.size())

        x4 = self.dam1(x4)  # Apply dual attention module
        sizes.append(x4.size())

        x = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=True) if self.bilinear else F.conv_transpose2d(x4, self.up1.weight, stride=2)
        x = torch.cat([x, x3], dim=1)
        x = self.up1(x)
        sizes.append(x.size())
        x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True) if self.bilinear else F.conv_transpose2d(x, self.up2.weight, stride=2)
        x = torch.cat([x, x2], dim=1)
        x = self.up2(x)
        sizes.append(x.size())
        x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True) if self.bilinear else F.conv_transpose2d(x, self.up3.weight, stride=2)
        x = torch.cat([x, x1], dim=1)
        x = self.up3(x)
        sizes.append(x.size())

        x = self.dam2(x)  # Apply dual attention module
        sizes.append(x.size())

        logits = self.outc(x)
        sizes.append(logits.size())
        return logits, sizes

# Create the model
model = UNetWithDualAttention(n_channels=3, n_classes=1)

# Create a dummy input tensor of size (1, 3, 256, 256) (batch_size, channels, height, width)
dummy_input = torch.randn(1, 3, 256, 256)

# Forward pass through the model to get the size of feature maps after each layer
_, sizes = model(dummy_input)

# Output the sizes
for i, size in enumerate(sizes, start=1):
    print(f"Layer {i} output size: {size}")


Layer 1 output size: torch.Size([1, 64, 256, 256])
Layer 2 output size: torch.Size([1, 128, 128, 128])
Layer 3 output size: torch.Size([1, 256, 64, 64])
Layer 4 output size: torch.Size([1, 512, 32, 32])
Layer 5 output size: torch.Size([1, 512, 32, 32])
Layer 6 output size: torch.Size([1, 256, 64, 64])
Layer 7 output size: torch.Size([1, 128, 128, 128])
Layer 8 output size: torch.Size([1, 64, 256, 256])
Layer 9 output size: torch.Size([1, 64, 256, 256])
Layer 10 output size: torch.Size([1, 1, 256, 256])
