# Flash attention in python, numba and cuda

todo: registers calculation, show spills and run both versions

In [2]:
import numba
from numba.cuda import as_cuda_array as ca
import numpy as np
import math
import torch
import sys, os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

sys.path.insert(0, '../..')
from utils import load_cuda, cuda_begin, cdiv, get_sig

def cdiv(a,b):
    "Int ceiling division of `a` over `b`"
    return (a+b-1)//b

import os


In [3]:
# Test tensors
N_inp = 32
N_out = 32
d = 128
Q = torch.randn(N_out, d).contiguous()
K = torch.randn(N_inp, d).contiguous()
V = torch.randn(N_inp, d).contiguous()
Kc = K.to("cuda")
Qc = Q.to("cuda")
Vc = V.to("cuda")
scaling = 1.0 / math.sqrt(d)

# Get expected O
O_expected = torch.softmax(Q @ K.T * scaling, dim=-1) @ V
# Get expected L
S = (Q @ K.T) * scaling  # shape: (N_out, N_in)
max_per_row, _ = torch.max(S, dim=1, keepdim=True)  # shape: (N_out, 1)
exp_shifted = torch.exp(S - max_per_row)  # shape: (N_out, N_in)
L_expected = torch.sum(exp_shifted, dim=1)

def check_diff(O, L):
    print("Max absolute difference O: ", (O-O_expected).abs().max())
    print("Max absolute difference L: ", (L-L_expected).abs().max())


# Pure torch

## Numba

In [75]:
@numba.cuda.jit
def attention_numba(Q, K, V, scale_factor: numba.float32, L, O):
    B_c = 16
    B_r = 16
    T_c = (N_inp + B_c - 1) // B_c
    T_r = (N_out + B_r - 1) // B_r
    inp_dtype = K.dtype
    tid_x = numba.cuda.threadIdx.x
    tid_y = numba.cuda.threadIdx.y

    Q_i = numba.cuda.shared.array((B_r, d), inp_dtype)
    K_j = numba.cuda.shared.array((B_c, d), inp_dtype)
    V_j = numba.cuda.shared.array((B_c, d), inp_dtype)
    S = numba.cuda.shared.array((B_r, B_c), inp_dtype)
    
    # These can be in registers
    l_i = numba.cuda.local.array((1,), inp_dtype)
    m_i = numba.cuda.local.array((1,), inp_dtype)
    O_i = numba.cuda.local.array((1, 4), inp_dtype)

                 
    for i in range(T_r):
        for ii in range(tid_y, B_r, numba.cuda.blockDim.y):
            for dd in range(tid_x, d, numba.cuda.blockDim.x):
                Q_i[ii, dd] = Q[ii + i * B_r, dd]
                O_i[ii//16, dd//32] = 0
            l_i[ii//16] = 0
            m_i[ii//16] = -math.inf
        numba.cuda.syncthreads()

        for j in range(T_c):
            for jj in range(tid_y, B_c, numba.cuda.blockDim.y):
                for dd in range(tid_x, d, numba.cuda.blockDim.x):
                    K_j[jj, dd] = K[jj + j * B_c, dd]
                    V_j[jj, dd] = V[jj + j * B_c, dd]

            # S[ii][jj] = scale_factor * (Q_i @ K_j.T)
            numba.cuda.syncthreads()
            for ii in range(tid_x, B_r, numba.cuda.blockDim.x):
                for jj in range(tid_y, B_c, numba.cuda.blockDim.y):
                    S_ij = 0
                    for dd in range(d):
                        S_ij += Q_i[ii, dd] * K_j[jj, dd]
                    S_ij = scale_factor * S_ij
                    S[ii][jj] = S_ij

            numba.cuda.syncthreads()
            for ii in range(tid_y, B_r, numba.cuda.blockDim.y):
                # torch.maximum(m_i, S_i.max(dim=-1, keepdim=True).values)
                # this needs to use the parallel reduction pattern
                m = m_i[ii//16]
                last_m = m
                for jj in range(B_c):
                    m = max(m, S[ii][jj])
                m_i[ii//16] = m
                l = math.exp(last_m - m) * l_i[ii//16]

                for dd in range(tid_x, d, numba.cuda.blockDim.x):
                    O_i[ii//16, dd//32] *= math.exp(last_m - m)
                for jj in range(B_c):
                    P_ij = math.exp(S[ii][jj] - m)  # Cache...
                    l += P_ij
                    for dd in range(tid_x, d, numba.cuda.blockDim.x):
                        O_i[ii//16, dd//32] += P_ij * V_j[jj, dd]
                l_i[ii//16] = l
        numba.cuda.syncthreads()
        for ii in range(tid_y, B_r, numba.cuda.blockDim.y):  
            for dd in range(tid_x, d, numba.cuda.blockDim.x):
                O[ii + i * B_r, dd] = O_i[ii//16, dd//32] / l_i[ii//16]
            L[ii + i * B_r] = l_i[ii//16]   
        numba.cuda.syncthreads() 
   


In [76]:
O2 = torch.zeros(N_out, d, device="cuda").contiguous()
L2 = torch.zeros(N_out, device="cuda")
tpb = (32, 16)
# tpb = (1,) # works
grid = (1,)
# attention_numba[grid, tpb](Qc, Kc, Vc, scaling, L2, O2)
torch.cuda.synchronize()
attention_numba[grid, tpb](Qc, Kc, Vc, scaling, L2, O2)

check_diff(O2.cpu(), L2.cpu())



Max absolute difference O:  tensor(5.3644e-07)
Max absolute difference L:  tensor(7.6294e-06)


In [74]:
check_diff(O2.cpu(), L2.cpu())

Max absolute difference O:  tensor(1.0513)
Max absolute difference L:  tensor(7.6294e-06)


### Registers not spilling