Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support QueryInst #6050

Merged
merged 25 commits into from
Oct 26, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
d3467e1
impl queryinst
vealocia Aug 18, 2021
39b3eb9
bug free queryinst with crop and negative samples
vealocia Aug 24, 2021
b62e491
use detr hyperparameters
vealocia Sep 1, 2021
3e87716
Merge https://github.com.cnpmjs.org/open-mmlab/mmdetection into detraug
vealocia Sep 5, 2021
062a085
pre-commit hooks
vealocia Sep 7, 2021
e157c8d
modified dynamic_mask_head docstrings
vealocia Sep 7, 2021
d669826
remove unused dropout in dynamic_mask_head
vealocia Sep 7, 2021
35a4d34
add docstring for dice_loss
vealocia Sep 7, 2021
fbb6e3c
add dice_loss unit test
vealocia Sep 7, 2021
d9b7025
impl unit test for dynamic_mask_head
vealocia Sep 7, 2021
51e5fa5
update queryinst docstring and implementation
vealocia Sep 8, 2021
cc838dc
stability update for dice_loss and dynamic_mask_head
vealocia Sep 10, 2021
be0cfed
Merge branch 'open-mmlab:master' into queryinst
vealocia Sep 12, 2021
b5efc5c
Merge branch 'open-mmlab:master' into queryinst
vealocia Sep 22, 2021
3de895d
update for clarify
vealocia Sep 22, 2021
f6431de
Merge branch 'open-mmlab:master' into queryinst
vealocia Sep 23, 2021
1e3b916
Merge remote-tracking branch 'mmdet/master' into queryinst
vealocia Oct 2, 2021
dd59078
Merge branch 'open-mmlab:master' into queryinst
vealocia Oct 19, 2021
0609df4
bug free in case of num_proposals equal to zero
vealocia Oct 19, 2021
b25d01a
detail docstrings
vealocia Oct 19, 2021
6b04e60
fixed CI issues
vealocia Oct 20, 2021
8de6779
Merge branch 'open-mmlab:master' into queryinst
vealocia Oct 20, 2021
557dbf2
issues resolved
vealocia Oct 22, 2021
37b28ed
Merge branch 'queryinst' of github.com:vealocia/mmdetection into quer…
vealocia Oct 22, 2021
e229d4f
add queryinst docs
vealocia Oct 22, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
_base_ = './queryinst_r50_fpn_300_proposals_crop_mstrain_480-800_3x_coco.py'

model = dict(
backbone=dict(
depth=101,
init_cfg=dict(type='Pretrained',
checkpoint='torchvision://resnet101')))
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
_base_ = './queryinst_r50_fpn_mstrain_480-800_3x_coco.py'

model = dict(
backbone=dict(
depth=101,
init_cfg=dict(type='Pretrained',
checkpoint='torchvision://resnet101')))
125 changes: 125 additions & 0 deletions configs/queryinst/queryinst_r50_fpn_1x_coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
_base_ = [
'../_base_/datasets/coco_instance.py',
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
]
num_stages = 6
num_proposals = 100
model = dict(
type='SparseRCNN',
vealocia marked this conversation as resolved.
Show resolved Hide resolved
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=True,
style='pytorch',
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
start_level=0,
add_extra_convs='on_input',
num_outs=4),
rpn_head=dict(
type='EmbeddingRPNHead',
num_proposals=num_proposals,
proposal_feature_channel=256),
roi_head=dict(
type='SparseRoIHead',
num_stages=num_stages,
stage_loss_weights=[1] * num_stages,
proposal_feature_channel=256,
bbox_roi_extractor=dict(
type='SingleRoIExtractor',
roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=2),
out_channels=256,
featmap_strides=[4, 8, 16, 32]),
mask_roi_extractor=dict(
type='SingleRoIExtractor',
roi_layer=dict(type='RoIAlign', output_size=14, sampling_ratio=2),
out_channels=256,
featmap_strides=[4, 8, 16, 32]),
bbox_head=[
dict(
type='DIIHead',
num_classes=80,
num_ffn_fcs=2,
num_heads=8,
num_cls_fcs=1,
num_reg_fcs=3,
feedforward_channels=2048,
in_channels=256,
dropout=0.0,
ffn_act_cfg=dict(type='ReLU', inplace=True),
dynamic_conv_cfg=dict(
type='DynamicConv',
in_channels=256,
feat_channels=64,
out_channels=256,
input_feat_shape=7,
act_cfg=dict(type='ReLU', inplace=True),
norm_cfg=dict(type='LN')),
loss_bbox=dict(type='L1Loss', loss_weight=5.0),
loss_iou=dict(type='GIoULoss', loss_weight=2.0),
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=2.0),
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
clip_border=False,
target_means=[0., 0., 0., 0.],
target_stds=[0.5, 0.5, 1., 1.])) for _ in range(num_stages)
],
mask_head=[
dict(
type='DynamicMaskHead',
dynamic_conv_cfg=dict(
type='DynamicConv',
in_channels=256,
feat_channels=64,
out_channels=256,
input_feat_shape=14,
with_proj=False,
act_cfg=dict(type='ReLU', inplace=True),
norm_cfg=dict(type='LN')),
dropout=0.0,
num_convs=4,
num_classes=80,
roi_feat_size=14,
in_channels=256,
conv_kernel_size=3,
conv_out_channels=256,
class_agnostic=False,
norm_cfg=dict(type='BN'),
upsample_cfg=dict(type='deconv', scale_factor=2),
loss_dice=dict(type='DiceLoss', loss_weight=8.0)) for _ in range(num_stages)
]),
# training and testing settings
train_cfg=dict(
rpn=None,
rcnn=[
dict(
assigner=dict(
type='HungarianAssigner',
cls_cost=dict(type='FocalLossCost', weight=2.0),
reg_cost=dict(type='BBoxL1Cost', weight=5.0),
iou_cost=dict(type='IoUCost', iou_mode='giou',
weight=2.0)),
sampler=dict(type='PseudoSampler'),
pos_weight=1,
mask_size=28,) for _ in range(num_stages)
]),
test_cfg=dict(rpn=None, rcnn=dict(max_per_img=num_proposals, mask_thr_binary=0.5)))

