In [2]:
import numpy as np
import math

# Online normalizer calculation for softmax

This is a CPU implementation of the [Online normalizer calculation for softmax](https://arxiv.org/pdf/1805.02867)

In [34]:
# the input vector
x = np.random.randn(100).tolist()

In [35]:
def standard_softmax(x):
    total = 0
    # first load each element to get the sum of the exponential
    for v in x:
        total += math.exp(v)
    # second load each element to get the result
    res = []
    # one store operation
    for v in x:
        res.append(math.exp(v) / total)
    return res

In [36]:
standard_softmax_res = standard_softmax(x)

In [37]:
def safe_softmax(x):
    max_value = float('-inf')
    # first load each element to find the max value
    for v in x:
        if v > max_value:
            max_value = v
    total = 0.0
    # second load each element to calculate the sum of the exponential
    for v in x:
        total += math.exp(v - max_value)
    # third load each element to calculate the result
    res = []
    # one store operation
    for v in x:
        res.append(math.exp(v - max_value) / total)
    return res

In [38]:
safe_softmax_res = safe_softmax(x)

In [39]:
# safe and standard softmax should be the same
np.allclose(standard_softmax_res, safe_softmax_res)

True

In [40]:
def online_softmax(x):
    total = 0.0
    max_value = float('-inf')
    # first load each element to find the max value and sum of the exponential
    for v in x:
        old_max = max_value
        if v > max_value:
            max_value = v
        total = total * math.exp(old_max - max_value) + math.exp(v - max_value)
    # second load each element to calculate the result
    res = []
    # one store operation
    for v in x:
        res.append(math.exp(v - max_value) / total)
    return res

In [41]:
online_softmax_res = online_softmax(x)
np.allclose(online_softmax_res, safe_softmax_res)

True

# Flash Attention

Flash attention paper: [Flash Attention: Fast and Memory-Efficient Exact Attention with IO-Awareness](https://arxiv.org/pdf/2205.14135)

Flash attention is building on top of the online normalizer calculation for softmax. So a good understanding of the online normalizer calculation for softmax will help us understand flash attention.

In [7]:
N, d = 9, 9
q = np.random.randn(N, d)
k = np.random.randn(N, d)
v = np.random.randn(N, d)

In [6]:
# the block size for row and column
b_r, b_c = 3, 3

In [20]:
def flash_attention(q, k, v, b_r, b_c):
    """
    Calculates the attention with flash attention method.
    
    The FLOPs is O(N^2 * d)
    
    The additional memory required is for m and l which is O(N)

    Args:
        q: Query matrix (shape: (N, d))
        k: Key matrix (shape: (N, d))
        v: Value matrix (shape: (N, d))
        b_r: The block size for q
        b_c: The block size for k and v
    Returns:
        attention_output: Output of the attention mechanism (shape: (N, d))
    """
    assert q.shape[0] % b_r == 0, "the number of rows of q must be divisible by b_r"
    assert k.shape[0] % b_c == 0, "the number of rows of k must be divisible by b_c"
    assert v.shape[0] % b_c == 0, "the number of rows of v must be divisible by b_c"
    assert q.shape[1] == k.shape[1] == v.shape[1], "the number of columns of q, k, v must be the same"
    
    # the output of the flash attention
    o = np.zeros_like(q)

    N = q.shape[0]
    # the sum of exponential vector for each row
    l = np.zeros((N, 1))
    # the max value for each row
    m = np.full((N, 1), float('-inf'))
    
    # divide the q into blocks and each block has size b_r * d
    q_blocks = [q[i:i+b_r, :] for i in range(0, N, b_r)]
    # divide the k into blocks and each block has size b_c * d
    k_blocks = [k[i:i+b_c, :] for i in range(0, N, b_c)]
    # divide the v into blocks and each block has size b_c * d
    v_blocks = [v[i:i+b_c, :] for i in range(0, N, b_c)]
    # divide the o into blocks and each block has size b_r * d
    o_blocks = [o[i:i+b_r, :] for i in range(0, N, b_r)]
    # divide the l into blocks and each block has size b_r * 1
    l_blocks = [l[i:i+b_r, :] for i in range(0, N, b_r)]
    # divide the m into blocks and each block has size b_r * 1
    m_blocks = [m[i:i+b_r, :] for i in range(0, N, b_r)]
    
    n_q_blocks = len(q_blocks)
    n_k_blocks = len(k_blocks)
    
    for j in range(n_k_blocks):
        # load k_j and v_j from HBM to on-chip SRAM, line 6
        k_block = k_blocks[j]
        v_block = v_blocks[j]
        # so for FLOPs, the dominant part is b_r * b_c * d, we have n_k_blocks * n_q_blocks * b_r * b_c * d which is O(N^2 * d)
        for i in range(n_q_blocks):
            # load q_i, m_i, l_i, o_i from HBM to on-chip SRAM, line 8
            q_block, m_block, l_block, o_block = q_blocks[i], m_blocks[i], l_blocks[i], o_blocks[i]
            # calculate the dot product of size b_r * b_c, line 9, FLOPs: b_r * b_c * d
            s_i_j = np.matmul(q_block, k_block.T)
            # calculate the max value for each row, b_r * 1, line 10, FLOPs: b_r * d
            m_i_j = np.max(s_i_j, axis=1, keepdims=True)
            # calculate nominator of the softmax of size b_r * b_c, line 10, FLOPs: b_r * b_c * d
            p_i_j = np.exp(s_i_j - m_i_j)
            # calcualte the sum of the exponential for each row, line 10, FLOPs: b_r * d
            l_i_j = np.sum(p_i_j, axis=1, keepdims=True)
            # get the new max value for each row, line 11, FLOPs: b_r * d
            m_i_new = np.maximum(m_block, m_i_j)
            # get the new sum of exponential vector for each row, line 11, FLOPs: 2 * (b_r * d + d)
            l_i_new = l_block * np.exp(m_block - m_i_new) + l_i_j * np.exp(m_i_j - m_i_new)
            # update the output matrix O, line 12, FLOPs: (b_r * b_c * d + b_r * d) + b_r * d + b_r
            current_o_block = np.exp(m_i_j - m_i_new) * p_i_j @ v_block # b_r * d
            updated_old_o_block = l_block * np.exp(m_block - m_i_new) * o_blocks[i] # FLOPs: b_r * d + b_r
            # update output matrix block and store to HBM, line 12, FLOPs: 2 * (b_r * d + b_r)
            o_blocks[i] = (current_o_block + updated_old_o_block) / l_i_new # b_r * d
            # update the max value for each row and store to HBM, line 13, FLOPs: b_r
            m_blocks[i] = m_i_new # b_r * 1
            # update the sum of exponential vector for each row and store to HBM, line 13, FLOPs: b_r
            l_blocks[i] = l_i_new # b_r * 1
            
    o = np.concatenate(o_blocks, axis=0)
    return o


In [21]:
flash_attention_res = flash_attention(q, k, v, b_r, b_c)

In [25]:
# to verify the flash attention is doing the right thing
def standard_attention(Q, K, V):
  """
  Calculates dot-product attention.
  
  The FLOPs is O(N * d^2)
    
  No additional memory required

  Args:
    Q: Query matrix (shape: (N, d))
    K: Key matrix (shape: (N, d))
    V: Value matrix (shape: (N, d))

  Returns:
    attention_output: Output of the attention mechanism (shape: (N, d))
  """
  # For standard attention, the FLOPs is O(N * d^2) which should be smaller than flash attention since N is usually much larger than d
  # Calculate attention scores
  scores = np.matmul(Q, K.T)  # (sequence_length, sequence_length)

  # Calculate attention weights
  attention_weights = np.exp(scores - np.max(scores, axis=-1, keepdims=True))  # (sequence_length, sequence_length)
  attention_weights /= attention_weights.sum(axis=-1, keepdims=True) # (sequence_length, sequence_length)
  # Calculate output of the attention
  attention_output = np.matmul(attention_weights, V)  # (sequence_length, head_dim)
  return attention_output

In [26]:
standard_attention_res = standard_attention(q, k, v)

In [28]:
np.allclose(flash_attention_res, standard_attention_res)

True