<a href="https://colab.research.google.com/github/jerogar/DeepLabV3p-Pytorch-Tutorial_KOR/blob/master/DeepLab_v3_Tutorial.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# DeepLab v3 Tutorial

가짜연구소 Season 2 논문미식회

reference: https://github.com/VainF/DeepLabV3Plus-Pytorch



![Main](https://sthalles.github.io/assets/deep_segmentation_network/semantic_segmentation.jpg)


# 1. 환경 설정

## (1) Mount Google Drive
- 구글 드라이브를 연결하고 작업 폴더를 Colab Notebooks/로 설정

In [None]:
from google.colab import drive # import drive from google colab

ROOT = "/content/drive"     # default location for the drive
print(ROOT)                 # print content of ROOT (Optional)

drive.mount(ROOT)           # we mount the google drive at /content/drive

# import join used to join ROOT path and MY_GOOGLE_DRIVE_PATH
from os.path import join  

# path to your project on Google Drive
MY_GOOGLE_DRIVE_PATH = '/content/drive/My Drive/Colab Notebooks'

print("MY_GOOGLE_DRIVE_PATH: ", MY_GOOGLE_DRIVE_PATH)
# In case we haven't created the folder already; we will create a folder in the project path  
%cd "{MY_GOOGLE_DRIVE_PATH}"

/content/drive
Mounted at /content/drive
MY_GOOGLE_DRIVE_PATH:  /content/drive/My Drive/Colab Notebooks
/content/drive/My Drive/Colab Notebooks


## (2) Set Work directory

In [None]:
%cd DeepLabV3Plus-Pytorch 

/content/drive/My Drive/Colab Notebooks/DeepLabV3Plus-Pytorch


In [None]:
!pip install torch torchvision numpy pillow scikit-learn tqdm matplotlib



In [None]:
from tqdm import tqdm
import os
import random
import numpy as np

from torch.utils import data
import torch
import torch.nn as nn

import torch.nn.functional as F
from collections import OrderedDict

from torchvision.transforms.functional import normalize
from sklearn.metrics import confusion_matrix


#2. DeepLab v3 Model

- DeepLab v1 - atrous convolution 및 Conditional Random Field(CRF) 도입
- DeepLab v2 - A*trous spatial pyramid pooling* (ASPP) 이용하여 다양한 크기의 객체에 대응
- DeepLab v3 - encoder에 ResNet 구조 도입. ASPP를 보완하고, batch norm. 사용하여 학습이 잘 되도록 함.  CRF 없이 동등 이상 성능 확보
- DeepLab v3+: Depth-wise seperable Conv.도입하여 런타임 속도 개선. (ASSPP) U-Net 구조를 단순화한 Decoder 구조 사용하여 성능 개선

![deeplab](https://miro.medium.com/max/1038/0*_Hm_2fqbcnlwLkoz.png)

## (1) Overall Structure
![deeplabv3](https://www.oreilly.com/library/view/hands-on-image-processing/9781789343731/assets/1aa5b349-5a66-456a-8afa-080a7b07a525.png)

- Backbone Network + DeepLab Head 구조
- Resnet의 layer 4의 feature를 backbone feature로 사용
- output_stride를 8로 설정하는 경우 
    - Backborn의 stride=8에 해당하는 layer까지 사용
    - ResNet을 사용하는 경우 layer2까지 사용하며, stride를 8로 유지하며 layer 3, 4의 atrous  convolution으로 대체.
    - DeepLab Head의 Atrous Spatial Pyramid Pooling(ASPP)의 dilation rate=[12, 24, 36] 사용
- output_stride를 16로 설정하는 경우 
    - Backborn의 stride=16에 해당하는 layer까지 사용
    - ResNet을 사용하는 경우 layer3까지 사용하며, stride를 16로 유지하며 layer 4를 atrous  convolution으로 대체.
    - DeepLab Head의 Atrous Spatial Pyramid Pooling(ASPP)의 dilation rate=[6, 12, 18] 사용

In [None]:
from collections import OrderedDict

def ResNetbasedDeepLabV3(num_classes, output_stride, pretrained_backbone):
    if output_stride==8:
        replace_stride_with_dilation=[False, True, True]
        aspp_dilate = [12, 24, 36]
    else:
        replace_stride_with_dilation=[False, False, True]
        aspp_dilate = [6, 12, 18]
        
    backbone = resnet50(
        pretrained=pretrained_backbone,
        replace_stride_with_dilation=replace_stride_with_dilation)
    
    inplanes = 2048
    low_level_planes = 256

    return_layers = {'layer4': 'out'}
    classifier = DeepLabHead(inplanes , num_classes, aspp_dilate)
    backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)


    # TODO: interpolation parameter 추가
    model = DeepLabV3(backbone, classifier, 'CARAFE')
    return model

##(2) Backborn - ResNet50 

In [None]:
import torch
import torch.nn as nn
from torchvision.models.utils import load_state_dict_from_url

model_urls = {
    'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
    'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
    'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
    'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
    'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
    'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
    'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
    'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
    'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
}
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=dilation, groups=groups, bias=False, dilation=dilation)
def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None):
        super(BasicBlock, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        if groups != 1 or base_width != 64:
            raise ValueError('BasicBlock only supports groups=1 and base_width=64')
        if dilation > 1:
            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = norm_layer(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = norm_layer(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out
class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None):
        super(Bottleneck, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        width = int(planes * (base_width / 64.)) * groups
        # Both self.conv2 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv1x1(inplanes, width)
        self.bn1 = norm_layer(width)
        self.conv2 = conv3x3(width, width, stride, groups, dilation)
        self.bn2 = norm_layer(width)
        self.conv3 = conv1x1(width, planes * self.expansion)
        self.bn3 = norm_layer(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out

class ResNet(nn.Module):

    def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
                 groups=1, width_per_group=64, replace_stride_with_dilation=None,
                 norm_layer=None):
        super(ResNet, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        self._norm_layer = norm_layer

        self.inplanes = 64
        self.dilation = 1
        if replace_stride_with_dilation is None:
            # each element in the tuple indicates if we should replace
            # the 2x2 stride with a dilated convolution instead
            replace_stride_with_dilation = [False, False, False]
        if len(replace_stride_with_dilation) != 3:
            raise ValueError("replace_stride_with_dilation should be None "
                             "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
        self.groups = groups
        self.base_width = width_per_group
        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.bn1 = norm_layer(self.inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
                                       dilate=replace_stride_with_dilation[0])
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
                                       dilate=replace_stride_with_dilation[1])
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
                                       dilate=replace_stride_with_dilation[2])
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
                    nn.init.constant_(m.bn3.weight, 0)
                elif isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)

    def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
        norm_layer = self._norm_layer
        downsample = None
        previous_dilation = self.dilation
        if dilate:
            self.dilation *= stride
            stride = 1
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride),
                norm_layer(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
                            self.base_width, previous_dilation, norm_layer))
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes, groups=self.groups,
                                base_width=self.base_width, dilation=self.dilation,
                                norm_layer=norm_layer))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)

        return x

