Skip to content

Commit

Permalink
cache page pointers
Browse files Browse the repository at this point in the history
  • Loading branch information
yzh119 committed Dec 9, 2023
1 parent 11364ca commit b83b408
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 64 deletions.
38 changes: 26 additions & 12 deletions include/flashinfer/decode.cuh
Expand Up @@ -552,6 +552,8 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(
extern __shared__ uint8_t smem[];
DTypeIn* k_smem = (DTypeIn*)smem;
DTypeIn* v_smem = (DTypeIn*)(smem + num_stages_smem * bdy * bdz * head_dim * sizeof(DTypeIn));
DTypeIn** k_ptrs_smem =
(DTypeIn**)(smem + 2 * num_stages_smem * bdy * bdz * head_dim * sizeof(DTypeIn));
float* smem_md = (float*)(smem + 2 * num_stages_smem * bdy * bdz * head_dim * sizeof(DTypeIn));

const uint32_t tx = threadIdx.x, ty = threadIdx.y, tz = threadIdx.z;
Expand Down Expand Up @@ -588,12 +590,17 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(
// preload k/v tiles
uint32_t stage_idx = 0;
constexpr uint32_t vec_bits = sizeof(DTypeIn) * vec_size * 8;
const IdType last_indptr = paged_kv.indptr[paged_kv.batch_size];

static_assert(num_stages_smem <= bdx);
k_ptrs_smem[(tz * bdy + ty) * bdx + tx] = paged_kv.protective_get_k_ptr(
cur_page_indptr_begin + ((tz * bdy + ty) * bdx + tx) / paged_kv.page_size, kv_head_idx,
((tz * bdy + ty) * bdx + tx) % paged_kv.page_size, 0, last_indptr);
block.sync();

#pragma unroll
for (uint32_t iter = 0; iter < num_stages_smem; ++iter) {
DTypeIn* k_ptr = paged_kv.template get_k_ptr<AccessMode::kProtective>(
cur_page_indptr_begin + ((iter * bdz + tz) * bdy + ty) / paged_kv.page_size, kv_head_idx,
((iter * bdz + tz) * bdy + ty) % paged_kv.page_size, tx * vec_size);
DTypeIn* k_ptr = k_ptrs_smem[(iter * bdz + tz) * bdy + ty] + tx * vec_size;
DTypeIn* v_ptr = k_ptr + paged_kv.kv_offset_delta();
cp_async::pred_load<vec_bits, PrefetchMode::kPrefetch, SharedMemFillMode::kNoFill>(
k_smem + ((stage_idx * bdz + tz) * bdy + ty) * head_dim + tx * vec_size, k_ptr,
Expand All @@ -611,6 +618,15 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(

#pragma unroll 4
for (uint32_t iter = 0; iter < (kv_chunk_len + bdy * bdz - 1) / (bdy * bdz); ++iter) {
if ((iter + num_stages_smem) % bdx == 0) {
k_ptrs_smem[(tz * bdy + ty) * bdx + tx] = paged_kv.protective_get_k_ptr(
cur_page_indptr_begin +
((iter + num_stages_smem) * bdy * bdz + (tz * bdy + ty) * bdx + tx) /
paged_kv.page_size,
kv_head_idx,
((iter + num_stages_smem) * bdy * bdz + (tz * bdy + ty) * bdx + tx) % paged_kv.page_size,
0, last_indptr);
}
const bool producer_pred_guard =
((iter + num_stages_smem) * bdz + tz) * bdy + ty < kv_chunk_len;
// compute qk
Expand All @@ -621,11 +637,8 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(
cur_chunk_start + iter * bdy * bdz, iter * bdy * bdz, kv_chunk_len, sm_scale, x, s);
block.sync();

DTypeIn* k_ptr = paged_kv.template get_k_ptr<AccessMode::kProtective>(
cur_page_indptr_begin +
(((iter + num_stages_smem) * bdz + tz) * bdy + ty) / paged_kv.page_size,
kv_head_idx, (((iter + num_stages_smem) * bdz + tz) * bdy + ty) % paged_kv.page_size,
tx * vec_size);
DTypeIn* k_ptr =
k_ptrs_smem[(((iter + num_stages_smem) % bdx) * bdz + tz) * bdy + ty] + tx * vec_size;
DTypeIn* v_ptr = k_ptr + paged_kv.kv_offset_delta();
// load k tiles
cp_async::pred_load<vec_bits, PrefetchMode::kPrefetch, SharedMemFillMode::kNoFill>(
Expand Down Expand Up @@ -1054,15 +1067,15 @@ cudaError_t BatchDecodeWithPagedKVCacheWorkEstimation(
{SWITCH_ROTARY_MODE(
rotary_mode, ROTARY_MODE, {SWITCH_PAGE_SIZE(paged_kv.page_size, PAGE_SIZE, {
constexpr uint32_t vec_size = std::max(16UL / sizeof(DTypeIn), HEAD_DIM / 32UL);
constexpr uint32_t num_stages_smem = 2U;
constexpr uint32_t num_stages_smem = 4U;
constexpr uint32_t bdx = HEAD_DIM / vec_size;
static_assert(bdx <= 32);
constexpr uint32_t bdy = GROUP_SIZE;
constexpr uint32_t num_threads = std::max(128U, bdx * bdy);
constexpr uint32_t bdz = num_threads / (bdx * bdy);
const uint32_t smem_size =
2 * num_stages_smem * bdy * bdz * head_dim * sizeof(DTypeIn) +
2 * bdy * bdz * sizeof(float);
std::max(num_threads * sizeof(DTypeIn*), 2 * bdy * bdz * sizeof(float));

auto cooperative_kernel =
BatchDecodeWithPagedKVCacheKernel<true, ROTARY_MODE, PAGE_SIZE, num_stages_smem,
Expand Down Expand Up @@ -1117,14 +1130,15 @@ cudaError_t BatchDecodeWithPagedKVCacheDispatched(
const uint32_t batch_size = paged_kv.batch_size;

constexpr uint32_t vec_size = std::max(16UL / sizeof(DTypeIn), HEAD_DIM / 32UL);
constexpr uint32_t num_stages_smem = 2U;
constexpr uint32_t num_stages_smem = 4U;
constexpr uint32_t bdx = HEAD_DIM / vec_size;
static_assert(bdx <= 32);
constexpr uint32_t bdy = GROUP_SIZE;
constexpr uint32_t num_threads = std::max(128U, bdx * bdy);
constexpr uint32_t bdz = num_threads / (bdx * bdy);
const uint32_t smem_size =
2 * num_stages_smem * bdy * bdz * HEAD_DIM * sizeof(DTypeIn) + 2 * bdy * bdz * sizeof(float);
2 * num_stages_smem * bdy * bdz * HEAD_DIM * sizeof(DTypeIn) +
std::max(num_threads * sizeof(DTypeIn*), 2 * bdy * bdz * sizeof(float));

if (tmp == nullptr) {
// do not use cooperative kernel
Expand Down
73 changes: 34 additions & 39 deletions include/flashinfer/page.cuh
Expand Up @@ -29,11 +29,6 @@ enum class PageStorage {
kPointer = 1U, // Store the pointers to each active page.
};

enum class AccessMode {
kProtective = 0U, // Check whether page_iter is out of range
kNonProtective = 1U // Do not check whether page_iter is out of range
};

/*!
* \brief Paged key-value cache
* \tparam page_storage Whether to store indices or pointers of each active page
Expand Down Expand Up @@ -285,56 +280,56 @@ struct paged_kv_t {
return num_heads * page_size * head_dim;
}

template <AccessMode access_mode = AccessMode::kNonProtective>
__device__ __forceinline__ DType* get_k_ptr(uint32_t page_iter, uint32_t head_idx,
__device__ __forceinline__ DType* get_k_ptr(IdType page_iter, uint32_t head_idx,
uint32_t entry_idx, uint32_t feat_idx) const {
if constexpr (page_storage == PageStorage::kIndices) {
if constexpr (access_mode == AccessMode::kProtective) {
if (page_iter < __ldg(indptr + batch_size)) {
return data +
get_k_elem_offset(__ldg(indices + page_iter), head_idx, entry_idx, feat_idx);
} else {
return data;
}
} else {
return data + get_k_elem_offset(__ldg(indices + page_iter), head_idx, entry_idx, feat_idx);
} else {
return __ldg(ptrs + page_iter) + get_k_elem_offset_in_page(head_idx, entry_idx, feat_idx);
}
}

__device__ __forceinline__ DType* protective_get_k_ptr(IdType page_iter, uint32_t head_idx,
uint32_t entry_idx, uint32_t feat_idx,
IdType last_indptr) const {
if constexpr (page_storage == PageStorage::kIndices) {
if (page_iter < last_indptr) {
return data + get_k_elem_offset(__ldg(indices + page_iter), head_idx, entry_idx, feat_idx);
} else {
return data;
}
} else {
if constexpr (access_mode == AccessMode::kProtective) {
if (page_iter < __ldg(indptr + batch_size)) {
return __ldg(ptrs + page_iter) + get_k_elem_offset_in_page(head_idx, entry_idx, feat_idx);
} else {
return __ldg(ptrs);
}
} else {
if (page_iter < last_indptr) {
return __ldg(ptrs + page_iter) + get_k_elem_offset_in_page(head_idx, entry_idx, feat_idx);
} else {
return __ldg(ptrs);
}
}
}

template <AccessMode access_mode = AccessMode::kNonProtective>
__device__ __forceinline__ DType* get_v_ptr(uint32_t page_iter, uint32_t head_idx,
__device__ __forceinline__ DType* get_v_ptr(IdType page_iter, uint32_t head_idx,
uint32_t entry_idx, uint32_t feat_idx) const {
if constexpr (page_storage == PageStorage::kIndices) {
if constexpr (access_mode == AccessMode::kProtective) {
if (page_iter < __ldg(indptr + batch_size)) {
return data +
get_v_elem_offset(__ldg(indices + page_iter), head_idx, entry_idx, feat_idx);
} else {
return data;
}
} else {
return data + get_v_elem_offset(__ldg(indices + page_iter), head_idx, entry_idx, feat_idx);
} else {
return __ldg(ptrs + page_iter) + get_v_elem_offset_in_page(head_idx, entry_idx, feat_idx);
}
}

__device__ __forceinline__ DType* protective_get_v_ptr(IdType page_iter, uint32_t head_idx,
uint32_t entry_idx, uint32_t feat_idx,
IdType last_indptr) const {
if constexpr (page_storage == PageStorage::kIndices) {
if (page_iter < last_indptr) {
return data + get_v_elem_offset(__ldg(indices + page_iter), head_idx, entry_idx, feat_idx);
} else {
return data;
}
} else {
if constexpr (access_mode == AccessMode::kProtective) {
if (page_iter < __ldg(indptr + batch_size)) {
return __ldg(ptrs + page_iter) + get_v_elem_offset_in_page(head_idx, entry_idx, feat_idx);
} else {
return __ldg(ptrs);
}
} else {
if (page_iter < last_indptr) {
return __ldg(ptrs + page_iter) + get_v_elem_offset_in_page(head_idx, entry_idx, feat_idx);
} else {
return __ldg(ptrs);
}
}
}
Expand Down
27 changes: 14 additions & 13 deletions include/flashinfer/prefill.cuh
Expand Up @@ -140,7 +140,7 @@ __device__ __forceinline__ void page_produce_kv(smem_t smem, uint32_t* smem_offs
paged_kv_t<page_storage, DType, IdType>& paged_kv,
const uint32_t kv_idx_base,
const uint32_t page_iter_base,
const uint32_t kv_len) {
const uint32_t kv_len, const IdType last_indptr) {
constexpr SharedMemFillMode fill_mode =
produce_v ? SharedMemFillMode::kFillZero : SharedMemFillMode::kNoFill;
constexpr uint32_t head_dim = num_frags_y * 16;
Expand All @@ -154,10 +154,10 @@ __device__ __forceinline__ void page_produce_kv(smem_t smem, uint32_t* smem_offs
const uint32_t page_iter = page_iter_base + (4 * num_warps * i + ty * 4) / page_size;
const uint32_t entry_idx = (4 * num_warps * i + ty * 4) % page_size + tx / 8;
DType* gptr =
produce_v ? (paged_kv.template get_v_ptr<AccessMode::kProtective>(
page_iter, kv_head_idx, entry_idx, (tx % 8) * cell_capacity<DType>()))
: (paged_kv.template get_k_ptr<AccessMode::kProtective>(
page_iter, kv_head_idx, entry_idx, (tx % 8) * cell_capacity<DType>()));
produce_v ? paged_kv.protective_get_v_ptr(page_iter, kv_head_idx, entry_idx,
(tx % 8) * cell_capacity<DType>(), last_indptr)
: paged_kv.protective_get_k_ptr(page_iter, kv_head_idx, entry_idx,
(tx % 8) * cell_capacity<DType>(), last_indptr);
#pragma unroll
for (uint32_t j = 0; j < num_frags_y / 4; ++j) {
smem.load_128b_async<fill_mode>(*smem_offset, gptr, kv_idx < kv_len);
Expand All @@ -174,10 +174,10 @@ __device__ __forceinline__ void page_produce_kv(smem_t smem, uint32_t* smem_offs
const uint32_t page_iter = page_iter_base + (4 * num_warps * i + ty * 4 + tx / 8) / page_size;
const uint32_t entry_idx = (4 * num_warps * i + ty * 4 + tx / 8) % page_size;
DType* gptr =
produce_v ? (paged_kv.template get_v_ptr<AccessMode::kProtective>(
page_iter, kv_head_idx, entry_idx, (tx % 8) * cell_capacity<DType>()))
: (paged_kv.template get_k_ptr<AccessMode::kProtective>(
page_iter, kv_head_idx, entry_idx, (tx % 8) * cell_capacity<DType>()));
produce_v ? paged_kv.protective_get_v_ptr(page_iter, kv_head_idx, entry_idx,
(tx % 8) * cell_capacity<DType>(), last_indptr)
: paged_kv.protective_get_k_ptr(page_iter, kv_head_idx, entry_idx,
(tx % 8) * cell_capacity<DType>(), last_indptr);
#pragma unroll
for (uint32_t j = 0; j < num_frags_y / 4; ++j) {
smem.load_128b_async<fill_mode>(*smem_offset, gptr, kv_idx < kv_len);
Expand Down Expand Up @@ -1027,13 +1027,14 @@ __global__ void BatchPrefillWithPagedKVCacheKernel(
smem_t::get_permuted_offset<num_cells_per_in_channel>(tx % 16, tx / 16),
kv_smem_offset_w =
smem_t::get_permuted_offset<num_cells_per_in_channel>(ty * 4 + tx / 8, tx % 8);
const IdType last_indptr = paged_kv.indptr[paged_kv.batch_size];

uint32_t page_iter_base = paged_kv.indptr[request_idx];
page_produce_kv<false, page_size, num_warps, num_frags_y, num_frags_z>(
k_smem, &kv_smem_offset_w, paged_kv, kv_idx_base, page_iter_base, kv_len);
k_smem, &kv_smem_offset_w, paged_kv, kv_idx_base, page_iter_base, kv_len, last_indptr);
cp_async::commit_group();
page_produce_kv<true, page_size, num_warps, num_frags_y, num_frags_z>(
v_smem, &kv_smem_offset_w, paged_kv, kv_idx_base, page_iter_base, kv_len);
v_smem, &kv_smem_offset_w, paged_kv, kv_idx_base, page_iter_base, kv_len, last_indptr);
cp_async::commit_group();

const uint32_t num_iterations =
Expand Down Expand Up @@ -1077,7 +1078,7 @@ __global__ void BatchPrefillWithPagedKVCacheKernel(
page_iter_base += 16 * num_frags_z / page_size;
kv_idx_base += 16 * num_frags_z;
page_produce_kv<false, page_size, num_warps, num_frags_y, num_frags_z>(
k_smem, &kv_smem_offset_w, paged_kv, kv_idx_base, page_iter_base, kv_len);
k_smem, &kv_smem_offset_w, paged_kv, kv_idx_base, page_iter_base, kv_len, last_indptr);
cp_async::commit_group();
cp_async::wait_group<1>();
block.sync();
Expand All @@ -1088,7 +1089,7 @@ __global__ void BatchPrefillWithPagedKVCacheKernel(

block.sync();
page_produce_kv<true, page_size, num_warps, num_frags_y, num_frags_z>(
v_smem, &kv_smem_offset_w, paged_kv, kv_idx_base, page_iter_base, kv_len);
v_smem, &kv_smem_offset_w, paged_kv, kv_idx_base, page_iter_base, kv_len, last_indptr);
cp_async::commit_group();
}
cp_async::wait_group<0>();
Expand Down

0 comments on commit b83b408

Please sign in to comment.