Skip to content

Conversation

@nbpatel
Copy link
Contributor

@nbpatel nbpatel commented Dec 9, 2025

This PR adds unrolling for vector.constant_mask op based on the targetShape. Each unrolled vector computes its local mask size in each dimension (d) as:
min(max(originalMaskSize[d] - offset[d], 0), unrolledMaskSize[d]).

@llvmbot
Copy link
Member

llvmbot commented Dec 9, 2025

@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: Nishant Patel (nbpatel)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/171518.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.td (+3-1)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp (+87-2)
  • (modified) mlir/test/Dialect/Vector/vector-unroll-options.mlir (+17)
  • (modified) mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp (+7-5)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index d8ed46c2820fe..3d76f3d7fbc46 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2547,7 +2547,9 @@ def Vector_TypeCastOp :
 }
 
 def Vector_ConstantMaskOp :
-  Vector_Op<"constant_mask", [Pure]>,
+  Vector_Op<"constant_mask", [Pure, 
+   DeclareOpInterfaceMethods<VectorUnrollOpInterface>
+   ]>,
     Arguments<(ins DenseI64ArrayAttr:$mask_dim_sizes)>,
     Results<(outs VectorOfAnyRankOf<[I1]>)> {
   let summary = "creates a constant vector mask";
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index 462bd8c3dc4a6..81e7d76eefcfb 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -1094,6 +1094,91 @@ struct UnrollCreateMaskPattern : public OpRewritePattern<vector::CreateMaskOp> {
   vector::UnrollVectorOptions options;
 };
 
+/// This pattern unrolls `vector.constant_mask` operations into smaller mask
+/// operations based on the target unroll shape. Each unrolled slice computes
+/// whether its elements should be masked based on the original mask dimensions
+/// and the slice's offset position.
+///
+/// Example:
+///   Given a constant_mask operation:
+///     %0 = vector.constant_mask [6, 10] : vector<8x16xi1>  // mask first 6x10
+///     elements
+///
+///   and a target unroll shape of <4x8>, the pattern produces:
+///
+///     %false = arith.constant dense<false> : vector<8x16xi1>
+///
+///     Slice [0,0]: elements [0:4, 0:8] - fully within [6, 10] bounds
+///     %mask00 = vector.constant_mask [4, 8] : vector<4x8xi1>
+///     %r0 = vector.insert_strided_slice %mask00, %false [0, 0], [1, 1]
+///       : vector<4x8xi1> into vector<8x16xi1>
+///
+///     Slice [0,8]: elements [0:4, 8:16] - partially within bounds
+///     %mask01 = vector.constant_mask [4, 2] : vector<4x8xi1>
+///     %r1 = vector.insert_strided_slice %mask01, %r0 [0, 8], [1, 1]
+///       : vector<4x8xi1> into vector<8x16xi1>
+///
+///     Slice [4,0]: elements [4:8, 0:8] - partially within bounds
+///     %mask10 = vector.constant_mask [2, 8] : vector<4x8xi1>
+///     %r2 = vector.insert_strided_slice %mask10, %r1 [4, 0], [1, 1]
+///       : vector<4x8xi1> into vector<8x16xi1>
+///
+///     Slice [4,8]: elements [4:8, 8:16] - partially within bounds
+///     %mask11 = vector.constant_mask [2, 2] : vector<4x8xi1>
+///     %result = vector.insert_strided_slice %mask11, %r2 [4, 8], [1, 1]
+///       : vector<4x8xi1> into vector<8x16xi1>
+struct UnrollConstantMaskPattern
+    : public OpRewritePattern<vector::ConstantMaskOp> {
+  UnrollConstantMaskPattern(MLIRContext *context,
+                            const vector::UnrollVectorOptions &options,
+                            PatternBenefit benefit = 1)
+      : OpRewritePattern<vector::ConstantMaskOp>(context, benefit),
+        options(options) {}
+
+  LogicalResult matchAndRewrite(vector::ConstantMaskOp constantMaskOp,
+                                PatternRewriter &rewriter) const override {
+    auto targetShape = getTargetShape(options, constantMaskOp);
+    if (!targetShape)
+      return failure();
+
+    VectorType resultType = constantMaskOp.getVectorType();
+    SmallVector<int64_t> originalSize = *constantMaskOp.getShapeForUnroll();
+    Location loc = constantMaskOp.getLoc();
+
+    Value result = arith::ConstantOp::create(rewriter, loc, resultType,
+                                             rewriter.getZeroAttr(resultType));
+    VectorType targetVectorType =
+        VectorType::get(*targetShape, rewriter.getI1Type());
+    SmallVector<int64_t> strides(targetShape->size(), 1);
+
+    // In each dimension (d), each unrolled vector computes its mask size as:
+    // min(max(originalMaskDim[d] - offset[d], 0), unrolledDimSize[d]).
+    for (SmallVector<int64_t> offsets :
+         StaticTileOffsetRange(originalSize, *targetShape)) {
+      SmallVector<int64_t> unrolledMaskDims;
+
+      for (auto [i, originalMaskDim] :
+           llvm::enumerate(constantMaskOp.getMaskDimSizes())) {
+        // Calculate how many elements in this dimension should be masked
+        // for this particular slice
+        int64_t adjustedMaskSize = std::max(originalMaskDim - offsets[i], 0L);
+        int64_t unrolledMaskDim = std::min(adjustedMaskSize, (*targetShape)[i]);
+        unrolledMaskDims.push_back(unrolledMaskDim);
+      }
+
+      auto unrolledMask = rewriter.createOrFold<vector::ConstantMaskOp>(
+          loc, targetVectorType, unrolledMaskDims);
+      result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
+          loc, unrolledMask, result, offsets, strides);
+    }
+    rewriter.replaceOp(constantMaskOp, result);
+    return success();
+  }
+
+private:
+  vector::UnrollVectorOptions options;
+};
+
 /// Checks whether extractShape is a contiguous slice of shape.
 /// For extractShape to be contiguous in shape:
 /// 1) All but the leading dimension of extractShape and shape must match
