From 4216cafd14f2224ac876ac51fad81c11baf69c51 Mon Sep 17 00:00:00 2001 From: LareinaM Date: Tue, 18 Jul 2023 16:20:49 +0800 Subject: [PATCH 1/6] edit motionbert config --- ...-lift_motionbert-243frm_8xb32-120e_h36m.py | 7 +++-- .../datasets/transforms/pose3d_transforms.py | 27 ++++++++----------- .../test_transforms/test_pose3d_transforms.py | 20 -------------- 3 files changed, 14 insertions(+), 40 deletions(-) diff --git a/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_motionbert-243frm_8xb32-120e_h36m.py b/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_motionbert-243frm_8xb32-120e_h36m.py index 88f6c3897d..094f919fd3 100644 --- a/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_motionbert-243frm_8xb32-120e_h36m.py +++ b/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_motionbert-243frm_8xb32-120e_h36m.py @@ -69,12 +69,11 @@ # pipelines train_pipeline = [ + dict(type='GenerateTarget', encoder=train_codec), dict( type='RandomFlipAroundRoot', - keypoints_flip_cfg={}, - target_flip_cfg={}, - flip_image=True), - dict(type='GenerateTarget', encoder=train_codec), + keypoints_flip_cfg=dict(center_mode='static', center_x=0.), + target_flip_cfg=dict(center_mode='static', center_x=0.)), dict( type='PackPoseInputs', meta_keys=('id', 'category_id', 'target_img_path', 'flip_indices', diff --git a/mmpose/datasets/transforms/pose3d_transforms.py b/mmpose/datasets/transforms/pose3d_transforms.py index 2149d7cb30..f92f9787ea 100644 --- a/mmpose/datasets/transforms/pose3d_transforms.py +++ b/mmpose/datasets/transforms/pose3d_transforms.py @@ -25,29 +25,31 @@ class RandomFlipAroundRoot(BaseTransform): flip_prob (float): Probability of flip. Default: 0.5. flip_camera (bool): Whether to flip horizontal distortion coefficients. Default: ``False``. - flip_image (bool): Whether to flip keypoints horizontally according - to image size. Default: ``False``. Required keys: - keypoints - lifting_target + - keypoints + - lifting_target + - keypoints_visible (optional) + - lifting_target_visible (optional) + - flip_indices (optional) Modified keys: - (keypoints, keypoints_visible, lifting_target, lifting_target_visible, - camera_param) + - keypoints (optional) + - keypoints_visible (optional) + - lifting_target (optional) + - lifting_target_visible (optional) + - camera_param (optional) """ def __init__(self, keypoints_flip_cfg, target_flip_cfg, flip_prob=0.5, - flip_camera=False, - flip_image=False): + flip_camera=False): self.keypoints_flip_cfg = keypoints_flip_cfg self.target_flip_cfg = target_flip_cfg self.flip_prob = flip_prob self.flip_camera = flip_camera - self.flip_image = flip_image def transform(self, results: Dict) -> dict: """The transform function of :class:`RandomFlipAroundRoot`. @@ -81,13 +83,6 @@ def transform(self, results: Dict) -> dict: # flip joint coordinates _camera_param = deepcopy(results['camera_param']) - if self.flip_image: - assert 'camera_param' in results, \ - 'Camera parameters are missing.' - assert 'w' in _camera_param - w = _camera_param['w'] / 2 - self.keypoints_flip_cfg['center_x'] = w - self.target_flip_cfg['center_x'] = w keypoints, keypoints_visible = flip_keypoints_custom_center( keypoints, keypoints_visible, flip_indices, diff --git a/tests/test_datasets/test_transforms/test_pose3d_transforms.py b/tests/test_datasets/test_transforms/test_pose3d_transforms.py index b87931bb74..db7a612dee 100644 --- a/tests/test_datasets/test_transforms/test_pose3d_transforms.py +++ b/tests/test_datasets/test_transforms/test_pose3d_transforms.py @@ -153,23 +153,3 @@ def test_transform(self): -self.data_info['camera_param']['p'][0], camera2['p'][0], atol=4.)) - - # test flipping w.r.t. image - transform = RandomFlipAroundRoot({}, {}, flip_prob=1, flip_image=True) - results = deepcopy(self.data_info) - results = transform(results) - kpts2 = results['keypoints'] - tar2 = results['lifting_target'] - - camera_param = results['camera_param'] - for left, right in enumerate(flip_indices): - self.assertTrue( - np.allclose( - camera_param['w'] - kpts1[0][left][:1], - kpts2[0][right][:1], - atol=4.)) - self.assertTrue( - np.allclose(kpts1[0][left][1:], kpts2[0][right][1:], atol=4.)) - self.assertTrue( - np.allclose( - tar1[..., left, 1:], tar2[..., right, 1:], atol=4.)) From 843667ce7aaeb730f47256a87a551c4eb7cab9c9 Mon Sep 17 00:00:00 2001 From: LareinaM Date: Tue, 18 Jul 2023 20:33:32 +0800 Subject: [PATCH 2/6] fix transform --- ...-lift_motionbert-243frm_8xb32-120e_h36m.py | 9 +-- mmpose/codecs/motionbert_label.py | 24 +++---- .../datasets/transforms/pose3d_transforms.py | 63 ++++++++++++------- tests/test_codecs/test_motionbert_label.py | 4 -- .../test_transforms/test_pose3d_transforms.py | 44 +++++++++++++ 5 files changed, 101 insertions(+), 43 deletions(-) diff --git a/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_motionbert-243frm_8xb32-120e_h36m.py b/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_motionbert-243frm_8xb32-120e_h36m.py index 094f919fd3..25399cfa75 100644 --- a/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_motionbert-243frm_8xb32-120e_h36m.py +++ b/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_motionbert-243frm_8xb32-120e_h36m.py @@ -32,11 +32,7 @@ # codec settings train_codec = dict( - type='MotionBERTLabel', - num_keypoints=17, - concat_vis=True, - rootrel=True, - factor_label=False) + type='MotionBERTLabel', num_keypoints=17, concat_vis=True, mode='train') val_codec = dict( type='MotionBERTLabel', num_keypoints=17, concat_vis=True, rootrel=True) @@ -73,7 +69,8 @@ dict( type='RandomFlipAroundRoot', keypoints_flip_cfg=dict(center_mode='static', center_x=0.), - target_flip_cfg=dict(center_mode='static', center_x=0.)), + target_flip_cfg=dict(center_mode='static', center_x=0.), + flip_label=True), dict( type='PackPoseInputs', meta_keys=('id', 'category_id', 'target_img_path', 'flip_indices', diff --git a/mmpose/codecs/motionbert_label.py b/mmpose/codecs/motionbert_label.py index d0c8cd0d40..08ff4ccd1a 100644 --- a/mmpose/codecs/motionbert_label.py +++ b/mmpose/codecs/motionbert_label.py @@ -34,8 +34,8 @@ class MotionBERTLabel(BaseKeypointCodec): Default: ``False``. rootrel (bool): If true, the root keypoint will be set to the coordinate origin. Default: ``False``. - factor_label (bool): If true, the label will be multiplied by a factor. - Default: ``True``. + mode (str): Indicating whether the current mode is 'train' or 'test'. + Default: ``'test'``. """ auxiliary_encode_keys = { @@ -49,7 +49,7 @@ def __init__(self, save_index: bool = False, concat_vis: bool = False, rootrel: bool = False, - factor_label: bool = True): + mode: str = 'test'): super().__init__() self.num_keypoints = num_keypoints @@ -58,7 +58,8 @@ def __init__(self, self.save_index = save_index self.concat_vis = concat_vis self.rootrel = rootrel - self.factor_label = factor_label + assert mode.lower() in {'train', 'test'} + self.mode = mode.lower() def encode(self, keypoints: np.ndarray, @@ -92,8 +93,6 @@ def encode(self, shape (K, C) or (K-1, C). - lifting_target_weights (np.ndarray): The target weights in shape (K, ) or (K-1, ). - - trajectory_weights (np.ndarray): The trajectory weights in - shape (K, ). - factor (np.ndarray): The factor mapping camera and image coordinate in shape (T, 1). """ @@ -104,16 +103,13 @@ def encode(self, lifting_target = [keypoints[..., 0, :, :]] # set initial value for `lifting_target_weights` - # and `trajectory_weights` if lifting_target_visible is None: lifting_target_visible = np.ones( lifting_target.shape[:-1], dtype=np.float32) lifting_target_weights = lifting_target_visible - trajectory_weights = (1 / lifting_target[:, 2]) else: valid = lifting_target_visible > 0.5 lifting_target_weights = np.where(valid, 1., 0.).astype(np.float32) - trajectory_weights = lifting_target_weights if camera_param is None: camera_param = dict() @@ -140,6 +136,13 @@ def encode(self, if 'f' in _camera_param and 'c' in _camera_param: lifting_target_label, factor_ = camera_to_image_coord( self.root_index, lifting_target_label, _camera_param) + if self.mode == 'train': + w, h = w / 1000, h / 1000 + lifting_target_label[ + ..., :2] = lifting_target_label[..., :2] / w * 2 - [ + 0.001, h / w + ] + lifting_target_label[..., 2] = lifting_target_label[..., 2] / w * 2 lifting_target_label[..., :, :] = lifting_target_label[ ..., :, :] - lifting_target_label[..., self.root_index:self.root_index + @@ -148,7 +151,7 @@ def encode(self, factor = factor_ if factor.ndim == 1: factor = factor[:, None] - if self.factor_label: + if self.mode == 'test': lifting_target_label *= factor[..., None] if self.concat_vis: @@ -164,7 +167,6 @@ def encode(self, encoded['lifting_target_weights'] = lifting_target_weights encoded['lifting_target'] = lifting_target_label encoded['lifting_target_visible'] = lifting_target_visible - encoded['trajectory_weights'] = trajectory_weights encoded['factor'] = factor return encoded diff --git a/mmpose/datasets/transforms/pose3d_transforms.py b/mmpose/datasets/transforms/pose3d_transforms.py index f92f9787ea..5831692000 100644 --- a/mmpose/datasets/transforms/pose3d_transforms.py +++ b/mmpose/datasets/transforms/pose3d_transforms.py @@ -25,18 +25,20 @@ class RandomFlipAroundRoot(BaseTransform): flip_prob (float): Probability of flip. Default: 0.5. flip_camera (bool): Whether to flip horizontal distortion coefficients. Default: ``False``. + flip_label (bool): Whether to flip labels instead of data. + Default: ``False``. Required keys: - - keypoints - - lifting_target - - keypoints_visible (optional) + - keypoints or keypoint_labels + - lifting_target or lifting_target_label + - keypoints_visible or keypoint_labels_visible (optional) - lifting_target_visible (optional) - flip_indices (optional) Modified keys: - - keypoints (optional) - - keypoints_visible (optional) - - lifting_target (optional) + - keypoints or keypoint_labels (optional) + - keypoints_visible or keypoint_labels_visible (optional) + - lifting_target or lifting_target_label (optional) - lifting_target_visible (optional) - camera_param (optional) """ @@ -45,11 +47,13 @@ def __init__(self, keypoints_flip_cfg, target_flip_cfg, flip_prob=0.5, - flip_camera=False): + flip_camera=False, + flip_label=False): self.keypoints_flip_cfg = keypoints_flip_cfg self.target_flip_cfg = target_flip_cfg self.flip_prob = flip_prob self.flip_camera = flip_camera + self.flip_label = flip_label def transform(self, results: Dict) -> dict: """The transform function of :class:`RandomFlipAroundRoot`. @@ -63,19 +67,34 @@ def transform(self, results: Dict) -> dict: dict: The result dict. """ - keypoints = results['keypoints'] - if 'keypoints_visible' in results: - keypoints_visible = results['keypoints_visible'] - else: - keypoints_visible = np.ones(keypoints.shape[:-1], dtype=np.float32) - lifting_target = results['lifting_target'] - if 'lifting_target_visible' in results: - lifting_target_visible = results['lifting_target_visible'] - else: - lifting_target_visible = np.ones( - lifting_target.shape[:-1], dtype=np.float32) - if np.random.rand() <= self.flip_prob: + if self.flip_label: + assert 'keypoint_labels' in results + assert 'lifting_target_label' in results + keypoints_key = 'keypoint_labels' + keypoints_visible_key = 'keypoint_labels_visible' + target_key = 'lifting_target_label' + else: + assert 'keypoints' in results + assert 'lifting_target' in results + keypoints_key = 'keypoints' + keypoints_visible_key = 'keypoints_visible' + target_key = 'lifting_target' + + keypoints = results[keypoints_key] + if keypoints_visible_key in results: + keypoints_visible = results[keypoints_visible_key] + else: + keypoints_visible = np.ones( + keypoints.shape[:-1], dtype=np.float32) + + lifting_target = results[target_key] + if 'lifting_target_visible' in results: + lifting_target_visible = results['lifting_target_visible'] + else: + lifting_target_visible = np.ones( + lifting_target.shape[:-1], dtype=np.float32) + if 'flip_indices' not in results: flip_indices = list(range(self.num_keypoints)) else: @@ -91,9 +110,9 @@ def transform(self, results: Dict) -> dict: lifting_target, lifting_target_visible, flip_indices, **self.target_flip_cfg) - results['keypoints'] = keypoints - results['keypoints_visible'] = keypoints_visible - results['lifting_target'] = lifting_target + results[keypoints_key] = keypoints + results[keypoints_visible_key] = keypoints_visible + results[target_key] = lifting_target results['lifting_target_visible'] = lifting_target_visible # flip horizontal distortion coefficients diff --git a/tests/test_codecs/test_motionbert_label.py b/tests/test_codecs/test_motionbert_label.py index 01c9c654a2..a42b3d0793 100644 --- a/tests/test_codecs/test_motionbert_label.py +++ b/tests/test_codecs/test_motionbert_label.py @@ -73,10 +73,6 @@ def test_encode(self): 1, 17, )) - self.assertEqual(encoded['trajectory_weights'].shape, ( - 1, - 17, - )) # test concatenating visibility codec = self.build_pose_lifting_label(concat_vis=True) diff --git a/tests/test_datasets/test_transforms/test_pose3d_transforms.py b/tests/test_datasets/test_transforms/test_pose3d_transforms.py index db7a612dee..c057dba4e7 100644 --- a/tests/test_datasets/test_transforms/test_pose3d_transforms.py +++ b/tests/test_datasets/test_transforms/test_pose3d_transforms.py @@ -153,3 +153,47 @@ def test_transform(self): -self.data_info['camera_param']['p'][0], camera2['p'][0], atol=4.)) + + # test label flipping + self.data_info['keypoint_labels'] = kpts1 + self.data_info['keypoint_labels_visible'] = kpts_vis1 + self.data_info['lifting_target_label'] = tar1 + + transform = RandomFlipAroundRoot( + self.keypoints_flip_cfg, + self.target_flip_cfg, + flip_prob=1, + flip_label=True) + results = transform(deepcopy(self.data_info)) + + kpts2 = results['keypoint_labels'] + kpts_vis2 = results['keypoint_labels_visible'] + tar2 = results['lifting_target_label'] + tar_vis2 = results['lifting_target_visible'] + + self.assertEqual(kpts_vis2.shape, (1, 17)) + self.assertEqual(tar_vis2.shape, ( + 1, + 17, + )) + self.assertEqual(kpts2.shape, (1, 17, 2)) + self.assertEqual(tar2.shape, (1, 17, 3)) + + flip_indices = [ + 0, 4, 5, 6, 1, 2, 3, 7, 8, 9, 10, 14, 15, 16, 11, 12, 13 + ] + for left, right in enumerate(flip_indices): + self.assertTrue( + np.allclose(-kpts1[0][left][:1], kpts2[0][right][:1], atol=4.)) + self.assertTrue( + np.allclose(kpts1[0][left][1:], kpts2[0][right][1:], atol=4.)) + self.assertTrue( + np.allclose( + tar1[..., left, 1:], tar2[..., right, 1:], atol=4.)) + + self.assertTrue( + np.allclose( + kpts_vis1[..., left], kpts_vis2[..., right], atol=4.)) + self.assertTrue( + np.allclose( + tar_vis1[..., left], tar_vis2[..., right], atol=4.)) From 056ad65fb145f3031231614409a317f825bb94b4 Mon Sep 17 00:00:00 2001 From: LareinaM Date: Wed, 19 Jul 2023 15:55:15 +0800 Subject: [PATCH 3/6] add flip test --- ...-lift_motionbert-243frm_8xb32-120e_h36m.py | 2 +- .../motion_regression_head.py | 20 ++++++++++++++++++- mmpose/models/pose_estimators/pose_lifter.py | 7 ++++++- 3 files changed, 26 insertions(+), 3 deletions(-) diff --git a/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_motionbert-243frm_8xb32-120e_h36m.py b/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_motionbert-243frm_8xb32-120e_h36m.py index 25399cfa75..25b9d216a2 100644 --- a/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_motionbert-243frm_8xb32-120e_h36m.py +++ b/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_motionbert-243frm_8xb32-120e_h36m.py @@ -57,7 +57,7 @@ loss=dict(type='MPJPEVelocityJointLoss'), decoder=val_codec, ), -) + test_cfg=dict(flip_test=True)) # base dataset settings dataset_type = 'Human36mDataset' diff --git a/mmpose/models/heads/regression_heads/motion_regression_head.py b/mmpose/models/heads/regression_heads/motion_regression_head.py index a0037180c7..0e40e9d31f 100644 --- a/mmpose/models/heads/regression_heads/motion_regression_head.py +++ b/mmpose/models/heads/regression_heads/motion_regression_head.py @@ -7,6 +7,7 @@ from torch import Tensor, nn from mmpose.evaluation.functional import keypoint_mpjpe +from mmpose.models.utils.tta import flip_heatmaps from mmpose.registry import KEYPOINT_CODECS, MODELS from mmpose.utils.tensor_utils import to_numpy from mmpose.utils.typing import (ConfigType, OptConfigType, OptSampleList, @@ -95,7 +96,24 @@ def predict(self, (B, N, K). """ - batch_coords = self.forward(feats) # (B, K, D) + if test_cfg.get('flip_test', False): + # TTA: flip test -> feats = [orig, flipped] + assert isinstance(feats, list) and len(feats) == 2 + flip_indices = batch_data_samples[0].metainfo['flip_indices'] + _feats, _feats_flip = feats + _batch_coords = self.forward(_feats) + _batch_coords_flip = np.stack([ + flip_heatmaps( + self.forward(_feat_flip), + flip_mode=test_cfg.get('flip_mode', 'heatmap'), + flip_indices=flip_indices, + shift_heatmap=test_cfg.get('shift_heatmap', False)) + for _feat_flip in _feats_flip + ], + axis=0) + batch_coords = (_batch_coords + _batch_coords_flip) * 0.5 + else: + batch_coords = self.forward(feats) # Restore global position with camera_param and factor camera_param = batch_data_samples[0].metainfo.get('camera_param', None) diff --git a/mmpose/models/pose_estimators/pose_lifter.py b/mmpose/models/pose_estimators/pose_lifter.py index 5bad3dde3c..89ec29fb9d 100644 --- a/mmpose/models/pose_estimators/pose_lifter.py +++ b/mmpose/models/pose_estimators/pose_lifter.py @@ -244,7 +244,12 @@ def predict(self, inputs: Tensor, data_samples: SampleList) -> SampleList: assert self.with_head, ( 'The model must have head to perform prediction.') - feats = self.extract_feat(inputs) + if self.test_cfg.get('flip_test', False): + _feats = self.extract_feat(inputs) + _feats_flip = self.extract_feat(inputs.flip(-1)) + feats = [_feats, _feats_flip] + else: + feats = self.extract_feat(inputs) pose_preds, batch_pred_instances, batch_pred_fields = None, None, None traj_preds, batch_traj_instances, batch_traj_fields = None, None, None From 7d85e85f10fe67bb07ba367843e16f80e6aaeec4 Mon Sep 17 00:00:00 2001 From: LareinaM Date: Wed, 19 Jul 2023 19:56:17 +0800 Subject: [PATCH 4/6] fix --- .../regression_heads/motion_regression_head.py | 16 ++++++++-------- mmpose/models/losses/__init__.py | 6 ++++-- mmpose/models/pose_estimators/pose_lifter.py | 14 +++++++++++++- 3 files changed, 25 insertions(+), 11 deletions(-) diff --git a/mmpose/models/heads/regression_heads/motion_regression_head.py b/mmpose/models/heads/regression_heads/motion_regression_head.py index 0e40e9d31f..3870e3c59e 100644 --- a/mmpose/models/heads/regression_heads/motion_regression_head.py +++ b/mmpose/models/heads/regression_heads/motion_regression_head.py @@ -7,7 +7,7 @@ from torch import Tensor, nn from mmpose.evaluation.functional import keypoint_mpjpe -from mmpose.models.utils.tta import flip_heatmaps +from mmpose.models.utils.tta import flip_coordinates from mmpose.registry import KEYPOINT_CODECS, MODELS from mmpose.utils.tensor_utils import to_numpy from mmpose.utils.typing import (ConfigType, OptConfigType, OptSampleList, @@ -102,15 +102,15 @@ def predict(self, flip_indices = batch_data_samples[0].metainfo['flip_indices'] _feats, _feats_flip = feats _batch_coords = self.forward(_feats) - _batch_coords_flip = np.stack([ - flip_heatmaps( - self.forward(_feat_flip), - flip_mode=test_cfg.get('flip_mode', 'heatmap'), + _batch_coords_flip = torch.stack([ + flip_coordinates( + _batch_coord_flip, flip_indices=flip_indices, - shift_heatmap=test_cfg.get('shift_heatmap', False)) - for _feat_flip in _feats_flip + shift_coords=test_cfg.get('shift_coords', True), + input_size=(1, 1)) + for _batch_coord_flip in self.forward(_feats_flip) ], - axis=0) + dim=0) batch_coords = (_batch_coords + _batch_coords_flip) * 0.5 else: batch_coords = self.forward(feats) diff --git a/mmpose/models/losses/__init__.py b/mmpose/models/losses/__init__.py index f21071e156..523e4df133 100644 --- a/mmpose/models/losses/__init__.py +++ b/mmpose/models/losses/__init__.py @@ -4,7 +4,8 @@ from .heatmap_loss import (AdaptiveWingLoss, KeypointMSELoss, KeypointOHKMMSELoss) from .loss_wrappers import CombinedLoss, MultipleLossWrapper -from .regression_loss import (BoneLoss, L1Loss, MPJPELoss, MSELoss, RLELoss, +from .regression_loss import (BoneLoss, L1Loss, MPJPELoss, + MPJPEVelocityJointLoss, MSELoss, RLELoss, SemiSupervisionLoss, SmoothL1Loss, SoftWeightSmoothL1Loss, SoftWingLoss, WingLoss) @@ -13,5 +14,6 @@ 'MPJPELoss', 'MSELoss', 'L1Loss', 'BCELoss', 'BoneLoss', 'SemiSupervisionLoss', 'SoftWingLoss', 'AdaptiveWingLoss', 'RLELoss', 'KLDiscretLoss', 'MultipleLossWrapper', 'JSDiscretLoss', 'CombinedLoss', - 'AssociativeEmbeddingLoss', 'SoftWeightSmoothL1Loss' + 'AssociativeEmbeddingLoss', 'SoftWeightSmoothL1Loss', + 'MPJPEVelocityJointLoss' ] diff --git a/mmpose/models/pose_estimators/pose_lifter.py b/mmpose/models/pose_estimators/pose_lifter.py index 89ec29fb9d..ec8401d1a2 100644 --- a/mmpose/models/pose_estimators/pose_lifter.py +++ b/mmpose/models/pose_estimators/pose_lifter.py @@ -2,9 +2,11 @@ from itertools import zip_longest from typing import Tuple, Union +import torch from torch import Tensor from mmpose.models.utils import check_and_update_config +from mmpose.models.utils.tta import flip_coordinates from mmpose.registry import MODELS from mmpose.utils.typing import (ConfigType, InstanceList, OptConfigType, Optional, OptMultiConfig, OptSampleList, @@ -245,8 +247,18 @@ def predict(self, inputs: Tensor, data_samples: SampleList) -> SampleList: 'The model must have head to perform prediction.') if self.test_cfg.get('flip_test', False): + flip_indices = data_samples[0].metainfo['flip_indices'] _feats = self.extract_feat(inputs) - _feats_flip = self.extract_feat(inputs.flip(-1)) + _feats_flip = self.extract_feat( + torch.stack([ + flip_coordinates( + _input, + flip_indices=flip_indices, + shift_coords=self.test_cfg.get('shift_coords', True), + input_size=(1, 1)) for _input in inputs + ], + dim=0)) + feats = [_feats, _feats_flip] else: feats = self.extract_feat(inputs) From 3559e99ecf24e0ca54f9dfb4fed8df9334199bd4 Mon Sep 17 00:00:00 2001 From: LareinaM Date: Wed, 19 Jul 2023 20:24:40 +0800 Subject: [PATCH 5/6] update results --- .../body_3d_keypoint/pose_lift/h36m/motionbert_h36m.md | 8 ++++---- .../body_3d_keypoint/pose_lift/h36m/motionbert_h36m.yml | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/configs/body_3d_keypoint/pose_lift/h36m/motionbert_h36m.md b/configs/body_3d_keypoint/pose_lift/h36m/motionbert_h36m.md index d830d65c18..8e7bea451a 100644 --- a/configs/body_3d_keypoint/pose_lift/h36m/motionbert_h36m.md +++ b/configs/body_3d_keypoint/pose_lift/h36m/motionbert_h36m.md @@ -40,14 +40,14 @@ Testing results on Human3.6M dataset with ground truth 2D detections | Arch | MPJPE | average MPJPE | P-MPJPE | ckpt | | :-------------------------------------------------------------------------------------- | :---: | :-----------: | :-----: | :--------------------------------------------------------------------------------------: | -| [MotionBERT\*](/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_motionbert-243frm_8xb32-120e_h36m.py) | 35.3 | 35.3 | 27.7 | [ckpt](https://download.openmmlab.com/mmpose/v1/body_3d_keypoint/pose_lift/h36m/motionbert_h36m-f554954f_20230531.pth) | -| [MotionBERT-finetuned\*](/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_motionbert-243frm_8xb32-120e_h36m.py) | 27.5 | 27.4 | 21.6 | [ckpt](https://download.openmmlab.com/mmpose/v1/body_3d_keypoint/pose_lift/h36m/motionbert_ft_h36m-d80af323_20230531.pth) | +| [MotionBERT\*](/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_motionbert-243frm_8xb32-120e_h36m.py) | 34.5 | 34.6 | 27.1 | [ckpt](https://download.openmmlab.com/mmpose/v1/body_3d_keypoint/pose_lift/h36m/motionbert_h36m-f554954f_20230531.pth) | +| [MotionBERT-finetuned\*](/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_motionbert-243frm_8xb32-120e_h36m.py) | 26.9 | 26.8 | 21.0 | [ckpt](https://download.openmmlab.com/mmpose/v1/body_3d_keypoint/pose_lift/h36m/motionbert_ft_h36m-d80af323_20230531.pth) | Testing results on Human3.6M dataset from the [official repo](https://github.com/Walter0807/MotionBERT) with ground truth 2D detections | Arch | MPJPE | average MPJPE | P-MPJPE | ckpt | | :-------------------------------------------------------------------------------------- | :---: | :-----------: | :-----: | :--------------------------------------------------------------------------------------: | -| [MotionBERT\*](/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_motionbert-243frm_8xb32-120e_h36m.py) | 40.5 | 39.9 | 34.1 | [ckpt](https://download.openmmlab.com/mmpose/v1/body_3d_keypoint/pose_lift/h36m/motionbert_h36m-f554954f_20230531.pth) | -| [MotionBERT-finetuned\*](/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_motionbert-243frm_8xb32-120e_h36m.py) | 38.2 | 37.7 | 32.6 | [ckpt](https://download.openmmlab.com/mmpose/v1/body_3d_keypoint/pose_lift/h36m/motionbert_ft_h36m-d80af323_20230531.pth) | +| [MotionBERT\*](/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_motionbert-243frm_8xb32-120e_h36m.py) | 39.8 | 39.2 | 33.4 | [ckpt](https://download.openmmlab.com/mmpose/v1/body_3d_keypoint/pose_lift/h36m/motionbert_h36m-f554954f_20230531.pth) | +| [MotionBERT-finetuned\*](/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_motionbert-243frm_8xb32-120e_h36m.py) | 37.7 | 37.2 | 32.2 | [ckpt](https://download.openmmlab.com/mmpose/v1/body_3d_keypoint/pose_lift/h36m/motionbert_ft_h36m-d80af323_20230531.pth) | *Models with * are converted from the [official repo](https://github.com/Walter0807/MotionBERT). The config files of these models are only for validation. We don't ensure these config files' training accuracy and welcome you to contribute your reproduction results.* diff --git a/configs/body_3d_keypoint/pose_lift/h36m/motionbert_h36m.yml b/configs/body_3d_keypoint/pose_lift/h36m/motionbert_h36m.yml index 7257fea5a6..7b16d063be 100644 --- a/configs/body_3d_keypoint/pose_lift/h36m/motionbert_h36m.yml +++ b/configs/body_3d_keypoint/pose_lift/h36m/motionbert_h36m.yml @@ -15,8 +15,8 @@ Models: Results: - Dataset: Human3.6M Metrics: - MPJPE: 35.3 - P-MPJPE: 27.7 + MPJPE: 34.5 + P-MPJPE: 27.1 Task: Body 3D Keypoint Weights: https://download.openmmlab.com/mmpose/v1/body_3d_keypoint/pose_lift/h36m/motionbert_h36m-f554954f_20230531.pth - Config: configs/body_3d_keypoint/pose_lift/h36m/pose-lift_motionbert_8xb32-120e_h36m.py @@ -28,7 +28,7 @@ Models: Results: - Dataset: Human3.6M Metrics: - MPJPE: 27.5 - P-MPJPE: 21.6 + MPJPE: 26.9 + P-MPJPE: 21.0 Task: Body 3D Keypoint Weights: https://download.openmmlab.com/mmpose/v1/body_3d_keypoint/pose_lift/h36m/motionbert_ft_h36m-d80af323_20230531.pth From 2e97a63c7c2971ae5ed53161dc23b98fe5af56ed Mon Sep 17 00:00:00 2001 From: LareinaM Date: Thu, 20 Jul 2023 12:51:41 +0800 Subject: [PATCH 6/6] add markdown --- .../pose_lift/h36m/motionbert_h36m.md | 8 +- .../pose_lift/h36m/motionbert_h36m.yml | 2 +- ...ionbert-243frm_8xb32-120e_h36m-original.py | 137 +++++++++++++++++ ...nbert-ft-243frm_8xb32-60e_h36m-original.py | 142 ++++++++++++++++++ ...ift_motionbert-ft-243frm_8xb32-60e_h36m.py | 141 +++++++++++++++++ 5 files changed, 426 insertions(+), 4 deletions(-) create mode 100644 configs/body_3d_keypoint/pose_lift/h36m/pose-lift_motionbert-243frm_8xb32-120e_h36m-original.py create mode 100644 configs/body_3d_keypoint/pose_lift/h36m/pose-lift_motionbert-ft-243frm_8xb32-60e_h36m-original.py create mode 100644 configs/body_3d_keypoint/pose_lift/h36m/pose-lift_motionbert-ft-243frm_8xb32-60e_h36m.py diff --git a/configs/body_3d_keypoint/pose_lift/h36m/motionbert_h36m.md b/configs/body_3d_keypoint/pose_lift/h36m/motionbert_h36m.md index 8e7bea451a..93cd29eddd 100644 --- a/configs/body_3d_keypoint/pose_lift/h36m/motionbert_h36m.md +++ b/configs/body_3d_keypoint/pose_lift/h36m/motionbert_h36m.md @@ -43,11 +43,13 @@ Testing results on Human3.6M dataset with ground truth 2D detections | [MotionBERT\*](/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_motionbert-243frm_8xb32-120e_h36m.py) | 34.5 | 34.6 | 27.1 | [ckpt](https://download.openmmlab.com/mmpose/v1/body_3d_keypoint/pose_lift/h36m/motionbert_h36m-f554954f_20230531.pth) | | [MotionBERT-finetuned\*](/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_motionbert-243frm_8xb32-120e_h36m.py) | 26.9 | 26.8 | 21.0 | [ckpt](https://download.openmmlab.com/mmpose/v1/body_3d_keypoint/pose_lift/h36m/motionbert_ft_h36m-d80af323_20230531.pth) | -Testing results on Human3.6M dataset from the [official repo](https://github.com/Walter0807/MotionBERT) with ground truth 2D detections +Testing results on Human3.6M dataset converted from the [official repo](https://github.com/Walter0807/MotionBERT)1 with ground truth 2D detections | Arch | MPJPE | average MPJPE | P-MPJPE | ckpt | | :-------------------------------------------------------------------------------------- | :---: | :-----------: | :-----: | :--------------------------------------------------------------------------------------: | -| [MotionBERT\*](/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_motionbert-243frm_8xb32-120e_h36m.py) | 39.8 | 39.2 | 33.4 | [ckpt](https://download.openmmlab.com/mmpose/v1/body_3d_keypoint/pose_lift/h36m/motionbert_h36m-f554954f_20230531.pth) | -| [MotionBERT-finetuned\*](/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_motionbert-243frm_8xb32-120e_h36m.py) | 37.7 | 37.2 | 32.2 | [ckpt](https://download.openmmlab.com/mmpose/v1/body_3d_keypoint/pose_lift/h36m/motionbert_ft_h36m-d80af323_20230531.pth) | +| [MotionBERT\*](/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_motionbert-243frm_8xb32-120e_h36m-original.py) | 39.8 | 39.2 | 33.4 | [ckpt](https://download.openmmlab.com/mmpose/v1/body_3d_keypoint/pose_lift/h36m/motionbert_h36m-f554954f_20230531.pth) | +| [MotionBERT-finetuned\*](/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_motionbert-243frm_8xb32-120e_h36m-original.py) | 37.7 | 37.2 | 32.2 | [ckpt](https://download.openmmlab.com/mmpose/v1/body_3d_keypoint/pose_lift/h36m/motionbert_ft_h36m-d80af323_20230531.pth) | + +1 To test with the dataset from official repo, please download the [test annotation file](https://download.openmmlab.com/mmpose/v1/body_3d_keypoint/pose_lift/h36m/h36m_test_original.npz), [train annotation file](https://download.openmmlab.com/mmpose/v1/body_3d_keypoint/pose_lift/h36m/h36m_train_original.npz) and [factors](https://download.openmmlab.com/mmpose/v1/body_3d_keypoint/pose_lift/h36m/h36m_factors.npy) under `$MMPOSE/data/h36m/annotation_body3d/fps50`. *Models with * are converted from the [official repo](https://github.com/Walter0807/MotionBERT). The config files of these models are only for validation. We don't ensure these config files' training accuracy and welcome you to contribute your reproduction results.* diff --git a/configs/body_3d_keypoint/pose_lift/h36m/motionbert_h36m.yml b/configs/body_3d_keypoint/pose_lift/h36m/motionbert_h36m.yml index 7b16d063be..11ab4bb382 100644 --- a/configs/body_3d_keypoint/pose_lift/h36m/motionbert_h36m.yml +++ b/configs/body_3d_keypoint/pose_lift/h36m/motionbert_h36m.yml @@ -19,7 +19,7 @@ Models: P-MPJPE: 27.1 Task: Body 3D Keypoint Weights: https://download.openmmlab.com/mmpose/v1/body_3d_keypoint/pose_lift/h36m/motionbert_h36m-f554954f_20230531.pth -- Config: configs/body_3d_keypoint/pose_lift/h36m/pose-lift_motionbert_8xb32-120e_h36m.py +- Config: configs/body_3d_keypoint/pose_lift/h36m/pose-lift_motionbert-ft_8xb32-120e_h36m.py In Collection: MotionBERT Metadata: Architecture: *id001 diff --git a/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_motionbert-243frm_8xb32-120e_h36m-original.py b/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_motionbert-243frm_8xb32-120e_h36m-original.py new file mode 100644 index 0000000000..032188f389 --- /dev/null +++ b/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_motionbert-243frm_8xb32-120e_h36m-original.py @@ -0,0 +1,137 @@ +_base_ = ['../../../_base_/default_runtime.py'] + +vis_backends = [ + dict(type='LocalVisBackend'), +] +visualizer = dict( + type='Pose3dLocalVisualizer', vis_backends=vis_backends, name='visualizer') + +# runtime +train_cfg = dict(max_epochs=120, val_interval=10) + +# optimizer +optim_wrapper = dict( + optimizer=dict(type='AdamW', lr=0.0002, weight_decay=0.01)) + +# learning policy +param_scheduler = [ + dict(type='ExponentialLR', gamma=0.99, end=120, by_epoch=True) +] + +auto_scale_lr = dict(base_batch_size=512) + +# hooks +default_hooks = dict( + checkpoint=dict( + type='CheckpointHook', + save_best='MPJPE', + rule='less', + max_keep_ckpts=1), + logger=dict(type='LoggerHook', interval=20), +) + +# codec settings +train_codec = dict( + type='MotionBERTLabel', num_keypoints=17, concat_vis=True, mode='train') +val_codec = dict( + type='MotionBERTLabel', num_keypoints=17, concat_vis=True, rootrel=True) + +# model settings +model = dict( + type='PoseLifter', + backbone=dict( + type='DSTFormer', + in_channels=3, + feat_size=512, + depth=5, + num_heads=8, + mlp_ratio=2, + seq_len=243, + att_fuse=True, + ), + head=dict( + type='MotionRegressionHead', + in_channels=512, + out_channels=3, + embedding_size=512, + loss=dict(type='MPJPEVelocityJointLoss'), + decoder=val_codec, + ), + test_cfg=dict(flip_test=True)) + +# base dataset settings +dataset_type = 'Human36mDataset' +data_root = 'data/h36m/' + +# pipelines +train_pipeline = [ + dict(type='GenerateTarget', encoder=train_codec), + dict( + type='RandomFlipAroundRoot', + keypoints_flip_cfg=dict(center_mode='static', center_x=0.), + target_flip_cfg=dict(center_mode='static', center_x=0.), + flip_label=True), + dict( + type='PackPoseInputs', + meta_keys=('id', 'category_id', 'target_img_path', 'flip_indices', + 'factor', 'camera_param')) +] +val_pipeline = [ + dict(type='GenerateTarget', encoder=val_codec), + dict( + type='PackPoseInputs', + meta_keys=('id', 'category_id', 'target_img_path', 'flip_indices', + 'factor', 'camera_param')) +] + +# data loaders +train_dataloader = dict( + batch_size=32, + prefetch_factor=4, + pin_memory=True, + num_workers=2, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + dataset=dict( + type=dataset_type, + ann_file='annotation_body3d/fps50/h36m_train_original.npz', + seq_len=1, + multiple_target=243, + multiple_target_step=81, + camera_param_file='annotation_body3d/cameras.pkl', + data_root=data_root, + data_prefix=dict(img='images/'), + pipeline=train_pipeline, + )) + +val_dataloader = dict( + batch_size=32, + prefetch_factor=4, + pin_memory=True, + num_workers=2, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False, round_up=False), + dataset=dict( + type=dataset_type, + ann_file='annotation_body3d/fps50/h36m_test_original.npz', + factor_file='annotation_body3d/fps50/h36m_factors.npy', + seq_len=1, + seq_step=1, + multiple_target=243, + camera_param_file='annotation_body3d/cameras.pkl', + data_root=data_root, + data_prefix=dict(img='images/'), + pipeline=val_pipeline, + test_mode=True, + )) +test_dataloader = val_dataloader + +# evaluators +skip_list = [ + 'S9_Greet', 'S9_SittingDown', 'S9_Wait_1', 'S9_Greeting', 'S9_Waiting_1' +] +val_evaluator = [ + dict(type='MPJPE', mode='mpjpe', skip_list=skip_list), + dict(type='MPJPE', mode='p-mpjpe', skip_list=skip_list) +] +test_evaluator = val_evaluator diff --git a/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_motionbert-ft-243frm_8xb32-60e_h36m-original.py b/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_motionbert-ft-243frm_8xb32-60e_h36m-original.py new file mode 100644 index 0000000000..9c2aa3697a --- /dev/null +++ b/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_motionbert-ft-243frm_8xb32-60e_h36m-original.py @@ -0,0 +1,142 @@ +_base_ = ['../../../_base_/default_runtime.py'] + +vis_backends = [ + dict(type='LocalVisBackend'), +] +visualizer = dict( + type='Pose3dLocalVisualizer', vis_backends=vis_backends, name='visualizer') + +# runtime +train_cfg = dict(max_epochs=60, val_interval=10) + +# optimizer +optim_wrapper = dict( + optimizer=dict(type='AdamW', lr=0.0002, weight_decay=0.01)) + +# learning policy +param_scheduler = [ + dict(type='ExponentialLR', gamma=0.99, end=60, by_epoch=True) +] + +auto_scale_lr = dict(base_batch_size=512) + +# hooks +default_hooks = dict( + checkpoint=dict( + type='CheckpointHook', + save_best='MPJPE', + rule='less', + max_keep_ckpts=1), + logger=dict(type='LoggerHook', interval=20), +) + +# codec settings +train_codec = dict( + type='MotionBERTLabel', num_keypoints=17, concat_vis=True, mode='train') +val_codec = dict( + type='MotionBERTLabel', num_keypoints=17, concat_vis=True, rootrel=True) + +# model settings +model = dict( + type='PoseLifter', + backbone=dict( + type='DSTFormer', + in_channels=3, + feat_size=512, + depth=5, + num_heads=8, + mlp_ratio=2, + seq_len=243, + att_fuse=True, + ), + head=dict( + type='MotionRegressionHead', + in_channels=512, + out_channels=3, + embedding_size=512, + loss=dict(type='MPJPEVelocityJointLoss'), + decoder=val_codec, + ), + test_cfg=dict(flip_test=True), + init_cfg=dict( + type='Pretrained', + checkpoint='https://download.openmmlab.com/mmpose/v1/body_3d_keypoint/' + 'pose_lift/h36m/motionbert_pretrain_h36m-29ffebf5_20230719.pth'), +) + +# base dataset settings +dataset_type = 'Human36mDataset' +data_root = 'data/h36m/' + +# pipelines +train_pipeline = [ + dict(type='GenerateTarget', encoder=train_codec), + dict( + type='RandomFlipAroundRoot', + keypoints_flip_cfg=dict(center_mode='static', center_x=0.), + target_flip_cfg=dict(center_mode='static', center_x=0.), + flip_label=True), + dict( + type='PackPoseInputs', + meta_keys=('id', 'category_id', 'target_img_path', 'flip_indices', + 'factor', 'camera_param')) +] +val_pipeline = [ + dict(type='GenerateTarget', encoder=val_codec), + dict( + type='PackPoseInputs', + meta_keys=('id', 'category_id', 'target_img_path', 'flip_indices', + 'factor', 'camera_param')) +] + +# data loaders +train_dataloader = dict( + batch_size=32, + prefetch_factor=4, + pin_memory=True, + num_workers=2, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + dataset=dict( + type=dataset_type, + ann_file='annotation_body3d/fps50/h36m_train_original.npz', + seq_len=1, + multiple_target=243, + multiple_target_step=81, + camera_param_file='annotation_body3d/cameras.pkl', + data_root=data_root, + data_prefix=dict(img='images/'), + pipeline=train_pipeline, + )) + +val_dataloader = dict( + batch_size=32, + prefetch_factor=4, + pin_memory=True, + num_workers=2, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False, round_up=False), + dataset=dict( + type=dataset_type, + ann_file='annotation_body3d/fps50/h36m_test_original.npz', + factor_file='annotation_body3d/fps50/h36m_factors.npy', + seq_len=1, + seq_step=1, + multiple_target=243, + camera_param_file='annotation_body3d/cameras.pkl', + data_root=data_root, + data_prefix=dict(img='images/'), + pipeline=val_pipeline, + test_mode=True, + )) +test_dataloader = val_dataloader + +# evaluators +skip_list = [ + 'S9_Greet', 'S9_SittingDown', 'S9_Wait_1', 'S9_Greeting', 'S9_Waiting_1' +] +val_evaluator = [ + dict(type='MPJPE', mode='mpjpe', skip_list=skip_list), + dict(type='MPJPE', mode='p-mpjpe', skip_list=skip_list) +] +test_evaluator = val_evaluator diff --git a/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_motionbert-ft-243frm_8xb32-60e_h36m.py b/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_motionbert-ft-243frm_8xb32-60e_h36m.py new file mode 100644 index 0000000000..5c42e62a60 --- /dev/null +++ b/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_motionbert-ft-243frm_8xb32-60e_h36m.py @@ -0,0 +1,141 @@ +_base_ = ['../../../_base_/default_runtime.py'] + +vis_backends = [ + dict(type='LocalVisBackend'), +] +visualizer = dict( + type='Pose3dLocalVisualizer', vis_backends=vis_backends, name='visualizer') + +# runtime +train_cfg = dict(max_epochs=60, val_interval=10) + +# optimizer +optim_wrapper = dict( + optimizer=dict(type='AdamW', lr=0.0002, weight_decay=0.01)) + +# learning policy +param_scheduler = [ + dict(type='ExponentialLR', gamma=0.99, end=60, by_epoch=True) +] + +auto_scale_lr = dict(base_batch_size=512) + +# hooks +default_hooks = dict( + checkpoint=dict( + type='CheckpointHook', + save_best='MPJPE', + rule='less', + max_keep_ckpts=1), + logger=dict(type='LoggerHook', interval=20), +) + +# codec settings +train_codec = dict( + type='MotionBERTLabel', num_keypoints=17, concat_vis=True, mode='train') +val_codec = dict( + type='MotionBERTLabel', num_keypoints=17, concat_vis=True, rootrel=True) + +# model settings +model = dict( + type='PoseLifter', + backbone=dict( + type='DSTFormer', + in_channels=3, + feat_size=512, + depth=5, + num_heads=8, + mlp_ratio=2, + seq_len=243, + att_fuse=True, + ), + head=dict( + type='MotionRegressionHead', + in_channels=512, + out_channels=3, + embedding_size=512, + loss=dict(type='MPJPEVelocityJointLoss'), + decoder=val_codec, + ), + test_cfg=dict(flip_test=True), + init_cfg=dict( + type='Pretrained', + checkpoint='https://download.openmmlab.com/mmpose/v1/body_3d_keypoint/' + 'pose_lift/h36m/motionbert_pretrain_h36m-29ffebf5_20230719.pth'), +) + +# base dataset settings +dataset_type = 'Human36mDataset' +data_root = 'data/h36m/' + +# pipelines +train_pipeline = [ + dict(type='GenerateTarget', encoder=train_codec), + dict( + type='RandomFlipAroundRoot', + keypoints_flip_cfg=dict(center_mode='static', center_x=0.), + target_flip_cfg=dict(center_mode='static', center_x=0.), + flip_label=True), + dict( + type='PackPoseInputs', + meta_keys=('id', 'category_id', 'target_img_path', 'flip_indices', + 'factor', 'camera_param')) +] +val_pipeline = [ + dict(type='GenerateTarget', encoder=val_codec), + dict( + type='PackPoseInputs', + meta_keys=('id', 'category_id', 'target_img_path', 'flip_indices', + 'factor', 'camera_param')) +] + +# data loaders +train_dataloader = dict( + batch_size=32, + prefetch_factor=4, + pin_memory=True, + num_workers=2, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + dataset=dict( + type=dataset_type, + ann_file='annotation_body3d/fps50/h36m_train.npz', + seq_len=1, + multiple_target=243, + multiple_target_step=81, + camera_param_file='annotation_body3d/cameras.pkl', + data_root=data_root, + data_prefix=dict(img='images/'), + pipeline=train_pipeline, + )) + +val_dataloader = dict( + batch_size=32, + prefetch_factor=4, + pin_memory=True, + num_workers=2, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False, round_up=False), + dataset=dict( + type=dataset_type, + ann_file='annotation_body3d/fps50/h36m_test.npz', + seq_len=1, + seq_step=1, + multiple_target=243, + camera_param_file='annotation_body3d/cameras.pkl', + data_root=data_root, + data_prefix=dict(img='images/'), + pipeline=val_pipeline, + test_mode=True, + )) +test_dataloader = val_dataloader + +# evaluators +skip_list = [ + 'S9_Greet', 'S9_SittingDown', 'S9_Wait_1', 'S9_Greeting', 'S9_Waiting_1' +] +val_evaluator = [ + dict(type='MPJPE', mode='mpjpe', skip_list=skip_list), + dict(type='MPJPE', mode='p-mpjpe', skip_list=skip_list) +] +test_evaluator = val_evaluator