In [1]:
import numpy as np
import jax.numpy as jnp
from jax import jit, grad
from jax import numpy as jnp
from ase import Atoms
from ase.build import bulk
from ase.calculators.lj import LennardJones
from jax_md import space, energy, quantity
from jax.config import config
config.update("jax_enable_x64", True)

import asax.utils

In [2]:
atoms = bulk('Ar', cubic=True) * [5, 5, 5]
atoms.set_cell(1.05 * atoms.get_cell(), scale_atoms=True)

R_real = atoms.get_positions()
max_box_length = np.max([np.linalg.norm(uv) for uv in atoms.get_cell().array])

sigma = 2.0
epsilon = 1.5
rc = 0.5 * max_box_length
ro = rc * 0.9

print("Cell:\t\t", atoms.get_cell())
print("max_box_length:\t", max_box_length)
print("ro:\t\t", ro)
print("rc:\t\t", rc)
print()
print("R_real:\n", R_real )

Cell:		 Cell([27.615, 27.615, 27.615])
max_box_length:	 27.615
ro:		 12.42675
rc:		 13.8075

R_real:
 [[ 0.      0.      0.    ]
 [ 0.      2.7615  2.7615]
 [ 2.7615  0.      2.7615]
 ...
 [22.092  24.8535 24.8535]
 [24.8535 22.092  24.8535]
 [24.8535 24.8535 22.092 ]]


## New `periodic_general()` with stress

In [28]:
from periodic_general import periodic_general as new_periodic_general
from periodic_general import inverse as new_inverse
from periodic_general import transform as new_transform

# what asax.utils.get_displacement() does, only with functions from the new periodic_general()
def new_get_displacement(atoms):
    cell = atoms.get_cell().array
    inverse_cell = new_inverse(cell)

    displacement_in_scaled_coordinates, _ = new_periodic_general(cell)

    # **kwargs are now used to feed through the box information
    def displacement(Ra: space.Array, Rb: space.Array, **kwargs) -> space.Array:
        Ra_scaled = new_transform(inverse_cell, Ra)
        Rb_scaled = new_transform(inverse_cell, Rb)
        return displacement_in_scaled_coordinates(Ra_scaled, Rb_scaled, **kwargs)

    return displacement

In [32]:
displacement_fn = new_get_displacement(atoms)
energy_fn = energy.lennard_jones_pair(displacement_fn, sigma=sigma, epsilon=epsilon, r_onset=ro/sigma, r_cutoff=rc/sigma)

mapping real to scaled coordinates


In [33]:
# dynamically changing the box size
box = atoms.get_cell().array

# correct: shrinking the box should increase the energy
print(energy_fn(R_real, box=box))
print(energy_fn(R_real, box=box/4))

mapping real to scaled coordinates
-380.8736576477754
mapping real to scaled coordinates
97759717.42890728


In [31]:
# computing stress
deformation_energy_fn = lambda deformation, R: energy_fn(R, box=box+deformation)

deformation = jnp.zeros_like(box)

# correct: should be very close to the untouched box
print(deformation_energy_fn(deformation, R_real))

# correct: 3x3 tensor, non-zero
stress = grad(deformation_energy_fn, argnums=0)(deformation, R_real) / jnp.linalg.det(box)
print(stress)

mapping real to scaled coordinates
-380.87365764777536
mapping real to scaled coordinates
[[ 1.29529915e-03  1.03998362e-21  9.83046596e-23]
 [ 4.63359039e-22  1.29529915e-03 -1.36336787e-21]
 [ 7.85298791e-22 -7.79985964e-22  1.29529915e-03]]