def _resnet(arch, block, layers, pretrained, progress, **kwargs):
    model = ResNet(block, layers, **kwargs)
    if pretrained:
        state_dict = load_state_dict_from_url(model_urls[arch],
                                              progress=progress)
        model.load_state_dict(state_dict)
    return model

def resnet50(pretrained=False, progress=True, **kwargs):
    """ResNet-50 model from <https://arxiv.org/pdf/1512.03385.pdf>`_
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
                   **kwargs)

class IntermediateLayerGetter(nn.ModuleDict):
    '''
    model의 중간 layer를 return하는 class
    사용되는 순서대로 layer가 등록?되었다는 가정하에 동작함. 

    Arguments:
        model (nn.Module): model on which we will extract the features
        return_layers (Dict[name, new_name]): 추출하고자 하는 layer명과 사용할 변수명
    Examples::
        >>> m = torchvision.models.resnet18(pretrained=True)
        >>> # extract layer1 and layer3, giving as names `feat1` and feat2`
        >>> new_m = torchvision.models._utils.IntermediateLayerGetter(m,
        >>>     {'layer1': 'feat1', 'layer3': 'feat2'})
        >>> out = new_m(torch.rand(1, 3, 224, 224))
        >>> print([(k, v.shape) for k, v in out.items()])
        >>>     [('feat1', torch.Size([1, 64, 56, 56])),
        >>>      ('feat2', torch.Size([1, 256, 14, 14]))]
    '''
    def __init__(self, model, return_layers):
        if not set(return_layers).issubset([name for name, _ in model.named_children()]):
            raise ValueError("return_layers are not present in model")

        orig_return_layers = return_layers
        return_layers = {k: v for k, v in return_layers.items()}
        layers = OrderedDict()
        for name, module in model.named_children():
            layers[name] = module
            if name in return_layers:
                del return_layers[name]
            if not return_layers:
                break

        super(IntermediateLayerGetter, self).__init__(layers)
        self.return_layers = orig_return_layers

    def forward(self, x):
        out = OrderedDict()
        for name, module in self.named_children():
            x = module(x)
            if name in self.return_layers:
                out_name = self.return_layers[name]
                out[out_name] = x
        return out

##(3) DeepLab Head
- ASPP(2048ch→ 256ch) + 3x3 Conv(256ch→256ch) + 1x1 Conv(256ch→21:number of class)

- Backbone + DeepLab v3 Head를 거쳐 도출된 segmentation 결과를, Bilinear upsampling을 이용하여 image-level 해상도로 복원

In [None]:
class DeepLabV3(nn.Module):

    # TODO: interpolation parameter 추가
    #def __init__(self, backbone, classifier):
    def __init__(self, backbone, classifier, upMethod = ''):
        super(DeepLabV3, self).__init__()
        self.backbone = backbone
        self.classifier = classifier
        self.upsampleMethod = upMethod
        
    def forward(self, x):
        input_shape = x.shape[-2:]
        input_ch = x.shape[1] # 256
        features = self.backbone(x)
        x = self.classifier(features)

        if self.upsampleMethod == 'CARAFE':
            # upScale =input_shape[1] / x.shape[-1]
            # print(input_ch)

            carafe =CARAFE(c=21, scale=2)
            if torch.cuda.is_available():
                carafe.cuda()
            x = carafe(x) 
            
            x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=False)

        else :
            x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=False)
        
        return x

class DeepLabHead(nn.Module):
    def __init__(self, in_channels, num_classes, aspp_dilate=[12, 24, 36]):
        super(DeepLabHead, self).__init__()

        self.classifier = nn.Sequential(
            ASPP(in_channels, aspp_dilate),
            nn.Conv2d(256, 256, 3, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, num_classes, 1)
        )
        self._init_weight()

    def forward(self, feature):
        return self.classifier( feature['out'] )

    def _init_weight(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)


### ASPP (Atrous Spatial Pyramid Pooling)
- 초기 Object Detection에 적용된 SPPNet의 Idea 차용 Atrous Conv.로 구성한 SPP 구성

    → ConvLayer의 output을 다양한 Atrous rate를 가지는 kernel을 병렬로 연산하여 multi-scale 특징을 추출하고자 함.

- Atrous rate가 커질수록 유효한 weight의 수가 작아지는 경향을 보임.
    ⇒  Atorus rate가 커지면 3x3 kernel이 1x1처럼 동작함.

    ⇒ large scale context는 output에 반영 안됨. (receptive field를 크게 만들고 싶은데 안됨..)

- 이런 degenerate 문제 해결을 위해 이미지의 last feature map에 대해서 "Global average pooling"을 적용.
- 성능 개선을 위해 각 module에 batch normalization 적용

![ASPP](https://gaussian37.github.io/assets/img/vision/segmentation/aspp/0.png)

### Atrous Convolution (=Dilated Convolution)

- Kernel의 성분들 사이에 빈 성분(0)을 삽입하여 convolution을 수행
- Dilation rate는 빈 성분(0)을 몇 개 삽입할지 결정. 기존 일반적인 convolution의 dilation rate=1. dilation rate=2일 경우는 성분들 사이에 빈 성분(0) 1개 추가, 3일 경우는 2개 추가, ...

- 기존 3x3 Kernel에 아래와 같이 dilation rate=2를 사용하면 field-of-view(receptive field)가 5x5의 영역을 커버하게 됨. 

- x: input feature map, w: filter, r: dilation rate, y: output일 때 다음과 같이 표현 가능하다. 
$$y[i]=\sum_{k}^{K}x[i+r\cdot{k}]w[k]$$

![https://miro.medium.com/max/395/1*1okwhewf5KCtIPaFib4XaA.gif](https://miro.medium.com/max/395/1*1okwhewf5KCtIPaFib4XaA.gif)
*convolution*
|
![https://miro.medium.com/max/395/1*SVkgHoFoiMZkjy54zM_SUw.gif](https://miro.medium.com/max/395/1*SVkgHoFoiMZkjy54zM_SUw.gif)*Atrous Conv.*

- **Atrous convolution은 pooling 연산 없이 넓은** field of view를 커버할 수 있음.
(기존 CNN의 경우 넓은 field of view를 커버하기 위해 conv. pooling을 사용하여 output feature map의 해상도가 감소.)


In [None]:
class ASPPConv(nn.Sequential):
    def __init__(self, in_channels, out_channels, dilation):
        modules = [
            nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        ]
        super(ASPPConv, self).__init__(*modules)

- Global Average Pooling + 1x1 Conv + Upsampling
    - Global Average Pooling (nn.AdaptiveAvgPool2d): Channel별로 average를 구하여 pooling (ex. 2048x33x33 → 2048x1x1)
    - 1x1 Conv: Output channel 수를 256으로  (ex. 2048x1x1 → 256x1x1)
    - Upsampling: Size를 원래대로 (ex. 256x1x1 → 256x33x33)

In [None]:
class ASPPPooling(nn.Sequential):
    def __init__(self, in_channels, out_channels):
        super(ASPPPooling, self).__init__(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(in_channels, out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True))

    def forward(self, x):
        size = x.shape[-2:]
        x = super(ASPPPooling, self).forward(x)
        return F.interpolate(x, size=size, mode='bilinear', align_corners=False)

- 병렬화된 1x1 Conv, 3개의 3x3 Dilated Conv, Image pooling(code의 ASPP Pooling)를 concatenate
- Concatenate된 feature를 1x1 Conv를 통해 output channel을 256으로 생성 (5x256→256)

In [None]:
class ASPP(nn.Module):
    def __init__(self, in_channels, atrous_rates):
        super(ASPP, self).__init__()
        out_channels = 256
        modules = []
        modules.append(nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)))

        rate1, rate2, rate3 = tuple(atrous_rates)
        modules.append(ASPPConv(in_channels, out_channels, rate1))
        modules.append(ASPPConv(in_channels, out_channels, rate2))
        modules.append(ASPPConv(in_channels, out_channels, rate3))
        modules.append(ASPPPooling(in_channels, out_channels))

        self.convs = nn.ModuleList(modules)

        self.project = nn.Sequential(
            nn.Conv2d(5 * out_channels, out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Dropout(0.1),)

    def forward(self, x):
        res = []
        for conv in self.convs:
            res.append(conv(x))
        res = torch.cat(res, dim=1)
        return self.project(res)

##(4) CARAFE
Code from: https://github.com/XiaLiPKU/CARAFE/blob/master/carafe.py#L68

- 

In [None]:
class ConvBNReLU(nn.Module):
    '''Module for the Conv-BN-ReLU tuple.'''
    def __init__(self, c_in, c_out, kernel_size, stride, padding, dilation,
                 use_relu=True):
        super(ConvBNReLU, self).__init__()
        self.conv = nn.Conv2d(
                c_in, c_out, kernel_size=kernel_size, stride=stride, 
                padding=padding, dilation=dilation, bias=False)
        self.bn = nn.BatchNorm2d(c_out)
        if use_relu:
            self.relu = nn.ReLU(inplace=True)
        else:
            self.relu = None

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        if self.relu is not None:
            x = self.relu(x)
        return x


class CARAFE(nn.Module):
    def __init__(self, c, c_mid=64, scale=2, k_up=5, k_enc=3):
        """ The unofficial implementation of the CARAFE module.
        The details are in "https://arxiv.org/abs/1905.02188".
        Args:
            c: The channel number of the input and the output.
            c_mid: The channel number after compression.
            scale: The expected upsample scale.
            k_up: The size of the reassembly kernel.
            k_enc: The kernel size of the encoder.
        Returns:
            X: The upsampled feature map.
        """
        super(CARAFE, self).__init__()
        self.scale = scale

        self.comp = ConvBNReLU(c, c_mid, kernel_size=1, stride=1, 
                               padding=0, dilation=1)
        self.enc = ConvBNReLU(c_mid, (scale*k_up)**2, kernel_size=k_enc, 
                              stride=1, padding=k_enc//2, dilation=1, 
                              use_relu=False)
        self.pix_shf = nn.PixelShuffle(scale)

        self.upsmp = nn.Upsample(scale_factor=scale, mode='nearest')
        self.unfold = nn.Unfold(kernel_size=k_up, dilation=scale, 
                                padding=k_up//2*scale)

    def forward(self, X):
        b, c, h, w = X.size()
        h_, w_ = h * self.scale, w * self.scale
        
        W = self.comp(X)                                # b * m * h * w
        W = self.enc(W)                                 # b * 100(sigma^2, k_up^2) * h * w
        W = self.pix_shf(W)                             # b * 25(k_up^2) * h_ * w_
        W = F.softmax(W, dim=1)                         # b * 25 * h_ * w_

        X = self.upsmp(X)                               # b * c * h_ * w_
        X = self.unfold(X)                              # b * 25c * h_ * w_
        X = X.view(b, c, -1, h_, w_)                    # b * 25 * c * h_ * w_

        X = torch.einsum('bkhw,bckhw->bchw', [W, X])    # b * c * h_ * w_
        return X

# 3. Dataset

###(1) PASCAL VOC 2012 Segmentation Task Dataset

http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar
https://academictorrents.com/details/df0aad374e63b3214ef9e92e178580ce27570e59

In [None]:
'''
!wget -P datasets/data/ "http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar"
!tar -xvf datasets/data/VOCtrainval_11-May-2012.tar -C ./datasets/data

# SegmentationClassAug.zip를 다운. (https://www.dropbox.com/s/oeu149j8qtbs1x0/SegmentationClassAug.zip?dl=0)
!wget -P datasets/data/ "https://www.dropbox.com/s/oeu149j8qtbs1x0/SegmentationClassAug.zip?dl=1"
!unzip datasets/data/SegmentationClassAug.zip -d datasets/data/VOCdevkit/VOC2012/
'''

'\n!wget -P datasets/data/ "http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar"\n!tar -xvf datasets/data/VOCtrainval_11-May-2012.tar -C ./datasets/data\n\n# SegmentationClassAug.zip를 다운. (https://www.dropbox.com/s/oeu149j8qtbs1x0/SegmentationClassAug.zip?dl=0)\n!wget -P datasets/data/ "https://www.dropbox.com/s/oeu149j8qtbs1x0/SegmentationClassAug.zip?dl=1"\n!unzip datasets/data/SegmentationClassAug.zip -d datasets/data/VOCdevkit/VOC2012/\n'

###(2) Data augmentation

- Scaling, Crop, 상하반전으로 Training data augmentation

In [None]:
import os
import sys
import tarfile
import collections
import torch.utils.data as data
import shutil
import numpy as np

from PIL import Image


def voc_cmap(N=256, normalized=False):
    def bitget(byteval, idx):
        return ((byteval & (1 << idx)) != 0)

    dtype = 'float32' if normalized else 'uint8'
    cmap = np.zeros((N, 3), dtype=dtype)
    for i in range(N):
        r = g = b = 0
        c = i
        for j in range(8):
            r = r | (bitget(c, 0) << 7-j)
            g = g | (bitget(c, 1) << 7-j)
            b = b | (bitget(c, 2) << 7-j)
            c = c >> 3

        cmap[i] = np.array([r, g, b])

    cmap = cmap/255 if normalized else cmap
    return cmap

class VOCSegmentation(data.Dataset):
    """`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Segmentation Dataset.
    Args:
        root (string): Root directory of the VOC Dataset.
        year (string, optional): The dataset year, supports years 2007 to 2012.
        image_set (string, optional): Select the image_set to use, ``train``, ``trainval`` or ``val``
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
    """
    cmap = voc_cmap()
    def __init__(self,
                 root,
                 year='2012',
                 image_set='train',
                 #download=False,
                 transform=None):

        is_aug=False
        if year=='2012_aug':
            is_aug = True
            year = '2012'
        
        self.root = os.path.expanduser(root)
        self.year = year
        #self.url = DATASET_YEAR_DICT[year]['url']
        #self.filename = DATASET_YEAR_DICT[year]['filename']
        #self.md5 = DATASET_YEAR_DICT[year]['md5']
        self.transform = transform
        
        self.image_set = image_set
        base_dir = 'VOCdevkit/VOC2012'
        voc_root = os.path.join(self.root, base_dir)
        image_dir = os.path.join(voc_root, 'JPEGImages')


        if not os.path.isdir(voc_root):
            raise RuntimeError('Dataset not found or corrupted.' +
                               ' You can use download=True to download it')
        
        if is_aug and image_set=='train':
            mask_dir = os.path.join(voc_root, 'SegmentationClassAug')
            assert os.path.exists(mask_dir), "SegmentationClassAug not found, please refer to README.md and prepare it manually"
            split_f = os.path.join( self.root, 'train_aug.txt')#'./datasets/data/train_aug.txt'
        else:
            mask_dir = os.path.join(voc_root, 'SegmentationClass')
            splits_dir = os.path.join(voc_root, 'ImageSets/Segmentation')
            split_f = os.path.join(splits_dir, image_set.rstrip('\n') + '.txt')

        if not os.path.exists(split_f):
            raise ValueError(
                'Wrong image_set entered! Please use image_set="train" '
                'or image_set="trainval" or image_set="val"')

        with open(os.path.join(split_f), "r") as f:
            file_names = [x.strip() for x in f.readlines()]
        
        self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names]
        self.masks = [os.path.join(mask_dir, x + ".png") for x in file_names]
        assert (len(self.images) == len(self.masks))

    def __getitem__(self, index):
        """
        Args:
            index (int): Index
        Returns:
            tuple: (image, target) where target is the image segmentation.
        """
        img = Image.open(self.images[index]).convert('RGB')
        target = Image.open(self.masks[index])
        if self.transform is not None:
            img, target = self.transform(img, target)

        return img, target


    def __len__(self):
        return len(self.images)

    @classmethod
    def decode_target(cls, mask):
        """decode semantic mask to RGB image"""
        return cls.cmap[mask]

In [None]:
import torchvision
import torch
import torchvision.transforms.functional as torchvisionF
import random 
import numbers
import numpy as np
from PIL import Image

#
#  Extended Transforms for Semantic Segmentation
#
class ExtRandomHorizontalFlip(object):
    """Horizontally flip the given PIL Image randomly with a given probability.
    Args:
        p (float): probability of the image being flipped. Default value is 0.5
    """

    def __init__(self, p=0.5):
        self.p = p

    def __call__(self, img, lbl):
        """
        Args:
            img (PIL Image): Image to be flipped.
        Returns:
            PIL Image: Randomly flipped image.
        """
        if random.random() < self.p:
            return torchvisionF.hflip(img), torchvisionF.hflip(lbl)
        return img, lbl

    def __repr__(self):
        return self.__class__.__name__ + '(p={})'.format(self.p)

class ExtCompose(object):
    """Composes several transforms together.
    Args:
        transforms (list of ``Transform`` objects): list of transforms to compose.
    Example:
        >>> transforms.Compose([
        >>>     transforms.CenterCrop(10),
        >>>     transforms.ToTensor(),
        >>> ])
    """

    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, img, lbl):
        for t in self.transforms:
            img, lbl = t(img, lbl)
        return img, lbl

    def __repr__(self):
        format_string = self.__class__.__name__ + '('
        for t in self.transforms:
            format_string += '\n'
            format_string += '    {0}'.format(t)
        format_string += '\n)'
        return format_string

class ExtCenterCrop(object):
    """Crops the given PIL Image at the center.
    Args:
        size (sequence or int): Desired output size of the crop. If size is an
            int instead of sequence like (h, w), a square crop (size, size) is
            made.
    """

    def __init__(self, size):
        if isinstance(size, numbers.Number):
            self.size = (int(size), int(size))
        else:
            self.size = size

    def __call__(self, img, lbl):
        """
        Args:
            img (PIL Image): Image to be cropped.
        Returns:
            PIL Image: Cropped image.
        """
        return center_crop(img, self.size), F.center_crop(lbl, self.size)

    def __repr__(self):
        return self.__class__.__name__ + '(size={0})'.format(self.size)

class ExtRandomScale(object):
    def __init__(self, scale_range, interpolation=Image.BILINEAR):
        self.scale_range = scale_range
        self.interpolation = interpolation

    def __call__(self, img, lbl):
        """
        Args:
            img (PIL Image): Image to be scaled.
            lbl (PIL Image): Label to be scaled.
        Returns:
            PIL Image: Rescaled image.
            PIL Image: Rescaled label.
        """
        assert img.size == lbl.size
        scale = random.uniform(self.scale_range[0], self.scale_range[1])
        target_size = ( int(img.size[1]*scale), int(img.size[0]*scale) )
        return torchvisionF.resize(img, target_size, self.interpolation), torchvisionF.resize(lbl, target_size, Image.NEAREST)

    def __repr__(self):
        interpolate_str = _pil_interpolation_to_str[self.interpolation]
        return self.__class__.__name__ + '(size={0}, interpolation={1})'.format(self.size, interpolate_str)

class ExtRandomHorizontalFlip(object):
    """Horizontally flip the given PIL Image randomly with a given probability.
    Args:
        p (float): probability of the image being flipped. Default value is 0.5
    """

    def __init__(self, p=0.5):
        self.p = p

    def __call__(self, img, lbl):
        """
        Args:
            img (PIL Image): Image to be flipped.
        Returns:
            PIL Image: Randomly flipped image.
        """
        if random.random() < self.p:
            return torchvisionF.hflip(img), torchvisionF.hflip(lbl)
        return img, lbl

    def __repr__(self):
        return self.__class__.__name__ + '(p={})'.format(self.p)

class ExtToTensor(object):
    """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
    Converts a PIL Image or numpy.ndarray (H x W x C) in the range
    [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0].
    """
    def __init__(self, normalize=True, target_type='uint8'):
        self.normalize = normalize
        self.target_type = target_type
    def __call__(self, pic, lbl):
        """
        Note that labels will not be normalized to [0, 1].
        Args:
            pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
            lbl (PIL Image or numpy.ndarray): Label to be converted to tensor. 
        Returns:
            Tensor: Converted image and label
        """
        if self.normalize:
            return torchvision.transforms.functional.to_tensor(pic), torch.from_numpy( np.array( lbl, dtype=self.target_type) )
        else:
            return torch.from_numpy( np.array( pic, dtype=np.float32).transpose(2, 0, 1) ), torch.from_numpy( np.array( lbl, dtype=self.target_type) )

    def __repr__(self):
        return self.__class__.__name__ + '()'

class ExtNormalize(object):
    """Normalize a tensor image with mean and standard deviation.
    Given mean: ``(M1,...,Mn)`` and std: ``(S1,..,Sn)`` for ``n`` channels, this transform
    will normalize each channel of the input ``torch.*Tensor`` i.e.
    ``input[channel] = (input[channel] - mean[channel]) / std[channel]``
    Args:
        mean (sequence): Sequence of means for each channel.
        std (sequence): Sequence of standard deviations for each channel.
    """

    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, tensor, lbl):
        """
        Args:
            tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
            tensor (Tensor): Tensor of label. A dummy input for ExtCompose
        Returns:
            Tensor: Normalized Tensor image.
            Tensor: Unchanged Tensor label
        """
        return torchvision.transforms.functional.normalize(tensor, self.mean, self.std), lbl

    def __repr__(self):
        return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)

class ExtRandomCrop(object):
    """Crop the given PIL Image at a random location.
    Args:
        size (sequence or int): Desired output size of the crop. If size is an
            int instead of sequence like (h, w), a square crop (size, size) is
            made.
        padding (int or sequence, optional): Optional padding on each border
            of the image. Default is 0, i.e no padding. If a sequence of length
            4 is provided, it is used to pad left, top, right, bottom borders
            respectively.
        pad_if_needed (boolean): It will pad the image if smaller than the
            desired size to avoid raising an exception.
    """

    def __init__(self, size, padding=0, pad_if_needed=False):
        if isinstance(size, numbers.Number):
            self.size = (int(size), int(size))
        else:
            self.size = size
        self.padding = padding
        self.pad_if_needed = pad_if_needed

    @staticmethod
    def get_params(img, output_size):
        """Get parameters for ``crop`` for a random crop.
        Args:
            img (PIL Image): Image to be cropped.
            output_size (tuple): Expected output size of the crop.
        Returns:
            tuple: params (i, j, h, w) to be passed to ``crop`` for random crop.
        """
        w, h = img.size
        th, tw = output_size
        if w == tw and h == th:
            return 0, 0, h, w

        i = random.randint(0, h - th)
        j = random.randint(0, w - tw)
        return i, j, th, tw

    def __call__(self, img, lbl):
        """
        Args:
            img (PIL Image): Image to be cropped.
            lbl (PIL Image): Label to be cropped.
        Returns:
            PIL Image: Cropped image.
            PIL Image: Cropped label.
        """
        assert img.size == lbl.size, 'size of img and lbl should be the same. %s, %s'%(img.size, lbl.size)
        if self.padding > 0:
            img = torchvisionF.pad(img, self.padding)
            lbl = torchvisionF.pad(lbl, self.padding)

        # pad the width if needed
        if self.pad_if_needed and img.size[0] < self.size[1]:
            img = torchvisionF.pad(img, padding=int((1 + self.size[1] - img.size[0]) / 2))
            lbl = torchvisionF.pad(lbl, padding=int((1 + self.size[1] - lbl.size[0]) / 2))

        # pad the height if needed
        if self.pad_if_needed and img.size[1] < self.size[0]:
            img = torchvisionF.pad(img, padding=int((1 + self.size[0] - img.size[1]) / 2))
            lbl = torchvisionF.pad(lbl, padding=int((1 + self.size[0] - lbl.size[1]) / 2))

        i, j, h, w = self.get_params(img, self.size)

        return torchvisionF.crop(img, i, j, h, w), torchvisionF.crop(lbl, i, j, h, w)

    def __repr__(self):
        return self.__class__.__name__ + '(size={0}, padding={1})'.format(self.size, self.padding)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

def get_dataset(opts):
    """ Dataset And Augmentation
    """
    if opts.dataset == 'voc':
        train_transform = ExtCompose([
            #et.ExtResize(size=opts.crop_size),
            ExtRandomScale((0.5, 2.0)),
            ExtRandomCrop(size=(opts.crop_size, opts.crop_size), pad_if_needed=True),
            ExtRandomHorizontalFlip(),
            ExtToTensor(),
            ExtNormalize(mean=[0.485, 0.456, 0.406],
                            std=[0.229, 0.224, 0.225]),
        ])
        if opts.crop_val:
            val_transform = ExtCompose([
                ExtResize(opts.crop_size),
                ExtCenterCrop(opts.crop_size),
                ExtToTensor(),
                ExtNormalize(mean=[0.485, 0.456, 0.406],
                                std=[0.229, 0.224, 0.225]),
            ])
        else:
            val_transform = ExtCompose([
                ExtToTensor(),
                ExtNormalize(mean=[0.485, 0.456, 0.406],
                                std=[0.229, 0.224, 0.225]),
            ])
        train_dst = VOCSegmentation(root=opts.data_root, year=opts.year,
                                    image_set='train', transform=train_transform)
        val_dst = VOCSegmentation(root=opts.data_root, year=opts.year,
                                  image_set='val', transform=val_transform)
    
    return train_dst, val_dst

#4. Train/Test


### (1) poly learning rate policy

- 초반엔 선형 감소에 가깝지만, Training step의 마지막에서는 조금 더 가파르게 감소하는 경향성을 나타냄
$$learning\_rate = \left( 1-\frac{iter}{max\_iter}\right)^{power}$$

In [None]:
from torch.optim.lr_scheduler import _LRScheduler, StepLR

class PolyLR(_LRScheduler):
    def __init__(self, optimizer, max_iters, power=0.9, last_epoch=-1, min_lr=1e-6):
        self.power = power
        self.max_iters = max_iters  # avoid zero lr
        self.min_lr = min_lr
        super(PolyLR, self).__init__(optimizer, last_epoch)
    
    def get_lr(self):
        return [ max( base_lr * ( 1 - self.last_epoch/self.max_iters )**self.power, self.min_lr)
                for base_lr in self.base_lrs]

###(2) Focal Loss
- ignore_index: 특정 target value를 무시하며, gradient 계산 할 때 포함시키지 않는다. size_average가 true일 경우 ignore_index되어 있는 target을 제외하고 평균을 계산한다.


In [None]:
import torch.nn as nn
import torch.nn.functional as F
import torch 

class FocalLoss(nn.Module):
    def __init__(self, alpha=1, gamma=0, size_average=True, ignore_index=255):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.ignore_index = ignore_index
        self.size_average = size_average

    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(
            inputs, targets, reduction='none', ignore_index=self.ignore_index)
        pt = torch.exp(-ce_loss)
        focal_loss = self.alpha * (1-pt)**self.gamma * ce_loss
        if self.size_average:
            return focal_loss.mean()
        else:
            return focal_loss.sum()

In [None]:
import torch.nn as nn
from PIL import Image

def set_bn_momentum(model, momentum=0.1):
    for m in model.modules():
        if isinstance(m, nn.BatchNorm2d):
            m.momentum = momentum

def save_ckpt(path):
    """ save current model
    """
    torch.save({
        "cur_itrs": cur_itrs,
        "model_state": model.module.state_dict(),
        "optimizer_state": optimizer.state_dict(),
        "scheduler_state": scheduler.state_dict(),
        "best_score": best_score,
    }, path)
    print("Model saved as %s" % path)

class Denormalize(object):
    def __init__(self, mean, std):
        mean = np.array(mean)
        std = np.array(std)
        self._mean = -mean/std
        self._std = 1/std

    def __call__(self, tensor):
        if isinstance(tensor, np.ndarray):
            return (tensor - self._mean.reshape(-1,1,1)) / self._std.reshape(-1,1,1)
        return normalize(tensor, self._mean, self._std)

###(3) Validation / Performance Measure

In [None]:
from tqdm import tqdm
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image

import cv2
import matplotlib
import matplotlib.pyplot as plt


def validate(opts, model, loader, device, metrics):
    """Do validation and return specified samples"""
    metrics.reset()
    ret_samples = []
    if opts.save_val_results:
        if not os.path.exists('results'):
            os.mkdir('results')
        denorm = Denormalize(mean=[0.485, 0.456, 0.406], 
                              std=[0.229, 0.224, 0.225])
        img_id = 0

    with torch.no_grad():
        for i, (images, labels) in tqdm(enumerate(loader)):
            
            images = images.to(device, dtype=torch.float32)
            labels = labels.to(device, dtype=torch.long)

            outputs = model(images)
            preds = outputs.detach().max(dim=1)[1].cpu().numpy()
            targets = labels.cpu().numpy()

            metrics.update(targets, preds)

            if opts.save_val_results:
                for i in range(len(images)):
                    image = images[i].detach().cpu().numpy()
                    target = targets[i]
                    pred = preds[i]

                    image = (denorm(image) * 255).transpose(1, 2, 0).astype(np.uint8)
                    target = loader.dataset.decode_target(target).astype(np.uint8)
                    pred = loader.dataset.decode_target(pred).astype(np.uint8)

                    Image.fromarray(image).save('results/%d_image.png' % img_id)
                    Image.fromarray(target).save('results/%d_target.png' % img_id)
                    Image.fromarray(pred).save('results/%d_pred.png' % img_id)

                    fig = plt.figure()
                    plt.imshow(image)
                    plt.axis('off')
                    plt.imshow(pred, alpha=0.7)
                    ax = plt.gca()
                    ax.xaxis.set_major_locator(matplotlib.ticker.NullLocator())
                    ax.yaxis.set_major_locator(matplotlib.ticker.NullLocator())
                    plt.savefig('results/%d_overlay.png' % img_id, bbox_inches='tight', pad_inches=0)
                    plt.close()


                    if opts.test_only:
                      imgBGR = cv2.imread('results/%d_image.png' % img_id)

                      # cv2.imread는 BGR로 불러오므로 plt를 이용하려면 RGB로 바꿔줘야 함
                      imgTargetBGR = cv2.imread('results/%d_target.png' % img_id)
                      imgPredBGR = cv2.imread('results/%d_pred.png' % img_id)
                      imgOverBGR = cv2.imread('results/%d_overlay.png' % img_id)
                      plt.axis('off')

                      plt.subplot(141), plt.axis('off'), plt.imshow(imgBGR)
                      plt.subplot(142), plt.axis('off'), plt.imshow(imgTargetBGR)
                      plt.subplot(143), plt.axis('off'), plt.imshow(imgPredBGR)
                      plt.subplot(144), plt.axis('off'), plt.imshow(imgOverBGR)
                      plt.show()


                    img_id += 1

        score = metrics.get_results()
    return score, ret_samples

class StreamSegMetrics(object):
    """
    Stream Metrics for Semantic Segmentation Task
    """
    def __init__(self, n_classes):
        self.n_classes = n_classes
        self.confusion_matrix = np.zeros((n_classes, n_classes))

    def update(self, label_trues, label_preds):
        for lt, lp in zip(label_trues, label_preds):
            self.confusion_matrix += self._fast_hist( lt.flatten(), lp.flatten() )
    
    @staticmethod
    def to_str(results):
        string = "\n"
        for k, v in results.items():
            if k!="Class IoU":
                string += "%s: %f\n"%(k, v)
        
        #string+='Class IoU:\n'
        #for k, v in results['Class IoU'].items():
        #    string += "\tclass %d: %f\n"%(k, v)
        return string

    def _fast_hist(self, label_true, label_pred):
        mask = (label_true >= 0) & (label_true < self.n_classes)
        hist = np.bincount(
            self.n_classes * label_true[mask].astype(int) + label_pred[mask],
            minlength=self.n_classes ** 2,
        ).reshape(self.n_classes, self.n_classes)
        return hist

    def get_results(self):
        """Returns accuracy score evaluation result.
            - overall accuracy
            - mean accuracy
            - mean IU
            - fwavacc
        """
        hist = self.confusion_matrix
        acc = np.diag(hist).sum() / hist.sum()
        acc_cls = np.diag(hist) / hist.sum(axis=1)
        acc_cls = np.nanmean(acc_cls)
        iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist))
        mean_iu = np.nanmean(iu)
        freq = hist.sum(axis=1) / hist.sum()
        fwavacc = (freq[freq > 0] * iu[freq > 0]).sum()
        cls_iu = dict(zip(range(self.n_classes), iu))

        return {
                "Overall Acc": acc,
                "Mean Acc": acc_cls,
                "FreqW Acc": fwavacc,
                "Mean IoU": mean_iu,
                "Class IoU": cls_iu,
            }
        
    def reset(self):
        self.confusion_matrix = np.zeros((self.n_classes, self.n_classes))

##(4) Run

In [None]:
#!wget -P checkpoints/ "https://www.dropbox.com/s/3eag5ojccwiexkq/best_deeplabv3_resnet50_voc_os16.pth?dl=1"

In [None]:
class Config():
    # TODO : ADD compound scaling factor
    def __init__(self):
        super(Config, self).__init__()

        self.data_root='./datasets/data'
        self.dataset ='voc'
        self.download =False
        self.year ='2012' # VOC

        self.num_classes = 21
        self.model ='deeplabv3_resnet50_CARAFE'
        self.output_stride =16

        self.pretrained_backbone = True

        # train option
        self.test_only =False
        self.save_val_results =True
        self.total_itrs =30e3
        self.lr =0.01
        self.lr_policy ='poly' # 'step'
        self.step_size =10000
        self.crop_val =False
        self.batch_size =16
        self.val_batch_size =4
        self.crop_size =513
        self.ckpt ='checkpoints/latest_deeplabv3_resnet50_CARAFE_voc_os16.pth'
        self.continue_training =False

        self.loss_type = 'cross_entropy' #'cross_entropy', 'focal_loss'
        self.gpu_id ='0'
        self.weight_decay =1e-4
        self.random_seed =1
        self.print_interval =10
        self.val_interval =100

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import data

import os
import numpy as np
import random

# Load option
opts =Config()

os.environ['CUDA_VISIBLE_DEVICES'] = opts.gpu_id
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Device: %s" % device)

# Setup random seed
torch.manual_seed(opts.random_seed)
np.random.seed(opts.random_seed)
random.seed(opts.random_seed)


# Set dataloader
if opts.dataset.lower() == 'voc':
    opts.num_classes = 21
#elif opts.dataset.lower() == 'cityscapes':
#    opts.num_classes = 19

if opts.dataset=='voc' and not opts.crop_val:
    opts.val_batch_size = 1

train_dst, val_dst = get_dataset(opts)
train_loader = data.DataLoader(
    train_dst, batch_size=opts.batch_size, shuffle=True, num_workers=2)
val_loader = data.DataLoader(
    val_dst, batch_size=opts.val_batch_size, shuffle=True, num_workers=2)
print("Dataset: %s, Train set: %d, Val set: %d" %
      (opts.dataset, len(train_dst), len(val_dst)))

# Set up model
model  = ResNetbasedDeepLabV3(num_classes=opts.num_classes, output_stride=opts.output_stride, pretrained_backbone=opts.pretrained_backbone)
set_bn_momentum(model.backbone, momentum=0.01)

# Set up optimizer
optimizer = torch.optim.SGD(params=[
    {'params': model.backbone.parameters(), 'lr': 0.1*opts.lr},
    {'params': model.classifier.parameters(), 'lr': opts.lr},
], lr=opts.lr, momentum=0.9, weight_decay=opts.weight_decay)

if opts.lr_policy=='poly':
    scheduler = PolyLR(optimizer, opts.total_itrs, power=0.9)
elif opts.lr_policy=='step':
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=opts.step_size, gamma=0.1)

# Set up criterion
if opts.loss_type == 'focal_loss':
    criterion = FocalLoss(ignore_index=255, size_average=True)
elif opts.loss_type == 'cross_entropy':
    criterion = nn.CrossEntropyLoss(ignore_index=255, reduction='mean')

# Set up metrics
metrics = StreamSegMetrics(opts.num_classes)

checkPointPath = 'checkpoints'
if not os.path.exists(checkPointPath):
    os.mkdir(checkPointPath)
# Restore
best_score = 0.0
cur_itrs = 0
cur_epochs = 0

if opts.ckpt is not None and os.path.isfile(opts.ckpt):
    # https://github.com/VainF/DeepLabV3Plus-Pytorch/issues/8#issuecomment-605601402, @PytaichukBohdan
    checkpoint = torch.load(opts.ckpt, map_location=torch.device('cuda'))
    model.load_state_dict(checkpoint["model_state"])
    model = nn.DataParallel(model)
    model.to(device)
    if opts.continue_training:
        optimizer.load_state_dict(checkpoint["optimizer_state"])
        scheduler.load_state_dict(checkpoint["scheduler_state"])
        cur_itrs = checkpoint["cur_itrs"]
        best_score = checkpoint['best_score']
        print("Training state restored from %s" % opts.ckpt)
    print("Model restored from %s" % opts.ckpt)
    del checkpoint  # free memory
else:
    print("[!] Retrain")
    model = nn.DataParallel(model)
    model.to(device)

denorm = Denormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # denormalization for ori images

if opts.test_only:
    model.eval()
    val_score, ret_samples = validate(opts=opts, model=model, loader=val_loader, device=device, metrics=metrics)
    print(metrics.to_str(val_score))

else:
  interval_loss = 0
  while True: #cur_itrs < opts.total_itrs:
      # =====  Train  =====
      model.train()
      cur_epochs += 1
      for (images, labels) in train_loader:
          cur_itrs += 1

          images = images.to(device, dtype=torch.float32)
          labels = labels.to(device, dtype=torch.long)

          optimizer.zero_grad()
          outputs = model(images)
          loss = criterion(outputs, labels)
          loss.backward()
          optimizer.step()

          np_loss = loss.detach().cpu().numpy()
          interval_loss += np_loss

          if (cur_itrs) % 10 == 0:
              interval_loss = interval_loss/10
              print("Epoch %d, Itrs %d/%d, Loss=%f" %
                    (cur_epochs, cur_itrs, opts.total_itrs, interval_loss))
              interval_loss = 0.0

          if (cur_itrs) % opts.val_interval == 0:
              save_ckpt('checkpoints/latest_%s_%s_os%d.pth' %
                        (opts.model, opts.dataset, opts.output_stride))
              
              print("validation...")
              
              model.eval()
              val_score, ret_samples = validate(
                  opts=opts, model=model, loader=val_loader, device=device, metrics=metrics)
              print(metrics.to_str(val_score))
              if val_score['Mean IoU'] > best_score:  # save best model
                  best_score = val_score['Mean IoU']
                  save_ckpt('checkpoints/best_%s_%s_os%d.pth' %
                            (opts.model, opts.dataset,opts.output_stride))
              
              #vis.vis_scalar("[Val] Overall Acc", cur_itrs, val_score['Overall Acc'])
              #vis.vis_scalar("[Val] Mean IoU", cur_itrs, val_score['Mean IoU'])
              #vis.vis_table("[Val] Class IoU", val_score['Class IoU'])

              #for k, (img, target, lbl) in enumerate(ret_samples):
              #    img = (denorm(img) * 255).astype(np.uint8)
              #    target = train_dst.decode_target(target).transpose(2, 0, 1).astype(np.uint8)
              #    lbl = train_dst.decode_target(lbl).transpose(2, 0, 1).astype(np.uint8)
              #    concat_img = np.concatenate((img, target, lbl), axis=2)  # concat along width
              #    vis.vis_image('Sample %d' % k, concat_img)

              model.train()
          scheduler.step()  

          #if cur_itrs >=  opts.total_itrs:
              #return


Device: cuda
Dataset: voc, Train set: 1464, Val set: 1449
Model restored from checkpoints/latest_deeplabv3_resnet50_CARAFE_voc_os16.pth


  "Argument interpolation should be of type InterpolationMode instead of int. "
  "Argument interpolation should be of type InterpolationMode instead of int. "


Epoch 1, Itrs 10/30000, Loss=0.100845
Epoch 1, Itrs 20/30000, Loss=0.092625
Epoch 1, Itrs 30/30000, Loss=0.093359
Epoch 1, Itrs 40/30000, Loss=0.097646
Epoch 1, Itrs 50/30000, Loss=0.089101
Epoch 1, Itrs 60/30000, Loss=0.089750
Epoch 1, Itrs 70/30000, Loss=0.084165
Epoch 1, Itrs 80/30000, Loss=0.090632
Epoch 1, Itrs 90/30000, Loss=0.101166


  "Argument interpolation should be of type InterpolationMode instead of int. "
  "Argument interpolation should be of type InterpolationMode instead of int. "


Epoch 2, Itrs 100/30000, Loss=0.085282
Model saved as checkpoints/latest_deeplabv3_resnet50_CARAFE_voc_os16.pth
validation...


1449it [32:54,  1.36s/it]



Overall Acc: 0.879841
Mean Acc: 0.745281
FreqW Acc: 0.798868
Mean IoU: 0.572428

Model saved as checkpoints/best_deeplabv3_resnet50_CARAFE_voc_os16.pth
Epoch 2, Itrs 110/30000, Loss=0.127214
Epoch 2, Itrs 120/30000, Loss=0.116897
Epoch 2, Itrs 130/30000, Loss=0.127013
Epoch 2, Itrs 140/30000, Loss=0.122655
Epoch 2, Itrs 150/30000, Loss=0.125529
Epoch 2, Itrs 160/30000, Loss=0.120137
Epoch 2, Itrs 170/30000, Loss=0.107338
Epoch 2, Itrs 180/30000, Loss=0.121240


  "Argument interpolation should be of type InterpolationMode instead of int. "
  "Argument interpolation should be of type InterpolationMode instead of int. "


Epoch 3, Itrs 190/30000, Loss=0.118151
Epoch 3, Itrs 200/30000, Loss=0.131133
Model saved as checkpoints/latest_deeplabv3_resnet50_CARAFE_voc_os16.pth
validation...


1449it [07:38,  3.16it/s]



Overall Acc: 0.861704
Mean Acc: 0.789893
FreqW Acc: 0.777186
Mean IoU: 0.549581

Epoch 3, Itrs 210/30000, Loss=0.131211
Epoch 3, Itrs 220/30000, Loss=0.130532
Epoch 3, Itrs 230/30000, Loss=0.139765
Epoch 3, Itrs 240/30000, Loss=0.150003
Epoch 3, Itrs 250/30000, Loss=0.139473
Epoch 3, Itrs 260/30000, Loss=0.135450
Epoch 3, Itrs 270/30000, Loss=0.151406


  "Argument interpolation should be of type InterpolationMode instead of int. "
  "Argument interpolation should be of type InterpolationMode instead of int. "


Epoch 4, Itrs 280/30000, Loss=0.126718
Epoch 4, Itrs 290/30000, Loss=0.151679
Epoch 4, Itrs 300/30000, Loss=0.135294
Model saved as checkpoints/latest_deeplabv3_resnet50_CARAFE_voc_os16.pth
validation...


1190it [1:14:34, 895.05s/it]

RuntimeError: ignored

In [None]:
print(torch.__version__)