Some simple illustrations of various calculations of increasing complexity using einsum. Inspired by Tim Röcktaschel's blog post

In [3]:
import jax.numpy as jnp
import jax
from einops import einsum


def random_array(*shape, num: int = 1) -> jax.Array | list[jax.Array]:
    key = jax.random.split(jax.random.key(0), num=num)
    arr = [jax.random.uniform(key[i], shape=shape) for i in range(num)]
    return arr[0] if num == 1 else arr

In [4]:
A, B = random_array(4, 4, num=2)

A @ B
# jnp.einsum("ij,jk->ik", A, B)

assert jnp.allclose(einsum(A, B, 'i j, j k -> i k'), A @ B)

In [5]:
A @ B.T
jnp.einsum('ij,kj->ik', A, B)
assert jnp.allclose(jnp.einsum('ij,kj->ik', A, B), A @ B.T)

In [6]:
# attention has a term like X W_Q W_K X - can we do it as a one liner with einsum?
B = 5  # batch size
E = 10 # embedding dim
D = 3  # hidden dim 
X = random_array(B, E)
W_Q, W_K = random_array(E, D, num=2)

v1 = X @ W_Q @ W_K.T @ X.T
v2 = X @ W_Q @ (X @ W_K).T
v3 = jnp.einsum('bi, ij, kj, lk -> bl', X, W_Q, W_K, X)
assert jnp.allclose(v1, v3)

In [7]:
import flax.nnx as nnx
from einops import einsum, rearrange
from functools import partial
from jax import vmap

B = 32  # batch size
T = 512  # seq length
E = 128  # embedding dim
D = 8  # hidden dim
H = 16  # number of heads


X_BTE = random_array(B, T, E)
WQ_HED = random_array(H, E, D)
WK_HED = random_array(H, E, D)
WH_HEE = random_array(H, E, E) # combined value/output matrix across heads - technical full-rank so more expressive

# used for vmap impl
WV_HED = random_array(H, E, D)
WO_GE = random_array(H * D, E)

@jax.jit
def attn_eins(X_BTE, WQ_HED, WK_HED, WH_HEE):
    Q_BTHD = einsum(X_BTE, WQ_HED, 'b t e, h e d -> b t h d')
    K_BTHD = einsum(X_BTE, WK_HED, 'b t e, h e d -> b t h d')
    A_BHTT = einsum(Q_BTHD, K_BTHD, 'b tq h d, b tk h d -> b h tq tk')
    A_BHTT = nnx.softmax(A_BHTT / jnp.sqrt(D), axis=-1)

    # sum over heads and compute full-rank output
    Y_BTE = einsum(A_BHTT, X_BTE, WH_HEE, 'b h tq tk, b tk e1, h e1 e2 -> b tq e2')
    return Y_BTE

attn_eins(X_BTE, WQ_HED, WK_HED, WH_HEE);

In [8]:
# now trying to do MHA with plenty of vmaps
@jax.jit
def attn_vmap(X_BTE, WQ_HED, WK_HED, WV_HED, WO_GE):
    @partial(vmap, in_axes=(None, 0, 0, 0))       # over the head dimension
    @partial(vmap, in_axes=(0, None, None, None)) # over the batch dimension
    def _attn_2d(X_TE, WQ_ED, WK_ED, WV_ED):
        Q_TD = X_TE @ WQ_ED
        K_TD = X_TE @ WK_ED 
        V_TD = X_TE @ WV_ED

        A_TT = nnx.softmax(Q_TD @ K_TD.T / jnp.sqrt(D), axis=-1)
        H_TD = A_TT @ V_TD
        return H_TD

    H_HBTD = _attn_2d(X_BTE, WQ_HED, WK_HED, WV_HED)
    H_BTD = rearrange(H_HBTD, 'h b t d -> b t (h d)')
    Y_BTE = H_BTD @ WO_GE
    return Y_BTE

attn_vmap(X_BTE, WQ_HED, WK_HED, WV_HED, WO_GE);

