In [43]:
from ase.atoms import Atoms
from ase.build import bulk
from ase.calculators.lj import LennardJones
from ase.md.velocitydistribution import MaxwellBoltzmannDistribution, Stationary
from ase.md import VelocityVerlet
from ase import units
from ase.io import read, write
from asax.lj import LennardJones as AsaxLennardJones
import time
import matplotlib.pyplot as plt
import sklearn
import numpy as np
from jax import lax, config, jit
from jax_md import simulate, space, energy
from jax_md.simulate import NVEState
config.update("jax_enable_x64", True)

In [44]:
def initialize_cubic_argon(multiplier=5, sigma=2.0, epsilon=1.5, rc=10.0, ro=6.0, temperature_K: int = 30) -> Atoms:
    atoms = bulk("Ar", cubic=True) * [multiplier, multiplier, multiplier]
    MaxwellBoltzmannDistribution(atoms, temperature_K=temperature_K)
    Stationary(atoms)

    atoms.calc = LennardJones(sigma=sigma, epsilon=epsilon, rc=rc, ro=ro, smooth=True)
    return atoms

def read_cubic_argon():
    atoms = read("geometry.in", format="aims")
    atoms.calc = LennardJones(sigma=sigma, epsilon=epsilon, rc=rc, ro=ro, smooth=True)
    return atoms

def write_cubic_argon():
    atoms = initialize_cubic_argon()
    write("geometry.in", atoms, velocities=True, format="aims")
    
def get_initial_nve_state(atoms: Atoms) -> NVEState:
    R = atoms.get_positions(wrapped=True)
    V = atoms.get_velocities()
    forces = atoms.get_forces()
    masses = atoms.get_masses()[0]
    return NVEState(R, V, forces, masses)

In [45]:
sigma = 2.0
epsilon = 1.5
rc = 10.0
ro = 6.0

dt = 5 * units.fs

## NVE in ASE

In [46]:
def run_ase_nve(atoms, steps, batch_size):
    print("n = {}".format(len(atoms)))
    
    positions = []
    velocities = []
    batch_times = []
    
    dyn = VelocityVerlet(atoms, timestep=dt)
    
    i = 0
    while i < steps:
        i += batch_size
        batch_start_time = time.monotonic()
        dyn.run(batch_size)

        # elapsed time for simulating the last batch (in milliseconds)
        batch_times += [round((time.monotonic() - batch_start_time) * 1000, 2)]
        # print("Steps {}/{} took {} ms".format(i, steps, batch_times[-1]))
        
        positions += [atoms.get_positions(wrapped=True)]
        velocities += [atoms.get_velocities()]
        
    mean_step_time = round(np.mean([bt/batch_size for bt in batch_times]), 2)
    print("Average ms/step: {}".format(mean_step_time))
    
    return np.array(positions), np.array(velocities)

## NVE in JAX-MD

In [70]:
def run_jaxmd_nve(atoms, steps, batch_size):
    print("n = {}".format(len(atoms)))
    
    def step_fn(i, state):
        state, neighbors = state
        neighbors = neighbor_fn(state.position, neighbors)
        state = apply_fn(state, neighbor=neighbors)
        return state, neighbors

    positions = []
    velocities = []
    batch_times = []
    
    # setup displacement
    box = atoms.get_cell().array
    displacement_fn, shift_fn = space.periodic_general(box, fractional_coordinates=False)
    
    # normalize LJ parameters and setup NL energy function
    normalized_ro = ro / sigma
    normalized_rc = rc / sigma
    neighbor_fn, energy_fn = energy.lennard_jones_neighbor_list(displacement_fn, box,
                                                                sigma=sigma,
                                                                epsilon=epsilon,
                                                                r_onset=normalized_ro,
                                                                r_cutoff=normalized_rc,
                                                                dr_threshold=1 * units.Angstrom)
    energy_fn = jit(energy_fn)
    
    # compute initial neighbor list
    R = atoms.get_positions(wrapped=True)
    neighbors = neighbor_fn(R)
    
    # compute initial state & setup NVE
    state = get_initial_nve_state(atoms)
    _, apply_fn = simulate.nve(energy_fn, shift_fn, dt=dt)
 
    # run MD loop
    i = 0
    
    while i < steps:
        batch_start_time = time.monotonic()
        
        state, neighbors = lax.fori_loop(0, batch_size, step_fn, (state, neighbors))

        if neighbors.did_buffer_overflow:
            neighbors = neighbor_fn(state.position)
            print("Steps {}/{}: Neighbor list overflow, recomputing...".format(i, steps))
            continue
            
        batch_times += [round((time.monotonic() - batch_start_time) * 1000, 2)]
        i += batch_size
        # print("Steps {}/{} took {} ms".format(i, steps, batch_times[-1]))
        
        
        positions += [state.position.block_until_ready()]
        velocities += [state.velocity.block_until_ready()]
        
    mean_step_time = round(np.mean([bt/batch_size for bt in batch_times]), 2)
    print("Average ms/step: {}".format(mean_step_time))
    
    return np.array(positions), np.array(velocities)

## NVE in ASAX

