In [1]:
import jax.numpy as jnp
from jax import jit, grad, jacfwd, random, ops
from jax import numpy as jnp
from jax_md import space, energy, quantity, simulate, partition
from jax.config import config
config.update("jax_enable_x64", True)

from ase import Atoms
from ase.build import bulk
from ase.calculators.lj import LennardJones
import numpy as np

In [2]:
atoms = bulk('Ar', cubic=True) * [2, 2, 2]

R_real = atoms.get_positions()
max_box_length = np.max([np.linalg.norm(uv) for uv in atoms.get_cell().array])

sigma = 2.0
epsilon = 1.5
r_cutoff = 0.4 * max_box_length
r_onset = 0.9 * r_cutoff

atoms.calc = LennardJones(sigma=sigma, epsilon=epsilon, rc=r_cutoff, ro=r_onset, smooth=True)

In [3]:
atoms.get_potential_energy() 

-27.175665458230192

In [4]:
from periodic_general import periodic_general as new_periodic_general
from periodic_general import inverse as new_inverse
from periodic_general import transform as new_transform

# what asax.utils.get_displacement() does, only with functions from the new periodic_general()
def new_get_displacement(atoms):
    cell = atoms.get_cell().array
    inverse_cell = new_inverse(cell)

    displacement_in_scaled_coordinates, shift_in_scaled_coordinates = new_periodic_general(cell)

    # **kwargs are now used to feed through the box information
    @jit
    def displacement(Ra: space.Array, Rb: space.Array, **kwargs) -> space.Array:
        Ra_scaled = new_transform(inverse_cell, Ra)
        Rb_scaled = new_transform(inverse_cell, Rb)
        return displacement_in_scaled_coordinates(Ra_scaled, Rb_scaled, **kwargs)
    
    #@jit
    #def shift(R: space.Array, dR: space.Array, **kwargs) -> space.Array:
    #    R_scaled = new_transform(inverse_cell, R)
    #    # dR is an output of displacement and should be already in real coordinates
    #    return shfit_in_scaled_coordinates(R_scaled, dR, **kwargs)

    return jit(displacement), shift_in_scaled_coordinates

In [5]:
# without this, we get a strange index–out–of-bounds error
R_real = jnp.array(atoms.get_positions())
displacement_fn, _ = new_get_displacement(atoms)
neighbor_fn, energy_fn = energy.lennard_jones_neighbor_list(displacement_fn, max_box_length, sigma=sigma, epsilon=epsilon, r_onset=r_onset/sigma, r_cutoff=r_cutoff/sigma, per_particle=True)

energy_fn = jit(energy_fn)
nbrs = neighbor_fn(R_real)

In [10]:
%timeit energy_fn(R_real, neighbor=nbrs)

172 µs ± 265 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [11]:
total_energy_fn = lambda R, *args, **kwargs: jnp.sum(energy_fn(R, *args, **kwargs))

if True: total_energy_fn = jit(total_energy_fn)

%timeit total_energy_fn(R_real, neighbor=nbrs)

177 µs ± 376 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [9]:
force_fn = quantity.force(total_energy_fn)
force_fn = jit(force_fn)
%timeit force_fn(R_real, neighbor=nbrs)
# 273 µs ± 95.7 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

258 µs ± 91.1 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [7]:
@jit
def compute_properties(R: space.Array):
    total_energy_fn = jit(lambda R, *args, **kwargs: jnp.sum(energy_fn(R, *args, **kwargs)))
    force_fn = jit(quantity.force(total_energy_fn))
    
    return total_energy_fn(R, neighbor=nbrs), energy_fn(R, neighbor=nbrs), force_fn(R, neighbor=nbrs)

In [8]:
%timeit compute_properties(R_real)

# no jit:     4.68 ms ± 559 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
# decorator:  294 µs ± 166 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
# individual: 

230 µs ± 92.8 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [36]:
R_real = jnp.array(atoms.get_positions())

displacement_fn, _ = new_get_displacement(atoms)
neighbor_fn, energy_fn = energy.lennard_jones_neighbor_list(displacement_fn, max_box_length, sigma=sigma, epsilon=epsilon, r_onset=r_onset/sigma, r_cutoff=r_cutoff/sigma, per_particle=True)
energy_fn = jit(energy_fn)

# total_energy = energy_fn(R_real, neighbor=nbrs)

# force_fn = jit(lambda R: grad(energy_fn)(R, nbrs))
# force_fn = lambda R: grad(energy_fn)(R, nbrs)

# print(total_energy)
# %timeit force_fn(R_real)
# no jit:     2.59 ms ± 199 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
# w/ jit:     214 µs ± 64.7 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

16 ns ± 0.171 ns per loop (mean ± std. dev. of 7 runs, 100000000 loops each)
