Given $q \in \mathbb{R}^d, K \in \mathbb{R}^{s \times d}, V \in \mathbb{R}^{s \times d}$, we define
$$\alpha_i = q^TK_i$$
$$\text{Attn}(q, K, V) = \frac{e^{\alpha_1} V_1 + \cdots + e^{\alpha_n} V_n}{e^{\alpha_1} + \cdots + e^{\alpha_n}}$$
We can compute $\text{Attn}(q, K, V)$ iteratively:

In [37]:
import numpy as np
d = 5
s = 7

np.random.seed(0)
q = np.random.random(d)
K = np.random.random((s, d))
V = np.random.random((s, d))

num = 0
den = 0

for k, v in zip(K, V):
  alpha = (q * k).sum()
  num += np.exp(alpha) * v  # e^{alpha_i}V_i
  den += np.exp(alpha)  # e^{alpha_i}

attn_output = num / den
# e^{alpha_1}V_1 + ... + e^{alpha_n}V_n
# -------------------------------------
#    e^{alpha_1} + ... + e^{alpha_n}

This is equivalent to Blockwise Parallel, applied to a single query vector, with a chunk-size of 1 and without the $\max_i$ term. We can check that `attn_output` is equal to traditional Attn(Q,K,V) computation:

In [38]:
from scipy.special import softmax

attn_weights = softmax(np.einsum('d,sd -> s', q, K), -1)  # q^T K
attn_output2 = np.einsum('s,sd -> d', attn_weights, V) # q^T K
assert np.allclose(attn_output, attn_output2)

Next we introduce $\max_i$ to improve floating point stability. In order to show the math more clearly, we will walk through the first two steps of the iteration, using numeric suffixes to prevent variables from shadowing each other.

In [39]:
num0 = 0
den0 = 0
max0_i = -np.inf

This is the logic that we perform in the first iteration of the loop, comparing the query vector with the first (size-1) chunk of the key matrix.

In [40]:
v0 = V[0]
k0 = K[0]
alpha0 = (q * k0).sum()
max1_i = max(alpha0, max0_i)
num1 = num0 * np.exp(max0_i - max1_i) + np.exp(alpha0 - max1_i) * v0
assert np.all(num1 == np.exp(alpha0 - max1_i) * v0)

Since $\texttt{num0} = 0$,
$\texttt{num1} = e^{\alpha_0 - \max0_i}v_0$


In [41]:
den1 = den0 * np.exp(max0_i - max1_i) + np.exp(alpha0 - max1_i)
assert den1 == np.exp(alpha0 - max1_i)

Similarly, since $\texttt{den0} = 0$, $\texttt{den1} = e^{\alpha_0 - \max1_i}$

In [42]:
v1 = V[1]
k1 = K[1]
alpha1 = (q * k1).sum()
max_i2 = max(alpha1, max1_i)
num2 = num1 * np.exp(max1_i - max_i2) + np.exp(alpha1 - max_i2) * v1

Substituting in the value of $\texttt{num1}$, we get:
$$\texttt{num2} = e^{\alpha_0 - \max0_i}v_0 \times e^{\max0_i - \max1_i} + e^{\alpha_1 - \max1_i}v_1$$
And simplifying the exponents:
$$ = e^{\alpha_0 - \max1_i}v_0 + e^{\alpha_1 - \max1_i}v_1$$

In [43]:
assert np.all(
    num2 == np.exp(alpha0 - max_i2) * v0 + np.exp(alpha1 - max_i2) * v1
)

In [44]:
den2 = den1 * np.exp(max1_i - max_i2) + np.exp(alpha1 - max_i2)
assert den2 == np.exp(alpha0 - max_i2) + np.exp(alpha1 - max_i2)

This second equality comes from the fact that when we substite in the value of $\texttt{den1}$, we get:
$$\texttt{den2} = e^{\alpha_0 - \max0_i} \times e^{\max0_i - \max1_i} + e^{\alpha_1 - \max1_i}$$
And again simplifying the exponents:
$$ = e^{\alpha_0 - \max1_i} + e^{\alpha_1 - \max1_i}$$

When we take tha fraction of $\texttt{num2}$ and $\texttt{den2}$, the $\max1_i$ terms cancel out:
 $$\frac{\text{num2}}{\text{den2}} 
 = \frac{e^{\alpha_0 - \max_i}v_0 + e^{\alpha_1 - \max_i}v_1}{e^{\alpha_0 - \max_i} + e^{\alpha_1 - \max_i}}
 = \frac{e^{\alpha_0} v_0 + e^{\alpha_1} v_1}{e^{\alpha_0} + e^{\alpha_1}}$$

