[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax-md/blob/main/notebooks/implicit_differentiation.ipynb)

# Implicit Differentiation

This cookbook was contributed by Maxi Lechner.

In [1]:
#@title Import & Util
!pip install -q git+https://www.github.com/google/jax-md
!pip install -q jaxopt

import jax
import jax.numpy as jnp
from jax.config import config

config.update("jax_enable_x64", True)

from jax import random, jit, lax

from jax_md import space, energy, minimize, quantity

from jaxopt.implicit_diff import custom_root

f32 = jnp.float32
f64 = jnp.float64

# Energy minimization using a while loop.
def run_minimization_while(
  energy_fn, R_init, shift, max_grad_thresh=1e-12, max_steps=1000000, **kwargs
):
  init, apply = minimize.fire_descent(
    jit(energy_fn), shift, dt_start=0.001, dt_max=0.005, **kwargs
  )
  apply = jit(apply)

  @jit
  def get_maxgrad(state):
    return jnp.amax(jnp.abs(state.force))

  @jit
  def cond_fn(val):
    state, i = val
    return (get_maxgrad(state) > max_grad_thresh) & (i < max_steps)

  @jit
  def body_fn(val):
    state, i = val
    return apply(state), i + 1

  state = init(R_init)
  state, num_iterations = lax.while_loop(cond_fn, body_fn, (state, 0))

  return state.position, get_maxgrad(state), num_iterations + 1


# Energy minimization using both a while loop and neighbor lists.
# We add the parameter forced_rebuilding to make sure
# that it is possible to construct new neighbor lists during the optimization.
def run_minimization_while_nl(
  neighbor_fn,
  energy_fn,
  R_init,
  shift,
  forced_rebuilding=True,
  max_grad_thresh=1e-12,
  max_num_steps=1000000,
  **kwargs
):
  init_fn, apply_fn = minimize.fire_descent(
    energy_fn, shift, dt_start=0.001, dt_max=0.005, **kwargs
  )
  apply_fn = jit(apply_fn)

  nbrs = neighbor_fn.allocate(R_init)
  state = init_fn(R_init, neighbor=nbrs)

  @jit
  def get_maxgrad(state):
    return jnp.amax(jnp.abs(state.force))

  @jit
  def cond_fn(state, i):
    return (get_maxgrad(state) > max_grad_thresh) & (i < max_num_steps)

  @jit
  def update_nbrs(R, nbrs):
    return neighbor_fn.update(R, nbrs)

  steps = 0
  while cond_fn(state, steps):
    nbrs = update_nbrs(state.position, nbrs)
    new_state = apply_fn(state, neighbor=nbrs)
    if forced_rebuilding and steps == 10:
      print("Forced rebuilding of neighbor_list.")
      nbrs = neighbor_fn.allocate(state.position)
    if nbrs.did_buffer_overflow:
      print("Rebuilding neighbor_list.")
      nbrs = neighbor_fn.allocate(state.position)
    else:
      state = new_state
      steps += 1

  return state.position, nbrs, steps + 1


# jax.grad compatible version using a scan.
def run_minimization_scan(force_fn, R_init, shift, steps=5000, **kwargs):
  init, apply = minimize.fire_descent(
    jit(force_fn), shift, dt_start=0.001, dt_max=0.005, **kwargs
  )
  apply = jit(apply)

  @jit
  def scan_fn(state, i):
    return apply(state), 0.0

  state = init(R_init)
  state, _ = lax.scan(scan_fn, state, jnp.arange(steps))
  return state.position, jnp.amax(jnp.abs(force_fn(state.position)))

