In [1]:
import math
from typing import Callable, Tuple

import numpy as np
import equinox
import jax
import jax.numpy as jnp
import jax.scipy.ndimage
import tensorial.distances

In [2]:
def get_num_plane_repetitions_to_bound_sphere(radius: float, volume: float, cross_len: float) -> float:
    # The vector normal to the plane
    return radius / volume * cross_len

In [92]:
class PeriodicNeighbourCalculator(equinox.Module):
    _cell: jax.Array
    _cutoff: float
    _cutoff_sq: float
    _max_cell_multiples: int
    _volume: float   
    _always_inside : jax.Array
    _always_inside_cells: jax.Array
    _halo: jax.Array
    _halo_cells: jax.Array

    def __init__(self, cell: jax.Array, cutoff: float, max_cell_multiples=100_000):
        self._cell = cell
        self._cutoff = cutoff
        self._cutoff_sq = cutoff * cutoff
        self._max_cell_multiples = max_cell_multiples
        # Unit cell volume
        self._volume = jnp.abs(jnp.dot(cell[0], jnp.cross(cell[1], cell[2])))

        max0 = self._get_max(0)
        max1 = self._get_max(1)
        max2 = self._get_max(2)
        
        if max0 > max_cell_multiples:
            max0 = max_cell_multiples
        if max1 > max_cell_multiples:
            max1 = max_cell_multiples
        if max2 > max_cell_multiples:
            max2 = max_cell_multiples
            
        cell_grid = jnp.array(jnp.meshgrid(
            jnp.arange(-max0, max0 + 1),
            jnp.arange(-max1, max1 + 1),
            jnp.arange(-max2, max2 + 1)
        ))
        reshaped = cell_grid.T.reshape(-1, 3)
        
        grid_points = jax.vmap(
            lambda grid_coords: jnp.multiply(grid_coords, cell).sum(axis=1))(reshaped)
        
        norms_sq = (grid_points**2).sum(axis=1).reshape(cell_grid.shape[1:])
        
        # Calculate a mask, and put back into the original [l, m, n] shape so we can do a convolution
        mask = (norms_sq < self._cutoff_sq)
        convolved = jax.scipy.signal.convolve(mask, jnp.ones((2, 2, 2)), mode='same').reshape(-1)
        
        # If a unit cell and all of its eight surrounding neighbour grid points (that can be arrived at by adding all possible
        # combinations of lattice vectors) are inside the cutoff, then we never need to check this unit cell again as no vector
        # inside it can ever leave the cutoff sphere
        always_inside_indices = (convolved == 8.)
        self._always_inside = grid_points[always_inside_indices]
        self._always_inside_cells = reshaped[always_inside_indices]
        
        # The halo region is where a grid point has one or more neighbouring grid points (or itself) that are within the cutoff
        # sphere, and therefore part of the unit cell is inside.  These will need to be checked depending on the actual
        # displacement vector within the cell
        in_halo = (convolved > 0.) & ~always_inside_indices
        self._halo = grid_points[in_halo]
        self._halo_cells = reshaped[in_halo]

    def _get_max(self, cell_vector: int) -> int:
        vec1 = (cell_vector + 1) % 3
        vec2 = (cell_vector + 2) % 3
        vec1_cross_vec2_len = jnp.linalg.norm(jnp.cross(self._cell[vec1], self._cell[vec2]))
        return math.ceil(
            get_num_plane_repetitions_to_bound_sphere(
                self._cutoff, self._volume, vec1_cross_vec2_len,
            )
        )
    
    def __call__(self, r2: jax.Array, r1: jax.Array):
        dr = r2 - r1
        get_pts = jax.vmap(lambda grid_point: grid_point +  dr)
        neighbour_vecs = get_pts(self._always_inside)
        
        halo_vecs = get_pts(self._halo)
        norms_sq = (halo_vecs**2).sum(axis=1)
        halo_indices = norms_sq < self._cutoff_sq
        halo_vecs = halo_vecs[halo_indices]
        halo_cells = self._halo_cells[halo_indices]
        
        return jnp.concat((neighbour_vecs, halo_vecs)), jnp.concat((self._always_inside_cells, halo_cells))

