Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhance] Reform multiviewpose #1853

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@
type='DetectAndRegress',
backbone=None,
pretrained=None,
keypoint_head=None,
human_detector=dict(
type='VoxelCenterDetector',
image_size=image_size,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@
model = dict(
type='DetectAndRegress',
backbone=None,
keypoint_head=None,
pretrained=None,
human_detector=dict(
type='VoxelCenterDetector',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,4 @@ Results on CMU Panoptic dataset.

| Arch | mAP | mAR | MPJPE | Recall@500mm | ckpt | log |
| :--------------------------------------------------------- | :---: | :---: | :---: | :----------: | :--------------------------------------------------------: | :-------------------------------------------------------: |
| [prn64_cpn80_res50](/configs/body/3d_kpt_mview_rgb_img/voxelpose/panoptic/voxelpose_prn64x64x64_cpn80x80x20_panoptic_cam5.py) | 97.31 | 97.99 | 17.57 | 99.85 | [ckpt](https://download.openmmlab.com/mmpose/body3d/voxelpose/voxelpose_prn64x64x64_cpn80x80x20_panoptic_cam5-545c150e_20211103.pth) | [log](https://download.openmmlab.com/mmpose/body3d/voxelpose/voxelpose_prn64x64x64_cpn80x80x20_panoptic_cam5_20211103.log.json) |
| [prn64_cpn80_res50](/configs/body/3d_kpt_mview_rgb_img/voxelpose/panoptic/voxelpose_prn64x64x64_cpn80x80x20_panoptic_cam5.py) | 97.15 | 97.70 | 17.09 | 99.25 | [ckpt](https://download.openmmlab.com/mmpose/body3d/voxelpose/voxelpose_prn64x64x64_cpn80x80x20_panoptic_cam5-358648cb_20230118.pth) | [log](https://download.openmmlab.com/mmpose/body3d/voxelpose/voxelpose_prn64x64x64_cpn80x80x20_panoptic_cam5_20230118.log.json) |
Original file line number Diff line number Diff line change
Expand Up @@ -65,44 +65,30 @@
subset='validation'))

# model settings
backbone = dict(
type='AssociativeEmbedding',
pretrained=None,
backbone=dict(type='ResNet', depth=50),
keypoint_head=dict(
type='DeconvHead',
in_channels=2048,
out_channels=num_joints,
num_deconv_layers=3,
num_deconv_filters=(256, 256, 256),
num_deconv_kernels=(4, 4, 4),
loss_keypoint=dict(
type='MultiLossFactory',
num_joints=15,
num_stages=1,
ae_loss_type='exp',
with_ae_loss=[False],
push_loss_factor=[0.001],
pull_loss_factor=[0.001],
with_heatmaps_loss=[True],
heatmaps_loss_factor=[1.0],
)),
train_cfg=dict(),
test_cfg=dict(
num_joints=num_joints,
nms_kernel=None,
nms_padding=None,
tag_per_joint=None,
max_num_people=None,
detection_threshold=None,
tag_threshold=None,
use_detection_val=None,
ignore_too_much=None,
backbone = dict(type='ResNet', depth=50)
keypoint_head = dict(
type='DeconvHead',
in_channels=2048,
out_channels=num_joints,
num_deconv_layers=3,
num_deconv_filters=(256, 256, 256),
num_deconv_kernels=(4, 4, 4),
loss_keypoint=dict(
type='MultiLossFactory',
num_joints=15,
num_stages=1,
ae_loss_type='exp',
with_ae_loss=[False],
push_loss_factor=[0.001],
pull_loss_factor=[0.001],
with_heatmaps_loss=[True],
heatmaps_loss_factor=[1.0],
))

model = dict(
type='DetectAndRegress',
backbone=backbone,
keypoint_head=keypoint_head,
pretrained='checkpoints/resnet_50_deconv.pth.tar',
human_detector=dict(
type='VoxelCenterDetector',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ Models:
Results:
- Dataset: CMU Panoptic
Metrics:
MPJPE: 17.57
mAP: 97.31
mAR: 97.99
MPJPE: 17.09
mAP: 97.15
mAR: 97.7
Task: Body 3D Keypoint
Weights: https://download.openmmlab.com/mmpose/body3d/voxelpose/voxelpose_prn64x64x64_cpn80x80x20_panoptic_cam5-545c150e_20211103.pth
Weights: https://download.openmmlab.com/mmpose/body3d/voxelpose/voxelpose_prn64x64x64_cpn80x80x20_panoptic_cam5-358648cb_20230118.pth
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@
model = dict(
type='DetectAndRegress',
backbone=None,
keypoint_head=None,
pretrained=None,
human_detector=dict(
type='VoxelCenterDetector',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@
type='DetectAndRegress',
backbone=None,
pretrained=None,
keypoint_head=None,
human_detector=dict(
type='VoxelCenterDetector',
image_size=image_size,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import json_tricks as json
import numpy as np
from scipy.io import loadmat
from torch.utils.data import Dataset

from mmpose.datasets import DatasetInfo
Expand Down Expand Up @@ -249,8 +248,5 @@ def _load_files(self):

assert osp.exists(self.gt_pose_db_file), f'gt_pose_db_file ' \
f"{self.gt_pose_db_file} doesn't exist, please check again"
gt = loadmat(self.gt_pose_db_file)
self.gt_pose_db = np.array(np.array(
gt['actor3D'].tolist()).tolist()).squeeze()

self.gt_pose_db = np.load(self.gt_pose_db_file)
self.num_persons = len(self.gt_pose_db)
42 changes: 30 additions & 12 deletions mmpose/models/detectors/multiview_pose.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from mmpose.core.post_processing.post_transforms import (
affine_transform_torch, get_affine_transform)
from .. import builder
from ..builder import POSENETS
from ..builder import BACKBONES, HEADS, POSENETS
from ..utils.misc import torch_meshgrid_ij
from .base import BasePose

Expand Down Expand Up @@ -138,7 +138,9 @@ class DetectAndRegress(BasePose):
"""DetectAndRegress approach for multiview human pose detection.

Args:
backbone (ConfigDict): Dictionary to construct the 2D pose detector
backbone (ConfigDict): Dictionary to construct the backbone.
keypoint_head (ConfigDict): Dictionary to construct the 2d
keypoint head.
human_detector (ConfigDict): dictionary to construct human detector
pose_regressor (ConfigDict): dictionary to construct pose regressor
train_cfg (ConfigDict): Config for training. Default: None.
Expand All @@ -150,6 +152,7 @@ class DetectAndRegress(BasePose):

def __init__(self,
backbone,
keypoint_head,
human_detector,
pose_regressor,
train_cfg=None,
Expand All @@ -158,11 +161,16 @@ def __init__(self,
freeze_2d=True):
super(DetectAndRegress, self).__init__()
if backbone is not None:
self.backbone = builder.build_posenet(backbone)
if self.training and pretrained is not None:
load_checkpoint(self.backbone, pretrained)
self.backbone = BACKBONES.build(backbone)
else:
self.backbone = None
if keypoint_head is not None:
self.keypoint_head = HEADS.build(keypoint_head)
else:
self.keypoint_head = None

if self.training and pretrained is not None:
load_checkpoint(self, pretrained)

self.freeze_2d = freeze_2d
self.human_detector = builder.MODELS.build(human_detector)
Expand All @@ -188,8 +196,11 @@ def train(self, mode=True):
Module: self
"""
super().train(mode)
if mode and self.freeze_2d and self.backbone is not None:
self._freeze(self.backbone)
if mode and self.freeze_2d:
if self.backbone is not None:
self._freeze(self.backbone)
if self.keypoint_head is not None:
self._freeze(self.keypoint_head)

return self

Expand Down Expand Up @@ -283,6 +294,12 @@ def train_step(self, data_batch, optimizer, **kwargs):

return outputs

def predict_heatmap(self, img):
output = self.backbone(img)
output = self.keypoint_head(output)

return output

def forward_train(self,
img,
img_metas,
Expand Down Expand Up @@ -331,7 +348,7 @@ def forward_train(self,
feature_maps = []
assert isinstance(img, list)
for img_ in img:
feature_maps.append(self.backbone.forward_dummy(img_)[0])
feature_maps.append(self.predict_heatmap(img_)[0])

losses = dict()
human_candidates, human_loss = self.human_detector.forward_train(
Expand All @@ -351,8 +368,9 @@ def forward_train(self,
heatmaps_tensor = torch.cat(feature_maps, dim=0)
targets_tensor = torch.cat(targets, dim=0)
masks_tensor = torch.cat(masks, dim=0)
losses_2d_ = self.backbone.get_loss(heatmaps_tensor,
targets_tensor, masks_tensor)
losses_2d_ = self.keypoint_head.get_loss(heatmaps_tensor,
targets_tensor,
masks_tensor)
for k, v in losses_2d_.items():
losses_2d[k + '_2d'] = v
losses.update(losses_2d)
Expand Down Expand Up @@ -400,7 +418,7 @@ def forward_test(
feature_maps = []
assert isinstance(img, list)
for img_ in img:
feature_maps.append(self.backbone.forward_dummy(img_)[0])
feature_maps.append(self.predict_heatmap(img_)[0])

human_candidates = self.human_detector.forward_test(
None, img_metas, feature_maps)
Expand Down Expand Up @@ -506,7 +524,7 @@ def forward_dummy(self, img, input_heatmaps=None, num_candidates=5):
feature_maps = []
assert isinstance(img, list)
for img_ in img:
feature_maps.append(self.backbone.forward_dummy(img_)[0])
feature_maps.append(self.predict_heatmap(img_)[0])

_ = self.human_detector.forward_dummy(feature_maps)

Expand Down
2 changes: 1 addition & 1 deletion model-index.yml
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,8 @@ Import:
- configs/face/2d_kpt_sview_rgb_img/topdown_heatmap/wflw/hrnetv2_dark_wflw.yml
- configs/face/2d_kpt_sview_rgb_img/topdown_heatmap/wflw/hrnetv2_wflw.yml
- configs/fashion/2d_kpt_sview_rgb_img/deeppose/deepfashion/resnet_deepfashion.yml
- configs/fashion/2d_kpt_sview_rgb_img/topdown_heatmap/deepfashion/resnet_deepfashion.yml
- configs/fashion/2d_kpt_sview_rgb_img/topdown_heatmap/deepfashion2/resnet_deepfashion2.yml
- configs/fashion/2d_kpt_sview_rgb_img/topdown_heatmap/deepfashion/resnet_deepfashion.yml
- configs/hand/2d_kpt_sview_rgb_img/deeppose/onehand10k/resnet_onehand10k.yml
- configs/hand/2d_kpt_sview_rgb_img/deeppose/panoptic2d/resnet_panoptic2d.yml
- configs/hand/2d_kpt_sview_rgb_img/deeppose/rhd2d/resnet_rhd2d.yml
Expand Down
Binary file removed tests/data/campus/actorsGT.mat
Binary file not shown.
Binary file added tests/data/campus/actorsGT.npy
Binary file not shown.
Binary file removed tests/data/shelf/actorsGT.mat
Binary file not shown.
Binary file added tests/data/shelf/actorsGT.npy
Binary file not shown.
8 changes: 4 additions & 4 deletions tests/test_datasets/test_body3d_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ def test_body3dmview_direct_campus_dataset():
cam_file=f'{data_root}/calibration_campus.json',
train_pose_db_file=f'{data_root}/panoptic_training_pose.pkl',
test_pose_db_file=f'{data_root}/pred_campus_maskrcnn_hrnet_coco.pkl',
gt_pose_db_file=f'{data_root}/actorsGT.mat',
gt_pose_db_file=f'{data_root}/actorsGT.npy',
)

test_data_cfg = dict(
Expand All @@ -398,7 +398,7 @@ def test_body3dmview_direct_campus_dataset():
cam_file=f'{data_root}/calibration_campus.json',
train_pose_db_file=f'{data_root}/panoptic_training_pose.pkl',
test_pose_db_file=f'{data_root}/pred_campus_maskrcnn_hrnet_coco.pkl',
gt_pose_db_file=f'{data_root}/actorsGT.mat',
gt_pose_db_file=f'{data_root}/actorsGT.npy',
)

# test when dataset_info is None
Expand Down Expand Up @@ -507,7 +507,7 @@ def test_body3dmview_direct_shelf_dataset():
cam_file=f'{data_root}/calibration_shelf.json',
train_pose_db_file=f'{data_root}/panoptic_training_pose.pkl',
test_pose_db_file=f'{data_root}/pred_shelf_maskrcnn_hrnet_coco.pkl',
gt_pose_db_file=f'{data_root}/actorsGT.mat',
gt_pose_db_file=f'{data_root}/actorsGT.npy',
)

test_data_cfg = dict(
Expand All @@ -526,7 +526,7 @@ def test_body3dmview_direct_shelf_dataset():
cam_file=f'{data_root}/calibration_shelf.json',
train_pose_db_file=f'{data_root}/panoptic_training_pose.pkl',
test_pose_db_file=f'{data_root}/pred_shelf_maskrcnn_hrnet_coco.pkl',
gt_pose_db_file=f'{data_root}/actorsGT.mat',
gt_pose_db_file=f'{data_root}/actorsGT.npy',
)

# test when dataset_info is None
Expand Down
1 change: 1 addition & 0 deletions tests/test_models/test_multiview_pose.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def test_voxelpose_forward():
model_cfg = dict(
type='DetectAndRegress',
backbone=None,
keypoint_head=None,
human_detector=dict(
type='VoxelCenterDetector',
image_size=[960, 512],
Expand Down