In [1]:
import jax.numpy as jnp
import numpy as np

import jax

In [2]:
def rotationX(q):
    cq = jnp.cos(q)
    sq = jnp.sin(q)
    m = jnp.array([
        [1.0, 0.0, 0.0, 0.0],
        [0.0, cq, -sq, 0.0],
        [0.0, sq, cq, 0.0],
        [0.0, 0.0, 0.0, 1.0],
    ])
    return m

def rotationY(q):
    cq = jnp.cos(q)
    sq = jnp.sin(q)
    m = jnp.array([
        [cq, 0.0, sq, 0.0],
        [0.0, 1.0, 0.0, 0.0],
        [-sq, 0.0, cq, 0.0],
        [0.0, 0.0, 0.0, 1.0],
    ])
    return m

def rotationZ(q):
    cq = jnp.cos(q)
    sq = jnp.sin(q)
    m = jnp.array([
        [cq, -sq, 0.0, 0.0],
        [sq, cq, 0.0, 0.0],
        [0.0, 0.0, 1.0, 0.0],
        [0.0, 0.0, 0.0, 1.0],
    ])
    return m

def translationX(q):
    m = jnp.array([
        [1.0, 0.0, 0.0, q],
        [0.0, 1.0, 0.0, 0.0],
        [0.0, 0.0, 1.0, 0.0],
        [0.0, 0.0, 0.0, 1.0],
    ])
    return m

def translationY(q):
    m = jnp.array([
        [1.0, 0.0, 0.0, 0.0],
        [0.0, 1.0, 0.0, q],
        [0.0, 0.0, 1.0, 0.0],
        [0.0, 0.0, 0.0, 1.0],
    ])
    return m

def translationZ(q):
    m = jnp.array([
        [1.0, 0.0, 0.0, 0.0],
        [0.0, 1.0, 0.0, 0.0],
        [0.0, 0.0, 1.0, q],
        [0.0, 0.0, 0.0, 1.0],
    ])
    return m

In [3]:
parXparQ = jax.grad(rotationX)

In [4]:
rotationX(0.)



DeviceArray([[ 1.,  0.,  0.,  0.],
             [ 0.,  1., -0.,  0.],
             [ 0.,  0.,  1.,  0.],
             [ 0.,  0.,  0.,  1.]], dtype=float32)

In [5]:
parXparQ = jax.jacobian(rotationX)

In [6]:
parXparQ(0.)

DeviceArray([[ 0.,  0.,  0.,  0.],
             [ 0.,  0., -1.,  0.],
             [ 0.,  1.,  0.,  0.],
             [ 0.,  0.,  0.,  0.]], dtype=float32)

In [21]:
x = lambda q: jnp.dot(rotationX(q[0]), translationY(q[1]))

In [22]:
xv = jax.jit(x)

In [23]:
xv([1.5, 1.0])

DeviceArray([[ 1.       ,  0.       ,  0.       ,  0.       ],
             [ 0.       ,  0.0707372, -0.997495 ,  0.0707372],
             [ 0.       ,  0.997495 ,  0.0707372,  0.997495 ],
             [ 0.       ,  0.       ,  0.       ,  1.       ]],
            dtype=float32)

In [24]:
dx = jax.jit(jax.jacobian(x))

In [25]:
dx([1.5, 1.])

[DeviceArray([[ 0.       ,  0.       ,  0.       ,  0.       ],
              [ 0.       , -0.997495 , -0.0707372, -0.997495 ],
              [ 0.       ,  0.0707372, -0.997495 ,  0.0707372],
              [ 0.       ,  0.       ,  0.       ,  0.       ]],
             dtype=float32),
 DeviceArray([[0.       , 0.       , 0.       , 0.       ],
              [0.       , 0.       , 0.       , 0.0707372],
              [0.       , 0.       , 0.       , 0.997495 ],
              [0.       , 0.       , 0.       , 0.       ]], dtype=float32)]

In [28]:
jax.jvp(x, ([1.5, 1.],), ([1.0, 1.0],))

(DeviceArray([[ 1.       ,  0.       ,  0.       ,  0.       ],
              [ 0.       ,  0.0707372, -0.997495 ,  0.0707372],
              [ 0.       ,  0.997495 ,  0.0707372,  0.997495 ],
              [ 0.       ,  0.       ,  0.       ,  1.       ]],
             dtype=float32),
 DeviceArray([[ 0.       ,  0.       ,  0.       ,  0.       ],
              [ 0.       , -0.997495 , -0.0707372, -0.92675781],
              [ 0.       ,  0.0707372, -0.997495 ,  1.06823218],
              [ 0.       ,  0.       ,  0.       ,  0.       ]],
             dtype=float32))