In [None]:
# default_exp transforms

# Transforms

> Utilities to transform between a few of the many rotational formalisms.

In [None]:
# hide
from nbdev.showdoc import *

In [None]:
# imports that are only required for testing
import tempfile
import warnings
from itertools import permutations, chain
from llamass.core import unpack_body_models, AMASS

In [None]:
# export
import torch
import torch.nn.functional as F
import numpy as np
from scipy.spatial.transform import Rotation as R

# Copied Code

This code is copied [from torchgeometry in `nghorbani/human_body_prior`][hbptools]. Copying it into this library to run tests with it.

[hbptools]: https://github.com/nghorbani/human_body_prior/blob/0278cb45180992e4d39ba1a11601f5ecc53ee148/src/human_body_prior/tools/tgm_conversion.py

In [None]:
# exports
def rotation_matrix_to_angle_axis(rotation_matrix):
    """Convert 3x4 rotation matrix to Rodrigues vector
    Args:
        rotation_matrix (Tensor): rotation matrix.
    Returns:
        Tensor: Rodrigues vector transformation.
    Shape:
        - Input: :math:`(N, 3, 4)`
        - Output: :math:`(N, 3)`
    Example:
        >>> input = torch.rand(2, 3, 4)  # Nx4x4
        >>> output = tgm.rotation_matrix_to_angle_axis(input)  # Nx3
    """
    # todo add check that matrix is a valid rotation matrix
    quaternion = rotation_matrix_to_quaternion(rotation_matrix)
    return quaternion_to_angle_axis(quaternion)

def rotation_matrix_to_quaternion(rotation_matrix, eps=1e-6):
    """Convert 3x4 rotation matrix to 4d quaternion vector
    This algorithm is based on algorithm described in
    https://github.com/KieranWynn/pyquaternion/blob/master/pyquaternion/quaternion.py#L201
    Args:
        rotation_matrix (Tensor): the rotation matrix to convert.
    Return:
        Tensor: the rotation in quaternion
    Shape:
        - Input: :math:`(N, 3, 4)`
        - Output: :math:`(N, 4)`
    Example:
        >>> input = torch.rand(4, 3, 4)  # Nx3x4
        >>> output = tgm.rotation_matrix_to_quaternion(input)  # Nx4
    """
    if not torch.is_tensor(rotation_matrix):
        raise TypeError("Input type is not a torch.Tensor. Got {}".format(
            type(rotation_matrix)))

    if len(rotation_matrix.shape) > 3:
        raise ValueError(
            "Input size must be a three dimensional tensor. Got {}".format(
                rotation_matrix.shape))
    if not rotation_matrix.shape[-2:] == (3, 4):
        raise ValueError(
            "Input size must be a N x 3 x 4  tensor. Got {}".format(
                rotation_matrix.shape))

    rmat_t = torch.transpose(rotation_matrix, 1, 2)

    mask_d2 = rmat_t[:, 2, 2] < eps

    mask_d0_d1 = rmat_t[:, 0, 0] > rmat_t[:, 1, 1]
    mask_d0_nd1 = rmat_t[:, 0, 0] < -rmat_t[:, 1, 1]

    t0 = 1 + rmat_t[:, 0, 0] - rmat_t[:, 1, 1] - rmat_t[:, 2, 2]
    q0 = torch.stack([rmat_t[:, 1, 2] - rmat_t[:, 2, 1],
                      t0, rmat_t[:, 0, 1] + rmat_t[:, 1, 0],
                      rmat_t[:, 2, 0] + rmat_t[:, 0, 2]], -1)
    t0_rep = t0.repeat(4, 1).t()

    t1 = 1 - rmat_t[:, 0, 0] + rmat_t[:, 1, 1] - rmat_t[:, 2, 2]
    q1 = torch.stack([rmat_t[:, 2, 0] - rmat_t[:, 0, 2],
                      rmat_t[:, 0, 1] + rmat_t[:, 1, 0],
                      t1, rmat_t[:, 1, 2] + rmat_t[:, 2, 1]], -1)
    t1_rep = t1.repeat(4, 1).t()

    t2 = 1 - rmat_t[:, 0, 0] - rmat_t[:, 1, 1] + rmat_t[:, 2, 2]
    q2 = torch.stack([rmat_t[:, 0, 1] - rmat_t[:, 1, 0],
                      rmat_t[:, 2, 0] + rmat_t[:, 0, 2],
                      rmat_t[:, 1, 2] + rmat_t[:, 2, 1], t2], -1)
    t2_rep = t2.repeat(4, 1).t()

    t3 = 1 + rmat_t[:, 0, 0] + rmat_t[:, 1, 1] + rmat_t[:, 2, 2]
    q3 = torch.stack([t3, rmat_t[:, 1, 2] - rmat_t[:, 2, 1],
                      rmat_t[:, 2, 0] - rmat_t[:, 0, 2],
                      rmat_t[:, 0, 1] - rmat_t[:, 1, 0]], -1)
    t3_rep = t3.repeat(4, 1).t()

    mask_c0 = mask_d2 * mask_d0_d1
    mask_c1 = mask_d2 * torch.logical_not(mask_d0_d1)
    mask_c2 = torch.logical_not(mask_d2) * mask_d0_nd1
    mask_c3 = torch.logical_not(mask_d2) * torch.logical_not(mask_d0_nd1)
    mask_c0 = mask_c0.view(-1, 1).type_as(q0)
    mask_c1 = mask_c1.view(-1, 1).type_as(q1)
    mask_c2 = mask_c2.view(-1, 1).type_as(q2)
    mask_c3 = mask_c3.view(-1, 1).type_as(q3)

    q = q0 * mask_c0 + q1 * mask_c1 + q2 * mask_c2 + q3 * mask_c3
    q /= torch.sqrt(t0_rep * mask_c0 + t1_rep * mask_c1 +  # noqa
                    t2_rep * mask_c2 + t3_rep * mask_c3)  # noqa
    q *= 0.5
    return q

