Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Inference]Move benchmark-related code to the example directory. #5408

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 113 additions & 0 deletions examples/inference/benchmark_ops/benchmark_context_attn_unpad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import torch
from transformers.modeling_attn_mask_utils import AttentionMaskConverter

from colossalai.inference.modeling.layers.attention import PagedAttention
from colossalai.kernel.triton import context_attention_unpadded
from colossalai.utils import get_current_device
from tests.test_infer.test_ops.triton.kernel_utils import generate_caches_and_block_tables_v2, torch_attn_ref

try:
import triton # noqa

except ImportError:
print("please install triton from https://github.com/openai/triton")

HEAD_DIM = 32
BATCH = 16
BLOCK_SIZE = 32
SAME_LEN = True
WARM_UPS = 10
REPS = 100
configs = [
triton.testing.Benchmark(
x_names=["KV_LEN"],
x_vals=[2**i for i in range(8, 13)],
# x_vals=[x for x in range(256, 8192, 256)],
line_arg="provider",
line_vals=["torch", "triton"],
line_names=["Torch", "Triton"],
styles=[("red", "-"), ("blue", "-")],
ylabel="ms",
plot_name=f"context_attn-block_size-{BLOCK_SIZE}-batch{BATCH}",
args={"bsz": BATCH, "block_size": BLOCK_SIZE, "same_context_len": SAME_LEN, "kv_group_num": 1},
)
]


@triton.testing.perf_report(configs)
def bench_kernel(
bsz,
KV_LEN,
provider,
block_size: int,
kv_group_num: int,
same_context_len: bool,
):
num_attn_heads = 16
max_num_blocks_per_seq = triton.cdiv(KV_LEN, block_size)
max_seq_len = block_size * max_num_blocks_per_seq

num_kv_heads = num_attn_heads // kv_group_num
assert isinstance(num_kv_heads, int) and num_kv_heads > 0, "Invalid number of kv heads."
dtype = torch.float16
device = get_current_device()

if same_context_len:
context_lengths = torch.tensor([max_seq_len for _ in range(bsz)], dtype=torch.int32, device=device)
else:
context_lengths = torch.randint(low=1, high=max_seq_len, size=(bsz,), dtype=torch.int32, device=device)
num_tokens = torch.sum(context_lengths).item()

qkv_size = (num_tokens, num_attn_heads + 2 * num_kv_heads, HEAD_DIM)
qkv_unpad = torch.empty(size=qkv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5)
q_unpad, k_unpad, v_unpad = torch.split(qkv_unpad, [num_attn_heads, num_kv_heads, num_kv_heads], dim=-2)
q_unpad = q_unpad.contiguous()
k_cache_ref, v_cache_ref, block_tables = generate_caches_and_block_tables_v2(
k_unpad, v_unpad, context_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device
)
block_tables = block_tables.to(device=device)

quantiles = [0.5, 0.2, 0.8]
if provider == "torch":
q_padded = PagedAttention.pad_and_reshape(q_unpad, context_lengths, max_seq_len, num_attn_heads, HEAD_DIM)
k_padded = PagedAttention.pad_and_reshape(k_unpad, context_lengths, max_seq_len, num_kv_heads, HEAD_DIM)
v_padded = PagedAttention.pad_and_reshape(v_unpad, context_lengths, max_seq_len, num_kv_heads, HEAD_DIM)
q_padded, k_padded, v_padded = (
q_padded.to(device=device),
k_padded.to(device=device),
v_padded.to(device=device),
)
q_padded = q_padded.transpose(1, 2)
k_padded = PagedAttention.repeat_kv(k_padded.transpose(1, 2), kv_group_num)
v_padded = PagedAttention.repeat_kv(v_padded.transpose(1, 2), kv_group_num)
# This benchmark ignores the padding mask. *Only* use the-same-length inputs for benchmarkings
attn_mask = AttentionMaskConverter._make_causal_mask(
(bsz, max_seq_len), q_padded.dtype, q_padded.device, past_key_values_length=0
)
attn_mask = attn_mask.to(device=q_padded.device)
fn = lambda: torch_attn_ref(
q_padded,
k_padded,
v_padded,
attn_mask,
bsz,
max_seq_len,
max_seq_len,
num_attn_heads,
num_kv_heads,
HEAD_DIM,
)
ms, min_ms, max_ms = triton.testing.do_bench(fn, warmup=WARM_UPS, rep=REPS, quantiles=quantiles)
if provider == "triton":
k_cache_triton = torch.zeros_like(k_cache_ref)
v_cache_triton = torch.zeros_like(v_cache_ref)
fn = lambda: context_attention_unpadded(
q_unpad, k_unpad, v_unpad, k_cache_triton, v_cache_triton, context_lengths, block_tables, block_size
)
ms, min_ms, max_ms = triton.testing.do_bench(fn, warmup=WARM_UPS, rep=REPS, quantiles=quantiles)