@@ -1294,8 +1379,8 @@ void mlir::vector::populateVectorUnrollPatterns(
                UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern,
                UnrollStorePattern, UnrollBroadcastPattern, UnrollFromElements,
                UnrollToElements, UnrollStepPattern, UnrollShapeCastPattern,
-               UnrollCreateMaskPattern>(patterns.getContext(), options,
-                                        benefit);
+               UnrollCreateMaskPattern, UnrollConstantMaskPattern>(
+      patterns.getContext(), options, benefit);
 }
 
 void mlir::vector::populateVectorToElementsUnrollPatterns(
diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
index 805e66f133c59..c2e7f6a9338b1 100644
--- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir
+++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
@@ -552,6 +552,23 @@ func.func @vector_create_mask_constant_dim_sizes() -> vector<16x16xi1> {
 // CHECK:   %[[S3:.*]] = vector.insert_strided_slice %[[CST_0]], %[[S2]] {offsets = [8, 8], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1>
 // CHECK:   return %[[S3]] : vector<16x16xi1>
 
+func.func @vector_constant_mask() -> vector<16x16xi1> {
+  %0 = vector.constant_mask [12, 10] : vector<16x16xi1>
+  return %0 : vector<16x16xi1>
+}
+
+// CHECK-LABEL: func @vector_constant_mask
+// CHECK-SAME: () -> vector<16x16xi1>
+//       CHECK:   %[[CST:.*]] = arith.constant dense<false> : vector<16x16xi1>
+//       CHECK:   %[[CST_TRUE:.*]] = arith.constant dense<true> : vector<8x8xi1>
+//       CHECK:   %[[INS00:.*]] = vector.insert_strided_slice %[[CST_TRUE]], %[[CST]] {offsets = [0, 0], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1>
+//       CHECK:   %[[MASK01:.*]] = vector.constant_mask [8, 2] : vector<8x8xi1>
+//       CHECK:   %[[INS01:.*]] = vector.insert_strided_slice %[[MASK01]], %[[INS00]] {offsets = [0, 8], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1>
+//       CHECK:   %[[MASK10:.*]] = vector.constant_mask [4, 8] : vector<8x8xi1>
+//       CHECK:   %[[INS10:.*]] = vector.insert_strided_slice %[[MASK10]], %[[INS01]] {offsets = [8, 0], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1>
+//       CHECK:   %[[MASK11:.*]] = vector.constant_mask [4, 2] : vector<8x8xi1>
+//       CHECK:   %[[INS11:.*]] = vector.insert_strided_slice %[[MASK11]], %[[INS10]] {offsets = [8, 8], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1>
+//       CHECK:   return %[[INS11]] : vector<16x16xi1>
 
 func.func @shape_cast_1D(%v: vector<16xf32>) -> vector<2x2x4xf32> {
   %0 = vector.shape_cast %v : vector<16xf32> to vector<2x2x4xf32>
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index f834d0cdd42bd..2cbb5ab3067f2 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -179,11 +179,13 @@ struct TestVectorUnrollingPatterns
                                        return success(isa<vector::StepOp>(op));
                                      }));
     populateVectorUnrollPatterns(
-        patterns, UnrollVectorOptions()
-                      .setNativeShape(ArrayRef<int64_t>{8, 8})
-                      .setFilterConstraint([](Operation *op) {
-                        return success(isa<vector::CreateMaskOp>(op));
-                      }));
+        patterns,
+        UnrollVectorOptions()
+            .setNativeShape(ArrayRef<int64_t>{8, 8})
+            .setFilterConstraint([](Operation *op) {
+              return success(
+                  isa<vector::CreateMaskOp, vector::ConstantMaskOp>(op));
+            }));
     populateVectorUnrollPatterns(
         patterns,
         UnrollVectorOptions()

@nbpatel nbpatel requested a review from amd-eochoalo December 9, 2025 22:42
@github-actions
Copy link

github-actions bot commented Dec 9, 2025

🪟 Windows x64 Test Results

  • 3348 tests passed
  • 411 tests skipped

✅ The build succeeded and all tests passed.

Copy link
Contributor

@amd-eochoalo amd-eochoalo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, just a small nit and agreeing with other reviewers' comments.

@nbpatel nbpatel merged commit 71ee84a into llvm:main Dec 11, 2025
10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants