diff --git a/csrc/src/mask.h b/csrc/src/mask.h index a63c322..166e765 100644 --- a/csrc/src/mask.h +++ b/csrc/src/mask.h @@ -4,22 +4,59 @@ #pragma once #include "namespace_config.h" -#include -#include -#include -#include -#include -#include -#include -#ifndef ITEMS_PER_THREAD -#define ITEMS_PER_THREAD 32 -#endif +#include namespace FLASH_NAMESPACE { using namespace cute; +template +__forceinline__ __device__ void apply_mask( + TensorType &tensor, + MaskType &Mask, + BiasType &Bias, + const float scale_softmax, + const int col_idx_offset_, + const int max_seqlen_k, + const int row_idx_offset, + const int max_seqlen_q, + const int warp_row_stride +) { + // tensor has shape (nrow=(2, MMA_M), ncol=(2, MMA_N)) + static_assert(TensorType::rank == 2, "Only support 2D Tensor"); + static_assert(MaskType::rank == 2, "Only support 2D Mask"); + static_assert(BiasType::rank == 2, "Only support 2D Bias"); + const int lane_id = threadIdx.x % 32; + const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; + #pragma unroll + for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { + const int row_idx_base = row_idx_offset + mi * warp_row_stride; + #pragma unroll + for (int i = 0; i < size<0, 0>(tensor); ++i) { + const int row_idx = row_idx_base + i * 8; + const int col_idx_limit = Causal_mask ? std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q) : max_seqlen_k; + #pragma unroll + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; + #pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx = col_idx_base + j; + // Without the "make_coord" we get wrong results + auto coord = make_coord(make_coord(i, mi), make_coord(j, nj)); + bool inactive = (col_idx >= col_idx_limit) || (Mask(coord) == 0.0f); + if (inactive) { + tensor(coord) = -INFINITY; + } else { + // Apply scaling and bias + tensor(coord) = tensor(coord) * scale_softmax + Bias(coord); + } + } + } + } + } +} + template struct Mask { const int max_seqlen_k, max_seqlen_q;