-
Notifications
You must be signed in to change notification settings - Fork 4.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Inference]Move benchmark-related code to the example directory. (#5408)
* move benchmark-related code to the example directory. * fix bugs in test_fused_rotary_embedding.py
- Loading branch information
1 parent
600881a
commit 0aa27f1
Showing
11 changed files
with
479 additions
and
433 deletions.
There are no files selected for viewing
113 changes: 113 additions & 0 deletions
113
examples/inference/benchmark_ops/benchmark_context_attn_unpad.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
import torch | ||
from transformers.modeling_attn_mask_utils import AttentionMaskConverter | ||
|
||
from colossalai.inference.modeling.layers.attention import PagedAttention | ||
from colossalai.kernel.triton import context_attention_unpadded | ||
from colossalai.utils import get_current_device | ||
from tests.test_infer.test_ops.triton.kernel_utils import generate_caches_and_block_tables_v2, torch_attn_ref | ||
|
||
try: | ||
import triton # noqa | ||
|
||
except ImportError: | ||
print("please install triton from https://github.com/openai/triton") | ||
|
||
HEAD_DIM = 32 | ||
BATCH = 16 | ||
BLOCK_SIZE = 32 | ||
SAME_LEN = True | ||
WARM_UPS = 10 | ||
REPS = 100 | ||
configs = [ | ||
triton.testing.Benchmark( | ||
x_names=["KV_LEN"], | ||
x_vals=[2**i for i in range(8, 13)], | ||
# x_vals=[x for x in range(256, 8192, 256)], | ||
line_arg="provider", | ||
line_vals=["torch", "triton"], | ||
line_names=["Torch", "Triton"], | ||
styles=[("red", "-"), ("blue", "-")], | ||
ylabel="ms", | ||
plot_name=f"context_attn-block_size-{BLOCK_SIZE}-batch{BATCH}", | ||
args={"bsz": BATCH, "block_size": BLOCK_SIZE, "same_context_len": SAME_LEN, "kv_group_num": 1}, | ||
) | ||
] | ||
|
||
|
||
@triton.testing.perf_report(configs) | ||
def bench_kernel( | ||
bsz, | ||
KV_LEN, | ||
provider, | ||
block_size: int, | ||
kv_group_num: int, | ||
same_context_len: bool, | ||
): | ||
num_attn_heads = 16 | ||
max_num_blocks_per_seq = triton.cdiv(KV_LEN, block_size) | ||
max_seq_len = block_size * max_num_blocks_per_seq | ||
|
||
num_kv_heads = num_attn_heads // kv_group_num | ||
assert isinstance(num_kv_heads, int) and num_kv_heads > 0, "Invalid number of kv heads." | ||
dtype = torch.float16 | ||
device = get_current_device() | ||
|
||
if same_context_len: | ||
context_lengths = torch.tensor([max_seq_len for _ in range(bsz)], dtype=torch.int32, device=device) | ||
else: | ||
context_lengths = torch.randint(low=1, high=max_seq_len, size=(bsz,), dtype=torch.int32, device=device) | ||
num_tokens = torch.sum(context_lengths).item() | ||
|
||
qkv_size = (num_tokens, num_attn_heads + 2 * num_kv_heads, HEAD_DIM) | ||
qkv_unpad = torch.empty(size=qkv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5) | ||
q_unpad, k_unpad, v_unpad = torch.split(qkv_unpad, [num_attn_heads, num_kv_heads, num_kv_heads], dim=-2) | ||
q_unpad = q_unpad.contiguous() | ||
k_cache_ref, v_cache_ref, block_tables = generate_caches_and_block_tables_v2( | ||
k_unpad, v_unpad, context_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device | ||
) | ||
block_tables = block_tables.to(device=device) | ||
|
||
quantiles = [0.5, 0.2, 0.8] | ||
if provider == "torch": | ||
q_padded = PagedAttention.pad_and_reshape(q_unpad, context_lengths, max_seq_len, num_attn_heads, HEAD_DIM) | ||
k_padded = PagedAttention.pad_and_reshape(k_unpad, context_lengths, max_seq_len, num_kv_heads, HEAD_DIM) | ||
v_padded = PagedAttention.pad_and_reshape(v_unpad, context_lengths, max_seq_len, num_kv_heads, HEAD_DIM) | ||
q_padded, k_padded, v_padded = ( | ||
q_padded.to(device=device), | ||
k_padded.to(device=device), | ||
v_padded.to(device=device), | ||
) | ||
q_padded = q_padded.transpose(1, 2) | ||
k_padded = PagedAttention.repeat_kv(k_padded.transpose(1, 2), kv_group_num) | ||
v_padded = PagedAttention.repeat_kv(v_padded.transpose(1, 2), kv_group_num) | ||
# This benchmark ignores the padding mask. *Only* use the-same-length inputs for benchmarkings | ||
attn_mask = AttentionMaskConverter._make_causal_mask( | ||
(bsz, max_seq_len), q_padded.dtype, q_padded.device, past_key_values_length=0 | ||
) | ||
attn_mask = attn_mask.to(device=q_padded.device) | ||
fn = lambda: torch_attn_ref( | ||
q_padded, | ||
k_padded, | ||
v_padded, | ||
attn_mask, | ||
bsz, | ||
max_seq_len, | ||
max_seq_len, | ||
num_attn_heads, | ||
num_kv_heads, | ||
HEAD_DIM, | ||
) | ||
ms, min_ms, max_ms = triton.testing.do_bench(fn, warmup=WARM_UPS, rep=REPS, quantiles=quantiles) | ||
if provider == "triton": | ||
k_cache_triton = torch.zeros_like(k_cache_ref) | ||
v_cache_triton = torch.zeros_like(v_cache_ref) | ||
fn = lambda: context_attention_unpadded( | ||
q_unpad, k_unpad, v_unpad, k_cache_triton, v_cache_triton, context_lengths, block_tables, block_size | ||
) | ||
ms, min_ms, max_ms = triton.testing.do_bench(fn, warmup=WARM_UPS, rep=REPS, quantiles=quantiles) | ||
|
||
return ms, min_ms, max_ms | ||
|
||
|
||
if __name__ == "__main__": | ||
bench_kernel.run(save_path=".", print_data=True) |
110 changes: 110 additions & 0 deletions
110
examples/inference/benchmark_ops/benchmark_decoding_attn.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
import torch | ||
|
||
from colossalai.kernel.triton import flash_decoding_attention | ||
from colossalai.utils import get_current_device | ||
from tests.test_infer.test_ops.triton.kernel_utils import ( | ||
convert_kv_unpad_to_padded, | ||
generate_caches_and_block_tables_v2, | ||
prepare_padding_mask, | ||
torch_attn_ref, | ||
) | ||
from tests.test_infer.test_ops.triton.test_decoding_attn import prepare_data | ||
|
||
try: | ||
import triton # noqa | ||
|
||
except ImportError: | ||
print("please install triton from https://github.com/openai/triton") | ||
|
||
Q_LEN = 1 | ||
HEAD_DIM = 128 | ||
BATCH = 16 | ||
BLOCK_SIZE = 32 | ||
SAME_LEN = True | ||
WARM_UPS = 10 | ||
REPS = 100 | ||
configs = [ | ||
triton.testing.Benchmark( | ||
x_names=["KV_LEN"], | ||
x_vals=[2**i for i in range(8, 14)], | ||
# x_vals=[x for x in range(256, 8192, 256)], | ||
line_arg="provider", | ||
line_vals=["torch", "triton"], | ||
line_names=["Torch", "Triton"], | ||
styles=[("red", "-"), ("blue", "-")], | ||
ylabel="ms", | ||
plot_name=f"decoding-block_size-{BLOCK_SIZE}-batch{BATCH}", | ||
args={"bsz": BATCH, "block_size": BLOCK_SIZE, "same_context_len": SAME_LEN, "kv_group_num": 1}, | ||
) | ||
] | ||
|
||
|
||
@triton.testing.perf_report(configs) | ||
def bench_kernel( | ||
bsz, | ||
KV_LEN, | ||
provider, | ||
block_size: int, | ||
kv_group_num: int, | ||
same_context_len: bool, | ||
): | ||
num_attn_heads = 16 | ||
max_num_blocks_per_seq = triton.cdiv(KV_LEN, block_size) | ||
max_seq_len = block_size * max_num_blocks_per_seq | ||
|
||
num_kv_heads = num_attn_heads // kv_group_num | ||
assert isinstance(num_kv_heads, int) and num_kv_heads > 0, "Invalid number of kv heads." | ||
block_size * max_num_blocks_per_seq | ||
dtype = torch.float16 | ||
device = get_current_device() | ||
|
||
q, k_unpad, v_unpad, kv_lengths = prepare_data( | ||
bsz, num_attn_heads, num_kv_heads, HEAD_DIM, same_context_len, Q_LEN, max_seq_len, dtype, device | ||
) | ||
max_seq_len_in_b = kv_lengths.max().item() # for random lengths | ||
|
||
quantiles = [0.5, 0.2, 0.8] | ||
if provider == "torch": | ||
k_torch = convert_kv_unpad_to_padded(k_unpad, kv_lengths, bsz, max_seq_len_in_b) | ||
v_torch = convert_kv_unpad_to_padded(v_unpad, kv_lengths, bsz, max_seq_len_in_b) | ||
torch_padding_mask = prepare_padding_mask(kv_lengths, bsz, max_seq_len_in_b, q.device) | ||
fn = lambda: torch_attn_ref( | ||
q, k_torch, v_torch, torch_padding_mask, bsz, 1, max_seq_len_in_b, num_attn_heads, num_kv_heads, HEAD_DIM | ||
) | ||
ms, min_ms, max_ms = triton.testing.do_bench(fn, warmup=WARM_UPS, rep=REPS, quantiles=quantiles) | ||
if provider == "triton": | ||
k_cache, v_cache, block_tables = generate_caches_and_block_tables_v2( | ||
k_unpad, v_unpad, kv_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device | ||
) | ||
block_tables = block_tables.to(device=device) | ||
# the maximum block length splitted on kv should be the kv cache block size | ||
kv_max_split_num = (max_seq_len_in_b + block_size - 1) // block_size | ||
output = torch.empty((bsz, num_attn_heads, HEAD_DIM), dtype=dtype, device=device) | ||
mid_output = torch.empty( | ||
size=(bsz, num_attn_heads, kv_max_split_num, HEAD_DIM), dtype=torch.float32, device=q.device | ||
) | ||
mid_output_lse = torch.empty(size=(bsz, num_attn_heads, kv_max_split_num), dtype=torch.float32, device=q.device) | ||
sm_scale = 1.0 / (HEAD_DIM**0.5) | ||
fn = lambda: flash_decoding_attention( | ||
# Here we use q.squeeze(2) because we hide the q_len dimension (which is equivalent to 1), | ||
# refer to attention forward in modeling. | ||
q.squeeze(2), | ||
k_cache, | ||
v_cache, | ||
kv_lengths, | ||
block_tables, | ||
block_size, | ||
max_seq_len_in_b, | ||
output, | ||
mid_output, | ||
mid_output_lse, | ||
sm_scale=sm_scale, | ||
kv_group_num=kv_group_num, | ||
) # [bsz, 1, num_heads, head_dim] | ||
ms, min_ms, max_ms = triton.testing.do_bench(fn, warmup=WARM_UPS, rep=REPS, quantiles=quantiles) | ||
|
||
return ms, min_ms, max_ms | ||
|
||
|
||
if __name__ == "__main__": | ||
bench_kernel.run(save_path=".", print_data=True) |
65 changes: 65 additions & 0 deletions
65
examples/inference/benchmark_ops/benchmark_fused_rotary_embedding.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
import torch | ||
import triton | ||
|
||
from colossalai.kernel.triton.fused_rotary_embedding import fused_rotary_embedding | ||
|
||
BATCH = 16 | ||
configs = [ | ||
triton.testing.Benchmark( | ||
x_names=["num_tokens"], | ||
x_vals=[2**i for i in range(4, 12)], | ||
line_arg="provider", | ||
line_vals=["torch_rotary_emb_func", "triton_rotary_emb_func"], | ||
line_names=["torch_rotary_emb_func", "triton_rotary_emb_func"], | ||
styles=[("red", "-"), ("blue", "-")], | ||
ylabel="ms", | ||
plot_name=f"rotary_emb-batch-{BATCH}", | ||
args={"num_kv_heads": 16}, | ||
) | ||
] | ||
|
||
|
||
def torch_rotary_emb(x, cos, sin): | ||
seq_len, h, dim = x.shape | ||
x0 = x[:, :, 0 : dim // 2] | ||
x1 = x[:, :, dim // 2 : dim] | ||
cos = cos.view((seq_len, 1, dim // 2)) | ||
sin = sin.view((seq_len, 1, dim // 2)) | ||
o0 = x0 * cos - x1 * sin | ||
o1 = x0 * sin + x1 * cos | ||
return torch.cat((o0, o1), dim=-1) | ||
|
||
|
||
@triton.testing.perf_report(configs) | ||
def benchmark_rotary_emb( | ||
provider: str, | ||
num_tokens: int, | ||
num_kv_heads: int, | ||
): | ||
warmup = 10 | ||
rep = 100 | ||
|
||
head_dim = 128 | ||
dtype = torch.float16 | ||
q_shape = (num_tokens, num_kv_heads, head_dim) | ||
q = -2.3 + 0.5 * torch.randn(q_shape, dtype=dtype, device="cuda") | ||
k_shape = (num_tokens, num_kv_heads, head_dim) | ||
k = -2.3 + 0.5 * torch.randn(k_shape, dtype=dtype, device="cuda") | ||
cos_shape = (4096, head_dim // 2) | ||
cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda") | ||
sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda") | ||
lengths = torch.tensor([3, 4, 6, 7], device="cuda") | ||
|
||
if provider == "torch_rotary_emb_func": | ||
fn = lambda: torch_rotary_emb(q, cos[:num_tokens], sin[:num_tokens]) | ||
elif provider == "triton_rotary_emb_func": | ||
fn = lambda: fused_rotary_embedding(q, k, cos, sin, lengths) | ||
else: | ||
raise ValueError("Undefined provider") | ||
|
||
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) | ||
return ms | ||
|
||
|
||
if __name__ == "__main__": | ||
benchmark_rotary_emb.run(save_path=".", print_data=True) |
78 changes: 78 additions & 0 deletions
78
examples/inference/benchmark_ops/benchmark_rmsnorm_triton.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
import torch | ||
import triton | ||
|
||
from colossalai.kernel.triton import rms_layernorm | ||
|
||
try: | ||
import triton # noqa | ||
|
||
except ImportError: | ||
print("please install triton from https://github.com/openai/triton") | ||
|
||
|
||
# Triton benchmark plot attributions | ||
configs = [ | ||
triton.testing.Benchmark( | ||
x_names=["SEQUENCE_TOTAL"], | ||
x_vals=[i for i in range(128, 1025, 128)], | ||
line_arg="provider", | ||
line_vals=[ | ||
"vllm_rms_layernorm", | ||
"triton_rms_layernorm", | ||
"triton_rms_layernorm_with_residual", | ||
"vllm_rms_layernorm_with_residual", | ||
], | ||
line_names=[ | ||
"vllm_rms_layernorm", | ||
"triton_rms_layernorm", | ||
"triton_rms_layernorm_with_residual", | ||
"vllm_rms_layernorm_with_residual", | ||
], | ||
styles=[("red", "-"), ("blue", "-"), ("yellow", "-"), ("green", "-")], | ||
ylabel="ms", | ||
plot_name=f"RMSNorm benchmarking results", | ||
args={"HIDDEN_SIZE": 1024}, | ||
) | ||
] | ||
|
||
|
||
@triton.testing.perf_report(configs) | ||
def benchmark_rms_layernorm( | ||
provider: str, | ||
SEQUENCE_TOTAL: int, | ||
HIDDEN_SIZE: int, | ||
): | ||
try: | ||
from vllm.model_executor.layers.layernorm import RMSNorm | ||
except ImportError: | ||
raise ImportError("Please install vllm from https://github.com/vllm-project/vllm") | ||
|
||
warmup = 10 | ||
rep = 1000 | ||
|
||
dtype = torch.float16 | ||
eps = 1e-5 | ||
x_shape = (SEQUENCE_TOTAL, HIDDEN_SIZE) | ||
w_shape = (x_shape[-1],) | ||
residual = torch.rand(x_shape, dtype=dtype, device="cuda") | ||
weight = torch.ones(w_shape, dtype=dtype, device="cuda") | ||
vllm_norm = RMSNorm(hidden_size=HIDDEN_SIZE, eps=eps).to(dtype=dtype, device="cuda") | ||
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda") | ||
if provider == "vllm_rms_layernorm": | ||
fn = lambda: vllm_norm(x) | ||
elif provider == "triton_rms_layernorm": | ||
fn = lambda: rms_layernorm(x, weight, eps=eps) | ||
elif provider == "vllm_rms_layernorm_with_residual": | ||
fn = lambda: vllm_norm(x, residual=residual) | ||
elif provider == "triton_rms_layernorm_with_residual": | ||
fn = lambda: rms_layernorm(x, weight, eps=eps, residual=residual) | ||
else: | ||
raise ValueError("Undefined provider.") | ||
|
||
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) | ||
|
||
return ms | ||
|
||
|
||
if __name__ == "__main__": | ||
benchmark_rms_layernorm.run(save_path=".", print_data=True) |
Oops, something went wrong.