In [1]:
import matplotlib.pyplot as plt
import jax.numpy as np
from jax_md.space import *
from jax import grad
from jax.config import config
config.update("jax_enable_x64", True)


In [85]:

def _check_transform_shapes(T: Array, v: Array=None):
  """Check whether a transform and collection of vectors have valid shape."""
  if len(T.shape) != 2:
    raise ValueError(
        ('Transform has invalid rank.'
         ' Found rank {}, expected rank 2.'.format(len(T.shape))))

  if T.shape[0] != T.shape[1]:
    raise ValueError('Found non-square transform.')

  if v is not None and v.shape[-1] != T.shape[1]:
    raise ValueError(
        ('Transform and vectors have incommensurate spatial dimension. '
         'Found {} and {} respectively.'.format(T.shape[1], v.shape[-1])))


def _small_inverse(T: Array) -> Array:
  """Compute the inverse of a small matrix."""
  _check_transform_shapes(T)
  dim = T.shape[0]
  # TODO(schsam): Check whether matrices are singular. @ErrorChecking
  return jnp.linalg.inv(T)

# @custom_jvp
def transform(T: Array, v: Array) -> Array:
  """Apply a linear transformation, T, to a collection of vectors, v.
  Transform is written such that it acts as the identity during gradient
  backpropagation.
  Args:
    T: Transformation; ndarray(shape=[spatial_dim, spatial_dim]).
    v: Collection of vectors; ndarray(shape=[..., spatial_dim]).
  Returns:
    Transformed vectors; ndarray(shape=[..., spatial_dim]).
  """
  _check_transform_shapes(T, v)
  return jnp.dot(v, T)


# @transform.defjvp
# def transform_jvp(primals: Tuple[Array, Array],
#                   tangents: Tuple[Array, Array]) -> Tuple[Array, Array]:
#   T, v = primals
#   dT, dv = tangents
#   return transform(T, v), dv

def periodic_general(T: Union[Array, Callable[..., Array]],
                     wrapped: bool=True) -> Space:
  """Periodic boundary conditions on a parallelepiped.
  This function defines a simulation on a parellelepiped formed by applying an
  affine transformation to the unit hypercube [0, 1]^spatial_dimension.
  When using periodic_general, particles positions should be stored in the unit
  hypercube. To get real positions from the simulation you should call
  R_sim = space.transform(T, R_unit_cube).
  The affine transformation can feature time dependence (if T is a function
  instead of a scalar). In this case the resulting space will also be time
  dependent. This can be useful for simulating systems under mechanical strain.
  Args:
    T: An affine transformation.
       Either:
         1) An ndarray of shape [spatial_dim, spatial_dim].
         2) A function that takes floating point times and produces ndarrays of
            shape [spatial_dim, spatial_dim].
    wrapped: A boolean specifying whether or not particle positions are
      remapped back into the box after each step
  Returns:
    (displacement_fn, shift_fn) tuple.
  """
  if callable(T):
    def displacement(Ra: Array, Rb: Array, **kwargs) -> Array:
      dR = periodic_displacement(f32(1.0), pairwise_displacement(Ra, Rb))
      return transform(T(**kwargs), dR)
    # Can we cache the inverse? @Optimization
    if wrapped:
      def shift(R: Array, dR: Array, **kwargs) -> Array:
        return periodic_shift(f32(1.0),
                              R,
                              transform(_small_inverse(T(**kwargs)), dR))
    else:
      def shift(R: Array, dR: Array, **kwargs) -> Array:
        return R + transform(_small_inverse(T(**kwargs)), dR)
  else:
    T_inv = _small_inverse(T)
    def displacement(Ra: Array, Rb: Array, **unused_kwargs) -> Array:
      dR = periodic_displacement(f32(1.0), pairwise_displacement(Ra, Rb))
      return transform(T, dR)
    if wrapped:
      def shift(R: Array, dR: Array, **unused_kwargs) -> Array:
        return periodic_shift(f32(1.0), R, transform(T_inv, dR))
    else:
      def shift(R: Array, dR: Array, **unused_kwargs) -> Array:
        return R + transform(T_inv, dR)
  return displacement, shift



In [86]:
def get_T(basis):
    """Generate callable T(strain)
    
    We want to apply the strain transformation (1 + strain) to the
    unit cell T. Since jax_md works in scaled coordinates
    for periodic_general, this will automatically apply the
    trafo to *all* coordinates.
    
    Here, we generate a callable T that has the keyword argument `strain`.
    Note that we have to add a dummy `t` argument since jax_md internally
    assumes that T is a function of time `t`
    """

    def T(strain: Array = np.zeros((3, 3), dtype=np.double), t=0) -> Array:
        strain_transformation = np.eye(N=3, M=3, dtype=np.double) + strain
        return transform(strain_transformation, basis)

    return T
    

In [87]:
basis = np.array([[1, 0, 0], [0, 2, 0], [0, 0, 1]], dtype=np.double)
T = get_T(basis)

disp, _ = periodic_general(T)

In [95]:
a = np.array([0, 0, 0], dtype=np.double)
b = np.array([0.5, 0.5, 0.5], dtype=np.double)
c = np.array([0.25, 0.5, 0.75], dtype=np.double)

print(disp(a, b))
print(disp(a, c))
# => looks ok!

[-0.5 -1.  -0.5]
[-0.25 -1.    0.25]


In [92]:
from jax_md import energy
from functools import partial

lj = energy.lennard_jones_pair(disp)

def my_lj(R: Array, strain: Array) -> Array:
    """LJ with TWO arguments: coordinates and strain
    
    `jax.grad` doesn't work with keyword args, so we wrap it.
    """
    return lj(R, strain=strain)

R = np.array([a, b, c])
strain = np.zeros((3, 3), dtype=np.double) 

print(my_lj(R, strain))
stress = grad(my_lj, argnums=1)
print(stress(R, strain))

1046526.3297336849
[[-6.28531190e+06  9.65706447e-01  6.28531286e+06]
 [ 9.65706447e-01 -4.13169979e+00  9.65706447e-01]
 [ 6.28531286e+06  9.65706447e-01 -6.28531190e+06]]


In [93]:
from jax import value_and_grad
energy_and_forces_and_stress = value_and_grad(my_lj, argnums=(0, 1)) 
energy_and_forces_and_stress(R, strain)

(DeviceArray(1046526.32973368, dtype=float64),
 (DeviceArray([[ 1.51577817e+00,  0.00000000e+00, -1.51577817e+00],
               [-2.51412480e+07,  0.00000000e+00,  2.51412480e+07],
               [ 2.51412465e+07,  0.00000000e+00, -2.51412465e+07]],            dtype=float64),
  DeviceArray([[-6.28531190e+06,  9.65706447e-01,  6.28531286e+06],
               [ 9.65706447e-01, -4.13169979e+00,  9.65706447e-01],
               [ 6.28531286e+06,  9.65706447e-01, -6.28531190e+06]],            dtype=float64)))