Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add RTMPose3D-Wholebody #3037

Open
wants to merge 16 commits into
base: dev-1.x
Choose a base branch
from
256 changes: 128 additions & 128 deletions configs/_base_/datasets/h3wb.py

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions mmpose/datasets/datasets/base/base_coco_style_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,8 @@ def load_data_list(self) -> List[dict]:
data_list = self._get_bottomup_data_infos(
instance_list, image_list)

if hasattr(self, 'coco'):
del self.coco
return data_list

def _load_annotations(self) -> Tuple[List[dict], List[dict]]:
Expand Down
3 changes: 2 additions & 1 deletion mmpose/datasets/datasets/body/mpii_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,5 +221,6 @@ def _load_annotations(self) -> Tuple[List[dict], List[dict]]:

instance_list.append(instance_info)
ann_id = ann_id + 1

del self.anns
self.coco = None
return instance_list, image_list
3 changes: 2 additions & 1 deletion mmpose/datasets/datasets/wholebody/coco_wholebody_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,12 +120,13 @@ def parse_data_info(self, raw_data_info: dict) -> Optional[dict]:
'bbox_score': np.ones(1, dtype=np.float32),
'num_keypoints': num_keypoints,
'keypoints': keypoints,
'keypoints_3d': None,
'keypoints_visible': keypoints_visible,
'iscrowd': ann['iscrowd'],
'segmentation': ann['segmentation'],
'area': area,
'id': ann['id'],
'category_id': np.array(ann['category_id']),
'category_id': ann['category_id'],
# store the raw annotation of the instance
# it is useful for evaluation without providing ann_file
'raw_ann_info': copy.deepcopy(ann),
Expand Down
43 changes: 35 additions & 8 deletions mmpose/datasets/datasets/wholebody3d/h3wb_dataset.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from typing import List, Tuple

import numpy as np
Expand Down Expand Up @@ -106,6 +107,7 @@ def _load_ann_file(self, ann_file: str) -> dict:

self.ann_data = data['train_data'].item()
self.camera_data = data['metadata'].item()
self.bboxes = data['bbox'].item()

def get_sequence_indices(self) -> List[List[int]]:
return []
Expand All @@ -132,19 +134,26 @@ def _load_annotations(self) -> Tuple[List[dict], List[dict]]:
'K': camera_param['K'][0, :2, ...],
'R': camera_param['R'][0],
'T': camera_param['T'].reshape(3, 1),
'Distortion': camera_param['Distortion'][0]
'Distortion': camera_param['Distortion'][0],
}
camera_param['f'] = (camera_param['K'][0, 0] * 1000,
camera_param['K'][1, 1] * 1000)
camera_param['c'] = (camera_param['K'][0, 2] * 1000,
camera_param['K'][1, 2] * 1000)

seq_step = 1
_len = (self.seq_len - 1) * seq_step + 1
_indices = list(
range(len(self.ann_data[subject][act]['frame_id'])))

seq_indices = [
_indices[i:(i + _len):seq_step]
for i in list(range(0,
len(_indices) - _len + 1))
]

frames = self.ann_data[subject][act]['frame_id']

for idx, frame_ids in enumerate(seq_indices):
expected_num_frames = self.seq_len
if self.multiple_target:
Expand All @@ -163,6 +172,21 @@ def _load_annotations(self) -> Tuple[List[dict], List[dict]]:
if self.multiple_target > 0:
target_idx = list(range(self.multiple_target))

bbox = self.bboxes[(subject, act, cam,
frames[frame_ids[-1]])]
bbox = np.array([[
bbox['x_min'], bbox['y_min'], bbox['x_max'],
bbox['y_max']
]],
dtype=np.float32)

img_paths = [
osp.join(self.data_root, 'original', subject,
'Images', f'{act}.{cam}',
f'frame_{frames[i]}.jpg') # noqa
for i in frame_ids
]