In [93]:
cell = 5. * np.eye(3, 3)
cutoff = 8.7
faster_calculator = PeriodicNeighbourCalculator(cell, cutoff)
slower_calculator = tensorial.distances.Periodic(cell, cutoff)

8


In [94]:
vecs = jnp.array([0., 0., 0]), jnp.array([0.05, 0., 0.])

In [96]:
pts = faster_calculator(*vecs)[0]
jnp.linalg.norm(pts, axis=1)

Array([0.05     , 5.00025  , 4.95     , 7.0358014, 5.00025  , 7.0712447,
       7.0358014, 8.631483 , 8.689218 , 7.106511 , 8.689218 , 7.0712447,
       5.00025  , 7.0712447, 8.631483 , 7.0358014, 8.631483 , 7.106511 ,
       5.05     , 7.106511 , 5.00025  , 7.0358014, 8.689218 , 7.106511 ,
       8.689218 , 7.0712447, 8.631483 ], dtype=float32)

In [97]:
pts = slower_calculator(*vecs)[0]
jnp.linalg.norm(pts, axis=1)

Array([8.631483 , 7.0358014, 8.631483 , 7.0358014, 4.95     , 7.0358014,
       8.631483 , 7.0358014, 8.631483 , 7.0712447, 5.00025  , 7.0712447,
       5.00025  , 0.05     , 5.00025  , 7.0712447, 5.00025  , 7.0712447,
       8.689218 , 7.106511 , 8.689218 , 7.106511 , 5.05     , 7.106511 ,
       8.689218 , 7.106511 , 8.689218 ], dtype=float32)

In [72]:
faster_calculator._always_inside

Array([[0., 0., 0.],
       [0., 5., 0.],
       [5., 0., 0.],
       [5., 5., 0.],
       [0., 0., 5.],
       [0., 5., 5.],
       [5., 0., 5.],
       [5., 5., 5.]], dtype=float32)

In [30]:
max0 = max1 = max2 = 2
grid = jnp.array(jnp.meshgrid(
    jnp.arange(-max0, max0 + 1),
    jnp.arange(-max1, max1 + 1),
    jnp.arange(-max2, max2 + 1)
))
reshaped = grid.T.reshape(-1, 3)

In [32]:
grid_points = jax.vmap(
    lambda grid_coords: jnp.multiply(grid_coords, cell).sum(axis=1))(reshaped)
norms_sq = jnp.linalg.norm(grid_points, axis=1, ord=2).reshape(grid.shape[1:])

In [45]:
print(max0, max1, max2)

grid = jnp.array(jnp.meshgrid(
    jnp.arange(-max0, max0 + 1),
    jnp.arange(-max1, max1 + 1),
    jnp.arange(-max2, max2 + 1)
))
reshaped = grid.T.reshape(-1, 3)

2 2 2


In [47]:
reshaped

