Skip to content

Conversation

@LoserCheems
Copy link
Collaborator

Introduces a template function that consolidates mask, bias, and scaling operations into a single kernel call. Supports both causal and non-causal masking patterns through template specialization.

Removes unnecessary header includes and preprocessor definitions to streamline dependencies and improve compilation times.

Applies proper coordinate transformations and bounds checking to ensure correct tensor indexing across different matrix dimensions.

Introduces a template function that consolidates mask, bias, and scaling operations into a single kernel call. Supports both causal and non-causal masking patterns through template specialization.

Removes unnecessary header includes and preprocessor definitions to streamline dependencies and improve compilation times.

Applies proper coordinate transformations and bounds checking to ensure correct tensor indexing across different matrix dimensions.
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR introduces a unified template function for applying masks, bias, and scaling operations in a single kernel call, with support for both causal and non-causal masking patterns. The changes streamline the codebase by removing unnecessary header includes and preprocessor definitions while consolidating mask application logic.

  • Adds a new apply_mask template function with causal masking support through template specialization
  • Removes unnecessary CUTLASS headers and preprocessor definitions to reduce compilation dependencies
  • Implements proper coordinate transformations for 2D tensor indexing (compared to existing 3D implementation)
Comments suppressed due to low confidence (2)

csrc/src/mask.h:47

  • The parameter name 'Mask' (capitalized) is inconsistent with typical C++ naming conventions and conflicts with the existing 'Mask' struct name. Consider renaming to 'mask' (lowercase) for consistency with the existing codebase pattern.
                    bool inactive = (col_idx >= col_idx_limit) || (Mask(coord) == 0.0f);

csrc/src/mask.h:52

  • The parameter name 'Bias' (capitalized) is inconsistent with typical C++ naming conventions. Consider renaming to 'bias' (lowercase) for consistency with the existing codebase pattern shown in the context.
                        tensor(coord) = tensor(coord) * scale_softmax + Bias(coord);

Comment on lines +27 to +33
static_assert(TensorType::rank == 2, "Only support 2D Tensor");
static_assert(MaskType::rank == 2, "Only support 2D Mask");
static_assert(BiasType::rank == 2, "Only support 2D Bias");
const int lane_id = threadIdx.x % 32;
const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
#pragma unroll
for (int mi = 0; mi < size<0, 1>(tensor); ++mi) {
Copy link

Copilot AI Jul 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new apply_mask function only supports 2D tensors, but the existing Mask::apply_mask method in the codebase works with 3D tensors (rank == 3). This inconsistency could cause confusion and limit interoperability between the two implementations.

Suggested change
static_assert(TensorType::rank == 2, "Only support 2D Tensor");
static_assert(MaskType::rank == 2, "Only support 2D Mask");
static_assert(BiasType::rank == 2, "Only support 2D Bias");
const int lane_id = threadIdx.x % 32;
const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
#pragma unroll
for (int mi = 0; mi < size<0, 1>(tensor); ++mi) {
static_assert(TensorType::rank == 2 || TensorType::rank == 3, "Only support 2D or 3D Tensor");
static_assert(MaskType::rank == 2 || MaskType::rank == 3, "Only support 2D or 3D Mask");
static_assert(BiasType::rank == 2 || BiasType::rank == 3, "Only support 2D or 3D Bias");
const int lane_id = threadIdx.x % 32;
const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
#pragma unroll
for (int bi = 0; bi < (TensorType::rank == 3 ? size<0, 2>(tensor) : 1); ++bi) { // Handle batch dimension for 3D tensors
const int batch_idx = (TensorType::rank == 3 ? bi : 0);
#pragma unroll
for (int mi = 0; mi < size<0, 1>(tensor); ++mi) {

Copilot uses AI. Check for mistakes.
#pragma unroll
for (int i = 0; i < size<0, 0>(tensor); ++i) {
const int row_idx = row_idx_base + i * 8;
const int col_idx_limit = Causal_mask ? std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q) : max_seqlen_k;
Copy link

Copilot AI Jul 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The causal mask calculation 'row_idx + 1 + max_seqlen_k - max_seqlen_q' contains magic numbers and complex logic that should be extracted into a well-named helper function or have explanatory comments describing the mathematical relationship.

Suggested change
const int col_idx_limit = Causal_mask ? std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q) : max_seqlen_k;
const int col_idx_limit = Causal_mask ? std::min(max_seqlen_k, compute_causal_col_idx_limit(row_idx, max_seqlen_k, max_seqlen_q)) : max_seqlen_k;

Copilot uses AI. Check for mistakes.
@LoserCheems LoserCheems merged commit a51b716 into main Jul 21, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants