## 1. Anchor Generator


In [None]:
import torch

In [None]:
class AnchorGenerator:
    '''이미지의 각 grid cell에 서로 다른 종횡비를 갖는 9개의 anchor box 정의'''
    def __init__(self, sizes, ratios):
        self.sizes = sizes  # anchor box의 w, h - [128, 256, 512]
        self.ratios = ratios  # anchor box의 w, h 길이 비율 - [0.5, 1, 2]
        
        self.cell_anchor = None
        self._cache = {}

    def set_cell_anchor(self, dtype, device): 
        '''scale_sizes, aspect_ratios를 고려한 pre-defined anchor box 9개 정의'''
        if self.cell_anchor is not None:
            return 

        sizes = torch.tensor(self.sizes, dtype=dtype, device=device)  # tensor([128, 256, 512])
        ratios = torch.tensor(self.ratios, dtype=dtype, device=device)  # tensor([0.5, 1, 2])

        h_ratios = torch.sqrt(ratios)
        w_ratios = 1 / h_ratios

        hs = (sizes[:, None] * h_ratios[None, :]).view(-1)
        ws = (sizes[:, None] * w_ratios[None, :]).view(-1)

        self.cell_anchor = torch.stack([-ws, -hs, ws, hs], dim=1) / 2  # (9, 4) - 9개의 anchor box 정보

    def grid_anchor(self, grid_size, stride):
        '''각 grid cell마다 중심을 기준으로 9개의 anchor box 생성'''
        dtype, device = self.cell_anchor.dtype, self.cell_anchor.device

        # 각 grid cell의 서로 다른 중심 좌표 반영 (stride 이용)
        shift_x = torch.arange(0, grid_size[1], dtype=dtype, device=device) * stride[1]
        shift_y = torch.arange(0, grid_size[0], dtype=dtype, device=device) * stride[0]

        y, x = torch.meshgrid(shift_y, shift_x)  # 중심 좌표 기준으로 격자
        x = x.reshape(-1)
        y = y.reshape(-1)
        shift = torch.stack((x, y, x, y), dim=1).reshape(-1, 1, 4)

        anchor = (shift + self.cell_anchor).reshape(-1, 4)

        return anchor

    def cached_grid_anchor(self, grid_size, stride):
        key = grid_size + stride
        if key in self._cache:
            return self._cache[key]
        anchor = self.grid_anchor(grid_size, stride)
        
        if len(self._cache) >= 3:
            self._cache.clear()
        self._cache[key] = anchor

        return anchor

    def __call__(self, feature, image_size):
        dtype, device = feature.dtype, feature.device
        grid_size = tuple(feature.shape[-2:])  # w, h
        stride = tuple(int(i / g) for i, g in zip(image_size, grid_size))  # 각 grid cell간의 가로, 세로 간격
        
        self.set_cell_anchor(dtype, device)
        
        anchor = self.cached_grid_anchor(grid_size, stride)
        
        return anchor

## 2. RPN


In [None]:
import torch
import torch.nn.functional as F
from torch import nn

### 2.1 Utils

In [None]:
class Matcher:
    def __init__(self, high_threshold, low_threshold, allow_low_quality_matches=False):
        self.high_threshold = high_threshold
        self.low_threshold = low_threshold
        self.allow_low_quality_matches = allow_low_quality_matches
    
    def __call__(self, iou):
        """
        Arguments:
            iou (Tensor[M, N]): 각 M개의 gt box의 N개의 pred box에 대한 iou 값 (pairwise-quality)

        Returns:
            label (Tensor[N]): N개의 pred box에 대한 label 예측 (1, 0, -1)
            matched_idx (Tensor[N]): N개의 pred box에 매칭되는 gt box의 index
        """
        value, matched_idx = iou.max(dim=0)  # 행 기준 최댓값 - 각 gt box에 대해 가장 큰 iou 값을 갖는 pred box와 해당 iou값
        label = torch.full((iou.shape[1],), -1, dtype=torch.float, device=iou.device)

        # label 할당 (1: positive / 0: negative / -1: ignore)
        label[value >= self.high_threshold] = 1
        label[value < self.low_threshold] = 0

        if self.allow_low_quality_matches:
            highest_quality = iou.max(dim=1)[0]
            gt_pred_pairs = torch.where(iou == highest_quality[:, None])[1]
            label[gt_pred_pairs] = 1

        return label, matched_idx

