Skip to content

Commit

Permalink
Merge 304f011 into 3eca326
Browse files Browse the repository at this point in the history
  • Loading branch information
Ben-Louis committed Sep 6, 2023
2 parents 3eca326 + 304f011 commit f42ec26
Show file tree
Hide file tree
Showing 13 changed files with 591 additions and 6 deletions.
118 changes: 118 additions & 0 deletions docs/en/user_guides/train_and_test.md
Original file line number Diff line number Diff line change
Expand Up @@ -390,3 +390,121 @@ Here are the environment variables that can be used to configure the slurm job.
| `GPUS_PER_NODE` | The number of GPUs to be allocated per node. Defaults to 8. |
| `CPUS_PER_TASK` | The number of CPUs to be allocated per task (Usually one GPU corresponds to one task). Defaults to 5. |
| `SRUN_ARGS` | The other arguments of `srun`. Available options can be found [here](https://slurm.schedmd.com/srun.html). |

## Custom Testing Features

### Test with Custom Metrics

If you're looking to assess models using unique metrics not already supported by MMPose, you'll need to code these metrics yourself and include them in your config file. For guidance on how to accomplish this, check out our [customized evaluation guide](https://mmpose.readthedocs.io/en/latest/advanced_guides/customize_evaluation.html).

### Evaluating Across Multiple Datasets

MMPose offers a handy tool known as `MultiDatasetEvaluator` for streamlined assessment across multiple datasets. Setting up this evaluator in your config file is a breeze. Below is a quick example demonstrating how to evaluate a model using both the COCO and AIC datasets:

```python
# Set up validation datasets
coco_val = dict(type='CocoDataset', ...)
aic_val = dict(type='AicDataset', ...)
val_dataset = dict(
type='CombinedDataset',
datasets=[coco_val, aic_val],
pipeline=val_pipeline,
...)

# configurate the evaluator
val_evaluator = dict(
type='MultiDatasetEvaluator',
metrics=[ # metrics for each dataset
dict(type='CocoMetric',
ann_file='data/coco/annotations/person_keypoints_val2017.json'),
dict(type='CocoMetric',
ann_file='data/aic/annotations/aic_val.json',
use_area=False,
prefix='aic')
],
# the number and order of datasets must align with metrics
datasets=[coco_val, aic_val],
)
```

Keep in mind that different datasets, like COCO and AIC, have various keypoint definitions. Yet, the model's output keypoints are standardized. This results in a discrepancy between the model outputs and the actual ground truth. To address this, you can employ `KeypointConverter` to align the keypoint configurations between different datasets. Here’s a full example that shows how to leverage `KeypointConverter` to align AIC keypoints with COCO keypoints:

```python
aic_to_coco_converter = dict(
type='KeypointConverter',
num_keypoints=17,
mapping=[
(0, 6),
(1, 8),
(2, 10),
(3, 5),
(4, 7),
(5, 9),
(6, 12),
(7, 14),
(8, 16),
(9, 11),
(10, 13),
(11, 15),
])

# val datasets
coco_val = dict(
type='CocoDataset',
data_root='data/coco/',
data_mode='topdown',
ann_file='annotations/person_keypoints_val2017.json',
bbox_file='data/coco/person_detection_results/'
'COCO_val2017_detections_AP_H_56_person.json',
data_prefix=dict(img='val2017/'),
test_mode=True,
pipeline=[],
)

aic_val = dict(
type='AicDataset',
data_root='data/aic/',
data_mode=data_mode,
ann_file='annotations/aic_val.json',
data_prefix=dict(img='ai_challenger_keypoint_validation_20170911/'
'keypoint_validation_images_20170911/'),
test_mode=True,
pipeline=[],
)

val_dataset = dict(
type='CombinedDataset',
metainfo=dict(from_file='configs/_base_/datasets/coco.py'),
datasets=[coco_val, aic_val],
pipeline=val_pipeline,
test_mode=True,
)

val_dataloader = dict(
batch_size=32,
num_workers=2,
persistent_workers=True,
drop_last=False,
sampler=dict(type='DefaultSampler', shuffle=False, round_up=False),
dataset=val_dataset)

test_dataloader = val_dataloader

val_evaluator = dict(
type='MultiDatasetEvaluator',
metrics=[
dict(type='CocoMetric',
ann_file=data_root + 'annotations/person_keypoints_val2017.json'),
dict(type='CocoMetric',
ann_file='data/aic/annotations/aic_val.json',
use_area=False,
gt_converter=aic_to_coco_converter,
prefix='aic')
],
datasets=val_dataset['datasets'],
)

test_evaluator = val_evaluator
```

For further clarification on converting AIC keypoints to COCO keypoints, please consult [this guide](https://mmpose.readthedocs.io/en/latest/user_guides/mixed_datasets.html#merge-aic-into-coco).
120 changes: 120 additions & 0 deletions docs/zh_cn/user_guides/train_and_test.md
Original file line number Diff line number Diff line change
Expand Up @@ -376,3 +376,123 @@ NNODES=2 NODE_RANK=1 PORT=$MASTER_PORT MASTER_ADDR=$MASTER_ADDR bash tools/dist_
| `GPUS_PER_NODE` | 每台机器使用的 GPU 总数,默认为 8 |
| `CPUS_PER_TASK` | 每个任务分配的 CPU 总数(通常为 1 张 GPU 对应 1 个任务进程),默认为 5 |
| `SRUN_ARGS` | `srun` 的其他参数,可选项见 [这里](https://slurm.schedmd.com/srun.html). |

## 自定义测试

### 用自定义度量进行测试

如果您希望使用 MMPose 中尚未支持的独特度量来评估模型,您将需要自己编写这些度量并将它们包含在您的配置文件中。关于如何实现这一点的指导,请查看我们的 [自定义评估指南](https://mmpose.readthedocs.io/zh_CN/dev-1.x/advanced_guides/customize_evaluation.html)

### 在多个数据集上进行评估

MMPose 提供了一个名为 `MultiDatasetEvaluator` 的便捷工具,用于在多个数据集上进行简化评估。在配置文件中设置此评估器非常简单。下面是一个快速示例,演示如何使用 COCO 和 AIC 数据集评估模型:

```python
# 设置验证数据集
coco_val = dict(type='CocoDataset', ...)

aic_val = dict(type='AicDataset', ...)

val_dataset = dict(
type='CombinedDataset',
datasets=[coco_val, aic_val],
pipeline=val_pipeline,
...)

# 配置评估器
val_evaluator = dict(
type='MultiDatasetEvaluator',
metrics=[ # 为每个数据集配置度量
dict(type='CocoMetric',
ann_file='data/coco/annotations/person_keypoints_val2017.json'),
dict(type='CocoMetric',
ann_file='data/aic/annotations/aic_val.json',
use_area=False,
prefix='aic')
],
# 数据集个数和顺序与度量必须匹配
datasets=[coco_val, aic_val],
)
```

同的数据集(如 COCO 和 AIC)具有不同的关键点定义。然而,模型的输出关键点是标准化的。这导致了模型输出与真值之间关键点顺序的差异。为解决这一问题,您可以使用 `KeypointConverter` 来对齐不同数据集之间的关键点顺序。下面是一个完整示例,展示了如何利用 `KeypointConverter` 来对齐 AIC 关键点与 COCO 关键点:

```python
aic_to_coco_converter = dict(
type='KeypointConverter',
num_keypoints=17,
mapping=[
(0, 6),
(1, 8),
(2, 10),
(3, 5),
(4, 7),
(5, 9),
(6, 12),
(7, 14),
(8, 16),
(9, 11),
(10, 13),
(11, 15),
])

# val datasets
coco_val = dict(
type='CocoDataset',
data_root='data/coco/',
data_mode='topdown',
ann_file='annotations/person_keypoints_val2017.json',
bbox_file='data/coco/person_detection_results/'
'COCO_val2017_detections_AP_H_56_person.json',
data_prefix=dict(img='val2017/'),
test_mode=True,
pipeline=[],
)

aic_val = dict(
type='AicDataset',
data_root='data/aic/',
data_mode=data_mode,
ann_file='annotations/aic_val.json',
data_prefix=dict(img='ai_challenger_keypoint_validation_20170911/'
'keypoint_validation_images_20170911/'),
test_mode=True,
pipeline=[],
)

val_dataset = dict(
type='CombinedDataset',
metainfo=dict(from_file='configs/_base_/datasets/coco.py'),
datasets=[coco_val, aic_val],
pipeline=val_pipeline,
test_mode=True,
)

val_dataloader = dict(
batch_size=32,
num_workers=2,
persistent_workers=True,
drop_last=False,
sampler=dict(type='DefaultSampler', shuffle=False, round_up=False),
dataset=val_dataset)

test_dataloader = val_dataloader

val_evaluator = dict(
type='MultiDatasetEvaluator',
metrics=[
dict(type='CocoMetric',
ann_file=data_root + 'annotations/person_keypoints_val2017.json'),
dict(type='CocoMetric',
ann_file='data/aic/annotations/aic_val.json',
use_area=False,
gt_converter=aic_to_coco_converter,
prefix='aic')
],
datasets=val_dataset['datasets'],
)

test_evaluator = val_evaluator
```

如需进一步了解如何将 AIC 关键点转换为 COCO 关键点,请查阅 [该指南](https://mmpose.readthedocs.io/zh_CN/dev-1.x/user_guides/mixed_datasets.html#aic-coco)
2 changes: 1 addition & 1 deletion mmpose/datasets/datasets/base/base_coco_style_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def get_data_info(self, idx: int) -> dict:

# Add metainfo items that are required in the pipeline and the model
metainfo_keys = [
'upper_body_ids', 'lower_body_ids', 'flip_pairs',
'dataset_name', 'upper_body_ids', 'lower_body_ids', 'flip_pairs',
'dataset_keypoint_weights', 'flip_indices', 'skeleton_links'
]

Expand Down
40 changes: 39 additions & 1 deletion mmpose/datasets/transforms/converting.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def __init__(self, num_keypoints: int,
self.source_index2 = src2

self.source_index = src1
self.target_index = target_index
self.target_index = list(target_index)
self.interpolation = interpolation

def transform(self, results: dict) -> dict:
Expand Down Expand Up @@ -122,6 +122,44 @@ def transform(self, results: dict) -> dict:
[keypoints_visible, keypoints_visible_weights], axis=2)
return results

def transform_sigmas(self, sigmas: Union[List, np.ndarray]):
"""Transforms the sigmas based on the mapping."""
list_input = False
if isinstance(sigmas, list):
sigmas = np.array(sigmas)
list_input = True

Check warning on line 130 in mmpose/datasets/transforms/converting.py

View check run for this annotation

Codecov / codecov/patch

mmpose/datasets/transforms/converting.py#L129-L130

Added lines #L129 - L130 were not covered by tests

new_sigmas = np.ones(self.num_keypoints, dtype=sigmas.dtype)
new_sigmas[self.target_index] = sigmas[self.source_index]

if list_input:
new_sigmas = new_sigmas.tolist()

Check warning on line 136 in mmpose/datasets/transforms/converting.py

View check run for this annotation

Codecov / codecov/patch

mmpose/datasets/transforms/converting.py#L136

Added line #L136 was not covered by tests

return new_sigmas

def transform_ann(self, ann_info: Union[dict, list]):
"""Transforms the annotations based on the mapping."""

list_input = True
if not isinstance(ann_info, list):
ann_info = [ann_info]
list_input = False

for ann in ann_info:
if 'keypoints' in ann:
keypoints = np.array(ann['keypoints']).reshape(-1, 3)
new_keypoints = np.zeros((self.num_keypoints, 3),
dtype=keypoints.dtype)
new_keypoints[self.target_index] = keypoints[self.source_index]
ann['keypoints'] = new_keypoints.reshape(-1).tolist()
if 'num_keypoints' in ann:
ann['num_keypoints'] = self.num_keypoints

if not list_input:
ann_info = ann_info[0]

return ann_info

def __repr__(self) -> str:
"""print the basic information of the transform.
Expand Down
2 changes: 1 addition & 1 deletion mmpose/datasets/transforms/formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def __init__(self,
'crowd_index', 'ori_shape', 'img_shape',
'input_size', 'input_center', 'input_scale',
'flip', 'flip_direction', 'flip_indices',
'raw_ann_info'),
'raw_ann_info', 'dataset_name'),
pack_transformed=False):
self.meta_keys = meta_keys
self.pack_transformed = pack_transformed
Expand Down
1 change: 1 addition & 0 deletions mmpose/evaluation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .evaluators import * # noqa: F401,F403
from .functional import * # noqa: F401,F403
from .metrics import * # noqa: F401,F403
4 changes: 4 additions & 0 deletions mmpose/evaluation/evaluators/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .mutli_dataset_evaluator import MultiDatasetEvaluator

__all__ = ['MultiDatasetEvaluator']
Loading

0 comments on commit f42ec26

Please sign in to comment.