In [2]:
import jax.numpy as jnp
from jax import lax, jit
import numpy as np
from scipy.spatial.transform import Rotation

In [92]:
def hat(omg):
    return jnp.array([[0, -omg[2], omg[1]],
                      [omg[2], 0, -omg[0]],
                      [-omg[1], omg[0], 0]])
def vee(so3mat):
    return jnp.array([so3mat[2,1], so3mat[0,2], so3mat[1,0]])
    
def ec(R):
    def zero_angle(acosinput, R):
        return jnp.zeros(3)
    def pi_angle(acosinput, R):
        def get_omg_case1(R):
            return (1.0 / jnp.sqrt(2 * (1 + R[2,2]))) \
                  * jnp.array([R[0,2], R[1,2], 1 + R[2,2]])
        def get_omg_case2(R):
            return (1.0 / jnp.sqrt(2 * (1 + R[1,1]))) \
                  * jnp.array([R[0,1], 1 + R[1,1], R[2,1]])
        def get_omg_case0(R):
            return (1.0 / jnp.sqrt(2 * (1 + R[0,0]))) \
                  * jnp.array([1 + R[0,0], R[1,0], R[2,0]])
        case1 = jnp.abs(1 + R[2,2]) >= 1e-10
        case2 = jnp.abs(1 + R[1,1]) >= 1e-10
        case = case1 + case2*2
        return jnp.pi * lax.switch(case, (get_omg_case0, get_omg_case1, get_omg_case2), R)
    def normal_case(acosinput, R):
        angle = jnp.arccos(acosinput)
        return angle / 2. / jnp.sin(angle) * vee(R - jnp.array(R).T)
    acosinput = (jnp.trace(R) - 1.) / 2.0
    is_zero_angle = acosinput >= 1.
    is_pi_angle = acosinput <= -1.
    cond = (is_zero_angle + is_pi_angle*2).astype(int)
    return lax.switch(cond, (normal_case, zero_angle, pi_angle), acosinput, R)
ec = jit(ec)

In [25]:
from jax import lax
def mat2rpy(R):
    def singular_case(R, sy): #cos(p) = 0
        r = jnp.arctan2(-R[1,2], R[1,1])
        p = jnp.arctan2(-R[2,0], sy)
        y = 0.
        return jnp.array([r, p, y])
    def normal_case(R, sy):
        r = jnp.arctan2(R[2,1] , R[2,2])
        p = jnp.arctan2(-R[2,0], sy)
        y = jnp.arctan2(R[1,0], R[0,0])
        return jnp.array([r, p, y])
    sy = jnp.sqrt(R[2,1]**2 + R[2,2]**2)
    return lax.cond(sy < 1e-6, singular_case, normal_case, R, sy)

def rpy2mat(rpy):
    #rpy vector should be jnp.array([r, p, y])
    r, p, y = rpy
    cx, sx, cy, sy, cz, sz = jnp.cos(r), jnp.sin(r), jnp.cos(p), jnp.sin(p), jnp.cos(y), jnp.sin(y)
    return jnp.array([[cy*cz, sx*sy*cz-sz*cx, sx*sz+sy*cx*cz],
                      [sz*cy, sx*sy*sz+cx*cz, -sx*cz+sy*sz*cx],
                      [-sy,   sx*cy,          cx*cy]])

In [27]:
R = Rotation.random().as_matrix()
R

array([[ 0.93302866,  0.35612203,  0.05132852],
       [ 0.18378387, -0.34906532, -0.91889983],
       [-0.30932347,  0.86679323, -0.39113743]])

In [28]:
rpy2mat(mat2rpy(R))

Array([[ 0.93302864,  0.35612202,  0.05132856],
       [ 0.18378389, -0.34906527, -0.9188999 ],
       [-0.30932346,  0.8667933 , -0.39113742]], dtype=float32)

In [22]:
Rotation.from_matrix(R).as_euler("xyz")

array([ 1.50506082, -0.84967868, -2.73464721])

In [35]:
# test cases
#R = jnp.eye(3) #zero angle
R = Rotation.from_euler("zxy", [np.pi,0,0]).as_matrix() # pi angle case
# R = jnp.array(Rotation.random().as_matrix()) #random angle
R

array([[-1.0000000e+00, -1.2246468e-16,  0.0000000e+00],
       [ 1.2246468e-16, -1.0000000e+00,  0.0000000e+00],
       [ 0.0000000e+00,  0.0000000e+00,  1.0000000e+00]])

In [36]:
rpy2mat(mat2rpy(R))

Array([[-1.000000e+00,  8.742278e-08,  0.000000e+00],
       [-8.742278e-08, -1.000000e+00,  0.000000e+00],
       [ 0.000000e+00,  0.000000e+00,  1.000000e+00]], dtype=float32)