Skip to content

Commit

Permalink
[Inference/kernel]Add Fused Rotary Embedding and KVCache Memcopy CUDA…
Browse files Browse the repository at this point in the history
… Kernel (#5418)

* add rotary embedding kernel

* add rotary_embedding_kernel

* add fused rotary_emb and kvcache memcopy

* add fused_rotary_emb_and_cache_kernel.cu

* add fused_rotary_emb_and_memcopy

* fix bugs in fused_rotary_emb_and_cache_kernel.cu

* fix ci bugs

* use vec memcopy and opt the  gloabl memory access

* fix code style

* fix test_rotary_embdding_unpad.py

* codes revised based on the review comments

* fix bugs about include path

* rm inline
  • Loading branch information
yuehuayingxueluo committed Mar 13, 2024
1 parent ed431de commit f366a5e
Show file tree
Hide file tree
Showing 13 changed files with 928 additions and 78 deletions.
19 changes: 15 additions & 4 deletions colossalai/inference/modeling/models/nopadding_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,8 +320,12 @@ def forward(
)

block_size = k_cache.size(-2)

if is_prompts:
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
if use_cuda_kernel:
inference_ops.rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
else:
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
attn_output = context_attention_unpadded(
q=query_states,
k=key_states,
Expand All @@ -337,9 +341,16 @@ def forward(
)
else:
if use_cuda_kernel:
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
inference_ops.decode_kv_cache_memcpy(
key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables
inference_ops.rotary_embedding_and_cache_copy(
query_states,
key_states,
value_states,
cos_sin[0],
cos_sin[1],
k_cache,
v_cache,
sequence_lengths,
block_tables,
)
else:
decoding_fused_rotary_embedding(
Expand Down
4 changes: 2 additions & 2 deletions colossalai/inference/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,5 +47,5 @@ def init_to_get_rotary(self, base=10000, use_elem=False):
t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor
freqs = torch.outer(t, inv_freq)

self._cos_cached = torch.cos(freqs).to(torch.float16).cuda()
self._sin_cached = torch.sin(freqs).to(torch.float16).cuda()
self._cos_cached = torch.cos(freqs).to(self.dtype).cuda()
self._sin_cached = torch.sin(freqs).to(self.dtype).cuda()
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import torch

from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.kernel.triton import copy_kv_to_blocked_cache, decoding_fused_rotary_embedding, rotary_embedding
from tests.test_infer.test_ops.triton.kernel_utils import mock_alloc_block_table_and_kvcache_v2, mock_alloc_single_token

inference_ops = InferenceOpsLoader().load()

try:
import triton # noqa

Expand All @@ -16,9 +19,19 @@
x_names=["num_tokens"],
x_vals=[2**i for i in range(4, 11)],
line_arg="provider",
line_vals=["no_fused_rotary_emb_func", "fused_triton_rotary_emb_func"],
line_names=["no_fused_rotary_emb_func", "fused_triton_rotary_emb_func"],
styles=[("red", "-"), ("blue", "-")],
line_vals=[
"no_fused_triton_rotary_emb_func",
"fused_triton_rotary_emb_func",
"no_fused_cuda_rotary_emb_func",
"fused_cuda_rotary_emb_func",
],
line_names=[
"no_fused_triton_rotary_emb_func",
"fused_triton_rotary_emb_func",
"no_fused_cuda_rotary_emb_func",
"fused_cuda_rotary_emb_func",
],
styles=[("red", "-"), ("blue", "-"), ("green", "-"), ("yellow", "-")],
ylabel="ms",
plot_name=f"rotary_emb-batch-{BATCH}",
args={"num_kv_heads": 16},
Expand All @@ -32,7 +45,7 @@ def benchmark_rotary_emb(
num_tokens: int,
num_kv_heads: int,
):
BATCH_SIZE = 4
BATCH_SIZE = 16
SEQ_LEN = num_tokens // BATCH_SIZE
max_num_blocks_per_seq = 8
block_size = 64
Expand Down Expand Up @@ -68,7 +81,7 @@ def benchmark_rotary_emb(
kv_seq_lengths = past_kv_seq_lengths + 1
block_tables = block_tables.to(device="cuda")

if provider == "no_fused_rotary_emb_func":
if provider == "no_fused_triton_rotary_emb_func":
fn = lambda: [
rotary_embedding(new_q, new_k, cos, sin),
copy_kv_to_blocked_cache(
Expand All @@ -77,7 +90,16 @@ def benchmark_rotary_emb(
]
elif provider == "fused_triton_rotary_emb_func":
fn = lambda: decoding_fused_rotary_embedding(
new_q, new_k, new_k, cos, sin, k_cache, k_cache, block_tables, kv_seq_lengths
new_q, new_k, new_v, cos, sin, k_cache, v_cache, block_tables, kv_seq_lengths
)
elif provider == "no_fused_cuda_rotary_emb_func":
fn = lambda: [
inference_ops.rotary_embedding(new_q, new_k, cos, sin),
inference_ops.decode_kv_cache_memcpy(new_k, new_v, k_cache, v_cache, kv_seq_lengths, block_tables),
]
elif provider == "fused_cuda_rotary_emb_func":
fn = lambda: inference_ops.rotary_embedding_and_cache_copy(
new_q, new_k, new_v, cos, sin, k_cache, v_cache, kv_seq_lengths, block_tables
)
else:
raise ValueError("Undefined provider")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
import torch
import triton
from vllm._C import ops

from colossalai.kernel.triton.fused_rotary_embedding import fused_rotary_embedding
from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.kernel.triton import rotary_embedding

inference_ops = InferenceOpsLoader().load()

BATCH = 16
configs = [
triton.testing.Benchmark(
x_names=["num_tokens"],
x_vals=[2**i for i in range(4, 12)],
line_arg="provider",
line_vals=["torch_rotary_emb_func", "triton_rotary_emb_func"],
line_names=["torch_rotary_emb_func", "triton_rotary_emb_func"],
styles=[("red", "-"), ("blue", "-")],
line_vals=["triton_func", "colossal_cuda_func", "vllm_cuda_func"],
line_names=["triton_func", "colossal_cuda_func", "vllm_cuda_func"],
styles=[("red", "-"), ("blue", "-"), ("yellow", "-")],
ylabel="ms",
plot_name=f"rotary_emb-batch-{BATCH}",
args={"num_kv_heads": 16},
Expand Down Expand Up @@ -48,12 +52,19 @@ def benchmark_rotary_emb(
cos_shape = (4096, head_dim // 2)
cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
lengths = torch.tensor([3, 4, 6, 7], device="cuda")

if provider == "torch_rotary_emb_func":
fn = lambda: torch_rotary_emb(q, cos[:num_tokens], sin[:num_tokens])
elif provider == "triton_rotary_emb_func":
fn = lambda: fused_rotary_embedding(q, k, cos, sin, lengths)
cos_sin = torch.stack((cos, sin), dim=1).contiguous()

positions = torch.arange(num_tokens).cuda()

if provider == "triton_func":
fn = lambda: rotary_embedding(q, k, cos, sin)
elif provider == "colossal_cuda_func":
fn = lambda: inference_ops.rotary_embedding(q, k, cos, sin)
elif provider == "vllm_cuda_func":
q = q.view(num_tokens, -1)
k = k.view(num_tokens, -1)
fn = lambda: ops.rotary_embedding(positions, q, k, head_dim, cos_sin, True)
else:
raise ValueError("Undefined provider")

Expand Down
54 changes: 54 additions & 0 deletions examples/inference/benchmark_ops/benchmark_xine_copy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import torch

from colossalai.kernel.triton import get_xine_cache
from tests.test_infer.test_ops.triton.test_xine_copy import get_cos_sin

try:
import triton # noqa

except ImportError:
print("please install triton from https://github.com/openai/triton")


configs = [
triton.testing.Benchmark(
x_names=["max_num_tokens"],
x_vals=[2**i for i in range(6, 12)],
line_arg="provider",
line_vals=["torch_get_cos_sin", "triton_get_cos_sin"],
line_names=["torch_get_cos_sin", "triton_get_cos_sin"],
styles=[("red", "-"), ("blue", "-")],
ylabel="ms",
plot_name="Get_cos-sin_func",
args={"batch_size": 16, "head_dim": 256},
)
]


@triton.testing.perf_report(configs)
def benchmark_get_xine_cache(
provider: str,
max_num_tokens: int,
batch_size: int,
head_dim: int,
):
warmup = 10
rep = 1000
dtype = torch.float16
cos_cache = torch.randn((8912, head_dim), dtype=dtype, device="cuda")
sin_cache = torch.randn((8912, head_dim), dtype=dtype, device="cuda")
lengths = torch.randint(2, max_num_tokens, (batch_size,), device="cuda")

if provider == "torch_get_cos_sin":
fn = lambda: get_cos_sin(lengths, cos_cache, sin_cache, is_prompts=True, dtype=dtype)
elif provider == "triton_get_cos_sin":
fn = lambda: get_xine_cache(lengths, cos_cache, sin_cache, is_prompts=True)
else:
raise ValueError("Undefined provider")

ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
return ms


if __name__ == "__main__":
benchmark_get_xine_cache.run(save_path=".", print_data=True)
98 changes: 98 additions & 0 deletions extensions/csrc/common/vector_copy_utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@

#include <c10/macros/Macros.h>
#include <cuda_fp16.h>

#include <cfloat>

#include "string"

template <typename Datatype, int ELEMENTS_PER_LDG>
__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src);

template <>
__device__ __inline__ void copy_vector<c10::BFloat16, 1>(
c10::BFloat16 *dst, const c10::BFloat16 *src) {
*dst = *src;
}

template <>
__device__ __inline__ void copy_vector<c10::BFloat16, 2>(
c10::BFloat16 *dst, const c10::BFloat16 *src) {
*((float *)dst) = *((float *)src);
}

template <>
__device__ __inline__ void copy_vector<c10::BFloat16, 4>(
c10::BFloat16 *dst, const c10::BFloat16 *src) {
*((float2 *)dst) = *((float2 *)src);
}

template <>
__device__ __inline__ void copy_vector<c10::BFloat16, 8>(
c10::BFloat16 *dst, const c10::BFloat16 *src) {
*((float4 *)dst) = *((float4 *)src);
}

template <>
__device__ __inline__ void copy_vector<c10::Half, 1>(c10::Half *dst,
const c10::Half *src) {
*dst = *src;
}

template <>
__device__ __inline__ void copy_vector<c10::Half, 2>(c10::Half *dst,
const c10::Half *src) {
*((float *)dst) = *((float *)src);
}

template <>
__device__ __inline__ void copy_vector<c10::Half, 4>(c10::Half *dst,
const c10::Half *src) {
*((float2 *)dst) = *((float2 *)src);
}

template <>
__device__ __inline__ void copy_vector<c10::Half, 8>(c10::Half *dst,
const c10::Half *src) {
*((float4 *)dst) = *((float4 *)src);
}

template <>
__device__ __inline__ void copy_vector<float, 1>(float *dst, const float *src) {
*dst = *src;
}

template <>
__device__ __inline__ void copy_vector<float, 2>(float *dst, const float *src) {
*((float2 *)dst) = *((float2 *)src);
}

template <>
__device__ __inline__ void copy_vector<float, 4>(float *dst, const float *src) {
*((float4 *)dst) = *((float4 *)src);
}

template <>
__device__ __inline__ void copy_vector<float, 8>(float *dst, const float *src) {
// Since the maximum memory alignment length is 128 bits, we choose float4
// here.
*((float4 *)dst) = *((float4 *)src);
*((float4 *)(dst + 4)) = *((float4 *)(src + 4));
}

template <typename T>
int get_vec_size(const torch::Tensor &tensor) {
uint64_t address = reinterpret_cast<uint64_t>(tensor.data_ptr<T>());
const int max_aligned_size = 128;
const int dtype_size = sizeof(T) * 8;

const int vec_size = max_aligned_size / sizeof(T) / 8;

if (address % (dtype_size * 4) == 0) {
return std::min(4, vec_size);
} else if (address % (dtype_size * 2) == 0) {
return std::min(2, vec_size);
} else {
return 1;
}
}
3 changes: 3 additions & 0 deletions extensions/csrc/cuda/activation_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ torch::Tensor silu_and_mul(const torch::Tensor& ins)
auto ins_shape = ins.sizes().vec();

ins_shape[0] = ins_shape[0]/2;
if (ins_shape[0] == 1) {
ins_shape.erase(ins_shape.begin());
}
auto outs = torch::zeros(ins_shape,ins.options());
auto outs_shape = ins.sizes().vec();

Expand Down
Loading

0 comments on commit f366a5e

Please sign in to comment.