Skip to content

Commit

Permalink
Add the hmr head and discriminator for SMPL parameters. Add test code…
Browse files Browse the repository at this point in the history
…s and test data.
  • Loading branch information
zengwang430521 committed Sep 30, 2020
1 parent 3a35ed2 commit e068e7b
Show file tree
Hide file tree
Showing 7 changed files with 533 additions and 0 deletions.
1 change: 1 addition & 0 deletions mmpose/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .detectors import * # noqa
from .keypoint_heads import * # noqa
from .losses import * # noqa
from .mesh_heads import * # noqa
from .registry import BACKBONES, HEADS, LOSSES, POSENETS

__all__ = [
Expand Down
3 changes: 3 additions & 0 deletions mmpose/models/mesh_heads/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .hmr_head import MeshHMRHead

__all__ = ['MeshHMRHead']
286 changes: 286 additions & 0 deletions mmpose/models/mesh_heads/discriminator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,286 @@
# ------------------------------------------------------------------------------
# Adapted from https://github.com/akanazawa/hmr
# Original licence: Copyright (c) 2018 akanazawa, under the MIT License.
# ------------------------------------------------------------------------------

from abc import abstractmethod

import torch
import torch.nn as nn

from .geometric_layers import batch_rodrigues


class BaseDiscriminator(nn.Module):
"""Base linear module for SMPL parameter discriminator.
Args:
fc_layers (Tuple): Tuple of neuron count,
such as (9, 32, 32, 1)
use_dropout (Tuple): Tuple of bool define use dropout or not
for each layer, such as (True, True, False)
drop_prob (Tuple): Tuple of float defined the drop prob,
such as (0.5, 0.5, 0)
use_activation(Tuple): Tuple of bool define use active function
or not, such as (True, True, False)
"""

def __init__(self, fc_layers, use_dropout, drop_prob, use_activation):
super().__init__()
self.fc_layers = fc_layers
self.use_dropout = use_dropout
self.drop_prob = drop_prob
self.use_activation = use_activation
self._check()
self.create_layers()

def _check(self):
"""Check input to avoid ValueError."""
if not isinstance(self.fc_layers, tuple):
raise TypeError(f'fc_layers require tuple, '
f'get {type(self.fc_layers)}')

if not isinstance(self.use_dropout, tuple):
raise TypeError(f'use_dropout require tuple, '
f'get {type(self.use_dropout)}')

if not isinstance(self.drop_prob, tuple):
raise TypeError(f'drop_prob require tuple, '
f'get {type(self.drop_prob)}')

if not isinstance(self.use_activation, tuple):
raise TypeError(f'use_activation require tuple, '
f'get {type(self.use_activation)}')

l_fc_layer = len(self.fc_layers)
l_use_drop = len(self.use_dropout)
l_drop_prob = len(self.drop_prob)
l_use_activation = len(self.use_activation)

pass_check = (
l_fc_layer >= 2 and l_use_drop < l_fc_layer
and l_drop_prob < l_fc_layer and l_use_activation < l_fc_layer
and l_drop_prob == l_use_drop)

if not pass_check:
msg = 'Wrong BaseDiscriminator parameters!'
raise ValueError(msg)

def create_layers(self):
"""Create layers."""
l_fc_layer = len(self.fc_layers)
l_use_drop = len(self.use_dropout)
l_use_activation = len(self.use_activation)

self.fc_blocks = nn.Sequential()

for i in range(l_fc_layer - 1):
self.fc_blocks.add_module(
name=f'regressor_fc_{i}',
module=nn.Linear(
in_features=self.fc_layers[i],
out_features=self.fc_layers[i + 1]))

if i < l_use_activation and self.use_activation[i]:
self.fc_blocks.add_module(
name=f'regressor_af_{i}', module=nn.ReLU())

if i < l_use_drop and self.use_dropout[i]:
self.fc_blocks.add_module(
name=f'regressor_fc_dropout_{i}',
module=nn.Dropout(p=self.drop_prob[i]))

@abstractmethod
def forward(self, inputs):
"""Forward function."""
msg = 'the base class [BaseDiscriminator] is not callable!'
raise NotImplementedError(msg)


class ShapeDiscriminator(BaseDiscriminator):
"""Discriminator for SMPL shape parameters, the inputs is (batch size x 10)
Args:
fc_layers (Tuple): Tuple of neuron count,
such as (10, 5, 1)
use_dropout (Tuple): Tuple of bool define use dropout or
not for each layer, such as (True, True, False)
drop_prob (Tuple): Tuple of float defined the drop prob,
such as (0.5, 0)
use_activation(Tuple): Tuple of bool define use active
function or not, such as (True, False)
"""

def __init__(self, fc_layers, use_dropout, drop_prob, use_activation):
if fc_layers[-1] != 1:
msg = f'the neuron count of the last layer ' \
f'must be 1, but got {fc_layers[-1]}'
raise ValueError(msg)

super().__init__(fc_layers, use_dropout, drop_prob, use_activation)

def forward(self, inputs):
"""Forward function."""
return self.fc_blocks(inputs)


class PoseDiscriminator(nn.Module):
"""Discriminator for SMPL pose parameters of each joint. It is composed of
discriminators for each joints. The inputs is (batch size x joint_count x
9)
Args:
channels (Tuple): Tuple of channel number,
such as (9, 32, 32, 1)
joint_count (int): Joint number, such as 23
"""

def __init__(self, channels, joint_count):
super().__init__()
if channels[-1] != 1:
msg = f'the neuron count of the last layer ' \
f'must be 1, but got {channels[-1]}'
raise ValueError(msg)
self.joint_count = joint_count

self.conv_blocks = nn.Sequential()
len_channels = len(channels)
for idx in range(len_channels - 2):
self.conv_blocks.add_module(
name=f'conv_{idx}',
module=nn.Conv2d(
in_channels=channels[idx],
out_channels=channels[idx + 1],
kernel_size=1,
stride=1))

self.fc_layer = nn.ModuleList()
for idx in range(joint_count):
self.fc_layer.append(
nn.Linear(
in_features=channels[len_channels - 2], out_features=1))

def forward(self, inputs):
"""Forward function.
The input is (batch size x joint_count x 9)
"""
inputs = inputs.transpose(1, 2).\
unsqueeze(2).contiguous() # to N x 9 x 1 x joint_count
internal_outputs = self.conv_blocks(
inputs) # to N x c x 1 x joint_count
outputs = []
for idx in range(self.joint_count):
outputs.append(self.fc_layer[idx](internal_outputs[:, :, 0, idx]))

return torch.cat(outputs, 1), internal_outputs


class FullPoseDiscriminator(BaseDiscriminator):
"""Discriminator for SMPL pose parameters of all joints.
Args:
fc_layers (Tuple): Tuple of neuron count,
such as (736, 1024, 1024, 1)
use_dropout (Tuple): Tuple of bool define use dropout or not
for each layer, such as (True, True, False)
drop_prob (Tuple): Tuple of float defined the drop prob,
such as (0.5, 0.5, 0)
use_activation(Tuple): Tuple of bool define use active
function or not, such as (True, True, False)
"""

def __init__(self, fc_layers, use_dropout, drop_prob, use_activation):
if fc_layers[-1] != 1:
msg = f'the neuron count of the last layer must be 1,' \
f' but got {fc_layers[-1]}'
raise ValueError(msg)

super().__init__(fc_layers, use_dropout, drop_prob, use_activation)

def forward(self, inputs):
"""Forward function."""
return self.fc_blocks(inputs)


class SMPLDiscriminator(nn.Module):
"""Discriminator for SMPL pose and shape parameters. It is composed of a
discriminator for SMPL shape parameters, a discriminator for SMPL pose
parameters of all joints and a discriminator for SMPL pose parameters of
each joint.
Args:
beta_channel (tuple of int): Tuple of neuron count of the
discriminator of shape parameters. Defaults to (10, 5, 1)
per_joint_channel (tuple of int): Tuple of neuron count of the
discriminator of each joint. Defaults to (9, 32, 32, 1)
full_pose_channel (tuple of int): Tuple of neuron count of the
discriminator of full pose. Defaults to (23*32, 1024, 1024, 1)
"""

def __init__(self,
beta_channel=(10, 5, 1),
per_joint_channel=(9, 32, 32, 1),
full_pose_channel=(23 * 32, 1024, 1024, 1)):
super().__init__()
self.joint_count = 23
# The count of SMPL shape parameter is 10.
assert beta_channel[0] == 10
# Use 3 x 3 rotation matrix as the pose parameters
# of each joint, so the input channel is 9.
assert per_joint_channel[0] == 9
assert self.joint_count * per_joint_channel[-2] \
== full_pose_channel[0]

self.beta_channel = beta_channel
self.per_joint_channel = per_joint_channel
self.full_pose_channel = full_pose_channel
self._create_sub_modules()

def _create_sub_modules(self):
"""Create sub discriminators."""

# create theta discriminator for each joint
self.pose_discriminator = PoseDiscriminator(self.per_joint_channel,
self.joint_count)

# create full pose discriminator for total joints
fc_layers = self.full_pose_channel
use_dropout = tuple([False] * (len(fc_layers) - 1))
drop_prob = tuple([0.5] * (len(fc_layers) - 1))
use_activation = tuple([True] * (len(fc_layers) - 2) + [False])

self.full_pose_discriminator = FullPoseDiscriminator(
fc_layers, use_dropout, drop_prob, use_activation)

# create shape discriminator for betas
fc_layers = self.beta_channel
use_dropout = tuple([False] * (len(fc_layers) - 1))
drop_prob = tuple([0.5] * (len(fc_layers) - 1))
use_activation = tuple([True] * (len(fc_layers) - 2) + [False])
self.shape_discriminator = ShapeDiscriminator(fc_layers, use_dropout,
drop_prob,
use_activation)

def forward(self, thetas):
"""Forward function."""
cams, poses, shapes = thetas

batch_size = poses.shape[0]
shape_disc_value = self.shape_discriminator(shapes)

# The first rotation matrix is global rotation
# and is NOT used in discriminator.
if poses.dim() == 2:
rotate_matrixs = \
batch_rodrigues(poses.contiguous().view(-1, 3)
).view(batch_size, 24, 9)[:, 1:, :]
else:
rotate_matrixs = poses.contiguous().view(batch_size, 24,
9)[:, 1:, :].contiguous()
pose_disc_value, pose_inter_disc_value \
= self.pose_discriminator(rotate_matrixs)
full_pose_disc_value = self.full_pose_discriminator(
pose_inter_disc_value.contiguous().view(batch_size, -1))
return torch.cat(
(pose_disc_value, full_pose_disc_value, shape_disc_value), 1)
67 changes: 67 additions & 0 deletions mmpose/models/mesh_heads/geometric_layers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import torch
from torch.nn import functional as F


def rot6d_to_rotmat(x):
"""Convert 6D rotation representation to 3x3 rotation matrix.
Based on Zhou et al., "On the Continuity of Rotation
Representations in Neural Networks", CVPR 2019
Input:
(B,6) Batch of 6-D rotation representations
Output:
(B,3,3) Batch of corresponding rotation matrices
"""
x = x.view(-1, 3, 2)
a1 = x[:, :, 0]
a2 = x[:, :, 1]
b1 = F.normalize(a1)
b2 = F.normalize(a2 - torch.einsum('bi,bi->b', b1, a2).unsqueeze(-1) * b1)
b3 = torch.cross(b1, b2)
return torch.stack((b1, b2, b3), dim=-1)


def batch_rodrigues(theta):
"""Convert axis-angle representation to rotation matrix.
Args:
theta: size = [B, 3]
Returns:
Rotation matrix corresponding to the quaternion
-- size = [B, 3, 3]
"""
l2norm = torch.norm(theta + 1e-8, p=2, dim=1)
angle = torch.unsqueeze(l2norm, -1)
normalized = torch.div(theta, angle)
angle = angle * 0.5
v_cos = torch.cos(angle)
v_sin = torch.sin(angle)
quat = torch.cat([v_cos, v_sin * normalized], dim=1)
return quat_to_rotmat(quat)


def quat_to_rotmat(quat):
"""Convert quaternion coefficients to rotation matrix.
Args:
quat: size = [B, 4] 4 <===>(w, x, y, z)
Returns:
Rotation matrix corresponding to the quaternion
-- size = [B, 3, 3]
"""
norm_quat = quat
norm_quat = norm_quat / norm_quat.norm(p=2, dim=1, keepdim=True)
w, x, y, z = norm_quat[:, 0], norm_quat[:, 1],\
norm_quat[:, 2], norm_quat[:, 3]

B = quat.size(0)

w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2)
wx, wy, wz = w * x, w * y, w * z
xy, xz, yz = x * y, x * z, y * z

rotMat = torch.stack([
w2 + x2 - y2 - z2, 2 * xy - 2 * wz, 2 * wy + 2 * xz, 2 * wz + 2 * xy,
w2 - x2 + y2 - z2, 2 * yz - 2 * wx, 2 * xz - 2 * wy, 2 * wx + 2 * yz,
w2 - x2 - y2 + z2
],
dim=1).view(B, 3, 3)
return rotMat

0 comments on commit e068e7b

Please sign in to comment.