From 5c52f995c634c96ec8d3e5fa2961db8785da7366 Mon Sep 17 00:00:00 2001 From: "wenchao.w" Date: Thu, 14 Jul 2022 17:24:57 +0800 Subject: [PATCH] add single depth to 3d hand keypoints, add nyu dataset, awr network --- configs/_base_/datasets/nyu.py | 92 ++++ .../awr/nyu/res50_nyu_all_128x128.py | 177 +++++++ mmpose/core/evaluation/top_down_eval.py | 35 ++ mmpose/datasets/datasets/base/__init__.py | 12 +- ...kpt_3d_sview_depth_img_top_down_dataset.py | 400 +++++++++++++++ mmpose/datasets/datasets/hand/__init__.py | 3 +- .../datasets/datasets/hand/nyuhand_dataset.py | 168 +++++++ mmpose/datasets/pipelines/hand_transform.py | 87 ++++ .../datasets/pipelines/top_down_transform.py | 45 +- mmpose/models/backbones/__init__.py | 4 +- mmpose/models/backbones/awr_resnet.py | 33 ++ mmpose/models/detectors/__init__.py | 4 +- mmpose/models/detectors/depthhand_3d.py | 303 ++++++++++++ mmpose/models/heads/__init__.py | 3 +- mmpose/models/heads/awr_head.py | 458 ++++++++++++++++++ mmpose/models/losses/__init__.py | 9 +- mmpose/models/losses/regression_loss.py | 43 ++ tests/test_models/test_awr_3d_head.py | 91 ++++ .../test_models/test_depthhand_3d_forward.py | 108 +++++ 19 files changed, 2060 insertions(+), 15 deletions(-) create mode 100644 configs/_base_/datasets/nyu.py create mode 100644 configs/hand/3d_kpt_sview_depth_img/awr/nyu/res50_nyu_all_128x128.py create mode 100644 mmpose/datasets/datasets/base/kpt_3d_sview_depth_img_top_down_dataset.py create mode 100644 mmpose/datasets/datasets/hand/nyuhand_dataset.py create mode 100644 mmpose/models/backbones/awr_resnet.py create mode 100644 mmpose/models/detectors/depthhand_3d.py create mode 100644 mmpose/models/heads/awr_head.py create mode 100644 tests/test_models/test_awr_3d_head.py create mode 100644 tests/test_models/test_depthhand_3d_forward.py diff --git a/configs/_base_/datasets/nyu.py b/configs/_base_/datasets/nyu.py new file mode 100644 index 0000000000..cf1acb0a87 --- /dev/null +++ b/configs/_base_/datasets/nyu.py @@ -0,0 +1,92 @@ +dataset_info = dict( + dataset_name='nyu', + paper_info=dict( + author='Jonathan Tompson and Murphy Stein and Yann Lecun and ' + 'Ken Perlin', + title='Real-Time Continuous Pose Recovery of Human Hands ' + 'Using Convolutional Networks', + container='ACM Transactions on Graphics', + year='2014', + homepage='https://jonathantompson.github.io/NYU_Hand_Pose_Dataset.htm', + ), + keypoint_info={ + 0: dict(name='F1_KNU3_A', id=0, color=[255, 128, 0], type='', swap=''), + 1: dict(name='F1_KNU3_B', id=1, color=[255, 128, 0], type='', swap=''), + 2: dict(name='F1_KNU2_A', id=2, color=[255, 128, 0], type='', swap=''), + 3: dict(name='F1_KNU2_B', id=3, color=[255, 128, 0], type='', swap=''), + 4: + dict(name='F1_KNU1_A', id=4, color=[255, 153, 255], type='', swap=''), + 5: + dict(name='F1_KNU1_B', id=5, color=[255, 153, 255], type='', swap=''), + 6: + dict(name='F2_KNU3_A', id=6, color=[255, 153, 255], type='', swap=''), + 7: + dict(name='F2_KNU3_B', id=7, color=[255, 153, 255], type='', swap=''), + 8: + dict(name='F2_KNU2_A', id=8, color=[102, 178, 255], type='', swap=''), + 9: + dict(name='F2_KNU2_B', id=9, color=[102, 178, 255], type='', swap=''), + 10: + dict(name='F2_KNU1_A', id=10, color=[102, 178, 255], type='', swap=''), + 11: + dict(name='F2_KNU1_B', id=11, color=[102, 178, 255], type='', swap=''), + 12: + dict(name='F3_KNU3_A', id=12, color=[255, 51, 51], type='', swap=''), + 13: + dict(name='F3_KNU3_B', id=13, color=[255, 51, 51], type='', swap=''), + 14: + dict(name='F3_KNU2_A', id=14, color=[255, 51, 51], type='', swap=''), + 15: + dict(name='F3_KNU2_B', id=15, color=[255, 51, 51], type='', swap=''), + 16: dict(name='F3_KNU1_A', id=16, color=[0, 255, 0], type='', swap=''), + 17: dict(name='F3_KNU1_B', id=17, color=[0, 255, 0], type='', swap=''), + 18: dict(name='F4_KNU3_A', id=18, color=[0, 255, 0], type='', swap=''), + 19: dict(name='F4_KNU3_B', id=19, color=[0, 255, 0], type='', swap=''), + 20: + dict(name='F4_KNU2_A', id=20, color=[255, 255, 255], type='', swap=''), + 21: + dict(name='F4_KNU2_B', id=21, color=[255, 128, 0], type='', swap=''), + 22: + dict(name='F4_KNU1_A', id=22, color=[255, 128, 0], type='', swap=''), + 23: + dict(name='F4_KNU1_B', id=23, color=[255, 128, 0], type='', swap=''), + 24: + dict(name='TH_KNU3_A', id=24, color=[255, 128, 0], type='', swap=''), + 25: + dict(name='TH_KNU3_B', id=25, color=[255, 153, 255], type='', swap=''), + 26: + dict(name='TH_KNU2_A', id=26, color=[255, 153, 255], type='', swap=''), + 27: + dict(name='TH_KNU2_B', id=27, color=[255, 153, 255], type='', swap=''), + 28: + dict(name='TH_KNU1_A', id=28, color=[255, 153, 255], type='', swap=''), + 29: + dict(name='TH_KNU1_B', id=29, color=[102, 178, 255], type='', swap=''), + 30: + dict(name='PALM_1', id=30, color=[102, 178, 255], type='', swap=''), + 31: + dict(name='PALM_2', id=31, color=[102, 178, 255], type='', swap=''), + 32: + dict(name='PALM_3', id=32, color=[102, 178, 255], type='', swap=''), + 33: dict(name='PALM_4', id=33, color=[255, 51, 51], type='', swap=''), + 34: dict(name='PALM_5', id=34, color=[255, 51, 51], type='', swap=''), + 35: dict(name='PALM_6', id=35, color=[255, 51, 51], type='', swap=''), + }, + skeleton_info={ + 0: dict(link=('PALM_3', 'F1_KNU2_B'), id=0, color=[255, 128, 0]), + 1: dict(link=('F1_KNU2_B', 'F1_KNU3_A'), id=1, color=[255, 128, 0]), + 2: dict(link=('PALM_3', 'F2_KNU2_B'), id=2, color=[255, 128, 0]), + 3: dict(link=('F2_KNU2_B', 'F2_KNU3_A'), id=3, color=[255, 128, 0]), + 4: dict(link=('PALM_3', 'F3_KNU2_B'), id=4, color=[255, 153, 255]), + 5: dict(link=('F3_KNU2_B', 'F3_KNU3_A'), id=5, color=[255, 153, 255]), + 6: dict(link=('PALM_3', 'F4_KNU2_B'), id=6, color=[255, 153, 255]), + 7: dict(link=('F4_KNU2_B', 'F4_KNU3_A'), id=7, color=[255, 153, 255]), + 8: dict(link=('PALM_3', 'TH_KNU2_B'), id=8, color=[102, 178, 255]), + 9: dict(link=('TH_KNU2_B', 'TH_KNU3_B'), id=9, color=[102, 178, 255]), + 10: + dict(link=('TH_KNU3_B', 'TH_KNU3_A'), id=10, color=[102, 178, 255]), + 11: dict(link=('PALM_3', 'PALM_1'), id=11, color=[102, 178, 255]), + 12: dict(link=('PALM_3', 'PALM_2'), id=12, color=[255, 51, 51]), + }, + joint_weights=[1.] * 36, + sigmas=[]) diff --git a/configs/hand/3d_kpt_sview_depth_img/awr/nyu/res50_nyu_all_128x128.py b/configs/hand/3d_kpt_sview_depth_img/awr/nyu/res50_nyu_all_128x128.py new file mode 100644 index 0000000000..a6d78345ae --- /dev/null +++ b/configs/hand/3d_kpt_sview_depth_img/awr/nyu/res50_nyu_all_128x128.py @@ -0,0 +1,177 @@ +_base_ = [ + '../../../../_base_/default_runtime.py', + '../../../../_base_/datasets/nyu.py' +] +checkpoint_config = dict(interval=1) +# TODO: metric +evaluation = dict( + interval=1, + metric=['MRRPE', 'MPJPE', 'Handedness_acc'], + save_best='MPJPE_all') + +optimizer = dict( + type='Adam', + lr=2e-4, +) +optimizer_config = dict(grad_clip=None) +# learning policy +lr_config = dict(policy='step', step=[15, 17]) +total_epochs = 20 +log_config = dict( + interval=20, + hooks=[ + dict(type='TextLoggerHook'), + # dict(type='TensorboardLoggerHook') + ]) + +load_from = '/root/mmpose/data/ckpt/new_res50.pth' +used_keypoints_index = [0, 3, 6, 9, 12, 15, 18, 21, 24, 25, 27, 30, 31, 32] + +channel_cfg = dict( + num_output_channels=14, + dataset_joints=36, + dataset_channel=used_keypoints_index, + inference_channel=used_keypoints_index) + +# model settings +model = dict( + type='Depthhand3D', # pretrained=None + backbone=dict( + type='AWRResNet', + depth=50, + frozen_stages=-1, + zero_init_residual=False, + in_channels=1), + keypoint_head=dict( + type='AdaptiveWeightingRegression3DHead', + offset_head_cfg=dict( + in_channels=256, + out_channels_vector=42, + out_channels_scalar=14, + heatmap_kernel_size=1.0, + ), + deconv_head_cfg=dict( + in_channels=2048, + out_channels=256, + depth_size=64, + num_deconv_layers=3, + num_deconv_filters=(256, 256, 256), + num_deconv_kernels=(4, 4, 4), + extra=dict(final_conv_kernel=0, )), + loss_offset=dict(type='AWRSmoothL1Loss', use_target_weight=False), + loss_keypoint=dict(type='AWRSmoothL1Loss', use_target_weight=True), + ), + train_cfg=dict(use_img_for_head=True), + test_cfg=dict(use_img_for_head=True, flip_test=False)) + +data_cfg = dict( + image_size=[128, 128], + heatmap_size=[64, 64, 56], + cube_size=[300, 300, 300], + heatmap_size_root=64, + num_output_channels=channel_cfg['num_output_channels'], + num_joints=channel_cfg['dataset_joints'], + dataset_channel=channel_cfg['dataset_channel'], + inference_channel=channel_cfg['inference_channel']) + +train_pipeline = [ + dict(type='LoadImageFromFile', color_type='unchanged'), + dict(type='TopDownGetBboxCenterScale', padding=1.0), + dict(type='TopDownAffine'), + dict(type='DepthToTensor'), + dict( + type='MultitaskGatherTarget', + pipeline_list=[ + [ + dict( + type='TopDownGenerateTargetRegression', + use_zero_mean=True, + joint_indices=used_keypoints_index, + is_3d=True, + normalize_depth=True, + ), + dict( + type='HandGenerateJointToOffset', + heatmap_kernel_size=1.0, + ) + ], + [ + dict( + type='TopDownGenerateTargetRegression', + use_zero_mean=True, + joint_indices=used_keypoints_index, + is_3d=True, + normalize_depth=True, + ) + ], + ], + pipeline_indices=[0, 1], + ), + dict( + type='Collect', + keys=['img', 'target', 'target_weight'], + meta_keys=[ + 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale', + 'rotation', 'flip_pairs', 'cube_size', 'center_depth', 'focal', + 'princpt', 'image_size', 'joints_cam', 'dataset_channel', + 'joints_uvd' + ]), +] + +val_pipeline = [ + dict(type='LoadImageFromFile', color_type='unchanged'), + dict(type='TopDownGetBboxCenterScale', padding=1.0), + dict(type='TopDownAffine'), + dict(type='DepthToTensor'), + dict( + type='Collect', + keys=['img'], + meta_keys=[ + 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale', + 'rotation', 'flip_pairs', 'cube_size', 'center_depth', 'focal', + 'princpt', 'image_size', 'joints_cam', 'dataset_channel', + 'joints_uvd' + ]) +] + +test_pipeline = val_pipeline + +data_root = 'data/nyu' +data = dict( + samples_per_gpu=4, + workers_per_gpu=0, + shuffle=False, + train=dict( + type='NYUHandDataset', + ann_file=f'{data_root}/annotations/nyu_test_data.json', + camera_file=f'{data_root}/annotations/nyu_test_camera.json', + joint_file=f'{data_root}/annotations/nyu_test_joint_3d.json', + img_prefix=f'{data_root}/images/test/', + data_cfg=data_cfg, + use_refined_center=False, + align_uvd_xyz_direction=True, + pipeline=train_pipeline, + dataset_info={{_base_.dataset_info}}), + val=dict( + type='NYUHandDataset', + ann_file=f'{data_root}/annotations/nyu_test_data.json', + camera_file=f'{data_root}/annotations/nyu_test_camera.json', + joint_file=f'{data_root}/annotations/nyu_test_joint_3d.json', + img_prefix=f'{data_root}/images/test/', + data_cfg=data_cfg, + use_refined_center=False, + align_uvd_xyz_direction=True, + pipeline=val_pipeline, + dataset_info={{_base_.dataset_info}}), + test=dict( + type='NYUHandDataset', + ann_file=f'{data_root}/annotations/nyu_test_data.json', + camera_file=f'{data_root}/annotations/nyu_test_camera.json', + joint_file=f'{data_root}/annotations/nyu_test_joint_3d.json', + img_prefix=f'{data_root}/images/test/', + data_cfg=data_cfg, + use_refined_center=False, + align_uvd_xyz_direction=True, + pipeline=test_pipeline, + dataset_info={{_base_.dataset_info}}), +) diff --git a/mmpose/core/evaluation/top_down_eval.py b/mmpose/core/evaluation/top_down_eval.py index ee6a2501cf..21c6b8c6b3 100644 --- a/mmpose/core/evaluation/top_down_eval.py +++ b/mmpose/core/evaluation/top_down_eval.py @@ -655,6 +655,41 @@ def keypoints_from_heatmaps3d(heatmaps, center, scale): return preds, maxvals +def keypoints_from_joint_uvd(joint_uvd, center, scale, image_size): + """Get final keypoint predictions from 3d heatmaps and transform them back + to the image. + + Note: + - batch size: N + - num keypoints: K + - heatmap depth size: D + - heatmap height: H + - heatmap width: W + + Args: + heatmaps (np.ndarray[N, K, D, H, W]): model predicted heatmaps. + center (np.ndarray[N, 2]): Center of the bounding box (x, y). + scale (np.ndarray[N, 2]): Scale of the bounding box + wrt height/width. + + Returns: + tuple: A tuple containing keypoint predictions and scores. + + - preds (np.ndarray[N, K, 3]): Predicted 3d keypoint location \ + in images. + - maxvals (np.ndarray[N, K, 1]): Scores (confidence) of the keypoints. + """ + N, K, D = joint_uvd.shape + preds = joint_uvd + maxvals = np.ones((N, K, 1), dtype=np.float32) + # Transform back to the image + for i in range(N): + preds[i, :, :2] = transform_preds( + (preds[i, :, :2] + 1) * image_size[i] / 2, center[i], scale[i], + [image_size[i, 1], image_size[i, 0]]) + return preds, maxvals + + def multilabel_classification_accuracy(pred, gt, mask, thr=0.5): """Get multi-label classification accuracy. diff --git a/mmpose/datasets/datasets/base/__init__.py b/mmpose/datasets/datasets/base/__init__.py index e5f9a0899c..75d6901f2a 100644 --- a/mmpose/datasets/datasets/base/__init__.py +++ b/mmpose/datasets/datasets/base/__init__.py @@ -6,12 +6,18 @@ from .kpt_2d_sview_rgb_vid_top_down_dataset import \ Kpt2dSviewRgbVidTopDownDataset from .kpt_3d_mview_rgb_img_direct_dataset import Kpt3dMviewRgbImgDirectDataset +from .kpt_3d_sview_depth_img_top_down_dataset import \ + Kpt3dSviewDepthImgTopDownDataset from .kpt_3d_sview_kpt_2d_dataset import Kpt3dSviewKpt2dDataset from .kpt_3d_sview_rgb_img_top_down_dataset import \ Kpt3dSviewRgbImgTopDownDataset __all__ = [ - 'Kpt3dMviewRgbImgDirectDataset', 'Kpt2dSviewRgbImgTopDownDataset', - 'Kpt3dSviewRgbImgTopDownDataset', 'Kpt2dSviewRgbImgBottomUpDataset', - 'Kpt3dSviewKpt2dDataset', 'Kpt2dSviewRgbVidTopDownDataset' + 'Kpt3dMviewRgbImgDirectDataset', + 'Kpt2dSviewRgbImgTopDownDataset', + 'Kpt3dSviewRgbImgTopDownDataset', + 'Kpt2dSviewRgbImgBottomUpDataset', + 'Kpt3dSviewKpt2dDataset', + 'Kpt2dSviewRgbVidTopDownDataset', + 'Kpt3dSviewDepthImgTopDownDataset', ] diff --git a/mmpose/datasets/datasets/base/kpt_3d_sview_depth_img_top_down_dataset.py b/mmpose/datasets/datasets/base/kpt_3d_sview_depth_img_top_down_dataset.py new file mode 100644 index 0000000000..73a72e8c7e --- /dev/null +++ b/mmpose/datasets/datasets/base/kpt_3d_sview_depth_img_top_down_dataset.py @@ -0,0 +1,400 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +from abc import ABCMeta, abstractmethod + +import json_tricks as json +import numpy as np +from torch.utils.data import Dataset +from xtcocotools.coco import COCO + +from mmpose.core.evaluation.top_down_eval import (keypoint_auc, keypoint_epe, + keypoint_pck_accuracy) +from mmpose.datasets import DatasetInfo +from mmpose.datasets.pipelines import Compose + + +class Kpt3dSviewDepthImgTopDownDataset(Dataset, metaclass=ABCMeta): + """Base class for keypoint 3D top-down pose estimation with single-view + depth image as the input. + + All depth-based datasets should subclass it. + All subclasses should overwrite: + Methods:`_get_db`, 'evaluate' + + Args: + ann_file (str): Path to the annotation file. + img_prefix (str): Path to a directory where images are held. + Default: None. + data_cfg (dict): config + pipeline (list[dict | callable]): A sequence of data transforms. + dataset_info (DatasetInfo): A class containing all dataset info. + coco_style (bool): Whether the annotation json is coco-style. + Default: True + test_mode (bool): Store True when building test or + validation dataset. Default: False. + """ + + def __init__(self, + ann_file, + img_prefix, + data_cfg, + pipeline, + dataset_info=None, + coco_style=True, + test_mode=False): + + self.image_info = {} + self.ann_info = {} + + self.ann_file = ann_file + self.img_prefix = img_prefix + self.pipeline = pipeline + self.test_mode = test_mode + + self.ann_info['image_size'] = np.array(data_cfg['image_size']) + self.ann_info['heatmap_size'] = np.array(data_cfg['heatmap_size']) + self.ann_info['num_joints'] = data_cfg['num_joints'] + + self.ann_info['inference_channel'] = data_cfg['inference_channel'] + self.ann_info['num_output_channels'] = data_cfg['num_output_channels'] + self.ann_info['dataset_channel'] = data_cfg['dataset_channel'] + + if dataset_info is None: + raise ValueError( + 'Check https://github.com/open-mmlab/mmpose/pull/663 ' + 'for details.') + + dataset_info = DatasetInfo(dataset_info) + + assert self.ann_info['num_joints'] == dataset_info.keypoint_num + self.ann_info['flip_pairs'] = dataset_info.flip_pairs + self.ann_info['flip_index'] = dataset_info.flip_index + self.ann_info['upper_body_ids'] = dataset_info.upper_body_ids + self.ann_info['lower_body_ids'] = dataset_info.lower_body_ids + self.ann_info['joint_weights'] = dataset_info.joint_weights + self.ann_info['skeleton'] = dataset_info.skeleton + self.sigmas = dataset_info.sigmas + self.dataset_name = dataset_info.dataset_name + + if coco_style: + self.coco = COCO(ann_file) + if 'categories' in self.coco.dataset: + cats = [ + cat['name'] + for cat in self.coco.loadCats(self.coco.getCatIds()) + ] + self.classes = ['__background__'] + cats + self.num_classes = len(self.classes) + self._class_to_ind = dict( + zip(self.classes, range(self.num_classes))) + self._class_to_coco_ind = dict( + zip(cats, self.coco.getCatIds())) + self._coco_ind_to_class_ind = dict( + (self._class_to_coco_ind[cls], self._class_to_ind[cls]) + for cls in self.classes[1:]) + self.img_ids = self.coco.getImgIds() + self.num_images = len(self.img_ids) + self.id2name, self.name2id = self._get_mapping_id_name( + self.coco.imgs) + + self.db = [] + + self.pipeline = Compose(self.pipeline) + + @staticmethod + def _cam2pixel(cam_coord, f, c): + """Transform the joints from their camera coordinates to their pixel + coordinates. + + Note: + N: number of joints + + Args: + cam_coord (ndarray[N, 3]): 3D joints coordinates + in the camera coordinate system + f (ndarray[2]): focal length of x and y axis + c (ndarray[2]): principal point of x and y axis + + Returns: + img_coord (ndarray[N, 3]): the coordinates (x, y, 0) + in the image plane. + """ + x = cam_coord[:, 0] / (cam_coord[:, 2] + 1e-8) * f[0] + c[0] + y = cam_coord[:, 1] / (cam_coord[:, 2] + 1e-8) * f[1] + c[1] + z = np.zeros_like(x) + img_coord = np.concatenate((x[:, None], y[:, None], z[:, None]), 1) + return img_coord + + @staticmethod + def _world2cam(world_coord, R, T): + """Transform the joints from their world coordinates to their camera + coordinates. + + Note: + N: number of joints + + Args: + world_coord (ndarray[3, N]): 3D joints coordinates + in the world coordinate system + R (ndarray[3, 3]): camera rotation matrix + T (ndarray[3, 1]): camera position (x, y, z) + + Returns: + cam_coord (ndarray[3, N]): 3D joints coordinates + in the camera coordinate system + """ + cam_coord = np.dot(R, world_coord - T) + return cam_coord + + @staticmethod + def _pixel2cam(pixel_coord, f, c): + """Transform the joints from their pixel coordinates to their camera + coordinates. + + Note: + N: number of joints + + Args: + pixel_coord (ndarray[N, 3]): 3D joints coordinates + in the pixel coordinate system + f (ndarray[2]): focal length of x and y axis + c (ndarray[2]): principal point of x and y axis + + Returns: + cam_coord (ndarray[N, 3]): 3D joints coordinates + in the camera coordinate system + """ + x = (pixel_coord[:, 0] - c[0]) / f[0] * pixel_coord[:, 2] + y = (pixel_coord[:, 1] - c[1]) / f[1] * pixel_coord[:, 2] + z = pixel_coord[:, 2] + cam_coord = np.concatenate((x[:, None], y[:, None], z[:, None]), 1) + return cam_coord + + @staticmethod + def _xyz2uvd(xyz, f, c): + """Transform the joints from their 3d xyz camera coordinates to their + 2.5D uvd coordinates. + + Note: + N: number of joints + + Args: + xyz (ndarray[N, 3]): 3D joints coordinates + in the camera coordinate system + f (ndarray[2]): focal length of x and y axis + c (ndarray[2]): principal point of x and y axis + + Returns: + uvd (ndarray[N, 3]): the 2.5D coordinates (u, v, d) in the spatial. + """ + u = xyz[:, 0] / (xyz[:, 2] + 1e-8) * f[0] + c[0] + v = xyz[:, 1] / (xyz[:, 2] + 1e-8) * f[1] + c[1] + d = xyz[:, 2] + uvd = np.concatenate((u[:, None], v[:, None], d[:, None]), 1) + return uvd + + @staticmethod + def _uvd2xyz(uvd, f, c): + """Transform the joints from their 2.5D uvd coordinates to their 3D xyz + camera coordinates. + + Note: + N: number of joints + + Args: + uvd (ndarray[N, 3]): 3D joints coordinates + in the pixel coordinate system + f (ndarray[2]): focal length of x and y axis + c (ndarray[2]): principal point of x and y axis + + Returns: + xyz (ndarray[N, 3]): 3D joints coordinates + in the camera coordinate system + """ + x = (uvd[:, 0] - c[0]) / f[0] * uvd[:, 2] + y = (uvd[:, 1] - c[1]) / f[1] * uvd[:, 2] + z = uvd[:, 2] + xyz = np.concatenate((x[:, None], y[:, None], z[:, None]), 1) + return xyz + + @staticmethod + def _center2bounds(center_uvd, cube_size, f): + """ + + Args: + center_uvd (ndarray[1, 3]): + cube_size (ndarray[3]): + f (ndarray[2]): focal length of x and y axis + + Returns: + bounds (ndarray[1, 6]): 2.5D bounds + """ + + ustart = center_uvd[:, + 0] - (cube_size[0] / 2.) / center_uvd[:, 2] * f[0] + vstart = center_uvd[:, + 1] - (cube_size[1] / 2.) / center_uvd[:, 2] * f[1] + uend = center_uvd[:, 0] + (cube_size[0] / 2.) / center_uvd[:, 2] * f[0] + vend = center_uvd[:, 1] + (cube_size[1] / 2.) / center_uvd[:, 2] * f[1] + dstart = center_uvd[:, 2] - cube_size[2] / 2. + dend = center_uvd[:, 2] + cube_size[2] / 2. + bounds = np.concatenate( + (ustart[:, None], uend[:, None], vstart[:, None], vend[:, None], + dstart[:, None], dend[:, None]), 1) + return bounds + + @staticmethod + def _get_mapping_id_name(imgs): + """ + Args: + imgs (dict): dict of image info. + + Returns: + tuple: Image name & id mapping dicts. + + - id2name (dict): Mapping image id to name. + - name2id (dict): Mapping image name to id. + """ + id2name = {} + name2id = {} + for image_id, image in imgs.items(): + file_name = image['file_name'] + id2name[image_id] = file_name + name2id[file_name] = image_id + + return id2name, name2id + + def _xywh2cs(self, x, y, w, h, padding=1.25): + """This encodes bbox(x,y,w,h) into (center, scale) + + Args: + x, y, w, h (float): left, top, width and height + padding (float): bounding box padding factor + + Returns: + center (np.ndarray[float32](2,)): center of the bbox (x, y). + scale (np.ndarray[float32](2,)): scale of the bbox w & h. + """ + aspect_ratio = self.ann_info['image_size'][0] / self.ann_info[ + 'image_size'][1] + center = np.array([x + w * 0.5, y + h * 0.5], dtype=np.float32) + + if (not self.test_mode) and np.random.rand() < 0.3: + center += 0.4 * (np.random.rand(2) - 0.5) * [w, h] + + if w > aspect_ratio * h: + h = w * 1.0 / aspect_ratio + elif w < aspect_ratio * h: + w = h * aspect_ratio + + # pixel std is 200.0 + scale = np.array([w / 200.0, h / 200.0], dtype=np.float32) + # padding to include proper amount of context + scale = scale * padding + + return center, scale + + @abstractmethod + def _get_db(self): + """Load dataset.""" + raise NotImplementedError + + @abstractmethod + def evaluate(self, results, *args, **kwargs): + """Evaluate keypoint results.""" + + @staticmethod + def _write_keypoint_results(keypoints, res_file): + """Write results into a json file.""" + + with open(res_file, 'w') as f: + json.dump(keypoints, f, sort_keys=True, indent=4) + + def __len__(self): + """Get the size of the dataset.""" + return len(self.db) + + def __getitem__(self, idx): + """Get the sample given index.""" + results = copy.deepcopy(self.db[idx]) + results['ann_info'] = self.ann_info + return self.pipeline(results) + + def _sort_and_unique_bboxes(self, kpts, key='bbox_id'): + """sort kpts and remove the repeated ones.""" + kpts = sorted(kpts, key=lambda x: x[key]) + num = len(kpts) + for i in range(num - 1, 0, -1): + if kpts[i][key] == kpts[i - 1][key]: + del kpts[i] + + return kpts + + def _report_metric(self, + res_file, + metrics, + pck_thr=0.2, + pckh_thr=0.7, + auc_nor=30): + """Keypoint evaluation. + + Args: + res_file (str): Json file stored prediction results. + metrics (str | list[str]): Metric to be performed. + Options: 'PCK', 'PCKh', 'AUC', 'EPE', 'NME'. + pck_thr (float): PCK threshold, default as 0.2. + pckh_thr (float): PCKh threshold, default as 0.7. + auc_nor (float): AUC normalization factor, default as 30 pixel. + + Returns: + List: Evaluation results for evaluation metric. + """ + info_str = [] + + with open(res_file, 'r') as fin: + preds = json.load(fin) + assert len(preds) == len(self.db) + + outputs = [] + gts = [] + masks = [] + box_sizes = [] + threshold_bbox = [] + threshold_head_box = [] + + for pred, item in zip(preds, self.db): + + self.ann_info['image_size'] + + # pred_joint_xyz = self._uvd2xyz( + # np.array(pred['keypoints'], dtype=np.float32), item['focal'], + # item['princpt']) + outputs.append(np.array(pred['keypoints'])[:, :-1]) + gts.append(np.array(item['joints_3d'])[:, :-1]) + masks.append((np.array(item['joints_3d_visible'])[:, 0]) > 0) + if 'PCK' in metrics: + bbox = np.array(item['bbox']) + bbox_thr = np.max(bbox[2:]) + threshold_bbox.append(np.array([bbox_thr, bbox_thr])) + box_sizes.append(item.get('box_size', 1)) + + outputs = np.array(outputs) + gts = np.array(gts) + masks = np.array(masks) + threshold_bbox = np.array(threshold_bbox) + threshold_head_box = np.array(threshold_head_box) + box_sizes = np.array(box_sizes).reshape([-1, 1]) + + if 'PCK' in metrics: + _, pck, _ = keypoint_pck_accuracy(outputs, gts, masks, pck_thr, + threshold_bbox) + info_str.append(('PCK', pck)) + + if 'AUC' in metrics: + info_str.append(('AUC', keypoint_auc(outputs, gts, masks, + auc_nor))) + + if 'EPE' in metrics: + info_str.append(('EPE', keypoint_epe(outputs, gts, masks))) + + return info_str diff --git a/mmpose/datasets/datasets/hand/__init__.py b/mmpose/datasets/datasets/hand/__init__.py index 49159afa60..1fe5848032 100644 --- a/mmpose/datasets/datasets/hand/__init__.py +++ b/mmpose/datasets/datasets/hand/__init__.py @@ -3,6 +3,7 @@ from .hand_coco_wholebody_dataset import HandCocoWholeBodyDataset from .interhand2d_dataset import InterHand2DDataset from .interhand3d_dataset import InterHand3DDataset +from .nyuhand_dataset import NYUHandDataset from .onehand10k_dataset import OneHand10KDataset from .panoptic_hand2d_dataset import PanopticDataset from .rhd2d_dataset import Rhd2DDataset @@ -10,5 +11,5 @@ __all__ = [ 'FreiHandDataset', 'InterHand2DDataset', 'InterHand3DDataset', 'OneHand10KDataset', 'PanopticDataset', 'Rhd2DDataset', - 'HandCocoWholeBodyDataset' + 'HandCocoWholeBodyDataset', 'NYUHandDataset' ] diff --git a/mmpose/datasets/datasets/hand/nyuhand_dataset.py b/mmpose/datasets/datasets/hand/nyuhand_dataset.py new file mode 100644 index 0000000000..6c54557366 --- /dev/null +++ b/mmpose/datasets/datasets/hand/nyuhand_dataset.py @@ -0,0 +1,168 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +import warnings + +import json_tricks as json +import numpy as np +from mmcv import Config, deprecated_api_warning + +from mmpose.datasets.builder import DATASETS +from ..base import Kpt3dSviewDepthImgTopDownDataset + + +@DATASETS.register_module() +class NYUHandDataset(Kpt3dSviewDepthImgTopDownDataset): + """TODO, add more detail doc. + + Args: + ann_file (str): Path to the annotation file. + camera_file (str): Path to the camera file. + joint_file (str): Path to the joint file. + img_prefix (str): Path to a directory where images are held. + Default: None. + data_cfg (dict): config + pipeline (list[dict | callable]): A sequence of data transforms. + use_refined_center (bool): Using refined bbox center. + dataset_info (DatasetInfo): A class containing all dataset info. + test_mode (str): Store True when building test or + validation dataset. Default: False. + """ + + def __init__(self, + ann_file, + camera_file, + joint_file, + img_prefix, + data_cfg, + pipeline, + use_refined_center=False, + align_uvd_xyz_direction=True, + dataset_info=None, + test_mode=False): + + if dataset_info is None: + warnings.warn( + 'dataset_info is missing. ' + 'Check https://github.com/open-mmlab/mmpose/pull/663 ' + 'for details.', DeprecationWarning) + cfg = Config.fromfile('configs/_base_/datasets/nyu.py') + dataset_info = cfg._cfg_dict['dataset_info'] + + super().__init__( + ann_file, + img_prefix, + data_cfg, + pipeline, + dataset_info=dataset_info, + test_mode=test_mode) + + self.ann_info['cube_size'] = np.array(data_cfg['cube_size']) + self.ann_info['use_different_joint_weights'] = False + + self.camera_file = camera_file + self.joint_file = joint_file + self.align_uvd_xyz_direction = align_uvd_xyz_direction + self.use_refined_center = use_refined_center + if self.align_uvd_xyz_direction: + self.flip_y = -1 + else: + self.flip_y = 1 + self.meter2millimeter = 1 / 1000. + + self.db = self._get_db() + + print(f'=> num_images: {self.num_images}') + print(f'=> load {len(self.db)} samples') + + def _get_db(self): + """Load dataset.""" + with open(self.camera_file, 'r') as f: + cameras = json.load(f) + with open(self.joint_file, 'r') as f: + joints = json.load(f) + + gt_db = [] + bbox_id = 0 + for img_id in self.img_ids: + num_joints = self.ann_info['num_joints'] + + ann_id = self.coco.getAnnIds(imgIds=img_id, iscrowd=False) + ann = self.coco.loadAnns(ann_id)[0] + img = self.coco.loadImgs(img_id)[0] + + frame_idx = str(img['frame_idx']) + image_file = osp.join(self.img_prefix, self.id2name[img_id]) + + focal = np.array([cameras['fx'], cameras['fy']], dtype=np.float32) + principal_pt = np.array([cameras['cx'], cameras['cy']], + dtype=np.float32) + + joint_uvd = np.array( + joints[frame_idx]['joint_uvd'], dtype=np.float32) + joint_xyz = np.array( + joints[frame_idx]['joint_xyz'], dtype=np.float32) + joint_xyz[:, 1] *= self.flip_y + + # calculate bbox online + # using center_xyz and cube_size, then project to 2D as bbox + if self.use_refined_center: + center_xyz = np.array( + ann['center_refined_xyz'], + dtype=np.float32).reshape(-1, 1) + else: + center_xyz = np.mean(joint_xyz, axis=0, keepdims=True) + center_depth = center_xyz[0, 2] + center_uvd = self._xyz2uvd(center_xyz, focal, principal_pt) + + if self.test_mode and img_id >= 2440: + cube_size = np.array( + self.ann_info['cube_size'], dtype=np.float32) * 5.0 / 6.0 + else: + cube_size = np.array( + self.ann_info['cube_size'], dtype=np.float32) + + bounds_uvd = self._center2bounds(center_uvd, cube_size, focal) + bbox = np.array([ + bounds_uvd[0, 0], bounds_uvd[0, 2], bounds_uvd[0, 1] - + bounds_uvd[0, 0], bounds_uvd[0, 3] - bounds_uvd[0, 2] + ], + dtype=np.float32) + + valid_joints_idx = self.ann_info['dataset_channel'] + joint_valid = np.zeros(joint_xyz.shape[0], dtype=np.float32) + joint_valid[valid_joints_idx] = 1.0 + + # joint_3d will be normalized in pre-processing pipeline + # uv are processed by TopDownAffine + # depth are processed by DepthToTensor + joints_3d = np.zeros((num_joints, 3), dtype=np.float32) + joints_3d_visible = np.zeros((num_joints, 3), dtype=np.float32) + joints_3d[:, :2] = joint_uvd[:, :2] + joints_3d[:, 2] = joint_uvd[:, 2] + + joints_3d_visible[...] = np.minimum(1, joint_valid.reshape(-1, 1)) + + gt_db.append({ + 'image_file': image_file, + 'rotation': 0, + 'joints_3d': joints_3d, + 'joints_3d_visible': joints_3d_visible, + 'joints_cam': joint_xyz, + 'joints_uvd': joint_uvd, + 'cube_size': cube_size, + 'center_depth': center_depth, + 'focal': focal, + 'princpt': principal_pt, + 'dataset': self.dataset_name, + 'bbox': bbox, + 'bbox_score': 1, + 'bbox_id': bbox_id + }) + bbox_id = bbox_id + 1 + gt_db = sorted(gt_db, key=lambda x: x['bbox_id']) + + return gt_db + + @deprecated_api_warning(name_dict=dict(outputs='results')) + def evaluate(self, results, res_folder=None, metric='EPE', **kwargs): + raise NotImplementedError diff --git a/mmpose/datasets/pipelines/hand_transform.py b/mmpose/datasets/pipelines/hand_transform.py index b83e399c4e..e84ec87e8a 100644 --- a/mmpose/datasets/pipelines/hand_transform.py +++ b/mmpose/datasets/pipelines/hand_transform.py @@ -1,5 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. +import mmcv import numpy as np +from torchvision.transforms import functional as F from mmpose.datasets.builder import PIPELINES from .top_down_transform import TopDownRandomFlip @@ -61,3 +63,88 @@ def __call__(self, results): results['target'] = target * np.ones(1, dtype=np.float32) results['target_weight'] = target_weight * np.ones(1, dtype=np.float32) return results + + +@PIPELINES.register_module() +class DepthToTensor: + """Transform depth image to Tensor. + TODO: add reference from AWR github + + Required key: 'img', 'cube_size', 'center_depth'. + + Modifies key: 'img'. + """ + + def __init__(self): + pass + + def __call__(self, results): + if isinstance(results['img'], (list, tuple)): + results['img'] = [ + F.to_tensor(self._process_depth(img, results)) + for img in results['img'] + ] + else: + depth = self._process_depth(results['img'], results) + results['img'] = F.to_tensor(depth) + return results + + @staticmethod + def _process_depth(img, results): + depth = np.asarray(img[:, :, 0] + img[:, :, 1] * 256, dtype=np.float32) + img_max = np.max(depth) + depth_max = results['center_depth'] + (results['cube_size'][2] / 2.) + depth_min = results['center_depth'] - (results['cube_size'][2] / 2.) + depth[depth == img_max] = depth_max + depth[depth == 0] = depth_max + depth = np.clip(depth, depth_min, depth_max) + depth = (depth - results['center_depth']) / ( + results['cube_size'][2] / 2.) + return depth + + +@PIPELINES.register_module() +class HandGenerateJointToOffset: + """""" + + def __init__(self, heatmap_kernel_size): + self.heatmap_kernel_size = heatmap_kernel_size + + def __call__(self, results): + cfg = results['ann_info'] + feature_size = cfg['heatmap_size'] + joint_uvd = results['target'] # UV -1,1 + num_joints = joint_uvd.shape[0] + + img = results['img'] + depth = img.numpy()[0] # it is a hack + + coord_x = (2.0 * (np.arange(feature_size[0]) + 0.5) / feature_size[0] - + 1.0).astype(np.float32) + coord_y = (2.0 * (np.arange(feature_size[1]) + 0.5) / feature_size[1] - + 1.0).astype(np.float32) + xv, yv = np.meshgrid(coord_x, coord_y) + coord = np.stack((xv, yv), 0) + depth_resize = mmcv.imresize( + depth, (feature_size[0], feature_size[1]), interpolation='nearest') + depth_resize = np.expand_dims(depth_resize, 0) + coord_with_depth = np.expand_dims( + np.concatenate((coord, depth_resize), 0), 0) + jt_ft = np.broadcast_to(joint_uvd[:, :, np.newaxis, np.newaxis], + (joint_uvd.shape[0], joint_uvd.shape[1], + feature_size[0], feature_size[1])) + offset = jt_ft - coord_with_depth # [jt_num, 3, F, F] + dis = np.linalg.norm(offset + 1e-8, axis=1) # [jt_num, F, F] + offset_norm = offset / dis[:, np.newaxis, ...] # value in [-1, 1] + heatmap = (self.heatmap_kernel_size - + dis) / self.heatmap_kernel_size # [jt_num, F, F] + mask = (heatmap > 0).astype(np.float32) * (depth_resize < 0.99).astype( + np.float32) # [jt_num, F, F] + offset_norm_mask = (offset_norm * mask[:, None, ...]).reshape( + -1, feature_size[0], feature_size[1]) + heatmap_mask = heatmap * mask + offset_field = np.concatenate((offset_norm_mask, heatmap_mask), + axis=0) # [jt_num*4, F, F] + results['target'] = offset_field + results['target_weight'] = np.ones(num_joints) + return results diff --git a/mmpose/datasets/pipelines/top_down_transform.py b/mmpose/datasets/pipelines/top_down_transform.py index c230870eaf..886d15e9ac 100644 --- a/mmpose/datasets/pipelines/top_down_transform.py +++ b/mmpose/datasets/pipelines/top_down_transform.py @@ -726,10 +726,23 @@ class TopDownGenerateTargetRegression: Required key: 'joints_3d', 'joints_3d_visible', 'ann_info'. Modified key: 'target', and 'target_weight'. + + Args: + use_zero_mean: (bool) If set to True, target normalize to [-1, 1], + otherwise [0,1] + joint_indices: (list): Indices of joints used for heatmap generation. + If None (default) is given, all joints will be used. """ - def __init__(self): - pass + def __init__(self, + use_zero_mean=False, + joint_indices=None, + is_3d=False, + normalize_depth=False): + self.use_zero_mean = use_zero_mean + self.joint_indices = joint_indices + self.is_3d = is_3d + self.normalize_depth = normalize_depth def _generate_target(self, cfg, joints_3d, joints_3d_visible): """Generate the target regression vector. @@ -746,20 +759,41 @@ def _generate_target(self, cfg, joints_3d, joints_3d_visible): joint_weights = cfg['joint_weights'] use_different_joint_weights = cfg['use_different_joint_weights'] + # only preserve used joint if joint_indices is given + if self.joint_indices is not None and len(self.joint_indices) > 0: + joint_weights = joint_weights[self.joint_indices] + joints_3d = joints_3d[self.joint_indices] + joints_3d_visible = joints_3d_visible[self.joint_indices] + mask = (joints_3d[:, 0] >= 0) * ( joints_3d[:, 0] <= image_size[0] - 1) * (joints_3d[:, 1] >= 0) * ( joints_3d[:, 1] <= image_size[1] - 1) - target = joints_3d[:, :2] / image_size + keypoints_dim = 3 if self.is_3d else 2 + + if self.use_zero_mean: + target = joints_3d[:, :keypoints_dim] + target_2d = target[:, :2] / image_size + target_2d = 2 * target_2d - 1 + target[:, :2] = target_2d + else: + target = joints_3d[:, :keypoints_dim] + target_2d = target[:, :2] / image_size + target[:, :2] = target_2d target = target.astype(np.float32) - target_weight = joints_3d_visible[:, :2] * mask[:, None] + target_weight = joints_3d_visible[:, :keypoints_dim] * mask[:, None] if use_different_joint_weights: target_weight = np.multiply(target_weight, joint_weights) return target, target_weight + def _normalize_target(self, joints_3d, center_depth, cube_size): + joints_3d[:, 2] = (joints_3d[:, 2] - center_depth) / ( + cube_size[2] / 2.0) + return joints_3d + def __call__(self, results): """Generate the target heatmap.""" joints_3d = results['joints_3d'] @@ -769,6 +803,9 @@ def __call__(self, results): joints_3d, joints_3d_visible) + if self.is_3d and self.normalize_depth: + target = self._normalize_target(target, results['center_depth'], + results['cube_size']) results['target'] = target results['target_weight'] = target_weight diff --git a/mmpose/models/backbones/__init__.py b/mmpose/models/backbones/__init__.py index 2fc64a8af3..5ea977f824 100644 --- a/mmpose/models/backbones/__init__.py +++ b/mmpose/models/backbones/__init__.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from .alexnet import AlexNet +from .awr_resnet import AWRResNet from .cpm import CPM from .hourglass import HourglassNet from .hourglass_ae import HourglassAENet @@ -35,5 +36,6 @@ 'SEResNet', 'SEResNeXt', 'ShuffleNetV1', 'ShuffleNetV2', 'CPM', 'RSN', 'MSPN', 'ResNeSt', 'VGG', 'TCN', 'ViPNAS_ResNet', 'ViPNAS_MobileNetV3', 'LiteHRNet', 'V2VNet', 'HRFormer', 'PyramidVisionTransformer', - 'PyramidVisionTransformerV2', 'SwinTransformer', 'I3D', 'TCFormer' + 'PyramidVisionTransformerV2', 'SwinTransformer', 'I3D', 'TCFormer', + 'AWRResNet' ] diff --git a/mmpose/models/backbones/awr_resnet.py b/mmpose/models/backbones/awr_resnet.py new file mode 100644 index 0000000000..752725b532 --- /dev/null +++ b/mmpose/models/backbones/awr_resnet.py @@ -0,0 +1,33 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn +from mmcv.cnn import build_conv_layer, build_norm_layer + +from ..builder import BACKBONES +from .resnet import ResNet + + +@BACKBONES.register_module() +class AWRResNet(ResNet): + """AWR ResNet backbone. + + Using a specialized stem scheme. + """ + + def __init__(self, **kwargs): + super().__init__(deep_stem=False, **kwargs) + + def _make_stem_layer(self, in_channels, stem_channels): + """Make stem layer for depth.""" + self.conv1 = build_conv_layer( + self.conv_cfg, + in_channels, + stem_channels, + kernel_size=5, + stride=1, + padding=2, + bias=False) + self.norm1_name, norm1 = build_norm_layer( + self.norm_cfg, stem_channels, postfix=1) + self.add_module(self.norm1_name, norm1) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) diff --git a/mmpose/models/detectors/__init__.py b/mmpose/models/detectors/__init__.py index d94d8b8aab..bf5d1f629a 100644 --- a/mmpose/models/detectors/__init__.py +++ b/mmpose/models/detectors/__init__.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from .associative_embedding import AssociativeEmbedding +from .depthhand_3d import Depthhand3D from .gesture_recognizer import GestureRecognizer from .interhand_3d import Interhand3D from .mesh import ParametricMesh @@ -13,5 +14,6 @@ __all__ = [ 'TopDown', 'AssociativeEmbedding', 'ParametricMesh', 'MultiTask', 'PoseLifter', 'Interhand3D', 'PoseWarper', 'DetectAndRegress', - 'VoxelCenterDetector', 'VoxelSinglePose', 'GestureRecognizer' + 'VoxelCenterDetector', 'VoxelSinglePose', 'GestureRecognizer', + 'Depthhand3D' ] diff --git a/mmpose/models/detectors/depthhand_3d.py b/mmpose/models/detectors/depthhand_3d.py new file mode 100644 index 0000000000..ac33ab8749 --- /dev/null +++ b/mmpose/models/detectors/depthhand_3d.py @@ -0,0 +1,303 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import mmcv +import numpy as np +from mmcv.utils.misc import deprecated_api_warning + +from mmpose.core import imshow_keypoints, imshow_keypoints_3d +from ..builder import POSENETS +from .top_down import TopDown + + +@POSENETS.register_module() +class Depthhand3D(TopDown): + """Top-down depth-based 3d keypoints detector.""" + + def forward(self, + img, + target=None, + target_weight=None, + img_metas=None, + return_loss=True, + **kwargs): + """Calls either forward_train or forward_test depending on whether + return_loss=True. Note this setting will change the expected inputs. + When `return_loss=True`, img and img_meta are single-nested (i.e. + Tensor and List[dict]), and when `resturn_loss=False`, img and img_meta + should be double nested (i.e. list[Tensor], list[list[dict]]), with + the outer list indicating test time augmentations. + + Note: + - batch_size: N + - num_keypoints: K + - num_img_channel: C (Default: 3) + - img height: imgH + - img width: imgW + - heatmaps height: H + - heatmaps weight: W + + Args: + img (torch.Tensor[NxCximgHximgW]): Input images. + target (list[torch.Tensor]): Target heatmaps, relative hand + root depth and hand type. + target_weight (list[torch.Tensor]): Weights for target + heatmaps, relative hand root depth and hand type. + img_metas (list(dict)): Information about data augmentation + By default this includes: + + - "image_file: path to the image file + - "center": center of the bbox + - "scale": scale of the bbox + - "rotation": rotation of the bbox + - "bbox_score": score of bbox + - "heatmap3d_depth_bound": depth bound of hand keypoint 3D + heatmap + - "root_depth_bound": depth bound of relative root depth 1D + heatmap + return_loss (bool): Option to `return loss`. `return loss=True` + for training, `return loss=False` for validation & test. + + Returns: + dict|tuple: if `return loss` is true, then return losses. \ + Otherwise, return predicted poses, boxes, image paths, \ + heatmaps, relative hand root depth and hand type. + """ + if return_loss: + return self.forward_train(img, target, target_weight, img_metas, + **kwargs) + return self.forward_test(img, img_metas, **kwargs) + + def forward_train(self, img, target, target_weight, img_metas, **kwargs): + """Defines the computation performed at every call when training.""" + features = self.backbone(img) + if self.with_neck: + features = self.neck(features) + if self.with_keypoint: + if self.train_cfg['use_img_for_head']: + output = self.keypoint_head((features, img)) + else: + output = self.keypoint_head(features) + + # if return loss + losses = dict() + if self.with_keypoint: + keypoint_losses = self.keypoint_head.get_loss( + output, target, target_weight) + losses.update(keypoint_losses) + + # import ipdb + # ipdb.set_trace() + # + # from mmpose.core.post_processing import (affine_transform, + # get_affine_transform, + # get_warp_matrix, + # warp_affine_joints) + # + # used_index = img_metas[0]['dataset_channel'] + # jt_xyz_gt = img_metas[0]['joints_cam'][used_index] + # center = img_metas[0]['center'] + # scale = img_metas[0]['scale'] + # rotation = img_metas[0]['rotation'] + # image_size = img_metas[0]['image_size'] + # trans = get_affine_transform(center, scale, rotation, image_size) + # inv_trans = get_affine_transform( + # center, scale, rotation, image_size, inv=True) + # + # jt_uvd_pred = output[1][0].detach().cpu().numpy() + # jt_uvd_pred[:, :2] = (jt_uvd_pred[:, :2] + + # 1) * img_metas[0]['image_size'] / 2. + # jt_uvd_pred[:, 2] = jt_uvd_pred[:, 2] * img_metas[0]['cube_size'][ + # 2] / 2 + img_metas[0]['center_depth'] + # + # jt_uvd_gt = target[1][0].detach().cpu().numpy() + # jt_uvd_gt[:, :2] = (jt_uvd_gt[:, :2] + + # 1) * img_metas[0]['image_size'] / 2. + # jt_uvd_gt[:, 2] = jt_uvd_gt[:, 2] *\ + # img_metas[0]['cube_size'][2] / 2\ + # + img_metas[0]['center_depth'] + # + # for i in range(len(img_metas[0]['dataset_channel'])): + # jt_uvd_gt[i, 0:2] = affine_transform(jt_uvd_gt[i, 0:2].copy(), + # inv_trans) + # + # import ipdb + # ipdb.set_trace() + # from mmpose.datasets.datasets.base import \ + # Kpt3dSviewDepthImgTopDownDataset + # jt_xyz_gt_from_uvd = Kpt3dSviewDepthImgTopDownDataset._uvd2xyz( + # jt_uvd_gt, f=img_metas[0]['focal'], c=img_metas[0]['princpt']) + + return losses + + def forward_test(self, img, img_metas, **kwargs): + """Defines the computation performed at every call when testing.""" + assert img.size(0) == len(img_metas) + batch_size, _, img_height, img_width = img.shape + if batch_size > 1: + assert 'bbox_id' in img_metas[0] + + features = self.backbone(img) + if self.with_neck: + features = self.neck(features) + if self.with_keypoint: + if self.train_cfg['use_img_for_head']: + output = self.keypoint_head.inference_model((features, img), + flip_pairs=None) + else: + output = self.keypoint_head.inference_model( + features, flip_pairs=None) + + if self.test_cfg.get('flip_test', True): + raise NotImplementedError + + if self.with_keypoint: + result = self.keypoint_head.decode( + img_metas, output, img_size=[img_width, img_height]) + else: + result = {} + return result + + def forward_dummy(self, img): + """Used for computing network FLOPs. + + See ``tools/get_flops.py``. + + Args: + img (torch.Tensor): Input image. + + Returns: + Tensor: Output heatmaps. + """ + output = self.backbone(img) + if self.with_neck: + output = self.neck(output) + if self.with_keypoint: + if self.train_cfg['use_img_for_head']: + output = self.keypoint_head((output, img)) + else: + output = self.keypoint_head(output) + return output + + @deprecated_api_warning({'pose_limb_color': 'pose_link_color'}, + cls_name='Depthhand3D') + def show_result( + self, # TODO: NotImplement + result, + img=None, + skeleton=None, + kpt_score_thr=0.3, + radius=8, + bbox_color='green', + thickness=2, + pose_kpt_color=None, + pose_link_color=None, + vis_height=400, + num_instances=-1, + win_name='', + show=False, + wait_time=0, + out_file=None): + """Visualize 3D pose estimation results. + + Args: + result (list[dict]): The pose estimation results containing: + + - "keypoints_3d" ([K,4]): 3D keypoints + - "keypoints" ([K,3] or [T,K,3]): Optional for visualizing + 2D inputs. If a sequence is given, only the last frame + will be used for visualization + - "bbox" ([4,] or [T,4]): Optional for visualizing 2D inputs + - "title" (str): title for the subplot + img (str or Tensor): Optional. The image to visualize 2D inputs on. + skeleton (list of [idx_i,idx_j]): Skeleton described by a list of + links, each is a pair of joint indices. + kpt_score_thr (float, optional): Minimum score of keypoints + to be shown. Default: 0.3. + radius (int): Radius of circles. + bbox_color (str or tuple or :obj:`Color`): Color of bbox lines. + thickness (int): Thickness of lines. + pose_kpt_color (np.array[Nx3]`): Color of N keypoints. + If None, do not draw keypoints. + pose_link_color (np.array[Mx3]): Color of M limbs. + If None, do not draw limbs. + vis_height (int): The image height of the visualization. The width + will be N*vis_height depending on the number of visualized + items. + num_instances (int): Number of instances to be shown in 3D. If + smaller than 0, all the instances in the pose_result will be + shown. Otherwise, pad or truncate the pose_result to a length + of num_instances. + win_name (str): The window name. + show (bool): Whether to show the image. Default: False. + wait_time (int): Value of waitKey param. + Default: 0. + out_file (str or None): The filename to write the image. + Default: None. + + Returns: + Tensor: Visualized img, only if not `show` or `out_file`. + """ + if num_instances < 0: + assert len(result) > 0 + result = sorted(result, key=lambda x: x.get('track_id', 0)) + + # draw image and 2d poses + if img is not None: + img = mmcv.imread(img) + + bbox_result = [] + pose_2d = [] + for res in result: + if 'bbox' in res: + bbox = np.array(res['bbox']) + if bbox.ndim != 1: + assert bbox.ndim == 2 + bbox = bbox[-1] # Get bbox from the last frame + bbox_result.append(bbox) + if 'keypoints' in res: + kpts = np.array(res['keypoints']) + if kpts.ndim != 2: + assert kpts.ndim == 3 + kpts = kpts[-1] # Get 2D keypoints from the last frame + pose_2d.append(kpts) + + if len(bbox_result) > 0: + bboxes = np.vstack(bbox_result) + mmcv.imshow_bboxes( + img, + bboxes, + colors=bbox_color, + top_k=-1, + thickness=2, + show=False) + if len(pose_2d) > 0: + imshow_keypoints( + img, + pose_2d, + skeleton, + kpt_score_thr=kpt_score_thr, + pose_kpt_color=pose_kpt_color, + pose_link_color=pose_link_color, + radius=radius, + thickness=thickness) + img = mmcv.imrescale(img, scale=vis_height / img.shape[0]) + + img_vis = imshow_keypoints_3d( + result, + img, + skeleton, + pose_kpt_color, + pose_link_color, + vis_height, + axis_limit=300, + axis_azimuth=-115, + axis_elev=15, + kpt_score_thr=kpt_score_thr, + num_instances=num_instances) + + if show: + mmcv.visualization.imshow(img_vis, win_name, wait_time) + + if out_file is not None: + mmcv.imwrite(img_vis, out_file) + + return img_vis diff --git a/mmpose/models/heads/__init__.py b/mmpose/models/heads/__init__.py index 459c20b8bd..e72763ab9a 100644 --- a/mmpose/models/heads/__init__.py +++ b/mmpose/models/heads/__init__.py @@ -2,6 +2,7 @@ from .ae_higher_resolution_head import AEHigherResolutionHead from .ae_multi_stage_head import AEMultiStageHead from .ae_simple_head import AESimpleHead +from .awr_head import AdaptiveWeightingRegression3DHead from .deconv_head import DeconvHead from .deeppose_regression_head import DeepposeRegressionHead from .hmr_head import HMRMeshHead @@ -21,5 +22,5 @@ 'AEHigherResolutionHead', 'AESimpleHead', 'AEMultiStageHead', 'DeepposeRegressionHead', 'TemporalRegressionHead', 'Interhand3DHead', 'HMRMeshHead', 'DeconvHead', 'ViPNASHeatmapSimpleHead', 'CuboidCenterHead', - 'CuboidPoseHead', 'MultiModalSSAHead' + 'CuboidPoseHead', 'MultiModalSSAHead', 'AdaptiveWeightingRegression3DHead' ] diff --git a/mmpose/models/heads/awr_head.py b/mmpose/models/heads/awr_head.py new file mode 100644 index 0000000000..e29664c89d --- /dev/null +++ b/mmpose/models/heads/awr_head.py @@ -0,0 +1,458 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import build_upsample_layer, constant_init, normal_init + +from mmpose.core.camera import SimpleCamera +from mmpose.core.evaluation.top_down_eval import keypoints_from_joint_uvd +from mmpose.models.builder import build_loss +from ..builder import HEADS + + +class OffsetHead(nn.Module): + + def __init__(self, + in_channels, + out_channels_vector, + out_channels_scalar, + heatmap_kernel_size, + dummy_args=None): + + super().__init__() + + self.heatmap_kernel_size = heatmap_kernel_size + assert out_channels_vector == out_channels_scalar * 3 + self.vector_offset = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels_vector, + kernel_size=1, + stride=1, + padding=0) + self.scalar_offset = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels_scalar, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x): + """Forward function.""" + vec = self.vector_offset(x) + ht = self.scalar_offset(x) + # N, C, H, W = x.shape + offset_field = torch.cat((vec, ht), dim=1) + return offset_field + + def init_weights(self): + """Initialize model weights.""" + for m in self.vector_offset.modules(): + if isinstance(m, nn.Conv2d): + nn.init.normal_(m.weight, std=0.001) + nn.init.constant_(m.bias, 0) + for m in self.scalar_offset.modules(): + if isinstance(m, nn.Conv2d): + nn.init.normal_(m.weight, std=0.001) + nn.init.constant_(m.bias, 0) + + +class UpsampleHead(nn.Module): + """UpsampleHead is a sub-module of AWR Head, and outputs 3D heatmaps. + UpsampleHead is composed of (>=0) number of deconv layers. + + Args: + in_channels (int): Number of input channels + out_channels (int): Number of output channels + depth_size (int): Number of depth discretization size + num_deconv_layers (int): Number of deconv layers. + num_deconv_layers should >= 0. Note that 0 means no deconv layers. + num_deconv_filters (list|tuple): Number of filters. + num_deconv_kernels (list|tuple): Kernel sizes. + extra (dict): Configs for extra conv layers. Default: None + """ + + def __init__(self, + in_channels, + out_channels, + depth_size=64, + num_deconv_layers=3, + num_deconv_filters=(256, 256, 256), + num_deconv_kernels=(4, 4, 4), + extra=None): + + super().__init__() + + assert out_channels % depth_size == 0 + self.depth_size = depth_size + self.in_channels = in_channels + + if extra is not None and not isinstance(extra, dict): + raise TypeError('extra should be dict or None.') + + if num_deconv_layers > 0: + self.deconv_layers = self._make_deconv_layer( + num_deconv_layers, + num_deconv_filters, + num_deconv_kernels, + ) + elif num_deconv_layers == 0: + self.deconv_layers = nn.Identity() + else: + raise ValueError( + f'num_deconv_layers ({num_deconv_layers}) should >= 0.') + + identity_final_layer = False + if extra is not None and 'final_conv_kernel' in extra: + assert extra['final_conv_kernel'] in [0] + identity_final_layer = True + + if identity_final_layer: + self.final_layer = nn.Identity() + else: + # TODO: do not support this type of layer configuration + raise NotImplementedError + + def _make_deconv_layer(self, num_layers, num_filters, num_kernels): + """Make deconv layers.""" + if num_layers != len(num_filters): + error_msg = f'num_layers({num_layers}) ' \ + f'!= length of num_filters({len(num_filters)})' + raise ValueError(error_msg) + if num_layers != len(num_kernels): + error_msg = f'num_layers({num_layers}) ' \ + f'!= length of num_kernels({len(num_kernels)})' + raise ValueError(error_msg) + + layers = [] + for i in range(num_layers): + kernel, padding, output_padding = \ + self._get_deconv_cfg(num_kernels[i]) + + planes = num_filters[i] + layers.append( + build_upsample_layer( + dict(type='deconv'), + in_channels=self.in_channels, + out_channels=planes, + kernel_size=kernel, + stride=2, + padding=padding, + output_padding=output_padding, + bias=False)) + layers.append(nn.BatchNorm2d(planes)) + layers.append(nn.ReLU(inplace=True)) + self.in_channels = planes + + return nn.Sequential(*layers) + + @staticmethod + def _get_deconv_cfg(deconv_kernel): + """Get configurations for deconv layers.""" + if deconv_kernel == 4: + padding = 1 + output_padding = 0 + elif deconv_kernel == 3: + padding = 1 + output_padding = 1 + elif deconv_kernel == 2: + padding = 0 + output_padding = 0 + else: + raise ValueError(f'Not supported num_kernels ({deconv_kernel}).') + + return deconv_kernel, padding, output_padding + + def forward(self, x): + """Forward function.""" + x = self.deconv_layers(x) + x = self.final_layer(x) + # N, C, H, W = x.shape + # # reshape the 2D heatmap to 3D heatmap + # x = x.reshape(N, C // self.depth_size, self.depth_size, H, W) + return x + + def init_weights(self): + """Initialize model weights.""" + for _, m in self.deconv_layers.named_modules(): + if isinstance(m, nn.ConvTranspose2d): + normal_init(m, std=0.001) + elif isinstance(m, nn.BatchNorm2d): + constant_init(m, 1) + for m in self.final_layer.modules(): + if isinstance(m, nn.Conv2d): + normal_init(m, std=0.001, bias=0) + elif isinstance(m, nn.BatchNorm2d): + constant_init(m, 1) + + +@HEADS.register_module() +class AdaptiveWeightingRegression3DHead(nn.Module): + """ + + Args: + deconv_head_cfg (dict): Configs of UpsampleHead for hand + keypoint estimation. + offset_head_cfg (dict): Configs of OffsetHead for hand + keypoint offset field estimation. + loss_keypoint (dict): Config for keypoint loss. Default: None. + loss_offset (dict): Config for offset field loss. Default: None. + """ + + def __init__(self, + deconv_head_cfg, + offset_head_cfg, + loss_keypoint=None, + loss_offset=None, + train_cfg=None, + test_cfg=None): + super().__init__() + + self.deconv_head_cfg = deconv_head_cfg + self.offset_head_cfg = offset_head_cfg + + # build sub-module heads + # dense head + self.offset_head = OffsetHead(**offset_head_cfg) + # regression head + self.upsample_feature_head = UpsampleHead(**deconv_head_cfg) + + # build losses + self.keypoint_loss = build_loss(loss_keypoint) + self.offset_loss = build_loss(loss_offset) + self.train_cfg = {} if train_cfg is None else train_cfg + self.test_cfg = {} if test_cfg is None else test_cfg + + def init_weights(self): + self.upsample_feature_head.init_weights() + self.offset_head.init_weights() + + @staticmethod + def offset2joint_softmax(offset, img, kernel_size): + batch_size, feature_num, feature_size, _ = offset.size() + jt_num = int(feature_num / 4) + img = F.interpolate( + img, size=[feature_size, feature_size]) # (B, 1, F, F) + # unit directional vector + offset_vec = offset[:, :jt_num * 3].contiguous() # (B, jt_num*3, F, F) + # closeness heatmap + offset_ht = offset[:, jt_num * 3:].contiguous() # (B, jt_num, F, F) + + mesh_x = 2.0 * (torch.arange(feature_size).unsqueeze(0).expand( + feature_size, feature_size).float() + 0.5) / feature_size - 1.0 + mesh_y = 2.0 * (torch.arange(feature_size).unsqueeze(1).expand( + feature_size, feature_size).float() + 0.5) / feature_size - 1.0 + coords = torch.stack((mesh_x, mesh_y), dim=0) + coords = coords.unsqueeze(0).repeat(batch_size, 1, 1, + 1).to(offset.device) + coords = torch.cat((coords, img), + dim=1).repeat(1, jt_num, 1, + 1) # (B, jt_num*3, F, F) + coords = coords.view(batch_size, jt_num, 3, -1) # (B, jt_num, 3, F*F) + + mask = img.lt(0.99).float() # (B, 1, F, F) + offset_vec_mask = (offset_vec * mask).view(batch_size, jt_num, 3, + -1) # (B, jt_num, 3, F*F) + offset_ht_mask = (offset_ht * mask).view(batch_size, jt_num, + -1) # (B, jt_num, F*F) + offset_ht_norm = F.softmax( + offset_ht_mask * 30, dim=-1) # (B, jt_num, F*F) + dis = kernel_size - offset_ht_mask * kernel_size # (B, jt_num, F*F) + + jt_uvd = torch.sum( + (offset_vec_mask * dis.unsqueeze(2) + coords) * + offset_ht_norm.unsqueeze(2), + dim=-1) + + return jt_uvd.float() + + @staticmethod + def joint2offset(jt_uvd, img, kernel_size, feature_size): + """ + :params joint: hand joint coordinates, shape (B, joint_num, 3) + :params img: depth image, shape (B, C, H, W) + :params kernel_size + :params feature_size: size of generated offsets feature + """ + batch_size, jt_num, _ = jt_uvd.size() + img = F.interpolate(img, size=[feature_size, feature_size]) + jt_ft = jt_uvd.view(batch_size, -1, 1, + 1).repeat(1, 1, feature_size, + feature_size) # (B, joint_num*3, F, F) + + mesh_x = 2.0 * (torch.arange(feature_size).unsqueeze(0).expand( + feature_size, feature_size).float() + 0.5) / feature_size - 1.0 + mesh_y = 2.0 * (torch.arange(feature_size).unsqueeze(1).expand( + feature_size, feature_size).float() + 0.5) / feature_size - 1.0 + coords = torch.stack((mesh_x, mesh_y), dim=0) + coords = coords.unsqueeze(0).repeat(batch_size, 1, 1, 1).to( + jt_uvd.device) # (B, 2, F, F) + coords = torch.cat((coords, img), + dim=1).repeat(1, jt_num, 1, + 1) # (B, jt_num*3, F, F) + + offset = jt_ft - coords # (B, jt_num*3, F, F) + offset = offset.view(batch_size, jt_num, 3, feature_size, + feature_size) # (B, jt_num, 3, F, F) + dis = torch.sqrt(torch.sum(torch.pow(offset, 2), dim=2) + + 1e-8) # (B, jt_num, F, F) + + offset_norm = offset / dis.unsqueeze(2) # (B, jt_num, 3, F, F) + heatmap = (kernel_size - dis) / kernel_size # (B, jt_num, F, F) + mask = heatmap.ge(0).float() * img.lt( + 0.99).float() # (B, jt_num, F, F) + + offset_norm_mask = (offset_norm * + mask.unsqueeze(2)).view(batch_size, -1, + feature_size, feature_size) + heatmap_mask = heatmap * mask.float() + return torch.cat((offset_norm_mask, heatmap_mask), dim=1).float() + + def get_loss(self, output, target, target_weight): + """Calculate loss for hand keypoint heatmaps, relative root depth and + hand type. + + Args: + output (list[Tensor]): a list of outputs from multiple heads. + target (list[Tensor]): a list of targets for multiple heads. + target_weight (list[Tensor]): a list of targets weight for + multiple heads. + """ + losses = dict() + + # hand keypoint offset field loss, dense loss + assert not isinstance(self.keypoint_loss, nn.Sequential) + out, tar, tar_weight = output[0], target[0], target_weight[0] + assert tar.dim() == 4 and tar_weight.dim() in [1, 2] + losses['offset_loss'] = self.offset_loss(out, tar) + # hand keypoint joint loss, regression loss + assert not isinstance(self.offset_loss, nn.Sequential) + out, tar, tar_weight = output[1], target[1], target_weight[1] + assert tar.dim() == 3 and tar_weight.dim() == 3 + losses['joint_loss'] = self.keypoint_loss(out, tar, tar_weight) + + return losses + + def forward(self, x): + """Forward function.""" + backbone_feature, img = x + feature = self.upsample_feature_head(backbone_feature) + offset_field = self.offset_head(feature) + jt_uvd = self.offset2joint_softmax( + offset_field, img, self.offset_head_cfg['heatmap_kernel_size']) + outputs = [offset_field, jt_uvd] + return outputs + + def inference_model(self, x, flip_pairs=None): + """Inference function. + + Returns: + output (list[np.ndarray]): list of output hand keypoint + heatmaps, relative root depth and hand type. + + Args: + x (torch.Tensor[N,K,H,W]): Input features. + flip_pairs (None | list[tuple()): + Pairs of keypoints which are mirrored. + """ + + output = self.forward(x) + + if flip_pairs is not None: + raise NotImplementedError + else: + output = [out.detach().cpu().numpy() for out in output] + + return output + + def decode(self, img_metas, output, **kwargs): + """Decode hand keypoint and offset field. + + Args: + img_metas (list(dict)): Information about data augmentation + By default this includes: + + - "image_file: path to the image file + - "center": center of the bbox + - "scale": scale of the bbox + - "rotation": rotation of the bbox + - "bbox_score": score of bbox + - "heatmap3d_depth_bound": depth bound of hand keypoint + 3D heatmap + - "root_depth_bound": depth bound of relative root depth + 1D heatmap + output (list[np.ndarray]): model predicted 3D heatmaps, relative + root depth and hand type. + """ + + batch_size = len(img_metas) + result = {} + + center = np.zeros((batch_size, 2), dtype=np.float32) + scale = np.zeros((batch_size, 2), dtype=np.float32) + image_size = np.zeros((batch_size, 2), dtype=np.float32) + image_paths = [] + score = np.ones(batch_size, dtype=np.float32) + if 'bbox_id' in img_metas[0]: + bbox_ids = [] + else: + bbox_ids = None + + for i in range(batch_size): + center[i, :] = img_metas[i]['center'] + scale[i, :] = img_metas[i]['scale'] + image_size[i, :] = img_metas[i]['image_size'] + image_paths.append(img_metas[i]['image_file']) + + if 'bbox_score' in img_metas[i]: + score[i] = np.array(img_metas[i]['bbox_score']).reshape(-1) + if bbox_ids is not None: + bbox_ids.append(img_metas[i]['bbox_id']) + + all_boxes = np.zeros((batch_size, 6), dtype=np.float32) + all_boxes[:, 0:2] = center[:, 0:2] + all_boxes[:, 2:4] = scale[:, 0:2] + # scale is defined as: bbox_size / 200.0, so we + # need multiply 200.0 to get bbox size + all_boxes[:, 4] = np.prod(scale * 200.0, axis=1) + all_boxes[:, 5] = score + result['boxes'] = all_boxes + result['image_paths'] = image_paths + result['bbox_ids'] = bbox_ids + + # transform keypoint depth to camera space + joint_uvd = output[1] + preds, maxvals = keypoints_from_joint_uvd(joint_uvd, center, scale, + image_size) + keypoints_3d = np.zeros((batch_size, joint_uvd.shape[1], 4), + dtype=np.float32) + keypoints_3d[:, :, 0:3] = preds[:, :, 0:3] + keypoints_3d[:, :, 3:4] = maxvals + + center_depth = np.array( + [img_metas[i]['center_depth'] for i in range(len(img_metas))], + dtype=np.float32) + cube_size = np.array( + [img_metas[i]['cube_size'] for i in range(len(img_metas))], + dtype=np.float32) + keypoints_3d[:, :, 2] = \ + keypoints_3d[:, :, 2] * cube_size[:, 2:] / 2 \ + + center_depth[:, np.newaxis] + + result['preds'] = keypoints_3d + # joint uvd to joint xyz + cam_param = { + 'R': np.eye(3, dtype=np.float32), + 'T': np.zeros((3, 1), dtype=np.float32), + 'f': img_metas[0]['focal'].reshape(2, 1), + 'c': img_metas[0]['princpt'].reshape(2, 1), + } + single_view_camera = SimpleCamera(param=cam_param) + keypoints_xyz_list = [] + for batch_idx in range(batch_size): + keypoints_xyz_list.append( + single_view_camera.pixel_to_camera( + keypoints_3d[batch_idx, :, :3])) + result['preds_xyz'] = np.stack(keypoints_xyz_list, 0) + + return result diff --git a/mmpose/models/losses/__init__.py b/mmpose/models/losses/__init__.py index 9a491fbf76..34c2eee3e5 100644 --- a/mmpose/models/losses/__init__.py +++ b/mmpose/models/losses/__init__.py @@ -4,13 +4,14 @@ from .mesh_loss import GANLoss, MeshLoss from .mse_loss import JointsMSELoss, JointsOHKMMSELoss from .multi_loss_factory import AELoss, HeatmapLoss, MultiLossFactory -from .regression_loss import (BoneLoss, L1Loss, MPJPELoss, MSELoss, RLELoss, - SemiSupervisionLoss, SmoothL1Loss, SoftWingLoss, - WingLoss) +from .regression_loss import (AWRSmoothL1Loss, BoneLoss, L1Loss, MPJPELoss, + MSELoss, RLELoss, SemiSupervisionLoss, + SmoothL1Loss, SoftWingLoss, WingLoss) __all__ = [ 'JointsMSELoss', 'JointsOHKMMSELoss', 'HeatmapLoss', 'AELoss', 'MultiLossFactory', 'MeshLoss', 'GANLoss', 'SmoothL1Loss', 'WingLoss', 'MPJPELoss', 'MSELoss', 'L1Loss', 'BCELoss', 'BoneLoss', - 'SemiSupervisionLoss', 'SoftWingLoss', 'AdaptiveWingLoss', 'RLELoss' + 'SemiSupervisionLoss', 'SoftWingLoss', 'AdaptiveWingLoss', 'RLELoss', + 'AWRSmoothL1Loss' ] diff --git a/mmpose/models/losses/regression_loss.py b/mmpose/models/losses/regression_loss.py index fc7aa33847..25ed066198 100644 --- a/mmpose/models/losses/regression_loss.py +++ b/mmpose/models/losses/regression_loss.py @@ -528,3 +528,46 @@ def forward(self, output, target): losses['bone_loss'] = loss_bone return losses + + +@LOSSES.register_module() +class AWRSmoothL1Loss(nn.Module): + """L1Loss loss .""" + + def __init__(self, use_target_weight=False, loss_weight=1.): + super().__init__() + self.use_target_weight = use_target_weight + self.loss_weight = loss_weight + + def forward(self, output, target, target_weight=None): + """Forward function. + + Note: + - batch_size: N + - num_keypoints: K + + Args: + output (torch.Tensor[N, K, 3]): Output regression. + target (torch.Tensor[N, K, 3]): Target regression. + target_weight (torch.Tensor[N, K, 3]): + Weights across different joint types. + """ + assert (output.shape == target.shape) + if self.use_target_weight: + assert target_weight is not None + z = (output * target_weight - target * target_weight) + else: + z = (output - target) + mse_mask = (torch.abs(z) < 0.01).to(dtype=z.dtype, device=z.device) + l1_mask = (torch.abs(z) >= 0.01).to(dtype=z.dtype, device=z.device) + mse = mse_mask * z + l1 = l1_mask * z + loss = torch.mean(self._calculate_MSE(mse) * mse_mask) + torch.mean( + self._calculate_L1(l1) * l1_mask) + return loss + + def _calculate_MSE(self, z): + return 0.5 * (torch.pow(z, 2)) + + def _calculate_L1(self, z): + return 0.01 * (torch.abs(z) - 0.005) diff --git a/tests/test_models/test_awr_3d_head.py b/tests/test_models/test_awr_3d_head.py new file mode 100644 index 0000000000..e5325ee7eb --- /dev/null +++ b/tests/test_models/test_awr_3d_head.py @@ -0,0 +1,91 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np +import torch + +from mmpose.models import AdaptiveWeightingRegression3DHead + + +def test_awr_3d_head(): + N = 4 + input_shape = (N, 2048, 8, 8) + inputs = torch.rand(input_shape, dtype=torch.float32) + + img_input_shape = (N, 1, 128, 128) + img_inputs = torch.rand(img_input_shape, dtype=torch.float32) + + target = [ + inputs.new_ones(N, 14 * 4, 64, 64), + inputs.new_ones(N, 14, 3), + ] + target_weight = [ + inputs.new_ones(N, 14), + inputs.new_ones(N, 14, 3), + ] + + cameras = {'fx': 588.03, 'fy': 587.07, 'cx': 320.0, 'cy': 240.0} + + img_metas = [{ + 'img_shape': (128, 128, 3), + 'center': np.array([112, 112]), + 'scale': np.array([0.5, 0.5]), + 'bbox_score': 1.0, + 'bbox_id': 0, + 'flip_pairs': [], + 'inference_channel': np.arange(14), + 'cube_size': np.array([300, 300, 300]), + 'center_depth': 1.0, + 'focal': np.array([cameras['fx'], cameras['fy']]), + 'princpt': np.array([cameras['cx'], cameras['cy']]), + 'image_file': '.png', + } for _ in range(N)] + + print('fake input OK') + + head = AdaptiveWeightingRegression3DHead( + deconv_head_cfg=dict( + in_channels=2048, + out_channels=256, + depth_size=64, + num_deconv_layers=3, + num_deconv_filters=(256, 256, 256), + num_deconv_kernels=(4, 4, 4), + extra=dict(final_conv_kernel=0, )), + offset_head_cfg=dict( + in_channels=256, + out_channels_vector=42, + out_channels_scalar=14, + heatmap_kernel_size=0.4, + ), + loss_keypoint=dict(type='AWRSmoothL1Loss', use_target_weight=True), + loss_offset=dict(type='AWRSmoothL1Loss', use_target_weight=False), + train_cfg=dict(use_img_for_head=True), + test_cfg=dict(use_img_for_head=True, flip_test=False)) + + print('init OK') + + head.init_weights() + + # test forward + inputs_with_img = (inputs, img_inputs) + output = head(inputs_with_img) + assert isinstance(output, list) + assert len(output) == 2 + assert output[0].shape == (N, 14 * 4, 64, 64) + assert output[1].shape == (N, 14, 3) + + # test loss computation + losses = head.get_loss(output, target, target_weight) + assert 'joint_loss' in losses + assert 'offset_loss' in losses + + # test inference model + output = head.inference_model(inputs_with_img, flip_pairs=None) + assert isinstance(output, list) + assert len(output) == 2 + assert output[0].shape == (N, 14 * 4, 64, 64) + assert output[1].shape == (N, 14, 3) + + # test decode + result = head.decode(img_metas, output) + assert 'preds' in result + assert 'preds_xyz' in result diff --git a/tests/test_models/test_depthhand_3d_forward.py b/tests/test_models/test_depthhand_3d_forward.py new file mode 100644 index 0000000000..241d4ef1f4 --- /dev/null +++ b/tests/test_models/test_depthhand_3d_forward.py @@ -0,0 +1,108 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np +import torch + +from mmpose.models import build_posenet + + +def test_interhand3d_forward(): + # model settings + model_cfg = dict( + type='Depthhand3D', # pretrained=None + backbone=dict( + type='AWRResNet', + depth=50, + frozen_stages=-1, + zero_init_residual=False, + in_channels=1), + keypoint_head=dict( + type='AdaptiveWeightingRegression3DHead', + offset_head_cfg=dict( + in_channels=256, + out_channels_vector=42, + out_channels_scalar=14, + heatmap_kernel_size=0.4, + ), + deconv_head_cfg=dict( + in_channels=2048, + out_channels=256, + depth_size=64, + num_deconv_layers=3, + num_deconv_filters=(256, 256, 256), + num_deconv_kernels=(4, 4, 4), + extra=dict(final_conv_kernel=0, )), + loss_offset=dict(type='AWRSmoothL1Loss', use_target_weight=False), + loss_keypoint=dict(type='AWRSmoothL1Loss', use_target_weight=True), + ), + train_cfg=dict(use_img_for_head=True), + test_cfg=dict(use_img_for_head=True, flip_test=False)) + + detector = build_posenet(model_cfg) + detector.init_weights() + + input_shape = (2, 1, 128, 128) + mm_inputs = _demo_mm_inputs(input_shape) + + imgs = mm_inputs.pop('imgs') + target = mm_inputs.pop('target') + target_weight = mm_inputs.pop('target_weight') + img_metas = mm_inputs.pop('img_metas') + + # Test forward train + losses = detector.forward( + imgs, target, target_weight, img_metas, return_loss=True) + assert isinstance(losses, dict) + + # Test forward test + with torch.no_grad(): + _ = detector.forward(imgs, img_metas=img_metas, return_loss=False) + _ = detector.forward_dummy(imgs) + + +def _demo_mm_inputs(input_shape=(1, 1, 128, 128), num_outputs=None): + """Create a superset of inputs needed to run test or train batches. + + Args: + input_shape (tuple): + input batch dimensions + """ + (N, C, H, W) = input_shape + + rng = np.random.RandomState(0) + + imgs = rng.rand(*input_shape) + imgs = torch.FloatTensor(imgs) + + target = [ + imgs.new_ones(N, 14 * 4, 64, 64), + imgs.new_ones(N, 14, 3), + ] + target_weight = [ + imgs.new_ones(N, 14), + imgs.new_ones(N, 14, 3), + ] + + cameras = {'fx': 588.03, 'fy': 587.07, 'cx': 320.0, 'cy': 240.0} + + img_metas = [{ + 'img_shape': (128, 128, 3), + 'center': np.array([112, 112]), + 'scale': np.array([0.5, 0.5]), + 'bbox_score': 1.0, + 'bbox_id': 0, + 'flip_pairs': [], + 'inference_channel': np.arange(14), + 'cube_size': np.array([300, 300, 300]), + 'center_depth': 1.0, + 'focal': np.array([cameras['fx'], cameras['fy']]), + 'princpt': np.array([cameras['cx'], cameras['cy']]), + 'image_file': '.png', + } for _ in range(N)] + + mm_inputs = { + 'imgs': imgs.requires_grad_(True), + 'target': target, + 'target_weight': target_weight, + 'img_metas': img_metas + } + return mm_inputs