In [2]:
import jax 
import jax.numpy as jnp 
from jax import grad, jit, vmap 
from jax import random

key = random.key(0)

#### Higher-Order Derivatives 

JAX provides two transformation for computing the Jacobian of a function ``jax.jacfwd()`` and ``jax.jacrev()`` corresponding to forward and reverse mode auto-differentiation

In [None]:
def hessian(f):
    return jax.jacfwd(jax.grad(f))

def f(x):
    return jnp.dot(x, x)
hessian(f)(jnp.array([1., 2., 3.]))

Array([[2., 0., 0.],
       [0., 2., 0.],
       [0., 0., 2.]], dtype=float32)

To differentiate through gradient updates for __Model Agnostic Meta Learning (MAML)__, again JAX can be easily used.

In [None]:
def meta_loss_fn(params, data):
    """compute the loss after one step of SGD"""

    grads = jax.grad(loss_fn)(params, data)
    return loss_fn(params - lr * grads, data)

meta_grads = jax.grad(meta_loss_fn)(params, data)

#### Stopping Gradients

Sometimes, you might want to avoid backpropagating gradients through some subset of the computational graph.

In the case of __Temporal Difference TD(0)__ update, it is used to learn to estimate the value of a state in an environment from experience of interacting with the environment.

TD(0) update to the network parameter is:
$$Δθ = (r_{t} + v_{θ}(s_{t}) - v_{θ}(S_{t}-1)) ∇ v_{θ}(S_{t}-1)$$

This update is not the gradient of any loss functino. But it can be written as the gradient of the psuedo-loss function

$$L_{\theta} = -\frac{1}{2}[r_{t} + v_{θ}(s_{t})- v_{θ}(s_t-1)^2]$$

if the __dependency of the target $r_t + v_{θ}(s_t)$ on the parameter $\theta$ is ignored__.

We can use ``jax.lax.stop_gradient()`` to force JAX to ignore the dependency of the target on $\theta$

In [4]:
## Value function and initial parameters
value_fn = lambda theta, state: jnp.dot(theta, state)
theta = jnp.array([0.1, -0.1, 0.])

#An example transition 
s_tm1 = jnp.array([1., 2., -1.])
r_t = jnp.array(1.)
s_t = jnp.array([2., 1., 0.])

In [5]:
def td_loss(theta, s_tm1, r_t, s_t):

    v_tm1 = value_fn(theta, s_tm1)
    target = r_t + value_fn(theta, s_t)
    return -0.5 * ((jax.lax.stop_gradient(target) - v_tm1) ** 2)

td_update = jax.grad(td_loss)
delta_theta = td_update(theta, s_tm1, r_t, s_t)
delta_theta

Array([ 1.2,  2.4, -1.2], dtype=float32)

In [6]:
#cross-checking
s_grad = jax.grad(value_fn)(theta, s_tm1)
delta_theta = (r_t + value_fn(theta, s_t) - value_fn(theta, s_tm1)) * s_grad
delta_theta

Array([ 1.2,  2.4, -1.2], dtype=float32)

Straight-through estimator using ``stop_gradient``

It is a trick for defining a _gradient_ of a function that is otherwise non-differentiable.

Given a non-differentiable function, $f: \mathbb{R}^n → \mathbb{R}^n$ that is used as part of a larger function that we wish to find a gradient of, we simply pretend during the forward pass that $f$ is the identity function.

In [7]:
def f(x):
    return jnp.round(x)

def straight_through_f(x):
    #create an exactly-zero expression with Sternbenz lemma that 
    #an exactly-one gradient

    zero = x - jax.lax.stop_gradient(x)
    return zero + jax.lax.stop_gradient(f(x))

print("f(x):", f(3.2))
print("straight_through_f(3.2):", straight_through_f(3.2))

print("grad(f)(x):", jax.grad(f)(3.2))
print("grad(straight_through_f)(x):", jax.grad(straight_through_f)(3.2))

f(x): 3.0
straight_through_f(3.2): 3.0
grad(f)(x): 0.0
grad(straight_through_f)(x): 1.0


__Hessian-vector__ products with ``jax.grad`` of ``jax.grad``

