In [1]:
from time import time
from tqdm import tqdm
import platform
import torch
import torch.nn.functional as tnf
from torch.nn.attention import SDPBackend, sdpa_kernel

device = torch.device('cuda:0')

print(torch.backends.cuda.flash_sdp_enabled())
print(torch.backends.cuda.mem_efficient_sdp_enabled())
print(torch.backends.cuda.math_sdp_enabled())

True
True
True


In [2]:
print(platform.platform())
print(f'{torch.__version__=}')
print(f'{torch.cuda.get_device_properties(device)}'.removeprefix('_CudaDeviceProperties'))


Linux-6.5.0-35-generic-x86_64-with-glibc2.35
torch.__version__='2.3.0'
(name='NVIDIA GeForce RTX 3090', major=8, minor=6, total_memory=24259MB, multi_processor_count=82)


In [3]:
def speedtest():
    torch.cuda.reset_peak_memory_stats()
    start_time = time()
    N = 100
    for _ in tqdm(range(N), ascii=True):
        x = torch.randn(1024, 1024, 16, 64, device=device, dtype=torch.float16)
        y = tnf.scaled_dot_product_attention(x, x, x)
        torch.cuda.synchronize()
    elapsed_time = time() - start_time
    time_per_iter = elapsed_time / N
    print(f'{time_per_iter=:.4f} s/iter')
    mem = torch.cuda.max_memory_allocated(device) / 1e9
    print(f'{mem=:.2f} GB')

In [4]:
speedtest()

100%|##########| 100/100 [00:07<00:00, 14.17it/s]

time_per_iter=0.0706 s/iter
mem=6.51 GB





In [5]:
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
    speedtest()

100%|##########| 100/100 [00:07<00:00, 14.21it/s]

time_per_iter=0.0704 s/iter
mem=6.51 GB





In [6]:
with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION):
    speedtest()

100%|##########| 100/100 [00:02<00:00, 39.59it/s]

time_per_iter=0.0253 s/iter
mem=6.44 GB





In [7]:
with sdpa_kernel(SDPBackend.MATH):
    speedtest()

100%|##########| 100/100 [00:02<00:00, 35.09it/s]

time_per_iter=0.0285 s/iter
mem=9.14 GB



