In [None]:
# For tips on running notebooks in Google Colab, see
# https://pytorch.org/tutorials/beginner/colab
%matplotlib inline

(Beta) Implementing High-Performance Transformers with Scaled Dot Product Attention (SDPA)

实现高性能Transformer与缩放点积注意力（SDPA）

==========================================================================================

**Author:** [Driss Guessous](https://github.com/drisspg)


Summary 摘要
=======

In this tutorial, we want to highlight a new `torch.nn.functional`
function that can be helpful for implementing transformer architectures.
The function is named
`torch.nn.functional.scaled_dot_product_attention`. For detailed
description of the function, see the [PyTorch
documentation](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html#torch.nn.functional.scaled_dot_product_attention).
This function has already been incorporated into
`torch.nn.MultiheadAttention` and `torch.nn.TransformerEncoderLayer`.

在本教程中，我们将重点介绍一个新功能，该功能对于实现transformer架构非常有帮助。这个功能名为`torch.nn.functional.scaled_dot_product_attention`。有关该功能的详细说明，请参阅PyTorch文档。该功能已经被整合到`torch.nn.MultiheadAttention`和`torch.nn.TransformerEncoderLayer`中。


Overview 概述
========

At a high level, this PyTorch function calculates the scaled dot product
attention (SDPA) between query, key, and value according to the
definition found in the paper [Attention is all you
need](https://arxiv.org/abs/1706.03762). While this function can be
written in PyTorch using existing functions, a fused implementation can
provide large performance benefits over a naive implementation.

从高层次上讲，这个PyTorch函数根据论文《Attention is all you need》中定义的方式计算查询、键和值之间的缩放点积注意力（SDPA）。虽然可以使用现有的PyTorch函数编写此功能，但融合实现可以比朴素实现提供更大的性能优势。

Fused implementations 融合实现
=====================

For CUDA tensor inputs, the function will dispatch into one of the
following implementations:

-   [FlashAttention: Fast and Memory-Efficient Exact Attention with
    IO-Awareness](https://arxiv.org/abs/2205.14135)
-   [Memory-Efficient
    Attention](https://github.com/facebookresearch/xformers)
-   A PyTorch implementation defined in C++

对于CUDA张量输入，函数将调度到以下实现之一：

- FlashAttention: 快速且内存高效的IO感知精确注意力
- 内存高效注意力
- 用C++定义的PyTorch实现

<div style="background-color: #54c7ec; color: #fff; font-weight: 700; padding-left: 10px; padding-top: 5px; padding-bottom: 5px"><strong>NOTE:</strong></div>
<div style="background-color: #f3f4f7; padding-left: 10px; padding-top: 10px; padding-bottom: 10px; padding-right: 10px">
<p>This tutorial requires PyTorch 2.0.0 or later.</p>
</div>

本教程需要PyTorch 2.0.0或更高版本。

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
device = "cuda" if torch.cuda.is_available() else "cpu"

# Example Usage:
query, key, value = torch.randn(2, 3, 8, device=device), torch.randn(2, 3, 8, device=device), torch.randn(2, 3, 8, device=device)
F.scaled_dot_product_attention(query, key, value)

Explicit Dispatcher Control 显式调度控制
===========================

While the function will implicitly dispatch to one of the three
implementations, the user can also explicitly control the dispatch via
the use of a context manager. This context manager allows users to
explicitly disable certain implementations. If a user wants to ensure
the function is indeed using the fastest implementation for their
specific inputs, the context manager can be used to sweep through
measuring performance.

虽然该函数将隐式调度到三种实现之一，但用户也可以通过使用上下文管理器显式控制调度。该上下文管理器允许用户显式禁用某些实现。如果用户希望确保函数确实使用了其特定输入的最快实现，可以使用上下文管理器进行测量性能。

In [49]:
# Lets define a helpful benchmarking function:
import torch.utils.benchmark as benchmark
def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
    t0 = benchmark.Timer(
        stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
    )
    return t0.blocked_autorange().mean * 1e6

# Lets define the hyper-parameters of our input
batch_size = 32
max_sequence_len = 1024
num_heads = 32
embed_dimension = 32

dtype = torch.float16

query = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)
key = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)
value = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)

print(f"The default implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds")

print(f"The default implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds")

# Lets explore the speed of each of the 3 implementations
from torch.nn.attention import SDPBackend, sdpa_kernel


with sdpa_kernel(SDPBackend.MATH):
    math_time=benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value)
    print(f"The math implementation runs in {math_time:.3f} microseconds")

with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
    try:
        flash_time=benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value)
        print(f"The flash attention implementation runs in {flash_time:.3f} microseconds")
    except RuntimeError:
        print("FlashAttention is not supported. See warnings for reasons.")

