Some simple illustrations of various calculations of increasing complexity using einsum. Inspired by Tim Röcktaschel's blog post

In [15]:
import jax.numpy as jnp
import jax


def random_array(*shape, num: int = 1) -> jax.Array | list[jax.Array]:
    key = jax.random.split(jax.random.key(0), num=num)
    arr = [jax.random.uniform(key[i], shape=shape) for i in range(num)]
    return arr[0] if num == 1 else arr

In [16]:
A, B = random_array(4, 4, num=2)

A @ B
jnp.einsum("ij,jk->ik", A, B)
assert jnp.allclose(jnp.einsum("ij,jk->ik", A, B), A @ B)

In [17]:
A @ B.T
jnp.einsum('ij,kj->ik', A, B)
assert jnp.allclose(jnp.einsum('ij,kj->ik', A, B), A @ B.T)

In [19]:
# attention has a term like X W_Q W_K X - can we do it as a one liner with einsum?
B = 5  # batch size
E = 10 # embedding dim
D = 3  # hidden dim 
X = random_array(B, E)
W_Q, W_K = random_array(E, D, num=2)

v1 = X @ W_Q @ W_K.T @ X.T
v2 = X @ W_Q @ (X @ W_K).T
v3 = jnp.einsum('bi, ij, kj, lk -> bl', X, W_Q, W_K, X)
assert jnp.allclose(v1, v3)

In [20]:
import flax.nnx as nnx

B = 3  # batch size
T = 4  # seq length
E = 5  # embedding dim
D = 6  # hidden dim
H = 7  # number of heads


X_BTE = random_array(B, T, E)
WQ_HED = random_array(H, E, D)
WK_HED = random_array(H, E, D)
WH_HEE = random_array(H, E, E)

Q_HBTD = jnp.einsum("bti, hid -> hbtd", X_BTE, WQ_HED)
K_HBTD = jnp.einsum("bti, hid -> hbtd", X_BTE, WK_HED)
A_HBTB = jnp.einsum("hitj, hktj -> hitk", Q_HBTD, K_HBTD)
# softmax over the D dimension +. division by sqrt(D)
A_HBTB = nnx.softmax(A_HBTB / jnp.sqrt(D), axis=-1)
Y_BTE = jnp.einsum("hbti, itj, hjd->btd", A_HBTB, X_BTE, WH_HEE)

# next question - is it actually correct lol