-
Notifications
You must be signed in to change notification settings - Fork 25
/
refcoco.py
401 lines (329 loc) · 15.6 KB
/
refcoco.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
import copy
import random
import numpy as np
import torch
import mmcv
from gpt4roi.train.train import preprocess, preprocess_multimodal
from mmdet.datasets import CocoDataset
from mmdet.datasets.api_wrappers import COCO
QUESTIONS = [
'<spi_descript>',
]
REFG_QUESTIONS = [
'Can you provide me with a detailed description of the region in the picture marked by <spi_descript>?',
"I'm curious about the region represented by <spi_descript> in the picture. Could you describe it in detail?",
'What can you tell me about the region indicated by <spi_descript> in the image?',
"I'd like to know more about the area in the photo labeled <spi_descript>. Can you give me a detailed description?",
'Could you describe the region shown as <spi_descript> in the picture in great detail?',
'What details can you give me about the region outlined by <spi_descript> in the photo?',
'Please provide me with a comprehensive description of the region marked with <spi_descript> in the image.',
'Can you give me a detailed account of the region labeled as <spi_descript> in the picture?',
"I'm interested in learning more about the region represented by <spi_descript> in the photo. Can you describe it in detail?",
'What is the region outlined by <spi_descript> in the picture like? Could you give me a detailed description?',
'Can you provide me with a detailed description of the region in the picture marked by <spi_descript>, please?',
"I'm curious about the region represented by <spi_descript> in the picture. Could you describe it in detail, please?",
'What can you tell me about the region indicated by <spi_descript> in the image, exactly?',
"I'd like to know more about the area in the photo labeled <spi_descript>, please. Can you give me a detailed description?",
'Could you describe the region shown as <spi_descript> in the picture in great detail, please?',
'What details can you give me about the region outlined by <spi_descript> in the photo, please?',
'Please provide me with a comprehensive description of the region marked with <spi_descript> in the image, please.',
'Can you give me a detailed account of the region labeled as <spi_descript> in the picture, please?',
"I'm interested in learning more about the region represented by <spi_descript> in the photo. Can you describe it in detail, please?",
'What is the region outlined by <spi_descript> in the picture like, please? Could you give me a detailed description?',
]
import os
class RefCOCO(CocoDataset):
CLASSES = ('object',)
def __init__(self,
tokenizer,
multimodal_cfg=None,
vis_processor=None,
ann_file=None,
img_prefix=None,
add_eos=True,
ignore_instruction=True,
filter_small=False,
test_mode=False,
max_gt_per_img=15,
):
self.multimodal_cfg = multimodal_cfg
self.tokenizer = tokenizer
self.ann_file = ann_file
self.img_prefix = img_prefix
self.vis_processor = vis_processor
self.max_gt_per_img = max_gt_per_img
self.add_eos = add_eos
self.ignore_instruction = ignore_instruction
self.filter_small = filter_small
self.test_mode = test_mode
img_norm_cfg = dict(
mean=[0.48145466 * 255, 0.4578275 * 255, 0.40821073 * 255],
std=[0.26862954 * 255, 0.26130258 * 255, 0.27577711 * 255],
to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True),
dict(type='Resize', img_scale=(224, 224), keep_ratio=False),
# dict(type='RandomShift', shift_ratio=0.5, max_shift_px=32),
dict(type='FilterAnnotationsFlickr', min_gt_bbox_wh=(2.0, 2.0)),
dict(type='RandomFlip', flip_ratio=0.),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=1),
dict(type='DefaultFormatBundleFlickr'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True),
dict(type='Resize', img_scale=(224, 224), keep_ratio=False),
dict(type='FilterAnnotationsFlickr', min_gt_bbox_wh=(2.0, 2.0)),
dict(type='RandomFlip', flip_ratio=0.),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=224),
dict(type='DefaultFormatBundleFlickr'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
]
if test_mode:
pipeline = test_pipeline
else:
pipeline = train_pipeline
if test_mode:
ann_file = self.ann_file
img_prefix = self.img_prefix
else:
ann_file = self.ann_file
img_prefix = self.img_prefix
train = dict(
ann_file=ann_file,
img_prefix=img_prefix,
test_mode=False,
pipeline=pipeline, )
super(CocoDataset, self).__init__(**train)
# TODO filter the small image? < 32 ?
self.num_classes = len(self.CLASSES)
self.id_cap_dict = dict()
self.begin_str = '<image>\n I will provide you with only one region ' \
'containing only one object, although there may be other ' \
'objects present in the image. It is recommended that you ' \
"describe the object's relative position with respect to other " \
'objects in the image, as well as its position within ' \
'the image and its basic attributes.'
def _filter_imgs(self, min_size=32):
"""Filter images too small or without ground truths."""
valid_inds = []
# TODO: obtain images that contain annotation
valid_img_ids = []
for i, img_info in enumerate(self.data_infos):
img_id = self.img_ids[i]
if min(img_info['width'], img_info['height']) >= min_size:
valid_inds.append(i)
valid_img_ids.append(img_id)
self.img_ids = valid_img_ids
return valid_inds
def load_annotations(self, ann_file):
"""Load annotation from COCO style annotation file.
Args:
ann_file (str): Path of annotation file.
Returns:
list[dict]: Annotation info from COCO api.
"""
self.coco = COCO(ann_file)
# The order of returned `cat_ids` will not
# change with the order of the CLASSES
self.cat_ids = self.coco.get_cat_ids(cat_names=self.CLASSES)
self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)}
self.img_ids = self.coco.get_img_ids()
data_infos = []
total_ann_ids = []
num_remove_images = 0
for i in self.img_ids:
info = self.coco.load_imgs([i])[0]
if len(info['caption'].split(' ')) < 3:
num_remove_images += 1
continue
info['filename'] = info['file_name'].split('_')[-1]
# convert data type for flickr
info['height'] = int(info['height'])
info['width'] = int(info['width'])
data_infos.append(info)
ann_ids = self.coco.get_ann_ids(img_ids=[i])
total_ann_ids.extend(ann_ids)
assert len(set(total_ann_ids)) == len(
total_ann_ids), f"Annotation ids in '{ann_file}' are not unique!"
print(f'Filtered {num_remove_images} from {self.ann_file} ')
return data_infos
def _parse_ann_info(self, img_info, ann_info):
"""Parse bbox and mask annotation.
Args:
ann_info (list[dict]): Annotation info of an image.
with_mask (bool): Whether to parse mask annotations.
Returns:
dict: A dict containing the following keys: bboxes, bboxes_ignore,\
labels, masks, seg_map. "masks" are raw annotations and not \
decoded into binary masks.
"""
gt_bboxes = []
gt_labels = []
gt_bboxes_ignore = []
gt_masks_ann = []
img_path = os.path.join(self.img_prefix, img_info['file_name'].split('_')[-1])
self.id_cap_dict[img_info['file_name'].split('_')[-1]] = img_info['caption']
# flickr
for i, ann in enumerate(ann_info):
if ann.get('ignore', False):
continue
x1, y1, w, h = ann['bbox']
inter_w = max(0, min(x1 + w, img_info['width']) - max(x1, 0))
inter_h = max(0, min(y1 + h, img_info['height']) - max(y1, 0))
if inter_w * inter_h == 0:
continue
if ann['area'] <= 0 or w < 1 or h < 1:
continue
bbox = [x1, y1, x1 + w, y1 + h]
gt_bboxes.append(bbox)
gt_labels.append(img_info['caption'])
if gt_bboxes:
gt_bboxes = np.array(gt_bboxes, dtype=np.float32)
else:
gt_bboxes = np.zeros((0, 4), dtype=np.float32)
#mmcv.imshow_bboxes(img_path, gt_bboxes, win_name=img_info['caption'])
if gt_bboxes_ignore:
gt_bboxes_ignore = np.array(gt_bboxes_ignore, dtype=np.float32)
else:
gt_bboxes_ignore = np.zeros((0, 4), dtype=np.float32)
seg_map = img_info['filename'].replace('jpg', 'png')
ann = dict(
bboxes=gt_bboxes,
labels=gt_labels,
caption=img_info['caption'],
bboxes_ignore=gt_bboxes_ignore,
masks=gt_masks_ann,
seg_map=seg_map)
return ann
def process_text(self, data_item):
if isinstance(data_item['img'], list):
# test model
data_item = {k: v[0] for k, v in data_item.items()}
return self.train_process_test(data_item)
def train_process_test(self, data_item):
image = data_item['img'].data
ori_labels = data_item['gt_labels']
ori_bboxes = data_item['gt_bboxes'].data
sources = {'conversations': []}
# DETAILS QUESTION
shuffle_ids = torch.randperm(len(ori_labels))
if len(shuffle_ids) > self.max_gt_per_img:
shuffle_ids = shuffle_ids[:self.max_gt_per_img]
select_bboxes = ori_bboxes[shuffle_ids]
select_labels = [ori_labels[i] for i in shuffle_ids]
for i in range(len(select_labels)):
question = random.choice(QUESTIONS).strip()
question = question.replace('<spi_descript>', '<bbox>')
answer = select_labels[i] # already string
sources['conversations'].append(
{'from': 'human', 'value': question})
sources['conversations'].append({'from': 'gpt', 'value': answer})
sources['conversations'][0]['value'] = self.begin_str + \
sources['conversations'][0][
'value']
# print(sources["conversations"])
cur_token_len = (image.shape[1] // 14) * (image.shape[2] // 14)
assert image.shape[1] == image.shape[2]
# a hard code [] for sources
sources = preprocess_multimodal(
copy.deepcopy([sources['conversations']]),
self.multimodal_cfg,
cur_token_len)
data_dict = preprocess(
sources,
self.tokenizer)
# get single
if isinstance(i, int):
data_dict = dict(input_ids=data_dict['input_ids'][0],
labels=data_dict['labels'][0])
data_dict['image'] = image
# double for last detail question
ori_bboxes = select_bboxes
ori_bboxes = copy.deepcopy(ori_bboxes) / image.shape[1]
data_dict['bboxes'] = ori_bboxes
data_dict['img_metas'] = data_item['img_metas'].data
return data_dict
def __getitem__(self, idx):
data_item = super().__getitem__(idx)
max_loops = 10
i = 0
while True:
if i > max_loops:
raise ValueError('No gt_labels')
i += 1
if len(data_item['gt_labels']) == 0:
idx = random.randint(0, len(self) - 1)
data_item = super().__getitem__(idx)
else:
break
# print(data_item["img_metas"])
# img, input_ids, labels
data_dict = self.process_text(data_item=data_item)
return data_dict
class RefCOCOP(RefCOCO):
def __init__(self,*args,**kwargs):
super().__init__(*args,**kwargs)
self.begin_str = '<image>\n I will provide you with only one region ' \
'containing only one object, although there may be other ' \
'objects present in the image. It is recommended that you ' \
"describe the object's relative position with respect to other " \
'objects in the image and its basic attibuts, you should not ' \
'give its position within the image' \
class RefCOCOG(RefCOCO):
def __init__(self,*args,**kwargs):
super().__init__(*args,**kwargs)
# self.begin_str = "<image>\n I will provide you with only one region " \
# "containing only one object, although there may be other " \
# "objects present in the image. It is recommended that you " \
# "try to detail descript the region."
self.begin_str = """The <image> provides an overview of the picture.\n"""
def train_process_test(self, data_item):
image = data_item['img'].data
ori_labels = data_item['gt_labels']
ori_bboxes = data_item['gt_bboxes'].data
sources = {'conversations': []}
# DETAILS QUESTION
shuffle_ids = torch.randperm(len(ori_labels))
if len(shuffle_ids) > self.max_gt_per_img:
shuffle_ids = shuffle_ids[:self.max_gt_per_img]
select_bboxes = ori_bboxes[shuffle_ids]
select_labels = [ori_labels[i] for i in shuffle_ids]
for i in range(len(select_labels)):
question = random.choice(REFG_QUESTIONS).strip()
question = question.replace('<spi_descript>', f'region{i+1} <bbox>')
answer = select_labels[i] # already string
sources['conversations'].append(
{'from': 'human', 'value': question})
sources['conversations'].append({'from': 'gpt', 'value': answer})
sources['conversations'][0]['value'] = self.begin_str + \
sources['conversations'][0][
'value']
# print(sources["conversations"])
cur_token_len = (image.shape[1] // 14) * (image.shape[2] // 14)
assert image.shape[1] == image.shape[2]
# a hard code [] for sources
sources = preprocess_multimodal(
copy.deepcopy([sources['conversations']]),
self.multimodal_cfg,
cur_token_len)
data_dict = preprocess(
sources,
self.tokenizer)
# get single
if isinstance(i, int):
data_dict = dict(input_ids=data_dict['input_ids'][0],
labels=data_dict['labels'][0])
data_dict['image'] = image
# double for last detail question
ori_bboxes = select_bboxes
ori_bboxes = copy.deepcopy(ori_bboxes) / image.shape[1]
data_dict['bboxes'] = ori_bboxes
data_dict['img_metas'] = data_item['img_metas'].data
# print(sources)
# print(len(ori_bboxes))
return data_dict