Skip to content

Commit

Permalink
[CodeCamp2023-605] Add new configs of deformable_detr (#10936)
Browse files Browse the repository at this point in the history
  • Loading branch information
RangeKing committed Sep 18, 2023
1 parent 02526bc commit 75c2ada
Show file tree
Hide file tree
Showing 3 changed files with 210 additions and 0 deletions.
186 changes: 186 additions & 0 deletions mmdet/configs/deformable_detr/deformable_detr_r50_16xb2_50e_coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
# Copyright (c) OpenMMLab. All rights reserved.

# Please refer to https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#a-pure-python-style-configuration-file-beta for more details. # noqa
# mmcv >= 2.0.1
# mmengine >= 0.8.0

from mmengine.config import read_base

with read_base():
from .._base_.datasets.coco_detection import *
from .._base_.default_runtime import *

from mmcv.transforms import LoadImageFromFile, RandomChoice, RandomChoiceResize
from mmengine.optim.optimizer import OptimWrapper
from mmengine.optim.scheduler import MultiStepLR
from mmengine.runner.loops import EpochBasedTrainLoop, TestLoop, ValLoop
from torch.optim.adamw import AdamW

from mmdet.datasets.transforms import (LoadAnnotations, PackDetInputs,
RandomCrop, RandomFlip, Resize)
from mmdet.models.backbones import ResNet
from mmdet.models.data_preprocessors import DetDataPreprocessor
from mmdet.models.dense_heads import DeformableDETRHead
from mmdet.models.detectors import DeformableDETR
from mmdet.models.losses import FocalLoss, GIoULoss, L1Loss
from mmdet.models.necks import ChannelMapper
from mmdet.models.task_modules import (BBoxL1Cost, FocalLossCost,
HungarianAssigner, IoUCost)

model = dict(
type=DeformableDETR,
num_queries=300,
num_feature_levels=4,
with_box_refine=False,
as_two_stage=False,
data_preprocessor=dict(
type=DetDataPreprocessor,
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
bgr_to_rgb=True,
pad_size_divisor=1),
backbone=dict(
type=ResNet,
depth=50,
num_stages=4,
out_indices=(1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=False),
norm_eval=True,
style='pytorch',
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),
neck=dict(
type=ChannelMapper,
in_channels=[512, 1024, 2048],
kernel_size=1,
out_channels=256,
act_cfg=None,
norm_cfg=dict(type='GN', num_groups=32),
num_outs=4),
encoder=dict( # DeformableDetrTransformerEncoder
num_layers=6,
layer_cfg=dict( # DeformableDetrTransformerEncoderLayer
self_attn_cfg=dict( # MultiScaleDeformableAttention
embed_dims=256,
batch_first=True),
ffn_cfg=dict(
embed_dims=256, feedforward_channels=1024, ffn_drop=0.1))),
decoder=dict( # DeformableDetrTransformerDecoder
num_layers=6,
return_intermediate=True,
layer_cfg=dict( # DeformableDetrTransformerDecoderLayer
self_attn_cfg=dict( # MultiheadAttention
embed_dims=256,
num_heads=8,
dropout=0.1,
batch_first=True),
cross_attn_cfg=dict( # MultiScaleDeformableAttention
embed_dims=256,
batch_first=True),
ffn_cfg=dict(
embed_dims=256, feedforward_channels=1024, ffn_drop=0.1)),
post_norm_cfg=None),
positional_encoding=dict(num_feats=128, normalize=True, offset=-0.5),
bbox_head=dict(
type=DeformableDETRHead,
num_classes=80,
sync_cls_avg_factor=True,
loss_cls=dict(
type=FocalLoss,
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=2.0),
loss_bbox=dict(type=L1Loss, loss_weight=5.0),
loss_iou=dict(type=GIoULoss, loss_weight=2.0)),
# training and testing settings
train_cfg=dict(
assigner=dict(
type=HungarianAssigner,
match_costs=[
dict(type=FocalLossCost, weight=2.0),
dict(type=BBoxL1Cost, weight=5.0, box_format='xywh'),
dict(type=IoUCost, iou_mode='giou', weight=2.0)
])),
test_cfg=dict(max_per_img=100))

# train_pipeline, NOTE the img_scale and the Pad's size_divisor is different
# from the default setting in mmdet.
train_pipeline = [
dict(type=LoadImageFromFile, backend_args=backend_args),
dict(type=LoadAnnotations, with_bbox=True),
dict(type=RandomFlip, prob=0.5),
dict(
type=RandomChoice,
transforms=[
[
dict(
type=RandomChoiceResize,
scales=[(480, 1333), (512, 1333), (544, 1333), (576, 1333),
(608, 1333), (640, 1333), (672, 1333), (704, 1333),
(736, 1333), (768, 1333), (800, 1333)],
resize_type=Resize,
keep_ratio=True)
],
[
dict(
type=RandomChoiceResize,
# The radio of all image in train dataset < 7
# follow the original implement
scales=[(400, 4200), (500, 4200), (600, 4200)],
resize_type=Resize,
keep_ratio=True),
dict(
type=RandomCrop,
crop_type='absolute_range',
crop_size=(384, 600),
allow_negative_crop=True),
dict(
type=RandomChoiceResize,
scales=[(480, 1333), (512, 1333), (544, 1333), (576, 1333),
(608, 1333), (640, 1333), (672, 1333), (704, 1333),
(736, 1333), (768, 1333), (800, 1333)],
resize_type=Resize,
keep_ratio=True)
]
]),
dict(type=PackDetInputs)
]
train_dataloader.update(
dict(
dataset=dict(
filter_cfg=dict(filter_empty_gt=False), pipeline=train_pipeline)))

# optimizer
optim_wrapper = dict(
type=OptimWrapper,
optimizer=dict(type=AdamW, lr=0.0002, weight_decay=0.0001),
clip_grad=dict(max_norm=0.1, norm_type=2),
paramwise_cfg=dict(
custom_keys={
'backbone': dict(lr_mult=0.1),
'sampling_offsets': dict(lr_mult=0.1),
'reference_points': dict(lr_mult=0.1)
}))

# learning policy
max_epochs = 50
train_cfg = dict(
type=EpochBasedTrainLoop, max_epochs=max_epochs, val_interval=1)
val_cfg = dict(type=ValLoop)
test_cfg = dict(type=TestLoop)

param_scheduler = [
dict(
type=MultiStepLR,
begin=0,
end=max_epochs,
by_epoch=True,
milestones=[40],
gamma=0.1)
]

# NOTE: `auto_scale_lr` is for automatically scaling LR,
# USER SHOULD NOT CHANGE ITS VALUES.
# base_batch_size = (16 GPUs) x (2 samples per GPU)
auto_scale_lr = dict(base_batch_size=32)
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Copyright (c) OpenMMLab. All rights reserved.

# Please refer to https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#a-pure-python-style-configuration-file-beta for more details. # noqa
# mmcv >= 2.0.1
# mmengine >= 0.8.0

from mmengine.config import read_base

with read_base():
from .deformable_detr_r50_16xb2_50e_coco import *

model.update(dict(with_box_refine=True))
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Copyright (c) OpenMMLab. All rights reserved.

# Please refer to https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#a-pure-python-style-configuration-file-beta for more details. # noqa
# mmcv >= 2.0.1
# mmengine >= 0.8.0

from mmengine.config import read_base

with read_base():
from .deformable_detr_refine_r50_16xb2_50e_coco import *

model.update(dict(as_two_stage=True))

0 comments on commit 75c2ada

Please sign in to comment.