Array([[-2, -2, -2],
       [-2, -1, -2],
       [-2,  0, -2],
       [-2,  1, -2],
       [-2,  2, -2],
       [-1, -2, -2],
       [-1, -1, -2],
       [-1,  0, -2],
       [-1,  1, -2],
       [-1,  2, -2],
       [ 0, -2, -2],
       [ 0, -1, -2],
       [ 0,  0, -2],
       [ 0,  1, -2],
       [ 0,  2, -2],
       [ 1, -2, -2],
       [ 1, -1, -2],
       [ 1,  0, -2],
       [ 1,  1, -2],
       [ 1,  2, -2],
       [ 2, -2, -2],
       [ 2, -1, -2],
       [ 2,  0, -2],
       [ 2,  1, -2],
       [ 2,  2, -2],
       [-2, -2, -1],
       [-2, -1, -1],
       [-2,  0, -1],
       [-2,  1, -1],
       [-2,  2, -1],
       [-1, -2, -1],
       [-1, -1, -1],
       [-1,  0, -1],
       [-1,  1, -1],
       [-1,  2, -1],
       [ 0, -2, -1],
       [ 0, -1, -1],
       [ 0,  0, -1],
       [ 0,  1, -1],
       [ 0,  2, -1],
       [ 1, -2, -1],
       [ 1, -1, -1],
       [ 1,  0, -1],
       [ 1,  1, -1],
       [ 1,  2, -1],
       [ 2, -2, -1],
       [ 2, -1, -1],
       [ 2,  

In [48]:
grid_points = jax.vmap(
    lambda grid_coords: jnp.multiply(grid_coords, cell).sum(axis=1))(reshaped)
grid_points

Array([[-10., -10., -10.],
       [-10.,  -5., -10.],
       [-10.,   0., -10.],
       [-10.,   5., -10.],
       [-10.,  10., -10.],
       [ -5., -10., -10.],
       [ -5.,  -5., -10.],
       [ -5.,   0., -10.],
       [ -5.,   5., -10.],
       [ -5.,  10., -10.],
       [  0., -10., -10.],
       [  0.,  -5., -10.],
       [  0.,   0., -10.],
       [  0.,   5., -10.],
       [  0.,  10., -10.],
       [  5., -10., -10.],
       [  5.,  -5., -10.],
       [  5.,   0., -10.],
       [  5.,   5., -10.],
       [  5.,  10., -10.],
       [ 10., -10., -10.],
       [ 10.,  -5., -10.],
       [ 10.,   0., -10.],
       [ 10.,   5., -10.],
       [ 10.,  10., -10.],
       [-10., -10.,  -5.],
       [-10.,  -5.,  -5.],
       [-10.,   0.,  -5.],
       [-10.,   5.,  -5.],
       [-10.,  10.,  -5.],
       [ -5., -10.,  -5.],
       [ -5.,  -5.,  -5.],
       [ -5.,   0.,  -5.],
       [ -5.,   5.,  -5.],
       [ -5.,  10.,  -5.],
       [  0., -10.,  -5.],
       [  0.,  -5.,  -5.],
 

In [73]:
norms_sq = (grid_points**2).sum(axis=1).reshape(grid.shape[1:])

In [86]:
# Calculate a mask, and put back into the original [l, m, n] shape so we can do a convolution
mask = (norms_sq < 8.7**2)
convolved = jax.scipy.signal.convolve(mask, jnp.ones((2, 2, 2)), mode='same')
print(f"Convolved: {convolved.shape}, {reshaped.shape}")

Convolved: (5, 5, 5), (125, 3)


In [87]:
mask

Array([[[False, False, False, False, False],
        [False, False, False, False, False],
        [False, False, False, False, False],
        [False, False, False, False, False],
        [False, False, False, False, False]],

       [[False, False, False, False, False],
        [False,  True,  True,  True, False],
        [False,  True,  True,  True, False],
        [False,  True,  True,  True, False],
        [False, False, False, False, False]],

       [[False, False, False, False, False],
        [False,  True,  True,  True, False],
        [False,  True,  True,  True, False],
        [False,  True,  True,  True, False],
        [False, False, False, False, False]],

       [[False, False, False, False, False],
        [False,  True,  True,  True, False],
        [False,  True,  True,  True, False],
        [False,  True,  True,  True, False],
        [False, False, False, False, False]],

       [[False, False, False, False, False],
        [False, False, False, False, False],
  

In [81]:
convolved

Array([[[0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.]],

       [[0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.]],

       [[0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 1., 1., 0.],
        [0., 0., 1., 1., 0.],
        [0., 0., 0., 0., 0.]],

       [[0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 1., 1., 0.],
        [0., 0., 1., 1., 0.],
        [0., 0., 0., 0., 0.]],

       [[0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.]]], dtype=float32)