-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[MLIR][Vector] Add unrolling pattern for vector StepOp #157752
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir Author: Nishant Patel (nbpatel) ChangesThis PR adds unrolling pattern for vector.step op to VectorUnroll transform. Full diff: https://github.com/llvm/llvm-project/pull/157752.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 77e26cca1607f..3d43e26c6be42 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -3018,6 +3018,7 @@ def Vector_ScanOp :
def Vector_StepOp : Vector_Op<"step", [
Pure,
+ DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>
]> {
let summary = "A linear sequence of values from 0 to N";
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 85e485c28c74e..7ea13b5723ad8 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -7442,6 +7442,10 @@ void StepOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
setResultRanges(getResult(), result);
}
+std::optional<SmallVector<int64_t, 4>> StepOp::getShapeForUnroll() {
+ return llvm::to_vector<4>(llvm::cast<VectorType>(getType()).getShape());
+}
+
//===----------------------------------------------------------------------===//
// Vector Masking Utilities
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index e8ecb0c0be846..0671dd1c4bea2 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -809,6 +809,54 @@ struct UnrollBroadcastPattern : public OpRewritePattern<vector::BroadcastOp> {
vector::UnrollVectorOptions options;
};
+struct UnrollStepPattern : public OpRewritePattern<vector::StepOp> {
+ UnrollStepPattern(MLIRContext *context,
+ const vector::UnrollVectorOptions &options,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<vector::StepOp>(context, benefit), options(options) {}
+
+ LogicalResult matchAndRewrite(vector::StepOp stepOp,
+ PatternRewriter &rewriter) const override {
+ auto targetShape = getTargetShape(options, stepOp);
+ if (!targetShape)
+ return failure();
+
+ VectorType vecType = stepOp.getType();
+ if (vecType.isScalable()) {
+ // Scalable vectors are not supported by this pattern.
+ return failure();
+ }
+ int64_t originalSize = vecType.getShape()[0];
+ Location loc = stepOp.getLoc();
+ SmallVector<int64_t> strides(1, 1);
+
+ Value result = arith::ConstantOp::create(rewriter, loc, vecType,
+ rewriter.getZeroAttr(vecType));
+
+ for (SmallVector<int64_t> offsets :
+ StaticTileOffsetRange({originalSize}, *targetShape)) {
+ int64_t tileOffset = offsets[0];
+ auto targetVecType =
+ VectorType::get(*targetShape, vecType.getElementType());
+ Value baseStep = rewriter.create<vector::StepOp>(loc, targetVecType);
+ Value offsetVal =
+ rewriter.create<arith::ConstantIndexOp>(loc, tileOffset);
+ Value bcastOffset =
+ rewriter.create<vector::BroadcastOp>(loc, targetVecType, offsetVal);
+ Value tileStep =
+ rewriter.create<arith::AddIOp>(loc, baseStep, bcastOffset);
+
+ result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
+ loc, tileStep, result, offsets, strides);
+ }
+ rewriter.replaceOp(stepOp, result);
+ return success();
+ }
+
+private:
+ vector::UnrollVectorOptions options;
+};
+
} // namespace
void mlir::vector::populateVectorUnrollPatterns(
@@ -818,6 +866,6 @@ void mlir::vector::populateVectorUnrollPatterns(
UnrollContractionPattern, UnrollElementwisePattern,
UnrollReductionPattern, UnrollMultiReductionPattern,
UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern,
- UnrollStorePattern, UnrollBroadcastPattern>(
+ UnrollStorePattern, UnrollBroadcastPattern, UnrollStepPattern>(
patterns.getContext(), options, benefit);
}
diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
index e129cd5c40b9c..777af995b4554 100644
--- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir
+++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
@@ -420,3 +420,26 @@ func.func @vector_store_2D(%mem: memref<4x4xf16>, %v: vector<4x4xf16>) {
// CHECK: vector.store %[[V2]], %[[ARG0]][%[[C2]], %[[C0]]] : memref<4x4xf16>, vector<2x2xf16>
// CHECK: %[[V3:.*]] = vector.extract_strided_slice %[[ARG1]] {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf16> to vector<2x2xf16>
// CHECK: vector.store %[[V3]], %[[ARG0]][%[[C2]], %[[C2]]] : memref<4x4xf16>, vector<2x2xf16>
+
+
+func.func @vector_step() -> vector<32xindex> {
+ %0 = vector.step : vector<32xindex>
+ return %0 : vector<32xindex>
+}
+// CHECK-LABEL: func @vector_step
+// CHECK: %[[CST3:.*]] = arith.constant dense<24> : vector<8xindex>
+// CHECK: %[[CST2:.*]] = arith.constant dense<16> : vector<8xindex>
+// CHECK: %[[CST1:.*]] = arith.constant dense<8> : vector<8xindex>
+// CHECK: %[[CST0:.*]] = arith.constant dense<0> : vector<32xindex>
+// CHECK: %[[STEP0:.*]] = vector.step : vector<8xindex>
+// CHECK: %[[INS0:.*]] = vector.insert_strided_slice %[[STEP0]], %[[CST0]] {offsets = [0], strides = [1]} : vector<8xindex> into vector<32xindex>
+// CHECK: %[[STEP1:.*]] = vector.step : vector<8xindex>
+// CHECK: %[[ADD1:.*]] = arith.addi %[[STEP1]], %[[CST1]] : vector<8xindex>
+// CHECK: %[[INS1:.*]] = vector.insert_strided_slice %[[ADD1]], %[[INS0]] {offsets = [8], strides = [1]} : vector<8xindex> into vector<32xindex>
+// CHECK: %[[STEP2:.*]] = vector.step : vector<8xindex>
+// CHECK: %[[ADD2:.*]] = arith.addi %[[STEP2]], %[[CST2]] : vector<8xindex>
+// CHECK: %[[INS2:.*]] = vector.insert_strided_slice %[[ADD2]], %[[INS1]] {offsets = [16], strides = [1]} : vector<8xindex> into vector<32xindex>
+// CHECK: %[[STEP3:.*]] = vector.step : vector<8xindex>
+// CHECK: %[[ADD3:.*]] = arith.addi %[[STEP3]], %[[CST3]] : vector<8xindex>
+// CHECK: %[[INS3:.*]] = vector.insert_strided_slice %[[ADD3]], %[[INS2]] {offsets = [24], strides = [1]} : vector<8xindex> into vector<32xindex>
+// CHECK: return %[[INS3]] : vector<32xindex>
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index bb1598ee3efe5..1cd092cec2b81 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -172,6 +172,12 @@ struct TestVectorUnrollingPatterns
.setFilterConstraint([](Operation *op) {
return success(isa<vector::ReductionOp>(op));
}));
+ populateVectorUnrollPatterns(patterns,
+ UnrollVectorOptions()
+ .setNativeShape(ArrayRef<int64_t>{8})
+ .setFilterConstraint([](Operation *op) {
+ return success(isa<vector::StepOp>(op));
+ }));
populateVectorUnrollPatterns(
patterns, UnrollVectorOptions()
.setNativeShape(ArrayRef<int64_t>{1, 3, 4, 2})
|
@llvm/pr-subscribers-mlir-vector Author: Nishant Patel (nbpatel) ChangesThis PR adds unrolling pattern for vector.step op to VectorUnroll transform. Full diff: https://github.com/llvm/llvm-project/pull/157752.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 77e26cca1607f..3d43e26c6be42 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -3018,6 +3018,7 @@ def Vector_ScanOp :
def Vector_StepOp : Vector_Op<"step", [
Pure,
+ DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>
]> {
let summary = "A linear sequence of values from 0 to N";
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 85e485c28c74e..7ea13b5723ad8 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -7442,6 +7442,10 @@ void StepOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
setResultRanges(getResult(), result);
}
+std::optional<SmallVector<int64_t, 4>> StepOp::getShapeForUnroll() {
+ return llvm::to_vector<4>(llvm::cast<VectorType>(getType()).getShape());
+}
+
//===----------------------------------------------------------------------===//
// Vector Masking Utilities
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index e8ecb0c0be846..0671dd1c4bea2 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -809,6 +809,54 @@ struct UnrollBroadcastPattern : public OpRewritePattern<vector::BroadcastOp> {
vector::UnrollVectorOptions options;
};
+struct UnrollStepPattern : public OpRewritePattern<vector::StepOp> {
+ UnrollStepPattern(MLIRContext *context,
+ const vector::UnrollVectorOptions &options,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<vector::StepOp>(context, benefit), options(options) {}
+
+ LogicalResult matchAndRewrite(vector::StepOp stepOp,
+ PatternRewriter &rewriter) const override {
+ auto targetShape = getTargetShape(options, stepOp);
+ if (!targetShape)
+ return failure();
+
+ VectorType vecType = stepOp.getType();
+ if (vecType.isScalable()) {
+ // Scalable vectors are not supported by this pattern.
+ return failure();
+ }
+ int64_t originalSize = vecType.getShape()[0];
+ Location loc = stepOp.getLoc();
+ SmallVector<int64_t> strides(1, 1);
+
+ Value result = arith::ConstantOp::create(rewriter, loc, vecType,
+ rewriter.getZeroAttr(vecType));
+
+ for (SmallVector<int64_t> offsets :
+ StaticTileOffsetRange({originalSize}, *targetShape)) {
+ int64_t tileOffset = offsets[0];
+ auto targetVecType =
+ VectorType::get(*targetShape, vecType.getElementType());
+ Value baseStep = rewriter.create<vector::StepOp>(loc, targetVecType);
+ Value offsetVal =
+ rewriter.create<arith::ConstantIndexOp>(loc, tileOffset);
+ Value bcastOffset =
+ rewriter.create<vector::BroadcastOp>(loc, targetVecType, offsetVal);
+ Value tileStep =
+ rewriter.create<arith::AddIOp>(loc, baseStep, bcastOffset);
+
+ result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
+ loc, tileStep, result, offsets, strides);
+ }
+ rewriter.replaceOp(stepOp, result);
+ return success();
+ }
+
+private:
+ vector::UnrollVectorOptions options;
+};
+
} // namespace
void mlir::vector::populateVectorUnrollPatterns(
@@ -818,6 +866,6 @@ void mlir::vector::populateVectorUnrollPatterns(
UnrollContractionPattern, UnrollElementwisePattern,
UnrollReductionPattern, UnrollMultiReductionPattern,
UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern,
- UnrollStorePattern, UnrollBroadcastPattern>(
+ UnrollStorePattern, UnrollBroadcastPattern, UnrollStepPattern>(
patterns.getContext(), options, benefit);
}
diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
index e129cd5c40b9c..777af995b4554 100644
--- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir
+++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
@@ -420,3 +420,26 @@ func.func @vector_store_2D(%mem: memref<4x4xf16>, %v: vector<4x4xf16>) {
// CHECK: vector.store %[[V2]], %[[ARG0]][%[[C2]], %[[C0]]] : memref<4x4xf16>, vector<2x2xf16>
// CHECK: %[[V3:.*]] = vector.extract_strided_slice %[[ARG1]] {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf16> to vector<2x2xf16>
// CHECK: vector.store %[[V3]], %[[ARG0]][%[[C2]], %[[C2]]] : memref<4x4xf16>, vector<2x2xf16>
+
+
+func.func @vector_step() -> vector<32xindex> {
+ %0 = vector.step : vector<32xindex>
+ return %0 : vector<32xindex>
+}
+// CHECK-LABEL: func @vector_step
+// CHECK: %[[CST3:.*]] = arith.constant dense<24> : vector<8xindex>
+// CHECK: %[[CST2:.*]] = arith.constant dense<16> : vector<8xindex>
+// CHECK: %[[CST1:.*]] = arith.constant dense<8> : vector<8xindex>
+// CHECK: %[[CST0:.*]] = arith.constant dense<0> : vector<32xindex>
+// CHECK: %[[STEP0:.*]] = vector.step : vector<8xindex>
+// CHECK: %[[INS0:.*]] = vector.insert_strided_slice %[[STEP0]], %[[CST0]] {offsets = [0], strides = [1]} : vector<8xindex> into vector<32xindex>
+// CHECK: %[[STEP1:.*]] = vector.step : vector<8xindex>
+// CHECK: %[[ADD1:.*]] = arith.addi %[[STEP1]], %[[CST1]] : vector<8xindex>
+// CHECK: %[[INS1:.*]] = vector.insert_strided_slice %[[ADD1]], %[[INS0]] {offsets = [8], strides = [1]} : vector<8xindex> into vector<32xindex>
+// CHECK: %[[STEP2:.*]] = vector.step : vector<8xindex>
+// CHECK: %[[ADD2:.*]] = arith.addi %[[STEP2]], %[[CST2]] : vector<8xindex>
+// CHECK: %[[INS2:.*]] = vector.insert_strided_slice %[[ADD2]], %[[INS1]] {offsets = [16], strides = [1]} : vector<8xindex> into vector<32xindex>
+// CHECK: %[[STEP3:.*]] = vector.step : vector<8xindex>
+// CHECK: %[[ADD3:.*]] = arith.addi %[[STEP3]], %[[CST3]] : vector<8xindex>
+// CHECK: %[[INS3:.*]] = vector.insert_strided_slice %[[ADD3]], %[[INS2]] {offsets = [24], strides = [1]} : vector<8xindex> into vector<32xindex>
+// CHECK: return %[[INS3]] : vector<32xindex>
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index bb1598ee3efe5..1cd092cec2b81 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -172,6 +172,12 @@ struct TestVectorUnrollingPatterns
.setFilterConstraint([](Operation *op) {
return success(isa<vector::ReductionOp>(op));
}));
+ populateVectorUnrollPatterns(patterns,
+ UnrollVectorOptions()
+ .setNativeShape(ArrayRef<int64_t>{8})
+ .setFilterConstraint([](Operation *op) {
+ return success(isa<vector::StepOp>(op));
+ }));
populateVectorUnrollPatterns(
patterns, UnrollVectorOptions()
.setNativeShape(ArrayRef<int64_t>{1, 3, 4, 2})
|
pinging for review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the submission. I think it looks good, just some minor changes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, minor comments
Addressed all the feedback. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @nbpatel, all good on my side. Please address other reviewer's comments before merging.
This PR adds unrolling pattern for vector.step op to VectorUnroll transform.
This PR adds unrolling pattern for vector.step op to VectorUnroll transform.