In [2]:
import torch
from torch.nn import functional as F

## [FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness](https://arxiv.org/abs/2205.14135)

### Forward Pass

- To do:
  - Swith Q and KV loops (seems clearer)
  - Try [FlashDecoding++](https://arxiv.org/abs/2311.01282) aynchronized softmax ideas

In [5]:
'''
    Q: (batch, seq_len, dim)
    K: (batch, seq_len, dim)
    V: (batch, seq_len, dim)
    tau: softmax scaling
    is_causal: whether to use causal mask
    dropout_p: dropout probability
    
    return: (batch, seq_len, dim)
'''

def flash_attention(Q, K, V, M = 100_000, tau = 1, is_causal = False, dropout_p = 0, seed = 1234):
    b, N, d = Q.shape
    Bc = M // (4 * d)
    Br = min(Bc,d)
    O = torch.zeros_like(Q)
    l = torch.zeros(b, N)
    m = torch.ones(b, N) * -float('inf')
    Tr = N // Br + (1 if N % Br != 0 else 0)
    Tc = N // Bc + (1 if N % Bc != 0 else 0)

    for b_i in range(b):
        for j in range(Tc):
            N_c = Bc * j
            K_j = K[b_i,N_c:N_c + Bc]
            V_j = V[b_i,N_c:N_c + Bc]
            
            for i in range(Tr):
                N_r = Br * i
                Q_i = Q[b_i, N_r:N_r + Br]
                O_i = O[b_i, N_r:N_r + Br]
                l_i = l[b_i, N_r:N_r + Br]
                m_i = m[b_i, N_r:N_r + Br]

                S_ij = tau * (Q_i @ K_j.T)
                
                if is_causal:
                    # Apply causal mask
                    r_max = N_c + Bc - N_r - 1
                    c_start = Bc - r_max
                    for r in range(min(r_max, S_ij.shape[0])):
                        S_ij[r, max(c_start + r, 0):] = torch.tensor(-float('inf'))

                m_ij = torch.max(S_ij, dim=1).values
                S_ij = S_ij - m_ij.unsqueeze(1)
                # Need to make sure that the exp is not nan, value of nan could be caused by applying causal mask
                S_ij[S_ij.isnan()] = 0
                P_ij = torch.exp(S_ij)
                # P_ij = torch.exp(torch.where(S_ij - m_ij.unsqueeze(1) != float('nan'), ))
                l_ij = torch.sum(P_ij, dim=1)
                
                m_i_new = torch.max(m_i, m_ij)
                l_i_new = torch.exp(m_i - m_i_new) * l_i + torch.exp(m_ij - m_i_new) * l_ij
                if dropout_p > 0:
                    torch.manual_seed(seed)
                    P_ij = F.dropout(P_ij, p = dropout_p, training = True)
                
                O_i = torch.diag(1 / l_i_new) @ (torch.diag(l_i * torch.exp(m_i - m_i_new)) @ O_i
                                                 + torch.diag(torch.exp(m_ij - m_i_new)) @ (P_ij @ V_j))

                O[b_i, N_r:N_r + Br] = O_i
                l[b_i, N_r:N_r + Br] = l_i_new
                m[b_i, N_r:N_r + Br] = m_i_new
               
    return O, l, m

def attention(Q, K, V, tau = 1, dropout_p = 0, is_causal = False):
    S = tau * (Q @ K.transpose(1,2))
    if is_causal:
        S = torch.where(torch.tril(torch.ones_like(S)) != 0, S, torch.tensor(-float('inf')))  # apply causal mask
    P = torch.softmax(S, dim=-1)
    if dropout_p > 0:
        P = F.dropout(P, p = dropout_p, training = True)

    O = P @ V

    return O, P

#### Test

In [6]:
%%time

import math

# b, n, d = 16, 1024, 512
b, n, d = 8, 128, 32

Q = torch.randn(b, n, d)
K = torch.randn(b, n, d)
V = torch.randn(b, n, d)

O, P = attention(Q, K, V, tau=1/math.sqrt(d), is_causal = True)
O_flash, _, _ = flash_attention(Q, K, V, tau=1/math.sqrt(d), is_causal = True)


print(O[0][0][:5])
print(O_flash[0][0][:5])
print(f'Total difference: {torch.sum(torch.abs(O_flash - O))}')
print(O_flash.allclose(O, atol=1e-4))

O_flash, _, _ = flash_attention(Q, K, V, tau=1/math.sqrt(d), is_causal = True, dropout_p = 0.1)
print(O_flash[0][0][:5])

tensor([-0.0093, -0.1454,  1.0259, -0.1156, -0.2714])
tensor([-0.0093, -0.1454,  1.0259, -0.1156, -0.2714])
Total difference: 0.0009422217262908816
True
tensor([-0.0104, -0.1616,  1.1398, -0.1285, -0.3015])
CPU times: user 393 ms, sys: 2.33 ms, total: 396 ms
Wall time: 57.6 ms


### Backward Pass

In [7]:
def flash_attention_grad(Q, K, V, O, dO, l, m, M = 100_000, tau = 1, is_causal = False, dropout_p = 0, seed = 1234):
    b, N, d = Q.shape
    Bc = M // (4 * d)
    Br = min(Bc,d)
    Tr = N // Br + (1 if N % Br != 0 else 0)
    Tc = N // Bc + (1 if N % Bc != 0 else 0)
    
    dQ = torch.zeros_like(Q)
    dK = torch.zeros_like(K)
    dV = torch.zeros_like(V)

    for b_i in range(b):
        for j in range(Tc):
            N_c = Bc * j
            K_j = K[b_i,N_c:N_c + Bc]
            V_j = V[b_i,N_c:N_c + Bc]
            dK_j = dK[b_i,N_c:N_c + Bc]
            dV_j = dV[b_i,N_c:N_c + Bc]
            
            for i in range(Tr):
                N_r = Br * i
                Q_i = Q[b_i, N_r:N_r + Br]
                O_i = O[b_i, N_r:N_r + Br]
                dQ_i = dQ[b_i, N_r:N_r + Br]
                dO_i = dO[b_i, N_r:N_r + Br]
                l_i = l[b_i, N_r:N_r + Br]
                m_i = m[b_i, N_r:N_r + Br]

                S_ij = tau * (Q_i @ K_j.T)

                if is_causal:
                    # Apply causal mask
                    r_max = N_c + Bc - N_r - 1
                    c_start = Bc - r_max
                    for r in range(min(r_max, S_ij.shape[0])):
                        S_ij[r, max(c_start + r, 0):] = torch.tensor(-float('inf'))
                
                S_ij = S_ij - m_i.unsqueeze(1)
                # Need to make sure that the exp is not nan, value of nan could be caused by applying causal mask
                S_ij[S_ij.isnan()] = 0
                P_ij = torch.diag(1 / l_i) @ torch.exp(S_ij)
                
                Z_ij = torch.ones_like(P_ij)
                if dropout_p > 0:
                    torch.manual_seed(seed)
                    Z_ij = F.dropout(Z_ij, p = dropout_p, training = True)
                    P_ij = Z_ij * P_ij

                dV_j += P_ij.T @ dO_i

                dP_ij = (dO_i @ V_j.T) * Z_ij
                D_i = torch.sum(dO_i * O_i, dim=1)
                dS_ij = P_ij * (dP_ij - D_i.unsqueeze(1))
                
                dQ[b_i, N_r:N_r + Br] = dQ_i + tau * (dS_ij @ K_j )
                
                dK_j += tau * (dS_ij.T @ Q_i)
            
        dK[b_i,N_c:N_c + Bc] = dK_j
        dV[b_i,N_c:N_c + Bc] = dV_j

    return dQ, dK, dV


def attention_grad(Q, K, V, P, dO, tau = 1):
    dV = P.transpose(1,2) @ dO
    dP = dO @ V.transpose(1, 2)
    dS = P * (dP - (P * dP).sum(dim=-1).unsqueeze(-1))
    # dS = torch.zeros_like(dP)
    # for i in range(dP.shape[1]):
    #     for j in range(dP.shape[2]):
    #         dS[:, i, j] = P[:, i, j] * (dP[:, i, j] - (P[:,i] * dP[:,i]).sum(dim=1))
    dQ = tau * dS @ K
    dK = tau * dS.transpose(1,2) @ Q

    return dQ, dK, dV

#### Test

In [9]:
%%time

import math

torch.manual_seed(0)
# b, n, d = 16, 1024, 512
b, n, d = 8, 128, 32

Q = torch.randn(b, n, d, requires_grad=True)
K = torch.randn(b, n, d, requires_grad=True)
V = torch.randn(b, n, d, requires_grad=True)

O, P = attention(Q, K, V, tau=1/math.sqrt(d), is_causal = True)
O_flash, l_flash, m_flash = flash_attention(Q, K, V, tau = 1/math.sqrt(d), is_causal = True)
dO = O * 0.1

O.backward(dO)
with torch.no_grad():
    dQ, dK, dV = attention_grad(Q, K, V, P, dO, tau=1/math.sqrt(d))
    dQ_flash, dK_flash, dV_flash = flash_attention_grad(Q, K, V, O, dO, l_flash, m_flash, tau = 1/math.sqrt(d), is_causal = True)

print(f'--dQ--\n{dQ[0][0][:5]}\n{Q.grad[0][0][:5]}\n{dQ_flash[0][0][:5]}')
print(f'--dK--\n{dK[0][0][:5]}\n{K.grad[0][0][:5]}\n{dK_flash[0][0][:5]}')
print(f'--dV--\n{dV[0][0][:5]}\n{V.grad[0][0][:5]}\n{dV_flash[0][0][:5]}')
print(dQ.allclose(dQ_flash, atol = 1e-4), Q.grad.allclose(dQ_flash, atol = 1e-4))
print(dK.allclose(dK_flash, atol = 1e-4), K.grad.allclose(dK_flash, atol = 1e-4))
print(dV.allclose(dV_flash, atol = 1e-4), V.grad.allclose(dV_flash, atol = 1e-4))

with torch.no_grad():
    dQ_flash, dK_flash, dV_flash = flash_attention_grad(Q, K, V, O, dO, l_flash, m_flash, tau = 1/math.sqrt(d), is_causal = True, dropout_p = 0.1)
print(f'\n{dQ_flash[0][0][:5]}\n{dK_flash[0][0][:5]}\n{dV_flash[0][0][:5]}')

--dQ--
tensor([0., 0., 0., 0., 0.])
tensor([0., 0., 0., 0., 0.])
tensor([ 4.1220e-08, -2.5450e-08,  7.2401e-08,  1.1534e-09,  1.0810e-07])
--dK--
tensor([-0.0022,  0.0445,  0.0121,  0.0750, -0.0995])
tensor([-0.0022,  0.0445,  0.0121,  0.0750, -0.0995])
tensor([-0.0022,  0.0445,  0.0121,  0.0750, -0.0995])
--dV--
tensor([ 0.0896,  0.1473, -0.0259, -0.0246,  0.0578])
tensor([ 0.0896,  0.1473, -0.0259, -0.0246,  0.0578])
tensor([ 0.0896,  0.1473, -0.0259, -0.0246,  0.0578])
True True
True True
True True

tensor([-0.0443,  0.0274, -0.0779, -0.0012, -0.1163])
tensor([-0.1187,  0.0812,  0.0141,  0.0924, -0.0488])
tensor([ 0.0874,  0.1377, -0.0279,  0.0033,  0.0627])
CPU times: user 610 ms, sys: 13.5 ms, total: 623 ms
Wall time: 72.9 ms
