-
Notifications
You must be signed in to change notification settings - Fork 9.4k
/
grounding_dino_swin-t_pretrain_obj365_goldg_v3det.py
101 lines (96 loc) · 3.52 KB
/
grounding_dino_swin-t_pretrain_obj365_goldg_v3det.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
_base_ = 'grounding_dino_swin-t_pretrain_obj365.py'
o365v1_od_dataset = dict(
type='ODVGDataset',
data_root='data/objects365v1/',
ann_file='o365v1_train_odvg.json',
label_map_file='o365v1_label_map.json',
data_prefix=dict(img='train/'),
filter_cfg=dict(filter_empty_gt=False),
pipeline=_base_.train_pipeline,
return_classes=True,
backend_args=None,
)
flickr30k_dataset = dict(
type='ODVGDataset',
data_root='data/flickr30k_entities/',
ann_file='final_flickr_separateGT_train_vg.json',
label_map_file=None,
data_prefix=dict(img='flickr30k_images/'),
filter_cfg=dict(filter_empty_gt=False),
pipeline=_base_.train_pipeline,
return_classes=True,
backend_args=None)
gqa_dataset = dict(
type='ODVGDataset',
data_root='data/gqa/',
ann_file='final_mixed_train_no_coco_vg.json',
label_map_file=None,
data_prefix=dict(img='images/'),
filter_cfg=dict(filter_empty_gt=False),
pipeline=_base_.train_pipeline,
return_classes=True,
backend_args=None)
v3d_train_pipeline = [
dict(type='LoadImageFromFile', backend_args=_base_.backend_args),
dict(type='LoadAnnotations', with_bbox=True),
dict(type='RandomFlip', prob=0.5),
dict(
type='RandomChoice',
transforms=[
[
dict(
type='RandomChoiceResize',
scales=[(480, 1333), (512, 1333), (544, 1333), (576, 1333),
(608, 1333), (640, 1333), (672, 1333), (704, 1333),
(736, 1333), (768, 1333), (800, 1333)],
keep_ratio=True)
],
[
dict(
type='RandomChoiceResize',
# The radio of all image in train dataset < 7
# follow the original implement
scales=[(400, 4200), (500, 4200), (600, 4200)],
keep_ratio=True),
dict(
type='RandomCrop',
crop_type='absolute_range',
crop_size=(384, 600),
allow_negative_crop=True),
dict(
type='RandomChoiceResize',
scales=[(480, 1333), (512, 1333), (544, 1333), (576, 1333),
(608, 1333), (640, 1333), (672, 1333), (704, 1333),
(736, 1333), (768, 1333), (800, 1333)],
keep_ratio=True)
]
]),
dict(type='FilterAnnotations', min_gt_bbox_wh=(1e-2, 1e-2)),
dict(
type='RandomSamplingNegPos',
tokenizer_name=_base_.lang_model_name,
num_sample_negative=85,
# change this
label_map_file='data/V3Det/annotations/v3det_2023_v1_label_map.json',
max_tokens=256),
dict(
type='PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
'scale_factor', 'flip', 'flip_direction', 'text',
'custom_entities', 'tokens_positive', 'dataset_mode'))
]
v3det_dataset = dict(
type='ODVGDataset',
data_root='data/V3Det/',
ann_file='annotations/v3det_2023_v1_train_od.json',
label_map_file='annotations/v3det_2023_v1_label_map.json',
data_prefix=dict(img=''),
filter_cfg=dict(filter_empty_gt=False),
need_text=False, # change this
pipeline=v3d_train_pipeline,
return_classes=True,
backend_args=None)
train_dataloader = dict(
dataset=dict(datasets=[
o365v1_od_dataset, flickr30k_dataset, gqa_dataset, v3det_dataset
]))