## 1. Backbone Network

In [1]:
# ResNet
import torch
import torch.nn as nn
from torchvision.models.resnet import ResNet, Bottleneck

from collections import OrderedDict

# DeepLab
from torchvision.models._utils import IntermediateLayerGetter
from torch.hub import load_state_dict_from_url
from torchvision.models.segmentation._utils import _SimpleSegmentationModel
from torchvision.models.segmentation.deeplabv3 import DeepLabHead
from torchvision.models.segmentation.fcn import FCNHead
from torchvision.models import resnet50, resnet101

In [2]:
class ResNetXX3(ResNet):
    def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
                 groups=1, width_per_group=64, replace_stride_with_dilation=None,
                 norm_layer=None):
        super().__init__(block, layers, num_classes, zero_init_residual,
                         groups, width_per_group, replace_stride_with_dilation,
                         norm_layer)
        self.conv1 = nn.Conv2d(3, 64, 3, 1, 1, bias=False)
        nn.init.kaiming_normal_(self.conv1.weight, mode='fan_out', nonlinearity='relu')


def resnet53(pretrained=False, progress=True, **kwargs):
    r"""ResNet-50 model from
    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    return ResNetXX3(Bottleneck, [3, 4, 6, 3], **kwargs)


def resnet103(pretrained=False, progress=True, **kwargs):
    r"""ResNet-101 model from
    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    return ResNetXX3(Bottleneck, [3, 4, 23, 3], **kwargs)

In [12]:
class SmallDeepLab(_SimpleSegmentationModel):
    def forward(self, input_):
        result = self.backbone(input_)
        result["coarse"] = self.classifier(result["out"])
        return result

def deeplabv3(pretrained=False, resnet="res103", head_in_ch=2048, num_classes=21):
    resnet = {
        "res53":  resnet53,
        "res103": resnet103,
        "res50":  resnet50,
        "res101": resnet101
    }[resnet]

    net = SmallDeepLab(
        # Backbone
        backbone = IntermediateLayerGetter(
            resnet(pretrained=True, replace_stride_with_dilation=[False, True, True]),
            return_layers={'layer2': 'res2', 'layer4': 'out'}  
        ),

        # Classifier (layer4의 출력을 입력으로 받음)
        classifier = DeepLabHead(head_in_ch, num_classes)  # coarse mask/prediction 출력
    )

    return net

In [13]:
net = deeplabv3(False).cuda()
x = torch.randn(3, 3, 256, 256).cuda()
result = net(x)
#net

In [14]:
res2 = result['res2']  
out = result['coarse'] 

print('layer2 :', res2.shape)  # intermediate features
print('coarse :', out.shape)  # coarse features (mask)

layer2 : torch.Size([3, 512, 64, 64])
coarse : torch.Size([3, 21, 64, 64])


## 2. Point Selection



Train : feature map에서 학습할 N개의 point 샘플링

Inference : N개의 가장 **불확실한** (이웃 값과 값이 달라질 가능성이 높은, binary mask의 확률이 0.5에 가까운) point 샘플링

<br/>

**Sampling Strategy (Training)**

