In [1]:
from numba import cuda
from numba.cuda import as_cuda_array as ca
import numpy as np
import math
import torch

def cdiv(a,b):
    "Int ceiling division of `a` over `b`"
    return (a+b-1)//b

import os


# Pure python

In [73]:
import torch, math
# within one block.


def flash_attention_numpy(Q, K, V, O, N_inp, N_out, d) -> None:
    """Forward algo from https://arxiv.org/pdf/2307.08691
    """

    B_c = min(16, N_inp)
    B_r = min(16, N_out)
    T_c = (N_inp + B_c - 1) // B_c
    T_r = (N_out + B_r - 1) // B_r

    scale_factor = 1 / math.sqrt(Q.size(-1))

    # Q and O L split into T_r; K, V in T_c blocks
    for i in range(T_r):
        Q_i = Q[i * B_r : (i + 1) * B_r]
        O_i = torch.zeros(B_r, d)
        L_i = torch.zeros(B_r, 1)
        m_i = torch.full((B_r, 1), -math.inf)
        last_m_i = m_i
        for j in range(T_c):
            K_j = K[j * B_c : (j + 1) * B_c]
            V_j = V[j * B_c : (j + 1) * B_c]
            S_i = scale_factor * (Q_i @ K_j.T)
            m_i = torch.maximum(m_i, S_i.max(dim=-1, keepdim=True).values)
            P_i = torch.exp(S_i - m_i)
            L_i = torch.exp(last_m_i - m_i) * L_i + P_i.sum(dim=-1, keepdim=True)
            O_i = torch.exp(last_m_i - m_i) * O_i + P_i @ V_j
            last_m_i = m_i
        O_i = (1.0 / L_i) * O_i
        L_i = m_i + torch.log(L_i)
        O[i * B_r : (i + 1) * B_r] = O_i
        L[i * B_r : (i + 1) * B_r] = L_i

In [76]:

N_inp = 16
N_out = 16
d = 32
Q = torch.randn(N_out, d)
K = torch.randn(N_inp, d)
V = torch.randn(N_inp, d)
O = torch.zeros(N_out, d)
L = torch.zeros(N_out, 1)

flash_attention_numpy(Q, K, V, O, N_inp, N_out, d)

# print(O) 
torch_fa = torch.softmax(Q @ K.T / math.sqrt(d), dim=-1) @ V


if not bool(torch.isclose(O, torch_fa, atol=1e-6).all()):
    print("Mismatch")
    # print(f"\n{O}")
    # print(f"\n{torch_fa}")
print("Max absolute difference: ", (O-torch_fa).abs().max())

# expected = torch.nn.functional.scaled_dot_product_attention(Q[:, :], K[:, :], V[:, :])
# (O-expected).abs().max()

Max absolute difference:  tensor(3.5763e-07)


## Numba

In [8]:
# @cuda.jit
# def flash_attention_k_numba(Q, K, V, O, N, d):
#     # 1D block of size d for Q
#     # K and V are not transponse
#     # Block is B_r x d (while for K, V we want B_c x d)
#     cbi, cbd, tid = cuda.blockIdx, cuda.blockDim, cuda.threadIdx
#     tc = tid.x
#     tr = tid.y
#     r = cbi.y * cbd.y + tr
#     c = cbi.x * cbd.x + tc

#     B_r: int = 1
#     T_r = math.ceil(N/B_r)
#     B_c: int = 1
#     T_c = math.ceil(N/B_c)


#     shar = cuda.shared.array(0, dtype=np.float32)
#     Qs = shar[: d * B_r]
#     Ks = shar[d * B_r: (d * B_r) + d * B_c]
#     Vs = shar[(d * B_r) + d * B_c: (d * B_r) + (d * B_c)*2]
#     Ss = shar[(d * B_r) + (d * B_c)*2: (d * B_r) + (d * B_c)*2 + B_r * B_c]
#     Ps = shar[(d * B_r) + (d * B_c)*2 + B_r * B_c:  2*(d * B_r) + (d * B_c)*2 + B_r * B_c]
#     Os = shar[2*(d * B_r) + (d * B_c)*2 + B_r * B_c: 3*(d * B_r) + (d * B_c)*2 + B_r * B_c]

#     p = np.float32(0.0)
#     m_cur = np.float32(-np.inf)
#     m_new = np.float32(-np.inf)
#     l_cur = np.float32(0.0)
#     l_new = np.float32(0.0)

