Given $q \in \mathbb{R}^d, K \in \mathbb{R}^{s \times d}, V \in \mathbb{R}^{s \times d}$, we define
$$\alpha_i = q^TK_i$$
$$\text{Attn}(q, K, V) = \frac{e^{\alpha_1} V_1 + \cdots + e^{\alpha_n} V_n}{e^{\alpha_1} + \cdots + e^{\alpha_n}}$$
We can compute $\text{Attn}(q, K, V)$ iteratively:

In [74]:
import numpy as np
d = 5
s = 7

q = np.random.random(d)
K = np.random.random((s, d))
V = np.random.random((s, d))

num = 0
den = 0

for k, v in zip(K, V):
  alpha = (q * k).sum()
  num += np.exp(alpha) * v  # e^{alpha_i}V_i
  den += np.exp(alpha)  # e^{alpha_i}

attn_output = num / den
# e^{alpha_1}V_1 + ... + e^{alpha_n}V_n
# -------------------------------------
#    e^{alpha_1} + ... + e^{alpha_n}

This is equivalent to Blockwise Parallel, applied to a single query vector, with a chunk-size of 1 and without the $\max_i$ term. We can check that `attn_output` is equal to traditional Attn(Q,K,V) computation:

In [75]:
from scipy.special import softmax

attn_weights = softmax(np.einsum('d,sd -> s', q, K), -1)  # q^T K
attn_output2 = np.einsum('s,sd -> d', attn_weights, V) # q^T K
np.allclose(attn_output, attn_output2)

True

Next we introduce $\max_i$ to improve floating point stability. In order to show the math more clearly, we will walk through the first two steps of the iteration, using numeric suffixes to prevent variables from shadowing each other.

In [89]:
num0 = 0
den0 = 0
max0_i = -np.inf

This is the logic that we perform in the first iteration of the loop, comparing the query vector with the first (size-1) chunk of the key matrix.

In [77]:
v0 = V[0]
k0 = K[0]
alpha0 = (q * k0).sum()
max1_i = max(alpha0, max0_i)
num1 = num0 * np.exp(max0_i - max1_i) + np.exp(alpha0 - max1_i) * v0
num1 == np.exp(alpha0 - max1_i) * v0

array([ True,  True,  True,  True,  True])

Since $\texttt{num0} = 0$,
$\texttt{num1} = e^{\alpha_0 - \max0_i}v_0$


In [78]:
den1 = den0 * np.exp(max0_i - max1_i) + np.exp(alpha0 - max1_i)
den1 == np.exp(alpha0 - max1_i)

True

Similarly, since $\texttt{den0} = 0$, $\texttt{den1} = e^{\alpha_0 - \max1_i}$

In [79]:
v1 = V[1]
k1 = K[1]
alpha1 = (q * k1).sum()
max_i2 = max(alpha1, max1_i)
num2 = num1 * np.exp(max1_i - max_i2) + np.exp(alpha1 - max_i2) * v1

Substituting in the value of $\texttt{num1}$, we get:
$$\texttt{num2} = e^{\alpha_0 - \max0_i}v_0 \times e^{\max0_i - \max1_i} + e^{\alpha_1 - \max1_i}v_1$$
And simplifying the exponents:
$$ = e^{\alpha_0 - \max1_i}v_0 + e^{\alpha_1 - \max1_i}v_1$$

In [82]:
np.all(
    num2 == np.exp(alpha0 - max_i2) * v0 + np.exp(alpha1 - max_i2) * v1
)

True

In [86]:
den2 = den1 * np.exp(max1_i - max_i2) + np.exp(alpha1 - max_i2)
den2 == np.exp(alpha0 - max_i2) + np.exp(alpha1 - max_i2)

True

This second equality comes from the fact that when we substite in the value of $\texttt{den1}$, we get:
$$\texttt{den2} = e^{\alpha_0 - \max0_i} \times e^{\max0_i - \max1_i} + e^{\alpha_1 - \max1_i}$$
And again simplifying the exponents:
$$ = e^{\alpha_0 - \max1_i} + e^{\alpha_1 - \max1_i}$$

When we take tha fraction of $\texttt{num2}$ and $\texttt{den2}$, the $\max1_i$ terms cancel out:
 $$\frac{\text{num2}}{\text{den2}} 
 = \frac{e^{\alpha_0 - \max_i}v_0 + e^{\alpha_1 - \max_i}v_1}{e^{\alpha_0 - \max_i} + e^{\alpha_1 - \max_i}}
 = \frac{e^{\alpha_0} v_0 + e^{\alpha_1} v_1}{e^{\alpha_0} + e^{\alpha_1}}$$

In [87]:
np.allclose(
    num2 / den2,
    (np.exp(alpha0) * v0 + np.exp(alpha1) * v1) / (np.exp(alpha0) + np.exp(alpha1))
)

True

We can see that this previous expression looks like traditional Attn(Q,K,V). We can confirm the equivalence as follows:

In [88]:
alpha = np.array([alpha0, alpha1])
v = np.array([v0, v1])
np.allclose(
    num2 / den2,
    (softmax(alpha)[..., None] * v).sum(0)
)

True

 $$\frac{\text{num2}}{\text{den2}} = \text{softmax}(\alpha_0, \alpha_1)^T[ v_0, v_1 ]$$

 Hopefully this suffices to show the equivalence between Blockwise Parallel and traditional Attn(Q,K,V) computation. We now add the following logic:
 - iteration over chunks of the query vector
 - chunk-size > 1
 - a batch dimension which might also include the head dimension

