In [None]:
import torch
import torch.nn as nn

class ConvBNReLU(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ConvBNReLU, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x

class ResPath(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ResPath, self).__init__()
        self.shortcut = ConvBNReLU(in_channels, out_channels)
        self.conv1 = ConvBNReLU(in_channels, out_channels)

    def forward(self, x):
        shortcut = self.shortcut(x)
        out = self.conv1(x)
        out += shortcut
        return out

class DenseresNet(nn.Module):
    def __init__(self):
        super(DenseresNet, self).__init__()
        self.conv0 = ConvBNReLU(3, 32)
        self.conv1 = ConvBNReLU(32, 64)
        self.conv2 = ConvBNReLU(64, 128)
        self.conv3 = ConvBNReLU(128, 256)
        self.conv4 = ConvBNReLU(256, 512)
        self.conv5 = ConvBNReLU(512, 512)
        self.conv6 = ConvBNReLU(512, 256)
        self.conv7 = ConvBNReLU(256, 128)
        self.conv8 = ConvBNReLU(128, 64)
        self.conv9 = ConvBNReLU(64, 32)
        self.conv10 = ConvBNReLU(32, 16)
        self.conv11 = ConvBNReLU(16, 8)
        self.conv12 = nn.Conv2d(8, 3, kernel_size=3, stride=1, padding=1)

        self.respath0 = ResPath(32, 32)
        self.respath1 = ResPath(64, 64)
        self.respath2 = ResPath(128, 128)
        self.respath3 = ResPath(256, 256)
        self.respath4 = ResPath(512, 512)

    def forward(self, x):
        x0 = self.conv0(x)
        x1 = self.conv1(x0)
        x2 = self.conv2(x1)
        x3 = self.conv3(x2)
        x4 = self.conv4(x3)
        x5 = self.conv5(x4)
        x6 = self.conv6(x5)
        x7 = self.conv7(x6)
        x8 = self.conv8(x7)
        x9 = self.conv9(x8)
        x10 = self.conv10(x9)
        x11 = self.conv11(x10)
        x12 = self.conv12(x11)

        return x12

# Create an instance of the UNet model
model = DenseresNet()

# Assuming input image size is 256x256
input_size = (3, 256, 256)

# Print model summary
print(model)


DenseresNet(
  (conv0): ConvBNReLU(
    (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
  )
  (conv1): ConvBNReLU(
    (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
  )
  (conv2): ConvBNReLU(
    (conv): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
  )
  (conv3): ConvBNReLU(
    (conv): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
  )
  (conv4): ConvBNReLU(
    (conv): Conv2d(256, 512, kernel_size=(3, 3), stride=

In [None]:

# Count the layers
total_layers = sum(1 for _ in model.modules() if isinstance(_, nn.Conv2d) or isinstance(_, nn.ConvTranspose2d) or isinstance(_, ResPath))
print(f"Total layers in the model: {total_layers}")

Total layers in the model: 28


In [None]:
import torch
from torchsummary import summary

model = DenseresNet()

# Assuming input image size is 256x256
input_size = (3, 256, 256)

# Print model summary
summary(model, input_size)


----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 32, 256, 256]             896
       BatchNorm2d-2         [-1, 32, 256, 256]              64
              ReLU-3         [-1, 32, 256, 256]               0
        ConvBNReLU-4         [-1, 32, 256, 256]               0
            Conv2d-5         [-1, 64, 256, 256]          18,496
       BatchNorm2d-6         [-1, 64, 256, 256]             128
              ReLU-7         [-1, 64, 256, 256]               0
        ConvBNReLU-8         [-1, 64, 256, 256]               0
            Conv2d-9        [-1, 128, 256, 256]          73,856
      BatchNorm2d-10        [-1, 128, 256, 256]             256
             ReLU-11        [-1, 128, 256, 256]               0
       ConvBNReLU-12        [-1, 128, 256, 256]               0
           Conv2d-13        [-1, 256, 256, 256]         295,168
      BatchNorm2d-14        [-1, 256, 2