with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION):
    try:
        efficient_time=benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value)
        print(f"The memory efficient implementation runs in {efficient_time:.3f} microseconds")
    except RuntimeError:
        print("EfficientAttention is not supported. See warnings for reasons.")

The default implementation runs in 1278.631 microseconds
The default implementation runs in 1250.900 microseconds
The math implementation runs in 7487.678 microseconds
The flash attention implementation runs in 1240.477 microseconds
The memory efficient implementation runs in 3143.190 microseconds


Hardware dependence 硬件依赖
===================

Depending on what machine you ran the above cell on and what hardware is
available, your results might be different. - If you don't have a GPU
and are running on CPU then the context manager will have no effect and
all three runs should return similar timings. - Depending on what
compute capability your graphics card supports flash attention or memory
efficient might have failed.

根据你运行上述代码的机器和可用的硬件，结果可能会有所不同。 - 如果没有GPU并且在CPU上运行，则上下文管理器不会产生任何效果，所有三次运行的时间应该相似。 - 根据你的显卡支持的计算能力，flash attention或内存高效可能会失败。

Causal Self Attention 因果自注意力
=====================

Below is an example implementation of a multi-headed causal self
attention block inspired by [Andrej Karpathy
NanoGPT](https://github.com/karpathy/nanoGPT) repository.

下面是一个受Andrej Karpathy的NanoGPT库启发的多头因果自注意力模块的示例实现。


In [51]:
class CausalSelfAttention(nn.Module):

    def __init__(self, num_heads: int, embed_dimension: int, bias: bool=False, is_causal: bool=False, dropout:float=0.0):
        super().__init__()
        assert embed_dimension % num_heads == 0
        # key, query, value projections for all heads, but in a batch
        self.c_attn = nn.Linear(embed_dimension, 3 * embed_dimension, bias=bias)
        # output projection
        self.c_proj = nn.Linear(embed_dimension, embed_dimension, bias=bias)
        # regularization
        self.dropout = dropout
        self.resid_dropout = nn.Dropout(dropout)
        self.num_heads = num_heads
        self.embed_dimension = embed_dimension
        # Perform causal masking
        self.is_causal = is_causal

    def forward(self, x):
        # [x]: (batch_size, seq_length, embed_dimension)
        
        # <num_heads>= 8
        # <heads_per_dim> = 64
        
        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        
        query_projected = self.c_attn(x)
        # [query_projected]: (batch_size, seq_length, 3 * embed_dimension) <- [x]: (batch_size, seq_length, embed_dimension)
        # [query_projected]: (32, 512, 3 * 512) <- (32, 512, 512)
        # [query_projected]= (32, 512, 1536)

        batch_size = query_projected.size(0)
        # <batch_size>= (batch_size, seq_length, 3 * embed_dimension)(0) = batch_size
        # <batch_size>= 32
        
        embed_dim = query_projected.size(2)
        # <embed_dim>= 3 * embed_dimension= 1536
        
        head_dim = embed_dim // (self.num_heads * 3)
        # <head_dim>= embed_dim // (num_heads * 3) = embed_dim // (8 * 3) = (3 * embed_dimension) // (8 * 3) 
        #           = embed_dimension // 8
        #           = embed_dimension // num_heads
        # <head_dim>= 64

        query, key, value = query_projected.chunk(3, -1)
        # [query]: (batch_size, seq_length, embed_dimension): (32, 512, 512)
        # [key]: (batch_size, seq_length, embed_dimension): (32, 512, 512)
        # [value]: (batch_size, seq_length, embed_dimension): (32, 512, 512)
        # chunk(3, -1)：方法 chunk 会将张量分成指定数量的块。这里的 3 表示要分成3个块，-1 表示沿着最后一个维度进行分割
        
        query = query.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2)
        # [query]: (batch_size, num_heads, seq_length, head_dim) <- (batch_size, seq_length, num_heads, head_dim) <- (batch_size, seq_length, embed_dimension)
        # [query]: (32, 8, 512, 64) <- (32, 512, 8, 64) <- (32, 512, 512)
        # view(batch_size, -1, self.num_heads, head_dim) 将 query 从 (batch_size, seq_length, embed_dimension) 转换为 (batch_size, seq_length, num_heads, head_dim), 其中 embed_dimension 被拆分为 num_heads 和 head_dim 两个维度, 即 embed_dimension = num_heads * head_dim
        # transpose(1, 2) 交换了第二维和第三维的位置
        
        key = key.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2)
        # [key]: (batch_size, num_heads, seq_length, head_dim) <- (batch_size, seq_length, num_heads, head_dim) <- (batch_size, seq_length, embed_dimension)
        # [key]: (32, 8, 512, 64) <- (32, 512, 8, 64) <- (32, 512, 512)
        
        value = value.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2)
        # [value]: (batch_size, num_heads, seq_length, head_dim) <- (batch_size, seq_length, num_heads, head_dim) <- (batch_size, seq_length, embed_dimension)
        # [value]: (32, 8, 512, 64) <- (32, 512, 8, 64) <- (32, 512, 512)
        
        if self.training:
            dropout = self.dropout
            is_causal = self.is_causal
        else:
            dropout = 0.0
            is_causal = False

        y = F.scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=dropout, is_causal=is_causal)
        # [y]: (batch_size, num_heads, seq_length, head_dim)
        # [y]: (32, 8, 512, 64)
        
        # [query, key, value]: (batch_size, num_heads, seq_length, head_dim)= (32, 8, 512, 64)
        
        
        y = y.transpose(1, 2).view(batch_size, -1, self.num_heads * head_dim)
        # [y]: (batch_size, seq_length, embed_dimension) <- (batch_size, seq_length, num_heads, head_dim) <- (batch_size, num_heads, seq_length, head_dim)
        # [y]: (batch_size, seq_length, embed_dimension)
        
        y = self.resid_dropout(self.c_proj(y))
        # [y]: (batch_size, seq_length, embed_dimension)
        return y


