Skip to content

Commit

Permalink
feat: initial support of logits hook (#298)
Browse files Browse the repository at this point in the history
Implement the #257 feature.
  • Loading branch information
yzh119 committed Jun 14, 2024
1 parent 5602659 commit ab1e2ad
Show file tree
Hide file tree
Showing 36 changed files with 1,239 additions and 744 deletions.
252 changes: 132 additions & 120 deletions CMakeLists.txt

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions cmake/config.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ set(FLASHINFER_DISTRIBUTED ON)
# The following configurations can impact the binary
# size of the generated library
set(FLASHINFER_GEN_GROUP_SIZES 1 4 6 8)
set(FLASHINFER_GEN_LOGITS_POST_HOOKS 0)
set(FLASHINFER_GEN_PAGE_SIZES 1 16 32)
set(FLASHINFER_GEN_HEAD_DIMS 64 128 256)
set(FLASHINFER_GEN_KV_LAYOUTS 0 1)
Expand Down
88 changes: 50 additions & 38 deletions include/flashinfer/attention/decode.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
#include "../utils.cuh"
#include "../vec_dtypes.cuh"
#include "cascade.cuh"
#include "logits_post_hook.cuh"
#include "state.cuh"

namespace flashinfer {
Expand All @@ -48,6 +49,7 @@ namespace {

/*!
* \brief Load k tile from smem and compute qk
* \tparam logits_post_hook The logits post hook used in the kernel
* \tparam pos_encoding_mode The positional encoding mode used in the kernel
* \tparam head_dim A template integer indicates the head dimension
* \tparam vec_size A template integer indicates the vector size
Expand All @@ -65,8 +67,8 @@ namespace {
* \param s A float indicates the thread-local result of qk
* \param st The self-attention state to be updated
*/
template <PosEncodingMode pos_encoding_mode, uint32_t vec_size, uint32_t bdx, uint32_t tile_size,
typename T>
template <LogitsPostHook logits_post_hook, PosEncodingMode pos_encoding_mode, uint32_t vec_size,
uint32_t bdx, uint32_t tile_size, typename T>
__device__ __forceinline__ void compute_qk(const T* smem, uint32_t compute_stage_idx,
const vec_t<float, vec_size>& q_vec,
const vec_t<float, vec_size>& freq, uint32_t kv_idx_base,
Expand Down Expand Up @@ -96,6 +98,7 @@ __device__ __forceinline__ void compute_qk(const T* smem, uint32_t compute_stage
s[j] += math::shfl_xor_sync(s[j], offset);
}
s[j] = (iter_base + tz * tile_size + j < iter_bound) ? s[j] : -5e4;
s[j] = apply_logits_post_hook<logits_post_hook>(s[j]);
if constexpr (pos_encoding_mode == PosEncodingMode::kALiBi) {
s[j] += alibi_slope * float(int(kv_idx_base + tz * tile_size + j) - q_offset);
}
Expand Down Expand Up @@ -178,6 +181,7 @@ __device__ __forceinline__ void sync_state(state_t<vec_size>& st, float* smem, f

/*!
* \brief FlashAttention decoding cuda kernel with kv-cache for a single request
* \tparam logits_post_hook The logits post hook used in the kernel
* \tparam kv_layout The layout of k/v matrices (NHD or HND)
* \tparam partition_kv Whether to partition kv-cache on sequence length dimension or not
* \tparam pos_encoding_mode The positional encoding mode
Expand All @@ -202,9 +206,10 @@ __device__ __forceinline__ void sync_state(state_t<vec_size>& st, float* smem, f
* of "theta" used in RoPE (Rotary Positional Embeddings)
* \param kv_chunk_size A integer indicates the kv-chunk size
*/
template <QKVLayout kv_layout, bool partition_kv, PosEncodingMode pos_encoding_mode,
uint32_t num_stages_smem, uint32_t tile_size_per_bdx, uint32_t vec_size, uint32_t bdx,
uint32_t bdy, uint32_t bdz, typename DTypeQ, typename DTypeKV, typename DTypeOut>
template <LogitsPostHook logits_post_hook, QKVLayout kv_layout, bool partition_kv,
PosEncodingMode pos_encoding_mode, uint32_t num_stages_smem, uint32_t tile_size_per_bdx,
uint32_t vec_size, uint32_t bdx, uint32_t bdy, uint32_t bdz, typename DTypeQ,
typename DTypeKV, typename DTypeOut>
__global__ void SingleDecodeWithKVCacheKernel(DTypeQ* __restrict__ q, DTypeKV* __restrict__ k,
DTypeKV* __restrict__ v, DTypeOut* __restrict__ o,
DTypeOut* __restrict__ tmp,
Expand All @@ -213,7 +218,7 @@ __global__ void SingleDecodeWithKVCacheKernel(DTypeQ* __restrict__ q, DTypeKV* _
float rope_rcp_theta, uint32_t kv_chunk_size) {
auto block = cg::this_thread_block();
auto grid = cg::this_grid();
sm_scale *= math::log2e;
sm_scale *= (logits_post_hook == LogitsPostHook::kNone ? math::log2e : 1.f / 30.f);

constexpr uint32_t head_dim = bdx * vec_size;
uint32_t kv_head_idx = blockIdx.y;
Expand Down Expand Up @@ -297,7 +302,7 @@ __global__ void SingleDecodeWithKVCacheKernel(DTypeQ* __restrict__ q, DTypeKV* _
// compute qk
cp_async::wait_group<2 * num_stages_smem - 1>();
block.sync();
compute_qk<pos_encoding_mode, vec_size, bdx, bdy * tile_size_per_bdx>(
compute_qk<logits_post_hook, pos_encoding_mode, vec_size, bdx, bdy * tile_size_per_bdx>(
k_smem + (stage_idx * bdz + tz) * bdy * tile_size_per_bdx * head_dim, stage_idx, q_vec,
freq, consumer_kv_idx_base, iter * bdy * tile_size_per_bdx * bdz, kv_chunk_size,
seq_len - 1, alibi_slope, s, st_local);
Expand Down Expand Up @@ -356,16 +361,16 @@ __global__ void SingleDecodeWithKVCacheKernel(DTypeQ* __restrict__ q, DTypeKV* _
}
}

template <QKVLayout kv_layout, PosEncodingMode pos_encoding_mode, uint32_t num_stages_smem,
uint32_t vec_size, uint32_t bdx, uint32_t bdy, uint32_t bdz, typename DTypeQ,
typename DTypeKV, typename DTypeOut>
template <LogitsPostHook logits_post_hook, QKVLayout kv_layout, PosEncodingMode pos_encoding_mode,
uint32_t num_stages_smem, uint32_t vec_size, uint32_t bdx, uint32_t bdy, uint32_t bdz,
typename DTypeQ, typename DTypeKV, typename DTypeOut>
__global__ void BatchDecodeWithPaddedKVCacheKernel(
DTypeQ* __restrict__ q, DTypeKV* __restrict__ k, DTypeKV* __restrict__ v,
DTypeOut* __restrict__ o, float* __restrict__ lse,
tensor_info_t<kv_layout, bdy, bdx * vec_size> info, float sm_scale, float rope_rcp_scale,
float rope_rcp_theta) {
auto block = cg::this_thread_block();
sm_scale *= math::log2e;
sm_scale *= (logits_post_hook == LogitsPostHook::kNone ? math::log2e : 1.f / 30.f);

constexpr uint32_t head_dim = bdx * vec_size;
uint32_t kv_head_idx = blockIdx.y;
Expand Down Expand Up @@ -438,7 +443,7 @@ __global__ void BatchDecodeWithPaddedKVCacheKernel(
// compute qk
cp_async::wait_group<2 * num_stages_smem - 1>();
block.sync();
compute_qk<pos_encoding_mode, vec_size, bdx, bdy>(
compute_qk<logits_post_hook, pos_encoding_mode, vec_size, bdx, bdy>(
k_smem + (stage_idx * bdz + tz) * bdy * head_dim, stage_idx, q_vec, freq,
consumer_kv_idx_base, iter * bdy * bdz, seq_len, seq_len - 1, alibi_slope, s, st_local);
block.sync();
Expand Down Expand Up @@ -489,6 +494,7 @@ __global__ void BatchDecodeWithPaddedKVCacheKernel(

/*!
* \brief FlashAttention decoding cuda kernel with paged kv-cache for multiple requests
* \tparam logits_post_hook The logits post hook used in the kernel
* \tparam partition_kv Whether to partition kv-cache on sequence length dimension or not
* \tparam pos_encoding_mode The positional encoding mode
* \tparam vec_size A template integer indicates the vector size
Expand All @@ -512,10 +518,10 @@ __global__ void BatchDecodeWithPaddedKVCacheKernel(
* \param rope_rcp_theta A floating number indicate the reciprocal
* of "theta" used in RoPE (Rotary Positional Embeddings)
*/
template <bool partition_kv, PosEncodingMode pos_encoding_mode, uint32_t num_stages_smem,
uint32_t tile_size_per_bdx, uint32_t vec_size, uint32_t bdx, uint32_t bdy, uint32_t bdz,
PageStorage page_storage, QKVLayout kv_layout, typename DTypeQ, typename DTypeKV,
typename DTypeOut, typename IdType>
template <LogitsPostHook logits_post_hook, bool partition_kv, PosEncodingMode pos_encoding_mode,
uint32_t num_stages_smem, uint32_t tile_size_per_bdx, uint32_t vec_size, uint32_t bdx,
uint32_t bdy, uint32_t bdz, PageStorage page_storage, QKVLayout kv_layout,
typename DTypeQ, typename DTypeKV, typename DTypeOut, typename IdType>
__global__ void BatchDecodeWithPagedKVCacheKernel(
DTypeQ* __restrict__ q, IdType* __restrict__ q_offset,
paged_kv_t<page_storage, kv_layout, DTypeKV, IdType> paged_kv,
Expand All @@ -524,7 +530,7 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(
bool* __restrict__ block_valid_mask, float sm_scale, float rope_rcp_scale,
float rope_rcp_theta) {
auto block = cg::this_thread_block();
sm_scale *= math::log2e;
sm_scale *= (logits_post_hook == LogitsPostHook::kNone ? math::log2e : 1.f / 30.f);

constexpr uint32_t head_dim = bdx * vec_size;
const uint32_t batch_idx = blockIdx.x;
Expand Down Expand Up @@ -649,7 +655,7 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(
// compute qk
cp_async::wait_group<2 * num_stages_smem - 1>();
block.sync();
compute_qk<pos_encoding_mode, vec_size, bdx, bdy * tile_size_per_bdx>(
compute_qk<logits_post_hook, pos_encoding_mode, vec_size, bdx, bdy * tile_size_per_bdx>(
k_smem + (stage_idx * bdz + tz) * bdy * tile_size_per_bdx * head_dim, stage_idx, q_vec,
freq,
(paged_kv.rope_pos_offset == nullptr ? 0 : paged_kv.rope_pos_offset[mapped_batch_idx]) +
Expand Down Expand Up @@ -760,8 +766,8 @@ constexpr uint32_t get_heuristic_num_threads(uint32_t group_size, uint32_t sizeo
* \param stream The cuda stream to launch the kernel
* \return status Indicates whether CUDA calls are successful
*/
template <uint32_t GROUP_SIZE, uint32_t HEAD_DIM, QKVLayout KV_LAYOUT,
PosEncodingMode POS_ENCODING_MODE, typename DTypeQ, typename DTypeKV,
template <uint32_t GROUP_SIZE, uint32_t HEAD_DIM, LogitsPostHook LOGITS_POST_HOOK,
QKVLayout KV_LAYOUT, PosEncodingMode POS_ENCODING_MODE, typename DTypeQ, typename DTypeKV,
typename DTypeOut>
cudaError_t SingleDecodeWithKVCacheDispatched(DTypeQ* q, DTypeKV* k, DTypeKV* v, DTypeOut* o,
DTypeOut* tmp, uint32_t num_kv_heads,
Expand All @@ -786,9 +792,9 @@ cudaError_t SingleDecodeWithKVCacheDispatched(DTypeQ* q, DTypeKV* k, DTypeKV* v,
if (seq_len <= 256 || tmp == nullptr) {
// no need to use partition-kv kernel
auto kernel =
SingleDecodeWithKVCacheKernel<KV_LAYOUT, /*partition_kv=*/false, POS_ENCODING_MODE,
num_stages_smem, tile_size_per_bdx, vec_size, bdx, bdy, bdz,
DTypeQ, DTypeKV, DTypeOut>;
SingleDecodeWithKVCacheKernel<LOGITS_POST_HOOK, KV_LAYOUT, /*partition_kv=*/false,
POS_ENCODING_MODE, num_stages_smem, tile_size_per_bdx,
vec_size, bdx, bdy, bdz, DTypeQ, DTypeKV, DTypeOut>;
FLASHINFER_CUDA_CALL(
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));

Expand All @@ -807,9 +813,10 @@ cudaError_t SingleDecodeWithKVCacheDispatched(DTypeQ* q, DTypeKV* k, DTypeKV* v,
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
} else {
// use partition-kv kernel
auto kernel = SingleDecodeWithKVCacheKernel<KV_LAYOUT, /*partition_kv=*/true, POS_ENCODING_MODE,
num_stages_smem, tile_size_per_bdx, vec_size, bdx,
bdy, bdz, DTypeQ, DTypeKV, DTypeOut>;
auto kernel =
SingleDecodeWithKVCacheKernel<LOGITS_POST_HOOK, KV_LAYOUT, /*partition_kv=*/true,
POS_ENCODING_MODE, num_stages_smem, tile_size_per_bdx,
vec_size, bdx, bdy, bdz, DTypeQ, DTypeKV, DTypeOut>;
FLASHINFER_CUDA_CALL(
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));

Expand Down Expand Up @@ -848,8 +855,9 @@ cudaError_t SingleDecodeWithKVCacheDispatched(DTypeQ* q, DTypeKV* k, DTypeKV* v,
return cudaSuccess;
}

template <uint32_t GROUP_SIZE, uint32_t HEAD_DIM, PageStorage page_storage, QKVLayout kv_layout,
PosEncodingMode POS_ENCODING_MODE, typename DTypeQ, typename DTypeKV, typename DTypeOut, typename IdType>
template <uint32_t GROUP_SIZE, uint32_t HEAD_DIM, PageStorage page_storage,
LogitsPostHook LOGITS_POST_HOOK, QKVLayout kv_layout, PosEncodingMode POS_ENCODING_MODE,
typename DTypeQ, typename DTypeKV, typename DTypeOut, typename IdType>
cudaError_t BatchDecodeWithPagedKVCacheDispatched(
DTypeQ* q, IdType* q_offset, paged_kv_t<page_storage, kv_layout, DTypeKV, IdType> paged_kv,
kv_partition_info_t<IdType> kv_partition_info, DTypeOut* o, DTypeOut* tmp_v, float* tmp_s,
Expand Down Expand Up @@ -877,9 +885,10 @@ cudaError_t BatchDecodeWithPagedKVCacheDispatched(
dim3 nblks(padded_batch_size, num_kv_heads);
dim3 nthrs(bdx, bdy, bdz);
auto kernel =
BatchDecodeWithPagedKVCacheKernel</*partition_kv=*/false, POS_ENCODING_MODE,
num_stages_smem, tile_size_per_bdx, vec_size, bdx, bdy,
bdz, page_storage, kv_layout, DTypeQ, DTypeKV, DTypeOut, IdType>;
BatchDecodeWithPagedKVCacheKernel<LOGITS_POST_HOOK, /*partition_kv=*/false,
POS_ENCODING_MODE, num_stages_smem, tile_size_per_bdx,
vec_size, bdx, bdy, bdz, page_storage, kv_layout, DTypeQ,
DTypeKV, DTypeOut, IdType>;
FLASHINFER_CUDA_CALL(
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
void* args[] = {(void*)&q,
Expand All @@ -898,9 +907,10 @@ cudaError_t BatchDecodeWithPagedKVCacheDispatched(
} else {
// use partition-kv kernel
auto partition_kv_kernel =
BatchDecodeWithPagedKVCacheKernel</*partition_kv=*/true, POS_ENCODING_MODE, num_stages_smem,
tile_size_per_bdx, vec_size, bdx, bdy, bdz, page_storage,
kv_layout, DTypeQ, DTypeKV, DTypeOut, IdType>;
BatchDecodeWithPagedKVCacheKernel<LOGITS_POST_HOOK, /*partition_kv=*/true,
POS_ENCODING_MODE, num_stages_smem, tile_size_per_bdx,
vec_size, bdx, bdy, bdz, page_storage, kv_layout, DTypeQ,
DTypeKV, DTypeOut, IdType>;
FLASHINFER_CUDA_CALL(cudaFuncSetAttribute(
partition_kv_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
void* args[] = {(void*)&q,
Expand Down Expand Up @@ -946,8 +956,9 @@ cudaError_t BatchDecodeWithPagedKVCacheDispatched(
* \param stream The cuda stream to launch the kernel
* \return status Indicates whether CUDA calls are successful
*/
template <uint32_t GROUP_SIZE, uint32_t HEAD_DIM, QKVLayout KV_LAYOUT,
PosEncodingMode POS_ENCODING_MODE, typename DTypeQ, typename DTypeKV, typename DTypeOut>
template <uint32_t GROUP_SIZE, uint32_t HEAD_DIM, LogitsPostHook LOGITS_POST_HOOK,
QKVLayout KV_LAYOUT, PosEncodingMode POS_ENCODING_MODE, typename DTypeQ, typename DTypeKV,
typename DTypeOut>
cudaError_t BatchDecodeWithPaddedKVCacheDispatched(DTypeQ* q, DTypeKV* k, DTypeKV* v, DTypeOut* o,
DTypeOut* tmp, float* lse, uint32_t batch_size,
uint32_t padded_kv_len, uint32_t num_qo_heads,
Expand All @@ -970,8 +981,9 @@ cudaError_t BatchDecodeWithPaddedKVCacheDispatched(DTypeQ* q, DTypeKV* k, DTypeK

dim3 nblks(batch_size, num_kv_heads);
dim3 nthrs(bdx, bdy, bdz);
auto kernel = BatchDecodeWithPaddedKVCacheKernel<KV_LAYOUT, POS_ENCODING_MODE, num_stages_smem,
vec_size, bdx, bdy, bdz, DTypeQ, DTypeKV, DTypeOut>;
auto kernel = BatchDecodeWithPaddedKVCacheKernel<LOGITS_POST_HOOK, KV_LAYOUT, POS_ENCODING_MODE,
num_stages_smem, vec_size, bdx, bdy, bdz, DTypeQ,
DTypeKV, DTypeOut>;
FLASHINFER_CUDA_CALL(
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
tensor_info_t<KV_LAYOUT, GROUP_SIZE, HEAD_DIM> info(1, padded_kv_len, num_kv_heads);
Expand Down
Loading

0 comments on commit ab1e2ad

Please sign in to comment.