In [1]:
import os
from typing import List

from ase.atoms import Atoms
from ase.build import bulk
from ase.calculators.lj import LennardJones
from ase.md.velocitydistribution import MaxwellBoltzmannDistribution, Stationary
from ase.md import VelocityVerlet
from ase import units
from ase.io import read, write
from asax.lj import LennardJones as AsaxLennardJones
import time
import matplotlib.pyplot as plt
import sklearn
import numpy as np
from jax import lax, config, jit
from jax_md import simulate, space, energy
from jax_md.simulate import NVEState
config.update("jax_enable_x64", True)

In [56]:
def initialize_cubic_argon(multiplier=5, sigma=2.0, epsilon=1.5, rc=10.0, ro=6.0, temperature_K: int = 30) -> Atoms:
    atoms = bulk("Ar", cubic=True) * [multiplier, multiplier, multiplier]
    MaxwellBoltzmannDistribution(atoms, temperature_K=temperature_K)
    Stationary(atoms)

    atoms.calc = LennardJones(sigma=sigma, epsilon=epsilon, rc=rc, ro=ro, smooth=True)
    return atoms

def read_cubic_argon():
    atoms = read("git/gknet-benchmarks/geometries/geometry.in", format="aims")
    atoms.calc = LennardJones(sigma=2.0, epsilon=1.5, rc=10.0, ro=6.0, smooth=True)
    return atoms

def write_cubic_argon():
    atoms = initialize_cubic_argon()
    write("git/gknet-benchmarks/geometries/geometry.in", atoms, velocities=True, format="aims")
    
def get_initial_nve_state(atoms: Atoms) -> NVEState:
    R = atoms.get_positions(wrapped=True)
    V = atoms.get_velocities()
    forces = atoms.get_forces()
    masses = atoms.get_masses()[0]
    return NVEState(R, V, forces, masses)

def get_milliseconds(start_time: float) -> float:
    return round((time.monotonic() - start_time) * 1000, 2)

def get_mean_step_time(batch_times: List[float], batch_size: int) -> float:
    return round(np.mean([bt/batch_size for bt in batch_times]), 2)

def print_difference_metrics(positions_1: np.array, positions_2: np.array):
    diff = positions_1 - positions_2
    max_diff = np.max(diff)
    sad = np.sum(np.abs(diff))
    print(diff)
    print("max diff: {}".format(max_diff))
    print("Sum of absolute differences: {}".format(sad))

In [16]:
sigma = 2.0
epsilon = 1.5
rc = 10.0
ro = 6.0

dt = 5 * units.fs

# 1. NVE Simulations
## 1.1 NVE in ASE

In [20]:
def run_ase_nve(atoms, steps, batch_size):
    print("n = {}".format(len(atoms)))
    
    positions = []
    velocities = []
    batch_times = []
    
    dyn = VelocityVerlet(atoms, timestep=dt)
    
    i = 0
    while i < steps:
        batch_start_time = time.monotonic()
        dyn.run(batch_size)

        batch_times += [get_milliseconds(batch_start_time)]
        positions += [atoms.get_positions(wrapped=True)]
        velocities += [atoms.get_velocities()]
        i += batch_size
        
    mean_step_time = get_mean_step_time(batch_times, batch_size)
    print("Average ms/step: {}".format(mean_step_time))
    
    return np.array(positions), np.array(velocities)

## 1.2 NVE in JAX-MD

In [79]:
def run_jaxmd_nve(atoms, steps, batch_size):
    print("n = {}".format(len(atoms)))

    @jit
    def step_fn(i, state):
        state, neighbors = state
        neighbors = neighbor_fn(state.position, neighbors)
        state = apply_fn(state, neighbor=neighbors)
        return state, neighbors

    positions = []
    velocities = []
    batch_times = []
    
    # setup displacement
    box = atoms.get_cell().array
    displacement_fn, shift_fn = space.periodic_general(box, fractional_coordinates=False)
    
    # normalize LJ parameters and setup NL energy function
    normalized_ro = ro / sigma
    normalized_rc = rc / sigma
    neighbor_fn, energy_fn = energy.lennard_jones_neighbor_list(displacement_fn, box,
                                                                sigma=sigma,
                                                                epsilon=epsilon,
                                                                r_onset=normalized_ro,
                                                                r_cutoff=normalized_rc,
                                                                dr_threshold=1 * units.Angstrom)
    energy_fn = jit(energy_fn)
    
    # compute initial neighbor list
    R = atoms.get_positions()
    neighbors = neighbor_fn(R)
    
    # compute initial state & setup NVE
    state = get_initial_nve_state(atoms)
    _, apply_fn = simulate.nve(energy_fn, shift_fn, dt=dt)
 
    # run MD loop
    i = 0
    while i < steps:
        batch_start_time = time.monotonic()
        state, neighbors = lax.fori_loop(0, batch_size, step_fn, (state, neighbors))

        if neighbors.did_buffer_overflow:
            neighbors = neighbor_fn(state.position)
            print("Steps {}/{}: Neighbor list overflow, recomputing...".format(i, steps))
            continue
            
        batch_times += [get_milliseconds(batch_start_time)]

        atoms.set_positions(state.position)
        positions += [atoms.get_positions()]
        velocities += [state.velocity]
        i += batch_size


    mean_step_time = get_mean_step_time(batch_times, batch_size)
    print("Average ms/step: {}".format(mean_step_time))

    return np.array(positions), np.array(velocities)

