[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax-md/blob/main/notebooks/implicit_differentiation.ipynb)

In [None]:
#@title Import & Util
!pip install -q git+https://www.github.com/google/jax
!pip install -q git+https://www.github.com/google/jax-md
!pip install jaxopt

import jax
import jax.numpy as jnp
from jax.config import config
config.update('jax_enable_x64', True)

from jax import random, jit, lax

from jax_md import space, energy, minimize, quantity

import jaxopt

f32 = jnp.float32
f64 = jnp.float64

# Energy minimization functions

In [None]:
# Energy minimization functions with and without neighbor lists.
def run_minimization_while(
    energy_fn, R_init, shift, max_grad_thresh=1e-12, max_num_steps=1000000, **kwargs
):
    init, apply = minimize.fire_descent(jit(energy_fn), shift, **kwargs)
    apply = jit(apply)

    @jit
    def get_maxgrad(state):
        return jnp.amax(jnp.abs(state.force))

    @jit
    def cond_fn(val):
        state, i = val
        return jnp.logical_and(get_maxgrad(state) > max_grad_thresh, i < max_num_steps)

    @jit
    def body_fn(val):
        state, i = val
        return apply(state), i + 1

    state = init(R_init)
    state, num_iterations = lax.while_loop(cond_fn, body_fn, (state, 0))

    return state.position, get_maxgrad(state), num_iterations+1


def run_minimization_while_nl(
    neighbor_fn,
    energy_fn,
    R_init,
    forced_rebuilding,
    shift,
    max_grad_thresh=1e-12,
    max_num_steps=1000000,
    **kwargs
):
    init_fn, apply_fn = minimize.fire_descent(energy_fn, shift, **kwargs)
    apply_fn = jit(apply_fn)

    nbrs = neighbor_fn.allocate(R_init)
    state = init_fn(R_init, neighbor=nbrs)

    @jit
    def get_maxgrad(state):
        return jnp.amax(jnp.abs(state.force))

    @jit
    def cond_fn(state, i):
        return jnp.logical_and(get_maxgrad(state) > max_grad_thresh, i < max_num_steps)

    @jit
    def update_nbrs(R,nbrs):
        return neighbor_fn.update(R,nbrs)
    
    steps = 0
    while cond_fn(state, steps):
        nbrs = update_nbrs(state.position,nbrs)
        new_state = apply_fn(state, neighbor=nbrs)
        if forced_rebuilding and steps == 10:
            print("Forced rebuilding of neighbor_list.")
            nbrs = neighbor_fn.allocate(state.position)
        if nbrs.did_buffer_overflow:
            print("Rebuilding neighbor_list.")
            nbrs = neighbor_fn.allocate(state.position)
        else:
            state = new_state
            steps += 1
        
    return state.position, nbrs, steps+1

# This version is fully differentiable and jit-able
def run_minimization_scan(force_fn, R_init, shift, num_steps=5000, **kwargs):
    init, apply = minimize.fire_descent(jit(force_fn), shift, **kwargs)
    apply = jit(apply)

    @jit
    def scan_fn(state, i):
        return apply(state), 0.0

    state = init(R_init)
    state, _ = lax.scan(scan_fn, state, jnp.arange(num_steps))
    return state.position, jnp.amax(jnp.abs(force_fn(state.position)))

# Meta Optimization

In this notebook we'll have another look at differentiating through an energy minimization routine. But this time we'll look at how a technique called implicit differentiation lets us do so much more efficiently than before. Let us first set up our system and see for ourselves what goes wrong when we aren't using implicit differentiation.

We'll work with a system of `N` soft spheres and we are interested in first computing the energy minimum of this system and then to compute the gradient of the energy with respect to the parameters `sigma` and/or `alpha`.

We'll start of with a very small and rather dilute system.

In [None]:
N = 7
density = 0.3

dimension = 2
box_size = quantity.box_size_at_number_density(N,density,dimension)
displacement, shift = space.periodic(box_size) 

key = random.PRNGKey(5)
key, split = random.split(key)
R_init = random.uniform(key, (N,dimension), minval=0.0, maxval=box_size, dtype=f64) 

sigma = jnp.full((N,N), 2.)
alpha = jnp.full((N,N), 2.)
param_dict = {'sigma':sigma,'alpha':alpha}

Note that we are using a scan for the `explicit` function in order to use `jax.grad`. Here the step size is fixed and we need to choose it large enough in order to reach the energy minimum!

In [None]:
def explicit_diff(params,R_init,displacement,num_steps):
    energy_fn = energy.soft_sphere_pair(displacement, **params)

    force_fn = quantity.force(energy_fn)
    force_fn = jit(force_fn)

    # we need to use a scan instead of a while loop in order to use jax.grad
    solver = lambda f, x: run_minimization_scan(f, x, shift, num_steps=num_steps,dt_start=0.001, dt_max=0.005)[0]
    R_final = solver(force_fn,R_init)
    
    return energy_fn(R_final), jnp.amax(jnp.abs(force_fn(R_final)))

