In [2]:
import matplotlib.pyplot as plt
import jax.numpy as np
from jax_md.space import *
from jax import grad, jit
from jax.config import config
config.update("jax_enable_x64", True)

from jax_md.space import DisplacementOrMetricFn, Array
from jax_md import smap, space
from jax_md.energy import multiplicative_isotropic_cutoff, lennard_jones

from typing import Callable

In [24]:

def _check_transform_shapes(T: Array, v: Array=None):
  """Check whether a transform and collection of vectors have valid shape."""
  if len(T.shape) != 2:
    raise ValueError(
        ('Transform has invalid rank.'
         ' Found rank {}, expected rank 2.'.format(len(T.shape))))

  if T.shape[0] != T.shape[1]:
    raise ValueError('Found non-square transform.')

  if v is not None and v.shape[-1] != T.shape[1]:
    raise ValueError(
        ('Transform and vectors have incommensurate spatial dimension. '
         'Found {} and {} respectively.'.format(T.shape[1], v.shape[-1])))


def _small_inverse(T: Array) -> Array:
  """Compute the inverse of a small matrix."""
  _check_transform_shapes(T)
  dim = T.shape[0]
  # TODO(schsam): Check whether matrices are singular. @ErrorChecking
  return jnp.linalg.inv(T)

# @custom_jvp
def transform(T: Array, v: Array) -> Array:
  """Apply a linear transformation, T, to a collection of vectors, v.
  Transform is written such that it acts as the identity during gradient
  backpropagation.
  Args:
    T: Transformation; ndarray(shape=[spatial_dim, spatial_dim]).
    v: Collection of vectors; ndarray(shape=[..., spatial_dim]).
  Returns:
    Transformed vectors; ndarray(shape=[..., spatial_dim]).
  """
  _check_transform_shapes(T, v)
  return jnp.dot(v, T)


# @transform.defjvp
#def transform_jvp(primals: Tuple[Array, Array],
#                   tangents: Tuple[Array, Array]) -> Tuple[Array, Array]:
#  T, v = primals
#  dT, dv = tangents
#  return transform(T, v), dv

def periodic_general(T: Union[Array, Callable[..., Array]],
                     wrapped: bool=True) -> Space:
  """Periodic boundary conditions on a parallelepiped.
  This function defines a simulation on a parellelepiped formed by applying an
  affine transformation to the unit hypercube [0, 1]^spatial_dimension.
  When using periodic_general, particles positions should be stored in the unit
  hypercube. To get real positions from the simulation you should call
  R_sim = space.transform(T, R_unit_cube).
  The affine transformation can feature time dependence (if T is a function
  instead of a scalar). In this case the resulting space will also be time
  dependent. This can be useful for simulating systems under mechanical strain.
  Args:
    T: An affine transformation.
       Either:
         1) An ndarray of shape [spatial_dim, spatial_dim].
         2) A function that takes floating point times and produces ndarrays of
            shape [spatial_dim, spatial_dim].
    wrapped: A boolean specifying whether or not particle positions are
      remapped back into the box after each step
  Returns:
    (displacement_fn, shift_fn) tuple.
  """
  if callable(T):
    def displacement(Ra: Array, Rb: Array, **kwargs) -> Array:
      print("Tracing displacement (callable)")
      dR = periodic_displacement(f32(1.0), pairwise_displacement(Ra, Rb))
      return transform(T(**kwargs), dR)
    # Can we cache the inverse? @Optimization
    if wrapped:
      def shift(R: Array, dR: Array, **kwargs) -> Array:
        return periodic_shift(f32(1.0),
                              R,
                              transform(_small_inverse(T(**kwargs)), dR))
    else:
      def shift(R: Array, dR: Array, **kwargs) -> Array:
        return R + transform(_small_inverse(T(**kwargs)), dR)
  else:
    print("Tracing displacement (no callable)")
    T_inv = _small_inverse(T)
    def displacement(Ra: Array, Rb: Array, **unused_kwargs) -> Array:
      dR = periodic_displacement(f32(1.0), pairwise_displacement(Ra, Rb))
      return transform(T, dR)
    if wrapped:
      def shift(R: Array, dR: Array, **unused_kwargs) -> Array:
        return periodic_shift(f32(1.0), R, transform(T_inv, dR))
    else:
      def shift(R: Array, dR: Array, **unused_kwargs) -> Array:
        return R + transform(T_inv, dR)
  return displacement, shift



In [28]:
def get_T():
    """Generate callable T(strain)
    
    We want to apply the strain transformation (1 + strain) to the
    unit cell T. Since jax_md works in scaled coordinates
    for periodic_general, this will automatically apply the
    trafo to *all* coordinates.
    
    Here, we generate a callable T that has the keyword argument `strain`.
    Note that we have to add a dummy `t` argument since jax_md internally
    assumes that T is a function of time `t`
    """
    # def T(strain: Array = np.zeros((3, 3), dtype=np.double), t=0) -> Array:
    def T() -> Array:
        print("Tracing T")
        basis = np.array([[1, 0, 0], [0, 2, 0], [0, 0, 1]], dtype=np.double)
        strain = np.zeros((3, 3), dtype=np.double) 
        strain_transformation = np.eye(N=3, M=3, dtype=np.double) + strain
        return transform(strain_transformation, basis)

    return T
    

### Define a transformed displacement - no JIT

**Both** `displacement_fn` and `T` can be correctly traced on every execution.

In [25]:
basis = np.array([[1, 0, 0], [0, 2, 0], [0, 0, 1]], dtype=np.double)
T = get_T(basis)

displacement_fn, _ = periodic_general(T)