In [None]:
class BalancedPositiveNegativeSampler:
    '''Class imbalance 해결용 (negative label이 positive보다 훨씬 많기 때문에)'''
    def __init__(self, num_samples, positive_fraction):
        self.num_samples = num_samples  # 학습하는데 사용하는 anchor의 수
        self.positive_fraction = positive_fraction  # 샘플링되는 anchor들 중 positive label을 갖는 anchor의 비율

    def __call__(self, label):
        positive = torch.where(label == 1)[0]
        negative = torch.where(label == 0)[0]
        
        # 설정한 positive_fraction에 의해 positive/negative의 샘플링 수가 정의됨
        num_pos = int(self.num_samples * self.positive_fraction)
        num_pos = min(positive.numel(), num_pos)
        num_neg = self.num_samples - num_pos
        num_neg = min(negative.numel(), num_neg)

        pos_perm = torch.randperm(positive.numel(), device=positive.device)[:num_pos]
        neg_perm = torch.randperm(negative.numel(), device=negative.device)[:num_neg]

        pos_idx = positive[pos_perm]
        neg_idx = negative[neg_perm]

        return pos_idx, neg_idx

In [None]:
import math
import torch

class BoxCoder:
    def __init__(self, weights, bbox_xform_clip=math.log(1000. / 16)):
        self.weights = weights
        self.bbox_xform_clip = bbox_xform_clip

    def encode(self, reference_box, proposal):
        """
        Encode a set of proposals with respect to some
        reference boxes
        Arguments:
            reference_boxes (Tensor[N, 4]): reference boxes
            proposals (Tensor[N, 4]): boxes to be encoded
        """
        
        width = proposal[:, 2] - proposal[:, 0]
        height = proposal[:, 3] - proposal[:, 1]
        ctr_x = proposal[:, 0] + 0.5 * width
        ctr_y = proposal[:, 1] + 0.5 * height

        gt_width = reference_box[:, 2] - reference_box[:, 0]
        gt_height = reference_box[:, 3] - reference_box[:, 1]
        gt_ctr_x = reference_box[:, 0] + 0.5 * gt_width
        gt_ctr_y = reference_box[:, 1] + 0.5 * gt_height

        dx = self.weights[0] * (gt_ctr_x - ctr_x) / width
        dy = self.weights[1] * (gt_ctr_y - ctr_y) / height
        dw = self.weights[2] * torch.log(gt_width / width)
        dh = self.weights[3] * torch.log(gt_height / height)

        delta = torch.stack((dx, dy, dw, dh), dim=1)
        return delta

    def decode(self, delta, box):
        """
        From a set of original boxes and encoded relative box offsets,
        get the decoded boxes.
        Arguments:
            delta (Tensor[N, 4]): encoded boxes.
            boxes (Tensor[N, 4]): reference boxes.
        """  
        dx = delta[:, 0] / self.weights[0]
        dy = delta[:, 1] / self.weights[1]
        dw = delta[:, 2] / self.weights[2]
        dh = delta[:, 3] / self.weights[3]

        dw = torch.clamp(dw, max=self.bbox_xform_clip)
        dh = torch.clamp(dh, max=self.bbox_xform_clip)

        width = box[:, 2] - box[:, 0]
        height = box[:, 3] - box[:, 1]
        ctr_x = box[:, 0] + 0.5 * width
        ctr_y = box[:, 1] + 0.5 * height

        pred_ctr_x = dx * width + ctr_x
        pred_ctr_y = dy * height + ctr_y
        pred_w = torch.exp(dw) * width
        pred_h = torch.exp(dh) * height

        xmin = pred_ctr_x - 0.5 * pred_w
        ymin = pred_ctr_y - 0.5 * pred_h
        xmax = pred_ctr_x + 0.5 * pred_w
        ymax = pred_ctr_y + 0.5 * pred_h

        target = torch.stack((xmin, ymin, xmax, ymax), dim=1)
        return target

In [None]:
def process_box(box, score, image_shape, min_size):
    """
    Clip boxes in the image size and remove boxes which are too small.
    """
    
    box[:, [0, 2]] = box[:, [0, 2]].clamp(0, image_shape[1]) 
    box[:, [1, 3]] = box[:, [1, 3]].clamp(0, image_shape[0]) 

    w, h = box[:, 2] - box[:, 0], box[:, 3] - box[:, 1]
    keep = torch.where((w >= min_size) & (h >= min_size))[0]
    box, score = box[keep], score[keep]
    return box, score

