In [126]:
import numpy as np
import jax.numpy as jp
import jax

import jax
jax.config.update('jax_enable_x64', True)

In [127]:
def mat2quat_jax(mat):
    """ Convert Rotation Matrix to Quaternion using JAX """
    mat = jp.asarray(mat, dtype=jp.float64)
    assert mat.shape[-2:] == (3, 3), f"Invalid shape matrix {mat.shape}"

    Qxx, Qyx, Qzx, Qxy, Qyy, Qzy, Qxz, Qyz, Qzz = mat.ravel()

    # Construct the symmetric K matrix
    K = jp.array([
        [Qxx - Qyy - Qzz, 0,               0,               0              ],
        [Qyx + Qxy,       Qyy - Qxx - Qzz, 0,               0              ],
        [Qzx + Qxz,       Qzy + Qyz,       Qzz - Qxx - Qyy, 0              ],
        [Qyz - Qzy,       Qzx - Qxz,       Qxy - Qyx,       Qxx + Qyy + Qzz]
    ]) / 3.0

    # Compute eigenvalues and eigenvectors
    vals, vecs = jp.linalg.eigh(K)

    # Select eigenvector corresponding to the largest eigenvalue
    max_idx = jp.argmax(vals)
    q = vecs[[3, 0, 1, 2], max_idx]

    # Ensure quaternion has positive w (sign consistency)
    q = jp.where(q[0] < 0, -q, q)

    return q


In [128]:
def mat2quat(mat):
    """Convert Rotation Matrix to Quaternion"""
    mat = np.asarray(mat, dtype=np.float64)
    assert mat.shape[-2:] == (3, 3), "Invalid shape matrix {}".format(mat)

    Qxx, Qyx, Qzx, Qxy, Qyy, Qzy, Qxz, Qyz, Qzz = mat.flat
    # Fill only lower half of symmetric matrix
    K = np.array([
        [Qxx - Qyy - Qzz, 0,               0,               0              ],
        [Qyx + Qxy,       Qyy - Qxx - Qzz, 0,               0              ],
        [Qzx + Qxz,       Qzy + Qyz,       Qzz - Qxx - Qyy, 0              ],
        [Qyz - Qzy,       Qzx - Qxz,       Qxy - Qyx,       Qxx + Qyy + Qzz]]
        ) / 3.0
    # Use Hermitian eigenvectors, values for speed
    vals, vecs = np.linalg.eigh(K)
    # Select largest eigenvector, reorder to w,x,y,z quaternion
    q = vecs[[3, 0, 1, 2], np.argmax(vals)]
    # Prefer quaternion with positive w
    # (q * -1 corresponds to same rotation as q)
    if q[0] < 0:
        q *= -1
        
    return q

def mat2quat_fast(mat):
    """ Convert a 3x3 normalized rotation matrix to a quaternion, ensuring consistent signs. """
    mat = np.asarray(mat, dtype=np.float64)
    assert mat.shape == (3, 3), f"Invalid shape matrix {mat.shape}"

    q = np.zeros(4, dtype=np.float64)

    if mat[2, 2] < 0.0:
        if mat[0, 0] > mat[1, 1]:
            trace = 1.0 + mat[0, 0] - mat[1, 1] - mat[2, 2]
            s = 2.0 * np.sqrt(trace)
            if mat[1, 2] < mat[2, 1]:
                s = -s
            q[1] = 0.25 * s
            s = 1.0 / s
            q[0] = (mat[1, 2] - mat[2, 1]) * s
            q[2] = (mat[0, 1] + mat[1, 0]) * s
            q[3] = (mat[2, 0] + mat[0, 2]) * s
        else:
            trace = 1.0 - mat[0, 0] + mat[1, 1] - mat[2, 2]
            s = 2.0 * np.sqrt(trace)
            if mat[2, 0] < mat[0, 2]:
                s = -s
            q[2] = 0.25 * s
            s = 1.0 / s
            q[0] = (mat[2, 0] - mat[0, 2]) * s
            q[1] = (mat[0, 1] + mat[1, 0]) * s
            q[3] = (mat[1, 2] + mat[2, 1]) * s
    else:
        if mat[0, 0] < -mat[1, 1]:
            trace = 1.0 - mat[0, 0] - mat[1, 1] + mat[2, 2]
            s = 2.0 * np.sqrt(trace)
            if mat[0, 1] < mat[1, 0]:
                s = -s
            q[3] = 0.25 * s
            s = 1.0 / s
            q[0] = (mat[0, 1] - mat[1, 0]) * s
            q[1] = (mat[2, 0] + mat[0, 2]) * s
            q[2] = (mat[1, 2] + mat[2, 1]) * s
        else:
            trace = 1.0 + mat[0, 0] + mat[1, 1] + mat[2, 2]
            s = 2.0 * np.sqrt(trace)
            q[0] = 0.25 * s
            s = 1.0 / s
            q[1] = (mat[1, 2] - mat[2, 1]) * s
            q[2] = (mat[2, 0] - mat[0, 2]) * s
            q[3] = (mat[0, 1] - mat[1, 0]) * s
    
    return q

