# Automatic differentiation

In this exercise you will use automatic differentiation in JAX and estimagic to solve the previous problem.

> Note. Because JAX cannot (yet) be installed on Windows there will be extra exercises for Windows users.

## Resources

- https://jax.readthedocs.io/en/latest/jax.numpy.html
- https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html

In [1]:
import jax 
import jax.numpy as jnp
import estimagic as em

jax.config.update("jax_enable_x64", True)

## Task 1:  Switch to JAX

- Use the code from exercise 2, task 2, and convert the criterion function and the start parameters to JAX. Look at the [`jax.numpy` documentation](https://jax.readthedocs.io/en/latest/jax.numpy.html) and slides if you have any questions.

---

## Task 1 (Windows): Copy functions

- Copy the criterion function and start parameters from exericse 2, task 2, here.

## Task 2: Gradient

- Compute the gradient of the criterion (the whole function). Look at the [`autodiff_cookbook` documentation](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html) and slides if you have any questions.
- Measure the runtime of a jitted and unjitted version of the gradient (using `%timeit`.)

---

## Task 2 (Windows): Gradient

- Compute the gradient of the criterion (the whole function) analytically
- Implement the analytical gradient

## Task 3 (all systems): Minimize

- Use estimagic to minimize the criterion
    - pass the gradient function you computed above to the minimize call.
    - use the `"scipy_lbfgsb"` algorithm or other gradient based optimizers.