In the FIM implementation, we need to implement the covariance matrix

$$\Sigma^{-1} = - \frac{\partial^2 \log p}{\partial \vec \phi_\text{img} \partial \vec \phi_\text{img}}$$

Here we do a jax implementation of a simpler scalar function:

$$\Sigma^{-1} = - \frac{\partial^2 f(\vec x)}{\partial \vec x \partial \vec x}$$

where 
$$f(\vec x) = x_1^3 + x_1^2 + x_2^2 + 2 x_1 x_2$$

In this case, the Hessian should be:

$$\Sigma^{-1} = \begin{pmatrix}
2+6 x_1 & 2\\
2 & 2
\end{pmatrix} $$

In [4]:
import numpy as np
import jax.numpy as jnp
from jax import hessian, jacobian, grad, jit

In [21]:
def f(x):
    """Function to compute the value of f."""
    return x[0]**3 + x[0]**2 + x[1]**2 + 2*x[0]*x[1]

def hessian_analytical(x):
    """Analytical Hessian of the function f."""
    return jnp.array([[6*x[0] + 2, 2],
                     [2, 2]])

In [22]:
# Take the hessian of f using jax:
hessian_f = hessian(f)

In [23]:
# Print out the hessian at the point (1, 2) using both the analytical and jax methods
x = jnp.array([1., 2.])
print("Analytical Hessian at (1, 2):")
print(hessian_analytical(x))
print("JAX Hessian at (1, 2):")
print(hessian_f(x))
print("Difference between analytical and JAX Hessian:")
print(hessian_analytical(x) - hessian_f(x))

Analytical Hessian at (1, 2):
[[8. 2.]
 [2. 2.]]
JAX Hessian at (1, 2):
[[8. 2.]
 [2. 2.]]
Difference between analytical and JAX Hessian:
[[0. 0.]
 [0. 0.]]
