In [1]:
import torch
import torch.utils.benchmark as benchmark
from torch.backends.cuda import sdp_kernel, SDPBackend
# from torch.nn.attention import sdp_kernel
import torch.nn.functional as F
import os
import numpy as np
import random

DEFAULT_SEED = 42

def set_seed(seed=DEFAULT_SEED):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True


# os.environ['CUDA_VISIBLE_DEVICES'] = '2'
# Lets define a helpful benchmarking function:
def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
    t0 = benchmark.Timer(
        stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
    )
    return t0.blocked_autorange().mean * 1e6

# Lets define the hyper-parameters of our input
batch_size = 32
max_sequence_len = 1024
num_heads = 32
embed_dimension = 32

dtype = torch.float16
device = "cuda:2"

query = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)
key = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)
value = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)

print(f"The default implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds")

# Lets explore the speed of each of the 3 implementations

# Helpful arg mapper
backend_map = {
    SDPBackend.MATH: {"enable_math": True, "enable_flash": False, "enable_mem_efficient": False},
    SDPBackend.FLASH_ATTENTION: {"enable_math": False, "enable_flash": True, "enable_mem_efficient": False},
    SDPBackend.EFFICIENT_ATTENTION: {
        "enable_math": False, "enable_flash": False, "enable_mem_efficient": True}
}

with sdp_kernel(**backend_map[SDPBackend.MATH]):
    print(f"The math implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds")

with sdp_kernel(**backend_map[SDPBackend.FLASH_ATTENTION]):
    try:
        print(f"The flash attention implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds")
    except RuntimeError:
        print("FlashAttention is not supported. See warnings for reasons.")

with sdp_kernel(**backend_map[SDPBackend.EFFICIENT_ATTENTION]):
    try:
        print(f"The memory efficient implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds")
    except RuntimeError:
        print("EfficientAttention is not supported. See warnings for reasons.")


The default implementation runs in 939.589 microseconds




The math implementation runs in 928.117 microseconds
The flash attention implementation runs in 2005.556 microseconds
The memory efficient implementation runs in 176.585 microseconds


In [237]:
dtype = torch.bfloat16

set_seed()
query = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)
key = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)
value = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)


# with sdp_kernel(**backend_map[SDPBackend.FLASH_ATTENTION]):
print(F.scaled_dot_product_attention(query, key, value))
    


tensor([[[[0.5000, 0.4902, 0.4863,  ..., 0.4941, 0.4980, 0.4863],
          [0.5000, 0.4883, 0.4863,  ..., 0.4941, 0.4980, 0.4844],
          [0.5000, 0.4902, 0.4863,  ..., 0.4941, 0.4980, 0.4824],
          ...,
          [0.5000, 0.4902, 0.4863,  ..., 0.4961, 0.4980, 0.4844],
          [0.5000, 0.4902, 0.4863,  ..., 0.4961, 0.4980, 0.4844],
          [0.5000, 0.4902, 0.4844,  ..., 0.4961, 0.4961, 0.4844]],

         [[0.4922, 0.4941, 0.4980,  ..., 0.5039, 0.4883, 0.4863],
          [0.4902, 0.4961, 0.4980,  ..., 0.5039, 0.4883, 0.4863],
          [0.4922, 0.4941, 0.4980,  ..., 0.5039, 0.4883, 0.4863],
          ...,
          [0.4922, 0.4961, 0.4980,  ..., 0.5039, 0.4883, 0.4863],
          [0.4922, 0.4961, 0.4980,  ..., 0.5039, 0.4883, 0.4863],
          [0.4922, 0.4941, 0.4961,  ..., 0.5039, 0.4883, 0.4883]],

         [[0.4883, 0.5195, 0.4902,  ..., 0.4941, 0.5000, 0.5117],
          [0.4902, 0.5195, 0.4902,  ..., 0.4941, 0.5000, 0.5117],
          [0.4902, 0.5195, 0.4883,  ..., 0