In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
# nice!!

device

device(type='cuda')

In [5]:
class U_Net(nn.Module):
    def __init__(self, init_weights=True):
        ## layer 쌓기

        super().__init__()

        def CBR2d(input_ch, output_ch, kernel_size=3, stride=1):
            """
            입력 채널, 출력 채널, 커널 사이즈, 스트라이드
            """
            layer = nn.Sequential(
                nn.Conv2d(input_ch, output_ch, kernel_size=kernel_size, stride=stride),
                nn.BatchNorm2d(output_ch),
                nn.ReLU(),
            )
            return layer


        ## Contraction
        self.conv1 = nn.Sequential(
            CBR2d(1, 64, 3, 1),
            CBR2d(64, 64, 3, 1),
        )
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv2 = nn.Sequential(
            CBR2d(64, 128, 3, 1),
            CBR2d(128, 128, 3, 1),
        )
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv3 = nn.Sequential(
            CBR2d(128, 256, 3, 1),
            CBR2d(256, 256, 3, 1),
        )
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv4 = nn.Sequential(
            CBR2d(256, 512, 3, 1),
            CBR2d(512, 512, 3, 1),
        )
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)


        ## BottleNeck
        self.bottleNeck = nn.Sequential(
            CBR2d(512, 1024, 3, 1),
            CBR2d(1024, 1024, 3, 1)
        )


        ## Expansion
        self.upconv1 = nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=2, stride=2)
        self.ex_conv1 = nn.Sequential(
            CBR2d(1024, 512, 3, 1),
            CBR2d(512, 512, 3, 1)
        )

        self.upconv2 = nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=2, stride=2)
        self.ex_conv2 = nn.Sequential(
            CBR2d(512, 256, 3, 1),
            CBR2d(256, 256, 3, 1)
        )

        self.upconv3 = nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=2, stride=2)
        self.ex_conv3 = nn.Sequential(
            CBR2d(256, 128, 3, 1),
            CBR2d(128, 128, 3, 1)
        )

        self.upconv4 = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=2, stride=2)
        self.ex_conv4 = nn.Sequential(
            CBR2d(128, 64, 3, 1),
            CBR2d(64, 64, 3, 1)
        )


        ## Final Predictoin
        self.segment = nn.Conv2d(64, 1, kernel_size=1, stride=1)

        if init_weights:
            self._initialize_weights()
    
    def forward(self, x):
        enc_feature = []
        # Contraction
        layer1 =self.conv1(x)
        enc_feature.append(layer1)
        out = self.pool1(layer1)
        
        layer2 =self.conv2(out)
        enc_feature.append(layer2)
        out = self.pool2(layer2)
        
        layer3 =self.conv3(out)
        enc_feature.append(layer3)
        out = self.pool3(layer3)
        
        layer4 =self.conv4(x)
        enc_feature.append(layer4)
        out = self.pool4(layer4)

        # BottleNeck
        bottleNeck = self.bottleNeck(out)

        # Expansion
        upconv1 = self.upconv1(bottleNeck)
        # enc_feature[3]을 upconv1의 shape으로 CenterCrop한 다음 dim=1의 방향으로 upconv1와 concat
        cat1 = torch.cat((transforms.CenterCrop((upconv1.shape[2], upconv1.shape[3]))(enc_feature[3]), upconv1), dim=1)
        ex_layer1 = self.ex_conv1(cat1)
        
        upconv2 = self.upconv2(ex_layer1)
        cat2 = torch.cat((transforms.CenterCrop((upconv2.shape[2], upconv2.shape[3]))(enc_feature[2]), upconv2), dim=1)
        ex_layer2 = self.ex_conv2(cat2)

        upconv3 = self.upconv3(ex_layer2)
        cat3 = torch.cat((transforms.CenterCrop((upconv3.shape[2], upconv3.shape[3]))(enc_feature[1]), upconv3), dim=1)
        ex_layer3 = self.ex_conv3(cat3)

        upconv4 = self.upconv4(ex_layer3)
        cat4 = torch.cat((transforms.CenterCrop((upconv4.shape[2], upconv4.shape[3]))(enc_feature[1]), upconv4), dim=1)
        out = self.ex_conv4(cat4)

        # Final Prediction
        out = self.segment(out)

        return out
    
    def _initialize_weights(self):
        # self.modules() : 모델 클래스에서 정의된 layer들을 하나씩 반환.
        for m in self.modules():
            nn.init.kaiming_normal_(m.weight)