return ms, min_ms, max_ms


if __name__ == "__main__":
bench_kernel.run(save_path=".", print_data=True)
110 changes: 110 additions & 0 deletions examples/inference/benchmark_ops/benchmark_decoding_attn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import torch

from colossalai.kernel.triton import flash_decoding_attention
from colossalai.utils import get_current_device
from tests.test_infer.test_ops.triton.kernel_utils import (
convert_kv_unpad_to_padded,
generate_caches_and_block_tables_v2,
prepare_padding_mask,
torch_attn_ref,
)
from tests.test_infer.test_ops.triton.test_decoding_attn import prepare_data

try:
import triton # noqa

except ImportError:
print("please install triton from https://github.com/openai/triton")

Q_LEN = 1
HEAD_DIM = 128
BATCH = 16
BLOCK_SIZE = 32
SAME_LEN = True
WARM_UPS = 10
REPS = 100
configs = [
triton.testing.Benchmark(
x_names=["KV_LEN"],
x_vals=[2**i for i in range(8, 14)],
# x_vals=[x for x in range(256, 8192, 256)],
line_arg="provider",
line_vals=["torch", "triton"],
line_names=["Torch", "Triton"],
styles=[("red", "-"), ("blue", "-")],
ylabel="ms",
plot_name=f"decoding-block_size-{BLOCK_SIZE}-batch{BATCH}",
args={"bsz": BATCH, "block_size": BLOCK_SIZE, "same_context_len": SAME_LEN, "kv_group_num": 1},
)
]


@triton.testing.perf_report(configs)
def bench_kernel(
bsz,
KV_LEN,
provider,
block_size: int,
kv_group_num: int,
same_context_len: bool,
):
num_attn_heads = 16
max_num_blocks_per_seq = triton.cdiv(KV_LEN, block_size)
max_seq_len = block_size * max_num_blocks_per_seq

num_kv_heads = num_attn_heads // kv_group_num
assert isinstance(num_kv_heads, int) and num_kv_heads > 0, "Invalid number of kv heads."
block_size * max_num_blocks_per_seq
dtype = torch.float16
device = get_current_device()

q, k_unpad, v_unpad, kv_lengths = prepare_data(
bsz, num_attn_heads, num_kv_heads, HEAD_DIM, same_context_len, Q_LEN, max_seq_len, dtype, device
)
max_seq_len_in_b = kv_lengths.max().item() # for random lengths

quantiles = [0.5, 0.2, 0.8]
if provider == "torch":
k_torch = convert_kv_unpad_to_padded(k_unpad, kv_lengths, bsz, max_seq_len_in_b)
v_torch = convert_kv_unpad_to_padded(v_unpad, kv_lengths, bsz, max_seq_len_in_b)
torch_padding_mask = prepare_padding_mask(kv_lengths, bsz, max_seq_len_in_b, q.device)
fn = lambda: torch_attn_ref(
q, k_torch, v_torch, torch_padding_mask, bsz, 1, max_seq_len_in_b, num_attn_heads, num_kv_heads, HEAD_DIM
)
ms, min_ms, max_ms = triton.testing.do_bench(fn, warmup=WARM_UPS, rep=REPS, quantiles=quantiles)
if provider == "triton":
k_cache, v_cache, block_tables = generate_caches_and_block_tables_v2(
k_unpad, v_unpad, kv_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device
)
block_tables = block_tables.to(device=device)
# the maximum block length splitted on kv should be the kv cache block size
kv_max_split_num = (max_seq_len_in_b + block_size - 1) // block_size
output = torch.empty((bsz, num_attn_heads, HEAD_DIM), dtype=dtype, device=device)
mid_output = torch.empty(
size=(bsz, num_attn_heads, kv_max_split_num, HEAD_DIM), dtype=torch.float32, device=q.device
)
mid_output_lse = torch.empty(size=(bsz, num_attn_heads, kv_max_split_num), dtype=torch.float32, device=q.device)
sm_scale = 1.0 / (HEAD_DIM**0.5)
fn = lambda: flash_decoding_attention(
# Here we use q.squeeze(2) because we hide the q_len dimension (which is equivalent to 1),
# refer to attention forward in modeling.
q.squeeze(2),
k_cache,
v_cache,
kv_lengths,
block_tables,
block_size,
max_seq_len_in_b,
output,
mid_output,
mid_output_lse,
sm_scale=sm_scale,
kv_group_num=kv_group_num,
) # [bsz, 1, num_heads, head_dim]
ms, min_ms, max_ms = triton.testing.do_bench(fn, warmup=WARM_UPS, rep=REPS, quantiles=quantiles)

return ms, min_ms, max_ms


if __name__ == "__main__":
bench_kernel.run(save_path=".", print_data=True)
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import torch
import triton

