# Flash Attention in Torch, Numba and Cuda


We implement in 3 different ways the forward algorithm from the [Flash Attention 2 paper](https://arxiv.org/pdf/2307.08691):

1. Torch operations
2. Numba
3. Cuda

We do some basic performance analysis as well as running the custom kernel with cuda-python and thunder.


- We build the kernel for `d=128` and design it so that it computes the full attention in a single block.

![./flash_attention_fwd.png](./flash_attention_fwd.png)


## Utils

In [23]:
import numba
from numba.cuda import as_cuda_array as ca
from pathlib import Path
import numpy as np
import math
import torch
import sys, os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

sys.path.insert(0, "..")
from utils import load_cuda, get_sig, print_cuda_info

import os

TEST_DIMS = [
    (32, 32),
    (128, 64),
    (512, 512),
    (512, 1024),
]

def get_loaded_cuda_module(fname, verbose=False):
    cuda_src_path = f"./{fname}.cu"
    torch_src_path = f"./torch_extension_template.cu"
    cuda_src = Path(cuda_src_path).read_text()
    cuda_src += Path(torch_src_path).read_text()
    cuda_src = cuda_src.replace("your_function_name", fname)
    cpp_src = get_sig(fname, cuda_src)
    return load_cuda(cuda_src, cpp_src, [fname], verbose=verbose)


def check_close(O, O_expected, L=None, L_expected=None, atol=5*1e-5):
    O_diff = (O-O_expected).abs().max()
    print("Max absolute difference:")
    if atol:
        assert O_diff < atol, f"O diff too large: {O_diff} > {atol=}"
    print("O: ", O_diff)
    if L is not None:
        L_diff = (L.squeeze()-L_expected).abs().max()
        if atol:
            assert L_diff < atol, f"L diff too large: {L_diff} > {atol=}"
        print("L: ", L_diff)


In [24]:
# Test tensors
def get_test_tensors(N_inp, N_out, d):
    Q = torch.randn(N_out, d).contiguous().to("cuda")
    K = torch.randn(N_inp, d).contiguous().to("cuda")
    V = torch.randn(N_inp, d).contiguous().to("cuda")
    scaling = 1.0 / math.sqrt(d)

    # Get expected O
    O_expected = torch.nn.functional.scaled_dot_product_attention(Q, K, V)
    S = (Q @ K.T) * scaling  # shape: (N_out, N_inp)
    L_expected = torch.logsumexp(S, dim=-1)
    return Q, K, V, scaling, O_expected, L_expected

N_inp = 512
N_out = 512
d = 128

Q, K, V, scaling, O_expected, L_expected = get_test_tensors(N_inp, N_out, d)
check_close(O=torch.softmax(Q @ K.T * scaling, dim=-1) @ V, O_expected=O_expected)



Max absolute difference:
O:  tensor(5.6624e-07, device='cuda:0')


## Torch

In [25]:
def flash_attention_torch(Q, K, V, O, L, N_inp, N_out, d) -> None:
    """Forward algo from https://arxiv.org/pdf/2307.08691
    """

    B_c = 16
    B_r = 16
    T_c = (N_inp + B_c - 1) // B_c
    T_r = (N_out + B_r - 1) // B_r

    scaling = 1 / math.sqrt(d)

    # Q and O L split into T_r; K, V in T_c blocks
    for i in range(T_r):
        Q_i = Q[i * B_r : (i + 1) * B_r]
        O_i = torch.zeros(B_r, d)
        L_i = torch.zeros(B_r, 1)
        m_i = torch.full((B_r, 1), -math.inf)
        last_m_i = m_i
        for j in range(T_c):
            K_j = K[j * B_c : (j + 1) * B_c]
            V_j = V[j * B_c : (j + 1) * B_c]
            S_i = scaling * (Q_i @ K_j.T)
            m_i = torch.maximum(m_i, S_i.max(dim=-1, keepdim=True).values)
            P_i = torch.exp(S_i - m_i)
            L_i = torch.exp(last_m_i - m_i) * L_i + P_i.sum(dim=-1, keepdim=True)
            O_i = torch.exp(last_m_i - m_i) * O_i + P_i @ V_j
            last_m_i = m_i
        O_i = (1.0 / L_i) * O_i
        L_i = m_i + torch.log(L_i)
        O[i * B_r : (i + 1) * B_r] = O_i
        L[i * B_r : (i + 1) * B_r] = L_i

In [26]:
O_torch_loop = torch.zeros(N_out, d)
L_torch_loop = torch.zeros(N_out, 1)

flash_attention_torch(Q.to("cpu"), K.to("cpu"), V.to("cpu"), O_torch_loop, L_torch_loop, N_inp, N_out, d)

check_close(
    O_torch_loop.to("cuda"), 
    O_expected,
    L_torch_loop.to("cuda"),
    L_expected
)

Max absolute difference:
O:  tensor(5.9605e-07, device='cuda:0')
L:  tensor(9.5367e-07, device='cuda:0')


## Numba

Tiling strategy: each thread computes one value in

### All arrays in shared memory

In [27]:
@numba.cuda.jit
def flash_attention_numba_all_smem(Q, K, V, scaling: numba.float32, L, O, N_out, N_inp):
    B_c = 16
    B_r = 16
    T_c = (N_inp + B_c - 1) // B_c
    T_r = (N_out + B_r - 1) // B_r
    inp_dtype = K.dtype
    tid_x = numba.cuda.threadIdx.x
    tid_y = numba.cuda.threadIdx.y

    Q_i = numba.cuda.shared.array((B_r, d), inp_dtype)
    K_j = numba.cuda.shared.array((B_c, d), inp_dtype)
    V_j = numba.cuda.shared.array((B_c, d), inp_dtype)
    S = numba.cuda.shared.array((B_r, B_c), inp_dtype)
    # These can be in registers but wont fit too large
    l_i = numba.cuda.shared.array((B_r,), inp_dtype)
    m_i = numba.cuda.shared.array((B_r,), inp_dtype)
    O_i = numba.cuda.shared.array((B_r, d), inp_dtype)

                 
    for i in range(T_r):
        for ii in range(tid_y, B_r, numba.cuda.blockDim.y):
            for dd in range(tid_x, d, numba.cuda.blockDim.x):
                Q_i[ii, dd] = Q[ii + i * B_r, dd]
                O_i[ii, dd] = 0
            l_i[ii] = 0
            m_i[ii] = -math.inf
        numba.cuda.syncthreads()

        for j in range(T_c):
            for jj in range(tid_y, B_c, numba.cuda.blockDim.y):
                for dd in range(tid_x, d, numba.cuda.blockDim.x):
                    K_j[jj, dd] = K[jj + j * B_c, dd]
                    V_j[jj, dd] = V[jj + j * B_c, dd]

            # S[ii][jj] = scaling * (Q_i @ K_j.T)
            numba.cuda.syncthreads()
            for ii in range(tid_x, B_r, numba.cuda.blockDim.x):
                for jj in range(tid_y, B_c, numba.cuda.blockDim.y):
                    S_ij = 0
                    for dd in range(d):
                        S_ij += Q_i[ii, dd] * K_j[jj, dd]
                    S_ij = scaling * S_ij
                    S[ii][jj] = S_ij

            numba.cuda.syncthreads()
            for ii in range(tid_y, B_r, numba.cuda.blockDim.y):
                m = m_i[ii]
                last_m = m
                for jj in range(B_c):
                    m = max(m, S[ii][jj])
                m_i[ii] = m
                l = math.exp(last_m - m) * l_i[ii]

                for dd in range(tid_x, d, numba.cuda.blockDim.x):
                    O_i[ii, dd] *= math.exp(last_m - m)
                for jj in range(B_c):
                    P_ij = math.exp(S[ii][jj] - m)  # Cache...
                    l += P_ij
                    for dd in range(tid_x, d, numba.cuda.blockDim.x):
                        O_i[ii, dd] += P_ij * V_j[jj, dd]
                l_i[ii] = l
                
        numba.cuda.syncthreads()
        for ii in range(tid_y, B_r, numba.cuda.blockDim.y):  
            for dd in range(tid_x, d, numba.cuda.blockDim.x):
                O[ii + i * B_r, dd] = O_i[ii, dd] / l_i[ii]
            L[ii + i * B_r] = m_i[ii] + math.log(l_i[ii])
        numba.cuda.syncthreads() 
   


In [28]:

block_dim_x = 32
block_dim_y = 16

for N_inp, N_out in TEST_DIMS:

    Q, K, V, scaling, O_expected, L_expected = get_test_tensors(N_inp, N_out, d)
    O_all_smem = torch.zeros(N_out, d, device="cuda").contiguous()
    L_all_smem = torch.zeros(N_out, device="cuda")
    tpb = (block_dim_x, block_dim_y)
    grid = (1,)
    flash_attention_numba_all_smem[grid, tpb](Q, K, V, scaling, L_all_smem, O_all_smem,  N_out, N_inp)
    torch.cuda.synchronize()
    check_close(
        O_all_smem, 
        O_expected,
        L_all_smem,
        L_expected,
    )



Max absolute difference:
O:  tensor(4.7684e-07, device='cuda:0')
L:  tensor(4.7684e-07, device='cuda:0')
Max absolute difference:
O:  tensor(3.8743e-07, device='cuda:0')
L:  tensor(9.5367e-07, device='cuda:0')
Max absolute difference:
O:  tensor(7.7486e-07, device='cuda:0')
L:  tensor(1.4305e-06, device='cuda:0')
Max absolute difference:
O:  tensor(6.5565e-07, device='cuda:0')
L:  tensor(1.4305e-06, device='cuda:0')


### Moving `m_i`, `l_i`, `O_i` to local arrays

Current shared-memory usage across threads:
```
Shar = (B_r * d * 4) # Q_i
+ (B_c * d * 4) # K_j
+ (B_c * d * 4) # V_j
+ (B_r * B_c * 4) # S
= ~25 KB
```

Current block-shared accumulators (`m_i`, `l_i`, `O_i`):
```
Loc = 4 * (B_r + B_r + (B_r * d)) = 8320 B ≈ 8 KB
```

Total shared usage: **~33 KB** (fine for 1 block/SM).

---

**Idea:** Move `m_i`, `l_i`, `O_i` to *per-thread* locals to fit in registers.

Problem: Full-size per-thread arrays would need

```
Loc * 32 * 16 ≈ 4 MB > 64 KB register file per SM
```

---

**Optimization:** With tiling `tpb = (32, 16)`:

- Each thread handles only  
  `d // blockDim.x = 4` columns in `x`  
  `B_r // blockDim.y = 1` row in `y`
- So per-thread locals can be much smaller:

```python
l_i = numba.cuda.local.array((1,), inp_dtype)   # 4 B
m_i = numba.cuda.local.array((1,), inp_dtype)   # 4 B
O_i = numba.cuda.local.array((4,), inp_dtype)   # 16 B
```
-> Per-thread = 24 B, per block = 24 * 32 * 16 = 12 KB < 64 KB -> avoid register pressure and spills.


This is how we set up `flash_attention_numba` below and the cuda version in `./flash_attention.cu`

In the performance section we run `./flash_attention_spilling_from_registers.cu` that fits full arrays in local variables, to show the performance decrease by slowing the kernel by ~2.5×

In [29]:
block_dim_x = 32
block_dim_y = 16
B_r = 16
B_c = 16
d_over_dim_x = d // block_dim_x
B_r_over_dim_y = B_r // block_dim_y

@numba.cuda.jit
def flash_attention_numba(Q, K, V, scaling: numba.float32, L, O, N_out, N_inp):
    B_c = 16
    B_r = 16
    T_c = (N_inp + B_c - 1) // B_c
    T_r = (N_out + B_r - 1) // B_r
    inp_dtype = K.dtype
    tid_x = numba.cuda.threadIdx.x
    tid_y = numba.cuda.threadIdx.y
    dim_y = numba.cuda.blockDim.y
    dim_x = numba.cuda.blockDim.x
    

    Q_i = numba.cuda.shared.array((B_r, d), inp_dtype)
    K_j = numba.cuda.shared.array((B_c, d), inp_dtype)
    V_j = numba.cuda.shared.array((B_c, d), inp_dtype)
    S = numba.cuda.shared.array((B_r, B_c), inp_dtype)

    l_i = numba.cuda.local.array((B_r_over_dim_y,), inp_dtype)
    m_i = numba.cuda.local.array((B_r_over_dim_y,), inp_dtype)
    O_i = numba.cuda.local.array((B_r_over_dim_y, d_over_dim_x), inp_dtype)

                 
    for i in range(T_r):
        for ii in range(tid_y, B_r, dim_y):
            for dd in range(tid_x, d, dim_x):
                Q_i[ii, dd] = Q[ii + i * B_r, dd]
        numba.cuda.syncthreads()

        for ii in range(B_r_over_dim_y):
            for dd in range(d_over_dim_x):
                O_i[ii, dd] = 0
            l_i[ii] = 0
            m_i[ii] = -math.inf
        numba.cuda.syncthreads()

        for j in range(T_c):
            for jj in range(tid_y, B_c, dim_y):
                for dd in range(tid_x, d, dim_x):
                    K_j[jj, dd] = K[jj + j * B_c, dd]
                    V_j[jj, dd] = V[jj + j * B_c, dd]

            # S[ii][jj] = scaling * (Q_i @ K_j.T)
            numba.cuda.syncthreads()
            for ii in range(tid_x, B_r, dim_x):
                for jj in range(tid_y, B_c, dim_y):
                    S_ij = 0
                    for dd in range(d):
                        S_ij += Q_i[ii, dd] * K_j[jj, dd]
                    S_ij = scaling * S_ij
                    S[ii][jj] = S_ij

            numba.cuda.syncthreads()
            for ii in range(B_r_over_dim_y):
                m = m_i[ii]
                last_m = m
                for jj in range(B_c):
                    m = max(m, S[ii * dim_y + tid_y][jj])
                m_i[ii] = m
                l = numba.float32(math.exp(last_m - m)) * l_i[ii]

                for dd in range(d_over_dim_x):
                    O_i[ii, dd] *= numba.float32(math.exp(last_m - m))
                for jj in range(B_c):
                    P_ij = numba.float32(math.exp(S[ii * dim_y + tid_y][jj] - m))
                    l += P_ij
                    for dd in range(d_over_dim_x):
                        O_i[ii, dd] += P_ij * V_j[jj, dd * dim_x + tid_x]
                l_i[ii] = l
                
        numba.cuda.syncthreads()
        for ii in range(B_r_over_dim_y):  
            for dd in range(d_over_dim_x):
                O[ii * dim_y + tid_y + i * B_r, dd * dim_x + tid_x] = O_i[ii, dd] / l_i[ii]
            L[ii * dim_y + tid_y + i * B_r] = m_i[ii] + numba.float32(math.log(l_i[ii]))
        numba.cuda.syncthreads() 

In [30]:
for N_inp, N_out in TEST_DIMS:
    Q, K, V, scaling, O_expected, L_expected = get_test_tensors(N_inp, N_out, d)

    O_all_smem = torch.zeros(N_out, d, device="cuda").contiguous()
    L_all_smem = torch.zeros(N_out, device="cuda")
    tpb = (block_dim_x, block_dim_y)
    grid = (1,)
    flash_attention_numba[grid, tpb](Q, K, V, scaling, L_all_smem, O_all_smem,  N_out, N_inp)
    check_close(
        O_all_smem, 
        O_expected,
        L_all_smem,
        L_expected,
    )



Max absolute difference:
O:  tensor(4.1723e-07, device='cuda:0')
L:  tensor(4.7684e-07, device='cuda:0')
Max absolute difference:
O:  tensor(2.9802e-07, device='cuda:0')
L:  tensor(4.7684e-07, device='cuda:0')
Max absolute difference:
O:  tensor(4.1723e-07, device='cuda:0')
L:  tensor(9.5367e-07, device='cuda:0')
Max absolute difference:
O:  tensor(1.1921e-06, device='cuda:0')
L:  tensor(1.4305e-06, device='cuda:0')


## Cuda

### flash_attention_numba in Cuda

In [31]:
fname = "flash_attention"
module_cuda = get_loaded_cuda_module(fname, verbose=True)

Using /home/zeus/.cache/torch_extensions/py310_cu128 as PyTorch extensions root...
Detected CUDA files, patching ldflags
Emitting ninja build file /home/zeus/.cache/torch_extensions/py310_cu128/flash_attentiontest/build.ninja...
Building extension module flash_attentiontest...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)