In [None]:
def nms(box, score, threshold):
    """
    Arguments:
        box (Tensor[N, 4])
        score (Tensor[N]): scores of the boxes.
        threshold (float): iou threshold.
    Returns: 
        keep (Tensor): indices of boxes filtered by NMS.
    """
    
    return torch.ops.torchvision.nms(box, score, threshold)

In [None]:
def box_iou(box_a, box_b):
    """
    Arguments:
        boxe_a (Tensor[N, 4])
        boxe_b (Tensor[M, 4])
    Returns:
        iou (Tensor[N, M]): the NxM matrix containing the pairwise
            IoU values for every element in box_a and box_b
    """
    
    lt = torch.max(box_a[:, None, :2], box_b[:, :2])
    rb = torch.min(box_a[:, None, 2:], box_b[:, 2:])

    wh = (rb - lt).clamp(min=0)
    inter = wh[:, :, 0] * wh[:, :, 1]
    area_a = torch.prod(box_a[:, 2:] - box_a[:, :2], 1)
    area_b = torch.prod(box_b[:, 2:] - box_b[:, :2], 1)
    
    return inter / (area_a[:, None] + area_b - inter)

### 2.1 RPN Head

In [None]:
class RPNHead(nn.Module):
    def __init__(self, in_channels, num_anchors):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, in_channels, 3, 1, 1)
        self.cls_logits = nn.Conv2d(in_channels, num_anchors, 1)  # class label 예측
        self.bbox_pred = nn.Conv2d(in_channels, 4 * num_anchors, 1)  # bbox 예측
        
        for l in self.children():
            nn.init.normal_(l.weight, std=0.01)
            nn.init.constant_(l.bias, 0)
            
    def forward(self, x):
        x = F.relu(self.conv(x))
        logits = self.cls_logits(x)
        bbox_reg = self.bbox_pred(x)

        return logits, bbox_reg  # objectness, pred_bbox_delta

### 2.1 RPN Network

In [None]:
class RegionProposalNetwork(nn.Module):
    def __init__(self, anchor_generator, head, 
                 fg_iou_thresh, bg_iou_thresh,
                 num_samples, positive_fraction,
                 reg_weights,
                 pre_nms_top_n, post_nms_top_n, nms_thresh):
        super().__init__()
        
        self.anchor_generator = anchor_generator  # anchor generator layer
        self.head = head  # head(prediction) layer
        
        # utils
        self.proposal_matcher = Matcher(fg_iou_thresh, bg_iou_thresh, allow_low_quality_matches=True)
        self.fg_bg_sampler = BalancedPositiveNegativeSampler(num_samples, positive_fraction)
        self.box_coder = BoxCoder(reg_weights)  # bbox offset 예측을 반영하여 좌표 변환
        
        self._pre_nms_top_n = pre_nms_top_n  # NMS 적용 전에 남겨둘 proposal 수
        self._post_nms_top_n = post_nms_top_n  # NMS 적용 후에 남겨둘 proposal 수
        self.nms_thresh = nms_thresh  # RPN proposal에 대한 post-processing의 ths  
        self.min_size = 1

    def create_proposal(self, anchor, objectness, pred_bbox_delta, image_shape):
        if self.training:
            pre_nms_top_n = self._pre_nms_top_n['training']
            post_nms_top_n = self._post_nms_top_n['training']
        else:
            pre_nms_top_n = self._pre_nms_top_n['testing']
            post_nms_top_n = self._post_nms_top_n['testing']             

        pre_nms_top_n = min(objectness.shape[0], pre_nms_top_n)
        top_n_idx = objectness.topk(pre_nms_top_n)[1]  # 산출된 bbox들 중 확률이 높은 n개 사용

        score = objectness[top_n_idx]  
        proposal = self.box_coder.decode(pred_bbox_delta[top_n_idx], anchor[top_n_idx])  # 상위 n개의 proposal

        proposal, score = process_box(proposal, score, image_shape, self.min_size)
        keep = nms(proposal, score, self.nms_thresh)[:post_nms_top_n]  # 최적의 proposal만 추출
        proposal = proposal[keep]

        return proposal

    def compute_loss(self, objectness, pred_bbox_delta, gt_box, anchor):
        iou = box_iou(gt_box, anchor)  # iou 계산
        label, matched_idx = self.proposal_matcher(iou)  # 각 pred box의 label 예측, pred box에 대응되는 gt box의 idx 
        
        # box regression target
        pos_idx, neg_idx = self.fg_bg_sampler(label)  # class 비율을 맞춰서 샘플링
        idx = torch.cat((pos_idx, neg_idx))
        regression_target = self.box_coder.encode(gt_box[matched_idx[pos_idx]], anchor[pos_idx])
        
        # loss
        objectness_loss = F.binary_cross_entropy_with_logits(objectness[idx], label[idx])
        box_loss = F.l1_loss(pred_bbox_delta[pos_idx], regression_target, reduction='sum') / idx.numel()

        return objectness_loss, box_loss

    def forward(self, feature, image_shape, target=None):
        if target is not None:
            gt_box = target['boxes']
        anchor = self.anchor_generator(feature, image_shape)
        
        objectness, pred_bbox_delta = self.head(feature)  # feature를 입력으로 받아 detection 예측 수행
        objectness = objectness.permute(0, 2, 3, 1).flatten()
        pred_bbox_delta = pred_bbox_delta.permute(0, 2, 3, 1).reshape(-1, 4)

        proposal = self.create_proposal(anchor, objectness.detach(), pred_bbox_delta.detach(), image_shape)  # 최적의 region proposals (bbox 좌표 형식)
        if self.training:  # 학습시 GT, 최적의 proposal간의 loss 계산
            objectness_loss, box_loss = self.compute_loss(objectness, pred_bbox_delta, gt_box, anchor)
            return proposal, dict(rpn_objectness_loss=objectness_loss, rpn_box_loss=box_loss)
        
        return proposal, {}