In [129]:
mat2quat_test_cases = [
            # Identity matrix (no rotation)
            np.eye(3, dtype=np.float64),
            
            # 90-degree rotations around principal axes
            np.array([[1., 0., 0.],
                     [0., 0., -1.],
                     [0., 1., 0.]], dtype=np.float64),  # 90° around x
            
            np.array([[0., 0., 1.],
                     [0., 1., 0.],
                     [-1., 0., 0.]], dtype=np.float64),  # 90° around y
            
            np.array([[0., -1., 0.],
                     [1., 0., 0.],
                     [0., 0., 1.]], dtype=np.float64),  # 90° around z
            
            # 45-degree rotations
            np.array([[1., 0., 0.],
                     [0., 0.7071068, -0.7071068],
                     [0., 0.7071068, 0.7071068]], dtype=np.float64),  # 45° around x
            
            np.array([[0.7071068, 0., 0.7071068],
                     [0., 1., 0.],
                     [-0.7071068, 0., 0.7071068]], dtype=np.float64),  # 45° around y
            
            np.array([[0.7071068, -0.7071068, 0.],
                     [0.7071068, 0.7071068, 0.],
                     [0., 0., 1.]], dtype=np.float64),  # 45° around z
            
            # Combined rotation
            np.array([[0.3536, -0.6124, 0.7071],
                     [0.866007, 0.500033, 0.],
                     [-0.353576, 0.612359, 0.707107]], dtype=np.float64),
        ]

In [130]:
for idx in range(7):
    print ("Current Implementation", mat2quat(mat2quat_test_cases[idx]))
    print ("JAX as it is", mat2quat_jax(mat2quat_test_cases[idx]))
    print ("From Blender", mat2quat_fast(mat2quat_test_cases[idx]))
    print ("-----------------------------")

Current Implementation [1. 0. 0. 0.]
JAX as it is [1. 0. 0. 0.]
From Blender [1. 0. 0. 0.]
-----------------------------
Current Implementation [ 0.70710678  0.70710678 -0.         -0.        ]
JAX as it is [ 0.70710678  0.70710678 -0.         -0.        ]
From Blender [ 0.70710678 -0.70710678  0.          0.        ]
-----------------------------
Current Implementation [ 0.70710678 -0.          0.70710678 -0.        ]
JAX as it is [ 0.70710678 -0.          0.70710678 -0.        ]
From Blender [ 0.70710678  0.         -0.70710678  0.        ]
-----------------------------
Current Implementation [0.70710678 0.         0.         0.70710678]
JAX as it is [0.70710678 0.         0.         0.70710678]
From Blender [ 0.70710678  0.          0.         -0.70710678]
-----------------------------
Current Implementation [ 0.92387953  0.38268343 -0.         -0.        ]
JAX as it is [ 0.97324899  0.22975292 -0.         -0.        ]
From Blender [ 0.92387954 -0.38268344  0.          0.        ]
-