In [9]:
import torch
import torch.nn.functional as F
from torch.nn.attention import SDPBackend, sdpa_kernel

D = 128
warmup = 3
reps = 5
seqs = [1024]
bh_sets = [(128,16)]

use_cuda = torch.cuda.is_available()
device = torch.device('cuda' if use_cuda else 'cpu')
dtype = torch.float16 if use_cuda else torch.float32

torch.manual_seed(0)
if use_cuda:
    torch.cuda.manual_seed_all(0)

def bench_sdpa(B, H, L, backend, warmup, reps):
    q = torch.randn(B, H, L, D, device=device, dtype=dtype)
    k = torch.randn(B, H, L, D, device=device, dtype=dtype)
    v = torch.randn(B, H, L, D, device=device, dtype=dtype)

    if not use_cuda:
        import time
        t0 = time.perf_counter()
        for _ in range(warmup):
            _ = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=False)
        t1 = time.perf_counter()
        for _ in range(reps):
            _ = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=False)
        t2 = time.perf_counter()
        warmup_ms = (t1 - t0) * 1e3
        avg_ms = (t2 - t1) * 1e3 / reps
        return avg_ms, None
    else:
        if backend == 'flash':
            backends = [SDPBackend.FLASH_ATTENTION]
        elif backend == 'mem_efficient':
            backends = [SDPBackend.EFFICIENT_ATTENTION]
        elif backend == 'math':
            backends = [SDPBackend.MATH]
        else:
            raise ValueError('Unknown backend')

        with sdpa_kernel(backends):
            for _ in range(warmup):
                _ = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=False)
            torch.cuda.synchronize()

            start = torch.cuda.Event(enable_timing=True)
            end = torch.cuda.Event(enable_timing=True)
            sum_time = 0
            for _ in range(reps):
                start.record()
                _ = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=False)
                end.record()
                torch.cuda.synchronize()
                sum_time += start.elapsed_time(end)
        return sum_time / reps, (q, k, v)

for (B, H) in bh_sets:
    print(f"\n== B={B} H={H} ==")
    for L in seqs:
        try:
            f_ms, tensors = bench_sdpa(B, H, L, 'flash', warmup, reps)
        except Exception as e:
            print(f"L={L:4d} | flash failed: {e}")
            f_ms, tensors = None, None

        try:
            m_ms, tensors_m = bench_sdpa(B, H, L, 'math', warmup, reps)
        except Exception as e:
            print(f"L={L:4d} | math failed: {e}")
            m_ms, tensors_m = None, None

        if use_cuda and tensors is not None and tensors_m is not None:
            q, k, v = tensors
            q2, k2, v2 = tensors_m
            with sdpa_kernel([SDPBackend.FLASH_ATTENTION]):
                out_f = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=False)
            with sdpa_kernel([SDPBackend.MATH]):
                out_m = F.scaled_dot_product_attention(q2, k2, v2, dropout_p=0.0, is_causal=False)

        print(f"L={L:4d} | flash={f_ms!s:>8} ms | math={m_ms!s:>8} ms")


== B=128 H=16 ==
L=1024 | math failed: CUDA out of memory. Tried to allocate 2.00 GiB. GPU 0 has a total capacity of 23.56 GiB of which 221.38 MiB is free. Including non-PyTorch memory, this process has 23.32 GiB memory in use. Of the allocated memory 23.01 GiB is allocated by PyTorch, and 11.88 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
L=1024 | flash=16.651417541503907 ms | math=    None ms
