Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 95 additions & 84 deletions csrc/src/mask.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <cutlass/numeric_types.h>
#include <cutlass/numeric_conversion.h>
#include <cutlass/fast_math.h>
#include <cub/block/block_merge_sort.cuh>

#ifndef BLOCK_THREADS
#define BLOCK_THREADS 128 // Common CUDA thread block size (multiple of 32)
Expand All @@ -23,7 +24,37 @@ namespace FLASH_NAMESPACE {

using namespace cute;

// Struct wrapper for dynamic mask application
// Value-Index pair for top-k selection
template<typename ValueType>
struct TopKPair {
ValueType value;
int col_index;

__device__ __forceinline__ TopKPair() : value(ValueType(-INFINITY)), col_index(-1) {}
__device__ __forceinline__ TopKPair(ValueType v, int idx) : value(v), col_index(idx) {}

__device__ __forceinline__ bool is_valid() const {
return col_index >= 0 && isfinite(value);
}
};

// Comparison functor for descending sort (greater values first)
template<typename ValueType>
struct DescendingComparator {
__device__ __forceinline__ bool operator()(const TopKPair<ValueType>& a, const TopKPair<ValueType>& b) const {
// if (isfinite(a.value) && isfinite(b.value)) {
// return a.value > b.value;
// } else if (isfinite(a.value)) {
// return true; // a is valid, b is not
// } else if (isfinite(b.value)) {
// return false; // b is valid, a is not
// } else {
// return a.col_index < b.col_index; // Compare indices if both are invalid
// }
return a.value > b.value; // Descending order
}
};

template <bool Is_causal, int BlockThreads>
struct DynamicMask {
const int max_seqlen_k, max_seqlen_q;
Expand Down Expand Up @@ -100,112 +131,92 @@ struct DynamicMask {
return;
}

// Apply top-k selection per row if needed
#pragma unroll
// Declare shared memory for BlockMergeSort at block scope
using BlockMergeSortT = cub::BlockMergeSort<TopKPair<ElementZeroHold>, BlockThreads, ITEMS_PER_THREAD>;
__shared__ typename BlockMergeSortT::TempStorage temp_storage;
// Process each row with TopK sorting
for (int mi = 0; mi < size<0, 1>(zero_hold); ++mi) {
const int row_idx_base = row_idx_offset + mi * warp_row_stride;
#pragma unroll
for (int i = 0; i < size<0, 0>(zero_hold); ++i) {
const int row_idx = row_idx_base + i * 8;
// Skip if out of bounds
if (row_idx >= max_seqlen_q) continue;

// Temporarily mark all active elements as inactive for selection
#pragma unroll
for (int nj = 0; nj < size<1, 1>(zero_hold); ++nj) {
#pragma unroll
for (int j = 0; j < size<1, 0>(zero_hold); ++j) {
auto coord = make_coord(make_coord(i, mi), make_coord(j, nj));
if (active_indices(coord)) {
active_indices(coord) = false;
}
}
}
__syncthreads();

// Shared memory for reduction
__shared__ float s_max_vals[BlockThreads];
__shared__ int s_max_indices_nj[BlockThreads];
__shared__ int s_max_indices_j[BlockThreads];
// Step 1: Thread-local storage for collecting current row elements
TopKPair<ElementZeroHold> thread_data[ITEMS_PER_THREAD];

// Iteratively select top-k elements
for (int k = 0; k < keep_window_size; ++k) {
float thread_max = -FLT_MAX;
int thread_max_nj = -1;
int thread_max_j = -1;

// Each thread finds its local maximum using the same loop structure
#pragma unroll
for (int nj = 0; nj < size<1, 1>(zero_hold); ++nj) {
const int col_idx_base = col_idx_offset + nj * 8;
#pragma unroll
for (int j = 0; j < size<1, 0>(zero_hold); ++j) {
const int col_idx = col_idx_base + j;
auto coord = make_coord(make_coord(i, mi), make_coord(j, nj));

bool valid = (col_idx < max_seqlen_k) && !(Is_causal && col_idx > row_idx);
float val = static_cast<float>(zero_hold(coord));
if (valid && !active_indices(coord) && !isinf(val) && val > thread_max) {
thread_max = val;
thread_max_nj = nj;
thread_max_j = j;
}
}
}

// Store thread-local maximum
s_max_vals[tid] = thread_max;
s_max_indices_nj[tid] = thread_max_nj;
s_max_indices_j[tid] = thread_max_j;
__syncthreads();
// Initialize all elements as invalid
for (int item = 0; item < ITEMS_PER_THREAD; ++item) {
thread_data[item] = TopKPair<ElementZeroHold>();
}

// Collect data from current row
for (int item = 0; item < ITEMS_PER_THREAD; ++item) {
int global_idx = tid * ITEMS_PER_THREAD + item;

// Parallel reduction to find global maximum
for (int stride = BlockThreads / 2; stride > 0; stride >>= 1) {
if (tid < stride) {
if (s_max_vals[tid] < s_max_vals[tid + stride]) {
s_max_vals[tid] = s_max_vals[tid + stride];
s_max_indices_nj[tid] = s_max_indices_nj[tid + stride];
s_max_indices_j[tid] = s_max_indices_j[tid + stride];
if (global_idx < max_seqlen_k) {
// Find element with column index = global_idx in current row
for (int nj = 0; nj < size<1, 1>(zero_hold); ++nj) {
const int col_idx_base = col_idx_offset + nj * 8;
for (int j = 0; j < size<1, 0>(zero_hold); ++j) {
const int col_idx = col_idx_base + j;
if (col_idx == global_idx) {
auto coord = make_coord(make_coord(i, mi), make_coord(j, nj));

// If active, collect its value and index
if (active_indices(coord)) {
ElementZeroHold val = zero_hold(coord);
thread_data[item] = TopKPair<ElementZeroHold>(val, col_idx);
}
break; // Found the element, no need to continue
}
}
}
__syncthreads();
}

// Mark the selected index as active
if (tid == 0 && s_max_indices_nj[0] >= 0 && s_max_indices_j[0] >= 0) {
auto coord = make_coord(make_coord(i, mi), make_coord(s_max_indices_j[0], s_max_indices_nj[0]));
active_indices(coord) = true;
}
__syncthreads();

// Early exit if no more valid elements
if (s_max_vals[0] == -FLT_MAX) {
break;
}
}

// Clear non-selected values using the same loop structure
#pragma unroll
// Step 2: Block-wide collaborative sorting with explicit comparator
DescendingComparator<ElementZeroHold> comp;
BlockMergeSortT(temp_storage).Sort(thread_data, comp);
__syncthreads(); // Ensure sorting is complete

// Step 3: Update active_indices - keep only topk
// Traverse each coordinate and check if its col_idx is in topk
for (int nj = 0; nj < size<1, 1>(zero_hold); ++nj) {
#pragma unroll
const int col_idx_base = col_idx_offset + nj * 8;
for (int j = 0; j < size<1, 0>(zero_hold); ++j) {
const int col_idx = col_idx_base + j;
auto coord = make_coord(make_coord(i, mi), make_coord(j, nj));
if (!active_indices(coord)) {
zero_hold(coord) = ElementZeroHold(-INFINITY);

// If current position is active, check if it's in topk
if (active_indices(coord)) {
// Check if this element is in thread's own topk data
bool is_in_topk = false;

for (int item = 0; item < ITEMS_PER_THREAD; ++item) {
// Global position in sorted order
int global_pos = tid * ITEMS_PER_THREAD + item;

// Only elements with global_pos < keep_window_size are topk
if (global_pos < keep_window_size &&
thread_data[item].col_index == col_idx) {
is_in_topk = true;
break;
}
}

// If not in topk, set as inactive
if (!is_in_topk) {
active_indices(coord) = false;
}
}
}
}
__syncthreads();

__syncthreads(); // Ensure row processing is complete
}
}
}


template <
bool Causal_mask=false, bool Is_even_MN=true,
typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Engine2, typename Layout2
>
template <bool Causal_mask=false, bool Is_even_MN=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Engine2, typename Layout2>
__forceinline__ __device__ void apply_mask(
Tensor<Engine0, Layout0> &tensor_, // acc_s (attention scores, 3D)
Tensor<Engine1, Layout1> &tZeroHold, // Zero-hold states (3D)
Expand Down