In [None]:
#| default_exp utils

In [None]:
#| hide
from nbdev.showdoc import *

## SO(3) parameterization conversions

Implementations to convert `rotation_10d` ([Peretroukhin et al., 2021](https://arxiv.org/abs/2006.01031)) and `quaternion_adjugate` ([Hanson and Hanson, 2022](https://arxiv.org/abs/2205.09116)) parameterizations of SO(3) to quaternions. These can then be turned into rotation matrices by `PyTorch3D`.

In [None]:
#| export
import torch

In [None]:
#| exporti
PARAMETERIZATIONS = [
    "axis_angle",
    "euler_angles",
    "matrix",
    "quaternion",
    "rotation_6d",
    "rotation_10d",
    "quaternion_adjugate",
]

In [None]:
#| exporti
def _10vec_to_4x4symmetric(vec):
    """Convert a 10-vector to a symmetric 4x4 matrix."""
    b = len(vec)
    A = torch.zeros(b, 4, 4, device=vec.device)
    idx, jdx = torch.triu_indices(4, 4)
    A[..., idx, jdx] = vec
    A[..., jdx, idx] = vec
    return A

In [None]:
#| export
def rotation_10d_to_quaternion(rotation: torch.Tensor) -> torch.Tensor:
    """
    Convert a 10-vector into a symmetric matrix, whose eigenvector corresponding
    to the eigenvalue of minimum modulus is the resulting quaternion.

    Source: https://arxiv.org/abs/2006.01031
    """
    A = _10vec_to_4x4symmetric(rotation)  # A is a symmetric data matrix
    return torch.linalg.eigh(A).eigenvectors[..., 0]

In [None]:
#| export
def quaternion_adjugate_to_quaternion(rotation: torch.Tensor) -> torch.Tensor:
    """
    Convert a 10-vector in the quaternion adjugate, a symmetric matrix whose
    eigenvector corresponding to the eigenvalue of maximum modulus is the
    (unnormalized) quaternion. Uses a fast method to solve for the eigenvector
    without explicity computing the eigendecomposition.

    Source: https://arxiv.org/abs/2205.09116
    """
    A = _10vec_to_4x4symmetric(rotation)  # A is the quaternion adjugate
    norms = A.norm(dim=1).amax(dim=1, keepdim=True)
    max_eigenvectors = torch.argmax(A.norm(dim=1), dim=1)
    return A[range(len(A)), max_eigenvectors] / norms

In [None]:
#| export
def convert(
    rotation,
    input_parameterization,
    output_parameterization,
    input_convention=None,
    output_convention=None,
):
    """
    Convert a rotation in SO(3) from some parameterization to another.
    Intermediated by temporary conversion to a rotation matrix.

    If input or output parameterizations are `euler_angles`, need to specify
    `input_convention` or `output_convention`.

    Note: cannot convert to `rotation_10d` or `quaternion_adjugate` because
    there is no unique mapping from SO(3) to these representations.
    """
    if output_parameterization in ["rotation_10d", "quaternion_adjugate"]:
        raise ValueError(
            f"Cannot convert {input_parameterization} to a unique 10-dimensional representation"
        )

    matrix = _convert_to_rotation_matrix(
        rotation, input_parameterization, input_convention
    )
    return _convert_from_rotation_matrix(
        matrix, output_parameterization, output_convention
    )

In [None]:
#| exporti
from pytorch3d.transforms import (
    axis_angle_to_matrix,
    euler_angles_to_matrix,
    quaternion_to_matrix,
    rotation_6d_to_matrix,
)


def _convert_to_rotation_matrix(rotation, parameterization, convention):
    """Convert any parameterization of a rotation to a matrix representation."""
    if parameterization == "axis_angle":
        R = axis_angle_to_matrix(rotation)
    elif parameterization == "euler_angles":
        R = euler_angles_to_matrix(rotation, convention)
    elif parameterization == "matrix":
        R = rotation
    elif parameterization == "quaternion":
        R = quaternion_to_matrix(rotation)
    elif parameterization == "rotation_6d":
        R = rotation_6d_to_matrix(rotation)
    elif parameterization == "rotation_10d":
        R = quaternion_to_matrix(rotation_10d_to_quaternion(rotation))
    elif parameterization == "quaternion_adjugate":
        R = quaternion_to_matrix(quaternion_adjugate_to_quaternion(rotation))
    else:
        raise ValueError(
            f"parameterization must be in {PARAMETERIZATIONS}, not {parameterization}"
        )
    return R

In [None]:
#| exporti
from pytorch3d.transforms import (
    matrix_to_axis_angle,
    matrix_to_euler_angles,
    matrix_to_quaternion,
    matrix_to_rotation_6d,
)


def _convert_from_rotation_matrix(matrix, parameterization, convention=None):
    "Convert a rotation matrix to any allowed parameterization."
    if parameterization == "axis_angle":
        rotation = matrix_to_axis_angle(matrix)
    elif parameterization == "euler_angles":
        rotation = matrix_to_euler_angles(matrix, convention)
    elif parameterization == "matrix":
        rotation = matrix
    elif parameterization == "quaternion":
        rotation = matrix_to_quaternion(matrix)
    elif parameterization == "rotation_6d":
        rotation = matrix_to_rotation_6d(matrix)
    elif parameterization in ["rotation_10d", "quaternion_adjugate"]:
        raise ValueError(
            "Cannot convert an element in SO(3) to a unique 10-dimensional representation"
        )
    else:
        raise ValueError(
            f"parameterization must be in {PARAMETERIZATIONS}, not {parameterization}"
        )
    return rotation

In [None]:
#| hide
import nbdev

nbdev.nbdev_export()