In [1]:
import torch
import torch.nn.functional as F
from torch.nn.attention import SDPBackend, sdpa_kernel

D = 64
warmup = 3
reps = 5
seqs = [128, 256, 512, 1024]
bh_sets = [(2,4), (4,8), (8,8), (8,16), (16,16)]

use_cuda = torch.cuda.is_available()
device = torch.device('cuda' if use_cuda else 'cpu')
dtype = torch.float16 if use_cuda else torch.float32

torch.manual_seed(0)
if use_cuda:
    torch.cuda.manual_seed_all(0)

def bench_sdpa(B, H, L, backend, warmup, reps):
    q = torch.randn(B, H, L, D, device=device, dtype=dtype)
    k = torch.randn(B, H, L, D, device=device, dtype=dtype)
    v = torch.randn(B, H, L, D, device=device, dtype=dtype)

    if not use_cuda:
        import time
        t0 = time.perf_counter()
        for _ in range(warmup):
            _ = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=False)
        t1 = time.perf_counter()
        for _ in range(reps):
            _ = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=False)
        t2 = time.perf_counter()
        warmup_ms = (t1 - t0) * 1e3
        avg_ms = (t2 - t1) * 1e3 / reps
        return avg_ms, None
    else:
        if backend == 'flash':
            backends = [SDPBackend.FLASH_ATTENTION]
        elif backend == 'mem_efficient':
            backends = [SDPBackend.EFFICIENT_ATTENTION]
        elif backend == 'math':
            backends = [SDPBackend.MATH]
        else:
            raise ValueError('Unknown backend')

        with sdpa_kernel(backends):
            for _ in range(warmup):
                _ = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=False)
            torch.cuda.synchronize()

            start = torch.cuda.Event(enable_timing=True)
            end = torch.cuda.Event(enable_timing=True)
            start.record()
            for _ in range(reps):
                _ = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=False)
            end.record()
            torch.cuda.synchronize()
            avg_ms = start.elapsed_time(end) / reps
        return avg_ms, (q, k, v)

for (B, H) in bh_sets:
    print(f"\n== B={B} H={H} ==")
    for L in seqs:
        try:
            f_ms, tensors = bench_sdpa(B, H, L, 'flash', warmup, reps)
        except Exception as e:
            print(f"L={L:4d} | flash failed: {e}")
            f_ms, tensors = None, None

        try:
            m_ms, tensors_m = bench_sdpa(B, H, L, 'math', warmup, reps)
        except Exception as e:
            print(f"L={L:4d} | math failed: {e}")
            m_ms, tensors_m = None, None

        if use_cuda and tensors is not None and tensors_m is not None:
            q, k, v = tensors
            q2, k2, v2 = tensors_m
            with sdpa_kernel([SDPBackend.FLASH_ATTENTION]):
                out_f = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=False)
            with sdpa_kernel([SDPBackend.MATH]):
                out_m = F.scaled_dot_product_attention(q2, k2, v2, dropout_p=0.0, is_causal=False)

        print(f"L={L:4d} | flash={f_ms!s:>8} ms | math={m_ms!s:>8} ms")


== B=2 H=4 ==
L= 128 | flash=0.027820798754692077 ms | math=0.16079360246658325 ms
L= 256 | flash=0.026848000288009644 ms | math=0.16028800010681152 ms
L= 512 | flash=0.03256320059299469 ms | math=0.2091007947921753 ms
L=1024 | flash=0.05509120225906372 ms | math=0.5545983791351319 ms

== B=4 H=8 ==
L= 128 | flash=0.02682879865169525 ms | math=0.15912959575653077 ms
L= 256 | flash=0.02622080147266388 ms | math=0.19537919759750366 ms
L= 512 | flash=0.05631999969482422 ms | math=0.5799680233001709 ms
L=1024 | flash=0.18350080251693726 ms | math=1.9929088592529296 ms

== B=8 H=8 ==
L= 128 | flash=0.0268095999956131 ms | math=0.159334397315979 ms
L= 256 | flash=0.03583999872207642 ms | math=0.3483648061752319 ms
L= 512 | flash=0.1000704050064087 ms | math=1.0936320304870606 ms
L=1024 | flash=0.3090431928634644 ms | math=3.7894142150878904 ms

== B=8 H=16 ==
L= 128 | flash=0.027238398790359497 ms | math=0.22525439262390137 ms
L= 256 | flash=0.058982402086257935 ms | math=0.6461184024810791