(exp_e, exp_f), exp_g = jax.value_and_grad(explicit_diff,has_aux=True)(param_dict,R_init,displacement,19400)
print('Energy        : ',exp_e)
print('Max_grad_force: ',exp_f)
print('Gradient of the energy:')
print(exp_g['sigma'][0])
print(exp_g['alpha'][0])

This being plain `jax` code we can easily compute gradients with respect to different parameters in one go.

Now what's the problem here?\
The answer to that is quite simple and easily demonstrated by working with a larger system and/or a system where we need to take more optimization steps to compute the energy minimum.

In [None]:
# increase the number of particles
N = 55
# and the density
density = 0.5

dimension = 2
box_size = quantity.box_size_at_number_density(N,density,dimension)
displacement, shift = space.periodic(box_size) 

key = random.PRNGKey(5)
key, split = random.split(key)
R_init = random.uniform(key, (N,dimension), minval=0.0, maxval=box_size, dtype=f64) 

sigma = jnp.full((N,N), 2.0)
alpha = jnp.full((N,N),2.)
param_dict = {'sigma':sigma,'alpha':alpha}

For this system we now have to take nearly 10 times more steps.

In [None]:
# Do not run this cell locally! This cell realiably crashes my mac.
(exp_e_large, exp_f_large), exp_g_large = jax.value_and_grad(explicit_diff,has_aux=True)(param_dict,R_init,displacement,163954)
print('Energy        : ',exp_e_large)
print('Max_grad_force: ',exp_f_large)
print('Gradient of the energy:')
print(exp_g_large['sigma'][0])
print(exp_g_large['alpha'][0])

As you can see our computation fails due to running out of memory when trying to compute the gradient. The reason for that is that for reverse mode differentiation (e.g. ```jax.grad```) the memory consumption grows linearly with respect to the number of optimization steps we have to take since reverse mode differentiation needs to store the whole forward pass in order to compute the gradient in the backwards pass.

We could lessen the memory requirements by using a technique called [gradient rematerialization/checkpointing](https://github.com/google/jax/blob/f3c4ae3d8918de3a35cec74fdf24231a77ef0e92/jax/_src/api.py#L2806) for a corresponding increase in computation time. While this strategy works, there exists a better solution for our problem called implicit differentiation.

# Implicit Differentiation

Implicit differentiation gets its name from the [implicit function theorem](https://en.wikipedia.org/wiki/Implicit_function_theorem). This theorem roughly states that when we want to differentiate through a root finding procedure, e.g. find $z$ such that $F(a,z) = 0$, it does not matter how we computed the solution. Instead it is possible to directly differentiate through the solution $z^*$ of our root finding problem using the following formular. Here $\partial_i$ means we differentiate with respect to the $i$'th argument.

$$
\partial z^*(a) = -[\partial_1 f(a_0,z_0))]^{-1} \partial_0 f(a_0,z_0))
$$

