From b83b40825b53da1c94039503f772150af9c2fccb Mon Sep 17 00:00:00 2001 From: yzh119 Date: Sat, 9 Dec 2023 04:31:36 -0500 Subject: [PATCH] cache page pointers --- include/flashinfer/decode.cuh | 38 ++++++++++++------ include/flashinfer/page.cuh | 73 ++++++++++++++++------------------ include/flashinfer/prefill.cuh | 27 +++++++------ 3 files changed, 74 insertions(+), 64 deletions(-) diff --git a/include/flashinfer/decode.cuh b/include/flashinfer/decode.cuh index 1ee8f372..9b97346d 100644 --- a/include/flashinfer/decode.cuh +++ b/include/flashinfer/decode.cuh @@ -552,6 +552,8 @@ __global__ void BatchDecodeWithPagedKVCacheKernel( extern __shared__ uint8_t smem[]; DTypeIn* k_smem = (DTypeIn*)smem; DTypeIn* v_smem = (DTypeIn*)(smem + num_stages_smem * bdy * bdz * head_dim * sizeof(DTypeIn)); + DTypeIn** k_ptrs_smem = + (DTypeIn**)(smem + 2 * num_stages_smem * bdy * bdz * head_dim * sizeof(DTypeIn)); float* smem_md = (float*)(smem + 2 * num_stages_smem * bdy * bdz * head_dim * sizeof(DTypeIn)); const uint32_t tx = threadIdx.x, ty = threadIdx.y, tz = threadIdx.z; @@ -588,12 +590,17 @@ __global__ void BatchDecodeWithPagedKVCacheKernel( // preload k/v tiles uint32_t stage_idx = 0; constexpr uint32_t vec_bits = sizeof(DTypeIn) * vec_size * 8; + const IdType last_indptr = paged_kv.indptr[paged_kv.batch_size]; + + static_assert(num_stages_smem <= bdx); + k_ptrs_smem[(tz * bdy + ty) * bdx + tx] = paged_kv.protective_get_k_ptr( + cur_page_indptr_begin + ((tz * bdy + ty) * bdx + tx) / paged_kv.page_size, kv_head_idx, + ((tz * bdy + ty) * bdx + tx) % paged_kv.page_size, 0, last_indptr); + block.sync(); #pragma unroll for (uint32_t iter = 0; iter < num_stages_smem; ++iter) { - DTypeIn* k_ptr = paged_kv.template get_k_ptr( - cur_page_indptr_begin + ((iter * bdz + tz) * bdy + ty) / paged_kv.page_size, kv_head_idx, - ((iter * bdz + tz) * bdy + ty) % paged_kv.page_size, tx * vec_size); + DTypeIn* k_ptr = k_ptrs_smem[(iter * bdz + tz) * bdy + ty] + tx * vec_size; DTypeIn* v_ptr = k_ptr + paged_kv.kv_offset_delta(); cp_async::pred_load( k_smem + ((stage_idx * bdz + tz) * bdy + ty) * head_dim + tx * vec_size, k_ptr, @@ -611,6 +618,15 @@ __global__ void BatchDecodeWithPagedKVCacheKernel( #pragma unroll 4 for (uint32_t iter = 0; iter < (kv_chunk_len + bdy * bdz - 1) / (bdy * bdz); ++iter) { + if ((iter + num_stages_smem) % bdx == 0) { + k_ptrs_smem[(tz * bdy + ty) * bdx + tx] = paged_kv.protective_get_k_ptr( + cur_page_indptr_begin + + ((iter + num_stages_smem) * bdy * bdz + (tz * bdy + ty) * bdx + tx) / + paged_kv.page_size, + kv_head_idx, + ((iter + num_stages_smem) * bdy * bdz + (tz * bdy + ty) * bdx + tx) % paged_kv.page_size, + 0, last_indptr); + } const bool producer_pred_guard = ((iter + num_stages_smem) * bdz + tz) * bdy + ty < kv_chunk_len; // compute qk @@ -621,11 +637,8 @@ __global__ void BatchDecodeWithPagedKVCacheKernel( cur_chunk_start + iter * bdy * bdz, iter * bdy * bdz, kv_chunk_len, sm_scale, x, s); block.sync(); - DTypeIn* k_ptr = paged_kv.template get_k_ptr( - cur_page_indptr_begin + - (((iter + num_stages_smem) * bdz + tz) * bdy + ty) / paged_kv.page_size, - kv_head_idx, (((iter + num_stages_smem) * bdz + tz) * bdy + ty) % paged_kv.page_size, - tx * vec_size); + DTypeIn* k_ptr = + k_ptrs_smem[(((iter + num_stages_smem) % bdx) * bdz + tz) * bdy + ty] + tx * vec_size; DTypeIn* v_ptr = k_ptr + paged_kv.kv_offset_delta(); // load k tiles cp_async::pred_load( @@ -1054,7 +1067,7 @@ cudaError_t BatchDecodeWithPagedKVCacheWorkEstimation( {SWITCH_ROTARY_MODE( rotary_mode, ROTARY_MODE, {SWITCH_PAGE_SIZE(paged_kv.page_size, PAGE_SIZE, { constexpr uint32_t vec_size = std::max(16UL / sizeof(DTypeIn), HEAD_DIM / 32UL); - constexpr uint32_t num_stages_smem = 2U; + constexpr uint32_t num_stages_smem = 4U; constexpr uint32_t bdx = HEAD_DIM / vec_size; static_assert(bdx <= 32); constexpr uint32_t bdy = GROUP_SIZE; @@ -1062,7 +1075,7 @@ cudaError_t BatchDecodeWithPagedKVCacheWorkEstimation( constexpr uint32_t bdz = num_threads / (bdx * bdy); const uint32_t smem_size = 2 * num_stages_smem * bdy * bdz * head_dim * sizeof(DTypeIn) + - 2 * bdy * bdz * sizeof(float); + std::max(num_threads * sizeof(DTypeIn*), 2 * bdy * bdz * sizeof(float)); auto cooperative_kernel = BatchDecodeWithPagedKVCacheKernel - __device__ __forceinline__ DType* get_k_ptr(uint32_t page_iter, uint32_t head_idx, + __device__ __forceinline__ DType* get_k_ptr(IdType page_iter, uint32_t head_idx, uint32_t entry_idx, uint32_t feat_idx) const { if constexpr (page_storage == PageStorage::kIndices) { - if constexpr (access_mode == AccessMode::kProtective) { - if (page_iter < __ldg(indptr + batch_size)) { - return data + - get_k_elem_offset(__ldg(indices + page_iter), head_idx, entry_idx, feat_idx); - } else { - return data; - } - } else { + return data + get_k_elem_offset(__ldg(indices + page_iter), head_idx, entry_idx, feat_idx); + } else { + return __ldg(ptrs + page_iter) + get_k_elem_offset_in_page(head_idx, entry_idx, feat_idx); + } + } + + __device__ __forceinline__ DType* protective_get_k_ptr(IdType page_iter, uint32_t head_idx, + uint32_t entry_idx, uint32_t feat_idx, + IdType last_indptr) const { + if constexpr (page_storage == PageStorage::kIndices) { + if (page_iter < last_indptr) { return data + get_k_elem_offset(__ldg(indices + page_iter), head_idx, entry_idx, feat_idx); + } else { + return data; } } else { - if constexpr (access_mode == AccessMode::kProtective) { - if (page_iter < __ldg(indptr + batch_size)) { - return __ldg(ptrs + page_iter) + get_k_elem_offset_in_page(head_idx, entry_idx, feat_idx); - } else { - return __ldg(ptrs); - } - } else { + if (page_iter < last_indptr) { return __ldg(ptrs + page_iter) + get_k_elem_offset_in_page(head_idx, entry_idx, feat_idx); + } else { + return __ldg(ptrs); } } } - template - __device__ __forceinline__ DType* get_v_ptr(uint32_t page_iter, uint32_t head_idx, + __device__ __forceinline__ DType* get_v_ptr(IdType page_iter, uint32_t head_idx, uint32_t entry_idx, uint32_t feat_idx) const { if constexpr (page_storage == PageStorage::kIndices) { - if constexpr (access_mode == AccessMode::kProtective) { - if (page_iter < __ldg(indptr + batch_size)) { - return data + - get_v_elem_offset(__ldg(indices + page_iter), head_idx, entry_idx, feat_idx); - } else { - return data; - } - } else { + return data + get_v_elem_offset(__ldg(indices + page_iter), head_idx, entry_idx, feat_idx); + } else { + return __ldg(ptrs + page_iter) + get_v_elem_offset_in_page(head_idx, entry_idx, feat_idx); + } + } + + __device__ __forceinline__ DType* protective_get_v_ptr(IdType page_iter, uint32_t head_idx, + uint32_t entry_idx, uint32_t feat_idx, + IdType last_indptr) const { + if constexpr (page_storage == PageStorage::kIndices) { + if (page_iter < last_indptr) { return data + get_v_elem_offset(__ldg(indices + page_iter), head_idx, entry_idx, feat_idx); + } else { + return data; } } else { - if constexpr (access_mode == AccessMode::kProtective) { - if (page_iter < __ldg(indptr + batch_size)) { - return __ldg(ptrs + page_iter) + get_v_elem_offset_in_page(head_idx, entry_idx, feat_idx); - } else { - return __ldg(ptrs); - } - } else { + if (page_iter < last_indptr) { return __ldg(ptrs + page_iter) + get_v_elem_offset_in_page(head_idx, entry_idx, feat_idx); + } else { + return __ldg(ptrs); } } } diff --git a/include/flashinfer/prefill.cuh b/include/flashinfer/prefill.cuh index 7f579f39..0e29f395 100644 --- a/include/flashinfer/prefill.cuh +++ b/include/flashinfer/prefill.cuh @@ -140,7 +140,7 @@ __device__ __forceinline__ void page_produce_kv(smem_t smem, uint32_t* smem_offs paged_kv_t& paged_kv, const uint32_t kv_idx_base, const uint32_t page_iter_base, - const uint32_t kv_len) { + const uint32_t kv_len, const IdType last_indptr) { constexpr SharedMemFillMode fill_mode = produce_v ? SharedMemFillMode::kFillZero : SharedMemFillMode::kNoFill; constexpr uint32_t head_dim = num_frags_y * 16; @@ -154,10 +154,10 @@ __device__ __forceinline__ void page_produce_kv(smem_t smem, uint32_t* smem_offs const uint32_t page_iter = page_iter_base + (4 * num_warps * i + ty * 4) / page_size; const uint32_t entry_idx = (4 * num_warps * i + ty * 4) % page_size + tx / 8; DType* gptr = - produce_v ? (paged_kv.template get_v_ptr( - page_iter, kv_head_idx, entry_idx, (tx % 8) * cell_capacity())) - : (paged_kv.template get_k_ptr( - page_iter, kv_head_idx, entry_idx, (tx % 8) * cell_capacity())); + produce_v ? paged_kv.protective_get_v_ptr(page_iter, kv_head_idx, entry_idx, + (tx % 8) * cell_capacity(), last_indptr) + : paged_kv.protective_get_k_ptr(page_iter, kv_head_idx, entry_idx, + (tx % 8) * cell_capacity(), last_indptr); #pragma unroll for (uint32_t j = 0; j < num_frags_y / 4; ++j) { smem.load_128b_async(*smem_offset, gptr, kv_idx < kv_len); @@ -174,10 +174,10 @@ __device__ __forceinline__ void page_produce_kv(smem_t smem, uint32_t* smem_offs const uint32_t page_iter = page_iter_base + (4 * num_warps * i + ty * 4 + tx / 8) / page_size; const uint32_t entry_idx = (4 * num_warps * i + ty * 4 + tx / 8) % page_size; DType* gptr = - produce_v ? (paged_kv.template get_v_ptr( - page_iter, kv_head_idx, entry_idx, (tx % 8) * cell_capacity())) - : (paged_kv.template get_k_ptr( - page_iter, kv_head_idx, entry_idx, (tx % 8) * cell_capacity())); + produce_v ? paged_kv.protective_get_v_ptr(page_iter, kv_head_idx, entry_idx, + (tx % 8) * cell_capacity(), last_indptr) + : paged_kv.protective_get_k_ptr(page_iter, kv_head_idx, entry_idx, + (tx % 8) * cell_capacity(), last_indptr); #pragma unroll for (uint32_t j = 0; j < num_frags_y / 4; ++j) { smem.load_128b_async(*smem_offset, gptr, kv_idx < kv_len); @@ -1027,13 +1027,14 @@ __global__ void BatchPrefillWithPagedKVCacheKernel( smem_t::get_permuted_offset(tx % 16, tx / 16), kv_smem_offset_w = smem_t::get_permuted_offset(ty * 4 + tx / 8, tx % 8); + const IdType last_indptr = paged_kv.indptr[paged_kv.batch_size]; uint32_t page_iter_base = paged_kv.indptr[request_idx]; page_produce_kv( - k_smem, &kv_smem_offset_w, paged_kv, kv_idx_base, page_iter_base, kv_len); + k_smem, &kv_smem_offset_w, paged_kv, kv_idx_base, page_iter_base, kv_len, last_indptr); cp_async::commit_group(); page_produce_kv( - v_smem, &kv_smem_offset_w, paged_kv, kv_idx_base, page_iter_base, kv_len); + v_smem, &kv_smem_offset_w, paged_kv, kv_idx_base, page_iter_base, kv_len, last_indptr); cp_async::commit_group(); const uint32_t num_iterations = @@ -1077,7 +1078,7 @@ __global__ void BatchPrefillWithPagedKVCacheKernel( page_iter_base += 16 * num_frags_z / page_size; kv_idx_base += 16 * num_frags_z; page_produce_kv( - k_smem, &kv_smem_offset_w, paged_kv, kv_idx_base, page_iter_base, kv_len); + k_smem, &kv_smem_offset_w, paged_kv, kv_idx_base, page_iter_base, kv_len, last_indptr); cp_async::commit_group(); cp_async::wait_group<1>(); block.sync(); @@ -1088,7 +1089,7 @@ __global__ void BatchPrefillWithPagedKVCacheKernel( block.sync(); page_produce_kv( - v_smem, &kv_smem_offset_w, paged_kv, kv_idx_base, page_iter_base, kv_len); + v_smem, &kv_smem_offset_w, paged_kv, kv_idx_base, page_iter_base, kv_len, last_indptr); cp_async::commit_group(); } cp_async::wait_group<0>();