In [1]:
# Install the newest triton version with
# pip install "git+https://github.com/openai/triton.git#egg=triton&subdirectory=python"
import pickle
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
from einops import rearrange, repeat
import pandas as pd
from flash_attn.utils.benchmark import benchmark_all, benchmark_forward, benchmark_backward
from flash_attn.utils.benchmark import benchmark_fwd_bwd, benchmark_combined
from tqdm import tqdm

from flash_attn import flash_attn_qkvpacked_func

try:
    from triton.ops.flash_attention import attention as attention_triton
except ImportError:
    attention_triton = None

try:
    import xformers.ops as xops
except ImportError:
    xops = None


def flops(batch, seqlen, headdim, nheads, causal, mode="fwd"):
    assert mode in ["fwd", "bwd", "fwd_bwd"]
    f = 4 * batch * seqlen**2 * nheads * headdim // (2 if causal else 1)
    return f if mode == "fwd" else (2.5 * f if mode == "bwd" else 3.5 * f)

def efficiency(flop, time):
    return (flop / time / 10**12) if not math.isnan(time) else 0.0


def attention_pytorch(qkv, dropout_p=0.0, causal=True):
    """
    Arguments:
        qkv: (batch_size, seqlen, 3, nheads, head_dim)
        dropout_p: float
    Output:
        output: (batch_size, seqlen, nheads, head_dim)
    """
    batch_size, seqlen, _, nheads, d = qkv.shape
    q, k, v = qkv.unbind(dim=2)
    q = rearrange(q, 'b t h d -> (b h) t d')
    k = rearrange(k, 'b s h d -> (b h) d s')
    softmax_scale = 1.0 / math.sqrt(d)
    # Preallocate attn_weights for `baddbmm`
    scores = torch.empty(batch_size * nheads, seqlen, seqlen, dtype=qkv.dtype, device=qkv.device)
    scores = rearrange(torch.baddbmm(scores, q, k, beta=0, alpha=softmax_scale),
                       '(b h) t s -> b h t s', h=nheads)
    if causal:
        # "triu_tril_cuda_template" not implemented for 'BFloat16'
        # So we have to construct the mask in float
        causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1)
        # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
        scores = scores + causal_mask.to(dtype=scores.dtype)
    attention = torch.softmax(scores, dim=-1)
    attention_drop = F.dropout(attention, dropout_p)
    output = torch.einsum('bhts,bshd->bthd', attention_drop , v)
    return output.to(dtype=qkv.dtype)


def time_fwd_bwd(func, *args, **kwargs):
    time_f, time_b = benchmark_fwd_bwd(func, *args, **kwargs)
    return time_f[1].mean, time_b[1].mean

repeats = 100
device = 'cuda'
dtype = torch.float16

# bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4096), (2, 8192), (1, 16384)]
causal_vals = [False, True]
# headdim_vals = [64, 128]
# dim = 2048
dropout_p = 0.0

methods = (["Flash2", "Pytorch"]
           + (["Triton"] if attention_triton is not None else [])
           + (["xformers.c"] if xops is not None else [])
           + (["xformers.f"] if xops is not None else []))
configurations = [
    (1, 40, 128, 512),
    (1, 40, 128, 1024),
    (1, 40, 128, 2048),
    (1, 40, 128, 4096),
    (1, 40, 128, 8192),
    (1, 8, 128, 1536),
    (1, 8, 128, 2048),
    (1, 8, 128, 3072),
    (1, 8, 128, 6144),
    (1, 16, 128, 1536),
    (1, 16, 128, 2048),
    (1, 16, 128, 3072),
    (1, 16, 128, 6144),
    (1, 64, 128, 2048),
    (1, 64, 128, 4096),
    (1, 64, 128, 8192)
]

