# PSPNet

![fig_3](./fig_3.png)

코드의 구현에는 https://github.com/hszhao/semseg 를 참고하였습니다.

## 순서
1. Dilated ResNet 코드
2. Pyramid Parsing Module 코드
3. PSPNet 전체 코드

### 1. Dilated Residual Network (Dilated ResNet)

In [1]:
import torch
import torch.nn as nn
from torchinfo import summary
import torch.nn.functional as F

import resnet as models

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
class DilatedResNet(nn.Module):
    def __init__(self, layers=50, pretrained=True):
        super(DilatedResNet, self).__init__()

        # ResNet 50
        if layers == 50:
            resnet = models.resnet50(pretrained=pretrained)
        # ResNet 101
        elif layers == 101:
            resnet = models.resnet101(pretrained=pretrained)
        # ResNet 152
        else:
            resnet = models.resnet152(pretrained=pretrained)

        # ResNet with dilated network
        self.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu,
                                    resnet.conv2, resnet.bn2, resnet.relu,
                                    resnet.conv3, resnet.bn3, resnet.relu, resnet.maxpool)
        self.layer1 = resnet.layer1
        self.layer2 = resnet.layer2
        self.layer3 = resnet.layer3
        self.layer4 = resnet.layer4
        
        for n, m in self.layer3.named_modules():
            if 'conv2' in n:
                m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1)
            elif 'downsample.0' in n:
                m.stride = (1, 1)
        for n, m in self.layer4.named_modules():
            if 'conv2' in n:
                m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1)
            elif 'downsample.0' in n:
                m.stride = (1, 1)


    def forward(self, x, y=None):
        x = self.layer0(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x_tmp = self.layer3(x)
        x = self.layer4(x_tmp)

        return x

In [3]:
inp = torch.rand(4, 3, 200, 200)
layers = 50

resnet = DilatedResNet(layers=layers, pretrained=False)
output = resnet(inp)
print(f"Dilated ResNet {layers}'s output size : {output.size()}")

Dilated ResNet 50's output size : torch.Size([4, 2048, 25, 25])


## 2. Pyramid Pooling Module

In [4]:
class PPM(nn.Module):
    def __init__(self, in_dim, reduction_dim, bins):
        super(PPM, self).__init__()

        self.features = []
        
        # bins = (1, 2, 3, 6) : 1x1, 2x2, 3x3, 6x6
        for bin in bins:
            self.features.append(nn.Sequential(
                # Pyramid scale에 따라 각각의 pooling을 생성
                nn.AdaptiveAvgPool2d(bin),

                # 1/N으로 dimension reduction (reduction_dim = 4, pyramid level의 수)
                nn.Conv2d(in_dim, reduction_dim, kernel_size=1, bias=False),
                nn.BatchNorm2d(reduction_dim),
                nn.ReLU(inplace=True)
            ))
        self.features = nn.ModuleList(self.features)

    def forward(self, x):
        x_size = x.size()
        out = [x]
        for f in self.features:
            out.append(F.interpolate(f(x), x_size[2:], mode='bilinear', align_corners=True))
            
        # 각각의 pyramid scale에 따른 pooling 결과들을 concatenate
        return torch.cat(out, 1)

In [5]:
# input features dim : 2048
in_dim = output.size()[1]

# pyramid pooling levels : 1x1, 2x2, 3x3, 6x6
bins = (1, 2, 3, 6)

# dimension reduction : 1 / N
reduction_dim = int(in_dim / len(bins)) # N = 4

ppm = PPM(in_dim=in_dim, reduction_dim=reduction_dim, bins=bins)
output = ppm(output)
print(f"Pyramid Pooling Module's output size : {output.size()}")

Pyramid Pooling Module's output size : torch.Size([4, 4096, 25, 25])


### AdaptiveAvgPool2d

In [6]:
inp = torch.tensor([[[[1., 2., 3.], [4., 5., 6.], [7., 8., 9]]]], dtype = torch.float)
print(inp.shape)
print(inp)

torch.Size([1, 1, 3, 3])
tensor([[[[1., 2., 3.],
          [4., 5., 6.],
          [7., 8., 9.]]]])


In [7]:
out = nn.AdaptiveAvgPool2d(2)(inp)
print(out)
# print(torch.tensor(
#     [[[(1. + 2. + 4. + 5.) / 4, (2. + 3. + 5. + 6.) / 4],
#       [(4. + 5. + 7. + 8.) / 4, (5. + 6. + 8. + 9.) / 4]]]))

tensor([[[[3., 4.],
          [6., 7.]]]])


In [8]:
# Global Average Pooling
out = nn.AdaptiveAvgPool2d(1)(inp)
print(out)

tensor([[[[5.]]]])


## 3. PSPNet 전체 코드

In [9]:
class PSPNet(nn.Module):
    def __init__(self, layers=50, bins=(1, 2, 3, 6), dropout=0.1, classes=2, zoom_factor=8, pretrained=True):
        super(PSPNet, self).__init__()
        
        # output의 크기를 원본 이미지와 동일하게 복원하기 위한 값
        # Feature map의 크기는 원본 이미지의 1/8
        self.zoom_factor = zoom_factor
        
        self.criterion = nn.CrossEntropyLoss()
        
        # ResNet
        if layers == 50:
            resnet = models.resnet50(pretrained=pretrained)
        elif layers == 101:
            resnet = models.resnet101(pretrained=pretrained)
        else:
            resnet = models.resnet152(pretrained=pretrained)
        
        # ResNet with dilated network
        self.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu,
                                    resnet.conv2, resnet.bn2, resnet.relu,
                                    resnet.conv3, resnet.bn3, resnet.relu, resnet.maxpool)
        self.layer1 = resnet.layer1
        self.layer2 = resnet.layer2
        self.layer3 = resnet.layer3
        self.layer4 = resnet.layer4
        
        for n, m in self.layer3.named_modules():
            if 'conv2' in n:
                m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1)
            elif 'downsample.0' in n:
                m.stride = (1, 1)
        for n, m in self.layer4.named_modules():
            if 'conv2' in n:
                m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1)
            elif 'downsample.0' in n:
                m.stride = (1, 1)

        # Dilated ResNet output size : torch.Size([4, 2048, 60, 60])
        fea_dim = 2048
        self.ppm = PPM(in_dim = fea_dim, reduction_dim = int(fea_dim / len(bins)), bins=bins)
        
        # Pyramid Pooling Module output size : torch.Size([4, 4096, 60, 60])
        fea_dim *= 2 # 4096
        
        self.cls = nn.Sequential(
            nn.Conv2d(fea_dim, 512, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Dropout2d(p=dropout),
            nn.Conv2d(512, classes, kernel_size=1)
        )
        if self.training:
            self.aux = nn.Sequential(
                nn.Conv2d(1024, 256, kernel_size=3, padding=1, bias=False),
                nn.BatchNorm2d(256),
                nn.ReLU(inplace=True),
                nn.Dropout2d(p=dropout),
                nn.Conv2d(256, classes, kernel_size=1)
            )

    def forward(self, x, y=None):
        x_size = x.size()
        
        # Input image's height, width
        h = int((x_size[2] - 1) / 8 * self.zoom_factor + 1)
        w = int((x_size[3] - 1) / 8 * self.zoom_factor + 1)
        
        # Resnet with dilated network
        x = self.layer0(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x_tmp = self.layer3(x)
        x = self.layer4(x_tmp)
        
        # Pyramid Pooling Module
        x = self.ppm(x)
        
        # Master branch
        x = self.cls(x)
        
        # 원본 이미지 크기로 upsampling
        if self.zoom_factor != 1:
            x = F.interpolate(x, size=(h, w), mode='bilinear', align_corners=True)

        if self.training:
            # Auxiliary Loss는 training에서만 사용
            aux = self.aux(x_tmp)
            
            # 원본 이미지 크기로 upsampling
            if self.zoom_factor != 1:
                aux = F.interpolate(aux, size=(h, w), mode='bilinear', align_corners=True)
                
            main_loss = self.criterion(x, y)
            aux_loss = self.criterion(aux, y)
            return x.max(1)[1], main_loss, aux_loss
        else:
            return x

In [10]:
inp = torch.rand(4, 3, 473, 473).to(device)
layers = 50

pspnet = PSPNet(layers=layers, bins=(1, 2, 3, 6), dropout=0.1, classes=2, zoom_factor=8, pretrained=False).to(device)
pspnet.eval()
output = pspnet(inp)
print(f"PSPNet with Dilated ResNet {layers}'s output size : {output.size()}")

PSPNet with Dilated ResNet 50's output size : torch.Size([4, 2, 473, 473])
