In [16]:
import torch
import torch.nn as nn

class conv_block_nested(nn.Module):
    def __init__(self, in_ch, mid_ch, out_ch):
        super(conv_block_nested, self).__init__()
        self.activation = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv2d(in_ch, mid_ch, kernel_size=3, padding=1, bias=True)
        self.bn1 = nn.BatchNorm2d(mid_ch)
        self.conv2 = nn.Conv2d(mid_ch, out_ch, kernel_size=3, padding=1, bias=True)
        self.bn2 = nn.BatchNorm2d(out_ch)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.activation(x)

        x = self.conv2(x)
        x = self.bn2(x)
        output = self.activation(x)
        return output

class UNetPlusPlus(nn.Module):

    def __init__(self, in_ch=3, out_ch=1, n1=64, height=512, width=512, supervision=True):
        super(UNetPlusPlus, self).__init__()

        filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16] # 64,128,256,512,1024 채널수

        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        # 512,256,128,64 로 이미지 사이즈를 키워줌 
        self.Up = nn.ModuleList([nn.Upsample(size=(height//(2**c), width//(2**c)), mode='bilinear', align_corners=True) for c in range(4)]) # nn.Upsample = nn.functional.interpolate
        self.supervision = supervision

        self.conv0_0 = conv_block_nested(in_ch, filters[0], filters[0]) #  (0,0) , 3-> 64
        self.conv1_0 = conv_block_nested(filters[0], filters[1], filters[1]) # (1,0) , 64-> 128
        self.conv2_0 = conv_block_nested(filters[1], filters[2], filters[2]) # (2,0) , 128-> 256
        self.conv3_0 = conv_block_nested(filters[2], filters[3], filters[3]) # (3,0) , 256-> 512
        self.conv4_0 = conv_block_nested(filters[3], filters[4], filters[4]) # (4,0) , 512-> 1024

        self.conv0_1 = conv_block_nested(filters[0] + filters[1], filters[0], filters[0]) # (0,1) 64+128-> 64
        self.conv1_1 = conv_block_nested(filters[1] + filters[2], filters[1], filters[1]) # (1,1) 128+256-> 128
        self.conv2_1 = conv_block_nested(filters[2] + filters[3], filters[2], filters[2]) # (2,1) 256+512-> 256
        self.conv3_1 = conv_block_nested(filters[3] + filters[4], filters[3], filters[3]) # (3,1) 512+1024-> 512

        self.conv0_2 = conv_block_nested(filters[0]*2 + filters[1], filters[0], filters[0]) # (0,2) 64+64+128-> 64
        self.conv1_2 = conv_block_nested(filters[1]*2 + filters[2], filters[1], filters[1]) # (1,2) 128+128+256-> 128
        self.conv2_2 = conv_block_nested(filters[2]*2 + filters[3], filters[2], filters[2]) # (2,2) 256+256+512-> 256

        self.conv0_3 = conv_block_nested(filters[0]*3 + filters[1], filters[0], filters[0]) # (0,3) 64+64+64+128-> 64
        self.conv1_3 = conv_block_nested(filters[1]*3 + filters[2], filters[1], filters[1]) # (1,3) 128+128+128+256-> 128

        self.conv0_4 = conv_block_nested(filters[0]*4 + filters[1], filters[0], filters[0]) # (0,4) 64+64+64+64+128-> 64

        self.seg_outputs = nn.ModuleList([nn.Conv2d(filters[0], out_ch, kernel_size=1, padding=0) for _ in range(4)])

    def forward(self, x):
        # concat 할때는 항상 이미지 사이즈 맞춰주기 self.Up
        
        seg_outputs = []
        x0_0 = self.conv0_0(x) # channel = 64
        x1_0 = self.conv1_0(self.pool(x0_0)) # channel = 128
        x0_1 = self.conv0_1(torch.cat([x0_0, self.Up[0](x1_0)], 1)) # channel = 64 , x1_0의 이미지사이즈는 절반으로 변했으므로 512x512로 다시 up해줌(Up[0])
        seg_outputs.append(self.seg_outputs[0](x0_1)) # x0_1에 1x1 Conv 적용해서 num_classes만큼 변경

        x2_0 = self.conv2_0(self.pool(x1_0)) # 256
        x1_1 = self.conv1_1(torch.cat([x1_0, self.Up[1](x2_0)], 1)) # 128 
        x0_2 = self.conv0_2(torch.cat([x0_0, x0_1, self.Up[0](x1_1)], 1)) # 64 ,
        seg_outputs.append(self.seg_outputs[1](x0_2)) #x0_2에 1x1 Conv 적용해서 num_classes만큼 변경

        x3_0 = self.conv3_0(self.pool(x2_0))
        x2_1 = self.conv2_1(torch.cat([x2_0, self.Up[2](x3_0)], 1))
        x1_2 = self.conv1_2(torch.cat([x1_0, x1_1, self.Up[1](x2_1)], 1))
        x0_3 = self.conv0_3(torch.cat([x0_0, x0_1, x0_2, self.Up[0](x1_2)], 1))
        seg_outputs.append(self.seg_outputs[2](x0_3)) #x0_3에 1x1 Conv 적용해서 num_classes만큼 변경

        x4_0 = self.conv4_0(self.pool(x3_0))
        x3_1 = self.conv3_1(torch.cat([x3_0, self.Up[3](x4_0)], 1))
        x2_2 = self.conv2_2(torch.cat([x2_0, x2_1, self.Up[2](x3_1)], 1))
        x1_3 = self.conv1_3(torch.cat([x1_0, x1_1, x1_2, self.Up[1](x2_2)], 1))
        x0_4 = self.conv0_4(torch.cat([x0_0, x0_1, x0_2, x0_3, self.Up[0](x1_3)], 1))
        seg_outputs.append(self.seg_outputs[3](x0_4)) #x0_4에 1x1 Conv 적용해서 num_classes만큼 변경

        if self.supervision: 
            return torch.stack(seg_outputs)
        else:
            return seg_outputs[-1]

In [17]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# 구현된 model에 임의의 input을 넣어 output이 잘 나오는지 test

model = UNetPlusPlus(out_ch=12, supervision=True)
x = torch.randn([1, 3, 512, 512])
print("input shape : ", x.shape)
out = model(x).to(device)
print("output shape : ", out.size())

input shape :  torch.Size([1, 3, 512, 512])
output shape :  torch.Size([4, 1, 12, 512, 512])
