Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support QueryInst #6050

Merged
merged 25 commits into from
Oct 26, 2021
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
d3467e1
impl queryinst
vealocia Aug 18, 2021
39b3eb9
bug free queryinst with crop and negative samples
vealocia Aug 24, 2021
b62e491
use detr hyperparameters
vealocia Sep 1, 2021
3e87716
Merge https://github.com.cnpmjs.org/open-mmlab/mmdetection into detraug
vealocia Sep 5, 2021
062a085
pre-commit hooks
vealocia Sep 7, 2021
e157c8d
modified dynamic_mask_head docstrings
vealocia Sep 7, 2021
d669826
remove unused dropout in dynamic_mask_head
vealocia Sep 7, 2021
35a4d34
add docstring for dice_loss
vealocia Sep 7, 2021
fbb6e3c
add dice_loss unit test
vealocia Sep 7, 2021
d9b7025
impl unit test for dynamic_mask_head
vealocia Sep 7, 2021
51e5fa5
update queryinst docstring and implementation
vealocia Sep 8, 2021
cc838dc
stability update for dice_loss and dynamic_mask_head
vealocia Sep 10, 2021
be0cfed
Merge branch 'open-mmlab:master' into queryinst
vealocia Sep 12, 2021
b5efc5c
Merge branch 'open-mmlab:master' into queryinst
vealocia Sep 22, 2021
3de895d
update for clarify
vealocia Sep 22, 2021
f6431de
Merge branch 'open-mmlab:master' into queryinst
vealocia Sep 23, 2021
1e3b916
Merge remote-tracking branch 'mmdet/master' into queryinst
vealocia Oct 2, 2021
dd59078
Merge branch 'open-mmlab:master' into queryinst
vealocia Oct 19, 2021
0609df4
bug free in case of num_proposals equal to zero
vealocia Oct 19, 2021
b25d01a
detail docstrings
vealocia Oct 19, 2021
6b04e60
fixed CI issues
vealocia Oct 20, 2021
8de6779
Merge branch 'open-mmlab:master' into queryinst
vealocia Oct 20, 2021
557dbf2
issues resolved
vealocia Oct 22, 2021
37b28ed
Merge branch 'queryinst' of github.com:vealocia/mmdetection into quer…
vealocia Oct 22, 2021
e229d4f
add queryinst docs
vealocia Oct 22, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
_base_ = './queryinst_r50_fpn_300_proposals_crop_mstrain_480-800_3x_coco.py'

model = dict(
backbone=dict(
depth=101,
init_cfg=dict(type='Pretrained',
checkpoint='torchvision://resnet101')))
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
_base_ = './queryinst_r50_fpn_mstrain_480-800_3x_coco.py'

