# <a href="https://pasus.tistory.com/204">U-Net Implementation in TF</a>
* Tensorflow로 구현되어있는 U-Net PyTorch로 구현해보기
* 목표: 완전히 이해하면서 구현하기
* Biomedical Image Dataset Segmentation까지가 목표
* * * 
* <a href="https://dotiromoook.tistory.com/14"><code>Conv2d</code> 사용법</a>  
* <a href="https://ban2aru.tistory.com/35"><code>BatchNorm2d</code> 사용법</a>
* * * 
* 모델 구조 시각화 방법 좀 더 찾아보자

In [1]:
import torch
from torch import nn
from torchsummary import summary

# ConvBlock Implementation

In [2]:
class ConvBlock(nn.Module):
    def __init__(self, in_filters, out_filters):
        super().__init__()

        # no padding or same padding
        # no padding - pooling 할 때, divide by 2로 가야함
        self.conv1 = nn.Conv2d(in_filters, out_filters, padding=1, kernel_size=3)
        self.conv2 = nn.Conv2d(out_filters, out_filters, padding=1, kernel_size=3)

        self.bn1 = nn.BatchNorm2d(out_filters)
        self.bn2 = nn.BatchNorm2d(out_filters)

        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)

        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)

        return x

In [3]:
convBlock = ConvBlock(3, 64)
summary(convBlock, (3, 512, 512))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 512, 512]           1,792
       BatchNorm2d-2         [-1, 64, 512, 512]             128
              ReLU-3         [-1, 64, 512, 512]               0
            Conv2d-4         [-1, 64, 512, 512]          36,928
       BatchNorm2d-5         [-1, 64, 512, 512]             128
              ReLU-6         [-1, 64, 512, 512]               0
Total params: 38,976
Trainable params: 38,976
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 3.00
Forward/backward pass size (MB): 768.00
Params size (MB): 0.15
Estimated Total Size (MB): 771.15
----------------------------------------------------------------


# Encoder Block

In [4]:
class EncoderBlock(nn.Module):
    def __init__(self, in_filters, out_filters):
        super().__init__()

        self.convBlk = ConvBlock(in_filters, out_filters)
        self.down = nn.MaxPool2d(2)

    def forward(self, x):
        x = self.convBlk(x)
        p = self.down(x)
        return x, p

In [5]:
encoderBlk = EncoderBlock(3, 64)
summary(encoderBlk, (3, 512, 512))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 512, 512]           1,792
       BatchNorm2d-2         [-1, 64, 512, 512]             128
              ReLU-3         [-1, 64, 512, 512]               0
            Conv2d-4         [-1, 64, 512, 512]          36,928
       BatchNorm2d-5         [-1, 64, 512, 512]             128
              ReLU-6         [-1, 64, 512, 512]               0
         ConvBlock-7         [-1, 64, 512, 512]               0
         MaxPool2d-8         [-1, 64, 256, 256]               0
Total params: 38,976
Trainable params: 38,976
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 3.00
Forward/backward pass size (MB): 928.00
Params size (MB): 0.15
Estimated Total Size (MB): 931.15
----------------------------------------------------------------


# Decoder Block

In [6]:
TEST_BATCH_SIZE = 32
test_1 = torch.randn((TEST_BATCH_SIZE, 128, 64, 64))
test_2 = torch.randn((TEST_BATCH_SIZE, 128, 64, 64))
test_out = torch.cat([test_1, test_2], dim=1)
test_out.shape

torch.Size([32, 256, 64, 64])

In [7]:
class DecoderBlock(nn.Module):
    def __init__(self, in_filters, out_filters):
        super().__init__()

        # 여긴 padding 안 해야 사이즈 2배 됨.
        self.up = nn.ConvTranspose2d(in_filters, out_filters, kernel_size=2, stride=2)
        self.convBlk = ConvBlock(in_filters, out_filters)

        # up 기준 channel이 512일 때, up이 되면 256이되고,
        # skip conncection이 합쳐졌을 때, 다시 512 -> ConvBlock -> 256

    def forward(self, x, skip):
        x = self.up(x)
        # skip connection이랑 upsampling block 순서 상관 없나?
        # dim 0 -> batch, dim 1 -> channel
        x = torch.cat([x, skip], dim=1) # skip 앞 뒤 순서 애매함
        x = self.convBlk(x)

        return x