total_iterations = len(causal_vals)* len(configurations)
progress_bar = tqdm(total=total_iterations, desc="Processing Configurations")
time_f = {}
time_b = {}
time_f_b = {}
speed_f = {}
speed_b = {}
speed_f_b = {}
for causal in causal_vals: # This loop may not be necessary if you're only using headdim=128
    for config_4 in configurations:
        batch_size, nheads, headdim, seqlen = config_4
        config = (causal, batch_size, nheads, headdim, seqlen)
        qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype,
                            requires_grad=True)
        f, b = time_fwd_bwd(
            flash_attn_qkvpacked_func, qkv, dropout_p, causal=causal, repeats=repeats, verbose=False
        )
        time_f[config, "Flash2"] = f
        time_b[config, "Flash2"] = b

        try:
            qkv = qkv.detach().requires_grad_(True)
            f, b = time_fwd_bwd(
                attention_pytorch, qkv, dropout_p, causal=causal, repeats=repeats, verbose=False
            )
        except:  # Skip if OOM
            f, b = float('nan'), float('nan')
        time_f[config, "Pytorch"] = f
        time_b[config, "Pytorch"] = b

        if attention_triton is not None:
            q, k, v = [torch.randn(batch_size, nheads, seqlen, headdim, device=device, dtype=dtype,
                                requires_grad=True) for _ in range(3)]
            # Try both values of sequence_parallel and pick the faster one
            try:
                f, b = time_fwd_bwd(
                    attention_triton, q, k, v, causal, headdim**(-0.5),
                    False, repeats=repeats, verbose=False
                )
            except:
                f, b = float('nan'), float('inf')
            try:
                _, b0 = time_fwd_bwd(
                    attention_triton, q, k, v, causal, headdim**(-0.5),
                    True, repeats=repeats, verbose=False
                )
            except:
                b0 = float('inf')
            time_f[config, "Triton"] = f
            time_b[config, "Triton"] = min(b, b0) if min(b, b0) < float('inf') else float('nan')

        if xops is not None:
            q, k, v = [torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype,
                                requires_grad=True) for _ in range(3)]
            f, b = time_fwd_bwd(
                xops.memory_efficient_attention, q, k, v,
                attn_bias=xops.LowerTriangularMask() if causal else None,
                op=(xops.fmha.cutlass.FwOp, xops.fmha.cutlass.BwOp)
            )
            time_f[config, "xformers.c"] = f
            time_b[config, "xformers.c"] = b

        if xops is not None:
            q, k, v = [torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype,
                                requires_grad=True) for _ in range(3)]
            f, b = time_fwd_bwd(
                xops.memory_efficient_attention, q, k, v,
                attn_bias=xops.LowerTriangularMask() if causal else None,
                op=(xops.fmha.flash.FwOp, xops.fmha.flash.BwOp)
            )
            time_f[config, "xformers.f"] = f
            time_b[config, "xformers.f"] = b
        progress_bar.update(1)

        # print(f"### causal={causal}, headdim={headdim}, batch_size={batch_size}, seqlen={seqlen} ###")
        # for method in methods:
        #     time_f_b[config, method] = time_f[config, method] + time_b[config, method]
        #     speed_f[config, method] = efficiency(
        #         flops(batch_size, seqlen, headdim, nheads, causal, mode="fwd"),
        #         time_f[config, method]
        #     )
        #     speed_b[config, method] = efficiency(
        #         flops(batch_size, seqlen, headdim, nheads, causal, mode="bwd"),
        #         time_b[config, method]
        #     )
        #     speed_f_b[config, method] = efficiency(
        #         flops(batch_size, seqlen, headdim, nheads, causal, mode="fwd_bwd"),
        #         time_f_b[config, method]
        #     )
        #     print(
        #         f"{method} fwd: {speed_f[config, method]:.2f} TFLOPs/s, "
        #         f"bwd: {speed_b[config, method]:.2f} TFLOPs/s, "
        #         f"fwd + bwd: {speed_f_b[config, method]:.2f} TFLOPs/s"

Processing Configurations: 100%|██████████| 32/32 [01:25<00:00,  6.34s/it]

In [2]:
progress_bar.close()

Processing Configurations: 100%|██████████| 32/32 [01:25<00:00,  2.68s/it]


In [3]:
time_f