num_heads = 8
heads_per_dim = 64
embed_dimension = num_heads * heads_per_dim
# 512 = 8 * 64
dtype = torch.float16
model = CausalSelfAttention(num_heads=num_heads, embed_dimension=embed_dimension, bias=False, is_causal=True, dropout=0.1).to("cuda").to(dtype).eval()
print(model)

CausalSelfAttention(
  (c_attn): Linear(in_features=512, out_features=1536, bias=False)
  (c_proj): Linear(in_features=512, out_features=512, bias=False)
  (resid_dropout): Dropout(p=0.1, inplace=False)
)


In [None]:
import math
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None) -> torch.Tensor:
    # [query, key, value]: (batch_size, num_heads, seq_length, head_dim)= (32, 8, 512, 64)
    L, S = query.size(-2), key.size(-2)
    # <L>= seq_length= 512
    # <S>= seq_length= 512
    
    scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
    # <query.size(-1)>= head_dim= 64
    # <scale_factor>= 1 / sqrt(head_dim)= 1 / sqrt(64)= 1 / 8= 0.125
    
    attn_bias = torch.zeros(L, S, dtype=query.dtype)
    # [attn_bias]: (seq_length, seq_length)= (512, 512)
    
    if is_causal:
        assert attn_mask is None
        temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
        # [temp_mask]: (seq_length, seq_length)= (512, 512)
        # tril(diagonal=0)：返回一个矩阵，该矩阵是输入矩阵的下三角部分，其余部分被置零
        
        attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
        # [attn_bias]: (seq_length, seq_length)= (512, 512)
        # [temp_mask.logical_not()]: (seq_length, seq_length)= (512, 512)
        # (4, 4)::
        # Temp Mask:
        # tensor([[ True, False, False, False],
        #         [ True,  True, False, False],
        #         [ True,  True,  True, False],
        #         [ True,  True,  True,  True]])

        # Attention Bias:
        # tensor([[   0.,  -inf,  -inf,  -inf],
        #         [   0.,    0.,  -inf,  -inf],
        #         [   0.,    0.,    0.,  -inf],
        #         [   0.,    0.,    0.,    0.]])
        # temp_mask 是一个下三角布尔矩阵, 其中 True 表示可以注意到的位置, False 表示不能注意到的位置
        # attn_bias 是一个下三角矩阵, 0 表示没有偏置, -inf 表示注意力分数将被极大地缩小, 从而有效地屏蔽了这些位置
        
        attn_bias.to(query.dtype)

    if attn_mask is not None:
        if attn_mask.dtype == torch.bool:
            attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
            # [attn_bias]: (seq_length, seq_length)= (512, 512)
            # Attention Mask:
            # tensor([[ True, False,  True,  True],
            #         [ True,  True, False,  True],
            #         [ True,  True,  True, False],
            #         [False,  True,  True,  True]])

            # Attention Bias after masked_fill_:
            # tensor([[   0., -inf,    0.,    0.],
            #         [   0.,    0., -inf,    0.],
            #         [   0.,    0.,    0., -inf],
            #         [-inf,    0.,    0.,    0.]])
            # attn_mask 是一个布尔张量，其中 True 表示可以注意到的位置，False 表示不能注意到的位置
            # attn_mask.logical_not() 将 True 变为 False，False 变为 True
            # attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) 将 attn_mask 中 False 的位置在 attn_bias 中填充为 -inf，这样在计算注意力权重时，这些位置的注意力分数将被极大地缩小，从而有效地屏蔽了这些位置
                
        else:
            attn_bias += attn_mask
            # [attn_bias]: (seq_length, seq_length)= (512, 512)
            # Attention Mask:
            # tensor([[   0., -inf,    0.,    0.],
            #         [   0.,    0., -inf,    0.],
            #         [   0.,    0.,    0., -inf],
            #         [-inf,    0.,    0.,    0.]])

            # Attention Bias after adding attn_mask:
            # tensor([[   0., -inf,    0.,    0.],
            #         [   0.,    0., -inf,    0.],
            #         [   0.,    0.,    0., -inf],
            #         [-inf,    0.,    0.,    0.]])
            # attn_mask 是一个包含具体偏置值的浮点张量，其中 0 表示没有偏置，-inf 表示该位置应该被屏蔽。
            # attn_bias 初始为全零张量，形状为 (4, 4)
            # attn_bias += attn_mask 将 attn_mask 中的值逐元素加到 attn_bias 中。结果是 attn_bias 的对应位置被更新为 attn_mask 的值。
    attn_weight = query @ key.transpose(-2, -1) * scale_factor
    # [attn_weight]: (batch_size, num_heads, seq_length_q, seq_length_k)= (32, 8, 512, 512)
    
    # [query]: (batch_size, num_heads, seq_length_q, head_dim)= (32, 8, 512, 64)
    # [key]: (batch_size, num_heads, seq_length_k, head_dim)= (32, 8, 512, 64)
    # [key.transpose(-2, -1)]: (batch_size, num_heads, head_dim, seq_length_k)= (32, 8, 64, 512)
    
    # [query @ key.transpose(-2, -1)]= [(batch_size, num_heads, seq_length_k, head_dim) @ (batch_size, num_heads, head_dim, seq_length_k)]= (batch_size, num_heads, seq_length_q, seq_length_k)= [(32, 8, 512, 64) @ (32, 8, 64, 512)]= (32, 8, 512, 512)
    
    attn_weight += attn_bias
    # [attn_weight]: (batch_size, num_heads, seq_length_q, seq_length_k)= (32, 8, 512, 512)
    # [attn_bias]: (seq_length_q, seq_length_k)= (512, 512)
    # attn_weight += attn_bias 将 attn_bias 加到 attn_weight 上，这样在计算 softmax 时，这些位置的注意力分数将被极大地缩小，从而有效地屏蔽了这些位置
    
    attn_weight = torch.softmax(attn_weight, dim=-1)
    # [attn_weight]: (batch_size, num_heads, seq_length_q, seq_length_k)= (32, 8, 512, 512)
    # softmax(dim=-1)：对最后一个维度进行 softmax 操作，即对 seq_length_k 进行 softmax 操作
    # 未应用 softmax 前，attn_weight 包含未归一化的注意力得分。
    # softmax 操作将所有位置的注意力分数归一化到 [0, 1] 之间，使得所有位置的注意力分数之和为 1
    # 应用 softmax 后，attn_weight 变成了概率分布，每个查询位置对于所有键位置的注意力权重和为1。
    
    attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
    # [attn_weight]: (batch_size, num_heads, seq_length, seq_length)= (32, 8, 512, 512)
    # dropout(attn_weight, dropout_p, train=True)：对 attn_weight 进行 dropout 操作，以减少过拟合
    # dropout_p=0.0 表示没有 dropout 操作
    
    return attn_weight @ value
    # [value]: (batch_size, num_heads, seq_length_v, head_dim)= (32, 8, 512, 64)
    # [attn_weight @ value]: [(batch_size, num_heads, seq_length_q, seq_length_k) @ (batch_size, num_heads, seq_length_v, head_dim)]= (batch_size, num_heads, seq_length, head_dim)= (32, 8, 512, 64)


