In [None]:
#| default_exp utils

In [None]:
#| hide
from nbdev.showdoc import *

## SO(3) parameterization conversions

Implementations to convert `rotation_10d` ([Peretroukhin et al., 2021](https://arxiv.org/abs/2006.01031)) and `quaternion_adjugate` ([Hanson and Hanson, 2022](https://arxiv.org/abs/2205.09116)) parameterizations of SO(3) to quaternions. These can then be turned into rotation matrices by `PyTorch3D`.

In [None]:
#| export
import torch

In [None]:
#| exporti
def _10vec_to_4x4symmetric(vec):
    """Convert a 10-vector to a symmetric 4x4 matrix."""
    b = len(vec)
    A = torch.zeros(b, 4, 4, device=vec.device)
    idx, jdx = torch.triu_indices(4, 4)
    A[..., idx, jdx] = vec
    A[..., jdx, idx] = vec
    return A

In [None]:
#| export
def rotation_10d_to_quaternion(rotations: torch.Tensor) -> torch.Tensor:
    """
    Convert a 10-vector into a symmetric matrix, whose eigenvector corresponding
    to the eigenvalue of minimum modulus is the resulting quaternion.

    Source: https://arxiv.org/abs/2006.01031
    """
    A = _10vec_to_4x4symmetric(rotations)  # A is a symmetric data matrix
    return torch.linalg.eigh(A).eigenvectors[..., 0]

In [None]:
#| export
def quaternion_adjugate_to_quaternion(rotations: torch.Tensor) -> torch.Tensor:
    """
    Convert a 10-vector in the quaternion adjugate, a symmetric matrix whose
    eigenvector corresponding to the eigenvalue of maximum modulus is the
    (unnormalized) quaternion. Uses a fast method to solve for the eigenvector
    without explicity computing the eigendecomposition.

    Source: https://arxiv.org/abs/2205.09116
    """
    A = _10vec_to_4x4symmetric(rotations)  # A is the quaternion adjugate
    norms = A.norm(dim=1).amax(dim=1, keepdim=True)
    max_eigenvectors = torch.argmax(A.norm(dim=1), dim=1)
    return A[range(len(A)), max_eigenvectors] / norms

In [None]:
#| hide
import nbdev

nbdev.nbdev_export()