## Context
- Comparing ASE and JAX-MD output to ensure correctness
- Lennard-Jones Pair energy output is equal, but not the forces. Stress probably also mismatched.
- Unknown: Did the error exist before or did the new `periodic_general()` cause this? Probably the latter.
- We initialize JAX-MD via ASE, so most likely the error is in `asax.utils.get_displacement()` to convert from real coordinates (ASE) to relative coordinates (JAX-MD).

In [1]:
import numpy as np
import jax.numpy as jnp
from ase import Atoms
from ase.build import bulk
from ase.calculators.lj import LennardJones
# from asax.utils import get_displacement
from jax_md import space, energy
from periodic_general import periodic_general, inverse, transform
from jax.config import config
config.update("jax_enable_x64", True)

sigma = 2.0
epsilon = 1.5
rc = 11.0
ro = 6.0

In [2]:
atoms = bulk('Ar') * [2, 2, 2]
atoms.calc = LennardJones(sigma=sigma, epsilon=epsilon, rc=rc, ro=ro)

In [3]:
# why is this not a diagonal matrix?
# --> this does not describe a cube, but a parallelepiped
real_lattices = atoms.get_cell().array
print(real_lattices)

[[0.   5.26 5.26]
 [5.26 0.   5.26]
 [5.26 5.26 0.  ]]


In [4]:
# shouldn't this be all zeros, i.e. in equilibrium?
atoms.get_forces()

array([[-1.75098651e-17,  1.74556550e-17, -3.08997619e-17],
       [ 3.85298347e-17,  3.25802753e-17, -5.49853132e-16],
       [-2.83112292e-17,  1.61383493e-16, -6.17182087e-16],
       [ 2.77555756e-17, -2.60466019e-16, -4.39318720e-16],
       [ 5.80861314e-17, -3.33229538e-16, -3.33175328e-16],
       [-3.89635156e-17,  7.37528528e-17, -4.05112142e-16],
       [ 4.28490251e-16, -3.98444298e-17, -6.67841433e-16],
       [ 3.10949183e-16, -4.27175656e-17, -7.13405029e-17]])

In [5]:
real_atom_positions = atoms.get_positions()
print(real_atom_positions)

[[0.   0.   0.  ]
 [2.63 2.63 0.  ]
 [2.63 0.   2.63]
 [5.26 2.63 2.63]
 [0.   2.63 2.63]
 [2.63 5.26 2.63]
 [2.63 2.63 5.26]
 [5.26 5.26 5.26]]


### Atomic distances in ASE
- Real-valued input & output

In [6]:
# i.e. distance between atom indices 0 and 1
atoms.get_distance(1, 0, vector=True)
# order has to be 1, 0 for matching results

array([-2.63, -2.63,  0.  ])

### Atomic distances in JAX-MD
- Real-valued input using `ASE.atoms`
- Transformation to scaled coordinates inside `displacement_fn`
- Real-valued output

In [7]:
# define two real-valued atomic position vectors
Ra = np.array([0, 0, 0])
Rb = np.array([2.63, 2.63, 0])

What `asax` does: mapping real coordinates to scaled coordinates

In [16]:
cell = atoms.get_cell().array         # ASE's real coordinate parallelopiped

box = atoms.get_cell().array * np.eye(3)
print(cell)

inverse = space._small_inverse(cell)  # compute the inverse


scaled_coordinates_displacement_fn, _ = periodic_general(cell)

def displacement_fn(Ra_real: space.Array, Rb_real: space.Array, **unused_kwargs) -> space.Array:
    Ra_scaled = transform(inverse, Ra_real)
    Rb_scaled = transform(inverse, Rb_real)
    return scaled_coordinates_displacement_fn(Ra_scaled, Rb_scaled)

[[0.   5.26 5.26]
 [5.26 0.   5.26]
 [5.26 5.26 0.  ]]


In [9]:
displacement_real = displacement_fn(Ra, Rb)
print(displacement_real)

[-2.63 -2.63  0.  ]