This expression can be efficiently solved. If you are interested in a derivation and how one could implement this in `jax` then I can highly recommend chapter $2$ of the NeurIPS 2020 [Deep Implicit Layers](https://implicit-layers-tutorial.org/) tutorial.

In order to use implicit differentiation we first need to rewrite our optimization problem as a root finding procedure. This is easily done since the force is $0$ at the energy minimum.

I'll first define a new function that uses implicit differentiation and then I'll highlight the differences.

In [None]:
# Let's go back to our small system in order to compare explicit and implicit differentiation.
N = 7
density = 0.3

dimension = 2
box_size = quantity.box_size_at_number_density(N,density,dimension)
displacement, shift = space.periodic(box_size) 

key = random.PRNGKey(5)
key, split = random.split(key)
R_init = random.uniform(key, (N,dimension), minval=0.0, maxval=box_size, dtype=f64) 

sigma = jnp.full((N,N), 2.0)
alpha = jnp.full((N,N),2.)
param_dict = {'sigma':sigma,'alpha':alpha}

In [None]:
def implicit_diff(params,R_init,displacement):
    energy_fn = energy.soft_sphere_pair(displacement, **params)
    force_fn = jit(quantity.force(energy_fn))
    
    # wrap force_fn with a lax.stop_gradient to prevent a CustomVJPException
    no_grad_force_fn = jit(lambda x: lax.stop_gradient(force_fn(x)))
    
    # make the dependence on the variables we want to differentiate explicit
    explicit_force_fn = jit(lambda R, p: force_fn(R, **p))
    
    def solver(params, x):
        # params are unused
        del params
        # need to use no_grad_force_fn!
        return run_minimization_while(no_grad_force_fn, x, shift, dt_start=0.001, dt_max=0.005)[0]
    
    decorated_solver = jaxopt.implicit_diff.custom_root(explicit_force_fn)(solver)
    
    R_final = decorated_solver(None,R_init)
    
    # Here we can just use our original energy_fn/force_fn
    return energy_fn(R_final), jnp.amax(jnp.abs(force_fn(R_final)))

(imp_e, imp_f), imp_g = jax.value_and_grad(implicit_diff,has_aux=True)(param_dict,R_init,displacement)
print('Energy        : ',imp_e)
print('Max_grad_force: ',imp_f)
print('Gradient of the energy:')
print(imp_g['sigma'][0])
print(imp_g['alpha'][0])

As you can see we do not have to change anything compared to our explicit version when we want to compute gradients with respect to more then one parameter.

In [None]:
# As you can see the we get the same results.
print(jax.tree_map(jnp.allclose,exp_e,imp_e))
print(jax.tree_map(jnp.allclose,exp_f,imp_f))
print(jax.tree_map(jnp.allclose,exp_g,imp_g))

We use [jaxopt](https://github.com/google/jaxopt) to define implicit gradients for our energy minimization routine. To this end we need to define two new force functions from our original force_fn. This is necessary for two seperate reasons. 

For one [jaxopt.implicit_diff.custom root](https://jaxopt.github.io/stable/_autosummary/jaxopt.implicit_diff.custom_root.html#jaxopt.implicit_diff.custom_root) requires a function that takes the parameter that we want to differentiate as an explicit input. To this end we define a new explicit_force_fn as
```python 
explicit_force_fn = jit(lambda R, p: force_fn(R,**p))
```

Furthermore we cannot pass our original force_fn to our solver as this will cause a ```CustomVJPException```. We can just wrap the output of our ```force_fn``` with ```lax.stop_gradient``` in order to fix this issue.
```python
no_grad_force_fn = jit(lambda x: lax.stop_gradient(force_fn(x)))
```

Having done that we can define our ```solver``` which takes two parameters as its input. A dummy variable ```params```, which we can just delete as we do not need it and a variable ```x``` which is our set of initial positions. It is necessary that we pass our newly defined ```no_grad_force_fn``` to our solver.

Now we can put it all together and define a ```decorated_solver``` using ```jaxopt``` in order to be able to efficiently differentiate through it:
```python
decorated_solver = implicit_diff.custom_root(explicit_force_fn)(solver)
```
Here we have to use our ```explicit_force_fn```. It is also possible to use a different linear solver for the ```vjp``` computation using the ```solve``` argument of `implicit_diff.custom_root`. Jax comes with a handfull of sparse linear solvers in ```jax.scipy.sparse.linalg```.

# Adding neighbor lists and other auxiliary output to our solver

Until now we have only worked with solvers that return one parameter `R_final`. Now we'll see how we have to change our `implicit_diff` function in order to be able to also have our solver return a neighbor list `nbrs` and the number of steps `num_steps` needed to reach the minimium.

There is only one thing we have to change besides the normal stuff we have to change when using neighbor lists. We simply add the ```has_aux=True``` keyword to ```implicit_diff.custom_root```.

There's one more funny thing to note about the neighbor list version. Instead of getting an ```CustomVJPException``` when we forget to wrap our ```force_fn``` with ```lax.stop_gradient``` we get no explicit error message but we now get a memory leak instead which drastically slows down the computation. So pay attention that your are passing the correct version of your ```force_fn``` to your ```solver```!

In [None]:
# We add the parameter forced_rebuilding to make sure 
# that it is possible to construct new neighbor lists during the optimization.

def implicit_diff_nl(params,R_init,box_size,forced_rebuilding):
    neighbor_fn, energy_fn = energy.soft_sphere_neighbor_list(displacement, box_size, **params)

    force_fn = jit(quantity.force(energy_fn))
    
    # wrap force_fn with a lax.stop_gradient to prevent a memory leak
    no_grad_force_fn = jit(lambda x,neighbor: lax.stop_gradient(force_fn(x,neighbor)))
    
    # make the dependence on the variables we want to differentiate explicit
    explicit_force_fn = jit(lambda R, neighbor, p: force_fn(R,neighbor,**p))
    
    def solver(params, x):
        # params are unused
        del params
        # need to use no_grad_force_fn!
        return run_minimization_while_nl(neighbor_fn,no_grad_force_fn, x, forced_rebuilding, shift, dt_start=0.001, dt_max=0.005)
    
    # Need to use hax_aux=True
    decorated_solver = jaxopt.implicit_diff.custom_root(explicit_force_fn,has_aux=True)(solver)
    
    R_final, nbrs, num_steps = decorated_solver(None,R_init)
    
    # Here we can just use our original energy_fn/force_fn
    return energy_fn(R_final,neighbor=nbrs), jnp.amax(jnp.abs(force_fn(R_final,neighbor=nbrs)))

(imp_nl_e, imp_nl_f), imp_nl_g = jax.value_and_grad(implicit_diff_nl,has_aux=True)(param_dict,R_init,box_size,True)
print('Energy        : ',imp_nl_e)
print('Max_grad_force: ',imp_nl_f)
print('Gradient of the energy:')
print(imp_nl_g['sigma'][0])
print(imp_nl_g['alpha'][0])

In [None]:
# We again get the same results.
print(jax.tree_map(jnp.allclose,exp_e,imp_e))
print(jax.tree_map(jnp.allclose,exp_f,imp_f))
print(jax.tree_map(jnp.allclose,exp_g,imp_g))