`NestedTensor` and Dense tensor support `NestedTensor`与Dense Tensor支持
=======================================

SDPA supports both `NestedTensor` and Dense tensor inputs.
`NestedTensors` handle the case where the input is a batch of variable
length sequences without needing to pad each sequence to the maximum
length in the batch. For more information about `NestedTensors` see
[torch.nested](https://pytorch.org/docs/stable/nested.html) and
[NestedTensors
Tutorial](https://pytorch.org/tutorials/prototype/nestedtensor.html).

SDPA支持`NestedTensor`和Dense Tensor输入。`NestedTensor`处理输入是可变长度序列的批次的情况，而无需将每个序列填充到批次的最大长度。有关NestedTensors的更多信息，请参阅`torch.nested`和NestedTensors教程。

In [54]:
import random
def generate_rand_batch(
    batch_size,
    max_sequence_len,
    embed_dimension,
    pad_percentage=None,
    dtype=torch.float16,
    device="cuda",
):
    if not pad_percentage:
        return (
            torch.randn(
                batch_size,
                max_sequence_len,
                embed_dimension,
                dtype=dtype,
                device=device,
            ),
            None,
        )
        # torch.randn(batch_size, max_sequence_len, embed_dimension, dtype=dtype, device=device)：生成一个形状为 (batch_size, max_sequence_len, embed_dimension) 的张量
        # 返回值：(张量, None)
        
    # Random sequence lengths
    seq_len_list = [
        int(max_sequence_len * (1 - random.gauss(pad_percentage, 0.01)))
        for _ in range(batch_size)
    ]
    # random.gauss(pad_percentage, 0.01)：从均值为 pad_percentage，标准差为 0.01 的正态分布中采样一个随机数, 生成的随机数大部分在 pad_percentage 附近
    # 1 - random.gauss(pad_percentage, 0.01) 生成的随机数大部分在 1 - pad_percentage 附近
    # int(max_sequence_len * (1 - random.gauss(pad_percentage, 0.01)))：生成一个随机的序列长度，大部分在 max_sequence_len * (1 - pad_percentage) 附近
    # Make random entry in the batch have max sequence length
    seq_len_list[random.randint(0, batch_size - 1)] = max_sequence_len
    # random.randint(0, batch_size - 1)：生成一个随机的整数，范围在 [0, batch_size - 1] 之间
    # seq_len_list[random.randint(0, batch_size - 1)] = max_sequence_len 将随机的一个序列长度设置为 max_sequence_len
    
    return (
        torch.nested.nested_tensor(
            [
                torch.randn(seq_len, embed_dimension,
                            dtype=dtype, device=device)
                for seq_len in seq_len_list
            ]
        ),
        seq_len_list,
    )
    # [torch.randn(seq_len, embed_dimension, dtype=dtype, device=device) for seq_len in seq_len_list]：生成一个嵌套张量，其中每个张量的形状为 (seq_len, embed_dimension)
    # torch.nested.nested_tensor：将一个嵌套的张量列表转换为嵌套张量
    # 返回值：(嵌套张量, 序列长度列表)

random_nt, _ = generate_rand_batch(32, 512, embed_dimension, pad_percentage=0.5, dtype=dtype, device=device)
random_dense, _ = generate_rand_batch(32, 512, embed_dimension, pad_percentage=None, dtype=dtype, device=device)

# Currently the fused implementations don't support ``NestedTensor`` for training
model.eval()

with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
    try:
        print(f"Random NT runs in {benchmark_torch_function_in_microseconds(model, random_nt):.3f} microseconds")
        print(f"Random Dense runs in {benchmark_torch_function_in_microseconds(model, random_dense):.3f} microseconds")
    except RuntimeError:
        print("FlashAttention is not supported. See warnings for reasons.")

Random NT runs in 494.214 microseconds
Random Dense runs in 374.631 microseconds


Using SDPA with `torch.compile` 使用`torch.compile`的SDPA
===============================

With the release of PyTorch 2.0, a new feature called `torch.compile()`
has been introduced, which can provide significant performance
improvements over eager mode. Scaled dot product attention is fully
composable with `torch.compile()`. To demonstrate this, let\'s compile
the `CausalSelfAttention` module using `torch.compile()` and observe the
resulting performance improvements.

随着PyTorch 2.0的发布，一个名为`torch.compile()`的新特性被引入，该特性相较于即时模式可以提供显著的性能提升。缩放点积注意力可以完全与`torch.compile()`结合使用。为了演示这一点，我们将编译`CausalSelfAttention`模块，并观察由此带来的性能提升。


In [59]:
batch_size = 32
max_sequence_len = 256
x = torch.rand(batch_size, max_sequence_len,
               embed_dimension, device=device, dtype=dtype)
print(
    f"The non compiled module runs in  {benchmark_torch_function_in_microseconds(model, x):.3f} microseconds")


compiled_model = torch.compile(model)
# Let's compile it
compiled_model(x)
print(
    f"The compiled module runs in  {benchmark_torch_function_in_microseconds(compiled_model, x):.3f} microseconds")

The non compiled module runs in  188.795 microseconds
The compiled module runs in  286.094 microseconds


The exact execution time is dependent on machine, however the results
for mine: The non compiled module runs in 166.616 microseconds The
compiled module runs in 166.726 microseconds That is not what we were
expecting. Let\'s dig a little deeper. PyTorch comes with an amazing
built-in profiler that you can use to inspect the performance
characteristics of your code.

确切的执行时间依赖于机器，但我的结果如下：非编译模块运行时间：166.616微秒，编译模块运行时间：166.726微秒。结果并不是我们所期望的。让我们深入探讨一下。PyTorch自带了一个惊人的内置分析器，可以用来检查代码的性能特征。


In [56]:
from torch.profiler import profile, record_function, ProfilerActivity
activities = [ProfilerActivity.CPU]
if device == 'cuda':
    activities.append(ProfilerActivity.CUDA)

with profile(activities=activities, record_shapes=False) as prof:
    with record_function(" Non-Compilied Causal Attention"):
        for _ in range(25):
            model(x)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
# 打印性能分析结果，按 CUDA 时间总和排序，显示前 10 行。

with profile(activities=activities, record_shapes=False) as prof:
    with record_function("Compiled Causal Attention"):
        for _ in range(25):
            compiled_model(x)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))

