# Flash attention in python, numba and cuda

In [2]:
from numba import cuda
from numba.cuda import as_cuda_array as ca
import numpy as np
import math
import torch

def cdiv(a,b):
    "Int ceiling division of `a` over `b`"
    return (a+b-1)//b

import os


# Pure python

In [3]:
import torch, math

def flash_attention_numpy(Q, K, V, O, N_inp, N_out, d) -> None:
    """Forward algo from https://arxiv.org/pdf/2307.08691
    """

    B_c = min(16, N_inp)
    B_r = min(16, N_out)
    T_c = (N_inp + B_c - 1) // B_c
    T_r = (N_out + B_r - 1) // B_r

    scale_factor = 1 / math.sqrt(Q.size(-1))

    # Q and O L split into T_r; K, V in T_c blocks
    for i in range(T_r):
        Q_i = Q[i * B_r : (i + 1) * B_r]
        O_i = torch.zeros(B_r, d)
        L_i = torch.zeros(B_r, 1)
        m_i = torch.full((B_r, 1), -math.inf)
        last_m_i = m_i
        for j in range(T_c):
            K_j = K[j * B_c : (j + 1) * B_c]
            V_j = V[j * B_c : (j + 1) * B_c]
            S_i = scale_factor * (Q_i @ K_j.T)
            m_i = torch.maximum(m_i, S_i.max(dim=-1, keepdim=True).values)
            P_i = torch.exp(S_i - m_i)
            L_i = torch.exp(last_m_i - m_i) * L_i + P_i.sum(dim=-1, keepdim=True)
            O_i = torch.exp(last_m_i - m_i) * O_i + P_i @ V_j
            last_m_i = m_i
        O_i = (1.0 / L_i) * O_i
        L_i = m_i + torch.log(L_i)
        O[i * B_r : (i + 1) * B_r] = O_i
        L[i * B_r : (i + 1) * B_r] = L_i

In [6]:
# Run numpy 
N_inp = 16
N_out = 16
d = 32
Q = torch.randn(N_out, d)
K = torch.randn(N_inp, d)
V = torch.randn(N_inp, d)
O = torch.zeros(N_out, d)
L = torch.zeros(N_out, 1)

flash_attention_numpy(Q, K, V, O, N_inp, N_out, d)

expected = torch.softmax(Q @ K.T / math.sqrt(d), dim=-1) @ V

print("Max absolute difference: ", (O-expected).abs().max())

Max absolute difference:  tensor(3.5763e-07)


## Numba

In [57]:
import numba

@cuda.jit
def attention(Q, K, V, scale_factor: numba.float32, L, O):
    B_c = 16
    B_r = 16
    T_c = (N_inp + B_c - 1) // B_c
    T_r = (N_out + B_r - 1) // B_r
    inp_dtype = K.dtype
    tid_x = numba.cuda.threadIdx.x
    tid_y = numba.cuda.threadIdx.y

    Q_i = numba.cuda.shared.array((B_r, d), inp_dtype)
    K_j = numba.cuda.shared.array((B_c, d), inp_dtype)
    V_j = numba.cuda.shared.array((B_c, d), inp_dtype)

    O_i = numba.cuda.shared.array((B_r, d), inp_dtype)
    l_i = numba.cuda.shared.array((B_r,), inp_dtype)
    m_i = numba.cuda.shared.array((B_r,), inp_dtype)
    S_i = numba.cuda.shared.array((B_c,), inp_dtype)

    for i in range(T_r):
        for ii in range(tid_y, B_r, numba.cuda.blockDim.y):
            for dd in range(tid_x, d, numba.cuda.blockDim.x):
                Q_i[ii, dd] = Q[ii + i * B_r, dd]
                O_i[ii, dd] = 0
            l_i[ii] = 0
            m_i[ii] = -math.inf
        numba.cuda.syncthreads()

        for j in range(T_c):
            for jj in range(tid_y, B_c, numba.cuda.blockDim.y):
                for dd in range(tid_x, d, numba.cuda.blockDim.x):
                    K_j[jj, dd] = K[jj + j * B_c, dd]
                    V_j[jj, dd] = V[jj + j * B_c, dd]

            # S_i = scale_factor * (O_i @ K_j.T)
            for ii in range(tid_x, B_r, numba.cuda.blockDim.x):
                numba.cuda.syncthreads()
                for jj in range(tid_y, B_c, numba.cuda.blockDim.y):
                    S_ij = 0
                    for dd in range(d):
                        S_ij += Q_i[ii, dd] * K_j[jj, dd]
                    S_ij = scale_factor * S_ij
                    S_i[jj] = S_ij

                # torch.maximum(m_i, S_i.max(dim=-1, keepdim=True).values)
                # this needs to use the parallel reduction pattern
                numba.cuda.syncthreads()
                m = m_i[ii]
                last_m = m
                for jj in range(B_c):
                    m = max(m, S_i[jj])
                m_i[ii] = m
                l = math.exp(last_m - m) * l_i[ii]

                for dd in range(d):
                    O_i[ii, dd] *= math.exp(last_m - m)
                for jj in range(B_c):
                    S_i[jj] = math.exp(S_i[jj] - m)  # Cache...
                    l += S_i[jj]
                    for dd in range(d):
                        O_i[ii, dd] += S_i[jj] * V_j[jj, dd]
                l_i[ii] = l
                
                for ii in range(tid_y, B_r, numba.cuda.blockDim.y):
                    for dd in range(tid_x, d, numba.cuda.blockDim.x):
                        O[ii + i * B_r, dd] = O_i[ii, dd] / l_i[ii]
                    L[ii + i * B_r] = l_i[ii]    