In [71]:
def run_asax_nve(atoms, steps, batch_size):
    print("n = {}".format(len(atoms)))
    
    positions = []
    velocities = []
    batch_times = []
    
    atoms.calc = AsaxLennardJones(epsilon, sigma, rc, ro, stress=False)
    dyn = VelocityVerlet(atoms, timestep=dt)
    
    i = 0
    while i < steps:
        i += batch_size
        batch_start_time = time.monotonic()
        dyn.run(batch_size)

        # elapsed time for simulating the last batch (in milliseconds)
        batch_times += [round((time.monotonic() - batch_start_time) * 1000, 2)]
        # print("Steps {}/{} took {} ms".format(i, steps, batch_times[-1]))
        
        positions += [atoms.get_positions(wrapped=True)]
        velocities += [atoms.get_velocities()]
        
    mean_step_time = round(np.mean([bt/batch_size for bt in batch_times]), 2)
    print("Average ms/step: {}".format(mean_step_time))
    
    return np.array(positions), np.array(velocities)

## Does NVE compute the same thing?
- After 1 MD step, positions and velocities should be equal or very similar.
- As a sanity check, we'll run two ASE simulations and compare their results first.

In [72]:
ase_atoms = read_cubic_argon()
run_ase_nve(ase_atoms, steps=1, batch_size=1)

ase_atoms_2 = read_cubic_argon()
run_ase_nve(ase_atoms_2, steps=1, batch_size=1)

print(ase_atoms.get_positions() - ase_atoms_2.get_positions())

n = 500
Average ms/step: 521.96
n = 500
Average ms/step: 474.13
[[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 ...
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]


Looks good! No positional differences after a single step. 

### JAX-MD

In [117]:
jax_atoms = read_cubic_argon()
jax_positions, jax_velocities = run_jaxmd_nve(jax_atoms, steps=1, batch_size=1)

jax_positions = jax_positions[0]

print(np.sum(jax_positions - ase_atoms.get_positions()))
print()
print(jax_positions - ase_atoms.get_positions(wrapped=True))

n = 500
Average ms/step: 542.54
1709.4999999999998

[[-3.97338213e-11  2.63000000e+01  2.63000000e+01]
 [-3.26640884e-11 -3.06887848e-11 -1.87085902e-11]
 [-5.48978640e-11  2.63000000e+01  4.05822043e-11]
 ...
 [-2.88018498e-11 -1.28963507e-11  1.83995041e-11]
 [ 1.95754524e-12 -8.87503404e-11 -1.32168054e-10]
 [ 1.12997611e-10  8.62563354e-11  9.39763822e-11]]


In [118]:
ase_atoms.get_positions()

array([[ 3.02012126e-03, -3.97189777e-03, -1.72837188e-03],
       [ 2.48275914e-03,  2.63233260e+00,  2.63142202e+00],
       [ 2.63417272e+00, -1.53259100e-03,  2.62691538e+00],
       ...,
       [ 2.10421890e+01,  2.36709802e+01,  2.36686017e+01],
       [ 2.36698514e+01,  2.10467458e+01,  2.36800460e+01],
       [ 2.36614112e+01,  2.36634439e+01,  2.10328570e+01]])

In [119]:
jax_positions

array([[3.02012122e-03, 2.62960281e+01, 2.62982716e+01],
       [2.48275911e-03, 2.63233260e+00, 2.63142202e+00],
       [2.63417272e+00, 2.62984674e+01, 2.62691538e+00],
       ...,
       [2.10421890e+01, 2.36709802e+01, 2.36686017e+01],
       [2.36698514e+01, 2.10467458e+01, 2.36800460e+01],
       [2.36614112e+01, 2.36634439e+01, 2.10328570e+01]])

In [122]:
foo = Atoms(positions=jax_positions, cell=ase_atoms.get_cell(), pbc=True)
foo.wrap()
foo.get_positions()

array([[3.02012122e-03, 2.62960281e+01, 2.62982716e+01],
       [2.48275911e-03, 2.63233260e+00, 2.63142202e+00],
       [2.63417272e+00, 2.62984674e+01, 2.62691538e+00],
       ...,
       [2.10421890e+01, 2.36709802e+01, 2.36686017e+01],
       [2.36698514e+01, 2.10467458e+01, 2.36800460e+01],
       [2.36614112e+01, 2.36634439e+01, 2.10328570e+01]])

In [103]:
jax_atoms.get_cell()[0][0]

26.299999999999997

In [100]:
# position mismatch at indices [0][1]

# ase value             box      jax value
# -3.97189777e-03 * x * 26.29 =  2.62960281e+01
# x ≈ -251

# -3.97189777e-03 * (-251) * 26.29 = 26.2097192857

26.11522783775

## ASAX

In [67]:
asax_atoms = read_cubic_argon()
asax_positions, asax_velocities = run_asax_nve(asax_atoms, steps=1, batch_size=1)

print(np.sum(asax_positions - ase_atoms.get_positions()))
print()
print(asax_positions - ase_atoms.get_positions(wrapped=True))

n = 500
Average ms/step: 865.92
-2.790072318147041e-15

[[[4.77048956e-18 6.07153217e-18 5.63785130e-18]
  [2.60208521e-18 0.00000000e+00 0.00000000e+00]
  [0.00000000e+00 4.33680869e-18 0.00000000e+00]
  ...
  [0.00000000e+00 0.00000000e+00 0.00000000e+00]
  [0.00000000e+00 0.00000000e+00 0.00000000e+00]
  [0.00000000e+00 0.00000000e+00 0.00000000e+00]]]
