# Flash Attention in Torch, Numba and Cuda


We implement in 3 different ways the forward algorithm from the [Flash Attention 2 paper](https://arxiv.org/pdf/2307.08691)

We build the kernel for `d=128` and design it so that it computes the full attention in a single block.

![./flash_attention_fwd.png](./flash_attention_fwd.png)

todo: 
- registers calculation, show spills
- investigate why larger matrices slower with profiler? allegedly it becoems comput intensive as we use only 1 block, prove that?
- mention no thunder example
- fix logsumexp calculation

In [1]:
import numba
from numba.cuda import as_cuda_array as ca
from pathlib import Path
import numpy as np
import math
import torch
import sys, os
from torch.profiler import profile, ProfilerActivity
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

sys.path.insert(0, '../..')
from utils import load_cuda, get_sig

import os

def get_loaded_cuda_module(fname, verbose=False):
    cuda_src_path = f"./{fname}.cu"
    cuda_src = Path(cuda_src_path).read_text()
    cpp_src = get_sig(fname, cuda_src)
    return load_cuda(cuda_src, cpp_src, [fname], verbose=verbose)

def profile_kernel(module, fname, *args, **kwargs):
    with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
                with_stack=True, record_shapes=True) as prof:
        torch.cuda.synchronize()
        getattr(module, fname)(*args, **kwargs)
        torch.cuda.synchronize()
    print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=20))


In [31]:
# Test tensors
def get_test_tensors(N_inp, N_out, d):
    Q = torch.randn(N_out, d).contiguous().to("cuda")
    K = torch.randn(N_inp, d).contiguous().to("cuda")
    V = torch.randn(N_inp, d).contiguous().to("cuda")
    scaling = 1.0 / math.sqrt(d)
    return Q, K, V, scaling

N_inp = 16
N_out = 16
d = 2

Q, K, V, scaling = get_test_tensors(N_inp, N_out, d)

# Get expected O
O_expected = torch.nn.functional.scaled_dot_product_attention(Q, K, V)
# Get expected L
S = (Q @ K.T) * scaling  # shape: (N_out, N_in)
max_per_row, _ = torch.max(S, dim=1, keepdim=True)  # shape: (N_out, 1)
exp_shifted = torch.exp(S - max_per_row)  # shape: (N_out, N_inp)
L_expected = torch.sum(exp_shifted, dim=-1)
L_expected2 = max_per_row + torch.log(L_expected)
L_expected3 = max_per_row + torch.logsumexp(S - max_per_row, dim=1)

assert (L_expected2-L_expected3).abs().max() < 5*1e-10

def check_diff(O, L=None, atol=5*1e-5):
    O_diff = (O-O_expected).abs().max()
    print("Max absolute difference:")
    if atol:
        assert O_diff < atol, f"O diff too large: {O_diff} > {atol=}"
    print("O: ", O_diff)
    if L is not None:
        L_diff = (L.squeeze()-L_expected).abs().max()
        if atol:
            assert L_diff < atol, f"L diff too large: {L_diff} > {atol=}"
        print("L: ", L_diff)


check_diff(O=torch.softmax(Q @ K.T * scaling, dim=-1) @ V)

Max absolute difference:
O:  tensor(7.4506e-08, device='cuda:0')


# Pure torch

In [32]:
def flash_attention_torch(Q, K, V, O, L, N_inp, N_out, d) -> None:
    """Forward algo from https://arxiv.org/pdf/2307.08691
    """

    B_c = 16
    B_r = 16
    T_c = (N_inp + B_c - 1) // B_c
    T_r = (N_out + B_r - 1) // B_r

    scaling = 1 / math.sqrt(d)

    # Q and O L split into T_r; K, V in T_c blocks
    for i in range(T_r):
        Q_i = Q[i * B_r : (i + 1) * B_r]
        O_i = torch.zeros(B_r, d)
        L_i = torch.zeros(B_r, 1)
        m_i = torch.full((B_r, 1), -math.inf)
        last_m_i = m_i
        for j in range(T_c):
            K_j = K[j * B_c : (j + 1) * B_c]
            V_j = V[j * B_c : (j + 1) * B_c]
            S_i = scaling * (Q_i @ K_j.T)
            m_i = torch.maximum(m_i, S_i.max(dim=-1, keepdim=True).values)
            P_i = torch.exp(S_i - m_i)
            L_i = torch.exp(last_m_i - m_i) * L_i + P_i.sum(dim=-1, keepdim=True)
            O_i = torch.exp(last_m_i - m_i) * O_i + P_i @ V_j
            last_m_i = m_i
        O_i = (1.0 / L_i) * O_i
        O[i * B_r : (i + 1) * B_r] = O_i
        L[i * B_r : (i + 1) * B_r] = L_i