## 1.3 NVE in ASAX

In [93]:
def run_asax_nve(atoms, steps, batch_size):
    print("n = {}".format(len(atoms)))
    
    positions = []
    velocities = []
    batch_times = []
    
    atoms.calc = AsaxLennardJones(epsilon, sigma, rc, ro, stress=False)
    dyn = VelocityVerlet(atoms, timestep=dt)
    
    i = 0
    while i < steps:
        batch_start_time = time.monotonic()
        dyn.run(batch_size)

        batch_times += [get_milliseconds(batch_start_time)]
        positions += [atoms.get_positions()]
        velocities += [atoms.get_velocities()]
        i += batch_size
        
    mean_step_time = round(np.mean([bt/batch_size for bt in batch_times]), 2)
    print("Average ms/step: {}".format(mean_step_time))
    
    return np.array(positions), np.array(velocities)

# 2. Comparing atomic positions after 1 NVE step
- Positions and velocities should be equal or very similar.
- As a sanity check, we'll run two ASE simulations and compare their results first.

In [94]:
ase_atoms = read_cubic_argon()
run_ase_nve(ase_atoms, steps=1, batch_size=1)

ase_atoms_2 = read_cubic_argon()
run_ase_nve(ase_atoms_2, steps=1, batch_size=1)

print(ase_atoms.get_positions() - ase_atoms_2.get_positions())