ninja: no work to do.


Loading extension module flash_attentiontest...


In [32]:
for N_inp, N_out in TEST_DIMS:

    Q, K, V, _, O_expected, L_expected = get_test_tensors(N_inp, N_out, d)
    O_move_registers, L_move_registers = getattr(module_cuda, fname)(Q, K, V)
    check_close(
        O_move_registers, 
        O_expected,
        L_move_registers,
        L_expected,
    )

Max absolute difference:
O:  tensor(8.9407e-07, device='cuda:0')
L:  tensor(4.7684e-07, device='cuda:0')
Max absolute difference:
O:  tensor(1.0133e-06, device='cuda:0')
L:  tensor(9.5367e-07, device='cuda:0')
Max absolute difference:
O:  tensor(1.4007e-06, device='cuda:0')
L:  tensor(9.5367e-07, device='cuda:0')
Max absolute difference:
O:  tensor(7.3016e-07, device='cuda:0')
L:  tensor(1.4305e-06, device='cuda:0')


## Cuda-Python

In [33]:
# Create program
# conda install conda-forge::cuda-python
from cuda import cuda, nvrtc

cuda_src_path = f"./flash_attention.cu"
cuda_src = Path(cuda_src_path).read_text()

N_inp = 32
N_out = 32
d = 128
B_r, B_c = 16, 16
T_r = (N_out + B_r -1) // B_r
T_c = (N_inp + B_r -1) // B_c
Q, K, V, scale_factor, O_expected, L_expected = get_test_tensors(N_inp, N_out, d)

