From ce8d585a4a0eb07f5814a94a2257d8efd5a52c3e Mon Sep 17 00:00:00 2001 From: Tau-J <674106399@qq.com> Date: Thu, 12 Jan 2023 11:53:04 +0800 Subject: [PATCH 1/4] fix bug in CombinedDataset --- mmpose/datasets/__init__.py | 4 ++-- mmpose/datasets/dataset_wrappers.py | 32 ++--------------------------- 2 files changed, 4 insertions(+), 32 deletions(-) diff --git a/mmpose/datasets/__init__.py b/mmpose/datasets/__init__.py index 042c1b7e28..001155172b 100644 --- a/mmpose/datasets/__init__.py +++ b/mmpose/datasets/__init__.py @@ -1,7 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from .builder import build_dataset -from .dataset_wrappers import CombinedDataset, RepeatDataset +from .dataset_wrappers import CombinedDataset from .datasets import * # noqa from .transforms import * # noqa -__all__ = ['build_dataset', 'RepeatDataset', 'CombinedDataset'] +__all__ = ['build_dataset', 'CombinedDataset'] diff --git a/mmpose/datasets/dataset_wrappers.py b/mmpose/datasets/dataset_wrappers.py index 2997615afd..2ea335c227 100644 --- a/mmpose/datasets/dataset_wrappers.py +++ b/mmpose/datasets/dataset_wrappers.py @@ -10,35 +10,6 @@ from .datasets.utils import parse_pose_metainfo -@DATASETS.register_module() -class RepeatDataset: - """A wrapper of repeated dataset. - - The length of repeated dataset will be `times` larger than the original - dataset. This is useful when the data loading time is long but the dataset - is small. Using RepeatDataset can reduce the data loading time between - epochs. - - Args: - dataset (:obj:`Dataset`): The dataset to be repeated. - times (int): Repeat times. - """ - - def __init__(self, dataset, times): - self.dataset = dataset - self.times = times - - self._ori_len = len(self.dataset) - - def __getitem__(self, idx): - """Get data.""" - return self.dataset[idx % self._ori_len] - - def __len__(self): - """Length after repetition.""" - return self.times * self._ori_len - - @DATASETS.register_module() class CombinedDataset(BaseDataset): """A wrapper of combined dataset. @@ -116,7 +87,8 @@ def prepare_data(self, idx: int) -> Any: subset_idx, sample_idx = self._get_subset_index(idx) # Get data sample from the subset data_info = self.datasets[subset_idx].get_data_info(sample_idx) - data_info = self.datasets[subset_idx].pipeline(data_info) + if hasattr(self.datasets[subset_idx], 'pipeline'): + data_info = self.datasets[subset_idx].pipeline(data_info) # Add metainfo items that are required in the pipeline and the model metainfo_keys = [ From 98468e0c428d2e0b8ee1b9957fe9aa8211dd2bae Mon Sep 17 00:00:00 2001 From: Tau-J <674106399@qq.com> Date: Thu, 12 Jan 2023 12:52:07 +0800 Subject: [PATCH 2/4] update --- mmpose/datasets/dataset_wrappers.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/mmpose/datasets/dataset_wrappers.py b/mmpose/datasets/dataset_wrappers.py index 2ea335c227..04aed5bca6 100644 --- a/mmpose/datasets/dataset_wrappers.py +++ b/mmpose/datasets/dataset_wrappers.py @@ -86,9 +86,7 @@ def prepare_data(self, idx: int) -> Any: subset_idx, sample_idx = self._get_subset_index(idx) # Get data sample from the subset - data_info = self.datasets[subset_idx].get_data_info(sample_idx) - if hasattr(self.datasets[subset_idx], 'pipeline'): - data_info = self.datasets[subset_idx].pipeline(data_info) + data_info = self.datasets[subset_idx][sample_idx] # Add metainfo items that are required in the pipeline and the model metainfo_keys = [ From 64f680fe1e64f5b6b0613977cdc2665f659a7b9d Mon Sep 17 00:00:00 2001 From: Tau-J <674106399@qq.com> Date: Thu, 12 Jan 2023 12:59:39 +0800 Subject: [PATCH 3/4] add get_data_info() --- mmpose/datasets/dataset_wrappers.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/mmpose/datasets/dataset_wrappers.py b/mmpose/datasets/dataset_wrappers.py index 04aed5bca6..b4fcbb86f1 100644 --- a/mmpose/datasets/dataset_wrappers.py +++ b/mmpose/datasets/dataset_wrappers.py @@ -84,8 +84,19 @@ def prepare_data(self, idx: int) -> Any: Any: Depends on ``self.pipeline``. """ + data_info = self.get_data_info(idx) + return self.pipeline(data_info) + + def get_data_info(self, idx: int) -> dict: + """Get annotation by index. + + Args: + idx (int): Global index of ``CombinedDataset``. + Returns: + dict: The idx-th annotation of the datasets. + """ subset_idx, sample_idx = self._get_subset_index(idx) - # Get data sample from the subset + # Get data sample processed by ``self.pipeline`` from the subset data_info = self.datasets[subset_idx][sample_idx] # Add metainfo items that are required in the pipeline and the model @@ -100,7 +111,7 @@ def prepare_data(self, idx: int) -> Any: 'exists in the `data_info`.') data_info[key] = deepcopy(self._metainfo[key]) - return self.pipeline(data_info) + return data_info def full_init(self): """Fully initialize all sub datasets.""" From 7a4ec966f82b5d0005c89bdfca597e1a1bc4e3ef Mon Sep 17 00:00:00 2001 From: Tau-J <674106399@qq.com> Date: Thu, 12 Jan 2023 13:31:50 +0800 Subject: [PATCH 4/4] update --- mmpose/datasets/dataset_wrappers.py | 24 +++++++++++------------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/mmpose/datasets/dataset_wrappers.py b/mmpose/datasets/dataset_wrappers.py index b4fcbb86f1..3836100ed2 100644 --- a/mmpose/datasets/dataset_wrappers.py +++ b/mmpose/datasets/dataset_wrappers.py @@ -85,6 +85,16 @@ def prepare_data(self, idx: int) -> Any: """ data_info = self.get_data_info(idx) + + # Add metainfo items that are required in the pipeline and the model + metainfo_keys = [ + 'upper_body_ids', 'lower_body_ids', 'flip_pairs', + 'dataset_keypoint_weights', 'flip_indices' + ] + + for key in metainfo_keys: + data_info[key] = deepcopy(self._metainfo[key]) + return self.pipeline(data_info) def get_data_info(self, idx: int) -> dict: @@ -96,21 +106,9 @@ def get_data_info(self, idx: int) -> dict: dict: The idx-th annotation of the datasets. """ subset_idx, sample_idx = self._get_subset_index(idx) - # Get data sample processed by ``self.pipeline`` from the subset + # Get data sample processed by ``subset.pipeline`` data_info = self.datasets[subset_idx][sample_idx] - # Add metainfo items that are required in the pipeline and the model - metainfo_keys = [ - 'upper_body_ids', 'lower_body_ids', 'flip_pairs', - 'dataset_keypoint_weights', 'flip_indices' - ] - - for key in metainfo_keys: - assert key not in data_info, ( - f'"{key}" is a reserved key for `metainfo`, but already ' - 'exists in the `data_info`.') - data_info[key] = deepcopy(self._metainfo[key]) - return data_info def full_init(self):