![image](https://user-images.githubusercontent.com/44194558/158567732-a03c3005-cf54-4499-bdc5-b1244ed60de1.png)

참고 : https://doooob.tistory.com/79

In [4]:
import torch
import torch.nn.functional as F

In [5]:
def point_sample(input, point_coords, **kwargs):
    """
    Input:
        input (Tensor): (N, C, H, W) - (3, 21, 64, 64) coarse mask
        point_coords (Tensor): (N, P, 2) - 선택된 모든 point들의 normalize된 좌표 정보

    Output:
        output (Tensor): (N, C, P) - 선택된 모든 P개의 point들의 point-wise feature (mask에 대해 bi-linear interpolation을 통해 계산됨)
                         `torch.nn.functional.grid_sample` 함수를 이용하여 계산 (https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html)
    """
    add_dim = False
    if point_coords.dim() == 3:
        add_dim = True
        point_coords = point_coords.unsqueeze(2)  # (3, 96, 2) -> (3, 96, 1, 2)
    
    # coarse mask에서 point_coords에 저장된 point 위치 정보를 이용하여 point 별로 feature vector 계산
    output = F.grid_sample(input, 2.0 * point_coords - 1.0, **kwargs)  # (3, 21, 96, 1)
    if add_dim:
        output = output.squeeze(3)  # (3, 21, 96) - 96개 point들의 21차원 feature vector

    return output

In [6]:
@torch.no_grad()
def sampling_points(mask, N, k=3, beta=0.75, training=True):
    """
    Input:
        mask(Tensor): [B, C, H, W] - Coarse mask
        N(int): `During training we sample as many points as there are on a stride 16 feature map of the input`

        'Sampling strategy 참고'
         - k(int): Over generation multiplier (k > 1) 
         - beta(float): ratio of importance points

        training(bool): flag

    Return:
        selected_point(Tensor) : flattened indexing points [B, num_points, 2]
    """
    
    assert mask.dim() == 4,"Dim must be N(Batch)CHW"
    device = mask.device
    B, _, H, W = mask.shape  # coarse mask/prediction - (3, 21, 64, 64)
    mask, _ = mask.sort(1, descending=True)  # 채널 차원으로 내림차순 정렬 (importance sampling의 uncertatinty 계산과 관련 있음)
    
    # Inference
    if not training:
        H_step, W_step = 1 / H, 1 / W
        N = min(H * W, N)
        # uncertatinty 계산 
        uncertainty_map = -1 * (mask[:, 0] - mask[:, 1])
        _, idx = uncertainty_map.view(B, -1).topk(N, dim=1)  # N개의 불확실한 point 선택

        points = torch.zeros(B, N, 2, dtype=torch.float, device=device)
        points[:, :, 0] = W_step / 2.0 + (idx  % W).to(torch.float) * W_step
        points[:, :, 1] = H_step / 2.0 + (idx // W).to(torch.float) * H_step

        return idx, points
    
    # 1. Over-generation : kN개의 point를 무작위로 샘플링 (N=32, k=3) 
    over_generation = torch.randn(B, k * N, 2, device=device)  # (3, 96, 2) - 96개 point의 좌표 정보
    over_generation_map = point_sample(mask, over_generation, align_corners=False)  # (3, 21, 96) - point wise features
    
    # 2. Importance sampling - 96개의 over-sampled point중 불확실성의 정도가 높은 24개(Nxbeta)의 point 선별
    # topk : 주어진 차원을 따라 텐서 중 특정 개수의 가장 큰 요소 반환 (https://runebook.dev/ko/docs/pytorch/generated/torch.topk)
    uncertainty_map = -1 * (over_generation_map[:, 0] - over_generation_map[:, 1])  # the diff btw the most confident & second most confident class probabilities - (3, 96)
    _, idx = uncertainty_map.topk(int(beta * N), -1)  # (3, 24) - 96개 point중 24개 선택

    shift = (k * N) * torch.arange(B, dtype=torch.long, device=device)
    idx += shift[:, None]
    importance = over_generation.view(-1, 2)[idx.view(-1), :].view(B, int(beta * N), 2)  # (3, 24, 2)

    # 3. Coverage : uniform dist로부터 나머지 8개(Nx(1-beta)) point 선별 
    coverage = torch.rand(B, N - int(beta * N), 2, device=device)  # (3, 8, 2)

    return torch.cat([importance, coverage], 1).to(device)  # 최종적으로 샘플링된 32개의 point

## 3. PointRend

### 3.1 PointHead

`Given the point-wise feature representation at each selected point, PointRend makes point-wise segmentation predictions using a simple MLP.`

<br/>

2.Point Selection의 sampling_points, point_sample 활용

In [7]:
class PointHead(nn.Module):
    def __init__(self, in_c=533, num_classes=21, k=3, beta=0.75):
        super().__init__()
        self.mlp = nn.Conv1d(in_c, num_classes, 1)
        self.k = k
        self.beta = beta

    def forward(self, x, res2, out):
        """
        Input:
            x : (B, C, H, W)
            res2 : backbone network의 layer2 출력
            out : coarse mask (prediction)
        """
        if not self.training:
            return self.inference(x, res2, out)
        
        # sampling_points(mask, N, k=3, beta=0.75, training=True):
        points = sampling_points(out, x.shape[-1] // 16, self.k, self.beta)  # 샘플링된 32개 point
        
        # point-wise features
        coarse = point_sample(out, points, align_corners=False)  # (3, 21, 32)  
        fine = point_sample(res2, points, align_corners=False)  # (3, 512, 32)
        
        # coarse prediction feature, fine-grained feature 결합
        feature_representation = torch.cat([coarse, fine], dim=1)  # (3, 533, 32)

        rend = self.mlp(feature_representation)

        return {"rend": rend,  # (3, 21, 16) 
                "points": points}  # (3, 16, 2)

### 3.2 Network

In [8]:
class PointRend(nn.Module):
    def __init__(self, backbone, head):
        super().__init__()
        self.backbone = backbone
        self.head = head  # PointHead

    def forward(self, x):
        result = self.backbone(x)  # backbone network 통과
        result.update(self.head(x, result["res2"], result["coarse"]))
        
        return result

In [9]:
x = torch.randn(3, 3, 256, 256).cuda()
net = PointRend(deeplabv3(False), PointHead()).cuda()
out = net(x)

In [11]:
for k, v in out.items():
    print(k, ':', v.shape)

res2 : torch.Size([3, 512, 64, 64])
out : torch.Size([3, 2048, 64, 64])
coarse : torch.Size([3, 21, 64, 64])
rend : torch.Size([3, 21, 16])
points : torch.Size([3, 16, 2])
