Skip to content

Commit

Permalink
Support GroundingDINO finetune (#10954)
Browse files Browse the repository at this point in the history
Co-authored-by: huanghaian <huanghaian@sensetime.com>
  • Loading branch information
Johnson-Wang and hhaAndroid committed Sep 26, 2023
1 parent 2457c4e commit 658c19e
Show file tree
Hide file tree
Showing 19 changed files with 1,145 additions and 51 deletions.
17 changes: 17 additions & 0 deletions configs/glip/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,23 @@ configs/glip/glip_atss_swin-t_a_fpn_dyhead_pretrain_obj365.py \
<img src="https://github.com/open-mmlab/mmdetection/assets/17425982/7b450d96-81ac-462a-92bc-0d4ae7b8721c" width="40%"/>
</div>

## NOTE

GLIP utilizes BERT as the language model, which requires access to https://huggingface.co/. If you encounter connection errors due to network access, you can download the required files on a computer with internet access and save them locally. Finally, modify the `lang_model_name` field in the config to the local path. Please refer to the following code:

```python
from transformers import BertConfig, BertModel
from transformers import AutoTokenizer

config = BertConfig.from_pretrained("bert-base-uncased")
model = BertModel.from_pretrained("bert-base-uncased", add_pooling_layer=False, config=config)
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

config.save_pretrained("your path/bert-base-uncased")
model.save_pretrained("your path/bert-base-uncased")
tokenizer.save_pretrained("your path/bert-base-uncased")
```

## Results and Models

| Model | Zero-shot or Funetune | COCO mAP | Official COCO mAP | Pre-Train Data | Config | Download |
Expand Down
34 changes: 29 additions & 5 deletions configs/grounding_dino/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Grounding DINO: Marrying DINO with Grounded Pre-Training for Open-Set Object Detection

[GLIP: Grounded Language-Image Pre-training](https://arxiv.org/abs/2112.03857)
[Grounding DINO: Marrying DINO with Grounded Pre-Training for Open-Set Object Detection](https://arxiv.org/abs/2303.05499)

<!-- [ALGORITHM] -->

Expand All @@ -24,6 +24,25 @@ pip install -r requirements/multimodal.txt
mim install mmdet[multimodal]
```

## NOTE

Grounding DINO utilizes BERT as the language model, which requires access to https://huggingface.co/. If you encounter connection errors due to network access, you can download the required files on a computer with internet access and save them locally. Finally, modify the `lang_model_name` field in the config to the local path. Please refer to the following code:

```python
from transformers import BertConfig, BertModel
from transformers import AutoTokenizer

config = BertConfig.from_pretrained("bert-base-uncased")
model = BertModel.from_pretrained("bert-base-uncased", add_pooling_layer=False, config=config)
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

config.save_pretrained("your path/bert-base-uncased")
model.save_pretrained("your path/bert-base-uncased")
tokenizer.save_pretrained("your path/bert-base-uncased")
```

## Inference

```
cd $MMDETROOT
Expand All @@ -42,11 +61,16 @@ python demo/image_demo.py \

## Results and Models

| Model | backbone | COCO mAP | Pre-Train Data | Config | Download |
| :--------------: | :------: | :------: | :----------------------------------------------: | :------------------------------------------------------------: | :--------------------------------------------------------------------------------------------------------------------: |
| Grounding DINO-T | Swin-T | 48.5 | O365,GoldG,Cap4M | [config](grounding_dino_swin-t_pretrain_obj365_goldg_cap4m.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/grounding_dino/groundingdino_swint_ogc_mmdet-822d7e9d.pth) |
| Grounding DINO-B | Swin-B | 56.9 | COCO,O365,GoldG,Cap4M,OpenImage,ODinW-35,RefCOCO | [config](grounding_dino_swin-b_pretrain_mixeddata.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/grounding_dino/groundingdino_swinb_cogcoor_mmdet-55949c9c.pth) |
| Model | Backbone | Style | COCO mAP | Official COCO mAP | Pre-Train Data | Config | Download |
| :----------------: | :------: | :-------: | :--------: | :---------------: | :----------------------------------------------: | :------------------------------------------------------------: | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
| Grounding DINO-T | Swin-T | Zero-shot | 48.5 | 48.4 | O365,GoldG,Cap4M | [config](grounding_dino_swin-t_pretrain_obj365_goldg_cap4m.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/grounding_dino/groundingdino_swint_ogc_mmdet-822d7e9d.pth) |
| Grounding DINO-T | Swin-T | Funetune | 58.1(+0.9) | 57.2 | O365,GoldG,Cap4M | [config](grounding_dino_swin-t_finetune_16xb2_1x_coco.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/grounding_dino/grounding_dino_swin-t_finetune_16xb2_1x_coco/grounding_dino_swin-t_finetune_16xb2_1x_coco_20230921_152544-5f234b20.pth) \| [log](https://download.openmmlab.com/mmdetection/v3.0/grounding_dino/grounding_dino_swin-t_finetune_16xb2_1x_coco/grounding_dino_swin-t_finetune_16xb2_1x_coco_20230921_152544.log.json) |
| Grounding DINO-B | Swin-B | Zero-shot | 56.9 | 56.7 | COCO,O365,GoldG,Cap4M,OpenImage,ODinW-35,RefCOCO | [config](grounding_dino_swin-b_pretrain_mixeddata.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/grounding_dino/groundingdino_swinb_cogcoor_mmdet-55949c9c.pth) |
| Grounding DINO-B | Swin-B | Funetune | 59.7 | | COCO,O365,GoldG,Cap4M,OpenImage,ODinW-35,RefCOCO | [config](grounding_dino_swin-b_finetune_16xb2_1x_coco.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/grounding_dino/grounding_dino_swin-b_finetune_16xb2_1x_coco/grounding_dino_swin-b_finetune_16xb2_1x_coco_20230921_153201-f219e0c0.pth) \| [log](https://download.openmmlab.com/mmdetection/v3.0/grounding_dino/grounding_dino_swin-b_finetune_16xb2_1x_coco/grounding_dino_swin-b_finetune_16xb2_1x_coco_20230921_153201.log.json) |
| Grounding DINO-R50 | R50 | scratch | 48.9(+0.8) | 48.1 | | [config](grounding_dino_r50_scratch_8xb2_1x_coco.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/grounding_dino/grounding_dino_r50_scratch_8xb2_1x_coco/grounding_dino_r50_scratch_1x_coco-fe0002f2.pth) \| [log](https://download.openmmlab.com/mmdetection/v3.0/grounding_dino/grounding_dino_r50_scratch_8xb2_1x_coco/20230922_114218.json) |

Note:

1. The weights corresponding to the zero-shot model are adopted from the official weights and converted using the [script](../../tools/model_converters/groundingdino_to_mmdet.py). We have not retrained the model for the time being.
2. Funetune refers to fine-tuning on the COCO 2017 dataset. The R50 model is trained using 8 NVIDIA GeForce 3090 GPUs, while the remaining models are trained using 16 NVIDIA GeForce 3090 GPUs. The GPU memory usage is approximately 8.5GB.
3. Our performance is higher than the official model due to two reasons: we modified the initialization strategy and introduced a log scaler.
208 changes: 208 additions & 0 deletions configs/grounding_dino/grounding_dino_r50_scratch_8xb2_1x_coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
_base_ = [
'../_base_/datasets/coco_detection.py',
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
]
lang_model_name = 'bert-base-uncased'

model = dict(
type='GroundingDINO',
num_queries=900,
with_box_refine=True,
as_two_stage=True,
data_preprocessor=dict(
type='DetDataPreprocessor',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
bgr_to_rgb=True,
pad_mask=False,
),
language_model=dict(
type='BertModel',
name=lang_model_name,
pad_to_max=False,
use_sub_sentence_represent=True,
special_tokens_list=['[CLS]', '[SEP]', '.', '?'],
add_pooling_layer=False,
),
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=False),
norm_eval=True,
style='pytorch',
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),
neck=dict(
type='ChannelMapper',
in_channels=[512, 1024, 2048],
kernel_size=1,
out_channels=256,
act_cfg=None,
bias=True,
norm_cfg=dict(type='GN', num_groups=32),
num_outs=4),
encoder=dict(
num_layers=6,
num_cp=6,
# visual layer config
layer_cfg=dict(
self_attn_cfg=dict(embed_dims=256, num_levels=4, dropout=0.0),
ffn_cfg=dict(
embed_dims=256, feedforward_channels=2048, ffn_drop=0.0)),
# text layer config
text_layer_cfg=dict(
self_attn_cfg=dict(num_heads=4, embed_dims=256, dropout=0.0),
ffn_cfg=dict(
embed_dims=256, feedforward_channels=1024, ffn_drop=0.0)),
# fusion layer config
fusion_layer_cfg=dict(
v_dim=256,
l_dim=256,
embed_dim=1024,
num_heads=4,
init_values=1e-4),
),
decoder=dict(
num_layers=6,
return_intermediate=True,
layer_cfg=dict(
# query self attention layer
self_attn_cfg=dict(embed_dims=256, num_heads=8, dropout=0.0),
# cross attention layer query to text
cross_attn_text_cfg=dict(embed_dims=256, num_heads=8, dropout=0.0),
# cross attention layer query to image
cross_attn_cfg=dict(embed_dims=256, num_heads=8, dropout=0.0),
ffn_cfg=dict(
embed_dims=256, feedforward_channels=2048, ffn_drop=0.0)),
post_norm_cfg=None),
positional_encoding=dict(
num_feats=128, normalize=True, offset=0.0, temperature=20),
bbox_head=dict(
type='GroundingDINOHead',
num_classes=80,
sync_cls_avg_factor=True,
contrastive_cfg=dict(max_text_len=256, log_scale='auto', bias=True),
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0), # 2.0 in DeformDETR
loss_bbox=dict(type='L1Loss', loss_weight=5.0),
loss_iou=dict(type='GIoULoss', loss_weight=2.0)),
dn_cfg=dict( # TODO: Move to model.train_cfg ?
label_noise_scale=0.5,
box_noise_scale=1.0, # 0.4 for DN-DETR
group_cfg=dict(dynamic=True, num_groups=None,
num_dn_queries=100)), # TODO: half num_dn_queries
# training and testing settings
train_cfg=dict(
assigner=dict(
type='HungarianAssigner',
match_costs=[
dict(type='BinaryFocalLossCost', weight=2.0),
dict(type='BBoxL1Cost', weight=5.0, box_format='xywh'),
dict(type='IoUCost', iou_mode='giou', weight=2.0)
])),
test_cfg=dict(max_per_img=300))

# dataset settings
train_pipeline = [
dict(type='LoadImageFromFile', backend_args=_base_.backend_args),
dict(type='LoadAnnotations', with_bbox=True),
dict(type='RandomFlip', prob=0.5),
dict(
type='RandomChoice',
transforms=[
[
dict(
type='RandomChoiceResize',
scales=[(480, 1333), (512, 1333), (544, 1333), (576, 1333),
(608, 1333), (640, 1333), (672, 1333), (704, 1333),
(736, 1333), (768, 1333), (800, 1333)],
keep_ratio=True)
],
[
dict(
type='RandomChoiceResize',
# The radio of all image in train dataset < 7
# follow the original implement
scales=[(400, 4200), (500, 4200), (600, 4200)],
keep_ratio=True),
dict(
type='RandomCrop',
crop_type='absolute_range',
crop_size=(384, 600),
allow_negative_crop=True),
dict(
type='RandomChoiceResize',
scales=[(480, 1333), (512, 1333), (544, 1333), (576, 1333),
(608, 1333), (640, 1333), (672, 1333), (704, 1333),
(736, 1333), (768, 1333), (800, 1333)],
keep_ratio=True)
]
]),
dict(
type='PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
'scale_factor', 'flip', 'flip_direction', 'text',
'custom_entities'))
]

test_pipeline = [
dict(type='LoadImageFromFile', backend_args=_base_.backend_args),
dict(type='FixScaleResize', scale=(800, 1333), keep_ratio=True),
dict(type='LoadAnnotations', with_bbox=True),
dict(
type='PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
'scale_factor', 'text', 'custom_entities'))
]

train_dataloader = dict(
dataset=dict(
filter_cfg=dict(filter_empty_gt=False),
pipeline=train_pipeline,
return_classes=True))
val_dataloader = dict(
dataset=dict(pipeline=test_pipeline, return_classes=True))
test_dataloader = val_dataloader

# We did not adopt the official 24e optimizer strategy
# because the results indicate that the current strategy is superior.
optim_wrapper = dict(
_delete_=True,
type='OptimWrapper',
optimizer=dict(
type='AdamW',
lr=0.0001, # 0.0002 for DeformDETR
weight_decay=0.0001),
clip_grad=dict(max_norm=0.1, norm_type=2),
paramwise_cfg=dict(custom_keys={
'absolute_pos_embed': dict(decay_mult=0.),
'backbone': dict(lr_mult=0.1)
}))
# learning policy
max_epochs = 12
train_cfg = dict(
type='EpochBasedTrainLoop', max_epochs=max_epochs, val_interval=1)

val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')

param_scheduler = [
dict(
type='MultiStepLR',
begin=0,
end=max_epochs,
by_epoch=True,
milestones=[11],
gamma=0.1)
]

# NOTE: `auto_scale_lr` is for automatically scaling LR,
# USER SHOULD NOT CHANGE ITS VALUES.
# base_batch_size = (8 GPUs) x (2 samples per GPU)
auto_scale_lr = dict(base_batch_size=16)
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
_base_ = [
'./grounding_dino_swin-t_finetune_16xb2_1x_coco.py',
]

load_from = 'https://download.openmmlab.com/mmdetection/v3.0/grounding_dino/groundingdino_swinb_cogcoor_mmdet-55949c9c.pth' # noqa
model = dict(
type='GroundingDINO',
backbone=dict(
pretrain_img_size=384,
embed_dims=128,
depths=[2, 2, 18, 2],
num_heads=[4, 8, 16, 32],
window_size=12,
drop_path_rate=0.3,
patch_norm=True),
neck=dict(in_channels=[256, 512, 1024]),
)
Loading

0 comments on commit 658c19e

Please sign in to comment.