In [5]:
import torch
from torch.nn import functional as F
import numpy as np

[FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness](https://arxiv.org/abs/2205.14135)

In [161]:
'''
    Q: (batch, seq_len, dim)
    K: (batch, seq_len, dim)
    V: (batch, seq_len, dim)
    tau: softmax scaling
    is_causal: whether to use causal mask
    dropout_p: dropout probability
    
    return: (batch, seq_len, dim)
'''

def flash_attention(Q, K, V, M = 100_000, tau = 1, dropout_p = 0, is_causal = False):
    b, N, d = Q.shape
    Bc = M // (4 * d)
    Br = min(Bc,d)
    O = torch.zeros_like(Q)
    l = torch.zeros(b, N)
    m = torch.ones(b, N) * -float('inf')
    Tr = N // Br + (1 if N % Br != 0 else 0)
    Tc = N // Bc + (1 if N % Bc != 0 else 0)

    for b_i in range(b):
        for j in range(Tc):
            N_c = Bc * j
            K_j = K[b_i,N_c:N_c + Bc]
            V_j = V[b_i,N_c:N_c + Bc]
            
            for i in range(Tr):
                N_r = Br * i
                Q_i = Q[b_i, N_r:N_r + Br]
                O_i = O[b_i, N_r:N_r + Br]
                l_i = l[b_i, N_r:N_r + Br]
                m_i = m[b_i, N_r:N_r + Br]

                S_ij = torch.matmul(Q_i, K_j.T) * tau
                
                if is_causal:
                    # Apply causal mask
                    r_max = N_c + Bc - N_r - 1
                    c_start = Bc - r_max
                    for r in range(min(r_max, S_ij.shape[0])):
                        S_ij[r, max(c_start + r, 0):] = torch.tensor(-float('inf'))

                m_ij = torch.max(S_ij, dim=1).values
                S_ij = S_ij - m_ij.unsqueeze(1)
                # Need to make sure that the exp is not nan, value of nan could be caused by applying causal mask
                S_ij[S_ij.isnan()] = 0
                P_ij = torch.exp(S_ij)
                # P_ij = torch.exp(torch.where(S_ij - m_ij.unsqueeze(1) != float('nan'), ))
                l_ij = torch.sum(P_ij, dim=1)                
                m_i_new = torch.max(m_i, m_ij)
                l_i_new = torch.exp(m_i - m_i_new) * l_i + torch.exp(m_ij - m_i_new) * l_ij
                if dropout_p > 0:
                    P_ij = F.dropout(P_ij, p = dropout_p, training = True)
                
                O_i = torch.diag(1 / l_i_new) @ (torch.diag(l_i * torch.exp(m_i - m_i_new)) @ O_i
                                                 + torch.diag(torch.exp(m_ij - m_i_new)) @ torch.matmul(P_ij, V_j))

                O[b_i, N_r:N_r + Br] = O_i
                l[b_i, N_r:N_r + Br] = l_i_new
                m[b_i, N_r:N_r + Br] = m_i_new
               
    return O


In [163]:
%%time


## Test

b, n, d = 16, 1024, 512
# b, n, d = 8, 128, 32

q=torch.randn(b, n, d)
k=torch.randn(b, n, d)
v=torch.randn(b, n, d)

# o_a = flash_attention(q, k, v, tau=1/np.sqrt(d))
o_a = flash_attention(q, k, v, tau=1/np.sqrt(d), is_causal = True)

# o = torch.softmax(q @ k.transpose(1,2) / np.sqrt(d), dim=-1) @ v
s = (q @ k.transpose(1,2)) / np.sqrt(d)
s = torch.where(torch.tril(torch.ones_like(s)) != 0, s, torch.tensor(-float('inf')))  # apply causal mask
p = torch.softmax(s, dim=-1)
o = p @ v

print(o_a[0][0][:30])
print(o[0][0][:30])
print(f'Total difference: {torch.sum(torch.abs(o_a - o))}')
print(o_a.allclose(o, atol=1e-4))

o_a = flash_attention(q, k, v, tau=1/np.sqrt(d), is_causal = True, dropout_p = 0.1)
print(o_a[0][0][:30])

tensor([-1.6445, -1.4075,  1.9288,  0.7326,  0.8834,  0.8957, -0.1964,  0.8984,
         0.5274, -0.9962,  1.7195, -1.2960, -0.6398,  1.2130, -0.0071,  1.2505,
        -0.0959,  0.2520,  0.6756, -1.5047,  1.3196, -0.3428, -0.6408, -0.1185,
         0.5393,  0.1485, -0.8046,  1.1012,  0.0594, -0.7007])
tensor([-1.6445, -1.4075,  1.9288,  0.7326,  0.8834,  0.8957, -0.1964,  0.8984,
         0.5274, -0.9962,  1.7195, -1.2960, -0.6398,  1.2130, -0.0071,  1.2505,
        -0.0959,  0.2520,  0.6756, -1.5047,  1.3196, -0.3428, -0.6408, -0.1185,
         0.5393,  0.1485, -0.8046,  1.1012,  0.0594, -0.7007])
Total difference: 0.31099173426628113
True
tensor([-1.8273, -1.5639,  2.1431,  0.8140,  0.9816,  0.9952, -0.2182,  0.9982,
         0.5859, -1.1069,  1.9106, -1.4400, -0.7109,  1.3477, -0.0079,  1.3894,
        -0.1065,  0.2800,  0.7506, -1.6719,  1.4662, -0.3809, -0.7120, -0.1317,
         0.5992,  0.1650, -0.8940,  1.2236,  0.0660, -0.7786])
CPU times: user 58.6 s, sys: 109 ms, total: 58.7

In [166]:
N = 1024
A = torch.randn(N, N)
B = A.clone()

Br=2
Bc=2
Tr = N // Br + (1 if N % Br != 0 else 0)
Tc = N // Bc + (1 if N % Bc != 0 else 0)
for j in range(Tc):
    N_c = Bc * j
    for i in range(Tr):                
        N_r = Br * i
        S_ij = A[N_r:N_r + Br, N_c:N_c + Bc]
        
        r_max = N_c + Bc - N_r - 1
        i_start = Bc - r_max
        for r in range(min(r_max, Br)):
            S_ij[r, max(i_start + r, 0):] = torch.tensor(-float('inf'))

print(A[20][:30], B[20][:30])

tensor([ 4.8389e-01,  8.9228e-01, -8.7868e-01,  5.4902e-01,  2.2318e-01,
        -6.5148e-01,  1.0080e+00,  8.6370e-01, -4.8014e-01, -2.0496e-01,
         2.2228e+00,  1.4096e+00, -1.2294e-01, -2.4882e-01, -1.9225e-01,
         1.1560e+00, -5.0234e-02,  3.0937e-01, -3.8720e-01,  1.2193e-03,
        -8.5482e-03,        -inf,        -inf,        -inf,        -inf,
               -inf,        -inf,        -inf,        -inf,        -inf]) tensor([ 4.8389e-01,  8.9228e-01, -8.7868e-01,  5.4902e-01,  2.2318e-01,
        -6.5148e-01,  1.0080e+00,  8.6370e-01, -4.8014e-01, -2.0496e-01,
         2.2228e+00,  1.4096e+00, -1.2294e-01, -2.4882e-01, -1.9225e-01,
         1.1560e+00, -5.0234e-02,  3.0937e-01, -3.8720e-01,  1.2193e-03,
        -8.5482e-03,  7.2774e-01,  1.3160e-01,  1.2285e+00, -2.4422e-01,
         4.7846e-01, -5.2039e-01, -7.3353e-01, -1.6904e+00,  1.0902e+00])
