define Unet

In [35]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [130]:
def Unet_conv(in_channels, out_channels) :
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, 3, padding=1),
        nn.ReLU(inplace = True), 
        # inplace = True : input으로 들어온 것 자체를 수정하겠다는 말
        # 메모리 usage가 좋아지지만 input을 없앤다는 단점 존재
        nn.Conv2d(out_channels, out_channels, 3, padding=1),
        nn.ReLU(inplace=True)
    )

class Unet(nn.Module):
    def __init__(self) :
        super().__init__()
        
        # Contracting path
        self.conv_down1 = Unet_conv(3, 64)
        self.conv_down2 = Unet_conv(64, 128)
        self.conv_down3 = Unet_conv(128, 256)
        self.conv_down4 = Unet_conv(256, 512)
        self.conv_down5 = Unet_conv(512, 1024)

        
        self.maxpool = nn.MaxPool2d(kernel_size = 2)
        self.upsample = nn.Upsample(scale_factor = 2, mode = 'bilinear', align_corners=True)
        

        # Expansive path
        self.conv_up4 = Unet_conv(512+1024, 512)
        self.conv_up3 = Unet_conv(256+512, 256)
        self.conv_up2 = Unet_conv(128+256, 128)
        self.conv_up1 = Unet_conv(64+128, 64)
        self.final_layer = nn.Conv2d(64, 2, 1)
    
    
        
    def forward(self, x):
            # Expansive 과정에서 copy and crop을 위해 conv1,2,3,4를 따로 정의
            
            # Contracting
            conv1 = self.conv_down1(x)
            x = self.maxpool(conv1)
            
            conv2 = self.conv_down2(x)
            x = self.maxpool(conv2)
            
            conv3 = self.conv_down3(x)
            x = self.maxpool(conv3)
            
            conv4 = self.conv_down4(x)
            x = self.maxpool(conv4)
            
            #Expansive
            x = self.conv_down5(x)
            
            x = self.upsample(x)
            x = torch.cat([x, conv4], dim = 1)
            
            x = self.conv_up4(x)
            x = self.upsample(x)
            x = torch.cat([x, conv3], dim = 1)
            
            x = self.conv_up3(x)
            x = self.upsample(x)
            x = torch.cat([x, conv2], dim = 1)
            
            x = self.conv_up2(x)
            x = self.upsample(x)
            x = torch.cat([x, conv1], dim = 1)
            
            x = self.conv_up1(x)
            
            out = self.final_layer(x)
            
            return out
        

In [131]:
from torchsummary import summary

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = Unet()
model = model.to(device)

summary(model, input_size = (3, 64, 64))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 64, 64]           1,792
              ReLU-2           [-1, 64, 64, 64]               0
            Conv2d-3           [-1, 64, 64, 64]          36,928
              ReLU-4           [-1, 64, 64, 64]               0
         MaxPool2d-5           [-1, 64, 32, 32]               0
            Conv2d-6          [-1, 128, 32, 32]          73,856
              ReLU-7          [-1, 128, 32, 32]               0
            Conv2d-8          [-1, 128, 32, 32]         147,584
              ReLU-9          [-1, 128, 32, 32]               0
        MaxPool2d-10          [-1, 128, 16, 16]               0
           Conv2d-11          [-1, 256, 16, 16]         295,168
             ReLU-12          [-1, 256, 16, 16]               0
           Conv2d-13          [-1, 256, 16, 16]         590,080
             ReLU-14          [-1, 256,

In [133]:
input = torch.Tensor(32, 32)
input.size()

torch.Size([32, 32])