<a href="https://colab.research.google.com/github/google/jax-md/blob/main/notebooks/customizing_potentials_cookbook.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Customizing Potentials in JAX MD

This cookbook was contributed by Carl Goodrich.

In [None]:
#@title Imports & Utils
!pip install -q git+https://www.github.com/google/jax-md


import numpy as onp

import jax.numpy as np
from jax import random
from jax import jit, grad, vmap, value_and_grad
from jax import lax
from jax import ops

from jax import config
config.update("jax_enable_x64", True)

from jax_md import space, smap, energy, minimize, quantity, simulate, partition

from functools import partial
import time

f32 = np.float32
f64 = np.float64

import matplotlib
import matplotlib.pyplot as plt
plt.rcParams.update({'font.size': 16})
#import seaborn as sns 
#sns.set_style(style='white')

def format_plot(x, y):  
  plt.grid(True)
  plt.xlabel(x, fontsize=20)
  plt.ylabel(y, fontsize=20)
  
def finalize_plot(shape=(1, 0.7)):
  plt.gcf().set_size_inches(
    shape[0] * 1.5 * plt.gcf().get_size_inches()[1], 
    shape[1] * 1.5 * plt.gcf().get_size_inches()[1])
  plt.tight_layout()

def calculate_bond_data(displacement_or_metric, R, dr_cutoff, species=None):
  if( not(species is None)):
    assert(False)
    
  metric = space.map_product(space.canonicalize_displacement_or_metric(displacement))
  dr = metric(R,R)

  dr_include = np.triu(np.where(dr<dr_cutoff, 1, 0)) - np.eye(R.shape[0],dtype=np.int32)
  index_list=np.dstack(np.meshgrid(np.arange(N), np.arange(N), indexing='ij'))

  i_s = np.where(dr_include==1, index_list[:,:,0], -1).flatten()
  j_s = np.where(dr_include==1, index_list[:,:,1], -1).flatten()
  ij_s = np.transpose(np.array([i_s,j_s]))

  bonds = ij_s[(ij_s!=np.array([-1,-1]))[:,1]]
  lengths = dr.flatten()[(ij_s!=np.array([-1,-1]))[:,1]]

  return bonds, lengths

def plot_system(R,box_size,species=None,ms=20):
  R_plt = onp.array(R)

  if(species is None):
    plt.plot(R_plt[:, 0], R_plt[:, 1], 'o', markersize=ms)
  else:
    for ii in range(np.amax(species)+1):
      Rtemp = R_plt[species==ii]
      plt.plot(Rtemp[:, 0], Rtemp[:, 1], 'o', markersize=ms)

  plt.xlim([0, box_size])
  plt.ylim([0, box_size])
  plt.xticks([], [])
  plt.yticks([], [])

  finalize_plot((1,1))
  
key = random.PRNGKey(0)

##Prerequisites

This cookbook assumes a working knowledge of Python and Numpy. The concept of broadcasting is particularly important both in this cookbook and in JAX MD. 

