In [4]:
import jax.numpy as np

from jax.api import jit, grad
from jax import lax
from jax import random

from jax_md import space, energy, simulate, minimize, quantity
#from jax_md.colab_tools import renderer

In [5]:
from jax_md import space
from jax import custom_jvp
from jax import lax

periodic_displacement = space.periodic_displacement
pairwise_displacement = space.pairwise_displacement
periodic_shift = space.periodic_shift

f32 = np.float32

def inverse(box):
  if np.isscalar(box) or box.size == 1:
    return 1 / box
  elif box.ndim == 1:
    return 1 / box
  elif box.ndim == 2:
    return np.linalg.inv(box)
  
  raise ValueError()

def get_free_indices(n):
  return ''.join([chr(ord('a') + i) for i in range(n)])

@custom_jvp
def transform(box, R):
  if np.isscalar(box) or box.size == 1:
    return R * box
  elif box.ndim == 1:
    indices = get_free_indices(R.ndim - 1) + 'i'
    return np.einsum(f'i,{indices}->{indices}', box, R)
  elif box.ndim == 2:
    free_indices = get_free_indices(R.ndim - 1)
    left_indices = free_indices + 'j'
    right_indices = free_indices + 'i'
    return np.einsum(f'ij,{left_indices}->{right_indices}', box, R)
  raise ValueError()

@transform.defjvp
def transform_jvp(primals, tangents):
  box, R = primals
  dbox, dR = tangents

  return transform(box, R), dR + transform(dbox, R)

def periodic_general(box, wrapped=True):

  inv_box = inverse(box)

  def displacement_fn(Ra, Rb, **kwargs):
    _box, _inv_box = box, inv_box

    if 'box' in kwargs:      
      _box = kwargs['box']

    dR = periodic_displacement(f32(1.0), pairwise_displacement(Ra, Rb))
    return transform(_box, dR) 

  def u(R, dR):
    if wrapped:
      return periodic_shift(f32(1.0), R, dR)
    return R + dR

  def shift_fn(R, dR, **kwargs):
    _box, _inv_box = box, inv_box
    if 'box' in kwargs:
      _box = kwargs['box']
      _inv_box = inverse(_box)
    dR = transform(_inv_box, dR)
    R = u(R, dR)
    return R
  
  return displacement_fn, shift_fn

In [10]:
N = 1024
box_size = quantity.box_size_at_number_density(N, 1.2, 2)
box = box_size * np.eye(2)
displacement, shift = periodic_general(box)

key = random.PRNGKey(0)
R = random.uniform(key, (N, 2))

box

DeviceArray([[29.211868,  0.      ],
             [ 0.      , 29.211868]], dtype=float32)

In [None]:
energy_fn = energy.soft_sphere_pair(displacement)
init_fn, step_fn = minimize.fire_descent(energy_fn, shift)

state = init_fn(R)

  lax._check_user_dtype_supported(dtype, name)


In [None]:
state = lax.while_loop(
    lambda state: np.max(np.abs(state.force)) > 1e-3,
    step_fn,
    state
)

  lax._check_user_dtype_supported(dtype, name)


In [None]:
renderer.render(
    box_size,
    {'particles': renderer.Disk(transform(box, state.position))}
)

In [None]:
def box_energy(epsilon, R):
  return energy_fn(R, box=box + epsilon) / np.linalg.det(box)
stress = jit(grad(box_energy))

In [None]:
box_energy(np.zeros((2, 2)), state.position)

  lax._check_user_dtype_supported(dtype, name)


DeviceArray(0.00147612, dtype=float32)

In [None]:
stress(np.zeros((2, 2)), state.position)

  lax._check_user_dtype_supported(dtype, name)


DeviceArray([[-1.4491812e-03, -6.8001609e-05],
             [-6.8001609e-05, -1.5729750e-03]], dtype=float32)