In [58]:
N_inp = 64
N_out = 64
d = 128
Q = torch.randn(N_out, d)
K = torch.randn(N_inp, d)
V = torch.randn(N_inp, d)
O2 = torch.zeros(N_out, d, device="cuda")
L2 = torch.zeros(N_out, device="cuda")
Kc = K.to("cuda")
Qc = Q.to("cuda")
Vc = V.to("cuda")
scale_factor = 1/math.sqrt(d)
attention[(32, 4), (1,)](Qc, Kc, Vc, scale_factor, L2, O2)

In [59]:
expected = torch.softmax(Q @ K.T / math.sqrt(d), dim=-1) @ V
print("Max absolute difference: ", (O2.cpu()-expected).abs().max())

Max absolute difference:  tensor(5.9605e-07)


In [78]:
@cuda.jit
def flash_attention_k_numba(Q, K, V, O, N_inp, N_out, d):
    # 1D block of size d for Q
    # K and V are not transponse
    # Block is B_r x d (while for K, V we want B_c x d)
    cbi, cbd, tid = cuda.blockIdx, cuda.blockDim, cuda.threadIdx
    tc = tid.x
    tr = tid.y
    r = cbi.y * cbd.y + tr
    c = cbi.x * cbd.x + tc

    B_r: int = 1
    T_r = math.ceil(N_out/B_r)
    B_c: int = 1
    T_c = math.ceil(N_inp/B_c)


    shar = cuda.shared.array(0, dtype=np.float32)
    Qs = shar[: B_r * d]
    Ks = shar[B_r * d: (B_r * d) + d * B_c]
    Vs = shar[(B_r * d) + d * B_c: (B_r * d) + (d * B_c)*2]
    Ss = shar[(B_r * d) + (d * B_c) * 2: (B_r * d) + (d * B_c)*2 + B_r * B_c]
    Ps = shar[(B_r * d) + (d * B_c) * 2 + B_r * B_c:  2*(B_r * d) + (d * B_c) * 2 + B_r * B_c]
    Os = shar[2 * (B_r * d) + (d * B_c) * 2 + B_r * B_c: 3 * (B_r * d) + (d * B_c) * 2 + B_r * B_c]


    L_i = torch.zeros(B_r, 1)
    m_i = torch.full((B_r, 1), -math.inf)
    last_m_i = torch.full((B_r, 1), -math.inf)

    # Load Qs on chip
    Qs[tr * cbd.x + tc ] = Q[r, c] if r < N and c < d else 0.
    cuda.syncthreads()

    for j in range(T_c):
        # Columns loaded contigously (column major/transpose)
        Ks[j*d + tc] = K[c, j * B_c + tc] if c < d and j * B_c + tc < d else 0.
        Vs[j*d + tc] = V[c, j * B_c + tc] if c < d and j * B_c + tc < d else 0.
        cuda.syncthreads()
        # Compute S_ij = Q_i K_j^T
        for _d in range(d): 
            Ss[tr*d + tc] += Qs[tr*d + _d] * Ks[j*d + _d]
        cuda.syncthreads()
        Ss_max = torch.full((B_r, 1), -math.inf)

        # todo must be by row
        for idx in range(B_r * B_c):  # or actual length of Ss
            if Ss[idx] > Ss_max:
                Ss_max = Ss[idx]
                m_i[Idx] = ...
        m_new = max(m_cur, Ss_max) 

        if j == tc:
            for idx in range(B_r * B_c):  # or the actual length of Ss
                Ps[idx] = math.exp(Ss[idx] - m_new)

        cuda.syncthreads()
        l_new = math.exp(m_cur - m_new) * l_cur
        for idx in range(B_r * B_c):
            l_new +=  Ps[idx]
        l_cur = l_new

        Os[c] = Os[c] * math.exp(m_cur - m_new)
        
        for _d in range(d): 
            Os[c] += Ps[_d] * Vs[_d]

        cuda.syncthreads()

  
    if r < N and c < d: O[r, c] = Os[c] / l_new