err, prog = nvrtc.nvrtcCreateProgram(str.encode(cuda_src), b"flash_attention.cu", 0, [], [])

# Compile program
min, maj = torch.cuda.get_device_capability()
opts = [
    f"--gpu-architecture=compute_{min}{maj}".encode(), 
    "--device-as-default-execution-space".encode(),
    "--std=c++14".encode()]
err, = nvrtc.nvrtcCompileProgram(prog, len(opts), opts)

print(err)

# Get PTX from compilation
err, ptxSize = nvrtc.nvrtcGetPTXSize(prog)
ptx = b" " * ptxSize
err, = nvrtc.nvrtcGetPTX(prog, ptx)
print(err)

err, logSize = nvrtc.nvrtcGetProgramLogSize(prog)
log = b" " * logSize
err, = nvrtc.nvrtcGetProgramLog(prog, log)
print(log.decode())
# print(ptx.decode())

# Load PTX as module data and retrieve function
err, module = cuda.cuModuleLoadData(ptx)
print(err)
err, kernel = cuda.cuModuleGetFunction(module, b"flash_attention_k")
print(err, kernel)

# Allocate tensors
# S3 = torch.zeros(N_out, N_out, device="cuda")
O_cuda_py = torch.zeros(N_out, d, device="cuda")
L_cuda_py = torch.zeros(N_out, device="cuda")