In [9]:
@jax.jit
def attn_nnx(X_BTE, WQ_HED, WK_HED, WV_HED, WO_GE):
    Q_BTHD = einsum(X_BTE, WQ_HED, 'b t e, h e d -> b t h d')
    K_BTHD = einsum(X_BTE, WK_HED, 'b t e, h e d -> b t h d')
    V_BTHE = einsum(X_BTE, WV_HED, 'b t e, h e d -> b t h d')
    H_BTHD = nnx.dot_product_attention(Q_BTHD, K_BTHD, V_BTHE)

    H_BTD = rearrange(H_BTHD, 'b t h d -> b t (h d)')
    Y_BTE = H_BTD @ WO_GE
    return Y_BTE

attn_nnx(X_BTE, WQ_HED, WK_HED, WV_HED, WO_GE);

In [None]:
%%timeit 
attn_nnx(X_BTE, WQ_HED, WK_HED, WV_HED, WO_GE)

157 ms ± 1.42 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [121]:
%%timeit 
attn_eins(X_BTE, WQ_HED, WK_HED, WH_HEE)

256 ms ± 4.58 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [None]:
%%timeit 
attn_vmap(X_BTE, WQ_HED, WK_HED, WV_HED, WO_GE)

251 ms ± 7.75 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [10]:
@jax.jit
def attn_fast_chatgpt(X_BTE, WQ_HED, WK_HED, WV_HED, WO_GE):
    B, T, E = X_BTE.shape
    H, _, D = WQ_HED.shape

    # Pack heads -> 3 big GEMMs
    WQ_p = WQ_HED.reshape(E, H * D)
    WK_p = WK_HED.reshape(E, H * D)
    WV_p = WV_HED.reshape(E, H * D)

    Q = (X_BTE @ WQ_p).reshape(B, T, H, D)  # (B,T,H,D)
    K = (X_BTE @ WK_p).reshape(B, T, H, D)  # (B,T,H,D)
    V = (X_BTE @ WV_p).reshape(B, T, H, D)  # (B,T,H,D)

    # Explicit attention (no lax SDPA)
    scale = 1.0 / jnp.sqrt(D)
    scores = jnp.einsum('b t h d, b s h d -> b h t s', Q, K) * scale   # (B,H,T,S)
    probs  = jax.nn.softmax(scores, axis=-1)                           # (B,H,T,S)
    H_out  = jnp.einsum('b h t s, b s h d -> b t h d', probs, V)       # (B,T,H,D)

    Y_BTE = H_out.reshape(B, T, H * D) @ WO_GE                          # (B,T,E)
    return Y_BTE

In [11]:
@jax.jit
def attn_eins(X_BTE, WQ_HED, WK_HED, WH_HEE):
    Q_BTHD = einsum(X_BTE, WQ_HED, 'b t e, h e d -> b t h d')
    K_BTHD = einsum(X_BTE, WK_HED, 'b t e, h e d -> b t h d')
    A_BHTT = einsum(Q_BTHD, K_BTHD, 'b tq h d, b tk h d -> b h tq tk')
    A_BHTT = nnx.softmax(A_BHTT / jnp.sqrt(D), axis=-1)

    # sum over heads and compute full-rank output
    Y_BTE = einsum(A_BHTT, X_BTE, WH_HEE, 'b h tq tk, b tk e1, h e1 e2 -> b tq e2')
    return Y_BTE

attn_eins(X_BTE, WQ_HED, WK_HED, WH_HEE);

@jax.jit
def attn_eins2(X_BTE, WQ_HED, WK_HED, WV_HED, WO_GE):
    # G ~ H * D
    WQ_EG = rearrange(WQ_HED, 'h e d -> e (h d)')
    WK_EG = rearrange(WK_HED, 'h e d -> e (h d)')
    WV_EG = rearrange(WV_HED, 'h e d -> e (h d)')

    Q_BTHD = rearrange(X_BTE @ WQ_EG, 'b t (h d) -> b t h d', h=H)
    K_BTHD = rearrange(X_BTE @ WK_EG, 'b t (h d) -> b t h d', h=H)
    V_BTHD = rearrange(X_BTE @ WV_EG, 'b t (h d) -> b t h d', h=H)
    A_BHTT = einsum(Q_BTHD, K_BTHD, 'b tq h d, b tk h d -> b h tq tk')
    A_BHTT = nnx.softmax(A_BHTT / jnp.sqrt(D), axis=-1)

    H_BTHD = einsum(A_BHTT, V_BTHD, 'b h tq tk, b tk h d -> b tq h d')
    H_BTG = rearrange(H_BTHD, 'b t h d -> b t (h d)')
    return H_BTG @ WO_GE