# For even more insights, you can export the trace and use ``chrome://tracing`` to view the results
#
# .. code-block:: python
#
#    prof.export_chrome_trace("compiled_causal_attention_trace.json").

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                         Non-Compilied Causal Attention         0.00%       0.000us         0.00%       0.000us       0.000us       8.605ms        59.57%       8.605ms       8.605ms             1  
                         Non-Compilied Causal Attention        18.39%       2.551ms        99.94%      13.859ms      13.859ms       0.000us         0.00%       5.839ms       5.839ms             1  
         

STAGE:2024-06-27 11:30:01 62177:62177 ActivityProfilerController.cpp:314] Completed Stage: Warm Up
STAGE:2024-06-27 11:30:01 62177:62177 ActivityProfilerController.cpp:320] Completed Stage: Collection
STAGE:2024-06-27 11:30:01 62177:62177 ActivityProfilerController.cpp:324] Completed Stage: Post Processing
STAGE:2024-06-27 11:30:02 62177:62177 ActivityProfilerController.cpp:314] Completed Stage: Warm Up
STAGE:2024-06-27 11:30:02 62177:62177 ActivityProfilerController.cpp:320] Completed Stage: Collection
STAGE:2024-06-27 11:30:02 62177:62177 ActivityProfilerController.cpp:324] Completed Stage: Post Processing


