In [40]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from rich import print
import torch.utils.benchmark as benchmark
import math

In [33]:
# Windows和Linux上使用GPU
# device = "cuda" if torch.cuda.is_available() else "cpu"
# Mac 上使用 GPU加速：
# device = torch.device("mps")
device = "mps" if torch.backends.mps.is_built() else "cpu"

In [58]:
# 设置超参数：
batch_size = 32
max_sequence_len = 128

num_heads = 8
heads_per_dim = 64
embed_dimension = num_heads * heads_per_dim
block_size = 1024
dtype = torch.float16

In [53]:
# 计时器:
def torch_timer(f, *args, **kwargs):
    t0 = benchmark.Timer(
        stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
    )
    return t0.blocked_autorange().mean * 1e6

In [54]:
class CausalSelfAttention(nn.Module):

    def __init__(self, num_heads: int, embed_dimension: int, bias: bool=False, dropout:float=0.0):
        super().__init__()
        assert embed_dimension % num_heads == 0
        # key, query, value projections for all heads, but in a batch
        self.c_attn = nn.Linear(embed_dimension, 3 * embed_dimension, bias=bias)
        # output projection
        self.c_proj = nn.Linear(embed_dimension, embed_dimension, bias=bias)
        # regularization
        self.attn_dropout = nn.Dropout(dropout)
        self.resid_dropout = nn.Dropout(dropout)
        self.num_heads = num_heads
        self.embed_dimension = embed_dimension
        self.dropout = dropout
        # flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0
        self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
        if not self.flash:
            print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
            # causal mask to ensure that attention is only applied to the left in the input sequence
            self.register_buffer("bias", torch.tril(torch.ones(block_size, block_size))
                                        .view(1, 1, block_size, block_size))

    def forward(self, x):
        B, T, C = x.size() # batch size, sequence length, embedding dimensionality (embed_dimension)

        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        q, k ,v  = self.c_attn(x).split(self.embed_dimension, dim=2)
        k = k.view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, nh, T, hs)
        q = q.view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, nh, T, hs)
        v = v.view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, nh, T, hs)

        # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
        if self.flash:
            # efficient attention using Flash Attention CUDA kernels
            y = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout, is_causal=True)
        else:
            # manual implementation of attention
            att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
            att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
            att = F.softmax(att, dim=-1)
            att = self.attn_dropout(att)
            y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
        y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side

        # output projection
        y = self.resid_dropout(self.c_proj(y))
        return y

In [55]:
model = CausalSelfAttention(num_heads=num_heads, 
                            embed_dimension=embed_dimension, 
                            bias=False, 
                            dropout=0.1).to("mps").to(dtype).eval() # mps / cuda
print(model)

In [56]:
# 模拟数据
x = torch.rand(batch_size,
               max_sequence_len,
               embed_dimension,
               device=device, 
               dtype=dtype)

print(f"原始model 运行时间： {torch_timer(model, x):.3f} microseconds")



In [57]:
torch._dynamo.config.suppress_errors = True
torch._dynamo.config.verbose=True

compiled_model = torch.compile(model)
compiled_model(x)
print(f"compiled model 运行时间： {torch_timer(compiled_model, x):.3f} microseconds")

[2023-03-19 15:01:54,160] torch._dynamo.convert_frame: [ERROR] WON'T CONVERT forward /var/folders/1_/3zjkdvv134x2rw13wbvrh2vh0000gn/T/ipykernel_11159/617640278.py line 24 
 25           0 LOAD_FAST                1 (x)
              2 LOAD_METHOD              0 (size)
              4 CALL_METHOD              0
              6 UNPACK_SEQUENCE          3
              8 STORE_FAST               2 (B)
             10 STORE_FAST               3 (T)
             12 STORE_FAST               4 (C)

 28          14 LOAD_FAST                0 (self)
             16 LOAD_METHOD              1 (c_attn)
             18 LOAD_FAST                1 (x)
             20 CALL_METHOD              1
             22 LOAD_ATTR                2 (split)
             24 LOAD_FAST                0 (self)
             26 LOAD_ATTR                3 (embed_dimension)
             28 LOAD_CONST               1 (2)
             30 LOAD_CONST               2 (('dim',))
             32 CALL_FUNCTION_KW         2
     

In [None]:
from torch.profiler import profile, record_function, ProfilerActivity

activities = [ProfilerActivity.CPU]
if device == 'cuda':
    activities.append(ProfilerActivity.CUDA)

with profile(activities=activities, record_shapes=False) as prof:
    with record_function(" Non-Compilied Causal Attention"):
        for _ in range(25):
            model(x)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))


with profile(activities=activities, record_shapes=False) as prof:
    with record_function("Compiled Causal Attention"):
        for _ in range(25):
            compiled_model(x)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))

# For even more insights, you can export the trace and use ``chrome://tracing`` to view the results
# prof.export_chrome_trace("compiled_causal_attention_trace.json").