-
Notifications
You must be signed in to change notification settings - Fork 9.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Support V3Det: Vast Vocabulary Visual Detection Dataset (ICCV 2023 Or…
…al) (#10938) Co-authored-by: Yuhang Cao <yhcao6@gmail.com> Co-authored-by: myownskyW7 <727032989@qq.com> Co-authored-by: Jiaqi Wang <wjqdev@gmail.com> Co-authored-by: huanghaian <huanghaian@sensetime.com>
- Loading branch information
1 parent
9915a5e
commit 6f85dfe
Showing
27 changed files
with
14,738 additions
and
17 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
# dataset settings | ||
dataset_type = 'V3DetDataset' | ||
data_root = 'data/V3Det/' | ||
|
||
backend_args = None | ||
|
||
train_pipeline = [ | ||
dict(type='LoadImageFromFile', backend_args=backend_args), | ||
dict(type='LoadAnnotations', with_bbox=True), | ||
dict( | ||
type='RandomChoiceResize', | ||
scales=[(1333, 640), (1333, 672), (1333, 704), (1333, 736), | ||
(1333, 768), (1333, 800)], | ||
keep_ratio=True), | ||
dict(type='RandomFlip', prob=0.5), | ||
dict(type='PackDetInputs') | ||
] | ||
test_pipeline = [ | ||
dict(type='LoadImageFromFile', backend_args=backend_args), | ||
dict(type='Resize', scale=(1333, 800), keep_ratio=True), | ||
# If you don't have a gt annotation, delete the pipeline | ||
dict(type='LoadAnnotations', with_bbox=True), | ||
dict( | ||
type='PackDetInputs', | ||
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', | ||
'scale_factor')) | ||
] | ||
train_dataloader = dict( | ||
batch_size=2, | ||
num_workers=2, | ||
persistent_workers=True, | ||
sampler=dict(type='DefaultSampler', shuffle=True), | ||
batch_sampler=dict(type='AspectRatioBatchSampler'), | ||
dataset=dict( | ||
type='ClassBalancedDataset', | ||
oversample_thr=1e-3, | ||
dataset=dict( | ||
type=dataset_type, | ||
data_root=data_root, | ||
ann_file='annotations/v3det_2023_v1_train.json', | ||
data_prefix=dict(img=''), | ||
filter_cfg=dict(filter_empty_gt=True, min_size=4), | ||
pipeline=train_pipeline, | ||
backend_args=backend_args))) | ||
val_dataloader = dict( | ||
batch_size=1, | ||
num_workers=2, | ||
persistent_workers=True, | ||
drop_last=False, | ||
sampler=dict(type='DefaultSampler', shuffle=False), | ||
dataset=dict( | ||
type=dataset_type, | ||
data_root=data_root, | ||
ann_file='annotations/v3det_2023_v1_val.json', | ||
data_prefix=dict(img=''), | ||
test_mode=True, | ||
pipeline=test_pipeline, | ||
backend_args=backend_args)) | ||
test_dataloader = val_dataloader | ||
|
||
val_evaluator = dict( | ||
type='CocoMetric', | ||
ann_file=data_root + 'annotations/v3det_2023_v1_val.json', | ||
metric='bbox', | ||
format_only=False, | ||
backend_args=backend_args, | ||
use_mp_eval=True, | ||
proposal_nums=[300]) | ||
test_evaluator = val_evaluator |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
<p> | ||
<div align="center"> | ||
|
||
# <img src="v3det_icon.jpg" height="25"> V3Det: Vast Vocabulary Visual Detection Dataset | ||
|
||
<div> | ||
<a href='https://myownskyw7.github.io/' target='_blank'>Jiaqi Wang</a>*, | ||
<a href='https://panzhang0212.github.io/' target='_blank'>Pan Zhang</a>*, | ||
Tao Chu*, | ||
Yuhang Cao*, </br> | ||
Yujie Zhou, | ||
<a href='https://wutong16.github.io/' target='_blank'>Tong Wu</a>, | ||
Bin Wang, | ||
Conghui He, | ||
<a href='http://dahua.site/' target='_blank'>Dahua Lin</a></br> | ||
(* equal contribution)</br> | ||
<strong>Accepted to ICCV 2023 (Oral)</strong> | ||
</div> | ||
</p> | ||
<p> | ||
<div> | ||
<strong> | ||
<a href='https://arxiv.org/pdf/2304.03752.pdf' target='_blank'>Paper</a>, | ||
<a href='https://v3det.openxlab.org.cn/' target='_blank'>Dataset</a></br> | ||
</strong> | ||
</div> | ||
</div> | ||
</p> | ||
|
||
<div align=center> | ||
<img width=960 src="https://github.com/open-mmlab/mmdetection/assets/17425982/9c216387-02be-46e6-b0f2-b856f80f6d84"/> | ||
</div> | ||
|
||
<!-- [ALGORITHM] --> | ||
|
||
## Abstract | ||
|
||
Recent advances in detecting arbitrary objects in the real world are trained and evaluated on object detection datasets with a relatively restricted vocabulary. To facilitate the development of more general visual object detection, we propose V3Det, a vast vocabulary visual detection dataset with precisely annotated bounding boxes on massive images. V3Det has several appealing properties: 1) Vast Vocabulary: It contains bounding boxes of objects from 13,204 categories on real-world images, which is 10 times larger than the existing large vocabulary object detection dataset, e.g., LVIS. 2) Hierarchical Category Organization: The vast vocabulary of V3Det is organized by a hierarchical category tree which annotates the inclusion relationship among categories, encouraging the exploration of category relationships in vast and open vocabulary object detection. 3) Rich Annotations: V3Det comprises precisely annotated objects in 243k images and professional descriptions of each category written by human experts and a powerful chatbot. By offering a vast exploration space, V3Det enables extensive benchmarks on both vast and open vocabulary object detection, leading to new observations, practices, and insights for future research. It has the potential to serve as a cornerstone dataset for developing more general visual perception systems. V3Det is available at https://v3det.openxlab.org.cn/. | ||
|
||
## Prepare Dataset | ||
|
||
Please download and prepare V3Det Dataset at [V3Det Homepage](https://v3det.openxlab.org.cn/) and [V3Det Github](https://github.com/V3Det/V3Det). | ||
|
||
The data includes a training set, a validation set, comprising 13,204 categories. The training set consists of 183,354 images, while the validation set has 29,821 images. The data organization is: | ||
|
||
``` | ||
data/ | ||
images/ | ||
<category_node>/ | ||
|────<image_name>.png | ||
... | ||
... | ||
annotations/ | ||
|────v3det_2023_v1_category_tree.json # Category tree | ||
|────category_name_13204_v3det_2023_v1.txt # Category name | ||
|────v3det_2023_v1_train.json # Train set | ||
|────v3det_2023_v1_val.json # Validation set | ||
``` | ||
|
||
## Results and Models | ||
|
||
| Backbone | Model | Lr schd | box AP | Config | Download | | ||
| :------: | :-------------: | :-----: | :----: | :----------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------------------: | | ||
| R-50 | Faster R-CNN | 2x | 25.4 | [config](./faster_rcnn_r50_fpn_8x4_sample1e-3_mstrain_v3det_2x.py) | [model](https://download.openxlab.org.cn/models/V3Det/V3Det/weight//faster_rcnn_r50_fpn_8x4_sample1e-3_mstrain_v3det_2x) | | ||
| R-50 | Cascade R-CNN | 2x | 31.6 | [config](./cascade_rcnn_r50_fpn_8x4_sample1e-3_mstrain_v3det_2x.py) | [model](https://download.openxlab.org.cn/models/V3Det/V3Det/weight//cascade_rcnn_r50_fpn_8x4_sample1e-3_mstrain_v3det_2x) | | ||
| R-50 | FCOS | 2x | 9.4 | [config](./fcos_r50_fpn_8x4_sample1e-3_mstrain_v3det_2x.py) | [model](https://download.openxlab.org.cn/models/V3Det/V3Det/weight//fcos_r50_fpn_8x4_sample1e-3_mstrain_v3det_2x) | | ||
| R-50 | Deformable-DETR | 50e | 34.4 | [config](./deformable-detr-refine-twostage_r50_8xb4_sample1e-3_v3det_50e.py) | [model](https://download.openxlab.org.cn/models/V3Det/V3Det/weight/Deformable_DETR_V3Det_R50) | | ||
| R-50 | DINO | 36e | 33.5 | [config](./dino-4scale_r50_8xb2_sample1e-3_v3det_36e.py) | [model](https://download.openxlab.org.cn/models/V3Det/V3Det/weight/DINO_V3Det_R50) | | ||
| Swin-B | Faster R-CNN | 2x | 37.6 | [config](./faster_rcnn_swinb_fpn_8x4_sample1e-3_mstrain_v3det_2x.py) | [model](https://download.openxlab.org.cn/models/V3Det/V3Det/weight//faster_rcnn_swinb_fpn_8x4_sample1e-3_mstrain_v3det_2x) | | ||
| Swin-B | Cascade R-CNN | 2x | 42.5 | [config](./cascade_rcnn_swinb_fpn_8x4_sample1e-3_mstrain_v3det_2x.py) | [model](https://download.openxlab.org.cn/models/V3Det/V3Det/weight//cascade_rcnn_swinb_fpn_8x4_sample1e-3_mstrain_v3det_2x) | | ||
| Swin-B | FCOS | 2x | 21.0 | [config](./fcos_swinb_fpn_8x4_sample1e-3_mstrain_v3det_2x.py) | [model](https://download.openxlab.org.cn/models/V3Det/V3Det/weight//fcos_swinb_fpn_8x4_sample1e-3_mstrain_v3det_2x) | | ||
| Swin-B | Deformable-DETR | 50e | 42.5 | [config](./deformable-detr-refine-twostage_swin_16xb2_sample1e-3_v3det_50e.py) | [model](https://download.openxlab.org.cn/models/V3Det/V3Det/weight/Deformable_DETR_V3Det_SwinB) | | ||
| Swin-B | DINO | 36e | 42.0 | [config](./dino-4scale_swin_16xb1_sample1e-3_v3det_36e.py) | [model](https://download.openxlab.org.cn/models/V3Det/V3Det/weight/DINO_V3Det_SwinB) | | ||
|
||
## Citation | ||
|
||
```latex | ||
@inproceedings{wang2023v3det, | ||
title = {V3Det: Vast Vocabulary Visual Detection Dataset}, | ||
author = {Wang, Jiaqi and Zhang, Pan and Chu, Tao and Cao, Yuhang and Zhou, Yujie and Wu, Tong and Wang, Bin and He, Conghui and Lin, Dahua}, | ||
booktitle = {The IEEE International Conference on Computer Vision (ICCV)}, | ||
month = {October}, | ||
year = {2023} | ||
} | ||
``` |
171 changes: 171 additions & 0 deletions
171
configs/v3det/cascade_rcnn_r50_fpn_8x4_sample1e-3_mstrain_v3det_2x.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,171 @@ | ||
_base_ = [ | ||
'../_base_/models/cascade-rcnn_r50_fpn.py', '../_base_/datasets/v3det.py', | ||
'../_base_/schedules/schedule_2x.py', '../_base_/default_runtime.py' | ||
] | ||
# model settings | ||
model = dict( | ||
rpn_head=dict( | ||
loss_bbox=dict(_delete_=True, type='L1Loss', loss_weight=1.0)), | ||
roi_head=dict(bbox_head=[ | ||
dict( | ||
type='Shared2FCBBoxHead', | ||
in_channels=256, | ||
fc_out_channels=1024, | ||
roi_feat_size=7, | ||
num_classes=13204, | ||
bbox_coder=dict( | ||
type='DeltaXYWHBBoxCoder', | ||
target_means=[0., 0., 0., 0.], | ||
target_stds=[0.1, 0.1, 0.2, 0.2]), | ||
reg_class_agnostic=True, | ||
cls_predictor_cfg=dict( | ||
type='NormedLinear', tempearture=50, bias=True), | ||
loss_cls=dict( | ||
type='CrossEntropyCustomLoss', | ||
num_classes=13204, | ||
use_sigmoid=True, | ||
loss_weight=1.0), | ||
loss_bbox=dict(type='L1Loss', loss_weight=1.0)), | ||
dict( | ||
type='Shared2FCBBoxHead', | ||
in_channels=256, | ||
fc_out_channels=1024, | ||
roi_feat_size=7, | ||
num_classes=13204, | ||
bbox_coder=dict( | ||
type='DeltaXYWHBBoxCoder', | ||
target_means=[0., 0., 0., 0.], | ||
target_stds=[0.05, 0.05, 0.1, 0.1]), | ||
reg_class_agnostic=True, | ||
cls_predictor_cfg=dict( | ||
type='NormedLinear', tempearture=50, bias=True), | ||
loss_cls=dict( | ||
type='CrossEntropyCustomLoss', | ||
num_classes=13204, | ||
use_sigmoid=True, | ||
loss_weight=1.0), | ||
loss_bbox=dict(type='L1Loss', loss_weight=1.0)), | ||
dict( | ||
type='Shared2FCBBoxHead', | ||
in_channels=256, | ||
fc_out_channels=1024, | ||
roi_feat_size=7, | ||
num_classes=13204, | ||
bbox_coder=dict( | ||
type='DeltaXYWHBBoxCoder', | ||
target_means=[0., 0., 0., 0.], | ||
target_stds=[0.033, 0.033, 0.067, 0.067]), | ||
reg_class_agnostic=True, | ||
cls_predictor_cfg=dict( | ||
type='NormedLinear', tempearture=50, bias=True), | ||
loss_cls=dict( | ||
type='CrossEntropyCustomLoss', | ||
num_classes=13204, | ||
use_sigmoid=True, | ||
loss_weight=1.0), | ||
loss_bbox=dict(type='L1Loss', loss_weight=1.0)) | ||
]), | ||
# model training and testing settings | ||
train_cfg=dict( | ||
rpn_proposal=dict(nms_pre=4000, max_per_img=2000), | ||
rcnn=[ | ||
dict( | ||
assigner=dict( | ||
type='MaxIoUAssigner', | ||
pos_iou_thr=0.5, | ||
neg_iou_thr=0.5, | ||
min_pos_iou=0.5, | ||
match_low_quality=False, | ||
ignore_iof_thr=-1, | ||
perm_repeat_gt_cfg=dict(iou_thr=0.7, perm_range=0.01)), | ||
sampler=dict( | ||
type='RandomSampler', | ||
num=512, | ||
pos_fraction=0.25, | ||
neg_pos_ub=-1, | ||
add_gt_as_proposals=True), | ||
pos_weight=-1, | ||
debug=False), | ||
dict( | ||
assigner=dict( | ||
type='MaxIoUAssigner', | ||
pos_iou_thr=0.6, | ||
neg_iou_thr=0.6, | ||
min_pos_iou=0.6, | ||
match_low_quality=False, | ||
ignore_iof_thr=-1, | ||
perm_repeat_gt_cfg=dict(iou_thr=0.7, perm_range=0.01)), | ||
sampler=dict( | ||
type='RandomSampler', | ||
num=512, | ||
pos_fraction=0.25, | ||
neg_pos_ub=-1, | ||
add_gt_as_proposals=True), | ||
pos_weight=-1, | ||
debug=False), | ||
dict( | ||
assigner=dict( | ||
type='MaxIoUAssigner', | ||
pos_iou_thr=0.7, | ||
neg_iou_thr=0.7, | ||
min_pos_iou=0.7, | ||
match_low_quality=False, | ||
ignore_iof_thr=-1, | ||
perm_repeat_gt_cfg=dict(iou_thr=0.7, perm_range=0.01)), | ||
sampler=dict( | ||
type='RandomSampler', | ||
num=512, | ||
pos_fraction=0.25, | ||
neg_pos_ub=-1, | ||
add_gt_as_proposals=True), | ||
pos_weight=-1, | ||
debug=False) | ||
]), | ||
test_cfg=dict( | ||
rcnn=dict( | ||
score_thr=0.0001, | ||
nms=dict(type='nms', iou_threshold=0.6), | ||
max_per_img=300))) | ||
# dataset settings | ||
train_dataloader = dict(batch_size=4, num_workers=8) | ||
|
||
# training schedule for 1x | ||
max_iter = 68760 * 2 | ||
train_cfg = dict( | ||
_delete_=True, | ||
type='IterBasedTrainLoop', | ||
max_iters=max_iter, | ||
val_interval=max_iter) | ||
|
||
# learning rate | ||
param_scheduler = [ | ||
dict( | ||
type='LinearLR', | ||
start_factor=1.0 / 2048, | ||
by_epoch=False, | ||
begin=0, | ||
end=5000), | ||
dict( | ||
type='MultiStepLR', | ||
begin=0, | ||
end=max_iter, | ||
by_epoch=False, | ||
milestones=[45840 * 2, 63030 * 2], | ||
gamma=0.1) | ||
] | ||
|
||
# optimizer | ||
optim_wrapper = dict( | ||
type='OptimWrapper', | ||
optimizer=dict(_delete_=True, type='AdamW', lr=1e-4 * 1, weight_decay=0.1), | ||
clip_grad=dict(max_norm=35, norm_type=2)) | ||
|
||
# Default setting for scaling LR automatically | ||
# - `enable` means enable scaling LR automatically | ||
# or not by default. | ||
# - `base_batch_size` = (8 GPUs) x (2 samples per GPU). | ||
auto_scale_lr = dict(enable=False, base_batch_size=32) | ||
|
||
default_hooks = dict( | ||
checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=5730 * 2)) | ||
log_processor = dict(type='LogProcessor', window_size=50, by_epoch=False) |
27 changes: 27 additions & 0 deletions
27
configs/v3det/cascade_rcnn_swinb_fpn_8x4_sample1e-3_mstrain_v3det_2x.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
_base_ = [ | ||
'./cascade_rcnn_r50_fpn_8x4_sample1e-3_mstrain_v3det_2x.py', | ||
] | ||
|
||
pretrained = 'https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224.pth' # noqa | ||
|
||
# model settings | ||
model = dict( | ||
backbone=dict( | ||
_delete_=True, | ||
type='SwinTransformer', | ||
embed_dims=128, | ||
depths=[2, 2, 18, 2], | ||
num_heads=[4, 8, 16, 32], | ||
window_size=7, | ||
mlp_ratio=4, | ||
qkv_bias=True, | ||
qk_scale=None, | ||
drop_rate=0., | ||
attn_drop_rate=0., | ||
drop_path_rate=0.3, | ||
patch_norm=True, | ||
out_indices=(0, 1, 2, 3), | ||
with_cp=False, | ||
convert_weights=True, | ||
init_cfg=dict(type='Pretrained', checkpoint=pretrained)), | ||
neck=dict(in_channels=[128, 256, 512, 1024])) |
Oops, something went wrong.