The previous code snippet generates a report of the top 10 PyTorch
functions that consumed the most GPU execution time, for both the
compiled and non-compiled module. The analysis reveals that the majority
of time spent on the GPU is concentrated on the same set of functions
for both modules. The reason for this here is that `torch.compile` is
very good at removing the framework overhead associated with PyTorch. If
your model is launching large, efficient CUDA kernels, which in this
case `CausalSelfAttention` is, then the overhead of PyTorch can be
hidden.

先前的代码生成了一份报告，显示了消耗GPU执行时间最多的前十个PyTorch函数，分别针对编译和未编译模块。分析显示，GPU上大多数时间都集中在相同的一组函数上，这两个模块都是如此。原因是`torch.compile`非常擅长消除与PyTorch相关的框架开销。如果你的模型启动了大型、高效的CUDA内核（在这种情况下，`CausalSelfAttention`就是这种情况），那么PyTorch的开销可能会被隐藏。

In reality, your module does not normally consist of a singular
`CausalSelfAttention` block. When experimenting with [Andrej Karpathy
NanoGPT](https://github.com/karpathy/nanoGPT) repository, compiling the
module took the time per train step from: `6090.49ms` to `3273.17ms`!
This was done on commit: `ae3a8d5` of NanoGPT training on the
Shakespeare dataset.

实际上，你的模块通常不会仅包含一个`CausalSelfAttention`块。在实验Andrej Karpathy的[NanoGPT](https://github.com/karpathy/nanoGPT)库时，编译模块将每次训练步骤的时间从6090.49毫秒减少到3273.17毫秒！这是在NanoGPT的commit: `ae3a8d5`上对Shakespeare数据集进行训练时实现的。

Using SDPA with attn\_bias subclasses\`
=======================================

As of PyTorch 2.3, we have added a new submodule that contains tensor
subclasses. Designed to be used with
`torch.nn.functional.scaled_dot_product_attention`. The module is named
`torch.nn.attention.bias` and contains the following two utilities for
generating causal attention variants:

-   `torch.nn.attention.bias.causal_upper_left`
-   `torch.nn.attention.bias.causal_lower_right`

从PyTorch 2.3开始，我们添加了一个包含张量子类的新子模块。该模块设计用于与`torch.nn.functional.scaled_dot_product_attention`一起使用。这个模块名为`torch.nn.attention.bias`，包含以下两个生成因果注意力变体的工具：

-   `torch.nn.attention.bias.causal_upper_left`
-   `torch.nn.attention.bias.causal_lower_right`

<div style="background-color: #54c7ec; color: #fff; font-weight: 700; padding-left: 10px; padding-top: 5px; padding-bottom: 5px"><strong>NOTE:</strong></div>
<div style="background-color: #f3f4f7; padding-left: 10px; padding-top: 10px; padding-bottom: 10px; padding-right: 10px">
<p>The current argument <code>is_causal</code> in <code>torch.nn.functional.scaled_dot_product_attention</code>is the same as using <code>torch.nn.attention.bias.causal_upper_left</code>.</p>
</div>

当前在`torch.nn.functional.scaled_dot_product_attention`中的参数`is_causal`与使用`torch.nn.attention.bias.causal_upper_left`相同。


In [60]:
from torch.nn.attention.bias import causal_lower_right, causal_upper_left

batch_size = 32
sequence_length_q = 2
sequence_length_kv = 10
num_heads = 16
embed_dimension = 32

dtype = torch.float16

query = torch.rand(batch_size, num_heads, sequence_length_q, embed_dimension, device=device, dtype=dtype)
key = torch.rand(batch_size, num_heads, sequence_length_kv, embed_dimension, device=device, dtype=dtype)
value = torch.rand(batch_size, num_heads, sequence_length_kv, embed_dimension, device=device, dtype=dtype)

upper_left_bias = causal_upper_left(sequence_length_q, sequence_length_kv)
# [[ True, False, False, False, False, False, False, False, False, False],
#  [ True,  True, False, False, False, False, False, False, False, False]]
# 查询序列 q 中的第0个token只能注意到键序列 k 中的第0个token
# 查询序列 q 中的第1个token可以注意到键序列 k 中的前2个token
# 具体来说
#   upper_left_bias[0] 表示查询序列 q 中的第0个token，对应键序列 k 的第0个token
#   upper_left_bias[1] 表示查询序列 q 中的第1个token，对应键序列 k 的前2个token


lower_right_bias = causal_lower_right(sequence_length_q, sequence_length_kv)
# [[ True,  True,  True,  True,  True,  True,  True,  True,  True, False],
#  [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True]]
# 查询序列 q 中的第0个token可以注意到键序列 k 中的前9个token
# 查询序列 q 中的第1个token可以注意到键序列 k 中的所有token
# 具体来说
#   lower_right_bias[0] 表示查询序列 q 中的第0个token，对应键序列 k 的前9个token
#   lower_right_bias[1] 表示查询序列 q 中的第1个token，对应键序列 k 的所有token

# 上左偏置 (causal_upper_left)：将因果注意力掩码对齐到注意力分数矩阵的左上角。适用于从头开始的解码场景。
#   查询的第0个token只能关注到键的第0个token。
#   查询的第1个token可以关注到键的第0个和第1个token。
# 下右偏置 (causal_lower_right)：将因果注意力掩码对齐到注意力分数矩阵的右下角。适用于从尾开始的解码场景。
#   查询的第0个token可以关注到键的前9个token。
#   查询的第1个token可以关注到键的所有token。


print(type(upper_left_bias))
print(type(lower_right_bias))

assert type(upper_left_bias) == type(lower_right_bias)
assert issubclass(type(upper_left_bias), torch.Tensor)

# As you can see from the previous output, are the same type ``torch.nn.attention.bias.CausalBias``
# and subclass ``torch.Tensor``

# Lets see what these tensors look like
print(upper_left_bias)
print(lower_right_bias)

# Upper Left Bias aligns the causal attention mask to the upper left corner of the attention scores matrix.
# This only has an impact when the attention scores matrix is not square, which is common for decoding use cases.
# Another way of thinking about this concept is that when you use upper left bias,
# the 0th token in the query is aligned to the 0th token in the key, while for lower right bias,
# Assuming the attention score matrix is two dimensional, ``attn_score[0][0]`` is the attention score
# between the 0th token in the query and the 0th token in the key.
# For lower right bias, the sequence of q is aligned so that the last token in q is aligned to the last token in k
# (for example, ``attn_score[-1][-1])`` is all True since the last token in q is at the same position as the last token in k
# even if the sequence length of q and k are different.
# Upper Left Bias将因果注意力掩码对齐到注意力分数矩阵的左上角
# 当注意力分数矩阵不是方阵时，这只会产生影响，而这在解码用例中很常见
# 另一种思考这个概念的方式是，当使用上左偏置时，查询中的第 0 个标记与键中的第 0 个标记对齐，而对于下右偏置，假设注意力分数矩阵是二维的，``attn_score[0][0]`` 是查询中第 0 个标记与键中第 0 个标记之间的注意力分数
# 对于下右偏置，q 的序列对齐，使得 q 中的最后一个标记对齐到 k 中的最后一个标记（例如，``attn_score[-1][-1]`` 全为 True，因为 q 中的最后一个标记与 k 中的最后一个标记在同一位置，即使 q 和 k 的序列长度不同）

# These objects are intended to be used with sdpa
out_upper_left = F.scaled_dot_product_attention(query, key, value, upper_left_bias)
out_lower_right = F.scaled_dot_product_attention(query, key, value, lower_right_bias)
out_is_causal = F.scaled_dot_product_attention(query, key, value, is_causal=True)

assert torch.allclose(out_upper_left, out_is_causal)
assert not torch.allclose(out_upper_left, out_lower_right)

# These attention biases should also be compatible with torch.compile
compiled_sdpa = torch.compile(F.scaled_dot_product_attention, fullgraph=True)
out_upper_left = compiled_sdpa(query, key, value, upper_left_bias)

<class 'torch.nn.attention.bias.CausalBias'>
<class 'torch.nn.attention.bias.CausalBias'>
tensor([[ True, False, False, False, False, False, False, False, False, False],
        [ True,  True, False, False, False, False, False, False, False, False]])
tensor([[ True,  True,  True,  True,  True,  True,  True,  True,  True, False],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True]])


Conclusion
==========

In this tutorial, we have demonstrated the basic usage of
`torch.nn.functional.scaled_dot_product_attention`. We have shown how
the `sdpa_kernel` context manager can be used to assert a certain
implementation is used on GPU. As well, we built a simple
`CausalSelfAttention` module that works with `NestedTensor` and is torch
compilable. In the process we have shown how to the profiling tools can
be used to explore the performance characteristics of a user defined
module.

在本教程中，我们演示了`torch.nn.functional.scaled_dot_product_attention`的基本用法。我们展示了如何使用`sdpa_kernel`上下文管理器来断言在GPU上使用某种实现。此外，我们构建了一个简单的`CausalSelfAttention`模块，该模块支持`NestedTensor`并可编译。在此过程中，我们展示了如何使用分析工具来探索用户定义模块的性能特征。