## 3. RoI Heads

### 3.1 RoI Align

In [1]:
import math
import torch

In [None]:
# https://pytorch.org/vision/stable/_modules/torchvision/ops/roi_align.html
def roi_align(features, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio):
    if torch.__version__ >= "1.5.0":
        return torch.ops.torchvision.roi_align(
            features, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, False)  # pooled_h x pooled_w 크기의 align 출력 반환
        
    else:
        return torch.ops.torchvision.roi_align(
            features, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio)

In [None]:
class RoIAlign:
    def __init__(self, output_size, sampling_ratio):
        self.output_size = output_size  # 고정된 크기의 RoIAlign 출력
        self.sampling_ratio = sampling_ratio
        self.spatial_scale = None
        
    def setup_scale(self, feature_shape, image_shape):
        if self.spatial_scale is not None:
            return
        
        possible_scales = []
        for s1, s2 in zip(feature_shape, image_shape):
            scale = 2 ** int(math.log2(s1 / s2))
            possible_scales.append(scale)
        assert possible_scales[0] == possible_scales[1]
        self.spatial_scale = possible_scales[0]
        
    def __call__(self, feature, proposal, image_shape):
        """
        Arguments:
            feature (Tensor[N, C, H, W])
            proposal (Tensor[K, 4])
            image_shape (Torch.Size([H, W]))

        Returns:
            output (Tensor[K, C, self.output_size[0], self.output_size[1]])
        """
        idx = proposal.new_full((proposal.shape[0], 1), 0)
        roi = torch.cat((idx, proposal), dim=1)
        
        self.setup_scale(feature.shape[-2:], image_shape)
        
        # feature map에서 RoI bbox 좌표에 해당하는 영역에 대해 align 연산 수행
        return roi_align(feature.to(roi), roi, self.spatial_scale, self.output_size[0], self.output_size[1], self.sampling_ratio)

### 3.2 Predictiors

In [None]:
from collections import OrderedDict

In [None]:
class FastRCNNPredictor(nn.Module):
    def __init__(self, in_channels, mid_channels, num_classes):
        super().__init__()
        self.fc1 = nn.Linear(in_channels, mid_channels)
        self.fc2 = nn.Linear(mid_channels, mid_channels)
        self.cls_score = nn.Linear(mid_channels, num_classes)
        self.bbox_pred = nn.Linear(mid_channels, num_classes * 4)
        
    def forward(self, x):
        x = x.flatten(start_dim=1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))

        score = self.cls_score(x)
        bbox_delta = self.bbox_pred(x)

        return score, bbox_delta        

