Skip to content

Commit

Permalink
accelerate single decode
Browse files Browse the repository at this point in the history
  • Loading branch information
yzh119 committed Dec 17, 2023
1 parent 1f57b6c commit 2a3d6d0
Showing 1 changed file with 75 additions and 50 deletions.
125 changes: 75 additions & 50 deletions include/flashinfer/decode.cuh
Expand Up @@ -227,8 +227,8 @@ __device__ __forceinline__ void sync_state(state_t<vec_size>& st, float* smem, f
* \param kv_chunk_size A integer indicates the kv-chunk size
*/
template <QKVLayout layout, bool cooperative, RotaryMode rotary_mode, uint32_t num_stages_smem,
uint32_t vec_size, uint32_t bdx, uint32_t bdy, uint32_t bdz, typename DTypeIn,
typename DTypeOut>
uint32_t tile_size_per_bdx, uint32_t vec_size, uint32_t bdx, uint32_t bdy, uint32_t bdz,
typename DTypeIn, typename DTypeOut>
__global__ void SingleDecodeWithKVCacheKernel(DTypeIn* __restrict__ q, DTypeIn* __restrict__ k,
DTypeIn* __restrict__ v, DTypeOut* __restrict__ o,
float* __restrict__ tmp,
Expand All @@ -249,8 +249,10 @@ __global__ void SingleDecodeWithKVCacheKernel(DTypeIn* __restrict__ q, DTypeIn*

extern __shared__ uint8_t smem[];
DTypeIn* k_smem = (DTypeIn*)smem;
DTypeIn* v_smem = (DTypeIn*)(smem + num_stages_smem * bdy * bdz * head_dim * sizeof(DTypeIn));
float* smem_md = (float*)(smem + 2 * num_stages_smem * bdy * bdz * head_dim * sizeof(DTypeIn));
DTypeIn* v_smem = (DTypeIn*)(smem + num_stages_smem * bdy * tile_size_per_bdx * bdz * head_dim *
sizeof(DTypeIn));
float* smem_md = (float*)(smem + 2 * num_stages_smem * bdy * tile_size_per_bdx * bdz * head_dim *
sizeof(DTypeIn));

uint32_t tx = threadIdx.x, ty = threadIdx.y, tz = threadIdx.z;
vec_t<float, vec_size> q_vec;
Expand Down Expand Up @@ -280,61 +282,81 @@ __global__ void SingleDecodeWithKVCacheKernel(DTypeIn* __restrict__ q, DTypeIn*
constexpr uint32_t vec_bits = sizeof(DTypeIn) * vec_size * 8;
#pragma unroll
for (uint32_t iter = 0; iter < num_stages_smem; ++iter) {
cp_async::pred_load<vec_bits, PrefetchMode::kPrefetch, SharedMemFillMode::kNoFill>(
k_smem + ((iter * bdz + tz) * bdy + ty) * head_dim + tx * vec_size,
k + info.get_kv_elem_offset(producer_kv_idx_base + tz * bdy + ty, kv_head_idx,
tx * vec_size),
producer_kv_idx_base + tz * bdy + ty < chunk_end);
for (uint32_t j = 0; j < tile_size_per_bdx; ++j) {
cp_async::pred_load<vec_bits, PrefetchMode::kPrefetch, SharedMemFillMode::kNoFill>(
k_smem + (((iter * bdz + tz) * bdy + ty) * tile_size_per_bdx + j) * head_dim +
tx * vec_size,
k + info.get_kv_elem_offset(
producer_kv_idx_base + (tz * bdy + ty) * tile_size_per_bdx + j, kv_head_idx,
tx * vec_size),
producer_kv_idx_base + (tz * bdy + ty) * tile_size_per_bdx + j < chunk_end);
}
cp_async::commit_group();
cp_async::pred_load<vec_bits, PrefetchMode::kPrefetch, SharedMemFillMode::kFillZero>(
v_smem + ((iter * bdz + tz) * bdy + ty) * head_dim + tx * vec_size,
v + info.get_kv_elem_offset(producer_kv_idx_base + tz * bdy + ty, kv_head_idx,
tx * vec_size),
producer_kv_idx_base + tz * bdy + ty < chunk_end);
for (uint32_t j = 0; j < tile_size_per_bdx; ++j) {
cp_async::pred_load<vec_bits, PrefetchMode::kPrefetch, SharedMemFillMode::kFillZero>(
v_smem + (((iter * bdz + tz) * bdy + ty) * tile_size_per_bdx + j) * head_dim +
tx * vec_size,
v + info.get_kv_elem_offset(
producer_kv_idx_base + (tz * bdy + ty) * tile_size_per_bdx + j, kv_head_idx,
tx * vec_size),
producer_kv_idx_base + (tz * bdy + ty) * tile_size_per_bdx + j < chunk_end);
}
cp_async::commit_group();
producer_kv_idx_base += bdy * bdz;
producer_kv_idx_base += bdy * bdz * tile_size_per_bdx;
}

// pipelining k/v tiles loading and state updating
uint32_t consumer_kv_idx_base = chunk_start, stage_idx = 0;
state_t<vec_size> st_local;
float s[bdy];
float s[bdy * tile_size_per_bdx];

#pragma unroll 4
for (uint32_t iter = 0; iter < (kv_chunk_size + bdy * bdz - 1) / (bdy * bdz); ++iter) {
#pragma unroll 2
for (uint32_t iter = 0;
iter < (kv_chunk_size + tile_size_per_bdx * bdy * bdz - 1) / (tile_size_per_bdx * bdy * bdz);
++iter) {
// compute qk
cp_async::wait_group<2 * num_stages_smem - 1>();
block.sync();
compute_qk<rotary_mode, vec_size, bdx, bdy>(
k_smem + (stage_idx * bdz + tz) * bdy * head_dim, stage_idx, q_vec, freq,
consumer_kv_idx_base, iter * bdy * bdz, kv_chunk_size, sm_scale, s, st_local);
compute_qk<rotary_mode, vec_size, bdx, bdy * tile_size_per_bdx>(
k_smem + (stage_idx * bdz + tz) * bdy * tile_size_per_bdx * head_dim, stage_idx, q_vec,
freq, consumer_kv_idx_base, iter * bdy * tile_size_per_bdx * bdz, kv_chunk_size, sm_scale,
s, st_local);
block.sync();
// load k
cp_async::pred_load<vec_bits, PrefetchMode::kPrefetch, SharedMemFillMode::kNoFill>(
k_smem + ((stage_idx * bdz + tz) * bdy + ty) * head_dim + tx * vec_size,
k + info.get_kv_elem_offset(producer_kv_idx_base + tz * bdy + ty, kv_head_idx,
tx * vec_size),
producer_kv_idx_base + tz * bdy + ty < chunk_end);
for (uint32_t j = 0; j < tile_size_per_bdx; ++j) {
cp_async::pred_load<vec_bits, PrefetchMode::kPrefetch, SharedMemFillMode::kNoFill>(
k_smem + (((stage_idx * bdz + tz) * bdy + ty) * tile_size_per_bdx + j) * head_dim +
tx * vec_size,
k + info.get_kv_elem_offset(
producer_kv_idx_base + (tz * bdy + ty) * tile_size_per_bdx + j, kv_head_idx,
tx * vec_size),
producer_kv_idx_base + (tz * bdy + ty) * tile_size_per_bdx + j < chunk_end);
}
cp_async::commit_group();

// update m/d/o state
cp_async::wait_group<2 * num_stages_smem - 1>();
block.sync();
update_local_state<vec_size, bdx, bdy>(v_smem + (stage_idx * bdz + tz) * bdy * head_dim, s,
stage_idx, st_local);
update_local_state<vec_size, bdx, bdy * tile_size_per_bdx>(
v_smem + (stage_idx * bdz + tz) * bdy * tile_size_per_bdx * head_dim, s, stage_idx,
st_local);
block.sync();

// load v
cp_async::pred_load<vec_bits, PrefetchMode::kPrefetch, SharedMemFillMode::kFillZero>(
v_smem + ((stage_idx * bdz + tz) * bdy + ty) * head_dim + tx * vec_size,
v + info.get_kv_elem_offset(producer_kv_idx_base + tz * bdy + ty, kv_head_idx,
tx * vec_size),
producer_kv_idx_base + tz * bdy + ty < chunk_end);
for (uint32_t j = 0; j < tile_size_per_bdx; ++j) {
cp_async::pred_load<vec_bits, PrefetchMode::kPrefetch, SharedMemFillMode::kFillZero>(
v_smem + (((stage_idx * bdz + tz) * bdy + ty) * tile_size_per_bdx + j) * head_dim +
tx * vec_size,
v + info.get_kv_elem_offset(
producer_kv_idx_base + (tz * bdy + ty) * tile_size_per_bdx + j, kv_head_idx,
tx * vec_size),
producer_kv_idx_base + (tz * bdy + ty) * tile_size_per_bdx + j < chunk_end);
}
cp_async::commit_group();

stage_idx = (stage_idx + 1) % num_stages_smem;
producer_kv_idx_base += bdy * bdz;
consumer_kv_idx_base += bdy * bdz;
producer_kv_idx_base += tile_size_per_bdx * bdy * bdz;
consumer_kv_idx_base += tile_size_per_bdx * bdy * bdz;
}
cp_async::wait_group<0>();
block.sync();
Expand All @@ -353,7 +375,7 @@ __global__ void SingleDecodeWithKVCacheKernel(DTypeIn* __restrict__ q, DTypeIn*
// sync global states
if (kv_chunk_idx == 0) {
state_t<vec_size> st_global;
#pragma unroll 4
#pragma unroll 2
for (uint32_t iter = 0; iter < (num_kv_chunks + bdz - 1) / bdz; ++iter) {
uint32_t kv_chunk_idx = iter * bdz + tz;
if (kv_chunk_idx < num_kv_chunks) {
Expand Down Expand Up @@ -775,13 +797,15 @@ cudaError_t SingleDecodeWithKVCacheWorkEstimation(uint32_t& tmp_size, uint32_t&
constexpr uint32_t num_threads =
std::max(get_heuristic_num_threads(GROUP_SIZE, sizeof(DTypeIn)), bdx * bdy);
constexpr uint32_t bdz = num_threads / (bdx * bdy);
const uint32_t smem_size =
2U * num_stages_smem * bdy * bdz * head_dim * sizeof(DTypeIn) +
2U * bdy * bdz * sizeof(float);
constexpr uint32_t tile_size_per_bdx = 8U / GROUP_SIZE;
const uint32_t smem_size = 2U * num_stages_smem * bdy * tile_size_per_bdx * bdz *
head_dim * sizeof(DTypeIn) +
2U * bdy * bdz * sizeof(float);

auto kernel =
SingleDecodeWithKVCacheKernel<QKV_LAYOUT, true, ROTARY_MODE, num_stages_smem,
vec_size, bdx, bdy, bdz, DTypeIn, DTypeOut>;
tile_size_per_bdx, vec_size, bdx, bdy, bdz,
DTypeIn, DTypeOut>;
int num_blocks_per_sm = 0;
int num_sm = 0;
int dev_id = 0;
Expand All @@ -793,8 +817,7 @@ cudaError_t SingleDecodeWithKVCacheWorkEstimation(uint32_t& tmp_size, uint32_t&
max_grid_size = uint32_t(num_blocks_per_sm) * uint32_t(num_sm);
uint32_t max_num_kv_chunks = max_grid_size / num_kv_heads;
uint32_t kv_chunk_size =
max((seq_len + max_num_kv_chunks - 1U) / max_num_kv_chunks,
uint32_t(std::sqrt(seq_len / GROUP_SIZE)) * 4);
max((seq_len + max_num_kv_chunks - 1U) / max_num_kv_chunks, 256);
uint32_t num_kv_chunks = (seq_len + kv_chunk_size - 1) / kv_chunk_size;
tmp_size = num_qo_heads * num_kv_chunks * (head_dim + 2);
})})})});
Expand Down Expand Up @@ -855,14 +878,16 @@ cudaError_t SingleDecodeWithKVCache(DTypeIn* q, DTypeIn* k, DTypeIn* v, DTypeOut
std::max(get_heuristic_num_threads(GROUP_SIZE, sizeof(DTypeIn)), bdx * bdy);
constexpr uint32_t bdz = num_threads / (bdx * bdy);
tensor_info_t<QKV_LAYOUT, GROUP_SIZE> info(1, seq_len, num_kv_heads, head_dim);
const uint32_t smem_size =
2U * num_stages_smem * bdy * bdz * head_dim * sizeof(DTypeIn) +
2U * bdy * bdz * sizeof(float);
if (seq_len <= 128U / uint32_t(std::sqrt(GROUP_SIZE)) || tmp == nullptr) {
constexpr uint32_t tile_size_per_bdx = 8U / GROUP_SIZE;
const uint32_t smem_size = 2U * num_stages_smem * bdy * tile_size_per_bdx * bdz *
head_dim * sizeof(DTypeIn) +
2U * bdy * bdz * sizeof(float);
if (seq_len <= 256 || tmp == nullptr) {
// no need to use cooperative kernel
auto kernel =
SingleDecodeWithKVCacheKernel<QKV_LAYOUT, false, ROTARY_MODE, num_stages_smem,
vec_size, bdx, bdy, bdz, DTypeIn, DTypeOut>;
tile_size_per_bdx, vec_size, bdx, bdy, bdz,
DTypeIn, DTypeOut>;
FLASHINFER_CUDA_CALL(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));

Expand All @@ -884,7 +909,8 @@ cudaError_t SingleDecodeWithKVCache(DTypeIn* q, DTypeIn* k, DTypeIn* v, DTypeOut
// use cooperative kernel
auto kernel =
SingleDecodeWithKVCacheKernel<QKV_LAYOUT, true, ROTARY_MODE, num_stages_smem,
vec_size, bdx, bdy, bdz, DTypeIn, DTypeOut>;
tile_size_per_bdx, vec_size, bdx, bdy, bdz,
DTypeIn, DTypeOut>;
FLASHINFER_CUDA_CALL(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));

Expand All @@ -899,8 +925,7 @@ cudaError_t SingleDecodeWithKVCache(DTypeIn* q, DTypeIn* k, DTypeIn* v, DTypeOut
uint32_t max_grid_size = uint32_t(num_blocks_per_sm) * uint32_t(num_sm);
uint32_t max_num_kv_chunks = max_grid_size / num_kv_heads;
uint32_t kv_chunk_size =
max((seq_len + max_num_kv_chunks - 1U) / max_num_kv_chunks,
uint32_t(std::sqrt(seq_len / GROUP_SIZE)) * 4);
max((seq_len + max_num_kv_chunks - 1U) / max_num_kv_chunks, 256);
dim3 nblks = dim3((seq_len + kv_chunk_size - 1) / kv_chunk_size, num_kv_heads);
if (nblks.x == 0 || nblks.y == 0) {
std::cerr << "Invalid kernel configuration: nblks=(" << nblks.x << ","
Expand Down

0 comments on commit 2a3d6d0

Please sign in to comment.