In [None]:
import sys
print(sys.version)

CURR_DIR = '/content/drive/My Drive/google_colab_work/advanced_deep_learning_by_pytorch/'
CURR_DIR += '3_semantic_segmentation/'
sys.path.append(CURR_DIR)

3.6.9 (default, Jul 17 2020, 12:50:27) 
[GCC 8.4.0]


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
print('torch.__version__ =', torch.__version__)

torch.__version__ = 1.6.0+cu101


# Feature module (Encoder)

In [None]:
class ConvBatchNormRelu(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, bias):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias)
        self.batchnorm = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)


    def forward(self, x):
        outputs = self.conv(x)
        outputs = self.batchnorm(outputs)
        outputs = self.relu(outputs)

        return outputs

In [None]:
class FeatureMapConv(nn.Module):
    def __init__(self):
        super().__init__()

        in_channels=3; out_channels=64; kernel_size=3; stride=2; padding=1; dilation=1; bias=False
        self.conv_batchnorm_relu_1 = ConvBatchNormRelu(in_channels, out_channels, kernel_size, stride, padding, dilation, bias)

        in_channels=64; out_channels=64; kernel_size=3; stride=1; padding=1; dilation=1; bias=False
        self.conv_batchnorm_relu_2 = ConvBatchNormRelu(in_channels, out_channels, kernel_size, stride, padding, dilation, bias)

        in_channels=64; out_channels=128; kernel_size=3; stride=1; padding=1; dilation=1; bias=False
        self.conv_batchnorm_relu_3 = ConvBatchNormRelu(in_channels, out_channels, kernel_size, stride, padding, dilation, bias)

        kernel_size=3; stride=2; padding=1 
        self.maxpool = nn.MaxPool2d(kernel_size, stride, padding)


    def forward(self, x):
        outputs = self.conv_batchnorm_relu_1(x)
        outputs = self.conv_batchnorm_relu_2(outputs)
        outputs = self.conv_batchnorm_relu_3(outputs)
        outputs = self.maxpool(outputs)

        return outputs

In [None]:
class ConvBatchNorm(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, bias):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias)
        self.batchnorm = nn.BatchNorm2d(out_channels)


    def forward(self, x):
        outputs = self.conv(x)
        outputs = self.batchnorm(outputs)

        return outputs

In [None]:
class BottleNeckPSP(nn.Module):
    def __init__(self, in_channels, mid_channels, out_channels, stride, dilation):
        assert in_channels == 2*mid_channels
        assert 4*mid_channels == out_channels

        super().__init__()

        self.conv_batchnorm_relu_1 = ConvBatchNormRelu(in_channels, mid_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=False)
        self.conv_batchnorm_relu_2 = ConvBatchNormRelu(mid_channels, mid_channels, kernel_size=3, stride=stride, padding=dilation, dilation=dilation, bias=False)
        self.conv_batchnorm_3 = ConvBatchNorm(mid_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=False)

        self.conv_batchnorm_res = ConvBatchNorm(in_channels, out_channels, kernel_size=1, stride=stride, padding=0, dilation=1, bias=False)

        self.relu = nn.ReLU(inplace=True)

    
    def forward(self, x):
        conv = self.conv_batchnorm_relu_1(x)
        conv = self.conv_batchnorm_relu_2(conv)
        conv = self.conv_batchnorm_3(conv)

        res = self.conv_batchnorm_res(x)

        return self.relu(conv + res) ### Residual skip connection

