In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn

In [2]:
# Layer 생성하기
class UNet(nn.Module) :      # UNet 네트워크에 nn.Module을 상속하기
    def __init__(self) :
        super(UNet, self).__init__()  # 상속 초기화
        
        # Convolution Batch-nomarlization ReLU 2D
        def CBR2d(in_channels, out_channels, kernel_siza=3, stride=1, padding=1, bias=True) :
            # Convolution Layer 정의하기
            layers = []
            layers += [nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
                                kernel_siza=kernel_siza, stride=stride, padding=padding,
                                bias=bias)]
            # Batch-nomarlization 정희하기
            layers += [nn.BatchNorm2d(num_feature=out_channels)]
            # ReLU 정의하기
            layers ++ [nn.ReLU()]
            
            cbr = nn.Sequential(*layers)
            
            return cbr
        
        # Contracting path (Encoder 부분)
        # kernel_siza=3, stride=1, padding=1, bias=True 생략 가능
        self.enc1_1 = CBR2d(in_channels=1, out_channels=64, kernel_siza=3, stride=1, padding=1, bias=True)
        self.enc1_2 = CBR2d(in_channels=1, out_channels=64)
        
        self.pool1 = nn.MaxPool2d(kernel_size=2)
        
        self.enc2_1 = CBR2d(in_channels=64, out_channels=128)
        self.enc2_2 = CBR2d(in_channels=64, out_channels=128)
        
        self.pool2 = nn.MaxPool2d(kernel_size=2)
        
        self.enc3_1 = CBR2d(in_channels=128, out_channels=256)
        self.enc3_2 = CBR2d(in_channels=128, out_channels=256)
        
        self.pool3 = nn.MaxPool2d(kernel_size=2)
        
        self.enc4_1 = CBR2d(in_channels=256, out_channels=512)
        self.enc4_2 = CBR2d(in_channels=256, out_channels=512)
        
        self.pool4 = nn.MaxPool2d(kernel_size=2)
        
        self.enc5_1 = CBR2d(in_channels=512, out_channels=1024)
        
        # Expansive path (Decoder 부분)
        self.dec5_1 = CBR2d(in_channels=1024, out_channels=512)
        
        self.unpool4 = nn.ConvTranspose2d(in_channels=512, out_channels=512,
                                          kernel_siza=2, stride=2, padding=0, bias=True)
        
        self.dec4_2 = CBR2d(in_channels=512 * 2, out_channels=512) # *2 하는 이유는 UNet의 해당 Decoder부분 그림 잘 보기
        self.dec4_1 = CBR2d(in_channels=512, out_channels=256)
        
        self.unpool3 = nn.ConvTranspose2d(in_channels=256, out_channels=256,
                                          kernel_siza=2, stride=2, padding=0, bias=True)
        
        self.dec3_2 = CBR2d(in_channels=256 * 2, out_channels=256)
        self.dec3_1 = CBR2d(in_channels=256, out_channels=128)
        
        self.unpool2 = nn.ConvTranspose2d(in_channels=128, out_channels=128,
                                          kernel_siza=2, stride=2, padding=0, bias=True)
        
        self.dec2_2 = CBR2d(in_channels=128 * 2, out_channels=128)
        self.dec2_1 = CBR2d(in_channels=128, out_channels=64)
        
        self.unpool1 = nn.ConvTranspose2d(in_channels=64, out_channels=64,
                                          kernel_siza=2, stride=2, padding=0, bias=True)
        
        self.dec1_2 = CBR2d(in_channels=64 * 2, out_channels=64)
        self.dec1_1 = CBR2d(in_channels=64, out_channels=64)
        
        self.fc = nn.Conv2d(in_channels=64, out_channels=1, kernel_siza=2, stride=2, padding=0, bias=True)
        
        
    # UNet 레이어 연결하기
    def forward(self, x) :
        # Encoder 부분 연결하기
        enc1_1 = self.enc1_1(x)
        enc1_2 = self.enc1_2(enc1_1)
        pool1 = self.pool1(enc1_2)

        enc2_1 = self.enc1_1(pool1)
        enc2_2 = self.enc1_2(enc2_1)
        pool2 = self.pool1(enc2_2)
        
        enc3_1 = self.enc1_1(pool2)
        enc3_2 = self.enc1_2(enc3_1)
        pool3 = self.pool1(enc3_2) 
        
        enc4_1 = self.enc1_1(pool3)
        enc4_2 = self.enc1_2(enc4_1)
        pool4 = self.pool1(enc4_2)
        
        enc5_1 = self.enc1_1(pool4)
        
        # Decoder 부분 연결하기
        
        dec5_1 = self.dec5_1(enc5_1)
        
        unpool4 = self.unpool4(dec5_1)
        cat4 = torch.cat((unpool4, enc4_2), dim=1) # dim=[0:batch, 1:channel, 2:height, 3:width]
        dec4_2 = self.dec4_2(cat4)
        dec4_1 = self.dec4_1(dec4_2)
        
        unpool3 = self.unpool3(dec4_1)
        cat3 = torch.cat((unpool3, enc3_2), dim=1)
        dec3_2 = self.dec3_2(cat3)
        dec3_1 = self.dec3_1(dec3_2)
        
        unpool2 = self.unpool2(dec3_1)
        cat2 = torch.cat((unpool2, enc2_2), dim=1)
        dec2_2 = self.dec2_2(cat2)
        dec2_1 = self.dec2_1(dec2_2)
        
        unpool1 = self.unpool1(dec2_1)
        cat1 = torch.cat((unpool1, enc1_2), dim=1)
        dec1_2 = self.dec1_2(cat1)
        dec1_1 = self.dec1_1(dec1_2)
        
        x = self.fc(dec1_1)
        
        return x