In [125]:
import torch
import math

In [126]:

X = torch.randn(256)

In [127]:
ground_truth = torch.exp(X-X.max()) / torch.exp(X-X.max()).sum()
print(ground_truth)

# 2-pass online-softmax
只需要扫描两次，就可以得到softmax的结果。第一次是计算分母部分的求和。第二次是计算每个元素的softmax值。

In [128]:
X_block = torch.split(X, split_size_or_sections=4)

def dfs(X_block, l, r):
    if l == r:
        block_mx = X_block[l].max()
        block_sum = torch.exp(X_block[l] - block_mx).sum()
        return block_mx, block_sum
    mid = (l + r) // 2
    left_mx, left_sum = dfs(X_block, l, mid)
    right_mx, right_sum = dfs(X_block, mid + 1, r)
    block_mx = max(left_mx, right_mx)
    block_sum = left_sum * math.exp(left_mx-block_mx) + right_sum * math.exp(right_mx-block_mx)
    return block_mx, block_sum

block_mx, block_sum = dfs(X_block, 0, len(X_block) - 1)
block_softmax = torch.exp(X - block_mx) / block_sum
assert torch.allclose(block_softmax, ground_truth)


# 2-pass self-attention

这里结合一些KV cache的trick，可以得到一个2-pass的self-attention。

In [130]:
# (batch size, sequence length, hidden size)
# 这里这考虑单个head的情况
b, s, h = 2, 10, 4
K, V = torch.randn(b, s, h), torch.randn(b, s, h)
Q = torch.randn(b, 1, h) # 这里只考虑单个query

ground_truth = torch.softmax(Q @ K.transpose(-1, -2) / math.sqrt(h), dim=-1) @ V
print(ground_truth.shape)
print(ground_truth)

In [131]:
# * the first pass
mx, d = None, None
# shape of x: (b, 1, s)
# shape of mx: (b, 1, 1)
# shape of d: (b, 1, 1)
x = torch.zeros(b, 1, s)
for i in range(s):
    x[:, :, i:i+1] = Q @ K[..., i:i+1, :].transpose(-1, -2) / math.sqrt(h) # (b, 1, d) @ (b, d, 1) -> (b, 1, 1)
    if i == 0:
        mx = x[:, :, 0:1]
        d = torch.ones_like(mx)
    else:
        new_mx = torch.max(mx, x[:, :, i:i+1])
        d = d * torch.exp(mx-new_mx) + torch.exp(x[:, :, i:i+1]-new_mx)
        mx = new_mx

# * the second pass
# shape of d: (b, 1, 1)
att = torch.exp(x-mx) / d
# shape of att: (b, 1, s)
# shape of V: (b, s, d)
output = att @ V
print(output)
# shape of output: (b, 1, d)
assert torch.allclose(output, ground_truth)


# One-pass self-attention
注意，在one-pass self-attention中，每次只处理K上的一个新元素，但是其实可以处理K上的一个block。

In [200]:

# * the first pass
x = torch.zeros(b, 1, s)
mx, d, o = None, None, None
# shape of x: (b, 1, s)
# shape of mx: (b, 1, 1)
# shape of d: (b, 1, 1)
# shape of o: (b, 1, d)
for i in range(s):
    x[:, :, i:i+1] = Q @ K[..., i:i+1, :].transpose(-1, -2) / math.sqrt(h) # (b, 1, d) @ (b, d, 1) -> (b, 1, 1)
    if i == 0:
        mx = x[:, :, 0:1]
        d = torch.ones_like(mx)
        o = V[..., i:i+1, :]
    else:
        new_mx = torch.max(mx, x[:, :, i:i+1])
        new_d = d * torch.exp(mx-new_mx) + torch.exp(x[:, :, i:i+1]-new_mx)
        o = o * d / new_d * torch.exp(mx - new_mx) + V[..., i:i+1, :] * torch.exp(x[:, :, i:i+1] - new_mx) / new_d
        mx = new_mx
        d = new_d

output = o
print(output)
# shape of output: (b, 1, d)
assert torch.allclose(output, ground_truth)


# Flash attention V1


In [211]:
# 如果q_len=1，其实就是KV cache技巧下的情况
batch_size, num_heads, q_len, kv_len, head_dim = 2, 4, 8, 16, 10
q_block_size = 4
kv_block_size = 8
num_row_tiles = q_len // q_block_size
num_col_tiles = kv_len // kv_block_size

print(f"num_row_tiles: {num_row_tiles}")
print(f"num_col_tiles: {num_col_tiles}")


In [216]:

Q = torch.randn(batch_size, num_heads, q_len, head_dim, requires_grad=True)
K = torch.randn(batch_size, num_heads, kv_len, head_dim, requires_grad=True)
V = torch.randn(batch_size, num_heads, kv_len, head_dim, requires_grad=True)
O = torch.zeros_like(Q)
# [..., None] 是最后一个维度上增加一个维度的有效操作
m = torch.ones(Q.shape[:-1])[..., None] * -float('inf') # shape: (b, h, q_len, 1)
d = torch.zeros(Q.shape[:-1])[..., None] # shape: (b, h, q_len, 1)

# 这里省略了mask，实际上如果是使用KV cache技巧，只需要考虑q_len=1的情况，这个时候实际上已经考虑mask了
ground_truth = torch.softmax(Q @ K.transpose(-1, -2) / math.sqrt(head_dim), dim=-1) @ V
print(Q.shape)
print(ground_truth.shape)
assert Q.shape == ground_truth.shape

