Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CodeCamp #1492: Support Boxinst #9525

Merged
merged 10 commits into from
Dec 28, 2022
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
31 changes: 31 additions & 0 deletions configs/boxinst/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# BoxInst

> [BoxInst: High-Performance Instance Segmentation with Box Annotations](https://arxiv.org/pdf/2012.02310.pdf)

<!-- [ALGORITHM] -->

## Abstract

We present a high-performance method that can achieve mask-level instance segmentation with only bounding-box annotations for training. While this setting has been studied in the literature, here we show significantly stronger performance with a simple design (e.g., dramatically improving previous best reported mask AP of 21.1% to 31.6% on the COCO dataset). Our core idea is to redesign the loss
of learning masks in instance segmentation, with no modification to the segmentation network itself. The new loss functions can supervise the mask training without relying on mask annotations. This is made possible with two loss terms, namely, 1) a surrogate term that minimizes the discrepancy between the projections of the ground-truth box and the predicted mask; 2) a pairwise loss that can exploit the prior that proximal pixels with similar colors are very likely to have the same category label. Experiments demonstrate that the redesigned mask loss can yield surprisingly high-quality instance masks with only box annotations. For example, without using any mask annotations, with a ResNet-101 backbone and 3× training schedule, we achieve 33.2% mask AP on COCO test-dev split (vs. 39.1% of the fully supervised counterpart). Our excellent experiment results on COCO and Pascal VOC indicate that our method dramatically narrows the performance gap between weakly and fully supervised instance segmentation.

<div align=center>
<img src="https://user-images.githubusercontent.com/57584090/209087723-756b76d7-5061-4000-a93c-df1194a439a0.png"/>
</div>

## Results and Models

| Backbone | Style | MS train | Lr schd | bbox AP | mask AP | Config | Download |
| :------: | :-----: | :------: | :-----: | :-----: | :-----: | :----------------------------------------: | :----------------------: |
| R-50 | pytorch | Y | 1x | 39.4 | 30.8 | [config](./boxinst_r50_fpn_ms-90k_coco.py) | [model](<>) \| [log](<>) |

## Citation

```latex
@inproceedings{tian2020boxinst,
title = {{BoxInst}: High-Performance Instance Segmentation with Box Annotations},
author = {Tian, Zhi and Shen, Chunhua and Wang, Xinlong and Chen, Hao},
booktitle = {Proc. IEEE Conf. Computer Vision and Pattern Recognition (CVPR)},
year = {2021}
}
```
93 changes: 93 additions & 0 deletions configs/boxinst/boxinst_r50_fpn_ms-90k_coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
_base_ = '../common/ms-90k_coco.py'

# model settings
model = dict(
type='BoxInst',
data_preprocessor=dict(
type='BoxInstDataPreprocessor',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
bgr_to_rgb=True,
pad_size_divisor=32,
mask_stride=4,
pairwise_size=3,
pairwise_dilation=2,
pairwise_color_thresh=0.3,
bottom_pixels_removed=10),
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=True,
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50'),
style='pytorch'),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
start_level=1,
add_extra_convs='on_output', # use P5
num_outs=5,
relu_before_extra_convs=True),
bbox_head=dict(
type='BoxInstBboxHead',
num_params=593,
num_classes=80,
in_channels=256,
stacked_convs=4,
feat_channels=256,
strides=[8, 16, 32, 64, 128],
norm_on_bbox=True,
centerness_on_reg=True,
dcn_on_last_conv=False,
center_sampling=True,
conv_bias=True,
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
loss_bbox=dict(type='GIoULoss', loss_weight=1.0),
loss_centerness=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0)),
mask_head=dict(
type='BoxInstMaskHead',
num_layers=3,
feat_channels=16,
size_of_interest=8,
mask_out_stride=4,
topk_masks_per_img=64,
mask_feature_head=dict(
in_channels=256,
feat_channels=128,
start_level=0,
end_level=2,
out_channels=16,
mask_stride=8,
num_stacked_convs=4,
norm_cfg=dict(type='BN', requires_grad=True)),
loss_mask=dict(
type='DiceLoss',
use_sigmoid=True,
activate=True,
eps=5e-6,
loss_weight=1.0)),
# model training and testing settings
test_cfg=dict(
nms_pre=1000,
min_bbox_size=0,
score_thr=0.05,
nms=dict(type='nms', iou_threshold=0.6),
max_per_img=100,
mask_thr=0.5))