from colossalai.kernel.triton.fused_rotary_embedding import fused_rotary_embedding

BATCH = 16
configs = [
triton.testing.Benchmark(
x_names=["num_tokens"],
x_vals=[2**i for i in range(4, 12)],
line_arg="provider",
line_vals=["torch_rotary_emb_func", "triton_rotary_emb_func"],
line_names=["torch_rotary_emb_func", "triton_rotary_emb_func"],
styles=[("red", "-"), ("blue", "-")],
ylabel="ms",
plot_name=f"rotary_emb-batch-{BATCH}",
args={"num_kv_heads": 16},
)
]


def torch_rotary_emb(x, cos, sin):
seq_len, h, dim = x.shape
x0 = x[:, :, 0 : dim // 2]
x1 = x[:, :, dim // 2 : dim]
cos = cos.view((seq_len, 1, dim // 2))
sin = sin.view((seq_len, 1, dim // 2))
o0 = x0 * cos - x1 * sin
o1 = x0 * sin + x1 * cos
return torch.cat((o0, o1), dim=-1)


@triton.testing.perf_report(configs)
def benchmark_rotary_emb(
provider: str,
num_tokens: int,
num_kv_heads: int,
):
warmup = 10
rep = 100

head_dim = 128
dtype = torch.float16
q_shape = (num_tokens, num_kv_heads, head_dim)
q = -2.3 + 0.5 * torch.randn(q_shape, dtype=dtype, device="cuda")
k_shape = (num_tokens, num_kv_heads, head_dim)
k = -2.3 + 0.5 * torch.randn(k_shape, dtype=dtype, device="cuda")
cos_shape = (4096, head_dim // 2)
cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
lengths = torch.tensor([3, 4, 6, 7], device="cuda")

if provider == "torch_rotary_emb_func":
fn = lambda: torch_rotary_emb(q, cos[:num_tokens], sin[:num_tokens])
elif provider == "triton_rotary_emb_func":
fn = lambda: fused_rotary_embedding(q, k, cos, sin, lengths)
else:
raise ValueError("Undefined provider")

ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
return ms


if __name__ == "__main__":
benchmark_rotary_emb.run(save_path=".", print_data=True)
78 changes: 78 additions & 0 deletions examples/inference/benchmark_ops/benchmark_rmsnorm_triton.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import torch
import triton

from colossalai.kernel.triton import rms_layernorm

try:
import triton # noqa

except ImportError:
print("please install triton from https://github.com/openai/triton")


# Triton benchmark plot attributions
configs = [
triton.testing.Benchmark(
x_names=["SEQUENCE_TOTAL"],
x_vals=[i for i in range(128, 1025, 128)],
line_arg="provider",
line_vals=[
"vllm_rms_layernorm",
"triton_rms_layernorm",
"triton_rms_layernorm_with_residual",
"vllm_rms_layernorm_with_residual",
],
line_names=[
"vllm_rms_layernorm",
"triton_rms_layernorm",
"triton_rms_layernorm_with_residual",
"vllm_rms_layernorm_with_residual",
],
styles=[("red", "-"), ("blue", "-"), ("yellow", "-"), ("green", "-")],
ylabel="ms",
plot_name=f"RMSNorm benchmarking results",
args={"HIDDEN_SIZE": 1024},
)
]


@triton.testing.perf_report(configs)
def benchmark_rms_layernorm(
provider: str,
SEQUENCE_TOTAL: int,
HIDDEN_SIZE: int,
):
try:
from vllm.model_executor.layers.layernorm import RMSNorm
except ImportError:
raise ImportError("Please install vllm from https://github.com/vllm-project/vllm")

warmup = 10
rep = 1000

dtype = torch.float16
eps = 1e-5
x_shape = (SEQUENCE_TOTAL, HIDDEN_SIZE)
w_shape = (x_shape[-1],)
residual = torch.rand(x_shape, dtype=dtype, device="cuda")
weight = torch.ones(w_shape, dtype=dtype, device="cuda")
vllm_norm = RMSNorm(hidden_size=HIDDEN_SIZE, eps=eps).to(dtype=dtype, device="cuda")
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda")
if provider == "vllm_rms_layernorm":
fn = lambda: vllm_norm(x)
elif provider == "triton_rms_layernorm":
fn = lambda: rms_layernorm(x, weight, eps=eps)
elif provider == "vllm_rms_layernorm_with_residual":
fn = lambda: vllm_norm(x, residual=residual)
elif provider == "triton_rms_layernorm_with_residual":
fn = lambda: rms_layernorm(x, weight, eps=eps, residual=residual)
else:
raise ValueError("Undefined provider.")

ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)

return ms


if __name__ == "__main__":
benchmark_rms_layernorm.run(save_path=".", print_data=True)
Loading
Loading