In [1]:
import numpy as np
import jax.numpy as jnp
from jax import jit, grad
from jax import numpy as jnp
from ase import Atoms
from ase.build import bulk
from ase.calculators.lj import LennardJones
from jax_md import space, energy, quantity
from jax.config import config
config.update("jax_enable_x64", True)

import asax.utils

In [2]:
atoms = bulk('Ar', cubic=True) * [5, 5, 5]
# atoms.set_cell(1.05 * atoms.get_cell(), scale_atoms=True)

R_real = atoms.get_positions()
max_box_length = np.max([np.linalg.norm(uv) for uv in atoms.get_cell().array])

sigma = 2.0
epsilon = 1.5
rc = 0.5 * max_box_length
ro = rc * 0.9

print("Cell:\t\t", atoms.get_cell())
print("max_box_length:\t", max_box_length)
print("ro:\t\t", ro)
print("rc:\t\t", rc)
print()
print("R_real:\n", R_real )

Cell:		 Cell([26.299999999999997, 26.299999999999997, 26.299999999999997])
max_box_length:	 26.299999999999997
ro:		 11.834999999999999
rc:		 13.149999999999999

R_real:
 [[ 0.    0.    0.  ]
 [ 0.    2.63  2.63]
 [ 2.63  0.    2.63]
 ...
 [21.04 23.67 23.67]
 [23.67 21.04 23.67]
 [23.67 23.67 21.04]]


## Old `periodic_general()`

In [3]:
atoms.calc = LennardJones(sigma=sigma, epsilon=epsilon, rc=rc, ro=ro, smooth=True)
ase_epot = atoms.get_potential_energy()
ase_forces = atoms.get_forces()

displacement_fn = asax.utils.get_displacement(atoms)
energy_fn = energy.lennard_jones_pair(displacement_fn, sigma=sigma, epsilon=epsilon, r_onset=ro/sigma, r_cutoff=rc/sigma)
force_fn = quantity.force(energy_fn)

jaxmd_epot = energy_fn(R_real)
jaxmd_forces = force_fn(R_real)

# energies
np.testing.assert_allclose(ase_epot, jaxmd_epot, rtol=1e-7, atol=1e-15, equal_nan=True)
np.testing.assert_allclose(ase_forces, jaxmd_forces, rtol=1e-7, atol=1e-14, equal_nan=True)

# TODO: check force

print("ASE e_pot:\t", ase_epot)
print("JAX-MD e_pot:\t", jaxmd_epot)

ASE e_pot:	 -507.708181818377
JAX-MD e_pot:	 -507.70818181837694


## New `periodic_general()`

In [4]:
from periodic_general import periodic_general as new_periodic_general
from periodic_general import inverse as new_inverse
from periodic_general import transform as new_transform

In [9]:
# what asax.utils.get_displacement() does, only with functions from the new periodic_general()
def new_get_displacement(atoms):
    cell = atoms.get_cell().array
    inverse_cell = new_inverse(cell)

    displacement_in_scaled_coordinates, _ = new_periodic_general(cell)
    
    print("Tracing new_get_displacement")

    def displacement(Ra: space.Array, Rb: space.Array, **unused_kwargs) -> space.Array:
        print("Tracing displacement")

        Ra_scaled = new_transform(inverse_cell, Ra)
        Rb_scaled = new_transform(inverse_cell, Rb)
        return displacement_in_scaled_coordinates(Ra_scaled, Rb_scaled)

    return displacement

new_displacement_fn = new_get_displacement(atoms)
new_energy_fn = energy.lennard_jones_pair(new_displacement_fn, sigma=sigma, epsilon=epsilon, r_onset=ro/sigma, r_cutoff=rc/sigma)
new_force_fn = quantity.force(energy_fn)

# TODO: Investigate correct JIT behavior
new_energy_fn(R_real)

Tracing new_get_displacement
Tracing displacement
Tracing displacement


DeviceArray(-507.70818182, dtype=float64)

### Energies & Forces

In [10]:
new_jaxmd_epot = new_energy_fn(R_real)
new_jaxmd_forces = new_force_fn(R_real)

# energies
np.testing.assert_allclose(ase_epot, new_jaxmd_epot, rtol=1e-7, atol=1e-15, equal_nan=True)
np.testing.assert_allclose(jaxmd_epot, new_jaxmd_epot, rtol=1e-7, atol=1e-15, equal_nan=True)

# forces
np.testing.assert_allclose(ase_forces, new_jaxmd_forces, rtol=1e-7, atol=1e-14, equal_nan=True)
np.testing.assert_allclose(jaxmd_forces, new_jaxmd_forces, rtol=1e-7, atol=1e-14, equal_nan=True)

print("ASE e_pot:\t\t", ase_epot)
print("JAX-MD e_pot:\t\t", jaxmd_epot)
print("New JAX-MD e_pot:\t", new_jaxmd_epot)

Tracing displacement
Tracing displacement
ASE e_pot:		 -507.708181818377
JAX-MD e_pot:		 -507.70818181837694
New JAX-MD e_pot:	 -507.70818181837706


## New `periodic_general()` with stress

In [11]:
box = atoms.get_cell().array
strained_box_energy_fn = lambda epsilon, R: new_energy_fn(R, box=box + epsilon)
strained_box_force_fn = grad(strained_box_energy_fn, argnums=1)

In [12]:
epsilon = jnp.zeros_like(box)  # commonly (3, 3)
new_jaxmd_epot_strained = strained_box_energy_fn(epsilon, R_real)
new_jaxmd_forces_strained = strained_box_force_fn(epsilon, R_real)

# energies
np.testing.assert_allclose(ase_epot, new_jaxmd_epot_strained, rtol=1e-7, atol=1e-15, equal_nan=True)
np.testing.assert_allclose(jaxmd_epot, new_jaxmd_epot_strained, rtol=1e-7, atol=1e-15, equal_nan=True)
np.testing.assert_allclose(new_jaxmd_epot, new_jaxmd_epot_strained, rtol=1e-7, atol=1e-15, equal_nan=True)

# forces
np.testing.assert_allclose(ase_forces, new_jaxmd_forces_strained, rtol=1e-7, atol=1e-14, equal_nan=True)
np.testing.assert_allclose(jaxmd_forces, new_jaxmd_forces_strained, rtol=1e-7, atol=1e-14, equal_nan=True)
np.testing.assert_allclose(new_jaxmd_forces, new_jaxmd_forces_strained, rtol=1e-7, atol=1e-14, equal_nan=True)

Tracing displacement
Tracing displacement


In [13]:
displacement_fn = new_get_displacement(atoms)
energy_fn = energy.lennard_jones_pair(displacement_fn, sigma=sigma, epsilon=epsilon, r_onset=ro/sigma, r_cutoff=rc/sigma)

energy_fn(R_real)


# deformation_energy = lambda epsilon, R: strained_box_energy_fn(R, box=box + epsilon)
# deformation_energy(epsilon, R_real)


Tracing new_get_displacement
Tracing displacement
Tracing displacement


TypeError: mul got incompatible shapes for broadcasting: (3, 3), (500, 500).