def flash_attention_numba(Q, K, V):
    N_out ,d  = Q.shape
    N_inp, kw = K.shape
    vr, vw = V.shape
    assert d==kw, "Size mismatch!"
    assert d==vw, "Size mismatch!"
    assert N_inp==vr, "Size mismatch!"
    O = torch.zeros(N_out, d, dtype=Q.dtype, device=Q.device)

    B_r: int = 1
    B_c: int = 1
    
    dyn_shared_mem_size =  3*(B_r * d) + (d * B_c) * 2 + (B_r * B_c)
    tpb = d, B_r
    blocks = cdiv(d,tpb[0]), cdiv(d,tpb[1])
    flash_attention_k_numba[blocks, tpb, 0, dyn_shared_mem_size](
        ca(Q), ca(K), ca(V), ca(O), N_inp, N_out, d
    ) 
    return O

N_out, N_inp, d = 2, 3, 4

Q = torch.rand(N_out, d).contiguous().cuda()
K = torch.rand(N_inp, d).contiguous().cuda()
V = torch.rand(N_inp, d).contiguous().cuda()


custom_fa = flash_attention_numba(Q, K, V)
torch_fa = torch.softmax(Q @ K.T, dim=-1) @ V

if not torch.isclose(custom_fa,  torch_fa ).all():
    print("Mismatch")
    print(f"\n{custom_fa}")
    print(f"\n{torch_fa}")



Mismatch

tensor([[0., 0., 0., 0.],
        [0., 0., 0., 0.]], device='cuda:0')

tensor([[0.4583, 0.2336, 0.3260, 0.7613],
        [0.4511, 0.2400, 0.3714, 0.7489]], device='cuda:0')


In [31]:
N, d = 2, 3

Q = torch.rand(N, d).contiguous().cuda()
K = torch.rand(N, d).contiguous().cuda()
V = torch.rand(N, d).contiguous().cuda()

torch.softmax(Q @ K.T, dim=-1) @ V

tensor([[0.9416, 0.5529, 0.7670],
        [0.9408, 0.5372, 0.7598]], device='cuda:0')

# Matmul delete

In [None]:
@cuda.jit
def matmul_k_numba(m, n, out, tw):
    cbi,cbd,tid = cuda.blockIdx,cuda.blockDim,cuda.threadIdx
    tc,tr = tid.x,tid.y
    r,c = cbi.y * cbd.y + tr, cbi.x * cbd.x + tc
    h,k  = m.shape
    k2,w = n.shape

    shar = cuda.shared.array(0, dtype=np.float32)
    ms,ns = shar[:tw*tw],shar[tw*tw:2*tw*tw]

    p = np.float32(0.0)
    for ph in range(math.ceil(k/tw)):
        idx = ph*tw
        ms[tr*tw+tc] = m[r, tc+idx] if r<h and idx+tc<k else 0.
        ns[tr*tw+tc] = n[tr+idx, c] if c<w and idx+tr<k else 0.
        cuda.syncthreads()
        for i in range(tw): p += ms[tr*tw+i] * ns[i*tw+tc]
        cuda.syncthreads()
    if r < h and c < w: out[r, c] = p

In [3]:
def matmul_2d_numba(m, n, tw=16):
    h,k  = m.shape
    k2,w = n.shape
    assert k==k2, "Size mismatch!"
    out = torch.zeros(h, w, dtype=m.dtype, device=m.device)
    dyn_shared_mem_size = 2 * tw * tw * 4
    tpb = tw,tw
    blocks = cdiv(w,tpb[0]), cdiv(h,tpb[1])
    matmul_k_numba[blocks, tpb, 0, dyn_shared_mem_size](ca(m), ca(n), ca(out), tw) 
    return out

In [4]:

N, L, M = 12, 34, 65

Q = torch.rand(N, L).contiguous().cuda()
K = torch.rand(L, M).contiguous().cuda()

torch.isclose(matmul_2d_numba(Q, K), Q@K).all()



tensor(True, device='cuda:0')

In [5]:
N, L, M = 12, 34, 65

Q = torch.rand(N, L).contiguous().cuda()

In [6]:
%%timeit -n 10
matmul_2d_numba(Q,K)
torch.cuda.synchronize()

262 μs ± 68.3 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [7]:
%%timeit -n 10
Q@K
torch.cuda.synchronize()

33.6 μs ± 20.9 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)
