In [1]:
from time import time
from tqdm import tqdm
import platform
import torch
import torch.nn.functional as tnf
from torch.nn.attention import SDPBackend, sdpa_kernel

device = torch.device('cuda:0')

print(torch.backends.cuda.flash_sdp_enabled())
print(torch.backends.cuda.mem_efficient_sdp_enabled())
print(torch.backends.cuda.math_sdp_enabled())

True
True
True


In [2]:
print(platform.platform())
print(f'{torch.__version__=}')
print(f'{torch.cuda.get_device_properties(device)}'.removeprefix('_CudaDeviceProperties'))


Linux-6.5.0-35-generic-x86_64-with-glibc2.35
torch.__version__='2.3.0'
(name='NVIDIA GeForce RTX 3090', major=8, minor=6, total_memory=24259MB, multi_processor_count=82)


In [3]:
def speedtest():
    torch.cuda.reset_peak_memory_stats()
    start_time = time()
    N = 100
    for _ in tqdm(range(N), ascii=True):
        x = torch.randn(1024, 64, 16, 64, device=device, dtype=torch.float16)
        y = tnf.scaled_dot_product_attention(x, x, x)
        torch.cuda.synchronize()
    elapsed_time = time() - start_time
    time_per_iter = elapsed_time / N
    print(f'{time_per_iter=:.4f} s/iter')
    mem = torch.cuda.max_memory_allocated(device) / 1e9
    print(f'{mem=:.2f} GB')
    print(f'{y.shape=}')

In [4]:
speedtest()

100%|##########| 100/100 [00:01<00:00, 91.60it/s]

time_per_iter=0.0109 s/iter
mem=0.41 GB
y.shape=torch.Size([1024, 64, 16, 64])





In [5]:
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
    speedtest()

100%|##########| 100/100 [00:01<00:00, 98.42it/s]

time_per_iter=0.0102 s/iter
mem=0.41 GB
y.shape=torch.Size([1024, 64, 16, 64])





In [6]:
with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION):
    speedtest()

100%|##########| 100/100 [00:00<00:00, 280.64it/s]

time_per_iter=0.0036 s/iter
mem=0.40 GB
y.shape=torch.Size([1024, 64, 16, 64])





In [7]:
with sdpa_kernel(SDPBackend.MATH):
    speedtest()

100%|##########| 100/100 [00:00<00:00, 227.48it/s]

time_per_iter=0.0044 s/iter
mem=0.58 GB
y.shape=torch.Size([1024, 64, 16, 64])





In [20]:
# sanity check
x = torch.randn(1024, 64, 16, 64, device=device, dtype=torch.float16)
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
    y1 = tnf.scaled_dot_product_attention(x, x, x)
with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION):
    y2 = tnf.scaled_dot_product_attention(x, x, x)
with sdpa_kernel(SDPBackend.MATH):
    y3 = tnf.scaled_dot_product_attention(x, x, x)
y1, y2, y3 = y1.float(), y2.float(), y3.float()
print(torch.allclose(y1.float(), y2.float()), torch.allclose(y1.float(), y3.float()))
print(torch.allclose(y2, y3))
print(torch.abs(y1 - y2).max(), torch.abs(y1 - y3).max(), torch.abs(y2 - y3).max())

False False
False
tensor(1.4369e-07, device='cuda:0') tensor(0.0001, device='cuda:0') tensor(0.0001, device='cuda:0')


In [16]:
y1.max(), y2.max(), y3.max()

(tensor(5.4766, device='cuda:0'),
 tensor(5.4766, device='cuda:0'),
 tensor(5.4766, device='cuda:0'))