def quaternion_to_angle_axis(quaternion) -> torch.Tensor:
    """Convert quaternion vector to angle axis of rotation.
    Adapted from ceres C++ library: ceres-solver/include/ceres/rotation.h
    Args:
        quaternion (torch.Tensor): tensor with quaternions.
    Return:
        torch.Tensor: tensor with angle axis of rotation.
    Shape:
        - Input: :math:`(*, 4)` where `*` means, any number of dimensions
        - Output: :math:`(*, 3)`
    Example:
        >>> quaternion = torch.rand(2, 4)  # Nx4
        >>> angle_axis = tgm.quaternion_to_angle_axis(quaternion)  # Nx3
    """
    if not torch.is_tensor(quaternion):
        raise TypeError("Input type is not a torch.Tensor. Got {}".format(
            type(quaternion)))

    if not quaternion.shape[-1] == 4:
        raise ValueError("Input must be a tensor of shape Nx4 or 4. Got {}"
                         .format(quaternion.shape))
    # unpack input and compute conversion
    q1 = quaternion[..., 1]
    q2 = quaternion[..., 2]
    q3 = quaternion[..., 3]
    sin_squared_theta = q1 * q1 + q2 * q2 + q3 * q3

    sin_theta = torch.sqrt(sin_squared_theta)
    cos_theta = quaternion[..., 0]
    two_theta = 2.0 * torch.where(
        cos_theta < 0.0,
        torch.atan2(-sin_theta, -cos_theta),
        torch.atan2(sin_theta, cos_theta))

    k_pos = two_theta / sin_theta
    k_neg = 2.0 * torch.ones_like(sin_theta)
    k = torch.where(sin_squared_theta > 0.0, k_pos, k_neg)

    angle_axis = torch.zeros_like(quaternion)[..., :3]
    angle_axis[..., 0] += q1 * k
    angle_axis[..., 1] += q2 * k
    angle_axis[..., 2] += q3 * k
    return angle_axis