model = dict(
backbone=dict(
depth=101,
init_cfg=dict(type='Pretrained',
checkpoint='torchvision://resnet101')))
138 changes: 138 additions & 0 deletions configs/queryinst/queryinst_r50_fpn_1x_coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
_base_ = [
'../_base_/datasets/coco_instance.py',
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
]
num_stages = 6
num_proposals = 100
model = dict(
type='QueryInst',
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=True,
style='pytorch',
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
start_level=0,
add_extra_convs='on_input',
num_outs=4),
rpn_head=dict(
type='EmbeddingRPNHead',
num_proposals=num_proposals,
proposal_feature_channel=256),
roi_head=dict(
type='SparseRoIHead',
num_stages=num_stages,
stage_loss_weights=[1] * num_stages,
proposal_feature_channel=256,
bbox_roi_extractor=dict(
type='SingleRoIExtractor',
roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=2),
out_channels=256,
featmap_strides=[4, 8, 16, 32]),
mask_roi_extractor=dict(
type='SingleRoIExtractor',
roi_layer=dict(type='RoIAlign', output_size=14, sampling_ratio=2),
out_channels=256,
featmap_strides=[4, 8, 16, 32]),
bbox_head=[
dict(
type='DIIHead',
num_classes=80,
num_ffn_fcs=2,
num_heads=8,
num_cls_fcs=1,
num_reg_fcs=3,
feedforward_channels=2048,
in_channels=256,
dropout=0.0,
ffn_act_cfg=dict(type='ReLU', inplace=True),
dynamic_conv_cfg=dict(
type='DynamicConv',
in_channels=256,
feat_channels=64,
out_channels=256,
input_feat_shape=7,
act_cfg=dict(type='ReLU', inplace=True),
norm_cfg=dict(type='LN')),
loss_bbox=dict(type='L1Loss', loss_weight=5.0),
loss_iou=dict(type='GIoULoss', loss_weight=2.0),
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=2.0),
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
clip_border=False,
target_means=[0., 0., 0., 0.],
target_stds=[0.5, 0.5, 1., 1.])) for _ in range(num_stages)
],
mask_head=[
dict(
type='DynamicMaskHead',
dynamic_conv_cfg=dict(
type='DynamicConv',
in_channels=256,
feat_channels=64,
out_channels=256,
input_feat_shape=14,
with_proj=False,
act_cfg=dict(type='ReLU', inplace=True),
norm_cfg=dict(type='LN')),
num_convs=4,
num_classes=80,
roi_feat_size=14,
in_channels=256,
conv_kernel_size=3,
conv_out_channels=256,
class_agnostic=False,
norm_cfg=dict(type='BN'),
upsample_cfg=dict(type='deconv', scale_factor=2),
loss_mask=dict(
type='DiceLoss',
loss_weight=8.0,
use_sigmoid=True,
activate=False,
eps=1e-5)) for _ in range(num_stages)
]),
# training and testing settings
train_cfg=dict(
rpn=None,
rcnn=[
dict(
assigner=dict(
type='HungarianAssigner',
cls_cost=dict(type='FocalLossCost', weight=2.0),
reg_cost=dict(type='BBoxL1Cost', weight=5.0),
iou_cost=dict(type='IoUCost', iou_mode='giou',
weight=2.0)),
sampler=dict(type='PseudoSampler'),
pos_weight=1,
mask_size=28,
) for _ in range(num_stages)
]),
test_cfg=dict(
rpn=None, rcnn=dict(max_per_img=num_proposals, mask_thr_binary=0.5)))

# optimizer
optimizer = dict(
_delete_=True,
type='AdamW',
lr=0.0001,
weight_decay=0.0001,
paramwise_cfg=dict(
custom_keys={'backbone': dict(lr_mult=0.1, decay_mult=1.0)}))
optimizer_config = dict(
_delete_=True, grad_clip=dict(max_norm=0.1, norm_type=2))
# learning policy
lr_config = dict(policy='step', step=[8, 11], warmup_iters=1000)
runner = dict(type='EpochBasedRunner', max_epochs=12)
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
_base_ = './queryinst_r50_fpn_mstrain_480-800_3x_coco.py'
num_proposals = 300
model = dict(
rpn_head=dict(num_proposals=num_proposals),
test_cfg=dict(
_delete_=True,
rpn=None,
rcnn=dict(max_per_img=num_proposals, mask_thr_binary=0.5)))
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)

# augmentation strategy originates from DETR.
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(
type='AutoAugment',
policies=[[
dict(
type='Resize',
img_scale=[(480, 1333), (512, 1333), (544, 1333), (576, 1333),
(608, 1333), (640, 1333), (672, 1333), (704, 1333),
(736, 1333), (768, 1333), (800, 1333)],
multiscale_mode='value',
keep_ratio=True)
],
[
dict(
type='Resize',
img_scale=[(400, 1333), (500, 1333), (600, 1333)],
multiscale_mode='value',
keep_ratio=True),
dict(
type='RandomCrop',
crop_type='absolute_range',
crop_size=(384, 600),
allow_negative_crop=True),
dict(
type='Resize',
img_scale=[(480, 1333), (512, 1333), (544, 1333),
(576, 1333), (608, 1333), (640, 1333),
(672, 1333), (704, 1333), (736, 1333),
(768, 1333), (800, 1333)],
multiscale_mode='value',
override=True,
keep_ratio=True)
]]),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks'])
]
data = dict(train=dict(pipeline=train_pipeline))
23 changes: 23 additions & 0 deletions configs/queryinst/queryinst_r50_fpn_mstrain_480-800_3x_coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
_base_ = './queryinst_r50_fpn_1x_coco.py'

