-
Notifications
You must be signed in to change notification settings - Fork 4.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Inference/kernel]Add Fused Rotary Embedding and KVCache Memcopy CUDA…
… Kernel (#5418) * add rotary embedding kernel * add rotary_embedding_kernel * add fused rotary_emb and kvcache memcopy * add fused_rotary_emb_and_cache_kernel.cu * add fused_rotary_emb_and_memcopy * fix bugs in fused_rotary_emb_and_cache_kernel.cu * fix ci bugs * use vec memcopy and opt the gloabl memory access * fix code style * fix test_rotary_embdding_unpad.py * codes revised based on the review comments * fix bugs about include path * rm inline
- Loading branch information
1 parent
ed431de
commit f366a5e
Showing
13 changed files
with
928 additions
and
78 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
import torch | ||
|
||
from colossalai.kernel.triton import get_xine_cache | ||
from tests.test_infer.test_ops.triton.test_xine_copy import get_cos_sin | ||
|
||
try: | ||
import triton # noqa | ||
|
||
except ImportError: | ||
print("please install triton from https://github.com/openai/triton") | ||
|
||
|
||
configs = [ | ||
triton.testing.Benchmark( | ||
x_names=["max_num_tokens"], | ||
x_vals=[2**i for i in range(6, 12)], | ||
line_arg="provider", | ||
line_vals=["torch_get_cos_sin", "triton_get_cos_sin"], | ||
line_names=["torch_get_cos_sin", "triton_get_cos_sin"], | ||
styles=[("red", "-"), ("blue", "-")], | ||
ylabel="ms", | ||
plot_name="Get_cos-sin_func", | ||
args={"batch_size": 16, "head_dim": 256}, | ||
) | ||
] | ||
|
||
|
||
@triton.testing.perf_report(configs) | ||
def benchmark_get_xine_cache( | ||
provider: str, | ||
max_num_tokens: int, | ||
batch_size: int, | ||
head_dim: int, | ||
): | ||
warmup = 10 | ||
rep = 1000 | ||
dtype = torch.float16 | ||
cos_cache = torch.randn((8912, head_dim), dtype=dtype, device="cuda") | ||
sin_cache = torch.randn((8912, head_dim), dtype=dtype, device="cuda") | ||
lengths = torch.randint(2, max_num_tokens, (batch_size,), device="cuda") | ||
|
||
if provider == "torch_get_cos_sin": | ||
fn = lambda: get_cos_sin(lengths, cos_cache, sin_cache, is_prompts=True, dtype=dtype) | ||
elif provider == "triton_get_cos_sin": | ||
fn = lambda: get_xine_cache(lengths, cos_cache, sin_cache, is_prompts=True) | ||
else: | ||
raise ValueError("Undefined provider") | ||
|
||
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) | ||
return ms | ||
|
||
|
||
if __name__ == "__main__": | ||
benchmark_get_xine_cache.run(save_path=".", print_data=True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
|
||
#include <c10/macros/Macros.h> | ||
#include <cuda_fp16.h> | ||
|
||
#include <cfloat> | ||
|
||
#include "string" | ||
|
||
template <typename Datatype, int ELEMENTS_PER_LDG> | ||
__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src); | ||
|
||
template <> | ||
__device__ __inline__ void copy_vector<c10::BFloat16, 1>( | ||
c10::BFloat16 *dst, const c10::BFloat16 *src) { | ||
*dst = *src; | ||
} | ||
|
||
template <> | ||
__device__ __inline__ void copy_vector<c10::BFloat16, 2>( | ||
c10::BFloat16 *dst, const c10::BFloat16 *src) { | ||
*((float *)dst) = *((float *)src); | ||
} | ||
|
||
template <> | ||
__device__ __inline__ void copy_vector<c10::BFloat16, 4>( | ||
c10::BFloat16 *dst, const c10::BFloat16 *src) { | ||
*((float2 *)dst) = *((float2 *)src); | ||
} | ||
|
||
template <> | ||
__device__ __inline__ void copy_vector<c10::BFloat16, 8>( | ||
c10::BFloat16 *dst, const c10::BFloat16 *src) { | ||
*((float4 *)dst) = *((float4 *)src); | ||
} | ||
|
||
template <> | ||
__device__ __inline__ void copy_vector<c10::Half, 1>(c10::Half *dst, | ||
const c10::Half *src) { | ||
*dst = *src; | ||
} | ||
|
||
template <> | ||
__device__ __inline__ void copy_vector<c10::Half, 2>(c10::Half *dst, | ||
const c10::Half *src) { | ||
*((float *)dst) = *((float *)src); | ||
} | ||
|
||
template <> | ||
__device__ __inline__ void copy_vector<c10::Half, 4>(c10::Half *dst, | ||
const c10::Half *src) { | ||
*((float2 *)dst) = *((float2 *)src); | ||
} | ||
|
||
template <> | ||
__device__ __inline__ void copy_vector<c10::Half, 8>(c10::Half *dst, | ||
const c10::Half *src) { | ||
*((float4 *)dst) = *((float4 *)src); | ||
} | ||
|
||
template <> | ||
__device__ __inline__ void copy_vector<float, 1>(float *dst, const float *src) { | ||
*dst = *src; | ||
} | ||
|
||
template <> | ||
__device__ __inline__ void copy_vector<float, 2>(float *dst, const float *src) { | ||
*((float2 *)dst) = *((float2 *)src); | ||
} | ||
|
||
template <> | ||
__device__ __inline__ void copy_vector<float, 4>(float *dst, const float *src) { | ||
*((float4 *)dst) = *((float4 *)src); | ||
} | ||
|
||
template <> | ||
__device__ __inline__ void copy_vector<float, 8>(float *dst, const float *src) { | ||
// Since the maximum memory alignment length is 128 bits, we choose float4 | ||
// here. | ||
*((float4 *)dst) = *((float4 *)src); | ||
*((float4 *)(dst + 4)) = *((float4 *)(src + 4)); | ||
} | ||
|
||
template <typename T> | ||
int get_vec_size(const torch::Tensor &tensor) { | ||
uint64_t address = reinterpret_cast<uint64_t>(tensor.data_ptr<T>()); | ||
const int max_aligned_size = 128; | ||
const int dtype_size = sizeof(T) * 8; | ||
|
||
const int vec_size = max_aligned_size / sizeof(T) / 8; | ||
|
||
if (address % (dtype_size * 4) == 0) { | ||
return std::min(4, vec_size); | ||
} else if (address % (dtype_size * 2) == 0) { | ||
return std::min(2, vec_size); | ||
} else { | ||
return 1; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.