diff --git a/csrc/src/flash.h b/csrc/src/flash.h new file mode 100644 index 0000000..1facc13 --- /dev/null +++ b/csrc/src/flash.h @@ -0,0 +1,226 @@ +/****************************************************************************** + * Copyright (c) 2025, Jingze Shi and Tri Dao. + ******************************************************************************/ + +#pragma once + +#include "namespace_config.h" + +#include +#include + +#include // For at::Generator and at::PhiloxCudaState + +namespace FLASH_NAMESPACE { +constexpr int TOTAL_DIM = 0; +constexpr int H_DIM = 1; +constexpr int D_DIM = 2; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct QKV_params { + using index_t = int64_t; + // The QKV matrices. + void *__restrict__ q_ptr; // Query tensor [batch_size, num_heads, query_len, head_dim] + void *__restrict__ k_ptr; // Key tensor [batch_size, num_kv_heads, key_len, head_dim] + void *__restrict__ v_ptr; // Value tensor [batch_size, num_kv_heads, key_len, head_dim] + + // The stride between rows of the Q, K and V matrices. + index_t q_batch_stride; // Stride between batches of Q + index_t k_batch_stride; // Stride between batches of K + index_t v_batch_stride; // Stride between batches of V + index_t q_row_stride; // Stride between rows of Q + index_t k_row_stride; // Stride between rows of K + index_t v_row_stride; // Stride between rows of V + index_t q_head_stride; // Stride between heads of Q + index_t k_head_stride; // Stride between heads of K + index_t v_head_stride; // Stride between heads of V + + // The number of heads. + int h, h_k; + // In the case of multi-query and grouped-query attention (MQA/GQA), nheads_k could be + // different from nheads (query). + int h_h_k_ratio; // precompute h / h_k, +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct ZeroHold_params { + using index_t = int64_t; + + void *__restrict__ zero_hold_ptr; // Zero-hold states tensor [batch_size, num_kv_heads, query_len, key_len] + + // The stride of the zero-hold states tensor. + index_t zero_hold_batch_stride; // Stride between batches of zero-hold states + index_t zero_hold_head_stride; // Stride between heads of zero-hold states + index_t zero_hold_query_stride; // Stride for the third dimension (query_len) of zero-hold states + // Assuming last dim (key_len) has stride 1 for the zero_hold_states_ptr + + index_t causal_mask_batch_stride; // Stride between batches of causal_mask + index_t causal_mask_head_stride; // Stride for the second dimension (size 1) of causal_mask + index_t causal_mask_query_len_stride; // Stride for the third dimension (query_len) of causal_mask + // Assuming last dim (key_len) has stride 1 for the causal_mask_ptr + + // The keep window size. + int keep_window_size; // Number of tokens to keep in top-k (0 means don't apply top-k) +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Flash_fwd_params : public QKV_params, public ZeroHold_params { + // Input tensors + void *q_ptr = nullptr; + void *k_ptr = nullptr; + void *v_ptr = nullptr; + void *zero_hold_ptr = nullptr; + void *causal_mask_ptr = nullptr; + + // Input tensor for the bias + void *b_ptr = nullptr; + // Output tensors + void *o_ptr = nullptr; + // Tensor storing the output of softmax + void *p_ptr = nullptr; + // Buffer for partial derivatives + void *do_ptr = nullptr; + // Tensor storing the logsumexp for numerical stability. + void *softmax_lse_ptr = nullptr; + + // The O matrix (output). + void * __restrict__ oaccum_ptr; + + // The stride between rows of O. + index_t o_batch_stride; + index_t o_row_stride; + index_t o_head_stride; + + // The pointer to the softmax sum. + void * __restrict__ softmax_lseaccum_ptr; + + // The dimensions. + int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim, total_q; + + // The scaling factors for the kernel. + float scale_softmax; + float scale_softmax_log2; + + // array of length b+1 holding starting offset of each sequence. + int * __restrict__ cu_seqlens_q; + int * __restrict__ cu_seqlens_k; + int * __restrict__ leftpad_k; + + // If provided, the actual length of each k sequence. + int * __restrict__ seqused_k; + + int *__restrict__ blockmask; + + // The K_new and V_new matrices. + void * __restrict__ knew_ptr; + void * __restrict__ vnew_ptr; + + // The stride between rows of the Q, K and V matrices. + index_t knew_batch_stride; + index_t vnew_batch_stride; + index_t knew_row_stride; + index_t vnew_row_stride; + index_t knew_head_stride; + index_t vnew_head_stride; + + // The cos and sin matrices for rotary embedding. + void * __restrict__ rotary_cos_ptr; + void * __restrict__ rotary_sin_ptr; + + // The indices to index into the KV cache. + int * __restrict__ cache_batch_idx; + + // Paged KV cache + int * __restrict__ block_table; + index_t block_table_batch_stride; + int page_block_size; + + // The dropout probability (probability of keeping an activation). + float p_dropout; + // uint32_t p_dropout_in_uint; + // uint16_t p_dropout_in_uint16_t; + uint8_t p_dropout_in_uint8_t; + + // Scale factor of 1 / (1 - p_dropout). + float rp_dropout; + float scale_softmax_rp_dropout; + + // Random state. + at::PhiloxCudaState philox_args; + + // Pointer to the RNG seed (idx 0) and offset (idx 1). + uint64_t * rng_state; + + bool is_bf16; + bool is_causal; + + // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb]. + // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K. + bool is_seqlens_k_cumulative; + + bool is_rotary_interleaved; + + int num_splits; // For split-KV version + + bool unpadded_lse; // For varlen paths: LSE is in [nheads, total_seqlen_q] format instead of [b, nheads, seqlen_q]. + bool seqlenq_ngroups_swapped; // q has been transposed from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d). +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Flash_bwd_params : public Flash_fwd_params { + + // The dO and dQKV and dZeroHold matrices. + void *__restrict__ do_ptr; + void *__restrict__ dq_ptr; + void *__restrict__ dk_ptr; + void *__restrict__ dv_ptr; + void *__restrict__ dzero_hold_ptr; + + // To accumulate dQ + void *__restrict__ dq_accum_ptr; + void *__restrict__ dk_accum_ptr; + void *__restrict__ dv_accum_ptr; + void *__restrict__ dzero_hold_accum_ptr; + + // // To accumulate dK and dV in case we're splitting the bwd along seqlen_q + // dimension void *__restrict__ dk_accum_ptr; void *__restrict__ + // dv_accum_ptr; + + // The stride between rows of the dO, dQ, dK and dV matrices. + // TD [2022-04-16]: We're using 32-bit indexing to save registers. + // The code probably won't work for arrays larger than 2GB. + index_t do_batch_stride; + index_t do_row_stride; + index_t do_head_stride; + index_t dq_batch_stride; + index_t dk_batch_stride; + index_t dv_batch_stride; + index_t dq_row_stride; + index_t dk_row_stride; + index_t dv_row_stride; + index_t dq_head_stride; + index_t dk_head_stride; + index_t dv_head_stride; + index_t dzero_hold_batch_stride; + index_t dzero_hold_head_stride; + index_t dzero_hold_query_stride; + + // The pointer to the softmax d sum. + void *__restrict__ dsoftmax_sum; + + bool deterministic; + index_t dq_accum_split_stride; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +template void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE