In [15]:
from jax import jacfwd, jacrev, grad
import numpy as np
import jax.numpy as jnp

# The use of JVP (in case of forward mode) and VJP (in case of backward mode) is more memory efficient than computing the Jacobian directly.

In [None]:
#https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html

grad(f) = $\nabla f$

grad(f)(x) = $\nabla f(x)$

## `jax.vjp`

**A bit about math**

Mathematically, suppose we have a function $f: \mathbb{R}^n \rightarrow \mathbb{R}^m$, the Jacobian matrix of $f$ at a particular point $x$, denoted $J(x) \in \mathbb{R}^{m \times n}$, is a matrix:
$$J(x) = 
\left(\begin{matrix} 
\frac{\partial f_1}{\partial x_1} & \cdots & \frac{\partial f_1}{\partial x_n} \\
\vdots & \ddots & \vdots \\
\frac{\partial f_m}{\partial x_1} & \cdots & \frac{\partial f_m}{\partial x_n}
\end{matrix} \right)$$

You can think of it as a linear map $J(x): \mathbb{R}^n \rightarrow \mathbb{R}^m$ which maps $v$ to $J(x)v$.

What vector-Jacobian product does is to compute $vJ(x)$ or $J(x)^\top v$. `jax.vjp` is the api to compute the vector-Jacobian product in JAX with two arguments:
- first argument: a callable function $f$
- second argument: primal value at which point the Jacobian is evaluated (Should be either a tuple or a list of arguments)

It returns both $f(x)$ and a linear map $J(x)^\top: \mathbb{R}^m \rightarrow \mathbb{R}^n$ which map $v$ to $J^\top v$.

In [28]:
import jax

f = lambda x: x**3 + 2*x**2 - 3*x + 1
df = lambda x: 3*x**2 + 4*x - 3


dfdx = jax.grad(f)
d2fdx = jax.grad(dfdx)
d3fdx = jax.grad(d2fdx)
d4fdx = jax.grad(d3fdx)

dsig = jax.grad(sigmoid)

In [26]:
print(dfdx(1.))
print(df(1.))

4.0
4.0


In [78]:
def ff(x):
    x1, x2 = x[0], x[1]
    return jnp.array([x1**4 + 3 * x2**2 *x1, 5*x2**2 - 2*x1*x2+1])


J = jacrev(ff)

x = jnp.array([1.0, 2.0])
print(J(x))

[[16. 12.]
 [-4. 18.]]


In [None]:
#