# optimizer
optim_wrapper = dict(optimizer=dict(lr=0.01))

# evaluator
val_evaluator = dict(metric=['bbox', 'segm'])
test_evaluator = val_evaluator
5 changes: 3 additions & 2 deletions mmdet/models/data_preprocessors/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .data_preprocessor import (BatchFixedSizePad, BatchResize,
BatchSyncRandomResize, DetDataPreprocessor,
BatchSyncRandomResize, BoxInstDataPreprocessor,
DetDataPreprocessor,
MultiBranchDataPreprocessor)

__all__ = [
'DetDataPreprocessor', 'BatchSyncRandomResize', 'BatchFixedSizePad',
'MultiBranchDataPreprocessor', 'BatchResize'
'MultiBranchDataPreprocessor', 'BatchResize', 'BoxInstDataPreprocessor'
]
126 changes: 126 additions & 0 deletions mmdet/models/data_preprocessors/data_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,14 @@
from mmengine.model import BaseDataPreprocessor, ImgDataPreprocessor
from mmengine.structures import PixelData
from mmengine.utils import is_list_of
from skimage import color
JosonChan1998 marked this conversation as resolved.
Show resolved Hide resolved
from torch import Tensor

from mmdet.models.utils import unfold_wo_center
from mmdet.models.utils.misc import samplelist_boxtype2tensor
from mmdet.registry import MODELS
from mmdet.structures import DetDataSample
from mmdet.structures.mask import BitmapMasks
from mmdet.utils import ConfigType


Expand Down Expand Up @@ -645,3 +648,126 @@ def get_padded_tensor(self, tensor: Tensor, pad_value: int) -> Tensor:
padded_tensor = padded_tensor.type_as(tensor)
padded_tensor[:, :, :target_height, :target_width] = tensor
return padded_tensor


@MODELS.register_module()
class BoxInstDataPreprocessor(DetDataPreprocessor):
"""Pseudo mask pre-processor for BoxInst.

Comparing with the :class:`mmdet.DetDataPreprocessor`,

1. It generates masks using box annotations.
2. It computes the images color similarity in LAB color space.

Args:
mask_stride (int): The mask output stride in boxinst. Defaults to 4.
pairwise_size (int): The size of neighborhood for each pixel.
Defaults to 3.
pairwise_dilation (int): The dilation of neighborhood for each pixel.
Defaults to 2.
pairwise_color_thresh (float): The thresh of image color similarity.
Defaults to 0.3.
bottom_pixels_removed (int): The length of removed pixels in bottom.
JosonChan1998 marked this conversation as resolved.
Show resolved Hide resolved
Defaults to 10.
"""

def __init__(self,
*arg,
mask_stride: int = 4,
pairwise_size: int = 3,
pairwise_dilation: int = 2,
pairwise_color_thresh: float = 0.3,
bottom_pixels_removed: int = 10,
**kwargs) -> None:
super().__init__(*arg, **kwargs)
self.mask_stride = mask_stride
self.pairwise_size = pairwise_size
self.pairwise_dilation = pairwise_dilation
self.pairwise_color_thresh = pairwise_color_thresh
self.bottom_pixels_removed = bottom_pixels_removed

def get_images_color_similarity(self, inputs: Tensor, image_masks: Tensor):
"""Compute the image color similarity in LAB color space."""
assert inputs.dim() == 4
assert inputs.size(0) == 1

unfolded_images = unfold_wo_center(
inputs,
kernel_size=self.pairwise_size,
dilation=self.pairwise_dilation)
diff = inputs[:, :, None] - unfolded_images
similarity = torch.exp(-torch.norm(diff, dim=1) * 0.5)

