In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models
from torchvision.models import vgg16

In [2]:
# 중복블럭 생성용
def conv_relu(in_ch, out_ch, size, rate):
        conv_relu = nn.Sequential(nn.Conv2d(in_channels=in_ch,
                                            out_channels=out_ch,
                                            kernel_size=size,
                                            stride=1,
                                            padding=rate,
                                            dilation=rate), 
                                 nn.ReLU())
        return conv_relu

In [3]:
class Backbone(nn.Module):
    def __init__(self):
        super().__init__()
#         self.features1 = nn.Sequential(conv_relu(3, 64, 3, 1),
#                                        conv_relu(64, 64, 3, 1),
#                                        nn.MaxPool2d(2, stride=2, padding=0))
#         #112
        
#         self.features2 = nn.Sequential(conv_relu(64, 128, 3, 1),
#                                        conv_relu(128, 128, 3, 1),
#                                        nn.MaxPool2d(2, stride=2, padding=0))
#         #56
        
#         self.features3 = nn.Sequential(conv_relu(128, 256, 3, 1),
#                                        conv_relu(256, 256, 3, 1),
#                                        conv_relu(256, 256, 3, 1),
#                                        nn.MaxPool2d(2, stride=2, padding=0))
#         #28
        
#         self.features4 = nn.Sequential(conv_relu(256, 512, 3, 1),
#                                        conv_relu(512, 512, 3, 1),
#                                        conv_relu(512, 512, 3, 1))
        back = vgg16(pretrained=True)
        self.conv4 = back.features[:23]
        #28
        
        self.features5 = nn.Sequential(conv_relu(512, 512, 3, 2),
                                       conv_relu(512, 512, 3, 2),
                                       conv_relu(512, 512, 3, 2))
         #28
        
            
    def forward(self, x):
        x = self.conv4(x)
        x = self.features5(x)
        
        return x

In [4]:
class Classifier(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.classifier = nn.Sequential(nn.Conv2d(512, 4096, kernel_size=7, dilation=4, padding=12),
                                        nn.ReLU(),
                                        nn.Dropout2d(0.5),
                                        nn.Conv2d(4096, 4096, 1),
                                        nn.ReLU(),
                                        nn.Dropout2d(0.5),
                                        nn.Conv2d(4096, num_classes, kernel_size=1))

    def forward(self, x):
        x = self.classifier(x)
        return x

In [5]:
class BasicContextModule(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        
        self.context_module = nn.Sequential(conv_relu(num_classes, num_classes, 3, 1),
                                            conv_relu(num_classes, num_classes, 3, 1),
                                            conv_relu(num_classes, num_classes, 3, 2),
                                            conv_relu(num_classes, num_classes, 3, 4),
                                            conv_relu(num_classes, num_classes, 3, 8),
                                            conv_relu(num_classes, num_classes, 3, 16),
                                            conv_relu(num_classes, num_classes, 3, 1),
                                            #No Truncation
                                            nn.Conv2d(num_classes, num_classes, kernel_size=1))
        
    def forward(self, x):
        x = self.context_module(x)
        
        return x

In [6]:
class DilatedNet(nn.Module):
    def __init__(self, backbone, classifier, context_module, num_classes):
        super().__init__()
        self.backbone = backbone
        self.classifier = classifier
        self.context_module = context_module
        self.deconv = nn.ConvTranspose2d(in_channels=num_classes,
                                         out_channels=num_classes,
                                         kernel_size=16,
                                         stride=8,
                                         padding=4)
                
    def forward(self, x):
        x = self.backbone(x)
        x = self.classifier(x)
        x = self.context_module(x)
        x = self.deconv(x)
        
        return x

In [7]:
# 구현된 model에 임의의 input을 넣어 output이 잘 나오는지 test
# device = "cuda" if torch.cuda.is_available() else "cpu"
# backbone = Backbone()
# classifier = Classifier(12)
# context_module = BasicContextModule(12)
# model = DilatedNet(backbone , classifier, context_module, num_classes=12)
# x = torch.randn([1, 3, 224, 224])
# print("input shape : ", x.shape)
# out = model(x).to(device)
# print("output shape : ", out.size())

# model = model.to(device)

input shape :  torch.Size([1, 3, 224, 224])
torch.Size([1, 12, 28, 28])
torch.Size([1, 12, 28, 28])
output shape :  torch.Size([1, 12, 224, 224])