n = 500
Average ms/step: 464.96
n = 500
Average ms/step: 549.61
[[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 ...
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]


Looks good! No positional differences after a single step.

## 2.1 JAX-MD

In [95]:
jax_atoms = read_cubic_argon()
jax_positions, jax_velocities = run_jaxmd_nve(jax_atoms, steps=1, batch_size=1)

jax_positions = jax_positions[0]
jax_velocities = jax_velocities[0]

jax_atoms.set_positions(jax_positions)

n = 500
Average ms/step: 550.7


### 2.1.1 Positions
#### No position wrapping

In [96]:
print_difference_metrics(jax_positions, ase_atoms.get_positions())
print()


[[-1.07388265e-11 -4.21484386e-12 -9.25449457e-11]
 [ 2.63000000e+01  3.36957129e-11 -3.37250228e-11]
 [-9.66862146e-11  2.63000000e+01 -4.32764935e-12]
 ...
 [ 9.04947228e-11 -6.74553746e-11  1.03490549e-11]
 [-2.06838990e-11  1.01938014e-10 -5.26370059e-11]
 [-6.49151843e-11  1.14479093e-10  1.09949383e-10]]
max diff: 26.300000000113826
Sum of absolute differences: 1709.5000000624268



#### Wrap ASE positions

In [97]:
print_difference_metrics(jax_positions, ase_atoms.get_positions(wrap=True))
print()

[[-1.07388266e-11 -4.21484386e-12 -9.25449457e-11]
 [ 2.16147100e-11  3.36957129e-11 -3.37250228e-11]
 [-9.66862146e-11  4.55138149e-11 -4.32764935e-12]
 ...
 [ 9.04947228e-11 -6.74553746e-11  1.03490549e-11]
 [-2.06874518e-11  1.01938014e-10 -5.26370059e-11]
 [-6.49151843e-11  1.14475540e-10  1.09949383e-10]]
max diff: 26.300000000000026
Sum of absolute differences: 26.300000062427042



#### Wrap JAX positions

In [98]:
print_difference_metrics(jax_atoms.get_positions(wrap=True), ase_atoms.get_positions(wrap=False))
print()

[[-1.07388265e-11 -4.21484386e-12 -9.25449457e-11]
 [ 2.63000000e+01  3.36957129e-11 -3.37250228e-11]
 [-9.66862146e-11  2.63000000e+01 -4.32764935e-12]
 ...
 [ 9.04947228e-11 -6.74553746e-11  1.03490549e-11]
 [-2.06803463e-11  1.01938014e-10 -5.26370059e-11]
 [-6.49151843e-11  1.14479093e-10  1.09952936e-10]]
max diff: 26.300000000113826
Sum of absolute differences: 1683.2000000624269



#### Wrap both

In [99]:
print_difference_metrics(jax_atoms.get_positions(wrap=True), ase_atoms.get_positions(wrap=True))
print()

[[-1.07388266e-11 -4.21484386e-12 -9.25449457e-11]
 [ 2.16147100e-11  3.36957129e-11 -3.37250228e-11]
 [-9.66862146e-11  4.55138149e-11 -4.32764935e-12]
 ...
 [ 9.04947228e-11 -6.74553746e-11  1.03490549e-11]
 [-2.06838990e-11  1.01938014e-10 -5.26370059e-11]
 [-6.49151843e-11  1.14475540e-10  1.09952936e-10]]
max diff: 1.7140244779056957e-10
Sum of absolute differences: 6.242711912594389e-08



- Positions returned by JAX-MD seem to be slightly shifted outside the box.
- These can be mapped back into the box by feeding them back into an ASE `Atoms` object and calling `jax_atoms.get_positions(wrap=True)`.
- Question: What is the JAX-MD way to achieve this?
- This has to be done for both ASE and JAX-MD positions to obtain a sum of absolute differences of $1e-8$.

### 2.1.2 Velocities

In [100]:
print_difference_metrics(jax_velocities, ase_atoms.get_velocities())

[[ 1.85839842e-13  1.02639858e-12 -1.37477182e-12]
 [ 7.41110732e-13 -5.02375051e-13  9.57654095e-15]
 [-1.41735061e-12  1.70618727e-14  3.74410572e-13]
 ...
 [ 1.79854742e-12 -6.36067587e-13 -9.26911542e-13]
 [-3.60068746e-13  2.36036884e-12 -3.92441288e-13]
 [-9.07685385e-13  1.30098710e-12  1.82709403e-12]]
max diff: 3.130797704420374e-12
Sum of absolute differences: 1.1145476759733911e-09


Everything looks good here.


## 2.2 ASAX

In [103]:
asax_atoms = read_cubic_argon()
asax_positions, asax_velocities = run_asax_nve(asax_atoms, steps=1, batch_size=1)
asax_positions = asax_positions[0]
asax_velocities = asax_velocities[0]

n = 500
Average ms/step: 803.28


### 2.2.1 Positions

In [107]:
print_difference_metrics(asax_atoms.get_positions(wrap=True), ase_atoms.get_positions(wrap=True))

[[4.44522891e-18 5.74627151e-18 4.33680869e-18]
 [0.00000000e+00 0.00000000e+00 0.00000000e+00]
 [0.00000000e+00 0.00000000e+00 0.00000000e+00]
 ...
 [0.00000000e+00 0.00000000e+00 0.00000000e+00]
 [0.00000000e+00 0.00000000e+00 0.00000000e+00]
 [0.00000000e+00 0.00000000e+00 0.00000000e+00]]
max diff: 8.239936510889834e-18
Sum of absolute differences: 1.3249416415893854e-15


Again, we have to wrap positions of both `Atoms` objects back into the box to get practically equivalent positions.

### 2.2.2 Velocities

In [108]:
print_difference_metrics(asax_velocities, ase_atoms.get_velocities())

[[ 1.66967135e-17  1.05167611e-17  3.46944695e-18]
 [ 5.20417043e-18  8.67361738e-18  1.73472348e-18]
 [-5.20417043e-18  6.93889390e-18 -4.11996826e-18]
 ...
 [ 1.21430643e-17  3.46944695e-18 -8.02309608e-18]
 [ 3.90312782e-18 -6.93889390e-18  1.21430643e-17]
 [-8.67361738e-18  1.38777878e-17  0.00000000e+00]]
max diff: 3.122502256758253e-17
Sum of absolute differences: 1.1971228451894064e-14
[[ 1.66967135e-17  1.05167611e-17  3.46944695e-18]
 [ 5.20417043e-18  8.67361738e-18  1.73472348e-18]
 [-5.20417043e-18  6.93889390e-18 -4.11996826e-18]
 ...
 [ 1.21430643e-17  3.46944695e-18 -8.02309608e-18]
 [ 3.90312782e-18 -6.93889390e-18  1.21430643e-17]
 [-8.67361738e-18  1.38777878e-17  0.00000000e+00]]
max diff: 3.122502256758253e-17
Sum of absolute differences: 1.1971228451894064e-14


Everything is fine here.

