Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 47 additions & 10 deletions csrc/src/mask.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,59 @@

#pragma once
#include "namespace_config.h"
#include <cute/tensor.hpp>
#include <cutlass/cutlass.h>
#include <cutlass/array.h>
#include <cutlass/numeric_types.h>
#include <cutlass/numeric_conversion.h>
#include <cutlass/fast_math.h>
#include <cub/block/block_merge_sort.cuh>

#ifndef ITEMS_PER_THREAD
#define ITEMS_PER_THREAD 32
#endif
#include <cute/tensor.hpp>

namespace FLASH_NAMESPACE {

using namespace cute;

template <bool Causal_mask=false, typename TensorType, typename MaskType, typename BiasType>
__forceinline__ __device__ void apply_mask(
TensorType &tensor,
MaskType &Mask,
BiasType &Bias,
const float scale_softmax,
const int col_idx_offset_,
const int max_seqlen_k,
const int row_idx_offset,
const int max_seqlen_q,
const int warp_row_stride
) {
// tensor has shape (nrow=(2, MMA_M), ncol=(2, MMA_N))
static_assert(TensorType::rank == 2, "Only support 2D Tensor");
static_assert(MaskType::rank == 2, "Only support 2D Mask");
static_assert(BiasType::rank == 2, "Only support 2D Bias");
const int lane_id = threadIdx.x % 32;
const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
#pragma unroll
for (int mi = 0; mi < size<0, 1>(tensor); ++mi) {
Comment on lines +27 to +33
Copy link

Copilot AI Jul 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new apply_mask function only supports 2D tensors, but the existing Mask::apply_mask method in the codebase works with 3D tensors (rank == 3). This inconsistency could cause confusion and limit interoperability between the two implementations.

Suggested change
static_assert(TensorType::rank == 2, "Only support 2D Tensor");
static_assert(MaskType::rank == 2, "Only support 2D Mask");
static_assert(BiasType::rank == 2, "Only support 2D Bias");
const int lane_id = threadIdx.x % 32;
const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
#pragma unroll
for (int mi = 0; mi < size<0, 1>(tensor); ++mi) {
static_assert(TensorType::rank == 2 || TensorType::rank == 3, "Only support 2D or 3D Tensor");
static_assert(MaskType::rank == 2 || MaskType::rank == 3, "Only support 2D or 3D Mask");
static_assert(BiasType::rank == 2 || BiasType::rank == 3, "Only support 2D or 3D Bias");
const int lane_id = threadIdx.x % 32;
const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
#pragma unroll
for (int bi = 0; bi < (TensorType::rank == 3 ? size<0, 2>(tensor) : 1); ++bi) { // Handle batch dimension for 3D tensors
const int batch_idx = (TensorType::rank == 3 ? bi : 0);
#pragma unroll
for (int mi = 0; mi < size<0, 1>(tensor); ++mi) {

Copilot uses AI. Check for mistakes.
const int row_idx_base = row_idx_offset + mi * warp_row_stride;
#pragma unroll
for (int i = 0; i < size<0, 0>(tensor); ++i) {
const int row_idx = row_idx_base + i * 8;
const int col_idx_limit = Causal_mask ? std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q) : max_seqlen_k;
Copy link

Copilot AI Jul 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The causal mask calculation 'row_idx + 1 + max_seqlen_k - max_seqlen_q' contains magic numbers and complex logic that should be extracted into a well-named helper function or have explanatory comments describing the mathematical relationship.

Suggested change
const int col_idx_limit = Causal_mask ? std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q) : max_seqlen_k;
const int col_idx_limit = Causal_mask ? std::min(max_seqlen_k, compute_causal_col_idx_limit(row_idx, max_seqlen_k, max_seqlen_q)) : max_seqlen_k;

Copilot uses AI. Check for mistakes.
#pragma unroll
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
const int col_idx_base = col_idx_offset + nj * 8;
#pragma unroll
for (int j = 0; j < size<1, 0>(tensor); ++j) {
const int col_idx = col_idx_base + j;
// Without the "make_coord" we get wrong results
auto coord = make_coord(make_coord(i, mi), make_coord(j, nj));
bool inactive = (col_idx >= col_idx_limit) || (Mask(coord) == 0.0f);
if (inactive) {
tensor(coord) = -INFINITY;
} else {
// Apply scaling and bias
tensor(coord) = tensor(coord) * scale_softmax + Bias(coord);
}
}
}
}
}
}

template <bool Is_causal>
struct Mask {
const int max_seqlen_k, max_seqlen_q;
Expand Down