In [1]:
import torch
from torch import nn
from torchsummary import summary

In [2]:
class ConvBlock(nn.Module):
    def __init__(self, in_filters, out_filters):
        super().__init__()
        self.conv1 = nn.Conv2d(in_filters, out_filters, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(out_filters, out_filters, kernel_size=3, padding=1)

        self.bn1 = nn.BatchNorm2d(out_filters)
        self.bn2 = nn.BatchNorm2d(out_filters)

        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)
        return x

In [3]:
test1 = ConvBlock(16, 32)
summary(test1, (16, 512, 512))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 32, 512, 512]           4,640
       BatchNorm2d-2         [-1, 32, 512, 512]              64
              ReLU-3         [-1, 32, 512, 512]               0
            Conv2d-4         [-1, 32, 512, 512]           9,248
       BatchNorm2d-5         [-1, 32, 512, 512]              64
              ReLU-6         [-1, 32, 512, 512]               0
Total params: 14,016
Trainable params: 14,016
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 16.00
Forward/backward pass size (MB): 384.00
Params size (MB): 0.05
Estimated Total Size (MB): 400.05
----------------------------------------------------------------


In [4]:
class EncoderBlock(nn.Module):
    def __init__(self, in_filters, out_filters):
        super().__init__()

        self.convBlk = ConvBlock(in_filters, out_filters)
        self.pool = nn.MaxPool2d(2)

    def forward(self, x):
        x = self.convBlk(x)
        p = self.pool(x)
        return x, p

In [5]:
test2 = EncoderBlock(64, 32)
summary(test2, (64, 512, 512))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 32, 512, 512]          18,464
       BatchNorm2d-2         [-1, 32, 512, 512]              64
              ReLU-3         [-1, 32, 512, 512]               0
            Conv2d-4         [-1, 32, 512, 512]           9,248
       BatchNorm2d-5         [-1, 32, 512, 512]              64
              ReLU-6         [-1, 32, 512, 512]               0
         ConvBlock-7         [-1, 32, 512, 512]               0
         MaxPool2d-8         [-1, 32, 256, 256]               0
Total params: 27,840
Trainable params: 27,840
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 64.00
Forward/backward pass size (MB): 464.00
Params size (MB): 0.11
Estimated Total Size (MB): 528.11
----------------------------------------------------------------


In [6]:
class DecoderBlock(nn.Module):
    def __init__(self, in_filters, out_filters):
        super().__init__()
        self.transposeConv = nn.ConvTranspose2d(in_filters, out_filters, kernel_size=2, stride=2)
        self.convBlk = ConvBlock(in_filters, out_filters)
        
    def forward(self, x, skip):
        x = self.transposeConv(x)
        x = torch.cat([x, skip], dim=1)
        x = self.convBlk(x)
        
        return x

In [7]:
test3 = DecoderBlock(64, 32)
summary(test3, [(64, 16, 16), (32, 32, 32)])

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
   ConvTranspose2d-1           [-1, 32, 32, 32]           8,224
            Conv2d-2           [-1, 32, 32, 32]          18,464
       BatchNorm2d-3           [-1, 32, 32, 32]              64
              ReLU-4           [-1, 32, 32, 32]               0
            Conv2d-5           [-1, 32, 32, 32]           9,248
       BatchNorm2d-6           [-1, 32, 32, 32]              64
              ReLU-7           [-1, 32, 32, 32]               0
         ConvBlock-8           [-1, 32, 32, 32]               0
Total params: 36,064
Trainable params: 36,064
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 2048.00
Forward/backward pass size (MB): 2.00
Params size (MB): 0.14
Estimated Total Size (MB): 2050.14
----------------------------------------------------------------


In [19]:
class DobyUNet(nn.Module):
    def __init__(self):
        super().__init__()

        # Constracting Path
        self.e1 = EncoderBlock(1, 64)
        self.e2 = EncoderBlock(64, 128)
        self.e3 = EncoderBlock(128, 256)
        self.e4 = EncoderBlock(256, 512)

        # Bridge
        self.b = ConvBlock(512, 1024)

        # Expanding Path
        self.d1 = DecoderBlock(1024, 512)
        self.d2 = DecoderBlock(512, 256)
        self.d3 = DecoderBlock(256, 128)
        self.d4 = DecoderBlock(128, 64)

        self.convOut = nn.Conv2d(64, 1, kernel_size=1, stride=1)

    def forward(self, x):
        s1, p1 = self.e1(x)
        s2, p2 = self.e2(p1)
        s3, p3 = self.e3(p2)
        s4, p4 = self.e4(p3)
        
        b = self.b(p4)
        d1 = self.d1(b, s4)
        d2 = self.d2(d1, s3)
        d3 = self.d3(d2, s2)
        d4 = self.d4(d3, s1)

        output = self.convOut(d4)
        
        return output

In [20]:
model = DobyUNet()

In [21]:
summary(model, (1, 512, 512))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 512, 512]             640
       BatchNorm2d-2         [-1, 64, 512, 512]             128
              ReLU-3         [-1, 64, 512, 512]               0
            Conv2d-4         [-1, 64, 512, 512]          36,928
       BatchNorm2d-5         [-1, 64, 512, 512]             128
              ReLU-6         [-1, 64, 512, 512]               0
         ConvBlock-7         [-1, 64, 512, 512]               0
         MaxPool2d-8         [-1, 64, 256, 256]               0
      EncoderBlock-9  [[-1, 64, 512, 512], [-1, 64, 256, 256]]               0
           Conv2d-10        [-1, 128, 256, 256]          73,856
      BatchNorm2d-11        [-1, 128, 256, 256]             256
             ReLU-12        [-1, 128, 256, 256]               0
           Conv2d-13        [-1, 128, 256, 256]         147,584
      BatchNorm2d-14    

In [23]:
SAVE_PATH = './DobyUnet.pth'
torch.save(model, SAVE_PATH)