#     # Load Qs on chip
#     # TODO 1d asusmption here
#     Qs[tc] = Q[r, c] if r < N and c < d else 0.
#     cuda.syncthreads()

#     for j in range(T_c):
#         # TODO 1d asusmption here
#         Ks[tc] = K[c, j] if c < d and j < d else 0.
#         Vs[tc] = V[c, j] if c < d and j < d else 0.
#         cuda.syncthreads()
#         if j == tc:
#             # Compute Q_i K_i 
#             # TODO 1d asusmption here
#             for _d in range(d): Ss[0] += Qs[_d] * Ks[_d]
#         cuda.syncthreads()
#         Ss_max = -np.inf
#         for idx in range(B_r * B_c):  # or actual length of Ss
#             if Ss[idx] > Ss_max:
#                 Ss_max = Ss[idx]
#         m_new = max(m_cur, Ss_max) 

#         if j == tc:
#             for idx in range(B_r * B_c):  # or the actual length of Ss
#                 Ps[idx] = math.exp(Ss[idx] - m_new)

#         cuda.syncthreads()
#         l_new = math.exp(m_cur - m_new) * l_cur
#         for idx in range(B_r * B_c):
#             l_new +=  Ps[idx]
#         l_cur = l_new

#         Os[c] = Os[c] * math.exp(m_cur - m_new)
        
#         for _d in range(d): 
#             Os[c] += Ps[_d] * Vs[_d]

#         cuda.syncthreads()

  
#     if r < N and c < d: O[r, c] = Os[c] / l_new

In [78]:
@cuda.jit
def flash_attention_k_numba(Q, K, V, O, N_inp, N_out, d):
    # 1D block of size d for Q
    # K and V are not transponse
    # Block is B_r x d (while for K, V we want B_c x d)
    cbi, cbd, tid = cuda.blockIdx, cuda.blockDim, cuda.threadIdx
    tc = tid.x
    tr = tid.y
    r = cbi.y * cbd.y + tr
    c = cbi.x * cbd.x + tc

    B_r: int = 1
    T_r = math.ceil(N_out/B_r)
    B_c: int = 1
    T_c = math.ceil(N_inp/B_c)


    shar = cuda.shared.array(0, dtype=np.float32)
    Qs = shar[: d * B_r]
    Ks = shar[d * B_r: (d * B_r) + d * B_c]
    Vs = shar[(d * B_r) + d * B_c: (d * B_r) + (d * B_c)*2]
    Ss = shar[(d * B_r) + (d * B_c) * 2: (d * B_r) + (d * B_c)*2 + B_r * B_c]
    Ps = shar[(d * B_r) + (d * B_c) * 2 + B_r * B_c:  2*(d * B_r) + (d * B_c) * 2 + B_r * B_c]
    Os = shar[2*(d * B_r) + (d * B_c) * 2 + B_r * B_c: 3 * (d * B_r) + (d * B_c) * 2 + B_r * B_c]

    p = np.float32(0.0)
    m_cur = np.float32(-np.inf)
    m_new = np.float32(-np.inf)
    l_cur = np.float32(0.0)
    l_new = np.float32(0.0)

    # Load Qs on chip
    Qs[tr * cbd.x + tc ] = Q[r, c] if r < N and c < d else 0.
    cuda.syncthreads()

    for j in range(T_c):
        # TODO 1d asusmption here
        Ks[tc] = K[c, j*B_c + tc] if c < d and j*B_c + tc < d else 0.
        Vs[tc] = V[c, j*B_c + tc] if c < d and j*B_c + tc < d else 0.
        cuda.syncthreads()
        if j == tc:
            # Compute Q_i K_i 
            # TODO 1d asusmption here
            for _d in range(d): Ss[0] += Qs[_d] * Ks[_d]
        cuda.syncthreads()
        Ss_max = -np.inf
        for idx in range(B_r * B_c):  # or actual length of Ss
            if Ss[idx] > Ss_max:
                Ss_max = Ss[idx]
        m_new = max(m_cur, Ss_max) 

        if j == tc:
            for idx in range(B_r * B_c):  # or the actual length of Ss
                Ps[idx] = math.exp(Ss[idx] - m_new)

        cuda.syncthreads()
        l_new = math.exp(m_cur - m_new) * l_cur
        for idx in range(B_r * B_c):
            l_new +=  Ps[idx]
        l_cur = l_new

        Os[c] = Os[c] * math.exp(m_cur - m_new)
        
        for _d in range(d): 
            Os[c] += Ps[_d] * Vs[_d]

        cuda.syncthreads()

  
    if r < N and c < d: O[r, c] = Os[c] / l_new

