In [3]:
from jax import jacfwd, jacrev
import jax.numpy as jnp

### Hessians with JAX

https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#jacobians-and-hessians-using-jacfwd-and-jacrev

In [33]:
# jacfwd, jacrev both compute same value.
# jacfwd is jusing forward mode difff.
# jacrev is using reverse mode diff (backpropagation).

def hessian(f):
    return jacfwd(jacrev(f))

Hessian matrix $\mathrm{H}_f$ of a function $f(x, y) = 2x + y$ is

$$
\mathrm{H}_f =
\begin{pmatrix}
    \frac{\partial^2 f}{\partial x^2} & \frac{\partial^2 f}{\partial x \partial y } \\
    \frac{\partial^2 f}{\partial y \partial x} & \frac{\partial^2 f}{\partial y^2} \\
\end{pmatrix}
=
\begin{pmatrix}
    0 & 0 \\
    0 & 0 \\
\end{pmatrix}
$$

In [34]:
def f(x):
    return 2 * x[0] + x[1]

x = jnp.array([2., 4.])
hessian(f)(x)

DeviceArray([[0., 0.],
             [0., 0.]], dtype=float32)

Hessian matrix $\mathrm{H}_g$ of a function $g(x, y) = 2x^3 + 4y^2$ is

$$
\mathrm{H}_g =
\begin{pmatrix}
    \frac{\partial^2 g}{\partial x^2} & \frac{\partial^2 g}{\partial x \partial y } \\
    \frac{\partial^2 g}{\partial y \partial x} & \frac{\partial^2 g}{\partial y^2} \\
\end{pmatrix}
=
\begin{pmatrix}
    12x & 0 \\
    0 & 8 \\
\end{pmatrix}
$$

In [35]:
def g(x):
    return 2 * x[0] ** 3 + 4 * x[1] ** 2

x = jnp.array([2., 4.])
hessian(g)(x)

DeviceArray([[24.,  0.],
             [ 0.,  8.]], dtype=float32)

In [36]:
x = jnp.array([5., 1000.])
hessian(g)(x)

DeviceArray([[60.,  0.],
             [ 0.,  8.]], dtype=float32)

Hessian matrix $\mathrm{H}_q$ of a function $q(x, y) = xy$ is

$$
\mathrm{H}_q =
\begin{pmatrix}
    \frac{\partial^2 q}{\partial x^2} & \frac{\partial^2 q}{\partial x \partial y } \\
    \frac{\partial^2 q}{\partial y \partial x} & \frac{\partial^2 q}{\partial y^2} \\
\end{pmatrix}
=
\begin{pmatrix}
    0 & 1 \\
    1 & 0 \\
\end{pmatrix}
$$

In [37]:
def q(x):
    return x[0] * x[1]

x = jnp.array([5., 1000.])
hessian(q)(x)

DeviceArray([[0., 1.],
             [1., 0.]], dtype=float32)