In [2]:
import jax
import jax.numpy as jnp
import jax.random as jr

In [3]:
def cdist(x, y):
    '''
    Computes the pairwise euclidean distance between two arrays

    Inputs:
    ------------
    x: Array (N, d)
    y: Array (N, d)

    Returns:
    ------------
    dists: Array (N, N)
    '''
    # insert dimension checking
    dists = jnp.sqrt(((x[:, None, :] - y[None, :, :]) ** 2).sum(-1))
    return dists

In [4]:
def pdist(x):
    return cdist(x, x)

def pdist_single_sim(x):
    return jax.vmap(pdist, 0)(x)

def pdist_all_sim(x):
    return jax.vmap(pdist_single_sim, 0)(x)

In [22]:
def pdist_sims_and_times(x):
    return jax.vmap(jax.vmap(pdist, 0), 0)(x)

In [5]:
x = jr.randint(jr.PRNGKey(0), shape=(10, 100, 8, 2), minval=0, maxval=5)
all_distances = jax.vmap(jax.vmap(pdist, 0), 0)(x)

In [6]:
print(jnp.shape(all_distances))

(10, 100, 8, 8)


In [8]:
a = jnp.array([[1, 2], [3, 4]])
print(pdist(a))

[[0.       2.828427]
 [2.828427 0.      ]]


In [16]:
aa = jnp.tile(a, (10, 100, 1, 1))

In [17]:
adists = jax.vmap(jax.vmap(pdist, 0), 0)(aa)

In [18]:
print(jnp.shape(adists))

(10, 100, 2, 2)


In [21]:
print(adists[5, 46, :, :])

[[0.       2.828427]
 [2.828427 0.      ]]


In [30]:
tt = pdist_sims_and_times(aa)
print(jnp.shape(tt))
print(tt[1, 45, :, :])

(10, 100, 2, 2)
[[0.       2.828427]
 [2.828427 0.      ]]


In [17]:
from distances import *
xx = jnp.array([[0.1, 0.05], [0.75, 0.95]])

In [9]:
delta = jnp.abs(xx[:, None, :] - xx[None, :, :])

In [14]:
print(jnp.shape(delta))
print(-delta+1)

(2, 2, 2)
[[[1.         1.        ]
  [0.35000002 0.10000002]]

 [[0.35000002 0.10000002]
  [1.         1.        ]]]


In [27]:
def periodic1(p1, p2, L):
    """
    p1: [x1, y1]
    p2: [x2, y2]
    L: [Lx, Ly]
    """
    _r = p1 - p2
    _r = _r - L*jnp.rint(_r/L)
    return jnp.sqrt(jnp.sum(_r*_r))

- periodic1 takes as input two points of shape (d,) and sidelengths of shape (d,). It returns the nearest image euclidean distance between these points on a square/cube etc. of dimensions L_x x L_y x ...
- next we need to vmap this over arrays of shape (N, d) to compute a matrix of shape (N, N) storing the pairwise nearest image euclidean distances
- then double vmap this over the timesteps and simulations

In [37]:
p1 = jnp.array([0.1, 0.05])
p2 = jnp.array([0.75, 0.95])
print(periodic1(p1, p2, 1))
p = jnp.vstack((p1, p2))
print(p)

0.36400554
[[0.1  0.05]
 [0.75 0.95]]


In [58]:
def periodic2(p, L):
    return jax.vmap(lambda p1: jax.vmap(lambda p2: periodic1(p1, p2, L))(p)
    )(p)

In [59]:
print(periodic2(p, 1))

[[0.         0.36400554]
 [0.36400554 0.        ]]


In [68]:
def periodic3(p, L):
    return jax.vmap(jax.vmap(periodic2, (0, None)), (0, None))(p, L)

In [69]:
pp = jnp.tile(p, (10, 100, 1, 1))

In [70]:
print(periodic3(pp, 1))

[[[[0.         0.36400554]
   [0.36400554 0.        ]]

  [[0.         0.36400554]
   [0.36400554 0.        ]]

  [[0.         0.36400554]
   [0.36400554 0.        ]]

  ...

  [[0.         0.36400554]
   [0.36400554 0.        ]]

  [[0.         0.36400554]
   [0.36400554 0.        ]]

  [[0.         0.36400554]
   [0.36400554 0.        ]]]


 [[[0.         0.36400554]
   [0.36400554 0.        ]]

  [[0.         0.36400554]
   [0.36400554 0.        ]]

  [[0.         0.36400554]
   [0.36400554 0.        ]]

  ...

  [[0.         0.36400554]
   [0.36400554 0.        ]]

  [[0.         0.36400554]
   [0.36400554 0.        ]]

  [[0.         0.36400554]
   [0.36400554 0.        ]]]


 [[[0.         0.36400554]
   [0.36400554 0.        ]]

  [[0.         0.36400554]
   [0.36400554 0.        ]]

  [[0.         0.36400554]
   [0.36400554 0.        ]]

  ...

  [[0.         0.36400554]
   [0.36400554 0.        ]]

  [[0.         0.36400554]
   [0.36400554 0.        ]]

  [[0.         0.364005

In [71]:
print(jnp.shape(periodic3(pp, 1)))

(10, 100, 2, 2)
