## Computing stress in JAX-MD
The goal of this notebook is to demonstrate how to effectively compute stress in JAX-MD. We obtain accurate atomic positions by using ASE to initialize the system and show how to use them in JAX-MD. We compare computed properties with `ASE` calculations for a strained and non-strained system respectively. Finally, we investigate JIT behavior with the provided new implementation of `periodic_general()`.

In [114]:
import jax.numpy as jnp
from jax import jit, grad, jacfwd, random, ops
from jax import numpy as jnp
from jax_md import space, energy, quantity, simulate
from jax.config import config
config.update("jax_enable_x64", True)

from ase import Atoms
from ase.build import bulk
from ase.calculators.lj import LennardJones
import numpy as np

key = random.PRNGKey(0)

## Initialization
- Use `ASE` to initialize a cubic Argon super cell
- Obtain atom coordinates in real space
- Set Lennard-Jones parameters such that `ASE`'s neighbor list implementation will behave equivalently to a pair potential
- Using `smooth=True` will cause the energy to smoothly go to 0 between `r_onset` and `r_cutoff`, replicating `JAX-MD`'s behavior on this.

In [101]:
atoms = bulk('Ar', cubic=True) * [5, 5, 5]

R_real = atoms.get_positions()
max_box_length = np.max([np.linalg.norm(uv) for uv in atoms.get_cell().array])

sigma = 2.0
epsilon = 1.5
r_cutoff = 0.5 * max_box_length
r_onset = r_cutoff * 0.9

print("Cell:\t\t", atoms.get_cell())
print("max_box_length:\t", max_box_length)
print("ro:\t\t", r_onset)
print("rc:\t\t", r_cutoff)
print()
print("R_real:\n", R_real )

atoms.calc = LennardJones(sigma=sigma, epsilon=epsilon, rc=r_cutoff, ro=r_onset, smooth=True)

Cell:		 Cell([26.299999999999997, 26.299999999999997, 26.299999999999997])
max_box_length:	 26.299999999999997
ro:		 11.834999999999999
rc:		 13.149999999999999

R_real:
 [[ 0.    0.    0.  ]
 [ 0.    2.63  2.63]
 [ 2.63  0.    2.63]
 ...
 [21.04 23.67 23.67]
 [23.67 21.04 23.67]
 [23.67 23.67 21.04]]


