Skip to content

Commit

Permalink
[Fix] fix coco dataset for year 2017 (#1948)
Browse files Browse the repository at this point in the history
* fix coco dataset for 2017

* add caption style for prompt

---------

Co-authored-by: LeoXing1996 <xingzn1996@hotmail.com>
  • Loading branch information
liuwenran and LeoXing1996 committed Aug 4, 2023
1 parent a493039 commit 5b5f895
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 6 deletions.
18 changes: 14 additions & 4 deletions mmagic/datasets/mscoco_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ class MSCoCoDataset(BasicConditionalDataset):
the dataset is needed, which is not necessary to load annotation
file. ``Basedataset`` can skip load annotations to save time by set
``lazy_init=False``. Defaults to False.
caption_style (str): If you want to add a style description for each
caption, you can set caption_style to your style prompt. For
example, 'realistic style'. Defaults to empty str.
**kwargs: Other keyword arguments in :class:`BaseDataset`.
"""
METAINFO = dict(dataset_type='text_image_dataset', task_name='editing')
Expand All @@ -51,14 +54,19 @@ def __init__(self,
'.bmp', '.pgm', '.tif'),
lazy_init: bool = False,
classes: Union[str, Sequence[str], None] = None,
caption_style: str = '',
**kwargs):
ann_file = os.path.join('annotations', 'captions_' + phase +
f'{year}.json') if ann_file == '' else ann_file
self.image_prename = 'COCO_' + phase + f'{year}_'
self.year = year
assert self.year == 2014 or self.year == 2017, \
'Caption is only supported in 2014 or 2017.'
self.image_prename = ''
if self.year == 2014:
self.image_prename = 'COCO_' + phase + f'{year}_'
self.phase = phase
self.drop_rate = drop_caption_rate
self.year = year
assert self.year == 2014, 'We only support CoCo2014 now.'
self.caption_style = caption_style

super().__init__(
ann_file=ann_file,
Expand Down Expand Up @@ -90,10 +98,12 @@ def add_prefix(filename, prefix=''):
os.path.join(self.phase + str(self.year), image_name),
self.img_prefix)
caption = item['caption'].lower()
if self.caption_style != '':
caption = caption + ' ' + self.caption_style
info = {
'img_path':
img_path,
'gt_label':
'gt_prompt':
caption if (self.phase != 'train' or self.drop_rate < 1e-6
or random.random() >= self.drop_rate) else ''
}
Expand Down
4 changes: 2 additions & 2 deletions tests/test_datasets/test_mscoco_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def test_mscoco(self):
# test basic usage
dataset = MSCoCoDataset(data_root=self.data_root, pipeline=[])
assert dataset[0] == dict(
gt_label='a good meal',
gt_prompt='a good meal',
img_path=os.path.join(self.data_root, 'train2014',
'COCO_train2014_000000000009.jpg'),
sample_idx=0)
Expand All @@ -25,7 +25,7 @@ def test_mscoco(self):
dataset = MSCoCoDataset(
data_root=self.data_root, phase='val', pipeline=[])
assert dataset[0] == dict(
gt_label='a pair of slippers',
gt_prompt='a pair of slippers',
img_path=os.path.join(self.data_root, 'val2014',
'COCO_val2014_000000000042.jpg'),
sample_idx=0)

0 comments on commit 5b5f895

Please sign in to comment.