attn_eins2(X_BTE, WQ_HED, WK_HED, WV_HED, WO_GE);

@jax.jit
def attn_eins3(X_BTE, WQ_HED, WK_HED, WV_HED, WO_GE):
    Q_BTHD = einsum(X_BTE, WQ_HED, 'b t e, h e d -> b t h d')
    K_BTHD = einsum(X_BTE, WK_HED, 'b t e, h e d -> b t h d')
    V_BTHD = einsum(X_BTE, WV_HED, 'b t e, h e d -> b t h d')

    A_BHTT = einsum(Q_BTHD, K_BTHD, 'b tq h d, b tk h d -> b h tq tk')
    A_BHTT = nnx.softmax(A_BHTT / jnp.sqrt(D), axis=-1)

    H_BTHD = einsum(A_BHTT, V_BTHD, 'b h tq tk, b tk h d -> b tq h d')
    H_BTG = rearrange(H_BTHD, 'b t h d -> b t (h d)')
    return H_BTG @ WO_GE

attn_eins3(X_BTE, WQ_HED, WK_HED, WV_HED, WO_GE);

In [88]:
%%timeit 
attn_eins(X_BTE, WQ_HED, WK_HED, WH_HEE)

261 ms ± 1.92 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [21]:
%%timeit 
attn_eins2(X_BTE, WQ_HED, WK_HED, WV_HED, WO_GE)

134 ms ± 1.42 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [147]:
%%timeit 
attn_eins3(X_BTE, WQ_HED, WK_HED, WV_HED, WO_GE)

140 ms ± 2.81 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [None]:
%%timeit 
attn_fast_chatgpt(X_BTE, WQ_HED, WK_HED, WV_HED, WO_GE)

134 ms ± 2.72 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [None]:
%%timeit 
attn_nnx(X_BTE, WQ_HED, WK_HED, WV_HED, WO_GE)

158 ms ± 1.37 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [94]:
# testing batched GEMMs
M1_ABCD = random_array(50, 50, 50, 50)
M2_ABCD = random_array(50, 50, 50, 50)

In [107]:
@jax.jit
def mult_1(m1, m2):
    return einsum(m1, m2, 'a b i j, a b j k -> a b i k')

@jax.jit
def mult_2(m1, m2):
    return einsum(m1, m2, 'i j a b, j k a b -> i k a b')

In [108]:
%%timeit
a = mult_1(M1_ABCD, M2_ABCD)

8.25 ms ± 33.3 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [109]:
%%timeit
a = mult_2(M1_ABCD, M2_ABCD)

17.7 ms ± 60 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [None]:
from torch.nn import functional as F 
import torch

In [29]:
# torch comparison
x_torch = torch.from_dlpack(X_BTE)
wq_torch = torch.from_dlpack(WQ_HED)
wk_torch = torch.from_dlpack(WK_HED)
wv_torch = torch.from_dlpack(WV_HED)
wo_torch = torch.from_dlpack(WO_GE)

In [30]:
def attn_torch(X_BTE, WQ_HED, WK_HED, WV_HED, WO_GE):
    Q_BHTD = einsum(X_BTE, WQ_HED, 'b t e, h e d -> b h t d')
    K_BHTD = einsum(X_BTE, WK_HED, 'b t e, h e d -> b h t d')
    V_BHTE = einsum(X_BTE, WV_HED, 'b t e, h e d -> b h t d')
    H_BHTD = F.scaled_dot_product_attention(Q_BHTD, K_BHTD, V_BHTE)

    H_BTD = rearrange(H_BHTD, 'b h t d -> b t (h d)')
    Y_BTE = H_BTD @ WO_GE
    return Y_BTE

attn_torch(x_torch, wq_torch, wk_torch, wv_torch, wo_torch);

In [31]:
%%timeit
attn_torch(x_torch, wq_torch, wk_torch, wv_torch, wo_torch)

34.9 ms ± 1.61 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