# To quote the official tutorial: (https://nvidia.github.io/cuda-python/overview.html)
# The following code example is not intuitive
# Subject to change in a future release

int_args = torch.tensor([0, T_r, T_c], dtype=torch.int32)
float_args = torch.tensor([scale_factor], dtype=torch.float32)
ptr_args = torch.tensor([i.data_ptr() for i in (O_cuda_py, L_cuda_py, Q, K, V)], dtype=torch.uint64)

args = torch.tensor([
    *(i.data_ptr() for i in ptr_args),
    *(i.data_ptr() for i in float_args),
    *(i.data_ptr() for i in int_args)], dtype=torch.uint64)

args

nvrtcResult.NVRTC_SUCCESS
nvrtcResult.NVRTC_SUCCESS
 
CUresult.CUDA_SUCCESS
CUresult.CUDA_SUCCESS <CUfunction 0x3083c470>


tensor([700605696, 700605704, 700605712, 700605720, 700605728, 815044608,
        814450368, 814450372, 814450376], dtype=torch.uint64)

In [34]:
def fn():
    err = cuda.cuLaunchKernel(
        kernel,
        1,  # grid x dim
        1,  # grid y dim
        1,  # grid z dim
        16,  # block x dim
        16,  # block y dim
        1,  # block z dim
        0,  # dynamic shared memory
        torch.cuda.current_stream().stream_id,  # stream
        args.data_ptr(),  # kernel arguments
        0,  # extra (ignore)
    )

