In [3]:
from collections import OrderedDict

import torch.nn.functional as F
from torch import nn
from torch.utils.model_zoo import load_url
from torchvision import models
from torchvision.ops import misc

In [4]:
!git clone https://github.com/Okery/PyTorch-Simple-MaskRCNN.git mrcnn

Cloning into 'mrcnn'...
remote: Enumerating objects: 840, done.[K
remote: Counting objects: 100% (153/153), done.[K
remote: Compressing objects: 100% (152/152), done.[K
remote: Total 840 (delta 88), reused 2 (delta 0), pack-reused 687[K
Receiving objects: 100% (840/840), 4.78 MiB | 22.44 MiB/s, done.
Resolving deltas: 100% (485/485), done.


In [10]:
# Modules
from mrcnn.pytorch_mask_rcnn.model.utils import AnchorGenerator
from mrcnn.pytorch_mask_rcnn.model.rpn import RPNHead, RegionProposalNetwork
from mrcnn.pytorch_mask_rcnn.model.pooler import RoIAlign
from mrcnn.pytorch_mask_rcnn.model.roi_heads import RoIHeads
from mrcnn.pytorch_mask_rcnn.model.transform import Transformer

In [None]:
# ResNet Backbone
class ResBackbone(nn.Module):
    def __init__(self, backbone_name, pretrained):
        super().__init__()
        body = models.resnet.__dict__[backbone_name](
            pretrained=pretrained, norm_layer=misc.FrozenBatchNorm2d)
        
        for name, parameter in body.named_parameters():
            if 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
                parameter.requires_grad_(False)
                
        self.body = nn.ModuleDict(d for i, d in enumerate(body.named_children()) if i < 8)
        in_channels = 2048
        self.out_channels = 256
        
        self.inner_block_module = nn.Conv2d(in_channels, self.out_channels, 1)
        self.layer_block_module = nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1)
        
        for m in self.children():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_uniform_(m.weight, a=1)
                nn.init.constant_(m.bias, 0)
        
    def forward(self, x):
        for module in self.body.values():
            x = module(x)
        x = self.inner_block_module(x)
        x = self.layer_block_module(x)
        return x

In [None]:
class MaskRCNN(nn.Module):
    def __init__(self, backbone, num_classes,
                 # RPN parameters
                 rpn_fg_iou_thresh=0.7, rpn_bg_iou_thresh=0.3,
                 rpn_num_samples=256, rpn_positive_fraction=0.5,
                 rpn_reg_weights=(1., 1., 1., 1.),
                 rpn_pre_nms_top_n_train=2000, rpn_pre_nms_top_n_test=1000,
                 rpn_post_nms_top_n_train=2000, rpn_post_nms_top_n_test=1000,
                 rpn_nms_thresh=0.7,
                 # RoIHeads parameters
                 box_fg_iou_thresh=0.5, box_bg_iou_thresh=0.5,
                 box_num_samples=512, box_positive_fraction=0.25,
                 box_reg_weights=(10., 10., 5., 5.),
                 box_score_thresh=0.1, box_nms_thresh=0.6, box_num_detections=100):
      
        super().__init__()
        self.backbone = backbone
        out_channels = backbone.out_channels
        
        #------------- RPN --------------------------
        # 1. Anchor Generator
        anchor_sizes = (128, 256, 512)
        anchor_ratios = (0.5, 1, 2)
        num_anchors = len(anchor_sizes) * len(anchor_ratios)
        rpn_anchor_generator = AnchorGenerator(anchor_sizes, anchor_ratios)
        
        # 2. RPN Head
        rpn_head = RPNHead(out_channels, num_anchors)
        
        rpn_pre_nms_top_n = dict(training=rpn_pre_nms_top_n_train, testing=rpn_pre_nms_top_n_test)  # 남겨둘 proposal의 수 (NMS 전)
        rpn_post_nms_top_n = dict(training=rpn_post_nms_top_n_train, testing=rpn_post_nms_top_n_test)  # 남겨둘 proposal의 수 (NMS 후)

        # 3. Region Proposal Network - 최적의 region proposal(RoI) bbox 반환
        self.rpn = RegionProposalNetwork(
             rpn_anchor_generator, rpn_head,  # anchor generator, RPN head (분류 logits, bbox 출력)
             rpn_fg_iou_thresh, rpn_bg_iou_thresh,  # IOU ths (fg: pos / bg: neg)
             rpn_num_samples, rpn_positive_fraction,  # 샘플링한 anchor 수, pos/neg 비율
             rpn_reg_weights,
             rpn_pre_nms_top_n, rpn_post_nms_top_n, rpn_nms_thresh)
        
        #------------ RoIHeads --------------------------
        box_roi_pool = RoIAlign(output_size=(7, 7), sampling_ratio=2)  # 7x7의 고정된 크기로 align (detection)
        
        # 1. Class branch
        resolution = box_roi_pool.output_size[0]
        in_channels = out_channels * resolution ** 2
        mid_channels = 1024
        box_predictor = FastRCNNPredictor(in_channels, mid_channels, num_classes)  # detection (objectness score, bbox)
        
        # 2. Mask branch
        self.head = RoIHeads(
             box_roi_pool, box_predictor,  # feature map에 대해 RPN proposal을 바탕으로 pooling & mask prediction
             box_fg_iou_thresh, box_bg_iou_thresh,  # IOU ths (fg: pos / bg: neg)
             box_num_samples, box_positive_fraction,  # 샘플링한 proposal 수, pos/neg 비율
             box_reg_weights,
             box_score_thresh, box_nms_thresh, box_num_detections)  # 분류 score가 box_score_thresh보다 큰 proposal에 대해서만 inference
        
        self.head.mask_roi_pool = RoIAlign(output_size=(14, 14), sampling_ratio=2)  # 14x14의 고정된 크기로 align (mask)
        
        # 3. Prediction
        layers = (256, 256, 256, 256)
        dim_reduced = 256
        # backbone output channel, layers=(256)x4, 256, 91
        self.head.mask_predictor = MaskRCNNPredictor(out_channels, layers, dim_reduced, num_classes) 
        
        #------------ Transformer --------------------------
        self.transformer = Transformer(
            min_size=800, max_size=1333, 
            image_mean=[0.485, 0.456, 0.406], 
            image_std=[0.229, 0.224, 0.225])
        
    def forward(self, image, target=None):
        ori_image_shape = image.shape[-2:]
        
        # image pre-processing
        image, target = self.transformer(image, target)
        image_shape = image.shape[-2:]
        
        # 1. Backbone -> feature map
        feature = self.backbone(image)  
        
        # 2.RPN -> proposals
        # 2.1 RPN head가 feature map을 입력으로 받아 objectness, bbox reg 수행
        # 2.2 Objectness가 높은 상위 n개의 bbox들에 대해 NMS를 수행하여 최적의 proposal만 남김 
        # 2.3 샘플링 비율을 맞춰서 GT와의 loss 계산
        proposal, rpn_losses = self.rpn(feature, image_shape, target)

        # 3. RoI head
        # 3.1 RoI Align을 통해 proposal bbox에 해당하는 feature 계산
        # 3.2 해당 feature를 바탕으로 예측 수행
        result, roi_losses = self.head(feature, proposal, image_shape, target)  
        
        if self.training:
            return dict(**rpn_losses, **roi_losses)
        else:
            result = self.transformer.postprocess(result, image_shape, ori_image_shape)
            return result

In [None]:
def maskrcnn_resnet50(pretrained, num_classes, pretrained_backbone=True):
    """
    Constructs a Mask R-CNN model with a ResNet-50 backbone.
    
    Arguments:
        pretrained (bool): If True, returns a model pre-trained on COCO train2017.
        num_classes (int): number of classes (including the background).
    """
    if pretrained:
        backbone_pretrained = False

    # Model setting    
    backbone = ResBackbone('resnet50', pretrained_backbone)
    model = MaskRCNN(backbone, num_classes)
    
    if pretrained:
        model_urls = {
            'maskrcnn_resnet50_fpn_coco':
                'https://download.pytorch.org/models/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth',
        }
        # 학습된 가중치 load
        model_state_dict = load_url(model_urls['maskrcnn_resnet50_fpn_coco'])
        
        pretrained_msd = list(model_state_dict.values())
        del_list = [i for i in range(265, 271)] + [i for i in range(273, 279)]  # 출력층에 가까운 layer는 삭제
        for i, del_idx in enumerate(del_list):
            pretrained_msd.pop(del_idx - i)

        msd = model.state_dict()
        skip_list = [271, 272, 273, 274, 279, 280, 281, 282, 293, 294]
        if num_classes == 91:
            skip_list = [271, 272, 273, 274]
        for i, name in enumerate(msd):
            if i in skip_list:
                continue
            msd[name].copy_(pretrained_msd[i])
            
        model.load_state_dict(msd)
    
    return model