## Performance hit due to gradient multiplication
- While doing my own NVE benchmarks, I accidentally implemented interatomic forces as the energy gradient rather than the negative energy gradient.
- I noticed that this error speeds up NVE simulations by approx. 40%.
- I tried various implementations of the force function to pinpoint the source of this error.
- Based on `jax-md/notebooks/lj_benchmark.ipynb` as a starting point, the following is a reproduction of this error.
- I'm wondering how a single multiplication can have such a drastic performance hit.

In [6]:
import jax.numpy as jnp
import numpy as np
from jax import jit
from jax import random
from jax import lax, grad
from jax.config import config
config.update('jax_enable_x64', True)
from jax_md import space, energy, simulate, quantity, util
import time

### Prepare the system

In [13]:
lattice_constant = 1.37820
N_rep = 20
box_size = N_rep * lattice_constant
# Using float32 for positions / velocities, but float64 for reductions.
dtype = jnp.float32

displacement, shift = space.periodic(box_size)

In [14]:
R = []
for i in range(N_rep):
  for j in range(N_rep):
    for k in range(N_rep):
      R += [[i, j, k]]
R = jnp.array(R, dtype=dtype) * lattice_constant

In [15]:
N = R.shape[0]
phi = N / (lattice_constant * N_rep) ** 3
print(f'Created a system of {N} LJ particles with number density {phi:.3f}')

Created a system of 8000 LJ particles with number density 0.382


### NVE routine

In [None]:
def run_nve(neighbor_fn, energy_or_force_fn):
    
    def step(i, state_and_nbrs):
      state, nbrs = state_and_nbrs
      nbrs = neighbor_fn(state.position, nbrs)
      return apply(state, neighbor=nbrs), nbrs

    init, apply = simulate.nvt_nose_hoover(energy_or_force_fn, shift, 5e-3, kT=1.2)
    
    key = random.PRNGKey(0)
    nbrs = neighbor_fn(R, extra_capacity=55)
    state = init(key, R, neighbor=nbrs)

    # Run once to make sure the JIT cache is occupied.
    new_state, new_nbrs = lax.fori_loop(0, 10000, step, (state, nbrs))
    new_state.position.block_until_ready()

    start = time.monotonic()

    new_state, new_nbrs = lax.fori_loop(0, 10000, step, (state, nbrs))
    new_state.position.block_until_ready()

    print("Elapsed: {} seconds".format(time.monotonic() - start))

### Pass energy function

In [None]:
neighbor_fn, energy_fn = energy.lennard_jones_neighbor_list(displacement,
                                                            box_size, 
                                                            r_cutoff=3.0,
                                                            dr_threshold=1.)

run_nve(neighbor_fn, energy_fn)

### Pass force function
- The following functions correctly return interatomic forces.
- I was wondering whether the factor's data type might affect performance - it does not.

In [None]:
force_fn_1 = quantity.force(energy_fn)
run_nve(neighbor_fn, force_fn_1)

In [None]:
force_fn_2 = lambda R, **kwargs: -1.0 * grad(energy_fn)(R, **kwargs)
run_nve(neighbor_fn, force_fn_2)

In [None]:
force_fn_3 = lambda R, **kwargs: np.float32(-1.0) * grad(energy_fn)(R, **kwargs)
run_nve(neighbor_fn, force_fn_3)

In [None]:
force_fn_4 = lambda R, **kwargs: jnp.float32(-1.0) * grad(energy_fn)(R, **kwargs)
run_nve(neighbor_fn, force_fn_4)

In [None]:
force_fn_5 = lambda R, **kwargs: util.maybe_downcast(-1.0) * grad(energy_fn)(R, **kwargs)
run_nve(neighbor_fn, force_fn_5)

In [None]:
force_fn_6 = lambda R, **kwargs: np.short(-1) * grad(energy_fn)(R, **kwargs)
run_nve(neighbor_fn, force_fn_6)

TODO: Sum up

### Pass gradient function
- This is what I accidentally implemented and made me notice a seemingly decreased runtime.
- Obviously, the positive gradient yields incorrect forces.
- But a simple multiplication shouldn't affect performance that drastically, should it?
- I couldn't find any optimizations within `simulate.velocity_verlet()` that might disregard incorrect forces and cause speed-ups. Could this still be the case?

In [None]:
grad_fn_1 = grad(energy_fn)
run_nve(neighbor_fn, grad_fn_2)

In [None]:
grad_fn_2 = lambda R, **kwargs: grad(energy_fn)(R, **kwargs)
run_nve(neighbor_fn, grad_fn_2)

Note: Passing a (positive) grad function but causing the negation by setting LJ sigma < 0 causes the same slowdown.

### Only applies to NVE simulations

- I was wondering whether this is a general JAX phenomenon or related to JAX-MD simulations.
- When force and grad functions are compared without running NVE, the effect disappears.

In [None]:
nbrs = neighbor_fn(R)

In [None]:
%timeit force_fn_1(R, neighbor=nbrs)

In [None]:
%timeit grad_fn_1(R, neighbor=nbrs)