instance_info = {
'num_keypoints':
num_keypoints,
Expand All @@ -174,6 +198,10 @@ def _load_annotations(self) -> Tuple[List[dict], List[dict]]:
np.ones_like(_kpts_2d[..., 0], dtype=np.float32),
'keypoints_3d_visible':
np.ones_like(_kpts_2d[..., 0], dtype=np.float32),
'bbox':
bbox,
'bbox_score':
np.ones((len(frame_ids), )),
'scale':
np.zeros((1, 1), dtype=np.float32),
'center':
Expand All @@ -186,12 +214,11 @@ def _load_annotations(self) -> Tuple[List[dict], List[dict]]:
1,
'iscrowd':
0,
'camera_param':
camera_param,
'img_paths': [
f'{subject}/{act}/{cam}/{i:06d}.jpg'
for i in frame_ids
],
'camera_param': [camera_param],
'img_paths':
img_paths,
'img_path':
img_paths[-1],
'img_ids':
frame_ids,
'lifting_target':
Expand All @@ -209,5 +236,5 @@ def _load_annotations(self) -> Tuple[List[dict], List[dict]]:
image_list.append(img_info)

instance_id += 1

del self.ann_data
return instance_list, image_list
68 changes: 47 additions & 21 deletions mmpose/datasets/datasets/wholebody3d/ubody3d_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def __init__(self,

super().__init__(multiple_target=multiple_target, **kwargs)

METAINFO: dict = dict(from_file='configs/_base_/datasets/ubody3d.py')
METAINFO: dict = dict(from_file='configs/_base_/datasets/h3wb.py')

def _load_ann_file(self, ann_file: str) -> dict:
"""Load annotation file."""
Expand Down Expand Up @@ -167,7 +167,7 @@ def _parse_image_name(self, image_path: str) -> Tuple[str, int]:

def _load_annotations(self):
"""Load data from annotations in COCO format."""
num_keypoints = self.metainfo['num_keypoints']
num_keypoints = 133
self._metainfo['CLASSES'] = self.ann_data.loadCats(
self.ann_data.getCatIds())

Expand All @@ -184,23 +184,37 @@ def _load_annotations(self):
f'got {len(_ann_ids)} ')

anns = self.ann_data.loadAnns(_ann_ids)
num_anns = len(anns)
img_ids = []
kpts = np.zeros((len(anns), num_keypoints, 2), dtype=np.float32)
kpts_3d = np.zeros((len(anns), num_keypoints, 3), dtype=np.float32)
keypoints_visible = np.zeros((len(anns), num_keypoints, 1),
kpts = np.zeros((num_anns, num_keypoints, 2), dtype=np.float32)
kpts_3d = np.zeros((num_anns, num_keypoints, 3), dtype=np.float32)
keypoints_visible = np.zeros((num_anns, num_keypoints),
dtype=np.float32)
scales = np.zeros((num_anns, 2), dtype=np.float32)
centers = np.zeros((num_anns, 2), dtype=np.float32)
bboxes = np.zeros((num_anns, 4), dtype=np.float32)
bbox_scores = np.zeros((num_anns, ), dtype=np.float32)
bbox_scales = np.zeros((num_anns, 2), dtype=np.float32)

for j, ann in enumerate(anns):
img_ids.append(ann['image_id'])
kpts[j] = np.array(ann['keypoints'], dtype=np.float32)
kpts_3d[j] = np.array(ann['keypoints_3d'], dtype=np.float32)
keypoints_visible[j] = np.array(
ann['keypoints_valid'], dtype=np.float32)
if 'scale' in ann:
scales[j] = np.array(ann['scale'])
if 'center' in ann:
centers[j] = np.array(ann['center'])
bboxes[j] = np.array(ann['bbox'], dtype=np.float32)
bbox_scores[j] = np.array([1], dtype=np.float32)
bbox_scales[j] = np.array([1, 1], dtype=np.float32)

imgs = self.ann_data.loadImgs(img_ids)
keypoints_visible = keypoints_visible.squeeze(-1)

scales = np.zeros(len(imgs), dtype=np.float32)
centers = np.zeros((len(imgs), 2), dtype=np.float32)
img_paths = np.array([img['file_name'] for img in imgs])
img_paths = np.array([
f'{self.data_root}/images/' + img['file_name'] for img in imgs
])
factors = np.zeros((kpts_3d.shape[0], ), dtype=np.float32)

target_idx = [-1] if self.causal else [int(self.seq_len // 2)]
Expand All @@ -212,6 +226,8 @@ def _load_annotations(self):
cam_param['w'] = 1000
cam_param['h'] = 1000

cam_param = {'f': cam_param['focal'], 'c': cam_param['princpt']}

instance_info = {
'num_keypoints': num_keypoints,
'keypoints': kpts,
Expand All @@ -223,25 +239,35 @@ def _load_annotations(self):
'category_id': 1,
'iscrowd': 0,
'img_paths': list(img_paths),
'img_path': img_paths[-1],
'img_ids': [img['id'] for img in imgs],
'lifting_target': kpts_3d[target_idx],
'lifting_target_visible': keypoints_visible[target_idx],
'target_img_paths': img_paths[target_idx],
'camera_param': cam_param,
'target_img_paths': list(img_paths[target_idx]),
'camera_param': [cam_param],
'factor': factors,
'target_idx': target_idx,
'bbox': bboxes,
'bbox_scales': bbox_scales,
'bbox_scores': bbox_scores
}

instance_list.append(instance_info)

for img_id in self.ann_data.getImgIds():
img = self.ann_data.loadImgs(img_id)[0]
img.update({
'img_id':
img_id,
'img_path':
osp.join(self.data_prefix['img'], img['file_name']),
})
image_list.append(img)

if self.data_mode == 'bottomup':
for img_id in self.ann_data.getImgIds():
img = self.ann_data.loadImgs(img_id)[0]
img.update({
'img_id':
img_id,
'img_path':
osp.join(self.data_prefix['img'], img['file_name']),
})
image_list.append(img)
del self.ann_data
return instance_list, image_list

def load_data_list(self) -> List[dict]:
data_list = super().load_data_list()
self.ann_data = None
return data_list
2 changes: 1 addition & 1 deletion mmpose/datasets/transforms/common_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -973,7 +973,7 @@ def transform(self, results: Dict) -> Optional[dict]:
# For single encoding, the encoded items will be directly added
# into results.
auxiliary_encode_kwargs = {
key: results[key]
key: results.get(key, None)
for key in self.encoder.auxiliary_encode_keys
}
encoded = self.encoder.encode(
Expand Down
3 changes: 3 additions & 0 deletions mmpose/datasets/transforms/topdown_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,9 @@ def transform(self, results: Dict) -> Optional[dict]:
transformed_keypoints[..., :2] = cv2.transform(
results['keypoints'][..., :2], warp_mat)
results['transformed_keypoints'] = transformed_keypoints
else:
results['transformed_keypoints'] = np.zeros([])
results['keypoints_visible'] = np.ones((1, 1, 1))

results['input_size'] = (w, h)
results['input_center'] = center
Expand Down
18 changes: 17 additions & 1 deletion mmpose/models/losses/regression_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,7 +573,11 @@ class BoneLoss(nn.Module):
loss_weight (float): Weight of the loss. Default: 1.0.
"""

def __init__(self, joint_parents, use_target_weight=False, loss_weight=1.):
def __init__(self,
joint_parents,
use_target_weight: bool = False,
loss_weight: float = 1.,
loss_name: str = 'loss_bone'):
super().__init__()
self.joint_parents = joint_parents
self.use_target_weight = use_target_weight
Expand All @@ -584,6 +588,8 @@ def __init__(self, joint_parents, use_target_weight=False, loss_weight=1.):
if i != self.joint_parents[i]:
self.non_root_indices.append(i)

self._loss_name = loss_name

def forward(self, output, target, target_weight=None):
"""Forward function.

Expand All @@ -606,6 +612,7 @@ def forward(self, output, target, target_weight=None):
dim=-1)[:, self.non_root_indices]
if self.use_target_weight:
assert target_weight is not None
target_weight = target_weight[:, self.non_root_indices]
loss = torch.mean(
torch.abs((output_bone * target_weight).mean(dim=0) -
(target_bone * target_weight).mean(dim=0)))
Expand All @@ -615,6 +622,15 @@ def forward(self, output, target, target_weight=None):

return loss * self.loss_weight

@property
def loss_name(self):
"""Loss Name.

Returns:
str: The name of this loss item.
"""
return self._loss_name


@MODELS.register_module()
class SemiSupervisionLoss(nn.Module):
Expand Down
21 changes: 21 additions & 0 deletions projects/rtmpose3d/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# RTMPose3D: Real-Time 3D Pose Estimation toolkit based on RTMPose

## Abstract

RTMPose3D is a toolkit for real-time 3D pose estimation. It is based on the RTMPose model, which is a 2D pose estimation model that is capable of predicting 2D keypoints and body part associations in real-time. RTMPose3D extends RTMPose by adding a 3D pose estimation branch that can predict 3D keypoints from images directly.

## 🗂️ Model Zoo

| Model | AP on COCO-Wholebody | MPJPE on H3WB | Download |
| :--------------------------------------------------------- | :------------------: | :-----------: | :-----------------------------------------------------------------------------------------------------------: |
| [RTMW3D-L](./configs/rtmw3d-l_8xb64_cocktail14-384x288.py) | 0.678 | 0.052 | [ckpt](https://download.openmmlab.com/mmpose/v1/projects/rtmo/rtmo-s_8xb32-600e_coco-640x640-8db55a59_20231211.pth) |
| [RTMW3D-X](./configs/rtmw3d-x_8xb32_cocktail14-384x288.py) | 0.680 | 0.052 | [ckpt](https://download.openmmlab.com/mmpose/v1/projects/rtmo/rtmo-s_8xb32-600e_coco-640x640-8db55a59_20231211.pth) |

## Usage

👉🏼 TRY RTMPose3D NOW

```bash
cd /path/to/mmpose/projects/rtmpose3d
python body3d_img2pose_demo.py configs/rtmdet_m_640-8xb32_coco-person.py https://download.openmmlab.com/mmpose/v1/projects/rtmpose/rtmdet_m_8xb32-100e_coco-obj365-person-235e8209.pth configs\rtmw3d-l_8xb64_cocktail14-384x288.py rtmw3d-l_cock14-0d4ad840_20240422.pth --input /path/to/image --output-root /path/to/output
```
Loading
Loading