In [120]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# PSPNet

## Feature Module
### FeatureMap_convolution
- input size : $3\times 475\times 475(c\times h\times w)$
- output size : $128\times 119\times 119$


In [121]:
class conv2DBatchNormRelu(nn.Module):
    def __init__(self,in_channels,out_channels,kernel_size,stride,padding,dilation,bias):
        super(conv2DBatchNormRelu,self).__init__()
        self.conv=nn.Conv2d(in_channels,out_channels,kernel_size,stride,padding,dilation,bias=bias)
        self.batchnorm=nn.BatchNorm2d(out_channels)
        self.relu=nn.ReLU(inplace=True)
    
    def forward(self,x):
        x=self.conv(x)
        x=self.batchnorm(x)
        outputs=self.relu(x)
        return outputs

In [122]:
class FeatureMap_convolution(nn.Module):
    def __init__(self):
        super(FeatureMap_convolution,self).__init__()
        in_channels,out_channels,kernel_size,stride,padding,dilation,bias=3,64,3,2,1,1,False
        self.cbnr_1=conv2DBatchNormRelu(in_channels,out_channels,kernel_size,stride,padding,dilation,bias)

        in_channels,out_channels,kernel_size,stride,padding,dilation,bias=64,64,3,1,1,1,False
        self.cbnr_2=conv2DBatchNormRelu(in_channels,out_channels,kernel_size,stride,padding,dilation,bias)
        
        in_channels,out_channels,kernel_size,stride,padding,dilation,bias=64,128,3,1,1,1,False
        self.cbnr_3=conv2DBatchNormRelu(in_channels,out_channels,kernel_size,stride,padding,dilation,bias)

        self.maxpool=nn.MaxPool2d(kernel_size=3,stride=2,padding=1)
    
    def forward(self,x):
        x=self.cbnr_1(x)
        x=self.cbnr_2(x)
        x=self.cbnr_3(x)
        outputs=self.maxpool(x)
        return outputs


### 2 ResidualBlockPSP
> skip connection(= shortcut connection = bypass): degradation 막기 위해 residual block
>    - bottleNeckPSP: bypass에 conv 적용
>    - bottleNeckIdentifyPSP: bypass에 conv 적용하지 않음
>
> dilation convolution: conv filter에 간격을 두어 적용
>    - kernel size: filter간 간격
>    - dilation convolution: filter 내 간격

In [123]:
class conv2DBatchNorm(nn.Module):
    def __init__(self,in_channels,out_channels,kernel_size,stride,padding,dilation,bias):
        super(conv2DBatchNorm,self).__init__()
        self.conv=nn.Conv2d(in_channels,out_channels,kernel_size,stride,padding,dilation,bias=bias)
        self.batchnorm=nn.BatchNorm2d(out_channels)
        
    def forward(self,x):
        x=self.conv(x)
        outputs=self.batchnorm(x)
        return outputs

In [124]:
class bottleNeckPSP(nn.Module): 
    def __init__(self,in_channels,mid_channels,out_channels,stride,dilation):
        super(bottleNeckPSP,self).__init__()
        self.cbr_1=conv2DBatchNormRelu(in_channels,mid_channels,1,1,0,1,False)
        self.cbr_2=conv2DBatchNormRelu(mid_channels,mid_channels,3,stride,dilation,dilation,False)
        self.cbr_3=conv2DBatchNormRelu(mid_channels,out_channels,1,1,0,1,False)
        
        self.cb_residual=conv2DBatchNorm(in_channels,out_channels,1,stride,0,1,False) 
        self.relu=nn.ReLU(inplace=True)
        
    def forward(self,x):
        conv=self.cbr_3(self.cbr_2(self.cbr_1(x)))
        residual=self.cb_residual(x)
        return self.relu(conv+residual)

In [125]:
class bottleNeckIdentifyPSP(nn.Module):
    def __init__(self,in_channels,mid_channels,dilation):
        super(bottleNeckIdentifyPSP,self).__init__()
        self.cbr_1=conv2DBatchNormRelu(in_channels,mid_channels,1,1,0,1,False)
        self.cbr_2=conv2DBatchNormRelu(mid_channels,mid_channels,3,1,dilation,dilation,False)
        self.cbr_3=conv2DBatchNorm(mid_channels,in_channels,1,1,0,1,False)
        self.relu=nn.ReLU(inplace=True)
        
    def forward(self,x):
        conv=self.cbr_3(self.cbr_2(self.cbr_1(x)))
        #x=residual
        return self.relu(conv+x) 

In [126]:
class ResidualBlockPSP(nn.Sequential): #nn.Sequential 상속 시 forward가 이미 구현되어 있음
    def __init__(self,n_blocks,in_channels,mid_channels,out_channels,stride,dilation):
        super(ResidualBlockPSP,self).__init__()
        self.add_module("block1",bottleNeckPSP(in_channels,mid_channels,out_channels,stride,dilation))
        for i in range(n_blocks-1):
            block_n="block"+str(i+2)
            self.add_module(block_n,bottleNeckIdentifyPSP(out_channels,mid_channels,dilation))

## Pyramid Pooling Module

