In [1]:
from functools import partial
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

from einops import rearrange, repeat

from flash_attn.utils.benchmark import benchmark_all, benchmark_forward, benchmark_backward, benchmark_combined
from flash_attn.bert_padding import unpad_input, pad_input
from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
import tqdm 

def attention_ref(qkv, attn_mask, dropout_p, upcast=False, causal=False):
    """
    Arguments:
        qkv: (batch_size, seqlen, 3, nheads, head_dim)
        attn_mask: (batch_size, seqlen)
        dropout_p: float
    Output:
        output: (batch_size, seqlen, nheads, head_dim)
        attention: softmax after dropout
    """
    q, k, v = (qkv.float() if upcast else qkv).unbind(dim=2)
    seqlen = qkv.shape[1]
    d = qkv.shape[-1]
    scores = torch.einsum('bthd,bshd->bhts', q, k / math.sqrt(d))
    scores.masked_fill_(rearrange(~attn_mask, 'b s -> b 1 1 s'), float('-inf'))
    if causal:
        causal_mask = torch.triu(torch.ones(seqlen, seqlen, dtype=torch.bool, device=qkv.device), 1)
        scores.masked_fill_(causal_mask, float('-inf'))
    attention = torch.softmax(scores, dim=-1)
    attention_drop = F.dropout(attention, dropout_p)
    output = torch.einsum('bhts,bshd->bthd', attention_drop , v)
    # return output.to(dtype=qkv.dtype), attention.to(dtype=qkv.dtype)
    return output.to(dtype=qkv.dtype)

configurations = [
    (1, 40, 128, 512),
    (1, 40, 128, 1024),
    (1, 40, 128, 2048),
    (1, 40, 128, 4096),
    (1, 40, 128, 8192),
    (1, 8, 128, 1536),
    (1, 8, 128, 2048),
    (1, 8, 128, 3072),
    (1, 8, 128, 6144),
    (1, 16, 128, 1536),
    (1, 16, 128, 2048),
    (1, 16, 128, 3072),
    (1, 16, 128, 6144),
    (1, 64, 128, 2048),
    (1, 64, 128, 4096),
    (1, 64, 128, 8192)
]
time_f = {}
time_b = {}
causal_vals = [False, True]
total_iterations = len(causal_vals)* len(configurations)
# progress_bar = tqdm(total=total_iterations, desc="Processing Configurations")
torch.manual_seed(0)
repeats = 100
batch_size = 64
nheads = 16
seqlen = 1024
n = 1024
d = n // nheads
dropout_p = 0.0

dtype = torch.float16
device = 'cuda'
for causal in causal_vals: # This loop may not be necessary if you're only using headdim=128
    for config_4 in configurations:
        batch_size, nheads, d, seqlen = config_4
        config = (causal, batch_size, nheads, d, seqlen)
        n=nheads*d
        x = torch.randn(batch_size, seqlen, n, device='cuda', dtype=dtype, requires_grad=True)
        Wqkv = torch.nn.Linear(nheads * d, 3 * nheads * d, device=device, dtype=dtype)

        lengths = torch.randint(seqlen - 20, seqlen, (batch_size, 1), device='cuda')
        attention_mask_bool = repeat(torch.arange(seqlen, device='cuda'), 's -> b s', b=batch_size) < lengths
        attention_mask = torch.zeros(batch_size, seqlen, device='cuda', dtype=dtype)
        attention_mask[~attention_mask_bool] = -10000.0
        attention_mask = rearrange(attention_mask, 'b s -> b 1 1 s')

        x_unpad, indices, cu_seqlens, max_seqlen_in_batch = unpad_input(x, attention_mask_bool)
        qkv_unpad = rearrange(Wqkv(x_unpad), 'nnz (t h d) -> nnz t h d', t=3,
                            h=nheads).detach().requires_grad_()
        qkv = rearrange(Wqkv(x), 'b s (t h d) -> b s t h d', t=3, h=nheads).detach().requires_grad_()

        fn = lambda qkv_unpad: flash_attn_unpadded_qkvpacked_func(
            qkv_unpad, cu_seqlens, max_seqlen_in_batch, dropout_p, causal=causal
        )
        FA1=benchmark_all(fn, qkv_unpad, repeats=repeats, desc='FlashAttention')
        fn = lambda qkv: attention_ref(qkv, attention_mask_bool, dropout_p, causal=causal)
        # pytorch=benchmark_all(fn, qkv, repeats=repeats, desc='PyTorch Standard Attention')
        time_f[config, "Flash1"] = FA1[0][1].mean
        time_b[config, "Flash1"] = FA1[1][1].mean
        

FlashAttention - Forward pass
<torch.utils.benchmark.utils.common.Measurement object at 0x7fbd4ecbc7f0>
fn_amp(*inputs, **kwinputs)
  184.84 us
  1 measurement, 100 runs , 10 threads
FlashAttention - Backward pass
<torch.utils.benchmark.utils.common.Measurement object at 0x7fbd4ecbc790>
y.backward(grad, retain_graph=True)
  496.08 us
  1 measurement, 100 runs , 10 threads
FlashAttention - Forward + Backward pass
<torch.utils.benchmark.utils.common.Measurement object at 0x7fbd4ecd9160>
f(grad, *inputs, **kwinputs)
  423.75 us
  1 measurement, 100 runs , 10 threads
FlashAttention - Forward pass
<torch.utils.benchmark.utils.common.Measurement object at 0x7fbd4ecd9100>
fn_amp(*inputs, **kwinputs)
  336.10 us
  1 measurement, 100 runs , 10 threads
FlashAttention - Backward pass
<torch.utils.benchmark.utils.common.Measurement object at 0x7fbd4ecd9130>
y.backward(grad, retain_graph=True)
  854.57 us
  1 measurement, 100 runs , 10 threads
FlashAttention - Forward + Backward pass
<torch.utils.b

In [4]:
import pandas as pd
# 将 (config, method) 结构转换为分开的多列数据
df_time_f = pd.DataFrame([(causal, batch_size, nheads, headdim, seqlen, method, t) 
                          for ((causal, batch_size, nheads, headdim, seqlen), method), t in time_f.items()], 
                          columns=['Causal', 'BatchSize','nHeads','HeadDim', 'SeqLen', 'Method', 'Time F'])

df_time_b = pd.DataFrame([(causal, batch_size, nheads, headdim, seqlen, method, t) 
                          for ((causal, batch_size, nheads, headdim, seqlen), method), t in time_b.items()], 
                          columns=['Causal',  'BatchSize','nHeads','HeadDim', 'SeqLen', 'Method', 'Time B'])
df_time_f = df_time_f[df_time_f['Method'] != 'Triton']
df_time_b = df_time_b[df_time_b['Method'] != 'Triton']
# 保存 DataFrame 到 Excel 文件
with pd.ExcelWriter('times11.xlsx') as writer:
    df_time_f.to_excel(writer, sheet_name='Forward Times', index=False)
    df_time_b.to_excel(writer, sheet_name='Backward Times', index=False)

In [None]:
FA1[0][1].mean

0.0036811568463842076