# Neighbor list issues

- This notebook reproduces some issues that I've been having with neighbor lists.
- We use ASE for initialization and apply some small tricks to achieve compatibility to JAX-MD.
- To make sure that the following errors do not originate from our special case, I also looked intro reproducing said errors with plain JAX-MD code. As a starting point for this, I used your LJ benchmark from `jax-md/notebooks/lj_benchmark.ipynb`

## 1. Neighbor list crashes for small `n`
- This is essentially what you guys do in `jax-md/notebooks/lj_benchmark.ipynb`.
- Everything works fine for `n = 1000`, but the NL crashes for `n = 64`.
- I first noticed this in our setting, where we use ASE to initialize Lennard-Jones Argon.
- Here, the error becomes more apparent and appears for all systems with `n < 500`.

In [1]:
import jax.numpy as jnp
import numpy as np
from jax import jit, grad
from jax import random
from jax import lax
from jax.config import config
config.update('jax_enable_x64', True)
from jax_md import space, energy, simulate, quantity

In [2]:
def create_system(N_rep):
    lattice_constant = 1.37820
    box_size = N_rep * lattice_constant
    # Using float32 for positions / velocities, but float64 for reductions.
    dtype = jnp.float32

    R = []
    for i in range(N_rep):
      for j in range(N_rep):
        for k in range(N_rep):
          R += [[i, j, k]]
    R = jnp.array(R, dtype=dtype) * lattice_constant

    N = R.shape[0]
    phi = N / (lattice_constant * N_rep) ** 3
    print(f'Created a system of {N} LJ particles with number density {phi:.3f}')
    
    displacement, shift = space.periodic(box_size)
    return R, box_size, displacement, shift

In [3]:
def run_nvt(R, box_size, displacement, shift):
    
    def step(i, state_and_nbrs):
        state, nbrs = state_and_nbrs
        nbrs = neighbor_fn(state.position, nbrs)
        return apply(state, neighbor=nbrs), nbrs

    neighbor_fn, energy_fn = energy.lennard_jones_neighbor_list(displacement,
                                                            box_size, 
                                                            r_cutoff=3.0,
                                                            dr_threshold=1.)
    
    init, apply = simulate.nvt_nose_hoover(energy_fn, shift, 5e-3, kT=1.2)
    # init, apply = simulate.nve(energy_fn, shift, dt=5e-3)
    
    # nbrs = neighbor_fn(R, extra_capacity=55)
    nbrs = neighbor_fn(R)

    key = random.PRNGKey(0)
    state = init(key, R, kT=1.2, neighbor=nbrs)

    new_state, new_nbrs = lax.fori_loop(0, 1000, step, (state, nbrs))
    new_state.position.block_until_ready()
    
    # Check to make sure the neighbor list didn't overflow.  
    print("NL overflow = {}".format(new_nbrs.did_buffer_overflow))

Run a simulation - everything looks good here.

In [13]:
R, box_size, displacement, shift = create_system(N_rep=10)
run_nvt(R, box_size, displacement, shift)

Created a system of 1000 LJ particles with number density 0.382
NL overflow = False


However, we run into some NL indexing errors for very small systems (`n = 64`). As we use different Lennard-Jones parameters to simulate Argon, the same error appears for all systems with `n < 500`.

In [14]:
R, box_size, displacement, shift = create_system(N_rep=4)
run_nvt(R, box_size, displacement, shift)

Created a system of 64 LJ particles with number density 0.382


TypeError: true_fun and false_fun output must have identical types, got
NeighborList(idx=ShapedArray(int32[64,64]), reference_position=ShapedArray(float32[64,3]), did_buffer_overflow=ShapedArray(bool[]), max_occupancy=126, cell_list_fn=None)
and
NeighborList(idx=ShapedArray(int32[64,126]), reference_position=ShapedArray(float32[64,3]), did_buffer_overflow=ShapedArray(bool[]), max_occupancy=126, cell_list_fn=None).

## 2. Getting float32 for positions and velocities to work 

- This is about passing positions and velocities in single precision but allowing reductions to work with `float64`.
- Turns out `get_initial_nve_state()` was the issue! Explicitly casting to `jnp.float32()` solved the problem.


In [19]:
def run_argon_nve(R, box_size, displacement, shift, state):
    print("n = {}".format(R.shape[0]))
    
    def step(i, state_and_nbrs):
        state, nbrs = state_and_nbrs
        nbrs = neighbor_fn(state.position, nbrs)
        return apply(state, neighbor=nbrs), nbrs

    sigma=2.0
    epsilon=1.5
    rc=10.0
    ro=6.0
    
    neighbor_fn, energy_fn = energy.lennard_jones_neighbor_list(displacement,
                                                            box_size, 
                                                            sigma=sigma,
                                                            epsilon=epsilon,
                                                            r_cutoff=rc/sigma,
                                                            r_onset=ro/sigma,
                                                            dr_threshold=jnp.float32(1.0))
    
    init, apply = simulate.nve(energy_fn, shift, dt=5e-3)
        
    # nbrs = neighbor_fn(R, extra_capacity=55)
    nbrs = neighbor_fn(R)
    
    if state == None:
        key = random.PRNGKey(0)
        state = init(key, R, kT=1.2, neighbor=nbrs)
    
    new_state, new_nbrs = lax.fori_loop(0, 1000, step, (state, nbrs))
    new_state.position.block_until_ready()
    
    # Check to make sure the neighbor list didn't overflow.  
    print("NL overflow = {}".format(new_nbrs.did_buffer_overflow))

In [20]:
from ase import Atoms
from ase.build import bulk
from ase.calculators.lj import LennardJones
from ase.md.velocitydistribution import MaxwellBoltzmannDistribution, Stationary
from jax_md.simulate import NVEState

def create_cubic_argon(multiplier=5, sigma=2.0, epsilon=1.5, rc=10.0, ro=6.0, temperature_K: int=30):
    atoms = bulk("Ar", cubic=True) * [multiplier, multiplier, multiplier]
    MaxwellBoltzmannDistribution(atoms, temperature_K=temperature_K)
    Stationary(atoms)
    atoms.calc = LennardJones(sigma=sigma, epsilon=epsilon, rc=rc, ro=ro, smooth=True)
    return atoms

def get_initial_nve_state(atoms: Atoms) -> NVEState:
    # important: explicitly cast everything here to float32
    R = jnp.float32(atoms.get_positions())
    V = jnp.float32(atoms.get_velocities())  # å/ ase fs
    forces = jnp.float32(atoms.get_forces())
    masses = jnp.float32(atoms.get_masses())[0]
    return NVEState(R, V, forces, masses)

In [None]:
atoms = create_cubic_argon(multiplier=6)
state = get_initial_nve_state(atoms)

# box seems unaffected - this works both as float32 and float64
box = jnp.float32(atoms.get_cell().array)

# used to crash
R = jnp.float32(atoms.get_positions())
# R = jnp.array(atoms.get_positions(), dtype=jnp.float32)

# works
# R = jnp.float64(atoms.get_positions())
# R = jnp.array(atoms.get_positions(), dtype=jnp.float64)

displacement, shift = space.periodic_general(box, fractional_coordinates=False)
run_argon_nve(R, box, displacement, shift, state)