unfolded_weights = unfold_wo_center(
image_masks[None, None],
kernel_size=self.pairwise_size,
dilation=self.pairwise_dilation)
unfolded_weights = torch.max(unfolded_weights, dim=1)[0]

return similarity * unfolded_weights

def forward(self, data: dict, training: bool = False) -> dict:
"""Get pseudo mask labels using color similarity."""
det_data = super().forward(data, training)
inputs, data_samples = det_data['inputs'], det_data['data_samples']

if training:
# get image masks and remove bottom pixels
b_img_h, b_img_w = data_samples[0].batch_input_shape
img_masks = []
for i in range(inputs.shape[0]):
img_h, img_w = data_samples[i].img_shape
img_mask = inputs.new_ones((img_h, img_w))
pixels_removed = int(self.bottom_pixels_removed *
float(img_h) / float(b_img_h))
if pixels_removed > 0:
img_mask[-pixels_removed:, :] = 0
pad_w = b_img_w - img_w
pad_h = b_img_h - img_h
img_mask = F.pad(img_mask, (0, pad_w, 0, pad_h), 'constant',
0.)
img_masks.append(img_mask)
img_masks = torch.stack(img_masks, dim=0)
start = int(self.mask_stride // 2)
img_masks = img_masks[:, start::self.mask_stride,
start::self.mask_stride]

# Get origin rgb image for color similarity
ori_imgs = inputs * self.std + self.mean
downsampled_imgs = F.avg_pool2d(
ori_imgs.float(),
kernel_size=self.mask_stride,
stride=self.mask_stride,
padding=0)

# Compute color similarity for pseudo mask generation
for im_i, data_sample in enumerate(data_samples):
# TODO: Support rgb2lab in mmengine?
images_lab = color.rgb2lab(
downsampled_imgs[im_i].byte().permute(1, 2,
0).cpu().numpy())
images_lab = torch.as_tensor(
images_lab, device=ori_imgs.device, dtype=torch.float32)
images_lab = images_lab.permute(2, 0, 1)[None]
images_color_similarity = self.get_images_color_similarity(
images_lab, img_masks[im_i])
pairwise_masks = (images_color_similarity >=
self.pairwise_color_thresh).float()

per_im_bboxes = data_sample.gt_instances.bboxes
per_im_masks = []
for per_box in per_im_bboxes:
mask_full = torch.zeros((b_img_h, b_img_w),
device=self.device).float()
mask_full[int(per_box[1]):int(per_box[3] + 1),
int(per_box[0]):int(per_box[2] + 1)] = 1.0
per_im_masks.append(mask_full)
per_im_masks = torch.stack(per_im_masks, dim=0)

# TODO: Support BitmapMasks with tensor?
data_sample.gt_instances.masks = BitmapMasks(
per_im_masks.cpu().numpy(), b_img_h, b_img_w)
data_sample.gt_instances.pairwise_masks = torch.cat(
[pairwise_masks for _ in range(per_im_bboxes.shape[0])],
dim=0)
return inputs, data_samples
4 changes: 3 additions & 1 deletion mmdet/models/dense_heads/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .anchor_head import AnchorHead
from .atss_head import ATSSHead
from .autoassign_head import AutoAssignHead
from .boxinst_head import BoxInstBboxHead, BoxInstMaskHead
from .cascade_rpn_head import CascadeRPNHead, StageCascadeRPNHead
from .centernet_head import CenterNetHead
from .centernet_update_head import CenterNetUpdateHead
Expand Down Expand Up @@ -59,5 +60,6 @@
'DecoupledSOLOHead', 'DecoupledSOLOLightHead', 'SOLOV2Head', 'LADHead',
'TOODHead', 'MaskFormerHead', 'Mask2FormerHead', 'DDODHead',
'CenterNetUpdateHead', 'RTMDetHead', 'RTMDetSepBNHead', 'CondInstBboxHead',
'CondInstMaskHead', 'RTMDetInsHead', 'RTMDetInsSepBNHead'
'CondInstMaskHead', 'RTMDetInsHead', 'RTMDetInsSepBNHead',
'BoxInstBboxHead', 'BoxInstMaskHead'
]