# Automatic differentiation

In [1]:
import jax.numpy as jnp
import jax 
import estimagic as em

def criterion(x):
    first = (x["a"] - jnp.pi) ** 4 
    second =  jnp.linalg.norm(x["b"] - jnp.arange(3))
    third = jnp.linalg.norm(x["c"] - jnp.eye(2))
    return first + second + third
    
    
start_params = {
    "a": 1.,
    "b": jnp.ones(3).astype(float),
    "c": jnp.ones((2, 2)).astype(float)
}



In [2]:
gradient = jax.grad(criterion)
gradient(start_params)

{'a': DeviceArray(-39.28897, dtype=float32, weak_type=True),
 'b': DeviceArray([ 0.70710677,  0.        , -0.70710677], dtype=float32),
 'c': DeviceArray([[0.        , 0.70710677],
              [0.70710677, 0.        ]], dtype=float32)}

In [8]:
res = em.minimize(
    criterion=criterion,
    derivative=gradient,
    params=start_params,
    algorithm="scipy_lbfgsb",
)

In [9]:
res.params

{'a': 3.146955731817973,
 'b': DeviceArray([1.2965498e-09, 1.0000000e+00, 2.0000000e+00], dtype=float32),
 'c': DeviceArray([[ 1.0000000e+00, -1.1685792e-09],
              [-1.1685792e-09,  1.0000000e+00]], dtype=float32)}