In [1]:
import jax.numpy as jnp
import numpy as np

import jax

In [3]:
t = jnp.array([1, 2, 3])
x, y, z = t
print x
print y
print z

1
2
3


In [4]:
def make_pose(translation, rotation, axis):
    cq = jnp.cos(0.5 * rotation)
    sq = jnp.sin(0.5 * rotation)
    return jnp.array([translation[0], translation[1], translation[2], cq, sq * axis[0], sq * axis[1], sq * axis[2]])

def multiply(p_left, p_right):
    ltx, lty, ltz, lqw, lqx, lqy, lqz = p_left
    rtx, rty, rtz, rqw, rqx, rqy, rqz = p_right
    tw = -lqx*rtx - lqy*rty - lqz*rtz
    tx = lqw*rtx + lqy*rtz - lqz*rty
    ty = lqw*rty - lqx*rtz + lqz*rtx
    tz = lqw*rtz + lqx*rty - lqy*rtx
    
    tx =-tw*lqx + tx*lqw - ty*lqz + tz*lqy + ltx
    ty =-tw*lqy + tx*lqz + ty*lqw - tz*lqx + lty
    tz =-tw*lqz - tx*lqy + ty*lqx + tz*lqw + ltz
    
    qw = lqw*rqw - lqx*rqx - lqy*rqy - lqz*rqz
    qx = lqw*rqx + lqx*rqw + lqy*rqz - lqz*rqy
    qy = lqw*rqy - lqx*rqz + lqy*rqw + lqz*rqx
    qz = lqw*rqz + lqx*rqy - lqy*rqx + lqz*rqw
    return jnp.array([tx, ty, tz, qw, qx, qy, qz])

In [5]:
def SmallTree(q):
    t0 = make_pose([q[0], 0, 0], 0, [1, 0, 0])
    r0 = make_pose([0, 0, 0], q[1], [0, 0, 1])
    t1 = make_pose([0, q[2], q[2]], 0, [1, 0, 0])
    m = 0.5 * np.sqrt(2.)
    r2 = make_pose([0, 0, 0], q[3], [m, 0, m])

    p0 = multiply(r0, t0)
    p1 = multiply(p0, t1)
    p2 = multiply(p1, r2)
    return p2

v = [0.1, 0.2, 0.3, 0.1]
SmallTree(v)

DeviceArray([0.03840585, 0.32439619, 0.30000001, 0.99023253, 0.03516405,
             0.00352817, 0.1348727 ], dtype=float32)

In [29]:
jax.vmap(SmallTree)(jnp.array([v,v,v]))

DeviceArray([[0.03840585, 0.32439619, 0.30000001, 0.99023253, 0.03516405,
              0.00352817, 0.1348727 ],
             [0.03840585, 0.32439619, 0.30000001, 0.99023253, 0.03516405,
              0.00352817, 0.1348727 ],
             [0.03840585, 0.32439619, 0.30000001, 0.99023253, 0.03516405,
              0.00352817, 0.1348727 ]], dtype=float32)

In [8]:
m = jax.jit(jax.jacobian(SmallTree))(v)
np.array(m).shape

(4, 7)