A Hessian-vector product function can be useful for minimizing smooth convex functions, or for studying the curvature of neural network training objectives

In [10]:
def hvp(f, x, v):
    return grad(lambda x: jnp.vdot(grad(f)(x), v))(x) 

Jacobian and Hessian using ``jax.jacfwd`` and ``jax.jacrev``

In [12]:
from jax import jacfwd, jacrev 

#define a sigmoid function 

def sigmoid(x):
    return 0.5 * (jnp.tanh(x / 2.) + 1)

#outputs probability of a label being true
def predict(W, b, inputs):
    return sigmoid(jnp.dot(inputs, W)+b)

#build a toy dataset   

inputs = jnp.array([[0.52, 1.12,  0.77],
                     [0.88, -1.08, 0.15],
                     [0.52, 0.06, -1.30],
                     [0.74, -2.49, 1.39]])

#initialize random model parameters
key, W_key, b_key = random.split(key, 3)
W = random.normal(W_key, (3,))
b = random.normal(b_key, ())

#isolate the function from th weights matrix to the predictions
f = lambda W: predict(W, b, inputs)

J = jacfwd(f)(W)
print("Jacfwd result, with shape", J.shape)
print(J)

J = jacrev(f)(W)
print("Jacrev result, with shape", J.shape)
print(J)

Jacfwd result, with shape (4, 3)
[[ 0.10874642  0.23422305  0.16102834]
 [ 0.08166757 -0.10022839  0.01392061]
 [ 0.08774893  0.01012488 -0.21937233]
 [ 0.04643593 -0.15625063  0.08722425]]
Jacrev result, with shape (4, 3)
[[ 0.10874641  0.23422305  0.16102834]
 [ 0.08166758 -0.1002284   0.01392061]
 [ 0.08774893  0.01012488 -0.21937232]
 [ 0.04643593 -0.15625063  0.08722425]]


Both functions of ``jax.jacfwd()`` and ``jax.jacrev()`` computes the same value.

However,

- ``jax.jacfwd()`` is more efficient for _tall_ Jacobian i.e, more outputs than inputs
- ``jax.jacrev()`` is more for _wide_ Jacobian i.e, more inputs than outputs

For matrices that are near-square, ``jax.jacfwd()`` has probably edge over other.

In [13]:
#Both functions can be used with CONTAINER types

def predict_dic(params, inputs):
    return predict(params['W'], params['b'], inputs)

J_dict = jacrev(predict_dic)({'W': W, 'b': b}, inputs)
for k, v in J_dict.items():
    print("Jacobian form {} to logits is".format(k))
    print(v)

Jacobian form W to logits is
[[ 0.10874641  0.23422305  0.16102834]
 [ 0.08166758 -0.1002284   0.01392061]
 [ 0.08774893  0.01012488 -0.21937232]
 [ 0.04643593 -0.15625063  0.08722425]]
Jacobian form b to logits is
[0.20912772 0.09280407 0.16874795 0.06275126]


Using a composition of two of these functions gives us a way to compute dense HESSIAN matrices

To implement, ``hessian`` either way ``jacfwd(jacrev(f))`` or ``jacrev(jacfwd(f))`` would work, but __forward-over-reverse__ is typicall the most efficient

In [14]:
def hessian(f):
    return jacfwd(jacrev(f))

H = hessian(f)(W)
print("Hessian with Shape", H.shape)
print(H)

Hessian with Shape (4, 3, 3)
[[[-0.02286455 -0.04924672 -0.03385712]
  [-0.04924672 -0.10606987 -0.07292303]
  [-0.03385712 -0.07292303 -0.05013458]]

 [[ 0.05698795 -0.06993975  0.00971385]
  [-0.06993975  0.08583514 -0.01192155]
  [ 0.00971385 -0.01192155  0.00165577]]

 [[ 0.02601311  0.00300151 -0.06503277]
  [ 0.00300151  0.00034633 -0.00750378]
  [-0.06503277 -0.00750378  0.16258194]]

 [[ 0.02973893 -0.10006747  0.05586096]
  [-0.10006747  0.33671352 -0.18796457]
  [ 0.05586096 -0.18796457  0.10492801]]]