img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
min_values = (480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
dict(
type='Resize',
img_scale=[(1333, value) for value in min_values],
multiscale_mode='value',
keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks'])
]

data = dict(train=dict(pipeline=train_pipeline))
lr_config = dict(policy='step', step=[27, 33])
runner = dict(type='EpochBasedRunner', max_epochs=36)
3 changes: 2 additions & 1 deletion mmdet/models/detectors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from .panoptic_fpn import PanopticFPN
from .panoptic_two_stage_segmentor import TwoStagePanopticSegmentor
from .point_rend import PointRend
from .queryinst import QueryInst
from .reppoints_detector import RepPointsDetector
from .retinanet import RetinaNet
from .rpn import RPN
Expand All @@ -46,5 +47,5 @@
'NASFCOS', 'PointRend', 'GFL', 'CornerNet', 'PAA', 'YOLOV3', 'YOLACT',
'VFNet', 'DETR', 'TridentFasterRCNN', 'SparseRCNN', 'SCNet', 'SOLO',
'DeformableDETR', 'AutoAssign', 'YOLOF', 'CenterNet', 'YOLOX',
'TwoStagePanopticSegmentor', 'PanopticFPN'
'TwoStagePanopticSegmentor', 'PanopticFPN', 'QueryInst'
]
27 changes: 27 additions & 0 deletions mmdet/models/detectors/queryinst.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from ..builder import DETECTORS
from .sparse_rcnn import SparseRCNN


@DETECTORS.register_module()
class QueryInst(SparseRCNN):
r"""Implementation of
`Instances as Queries <http://arxiv.org/abs/2105.01928>`_"""

def __init__(self,
backbone,
rpn_head,
roi_head,
train_cfg,
test_cfg,
neck=None,
pretrained=None,
init_cfg=None):
super(QueryInst, self).__init__(
backbone=backbone,
neck=neck,
rpn_head=rpn_head,
roi_head=roi_head,
train_cfg=train_cfg,
test_cfg=test_cfg,
pretrained=pretrained,
init_cfg=init_cfg)
16 changes: 8 additions & 8 deletions mmdet/models/detectors/sparse_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ class SparseRCNN(TwoStageDetector):

def __init__(self, *args, **kwargs):
super(SparseRCNN, self).__init__(*args, **kwargs)
assert self.with_rpn, 'Sparse R-CNN do not support external proposals'
assert self.with_rpn, 'Sparse R-CNN and QueryInst ' \
'do not support external proposals'

