Skip to content

Commit

Permalink
fix code style
Browse files Browse the repository at this point in the history
  • Loading branch information
isky-cd committed Mar 8, 2024
1 parent dfe8184 commit b1564c4
Show file tree
Hide file tree
Showing 9 changed files with 521 additions and 195 deletions.
4 changes: 2 additions & 2 deletions colossalai/inference/modeling/models/nopadding_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ def llama_model_forward(
# NOTE: After testing, the performance of this configuration is relatively good. With updates
# and optimizations to the CUDA kernel implementation, a more detailed analysis of this configuration's
# selection should be conducted.
if batch_size >= 32 and kv_seq_len > 512:
use_cuda_kernel = False
# if batch_size >= 32 and kv_seq_len > 512:
# use_cuda_kernel = False

hidden_states = self.embed_tokens(input_ids)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
import torch
import triton
from vllm._C import ops

from colossalai.kernel.triton.fused_rotary_embedding import fused_rotary_embedding
from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.kernel.triton import rotary_embedding

inference_ops = InferenceOpsLoader().load()

BATCH = 16
configs = [
triton.testing.Benchmark(
x_names=["num_tokens"],
x_vals=[2**i for i in range(4, 12)],
line_arg="provider",
line_vals=["torch_rotary_emb_func", "triton_rotary_emb_func"],
line_names=["torch_rotary_emb_func", "triton_rotary_emb_func"],
styles=[("red", "-"), ("blue", "-")],
line_vals=["triton_func", "colossal_cuda_func", "vllm_cuda_func"],
line_names=["triton_func", "colossal_cuda_func", "vllm_cuda_func"],
styles=[("red", "-"), ("blue", "-"), ("yellow", "-")],
ylabel="ms",
plot_name=f"rotary_emb-batch-{BATCH}",
args={"num_kv_heads": 16},
Expand Down Expand Up @@ -48,12 +52,19 @@ def benchmark_rotary_emb(
cos_shape = (4096, head_dim // 2)
cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
lengths = torch.tensor([3, 4, 6, 7], device="cuda")

if provider == "torch_rotary_emb_func":
fn = lambda: torch_rotary_emb(q, cos[:num_tokens], sin[:num_tokens])
elif provider == "triton_rotary_emb_func":
fn = lambda: fused_rotary_embedding(q, k, cos, sin, lengths)
cos_sin = torch.stack((cos, sin), dim=1).contiguous()

positions = torch.arange(num_tokens).cuda()

if provider == "triton_func":
fn = lambda: rotary_embedding(q, k, cos, sin)
elif provider == "colossal_cuda_func":
fn = lambda: inference_ops.rotary_embedding(q, k, cos, sin)
elif provider == "vllm_cuda_func":
q = q.view(num_tokens, -1)
k = k.view(num_tokens, -1)
fn = lambda: ops.rotary_embedding(positions, q, k, head_dim, cos_sin, True)
else:
raise ValueError("Undefined provider")

Expand Down
54 changes: 54 additions & 0 deletions examples/inference/benchmark_ops/benchmark_xine_copy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import torch

from colossalai.kernel.triton import get_xine_cache
from tests.test_infer.test_ops.triton.test_xine_copy import get_cos_sin

try:
import triton # noqa

except ImportError:
print("please install triton from https://github.com/openai/triton")


configs = [
triton.testing.Benchmark(
x_names=["max_num_tokens"],
x_vals=[2**i for i in range(6, 12)],
line_arg="provider",
line_vals=["torch_get_cos_sin", "triton_get_cos_sin"],
line_names=["torch_get_cos_sin", "triton_get_cos_sin"],
styles=[("red", "-"), ("blue", "-")],
ylabel="ms",
plot_name="Get_cos-sin_func",
args={"batch_size": 16, "head_dim": 256},
)
]


@triton.testing.perf_report(configs)
def benchmark_get_xine_cache(
provider: str,
max_num_tokens: int,
batch_size: int,
head_dim: int,
):
warmup = 10
rep = 1000
dtype = torch.float16
cos_cache = torch.randn((8912, head_dim), dtype=dtype, device="cuda")
sin_cache = torch.randn((8912, head_dim), dtype=dtype, device="cuda")
lengths = torch.randint(2, max_num_tokens, (batch_size,), device="cuda")

if provider == "torch_get_cos_sin":
fn = lambda: get_cos_sin(lengths, cos_cache, sin_cache, is_prompts=True, dtype=dtype)
elif provider == "triton_get_cos_sin":
fn = lambda: get_xine_cache(lengths, cos_cache, sin_cache, is_prompts=True)
else:
raise ValueError("Undefined provider")

ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
return ms


if __name__ == "__main__":
benchmark_get_xine_cache.run(save_path=".", print_data=True)
141 changes: 105 additions & 36 deletions extensions/csrc/cuda/decode_kv_cache_memcpy_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,18 @@
#include <torch/extension.h>

#include "type_shim.h"
#include "vector_copy_utils.h"

template<typename scalar_t>
template<typename scalar_t, int VecSize>
__global__ void decode_kv_cache_memcpy_kernel(
const scalar_t* __restrict__ key,
const scalar_t* __restrict__ value,
scalar_t* __restrict__ key_cache,
scalar_t* __restrict__ value_cache,
const int* __restrict__ sequence_lengths,
const int* __restrict__ block_tables,
const int num_heads,
const int head_size,
const int head_num,
const int head_dim,
const int block_size,
const int64_t key_stride,
const int64_t value_stride,
Expand All @@ -23,66 +24,134 @@ __global__ void decode_kv_cache_memcpy_kernel(
const int seq_len = sequence_lengths[seq_id] - 1;
const int block_offset = seq_len % block_size;
const int block_id = block_tables[seq_id * block_table_stride + seq_len / block_size];
const int hidden_size = num_heads * head_size;
const int hidden_size = head_num * head_dim;

if ( block_id < 0 ) {
return ;
}

for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
const int head_id = i / head_size;
const int head_offset = i % head_size;
for (int i = threadIdx.x * VecSize; i < hidden_size; i += blockDim.x * VecSize) {
const int head_id = i / head_dim;
const int head_offset = i % head_dim;
const int64_t key_src_id = seq_id * key_stride + i;
const int64_t value_src_id = seq_id * value_stride + i;
const int64_t target_src_id = block_id * hidden_size * block_size
+ head_id * block_size * head_size
+ block_offset * head_size + head_offset;
const int64_t target_id = block_id * hidden_size * block_size
+ head_id * block_size * head_dim
+ block_offset * head_dim + head_offset;

key_cache[target_src_id] = key[key_src_id];
value_cache[target_src_id] = value[value_src_id];
copy_vector<scalar_t, VecSize>(key_cache + target_id, key + key_src_id);
copy_vector<scalar_t, VecSize>(value_cache + target_id, value + value_src_id);
}

}

void decode_kv_cache_memcpy(
at::Tensor& key, // [num_tokens, num_heads, head_size]
at::Tensor& value, // [num_tokens, num_heads, head_size]
at::Tensor& key_cache, // [num_blocks, num_heads, block_size, head_size]
at::Tensor& value_cache, // [num_blocks, num_heads, block_size, head_size]
template<typename scalar_t>
void apply_decode_kv_cache_memcpy(
at::Tensor& key, // [num_tokens, head_num, head_dim]
at::Tensor& value, // [num_tokens, head_num, head_dim]
at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim]
at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim]
at::Tensor& sequence_lengths, // [batch_size]
at::Tensor& block_tables) // [batch_size, max_seq_len]
{
int num_tokens = key.size(0);
int num_heads = key.size(1);
int head_size = key.size(2);
int head_num = key.size(1);
int head_dim = key.size(2);
int block_size = key_cache.size(2);

int64_t key_stride = key.stride(0);
int64_t value_stride = value.stride(0);
int block_table_stride = block_tables.stride(0);

int vec_size = get_vec_size<scalar_t>(key);

if (head_dim % vec_size != 0) {
// Disable vectorized loading optimization when head_dim is not divisible by VecSize.
vec_size = 1;
}

int thread_nums = head_num * head_dim / vec_size;

const cudaStream_t stream = at::cuda::getCurrentCUDAStream();

dim3 grid(num_tokens);
dim3 block(std::min(num_heads * head_size, 512));
dim3 block(std::min(thread_nums, 512));

switch (vec_size) {
case 1:
decode_kv_cache_memcpy_kernel<scalar_t, 1><<<grid, block, 0, stream>>>(
key.data_ptr<scalar_t>(),
value.data_ptr<scalar_t>(),
key_cache.data_ptr<scalar_t>(),
value_cache.data_ptr<scalar_t>(),
sequence_lengths.data_ptr<int>(),
block_tables.data_ptr<int>(),
head_num,
head_dim,
block_size,
key_stride,
value_stride,
block_table_stride
);
break;
case 2:
decode_kv_cache_memcpy_kernel<scalar_t, 2><<<grid, block, 0, stream>>>(
key.data_ptr<scalar_t>(),
value.data_ptr<scalar_t>(),
key_cache.data_ptr<scalar_t>(),
value_cache.data_ptr<scalar_t>(),
sequence_lengths.data_ptr<int>(),
block_tables.data_ptr<int>(),
head_num,
head_dim,
block_size,
key_stride,
value_stride,
block_table_stride
);
break;
case 4:
decode_kv_cache_memcpy_kernel<scalar_t, 4><<<grid, block, 0, stream>>>(
key.data_ptr<scalar_t>(),
value.data_ptr<scalar_t>(),
key_cache.data_ptr<scalar_t>(),
value_cache.data_ptr<scalar_t>(),
sequence_lengths.data_ptr<int>(),
block_tables.data_ptr<int>(),
head_num,
head_dim,
block_size,
key_stride,
value_stride,
block_table_stride
);
break;
default:
AT_ERROR("Unsupported vectorized size ", vec_size);
break;
}

AT_CUDA_CHECK(cudaGetLastError());

}

void decode_kv_cache_memcpy(
at::Tensor& key, // [num_tokens, head_num, head_dim]
at::Tensor& value, // [num_tokens, head_num, head_dim]
at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim]
at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim]
at::Tensor& sequence_lengths, // [batch_size]
at::Tensor& block_tables) // [batch_size, max_seq_len]
{
DISPATCH_FLOAT_HALF_AND_BFLOAT(
key.scalar_type(),
"decode_kv_cache_memcpy",
decode_kv_cache_memcpy_kernel<scalar_t><<<grid, block, 0, stream>>>(
key.data_ptr<scalar_t>(),
value.data_ptr<scalar_t>(),
key_cache.data_ptr<scalar_t>(),
value_cache.data_ptr<scalar_t>(),
sequence_lengths.data_ptr<int>(),
block_tables.data_ptr<int>(),
num_heads,
head_size,
block_size,
key_stride,
value_stride,
block_table_stride
apply_decode_kv_cache_memcpy<scalar_t>(
key,
value,
key_cache,
value_cache,
sequence_lengths,
block_tables
);)

AT_CUDA_CHECK(cudaGetLastError());

}
Loading

0 comments on commit b1564c4

Please sign in to comment.