In [8]:
decoderBlk = DecoderBlock(128, 64)

In [9]:
summary(decoderBlk, [(128, 64, 64), (64, 128, 128)])

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
   ConvTranspose2d-1         [-1, 64, 128, 128]          32,832
            Conv2d-2         [-1, 64, 128, 128]          73,792
       BatchNorm2d-3         [-1, 64, 128, 128]             128
              ReLU-4         [-1, 64, 128, 128]               0
            Conv2d-5         [-1, 64, 128, 128]          36,928
       BatchNorm2d-6         [-1, 64, 128, 128]             128
              ReLU-7         [-1, 64, 128, 128]               0
         ConvBlock-8         [-1, 64, 128, 128]               0
Total params: 143,808
Trainable params: 143,808
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 64.00
Params size (MB): 0.55
Estimated Total Size (MB): 64.55
----------------------------------------------------------------


# U-Net Implementation

In [10]:
class DobyUNet(nn.Module):
    def __init__(self, n_classes):
        super().__init__()

        # Encoder
        self.e1 = EncoderBlock(3, 64)
        self.e2 = EncoderBlock(64, 128)
        self.e3 = EncoderBlock(128, 256)
        self.e4 = EncoderBlock(256, 512)

        # Bridge
        self.b = ConvBlock(512, 1024)

        # Decoder
        self.d1 = DecoderBlock(1024, 512)
        self.d2 = DecoderBlock(512, 256)
        self.d3 = DecoderBlock(256, 128)
        self.d4 = DecoderBlock(128, 64)

        if n_classes == 1:
            self.output = nn.Sigmoid()
        else:
            self.output = nn.Softmax()
            
    def forward(self, x):
        s1, p1 = self.e1(x)
        s2, p2 = self.e2(p1)
        s3, p3 = self.e3(p2)
        s4, p4 = self.e4(p3)

        b = self.b(p4)

        d1 = self.d1(b, s4)
        d2 = self.d2(d1, s3)
        d3 = self.d3(d2, s2)
        d4 = self.d4(d3, s1)

        output = self.output(d4)

        return output

In [11]:
model = DobyUNet(n_classes=1)
print(model)

DobyUNet(
  (e1): EncoderBlock(
    (convBlk): ConvBlock(
      (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU()
    )
    (down): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (e2): EncoderBlock(
    (convBlk): ConvBlock(
      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU()
    )
    (down): MaxPool2d(kernel_size=2, stride=2, pa

In [12]:
summary(model, (3, 512, 512))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 512, 512]           1,792
       BatchNorm2d-2         [-1, 64, 512, 512]             128
              ReLU-3         [-1, 64, 512, 512]               0
            Conv2d-4         [-1, 64, 512, 512]          36,928
       BatchNorm2d-5         [-1, 64, 512, 512]             128
              ReLU-6         [-1, 64, 512, 512]               0
         ConvBlock-7         [-1, 64, 512, 512]               0
         MaxPool2d-8         [-1, 64, 256, 256]               0
      EncoderBlock-9  [[-1, 64, 512, 512], [-1, 64, 256, 256]]               0
           Conv2d-10        [-1, 128, 256, 256]          73,856
      BatchNorm2d-11        [-1, 128, 256, 256]             256
             ReLU-12        [-1, 128, 256, 256]               0
           Conv2d-13        [-1, 128, 256, 256]         147,584
      BatchNorm2d-14    

## <code>nn.ConvTranspose2d</code>에 대한 테스트
$$output = (input - 1) * stride + kernel - 2 * padding$$


In [13]:
class ConvTransposed2dSample(nn.Module):
    def __init__(self):
        super().__init__()
        self.main = nn.ConvTranspose2d(32, 64, kernel_size=2, stride=2)

    def forward(self, x):
        x = self.main(x)
        return x

In [14]:
convTest = ConvTransposed2dSample()

In [15]:
summary(convTest, (32, 24, 24))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
   ConvTranspose2d-1           [-1, 64, 48, 48]           8,256
Total params: 8,256
Trainable params: 8,256
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.07
Forward/backward pass size (MB): 1.12
Params size (MB): 0.03
Estimated Total Size (MB): 1.23
----------------------------------------------------------------
