From 983c12b51e8b03830f9e04b72f7ded5fab87be86 Mon Sep 17 00:00:00 2001 From: nbpatel Date: Tue, 9 Sep 2025 18:45:29 +0000 Subject: [PATCH 1/5] Add unroll pattern for StepOp --- .../mlir/Dialect/Vector/IR/VectorOps.td | 1 + mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 4 ++ .../Vector/Transforms/VectorUnroll.cpp | 50 ++++++++++++++++++- .../Dialect/Vector/vector-unroll-options.mlir | 23 +++++++++ .../Dialect/Vector/TestVectorTransforms.cpp | 6 +++ 5 files changed, 83 insertions(+), 1 deletion(-) diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index 77e26cca1607f..3d43e26c6be42 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -3018,6 +3018,7 @@ def Vector_ScanOp : def Vector_StepOp : Vector_Op<"step", [ Pure, + DeclareOpInterfaceMethods, DeclareOpInterfaceMethods ]> { let summary = "A linear sequence of values from 0 to N"; diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 85e485c28c74e..7ea13b5723ad8 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -7442,6 +7442,10 @@ void StepOp::inferResultRanges(ArrayRef argRanges, setResultRanges(getResult(), result); } +std::optional> StepOp::getShapeForUnroll() { + return llvm::to_vector<4>(llvm::cast(getType()).getShape()); +} + //===----------------------------------------------------------------------===// // Vector Masking Utilities //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp index e8ecb0c0be846..0671dd1c4bea2 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp @@ -809,6 +809,54 @@ struct UnrollBroadcastPattern : public OpRewritePattern { vector::UnrollVectorOptions options; }; +struct UnrollStepPattern : public OpRewritePattern { + UnrollStepPattern(MLIRContext *context, + const vector::UnrollVectorOptions &options, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), options(options) {} + + LogicalResult matchAndRewrite(vector::StepOp stepOp, + PatternRewriter &rewriter) const override { + auto targetShape = getTargetShape(options, stepOp); + if (!targetShape) + return failure(); + + VectorType vecType = stepOp.getType(); + if (vecType.isScalable()) { + // Scalable vectors are not supported by this pattern. + return failure(); + } + int64_t originalSize = vecType.getShape()[0]; + Location loc = stepOp.getLoc(); + SmallVector strides(1, 1); + + Value result = arith::ConstantOp::create(rewriter, loc, vecType, + rewriter.getZeroAttr(vecType)); + + for (SmallVector offsets : + StaticTileOffsetRange({originalSize}, *targetShape)) { + int64_t tileOffset = offsets[0]; + auto targetVecType = + VectorType::get(*targetShape, vecType.getElementType()); + Value baseStep = rewriter.create(loc, targetVecType); + Value offsetVal = + rewriter.create(loc, tileOffset); + Value bcastOffset = + rewriter.create(loc, targetVecType, offsetVal); + Value tileStep = + rewriter.create(loc, baseStep, bcastOffset); + + result = rewriter.createOrFold( + loc, tileStep, result, offsets, strides); + } + rewriter.replaceOp(stepOp, result); + return success(); + } + +private: + vector::UnrollVectorOptions options; +}; + } // namespace void mlir::vector::populateVectorUnrollPatterns( @@ -818,6 +866,6 @@ void mlir::vector::populateVectorUnrollPatterns( UnrollContractionPattern, UnrollElementwisePattern, UnrollReductionPattern, UnrollMultiReductionPattern, UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern, - UnrollStorePattern, UnrollBroadcastPattern>( + UnrollStorePattern, UnrollBroadcastPattern, UnrollStepPattern>( patterns.getContext(), options, benefit); } diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir index e129cd5c40b9c..777af995b4554 100644 --- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir +++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir @@ -420,3 +420,26 @@ func.func @vector_store_2D(%mem: memref<4x4xf16>, %v: vector<4x4xf16>) { // CHECK: vector.store %[[V2]], %[[ARG0]][%[[C2]], %[[C0]]] : memref<4x4xf16>, vector<2x2xf16> // CHECK: %[[V3:.*]] = vector.extract_strided_slice %[[ARG1]] {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf16> to vector<2x2xf16> // CHECK: vector.store %[[V3]], %[[ARG0]][%[[C2]], %[[C2]]] : memref<4x4xf16>, vector<2x2xf16> + + +func.func @vector_step() -> vector<32xindex> { + %0 = vector.step : vector<32xindex> + return %0 : vector<32xindex> +} +// CHECK-LABEL: func @vector_step +// CHECK: %[[CST3:.*]] = arith.constant dense<24> : vector<8xindex> +// CHECK: %[[CST2:.*]] = arith.constant dense<16> : vector<8xindex> +// CHECK: %[[CST1:.*]] = arith.constant dense<8> : vector<8xindex> +// CHECK: %[[CST0:.*]] = arith.constant dense<0> : vector<32xindex> +// CHECK: %[[STEP0:.*]] = vector.step : vector<8xindex> +// CHECK: %[[INS0:.*]] = vector.insert_strided_slice %[[STEP0]], %[[CST0]] {offsets = [0], strides = [1]} : vector<8xindex> into vector<32xindex> +// CHECK: %[[STEP1:.*]] = vector.step : vector<8xindex> +// CHECK: %[[ADD1:.*]] = arith.addi %[[STEP1]], %[[CST1]] : vector<8xindex> +// CHECK: %[[INS1:.*]] = vector.insert_strided_slice %[[ADD1]], %[[INS0]] {offsets = [8], strides = [1]} : vector<8xindex> into vector<32xindex> +// CHECK: %[[STEP2:.*]] = vector.step : vector<8xindex> +// CHECK: %[[ADD2:.*]] = arith.addi %[[STEP2]], %[[CST2]] : vector<8xindex> +// CHECK: %[[INS2:.*]] = vector.insert_strided_slice %[[ADD2]], %[[INS1]] {offsets = [16], strides = [1]} : vector<8xindex> into vector<32xindex> +// CHECK: %[[STEP3:.*]] = vector.step : vector<8xindex> +// CHECK: %[[ADD3:.*]] = arith.addi %[[STEP3]], %[[CST3]] : vector<8xindex> +// CHECK: %[[INS3:.*]] = vector.insert_strided_slice %[[ADD3]], %[[INS2]] {offsets = [24], strides = [1]} : vector<8xindex> into vector<32xindex> +// CHECK: return %[[INS3]] : vector<32xindex> diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp index bb1598ee3efe5..1cd092cec2b81 100644 --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -172,6 +172,12 @@ struct TestVectorUnrollingPatterns .setFilterConstraint([](Operation *op) { return success(isa(op)); })); + populateVectorUnrollPatterns(patterns, + UnrollVectorOptions() + .setNativeShape(ArrayRef{8}) + .setFilterConstraint([](Operation *op) { + return success(isa(op)); + })); populateVectorUnrollPatterns( patterns, UnrollVectorOptions() .setNativeShape(ArrayRef{1, 3, 4, 2}) From 8c6f310b09c71e6917227ae0b5e99c71fe1257a1 Mon Sep 17 00:00:00 2001 From: nbpatel Date: Tue, 16 Sep 2025 16:04:12 +0000 Subject: [PATCH 2/5] Address Feedback --- .../Vector/Transforms/VectorUnroll.cpp | 47 +++++++++++++++---- .../Dialect/Vector/vector-unroll-options.mlir | 19 ++++---- 2 files changed, 45 insertions(+), 21 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp index 0671dd1c4bea2..8865b96241548 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp @@ -809,6 +809,32 @@ struct UnrollBroadcastPattern : public OpRewritePattern { vector::UnrollVectorOptions options; }; +/// This pattern unrolls `vector.step` operations according to the provided +/// target unroll shape. It decomposes a large step vector into smaller step +/// vectors (segments) and assembles the result by inserting each computed +/// segment into the appropriate offset of the original vector. +/// +/// The pattern does not support scalable vectors and will fail to match them. +/// +/// For each segment, it adds the base step vector and the segment's offset, +/// then inserts the result into the output vector at the corresponding +/// position. +/// +/// Example: +/// Given a step operation: +/// %0 = vector.step : vector<8xindex> +/// +/// and a target unroll shape of <4>, the pattern produces: +/// +/// %base = vector.step : vector<4xindex> +/// %zero = arith.constant dense<0> : vector<4xindex> +/// %result0 = vector.insert_strided_slice %base, %zero +/// {offsets = [0], strides = [1]} : vector<4xindex> into vector<8xindex> +/// %offset = arith.constant dense<4> : vector<4xindex> +/// %segment1 = arith.addi %base, %offset : vector<4xindex> +/// %result1 = vector.insert_strided_slice %segment1, %result0 +/// {offsets = [4], strides = [1]} : vector<4xindex> into vector<8xindex> +/// struct UnrollStepPattern : public OpRewritePattern { UnrollStepPattern(MLIRContext *context, const vector::UnrollVectorOptions &options, @@ -817,7 +843,8 @@ struct UnrollStepPattern : public OpRewritePattern { LogicalResult matchAndRewrite(vector::StepOp stepOp, PatternRewriter &rewriter) const override { - auto targetShape = getTargetShape(options, stepOp); + std::optional> targetShape = + getTargetShape(options, stepOp); if (!targetShape) return failure(); @@ -833,18 +860,18 @@ struct UnrollStepPattern : public OpRewritePattern { Value result = arith::ConstantOp::create(rewriter, loc, vecType, rewriter.getZeroAttr(vecType)); + VectorType targetVecType = + VectorType::get(*targetShape, vecType.getElementType()); + Value baseStep = vector::StepOp::create(rewriter, loc, targetVecType); for (SmallVector offsets : StaticTileOffsetRange({originalSize}, *targetShape)) { - int64_t tileOffset = offsets[0]; - auto targetVecType = - VectorType::get(*targetShape, vecType.getElementType()); - Value baseStep = rewriter.create(loc, targetVecType); - Value offsetVal = - rewriter.create(loc, tileOffset); - Value bcastOffset = - rewriter.create(loc, targetVecType, offsetVal); + Value bcastOffset = arith::ConstantOp::create( + rewriter, loc, targetVecType, + DenseElementsAttr::get( + targetVecType, + IntegerAttr::get(targetVecType.getElementType(), offsets[0]))); Value tileStep = - rewriter.create(loc, baseStep, bcastOffset); + arith::AddIOp::create(rewriter, loc, baseStep, bcastOffset); result = rewriter.createOrFold( loc, tileStep, result, offsets, strides); diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir index 777af995b4554..35db14e0f7f1d 100644 --- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir +++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir @@ -427,19 +427,16 @@ func.func @vector_step() -> vector<32xindex> { return %0 : vector<32xindex> } // CHECK-LABEL: func @vector_step -// CHECK: %[[CST3:.*]] = arith.constant dense<24> : vector<8xindex> -// CHECK: %[[CST2:.*]] = arith.constant dense<16> : vector<8xindex> +// CHECK: %[[CST:.*]] = arith.constant dense<24> : vector<8xindex> +// CHECK: %[[CST0:.*]] = arith.constant dense<16> : vector<8xindex> // CHECK: %[[CST1:.*]] = arith.constant dense<8> : vector<8xindex> -// CHECK: %[[CST0:.*]] = arith.constant dense<0> : vector<32xindex> -// CHECK: %[[STEP0:.*]] = vector.step : vector<8xindex> -// CHECK: %[[INS0:.*]] = vector.insert_strided_slice %[[STEP0]], %[[CST0]] {offsets = [0], strides = [1]} : vector<8xindex> into vector<32xindex> -// CHECK: %[[STEP1:.*]] = vector.step : vector<8xindex> -// CHECK: %[[ADD1:.*]] = arith.addi %[[STEP1]], %[[CST1]] : vector<8xindex> +// CHECK: %[[CST2:.*]] = arith.constant dense<0> : vector<32xindex> +// CHECK: %[[STEP:.*]] = vector.step : vector<8xindex> +// CHECK: %[[INS0:.*]] = vector.insert_strided_slice %[[STEP]], %[[CST2]] {offsets = [0], strides = [1]} : vector<8xindex> into vector<32xindex> +// CHECK: %[[ADD1:.*]] = arith.addi %[[STEP]], %[[CST1]] : vector<8xindex> // CHECK: %[[INS1:.*]] = vector.insert_strided_slice %[[ADD1]], %[[INS0]] {offsets = [8], strides = [1]} : vector<8xindex> into vector<32xindex> -// CHECK: %[[STEP2:.*]] = vector.step : vector<8xindex> -// CHECK: %[[ADD2:.*]] = arith.addi %[[STEP2]], %[[CST2]] : vector<8xindex> +// CHECK: %[[ADD2:.*]] = arith.addi %[[STEP]], %[[CST0]] : vector<8xindex> // CHECK: %[[INS2:.*]] = vector.insert_strided_slice %[[ADD2]], %[[INS1]] {offsets = [16], strides = [1]} : vector<8xindex> into vector<32xindex> -// CHECK: %[[STEP3:.*]] = vector.step : vector<8xindex> -// CHECK: %[[ADD3:.*]] = arith.addi %[[STEP3]], %[[CST3]] : vector<8xindex> +// CHECK: %[[ADD3:.*]] = arith.addi %[[STEP]], %[[CST]] : vector<8xindex> // CHECK: %[[INS3:.*]] = vector.insert_strided_slice %[[ADD3]], %[[INS2]] {offsets = [24], strides = [1]} : vector<8xindex> into vector<32xindex> // CHECK: return %[[INS3]] : vector<32xindex> From f2673413c4df547e974095ebb9ba52cde39f5de4 Mon Sep 17 00:00:00 2001 From: nbpatel Date: Tue, 16 Sep 2025 16:19:24 +0000 Subject: [PATCH 3/5] Fix shape --- mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp index 8865b96241548..f778d3d860fe9 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp @@ -827,7 +827,7 @@ struct UnrollBroadcastPattern : public OpRewritePattern { /// and a target unroll shape of <4>, the pattern produces: /// /// %base = vector.step : vector<4xindex> -/// %zero = arith.constant dense<0> : vector<4xindex> +/// %zero = arith.constant dense<0> : vector<8xindex> /// %result0 = vector.insert_strided_slice %base, %zero /// {offsets = [0], strides = [1]} : vector<4xindex> into vector<8xindex> /// %offset = arith.constant dense<4> : vector<4xindex> From 4d8a2865e981800e64122c37136e55824d9d0fa3 Mon Sep 17 00:00:00 2001 From: nbpatel Date: Tue, 16 Sep 2025 17:28:01 +0000 Subject: [PATCH 4/5] Remove getShapeForUnroll for step --- mlir/include/mlir/Dialect/Vector/IR/VectorOps.td | 2 +- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 4 ---- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index 3d43e26c6be42..c60fa3b85b396 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -3018,7 +3018,7 @@ def Vector_ScanOp : def Vector_StepOp : Vector_Op<"step", [ Pure, - DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, DeclareOpInterfaceMethods ]> { let summary = "A linear sequence of values from 0 to N"; diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 7ea13b5723ad8..85e485c28c74e 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -7442,10 +7442,6 @@ void StepOp::inferResultRanges(ArrayRef argRanges, setResultRanges(getResult(), result); } -std::optional> StepOp::getShapeForUnroll() { - return llvm::to_vector<4>(llvm::cast(getType()).getShape()); -} - //===----------------------------------------------------------------------===// // Vector Masking Utilities //===----------------------------------------------------------------------===// From 80359dd4aa1c792de0987354be91c8737b507d81 Mon Sep 17 00:00:00 2001 From: nbpatel Date: Tue, 16 Sep 2025 18:28:33 +0000 Subject: [PATCH 5/5] Feedback --- mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp index f778d3d860fe9..79786f33a2d46 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp @@ -860,10 +860,10 @@ struct UnrollStepPattern : public OpRewritePattern { Value result = arith::ConstantOp::create(rewriter, loc, vecType, rewriter.getZeroAttr(vecType)); - VectorType targetVecType = + auto targetVecType = VectorType::get(*targetShape, vecType.getElementType()); Value baseStep = vector::StepOp::create(rewriter, loc, targetVecType); - for (SmallVector offsets : + for (const SmallVector &offsets : StaticTileOffsetRange({originalSize}, *targetShape)) { Value bcastOffset = arith::ConstantOp::create( rewriter, loc, targetVecType,