a = np.array([0, 0, 0], dtype=np.double)
b = np.array([0.5, 0.5, 0.5], dtype=np.double)
c = np.array([0.25, 0.5, 0.75], dtype=np.double)

print(displacement_fn(a, b))
print()

print(displacement_fn(a, b))
print()

# first run - 20 ms.
%time displacement_fn(a, b).block_until_ready()
print()

# second run - 6 ms.
%time displacement_fn(a, c).block_until_ready()
print()

# print(displacement_fn(a, b))
# print(displacement_fn(a, c))
# => looks ok!

Tracing displacement (callable)
Tracing T
[-0.5 -1.  -0.5]

Tracing displacement (callable)
Tracing T
[-0.5 -1.  -0.5]

Tracing displacement (callable)
Tracing T
CPU times: user 6.59 ms, sys: 2.39 ms, total: 8.98 ms
Wall time: 4.86 ms

Tracing displacement (callable)
Tracing T
CPU times: user 8.07 ms, sys: 25 µs, total: 8.1 ms
Wall time: 4.69 ms



Looks like the non-jitted function!

In [38]:
print(displacement_fn)

T()

<function periodic_general.<locals>.displacement at 0x7fa38bddfdc0>
Tracing T


DeviceArray([[1., 0., 0.],
             [0., 2., 0.],
             [0., 0., 1.]], dtype=float64)

### Define a transformed displacement - jitted

`displacement_fn` and `T` are jitted on the first run, with a significant speedup on following calls. Tracer prints are correctly removed once jitted.

In [39]:
# basis = np.array([[1, 0, 0], [0, 2, 0], [0, 0, 1]], dtype=np.double)
# T = jit(get_T(basis))
# T = get_T(basis)

def canonicalize_displacement_or_metric(displacement_or_metric):
  """Checks whether or not a displacement or metric was provided."""
  for dim in range(1, 4):
    try:
      R = ShapedArray((dim,), f32)
      dR_or_dr = eval_shape(displacement_or_metric, R, R, t=0)
      print("loop " + str(dim))  
        
      if len(dR_or_dr.shape) == 0:
        return displacement_or_metric
      else:
        return metric(displacement_or_metric)
    except TypeError:
      continue
    except ValueError:
      continue
  raise ValueError(
    'Canonicalize displacement not implemented for spatial dimension larger'
    'than 4.')




T = get_T()
displacement_fn, _ = periodic_general(T)
canonicalize_displacement_or_metric(displacement_fn)

# displacement_fn, _ = periodic_general(basis)
displacement_fn = jit(displacement_fn)

a = np.array([0, 0, 0], dtype=np.double)
b = np.array([0.5, 0.5, 0.5], dtype=np.double)
c = np.array([0.25, 0.5, 0.75], dtype=np.double)


# first run - jit overhead. 82 ms.
%time displacement_fn(a, b).block_until_ready()
print()

# second run - jitted function, different args (prevent caching). 1 ms, max.
%time displacement_fn(a, c).block_until_ready()
print()

# print(displacement_fn(a, b))
# print(displacement_fn(a, c))
# => looks ok!

Tracing displacement (callable)
Tracing displacement (callable)
Tracing displacement (callable)


ValueError: Canonicalize displacement not implemented for spatial dimension largerthan 4.

Looks like a jitted function!

In [30]:
print(displacement_fn)

<function api_boundary.<locals>.reraise_with_filtered_traceback at 0x7fa38be0e8b0>


## Jitted displacement and energy_fn

In the previous cell, we evaluate `displacement_fn` multiple times and, as expected, only saw the tracer on the first call.

Here, however, both `T` and the `displacement_fn` are traced on every execution, although we passed the jitted version to `lennard_jones_pair()`.

In [33]:
# default lj function, no changes
def lennard_jones_pair(displacement_or_metric: DisplacementOrMetricFn,
                       species: Array=None,
                       sigma: Array=1.0,
                       epsilon: Array=1.0,
                       r_onset: Array=2.0,
                       r_cutoff: Array=2.5,
                       per_particle: bool=False) -> Callable[[Array], Array]:
  """Convenience wrapper to compute Lennard-Jones energy over a system."""
  
  print("Tracing lennard_jones_pair")

  sigma = np.array(sigma, dtype=f32)  
  epsilon = np.array(epsilon, dtype=f32)
  r_onset = r_onset * np.max(sigma)
  r_cutoff = r_cutoff * np.max(sigma)
  return smap.pair(
    multiplicative_isotropic_cutoff(lennard_jones, r_onset, r_cutoff),
    space.canonicalize_displacement_or_metric(displacement_or_metric),
    species=species,
    sigma=sigma,
    epsilon=epsilon,
    reduce_axis=(1,) if per_particle else None)


lj = lennard_jones_pair(displacement_fn)

def energy_fn(R: Array, strain: Array) -> Array:
    print("Tracing energy_fn")
    """LJ with TWO arguments: coordinates and strain
    
    `jax.grad` doesn't work with keyword args, so we wrap it.
    """
    return jit(lj(R, strain=strain))
    # return jit(energy.lennard_jones_pair(displacement_fn)(R, strain=strain))

R = np.array([a, b, c])
strain = np.zeros((3, 3), dtype=np.double) 

Tracing lennard_jones_pair
Tracing displacement (callable)
Tracing displacement (callable)
Tracing displacement (callable)


ValueError: Canonicalize displacement not implemented for spatial dimension largerthan 4.

`displacement_fn` **still** looks like a jitted function! How are the tracers suddenly back in place? 

In [22]:
print(displacement_fn)

<function api_boundary.<locals>.reraise_with_filtered_traceback at 0x7fa38bff1550>
