In [1]:
import torch
import torch.utils.benchmark as benchmark
from torch.backends.cuda import sdp_kernel, SDPBackend
# from torch.nn.attention import sdp_kernel
import torch.nn.functional as F
import os
import numpy as np
import random

DEFAULT_SEED = 42

def set_seed(seed=DEFAULT_SEED):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True


# os.environ['CUDA_VISIBLE_DEVICES'] = '2'
# Lets define a helpful benchmarking function:
def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
    t0 = benchmark.Timer(
        stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
    )
    return t0.blocked_autorange().mean * 1e6

# Lets define the hyper-parameters of our input
batch_size = 32
max_sequence_len = 1024
num_heads = 32
embed_dimension = 32

dtype = torch.float16
device = "cuda:0"

query = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)
key = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)
value = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)

print(f"The default implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds")

# Lets explore the speed of each of the 3 implementations

# Helpful arg mapper
backend_map = {
    SDPBackend.MATH: {"enable_math": True, "enable_flash": False, "enable_mem_efficient": False},
    SDPBackend.FLASH_ATTENTION: {"enable_math": False, "enable_flash": True, "enable_mem_efficient": False},
    SDPBackend.EFFICIENT_ATTENTION: {
        "enable_math": False, "enable_flash": False, "enable_mem_efficient": True}
}

with sdp_kernel(**backend_map[SDPBackend.MATH]):
    print(f"The math implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds")

with sdp_kernel(**backend_map[SDPBackend.FLASH_ATTENTION]):
    try:
        print(f"The flash attention implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds")
    except RuntimeError:
        print("FlashAttention is not supported. See warnings for reasons.")

with sdp_kernel(**backend_map[SDPBackend.EFFICIENT_ATTENTION]):
    try:
        print(f"The memory efficient implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds")
    except RuntimeError:
        print("EfficientAttention is not supported. See warnings for reasons.")


The default implementation runs in 927.106 microseconds




The math implementation runs in 10094.952 microseconds
The flash attention implementation runs in 918.287 microseconds
The memory efficient implementation runs in 1782.254 microseconds


In [1]:
dtype = torch.bfloat16

set_seed()
query = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)
key = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)
value = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)


# with sdp_kernel(**backend_map[SDPBackend.FLASH_ATTENTION]):
# print(F.scaled_dot_product_attention(query, key, value))
    


In [2]:
from decimal import Decimal
import decimal

def data_round(x):
    return str(Decimal(x).quantize(Decimal('0.0000'), rounding=decimal.ROUND_HALF_UP))


def res_trans(res):
    print(data_round(res["ACC"]))
    print(data_round(res["F1"]))
    print(data_round(res["Precision"]))
    print(data_round(res["Recall"]))

res = {
        "model": "Mistral-7B-Instruct-v0.3",
        "train_test_split": "8:2",
        "train_ratio": "1.0",
        "train_loss": 0.06997728,
        "lr": "1.4e-4",
        "ACC": 0.9877049180327869,
        "F1": 0.9832533980502562,
        "Precision": 0.9883868032359371,
        "Recall": 0.9783695652173913
    }

res_trans(res)


0.9877
0.9833
0.9884
0.9784


In [4]:
import numpy as np
x = np.array([0.9869,	0.9877,	0.9852,	0.9877,	0.9877])
x.mean()
# 0.98672 0.98622 0.9852399999999999

0.98704

In [3]:
25 * 0.04

1.0