In [33]:
O_torch_loop = torch.zeros(N_out, d)
L_torch_loop = torch.zeros(N_out, 1)

flash_attention_torch(Q.to("cpu"), K.to("cpu"), V.to("cpu"), O_torch_loop, L_torch_loop, N_inp, N_out, d)

check_diff(
    O_torch_loop.to("cuda"), 
    L_torch_loop.to("cuda"),
)

Max absolute difference:
O:  tensor(2.3842e-07, device='cuda:0')
L:  tensor(1.4305e-06, device='cuda:0')


In [8]:
L_torch_loop, L_expected2

(tensor([[ 7.0635],
         [ 7.4613],
         [ 6.6346],
         [ 5.6546],
         [ 9.0227],
         [ 2.6117],
         [14.6637],
         [ 6.4889],
         [ 6.8281],
         [ 5.6799],
         [ 7.3275],
         [ 7.7920],
         [ 1.8350],
         [ 6.0733],
         [ 5.2760],
         [ 4.3893],
         [ 2.3933],
         [ 8.6296],
         [ 4.9935],
         [ 5.8907],
         [ 5.5498],
         [ 7.6121],
         [ 3.0597],
         [ 3.9153],
         [ 2.2110],
         [ 7.2019],
         [ 8.8995],
         [ 6.7409],
         [ 3.8753],
         [ 6.3121],
         [ 9.9120],
         [ 5.0796]]),
 tensor([[4.0131, 4.0679, 3.9505,  ..., 3.9006, 4.3519, 3.6834],
         [3.8243, 3.8791, 3.7617,  ..., 3.7118, 4.1631, 3.4946],
         [3.7784, 3.8332, 3.7157,  ..., 3.6659, 4.1172, 3.4487],
         ...,
         [4.4671, 4.5219, 4.4045,  ..., 4.3547, 4.8060, 4.1374],
         [3.5048, 3.5596, 3.4422,  ..., 3.3924, 3.8436, 3.1751],
         [4.1407, 4

## Numba

In [14]:
@numba.cuda.jit
def attention_numba_spilling(Q, K, V, scaling: numba.float32, L, O):
    B_c = 16
    B_r = 16
    T_c = (N_inp + B_c - 1) // B_c
    T_r = (N_out + B_r - 1) // B_r
    inp_dtype = K.dtype
    tid_x = numba.cuda.threadIdx.x
    tid_y = numba.cuda.threadIdx.y

    Q_i = numba.cuda.shared.array((B_r, d), inp_dtype)
    K_j = numba.cuda.shared.array((B_c, d), inp_dtype)
    V_j = numba.cuda.shared.array((B_c, d), inp_dtype)
    S = numba.cuda.shared.array((B_r, B_c), inp_dtype)
    
    # These can be in registers but wont fit too large
    # Also here they spill from SMEM
    l_i = numba.cuda.shared.array((B_r,), inp_dtype)
    m_i = numba.cuda.shared.array((B_r,), inp_dtype)
    O_i = numba.cuda.shared.array((B_r, d), inp_dtype)

                 
    for i in range(T_r):
        for ii in range(tid_y, B_r, numba.cuda.blockDim.y):
            for dd in range(tid_x, d, numba.cuda.blockDim.x):
                Q_i[ii, dd] = Q[ii + i * B_r, dd]
                O_i[ii, dd] = 0
            l_i[ii] = 0
            m_i[ii] = -math.inf
        numba.cuda.syncthreads()

        for j in range(T_c):
            for jj in range(tid_y, B_c, numba.cuda.blockDim.y):
                for dd in range(tid_x, d, numba.cuda.blockDim.x):
                    K_j[jj, dd] = K[jj + j * B_c, dd]
                    V_j[jj, dd] = V[jj + j * B_c, dd]

            # S[ii][jj] = scaling * (Q_i @ K_j.T)
            numba.cuda.syncthreads()
            for ii in range(tid_x, B_r, numba.cuda.blockDim.x):
                for jj in range(tid_y, B_c, numba.cuda.blockDim.y):
                    S_ij = 0
                    for dd in range(d):
                        S_ij += Q_i[ii, dd] * K_j[jj, dd]
                    S_ij = scaling * S_ij
                    S[ii][jj] = S_ij

            numba.cuda.syncthreads()
            for ii in range(tid_y, B_r, numba.cuda.blockDim.y):
                m = m_i[ii]
                last_m = m
                for jj in range(B_c):
                    m = max(m, S[ii][jj])
                m_i[ii] = m
                l = math.exp(last_m - m) * l_i[ii]

                for dd in range(tid_x, d, numba.cuda.blockDim.x):
                    O_i[ii, dd] *= math.exp(last_m - m)
                for jj in range(B_c):
                    P_ij = math.exp(S[ii][jj] - m)  # Cache...
                    l += P_ij
                    for dd in range(tid_x, d, numba.cuda.blockDim.x):
                        O_i[ii, dd] += P_ij * V_j[jj, dd]
                l_i[ii] = l
                
        numba.cuda.syncthreads()
        for ii in range(tid_y, B_r, numba.cuda.blockDim.y):  
            for dd in range(tid_x, d, numba.cuda.blockDim.x):
                O[ii + i * B_r, dd] = O_i[ii, dd] / l_i[ii]
            L[ii + i * B_r] = l_i[ii]   
        numba.cuda.syncthreads() 
   


In [15]:
O_splilling = torch.zeros(N_out, d, device="cuda").contiguous()
L_splilling = torch.zeros(N_out, device="cuda")
tpb = (8, 16)
grid = (1,)
attention_numba_spilling[grid, tpb](Q, K, V, scaling, L_splilling, O_splilling)
check_diff(O_splilling, L_splilling)



Max absolute difference:
O:  tensor(5.9605e-07, device='cuda:0')


AssertionError: L diff too large: 2.4368085861206055 > atol=5e-05

In [11]:
L_expected3 - (max_per_row+torch.log(L_splilling))

tensor([[-4.7684e-07, -4.7684e-07,  2.3842e-07,  ...,  2.3842e-07,
          0.0000e+00,  0.0000e+00],
        [-4.7684e-07, -2.3842e-07,  4.7684e-07,  ...,  2.3842e-07,
          4.7684e-07, -2.3842e-07],
        [-2.3842e-07, -4.7684e-07,  2.3842e-07,  ...,  4.7684e-07,
          4.7684e-07, -2.3842e-07],
        ...,
        [ 0.0000e+00,  0.0000e+00,  4.7684e-07,  ...,  4.7684e-07,
          4.7684e-07,  0.0000e+00],
        [-2.3842e-07, -2.3842e-07,  2.3842e-07,  ...,  2.3842e-07,
          2.3842e-07,  0.0000e+00],
        [-4.7684e-07,  0.0000e+00,  4.7684e-07,  ...,  4.7684e-07,
          4.7684e-07, -2.3842e-07]], device='cuda:0')

Let's make local threads arrays small as to make it more likely they will be in registers.

We have 1 block of size `(block_dim_x, block_dim_x)` Then for each ...

So we have 
```
B_r / block_dim_y * 2 + B_r * d / block_dim_x = 2 * 16 / 4 + 16 * 128 / 32 = 72
```
so 72 single pecision per thread -> (72 * 32 * 4) * 4bytes = 9216*4bytes= 36864 bytes = 36KB << 64KB register memory per SM.


If we do not shrink those we get
```
B_r * 2 + B_r *. d = 3 * 16 + 16* 128 = 2096 -> (2096 * 32 * 4) * 4 = 1MB > 64KB -> spill
```

Shared memory has enough to host 2096 * 4 = 8384 = 8KB extra? 
as already
Br*d + 2 Bc * d + bc*Br = 6400 -> 25KB < 48KB

But it;s slower than registers

In [None]:
2 * 16 / 4 + 16 * 128 / 32
72 * 32 * 4 * 4
36864/1024
3 * 16 + 16* 128
2096 * 32 * 4 * 4 /1024
2096 * 4
(16*128*2 + 16*128 + 16*16)*4

25600

In [None]:
# No spilling: make local arrays small
block_dim_x = 32
block_dim_y = 4
B_r = 16
o_per_thread_x = d // block_dim_x
o_per_thread_y = B_r // block_dim_y

@numba.cuda.jit
def attention_numba_no_spilling(Q, K, V, scaling: numba.float32, L, O):
    B_c = 16
    
    T_c = (N_inp + B_c - 1) // B_c
    T_r = (N_out + B_r - 1) // B_r
    inp_dtype = K.dtype
    tid_x = numba.cuda.threadIdx.x
    tid_y = numba.cuda.threadIdx.y

    Q_i = numba.cuda.shared.array((B_r, d), inp_dtype)
    K_j = numba.cuda.shared.array((B_c, d), inp_dtype)
    V_j = numba.cuda.shared.array((B_c, d), inp_dtype)
    S = numba.cuda.shared.array((B_r, B_c), inp_dtype)
    
    # These can be in registers
    l_i = numba.cuda.local.array((o_per_thread_y,), inp_dtype)
    m_i = numba.cuda.local.array((o_per_thread_y,), inp_dtype)
    O_i = numba.cuda.local.array((o_per_thread_y, o_per_thread_x), inp_dtype)

                 
    for i in range(T_r):
        for ii in range(tid_y, B_r, numba.cuda.blockDim.y):
            for dd in range(tid_x, d, numba.cuda.blockDim.x):
                Q_i[ii, dd] = Q[ii + i * B_r, dd]
                O_i[ii//block_dim_y, dd//block_dim_x] = 0
            l_i[ii//block_dim_y] = 0
            m_i[ii//block_dim_y] = -math.inf
        numba.cuda.syncthreads()

        for j in range(T_c):
            for jj in range(tid_y, B_c, numba.cuda.blockDim.y):
                for dd in range(tid_x, d, numba.cuda.blockDim.x):
                    K_j[jj, dd] = K[jj + j * B_c, dd]
                    V_j[jj, dd] = V[jj + j * B_c, dd]

            # S[ii][jj] = scaling * (Q_i @ K_j.T)
            numba.cuda.syncthreads()
            for ii in range(tid_x, B_r, numba.cuda.blockDim.x):
                for jj in range(tid_y, B_c, numba.cuda.blockDim.y):
                    S_ij = 0
                    for dd in range(d):
                        S_ij += Q_i[ii, dd] * K_j[jj, dd]
                    S_ij = scaling * S_ij
                    S[ii][jj] = S_ij

            numba.cuda.syncthreads()
            for ii in range(tid_y, B_r, numba.cuda.blockDim.y):
                # torch.maximum(m_i, S_i.max(dim=-1, keepdim=True).values)
                # this needs to use the parallel reduction pattern
                m = m_i[ii//block_dim_y]
                last_m = m
                for jj in range(B_c):
                    m = max(m, S[ii][jj])
                m_i[ii//block_dim_y] = m
                l = math.exp(last_m - m) * l_i[ii//block_dim_y]

                for dd in range(tid_x, d, numba.cuda.blockDim.x):
                    O_i[ii//block_dim_y, dd//block_dim_x] *= math.exp(last_m - m)
                for jj in range(B_c):
                    P_ij = math.exp(S[ii][jj] - m)  # Cache...
                    l += P_ij
                    for dd in range(tid_x, d, numba.cuda.blockDim.x):
                        O_i[ii//block_dim_y, dd//block_dim_x] += P_ij * V_j[jj, dd]
                l_i[ii//block_dim_y] = l
        numba.cuda.syncthreads()
        for ii in range(tid_y, B_r, numba.cuda.blockDim.y):  
            for dd in range(tid_x, d, numba.cuda.blockDim.x):
                O[ii + i * B_r, dd] = O_i[ii//block_dim_y, dd//block_dim_x] / l_i[ii//block_dim_y]
            L[ii + i * B_r] = l_i[ii//block_dim_y]   
        numba.cuda.syncthreads() 
   


In [None]:
O_no_spilling = torch.zeros(N_out, d, device="cuda").contiguous()
L_no_spilling = torch.zeros(N_out, device="cuda")
tpb = (block_dim_x, block_dim_y)
grid = (1,)
attention_numba_no_spilling[grid, tpb](Q, K, V, scaling, L_no_spilling, O_no_spilling)
check_diff(O_no_spilling, L_no_spilling)



Max absolute difference:
O:  tensor(4.1723e-07, device='cuda:0')
L:  tensor(3.8147e-06, device='cuda:0')


## Use Cuda

### Register spilling

In [None]:
fname = "flash_attention"
module_cuda = get_loaded_cuda_module(fname)

If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].


In [None]:
profile_kernel(module_cuda, fname, Q, K, V)

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                      flash_attention_k         0.00%       0.000us         0.00%       0.000us       0.000us      48.000us        94.76%      48.000us      48.000us             1  
                                            aten::zeros         0.81%      16.531us         6.33%     129.798us      64.899us       0.000us         0.00%       2.656us       1.328us             2  
         

In [None]:
O_cuda, L_cuda = getattr(module_cuda, fname)(Q, K, V)
check_diff(O_cuda, L_cuda)

Max absolute difference:
O:  tensor(0.0997, device='cuda:0')
L:  tensor(6.6757e-06, device='cuda:0')


## Appendix

### Registers spilling

In [None]:
fname_spill_from_registers = "flash_attention_spilling_from_registers"
module_cuda_spilling_from_registers = get_loaded_cuda_module(fname_spill_from_registers)

If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].


In [None]:
fname_spill_from_smem = "flash_attention_spilling_from_smem"
module_cuda_spilling_from_smem = get_loaded_cuda_module(fname_spill_from_smem)

If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].


In [None]:
profile_kernel(module_cuda_spilling_from_smem, fname_spill_from_smem, Q, K, V)

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                      flash_attention_k         0.00%       0.000us         0.00%       0.000us       0.000us      46.560us        95.16%      46.560us      46.560us             1  
                                            aten::zeros         0.87%      18.211us         7.24%     151.370us      75.685us       0.000us         0.00%       2.367us       1.184us             2  
         

In [None]:
O_cuda_spilling, L_cuda_spilling = getattr(module_cuda_spilling_from_smem, fname_spill_from_smem)(Q, K, V)
check_diff(O_cuda_spilling, L_cuda_spilling)

Max absolute difference:
O:  tensor(5.9605e-07, device='cuda:0')
L:  tensor(6.6757e-06, device='cuda:0')


### Test performance on different dimensions

In [None]:
TEST_DIMS = [
    (32, 32),
    (64, 32),
    (128, 128),
    (512, 512),
]
for N_inp, N_out in TEST_DIMS:
    Q, K, V, _ = get_test_tensors(N_inp, N_out, d)
    print(f"\n\n**********\nDimensions: {N_out=}, {N_inp=}, {d=}")
    torch.cuda.synchronize()
    print("\n- Torch scaled_dot_product_attention")
    %timeit torch.nn.functional.scaled_dot_product_attention(Q, K, V); torch.cuda.synchronize()
    print("\n- Flash Attention")
    %timeit getattr(module_cuda, fname)(Q, K, V); torch.cuda.synchronize()
    # print("\n- Flash Attention: spill from SMEM")
    # %timeit getattr(module_cuda_spilling_from_smem, fname_spill_from_smem)(Q, K, V); torch.cuda.synchronize()
    print("\n- Flash Attention: spill from registers")
    %timeit getattr(module_cuda_spilling_from_registers, fname_spill_from_registers)(Q, K, V); torch.cuda.synchronize()



**********
Dimensions: N_out=32, N_inp=32, d=128

- Torch scaled_dot_product_attention


209 μs ± 2.5 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

- Flash Attention
111 μs ± 661 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

- Flash Attention: spill from SMEM
109 μs ± 865 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

- Flash Attention: spill from registers
350 μs ± 1.17 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


**********
Dimensions: N_out=32, N_inp=64, d=128

- Torch scaled_dot_product_attention
210 μs ± 3.23 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

- Flash Attention
155 μs ± 430 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

- Flash Attention: spill from SMEM
151 μs ± 609 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

- Flash Attention: spill from registers
628 μs ± 1.22 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


**********
Dimensions: N_out=128, N_inp=128, d=128

- Torch scaled_dot_product_attention
209 μs ± 1.98 μs per loop (mean ± std. dev. o

In [None]:
# Create program
# conda install cuda-python
from cuda import cuda, nvrtc

cuda_src_path = f"./{fname}.cu"
import re
cuda_src = Path(cuda_src_path).read_text()
cuda_src = re.sub(r'__host__.*', '', cuda_src, flags=re.DOTALL)
cuda_src = cuda_src.replace("#include <math_constants.h>",  "")

N_inp = 32
N_out = 32
d = 128
B_r, B_c = 16, 16
T_r = (N_out + B_r -1) // B_r
T_c = (N_inp + B_r -1) // B_c
Q, K, V, scale_factor = get_test_tensors(N_inp, N_out, d)

err, prog = nvrtc.nvrtcCreateProgram(str.encode(cuda_src), b"flash_attention.cu", 0, [], [])

# Compile program
min, maj = torch.cuda.get_device_capability()
opts = [f"--gpu-architecture=compute_{min}{maj}".encode()]
err, = nvrtc.nvrtcCompileProgram(prog, len(opts), opts)

print(err)

# Get PTX from compilation
err, ptxSize = nvrtc.nvrtcGetPTXSize(prog)
ptx = b" " * ptxSize
err, = nvrtc.nvrtcGetPTX(prog, ptx)
print(err)

err, logSize = nvrtc.nvrtcGetProgramLogSize(prog)
log = b" " * logSize
err, = nvrtc.nvrtcGetProgramLog(prog, log)
print(log.decode())
# print(ptx.decode())

# Load PTX as module data and retrieve function
err, module = cuda.cuModuleLoadData(ptx)
print(err)
err, kernel = cuda.cuModuleGetFunction(module, b"flash_attention_k")
print(err, kernel)

# Allocate tensors
# S3 = torch.zeros(N_out, N_out, device="cuda")
O_cuda_py = torch.zeros(N_out, d, device="cuda")
L_cuda_py = torch.zeros(N_out, device="cuda")

# To quote the official tutorial: (https://nvidia.github.io/cuda-python/overview.html)
# The following code example is not intuitive
# Subject to change in a future release

int_args = torch.tensor([0, T_r, T_c], dtype=torch.int32)
float_args = torch.tensor([scale_factor], dtype=torch.float32)
ptr_args = torch.tensor([i.data_ptr() for i in (O_cuda_py, L_cuda_py, K, Q, V)], dtype=torch.uint64)

args = torch.tensor([
    *(i.data_ptr() for i in ptr_args),
    *(i.data_ptr() for i in float_args),
    *(i.data_ptr() for i in int_args)], dtype=torch.uint64)

args

0
0
 
0
0 <CUfunction 0x55ce7273cdc0>


tensor([94345151665600, 94345151665608, 94345151665616, 94345151665624,
        94345151665632, 94345172775232, 94345116918272, 94345116918276,
        94345116918280], dtype=torch.uint64)

In [None]:
O_expected = torch.nn.functional.scaled_dot_product_attention(Q, K, V)
def fn():
    err = cuda.cuLaunchKernel(
        kernel,
        1,  # grid x dim
        1,  # grid y dim
        1,  # grid z dim
        32,  # block x dim
        32,  # block y dim
        1,  # block z dim
        0,  # dynamic shared memory
        torch.cuda.current_stream().stream_id,  # stream
        args.data_ptr(),  # kernel arguments
        0,  # extra (ignore)
    )

fn()

(O_cuda_py - O_expected).abs().max()

tensor(5.3644e-07, device='cuda:0')

In [None]:
print(f"\n\n**********\nDimensions: {N_out=}, {N_inp=}, {d=}")
torch.cuda.synchronize()
print("\n- Torch scaled_dot_product_attention")
%timeit fn(); torch.cuda.synchronize()



**********
Dimensions: N_out=32, N_inp=32, d=128

- Torch scaled_dot_product_attention
69.6 μs ± 383 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


åppendix
TODO: for large matrices numerical error is too large

In [None]:
# Cuda setup