-
Notifications
You must be signed in to change notification settings - Fork 39
Adds unified mask application function with causal support #71
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Introduces a template function that consolidates mask, bias, and scaling operations into a single kernel call. Supports both causal and non-causal masking patterns through template specialization. Removes unnecessary header includes and preprocessor definitions to streamline dependencies and improve compilation times. Applies proper coordinate transformations and bounds checking to ensure correct tensor indexing across different matrix dimensions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR introduces a unified template function for applying masks, bias, and scaling operations in a single kernel call, with support for both causal and non-causal masking patterns. The changes streamline the codebase by removing unnecessary header includes and preprocessor definitions while consolidating mask application logic.
- Adds a new
apply_masktemplate function with causal masking support through template specialization - Removes unnecessary CUTLASS headers and preprocessor definitions to reduce compilation dependencies
- Implements proper coordinate transformations for 2D tensor indexing (compared to existing 3D implementation)
Comments suppressed due to low confidence (2)
csrc/src/mask.h:47
- The parameter name 'Mask' (capitalized) is inconsistent with typical C++ naming conventions and conflicts with the existing 'Mask' struct name. Consider renaming to 'mask' (lowercase) for consistency with the existing codebase pattern.
bool inactive = (col_idx >= col_idx_limit) || (Mask(coord) == 0.0f);
csrc/src/mask.h:52
- The parameter name 'Bias' (capitalized) is inconsistent with typical C++ naming conventions. Consider renaming to 'bias' (lowercase) for consistency with the existing codebase pattern shown in the context.
tensor(coord) = tensor(coord) * scale_softmax + Bias(coord);
| static_assert(TensorType::rank == 2, "Only support 2D Tensor"); | ||
| static_assert(MaskType::rank == 2, "Only support 2D Mask"); | ||
| static_assert(BiasType::rank == 2, "Only support 2D Bias"); | ||
| const int lane_id = threadIdx.x % 32; | ||
| const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; | ||
| #pragma unroll | ||
| for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { |
Copilot
AI
Jul 21, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The new apply_mask function only supports 2D tensors, but the existing Mask::apply_mask method in the codebase works with 3D tensors (rank == 3). This inconsistency could cause confusion and limit interoperability between the two implementations.
| static_assert(TensorType::rank == 2, "Only support 2D Tensor"); | |
| static_assert(MaskType::rank == 2, "Only support 2D Mask"); | |
| static_assert(BiasType::rank == 2, "Only support 2D Bias"); | |
| const int lane_id = threadIdx.x % 32; | |
| const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; | |
| #pragma unroll | |
| for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { | |
| static_assert(TensorType::rank == 2 || TensorType::rank == 3, "Only support 2D or 3D Tensor"); | |
| static_assert(MaskType::rank == 2 || MaskType::rank == 3, "Only support 2D or 3D Mask"); | |
| static_assert(BiasType::rank == 2 || BiasType::rank == 3, "Only support 2D or 3D Bias"); | |
| const int lane_id = threadIdx.x % 32; | |
| const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; | |
| #pragma unroll | |
| for (int bi = 0; bi < (TensorType::rank == 3 ? size<0, 2>(tensor) : 1); ++bi) { // Handle batch dimension for 3D tensors | |
| const int batch_idx = (TensorType::rank == 3 ? bi : 0); | |
| #pragma unroll | |
| for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { |
| #pragma unroll | ||
| for (int i = 0; i < size<0, 0>(tensor); ++i) { | ||
| const int row_idx = row_idx_base + i * 8; | ||
| const int col_idx_limit = Causal_mask ? std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q) : max_seqlen_k; |
Copilot
AI
Jul 21, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The causal mask calculation 'row_idx + 1 + max_seqlen_k - max_seqlen_q' contains magic numbers and complex logic that should be extracted into a well-named helper function or have explanatory comments describing the mathematical relationship.
| const int col_idx_limit = Causal_mask ? std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q) : max_seqlen_k; | |
| const int col_idx_limit = Causal_mask ? std::min(max_seqlen_k, compute_causal_col_idx_limit(row_idx, max_seqlen_k, max_seqlen_q)) : max_seqlen_k; |
Introduces a template function that consolidates mask, bias, and scaling operations into a single kernel call. Supports both causal and non-causal masking patterns through template specialization.
Removes unnecessary header includes and preprocessor definitions to streamline dependencies and improve compilation times.
Applies proper coordinate transformations and bounds checking to ensure correct tensor indexing across different matrix dimensions.