def angle_axis_to_rotation_matrix(angle_axis):
    """Convert 3d vector of axis-angle rotation to 4x4 rotation matrix
    Args:
        angle_axis (Tensor): tensor of 3d vector of axis-angle rotations.
    Returns:
        Tensor: tensor of 4x4 rotation matrices.
    Shape:
        - Input: :math:`(N, 3)`
        - Output: :math:`(N, 4, 4)`
    Example:
        >>> input = torch.rand(1, 3)  # Nx3
        >>> output = tgm.angle_axis_to_rotation_matrix(input)  # Nx4x4
    """
    def _compute_rotation_matrix(angle_axis, theta2, eps=1e-6):
        # We want to be careful to only evaluate the square root if the
        # norm of the angle_axis vector is greater than zero. Otherwise
        # we get a division by zero.
        k_one = 1.0
        theta = torch.sqrt(theta2)
        wxyz = angle_axis / (theta + eps)
        wx, wy, wz = torch.chunk(wxyz, 3, dim=1)
        cos_theta = torch.cos(theta)
        sin_theta = torch.sin(theta)

        r00 = cos_theta + wx * wx * (k_one - cos_theta)
        r10 = wz * sin_theta + wx * wy * (k_one - cos_theta)
        r20 = -wy * sin_theta + wx * wz * (k_one - cos_theta)
        r01 = wx * wy * (k_one - cos_theta) - wz * sin_theta
        r11 = cos_theta + wy * wy * (k_one - cos_theta)
        r21 = wx * sin_theta + wy * wz * (k_one - cos_theta)
        r02 = wy * sin_theta + wx * wz * (k_one - cos_theta)
        r12 = -wx * sin_theta + wy * wz * (k_one - cos_theta)
        r22 = cos_theta + wz * wz * (k_one - cos_theta)
        rotation_matrix = torch.cat(
            [r00, r01, r02, r10, r11, r12, r20, r21, r22], dim=1)
        return rotation_matrix.view(-1, 3, 3)

    def _compute_rotation_matrix_taylor(angle_axis):
        rx, ry, rz = torch.chunk(angle_axis, 3, dim=1)
        k_one = torch.ones_like(rx)
        rotation_matrix = torch.cat(
            [k_one, -rz, ry, rz, k_one, -rx, -ry, rx, k_one], dim=1)
        return rotation_matrix.view(-1, 3, 3)

    # stolen from ceres/rotation.h

    _angle_axis = torch.unsqueeze(angle_axis, dim=1)
    theta2 = torch.matmul(_angle_axis, _angle_axis.transpose(1, 2))
    theta2 = torch.squeeze(theta2, dim=1)

    # compute rotation matrices
    rotation_matrix_normal = _compute_rotation_matrix(angle_axis, theta2)
    rotation_matrix_taylor = _compute_rotation_matrix_taylor(angle_axis)

    # create mask to handle both cases
    eps = 1e-6
    mask = (theta2 > eps).view(-1, 1, 1).to(theta2.device)
    mask_pos = (mask).type_as(theta2)
    mask_neg = (mask == False).type_as(theta2)  # noqa

    # create output pose matrix
    batch_size = angle_axis.shape[0]
    rotation_matrix = torch.eye(4).to(angle_axis.device).type_as(angle_axis)
    rotation_matrix = rotation_matrix.view(1, 4, 4).repeat(batch_size, 1, 1)
    # fill output matrix with masked values
    rotation_matrix[..., :3, :3] = \
        mask_pos * rotation_matrix_normal + mask_neg * rotation_matrix_taylor
    return rotation_matrix  # Nx4x4

In [None]:
# human_body_prior
# this cell tests that the above copied code is consistent with the code in the original library
import human_body_prior.tools.rotation_tools as hbp_rot

def matrot2aa(pose_matrot):
    '''
    :param pose_matrot: Nx3x3
    :return: Nx3
    '''
    bs = pose_matrot.size(0)
    homogen_matrot = F.pad(pose_matrot, [0,1])
    pose = rotation_matrix_to_angle_axis(homogen_matrot)
    return pose

def aa2matrot(pose):
    '''
    :param Nx3
    :return: pose_matrot: Nx3x3
    '''
    bs = pose.size(0)
    num_joints = pose.size(1)//3
    pose_body_matrot = angle_axis_to_rotation_matrix(pose)[:, :3, :3].contiguous()
    return pose_body_matrot

with tempfile.TemporaryDirectory() as tmpdirname:
    unpack_body_models("sample_data/", tmpdirname, 2)
    amass = AMASS(tmpdirname, overlapping=False, clip_length=1, transform=torch.tensor)
    poses = []
    for x in amass:
        pose = x['poses'][:,3:66]
        poses.append(pose)
    poses = torch.stack(poses).view(-1, 3)
    m = aa2matrot(poses)
    _m = hbp_rot.aa2matrot(poses)
    assert torch.abs(m - _m).max() < 1e-4

  0%|          | 0/1 [00:00<?, ?it/s]

# Rotation Class

> A `scipy.spatial.transform`-like interface to transform Tensors.

