# Automatic differentiation

In this exercise you will use automatic differentiation in JAX and estimagic to solve the previous problem.

## Resources

- https://jax.readthedocs.io/en/latest/jax.numpy.html
- https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html

In [1]:
import jax 
import jax.numpy as jnp
import estimagic as em

jax.config.update("jax_enable_x64", True)

## Task 1:  Switch to JAX

- Use the code from exercise 2, task 2, and convert the criterion function and the parameters to JAX. Hint: look at the [`jax.numpy` documentation](https://jax.readthedocs.io/en/latest/jax.numpy.html) and slides if you have any questions.

In [2]:
def criterion(x):
    first = (x["a"] - jnp.pi) ** 2
    second =  jnp.sum((x["b"] - jnp.arange(3)) ** 2)
    third = jnp.sum((x["c"] - jnp.eye(2)) ** 2)
    return first + second + third
    
    
start_params = {
    "a": 1.,
    "b": jnp.ones(3).astype(float),
    "c": jnp.ones((2, 2)).astype(float)
}

In [3]:
criterion(start_params)

DeviceArray(8.58641909, dtype=float64)

## Solution, Task 1 (Windows):

In [4]:
import numpy as np

def criterion_windows(x):
    first = (x["a"] - jnp.pi) ** 2
    second =  np.sum((x["b"] - np.arange(3)) ** 2)
    third = np.sum((x["c"] - np.eye(2)) ** 2)
    return first + second + third
    
    
start_params_windows = {
    "a": 1.,
    "b": np.ones(3).astype(float),
    "c": np.ones((2, 2)).astype(float)
}

## Solution, Task 2: Gradient

- Compute the gradient of the criterion (the whole function). Hint: look at the [`autodiff_cookbook` documentation](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html) and slides if you have any questions.

In [5]:
gradient = jax.grad(criterion)
gradient(start_params)

{'a': DeviceArray(-4.28318531, dtype=float64, weak_type=True),
 'b': DeviceArray([ 2.,  0., -2.], dtype=float64),
 'c': DeviceArray([[0., 2.],
              [2., 0.]], dtype=float64)}

In [6]:
%timeit gradient(start_params)

11.5 ms ± 2.05 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [7]:
jitted_gradient = jax.jit(gradient)
%timeit jitted_gradient(start_params)

17.2 µs ± 7.57 µs per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


## Solution, Task 2 (Windows):

The analytical gradient of the function is given by:

- $\partial_a f(a, b, C) = 2 (a - \pi)$
- $\partial_b f(a, b, C) = 2 (b - \begin{pmatrix}0,1,2\end{pmatrix}^\top)$
- $\partial_C f(a, b, C) = 2 (C - I_2)$

---

- Implement the analytical gradient
    - return the gradient in the form of `{"a": ..., "b": ..., "C": ...}`

In [8]:
def gradient(params):
    return {
        "a": 2 * (params["a"] - np.pi),
        "b": 2 * (params["b"] - np.array([0, 1, 2])),
        "c": 2 * (params["c"] - np.eye(2))
    }

## Solution, Task 3: Minimize

- Use estimagic to minimize the criterion
    - pass the gradient function you computed above to the minimize call.
    - use the `"scipy_lbfgsb"` algorithm.

In [9]:
res = em.minimize(
    criterion=criterion,
    derivative=jitted_gradient,
    params=start_params,
    algorithm="scipy_lbfgsb",
)

res.params

{'a': 3.141592653589793,
 'b': DeviceArray([3.33066907e-16, 1.00000000e+00, 2.00000000e+00], dtype=float64),
 'c': DeviceArray([[1.00000000e+00, 3.33066907e-16],
              [3.33066907e-16, 1.00000000e+00]], dtype=float64)}

In [10]:
res = em.minimize(
    criterion=criterion_windows,
    derivative=gradient,
    params=start_params_windows,
    algorithm="scipy_lbfgsb",
)

res.params

{'a': 3.141592653589793,
 'b': array([3.33066907e-16, 1.00000000e+00, 2.00000000e+00]),
 'c': array([[1.00000000e+00, 3.33066907e-16],
        [3.33066907e-16, 1.00000000e+00]])}