In [50]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from rich import print


In [80]:
# Windows和Linux上使用GPU
# device = "cuda" if torch.cuda.is_available() else "cpu"
# Mac 上使用 GPU加速：
# device = torch.device("mps")
device = "mps" if torch.backends.mps.is_built() else "cpu"

小的示例

In [81]:
# query, key, value = torch.randn(2, 3, 8, device=device), torch.randn(2, 3, 8, device=device), torch.randn(2, 3, 8, device=device)
# print(F.scaled_dot_product_attention(query, key, value))

In [83]:
# 计时器:
import torch.utils.benchmark as benchmark
def torch_timer(f, *args, **kwargs):
    t0 = benchmark.Timer(
        stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
    )
    return t0.blocked_autorange().mean * 1e6


In [84]:
# 超参数定义
batch_size = 64
max_sequence_len = 256
num_heads = 32
embed_dimension = 32
dtype = torch.float16

In [85]:
# 模拟 q k v
query = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)
key = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)
value = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)

In [86]:
print(f"基本对照方案 运行时间： {torch_timer(F.scaled_dot_product_attention, query, key, value):.3f} microseconds")



性能对照测试

In [96]:
from torch.backends.cuda import sdp_kernel
from enum import IntEnum

class SDPBackend(IntEnum):
    r"""
    Enum class for the scaled dot product attention backends.
    """
    ERROR = -1
    MATH = 0
    FLASH_ATTENTION = 1
    EFFICIENT_ATTENTION = 2

# 使用上下文管理器context manager来
# 其他三种方案，字典映射
backend_map = {
    SDPBackend.MATH: {
        "enable_math": True, 
        "enable_flash": False, 
        "enable_mem_efficient": False},
    SDPBackend.FLASH_ATTENTION: {
        "enable_math": False, 
        "enable_flash": True, 
        "enable_mem_efficient": False},
    SDPBackend.EFFICIENT_ATTENTION: {
        "enable_math": False, 
        "enable_flash": False, 
        "enable_mem_efficient": True}
}


In [97]:
with sdp_kernel(**backend_map[SDPBackend.MATH]):
    print(f"math 运行时间： {torch_timer(F.scaled_dot_product_attention, query, key, value):.3f} microseconds")



In [98]:
with sdp_kernel(**backend_map[SDPBackend.FLASH_ATTENTION]):
    try:
        print(f"flash attention 运行时间： {torch_timer(F.scaled_dot_product_attention, query, key, value):.3f} microseconds")
    except RuntimeError:
        print("FlashAttention is not supported")



In [99]:
with sdp_kernel(**backend_map[SDPBackend.EFFICIENT_ATTENTION]):
    try:
        print(f"Memory efficient 运行时间： {torch_timer(F.scaled_dot_product_attention, query, key, value):.3f} microseconds")
    except RuntimeError:
        print("EfficientAttention is not supported")