We also assume a basic knowlege of [JAX](https://github.com/google/jax/), which JAX MD is built on top of. Here we briefly review a few JAX basics that are important for us:


*   ```jax.vmap``` allows for automatic vectorization of a function. What this means is that if you have a function that takes an input ```x``` and returns an output ```y```, i.e. ```y = f(x)```, then ```vmap``` will transform this function to act on an array of ```x```'s and return an array of ```y```'s, i.e. ```Y = vmap(f)(X)```, where ```X=np.array([x1,x2,...,xn])``` and ```Y=np.array([y1,y2,...,yn])```. 

*   ```jax.grad``` employs automatic differentiation to transform a function into a new function that calculates its gradient, for example: ```dydx = grad(f)(x)```. 

*   ```jax.lax.scan``` allows for efficient for-loops that can be compiled and differentiated over. See [here](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html#jax.lax.scan) for more details.

*   [Random numbers are different in JAX.](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#%F0%9F%94%AA-Random-Numbers) The details aren't necessary for this cookbook, but if things look a bit different, this is why.






##The basics of user-defined potentials

###Create a user defined potential function to use throughout this cookbook



Here we create a custom potential that has a short-ranged, non-diverging repulsive interaction and a medium-ranged Morse-like attractive interaction. It takes the following form:
\begin{equation}
V(r) =
\begin{cases}
    \frac{1}{2} k (r-r_0)^2 - D_0,&  r < r_0\\
    D_0\left( e^{-2\alpha (r-r_0)} -2 e^{-\alpha(r-r_0)}\right),              & r \geq r_0
\end{cases}
\end{equation}
and has 4 parameters: $D_0$, $\alpha$, $r_0$, and $k$.


In [None]:
def harmonic_morse(dr, D0=5.0, alpha=5.0, r0=1.0, k=50.0, **kwargs):
  U = np.where(dr < r0, 
               0.5 * k * (dr - r0)**2 - D0,
               D0 * (np.exp(-2. * alpha * (dr - r0)) - 2. * np.exp(-alpha * (dr - r0)))
               )
  return np.array(U, dtype=dr.dtype)

plot $V(r)$.

In [None]:
drs = np.arange(0,3,0.01)
U = harmonic_morse(drs)
plt.plot(drs,U)
format_plot(r'$r$', r'$V(r)$')
finalize_plot()

###Calculate the energy of a system of interacting particles

We now want to calculate the energy of a system of $N$ spheres in $d$ dimensions, where each particle interacts with every other particle via our user-defined function $V(r)$. The total energy is
\begin{equation}
E_\text{total} = \sum_{i<j}V(r_{ij}),
\end{equation}
where $r_{ij}$ is the distance between particles $i$ and $j$. 

Our first task is to set up the system by specifying the $N$, $d$, and the size of the simulation box. We then use JAX's internal random number generator to pick positions for each particle. 

In [None]:
N = 50
dimension = 2
box_size = 6.8

key, split = random.split(key)
R = random.uniform(split, (N,dimension), minval=0.0, maxval=box_size, dtype=f64) 

plot_system(R,box_size)

At this point, we could manually loop over all particle pairs and calculate the energy, keeping track of boundary conditions, etc. Fortunately, JAX MD has machinery to automate this. 

First, we must define two functions, ```displacement``` and ```shift```, which contain all the information of the simulation box, boundary conditions, and underlying metric. ```displacement``` is used to calculate the vector displacement between particles, and ```shift``` is used to move particles. For most cases, it is recommended to use JAX MD's built in functions, which can be called using:
*   ``` displacement, shift = space.free()```
*   ``` displacement, shift = space.periodic(box_size)```
*   ``` displacement, shift = space.periodic_general(T)```

For demonstration purposes, we will define these manually for a square periodic box, though without proper error handling, etc. The following should have the same functionality as ```displacement, shift = space.periodic(box_size)```.

In [None]:
def setup_periodic_box(box_size):
  def displacement_fn(Ra, Rb, **unused_kwargs):
    dR = Ra - Rb
    return np.mod(dR + box_size * f32(0.5), box_size) - f32(0.5) * box_size

  def shift_fn(R, dR, **unused_kwargs):
    return np.mod(R + dR, box_size)

  return displacement_fn, shift_fn
  
displacement, shift = setup_periodic_box(box_size)

We now set up a function to calculate the total energy of the system. The JAX MD function ```smap.pair``` takes a given potential and promotes it to act on all particle pairs in a system. ```smap.pair``` does not actually return an energy, rather it returns a function that can be used to calculate the energy. 

For convenience and readability, we wrap ```smap.pair``` in a new function called ```harmonic_morse_pair```. For now, ignore the species keyword, we will return to this later.

In [None]:
def harmonic_morse_pair(
    displacement_or_metric, species=None, D0=5.0, alpha=10.0, r0=1.0, k=50.0): 
  D0 = np.array(D0, dtype=f32)
  alpha = np.array(alpha, dtype=f32)
  r0 = np.array(r0, dtype=f32)
  k = np.array(k, dtype=f32)
  return smap.pair(
      harmonic_morse,
      space.canonicalize_displacement_or_metric(displacement_or_metric),
      species=species,
      D0=D0,
      alpha=alpha,
      r0=r0,
      k=k)

Our helper function can be used to construct a function to compute the energy of the entire system as follows.



In [None]:
# Create a function to calculate the total energy with specified parameters
energy_fn = harmonic_morse_pair(displacement,D0=5.0,alpha=10.0,r0=1.0,k=500.0)

# Use this to calculate the total energy
print(energy_fn(R))

# Use grad to calculate the net force
force = -grad(energy_fn)(R)
print(force[:5])

We are now in a position to use our energy function to manipulate the system. As an example, we perform energy minimization using JAX MD's implementation of the FIRE algorithm. 

We start by defining a function that takes an energy function, a set of initial positions, and a shift function and runs a specified number of steps of the minimization algorithm. The function returns the final set of positions and the maximum absolute value component of the force. We will use this function throughout this cookbook. 

In [None]:
def run_minimization(energy_fn, R_init, shift, num_steps=5000):
  dt_start = 0.001
  dt_max   = 0.004
  init,apply=minimize.fire_descent(jit(energy_fn),shift,dt_start=dt_start,dt_max=dt_max)
  apply = jit(apply)

  @jit
  def scan_fn(state, i):
    return apply(state), 0.

  state = init(R_init)
  state, _ = lax.scan(scan_fn,state,np.arange(num_steps))

  return state.position, np.amax(np.abs(-grad(energy_fn)(state.position)))

Now run the minimization with our custom energy function.

In [None]:
Rfinal, max_force_component = run_minimization(energy_fn, R, shift)
print('largest component of force after minimization = {}'.format(max_force_component))
plot_system( Rfinal, box_size )

###Create a truncated potential

It is often desirable to have a potential that is strictly zero beyond a well-defined cutoff distance. In addition, MD simulations require the energy and force (i.e. first derivative) to be continuous. To easily modify an existing potential $V(r)$ to have this property, JAX MD follows the approach [taken by HOOMD Blue](https://hoomd-blue.readthedocs.io/en/stable/module-md-pair.html#hoomd.md.pair.pair). 

Consider the function 
\begin{equation}
S(r) =
\begin{cases}
    1,& r<r_\mathrm{on} \\
    \frac{(r_\mathrm{cut}^2-r^2)^2 (r_\mathrm{cut}^2 + 2r^2 - 3 r_\mathrm{on}^2)}{(r_\mathrm{cut}^2-r_\mathrm{on}^2)^3},&  r_\mathrm{on} \leq r < r_\mathrm{cut}\\
    0,& r \geq r_\mathrm{cut}
\end{cases}
\end{equation}

Here we plot both $S(r)$ and $\frac{dS(r)}{dr}$, both of which are smooth and strictly zero above $r_\mathrm{cut}$.





In [None]:
dr = np.arange(0,3,0.01)
S = energy.multiplicative_isotropic_cutoff(lambda dr: 1, r_onset=1.5, r_cutoff=2.0)(dr)
ngradS = vmap(grad(energy.multiplicative_isotropic_cutoff(lambda dr: 1, r_onset=1.5, r_cutoff=2.0)))(dr)
plt.plot(dr,S,label=r'$S(r)$')
plt.plot(dr,ngradS,label=r'$\frac{dS(r)}{dr}$')
plt.legend()
format_plot(r'$r$','')
finalize_plot()

We then use $S(r)$ to create a new function 
\begin{equation}\tilde V(r) = V(r) S(r),
\end{equation} 
which is exactly $V(r)$ below $r_\mathrm{on}$, strictly zero above $r_\mathrm{cut}$ and is continuous in its first derivative.

This is implemented in JAX MD through ```energy.multiplicative_isotropic_cutoff```, which takes in a potential function $V(r)$ (e.g. our ```harmonic_morse``` function) and returns a new function $\tilde V(r)$.

In [None]:
harmonic_morse_cutoff = energy.multiplicative_isotropic_cutoff(
    harmonic_morse, r_onset=1.5, r_cutoff=2.0)

dr = np.arange(0,3,0.01)
V = harmonic_morse(dr)
V_cutoff = harmonic_morse_cutoff(dr)
F = -vmap(grad(harmonic_morse))(dr)
F_cutoff = -vmap(grad(harmonic_morse_cutoff))(dr)
plt.plot(dr,V, label=r'$V(r)$')
plt.plot(dr,V_cutoff, label=r'$\tilde V(r)$')
plt.plot(dr,F, label=r'$-\frac{d}{dr} V(r)$')
plt.plot(dr,F_cutoff, label=r'$-\frac{d}{dr} \tilde V(r)$')
plt.legend()
format_plot('$r$', '')
plt.ylim(-13,5)
finalize_plot()

As before, we can use ```smap.pair``` to promote this to act on an entire system.

In [None]:
def harmonic_morse_cutoff_pair(
    displacement_or_metric, D0=5.0, alpha=5.0, r0=1.0, k=50.0,
    r_onset=1.5, r_cutoff=2.0): 
  D0 = np.array(D0, dtype=f32)
  alpha = np.array(alpha, dtype=f32)
  r0 = np.array(r0, dtype=f32)
  k = np.array(k, dtype=f32)
  return smap.pair(
      energy.multiplicative_isotropic_cutoff(
          harmonic_morse, r_onset=r_onset, r_cutoff=r_cutoff),
      space.canonicalize_displacement_or_metric(displacement_or_metric),
      D0=D0,
      alpha=alpha,
      r0=r0,
      k=k)

This is implemented as before

In [None]:
# Create a function to calculate the total energy
energy_fn = harmonic_morse_cutoff_pair(displacement, D0=5.0, alpha=10.0, r0=1.0, 
                                       k=500.0, r_onset=1.5, r_cutoff=2.0)

# Use this to calculate the total energy
print(energy_fn(R))

# Use grad to calculate the net force
force = -grad(energy_fn)(R)
print(force[:5])

# Minimize the energy using the FIRE algorithm
Rfinal, max_force_component = run_minimization(energy_fn, R, shift)
print('largest component of force after minimization = {}'.format(max_force_component))
plot_system( Rfinal, box_size )

##Specifying parameters

###Dynamic parameters

In the above examples, the strategy is to create a function ```energy_fn``` that takes a set of positions and calculates the energy of the system with all the parameters (e.g. ```D0```, ```alpha```, etc.) baked in. However, JAX MD allows you to override these baked-in values dynamically, i.e. when ```energy_fn``` is called. 

For example, we can print out the minimized energy and force of the above system with the truncated potential:

In [None]:
print(energy_fn(Rfinal))
print(-grad(energy_fn)(Rfinal)[:5])

This uses the baked-in values of the 4 parameters: ```D0=5.0,alpha=10.0,r0=1.0,k=500.0```. If, for example, we want to dynamically turn off the attractive part of the potential, we simply pass ```D0=0``` to ```energy_fn```:

In [None]:
print(energy_fn(Rfinal, D0=0))

Since changing the potential moves the minimum, the force will not be zero:

In [None]:
print(-grad(energy_fn)(Rfinal, D0=0)[:5])

This ability to dynamically pass parameters is very powerful. For example, if you want to shrink particles each step during a simulation, you can simply specify a different ```r0``` each step. 

This is demonstrated below, where we run a Brownian dynamics simulation at zero temperature with continuously decreasing ```r0```. The details of ```simulate.brownian``` are beyond the scope of this cookbook, but the idea is that we pass a new value of ```r0``` to the function ```apply``` each time it is called. The function ```apply``` takes a step of the simulation, and internally it passes any extra parameters like ```r0``` to ```energy_fn```.

In [None]:
def run_brownian(energy_fn, R_init, shift, key, num_steps):
  init, apply = simulate.brownian(energy_fn, shift, 
                                  dt=0.00001, kT=0.0, gamma=0.1)
  apply = jit(apply)

  # Define how r0 changes for each step
  r0_initial = 1.0
  r0_final = .5
  def get_r0(t):
    return r0_final + (r0_initial-r0_final)*(num_steps-t)/num_steps

  @jit
  def scan_fn(state, t):
    # Dynamically pass r0 to apply, which passes it on to energy_fn
    return apply(state, r0=get_r0(t)), 0

  key, split = random.split(key)
  state = init(split, R_init)

  state, _ = lax.scan(scan_fn,state,np.arange(num_steps))
  return state.position, np.amax(np.abs(-grad(energy_fn)(state.position)))

If we use the previous result as the starting point for the Brownian Dynamics simulation, we find exactly what we would expect, the system contracts into a finite cluster, held together by the attractive part of the potential.

In [None]:
key, split = random.split(key)
Rfinal2, max_force_component = run_brownian(energy_fn, Rfinal, shift, split, 
                                            num_steps=6000)
plot_system( Rfinal2, box_size )

###Particle-specific parameters

Our example potential has 4 parameters: ```D0```, ```alpha```, ```r0```, and ```k```. The usual way to pass these parameters is as a scalar (e.g. ```D0=5.0```), in which case that parameter is fixed for every particle pair. However, Python broadcasting allows for these parameters to be specified separately for every different particle pair by passing an $(N,N)$ array rather than a scalar. 

As an example, let's do this for the parameter ```r0```, which is an effective way of generating a system with continuous polydispersity in particle size. Note that the polydispersity disrupts the crystalline order after minimization.

In [None]:
# Draw the radii from a uniform distribution
key, split = random.split(key)
radii = random.uniform(split, (N,), minval=1.0, maxval=2.0, dtype=f64)

# Rescale to match the initial volume fraction
radii = np.array([radii * np.sqrt(N/(4.*np.dot(radii,radii)))])

# Turn this into a matrix of sums
r0_matrix = radii+radii.transpose()

# Create the energy function using r0_matrix
energy_fn = harmonic_morse_pair(displacement, D0=5.0, alpha=10.0, r0=r0_matrix, 
                                       k=500.0)

# Minimize the energy using the FIRE algorithm
Rfinal, max_force_component = run_minimization(energy_fn, R, shift)
print('largest component of force after minimization = {}'.format(max_force_component))
plot_system( Rfinal, box_size )

In addition to standard Python broadcasting, JAX MD allows for the special case of additive parameters. If a parameter is passed as a (N,) array ```p_vector```, JAX MD will convert this into a (N,N) array ```p_matrix``` where ```p_matrix[i,j] = 0.5 (p_vector[i] + p_vector[j])```. This is a JAX MD specific ability and not a feature of Python broadcasting.

As it turns out, our above polydisperse example falls into this category. Therefore, we could achieve the same result by passing ```r0=2.0*radii```.

In [None]:
# Create the energy function the radii array
energy_fn = harmonic_morse_pair(displacement, D0=5.0, alpha=10.0, r0=2.*radii, 
                                       k=500.0)

# Minimize the energy using the FIRE algorithm
Rfinal, max_force_component = run_minimization(energy_fn, R, shift)
print('largest component of force after minimization = {}'.format(max_force_component))
plot_system( Rfinal, box_size )

### Species

It is often important to specify parameters differently for different particle pairs, but doing so with full ($N$,$N$) matrices is both inefficient and obnoxious. JAX MD allows users to create species, i.e. $N_s$ groups of particles that are identical to each other, so that parameters can be passed as much smaller ($N_s$,$N_s$) matrices.

First, create an array that specifies which particles belong in which species. We will divide our system into two species.

In [None]:
N_0 = N // 2  # Half the particles in species 0
N_1 = N - N_0 # The rest in species 1
species = np.array([0] * N_0 + [1] * N_1, dtype=np.int32)
print(species)

Next, create the $(2,2)$ matrix of ```r0```'s, which are set so that the overall volume fraction matches our monodisperse case. 

In [None]:
rsmall=0.41099747 # Match the total volume fraction
rlarge=1.4*rsmall
r0_species_matrix = np.array([[2*rsmall, rsmall+rlarge],
                              [rsmall+rlarge, 2*rlarge]])
print(r0_species_matrix)

In [None]:
energy_fn = harmonic_morse_pair(displacement, species=species, D0=5.0, 
                                alpha=10.0, r0=r0_species_matrix, k=500.0)

Rfinal, max_force_component = run_minimization(energy_fn, R, shift)
print('largest component of force after minimization = {}'.format(max_force_component))

plot_system(Rfinal, box_size, species=species )

###Dynamic Species

Just like standard parameters, the species list can be passed dynamically as well. However, unlike standard parameters, you have to tell `smap.pair` that the species will be specified dynamically. To do this, set  `species=2` be the total number of types of particles when creating your energy function.

The following sets up an energy function where the attractive part of the interaction only exists between members of the first species, but where the species will be defined dynamically.

In [None]:
D0_species_matrix = np.array([[ 5.0, 0.0],
                              [0.0,  0.0]])

energy_fn = harmonic_morse_pair(displacement, 
                                species=2, 
                                D0=D0_species_matrix, 
                                alpha=10.0,
                                r0=0.5, 
                                k=500.0)

Now we set up a finite temperature Brownian Dynamics simulation where, at every step, particles on the left half of the simulation box are assigned to species 0, while particles on the right half are assigned to species 1.

In [None]:
def run_brownian(energy_fn, R_init, shift, key, num_steps):
  init, apply = simulate.brownian(energy_fn, shift, dt=0.00001, kT=1.0, gamma=0.1)
  # apply = jit(apply)

  # Define a function to recalculate the species each step
  def get_species(R):
    return np.where(R[:,0] < box_size / 2, 0, 1)

  @jit
  def scan_fn(state, t):
    # Recalculate the species list
    species = get_species(state.position)
    # Dynamically pass species to apply, which passes it on to energy_fn
    return apply(state, species=species, species_count=2), 0

  key, split = random.split(key)
  state = init(split, R_init)

  state, _ = lax.scan(scan_fn,state,np.arange(num_steps))
  return state.position,np.amax(np.abs(-grad(energy_fn)(state.position,
                                                        species=get_species(state.position), 
                                                        species_count=2)))

When we run this, we see that particles on the left side form clusters while particles on the right side do not.

In [None]:
key, split = random.split(key)
Rfinal, max_force_component = run_brownian(energy_fn, R, shift, split, num_steps=10000)
plot_system( Rfinal, box_size )

##Efficeiently calculating neighbors

The most computationally expensive part of most MD programs is calculating the force between all pairs of particles. Generically, this scales with $N^2$. However, for systems with isotropic pairwise interactions that are strictly zero beyond a cutoff, there are techniques to dramatically improve the efficiency. The two most common methods are cell list and neighbor lists.

**Cell lists**

The technique here is to divide space into small cells that are just larger than the largest interaction range in the system. Thus, if particle $i$ is in cell $c_i$ and particle $j$ is in cell $c_j$, $i$ and $j$ can only interact if $c_i$ and $c_j$ are neighboring cells. Rather than searching all $N^2$ combinations of particle pairs for non-zero interactions, you only have to search the particles in the neighboring cells. 

**Neighbor lists**

Here, for each particle $i$, we make a list of *potential* neighbors: particles $j$ that are within some threshold distance $r_\mathrm{threshold}$. If $r_\mathrm{threshold} = r_\mathrm{cutoff} + \Delta r_\mathrm{threshold}$ (where $r_\mathrm{cutoff}$ is the largest interaction range in the system and $\Delta r_\mathrm{threshold}$ is an appropriately chosen buffer size), then all interacting particles will appear in this list as long as no particles moves by more than $\Delta r_\mathrm{threhsold}/2$. There is a tradeoff here: smaller $\Delta r_\mathrm{threhsold}$ means fewer particles to search over each MD step but the list must be recalculated more often, while larger $\Delta r_\mathrm{threhsold}$ means slower force calculates but less frequent neighbor list calculations. 

In practice, the most efficient technique is often to use cell lists to calculate neighbor lists. In JAX MD, this occurs under the hood, and so only calls to neighbor-list functionality are necessary.

To implement neighbor lists, we need two functions: 1) a function to create and update the neighbor list, and 2) an energy function that uses a neighbor list rather than operating on all particle pairs. We create these functions with ```partition.neighbor_list``` and ```smap.pair_neighbor_list```, respectively. 

```partition.neighbor_list``` takes basic box information as well as the maximum interaction range ```r_cutoff``` and the buffer size ```dr_threshold```. 

In [None]:
 def harmonic_morse_cutoff_neighbor_list(
    displacement_or_metric,
    box_size,
    species=None,
    D0=5.0, 
    alpha=5.0, 
    r0=1.0, 
    k=50.0,
    r_onset=1.0,
    r_cutoff=1.5, 
    dr_threshold=2.0,
    format=partition.OrderedSparse,
    **kwargs): 

  D0 = np.array(D0, dtype=np.float32)
  alpha = np.array(alpha, dtype=np.float32)
  r0 = np.array(r0, dtype=np.float32)
  k = np.array(k, dtype=np.float32)
  r_onset = np.array(r_onset, dtype=np.float32)
  r_cutoff = np.array(r_cutoff, np.float32)
  dr_threshold = np.float32(dr_threshold)

  neighbor_fn = partition.neighbor_list(
        displacement_or_metric, 
        box_size, 
        r_cutoff, 
        dr_threshold,
        format=format)

  energy_fn = smap.pair_neighbor_list(
    energy.multiplicative_isotropic_cutoff(harmonic_morse, r_onset, r_cutoff),
    space.canonicalize_displacement_or_metric(displacement_or_metric),
    species=species,
    D0=D0,
    alpha=alpha,
    r0=r0,
    k=k)

  return neighbor_fn, energy_fn

To test this, we generate our new ```neighbor_fn``` and ```energy_fn```, as well as a comparison energy function using the default approach.

In [None]:
r_onset  = 1.5
r_cutoff = 2.0
dr_threshold = 1.0

neighbor_fn, energy_fn = harmonic_morse_cutoff_neighbor_list(
    displacement, box_size, D0=5.0, alpha=10.0, r0=1.0, k=500.0,
    r_onset=r_onset, r_cutoff=r_cutoff, dr_threshold=dr_threshold)

energy_fn_comparison = harmonic_morse_cutoff_pair(
    displacement, D0=5.0, alpha=10.0, r0=1.0, k=500.0,
    r_onset=r_onset, r_cutoff=r_cutoff)

Next, we use ```neighbor_fn.allocate``` and the current set of positions to populate the neighbor list.

In [None]:
nbrs = neighbor_fn.allocate(R)

To calculate the energy, we pass `nbrs` to `energy_fn`. The energy matches the comparison.

In [None]:
print(energy_fn(R, neighbor=nbrs))
print(energy_fn_comparison(R))

Note that by default ```neighbor_fn``` uses a cell list internally to populate the neighbor list. This approach fails when the box size in any dimension is less than 3 times $r_\mathrm{threhsold} = r_\mathrm{cutoff} + \Delta r_\mathrm{threshold}$. In this case, ```neighbor_fn``` automatically turns off the use of cell lists, and instead searches over all particle pairs. This can also be done manually by passing ```disable_cell_list=True``` to ```partition.neighbor_list```. This can be useful for debugging or for small systems where the overhead of cell lists outweighs the benefit. 

###Updating neighbor lists

The function ```neighbor_fn``` has two different usages, depending on how it is called. When used as above, i.e. ```nbrs = neighbor_fn(R)```, a new neighbor list is generated from scratch. Internally, JAX MD uses the given positions ```R``` to estimate a maximum capacity, i.e. the maximum number of neighbors any particle will have at any point during the use of the neighbor list. This estimate can be adjusted by passing a value of  ```capacity_multiplier``` to ```partition.neighbor_list```, which defaults to ```capacity_multiplier=1.25```.

Since the maximum capacity is not known ahead of time, this construction of the neighbor list cannot be compiled. However, once a neighbor list is created in this way, repopulating the list with the same maximum capacity is a simpler operation that *can* be compiled. This is done by calling ```nbrs = neighbor_fn(R, nbrs)```. Internally, this checks if any particle has moved more than $\Delta r_\mathrm{threshold}/2$ and, if so, recomputes the neighbor list. If the new neighbor list exceeds the maximum capacity for any particle, the boolean variable ```nbrs.did_buffer_overflow``` is set to ```True```. 

These two uses together allow for safe and efficient neighbor list calculations. The example below demonstrates a typical simulation loop that uses neighbor lists. 



In [None]:
def run_brownian_neighbor_list(energy_fn, neighbor_fn, R_init, shift, key, num_steps):
  nbrs = neighbor_fn.allocate(R_init)

  init, apply = simulate.brownian(energy_fn, shift, dt=0.00001, kT=1.0, gamma=0.1)

  def body_fn(state, t):
    state, nbrs = state
    nbrs = nbrs.update(state.position)
    state = apply(state, neighbor=nbrs)
    return (state, nbrs), 0

  key, split = random.split(key)
  state = init(split, R_init)

  step = 0
  step_inc=100
  while step < num_steps/step_inc:
    rtn_state, _ = lax.scan(body_fn, (state, nbrs), np.arange(step_inc))
    new_state, nbrs = rtn_state
    # If the neighbor list overflowed, rebuild it and repeat part of 
    # the simulation.
    if nbrs.did_buffer_overflow:
      print('Buffer overflow.')
      nbrs = neighbor_fn.allocate(state.position)
    else:
      state = new_state
      step += 1

  return state.position

To run this, we consider a much larger system than we have to this point. Warning: running this may take a few minutes.

In [None]:
Nlarge = 100*N
box_size_large = 10*box_size
displacement_large, shift_large = setup_periodic_box(box_size_large)

key, split1, split2 = random.split(key,3)
Rlarge = random.uniform(split1, (Nlarge,dimension), minval=0.0, maxval=box_size_large, dtype=f64) 

dr_threshold = 1.5
neighbor_fn, energy_fn = harmonic_morse_cutoff_neighbor_list(
    displacement_large, box_size_large, D0=5.0, alpha=10.0, r0=1.0, k=500.0,
    r_onset=r_onset, r_cutoff=r_cutoff, dr_threshold=dr_threshold)
energy_fn = jit(energy_fn)

start_time = time.process_time()
Rfinal = run_brownian_neighbor_list(energy_fn, neighbor_fn, Rlarge, shift_large, split2, num_steps=4000)
end_time = time.process_time()
print('run time = {}'.format(end_time-start_time))

plot_system( Rfinal, box_size_large, ms=2 )

##Bonds

Bonds are a way of specifying potentials between specific pairs of particles that are "on" regardless of separation. For example, it is common to employ a two-sided spring potential between specific particle pairs, but JAX MD allows the user to specify arbitrary potentials with static or dynamic parameters. 

### Create and implement a bond potential

We start by creating a custom potential that corresponds to a bistable spring, taking the form
\begin{equation}
V(r) = a_4(r-r_0)^4 - a_2(r-r_0)^2.
\end{equation}
$V(r)$ has two minima, at $r = r_0 \pm \sqrt{\frac{a_2}{2a_4}}$.

In [None]:
def bistable_spring(dr, r0=1.0, a2=2, a4=5, **kwargs):
  return a4*(dr-r0)**4 - a2*(dr-r0)**2

Plot $V(r)$

In [None]:
drs = np.arange(0,2,0.01)
U = bistable_spring(drs)
plt.plot(drs,U)
format_plot(r'$r$', r'$V(r)$')
finalize_plot()

The next step is to promote this function to act on a set of bonds. This is done via ```smap.bond```, which takes our ```bistable_spring``` function, our displacement function, and a list of the bonds. It returns a function that calculates the energy for a given set of positions.

In [None]:
def bistable_spring_bond(
    displacement_or_metric, bond, bond_type=None, r0=1, a2=2, a4=5):
  """Convenience wrapper to compute energy of particles bonded by springs."""
  r0 = np.array(r0, f32)
  a2 = np.array(a2, f32)
  a4 = np.array(a4, f32)
  return smap.bond(
    bistable_spring,
    space.canonicalize_displacement_or_metric(displacement_or_metric),
    bond,
    bond_type,
    r0=r0,
    a2=a2,
    a4=a4)

However, in order to implement this, we need a list of bonds. We will do this by taking a system minimized under our original ```harmonic_morse``` potential:

In [None]:
R_temp, max_force_component = run_minimization(harmonic_morse_pair(displacement,D0=5.0,alpha=10.0,r0=1.0,k=500.0), R, shift)
print('largest component of force after minimization = {}'.format(max_force_component))
plot_system( R_temp, box_size )

We now place a bond between all particle pairs that are separated by less than 1.3. ```calculate_bond_data``` returns a list of such bonds, as well as a list of the corresponding current length of each bond.  

In [None]:
bonds, lengths = calculate_bond_data(displacement, R_temp, 1.3)

print(bonds[:5])   # list of particle index pairs that form bonds
print(lengths[:5]) # list of the current length of each bond

We use this length as the ```r0``` parameter, meaning that initially each bond is at the unstable local maximum $r=r_0$.

In [None]:
bond_energy_fn = bistable_spring_bond(displacement, bonds, r0=lengths)

We now use our new ```bond_energy_fn``` to minimize the energy of the system. The expectation is that nearby particles should either move closer together or further apart, and the choice of which to do should be made collectively due to the constraint of constant volume. This is exactly what we see.

In [None]:
Rfinal, max_force_component = run_minimization(bond_energy_fn, R_temp, shift)
print('largest component of force after minimization = {}'.format(max_force_component))
plot_system( Rfinal, box_size )

###Specifying bonds dynamically

As with species or parameters, bonds can be specified dynamically, i.e. when the energy function is called. Importantly, note that this does NOT override bonds that were specified statically in ```smap.bond```.

In [None]:
# Specifying the bonds dynamically ADDS additional bonds. 
#  Here, we dynamically pass the same bonds that were passed statically, which 
#  has the effect of doubling the energy
print(bond_energy_fn(R))
print(bond_energy_fn(R,bonds=bonds, r0=lengths))

We won't go thorugh a further example as the implementation is exactly the same as specifying species or parameters dynamically, but the ability to employ bonds both statically and dynamically is a very powerful and general framework.

## Combining potentials 



Most JAX MD functionality (e.g. simulations, energy minimizations) relies on a function that calculates energy for a set of positions. Importantly, while this cookbook focus on simple and robust ways of defining such functions, JAX MD is not limited to these methods; users can implement energy functions however they like. 

As an important example, here we consider the case where the energy includes both a pair potential and a bond potential. Specifically, we combine ```harmonic_morse_pair``` with ```bistable_spring_bond```. 

In [None]:
# Note, the code in the "Bonds" section must be run prior to this.
energy_fn = harmonic_morse_pair(displacement,D0=0.,alpha=10.0,r0=1.0,k=1.0)
bond_energy_fn = bistable_spring_bond(displacement, bonds, r0=lengths)
def combined_energy_fn(R):
  return energy_fn(R) + bond_energy_fn(R)

Here, we have set $D_0=0$, so the pair potential is just a one-sided repulsive harmonic potential. For particles connected with a bond, this raises the energy of the "contracted" minimum relative to the "extended" minimum.

In [None]:
drs = np.arange(0,2,0.01)
U = harmonic_morse(drs,D0=0.,alpha=10.0,r0=1.0,k=1.0)+bistable_spring(drs)
plt.plot(drs,U)
format_plot(r'$r$', r'$V(r)$')
finalize_plot()

This new energy function can be passed to the minimization routine (or any other JAX MD simulation routine) in the usual way.

In [None]:
Rfinal, max_force_component = run_minimization(combined_energy_fn, R_temp, shift)
print('largest component of force after minimization = {}'.format(max_force_component))
plot_system( Rfinal, box_size )

##Specifying forces instead of energies

So far, we have defined functions that calculate the energy of the system, which we then pass to JAX MD. Internally, JAX MD uses automatic differentiation to convert these into functions that calculate forces, which are necessary to evolve a system under a given dynamics. However, JAX MD has the option to pass force functions directly, rather than energy functions. This creates additional flexibility because some forces cannot be represented as the gradient of a potential.

As a simple example, we create a custom force function that zeros out the force of some particles. During energy minimization, where there is no stochastic noise, this has the effect of fixing the position of these particles.

First, we break the system up into two species, as before.

In [None]:
N_0 = N // 2  # Half the particles in species 0
N_1 = N - N_0 # The rest in species 1
species = np.array([0]*N_0 + [1]*N_1, dtype=np.int32)
print(species)

Next, we we creat our custom force function. Starting with our ```harmonic_morse``` pair potential, we calculate the force manually (i.e. using built-in automatic differentiation), and then multiply the force by the species id, which has the desired effect. 

In [None]:
energy_fn = harmonic_morse_pair(displacement,D0=5.0,alpha=10.0,r0=1.0,k=500.0)
force_fn = quantity.force(energy_fn)

def custom_force_fn(R, **kwargs):
  return vmap(lambda a,b: a*b)(force_fn(R),species)

Running simulations with custom forces is as easy as passing this force function to the simulation. 

In [None]:
def run_minimization_general(energy_or_force, R_init, shift, num_steps=5000):
  dt_start = 0.001
  dt_max   = 0.004
  init,apply=minimize.fire_descent(jit(energy_or_force),shift,dt_start=dt_start,dt_max=dt_max)
  apply = jit(apply)

  @jit
  def scan_fn(state, i):
    return apply(state), 0.

  state = init(R_init)
  state, _ = lax.scan(scan_fn,state,np.arange(num_steps))

  return state.position, np.amax(np.abs(quantity.canonicalize_force(energy_or_force)(state.position)))

We run this as usual,

In [None]:
key, split = random.split(key)
Rfinal, _ = run_minimization_general(custom_force_fn, R, shift)
plot_system( Rfinal, box_size, species )

After the above minimization, the blue particles have the same positions as they did initially:

In [None]:
plot_system( R, box_size, species )

Note, this method for fixing particles only works when there is no stochastic noise (e.g. in Langevin or Brownian dynamics) because such noise affects partices whether or not they have a net force. A safer way to fix particles is to create a custom ```shift``` function.

##Coupled ensembles

For a final example that demonstrates the flexibility within JAX MD, lets do something that is particularly difficult in most standard MD packages. We will create a "coupled ensemble" -- i.e. a set of two identical systems that are connected via a $Nd$ dimensional spring. An extension of this idea is used, for example, in the Doubly Nudged Elastic Band method for finding transition states. 

If the "normal" energy of each system is 
\begin{equation}
U(R) = \sum_{i,j} V( r_{ij} ),
\end{equation}
where $r_{ij}$ is the distance between the $i$th and $j$th particles in $R$ and the $V(r)$ is a standard pair potential, and if the two sets of positions, $R_0$ and $R_1$, are coupled via the potential
\begin{equation}
U_\mathrm{spr}(R_0,R_1) = \frac 12 k_\mathrm{spr} \left| R_1 - R_0 \right|^2,
\end{equation}
so that the total energy of the system is 
\begin{equation}
U_\mathrm{total} = U(R_0) + U(R_1) + U_\mathrm{spr}(R_0,R_1).
\end{equation}


In [None]:
energy_fn = harmonic_morse_pair(displacement,D0=5.0,alpha=10.0,r0=0.5,k=500.0)
def spring_energy_fn(Rall, k_spr=50.0, **kwargs):
  metric = vmap(space.canonicalize_displacement_or_metric(displacement), (0, 0), 0)
  dr = metric(Rall[0],Rall[1])
  return 0.5*k_spr*np.sum((dr)**2)
def total_energy_fn(Rall, **kwargs):
  return np.sum(vmap(energy_fn)(Rall)) + spring_energy_fn(Rall)

We now have to define a new shift function that can handle arrays of shape $(2,N,d)$. In addition, we make two copies of our initial positions ```R```, one for each system. 

In [None]:
def shift_all(Rall, dRall, **kwargs):
  return vmap(shift)(Rall, dRall)
Rall = np.array([R,R])

Now, all we have to do is pass our custom energy and shift functions, as well as the $(2,N,d)$ dimensional initial position, to JAX MD, and proceed as normal. 

As a demonstration, we define a simple and general Brownian Dynamics simulation function, similar to the simulation routines above except without the special cases (e.g. chaning ```r0``` or species). 

In [None]:
def run_brownian_simple(energy_or_force, R_init, shift, key, num_steps):
  init, apply = simulate.brownian(energy_or_force, shift, dt=0.00001, kT=1.0, gamma=0.1)
  apply = jit(apply)

  @jit
  def scan_fn(state, t):
    return apply(state), 0

  key, split = random.split(key)
  state = init(split, R_init)

  state, _ = lax.scan(scan_fn, state, np.arange(num_steps))
  return state.position

Note that nowhere in this function is there any indication that we are simulating an ensemble of systems. This comes entirely form the inputs: i.e. the energy function, the shift function, and the set of initial positions. 

In [None]:
key, split = random.split(key)
Rall_final = run_brownian_simple(total_energy_fn, Rall, shift_all, split, num_steps=10000)

The output also has shape $(2,N,d)$. If we display the results, we see that the two systems are in similar, but not identical, positions, showing that we have succeeded in simulating a coupled ensemble. 

In [None]:
for Ri in Rall_final:
  plot_system( Ri, box_size )
finalize_plot((0.5,0.5))