# Flash attention in python, numba and cuda

todo: registers calculation, show spills and run both versions

In [1]:
import numba
from numba.cuda import as_cuda_array as ca
import numpy as np
import math
import torch
import sys, os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

sys.path.insert(0, '../..')
from utils import load_cuda, cuda_begin, cdiv, get_sig

def cdiv(a,b):
    "Int ceiling division of `a` over `b`"
    return (a+b-1)//b

import os


In [2]:
# Test tensors
N_inp = 32
N_out = 32
d = 128
Q = torch.randn(N_out, d).contiguous()
K = torch.randn(N_inp, d).contiguous()
V = torch.randn(N_inp, d).contiguous()
Kc = K.to("cuda")
Qc = Q.to("cuda")
Vc = V.to("cuda")
scaling = 1.0 / math.sqrt(d)

# Get expected O
O_expected = torch.softmax(Q @ K.T * scaling, dim=-1) @ V
# Get expected L
S = (Q @ K.T) * scaling  # shape: (N_out, N_in)
max_per_row, _ = torch.max(S, dim=1, keepdim=True)  # shape: (N_out, 1)
exp_shifted = torch.exp(S - max_per_row)  # shape: (N_out, N_in)
L_expected = torch.sum(exp_shifted, dim=1)

def check_diff(O, L):
    print("Max absolute difference O: ", (O-O_expected).abs().max())
    print("Max absolute difference L: ", (L-L_expected).abs().max())


# O_expected2 = torch.nn.att

# Pure torch

In [3]:
def flash_attention_torch(Q, K, V, O, L, N_inp, N_out, d) -> None:
    """Forward algo from https://arxiv.org/pdf/2307.08691
    """

    B_c = min(16, N_inp)
    B_r = min(16, N_out)
    T_c = (N_inp + B_c - 1) // B_c
    T_r = (N_out + B_r - 1) // B_r

    scale_factor = 1 / math.sqrt(Q.size(-1))

    # Q and O L split into T_r; K, V in T_c blocks
    for i in range(T_r):
        Q_i = Q[i * B_r : (i + 1) * B_r]
        O_i = torch.zeros(B_r, d)
        L_i = torch.zeros(B_r, 1)
        m_i = torch.full((B_r, 1), -math.inf)
        last_m_i = m_i
        for j in range(T_c):
            K_j = K[j * B_c : (j + 1) * B_c]
            V_j = V[j * B_c : (j + 1) * B_c]
            S_i = scale_factor * (Q_i @ K_j.T)
            m_i = torch.maximum(m_i, S_i.max(dim=-1, keepdim=True).values)
            P_i = torch.exp(S_i - m_i)
            L_i = torch.exp(last_m_i - m_i) * L_i + P_i.sum(dim=-1, keepdim=True)
            O_i = torch.exp(last_m_i - m_i) * O_i + P_i @ V_j
            last_m_i = m_i
        O_i = (1.0 / L_i) * O_i
        L_i = m_i + torch.log(L_i)
        O[i * B_r : (i + 1) * B_r] = O_i
        L[i * B_r : (i + 1) * B_r] = L_i

In [4]:
O = torch.zeros(N_out, d)
L = torch.zeros(N_out, 1)

flash_attention_torch(Q, K, V, O, L, N_inp, N_out, d)

check_diff(O, L)

Max absolute difference O:  tensor(2.3842e-07)
Max absolute difference L:  tensor(9.0901)


## Numba

