diff --git a/include/flashinfer/decode.cuh b/include/flashinfer/decode.cuh index 5c87e55e..0d57e72a 100644 --- a/include/flashinfer/decode.cuh +++ b/include/flashinfer/decode.cuh @@ -227,8 +227,8 @@ __device__ __forceinline__ void sync_state(state_t& st, float* smem, f * \param kv_chunk_size A integer indicates the kv-chunk size */ template + uint32_t tile_size_per_bdx, uint32_t vec_size, uint32_t bdx, uint32_t bdy, uint32_t bdz, + typename DTypeIn, typename DTypeOut> __global__ void SingleDecodeWithKVCacheKernel(DTypeIn* __restrict__ q, DTypeIn* __restrict__ k, DTypeIn* __restrict__ v, DTypeOut* __restrict__ o, float* __restrict__ tmp, @@ -249,8 +249,10 @@ __global__ void SingleDecodeWithKVCacheKernel(DTypeIn* __restrict__ q, DTypeIn* extern __shared__ uint8_t smem[]; DTypeIn* k_smem = (DTypeIn*)smem; - DTypeIn* v_smem = (DTypeIn*)(smem + num_stages_smem * bdy * bdz * head_dim * sizeof(DTypeIn)); - float* smem_md = (float*)(smem + 2 * num_stages_smem * bdy * bdz * head_dim * sizeof(DTypeIn)); + DTypeIn* v_smem = (DTypeIn*)(smem + num_stages_smem * bdy * tile_size_per_bdx * bdz * head_dim * + sizeof(DTypeIn)); + float* smem_md = (float*)(smem + 2 * num_stages_smem * bdy * tile_size_per_bdx * bdz * head_dim * + sizeof(DTypeIn)); uint32_t tx = threadIdx.x, ty = threadIdx.y, tz = threadIdx.z; vec_t q_vec; @@ -280,61 +282,81 @@ __global__ void SingleDecodeWithKVCacheKernel(DTypeIn* __restrict__ q, DTypeIn* constexpr uint32_t vec_bits = sizeof(DTypeIn) * vec_size * 8; #pragma unroll for (uint32_t iter = 0; iter < num_stages_smem; ++iter) { - cp_async::pred_load( - k_smem + ((iter * bdz + tz) * bdy + ty) * head_dim + tx * vec_size, - k + info.get_kv_elem_offset(producer_kv_idx_base + tz * bdy + ty, kv_head_idx, - tx * vec_size), - producer_kv_idx_base + tz * bdy + ty < chunk_end); + for (uint32_t j = 0; j < tile_size_per_bdx; ++j) { + cp_async::pred_load( + k_smem + (((iter * bdz + tz) * bdy + ty) * tile_size_per_bdx + j) * head_dim + + tx * vec_size, + k + info.get_kv_elem_offset( + producer_kv_idx_base + (tz * bdy + ty) * tile_size_per_bdx + j, kv_head_idx, + tx * vec_size), + producer_kv_idx_base + (tz * bdy + ty) * tile_size_per_bdx + j < chunk_end); + } cp_async::commit_group(); - cp_async::pred_load( - v_smem + ((iter * bdz + tz) * bdy + ty) * head_dim + tx * vec_size, - v + info.get_kv_elem_offset(producer_kv_idx_base + tz * bdy + ty, kv_head_idx, - tx * vec_size), - producer_kv_idx_base + tz * bdy + ty < chunk_end); + for (uint32_t j = 0; j < tile_size_per_bdx; ++j) { + cp_async::pred_load( + v_smem + (((iter * bdz + tz) * bdy + ty) * tile_size_per_bdx + j) * head_dim + + tx * vec_size, + v + info.get_kv_elem_offset( + producer_kv_idx_base + (tz * bdy + ty) * tile_size_per_bdx + j, kv_head_idx, + tx * vec_size), + producer_kv_idx_base + (tz * bdy + ty) * tile_size_per_bdx + j < chunk_end); + } cp_async::commit_group(); - producer_kv_idx_base += bdy * bdz; + producer_kv_idx_base += bdy * bdz * tile_size_per_bdx; } // pipelining k/v tiles loading and state updating uint32_t consumer_kv_idx_base = chunk_start, stage_idx = 0; state_t st_local; - float s[bdy]; + float s[bdy * tile_size_per_bdx]; -#pragma unroll 4 - for (uint32_t iter = 0; iter < (kv_chunk_size + bdy * bdz - 1) / (bdy * bdz); ++iter) { +#pragma unroll 2 + for (uint32_t iter = 0; + iter < (kv_chunk_size + tile_size_per_bdx * bdy * bdz - 1) / (tile_size_per_bdx * bdy * bdz); + ++iter) { // compute qk cp_async::wait_group<2 * num_stages_smem - 1>(); block.sync(); - compute_qk( - k_smem + (stage_idx * bdz + tz) * bdy * head_dim, stage_idx, q_vec, freq, - consumer_kv_idx_base, iter * bdy * bdz, kv_chunk_size, sm_scale, s, st_local); + compute_qk( + k_smem + (stage_idx * bdz + tz) * bdy * tile_size_per_bdx * head_dim, stage_idx, q_vec, + freq, consumer_kv_idx_base, iter * bdy * tile_size_per_bdx * bdz, kv_chunk_size, sm_scale, + s, st_local); block.sync(); // load k - cp_async::pred_load( - k_smem + ((stage_idx * bdz + tz) * bdy + ty) * head_dim + tx * vec_size, - k + info.get_kv_elem_offset(producer_kv_idx_base + tz * bdy + ty, kv_head_idx, - tx * vec_size), - producer_kv_idx_base + tz * bdy + ty < chunk_end); + for (uint32_t j = 0; j < tile_size_per_bdx; ++j) { + cp_async::pred_load( + k_smem + (((stage_idx * bdz + tz) * bdy + ty) * tile_size_per_bdx + j) * head_dim + + tx * vec_size, + k + info.get_kv_elem_offset( + producer_kv_idx_base + (tz * bdy + ty) * tile_size_per_bdx + j, kv_head_idx, + tx * vec_size), + producer_kv_idx_base + (tz * bdy + ty) * tile_size_per_bdx + j < chunk_end); + } cp_async::commit_group(); // update m/d/o state cp_async::wait_group<2 * num_stages_smem - 1>(); block.sync(); - update_local_state(v_smem + (stage_idx * bdz + tz) * bdy * head_dim, s, - stage_idx, st_local); + update_local_state( + v_smem + (stage_idx * bdz + tz) * bdy * tile_size_per_bdx * head_dim, s, stage_idx, + st_local); block.sync(); // load v - cp_async::pred_load( - v_smem + ((stage_idx * bdz + tz) * bdy + ty) * head_dim + tx * vec_size, - v + info.get_kv_elem_offset(producer_kv_idx_base + tz * bdy + ty, kv_head_idx, - tx * vec_size), - producer_kv_idx_base + tz * bdy + ty < chunk_end); + for (uint32_t j = 0; j < tile_size_per_bdx; ++j) { + cp_async::pred_load( + v_smem + (((stage_idx * bdz + tz) * bdy + ty) * tile_size_per_bdx + j) * head_dim + + tx * vec_size, + v + info.get_kv_elem_offset( + producer_kv_idx_base + (tz * bdy + ty) * tile_size_per_bdx + j, kv_head_idx, + tx * vec_size), + producer_kv_idx_base + (tz * bdy + ty) * tile_size_per_bdx + j < chunk_end); + } cp_async::commit_group(); stage_idx = (stage_idx + 1) % num_stages_smem; - producer_kv_idx_base += bdy * bdz; - consumer_kv_idx_base += bdy * bdz; + producer_kv_idx_base += tile_size_per_bdx * bdy * bdz; + consumer_kv_idx_base += tile_size_per_bdx * bdy * bdz; } cp_async::wait_group<0>(); block.sync(); @@ -353,7 +375,7 @@ __global__ void SingleDecodeWithKVCacheKernel(DTypeIn* __restrict__ q, DTypeIn* // sync global states if (kv_chunk_idx == 0) { state_t st_global; -#pragma unroll 4 +#pragma unroll 2 for (uint32_t iter = 0; iter < (num_kv_chunks + bdz - 1) / bdz; ++iter) { uint32_t kv_chunk_idx = iter * bdz + tz; if (kv_chunk_idx < num_kv_chunks) { @@ -775,13 +797,15 @@ cudaError_t SingleDecodeWithKVCacheWorkEstimation(uint32_t& tmp_size, uint32_t& constexpr uint32_t num_threads = std::max(get_heuristic_num_threads(GROUP_SIZE, sizeof(DTypeIn)), bdx * bdy); constexpr uint32_t bdz = num_threads / (bdx * bdy); - const uint32_t smem_size = - 2U * num_stages_smem * bdy * bdz * head_dim * sizeof(DTypeIn) + - 2U * bdy * bdz * sizeof(float); + constexpr uint32_t tile_size_per_bdx = 8U / GROUP_SIZE; + const uint32_t smem_size = 2U * num_stages_smem * bdy * tile_size_per_bdx * bdz * + head_dim * sizeof(DTypeIn) + + 2U * bdy * bdz * sizeof(float); auto kernel = SingleDecodeWithKVCacheKernel; + tile_size_per_bdx, vec_size, bdx, bdy, bdz, + DTypeIn, DTypeOut>; int num_blocks_per_sm = 0; int num_sm = 0; int dev_id = 0; @@ -793,8 +817,7 @@ cudaError_t SingleDecodeWithKVCacheWorkEstimation(uint32_t& tmp_size, uint32_t& max_grid_size = uint32_t(num_blocks_per_sm) * uint32_t(num_sm); uint32_t max_num_kv_chunks = max_grid_size / num_kv_heads; uint32_t kv_chunk_size = - max((seq_len + max_num_kv_chunks - 1U) / max_num_kv_chunks, - uint32_t(std::sqrt(seq_len / GROUP_SIZE)) * 4); + max((seq_len + max_num_kv_chunks - 1U) / max_num_kv_chunks, 256); uint32_t num_kv_chunks = (seq_len + kv_chunk_size - 1) / kv_chunk_size; tmp_size = num_qo_heads * num_kv_chunks * (head_dim + 2); })})})}); @@ -855,14 +878,16 @@ cudaError_t SingleDecodeWithKVCache(DTypeIn* q, DTypeIn* k, DTypeIn* v, DTypeOut std::max(get_heuristic_num_threads(GROUP_SIZE, sizeof(DTypeIn)), bdx * bdy); constexpr uint32_t bdz = num_threads / (bdx * bdy); tensor_info_t info(1, seq_len, num_kv_heads, head_dim); - const uint32_t smem_size = - 2U * num_stages_smem * bdy * bdz * head_dim * sizeof(DTypeIn) + - 2U * bdy * bdz * sizeof(float); - if (seq_len <= 128U / uint32_t(std::sqrt(GROUP_SIZE)) || tmp == nullptr) { + constexpr uint32_t tile_size_per_bdx = 8U / GROUP_SIZE; + const uint32_t smem_size = 2U * num_stages_smem * bdy * tile_size_per_bdx * bdz * + head_dim * sizeof(DTypeIn) + + 2U * bdy * bdz * sizeof(float); + if (seq_len <= 256 || tmp == nullptr) { // no need to use cooperative kernel auto kernel = SingleDecodeWithKVCacheKernel; + tile_size_per_bdx, vec_size, bdx, bdy, bdz, + DTypeIn, DTypeOut>; FLASHINFER_CUDA_CALL(cudaFuncSetAttribute( kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); @@ -884,7 +909,8 @@ cudaError_t SingleDecodeWithKVCache(DTypeIn* q, DTypeIn* k, DTypeIn* v, DTypeOut // use cooperative kernel auto kernel = SingleDecodeWithKVCacheKernel; + tile_size_per_bdx, vec_size, bdx, bdy, bdz, + DTypeIn, DTypeOut>; FLASHINFER_CUDA_CALL(cudaFuncSetAttribute( kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); @@ -899,8 +925,7 @@ cudaError_t SingleDecodeWithKVCache(DTypeIn* q, DTypeIn* k, DTypeIn* v, DTypeOut uint32_t max_grid_size = uint32_t(num_blocks_per_sm) * uint32_t(num_sm); uint32_t max_num_kv_chunks = max_grid_size / num_kv_heads; uint32_t kv_chunk_size = - max((seq_len + max_num_kv_chunks - 1U) / max_num_kv_chunks, - uint32_t(std::sqrt(seq_len / GROUP_SIZE)) * 4); + max((seq_len + max_num_kv_chunks - 1U) / max_num_kv_chunks, 256); dim3 nblks = dim3((seq_len + kv_chunk_size - 1) / kv_chunk_size, num_kv_heads); if (nblks.x == 0 || nblks.y == 0) { std::cerr << "Invalid kernel configuration: nblks=(" << nblks.x << ","