Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 11 additions & 10 deletions csrc/src/flash_fwd_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "softmax.h"
#include "mask.h"
#include "dropout.h"
#include "rotary.h"

namespace FLASH_NAMESPACE {

Expand Down Expand Up @@ -179,7 +180,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
make_stride(params.seqlen_k_rounded, _1{})
);
Tensor mZOH = make_tensor(
make_gmem_ptr(reinterpret_cast<Element*>(params.zoh_ptr) + binfo.zoh_offset(params.zoh_batch_stride, bidb)),
make_gmem_ptr(reinterpret_cast<Element*>(params.zoh_ptr) + binfo.zoh_offset(params.zoh_batch_stride, params.zoh_row_stride, bidb)),
make_shape(params.h_k, binfo.actual_seqlen_q, binfo.actual_seqlen_k),
make_stride(params.zoh_head_stride, params.zoh_row_stride, _1{})
);
Expand All @@ -189,7 +190,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
make_coord(m_block, _)
); // (kBlockM, kBlockN, nblocksN)
Tensor mActiveMask = make_tensor(
make_gmem_ptr(reinterpret_cast<Element*>(params.active_mask_ptr) + binfo.active_mask_offset(params.active_mask_batch_stride, bidb)),
make_gmem_ptr(reinterpret_cast<Element*>(params.active_mask_ptr) + binfo.active_mask_offset(params.active_mask_batch_stride, params.active_mask_row_stride, bidb)),
make_shape(params.h_k, binfo.actual_seqlen_q, binfo.actual_seqlen_k),
make_stride(params.active_mask_head_stride, params.active_mask_row_stride, _1{})
);
Expand Down Expand Up @@ -240,25 +241,25 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi

Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ);
Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ);
Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K, nblocksN)
Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K, nblocksN)
Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK);
Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K, nblocksN)
Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K, nblocksN)
Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV);
Tensor tZOHgZOH = gmem_thr_copy_ZOH.partition_S(gZOH); // (ZOHCPY, ZOHCPY_M, ZOHCPY_N, nblocksN)
Tensor tZOHgZOH = gmem_thr_copy_ZOH.partition_S(gZOH); // (ZOHCPY, ZOHCPY_M, ZOHCPY_N, nblocksN)
Tensor tZOHsZOH = gmem_thr_copy_ZOH.partition_D(sZOH);
Tensor tAMgAM = gmem_thr_copy_AM.partition_S(gActiveMask); // (AMCPY, AMCPY_M, AMCPY_N, nblocksN)
Tensor tAMsAM = gmem_thr_copy_AM.partition_D(sActiveMask);

// Matrix Multiply Accumulate
typename Kernel_traits::TiledMma tiled_mma;
auto thr_mma = tiled_mma.get_thread_slice(tidx);
Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K)
Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K)
Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA, MMA_M, MMA_K)
Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA, MMA_N, MMA_K)
Tensor tSrZOH = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA, MMA_M, MMA_N)
Tensor tSrAM = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA, MMA_M, MMA_N)
Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K,MMA_N)
Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K, MMA_N)
Tensor tSgS = thr_mma.partition_C(gP);
Tensor acc_o = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kHeadDim>>{}); // MMA, MMA_M, MMA_K
Tensor acc_o = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kHeadDim>>{}); // (MMA, MMA_M, MMA_K)

// Copy Atom retiling
auto smem_tiled_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma);
Expand Down Expand Up @@ -383,7 +384,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
FLASH_NAMESPACE::Softmax<2 * size<1>(acc_o)> softmax;

// Init dynamic mask processor
FLASH_NAMESPACE::DynamicMask<Is_causal, Kernel_traits::kNThreads> dynamic_mask(
FLASH_NAMESPACE::DynamicMask<Is_causal> dynamic_mask(
binfo.actual_seqlen_k, binfo.actual_seqlen_q,
params.keep_window_size
);
Expand Down