Advanced Automatic Differentiation in JAX

https://jax.readthedocs.io/en/latest/jax-101/04-advanced-autodiff.html

In [None]:
import jax
import jax.numpy as jnp

In [None]:
f = lambda x: x**3 + 2*x**2 - 3*x + 1
dfdx = jax.grad(f)
d2fdx2 = jax.grad(dfdx)
d3fdx3 = jax.grad(d2fdx2)
d4fdx4 = jax.grad(d3fdx3)

In [None]:
def hessian(f):
    """ Hessian is second-order derivative (e.g. Jacobian of the gradient)"""
    # return jax.jacrev(jax.grad(f))
    return jax.jacfwd(jax.grad(f))

In [None]:
f = lambda x: jnp.dot(x, x)

x = jnp.array([1, 2, 3]).astype(jnp.float32)
hessian(f)(x)

In [None]:
# MAML Loss function
def loss(params, data):
    """ MSE. """
    return jnp.mean((data[0] - data[1])**2)

def meta_loss(params, data, lr=1e-5):
    """ Meta loss of loss"""
    grads = jax.grad(loss)(params, data)
    return loss(params - lr*grads, data)

params, data = None, None
meta_grads = jax.grad(meta_loss)(params, data)

In [None]:
# Theta is parameters of the value function model
value_fn = lambda theta, state: jnp.dot(theta, state)
theta = jnp.array([0.1, -0.1, 0.])

# State at time t-1
s_tm1 = jnp.array([1., 2., -1.])
# State at time t
s_t = jnp.array([2., 1., 0.])
# Reward
r_t = jnp.array(1.)

def td_loss(theta, s_tm1, r_t, s_t):
    # temporal difference as loss for value function
    v_tm1 = value_fn(theta, s_tm1)
    target = r_t + value_fn(theta, s_t)
    # stop gradients from getting to value function through s_t
    target = jax.lax.stop_gradient(target)
    return (target - v_tm1) ** 2

def update(theta, s_tm1, r_t, s_t, lr=1e-3):
    grads = jax.grad(td_loss)(theta, s_tm1, r_t, s_t)
    return theta - lr*grads, theta

In [None]:
# Does not apply to batch dimmension
per_example_grads = jax.vmap(jax.grad(td_loss), in_axes=(None, 0, 0, 0))

batch_of_s_tm1 = jnp.stack([s_tm1, s_tm1])
print(f's_tm1 {s_tm1.shape}, batched {batch_of_s_tm1.shape}')
batch_of_s_t = jnp.stack([s_t, s_t])
batch_of_r_t = jnp.stack([r_t, r_t])

per_example_grads(theta, batch_of_s_tm1, batch_of_r_t, batch_of_s_t)

In [None]:
dtd_lossdtheta = jax.grad(td_loss)
dtd_lossdtheta(theta, s_tm1, r_t, s_t)

In [None]:
no_jit = per_example_grads
with_jit = jax.jit(per_example_grads)

%timeit no_jit(theta, batch_of_s_tm1, batch_of_r_t, batch_of_s_t)
%timeit with_jit(theta, batch_of_s_tm1, batch_of_r_t, batch_of_s_t)

In [None]:
# You can use stop_gradient to shield the auto-diff from
# non-differentiable functions

def f(x):
    # non-differentiable
    return jnp.round(x)

def straight_through_f(x):
    return x - jax.lax.stop_gradient(x) + jax.lax.stop_gradient(f(x))