In [1]:
from eins import EinsOp, Reductions as Red

import jax
import jax.numpy as jnp
import numpy as np

In [2]:
x = jnp.array(np.random.randn(1024, 256, 3))
y = jnp.array(np.random.randn(1024, 256, 3))

z4 = EinsOp('b n1 d, b n2 d -> b n1 n2', combine='add', reduce=Red.l2_norm)(x, -y)

# Version without eins. Note how easy it would be to write x[:, None, ...] - y[:, :, None, ...],
# which would lead to the transposed version of the pairwise distances you want.
z5 = jnp.sqrt(jnp.sum(jnp.square(x[:, :, None, ...] - y[:, None, ...]), axis=-1))

jnp.max(jnp.abs(z4 - z5))

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
  return lax_numpy.astype(arr, dtype)


Array(9.536743e-07, dtype=float32)

In [3]:
@jax.jit
def jnp_pairwise_dist(x, y):
    return jnp.sqrt(jnp.sum(jnp.square(x[:, :, None, ...] - y[:, None, ...]), axis=-1))

@jax.jit
def ein_pairwise_dist(x, y):
    return EinsOp('b n1 d, b n2 d -> b n1 n2', combine='add', reduce=Red.l2_norm)(x, -y)

d1 = jnp_pairwise_dist(x, y)
d2 = ein_pairwise_dist(x, y)

In [4]:
%%timeit
jnp_pairwise_dist(x, y).block_until_ready()

22.3 ms ± 1.1 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [5]:
%%timeit
ein_pairwise_dist(x, y).block_until_ready()

250 ms ± 16.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [6]:
jax.make_jaxpr(jnp_pairwise_dist)(x, y)

{ [34m[22m[1mlambda [39m[22m[22m; a[35m:f32[1024,256,3][39m b[35m:f32[1024,256,3][39m. [34m[22m[1mlet
    [39m[22m[22mc[35m:f32[1024,256,256][39m = pjit[
      jaxpr={ [34m[22m[1mlambda [39m[22m[22m; d[35m:f32[1024,256,3][39m e[35m:f32[1024,256,3][39m. [34m[22m[1mlet
          [39m[22m[22mf[35m:f32[1024,256,1,3][39m = broadcast_in_dim[
            broadcast_dimensions=(0, 1, 3)
            shape=(1024, 256, 1, 3)
          ] d
          g[35m:f32[1024,1,256,3][39m = broadcast_in_dim[
            broadcast_dimensions=(0, 2, 3)
            shape=(1024, 1, 256, 3)
          ] e
          h[35m:f32[1024,256,256,3][39m = sub f g
          i[35m:f32[1024,256,256,3][39m = integer_pow[y=2] h
          j[35m:f32[1024,256,256][39m = reduce_sum[axes=(3,)] i
          k[35m:f32[1024,256,256][39m = sqrt j
        [34m[22m[1min [39m[22m[22m(k,) }
      name=jnp_pairwise_dist
    ] a b
  [34m[22m[1min [39m[22m[22m(c,) }

In [7]:
jax.make_jaxpr(ein_pairwise_dist)(x, y)

{ [34m[22m[1mlambda [39m[22m[22m; a[35m:f32[1024,256,3][39m b[35m:f32[1024,256,3][39m. [34m[22m[1mlet
    [39m[22m[22mc[35m:f32[1024,256,256][39m = pjit[
      jaxpr={ [34m[22m[1mlambda [39m[22m[22m; d[35m:f32[1024,256,3][39m e[35m:f32[1024,256,3][39m. [34m[22m[1mlet
          [39m[22m[22mf[35m:f32[1024,256,3][39m = neg e
          g[35m:f32[1024,256,3,1][39m = broadcast_in_dim[
            broadcast_dimensions=(0, 1, 2)
            shape=(1024, 256, 3, 1)
          ] d
          h[35m:f32[1024,3,256][39m = transpose[permutation=(0, 2, 1)] f
          i[35m:f32[1024,1,3,256][39m = broadcast_in_dim[
            broadcast_dimensions=(0, 2, 3)
            shape=(1024, 1, 3, 256)
          ] h
          j[35m:f32[1024,256,3,256][39m = add g i
          k[35m:f32[1024,256,3,256][39m = integer_pow[y=2] j
          l[35m:f32[1024,256,256][39m = reduce_sum[axes=(2,)] k
          m[35m:f32[1024,256,256][39m = sqrt l
        [34m[22m[1min [39