We also put this logic into a loop:

In [91]:
num0 = 0
den0 = 0
max_i0 = -np.inf
n = 3  # number of chunks
b = 2  # batch dimension (could also include head dimension, since heads are parallel for self-attention)
s = 7
d = 5
Q = np.random.random((n, b, s, d))
K = np.random.random((n, b, s, d))
V = np.random.random((n, b, s, d))

In [98]:
attn_outputs = []

q: np.ndarray
for i, q in enumerate(Q):
  assert list(q.shape) == [b, s, d]
  num = np.zeros((b,s,d))  # initialize numerator
  den = np.zeros((b,s))  # initialize denominator
  max_i = -np.inf * np.ones((b, s))  # initialize max_i

  k: np.ndarray
  v: np.ndarray
  for j, (k, v) in enumerate(zip(K, V)):
    assert list(k.shape) == [b, s, d]
    assert list(v.shape) == [b, s, d]
    alpha: np.ndarray = np.einsum('bqd,bkd -> bqk', q, k)  # q^T K
    prev = max_i
    max_i = np.maximum(alpha.max(-1), max_i)  # update max_i
    exp_values = np.einsum('bqk,bkd -> bqd', np.exp(alpha - max_i[..., None]), v)  # e^{alpha - max_i}^T v

    # update numerator and denominator
    num = num * np.exp(prev - max_i)[..., None] + exp_values  
    den = den * np.exp(prev - max_i) + np.exp(alpha - max_i[..., None]).sum(-1)

  attn_outputs.append(num / den[..., None])

attn_outputs = np.stack(attn_outputs)
attn_outputs.shape

(3, 2, 7, 5)

We can now compare this to a traditional Attn(Q,K,V) computation to verify that the two are equivalent:

In [97]:
Q1 = Q.transpose([1, 0, 2, 3]).reshape(b, -1, d)
K1 = K.transpose([1, 0, 2, 3]).reshape(b, -1, d)
V1 = V.transpose([1, 0, 2, 3]).reshape(b, -1, d)
attn_weights: np.ndarray = softmax(np.einsum('bqd,bkd -> bqk', Q1, K1), -1)  # Q^T K
assert list(attn_weights.shape) == [b, s * n, s * n]
attn_outputs2 = np.einsum('bqk,bkd -> bqd', attn_weights, V1)  # q^T K V
attn_outputs = attn_outputs.transpose(1, 0, 2, 3).reshape(b, n*s, d)  # merge blocks for comparison
np.allclose(attn_outputs, attn_outputs2)

True

In [103]:
w1 = np.random.random((d, d))
w2 = np.random.random((d, d))
relu = lambda x: np.maximum(0, x)

In [105]:
attn_outputs = []

q: np.ndarray
for i, q in enumerate(Q):
    assert list(q.shape) == [b, s, d]
    num = np.zeros((b,s,d))  # initialize numerator
    den = np.zeros((b,s))  # initialize denominator
    max_i = -np.inf * np.ones((b, s))  # initialize max_i

    k: np.ndarray
    v: np.ndarray
    for j, (k, v) in enumerate(zip(K, V)):
        assert list(k.shape) == [b, s, d]
        assert list(v.shape) == [b, s, d]
        alpha: np.ndarray = np.einsum('bqd,bkd -> bqk', q, k)  # q^T K
        prev = max_i
        max_i = np.maximum(alpha.max(-1), max_i)  # update max_i
        exp_values = np.einsum('bqk,bkd -> bqd', np.exp(alpha - max_i[..., None]), v)  # e^{alpha - max_i}^T v

        # update numerator and denominator
        num = num * np.exp(prev - max_i)[..., None] + exp_values  
        den = den * np.exp(prev - max_i) + np.exp(alpha - max_i[..., None]).sum(-1)

    chunk_attn_output = num / den[..., None]

    ################## NEW CODE ##################

    # 2-layer feedforward network
    resid_attn = np.einsum('bqd,dw -> bqw', chunk_attn_output, w1)
    resid_attn = relu(resid_attn)
    resid_attn = np.einsum('bqd,dw -> bqw', resid_attn, w2)

    # residual connection
    chunk = chunk_attn_output + resid_attn + q
    
    ################ END NEW CODE ################

    attn_outputs.append(chunk_attn_output)

attn_outputs = np.stack(attn_outputs)
attn_outputs.shape

(3, 2, 7, 5)

In [106]:
attn_weights: np.ndarray = softmax(np.einsum('bqd,bkd -> bqk', Q1, K1), -1)  # Q^T K
assert list(attn_weights.shape) == [b, s * n, s * n]
attn_outputs2 = np.einsum('bqk,bkd -> bqd', attn_weights, V1)  # q^T K V

# 2-layer feedforward network
resid_attn = np.einsum('bqd,dw -> bqw', attn_outputs2, w1)
resid_attn = relu(resid_attn)
resid_attn = np.einsum('bqd,dw -> bqw', resid_attn, w2)

# residual connection
chunk = attn_outputs2 + resid_attn + Q1

attn_outputs = attn_outputs.transpose(1, 0, 2, 3).reshape(b, n*s, d)  # merge blocks for comparison
np.allclose(attn_outputs, attn_outputs2)

True