fn()

(O_cuda_py - O_expected).abs().max()

tensor(1.1325e-06, device='cuda:0')

In [35]:
print(f"\n\n**********\nDimensions: {N_out=}, {N_inp=}, {d=}")
torch.cuda.synchronize()
print("\n- Custom Flash Attention: Cuda-python")
%timeit fn(); torch.cuda.synchronize()



**********
Dimensions: N_out=32, N_inp=32, d=128

- Custom Flash Attention: Cuda-python


65 µs ± 1.05 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


## Performance 

Run timeit on different dimensions. 

- Recall that we build the kernel for `d=128` and design it so that it computes the full attention in a single block.

- For matrices with small `N_out` this implementation is significantly faster than `scaled_dot_product_attention`.

- But by making N larger this implementation slows down dramatically it only uses a single block.

- Note that the register splilling version is much slower even compared to `scaled_dot_product_attention`.


### Load registers spilling version

This is the version loading full arrays as local variables in threads, which leads to spilling.

In [36]:
fname_spill_from_registers = "flash_attention_spilling_from_registers"
module_cuda_spilling_from_registers = get_loaded_cuda_module(fname_spill_from_registers)
O_cuda_spilling, L_cuda_spilling = getattr(module_cuda_spilling_from_registers, fname_spill_from_registers)(Q, K, V)
check_close(O_cuda_spilling, O_expected, L_cuda_spilling, L_expected)

Max absolute difference:
O:  tensor(1.1325e-06, device='cuda:0')
L:  tensor(4.7684e-07, device='cuda:0')


In [37]:
for N_inp, N_out in TEST_DIMS:
    Q, K, V, _, _, _ = get_test_tensors(N_inp, N_out, d)
    print(f"\n\n**********\nDimensions: {N_out=}, {N_inp=}, {d=}")
    torch.cuda.synchronize()
    print("\n- Torch scaled_dot_product_attention")
    %timeit torch.nn.functional.scaled_dot_product_attention(Q, K, V); torch.cuda.synchronize()
    print("\n- Custom Flash Attention")
    %timeit getattr(module_cuda, fname)(Q, K, V); torch.cuda.synchronize()
    print("\n- Custom Flash Attention: spill from registers")
    %timeit getattr(module_cuda_spilling_from_registers, fname_spill_from_registers)(Q, K, V); torch.cuda.synchronize()



