diff --git a/csrc/src/mask.h b/csrc/src/mask.h index 4855f28..80c7449 100644 --- a/csrc/src/mask.h +++ b/csrc/src/mask.h @@ -10,6 +10,7 @@ #include #include #include +#include #ifndef BLOCK_THREADS #define BLOCK_THREADS 128 // Common CUDA thread block size (multiple of 32) @@ -23,7 +24,37 @@ namespace FLASH_NAMESPACE { using namespace cute; -// Struct wrapper for dynamic mask application +// Value-Index pair for top-k selection +template +struct TopKPair { + ValueType value; + int col_index; + + __device__ __forceinline__ TopKPair() : value(ValueType(-INFINITY)), col_index(-1) {} + __device__ __forceinline__ TopKPair(ValueType v, int idx) : value(v), col_index(idx) {} + + __device__ __forceinline__ bool is_valid() const { + return col_index >= 0 && isfinite(value); + } +}; + +// Comparison functor for descending sort (greater values first) +template +struct DescendingComparator { + __device__ __forceinline__ bool operator()(const TopKPair& a, const TopKPair& b) const { + // if (isfinite(a.value) && isfinite(b.value)) { + // return a.value > b.value; + // } else if (isfinite(a.value)) { + // return true; // a is valid, b is not + // } else if (isfinite(b.value)) { + // return false; // b is valid, a is not + // } else { + // return a.col_index < b.col_index; // Compare indices if both are invalid + // } + return a.value > b.value; // Descending order + } +}; + template struct DynamicMask { const int max_seqlen_k, max_seqlen_q; @@ -100,112 +131,92 @@ struct DynamicMask { return; } - // Apply top-k selection per row if needed - #pragma unroll + // Declare shared memory for BlockMergeSort at block scope + using BlockMergeSortT = cub::BlockMergeSort, BlockThreads, ITEMS_PER_THREAD>; + __shared__ typename BlockMergeSortT::TempStorage temp_storage; + // Process each row with TopK sorting for (int mi = 0; mi < size<0, 1>(zero_hold); ++mi) { const int row_idx_base = row_idx_offset + mi * warp_row_stride; - #pragma unroll for (int i = 0; i < size<0, 0>(zero_hold); ++i) { const int row_idx = row_idx_base + i * 8; - // Skip if out of bounds if (row_idx >= max_seqlen_q) continue; - - // Temporarily mark all active elements as inactive for selection - #pragma unroll - for (int nj = 0; nj < size<1, 1>(zero_hold); ++nj) { - #pragma unroll - for (int j = 0; j < size<1, 0>(zero_hold); ++j) { - auto coord = make_coord(make_coord(i, mi), make_coord(j, nj)); - if (active_indices(coord)) { - active_indices(coord) = false; - } - } - } - __syncthreads(); - // Shared memory for reduction - __shared__ float s_max_vals[BlockThreads]; - __shared__ int s_max_indices_nj[BlockThreads]; - __shared__ int s_max_indices_j[BlockThreads]; + // Step 1: Thread-local storage for collecting current row elements + TopKPair thread_data[ITEMS_PER_THREAD]; - // Iteratively select top-k elements - for (int k = 0; k < keep_window_size; ++k) { - float thread_max = -FLT_MAX; - int thread_max_nj = -1; - int thread_max_j = -1; - - // Each thread finds its local maximum using the same loop structure - #pragma unroll - for (int nj = 0; nj < size<1, 1>(zero_hold); ++nj) { - const int col_idx_base = col_idx_offset + nj * 8; - #pragma unroll - for (int j = 0; j < size<1, 0>(zero_hold); ++j) { - const int col_idx = col_idx_base + j; - auto coord = make_coord(make_coord(i, mi), make_coord(j, nj)); - - bool valid = (col_idx < max_seqlen_k) && !(Is_causal && col_idx > row_idx); - float val = static_cast(zero_hold(coord)); - if (valid && !active_indices(coord) && !isinf(val) && val > thread_max) { - thread_max = val; - thread_max_nj = nj; - thread_max_j = j; - } - } - } - - // Store thread-local maximum - s_max_vals[tid] = thread_max; - s_max_indices_nj[tid] = thread_max_nj; - s_max_indices_j[tid] = thread_max_j; - __syncthreads(); + // Initialize all elements as invalid + for (int item = 0; item < ITEMS_PER_THREAD; ++item) { + thread_data[item] = TopKPair(); + } + + // Collect data from current row + for (int item = 0; item < ITEMS_PER_THREAD; ++item) { + int global_idx = tid * ITEMS_PER_THREAD + item; - // Parallel reduction to find global maximum - for (int stride = BlockThreads / 2; stride > 0; stride >>= 1) { - if (tid < stride) { - if (s_max_vals[tid] < s_max_vals[tid + stride]) { - s_max_vals[tid] = s_max_vals[tid + stride]; - s_max_indices_nj[tid] = s_max_indices_nj[tid + stride]; - s_max_indices_j[tid] = s_max_indices_j[tid + stride]; + if (global_idx < max_seqlen_k) { + // Find element with column index = global_idx in current row + for (int nj = 0; nj < size<1, 1>(zero_hold); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; + for (int j = 0; j < size<1, 0>(zero_hold); ++j) { + const int col_idx = col_idx_base + j; + if (col_idx == global_idx) { + auto coord = make_coord(make_coord(i, mi), make_coord(j, nj)); + + // If active, collect its value and index + if (active_indices(coord)) { + ElementZeroHold val = zero_hold(coord); + thread_data[item] = TopKPair(val, col_idx); + } + break; // Found the element, no need to continue + } } } - __syncthreads(); - } - - // Mark the selected index as active - if (tid == 0 && s_max_indices_nj[0] >= 0 && s_max_indices_j[0] >= 0) { - auto coord = make_coord(make_coord(i, mi), make_coord(s_max_indices_j[0], s_max_indices_nj[0])); - active_indices(coord) = true; - } - __syncthreads(); - - // Early exit if no more valid elements - if (s_max_vals[0] == -FLT_MAX) { - break; } } - // Clear non-selected values using the same loop structure - #pragma unroll + // Step 2: Block-wide collaborative sorting with explicit comparator + DescendingComparator comp; + BlockMergeSortT(temp_storage).Sort(thread_data, comp); + __syncthreads(); // Ensure sorting is complete + + // Step 3: Update active_indices - keep only topk + // Traverse each coordinate and check if its col_idx is in topk for (int nj = 0; nj < size<1, 1>(zero_hold); ++nj) { - #pragma unroll + const int col_idx_base = col_idx_offset + nj * 8; for (int j = 0; j < size<1, 0>(zero_hold); ++j) { + const int col_idx = col_idx_base + j; auto coord = make_coord(make_coord(i, mi), make_coord(j, nj)); - if (!active_indices(coord)) { - zero_hold(coord) = ElementZeroHold(-INFINITY); + + // If current position is active, check if it's in topk + if (active_indices(coord)) { + // Check if this element is in thread's own topk data + bool is_in_topk = false; + + for (int item = 0; item < ITEMS_PER_THREAD; ++item) { + // Global position in sorted order + int global_pos = tid * ITEMS_PER_THREAD + item; + + // Only elements with global_pos < keep_window_size are topk + if (global_pos < keep_window_size && + thread_data[item].col_index == col_idx) { + is_in_topk = true; + break; + } + } + + // If not in topk, set as inactive + if (!is_in_topk) { + active_indices(coord) = false; + } } } } - __syncthreads(); - + __syncthreads(); // Ensure row processing is complete } } } - - template < - bool Causal_mask=false, bool Is_even_MN=true, - typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Engine2, typename Layout2 - > + template __forceinline__ __device__ void apply_mask( Tensor &tensor_, // acc_s (attention scores, 3D) Tensor &tZeroHold, // Zero-hold states (3D)