In [None]:
# exports
class Rotation():
    """
    Class to give a scipy.spatial.transform-like interface
    for converting between rotational formalisms in PyTorch.
    Acts on trailing axes and maintains leading tensor shape.
    """
    def __init__(self, tensor, shape, formalism):
        self.tensor, self.shape = tensor, shape
        self.formalism = formalism
    
    @staticmethod
    def from_rotvec(x):
        return Rotation(x.view(-1, 3), x.size(), 'rotvec')
        
    @staticmethod
    def from_matrix(x):
        return Rotation(x.view(-1, 3, 3), x.size(), 'matrix')
    
    def as_rotvec(self):
        if self.formalism != 'matrix':
            raise NotImplementedError()
        s = self.shape
        rotvec = rotation_matrix_to_angle_axis(F.pad(self.tensor, [0,1]))
        return rotvec.reshape(*s[:-1], s[-1]//3)
    
    def as_matrix(self):
        if self.formalism != 'rotvec':
            raise NotImplementedError()
        s = self.shape
        matrot = angle_axis_to_rotation_matrix(self.tensor)[:, :3, :3].contiguous()
        return matrot.view(*s[:-1], s[-1]*3)

    def from_euler(self):
        raise NotImplementedError()
    def from_quat(self):
        raise NotImplementedError()
    def from_mrp(self):
        raise NotImplementedError()
    def as_euler(self, degrees=False):
        if degrees:
            raise NotImplementedError("Degrees as output not supported.")
        if self.formalism == 'rotvec':
            self = Rotation.from_matrix(self.as_matrix())
        elif self.formalism != 'matrix':
            raise NotImplementedError()
        rs = self.tensor
        n_samples = rs.size(0)
        
        # initialize to zeros
        e1 = torch.zeros([n_samples]).to(rs.device)
        e2 = torch.zeros([n_samples]).to(rs.device)
        e3 = torch.zeros([n_samples]).to(rs.device)
        
        # find indices where we need to treat special cases
        is_one = rs[:, 0, 2] == 1
        is_minus_one = rs[:, 0, 2] == -1
        is_special = torch.logical_or(is_one, is_minus_one)
        
        e1[is_special] = torch.atan2(rs[is_special, 0, 1], rs[is_special, 0, 2])
        e2[is_minus_one] = np.pi/2
        e2[is_one] = -np.pi/2
        
        # normal cases
        is_normal = torch.logical_not(torch.logical_or(is_one, is_minus_one))
        # clip inputs to arcsin
        in_ = torch.clamp(rs[is_normal, 0, 2], -1, 1)
        e2[is_normal] = -torch.arcsin(in_)
        e2_cos = torch.cos(e2[is_normal])
        e1[is_normal] = torch.atan2(rs[is_normal, 1, 2]/e2_cos,
                                    rs[is_normal, 2, 2]/e2_cos)
        e3[is_normal] = torch.atan2(rs[is_normal, 0, 1]/e2_cos,
                                    rs[is_normal, 0, 0]/e2_cos)

        eul = torch.stack([e1, e2, e3], axis=-1)
        #eul = np.reshape(eul, np.concatenate([orig_shape, eul.shape[1:]]))
        s = self.shape
        eul = eul.reshape(*s[:-1], s[-1]//3)
        return eul
    def as_quat(self):
        raise NotImplementedError()
    def as_mrp(self):
        raise NotImplementedError()

The following cell tests that the baked in `aa2matrot` function computes the same rotation matrices as scipy. It also confirms that the data is definitely encoded as rotation vectors.

In [None]:
"""
MIT License

Copyright (c) 2016 Julieta Martinez, Javier Romero

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""

def rotmat2euler( R ):
    """
    Converts a rotation matrix to Euler angles
    Matlab port to python for evaluation purposes
    https://github.com/asheshjain399/RNNexp/blob/srnn/structural_rnn/CRFProblems/H3.6m/mhmublv/Motion/RotMat2Euler.m#L1
    Args
    R: a 3x3 rotation matrix
    Returns
    eul: a 3x1 Euler angle representation of R
    """
    if R[0,2] == 1 or R[0,2] == -1:
        # special case
        E3   = 0 # set arbitrarily
        dlta = np.arctan2( R[0,1], R[0,2] );

        if R[0,2] == -1:
            E2 = np.pi/2;
            E1 = E3 + dlta;
        else:
            E2 = -np.pi/2;
            E1 = -E3 + dlta;
    else:
        E2 = -np.arcsin( R[0,2] )
        E1 = np.arctan2( R[1,2]/np.cos(E2), R[2,2]/np.cos(E2) )
        E3 = np.arctan2( R[0,1]/np.cos(E2), R[0,0]/np.cos(E2) )

    eul = np.array([E1, E2, E3]);
    return eul

In [None]:
with tempfile.TemporaryDirectory() as tmpdirname:
    unpack_body_models("sample_data/", tmpdirname, 2)
    amass = AMASS(tmpdirname, overlapping=False, clip_length=1, transform=torch.tensor)
    poses = []
    for x in amass:
        pose = x['poses'][:,3:66]
        poses.append(pose)
    poses = torch.stack(poses)
    
    # test rotvec -> matrix
    m = Rotation.from_rotvec(poses).as_matrix()
    # check if this match scipy
    matrots = []
    for rotvec in poses.view(-1, 3):
        r = R.from_rotvec(rotvec.numpy())
        matrots.append(r.as_matrix())
    matrots = torch.tensor(np.stack(matrots))
    assert torch.abs(matrots - m.view(-1, 3, 3)).max() < 1e-5
    
    # test matrix -> rotvec
    aa = Rotation.from_matrix(m).as_rotvec()
    rotvecs = []
    for matrix in m.view(-1, 3*3):
        r = R.from_matrix(matrix.view(3, 3))
        rotvecs.append(r.as_rotvec())
    rotvecs = torch.tensor(np.stack(rotvecs))
    assert torch.abs(rotvecs - aa.view(-1, 3)).max() < 1e-5
    assert torch.abs(poses - aa).max() < 1e-5
    
    # test rotvec -> matrix -> euler
    e = Rotation.from_rotvec(poses).as_euler()
    # iterate over all possible euler angle conventions
    for seq in chain(permutations('xyz'), permutations('XYZ')):
        euler_angles = []
        for rotvec in poses.view(-1, 3):
            r = R.from_rotvec(rotvec.numpy())
            euler_angles.append(r.as_euler("".join(seq), degrees=False))
        euler_angles = torch.tensor(np.stack(euler_angles))
        err = torch.abs(euler_angles - e.view(-1, 3)).max()
        #print(seq, *[x.item() for x in [err, e.min(), e.max(), euler_angles.min(), euler_angles.max()]])
        if err < 1e-5:
            break
    if err > 1e-5:
        warnings.warn(f"Euler angle computation does not match any convention.")
    euler_angles = []
    for rotmat in matrots:
        euler_angles.append(rotmat2euler(rotmat.numpy()))
    euler_angles = torch.tensor(np.stack(euler_angles))
    err = torch.abs(euler_angles - e.view(-1, 3)).max()
    assert err < 1e-5

  0%|          | 0/1 [00:00<?, ?it/s]



In [None]:
r = matrots[0].numpy()
e = rotmat2euler(r)
for seq in chain(permutations('xyz'), permutations('XYZ')):
    seq = "".join(seq)
    _e = R.from_matrix(r).as_euler(seq, degrees=False)
    print(seq, e-_e)

xyz [ 1.54809437  0.31419901 -0.26214211]
xzy [1.52761688 0.01639201 0.03872136]
yxz [ 0.91142335  0.98763906 -0.15936584]
yzx [0.97618798 0.14463208 0.71229793]
zxy [0.50030164 0.95476593 0.22065227]
zyx [0.69825729 0.42061958 0.69825729]
XYZ [ 1.53083387  0.42061958 -0.13431928]
XZY [1.5448745  0.14463208 0.1436114 ]
YXZ [ 1.05322884  0.95476593 -0.33227493]
YZX [0.87129794 0.01639201 0.6950403 ]
ZXY [0.67321073 0.98763906 0.07884678]
ZYX [0.57043446 0.31419901 0.7155178 ]


# Scipy Transforms

> Transforms to use during data loading on numpy arrays using `scipy.spatial.transforms`.

In [None]:
# exports
def scipy_aa_to_euler(x):
    s = x.shape
    x = x.reshape(-1, 3)
    euler = np.zeros_like(x)
    for i, aa in enumerate(x):
        euler[i] = R.from_rotvec(aa).as_euler('zyx')
    return euler.reshape(*s)

def scipy_euler_to_aa(x, zero_center=True):
    if zero_center:
        x = x - math.pi
    s = x.shape
    x = x.reshape(-1, 3)
    aa = np.zeros_like(x)
    for i, euler in enumerate(x):
        aa[i] = R.from_euler('zyx', euler).as_rotvec()
    return aa.reshape(*s)

In [None]:
# hide
from nbdev.export import notebook2script

notebook2script()

Converted 00_core.ipynb.
Converted 01_tqdm.ipynb.
Converted 02_features.ipynb.
Converted 03_transforms.ipynb.
Converted 05_losses.ipynb.
Converted index.ipynb.