**********
Dimensions: N_out=32, N_inp=32, d=128

- Torch scaled_dot_product_attention


176 µs ± 2.54 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

- Custom Flash Attention
88.5 µs ± 901 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

- Custom Flash Attention: spill from registers
353 µs ± 1.73 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


**********
Dimensions: N_out=64, N_inp=128, d=128

- Torch scaled_dot_product_attention
177 µs ± 3.12 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

- Custom Flash Attention
337 µs ± 958 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

- Custom Flash Attention: spill from registers
2.41 ms ± 4.43 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


**********
Dimensions: N_out=512, N_inp=512, d=128

- Torch scaled_dot_product_attention
207 µs ± 2.47 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

- Custom Flash Attention
9.02 ms ± 6.53 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

- Custom Flash Attention: spill from registers
74.8 ms ± 21.

## Profile

We can compare the performance of the kernel with and without register spilling as follows:
```
# Get ptx files
nvcc -arch=sm_80 -ptx flash_attention.cu -o flash_attention.ptx
nvcc -arch=sm_80 -ptx flash_attention_spilling_from_registers.cu -o flash_attention_spilling_from_registers.ptx

# Get ncu metrics
nvcc -O3 -o test_attention main.cu flash_attention.cu flash_attention_spilling_from_registers.cu
ncu ./test_attention
```

### PTX comparison

TODO


### Nsight Compute Comparison

| **Metric** | **Spilling Kernel (`flash_attention_spilling_from_registers_k`)** | **Non-Spilling Kernel (`flash_attention_k`)** | **Difference / Interpretation** |
|------------|------------------------------------------------------------------|-----------------------------------------------|---------------------------------|
| **Duration** | **13.02 ms** | **2.10 ms** | Spilling kernel is ~6× slower. |
| **Elapsed Cycles** | 7.6 M | 1.2 M | Extra cycles lost to spills. |
| **Compute (SM) Throughput** | **0.26%** | **1.20%** | 5× higher compute utilization in non-spilling kernel. |
| **L2 Cache Throughput** | **1.71%** | **0.56%** | Spilling kernel hits L2 more — spilled registers go to local memory via L2. |