# optimizer
optimizer = dict(_delete_=True, type='AdamW', lr=0.000025, weight_decay=0.0001)
optimizer_config = dict(_delete_=True, grad_clip=dict(max_norm=1, norm_type=2))
# learning policy
lr_config = dict(policy='step', step=[8, 11])
runner = dict(type='EpochBasedRunner', max_epochs=12)
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
_base_ = './queryinst_r50_fpn_mstrain_480-800_3x_coco.py'
num_proposals = 300
model = dict(
rpn_head=dict(num_proposals=num_proposals),
test_cfg=dict(
_delete_=True, rpn=None, rcnn=dict(max_per_img=num_proposals, mask_thr_binary=0.5)))
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)

# augmentation strategy originates from DETR.
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(
type='AutoAugment',
policies=[[
dict(
type='Resize',
img_scale=[(480, 1333), (512, 1333), (544, 1333), (576, 1333),
(608, 1333), (640, 1333), (672, 1333), (704, 1333),
(736, 1333), (768, 1333), (800, 1333)],
multiscale_mode='value',
keep_ratio=True)
],
[
dict(
type='Resize',
img_scale=[(400, 1333), (500, 1333), (600, 1333)],
multiscale_mode='value',
keep_ratio=True),
dict(
type='RandomCrop',
crop_type='absolute_range',
crop_size=(384, 600),
allow_negative_crop=True),
dict(
type='Resize',
img_scale=[(480, 1333), (512, 1333), (544, 1333),
(576, 1333), (608, 1333), (640, 1333),
(672, 1333), (704, 1333), (736, 1333),
(768, 1333), (800, 1333)],
multiscale_mode='value',
override=True,
keep_ratio=True)
]]),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks'])
]
data = dict(train=dict(pipeline=train_pipeline))
23 changes: 23 additions & 0 deletions configs/queryinst/queryinst_r50_fpn_mstrain_480-800_3x_coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
_base_ = './queryinst_r50_fpn_1x_coco.py'

