Skip to content

Commit

Permalink
Support CO-DETR (#10740)
Browse files Browse the repository at this point in the history
Co-authored-by: huanghaian <huanghaian@localhost.localdomain>
  • Loading branch information
hhaAndroid and huanghaian committed Aug 22, 2023
1 parent b4f62a4 commit c1b8677
Show file tree
Hide file tree
Showing 21 changed files with 3,486 additions and 8 deletions.
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -236,9 +236,12 @@ Results and models are available in the [model zoo](docs/en/model_zoo.md).
<li><a href="configs/dab_detr">DAB-DETR (ICLR'2022)</a></li>
<li><a href="configs/dino">DINO (ICLR'2023)</a></li>
<li><a href="configs/glip">GLIP (CVPR'2022)</a></li>
<li><a href="configs/ddq">DDQ (CVPR'2023)</a></li>
<li><a href="projects/DiffusionDet">DiffusionDet (ArXiv'2023)</a></li>
<li><a href="projects/EfficientDet">EfficientDet (CVPR'2020)</a></li>
<li><a href="projects/ViTDet">ViTDet (ECCV'2022)</a></li>
<li><a href="projects/Detic">Detic (ECCV'2022)</a></li>
<li><a href="projects/CO-DETR">CO-DETR (ICCV'2023)</a></li>
</ul>
</td>
<td>
Expand All @@ -260,13 +263,15 @@ Results and models are available in the [model zoo](docs/en/model_zoo.md).
<li><a href="projects/SparseInst">SparseInst (CVPR'2022)</a></li>
<li><a href="configs/rtmdet">RTMDet (ArXiv'2022)</a></li>
<li><a href="configs/boxinst">BoxInst (CVPR'2021)</a></li>
<li><a href="projects/ConvNeXt-V2">ConvNeXt-V2 (Arxiv'2023)</a></li>
</ul>
</td>
<td>
<ul>
<li><a href="configs/panoptic_fpn">Panoptic FPN (CVPR'2019)</a></li>
<li><a href="configs/maskformer">MaskFormer (NeurIPS'2021)</a></li>
<li><a href="configs/mask2former">Mask2Former (ArXiv'2021)</a></li>
<li><a href="configs/XDecoder">XDecoder (CVPR'2023)</a></li>
</ul>
</td>
<td>
Expand Down
5 changes: 5 additions & 0 deletions README_zh-CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -237,9 +237,12 @@ MMDetection 是一个基于 PyTorch 的目标检测开源工具箱。它是 [Ope
<li><a href="configs/dab_detr">DAB-DETR (ICLR'2022)</a></li>
<li><a href="configs/dino">DINO (ICLR'2023)</a></li>
<li><a href="configs/glip">GLIP (CVPR'2022)</a></li>
<li><a href="configs/ddq">DDQ (CVPR'2023)</a></li>
<li><a href="projects/DiffusionDet">DiffusionDet (ArXiv'2023)</a></li>
<li><a href="projects/EfficientDet">EfficientDet (CVPR'2020)</a></li>
<li><a href="projects/ViTDet">ViTDet (ECCV'2022)</a></li>
<li><a href="projects/Detic">Detic (ECCV'2022)</a></li>
<li><a href="projects/CO-DETR">CO-DETR (ICCV'2023)</a></li>
</ul>
</td>
<td>
Expand All @@ -261,13 +264,15 @@ MMDetection 是一个基于 PyTorch 的目标检测开源工具箱。它是 [Ope
<li><a href="projects/SparseInst">SparseInst (CVPR'2022)</a></li>
<li><a href="configs/rtmdet">RTMDet (ArXiv'2022)</a></li>
<li><a href="configs/boxinst">BoxInst (CVPR'2021)</a></li>
<li><a href="projects/ConvNeXt-V2">ConvNeXt-V2 (Arxiv'2023)</a></li>
</ul>
</td>
<td>
<ul>
<li><a href="configs/panoptic_fpn">Panoptic FPN (CVPR'2019)</a></li>
<li><a href="configs/maskformer">MaskFormer (NeurIPS'2021)</a></li>
<li><a href="configs/mask2former">Mask2Former (ArXiv'2021)</a></li>
<li><a href="configs/XDecoder">XDecoder (CVPR'2023)</a></li>
</ul>
</td>
<td>
Expand Down
26 changes: 23 additions & 3 deletions mmdet/models/dense_heads/detr_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@

from mmdet.registry import MODELS, TASK_UTILS
from mmdet.structures import SampleList
from mmdet.structures.bbox import bbox_cxcywh_to_xyxy, bbox_xyxy_to_cxcywh
from mmdet.structures.bbox import (bbox_cxcywh_to_xyxy, bbox_overlaps,
bbox_xyxy_to_cxcywh)
from mmdet.utils import (ConfigType, InstanceList, OptInstanceList,
OptMultiConfig, reduce_mean)
from ..losses import QualityFocalLoss
from ..utils import multi_apply


Expand Down Expand Up @@ -290,8 +292,26 @@ def loss_by_feat_single(self, cls_scores: Tensor, bbox_preds: Tensor,
cls_scores.new_tensor([cls_avg_factor]))
cls_avg_factor = max(cls_avg_factor, 1)

loss_cls = self.loss_cls(
cls_scores, labels, label_weights, avg_factor=cls_avg_factor)
if isinstance(self.loss_cls, QualityFocalLoss):
bg_class_ind = self.num_classes
pos_inds = ((labels >= 0)
& (labels < bg_class_ind)).nonzero().squeeze(1)
scores = label_weights.new_zeros(labels.shape)
pos_bbox_targets = bbox_targets[pos_inds]
pos_decode_bbox_targets = bbox_cxcywh_to_xyxy(pos_bbox_targets)
pos_bbox_pred = bbox_preds.reshape(-1, 4)[pos_inds]
pos_decode_bbox_pred = bbox_cxcywh_to_xyxy(pos_bbox_pred)
scores[pos_inds] = bbox_overlaps(
pos_decode_bbox_pred.detach(),
pos_decode_bbox_targets,
is_aligned=True)
loss_cls = self.loss_cls(
cls_scores, (labels, scores),
label_weights,
avg_factor=cls_avg_factor)
else:
loss_cls = self.loss_cls(
cls_scores, labels, label_weights, avg_factor=cls_avg_factor)

# Compute the average number of gt boxes across all gpus, for
# normalization purposes
Expand Down
29 changes: 26 additions & 3 deletions mmdet/models/dense_heads/dino_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@

from mmdet.registry import MODELS
from mmdet.structures import SampleList
from mmdet.structures.bbox import bbox_cxcywh_to_xyxy, bbox_xyxy_to_cxcywh
from mmdet.structures.bbox import (bbox_cxcywh_to_xyxy, bbox_overlaps,
bbox_xyxy_to_cxcywh)
from mmdet.utils import InstanceList, OptInstanceList, reduce_mean
from ..losses import QualityFocalLoss
from ..utils import multi_apply
from .deformable_detr_head import DeformableDETRHead

Expand Down Expand Up @@ -248,8 +250,29 @@ def _loss_dn_single(self, dn_cls_scores: Tensor, dn_bbox_preds: Tensor,
cls_avg_factor = max(cls_avg_factor, 1)

if len(cls_scores) > 0:
loss_cls = self.loss_cls(
cls_scores, labels, label_weights, avg_factor=cls_avg_factor)
if isinstance(self.loss_cls, QualityFocalLoss):
bg_class_ind = self.num_classes
pos_inds = ((labels >= 0)
& (labels < bg_class_ind)).nonzero().squeeze(1)
scores = label_weights.new_zeros(labels.shape)
pos_bbox_targets = bbox_targets[pos_inds]
pos_decode_bbox_targets = bbox_cxcywh_to_xyxy(pos_bbox_targets)
pos_bbox_pred = dn_bbox_preds.reshape(-1, 4)[pos_inds]
pos_decode_bbox_pred = bbox_cxcywh_to_xyxy(pos_bbox_pred)
scores[pos_inds] = bbox_overlaps(
pos_decode_bbox_pred.detach(),
pos_decode_bbox_targets,
is_aligned=True)
loss_cls = self.loss_cls(
cls_scores, (labels, scores),
weight=label_weights,
avg_factor=cls_avg_factor)
else:
loss_cls = self.loss_cls(
cls_scores,
labels,
label_weights,
avg_factor=cls_avg_factor)
else:
loss_cls = torch.zeros(
1, dtype=cls_scores.dtype, device=cls_scores.device)
Expand Down
32 changes: 32 additions & 0 deletions projects/CO-DETR/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# CO-DETR

> [DETRs with Collaborative Hybrid Assignments Training](https://arxiv.org/abs/2211.12860)
<!-- [ALGORITHM] -->

## Abstract

In this paper, we provide the observation that too few queries assigned as positive samples in DETR with one-to-one set matching leads to sparse supervision on the encoder's output which considerably hurt the discriminative feature learning of the encoder and vice visa for attention learning in the decoder. To alleviate this, we present a novel collaborative hybrid assignments training scheme, namely Co-DETR, to learn more efficient and effective DETR-based detectors from versatile label assignment manners. This new training scheme can easily enhance the encoder's learning ability in end-to-end detectors by training the multiple parallel auxiliary heads supervised by one-to-many label assignments such as ATSS and Faster RCNN. In addition, we conduct extra customized positive queries by extracting the positive coordinates from these auxiliary heads to improve the training efficiency of positive samples in the decoder. In inference, these auxiliary heads are discarded and thus our method introduces no additional parameters and computational cost to the original detector while requiring no hand-crafted non-maximum suppression (NMS). We conduct extensive experiments to evaluate the effectiveness of the proposed approach on DETR variants, including DAB-DETR, Deformable-DETR, and DINO-Deformable-DETR. The state-of-the-art DINO-Deformable-DETR with Swin-L can be improved from 58.5% to 59.5% AP on COCO val. Surprisingly, incorporated with ViT-L backbone, we achieve 66.0% AP on COCO test-dev and 67.9% AP on LVIS val, outperforming previous methods by clear margins with much fewer model sizes.

<div align=center>
<img src="https://github.com/open-mmlab/mmdetection/assets/17425982/dceaf7ee-cd6c-4be0-b7b1-5b01a7f11724"/>
</div>

## Results and Models

| Model | Backbone | Epochs | Aug | Dataset | box AP | Config | Download |
| :-------: | :------: | :----: | :--: | :---------------------------: | :----: | :--------------------------------------------------------------------: | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
| Co-DINO | R50 | 12 | LSJ | COCO | 52.0 | [config](configs/codino/co_dino_5scale_r50_lsj_8xb2_1x_coco.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/codetr/co_dino_5scale_r50_lsj_8xb2_1x_coco/co_dino_5scale_r50_lsj_8xb2_1x_coco-69a72d67.pth)\\ [log](https://download.openmmlab.com/mmdetection/v3.0/codetr/co_dino_5scale_r50_lsj_8xb2_1x_coco/co_dino_5scale_r50_lsj_8xb2_1x_coco_20230818_150457.json) |
| Co-DINO\* | R50 | 12 | DETR | COCO | 52.1 | [config](configs/codino/co_dino_5scale_r50_8xb2_1x_coco.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/codetr/co_dino_5scale_r50_1x_coco-7481f903.pth) |
| Co-DINO\* | R50 | 36 | LSJ | COCO | 54.8 | [config](configs/codino/co_dino_5scale_r50_lsj_8xb2_3x_coco.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/codetr/co_dino_5scale_lsj_r50_3x_coco-fe5a6829.pth) |
| Co-DINO\* | Swin-L | 12 | DETR | COCO | 58.9 | [config](configs/codino/co_dino_5scale_swin_l_16xb1_1x_coco.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/codetr/co_dino_5scale_swin_large_1x_coco-27c13da4.pth) |
| Co-DINO\* | Swin-L | 12 | LSJ | COCO | 59.3 | [config](configs/codino/co_dino_5scale_swin_l_lsj_16xb1_1x_coco.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/codetr/co_dino_5scale_lsj_swin_large_1x_coco-3af73af2.pth) |
| Co-DINO\* | Swin-L | 36 | DETR | COCO | 60.0 | [config](configs/codino/co_dino_5scale_swin_l_16xb1_3x_coco.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/codetr/co_dino_5scale_swin_large_3x_coco-d7a6d8af.pth) |
| Co-DINO\* | Swin-L | 36 | LSJ | COCO | 60.7 | [config](configs/codino/co_dino_5scale_swin_l_lsj_16xb1_3x_coco.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/codetr/co_dino_5scale_lsj_swin_large_1x_coco-3af73af2.pth) |
| Co-DINO\* | Swin-L | 16 | DETR | Objects365 pre-trained + COCO | 64.1 | [config](configs/codino/co_dino_5scale_swin_l_16xb1_16e_o365tococo.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/codetr/co_dino_5scale_swin_large_16e_o365tococo-614254c9.pth) |

Note

- Models labeled * are not trained by us, but from [CO-DETR](https://github.com/Sense-X/Co-DETR) official website.
- We find that the performance is unstable and may fluctuate by about 0.3 mAP.
- If you want to save GPU memory by enabling checkpointing, please use the `pip install fairscale` command.
13 changes: 13 additions & 0 deletions projects/CO-DETR/codetr/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .co_atss_head import CoATSSHead
from .co_dino_head import CoDINOHead
from .co_roi_head import CoStandardRoIHead
from .codetr import CoDETR
from .transformer import (CoDinoTransformer, DetrTransformerDecoderLayer,
DetrTransformerEncoder, DinoTransformerDecoder)

__all__ = [
'CoDETR', 'CoDinoTransformer', 'DinoTransformerDecoder', 'CoDINOHead',
'CoATSSHead', 'CoStandardRoIHead', 'DetrTransformerEncoder',
'DetrTransformerDecoderLayer'
]
Loading

0 comments on commit c1b8677

Please sign in to comment.