Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
226 changes: 226 additions & 0 deletions csrc/src/flash.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
/******************************************************************************
* Copyright (c) 2025, Jingze Shi and Tri Dao.
******************************************************************************/

#pragma once

#include "namespace_config.h"

#include <cuda.h>
#include <vector>

#include <ATen/cuda/CUDAGeneratorImpl.h> // For at::Generator and at::PhiloxCudaState

namespace FLASH_NAMESPACE {
constexpr int TOTAL_DIM = 0;
constexpr int H_DIM = 1;
constexpr int D_DIM = 2;

////////////////////////////////////////////////////////////////////////////////////////////////////

struct QKV_params {
using index_t = int64_t;
// The QKV matrices.
void *__restrict__ q_ptr; // Query tensor [batch_size, num_heads, query_len, head_dim]
void *__restrict__ k_ptr; // Key tensor [batch_size, num_kv_heads, key_len, head_dim]
void *__restrict__ v_ptr; // Value tensor [batch_size, num_kv_heads, key_len, head_dim]

// The stride between rows of the Q, K and V matrices.
index_t q_batch_stride; // Stride between batches of Q
index_t k_batch_stride; // Stride between batches of K
index_t v_batch_stride; // Stride between batches of V
index_t q_row_stride; // Stride between rows of Q
index_t k_row_stride; // Stride between rows of K
index_t v_row_stride; // Stride between rows of V
index_t q_head_stride; // Stride between heads of Q
index_t k_head_stride; // Stride between heads of K
index_t v_head_stride; // Stride between heads of V

// The number of heads.
int h, h_k;
// In the case of multi-query and grouped-query attention (MQA/GQA), nheads_k could be
// different from nheads (query).
int h_h_k_ratio; // precompute h / h_k,
};

////////////////////////////////////////////////////////////////////////////////////////////////////

struct ZeroHold_params {
using index_t = int64_t;

void *__restrict__ zero_hold_ptr; // Zero-hold states tensor [batch_size, num_kv_heads, query_len, key_len]

// The stride of the zero-hold states tensor.
index_t zero_hold_batch_stride; // Stride between batches of zero-hold states
index_t zero_hold_head_stride; // Stride between heads of zero-hold states
index_t zero_hold_query_stride; // Stride for the third dimension (query_len) of zero-hold states
// Assuming last dim (key_len) has stride 1 for the zero_hold_states_ptr

index_t causal_mask_batch_stride; // Stride between batches of causal_mask
index_t causal_mask_head_stride; // Stride for the second dimension (size 1) of causal_mask
index_t causal_mask_query_len_stride; // Stride for the third dimension (query_len) of causal_mask
// Assuming last dim (key_len) has stride 1 for the causal_mask_ptr

// The keep window size.
int keep_window_size; // Number of tokens to keep in top-k (0 means don't apply top-k)
};

////////////////////////////////////////////////////////////////////////////////////////////////////

struct Flash_fwd_params : public QKV_params, public ZeroHold_params {
// Input tensors
void *q_ptr = nullptr;
void *k_ptr = nullptr;
void *v_ptr = nullptr;
void *zero_hold_ptr = nullptr;
void *causal_mask_ptr = nullptr;

// Input tensor for the bias
void *b_ptr = nullptr;
// Output tensors
void *o_ptr = nullptr;
// Tensor storing the output of softmax
void *p_ptr = nullptr;
// Buffer for partial derivatives
void *do_ptr = nullptr;
// Tensor storing the logsumexp for numerical stability.
void *softmax_lse_ptr = nullptr;

// The O matrix (output).
void * __restrict__ oaccum_ptr;

// The stride between rows of O.
index_t o_batch_stride;
index_t o_row_stride;
index_t o_head_stride;

// The pointer to the softmax sum.
void * __restrict__ softmax_lseaccum_ptr;

// The dimensions.
int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim, total_q;

// The scaling factors for the kernel.
float scale_softmax;
float scale_softmax_log2;

// array of length b+1 holding starting offset of each sequence.
int * __restrict__ cu_seqlens_q;
int * __restrict__ cu_seqlens_k;
int * __restrict__ leftpad_k;

// If provided, the actual length of each k sequence.
int * __restrict__ seqused_k;

int *__restrict__ blockmask;

// The K_new and V_new matrices.
void * __restrict__ knew_ptr;
void * __restrict__ vnew_ptr;

// The stride between rows of the Q, K and V matrices.
index_t knew_batch_stride;
index_t vnew_batch_stride;
index_t knew_row_stride;
index_t vnew_row_stride;
index_t knew_head_stride;
index_t vnew_head_stride;

// The cos and sin matrices for rotary embedding.
void * __restrict__ rotary_cos_ptr;
void * __restrict__ rotary_sin_ptr;

// The indices to index into the KV cache.
int * __restrict__ cache_batch_idx;

// Paged KV cache
int * __restrict__ block_table;
index_t block_table_batch_stride;
int page_block_size;

// The dropout probability (probability of keeping an activation).
float p_dropout;
// uint32_t p_dropout_in_uint;
// uint16_t p_dropout_in_uint16_t;
uint8_t p_dropout_in_uint8_t;

// Scale factor of 1 / (1 - p_dropout).
float rp_dropout;
float scale_softmax_rp_dropout;

// Random state.
at::PhiloxCudaState philox_args;

// Pointer to the RNG seed (idx 0) and offset (idx 1).
uint64_t * rng_state;

bool is_bf16;
bool is_causal;

// If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
// Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
bool is_seqlens_k_cumulative;

bool is_rotary_interleaved;

int num_splits; // For split-KV version

bool unpadded_lse; // For varlen paths: LSE is in [nheads, total_seqlen_q] format instead of [b, nheads, seqlen_q].
bool seqlenq_ngroups_swapped; // q has been transposed from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d).
};

////////////////////////////////////////////////////////////////////////////////////////////////////

struct Flash_bwd_params : public Flash_fwd_params {

// The dO and dQKV and dZeroHold matrices.
void *__restrict__ do_ptr;
void *__restrict__ dq_ptr;
void *__restrict__ dk_ptr;
void *__restrict__ dv_ptr;
void *__restrict__ dzero_hold_ptr;

// To accumulate dQ
void *__restrict__ dq_accum_ptr;
void *__restrict__ dk_accum_ptr;
void *__restrict__ dv_accum_ptr;
void *__restrict__ dzero_hold_accum_ptr;

// // To accumulate dK and dV in case we're splitting the bwd along seqlen_q
// dimension void *__restrict__ dk_accum_ptr; void *__restrict__
// dv_accum_ptr;

// The stride between rows of the dO, dQ, dK and dV matrices.
// TD [2022-04-16]: We're using 32-bit indexing to save registers.
// The code probably won't work for arrays larger than 2GB.
index_t do_batch_stride;
index_t do_row_stride;
index_t do_head_stride;
index_t dq_batch_stride;
index_t dk_batch_stride;
index_t dv_batch_stride;
index_t dq_row_stride;
index_t dk_row_stride;
index_t dv_row_stride;
index_t dq_head_stride;
index_t dk_head_stride;
index_t dv_head_stride;
index_t dzero_hold_batch_stride;
index_t dzero_hold_head_stride;
index_t dzero_hold_query_stride;

// The pointer to the softmax d sum.
void *__restrict__ dsoftmax_sum;

bool deterministic;
index_t dq_accum_split_stride;
};

////////////////////////////////////////////////////////////////////////////////////////////////////

template<typename T, int Headdim, bool Is_causal> void run_mha_fwd_(Flash_fwd_params &params, cudaStream_t stream);
template<typename T, int Headdim, bool Is_causal> void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream);

template<typename T, int Headdim, bool Is_causal> void run_mha_bwd_(Flash_bwd_params &params, cudaStream_t stream);

} // namespace FLASH_NAMESPACE