Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] fix bug in CombinedDataset and remove RepeatDataset #1930

Merged
merged 4 commits into from Jan 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 2 additions & 2 deletions mmpose/datasets/__init__.py
@@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .builder import build_dataset
from .dataset_wrappers import CombinedDataset, RepeatDataset
from .dataset_wrappers import CombinedDataset
from .datasets import * # noqa
from .transforms import * # noqa

__all__ = ['build_dataset', 'RepeatDataset', 'CombinedDataset']
__all__ = ['build_dataset', 'CombinedDataset']
51 changes: 15 additions & 36 deletions mmpose/datasets/dataset_wrappers.py
Expand Up @@ -10,35 +10,6 @@
from .datasets.utils import parse_pose_metainfo


@DATASETS.register_module()
class RepeatDataset:
"""A wrapper of repeated dataset.

The length of repeated dataset will be `times` larger than the original
dataset. This is useful when the data loading time is long but the dataset
is small. Using RepeatDataset can reduce the data loading time between
epochs.

Args:
dataset (:obj:`Dataset`): The dataset to be repeated.
times (int): Repeat times.
"""

def __init__(self, dataset, times):
self.dataset = dataset
self.times = times

self._ori_len = len(self.dataset)

def __getitem__(self, idx):
"""Get data."""
return self.dataset[idx % self._ori_len]

def __len__(self):
"""Length after repetition."""
return self.times * self._ori_len


@DATASETS.register_module()
class CombinedDataset(BaseDataset):
"""A wrapper of combined dataset.
Expand Down Expand Up @@ -113,10 +84,7 @@ def prepare_data(self, idx: int) -> Any:
Any: Depends on ``self.pipeline``.
"""

subset_idx, sample_idx = self._get_subset_index(idx)
# Get data sample from the subset
data_info = self.datasets[subset_idx].get_data_info(sample_idx)
data_info = self.datasets[subset_idx].pipeline(data_info)
data_info = self.get_data_info(idx)

# Add metainfo items that are required in the pipeline and the model
metainfo_keys = [
Expand All @@ -125,13 +93,24 @@ def prepare_data(self, idx: int) -> Any:
]

for key in metainfo_keys:
assert key not in data_info, (
f'"{key}" is a reserved key for `metainfo`, but already '
'exists in the `data_info`.')
data_info[key] = deepcopy(self._metainfo[key])

return self.pipeline(data_info)

def get_data_info(self, idx: int) -> dict:
"""Get annotation by index.

Args:
idx (int): Global index of ``CombinedDataset``.
Returns:
dict: The idx-th annotation of the datasets.
"""
subset_idx, sample_idx = self._get_subset_index(idx)
# Get data sample processed by ``subset.pipeline``
data_info = self.datasets[subset_idx][sample_idx]

return data_info

def full_init(self):
"""Fully initialize all sub datasets."""

Expand Down