def flash_attention_numba(Q, K, V):
    N_out ,d  = Q.shape
    N_inp, kw = K.shape
    vr, vw = V.shape
    assert d==kw, "Size mismatch!"
    assert d==vw, "Size mismatch!"
    assert N_inp==vr, "Size mismatch!"
    O = torch.zeros(N_out, d, dtype=Q.dtype, device=Q.device)

    B_r: int = 1
    B_c: int = 1
    
    dyn_shared_mem_size =  3*(d * B_r) + (d * B_c) * 2 + (B_r * B_c)
    tpb = d, B_r
    blocks = cdiv(d,tpb[0]), cdiv(d,tpb[1])
    flash_attention_k_numba[blocks, tpb, 0, dyn_shared_mem_size](
        ca(Q), ca(K), ca(V), ca(O), N_inp, N_out, d
    ) 
    return O

N_out, N_inp, d = 2, 3, 4

Q = torch.rand(N_out, d).contiguous().cuda()
K = torch.rand(N_inp, d).contiguous().cuda()
V = torch.rand(N_inp, d).contiguous().cuda()


custom_fa = flash_attention_numba(Q, K, V)
torch_fa = torch.softmax(Q @ K.T, dim=-1) @ V

if not torch.isclose(custom_fa,  torch_fa ).all():
    print("Mismatch")
    print(f"\n{custom_fa}")
    print(f"\n{torch_fa}")



Mismatch

tensor([[0., 0., 0., 0.],
        [0., 0., 0., 0.]], device='cuda:0')

tensor([[0.4583, 0.2336, 0.3260, 0.7613],
        [0.4511, 0.2400, 0.3714, 0.7489]], device='cuda:0')


In [31]:
N, d = 2, 3

Q = torch.rand(N, d).contiguous().cuda()
K = torch.rand(N, d).contiguous().cuda()
V = torch.rand(N, d).contiguous().cuda()

torch.softmax(Q @ K.T, dim=-1) @ V

tensor([[0.9416, 0.5529, 0.7670],
        [0.9408, 0.5372, 0.7598]], device='cuda:0')

# Matmul delete

In [None]:
@cuda.jit
def matmul_k_numba(m, n, out, tw):
    cbi,cbd,tid = cuda.blockIdx,cuda.blockDim,cuda.threadIdx
    tc,tr = tid.x,tid.y
    r,c = cbi.y * cbd.y + tr, cbi.x * cbd.x + tc
    h,k  = m.shape
    k2,w = n.shape

    shar = cuda.shared.array(0, dtype=np.float32)
    ms,ns = shar[:tw*tw],shar[tw*tw:2*tw*tw]

    p = np.float32(0.0)
    for ph in range(math.ceil(k/tw)):
        idx = ph*tw
        ms[tr*tw+tc] = m[r, tc+idx] if r<h and idx+tc<k else 0.
        ns[tr*tw+tc] = n[tr+idx, c] if c<w and idx+tr<k else 0.
        cuda.syncthreads()
        for i in range(tw): p += ms[tr*tw+i] * ns[i*tw+tc]
        cuda.syncthreads()
    if r < h and c < w: out[r, c] = p

In [3]:
def matmul_2d_numba(m, n, tw=16):
    h,k  = m.shape
    k2,w = n.shape
    assert k==k2, "Size mismatch!"
    out = torch.zeros(h, w, dtype=m.dtype, device=m.device)
    dyn_shared_mem_size = 2 * tw * tw * 4
    tpb = tw,tw
    blocks = cdiv(w,tpb[0]), cdiv(h,tpb[1])
    matmul_k_numba[blocks, tpb, 0, dyn_shared_mem_size](ca(m), ca(n), ca(out), tw) 
    return out

In [4]:

N, L, M = 12, 34, 65

Q = torch.rand(N, L).contiguous().cuda()
K = torch.rand(L, M).contiguous().cuda()

torch.isclose(matmul_2d_numba(Q, K), Q@K).all()



tensor(True, device='cuda:0')

In [5]:
N, L, M = 12, 34, 65

Q = torch.rand(N, L).contiguous().cuda()

In [6]:
%%timeit -n 10
matmul_2d_numba(Q,K)
torch.cuda.synchronize()

262 μs ± 68.3 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [7]:
%%timeit -n 10
Q@K
torch.cuda.synchronize()

33.6 μs ± 20.9 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)
