Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Training is in progress] [Feature] Support RT-DETR #10498

Open
wants to merge 21 commits into
base: dev-3.x
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions configs/rtdetr/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# RT-DETR

> [DETRs Beat YOLOs on Real-time Object Detection](https://arxiv.org/abs/2304.08069)

<!-- [ALGORITHM] -->

## Abstract

Recently, end-to-end transformer-based detectors (DETRs) have achieved remarkable performance. However, the issue of the high computational cost of DETRs has not been effectively addressed, limiting their practical application and preventing them from fully exploiting the benefits of no post-processing, such as non-maximum suppression (NMS). In this paper, we first analyze the influence of NMS in modern real-time object detectors on inference speed, and establish an end-to-end speed benchmark. To avoid the inference delay caused by NMS, we propose a Real-Time DEtection TRansformer (RT-DETR), the first real-time end-to-end object detector to our best knowledge. Specifically, we design an efficient hybrid encoder to efficiently process multi-scale features by decoupling the intra-scale interaction and cross-scale fusion, and propose IoU-aware query selection to improve the initialization of object queries. In addition, our proposed detector supports flexibly adjustment of the inference speed by using different decoder layers without the need for retraining, which facilitates the practical application of real-time object detectors. Our RT-DETR-L achieves 53.0% AP on COCO val2017 and 114 FPS on T4 GPU, while RT-DETR-X achieves 54.8% AP and 74 FPS, outperforming all YOLO detectors of the same scale in both speed and accuracy. Furthermore, our RT-DETR-R50 achieves 53.1% AP and 108 FPS, outperforming DINO-Deformable-DETR-R50 by 2.2% AP in accuracy and by about 21 times in FPS. Source code and pretrained models will be available at PaddleDetection.

<div align=center>
<img src="https://user-images.githubusercontent.com/17582080/245363952-196b0a10-d2e8-401c-9132-54b9126e0a33.png"/>
</div>

## Results and Models

| Backbone | Model | Lr schd | box AP | Config | Download |
| :------: | :-----------: | :-----: | :----: | :---------------------------------------: | :---------------------------------------------------------------------------------------------------: |
| R-50 | RT-DETR-R50\* | 72e | 53.1 | [config](./rtdetr_r50vd_8xb2-72e_coco.py) | [model](https://github.com/nijkah/storage/releases/download/v0.0.1/rtdetr_r50vd_6x_coco_mmdet.pth) \| |

### NOTE

Models with * are converted from the [official repo](https://github.com/PaddlePaddle/PaddleDetection/). The config files of these models are only for inference. We haven't reprodcue the training results.

## Citation

```latex
@article{lv2023detrs,
title={Detrs beat yolos on real-time object detection},
author={Lv, Wenyu and Xu, Shangliang and Zhao, Yian and Wang, Guanzhong and Wei, Jinman and Cui, Cheng and Du, Yuning and Dang, Qingqing and Liu, Yi},
journal={arXiv preprint arXiv:2304.08069},
year={2023}
}
```
31 changes: 31 additions & 0 deletions configs/rtdetr/metafile.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
Collections:
- Name: RT-DETR
Metadata:
Training Data: COCO
Training Techniques:
- AdamW
- Gradient Clip
Training Resources: 8x A100 GPUs
Architecture:
- ResNet
- Transformer
Paper:
URL: https://arxiv.org/abs/2304.08069
Title: 'DETRs Beat YOLOs on Real-time Object Detection'
README: configs/rtdetr/README.md
Code:
URL: https://github.com/open-mmlab/mmdetection/blob/f4112c9e5611468ffbd57cfba548fd1289264b52/mmdet/models/detectors/dino.py#L17
Version: v3.0.0rc6

Models:
- Name: rtdetr_r50vd_8xb2-72e_coco.py
In Collection: RT-DETR
Config: configs/rtdetr/rtdetr_r50vd_8xb2-72e_coco.py
Metadata:
Epochs: 72
Results:
- Task: Object Detection
Dataset: COCO
Metrics:
box AP: 53.1
Weights: https://github.com/nijkah/storage/releases/download/v0.0.1/rtdetr_r50vd_6x_coco_mmdet.pth
186 changes: 186 additions & 0 deletions configs/rtdetr/rtdetr_r50vd_8xb2-72e_coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
_base_ = [
'../_base_/datasets/coco_detection.py', '../_base_/default_runtime.py'
]
pretrained = 'https://github.com/nijkah/storage/releases/download/v0.0.1/resnet50vd_ssld_v2_pretrained.pth' # noqa

eval_size = (640, 640)
model = dict(
type='RTDETR',
num_queries=300, # num_matching_queries
with_box_refine=True,
as_two_stage=True,
eval_size=eval_size,
data_preprocessor=dict(
type='DetDataPreprocessor',
mean=[0, 0, 0],
std=[255, 255, 255],
bgr_to_rgb=True,
pad_size_divisor=32),
backbone=dict(
type='ResNetV1d',
depth=50,
num_stages=4,
out_indices=(1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=False),
norm_eval=True,
style='pytorch',
init_cfg=dict(type='Pretrained', checkpoint=pretrained)),
neck=dict(
type='HybridEncoder',
num_encoder_layers=1,
layer_cfg=dict(
self_attn_cfg=dict(embed_dims=256, num_heads=8,
dropout=0.0), # 0.1 for DeformDETR
ffn_cfg=dict(
embed_dims=256,
feedforward_channels=1024, # 1024 for DeformDETR
ffn_drop=0.0,
act_cfg=dict(type='GELU'))),
projector=dict(
type='ChannelMapper',
in_channels=[256, 256, 256],
kernel_size=1,
out_channels=256,
act_cfg=None,
norm_cfg=dict(type='BN'),
num_outs=3)), # 0.1 for DeformDETR
encoder=None,
decoder=dict(
num_layers=6,
eval_idx=-1,
layer_cfg=dict(
self_attn_cfg=dict(embed_dims=256, num_heads=8,
dropout=0.0), # 0.1 for DeformDETR
cross_attn_cfg=dict(
embed_dims=256,
num_levels=3, # 4 for DeformDETR
dropout=0.0), # 0.1 for DeformDETR
ffn_cfg=dict(
embed_dims=256,
feedforward_channels=1024, # 2048 for DINO
ffn_drop=0.0)), # 0.1 for DeformDETR
post_norm_cfg=None),
positional_encoding=dict(
num_feats=128,
normalize=True,
offset=0.0, # -0.5 for DeformDETR
temperature=20), # 10000 for DeformDETR
bbox_head=dict(
type='RTDETRHead',
num_classes=80,
loss_cls=dict(
type='VarifocalLoss',
use_sigmoid=True,
use_rtdetr=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0), # 2.0 in DeformDETR
loss_bbox=dict(type='L1Loss', loss_weight=5.0),
loss_iou=dict(type='GIoULoss', loss_weight=2.0)),
dn_cfg=dict( # TODO: Move to model.train_cfg ?
label_noise_scale=0.5,
box_noise_scale=1.0, # 0.4 for DN-DETR
group_cfg=dict(dynamic=True, num_groups=None,
num_dn_queries=100)), # TODO: half num_dn_queries
# training and testing settings
train_cfg=dict(
assigner=dict(
type='HungarianAssigner',
match_costs=[
dict(type='FocalLossCost', weight=2.0),
dict(type='BBoxL1Cost', weight=5.0, box_format='xywh'),
dict(type='IoUCost', iou_mode='giou', weight=2.0)
])),
test_cfg=dict(max_per_img=300)) # 100 for DeformDETR

# train_pipeline, NOTE the img_scale and the Pad's size_divisor is different
# from the default setting in mmdet.
train_pipeline = [
dict(type='LoadImageFromFile', backend_args={{_base_.backend_args}}),
dict(type='LoadAnnotations', with_bbox=True),
dict(
type='RandomChoiceResize',
scales=[(480, 480), (512, 512), (544, 544), (576, 576), (608, 608),
(640, 640), (640, 640), (640, 640), (672, 672), (704, 704),
(736, 736), (768, 768), (800, 800)],
keep_ratio=False),
dict(type='PhotoMetricDistortion'),
dict(
type='Expand',
mean=[123.675, 116.28, 103.53],
to_rgb=True,
ratio_range=(1, 2)),
dict(type='RandomCrop', crop_size=(640, 640)),
dict(type='RandomFlip', prob=0.5),
dict(type='Pad', size=(640, 640), pad_val=dict(img=(114, 114, 114))),
dict(type='PackDetInputs')
]

test_pipeline = [
dict(type='LoadImageFromFile', backend_args={{_base_.backend_args}}),
dict(
type='Resize',
scale=eval_size,
keep_ratio=False,
interpolation='bicubic'),
dict(type='LoadAnnotations', with_bbox=True),
dict(
type='PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
'scale_factor'))
]

train_dataloader = dict(
batch_size=4,
dataset=dict(
filter_cfg=dict(filter_empty_gt=False), pipeline=train_pipeline))
val_dataloader = dict(dataset=dict(pipeline=test_pipeline))
test_dataloader = dict(dataset=dict(pipeline=test_pipeline))

# optimizer
optim_wrapper = dict(
type='OptimWrapper',
optimizer=dict(
type='AdamW',
lr=0.0001, # 0.0002 for DeformDETR
weight_decay=0.0001),
clip_grad=dict(max_norm=0.1, norm_type=2),
paramwise_cfg=dict(custom_keys={'backbone': dict(lr_mult=0.1)})
) # custom_keys contains sampling_offsets and reference_points in DeformDETR # noqa

# learning policy
max_epochs = 72
train_cfg = dict(
type='EpochBasedTrainLoop', max_epochs=max_epochs, val_interval=6)

val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')

param_scheduler = [
dict(
type='LinearLR', start_factor=0.001, by_epoch=False, begin=0,
end=2000),
dict(
type='MultiStepLR',
begin=0,
end=max_epochs,
by_epoch=True,
milestones=[100],
gamma=1.0)
]

# NOTE: `auto_scale_lr` is for automatically scaling LR,
# USER SHOULD NOT CHANGE ITS VALUES.
# base_batch_size = (8 GPUs) x (2 samples per GPU)
auto_scale_lr = dict(base_batch_size=16)

custom_hooks = [
dict(
type='EMAHook',
ema_type='ExpMomentumEMA',
momentum=0.0001,
update_buffers=True,
priority=49),
]
find_unused_parameters = True
3 changes: 2 additions & 1 deletion mmdet/models/dense_heads/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from .retina_head import RetinaHead
from .retina_sepbn_head import RetinaSepBNHead
from .rpn_head import RPNHead
from .rtdetr_head import RTDETRHead
from .rtmdet_head import RTMDetHead, RTMDetSepBNHead
from .rtmdet_ins_head import RTMDetInsHead, RTMDetInsSepBNHead
from .sabl_retina_head import SABLRetinaHead
Expand Down Expand Up @@ -66,5 +67,5 @@
'CenterNetUpdateHead', 'RTMDetHead', 'RTMDetSepBNHead', 'CondInstBboxHead',
'CondInstMaskHead', 'RTMDetInsHead', 'RTMDetInsSepBNHead',
'BoxInstBboxHead', 'BoxInstMaskHead', 'ConditionalDETRHead', 'DINOHead',
'ATSSVLFusionHead', 'DABDETRHead'
'ATSSVLFusionHead', 'DABDETRHead', 'RTDETRHead'
]
Loading