In [9]:
import torch
import torch.nn.functional as F
from torch.nn.attention import SDPBackend, sdpa_kernel

D = 128
H = 16
seqs = [512, 1024, 2048, 4096, 8192, 16384]
b = [32, 16, 8, 4, 2, 1]
warmup = 3
reps = 5
use_cuda = torch.cuda.is_available()
device = torch.device('cuda' if use_cuda else 'cpu')
dtype = torch.float16
torch.manual_seed(0)
if use_cuda:
    torch.cuda.manual_seed_all(0)

def bench_sdpa(B, H, L, backend, warmup, reps):
    q = torch.randn(B, H, L, D, device=device, dtype=dtype)
    k = torch.randn(B, H, L, D, device=device, dtype=dtype)
    v = torch.randn(B, H, L, D, device=device, dtype=dtype)

    if not use_cuda:
        import time
        t0 = time.perf_counter()
        for _ in range(warmup):
            _ = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=False)
        t1 = time.perf_counter()
        for _ in range(reps):
            _ = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=False)
        t2 = time.perf_counter()
        warmup_ms = (t1 - t0) * 1e3
        avg_ms = (t2 - t1) * 1e3 / reps
        return avg_ms, None
    else:
        if backend == 'flash':
            backends = [SDPBackend.FLASH_ATTENTION]
        else:
            raise ValueError('Unknown backend')

        with sdpa_kernel(backends):
            for _ in range(warmup):
                _ = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=False)
            torch.cuda.synchronize()

            start = torch.cuda.Event(enable_timing=True)
            end = torch.cuda.Event(enable_timing=True)
            sum_time = 0
            for _ in range(reps):
                start.record()
                _ = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=False)
                end.record()
                torch.cuda.synchronize()
                sum_time += start.elapsed_time(end)
        return sum_time / reps, (q, k, v)

for (B, L) in zip(b, seqs):
    try:
        f_ms, tensors = bench_sdpa(B, H, L, 'flash', warmup, reps)
    except Exception as e:
        print(f"L={L:4d} | flash failed: {e}")
        f_ms, tensors = None, None

    flops = 4.0 * L * L * D * H * B
    tflops = (flops / (f_ms / 1000.0)) / 1e12
    print(f"B={B} H= {H} L={L:4d} | flash={f_ms:.4} ms | TFLOPS = {tflops}")

B=32 H= 16 L= 512 | flash=1.677 ms | TFLOPS = 40.97500513046423
B=16 H= 16 L=1024 | flash=2.712 ms | TFLOPS = 50.67760280600406
B=8 H= 16 L=2048 | flash=4.529 ms | TFLOPS = 60.68763895319586
B=4 H= 16 L=4096 | flash=8.527 ms | TFLOPS = 64.4738494384939
B=2 H= 16 L=8192 | flash=16.84 ms | TFLOPS = 65.30799629394922
B=1 H= 16 L=16384 | flash=31.06 ms | TFLOPS = 70.79427202352494