In [45]:
assert np.allclose(
    num2 / den2,
    (np.exp(alpha0) * v0 + np.exp(alpha1) * v1) / (np.exp(alpha0) + np.exp(alpha1))
)

We can see that this previous expression looks like traditional Attn(Q,K,V). We can confirm the equivalence as follows:

In [46]:
alpha = np.array([alpha0, alpha1])
v = np.array([v0, v1])
assert np.allclose(
    num2 / den2,
    (softmax(alpha)[..., None] * v).sum(0)
)

 $$\frac{\text{num2}}{\text{den2}} = \text{softmax}(\alpha_0, \alpha_1)^T[ v_0, v_1 ]$$

This simple example is intended to clarify the equivalence between Blockwise Parallel and traditional Attn(Q,K,V) computation. For the next version, we add the following logic:
 - chunk-size > 1
 - iteration over a query matrix
 - a batch dimension

We also put this logic into a loop:

In [47]:
num0 = 0
den0 = 0
max_i0 = -np.inf
n = 3  # number of chunks
b = 2  # batch dimension (could also include head dimension, since heads are parallel for self-attention)
s = 7
d = 5
Q = np.random.random((n, b, s, d))
K = np.random.random((n, b, s, d))
V = np.random.random((n, b, s, d))

In [48]:
attn_outputs = []

q: np.ndarray
for i, q in enumerate(Q):
  assert list(q.shape) == [b, s, d]
  num = np.zeros((b,s,d))  # initialize numerator
  den = np.zeros((b,s))  # initialize denominator
  max_i = -np.inf * np.ones((b, s))  # initialize max_i

  k: np.ndarray
  v: np.ndarray
  for j, (k, v) in enumerate(zip(K, V)):
    assert list(k.shape) == [b, s, d]
    assert list(v.shape) == [b, s, d]
    alpha: np.ndarray = np.einsum('bqd,bkd -> bqk', q, k)  # q^T K
    prev = max_i
    max_i = np.maximum(alpha.max(-1), max_i)  # update max_i
    exp_values = np.einsum('bqk,bkd -> bqd', np.exp(alpha - max_i[..., None]), v)  # e^{alpha - max_i}^T v

    # update numerator and denominator
    num = num * np.exp(prev - max_i)[..., None] + exp_values  
    den = den * np.exp(prev - max_i) + np.exp(alpha - max_i[..., None]).sum(-1)

  attn_outputs.append(num / den[..., None])

attn_outputs = np.stack(attn_outputs)
attn_outputs.shape

(3, 2, 7, 5)

We can now compare this to a traditional Attn(Q,K,V) computation to verify that the two are equivalent:

In [49]:
Q1 = Q.transpose([1, 0, 2, 3]).reshape(b, -1, d)
K1 = K.transpose([1, 0, 2, 3]).reshape(b, -1, d)
V1 = V.transpose([1, 0, 2, 3]).reshape(b, -1, d)
attn_weights: np.ndarray = softmax(np.einsum('bqd,bkd -> bqk', Q1, K1), -1)  # Q^T K
assert list(attn_weights.shape) == [b, s * n, s * n]
attn_outputs2 = np.einsum('bqk,bkd -> bqd', attn_weights, V1)  # q^T K V
attn_outputs = attn_outputs.transpose(1, 0, 2, 3).reshape(b, n*s, d)  # merge blocks for comparison
assert np.allclose(attn_outputs, attn_outputs2)

We can complete the implementation of a block-wise transformer layer by adding a dense network with residual connections and layer normalization:

In [50]:
w1 = np.random.standard_normal((d, d))
b1 = np.random.standard_normal(d)
w2 = np.random.standard_normal((d, d))
b2 = np.random.standard_normal(d)

def layer_norm(x: np.ndarray):
    mean = np.mean(x, axis=-1, keepdims=True)
    variance = np.var(x, axis=-1, keepdims=True)
    return (x - mean) / np.sqrt(variance)

def relu(x: np.ndarray): 
    return np.maximum(0, x)

def linear(x: np.ndarray, w: np.ndarray, b: np.ndarray): 
    return np.einsum('bqd,dw -> bqw', x, w) + b[None, None]