```
flash_attention_spilling_from_registers_k (1, 1, 1)x(32, 16, 1), Context 1, Stream 7, Device 0, CC 7.5
    Section: GPU Speed Of Light Throughput
    ----------------------- ------------- ------------
    Metric Name               Metric Unit Metric Value
    ----------------------- ------------- ------------
    DRAM Frequency          cycle/nsecond         5.00
    SM Frequency            cycle/usecond       585.15
    Elapsed Cycles                  cycle      7619363
    Memory Throughput                   %         1.86
    DRAM Throughput                     %         0.01
    Duration                      msecond        13.02
    L1/TEX Cache Throughput             %        74.55
    L2 Cache Throughput                 %         1.71
    SM Active Cycles                cycle    190430.67
    Compute (SM) Throughput             %         0.26
    ----------------------- ------------- ------------

    WRN   This kernel grid is too small to fill the available resources on this device, resulting in only 0.0 full      
          waves across all SMs. Look at Launch Statistics for more details.                                             

    Section: Launch Statistics
    -------------------------------- --------------- ---------------
    Metric Name                          Metric Unit    Metric Value
    -------------------------------- --------------- ---------------
    Block Size                                                   512
    Function Cache Configuration                     CachePreferNone
    Grid Size                                                      1
    Registers Per Thread             register/thread              60
    Shared Memory Configuration Size           Kbyte           65.54
    Driver Shared Memory Per Block        byte/block               0
    Dynamic Shared Memory Per Block       byte/block               0
    Static Shared Memory Per Block       Kbyte/block           25.60
    Threads                                   thread             512
    Waves Per SM                                                0.01
    -------------------------------- --------------- ---------------

    WRN   The grid for this launch is configured to execute only 1 blocks, which is less than the GPU's 40              
          multiprocessors. This can underutilize some multiprocessors. If you do not intend to execute this kernel      
          concurrently with other workloads, consider reducing the block size to have at least one block per            
          multiprocessor or increase the size of the grid to fully utilize the available hardware resources. See the    
          Hardware Model (https://docs.nvidia.com/nsight-compute/ProfilingGuide/index.html#metrics-hw-model)            
          description for more details on launch configurations.                                                        

    Section: Occupancy
    ------------------------------- ----------- ------------
    Metric Name                     Metric Unit Metric Value
    ------------------------------- ----------- ------------
    Block Limit SM                        block           16
    Block Limit Registers                 block            2
    Block Limit Shared Mem                block            2
    Block Limit Warps                     block            2
    Theoretical Active Warps per SM        warp           32
    Theoretical Occupancy                     %          100
    Achieved Occupancy                        %        50.00
    Achieved Active Warps Per SM           warp        16.00
    ------------------------------- ----------- ------------

    WRN   This kernel's theoretical occupancy is not impacted by any block limit. The difference between calculated     
          theoretical (100.0%) and measured achieved occupancy (50.0%) can be the result of warp scheduling overheads   
          or workload imbalances during the kernel execution. Load imbalances can occur between warps within a block    
          as well as across blocks of the same kernel. See the CUDA Best Practices Guide                                
          (https://docs.nvidia.com/cuda/cuda-c-best-practices-guide/index.html#occupancy) for more details on           
          optimizing occupancy.                                                                                         

  flash_attention_k (1, 1, 1)x(32, 16, 1), Context 1, Stream 7, Device 0, CC 7.5
    Section: GPU Speed Of Light Throughput
    ----------------------- ------------- ------------
    Metric Name               Metric Unit Metric Value
    ----------------------- ------------- ------------
    DRAM Frequency          cycle/nsecond         4.99
    SM Frequency            cycle/usecond       584.08
    Elapsed Cycles                  cycle      1225715
    Memory Throughput                   %         1.72
    DRAM Throughput                     %         0.04
    Duration                      msecond         2.10
    L1/TEX Cache Throughput             %        68.85
    L2 Cache Throughput                 %         0.56
    SM Active Cycles                cycle     30658.20
    Compute (SM) Throughput             %         1.20
    ----------------------- ------------- ------------

    WRN   This kernel grid is too small to fill the available resources on this device, resulting in only 0.0 full      
          waves across all SMs. Look at Launch Statistics for more details.                                             

    Section: Launch Statistics
    -------------------------------- --------------- ---------------
    Metric Name                          Metric Unit    Metric Value
    -------------------------------- --------------- ---------------
    Block Size                                                   512
    Function Cache Configuration                     CachePreferNone
    Grid Size                                                      1
    Registers Per Thread             register/thread              62
    Shared Memory Configuration Size           Kbyte           65.54
    Driver Shared Memory Per Block        byte/block               0
    Dynamic Shared Memory Per Block       byte/block               0
    Static Shared Memory Per Block       Kbyte/block           25.60
    Threads                                   thread             512
    Waves Per SM                                                0.01
    -------------------------------- --------------- ---------------

    WRN   The grid for this launch is configured to execute only 1 blocks, which is less than the GPU's 40              
          multiprocessors. This can underutilize some multiprocessors. If you do not intend to execute this kernel      
          concurrently with other workloads, consider reducing the block size to have at least one block per            
          multiprocessor or increase the size of the grid to fully utilize the available hardware resources. See the    
          Hardware Model (https://docs.nvidia.com/nsight-compute/ProfilingGuide/index.html#metrics-hw-model)            
          description for more details on launch configurations.                                                        

    Section: Occupancy
    ------------------------------- ----------- ------------
    Metric Name                     Metric Unit Metric Value
    ------------------------------- ----------- ------------
    Block Limit SM                        block           16
    Block Limit Registers                 block            2
    Block Limit Shared Mem                block            2
    Block Limit Warps                     block            2
    Theoretical Active Warps per SM        warp           32
    Theoretical Occupancy                     %          100
    Achieved Occupancy                        %        50.00
    Achieved Active Warps Per SM           warp        16.00
    ------------------------------- ----------- ------------

    WRN   This kernel's theoretical occupancy is not impacted by any block limit. The difference between calculated     
          theoretical (100.0%) and measured achieved occupancy (50.0%) can be the result of warp scheduling overheads   
          or workload imbalances during the kernel execution. Load imbalances can occur between warps within a block    
          as well as across blocks of the same kernel. See the CUDA Best Practices Guide                                
          (https://docs.nvidia.com/cuda/cuda-c-best-practices-guide/index.html#occupancy) for more details on           
          optimizing occupancy.                                                                                         
```

## Thunder

