Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support Concat dataset #1139

Merged
merged 2 commits into from Jan 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
46 changes: 44 additions & 2 deletions mmpose/datasets/builder.py
@@ -1,13 +1,15 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import platform
import random
from functools import partial

import numpy as np
from mmcv.parallel import collate
from mmcv.runner import get_dist_info
from mmcv.utils import Registry, build_from_cfg
from mmcv.utils import Registry, build_from_cfg, is_seq_of
from mmcv.utils.parrots_wrapper import _get_dataloader
from torch.utils.data.dataset import ConcatDataset

from .samplers import DistributedSampler

Expand All @@ -24,6 +26,39 @@
PIPELINES = Registry('pipeline')


def _concat_dataset(cfg, default_args=None):
ly015 marked this conversation as resolved.
Show resolved Hide resolved
types = cfg['type']
ann_files = cfg['ann_file']
img_prefixes = cfg.get('img_prefix', None)
dataset_infos = cfg.get('dataset_info', None)

num_joints = cfg['data_cfg'].get('num_joints', None)
dataset_channel = cfg['data_cfg'].get('dataset_channel', None)

datasets = []
num_dset = len(ann_files)
for i in range(num_dset):
cfg_copy = copy.deepcopy(cfg)
cfg_copy['ann_file'] = ann_files[i]

if isinstance(types, (list, tuple)):
cfg_copy['type'] = types[i]
if isinstance(img_prefixes, (list, tuple)):
cfg_copy['img_prefix'] = img_prefixes[i]
if isinstance(dataset_infos, (list, tuple)):
cfg_copy['dataset_info'] = dataset_infos[i]

if isinstance(num_joints, (list, tuple)):
cfg_copy['data_cfg']['num_joints'] = num_joints[i]

if is_seq_of(dataset_channel, list):
cfg_copy['data_cfg']['dataset_channel'] = dataset_channel[i]

datasets.append(build_dataset(cfg_copy, default_args))

return ConcatDataset(datasets)


def build_dataset(cfg, default_args=None):
"""Build a dataset from config dict.

Expand All @@ -37,9 +72,16 @@ def build_dataset(cfg, default_args=None):
"""
from .dataset_wrappers import RepeatDataset

if cfg['type'] == 'RepeatDataset':
if isinstance(cfg, (list, tuple)):
dataset = ConcatDataset([build_dataset(c, default_args) for c in cfg])
elif cfg['type'] == 'ConcatDataset':
dataset = ConcatDataset(
[build_dataset(c, default_args) for c in cfg['datasets']])
elif cfg['type'] == 'RepeatDataset':
dataset = RepeatDataset(
build_dataset(cfg['dataset'], default_args), cfg['times'])
elif isinstance(cfg.get('ann_file'), (list, tuple)):
ly015 marked this conversation as resolved.
Show resolved Hide resolved
dataset = _concat_dataset(cfg, default_args)
else:
dataset = build_from_cfg(cfg, DATASETS, default_args)
return dataset
Expand Down
67 changes: 67 additions & 0 deletions tests/test_datasets/test_dataset_wrapper.py
@@ -0,0 +1,67 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv import Config

from mmpose.datasets.builder import build_dataset


def test_concat_dataset():
# build COCO-like dataset config
dataset_info = Config.fromfile(
'configs/_base_/datasets/coco.py').dataset_info

channel_cfg = dict(
num_output_channels=17,
dataset_joints=17,
dataset_channel=[
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
],
inference_channel=[
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16
])

data_cfg = dict(
image_size=[192, 256],
heatmap_size=[48, 64],
num_output_channels=channel_cfg['num_output_channels'],
num_joints=channel_cfg['dataset_joints'],
dataset_channel=channel_cfg['dataset_channel'],
inference_channel=channel_cfg['inference_channel'],
soft_nms=False,
nms_thr=1.0,
oks_thr=0.9,
vis_thr=0.2,
use_gt_bbox=True,
det_bbox_thr=0.0,
bbox_file='tests/data/coco/test_coco_det_AP_H_56.json',
)

dataset_cfg = dict(
type='TopDownCocoDataset',
ann_file='tests/data/coco/test_coco.json',
img_prefix='tests/data/coco/',
data_cfg=data_cfg,
pipeline=[],
dataset_info=dataset_info)

dataset = build_dataset(dataset_cfg)

# Case 1: build ConcatDataset explicitly
concat_dataset_cfg = dict(
type='ConcatDataset', datasets=[dataset_cfg, dataset_cfg])
concat_dataset = build_dataset(concat_dataset_cfg)
assert len(concat_dataset) == 2 * len(dataset)

# Case 2: build ConcatDataset from cfg sequence
concat_dataset = build_dataset([dataset_cfg, dataset_cfg])
assert len(concat_dataset) == 2 * len(dataset)

# Case 3: build ConcatDataset from ann_file sequence
concat_dataset_cfg = dataset_cfg.copy()
for key in ['ann_file', 'type', 'img_prefix', 'dataset_info']:
val = concat_dataset_cfg[key]
concat_dataset_cfg[key] = [val] * 2
for key in ['num_joints', 'dataset_channel']:
val = concat_dataset_cfg['data_cfg'][key]
concat_dataset_cfg['data_cfg'][key] = [val] * 2
concat_dataset = build_dataset(concat_dataset_cfg)
assert len(concat_dataset) == 2 * len(dataset)