In [1]:
import numpy as np
import jax.numpy as jnp
from jax import jit, grad
from jax import numpy as jnp
from ase import Atoms
from ase.build import bulk
from ase.calculators.lj import LennardJones
from jax_md import space, energy, quantity
from jax.config import config
config.update("jax_enable_x64", True)

import asax.utils

In [2]:
atoms = bulk('Ar', cubic=True) * [5, 5, 5]
# atoms.set_cell(1.05 * atoms.get_cell(), scale_atoms=True)

R_real = atoms.get_positions()
max_box_length = np.max([np.linalg.norm(uv) for uv in atoms.get_cell().array])

sigma = 2.0
epsilon = 1.5
rc = 0.5 * max_box_length
ro = rc * 0.9

print("Cell:\t\t", atoms.get_cell())
print("max_box_length:\t", max_box_length)
print("ro:\t\t", ro)
print("rc:\t\t", rc)
print()
print("R_real:\n", R_real )

atoms.calc = LennardJones(sigma=sigma, epsilon=epsilon, rc=rc, ro=ro, smooth=True)

Cell:		 Cell([26.299999999999997, 26.299999999999997, 26.299999999999997])
max_box_length:	 26.299999999999997
ro:		 11.834999999999999
rc:		 13.149999999999999

R_real:
 [[ 0.    0.    0.  ]
 [ 0.    2.63  2.63]
 [ 2.63  0.    2.63]
 ...
 [21.04 23.67 23.67]
 [23.67 21.04 23.67]
 [23.67 23.67 21.04]]


## New `periodic_general()` with stress

In [3]:
from periodic_general import periodic_general as new_periodic_general
from periodic_general import inverse as new_inverse
from periodic_general import transform as new_transform

# what asax.utils.get_displacement() does, only with functions from the new periodic_general()
def new_get_displacement(atoms):
    cell = atoms.get_cell().array
    inverse_cell = new_inverse(cell)

    displacement_in_scaled_coordinates, _ = new_periodic_general(cell)

    # **kwargs are now used to feed through the box information
    def displacement(Ra: space.Array, Rb: space.Array, **kwargs) -> space.Array:
        Ra_scaled = new_transform(inverse_cell, Ra)
        Rb_scaled = new_transform(inverse_cell, Rb)
        return displacement_in_scaled_coordinates(Ra_scaled, Rb_scaled, **kwargs)

    return displacement

In [4]:
displacement_fn = new_get_displacement(atoms)
energy_fn = energy.lennard_jones_pair(displacement_fn, sigma=sigma, epsilon=epsilon, r_onset=ro/sigma, r_cutoff=rc/sigma)
force_fn = quantity.force(energy_fn)

In [5]:
# dynamically changing the box size
box = atoms.get_cell().array

# correct: shrinking the box should increase the energy
print(energy_fn(R_real, box=box))
print(atoms.get_potential_energy())

print(energy_fn(R_real, box=box/4))

-507.70818181837706
-507.708181818377
176288289.11375165


In [41]:
# computing stress
deformation_energy_fn = lambda deformation, R: energy_fn(R, box=box+deformation)
deformation_force_fn = lambda deformation, R: grad(deformation_energy_fn, argnums=1)(deformation, R) * -1
stress_fn = lambda deformation, R: grad(deformation_energy_fn, argnums=0)(deformation, R) / jnp.linalg.det(box)

deformation = jnp.zeros_like(box)

# correct: should be very close to the untouched box
print(deformation_energy_fn(deformation, R_real))

# correct: 3x3 tensor, non-zero
stress = stress_fn(deformation, R_real)
print(stress)

-507.70818181837694
[[ 2.08734099e-03 -3.39717472e-21 -3.38783604e-21]
 [-1.77606678e-21  2.08734099e-03  7.87671640e-22]
 [-3.79236803e-21  7.15779675e-22  2.08734099e-03]]


In [56]:
# comparing to ASE output
def all_close(arr1, arr2, rtol=1e-7, atol=1e-15):
    np.testing.assert_allclose(arr1, arr2, rtol=rtol, atol=atol, equal_nan=True)
    
# energy - with and without deformation
all_close(atoms.get_potential_energy(), energy_fn(R_real))
all_close(atoms.get_potential_energy(), deformation_energy_fn(deformation, R_real))

# forces - with and without deformation
all_close(atoms.get_forces(), force_fn(R_real), atol=1e-14)
all_close(atoms.get_forces(), deformation_force_fn(deformation, R_real), atol=1e-14)

# stress. atol=1e-1 !!
all_close(atoms.get_stress(voigt=False), stress_fn(deformation, R_real), atol=1e-1)