In [217]:
Q_blocks = torch.split(Q, split_size_or_sections=q_block_size, dim=-2)
K_blocks = torch.split(K, split_size_or_sections=kv_block_size, dim=-2)
V_blocks = torch.split(V, split_size_or_sections=kv_block_size, dim=-2)
O_blocks = list(torch.split(O, split_size_or_sections=q_block_size, dim=-2)) # make it into a list, coz tuple assignment is not supported
m_blocks = torch.split(m, split_size_or_sections=q_block_size, dim=-2)
d_blocks = torch.split(d, split_size_or_sections=q_block_size, dim=-2)

for i in range(num_row_tiles):
    Qi = Q_blocks[i] # shape: (b, h, q_block_size, head_dim)

    oi = O_blocks[i] # shape: (b, h, q_block_size, head_dim)
    mi = m_blocks[i] # shape: (b, h, q_block_size, 1)
    di = d_blocks[i] # shape: (b, h, q_block_size, 1)

    for j in range(num_col_tiles):
        Kj = K_blocks[j] # shape: (b, h, kv_block_size, head_dim)
        Vj = V_blocks[j] # shape: (b, h, kv_block_size, head_dim)

        S_ij = Qi @ Kj.transpose(-1, -2) / math.sqrt(head_dim) # shape: (b, h, q_block_size, kv_block_size)
        m_ij = torch.max(S_ij, dim=-1, keepdim=True).values # shape: (b, h, q_block_size, 1)
        d_ij = torch.exp(S_ij-m_ij).sum(dim=-1, keepdim=True) # shape: (b, h, q_block_size, 1)
        P_ij = torch.exp(S_ij-m_ij) / d_ij # shape: (b, h, q_block_size, kv_block_size)
        o_ij = P_ij @ Vj # (b, h, q_block_size, head_dim)

        new_mi = torch.maximum(mi, m_ij) # shape: (b, h, q_block_size, 1)
        new_di = di * torch.exp(mi - new_mi) + d_ij * torch.exp(m_ij - new_mi) # shape: (b, h, q_block_size, 1)
        new_oi = oi * di / new_di * torch.exp(mi - new_mi) + o_ij * d_ij / new_di * torch.exp(m_ij - new_mi) # shape: (b, h, q_block_size, head_dim)

        mi = new_mi
        di = new_di
        oi = new_oi

    O_blocks[i] = oi # shape: (b, h, q_block_size, head_dim)

O = torch.cat(O_blocks, dim=-2) # shape: (b, h, q_len, head_dim)
print(O.shape)
assert torch.allclose(O, ground_truth, atol=1e-5)




# Flash attention V2

In [220]:
Q_blocks = torch.split(Q, split_size_or_sections=q_block_size, dim=-2)
K_blocks = torch.split(K, split_size_or_sections=kv_block_size, dim=-2)
V_blocks = torch.split(V, split_size_or_sections=kv_block_size, dim=-2)
O_blocks = list(torch.split(O, split_size_or_sections=q_block_size, dim=-2)) # make it into a list, coz tuple assignment is not supported
m_blocks = torch.split(m, split_size_or_sections=q_block_size, dim=-2)
d_blocks = torch.split(d, split_size_or_sections=q_block_size, dim=-2)

for i in range(num_row_tiles):
    Qi = Q_blocks[i] # shape: (b, h, q_block_size, head_dim)

    oi = O_blocks[i] # shape: (b, h, q_block_size, head_dim)
    mi = m_blocks[i] # shape: (b, h, q_block_size, 1)
    di = d_blocks[i] # shape: (b, h, q_block_size, 1)

    for j in range(num_col_tiles):
        Kj = K_blocks[j] # shape: (b, h, kv_block_size, head_dim)
        Vj = V_blocks[j] # shape: (b, h, kv_block_size, head_dim)

        S_ij = Qi @ Kj.transpose(-1, -2) / math.sqrt(head_dim) # shape: (b, h, q_block_size, kv_block_size)
        m_ij = torch.max(S_ij, dim=-1, keepdim=True).values # shape: (b, h, q_block_size, 1)
        d_ij = torch.exp(S_ij-m_ij).sum(dim=-1, keepdim=True) # shape: (b, h, q_block_size, 1)
        P_ij = torch.exp(S_ij-m_ij) / d_ij # shape: (b, h, q_block_size, kv_block_size)
        o_ij = d_ij * P_ij @ Vj # (b, h, q_block_size, head_dim)

        new_mi = torch.maximum(mi, m_ij) # shape: (b, h, q_block_size, 1)
        new_di = di * torch.exp(mi - new_mi) + d_ij * torch.exp(m_ij - new_mi) # shape: (b, h, q_block_size, 1)
        new_oi = oi * torch.exp(mi - new_mi) + o_ij * torch.exp(m_ij - new_mi) # shape: (b, h, q_block_size, head_dim)

        mi = new_mi
        di = new_di
        oi = new_oi

    O_blocks[i] = oi / di # shape: (b, h, q_block_size, head_dim)

O = torch.cat(O_blocks, dim=-2) # shape: (b, h, q_len, head_dim)
print(O.shape)
assert torch.allclose(O, ground_truth, atol=1e-5)


