Skip to content

Commit

Permalink
[Enhance] CombinedDataset element are now browsed in turn in dataset …
Browse files Browse the repository at this point in the history
…browser (#2985)
  • Loading branch information
drazicmartin committed Mar 20, 2024
1 parent d60c043 commit de67839
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 34 deletions.
3 changes: 2 additions & 1 deletion docs/en/user_guides/prepare_datasets.md
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ MMPose offers a convenient and versatile solution for training with mixed datase
`tools/analysis_tools/browse_dataset.py` helps the user to browse a pose dataset visually, or save the image to a designated directory.

```shell
python tools/misc/browse_dataset.py ${CONFIG} [-h] [--output-dir ${OUTPUT_DIR}] [--not-show] [--phase ${PHASE}] [--mode ${MODE}] [--show-interval ${SHOW_INTERVAL}]
python tools/misc/browse_dataset.py ${CONFIG} [-h] [--output-dir ${OUTPUT_DIR}] [--max-item-per-dataset ${MAX_ITEM_PER_DATASET}] [--not-show] [--phase ${PHASE}] [--mode ${MODE}] [--show-interval ${SHOW_INTERVAL}]
```

| ARGS | Description |
Expand All @@ -138,6 +138,7 @@ python tools/misc/browse_dataset.py ${CONFIG} [-h] [--output-dir ${OUTPUT_DIR}]
| `--phase {train, val, test}` | Options for dataset. |
| `--mode {original, transformed}` | Specify the type of visualized images. `original` means to show images without pre-processing; `transformed` means to show images are pre-processed. |
| `--show-interval SHOW_INTERVAL` | Time interval between visualizing two images. |
| `--max-item-per-dataset` | Define the maximum item processed per dataset, default to 50 |

For instance, users who want to visualize images and annotations in COCO dataset use:

Expand Down
3 changes: 2 additions & 1 deletion docs/zh_cn/user_guides/prepare_datasets.md
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ MMPose 提供了一个方便且多功能的解决方案,用于训练混合数
`tools/analysis_tools/browse_dataset.py` 帮助用户可视化地浏览姿态数据集,或将图像保存到指定的目录。

```shell
python tools/misc/browse_dataset.py ${CONFIG} [-h] [--output-dir ${OUTPUT_DIR}] [--not-show] [--phase ${PHASE}] [--mode ${MODE}] [--show-interval ${SHOW_INTERVAL}]
python tools/misc/browse_dataset.py ${CONFIG} [-h] [--output-dir ${OUTPUT_DIR}] [--max-item-per-dataset ${MAX_ITEM_PER_DATASET}] [--not-show] [--phase ${PHASE}] [--mode ${MODE}] [--show-interval ${SHOW_INTERVAL}]
```

| ARGS | Description |
Expand All @@ -138,6 +138,7 @@ python tools/misc/browse_dataset.py ${CONFIG} [-h] [--output-dir ${OUTPUT_DIR}]
| `--phase {train, val, test}` | 数据集选项 |
| `--mode {original, transformed}` | 指定可视化图片类型。 `original` 为不使用数据增强的原始图片及标注可视化; `transformed` 为经过增强后的可视化 |
| `--show-interval SHOW_INTERVAL` | 显示图片的时间间隔 |
| `--max-item-per-dataset` | 定义每个数据集可视化的最大样本数。默认为 50 |

例如,用户想要可视化 COCO 数据集中的图像和标注,可以使用:

Expand Down
4 changes: 4 additions & 0 deletions mmpose/datasets/dataset_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ def __init__(self,
def metainfo(self):
return deepcopy(self._metainfo)

@property
def lens(self):
return deepcopy(self._lens)

def __len__(self):
return self._len

Expand Down
93 changes: 61 additions & 32 deletions tools/misc/browse_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,16 @@
import argparse
import os
import os.path as osp
from itertools import accumulate

import mmcv
import mmengine
import mmengine.fileio as fileio
import numpy as np
from mmengine import Config, DictAction
from mmengine.registry import build_from_cfg, init_default_scope
from mmengine.structures import InstanceData

from mmpose.datasets import CombinedDataset
from mmpose.registry import DATASETS, VISUALIZERS
from mmpose.structures import PoseDataSample

Expand All @@ -24,6 +25,11 @@ def parse_args():
type=str,
help='If there is no display interface, you can save it.')
parser.add_argument('--not-show', default=False, action='store_true')
parser.add_argument(
'--max-item-per-dataset',
default=50,
type=int,
help='Define the maximum item processed per dataset')
parser.add_argument(
'--phase',
default='train',
Expand Down Expand Up @@ -99,50 +105,73 @@ def main():
visualizer = VISUALIZERS.build(cfg.visualizer)
visualizer.set_dataset_meta(dataset.metainfo)

progress_bar = mmengine.ProgressBar(len(dataset))
if isinstance(dataset, CombinedDataset):

def generate_index_generator(dataset_starting_indexes: list,
max_item_datasets: int):
"""Generates indexes to traverse each dataset element in turn,
based on starting indexes and maximum items per dataset."""
for relative_idx in range(max(max_item_datasets)):
for dataset_idx, dataset_starting_idx in enumerate(
dataset_starting_indexes):
if relative_idx >= max_item_datasets[dataset_idx]:
continue
yield dataset_starting_idx + relative_idx

# Generate starting indexes for each dataset
dataset_starting_indexes = list(accumulate([0] + dataset.lens[:-1]))
max_item_datasets = [
min(dataset_len, args.max_item_per_dataset)
for dataset_len in dataset.lens
]

# Generate indexes using the generator
indexes = generate_index_generator(dataset_starting_indexes,
max_item_datasets)

total = sum(max_item_datasets)
multiple_datasets = True
else:
max_length = min(len(dataset), args.max_item_per_dataset)
indexes = iter(range(max_length))
total = max_length
multiple_datasets = False

idx = 0
item = dataset[0]
progress_bar = mmengine.ProgressBar(total)

while idx < len(dataset):
idx += 1
next_item = None if idx >= len(dataset) else dataset[idx]
for idx in indexes:
item = dataset[idx]

if args.mode == 'original':
if next_item is not None and item['img_path'] == next_item[
'img_path']:
# merge annotations for one image
item['keypoints'] = np.concatenate(
(item['keypoints'], next_item['keypoints']))
item['keypoints_visible'] = np.concatenate(
(item['keypoints_visible'],
next_item['keypoints_visible']))
item['bbox'] = np.concatenate(
(item['bbox'], next_item['bbox']))
progress_bar.update()
continue
img_path = item['img_path']
img_bytes = fileio.get(img_path, backend_args=backend_args)
img = mmcv.imfrombytes(img_bytes, channel_order='bgr')
dataset_name = item.get('dataset_name', None)

# forge pseudo data_sample
gt_instances = InstanceData()
gt_instances.keypoints = item['keypoints']
if item['keypoints_visible'].ndim == 3:
gt_instances.keypoints_visible = item['keypoints_visible'][...,
0]
else:
img_path = item['img_path']
img_bytes = fileio.get(img_path, backend_args=backend_args)
img = mmcv.imfrombytes(img_bytes, channel_order='bgr')

# forge pseudo data_sample
gt_instances = InstanceData()
gt_instances.keypoints = item['keypoints']
gt_instances.keypoints_visible = item['keypoints_visible']
gt_instances.bboxes = item['bbox']
data_sample = PoseDataSample()
data_sample.gt_instances = gt_instances
gt_instances.bboxes = item['bbox']
data_sample = PoseDataSample()
data_sample.gt_instances = gt_instances

item = next_item
else:
img = item['inputs'].permute(1, 2, 0).numpy()
data_sample = item['data_samples']
img_path = data_sample.img_path
item = next_item
dataset_name = data_sample.metainfo.get('dataset_name', None)

# save image with annotation
output_dir = osp.join(
args.output_dir, dataset_name
) if multiple_datasets and dataset_name else args.output_dir
out_file = osp.join(
args.output_dir,
output_dir,
osp.basename(img_path)) if args.output_dir is not None else None
out_file = generate_dup_file_name(out_file)

Expand Down

0 comments on commit de67839

Please sign in to comment.