In [16]:
import time 
import numpy as np
import jax.numpy as jp
import jax

import jax
jax.config.update('jax_enable_x64', True)

In [17]:
@jax.jit
def mat2quat_jax(mat):
    """ Convert Rotation Matrix to Quaternion using JAX """
    mat = jp.asarray(mat, dtype=jp.float64)
    assert mat.shape[-2:] == (3, 3), f"Invalid shape matrix {mat.shape}"

    Qxx, Qyx, Qzx, Qxy, Qyy, Qzy, Qxz, Qyz, Qzz = mat.ravel()

    # Construct the symmetric K matrix
    K = jp.array([
        [Qxx - Qyy - Qzz, 0,               0,               0              ],
        [Qyx + Qxy,       Qyy - Qxx - Qzz, 0,               0              ],
        [Qzx + Qxz,       Qzy + Qyz,       Qzz - Qxx - Qyy, 0              ],
        [Qyz - Qzy,       Qzx - Qxz,       Qxy - Qyx,       Qxx + Qyy + Qzz]
    ]) / 3.0

    # Compute eigenvalues and eigenvectors
    vals, vecs = jp.linalg.eigh(K)

    # Select eigenvector corresponding to the largest eigenvalue
    max_idx = jp.argmax(vals)
    q = vecs[[3, 0, 1, 2], max_idx]

    # Ensure quaternion has positive w (sign consistency)
    q = jp.where(q[0] < 0, -q, q)

    return q

@jax.jit
def mat2quat_fast_jax(mat):
    """ Convert a 3x3 normalized rotation matrix to a quaternion using JAX. """
    mat = jp.asarray(mat, dtype=jp.float64)
    assert mat.shape == (3, 3), f"Invalid shape matrix {mat.shape}"

    def case_1(mat):
        trace = 1.0 + mat[0, 0] - mat[1, 1] - mat[2, 2]
        s = 2.0 * jp.sqrt(trace)
        s = jp.where(mat[1, 2] < mat[2, 1], -s, s)
        q1 = 0.25 * s
        s = 1.0 / s
        q0 = (mat[1, 2] - mat[2, 1]) * s
        q2 = (mat[0, 1] + mat[1, 0]) * s
        q3 = (mat[2, 0] + mat[0, 2]) * s
        return jp.array([q0, q1, q2, q3])

    def case_2(mat):
        trace = 1.0 - mat[0, 0] + mat[1, 1] - mat[2, 2]
        s = 2.0 * jp.sqrt(trace)
        s = jp.where(mat[2, 0] < mat[0, 2], -s, s)
        q2 = 0.25 * s
        s = 1.0 / s
        q0 = (mat[2, 0] - mat[0, 2]) * s
        q1 = (mat[0, 1] + mat[1, 0]) * s
        q3 = (mat[1, 2] + mat[2, 1]) * s
        return jp.array([q0, q1, q2, q3])

    def case_3(mat):
        trace = 1.0 - mat[0, 0] - mat[1, 1] + mat[2, 2]
        s = 2.0 * jp.sqrt(trace)
        s = jp.where(mat[0, 1] < mat[1, 0], -s, s)
        q3 = 0.25 * s
        s = 1.0 / s
        q0 = (mat[0, 1] - mat[1, 0]) * s
        q1 = (mat[2, 0] + mat[0, 2]) * s
        q2 = (mat[1, 2] + mat[2, 1]) * s
        return jp.array([q0, q1, q2, q3])

    def case_4(mat):
        trace = 1.0 + mat[0, 0] + mat[1, 1] + mat[2, 2]
        s = 2.0 * jp.sqrt(trace)
        q0 = 0.25 * s
        s = 1.0 / s
        q1 = (mat[1, 2] - mat[2, 1]) * s
        q2 = (mat[2, 0] - mat[0, 2]) * s
        q3 = (mat[0, 1] - mat[1, 0]) * s
        return jp.array([q0, q1, q2, q3])

    # Conditional execution for efficiency
    q = jax.lax.cond(
        mat[2, 2] < 0.0,
        lambda mat: jax.lax.cond(
            mat[0, 0] > mat[1, 1], case_1, case_2, mat
        ),
        lambda mat: jax.lax.cond(
            mat[0, 0] < -mat[1, 1], case_3, case_4, mat
        ),
        mat
    )

    q = q.at[1:].set(-q[1:])
    return q