[Installation guide](https://lightning.ai/docs/thunder/latest/fundamentals/installation.html)

In [38]:
import thunder

attn_ex = thunder.extend.OperatorExecutor('attn_ex', version=0.01)
thunder.add_default_executor(attn_ex)

# [attn_ex, attn_ex, sdpa, nvfuser]

def my_attn_impl(query, key, value, scale):
    n_out, d = query.shape

    # S3 = torch.zeros(N_out, N_out, device="cuda")
    O3 = torch.zeros(N_out, d, device="cuda")
    L3 = torch.zeros(N_out, device="cuda")

    B_c = 16
    B_r = 16
    T_c = (N_inp + B_c - 1) // B_c
    T_r = (N_out + B_r - 1) // B_r

    int_args = torch.tensor([N_out, T_r, T_c], dtype=torch.int32)
    float_args = torch.tensor([scale_factor], dtype=torch.float32)
    ptr_args = torch.tensor([i.data_ptr() for i in (O3, L3, key, query, value)], dtype=torch.uint64)

    args = torch.tensor([
        *(i.data_ptr() for i in ptr_args),
        *(i.data_ptr() for i in float_args),
        *(i.data_ptr() for i in int_args)], dtype=torch.uint64
    )

    err, _ = cuda.cuLaunchKernel(
        kernel,
        1,  # grid x dim
        1,  # grid y dim
        1,  # grid z dim
        32, # block x dim
        32, # block y dim
        1,  # block z dim
        0,  # dynamic shared memory
        torch.cuda.current_stream().stream_id,  # stream
        args.data_ptr(),  # kernel arguments
        0,  # extra (ignore)
    )
    assert err == cuda.CUresult.CUDA_SUCCESS, err
    return O3, L3


In [39]:
## Register our implementation as an operator
def my_attn_meta(query, key, value, scale):
    return thunder.TensorProxy(like=query), thunder.TensorProxy(like=query, shape=(query.shape[:-1],))

my_attn = attn_ex.register_operator('my_attn', meta=my_attn_meta, fn=my_attn_impl)

In [40]:
def my_attn_checker(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
    if attn_mask is not None or dropout_p == 0.0 or is_causal:
        return False
    if len(query.shape) > 2:
            return (query.device.device_type == thunder.devices.DeviceType.CUDA and
                key.device == query.device and
                value.device == query.device)
    return False

def my_attn_transform(query, key, value, attn_masks=None, dropout_p=0.0, is_causal=False, scale=None):
    if scale is None:
        scale = query.size(-1) ** -0.5
    out = my_attn(query, key, value, scale)
    return out[0]

attn_ex.register_implementation(thunder.torch.scaled_dot_product_attention, checker=my_attn_checker,
                                  execution_transform=my_attn_transform)


#### Run...

In [41]:

def test_fn(query, key, value):
        return torch.nn.functional.scaled_dot_product_attention(query, key, value, is_causal=False)

jfn = thunder.jit(test_fn)

print((jfn(Q, K, V) - test_fn(Q, K, V)).abs().max())
print(thunder.last_traces(jfn)[-1])

tensor(0., device='cuda:0')
# Constructed by Unwrap the actual return value
import torch
import torch.nn.functional
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def computation(query, key, value):
  # query: "cuda:0 f32[1024, 128]"
  # key: "cuda:0 f32[512, 128]"
  # value: "cuda:0 f32[512, 128]"

  # /tmp/ipykernel_30370/3826419770.py:2: 	        return torch.nn.functional.scaled_dot_product_attention(query, key, value, is_causal=False)
  t41 = torch.nn.functional.scaled_dot_product_attention(query, key, value, None, 0.0, False, scale=None)  # t41: "cuda:0 f32[1024, 128]"
    # t41 = ltorch.scaled_dot_product_attention(query, key, value, None, 0.0, False, scale=None)  # t41: "cuda:0 f32[1024, 128]"
      # t28 = ltorch.mul(query, 0.29730177875068026)  # t28: "cuda:0 f32[1024, 128]"
        # t28 = prims.mul(query, 0.29730177875068026)  # t28: "cuda:0 f32[1024, 128]"
      # t29 = ltorch.transpose(key, -2, -1)  # t29: "cuda:0 f32[128, 512]"
        #

## Cuda info

In [42]:
print_cuda_info()


=== PyTorch CUDA Info ===
PyTorch version: 2.7.1+cu128
CUDA available: True
CUDA version: 12.8
cuDNN version: 90701
Number of GPUs: 1
  GPU 0: Tesla T4
    Current device: 0
    Memory allocated: 10.77 MB
    Memory cached   : 27.26 MB

=== nvidia-smi Info (if available) ===
Mon Aug 18 10:30:31 2025       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.261.03             Driver Version: 535.261.03   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  Tesla T4                       Off | 00000000:00:1E.0 Off |                    0 |
| N/A   42C    P0              34W /  70W |    917MiB / 15360M