## New `periodic_general()` with stress
- `new_periodic_general()`, `new_inverse()` and `new_transform()` refer to currently unreleased implementations of [general periodic boundary conditions](https://gist.github.com/sschoenholz/14944c4b9dd263c95c524f84cc1c4287).
- As `ASE` uses real-space coordinates, a custom displacement allows us to convert those to scaled coordinates that are expected by `JAX-MD`.

In [1]:
from periodic_general import periodic_general as new_periodic_general
from periodic_general import inverse as new_inverse
from periodic_general import transform as new_transform

# what asax.utils.get_displacement() does, only with functions from the new periodic_general()
def new_get_displacement(atoms):
    cell = atoms.get_cell().array
    inverse_cell = new_inverse(cell)

    displacement_in_scaled_coordinates, shift_in_scaled_coordinates = new_periodic_general(cell)

    # **kwargs are now used to feed through the box information
    @jit
    def displacement(Ra: space.Array, Rb: space.Array, **kwargs) -> space.Array:
        Ra_scaled = new_transform(inverse_cell, Ra)
        Rb_scaled = new_transform(inverse_cell, Rb)
        return displacement_in_scaled_coordinates(Ra_scaled, Rb_scaled, **kwargs)
    
    #@jit
    #def shift(R: space.Array, dR: space.Array, **kwargs) -> space.Array:
    #    R_scaled = new_transform(inverse_cell, R)
    #    # dR is an output of displacement and should be already in real coordinates
    #    return shfit_in_scaled_coordinates(R_scaled, dR, **kwargs)

    return displacement, shift

Due to conventional differences in Lennard-Jones parameters, `r_onset` and `r_cutoff` have to be divided by `sigma` in order to obtain `ASE`-equivalent results.

In [103]:
displacement_fn, shift_fn = new_get_displacement(atoms)
energy_fn = jit(energy.lennard_jones_pair(displacement_fn, sigma=sigma, epsilon=epsilon, r_onset=r_onset/sigma, r_cutoff=r_cutoff/sigma, per_particle=True))
total_energy_fn = jit(lambda R, **kwargs: jnp.sum(energy_fn(R, **kwargs)))
force_fn = jit(quantity.force(total_energy_fn))

`**kwargs` to `energy_fn()` are now fed through `displacement_fn`, which allows us to manually override the box when computing energies.

In [104]:
box = atoms.get_cell().array
print(total_energy_fn(R_real, box=box))
print(total_energy_fn(R_real, box=box/4))

-507.70818181837694
176288289.1137517


### Computing stress

#### Theory
In this context, the stress tensor $\sigma$ is defined as the derivative of the total energy $E$ with respect to a 3x3 infinitesimal strain tensor $\epsilon$:

$\sigma_{ab} = \frac{1}{V} * \frac{\partial E}{\partial \epsilon_{ab}}$ at $\epsilon = 0$, where $V$ represents the system's volume.

$\epsilon$ describes a strain transformation of all real space coordinates $R$ as: 

$R \mapsto (1 + e) \cdot R$

This transformation is applied to all coordinates, i.e. the atoms in the unit cell and the basis vectors.

#### Implementation
In order to obtain the gradient w.r.t. said $\epsilon$, we introduce it into the energy function in multiple steps.
Firstly, we define a function that applies the mentioned strain transformation to the box. Additionally, we symmetrize the deformation tensor to obtain an *exactly* symmetric stress tensor later on.

In [105]:
strain_box_fn = lambda deformation, box: new_transform(jnp.eye(3) + (deformation + deformation.T) * 0.5, box)

We can now push the strained box through the energy function and derive force and stress functions from it.

In [106]:
deformation_energy_fn = lambda deformation, R: energy_fn(R, box=strain_box_fn(deformation, box))
total_deformation_energy_fn = lambda deformation, R: jnp.sum(deformation_energy_fn(deformation, R))
deformation_force_fn = lambda deformation, R: grad(total_deformation_energy_fn, argnums=1)(deformation, R) * -1

The stress is defined as the gradient of the strained box's energy w.r.t. to the applied strain transformation $\epsilon$ divided by the system's volume, i.e. the box's determinant. 

The Stress**es** follow the same definition, only for the atom-wise energy contributions rather than the total energy. 

In [107]:
volume = jnp.linalg.det(box)
stress_fn = lambda deformation, R: grad(total_deformation_energy_fn, argnums=0)(deformation, R) / volume
stresses_fn = lambda deformation, R: jacfwd(deformation_energy_fn, argnums=0)(deformation, R) / volume

We define $\epsilon$ as a zero-valued 3x3 matrix and calculate both the stress...

In [108]:
deformation = jnp.zeros_like(box)
stress_fn(deformation, R_real)

DeviceArray([[ 5.48970680e-02, -6.72443916e-20, -1.43739739e-20],
             [-6.72443916e-20,  5.48970680e-02,  9.75034346e-21],
             [-1.43739739e-20,  9.75034346e-21,  5.48970680e-02]],            dtype=float64)

...and stress**es** tensor.

In [109]:
stresses_fn(deformation, R_real)

DeviceArray([[[ 1.09794136e-04,  0.00000000e+00,  1.52574865e-21],
              [ 0.00000000e+00,  1.09794136e-04, -3.05149729e-21],
              [ 1.52574865e-21, -3.05149729e-21,  1.09794136e-04]],

             [[ 1.09794136e-04, -1.52574865e-21,  7.62874323e-22],
              [-1.52574865e-21,  1.09794136e-04, -4.05276984e-21],
              [ 7.62874323e-22, -4.05276984e-21,  1.09794136e-04]],

             [[ 1.09794136e-04, -3.81437161e-22,  3.81437161e-22],
              [-3.81437161e-22,  1.09794136e-04,  0.00000000e+00],
              [ 3.81437161e-22,  0.00000000e+00,  1.09794136e-04]],

             ...,

             [[ 1.09794136e-04,  1.52574865e-21,  0.00000000e+00],
              [ 1.52574865e-21,  1.09794136e-04, -1.52574865e-21],
              [ 0.00000000e+00, -1.52574865e-21,  1.09794136e-04]],

             [[ 1.09794136e-04, -1.90718581e-22,  3.05149729e-21],
              [-1.90718581e-22,  1.09794136e-04,  7.62874323e-22],
              [ 3.05149729e-21,  7.

### Verification
For verification purposes, let's compare our results with `ASE`.

In [110]:
def all_close(arr1, arr2, rtol=1e-7, atol=1e-15):
    np.testing.assert_allclose(arr1, arr2, rtol=rtol, atol=atol, equal_nan=True)
    
# energy - with and without deformation
all_close(atoms.get_potential_energy(), total_energy_fn(R_real))
all_close(atoms.get_potential_energy(), total_deformation_energy_fn(deformation, R_real))

# forces - with and without deformation
all_close(atoms.get_forces(), force_fn(R_real), atol=1e-14)
all_close(atoms.get_forces(), deformation_force_fn(deformation, R_real), atol=1e-14)

# stress
# TODO: Test with double64 precision
all_close(atoms.get_stress(voigt=False), stress_fn(deformation, R_real), atol=1e-17)

# stresses
all_close(atoms.get_stresses(voigt=False), stresses_fn(deformation, R_real), atol=1e-17)

In [111]:
# TODO: investigate JIT behavior

## Molecular Dynamics

In [113]:
dt = 5e-3
kT = lambda t: np.where(t < 5000.0 * dt, 0.1, 0.01)
key, split = random.split(key)

init_fn, apply_fn = simulate.nvt_nose_hoover(energy_fn, shift_fn, dt, kT(0.), tau=0.1)
state = init_fn(key, R_real)

In [None]:
def step_fn(i, state_and_log):
    state, log = state_and_log
    t = i * dt

    T = quantity.temperature(state.velocity)
    log['kT'] = ops.index_update(log['kT'], i, T)

    H = simulate.nose_hoover_invariant(energy_fn, state, kT(t))
    log['H'] = ops.index_update(log['H'], i, H)

    log['position'] = ops.index_update(log['position'], i, state.position)

    state = apply(state, kT=kT(t))
  
    return state, log