In [18]:
def mat2quat(mat):
    """Convert Rotation Matrix to Quaternion"""
    mat = np.asarray(mat, dtype=np.float64)
    assert mat.shape[-2:] == (3, 3), "Invalid shape matrix {}".format(mat)

    Qxx, Qyx, Qzx, Qxy, Qyy, Qzy, Qxz, Qyz, Qzz = mat.flat
    # Fill only lower half of symmetric matrix
    K = np.array([
        [Qxx - Qyy - Qzz, 0,               0,               0              ],
        [Qyx + Qxy,       Qyy - Qxx - Qzz, 0,               0              ],
        [Qzx + Qxz,       Qzy + Qyz,       Qzz - Qxx - Qyy, 0              ],
        [Qyz - Qzy,       Qzx - Qxz,       Qxy - Qyx,       Qxx + Qyy + Qzz]]
        ) / 3.0
    # Use Hermitian eigenvectors, values for speed
    vals, vecs = np.linalg.eigh(K)
    # Select largest eigenvector, reorder to w,x,y,z quaternion
    q = vecs[[3, 0, 1, 2], np.argmax(vals)]
    # Prefer quaternion with positive w
    # (q * -1 corresponds to same rotation as q)
    if q[0] < 0:
        q *= -1
        
    return q

def mat2quat_fast(mat):
    """ Convert a 3x3 normalized rotation matrix to a quaternion, ensuring consistent signs. """
    mat = np.asarray(mat, dtype=np.float64)
    assert mat.shape == (3, 3), f"Invalid shape matrix {mat.shape}"

    q = np.zeros(4, dtype=np.float64)

    if mat[2, 2] < 0.0:
        if mat[0, 0] > mat[1, 1]:
            trace = 1.0 + mat[0, 0] - mat[1, 1] - mat[2, 2]
            s = 2.0 * np.sqrt(trace)
            if mat[1, 2] < mat[2, 1]:
                s = -s
            q[1] = 0.25 * s
            s = 1.0 / s
            q[0] = (mat[1, 2] - mat[2, 1]) * s
            q[2] = (mat[0, 1] + mat[1, 0]) * s
            q[3] = (mat[2, 0] + mat[0, 2]) * s
        else:
            trace = 1.0 - mat[0, 0] + mat[1, 1] - mat[2, 2]
            s = 2.0 * np.sqrt(trace)
            if mat[2, 0] < mat[0, 2]:
                s = -s
            q[2] = 0.25 * s
            s = 1.0 / s
            q[0] = (mat[2, 0] - mat[0, 2]) * s
            q[1] = (mat[0, 1] + mat[1, 0]) * s
            q[3] = (mat[1, 2] + mat[2, 1]) * s
    else:
        if mat[0, 0] < -mat[1, 1]:
            trace = 1.0 - mat[0, 0] - mat[1, 1] + mat[2, 2]
            s = 2.0 * np.sqrt(trace)
            if mat[0, 1] < mat[1, 0]:
                s = -s
            q[3] = 0.25 * s
            s = 1.0 / s
            q[0] = (mat[0, 1] - mat[1, 0]) * s
            q[1] = (mat[2, 0] + mat[0, 2]) * s
            q[2] = (mat[1, 2] + mat[2, 1]) * s
        else:
            trace = 1.0 + mat[0, 0] + mat[1, 1] + mat[2, 2]
            s = 2.0 * np.sqrt(trace)
            q[0] = 0.25 * s
            s = 1.0 / s
            q[1] = (mat[1, 2] - mat[2, 1]) * s
            q[2] = (mat[2, 0] - mat[0, 2]) * s
            q[3] = (mat[0, 1] - mat[1, 0]) * s

    q[1:] = q[1:] * -1
    return q

In [19]:
mat2quat_test_cases = [
            # Identity matrix (no rotation)
            np.eye(3, dtype=np.float64),
            
            # 90-degree rotations around principal axes
            np.array([[1., 0., 0.],
                     [0., 0., -1.],
                     [0., 1., 0.]], dtype=np.float64),  # 90° around x
            
            np.array([[0., 0., 1.],
                     [0., 1., 0.],
                     [-1., 0., 0.]], dtype=np.float64),  # 90° around y
            
            np.array([[0., -1., 0.],
                     [1., 0., 0.],
                     [0., 0., 1.]], dtype=np.float64),  # 90° around z
            
            # 45-degree rotations
            np.array([[1., 0., 0.],
                     [0., 0.7071068, -0.7071068],
                     [0., 0.7071068, 0.7071068]], dtype=np.float64),  # 45° around x
            
            np.array([[0.7071068, 0., 0.7071068],
                     [0., 1., 0.],
                     [-0.7071068, 0., 0.7071068]], dtype=np.float64),  # 45° around y
            
            np.array([[0.7071068, -0.7071068, 0.],
                     [0.7071068, 0.7071068, 0.],
                     [0., 0., 1.]], dtype=np.float64),  # 45° around z
            
            # Combined rotation
            np.array([[0.3536, -0.6124, 0.7071],
                     [0.866007, 0.500033, 0.],
                     [-0.353576, 0.612359, 0.707107]], dtype=np.float64),
        ]

