-
Notifications
You must be signed in to change notification settings - Fork 15.5k
[MLIR][Vector] Add unroll pattern for vector.constant_mask #171518
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir Author: Nishant Patel (nbpatel) ChangesFull diff: https://github.com/llvm/llvm-project/pull/171518.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index d8ed46c2820fe..3d76f3d7fbc46 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2547,7 +2547,9 @@ def Vector_TypeCastOp :
}
def Vector_ConstantMaskOp :
- Vector_Op<"constant_mask", [Pure]>,
+ Vector_Op<"constant_mask", [Pure,
+ DeclareOpInterfaceMethods<VectorUnrollOpInterface>
+ ]>,
Arguments<(ins DenseI64ArrayAttr:$mask_dim_sizes)>,
Results<(outs VectorOfAnyRankOf<[I1]>)> {
let summary = "creates a constant vector mask";
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index 462bd8c3dc4a6..81e7d76eefcfb 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -1094,6 +1094,91 @@ struct UnrollCreateMaskPattern : public OpRewritePattern<vector::CreateMaskOp> {
vector::UnrollVectorOptions options;
};
+/// This pattern unrolls `vector.constant_mask` operations into smaller mask
+/// operations based on the target unroll shape. Each unrolled slice computes
+/// whether its elements should be masked based on the original mask dimensions
+/// and the slice's offset position.
+///
+/// Example:
+/// Given a constant_mask operation:
+/// %0 = vector.constant_mask [6, 10] : vector<8x16xi1> // mask first 6x10
+/// elements
+///
+/// and a target unroll shape of <4x8>, the pattern produces:
+///
+/// %false = arith.constant dense<false> : vector<8x16xi1>
+///
+/// Slice [0,0]: elements [0:4, 0:8] - fully within [6, 10] bounds
+/// %mask00 = vector.constant_mask [4, 8] : vector<4x8xi1>
+/// %r0 = vector.insert_strided_slice %mask00, %false [0, 0], [1, 1]
+/// : vector<4x8xi1> into vector<8x16xi1>
+///
+/// Slice [0,8]: elements [0:4, 8:16] - partially within bounds
+/// %mask01 = vector.constant_mask [4, 2] : vector<4x8xi1>
+/// %r1 = vector.insert_strided_slice %mask01, %r0 [0, 8], [1, 1]
+/// : vector<4x8xi1> into vector<8x16xi1>
+///
+/// Slice [4,0]: elements [4:8, 0:8] - partially within bounds
+/// %mask10 = vector.constant_mask [2, 8] : vector<4x8xi1>
+/// %r2 = vector.insert_strided_slice %mask10, %r1 [4, 0], [1, 1]
+/// : vector<4x8xi1> into vector<8x16xi1>
+///
+/// Slice [4,8]: elements [4:8, 8:16] - partially within bounds
+/// %mask11 = vector.constant_mask [2, 2] : vector<4x8xi1>
+/// %result = vector.insert_strided_slice %mask11, %r2 [4, 8], [1, 1]
+/// : vector<4x8xi1> into vector<8x16xi1>
+struct UnrollConstantMaskPattern
+ : public OpRewritePattern<vector::ConstantMaskOp> {
+ UnrollConstantMaskPattern(MLIRContext *context,
+ const vector::UnrollVectorOptions &options,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<vector::ConstantMaskOp>(context, benefit),
+ options(options) {}
+
+ LogicalResult matchAndRewrite(vector::ConstantMaskOp constantMaskOp,
+ PatternRewriter &rewriter) const override {
+ auto targetShape = getTargetShape(options, constantMaskOp);
+ if (!targetShape)
+ return failure();
+
+ VectorType resultType = constantMaskOp.getVectorType();
+ SmallVector<int64_t> originalSize = *constantMaskOp.getShapeForUnroll();
+ Location loc = constantMaskOp.getLoc();
+
+ Value result = arith::ConstantOp::create(rewriter, loc, resultType,
+ rewriter.getZeroAttr(resultType));
+ VectorType targetVectorType =
+ VectorType::get(*targetShape, rewriter.getI1Type());
+ SmallVector<int64_t> strides(targetShape->size(), 1);
+
+ // In each dimension (d), each unrolled vector computes its mask size as:
+ // min(max(originalMaskDim[d] - offset[d], 0), unrolledDimSize[d]).
+ for (SmallVector<int64_t> offsets :
+ StaticTileOffsetRange(originalSize, *targetShape)) {
+ SmallVector<int64_t> unrolledMaskDims;
+
+ for (auto [i, originalMaskDim] :
+ llvm::enumerate(constantMaskOp.getMaskDimSizes())) {
+ // Calculate how many elements in this dimension should be masked
+ // for this particular slice
+ int64_t adjustedMaskSize = std::max(originalMaskDim - offsets[i], 0L);
+ int64_t unrolledMaskDim = std::min(adjustedMaskSize, (*targetShape)[i]);
+ unrolledMaskDims.push_back(unrolledMaskDim);
+ }
+
+ auto unrolledMask = rewriter.createOrFold<vector::ConstantMaskOp>(
+ loc, targetVectorType, unrolledMaskDims);
+ result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
+ loc, unrolledMask, result, offsets, strides);
+ }
+ rewriter.replaceOp(constantMaskOp, result);
+ return success();
+ }
+
+private:
+ vector::UnrollVectorOptions options;
+};
+
/// Checks whether extractShape is a contiguous slice of shape.
/// For extractShape to be contiguous in shape:
/// 1) All but the leading dimension of extractShape and shape must match
@@ -1294,8 +1379,8 @@ void mlir::vector::populateVectorUnrollPatterns(
UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern,
UnrollStorePattern, UnrollBroadcastPattern, UnrollFromElements,
UnrollToElements, UnrollStepPattern, UnrollShapeCastPattern,
- UnrollCreateMaskPattern>(patterns.getContext(), options,
- benefit);
+ UnrollCreateMaskPattern, UnrollConstantMaskPattern>(
+ patterns.getContext(), options, benefit);
}
void mlir::vector::populateVectorToElementsUnrollPatterns(
diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
index 805e66f133c59..c2e7f6a9338b1 100644
--- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir
+++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
@@ -552,6 +552,23 @@ func.func @vector_create_mask_constant_dim_sizes() -> vector<16x16xi1> {
// CHECK: %[[S3:.*]] = vector.insert_strided_slice %[[CST_0]], %[[S2]] {offsets = [8, 8], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1>
// CHECK: return %[[S3]] : vector<16x16xi1>
+func.func @vector_constant_mask() -> vector<16x16xi1> {
+ %0 = vector.constant_mask [12, 10] : vector<16x16xi1>
+ return %0 : vector<16x16xi1>
+}
+
+// CHECK-LABEL: func @vector_constant_mask
+// CHECK-SAME: () -> vector<16x16xi1>
+// CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<16x16xi1>
+// CHECK: %[[CST_TRUE:.*]] = arith.constant dense<true> : vector<8x8xi1>
+// CHECK: %[[INS00:.*]] = vector.insert_strided_slice %[[CST_TRUE]], %[[CST]] {offsets = [0, 0], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1>
+// CHECK: %[[MASK01:.*]] = vector.constant_mask [8, 2] : vector<8x8xi1>
+// CHECK: %[[INS01:.*]] = vector.insert_strided_slice %[[MASK01]], %[[INS00]] {offsets = [0, 8], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1>
+// CHECK: %[[MASK10:.*]] = vector.constant_mask [4, 8] : vector<8x8xi1>
+// CHECK: %[[INS10:.*]] = vector.insert_strided_slice %[[MASK10]], %[[INS01]] {offsets = [8, 0], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1>
+// CHECK: %[[MASK11:.*]] = vector.constant_mask [4, 2] : vector<8x8xi1>
+// CHECK: %[[INS11:.*]] = vector.insert_strided_slice %[[MASK11]], %[[INS10]] {offsets = [8, 8], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1>
+// CHECK: return %[[INS11]] : vector<16x16xi1>
func.func @shape_cast_1D(%v: vector<16xf32>) -> vector<2x2x4xf32> {
%0 = vector.shape_cast %v : vector<16xf32> to vector<2x2x4xf32>
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index f834d0cdd42bd..2cbb5ab3067f2 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -179,11 +179,13 @@ struct TestVectorUnrollingPatterns
return success(isa<vector::StepOp>(op));
}));
populateVectorUnrollPatterns(
- patterns, UnrollVectorOptions()
- .setNativeShape(ArrayRef<int64_t>{8, 8})
- .setFilterConstraint([](Operation *op) {
- return success(isa<vector::CreateMaskOp>(op));
- }));
+ patterns,
+ UnrollVectorOptions()
+ .setNativeShape(ArrayRef<int64_t>{8, 8})
+ .setFilterConstraint([](Operation *op) {
+ return success(
+ isa<vector::CreateMaskOp, vector::ConstantMaskOp>(op));
+ }));
populateVectorUnrollPatterns(
patterns,
UnrollVectorOptions()
|
🪟 Windows x64 Test Results
✅ The build succeeded and all tests passed. |
amd-eochoalo
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, just a small nit and agreeing with other reviewers' comments.
This PR adds unrolling for vector.constant_mask op based on the targetShape. Each unrolled vector computes its local mask size in each dimension (d) as:
min(max(originalMaskSize[d] - offset[d], 0), unrolledMaskSize[d]).