In [107]:
import os
import time
import numpy as np
import matplotlib.pyplot as plt
import PIL
import torch

import torch.nn as nn
import torch.optim as optim
from torchvision import models
from torchvision.models import vgg16
from torchvision import datasets, transforms


print('pytorch version: {}'.format(torch.__version__))
print('GPU 사용 가능 여부: {}'.format(torch.cuda.is_available()))
device = "cuda" if torch.cuda.is_available() else "cpu"   # GPU 사용 가능 여부에 따라 device 정보 저장

pytorch version: 1.4.0
GPU 사용 가능 여부: True


### 네트워크 설계 I (Pretrained 된 모델 사용 X)

### Front-end Module

In [108]:
import torch
import torch.nn as nn


def conv_relu(in_ch, out_ch, size=3, rate=1):
    conv_relu = nn.Sequential(nn.Conv2d(in_channels=in_ch,
                                        out_channels=out_ch,
                                        kernel_size=3, 
                                        stride=1,
                                        padding=rate,
                                        dilation=rate),
                              nn.ReLU())
    return conv_relu            


class VGG16(nn.Module):
    def __init__(self):
        super(VGG16, self).__init__()
        
        self.features1 = nn.Sequential(conv_relu(3, 64, 3, 1),
                                       conv_relu(64, 64, 3, 1),
                                       nn.MaxPool2d(2, stride=2, padding=0)) # 1/2
        
        self.features2 = nn.Sequential(conv_relu(64, 128, 3, 1),
                                       conv_relu(128, 128, 3, 1),
                                       nn.MaxPool2d(2, stride=2, padding=0)) # 1/4        
        
        self.features3 = nn.Sequential(conv_relu(128, 256, 3, 1),
                                       conv_relu(256, 256, 3, 1),
                                       conv_relu(256, 256, 3, 1),
                                       nn.MaxPool2d(2, stride=2, padding=0)) # 1/8       
        
        self.features4 = nn.Sequential(conv_relu(256, 512, 3, 1),
                                       conv_relu(512, 512, 3, 1),
                                       conv_relu(512, 512, 3, 1))
        
                                       # and replace subsequent conv layer rate=2
        self.features5 = nn.Sequential(conv_relu(512, 512, 3, 2),
                                       conv_relu(512, 512, 3, 2),
                                       conv_relu(512, 512, 3, 2))             
        
    def forward(self, x):
        out = self.features1(x)
        out = self.features2(out)
        out = self.features3(out)
        out = self.features4(out)
        out = self.features5(out)
        
        return out

In [109]:
class classifier(nn.Module):
    def __init__(self, num_classes): 
        super(classifier, self).__init__()
        self.classifier = nn.Sequential(conv_relu(512, 4096, 7, rate=4), 
                                        nn.Dropout2d(0.5), 
                                        conv_relu(4096, 4096, 1, 1),
                                        nn.Dropout2d(0.5), 
                                        nn.Conv2d(4096, num_classes, 1)
                                        )
    def forward(self, x): 
        out = self.classifier(x)
        return out

### Context Module

A context module is constructed based on the dilated convolution as below:

![image.png](https://miro.medium.com/max/1576/1*aj0ymQMfAOCXbvhnSlTY_w.png)

In [110]:
class BasicContextModule(nn.Module):
    def __init__(self, num_classes):
        super(BasicContextModule, self).__init__()
        
        self.layer1 = nn.Sequential(conv_relu(num_classes, num_classes, 3, 1))
        self.layer2 = nn.Sequential(conv_relu(num_classes, num_classes, 3, 1))
        self.layer3 = nn.Sequential(conv_relu(num_classes, num_classes, 3, 2))
        self.layer4 = nn.Sequential(conv_relu(num_classes, num_classes, 3, 4))
        self.layer5 = nn.Sequential(conv_relu(num_classes, num_classes, 3, 8))
        self.layer6 = nn.Sequential(conv_relu(num_classes, num_classes, 3, 16))
        self.layer7 = nn.Sequential(conv_relu(num_classes, num_classes, 3, 1))
        # No Truncation 
        self.layer8 = nn.Sequential(nn.Conv2d(num_classes, num_classes, 1, 1))
        
    def forward(self, x): 
        
        out = self.layer1(x)
        out = self.layer2(x)
        out = self.layer3(x)
        out = self.layer4(x)
        out = self.layer5(x)
        out = self.layer6(x)
        out = self.layer7(x)
        out = self.layer8(x)
        
        return out

### DilatedNet

In [111]:
class DilatedNet(nn.Module):

    def __init__(self, backbone, classifier, context_module):
        super(DilatedNet, self).__init__()
        self.backbone = backbone
        self.classifier = classifier
        self.context_module = context_module
        
        self.deconv = nn.ConvTranspose2d(in_channels=21,
                                         out_channels=21,
                                         kernel_size=16,
                                         stride=8,
                                         padding=4)
        
    def forward(self, x):
        x = self.backbone(x)
        x = self.classifier(x)
        x = self.context_module(x)
        x = self.deconv(x)
        return x

In [112]:
# model output test
num_classes = 21
backbone = VGG16()
classifier = classifier(num_classes)
context_module = BasicContextModule(num_classes)

model = DilatedNet(backbone=backbone, classifier=classifier, context_module=context_module)


model.eval()
image = torch.randn(1, 3, 512, 512)
print("input:", image.shape)
print("output:", model(image).shape)

input: torch.Size([1, 3, 512, 512])
output: torch.Size([1, 21, 512, 512])


## CRF

In [2]:
import torch
import torch.nn as nn
from crfseg import CRF

model = nn.Sequential(
    nn.Identity(),  # your NN
    CRF(n_spatial_dims=2)
)

batch_size, n_channels, spatial = 10, 3,(100, 100)
x = torch.zeros(batch_size, n_channels, *spatial)
log_proba = model(x)

### Reference
---

- [Dilated Convolution for Semantic Image Segmentation using caffe](https://github.com/fyu/dilation/blob/master/network.py)