In [232]:
import torch
from torch.nn import functional as F
import numpy as np

In [233]:
a=torch.tensor([1,2,3],dtype=torch.float32)
b=torch.tensor([4,5,6],dtype=torch.float32)
print(torch.softmax(torch.concat([a,b]), dim=0))


m_a=torch.max(a)
f_a = torch.exp(a - m_a)
l_a = torch.sum(f_a)

m_b=torch.max(b)
f_b = torch.exp(b - m_b)
l_b = torch.sum(f_b)

m = torch.max(m_a,m_b)
f = torch.concat([torch.exp(m_a - m) * f_a, torch.exp(m_b - m) * f_b])
l = torch.sum(f)

print(f / l)

tensor([0.0043, 0.0116, 0.0315, 0.0858, 0.2331, 0.6337])
tensor([0.0043, 0.0116, 0.0315, 0.0858, 0.2331, 0.6337])


[FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness](https://arxiv.org/abs/2205.14135)

In [234]:
import math
'''
    Q: (batch, seq_len, dim)
    K: (batch, seq_len, dim)
    V: (batch, seq_len, dim)
    mask: (batch, seq_len, seq_len)
    return: (batch, seq_len, dim)
'''

def flash_attention(Q, K, V, M = 100_000, dropout_p = 0, is_causal=False):
    b, N, d = Q.shape
    Bc = M // (4 * d)
    Br = min(Bc,d)
    O = torch.zeros_like(Q)
    l = torch.zeros(b, N)
    m = torch.ones(b, N) * -float('inf')
    Tr = N // Br + (1 if N % Br != 0 else 0)
    Tc = N // Bc + (1 if N % Bc != 0 else 0)

    for b_i in range(b):
        for j in range(Tc):
            K_j = K[b_i,Bc * j:Bc * (j + 1)]
            V_j = V[b_i,Bc * j:Bc * (j + 1)]
            
            for i in range(Tr):
                Q_i = Q[b_i, Br * i:Br * (i + 1)]
                O_i = O[b_i, Br * i:Br * (i + 1)]
                l_i = l[b_i, Br * i:Br * (i + 1)]
                m_i = m[b_i, Br * i:Br * (i + 1)]

                S_ij = torch.matmul(Q_i, K_j.T) / math.sqrt(d)
                m_ij = torch.max(S_ij, dim=1).values
                P_ij = torch.exp(S_ij - m_ij.unsqueeze(1))
                l_ij = torch.sum(P_ij, dim=1)                
                m_i_new = torch.max(m_i, m_ij)
                l_i_new = torch.exp(m_i - m_i_new) * l_i + torch.exp(m_ij - m_i_new) * l_ij
                
                O_i = torch.diag(1 / l_i_new) @ (torch.diag(l_i * torch.exp(m_i - m_i_new)) @ O_i
                                                 + torch.diag(torch.exp(m_ij - m_i_new)) @ torch.matmul(P_ij, V_j))
                
                O[b_i, Br * i:Br * (i + 1)] = O_i
                l[b_i, Br * i:Br * (i + 1)] = l_i_new
                m[b_i, Br * i:Br * (i + 1)] = m_i_new
               
    return O

In [240]:
%%time

# b, n, d = 16, 1024, 512
b, n, d = 8, 128, 32

q=torch.randn(b, n, d)
k=torch.randn(b, n, d)
v=torch.randn(b, n, d)

f_a = flash_attention(q, k, v)
a = torch.softmax(q @ k.transpose(1,2) / np.sqrt(d), dim=-1) @ v
print(f_a[0][0][:30])
print(a[0][0][:30])
print(f_a.allclose(a, atol=1e-4))

tensor([-0.0196, -0.3990,  0.0781, -0.0791, -0.1644, -0.2502, -0.0676, -0.1285,
         0.3165,  0.0299, -0.1461,  0.1797,  0.0908,  0.1197,  0.1953,  0.2272,
         0.2118, -0.4319, -0.0455,  0.1400, -0.1682,  0.0430, -0.0486,  0.1379,
        -0.2002,  0.2272,  0.1945,  0.2185,  0.0459,  0.0734])
tensor([-0.0196, -0.3990,  0.0781, -0.0791, -0.1644, -0.2502, -0.0676, -0.1285,
         0.3165,  0.0299, -0.1461,  0.1797,  0.0908,  0.1197,  0.1953,  0.2272,
         0.2118, -0.4319, -0.0455,  0.1400, -0.1682,  0.0430, -0.0486,  0.1379,
        -0.2002,  0.2272,  0.1945,  0.2185,  0.0459,  0.0734])
True
CPU times: user 73.4 ms, sys: 9.6 ms, total: 83 ms
Wall time: 13.1 ms