In [20]:
for idx in range(7):
    print ("Current Implementation", mat2quat(mat2quat_test_cases[idx]))
    print ("JAX as it is", mat2quat_jax(jp.array(mat2quat_test_cases[idx])))
    print ("From Blender", mat2quat_fast(mat2quat_test_cases[idx]))
    print ("JAX From Blender", mat2quat_fast_jax(jp.array(mat2quat_test_cases[idx])))
    print ("-----------------------------")

Current Implementation [1. 0. 0. 0.]
JAX as it is [1. 0. 0. 0.]
From Blender [ 1. -0. -0. -0.]
JAX From Blender [ 1. -0. -0. -0.]
-----------------------------
Current Implementation [ 0.70710678  0.70710678 -0.         -0.        ]
JAX as it is [ 0.70710678  0.70710678 -0.         -0.        ]
From Blender [ 0.70710678  0.70710678 -0.         -0.        ]
JAX From Blender [ 0.70710678  0.70710678 -0.         -0.        ]
-----------------------------
Current Implementation [ 0.70710678 -0.          0.70710678 -0.        ]
JAX as it is [ 0.70710678 -0.          0.70710678 -0.        ]
From Blender [ 0.70710678 -0.          0.70710678 -0.        ]
JAX From Blender [ 0.70710678 -0.          0.70710678 -0.        ]
-----------------------------
Current Implementation [0.70710678 0.         0.         0.70710678]
JAX as it is [0.70710678 0.         0.         0.70710678]
From Blender [ 0.70710678 -0.         -0.          0.70710678]
JAX From Blender [ 0.70710678 -0.         -0.          0.

In [21]:
# Generate 100 random rotation matrices using the QR decomposition method
def random_rotation_matrices(n=100):
    """Generate n random 3x3 rotation matrices."""
    matrices = []
    for _ in range(n):
        rand_mat = np.random.randn(3, 3)  # Random 3x3 matrix
        q, r = np.linalg.qr(rand_mat)  # QR decomposition
        q *= np.sign(np.linalg.det(q))  # Ensure determinant is +1 for valid rotation
        matrices.append(q)
    return np.array(matrices)

In [27]:


# Generate test matrices
test_matrices = random_rotation_matrices(10000)

# Benchmark NumPy Fast version
start_fast_numpy = time.time()
numpy_fast_results = np.array([mat2quat_fast(mat) for mat in test_matrices])
time_fast_numpy = time.time() - start_fast_numpy

# Benchmark JAX Fast version
start_fast_jax = time.time()
jax_fast_results = jax.vmap(mat2quat_fast_jax)(jp.array(test_matrices))
time_fast_jax = time.time() - start_fast_jax

# Benchmark NumPy version
start_numpy = time.time()
numpy_results = np.array([mat2quat(mat) for mat in test_matrices])
time_numpy = time.time() - start_numpy

# Benchmark JAX  version
start_jax = time.time()
jax_results = jax.vmap(mat2quat_jax)(jp.array(test_matrices))
time_jax = time.time() - start_jax

# Function to count mismatches
def count_mismatches(arr1, arr2, atol=1e-3):
    return np.sum(~np.isclose(arr1, arr2, atol=atol))

# Count mismatches between different implementations
mismatches_numpy_slow_numpy_fast = count_mismatches(numpy_results, numpy_fast_results)
mismatches_numpy_slow_jax_slow = count_mismatches(numpy_results, np.array(jax_results))
mismatches_numpy_slow_jax_fast = count_mismatches(numpy_results, np.array(jax_fast_results))

print("numpy fast time:", time_numpy)
print("jax fast time:", time_jax)
print("numpy time:", time_numpy)
print("jax time:", time_jax)

print("comparison_numpy_slow_numpy_fast:", mismatches_numpy_slow_numpy_fast)
print("comparison_numpy_slow_jax_slow:", mismatches_numpy_slow_jax_slow)
print("comparison_numpy_slow_jax_fast:", mismatches_numpy_slow_jax_fast)


numpy fast time: 0.09654593467712402
jax fast time: 0.03492927551269531
numpy time: 0.09654593467712402
jax time: 0.03492927551269531
comparison_numpy_slow_numpy_fast: 0
comparison_numpy_slow_jax_slow: 39747
comparison_numpy_slow_jax_fast: 0