[?25l[K     |█▏                              | 10 kB 25.7 MB/s eta 0:00:01[K     |██▎                             | 20 kB 10.5 MB/s eta 0:00:01[K     |███▍                            | 30 kB 8.6 MB/s eta 0:00:01[K     |████▋                           | 40 kB 7.7 MB/s eta 0:00:01[K     |█████▊                          | 51 kB 4.3 MB/s eta 0:00:01[K     |██████▉                         | 61 kB 4.5 MB/s eta 0:00:01[K     |████████                        | 71 kB 4.6 MB/s eta 0:00:01[K     |█████████▏                      | 81 kB 5.2 MB/s eta 0:00:01[K     |██████████▎                     | 92 kB 5.3 MB/s eta 0:00:01[K     |███████████▍                    | 102 kB 4.3 MB/s eta 0:00:01[K     |████████████▌                   | 112 kB 4.3 MB/s eta 0:00:01[K     |█████████████▊                  | 122 kB 4.3 MB/s eta 0:00:01[K     |██████████████▉                 | 133 kB 4.3 MB/s eta 0:00:01[K     |████████████████                | 143 kB 4.3 MB/s eta 0:00:01[K   


#Meta Optimization

In this notebook we'll have another look at differentiating through an energy minimization routine. But this time we'll look at how a technique called implicit differentiation lets us do so much more efficiently than before. We will use the great [jaxopt](https://github.com/google/jaxopt) package for that purpose.

Let us first set up our system and see for ourselves what goes wrong when we aren't using implicit differentiation.

We'll work with a system of `N` soft spheres and we are interested in first computing the energy minimum of this system and then to compute the gradient of the energy with respect to the parameters `sigma` and/or `alpha`, which we take to be matrices such that for every pair of particles both `sigma` and `alpha` can take on a different value.

We'll start of with a very small and rather dilute system.


In [2]:
N = 7
density = 0.3

dimension = 2
box_size = quantity.box_size_at_number_density(N, density, dimension)
displacement, shift = space.periodic(box_size)

key = random.PRNGKey(5)
key, split = random.split(key)
R_init = random.uniform(key, (N, dimension), maxval=box_size, dtype=f64)

sigma = jnp.full((N, N), 2.0)
alpha = jnp.full((N, N), 2.0)
param_dict = {"sigma": sigma, "alpha": alpha}

Note that we are using a scan for the `explicit` function in order to use `jax.grad`. Here the step size is fixed and we need to select a large enough number of steps in order to reach the energy minimum!

In [3]:
def explicit_diff(params, R_init, displacement, num_steps):
  energy_fn = energy.soft_sphere_pair(displacement, **params)

  force_fn = quantity.force(energy_fn)
  force_fn = jit(force_fn)

  # We need to use a scan instead of a while loop in order to use jax.grad.
  solver = lambda f, x: run_minimization_scan(f, x, shift, steps=num_steps)[0]
  R_final = solver(force_fn, R_init)

  return energy_fn(R_final), jnp.amax(jnp.abs(force_fn(R_final)))


(exp_e, exp_f), exp_g = jax.value_and_grad(explicit_diff, has_aux=True)(
  param_dict, R_init, displacement, 19400
)
print("Energy        : ", exp_e)
print("Max_grad_force: ", exp_f)
print("Gradient of the energy:")
print(exp_g["sigma"][0])
print(exp_g["alpha"][0])

Energy        :  0.0615943027340254
Max_grad_force:  9.886709506634617e-13
Gradient of the energy:
[0.         0.00297984 0.00736627 0.015354   0.         0.015354
 0.00297984]
[ 0.         -0.00017895 -0.00092201 -0.00348099  0.         -0.00348099
 -0.00017895]


This being plain `jax` code we can easily compute gradients with respect to a whole dictionary of parameters in one go.

Now what's the problem here?\
The answer to that is quite simple and easily demonstrated by increasing the systems size and/or by working with a system where we need to take more optimization steps to reach the energy minimum.

In [4]:
# Increase the number of particles
N = 55
# and the density.
density = 0.5

dimension = 2
box_size = quantity.box_size_at_number_density(N, density, dimension)
displacement, shift = space.periodic(box_size)

key = random.PRNGKey(5)
key, split = random.split(key)
R_init = random.uniform(key, (N, dimension), maxval=box_size, dtype=f64)

sigma = jnp.full((N, N), 2.0)
alpha = jnp.full((N, N), 2.0)
param_dict = {"sigma": sigma, "alpha": alpha}

For this system we now have to take nearly 10 times more steps.

In [5]:
# Do not run this cell locally! This cell realiably crashes my mac.
(exp_e_large, exp_f_large), exp_g_large = jax.value_and_grad(
  explicit_diff, has_aux=True
)(param_dict, R_init, displacement, 163954)
print("Energy        : ", exp_e_large)
print("Max_grad_force: ", exp_f_large)
print("Gradient of the energy:")
print(exp_g_large["sigma"][0])
print(exp_g_large["alpha"][0])

RuntimeError: ignored

As you can see our computation fails due to running out of memory when trying to compute the gradient. The reason for that is that for reverse mode differentiation (e.g. ```jax.grad```) the memory consumption grows linearly with respect to the number of optimization steps we have to take since reverse mode differentiation needs to store the whole forward pass in order to compute the gradient in the backwards pass.

We could reduce the memory requirements by using a technique called [gradient rematerialization/checkpointing](https://github.com/google/jax/blob/f3c4ae3d8918de3a35cec74fdf24231a77ef0e92/jax/_src/api.py#L2806) for a corresponding increase in computation time. While this strategy works, there exists a better solution for our problem called implicit differentiation.

# Implicit Differentiation

Implicit differentiation gets its name from the [implicit function theorem](https://en.wikipedia.org/wiki/Implicit_function_theorem). This theorem roughly states that when we want to differentiate through a root finding procedure, e.g. find $z$ such that $F(a,z) = 0$, it does not matter how we computed the solution. Instead it is possible to directly differentiate through the solution $z^*$ of our root finding problem using the following formular. Here $\partial_i$ means we differentiate with respect to the $i$'th argument.

$$
\partial z^*(a) = -[\partial_1 f(a_0,z_0))]^{-1} \partial_0 f(a_0,z_0))
$$

This expression can be efficiently solved. If you are interested in a derivation and how one could implement this in `jax` then I can highly recommend chapter $2$ of the NeurIPS 2020 [Deep Implicit Layers](https://implicit-layers-tutorial.org/) tutorial.

Alternatively we can also phrase this as an fixed point equation `g(a,z) = z`. Since the force is $0$ at the minimum it is a fixed point of gradient decent, e.g. after having reached the energy minium $z^*$ any additional gradient decent step will always just return $z^*$.

In order to use implicit differentiation we first need to rewrite our optimization problem as a root finding procedure. This is easily done since the force is $0$ at the energy minimum. 

I'll first define a new function that uses implicit differentiation and then I'll highlight the differences.

In [6]:
# Let's go back to our small system in order
# to compare explicit and implicit differentiation.
N = 7
density = 0.3

dimension = 2
box_size = quantity.box_size_at_number_density(N, density, dimension)
displacement, shift = space.periodic(box_size)

key = random.PRNGKey(5)
key, split = random.split(key)
R_init = random.uniform(key, (N, dimension), maxval=box_size, dtype=f64)

sigma = jnp.full((N, N), 2.0)
alpha = jnp.full((N, N), 2.0)
param_dict = {"sigma": sigma, "alpha": alpha}

In [7]:
def implicit_diff(params, R_init, displacement):
  energy_fn = energy.soft_sphere_pair(displacement, **params)
  force_fn = jit(quantity.force(energy_fn))

  # Wrap force_fn with a lax.stop_gradient to prevent a CustomVJPException.
  no_grad_force_fn = jit(lambda x: lax.stop_gradient(force_fn(x)))

  # Make the dependence on the variables we want to differentiate explicit.
  explicit_force_fn = jit(lambda R, p: force_fn(R, **p))

  def solver(params, x):
    # params are unused
    del params
    # need to use no_grad_force_fn!
    return run_minimization_while(no_grad_force_fn, x, shift)[0]

  decorated_solver = custom_root(explicit_force_fn)(solver)

  R_final = decorated_solver(None, R_init)

  # Here we can just use our original energy_fn/force_fn.
  return energy_fn(R_final), jnp.amax(jnp.abs(force_fn(R_final)))


(imp_e, imp_f), imp_g = jax.value_and_grad(implicit_diff, has_aux=True)(
  param_dict, R_init, displacement
)
print("Energy        : ", imp_e)
print("Max_grad_force: ", imp_f)
print("Gradient of the energy:")
print(imp_g["sigma"][0])
print(imp_g["alpha"][0])

Energy        :  0.0615943027340254
Max_grad_force:  9.964598590705975e-13
Gradient of the energy:
[ 0.          0.00297984  0.00736627  0.015354   -0.          0.015354
  0.00297984]
[ 0.         -0.00017895 -0.00092201 -0.00348099  0.         -0.00348099
 -0.00017895]


In [8]:
# Implicit and explicit differentiation gives the same result.
print(jax.tree_map(jnp.allclose,exp_e,imp_e))
print(jax.tree_map(jnp.allclose,exp_f,imp_f))
print(jax.tree_map(jnp.allclose,exp_g,imp_g))

True
True
{'alpha': DeviceArray(True, dtype=bool), 'sigma': DeviceArray(True, dtype=bool)}


Let us now look at the changes we had to make in order to use implicit differentiation.

We use [jaxopt](https://github.com/google/jaxopt) to define implicit gradients for our energy minimization routine. To this end we need to define two new force functions from our original `force_fn`. This is necessary for two seperate reasons. 

For one [jaxopt.implicit_diff.custom root](https://jaxopt.github.io/stable/_autosummary/jaxopt.implicit_diff.custom_root.html#jaxopt.implicit_diff.custom_root) requires a function that takes the parameter that we want to differentiate as an explicit input. To this end we define a new explicit_force_fn as
```python 
explicit_force_fn = jit(lambda R, p: force_fn(R,**p))
```

Furthermore we cannot pass our original force_fn to our solver as this will cause a ```CustomVJPException```. We can just wrap the output of our ```force_fn``` with ```lax.stop_gradient``` in order to fix this issue.
```python
no_grad_force_fn = jit(lambda x: lax.stop_gradient(force_fn(x)))
```

Having done that we can define our ```solver``` which takes two parameters as its input. A dummy variable ```params```, which we can just delete as we do not need it and a variable ```x``` which is our set of initial positions. It is necessary that we pass our newly defined ```no_grad_force_fn``` to our solver.

Now we can put it all together and define a ```decorated_solver``` using ```jaxopt``` in order to be able to efficiently differentiate through it:
```python
decorated_solver = implicit_diff.custom_root(explicit_force_fn)(solver)
```
Here we have to use our ```explicit_force_fn```. It is also possible to use a different linear solver for the ```vjp``` computation using the ```solve``` argument of `implicit_diff.custom_root`. Jax comes with a handfull of sparse linear solvers in ```jax.scipy.sparse.linalg```.

# A small simplification

We can slightly simplify our code as constructing the `explicit_force_fn` is in fact unnecessary! While jax requires this explicit verson jaxopt is infact able to construct it automatically. This should work for all energy functions in jax-md as long as they and the solver do not take a catch-all `**kwargs` parameter.

In [9]:
def implicit_diff(params, R_init, displacement):
  energy_fn = energy.soft_sphere_pair(displacement, **params)
  force_fn = jit(quantity.force(energy_fn))

  # Wrap force_fn with a lax.stop_gradient to prevent a CustomVJPException.
  no_grad_force_fn = jit(lambda x: lax.stop_gradient(force_fn(x)))


  def solver(params, x):
    # params are unused
    del params
    # need to use no_grad_force_fn!
    return run_minimization_while(no_grad_force_fn, x, shift)[0]

  decorated_solver = custom_root(force_fn)(solver)

  R_final = decorated_solver(None, R_init)

  # Here we can just use our original energy_fn/force_fn.
  return energy_fn(R_final), jnp.amax(jnp.abs(force_fn(R_final)))


(imp_e, imp_f), imp_g = jax.value_and_grad(implicit_diff, has_aux=True)(
  param_dict, R_init, displacement
)
print("Energy        : ", imp_e)
print("Max_grad_force: ", imp_f)
print("Gradient of the energy:")
print(imp_g["sigma"][0])
print(imp_g["alpha"][0])

Energy        :  0.0615943027340254
Max_grad_force:  9.964598590705975e-13
Gradient of the energy:
[ 0.          0.00297984  0.00736627  0.015354   -0.          0.015354
  0.00297984]
[ 0.         -0.00017895 -0.00092201 -0.00348099  0.         -0.00348099
 -0.00017895]


In [10]:
# Implicit and explicit differentiation gives the same result.
print(jax.tree_map(jnp.allclose,exp_e,imp_e))
print(jax.tree_map(jnp.allclose,exp_f,imp_f))
print(jax.tree_map(jnp.allclose,exp_g,imp_g))

True
True
{'alpha': DeviceArray(True, dtype=bool), 'sigma': DeviceArray(True, dtype=bool)}


# Adding neighbor lists and other auxiliary output to our solver

Until now we have only worked with solvers that return one parameter `R_final`. Now we'll see how we have to change our `implicit_diff` function in order to be able to also have our solver return a neighbor list `nbrs` and the number of steps `num_steps` needed to reach the minimium.

There is only one thing we have to change besides the normal stuff we have to change when using neighbor lists. We simply add the ```has_aux=True``` keyword to ```implicit_diff.custom_root```.

There's one more funny thing to note about the neighbor list version. Instead of getting an ```CustomVJPException``` when we forget to wrap our ```force_fn``` with ```lax.stop_gradient``` we get no explicit error message but we now get a memory leak instead which drastically slows down the computation. So pay attention that your are passing the correct version of your ```force_fn``` to your ```solver```!

In [11]:
def implicit_diff_nl(params, R_init, box_size):
  neighbor_fn, energy_fn = energy.soft_sphere_neighbor_list(
    displacement, box_size, **params
  )

  force_fn = jit(quantity.force(energy_fn))

  # wrap force_fn with a lax.stop_gradient to prevent a memory leak
  no_grad_force_fn = jit(lambda x, neighbor: lax.stop_gradient(force_fn(x, neighbor)))


  def solver(params, x):
    # params are unused
    del params
    # need to use no_grad_force_fn!
    return run_minimization_while_nl(neighbor_fn, no_grad_force_fn, x, shift)

  # We need to use hax_aux=True.
  decorated_solver = custom_root(force_fn, has_aux=True)(solver)

  R_final, nbrs, num_steps = decorated_solver(None, R_init)

  # Here we can just use our original energy_fn/force_fn.
  return (
    energy_fn(R_final, neighbor=nbrs),
    jnp.amax(jnp.abs(force_fn(R_final, neighbor=nbrs))),
  )


out_nl = jax.value_and_grad(implicit_diff_nl, has_aux=True)(
param_dict, R_init, box_size
)
(imp_nl_e, imp_nl_f), imp_nl_g = out_nl
print("Energy        : ", imp_nl_e)
print("Max_grad_force: ", imp_nl_f)
print("Gradient of the energy:")
print(imp_nl_g["sigma"][0])
print(imp_nl_g["alpha"][0])

Forced rebuilding of neighbor_list.
Rebuilding neighbor_list.
Energy        :  0.06159430273402543
Max_grad_force:  9.96494553540117e-13
Gradient of the energy:
[0.         0.00595969 0.01473254 0.03070799 0.         0.03070799
 0.00595969]
[ 0.         -0.0003579  -0.00184401 -0.00696197  0.         -0.00696197
 -0.0003579 ]


In [12]:
# Using neighbor lists also gives the same results.
print(jax.tree_map(jnp.allclose,exp_e,imp_nl_e))
print(jax.tree_map(jnp.allclose,exp_f,imp_nl_f))

# We cannot directly compare the gradients because by using neighbor lists we 
# silently assume that the input matrix is symmetric. Thus we compare the upper
# triangular part of exp_g to imp_nl_g. Due to this symmetry we also have to 
# divide exp_g by 2 in order to prevent double counting.
exp_g_triu = jax.tree_map(jnp.triu,exp_g)
imp_nl_g_2 = jax.tree_map(lambda x: x/2, imp_nl_g)
print(jax.tree_map(jnp.allclose,exp_g_triu, imp_nl_g_2))

True
True
{'alpha': DeviceArray(True, dtype=bool), 'sigma': DeviceArray(True, dtype=bool)}


# A cautionary tale

Jaxopt uses the [custom derivatives machinery of jax](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html) to implement implicit differentiation which I've found to be rather brittle compared to the rest of jax. 
In order to demonstrate this let us now make a seemingly trivial change to our code. We'll move the setup of our system into our `run_implicit` function and we make the `box_size` of our system depend on the particle diameter `D`. When we now try to call `grad` on `run_implicit` we get an `NotImplementedError: Differentiation rule for 'custom_lin' not implemented`.

In [13]:
def run_implicit(D, key, N=128):
  box_size = 4.5 * D

  # box_size = lax.stop_gradient(box_size)

  displacement, shift = space.periodic(box_size)

  R_init = random.uniform(key, (N, 2), minval=0.0, maxval=box_size, dtype=f64)

  energy_fn = jit(energy.soft_sphere_pair(displacement, sigma=D))

  force_fn = jit(quantity.force(energy_fn))

  # wrap force_fn with a lax.stop_gradient to prevent a CustomVJPException
  no_grad_force_fn = jit(lambda x: lax.stop_gradient(force_fn(x)))

  def solver(params, x):
    # params are unused
    del params
    # need to use no_grad_force_fn!
    return run_minimization_while(no_grad_force_fn, x, shift)[0]

  decorated_solver = custom_root(force_fn)(solver)

  R_final = decorated_solver(None, R_init)

  # Here we can just use our original energy_fn/force_fn
  return energy_fn(R_final)

In [14]:
key = random.PRNGKey(5)
jax.grad(run_implicit)(1.3, key, N = 32)

NotImplementedError: ignored

This problem is currently being tracked in JAX issue #8557. In the meantime we can avoid this error by wrapping `box_size` with a call to `lax.stop_gradient`. Indeed it's my experience that a call to `lax.stop_gradient` is often the only thing you need to add when you are working with custom derivatives.