In [None]:
class MaskRCNNPredictor(nn.Sequential):
    def __init__(self, in_channels, layers, dim_reduced, num_classes):
        """
        Arguments:
            in_channels (int)
            layers (Tuple[int])
            dim_reduced (int)
            num_classes (int)
        """
        
        d = OrderedDict()
        next_feature = in_channels
        # layers = [256, 256, 256, 256]
        for layer_idx, layer_features in enumerate(layers, 1):
            d['mask_fcn{}'.format(layer_idx)] = nn.Conv2d(next_feature, layer_features, 3, 1, 1)
            d['relu{}'.format(layer_idx)] = nn.ReLU(inplace=True)
            next_feature = layer_features
        
        d['mask_conv5'] = nn.ConvTranspose2d(next_feature, dim_reduced, 2, 2, 0)
        d['relu5'] = nn.ReLU(inplace=True)
        d['mask_fcn_logits'] = nn.Conv2d(dim_reduced, num_classes, 1, 1, 0)  # 256, 21
        super().__init__(d)

        for name, param in self.named_parameters():
            if 'weight' in name:
                nn.init.kaiming_normal_(param, mode='fan_out', nonlinearity='relu')

In [None]:
import torch
import torch.nn.functional as F
from torch import nn

### 3.3 Multi task loss

2-stage loss

L = L_cls + L_box + L_mask



In [None]:
def fastrcnn_loss(class_logit, box_regression, label, regression_target):
    classifier_loss = F.cross_entropy(class_logit, label)

    N, num_pos = class_logit.shape[0], regression_target.shape[0]
    box_regression = box_regression.reshape(N, -1, 4)
    box_regression, label = box_regression[:num_pos], label[:num_pos]
    box_idx = torch.arange(num_pos, device=label.device)

    box_reg_loss = F.smooth_l1_loss(box_regression[box_idx, label], regression_target, reduction='sum') / N

    return classifier_loss, box_reg_loss  # L_cls, L_box


def maskrcnn_loss(mask_logit, proposal, matched_idx, label, gt_mask):
    matched_idx = matched_idx[:, None].to(proposal)
    roi = torch.cat((matched_idx, proposal), dim=1)
            
    M = mask_logit.shape[-1]
    gt_mask = gt_mask[:, None].to(roi)
    mask_target = roi_align(gt_mask, roi, 1., M, M, -1)[:, 0]

    idx = torch.arange(label.shape[0], device=label.device)
    mask_loss = F.binary_cross_entropy_with_logits(mask_logit[idx, label], mask_target)

    return mask_loss  # L_mask

### 3.4 RoI Heads