In [5]:
@numba.cuda.jit
def attention_numba_spilling(Q, K, V, scale_factor: numba.float32, L, O):
    B_c = 16
    B_r = 16
    T_c = (N_inp + B_c - 1) // B_c
    T_r = (N_out + B_r - 1) // B_r
    inp_dtype = K.dtype
    tid_x = numba.cuda.threadIdx.x
    tid_y = numba.cuda.threadIdx.y

    Q_i = numba.cuda.shared.array((B_r, d), inp_dtype)
    K_j = numba.cuda.shared.array((B_c, d), inp_dtype)
    V_j = numba.cuda.shared.array((B_c, d), inp_dtype)
    S = numba.cuda.shared.array((B_r, B_c), inp_dtype)
    
    # These can be in registers
    l_i = numba.cuda.local.array((B_r,), inp_dtype)
    m_i = numba.cuda.local.array((B_r,), inp_dtype)
    O_i = numba.cuda.local.array((B_r, d), inp_dtype)

                 
    for i in range(T_r):
        for ii in range(tid_y, B_r, numba.cuda.blockDim.y):
            for dd in range(tid_x, d, numba.cuda.blockDim.x):
                Q_i[ii, dd] = Q[ii + i * B_r, dd]
                O_i[ii, dd] = 0
            l_i[ii] = 0
            m_i[ii] = -math.inf
        numba.cuda.syncthreads()

        for j in range(T_c):
            for jj in range(tid_y, B_c, numba.cuda.blockDim.y):
                for dd in range(tid_x, d, numba.cuda.blockDim.x):
                    K_j[jj, dd] = K[jj + j * B_c, dd]
                    V_j[jj, dd] = V[jj + j * B_c, dd]

            # S[ii][jj] = scale_factor * (Q_i @ K_j.T)
            numba.cuda.syncthreads()
            for ii in range(tid_x, B_r, numba.cuda.blockDim.x):
                for jj in range(tid_y, B_c, numba.cuda.blockDim.y):
                    S_ij = 0
                    for dd in range(d):
                        S_ij += Q_i[ii, dd] * K_j[jj, dd]
                    S_ij = scale_factor * S_ij
                    S[ii][jj] = S_ij

            numba.cuda.syncthreads()
            for ii in range(tid_y, B_r, numba.cuda.blockDim.y):
                # torch.maximum(m_i, S_i.max(dim=-1, keepdim=True).values)
                # this needs to use the parallel reduction pattern
                m = m_i[ii]
                last_m = m
                for jj in range(B_c):
                    m = max(m, S[ii][jj])
                m_i[ii] = m
                l = math.exp(last_m - m) * l_i[ii]

                for dd in range(tid_x, d, numba.cuda.blockDim.x):
                    O_i[ii, dd] *= math.exp(last_m - m)
                for jj in range(B_c):
                    P_ij = math.exp(S[ii][jj] - m)  # Cache...
                    l += P_ij
                    for dd in range(tid_x, d, numba.cuda.blockDim.x):
                        O_i[ii, dd] += P_ij * V_j[jj, dd]
                l_i[ii] = l
        numba.cuda.syncthreads()
        for ii in range(tid_y, B_r, numba.cuda.blockDim.y):  
            for dd in range(tid_x, d, numba.cuda.blockDim.x):
                O[ii + i * B_r, dd] = O_i[ii, dd] / l_i[ii]
            L[ii + i * B_r] = l_i[ii]   
        numba.cuda.syncthreads() 
   


In [7]:
O_splilling = torch.zeros(N_out, d, device="cuda").contiguous()
L_splilling = torch.zeros(N_out, device="cuda")
tpb = (8, 16)
grid = (1,)
attention_numba_spilling[grid, tpb](Qc, Kc, Vc, scaling, L_splilling, O_splilling)
check_diff(O_splilling.cpu(), L_splilling.cpu())



Max absolute difference O:  tensor(5.6624e-07)
Max absolute difference L:  tensor(5.2452e-06)


In [9]:
block_dim_x = 32
block_dim_y = 4
B_r = 16
o_per_thread_x = d // block_dim_x
o_per_thread_y = B_r // block_dim_y