In [127]:
class PyramidPooling(nn.Module):
    def __init__(self,in_channels,pool_sizes,height,width):
        super(PyramidPooling,self).__init__()
        self.height=height
        self.width=width
        
        out_channels=int(in_channels/len(pool_sizes))
        self.avpool=[]
        self.cbr=[]
        for pool_size in pool_sizes:
            self.avpool.append(nn.AdaptiveAvgPool2d(output_size=pool_size))
            self.cbr.append(conv2DBatchNormRelu(in_channels,out_channels,1,1,0,1,False))

    def forward(self,x):
        outList=[x]
        for pool,cbr in zip(self.avpool,self.cbr):
            out=cbr(pool(x))
            out=F.interpolate(out,size=(self.height,self.width),mode="bilinear",align_corners=True)
            outList.append(out)
        output=torch.cat(outList,dim=1)
        return output

## Decoder,AuxLoss Module
- decode tenser information
- pixel 별로 class 분류
- upsampling ($475\times 475$)

### DecoderPSPFeature Module

In [128]:
class DecodePSPFeature(nn.Module):
    def __init__(self,height,width,n_classes):
        super(DecodePSPFeature,self).__init__()
        self.height=height
        self.width=width

        self.cbr=conv2DBatchNormRelu(4096,512,3,1,1,1,False)
        self.dropout=nn.Dropout2d(p=0.1)
        self.classification=nn.Conv2d(512,n_classes,1,1,0)
    
    def forward(self,x):
        x=self.cbr(x)
        x=self.dropout(x)
        x=self.classification(x)
        output=F.interpolate(x,size=(self.height,self.width),mode="bilinear",align_corners=True)
        return output

### AuxiliaryPSPLayers Module

In [129]:
class AuxiliaryPSPLayers(nn.Module):
    def __init__(self,in_channels,height,width,n_classes):
        super(AuxiliaryPSPLayers,self).__init__()
        self.height=height
        self.width=width
        self.cbr=conv2DBatchNormRelu(in_channels,256,3,1,1,1,False)
        self.dropout=nn.Dropout2d(p=0.1)
        self.classification=nn.Conv2d(256,n_classes,1,1,0)

    def forward(self,x):
        x=self.cbr(x)
        x=self.dropout(x)
        x=self.classification(x)
        output=F.interpolate(x,size=(self.height,self.width),mode="bilinear",align_corners=True)
        return output

## PSPNet Class

In [130]:
class PSPNet(nn.Module):
    def __init__(self,n_classes):
        super(PSPNet,self).__init__()
        block_config=[3,4,6,3]
        img_size=475
        img_size_8=60 #img_size/8

        self.feature_conv=FeatureMap_convolution()
        self.feature_res_1=ResidualBlockPSP(block_config[0],128,64,256,1,1)
        self.feature_res_2=ResidualBlockPSP(block_config[1],256,128,512,2,1)
        self.feature_dilated_res_1=ResidualBlockPSP(block_config[2],512,256,1024,1,2)
        self.feature_dilated_res_2=ResidualBlockPSP(block_config[3],1024,512,2048,1,4)

        self.pyramid_pooling=PyramidPooling(2048,[6,3,2,1],img_size_8,img_size_8)
        self.decode_feature=DecodePSPFeature(img_size,img_size,n_classes)
        self.aux=AuxiliaryPSPLayers(1024,img_size,img_size,n_classes)
    
    def forward(self,x):

        x=self.feature_conv(x)
        x=self.feature_res_1(x)
        x=self.feature_res_2(x)
        x=self.feature_dilated_res_1(x)
        output_aux=self.aux(x)

        x=self.feature_dilated_res_2(x)
        x=self.pyramid_pooling(x)
        output=self.decode_feature(x)

        return (output,output_aux)



In [131]:
net=PSPNet(n_classes=21)
net

PSPNet(
  (feature_conv): FeatureMap_convolution(
    (cbnr_1): conv2DBatchNormRelu(
      (conv): Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (batchnorm): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
    (cbnr_2): conv2DBatchNormRelu(
      (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (batchnorm): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
    (cbnr_3): conv2DBatchNormRelu(
      (conv): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (batchnorm): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  )
  (feature_res_1): ResidualBlockPSP(
    (block1): bottleNec

In [132]:
'''check'''
batch_size=2
dummy_img=torch.rand(batch_size,3,475,475)
outputs=net(dummy_img)
print(outputs)

(tensor([[[[-0.0914, -0.0021,  0.0873,  ..., -0.1116, -0.1086, -0.1055],
          [-0.0872, -0.0076,  0.0719,  ..., -0.0750, -0.0671, -0.0593],
          [-0.0829, -0.0132,  0.0565,  ..., -0.0383, -0.0257, -0.0130],
          ...,
          [ 0.3204,  0.3839,  0.4474,  ...,  0.2155,  0.2086,  0.2018],
          [ 0.3114,  0.3884,  0.4655,  ...,  0.2434,  0.2361,  0.2289],
          [ 0.3023,  0.3929,  0.4835,  ...,  0.2712,  0.2637,  0.2561]],

         [[ 0.4771,  0.4641,  0.4511,  ...,  0.9820,  0.9808,  0.9796],
          [ 0.5261,  0.5214,  0.5167,  ...,  0.9641,  0.9743,  0.9846],
          [ 0.5751,  0.5788,  0.5824,  ...,  0.9461,  0.9679,  0.9896],
          ...,
          [ 1.1746,  1.0974,  1.0202,  ...,  1.1683,  1.1517,  1.1351],
          [ 1.1943,  1.1152,  1.0362,  ...,  1.1776,  1.1525,  1.1274],
          [ 1.2140,  1.1330,  1.0521,  ...,  1.1870,  1.1533,  1.1196]],

         [[-0.0500, -0.1125, -0.1749,  ..., -0.3345, -0.3112, -0.2879],
          [-0.0606, -0.1170, 