From 1eb6c9a2f42a9ffdcc086a546f45062bf81e00f8 Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Mon, 21 Jul 2025 12:29:22 +0800 Subject: [PATCH] Adds unified mask application function with causal support Introduces a template function that consolidates mask, bias, and scaling operations into a single kernel call. Supports both causal and non-causal masking patterns through template specialization. Removes unnecessary header includes and preprocessor definitions to streamline dependencies and improve compilation times. Applies proper coordinate transformations and bounds checking to ensure correct tensor indexing across different matrix dimensions. --- csrc/src/mask.h | 57 ++++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 47 insertions(+), 10 deletions(-) diff --git a/csrc/src/mask.h b/csrc/src/mask.h index a63c322..166e765 100644 --- a/csrc/src/mask.h +++ b/csrc/src/mask.h @@ -4,22 +4,59 @@ #pragma once #include "namespace_config.h" -#include -#include -#include -#include -#include -#include -#include -#ifndef ITEMS_PER_THREAD -#define ITEMS_PER_THREAD 32 -#endif +#include namespace FLASH_NAMESPACE { using namespace cute; +template +__forceinline__ __device__ void apply_mask( + TensorType &tensor, + MaskType &Mask, + BiasType &Bias, + const float scale_softmax, + const int col_idx_offset_, + const int max_seqlen_k, + const int row_idx_offset, + const int max_seqlen_q, + const int warp_row_stride +) { + // tensor has shape (nrow=(2, MMA_M), ncol=(2, MMA_N)) + static_assert(TensorType::rank == 2, "Only support 2D Tensor"); + static_assert(MaskType::rank == 2, "Only support 2D Mask"); + static_assert(BiasType::rank == 2, "Only support 2D Bias"); + const int lane_id = threadIdx.x % 32; + const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; + #pragma unroll + for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { + const int row_idx_base = row_idx_offset + mi * warp_row_stride; + #pragma unroll + for (int i = 0; i < size<0, 0>(tensor); ++i) { + const int row_idx = row_idx_base + i * 8; + const int col_idx_limit = Causal_mask ? std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q) : max_seqlen_k; + #pragma unroll + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; + #pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx = col_idx_base + j; + // Without the "make_coord" we get wrong results + auto coord = make_coord(make_coord(i, mi), make_coord(j, nj)); + bool inactive = (col_idx >= col_idx_limit) || (Mask(coord) == 0.0f); + if (inactive) { + tensor(coord) = -INFINITY; + } else { + // Apply scaling and bias + tensor(coord) = tensor(coord) * scale_softmax + Bias(coord); + } + } + } + } + } +} + template struct Mask { const int max_seqlen_k, max_seqlen_q;