@numba.cuda.jit
def attention_numba_no_spilling(Q, K, V, scale_factor: numba.float32, L, O):
    B_c = 16
    
    T_c = (N_inp + B_c - 1) // B_c
    T_r = (N_out + B_r - 1) // B_r
    inp_dtype = K.dtype
    tid_x = numba.cuda.threadIdx.x
    tid_y = numba.cuda.threadIdx.y

    Q_i = numba.cuda.shared.array((B_r, d), inp_dtype)
    K_j = numba.cuda.shared.array((B_c, d), inp_dtype)
    V_j = numba.cuda.shared.array((B_c, d), inp_dtype)
    S = numba.cuda.shared.array((B_r, B_c), inp_dtype)
    
    # These can be in registers
    l_i = numba.cuda.local.array((o_per_thread_y,), inp_dtype)
    m_i = numba.cuda.local.array((o_per_thread_y,), inp_dtype)
    O_i = numba.cuda.local.array((o_per_thread_y, o_per_thread_x), inp_dtype)

                 
    for i in range(T_r):
        for ii in range(tid_y, B_r, numba.cuda.blockDim.y):
            for dd in range(tid_x, d, numba.cuda.blockDim.x):
                Q_i[ii, dd] = Q[ii + i * B_r, dd]
                O_i[ii//block_dim_y, dd//block_dim_x] = 0
            l_i[ii//block_dim_y] = 0
            m_i[ii//block_dim_y] = -math.inf
        numba.cuda.syncthreads()

        for j in range(T_c):
            for jj in range(tid_y, B_c, numba.cuda.blockDim.y):
                for dd in range(tid_x, d, numba.cuda.blockDim.x):
                    K_j[jj, dd] = K[jj + j * B_c, dd]
                    V_j[jj, dd] = V[jj + j * B_c, dd]

            # S[ii][jj] = scale_factor * (Q_i @ K_j.T)
            numba.cuda.syncthreads()
            for ii in range(tid_x, B_r, numba.cuda.blockDim.x):
                for jj in range(tid_y, B_c, numba.cuda.blockDim.y):
                    S_ij = 0
                    for dd in range(d):
                        S_ij += Q_i[ii, dd] * K_j[jj, dd]
                    S_ij = scale_factor * S_ij
                    S[ii][jj] = S_ij

            numba.cuda.syncthreads()
            for ii in range(tid_y, B_r, numba.cuda.blockDim.y):
                # torch.maximum(m_i, S_i.max(dim=-1, keepdim=True).values)
                # this needs to use the parallel reduction pattern
                m = m_i[ii//block_dim_y]
                last_m = m
                for jj in range(B_c):
                    m = max(m, S[ii][jj])
                m_i[ii//block_dim_y] = m
                l = math.exp(last_m - m) * l_i[ii//block_dim_y]

                for dd in range(tid_x, d, numba.cuda.blockDim.x):
                    O_i[ii//block_dim_y, dd//block_dim_x] *= math.exp(last_m - m)
                for jj in range(B_c):
                    P_ij = math.exp(S[ii][jj] - m)  # Cache...
                    l += P_ij
                    for dd in range(tid_x, d, numba.cuda.blockDim.x):
                        O_i[ii//block_dim_y, dd//block_dim_x] += P_ij * V_j[jj, dd]
                l_i[ii//block_dim_y] = l
        numba.cuda.syncthreads()
        for ii in range(tid_y, B_r, numba.cuda.blockDim.y):  
            for dd in range(tid_x, d, numba.cuda.blockDim.x):
                O[ii + i * B_r, dd] = O_i[ii//block_dim_y, dd//block_dim_x] / l_i[ii//block_dim_y]
            L[ii + i * B_r] = l_i[ii//block_dim_y]   
        numba.cuda.syncthreads() 
   


In [12]:
O_no_spilling = torch.zeros(N_out, d, device="cuda").contiguous()
L_no_spilling = torch.zeros(N_out, device="cuda")
tpb = (block_dim_x, block_dim_y)
grid = (1,)
attention_numba_no_spilling[grid, tpb](Qc, Kc, Vc, scaling, L_no_spilling, O_no_spilling)
check_diff(O_no_spilling.cpu(), L_no_spilling.cpu())

Max absolute difference O:  tensor(5.6624e-07)
Max absolute difference L:  tensor(5.2452e-06)


## Use Cuda

### Register spilling

In [16]:
from pathlib import Path
# cuda_src_path = "./flash_attention.cu"
cuda_src_path = "./flash_attention.cu"
cuda_src = Path(cuda_src_path).read_text()
fname = 'flash_attention'
cpp_src = get_sig(fname, cuda_src)
module_cuda_spilling = load_cuda(cuda_src, cpp_src, [fname], verbose=True)


Using /home/sagemaker-user/.cache/torch_extensions/py312_cu126 as PyTorch extensions root...
No modifications detected for re-loaded extension module flash_attention_v1, skipping build step...
Loading extension module flash_attention_v1...


In [17]:
O_cuda_spilling, L_cuda_spilling = getattr(module_cuda_spilling, fname)(Kc, Qc, Vc)
check_diff(O_cuda_spilling.cpu(), L_cuda_spilling.cpu())

Max absolute difference O:  tensor(2.9802e-07)
Max absolute difference L:  tensor(1.9073e-06)


## Appendix

### Registers spilling

In [13]:
from pathlib import Path
cuda_src_path = "./flash_attention_spilling.cu"
cuda_src = Path(cuda_src_path).read_text()
fname = 'flash_attention_spilling'
cpp_src = get_sig(fname, cuda_src)
module_not_spilling = load_cuda(cuda_src, cpp_src, [fname], verbose=True)

Using /home/sagemaker-user/.cache/torch_extensions/py312_cu126 as PyTorch extensions root...
Detected CUDA files, patching ldflags
Emitting ninja build file /home/sagemaker-user/.cache/torch_extensions/py312_cu126/flash_attention/build.ninja...
If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].
Building extension module flash_attention...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)


[1/3] /opt/conda/bin/x86_64-conda-linux-gnu-c++ -MMD -MF main.o.d -DTORCH_EXTENSION_NAME=flash_attention -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1018\" -isystem /opt/conda/include -isystem /opt/conda/include/torch/csrc/api/include -isystem /opt/conda/lib/python3.12/site-packages/torch/include -isystem /opt/conda/include -isystem /opt/conda/include/python3.12 -D_GLIBCXX_USE_CXX11_ABI=1 -fPIC -std=c++17 -c /home/sagemaker-user/.cache/torch_extensions/py312_cu126/flash_attention/main.cpp -o main.o 
[2/3] /opt/conda/bin/nvcc --generate-dependencies-with-compile --dependency-output cuda.cuda.o.d -ccbin /opt/conda/bin/x86_64-conda-linux-gnu-cc -DTORCH_EXTENSION_NAME=flash_attention -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1018\" -isystem /opt/conda/include -isystem /opt/conda/include/torch/csrc/api/include -isystem /o

Loading extension module flash_attention...


In [15]:
O4, L4 = getattr(module_not_spilling, fname)(Kc, Qc, Vc)
check_diff(O4.cpu(), L4.cpu())

Max absolute difference O:  tensor(1.8869)
Max absolute difference L:  tensor(11.8267)


### Official

In [None]:
from pathlib import Path
cuda_src_path = "./flash_attention_official.cu"
cuda_src = Path(cuda_src_path).read_text()
fname = 'flash_attention_official'
cpp_src = get_sig(fname, cuda_src)
module_not_spilling = load_cuda(cuda_src, cpp_src, [fname], verbose=True)

Using /home/sagemaker-user/.cache/torch_extensions/py312_cu126 as PyTorch extensions root...
The input conditions for extension module flash_attention_official have changed. Bumping to version 1 and re-building as flash_attention_official_v1...
Detected CUDA files, patching ldflags
Emitting ninja build file /home/sagemaker-user/.cache/torch_extensions/py312_cu126/flash_attention_official/build.ninja...
If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].
Building extension module flash_attention_official_v1...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)


[1/3] /opt/conda/bin/x86_64-conda-linux-gnu-c++ -MMD -MF main.o.d -DTORCH_EXTENSION_NAME=flash_attention_official_v1 -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1018\" -isystem /opt/conda/include -isystem /opt/conda/include/torch/csrc/api/include -isystem /opt/conda/lib/python3.12/site-packages/torch/include -isystem /opt/conda/include -isystem /opt/conda/include/python3.12 -D_GLIBCXX_USE_CXX11_ABI=1 -fPIC -std=c++17 -c /home/sagemaker-user/.cache/torch_extensions/py312_cu126/flash_attention_official/main.cpp -o main.o 
[2/3] /opt/conda/bin/nvcc --generate-dependencies-with-compile --dependency-output cuda.cuda.o.d -ccbin /opt/conda/bin/x86_64-conda-linux-gnu-cc -DTORCH_EXTENSION_NAME=flash_attention_official_v1 -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1018\" -isystem /opt/conda/include -isystem /opt/conda/include/t

Loading extension module flash_attention_official_v1...


In [None]:
O_official, L_official = getattr(module_not_spilling, fname)(Kc, Qc, Vc)
check_diff(O_official.cpu(), L_official.cpu())

Max absolute difference O:  tensor(0.9175)
Max absolute difference L:  tensor(11.3697)


## Get registers and shared memory info
To calculare block sizes

In [None]:
from numba import cuda
dev = cuda.get_current_device()
print(dir(dev))
print(dev.name)
print(dev.compute_capability)
print(dev.get_device_identity())
print(dev.get_primary_context())
print(dev.get_device_identity().get_info())


['COMPUTE_CAPABILITY_MAJOR', 'COMPUTE_CAPABILITY_MINOR', 'PCI_BUS_ID', 'PCI_DEVICE_ID', 'PCI_DOMAIN_ID', '__class__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattr__', '__getattribute__', '__getstate__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__str__', '__subclasshook__', '__weakref__', 'attributes', 'compute_capability', 'from_identity', 'get_device_identity', 'get_primary_context', 'id', 'name', 'primary_context', 'release_primary_context', 'reset', 'supports_float16', 'uuid']
b'NVIDIA A10G'
(8, 6)
{'pci_domain_id': 0, 'pci_bus_id': 0, 'pci_device_id': 30}
<CUDA context c_void_p(94262389834336) of device 0>


AttributeError: 'dict' object has no attribute 'get_info'

In [None]:
import torch
props = torch.cuda.get_device_properties(0)
print("Device:", props.name)
print("Max shared memory per block:", props.shared_memory_per_block, "bytes")
print("Max registers per block:", props.regs_per_block)
print("Warp size:", props.warp_size)


Device: NVIDIA A10G


AttributeError: 'torch._C._CudaDeviceProperties' object has no attribute 'shared_memory_per_block'

In [None]:
[mod for mod in dir(props)]

['L2_cache_size',
 '__class__',
 '__delattr__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__getstate__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '_pybind11_conduit_v1_',
 'gcnArchName',
 'is_integrated',
 'is_multi_gpu_board',
 'major',
 'max_threads_per_multi_processor',
 'minor',
 'multi_processor_count',
 'name',
 'regs_per_multiprocessor',
 'total_memory',
 'uuid',
 'warp_size']

## Appendix


In [None]:
@cuda.jit
def flash_attention_k_numba(Q, K, V, O, N_inp, N_out, d):
    # 1D block of size d for Q
    # K and V are not transponse
    # Block is B_r x d (while for K, V we want B_c x d)
    cbi, cbd, tid = cuda.blockIdx, cuda.blockDim, cuda.threadIdx
    tc = tid.x
    tr = tid.y
    r = cbi.y * cbd.y + tr
    c = cbi.x * cbd.x + tc

    B_r: int = 1
    T_r = math.ceil(N_out/B_r)
    B_c: int = 1
    T_c = math.ceil(N_inp/B_c)


    shar = cuda.shared.array(0, dtype=np.float32)
    Qs = shar[: B_r * d]
    Ks = shar[B_r * d: (B_r * d) + d * B_c]
    Vs = shar[(B_r * d) + d * B_c: (B_r * d) + (d * B_c)*2]
    Ss = shar[(B_r * d) + (d * B_c) * 2: (B_r * d) + (d * B_c)*2 + B_r * B_c]
    Ps = shar[(B_r * d) + (d * B_c) * 2 + B_r * B_c:  2*(B_r * d) + (d * B_c) * 2 + B_r * B_c]
    Os = shar[2 * (B_r * d) + (d * B_c) * 2 + B_r * B_c: 3 * (B_r * d) + (d * B_c) * 2 + B_r * B_c]


    L_i = torch.zeros(B_r, 1)
    m_i = torch.full((B_r, 1), -math.inf)
    last_m_i = torch.full((B_r, 1), -math.inf)

    # Load Qs on chip
    Qs[tr * cbd.x + tc ] = Q[r, c] if r < N and c < d else 0.
    cuda.syncthreads()

    for j in range(T_c):
        # Columns loaded contigously (column major/transpose)
        Ks[j*d + tc] = K[c, j * B_c + tc] if c < d and j * B_c + tc < d else 0.
        Vs[j*d + tc] = V[c, j * B_c + tc] if c < d and j * B_c + tc < d else 0.
        cuda.syncthreads()
        # Compute S_ij = Q_i K_j^T
        for _d in range(d): 
            Ss[tr*d + tc] += Qs[tr*d + _d] * Ks[j*d + _d]
        cuda.syncthreads()
        Ss_max = torch.full((B_r, 1), -math.inf)

        # todo must be by row
        for idx in range(B_r * B_c):  # or actual length of Ss
            if Ss[idx] > Ss_max:
                Ss_max = Ss[idx]
                m_i[Idx] = ...
        m_new = max(m_cur, Ss_max) 

        if j == tc:
            for idx in range(B_r * B_c):  # or the actual length of Ss
                Ps[idx] = math.exp(Ss[idx] - m_new)

        cuda.syncthreads()
        l_new = math.exp(m_cur - m_new) * l_cur
        for idx in range(B_r * B_c):
            l_new +=  Ps[idx]
        l_cur = l_new

        Os[c] = Os[c] * math.exp(m_cur - m_new)
        
        for _d in range(d): 
            Os[c] += Ps[_d] * Vs[_d]

        cuda.syncthreads()

  
    if r < N and c < d: O[r, c] = Os[c] / l_new

def flash_attention_numba(Q, K, V):
    N_out ,d  = Q.shape
    N_inp, kw = K.shape
    vr, vw = V.shape
    assert d==kw, "Size mismatch!"
    assert d==vw, "Size mismatch!"
    assert N_inp==vr, "Size mismatch!"
    O = torch.zeros(N_out, d, dtype=Q.dtype, device=Q.device)

    B_r: int = 1
    B_c: int = 1
    
    dyn_shared_mem_size =  3*(B_r * d) + (d * B_c) * 2 + (B_r * B_c)
    tpb = d, B_r
    blocks = cdiv(d,tpb[0]), cdiv(d,tpb[1])
    flash_attention_k_numba[blocks, tpb, 0, dyn_shared_mem_size](
        ca(Q), ca(K), ca(V), ca(O), N_inp, N_out, d
    ) 
    return O

N_out, N_inp, d = 2, 3, 4

Q = torch.rand(N_out, d).contiguous().cuda()
K = torch.rand(N_inp, d).contiguous().cuda()
V = torch.rand(N_inp, d).contiguous().cuda()


custom_fa = flash_attention_numba(Q, K, V)
torch_fa = torch.softmax(Q @ K.T, dim=-1) @ V

if not torch.isclose(custom_fa,  torch_fa ).all():
    print("Mismatch")
    print(f"\n{custom_fa}")
    print(f"\n{torch_fa}")



NotDefinedError: Failed in cuda mode pipeline (step: analyzing bytecode)
[1m[1mThe compiler failed to analyze the bytecode. Variable 'l_cur' is not defined.
[1m
File "../../../../../../tmp/ipykernel_5907/2646925363.py", line 58:[0m
[1m<source missing, REPL/exec in use?>[0m
[0m
[0m[1mDuring: Pass translate_bytecode[0m

In [None]:
N, d = 2, 3

Q = torch.rand(N, d).contiguous().cuda()
K = torch.rand(N, d).contiguous().cuda()
V = torch.rand(N, d).contiguous().cuda()

torch.softmax(Q @ K.T, dim=-1) @ V

tensor([[0.9416, 0.5529, 0.7670],
        [0.9408, 0.5372, 0.7598]], device='cuda:0')

# Matmul delete

In [None]:
@cuda.jit
def matmul_k_numba(m, n, out, tw):
    cbi,cbd,tid = cuda.blockIdx,cuda.blockDim,cuda.threadIdx
    tc,tr = tid.x,tid.y
    r,c = cbi.y * cbd.y + tr, cbi.x * cbd.x + tc
    h,k  = m.shape
    k2,w = n.shape

    shar = cuda.shared.array(0, dtype=np.float32)
    ms,ns = shar[:tw*tw],shar[tw*tw:2*tw*tw]

    p = np.float32(0.0)
    for ph in range(math.ceil(k/tw)):
        idx = ph*tw
        ms[tr*tw+tc] = m[r, tc+idx] if r<h and idx+tc<k else 0.
        ns[tr*tw+tc] = n[tr+idx, c] if c<w and idx+tr<k else 0.
        cuda.syncthreads()
        for i in range(tw): p += ms[tr*tw+i] * ns[i*tw+tc]
        cuda.syncthreads()
    if r < h and c < w: out[r, c] = p

In [None]:
def matmul_2d_numba(m, n, tw=16):
    h,k  = m.shape
    k2,w = n.shape
    assert k==k2, "Size mismatch!"
    out = torch.zeros(h, w, dtype=m.dtype, device=m.device)
    dyn_shared_mem_size = 2 * tw * tw * 4
    tpb = tw,tw
    blocks = cdiv(w,tpb[0]), cdiv(h,tpb[1])
    matmul_k_numba[blocks, tpb, 0, dyn_shared_mem_size](ca(m), ca(n), ca(out), tw) 
    return out

In [None]:

N, L, M = 12, 34, 65

Q = torch.rand(N, L).contiguous().cuda()
K = torch.rand(L, M).contiguous().cuda()

torch.isclose(matmul_2d_numba(Q, K), Q@K).all()



tensor(True, device='cuda:0')

In [None]:
N, L, M = 12, 34, 65

Q = torch.rand(N, L).contiguous().cuda()

In [None]:
%%timeit -n 10
matmul_2d_numba(Q,K)
torch.cuda.synchronize()

262 μs ± 68.3 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [None]:
%%timeit -n 10
Q@K
torch.cuda.synchronize()

33.6 μs ± 20.9 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)