{((False, 1, 40, 128, 512), 'Flash2'): 7.398700341582298e-05,
 ((False, 1, 40, 128, 512), 'Pytorch'): 0.00018832825124263765,
 ((False, 1, 40, 128, 512), 'Triton'): nan,
 ((False, 1, 40, 128, 1024), 'Flash2'): 0.0001949380338191986,
 ((False, 1, 40, 128, 1024), 'Pytorch'): 0.0005280192196369171,
 ((False, 1, 40, 128, 1024), 'Triton'): nan,
 ((False, 1, 40, 128, 2048), 'Flash2'): 0.00044373203068971635,
 ((False, 1, 40, 128, 2048), 'Pytorch'): 0.0015537716820836066,
 ((False, 1, 40, 128, 2048), 'Triton'): nan,
 ((False, 1, 40, 128, 4096), 'Flash2'): 0.00178089814260602,
 ((False, 1, 40, 128, 4096), 'Pytorch'): 0.0059799677319824695,
 ((False, 1, 40, 128, 4096), 'Triton'): nan,
 ((False, 1, 40, 128, 8192), 'Flash2'): 0.007277089785784483,
 ((False, 1, 40, 128, 8192), 'Pytorch'): 0.023374582156538964,
 ((False, 1, 40, 128, 8192), 'Triton'): nan,
 ((False, 1, 8, 128, 1536), 'Flash2'): 7.451500743627549e-05,
 ((False, 1, 8, 128, 1536), 'Pytorch'): 0.0002345100976526737,
 ((False, 1, 8, 128,

In [4]:
time_b

{((False, 1, 40, 128, 512), 'Flash2'): 0.00023322777822613717,
 ((False, 1, 40, 128, 512), 'Pytorch'): 0.000425503458827734,
 ((False, 1, 40, 128, 512), 'Triton'): nan,
 ((False, 1, 40, 128, 1024), 'Flash2'): 0.0006006010062992573,
 ((False, 1, 40, 128, 1024), 'Pytorch'): 0.000992963407188654,
 ((False, 1, 40, 128, 1024), 'Triton'): nan,
 ((False, 1, 40, 128, 2048), 'Flash2'): 0.001444813795387745,
 ((False, 1, 40, 128, 2048), 'Pytorch'): 0.003181739915162325,
 ((False, 1, 40, 128, 2048), 'Triton'): nan,
 ((False, 1, 40, 128, 4096), 'Flash2'): 0.005307868830859661,
 ((False, 1, 40, 128, 4096), 'Pytorch'): 0.011935706213116647,
 ((False, 1, 40, 128, 4096), 'Triton'): nan,
 ((False, 1, 40, 128, 8192), 'Flash2'): 0.020440635737031698,
 ((False, 1, 40, 128, 8192), 'Pytorch'): 0.043976778630167246,
 ((False, 1, 40, 128, 8192), 'Triton'): nan,
 ((False, 1, 8, 128, 1536), 'Flash2'): 0.00019406640902161598,
 ((False, 1, 8, 128, 1536), 'Pytorch'): 0.00044881651178002357,
 ((False, 1, 8, 128, 15

In [5]:
time_b

{((False, 1, 40, 128, 512), 'Flash2'): 0.00023322777822613717,
 ((False, 1, 40, 128, 512), 'Pytorch'): 0.000425503458827734,
 ((False, 1, 40, 128, 512), 'Triton'): nan,
 ((False, 1, 40, 128, 1024), 'Flash2'): 0.0006006010062992573,
 ((False, 1, 40, 128, 1024), 'Pytorch'): 0.000992963407188654,
 ((False, 1, 40, 128, 1024), 'Triton'): nan,
 ((False, 1, 40, 128, 2048), 'Flash2'): 0.001444813795387745,
 ((False, 1, 40, 128, 2048), 'Pytorch'): 0.003181739915162325,
 ((False, 1, 40, 128, 2048), 'Triton'): nan,
 ((False, 1, 40, 128, 4096), 'Flash2'): 0.005307868830859661,
 ((False, 1, 40, 128, 4096), 'Pytorch'): 0.011935706213116647,
 ((False, 1, 40, 128, 4096), 'Triton'): nan,
 ((False, 1, 40, 128, 8192), 'Flash2'): 0.020440635737031698,
 ((False, 1, 40, 128, 8192), 'Pytorch'): 0.043976778630167246,
 ((False, 1, 40, 128, 8192), 'Triton'): nan,
 ((False, 1, 8, 128, 1536), 'Flash2'): 0.00019406640902161598,
 ((False, 1, 8, 128, 1536), 'Pytorch'): 0.00044881651178002357,
 ((False, 1, 8, 128, 15

In [6]:
for (config, method), t in time_f.items():
    print(t)

7.398700341582298e-05
0.00018832825124263765
nan
0.0001949380338191986
0.0005280192196369171
nan
0.00044373203068971635
0.0015537716820836066
nan
0.00178089814260602
0.0059799677319824695
nan
0.007277089785784483
0.023374582156538964
nan
7.451500743627549e-05
0.0002345100976526737
nan
0.00014221351593732834
0.00034161580726504325
nan
0.000214122012257576
0.000809457041323185
nan
0.0008692327700555324
0.0030595683120191097
nan
0.00012004587799310684
0.00043376151472330094
nan
0.0002387201227247715
0.0006345830112695693
nan
0.0004338732920587063
0.0015640468709170819
nan
0.0017551548406481742
0.006032583843916655
nan
0.0007717288658022881
0.002488524131476879
nan
0.0029259537532925605
0.009728931300342082
nan
0.011624757833778858
0.037755873911082746
nan
6.22955895960331e-05
0.00027652382850646974
nan
0.0001347656361758709
0.0006315072998404503
nan
0.00030866755172610284
0.0025087839551270006
nan
0.0010923170484602451
0.010051079392433167
nan
0.004097110796719789
0.03983450936153531
nan


In [8]:
# 将 (config, method) 结构转换为分开的多列数据
df_time_f = pd.DataFrame([(causal, batch_size, nheads, headdim, seqlen, method, t) 
                          for ((causal, batch_size, nheads, headdim, seqlen), method), t in time_f.items()], 
                          columns=['Causal', 'BatchSize','nHeads','HeadDim', 'SeqLen', 'Method', 'Time F'])

df_time_b = pd.DataFrame([(causal, batch_size, nheads, headdim, seqlen, method, t) 
                          for ((causal, batch_size, nheads, headdim, seqlen), method), t in time_b.items()], 
                          columns=['Causal',  'BatchSize','nHeads','HeadDim', 'SeqLen', 'Method', 'Time B'])
df_time_f = df_time_f[df_time_f['Method'] != 'Triton']
df_time_b = df_time_b[df_time_b['Method'] != 'Triton']
# 保存 DataFrame 到 Excel 文件
with pd.ExcelWriter('times10.xlsx') as writer:
    df_time_f.to_excel(writer, sheet_name='Forward Times', index=False)
    df_time_b.to_excel(writer, sheet_name='Backward Times', index=False)


