Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Inference/kernel]Add Fused Rotary Embedding and KVCache Memcopy CUDA Kernel #5418

Merged
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 22 additions & 6 deletions colossalai/inference/modeling/models/nopadding_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ def llama_model_forward(
# NOTE: After testing, the performance of this configuration is relatively good. With updates
# and optimizations to the CUDA kernel implementation, a more detailed analysis of this configuration's
# selection should be conducted.
if batch_size >= 32 and kv_seq_len > 512:
use_cuda_kernel = False
# if batch_size >= 32 and kv_seq_len > 512:
# use_cuda_kernel = False
FrankLeeeee marked this conversation as resolved.
Show resolved Hide resolved

hidden_states = self.embed_tokens(input_ids)

Expand Down Expand Up @@ -298,8 +298,12 @@ def forward(
)

block_size = k_cache.size(-2)

if is_prompts:
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
if use_cuda_kernel:
inference_ops.rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
else:
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
attn_output = context_attention_unpadded(
q=query_states,
k=key_states,
Expand All @@ -315,9 +319,21 @@ def forward(
)
else:
if use_cuda_kernel:
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
inference_ops.decode_kv_cache_memcpy(
key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables
# using non-fused operation
# inference_ops.rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
# inference_ops.decode_kv_cache_memcpy(
# key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables
# )
FrankLeeeee marked this conversation as resolved.
Show resolved Hide resolved
inference_ops.rotary_embedding_and_cache_copy(
query_states,
key_states,
value_states,
cos_sin[0],
cos_sin[1],
k_cache,
v_cache,
sequence_lengths,
block_tables,
)
else:
decoding_fused_rotary_embedding(
Expand Down
4 changes: 2 additions & 2 deletions colossalai/inference/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,5 +47,5 @@ def init_to_get_rotary(self, base=10000, use_elem=False):
t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor
freqs = torch.outer(t, inv_freq)

self._cos_cached = torch.cos(freqs).to(torch.float16).cuda()
self._sin_cached = torch.sin(freqs).to(torch.float16).cuda()
self._cos_cached = torch.cos(freqs).to(self.dtype).cuda()
yuanheng-zhao marked this conversation as resolved.
Show resolved Hide resolved
self._sin_cached = torch.sin(freqs).to(self.dtype).cuda()
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import torch

from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.kernel.triton import copy_kv_to_blocked_cache, decoding_fused_rotary_embedding, rotary_embedding
from tests.test_infer.test_ops.triton.kernel_utils import mock_alloc_block_table_and_kvcache_v2, mock_alloc_single_token

inference_ops = InferenceOpsLoader().load()

try:
import triton # noqa

Expand All @@ -16,9 +19,19 @@
x_names=["num_tokens"],
x_vals=[2**i for i in range(4, 11)],
line_arg="provider",
line_vals=["no_fused_rotary_emb_func", "fused_triton_rotary_emb_func"],
line_names=["no_fused_rotary_emb_func", "fused_triton_rotary_emb_func"],
styles=[("red", "-"), ("blue", "-")],
line_vals=[
"no_fused_triton_rotary_emb_func",
"fused_triton_rotary_emb_func",
"no_fused_cuda_rotary_emb_func",
"fused_cuda_rotary_emb_func",
],
line_names=[
"no_fused_triton_rotary_emb_func",
"fused_triton_rotary_emb_func",
"no_fused_cuda_rotary_emb_func",
"fused_cuda_rotary_emb_func",
],
styles=[("red", "-"), ("blue", "-"), ("green", "-"), ("yellow", "-")],
ylabel="ms",
plot_name=f"rotary_emb-batch-{BATCH}",
args={"num_kv_heads": 16},
Expand All @@ -32,7 +45,7 @@ def benchmark_rotary_emb(
num_tokens: int,
num_kv_heads: int,
):
BATCH_SIZE = 4
BATCH_SIZE = 16
SEQ_LEN = num_tokens // BATCH_SIZE
max_num_blocks_per_seq = 8
block_size = 64
Expand Down Expand Up @@ -68,7 +81,7 @@ def benchmark_rotary_emb(
kv_seq_lengths = past_kv_seq_lengths + 1
block_tables = block_tables.to(device="cuda")

if provider == "no_fused_rotary_emb_func":
if provider == "no_fused_triton_rotary_emb_func":
fn = lambda: [
rotary_embedding(new_q, new_k, cos, sin),
copy_kv_to_blocked_cache(
Expand All @@ -77,7 +90,16 @@ def benchmark_rotary_emb(
]
elif provider == "fused_triton_rotary_emb_func":
fn = lambda: decoding_fused_rotary_embedding(
new_q, new_k, new_k, cos, sin, k_cache, k_cache, block_tables, kv_seq_lengths
new_q, new_k, new_v, cos, sin, k_cache, v_cache, block_tables, kv_seq_lengths
)
elif provider == "no_fused_cuda_rotary_emb_func":
fn = lambda: [
inference_ops.rotary_embedding(new_q, new_k, cos, sin),
inference_ops.decode_kv_cache_memcpy(new_k, new_v, k_cache, v_cache, kv_seq_lengths, block_tables),
]
elif provider == "fused_cuda_rotary_emb_func":
fn = lambda: inference_ops.rotary_embedding_and_cache_copy(
new_q, new_k, new_v, cos, sin, k_cache, v_cache, kv_seq_lengths, block_tables
)
else:
raise ValueError("Undefined provider")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
import torch
import triton
from vllm._C import ops

from colossalai.kernel.triton.fused_rotary_embedding import fused_rotary_embedding
from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.kernel.triton import rotary_embedding

inference_ops = InferenceOpsLoader().load()

BATCH = 16
configs = [
triton.testing.Benchmark(
x_names=["num_tokens"],
x_vals=[2**i for i in range(4, 12)],
line_arg="provider",
line_vals=["torch_rotary_emb_func", "triton_rotary_emb_func"],
line_names=["torch_rotary_emb_func", "triton_rotary_emb_func"],
styles=[("red", "-"), ("blue", "-")],
line_vals=["triton_func", "colossal_cuda_func", "vllm_cuda_func"],
line_names=["triton_func", "colossal_cuda_func", "vllm_cuda_func"],
styles=[("red", "-"), ("blue", "-"), ("yellow", "-")],
ylabel="ms",
plot_name=f"rotary_emb-batch-{BATCH}",
args={"num_kv_heads": 16},
Expand Down Expand Up @@ -48,12 +52,19 @@ def benchmark_rotary_emb(
cos_shape = (4096, head_dim // 2)
cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
lengths = torch.tensor([3, 4, 6, 7], device="cuda")

if provider == "torch_rotary_emb_func":
fn = lambda: torch_rotary_emb(q, cos[:num_tokens], sin[:num_tokens])
elif provider == "triton_rotary_emb_func":
fn = lambda: fused_rotary_embedding(q, k, cos, sin, lengths)
cos_sin = torch.stack((cos, sin), dim=1).contiguous()

positions = torch.arange(num_tokens).cuda()

if provider == "triton_func":
fn = lambda: rotary_embedding(q, k, cos, sin)
elif provider == "colossal_cuda_func":
fn = lambda: inference_ops.rotary_embedding(q, k, cos, sin)
elif provider == "vllm_cuda_func":
q = q.view(num_tokens, -1)
k = k.view(num_tokens, -1)
fn = lambda: ops.rotary_embedding(positions, q, k, head_dim, cos_sin, True)
else:
raise ValueError("Undefined provider")

Expand Down
54 changes: 54 additions & 0 deletions examples/inference/benchmark_ops/benchmark_xine_copy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import torch

from colossalai.kernel.triton import get_xine_cache
from tests.test_infer.test_ops.triton.test_xine_copy import get_cos_sin

try:
import triton # noqa

except ImportError:
print("please install triton from https://github.com/openai/triton")


configs = [
triton.testing.Benchmark(
x_names=["max_num_tokens"],
x_vals=[2**i for i in range(6, 12)],
line_arg="provider",
line_vals=["torch_get_cos_sin", "triton_get_cos_sin"],
line_names=["torch_get_cos_sin", "triton_get_cos_sin"],
styles=[("red", "-"), ("blue", "-")],
ylabel="ms",
plot_name="Get_cos-sin_func",
args={"batch_size": 16, "head_dim": 256},
)
]


@triton.testing.perf_report(configs)
def benchmark_get_xine_cache(
provider: str,
max_num_tokens: int,
batch_size: int,
head_dim: int,
):
warmup = 10
rep = 1000
dtype = torch.float16
cos_cache = torch.randn((8912, head_dim), dtype=dtype, device="cuda")
sin_cache = torch.randn((8912, head_dim), dtype=dtype, device="cuda")
lengths = torch.randint(2, max_num_tokens, (batch_size,), device="cuda")

if provider == "torch_get_cos_sin":
fn = lambda: get_cos_sin(lengths, cos_cache, sin_cache, is_prompts=True, dtype=dtype)
elif provider == "triton_get_cos_sin":
fn = lambda: get_xine_cache(lengths, cos_cache, sin_cache, is_prompts=True)
else:
raise ValueError("Undefined provider")

ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
return ms


if __name__ == "__main__":
benchmark_get_xine_cache.run(save_path=".", print_data=True)
23 changes: 23 additions & 0 deletions extensions/csrc/cuda/colossal_inference_C_frontend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,30 @@ void decode_kv_cache_memcpy(
torch::Tensor& sequence_lengths, // [batch_size]
torch::Tensor& block_tables); // [batch_size, max_seq_len]

void rotary_embedding(
torch::Tensor& query, // [total_tokens, head_num, head_dim]
torch::Tensor& key, // [total_tokens, kv_head_num, head_dim]
torch::Tensor& cos, // [total_tokens, head_dim]
torch::Tensor& sin); // [total_tokens, head_dim]

void rotary_embedding_and_cache_copy(
torch::Tensor& query, // [num_tokens, head_num, head_dim]
torch::Tensor& key, // [num_tokens, kv_head_num, head_dim]
torch::Tensor& value, // [num_tokens, num_heads, head_dim]
torch::Tensor& cos, // [num_tokens, head_dim]
torch::Tensor& sin, // [num_tokens, head_dim]
torch::Tensor& key_cache, // [num_blocks, num_heads, block_size, head_dim]
torch::Tensor&
value_cache, // [num_blocks, num_heads, block_size, head_dim]
torch::Tensor& sequence_lengths, // [batch_size]
torch::Tensor& block_tables); // [batch_size, max_seq_len]

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("decode_kv_cache_memcpy", &decode_kv_cache_memcpy,
"Copy the GPU memory of kvcache during the decode stage.");
m.def(
"rotary_embedding_and_cache_copy", &rotary_embedding_and_cache_copy,
"performing Rotary Embedding-related calculations and KVCache Memcopy.");
m.def("rotary_embedding", &rotary_embedding,
"performing Rotary Embedding-related calculations.");
}
Loading
Loading