img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
min_values = (480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
dict(
type='Resize',
img_scale=[(1333, value) for value in min_values],
multiscale_mode='value',
keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks'])
]

data = dict(train=dict(pipeline=train_pipeline))
lr_config = dict(policy='step', step=[27, 33])
runner = dict(type='EpochBasedRunner', max_epochs=36)
5 changes: 2 additions & 3 deletions mmdet/models/detectors/sparse_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ def forward_train(self,

assert proposals is None, 'Sparse R-CNN does not support' \
' external proposals'
assert gt_masks is None, 'Sparse R-CNN does not instance segmentation'

x = self.extract_feat(img)
proposal_boxes, proposal_features, imgs_whwh = \
Expand Down Expand Up @@ -80,14 +79,14 @@ def simple_test(self, img, img_metas, rescale=False):
x = self.extract_feat(img)
proposal_boxes, proposal_features, imgs_whwh = \
self.rpn_head.simple_test_rpn(x, img_metas)
bbox_results = self.roi_head.simple_test(
results = self.roi_head.simple_test(
vealocia marked this conversation as resolved.
Show resolved Hide resolved
x,
proposal_boxes,
proposal_features,
img_metas,
imgs_whwh=imgs_whwh,
rescale=rescale)
return bbox_results
return results

def forward_dummy(self, img):
"""Used for computing network flops.
Expand Down
3 changes: 2 additions & 1 deletion mmdet/models/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .balanced_l1_loss import BalancedL1Loss, balanced_l1_loss
from .cross_entropy_loss import (CrossEntropyLoss, binary_cross_entropy,
cross_entropy, mask_cross_entropy)
from .dice_loss import DiceLoss
from .focal_loss import FocalLoss, sigmoid_focal_loss
from .gaussian_focal_loss import GaussianFocalLoss
from .gfocal_loss import DistributionFocalLoss, QualityFocalLoss
Expand All @@ -26,5 +27,5 @@
'GHMR', 'reduce_loss', 'weight_reduce_loss', 'weighted_loss', 'L1Loss',
'l1_loss', 'isr_p', 'carl_loss', 'AssociativeEmbeddingLoss',
'GaussianFocalLoss', 'QualityFocalLoss', 'DistributionFocalLoss',
'VarifocalLoss', 'KnowledgeDistillationKLDivLoss', 'SeesawLoss'
'VarifocalLoss', 'KnowledgeDistillationKLDivLoss', 'SeesawLoss', 'DiceLoss'
]
37 changes: 37 additions & 0 deletions mmdet/models/losses/dice_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import torch.nn as nn

from ..builder import LOSSES
from .utils import weighted_loss


@weighted_loss
vealocia marked this conversation as resolved.
Show resolved Hide resolved
def dice_loss(pred, target):
x = pred
vealocia marked this conversation as resolved.
Show resolved Hide resolved
eps = 1e-5
vealocia marked this conversation as resolved.
Show resolved Hide resolved
n_inst = x.size(0)
x = x.reshape(n_inst, -1)
target = target.reshape(n_inst, -1)
intersection = (x * target).sum(dim=1)
union = (x ** 2.0).sum(dim=1) + (target ** 2.0).sum(dim=1) + eps
loss = 1. - (2 * intersection / union)
return loss

@LOSSES.register_module()
class DiceLoss(nn.Module):

vealocia marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, reduction='mean', loss_weight=1.0):
super(DiceLoss, self).__init__()
self.reduction = reduction
self.loss_weight = loss_weight

def forward(self,
pred,
target,
avg_factor=None,
reduction_override=None):
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
loss_dice = self.loss_weight * dice_loss(
pred, target, reduction=reduction, avg_factor=avg_factor)
return loss_dice
3 changes: 2 additions & 1 deletion mmdet/models/roi_heads/bbox_heads/dii_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ def forward(self, roi_feat, proposal_feat):
# Self attention
proposal_feat = proposal_feat.permute(1, 0, 2)
proposal_feat = self.attention_norm(self.attention(proposal_feat))
attn_feats = proposal_feat.permute(1, 0, 2)

# instance interactive
proposal_feat = proposal_feat.permute(1, 0,
hhaAndroid marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -191,7 +192,7 @@ def forward(self, roi_feat, proposal_feat):
cls_score = self.fc_cls(cls_feat).view(N, num_proposals, -1)
bbox_delta = self.fc_reg(reg_feat).view(N, num_proposals, -1)

return cls_score, bbox_delta, obj_feat.view(N, num_proposals, -1)
return cls_score, bbox_delta, obj_feat.view(N, num_proposals, -1), attn_feats

@force_fp32(apply_to=('cls_score', 'bbox_pred'))
def loss(self,
Expand Down
3 changes: 2 additions & 1 deletion mmdet/models/roi_heads/mask_heads/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .coarse_mask_head import CoarseMaskHead
from .dynamic_mask_head import DynamicMaskHead
from .fcn_mask_head import FCNMaskHead
from .feature_relay_head import FeatureRelayHead
from .fused_semantic_head import FusedSemanticHead
Expand All @@ -13,5 +14,5 @@
__all__ = [
'FCNMaskHead', 'HTCMaskHead', 'FusedSemanticHead', 'GridHead',
'MaskIoUHead', 'CoarseMaskHead', 'MaskPointHead', 'SCNetMaskHead',
'SCNetSemanticHead', 'GlobalContextHead', 'FeatureRelayHead'
'SCNetSemanticHead', 'GlobalContextHead', 'FeatureRelayHead', 'DynamicMaskHead'
]
Loading