def postprocess(x: np.ndarray):
    x0 = x
    x = layer_norm(x)

    # 2-layer feedforward network
    x = linear(x, w1, b1)
    x = relu(x)
    x = linear(x, w2, b2)

    # residual connection + layer normalization
    x = x0 + x
    x = layer_norm(x)
    return x


In [51]:
outputs = []

q: np.ndarray
for i, q in enumerate(Q):
    assert list(q.shape) == [b, s, d]
    num = np.zeros((b,s,d))  # initialize numerator
    den = np.zeros((b,s))  # initialize denominator
    max_i = -np.inf * np.ones((b, s))  # initialize max_i

    k: np.ndarray
    v: np.ndarray
    for j, (k, v) in enumerate(zip(K, V)):
        assert list(k.shape) == [b, s, d]
        assert list(v.shape) == [b, s, d]
        alpha: np.ndarray = np.einsum('bqd,bkd -> bqk', q, k)  # q^T K
        prev = max_i
        max_i = np.maximum(alpha.max(-1), max_i)  # update max_i
        exp_values = np.einsum('bqk,bkd -> bqd', np.exp(alpha - max_i[..., None]), v)  # e^{alpha - max_i}^T v

        # update numerator and denominator
        num = num * np.exp(prev - max_i)[..., None] + exp_values  
        den = den * np.exp(prev - max_i) + np.exp(alpha - max_i[..., None]).sum(-1)

    chunk_attn_output = num / den[..., None]
    x = postprocess(chunk_attn_output)
    outputs.append(x)

outputs = np.stack(outputs)
outputs.shape

(3, 2, 7, 5)

Confirm equivalence with vanilla Transformer block:

In [52]:
attn_weights: np.ndarray = softmax(np.einsum('bqd,bkd -> bqk', Q1, K1), -1)  # Q^T K
assert list(attn_weights.shape) == [b, s * n, s * n]
attn_outputs = np.einsum('bqk,bkd -> bqd', attn_weights, V1)  # q^T K V
outputs2 = postprocess(attn_outputs)

assert np.allclose(
    outputs.transpose(1, 0, 2, 3).reshape(b, n*s, d),  # merge blocks for comparison
    outputs2
)

Finally, we demonstrate a multiprocess implementation of a ring transformer. Because of the use of multiprocessing, the code has to be defined in a separate file. Here, we can see that the logic is still equivalent:

In [53]:
from concurrent.futures import ThreadPoolExecutor
from multiprocessing import Queue

def start_host(index: int, q: np.ndarray, recv: Queue, send: Queue):
    print(f"Starting host {index}")
    b, s, d = q.shape
    num = np.zeros((b, s, d))  # initialize numerator
    den = np.zeros((b, s))  # initialize denominator
    max_i = -np.inf * np.ones((b, s))  # initialize max_i

    for _ in range(n):
        k: np.ndarray
        v: np.ndarray
        k, v = recv.get()  # Receive k, v from the input queue (previous host)
        assert k.shape == (b, s, d) and v.shape == (b, s, d)

        alpha: np.ndarray = np.einsum("bqd,bkd -> bqk", q, k)  # q^T K
        prev = max_i
        max_i = np.maximum(alpha.max(-1), max_i)  # update max_i
        exp_values = np.einsum("bqk,bkd -> bqd", np.exp(alpha - max_i[..., None]), v)

        num = num * np.exp(prev - max_i)[..., None] + exp_values
        den = den * np.exp(prev - max_i) + np.exp(alpha - max_i[..., None]).sum(-1)

        send.put((k, v))  # Send (k, v) to the output queue for the next host

    x = num / den[..., None]
    x = postprocess(x)
    print(f"Host {index} done")
    return index, x


def ring_transformer():
    num_hosts = len(Q)
    queues = [
        Queue() for _ in range(num_hosts + 1)
    ]  # Create queues for each host pair, plus one extra to complete the ring

    # Initialize the first set of (k, v) pairs in the queues
    for queue, k, v in zip(queues, K, V):
        queue.put((k, v))

    with ThreadPoolExecutor(max_workers=num_hosts) as executor:
        futures = [
            executor.submit(start_host, i, q, queues[i], queues[(i + 1) % num_hosts])
            for i, q in enumerate(Q)
        ]
        outputs = [future.result() for future in futures]

    # Ensure outputs are sorted by index to maintain deterministic order
    return np.stack([x for _, x in sorted(outputs)])

In [54]:
outputs3 = ring_transformer()
assert np.allclose(
    outputs,
    outputs3
)

Starting host 0Starting host 1

Starting host 2
Host 2 done
Host 0 done
Host 1 done