In [None]:
class RoIHeads(nn.Module):
    def __init__(self, box_roi_pool, box_predictor,
                 fg_iou_thresh, bg_iou_thresh,
                 num_samples, positive_fraction,
                 reg_weights,
                 score_thresh, nms_thresh, num_detections):
        super().__init__()
        self.box_roi_pool = box_roi_pool
        self.box_predictor = box_predictor
        
        self.mask_roi_pool = None
        self.mask_predictor = None
        
        # utils
        self.proposal_matcher = Matcher(fg_iou_thresh, bg_iou_thresh, allow_low_quality_matches=False)
        self.fg_bg_sampler = BalancedPositiveNegativeSampler(num_samples, positive_fraction)
        self.box_coder = BoxCoder(reg_weights)
        
        self.score_thresh = score_thresh
        self.nms_thresh = nms_thresh
        self.num_detections = num_detections
        self.min_size = 1
        
    def has_mask(self):
        if self.mask_roi_pool is None:
            return False
        if self.mask_predictor is None:
            return False
        return True
        
    def select_training_samples(self, proposal, target):
        # Fast R-CNN GTs
        gt_box = target['boxes']
        gt_label = target['labels']

        # Regional Proposals
        proposal = torch.cat((proposal, gt_box))
        
        iou = box_iou(gt_box, proposal)  # iou 계산
        # 예측 proposal bbox들의 class(-1, 0, 1), 각 GT box에 매칭되는(최대의 iou를 갖는) 예측 box의 index 반환 
        pos_neg_label, matched_idx = self.proposal_matcher(iou)  
        pos_idx, neg_idx = self.fg_bg_sampler(pos_neg_label)  # pos/neg의 비율을 적절히 맞춰서 샘플링
        idx = torch.cat((pos_idx, neg_idx))
        
        regression_target = self.box_coder.encode(gt_box[matched_idx[pos_idx]], proposal[pos_idx])  # GT-예측 pair
        proposal = proposal[idx]
        matched_idx = matched_idx[idx]
        label = gt_label[matched_idx]
        num_pos = pos_idx.shape[0]
        label[num_pos:] = 0
        
        return proposal, matched_idx, label, regression_target
    
    def fastrcnn_inference(self, class_logit, box_regression, proposal, image_shape):
        N, num_classes = class_logit.shape
        
        device = class_logit.device
        pred_score = F.softmax(class_logit, dim=-1)
        box_regression = box_regression.reshape(N, -1, 4)
        
        boxes = []
        labels = []
        scores = []
        for l in range(1, num_classes):
            score, box_delta = pred_score[:, l], box_regression[:, l]

            keep = score >= self.score_thresh  # 특정 ths 이상의 bbox들의 idx
            box, score, box_delta = proposal[keep], score[keep], box_delta[keep]
            box = self.box_coder.decode(box_delta, box)
            
            box, score = process_box(box, score, image_shape, self.min_size)  # post-processing
            
            keep = nms(box, score, self.nms_thresh)[:self.num_detections]  # NMS
            box, score = box[keep], score[keep]
            label = torch.full((len(keep),), l, dtype=keep.dtype, device=device)
            
            boxes.append(box)  # bbox for detected objects
            labels.append(label)  # classification for detected objects
            scores.append(score)  # objectness score for detected objects

        results = dict(boxes=torch.cat(boxes), labels=torch.cat(labels), scores=torch.cat(scores))
        return results
    
    def forward(self, feature, proposal, image_shape, target):
        '''feature : backbone feature map output
           proposal : RPN output'''
        if self.training:
            proposal, matched_idx, label, regression_target = self.select_training_samples(proposal, target)
        
        # Detection
        # RoI Align을 통해 feature map에서 proposal bbox에 해당되는 영역의 feature계산
        box_feature = self.box_roi_pool(feature, proposal, image_shape)  # RoI Align
        class_logit, box_regression = self.box_predictor(box_feature)  # Fast R-CNN predictor
        
        result, losses = {}, {}
        if self.training:
            # L_cls, L_box
            classifier_loss, box_reg_loss = fastrcnn_loss(class_logit, box_regression, label, regression_target)
            losses = dict(roi_classifier_loss=classifier_loss, roi_box_loss=box_reg_loss)

        else:
            result = self.fastrcnn_inference(class_logit, box_regression, proposal, image_shape)

        # Mask pred    
        if self.has_mask():
            if self.training:
                num_pos = regression_target.shape[0]

                # proposal중 pos label에 해당되는 것들만 (검출된 object에 대해 mask prediction 수행) 
                mask_proposal = proposal[:num_pos] 
                pos_matched_idx = matched_idx[:num_pos]
                mask_label = label[:num_pos]
                
                '''
                # -------------- critial ----------------
                box_regression = box_regression[:num_pos].reshape(num_pos, -1, 4)
                idx = torch.arange(num_pos, device=mask_label.device)
                mask_proposal = self.box_coder.decode(box_regression[idx, mask_label], mask_proposal)
                # ---------------------------------------
                '''
                
                if mask_proposal.shape[0] == 0:
                    losses.update(dict(roi_mask_loss=torch.tensor(0)))
                    return result, losses
            else:
                mask_proposal = result['boxes']
                
                if mask_proposal.shape[0] == 0:
                    result.update(dict(masks=torch.empty((0, 28, 28))))
                    return result, losses
                
            mask_feature = self.mask_roi_pool(feature, mask_proposal, image_shape)  # RoI Align
            mask_logit = self.mask_predictor(mask_feature)  # Mask R-CNN predictor
            
            if self.training:
                # L_mask
                gt_mask = target['masks']
                mask_loss = maskrcnn_loss(mask_logit, mask_proposal, pos_matched_idx, mask_label, gt_mask)
                losses.update(dict(roi_mask_loss=mask_loss))
            else:
                label = result['labels']
                idx = torch.arange(label.shape[0], device=label.device)
                mask_logit = mask_logit[idx, label]

                mask_prob = mask_logit.sigmoid()
                result.update(dict(masks=mask_prob))
                
        return result, losses