In [None]:
class BottleNeckIdentityPSP(nn.Module):
    def __init__(self, out_channels, mid_channels, dilation):
        assert out_channels == 4*mid_channels

        super().__init__()

        self.conv_batchnorm_relu_1 = ConvBatchNormRelu(out_channels, mid_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=False)
        self.conv_batchnorm_relu_2 = ConvBatchNormRelu(mid_channels, mid_channels, kernel_size=3, stride=1, padding=dilation, dilation=dilation, bias=False)
        self.conv_batchnorm_3 = ConvBatchNorm(mid_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=False)

        self.relu = nn.ReLU(inplace=True)

    
    def forward(self, x):
        conv = self.conv_batchnorm_relu_1(x)
        conv = self.conv_batchnorm_relu_2(conv)
        conv = self.conv_batchnorm_3(conv)

        res = x

        return self.relu(conv + res) ### Residual skip connection

In [None]:
class ResBlockPSP(nn.Sequential):
    def __init__(self, n_blocks, in_channels, mid_channels, out_channels, stride, dilation):
        super().__init__()

        self.add_module(
            'block1',
            BottleNeckPSP(in_channels, mid_channels, out_channels, stride, dilation)
        )

        for i in range(n_blocks - 1):
            self.add_module(
                'block'+str(i+2),
                BottleNeckIdentityPSP(out_channels, mid_channels, dilation)
            )
            # print('[ResBlockPSP::__init__()] block', str(i+2))

# Pyramid Pooling module

In [None]:
class PyramidPooling(nn.Module):
    def __init__(self, in_channels, pool_sizes, height, width):
        super().__init__()
        self.height = height
        self.width = width

        out_channels = int(in_channels / len(pool_sizes))

        self.ada_avg_pool_1 = nn.AdaptiveAvgPool2d(output_size=pool_sizes[0])
        self.conv_batchnorm_relu_1 = ConvBatchNormRelu(in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=False)

        self.ada_avg_pool_2 = nn.AdaptiveAvgPool2d(output_size=pool_sizes[1])
        self.conv_batchnorm_relu_2 = ConvBatchNormRelu(in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=False)

        self.ada_avg_pool_3 = nn.AdaptiveAvgPool2d(output_size=pool_sizes[2])
        self.conv_batchnorm_relu_3 = ConvBatchNormRelu(in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=False)

        self.ada_avg_pool_4 = nn.AdaptiveAvgPool2d(output_size=pool_sizes[3])
        self.conv_batchnorm_relu_4 = ConvBatchNormRelu(in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=False)


    def forward(self, x):
        out1 = self.ada_avg_pool_1(x)
        out1 = self.conv_batchnorm_relu_1(out1)
        out1 = F.interpolate(out1, size=(self.height, self.width), mode='bilinear', align_corners=True)

        out2 = self.ada_avg_pool_2(x)
        out2 = self.conv_batchnorm_relu_2(out2)
        out2 = F.interpolate(out2, size=(self.height, self.width), mode='bilinear', align_corners=True)

        out3 = self.ada_avg_pool_3(x)
        out3 = self.conv_batchnorm_relu_3(out3)
        out3 = F.interpolate(out3, size=(self.height, self.width), mode='bilinear', align_corners=True)

        out4 = self.ada_avg_pool_4(x)
        out4 = self.conv_batchnorm_relu_4(out4)
        out4 = F.interpolate(out4, size=(self.height, self.width), mode='bilinear', align_corners=True)

        outputs = torch.cat([x, out1, out2, out3, out4], dim=1)

        return outputs

# Up Sampling module, Auxiliary Loss module (Decoder)

In [None]:
class UpSampling(nn.Module):
    def __init__(self, in_channels, height, width, n_classes):
        super().__init__()

        self.height = height
        self.width = width
        self.mid_channels = 512

        self.conv_batchnorm_relu = ConvBatchNormRelu(in_channels, out_channels=self.mid_channels, kernel_size=3, stride=1, padding=1, dilation=1, bias=False)
        self.dropout = nn.Dropout2d(p=0.1)
        self.classify = nn.Conv2d(in_channels=self.mid_channels, out_channels=n_classes, kernel_size=1, stride=1, padding=0)

    
    def forward(self, x):
        outputs = self.conv_batchnorm_relu(x)
        outputs = self.dropout(outputs)
        outputs = self.classify(outputs)
        outputs = F.interpolate(outputs, size=(self.height, self.width), mode='bilinear', align_corners=True)

        return outputs

