In [7]:
import jax.numpy as np
import numpy as onp
from jax import jit
from jax import random
from jax import lax

from jax.config import config
config.update('jax_enable_x64', True)

from jax_md import space, energy, simulate, quantity

## Initialization logic
Taken from original LJ benchmark notebook at `jax-md/notebooks/lj_benchmark.ipynb`

In [8]:
def initialize_structure():
    lattice_constant = 1.37820
    N_rep = 40
    box_size = N_rep * lattice_constant

    # Using float32 for positions / velocities, but float64 for reductions.
    dtype = np.float32

    R = []
    for i in range(N_rep):
        for j in range(N_rep):
            for k in range(N_rep):
                R += [[i, j, k]]
    R = np.array(R, dtype=dtype) * lattice_constant

    N = R.shape[0]
    phi = N / (lattice_constant * N_rep) ** 3
    print(f'Created a system of {N} LJ particles with number density {phi:.3f}')

    return R, box_size

In [9]:
R, box_size = initialize_structure()
displacement, _ = space.periodic(box_size)

Created a system of 64000 LJ particles with number density 0.382


## Benchmark scenario from `jax-md/notebooks/lj_benchmark.ipynb`

This works for 64.000 atoms

In [13]:
# apart from r_cutoff, these are default values
sigma = 1.0
epsilon = 1.0
r_onset = 2.0
r_cutoff = 2.5

neighbor_fn, energy_fn = energy.lennard_jones_neighbor_list(displacement,
                                                            box_size,
                                                            sigma=sigma,
                                                            epsilon=epsilon,
                                                            r_cutoff=r_cutoff,
                                                            r_onset=r_onset,
                                                            dr_threshold=1.)

energy_fn = jit(energy_fn)
displacement = jit(displacement)

# in the original benchmark, they set extra_capacity=55. doesn't seem to matter here.
nbrs = neighbor_fn(R)
energy_fn(R, neighbor=nbrs)

DeviceArray(-124052.76, dtype=float32)

## Modeling elemental Argon with Lennard-Jones

- This is where we run out of memory
- As `energy.lennard_jones_neighbor_list()` multiplies cutoff and onset with $sigma$, we have to normalize in order to achieve equal results with `ASE`.

In [12]:
sigma = 3.4
epsilon = 10.42
r_cutoff = 10.54 / sigma
r_onset = 8 / sigma

neighbor_fn, energy_fn = energy.lennard_jones_neighbor_list(displacement,
                                                            box_size,
                                                            sigma=sigma,
                                                            epsilon=epsilon,
                                                            r_cutoff=r_cutoff,
                                                            r_onset=r_onset,
                                                            dr_threshold=1.)

energy_fn = jit(energy_fn)
displacement = jit(displacement)

nbrs = neighbor_fn(R)
energy_fn(R, neighbor=nbrs)

RuntimeError: Resource exhausted: Out of memory while trying to allocate 23708160000 bytes.