def forward_train(self,
img,
Expand All @@ -21,7 +22,7 @@ def forward_train(self,
gt_masks=None,
proposals=None,
**kwargs):
"""Forward function of SparseR-CNN in train stage.
"""Forward function of SparseR-CNN and QueryInst in train stage.

Args:
img (Tensor): of shape (N, C, H, W) encoding input images.
Expand All @@ -37,17 +38,16 @@ def forward_train(self,
gt_bboxes_ignore (None | list[Tensor): specify which bounding
boxes can be ignored when computing the loss.
gt_masks (List[Tensor], optional) : Segmentation masks for
each box. But we don't support it in this architecture.
each box. This is required to train QueryInst.
proposals (List[Tensor], optional): override rpn proposals with
custom proposals. Use when `with_rpn` is False.

Returns:
dict[str, Tensor]: a dictionary of loss components
"""

assert proposals is None, 'Sparse R-CNN does not support' \
' external proposals'
assert gt_masks is None, 'Sparse R-CNN does not instance segmentation'
assert proposals is None, 'Sparse R-CNN and QueryInst ' \
'do not support external proposals'

x = self.extract_feat(img)
proposal_boxes, proposal_features, imgs_whwh = \
Expand Down Expand Up @@ -81,14 +81,14 @@ def simple_test(self, img, img_metas, rescale=False):
x = self.extract_feat(img)
proposal_boxes, proposal_features, imgs_whwh = \
self.rpn_head.simple_test_rpn(x, img_metas)
bbox_results = self.roi_head.simple_test(
results = self.roi_head.simple_test(
vealocia marked this conversation as resolved.
Show resolved Hide resolved
x,
proposal_boxes,
proposal_features,
img_metas,
imgs_whwh=imgs_whwh,
rescale=rescale)
return bbox_results
return results

def forward_dummy(self, img):
"""Used for computing network flops.
Expand Down
4 changes: 2 additions & 2 deletions mmdet/models/losses/dice_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ def dice_loss(pred,
the loss. Defaults to None.
"""

input = pred.reshape(pred.size()[0], -1)
target = target.reshape(target.size()[0], -1).float()
input = pred.flatten(1)
vealocia marked this conversation as resolved.
Show resolved Hide resolved
target = target.flatten(1).float()

a = torch.sum(input * target, 1)
b = torch.sum(input * input, 1) + eps
Expand Down
13 changes: 8 additions & 5 deletions mmdet/models/roi_heads/bbox_heads/dii_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,10 +168,10 @@ def forward(self, roi_feat, proposal_feat):
# Self attention
proposal_feat = proposal_feat.permute(1, 0, 2)
proposal_feat = self.attention_norm(self.attention(proposal_feat))
attn_feats = proposal_feat.permute(1, 0, 2)

# instance interactive
proposal_feat = proposal_feat.permute(1, 0,
2).reshape(-1, self.in_channels)
proposal_feat = attn_feats.reshape(-1, self.in_channels)
proposal_feat_iic = self.instance_interactive_conv(
proposal_feat, roi_feat)
proposal_feat = proposal_feat + self.instance_interactive_conv_dropout(
Expand All @@ -189,10 +189,13 @@ def forward(self, roi_feat, proposal_feat):
for reg_layer in self.reg_fcs:
reg_feat = reg_layer(reg_feat)

cls_score = self.fc_cls(cls_feat).view(N, num_proposals, -1)
bbox_delta = self.fc_reg(reg_feat).view(N, num_proposals, -1)
cls_score = self.fc_cls(cls_feat).view(
N, num_proposals, self.num_classes
if self.loss_cls.use_sigmoid else self.num_classes + 1)
bbox_delta = self.fc_reg(reg_feat).view(N, num_proposals, 4)

return cls_score, bbox_delta, obj_feat.view(N, num_proposals, -1)
return cls_score, bbox_delta, obj_feat.view(
N, num_proposals, self.in_channels), attn_feats

@force_fp32(apply_to=('cls_score', 'bbox_pred'))
def loss(self,
Expand Down
4 changes: 3 additions & 1 deletion mmdet/models/roi_heads/mask_heads/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .coarse_mask_head import CoarseMaskHead
from .dynamic_mask_head import DynamicMaskHead
from .fcn_mask_head import FCNMaskHead
from .feature_relay_head import FeatureRelayHead
from .fused_semantic_head import FusedSemanticHead
Expand All @@ -14,5 +15,6 @@
__all__ = [
'FCNMaskHead', 'HTCMaskHead', 'FusedSemanticHead', 'GridHead',
'MaskIoUHead', 'CoarseMaskHead', 'MaskPointHead', 'SCNetMaskHead',
'SCNetSemanticHead', 'GlobalContextHead', 'FeatureRelayHead'
'SCNetSemanticHead', 'GlobalContextHead', 'FeatureRelayHead',
'DynamicMaskHead'
]
Loading