In [1]:
import jax.numpy as np
import numpy as onp
from jax import jit
from jax import random
from jax import lax
import time

from jax.config import config
config.update('jax_enable_x64', True)

from jax_md import space, energy, simulate, quantity

## Initialization logic
Taken from original LJ benchmark notebook at `jax-md/notebooks/lj_benchmark.ipynb`

In [2]:
def initialize_structure(N_rep=40):
    lattice_constant = 1.37820
    box_size = N_rep * lattice_constant

    # Using float32 for positions / velocities, but float64 for reductions.
    # dtype = np.float32
    dtype = np.float64

    R = []
    for i in range(N_rep):
        for j in range(N_rep):
            for k in range(N_rep):
                R += [[i, j, k]]
    R = np.array(R, dtype=dtype) * lattice_constant


    N = R.shape[0]
    phi = N / (lattice_constant * N_rep) ** 3
    print(f'Created a system of {N} LJ particles with number density {phi:.3f}')

    return R, box_size

In [3]:
R, box_size = initialize_structure(N_rep=60)
displacement_fn, shift_fn = space.periodic(box_size)

Created a system of 216000 LJ particles with number density 0.382


## Benchmark potential

In [4]:
sigma = 1.0
epsilon = 1.0
r_onset = 2.0
# r_cutoff = 2.5

# 4.2
r_cutoff = 5

neighbor_fn, energy_fn = energy.lennard_jones_neighbor_list(displacement_fn,
                                                            box_size,
                                                            sigma=sigma,
                                                            epsilon=epsilon,
                                                            r_cutoff=r_cutoff,
                                                            r_onset=r_onset,
                                                            dr_threshold=1.)

energy_fn = jit(energy_fn)
displacement_fn = jit(displacement_fn)

# in the original benchmark, they set extra_capacity=55. doesn't seem to matter here.
nbrs = neighbor_fn(R)

RuntimeError: Resource exhausted: Out of memory while trying to allocate 10005984336 bytes.

In [5]:
#%%timeit -n 10000
energy_fn(R, neighbor=nbrs).block_until_ready()



DeviceArray(-452505.97, dtype=float32)

## Benchmark NVE

In [17]:
init_fn, apply_fn = simulate.nve(energy_fn, shift_fn, dt=1e-3)
init_fn = jit(init_fn)
apply_fn = jit(apply_fn)

In [18]:
state = init_fn(random.PRNGKey(0), R, kT=10, neighbor=nbrs)

In [19]:
@jit
def step_fn(i, state):
    state, nbrs = state
    nbrs = neighbor_fn(state.position, nbrs)
    state = apply_fn(state, neighbor=nbrs)
    return state, nbrs

In [20]:
%%time

step = 0
while step < 2000:
    new_state, nbrs = lax.fori_loop(0, 5, step_fn, (state, nbrs))
    if nbrs.did_buffer_overflow:
        nbrs = neighbor_fn(state.position)
    else:
        state = new_state
        step += 1
        
new_state.position.block_until_ready()

CPU times: user 5min 1s, sys: 1min 49s, total: 6min 50s
Wall time: 6min 33s


DeviceArray([[2.72203297e+00, 7.91638639e+01, 8.05392589e+01],
             [8.10163721e+01, 2.67793808e-02, 5.06349043e+00],
             [8.11568458e+01, 7.74385037e+01, 8.11570601e+01],
             ...,
             [8.10261684e+01, 8.74069029e+00, 7.34692893e+01],
             [8.84427235e-01, 8.22364651e+01, 7.85193114e+01],
             [8.14395714e+01, 8.24457489e+01, 1.02035859e+01]],            dtype=float64)