In [1]:
import jax
import jax.numpy as jnp
import einops as ein


## Reshape

In [18]:
rng = jax.random.PRNGKey(0)

x = jax.random.normal(rng, (128, 128, 128))

In [19]:
%%timeit
ein.rearrange(x, 'a b c -> b a c').block_until_ready()

108 µs ± 403 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [20]:
%%timeit
ein.rearrange(x, 'a b c -> a c b').block_until_ready()

111 µs ± 2.04 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [21]:
%%timeit
ein.rearrange(x, 'a b c -> c a b').block_until_ready()

120 µs ± 1.83 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [22]:
%%timeit
ein.rearrange(x, 'a b c -> (c a) b').block_until_ready()

179 µs ± 536 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [23]:
%%timeit
ein.rearrange(x, 'a b c -> c (a b)').block_until_ready()

179 µs ± 1.08 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [8]:
%%timeit
ein.rearrange(jnp.stack([x] * 5), 'd a b c -> (d a) b c', d=5).block_until_ready()

346 µs ± 8.17 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [9]:
%%timeit
ein.rearrange(jnp.stack([x] * 5), 'd a b c -> (a d) b c', d=5).block_until_ready()

407 µs ± 23.3 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


## Indexing vs. Multiply

In [38]:
rng = jax.random.PRNGKey(0)

x = jax.random.normal(rng, (501, 49, 47))

ii = jnp.argpartition(x.max(axis=1), 5, axis=-1)[:, :5]

In [41]:
@jax.jit
def normal_take(x, i):
    return jnp.take_along_axis(x, ii[:, None, :], axis=-1)

normal_take(x, ii).shape

(501, 49, 5)

In [43]:
%%timeit

normal_take(x, ii).block_until_ready()

47.3 µs ± 303 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [35]:
import einops as ein

@jax.jit
def oh_ind(x, i):
    I = jnp.eye(x.shape[2], dtype=jnp.bool_)
    ii = I[i]
    return ein.einsum(x, ii, 'a b c, a k c -> a b k')

oh_ind(x, ii).shape

(51, 49, 5)

In [37]:
%%timeit

oh_ind(x, ii).block_until_ready()

42 µs ± 619 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