In [None]:
class AuxLoss(nn.Module):
    def __init__(self, in_channels, height, width, n_classes):
        super().__init__()

        self.height = height
        self.width = width
        self.mid_channels = 256

        self.conv_batchnorm_relu = ConvBatchNormRelu(in_channels, out_channels=self.mid_channels, kernel_size=3, stride=1, padding=1, dilation=1, bias=False)
        self.dropout = nn.Dropout2d(p=0.1)
        self.classify = nn.Conv2d(in_channels=self.mid_channels, out_channels=n_classes, kernel_size=1, stride=1, padding=0)

    
    def forward(self, x):
        outputs = self.conv_batchnorm_relu(x)
        outputs = self.dropout(outputs)
        outputs = self.classify(outputs)
        outputs = F.interpolate(outputs, size=(self.height, self.width), mode='bilinear', align_corners=True)

        return outputs

# PSPNet

In [None]:
class PSPNet(nn.Module):
    def __init__(self, n_classes):
        super().__init__()

        img_size = 475
        img_szie_small = 60

        block_nums = [3, 4, 6, 3]
        pool_sizes = [6, 3, 2, 1]

        ### Feature module
        self.feature_conv = FeatureMapConv()
        self.feature_res_1 =ResBlockPSP(n_blocks=block_nums[0], in_channels=128, mid_channels=64, out_channels=256, stride=1, dilation=1)
        self.feature_res_2 =ResBlockPSP(n_blocks=block_nums[1], in_channels=256, mid_channels=128, out_channels=512, stride=2, dilation=1)
        self.feature_dilated_res_1 =ResBlockPSP(n_blocks=block_nums[2], in_channels=512, mid_channels=256, out_channels=1024, stride=1, dilation=2)
        self.feature_dilated_res_2 =ResBlockPSP(n_blocks=block_nums[3], in_channels=1024, mid_channels=512, out_channels=2048, stride=1, dilation=4)

        ### Pyramid Pooling module
        self.pyramid_pooling = PyramidPooling(in_channels=2048, pool_sizes=pool_sizes, height=img_szie_small, width=img_szie_small)

        ### Up Sampling module
        self.up_sampling = UpSampling(in_channels=4096, height=img_size, width=img_size, n_classes=n_classes)

        ### Auxiliary Loss module
        self.aux_loss = AuxLoss(in_channels=1024, height=img_size, width=img_size, n_classes=n_classes)


    def forward(self, x):
        y = self.feature_conv(x)
        y = self.feature_res_1(y)
        y = self.feature_res_2(y)
        y = self.feature_dilated_res_1(y)
        outputs_aux = self.aux_loss(y)
        y = self.feature_dilated_res_2(y)

        y = self.pyramid_pooling(y)

        outputs = self.up_sampling(y)

        return (outputs, outputs_aux)

In [None]:
net = PSPNet(n_classes=21)
print(net)

PSPNet(
  (feature_conv): FeatureMapConv(
    (conv_batchnorm_relu_1): ConvBatchNormRelu(
      (conv): Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (batchnorm): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
    (conv_batchnorm_relu_2): ConvBatchNormRelu(
      (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (batchnorm): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
    (conv_batchnorm_relu_3): ConvBatchNormRelu(
      (conv): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (batchnorm): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  )
  (feature_res_1): ResBlockPS

In [None]:
batch_size = 2
dummy_imgs = torch.rand(batch_size, 3, 475, 475)

outputs, outpus_aux = net(dummy_imgs)

print('outputs.size() =', outputs.size())
print('outputs_aux.size() =', outputs.size())

outputs.size() = torch.Size([2, 21, 475, 475])
outputs_aux.size() = torch.Size([2, 21, 475, 475])
