In the FIM implementation, we need to implement the covariance matrix

$$\Sigma^{-1} = - \frac{\partial^2 \log p}{\partial \vec \phi_\text{img} \partial \vec \phi_\text{img}}$$

Here we do a jax implementation of a simpler scalar function:

$$\Sigma^{-1} = - \frac{\partial^2 f(\vec x)}{\partial \vec x \partial \vec x}$$

where 
$$f(\vec x) = (x_1-1)^3 + (x_1-1)^2 + x_2^2 + 2 (x_1-1) x_2$$

In this case, the Hessian should be:

$$\Sigma^{-1} = \begin{pmatrix}
6 x_1-4 & 2\\
2 & 2
\end{pmatrix} $$


This has a max "posterior" value at $\vec x = (1, 0)$

In [251]:
import numpy as np
import jax.numpy as jnp
from jax import hessian, jacobian, grad, jit

In [252]:
def f(x):
    """Function to compute the value of f."""
    return (x[0]-1)**3 + (x[0]-1)**2 + x[1]**2 + 2 * (x[0]-1) * x[1]


def hessian_analytical(x):
    """Analytical Hessian of the function f."""
    return jnp.array([[6*x[0] - 4, 2],
                     [2, 2]])

In [253]:
# Take the hessian of f using jax:
hessian_f = hessian(f)
# Take the gradient of f using jax:
grad_f = grad(f)

In [254]:
# Print out the hessian at the point (1,1) using both the analytical and jax methods
x = jnp.array([2., 5.])
print("Analytical Hessian at (1,1):")
print(hessian_analytical(x))
print("JAX Hessian at (1,1):")
print(hessian_f(x))
print("Difference between analytical and JAX Hessian:")
print(hessian_analytical(x) - hessian_f(x))

Analytical Hessian at (1,1):
[[8. 2.]
 [2. 2.]]
JAX Hessian at (1,1):
[[8. 2.]
 [2. 2.]]
Difference between analytical and JAX Hessian:
[[0. 0.]
 [0. 0.]]


In [255]:
# Compute the gradient at the point (1,0) using only jax
print("Gradient at (1,0):")
print(grad_f(jnp.array([1., 0.])))

Gradient at (1,0):
[0. 0.]


Now compute the same thing but instead of using vectors, use dictionaries (because herculens uses dictionaries)

In [256]:
def f_using_dictionaries(x):
    """Function to compute the value of f using dictionaries."""
    return (x['x0']-1)**3 + (x['x0']-1)**2 + x['x1']**2 + 2*(x['x0']-1)*x['x1']

def hessian_analytical_using_dictionaries(x):
    """Analytical Hessian of the function f using dictionaries."""
    return jnp.array([[6*x['x0'] - 4, 2],
                     [2, 2]])
    
# Take the hessian of f using jax with dictionaries:
hessian_f_dict = hessian(f_using_dictionaries)

In [257]:
# Print out the hessian at the point (1, 1) using both the analytical and jax methods
x_dict = {'x0': 1., 'x1': 1.}
print("Analytical Hessian at (1, 1) using dictionaries:")
print(hessian_analytical_using_dictionaries(x_dict))
print("JAX Hessian at (1, 1) using dictionaries:")
print(hessian_f_dict(x_dict))

Analytical Hessian at (1, 1) using dictionaries:
[[2. 2.]
 [2. 2.]]
JAX Hessian at (1, 1) using dictionaries:
{'x0': {'x0': Array(2., dtype=float32, weak_type=True), 'x1': Array(2., dtype=float32, weak_type=True)}, 'x1': {'x0': Array(2., dtype=float32, weak_type=True), 'x1': Array(2., dtype=float32, weak_type=True)}}


In [258]:
# Convert the dictionary-based Hessian to a matrix form
# using the keys of the dictionary
hessian_dictionary_form = hessian_f_dict(x_dict)
keys = list(x_dict.keys())
hessian_matrix_form = jnp.array([[hessian_dictionary_form[keys[i]][keys[j]] for j in range(len(keys))] for i in range(len(keys))])

print(hessian_matrix_form)
print("Difference between analytical and JAX Hessian using dictionaries:")
print(hessian_analytical_using_dictionaries(x_dict) - hessian_matrix_form)

[[2. 2.]
 [2. 2.]]
Difference between analytical and JAX Hessian using dictionaries:
[[0. 0.]
 [0. 0.]]


Now compute the hessian using another set of parameters:

$$\vec y = \{y_1,y_2,y_3\}$$

where $$y_i = y_i(\vec x)$$

and, in particular,

$$x_1 = y_1$$

$$x_2 = y_2 + g(y_2,x_1)$$


$$x_2 = y_3 + g(y_2,x_1)$$

with $g(y,x_1) = - y^2/x_1,$ and $y_2=(x_1+\sqrt{x_1-4 x_1 x_2})/2$ and $y_3=(x_1-\sqrt{x_1-4 x_1 x_2})/2$

In this case, we can alternatively evaluate the hessian with:

$$\Sigma^{-1} = - \left( \frac{\vec x}{\vec y}\right)^{-1} \frac{\partial^2 f(\vec y)}{\partial \vec y \partial \vec y} \left( \frac{\vec x}{\vec y}\right)^{-1}$$

at the max posterior point $\vec x_\text{maxP} = (1,0)$. 

In [259]:
def g(y_scalar, x_1):
    return -y_scalar**2/x_1

def x(y):
    return jnp.array([y[0], 
                      (y[1]+g(y[1], y[0])+y[2]+g(y[2], y[0]))/2.
                    ])

def y(x):
    return jnp.array([x[0], 
                      (x[0]+jnp.sqrt(x[0]**2-4*x[0]*x[1]))/2.,
                      (x[0]-jnp.sqrt(x[0]**2-4*x[0]*x[2]))/2.
                    ])

def f_as_function_of_y(y):
    """Function to compute the value of f."""
    x0 =  x(y) # Convert y to x
    return f(x0)

In [267]:
# Compute the Hessian of f as a function of y:
hessian_f_y = hessian(f_as_function_of_y)
# Print out the hessian at the point y(x) with x=(1, 0)
x0 = jnp.array([1., 0.])
y0 = y(x0)
print("Hessian at y(x) with x=(1,0):")
print(hessian_f_y(y0))

Hessian at y(x) with x=(1,0):
[[ 4.5 -1.5  1.5]
 [-1.5  0.5 -0.5]
 [ 1.5 -0.5  0.5]]


In [283]:
# Compute the Jacobian dx/dy:
jac = jacobian(x)(y0); jac_pinv = jnp.linalg.pinv(jac)
print("Jacobian dx/dy:")
print(jac)

Jacobian dx/dy:
[[ 1.   0.   0. ]
 [ 0.5 -0.5  0.5]]


In [285]:
# Compare the hessian computed using hessian_f(x) and using hessian_f_y(y) with jacobians:
print("Hessian using hessian_f(x):")
y0 = y(x0)
print( hessian_f(x0) )
print("Hessian using hessian_f_y(y):")
print( jac_pinv.T @ hessian_f_y(y0) @ jac_pinv )

Hessian using hessian_f(x):
[[2. 2.]
 [2. 2.]]
Hessian using hessian_f_y(y):
[[1.9999993 2.0000002]
 [1.9999999 2.000001 ]]
