Skip to content

Conversation

nbpatel
Copy link
Contributor

@nbpatel nbpatel commented Sep 9, 2025

This PR adds unrolling pattern for vector.step op to VectorUnroll transform.

@llvmbot
Copy link
Member

llvmbot commented Sep 9, 2025

@llvm/pr-subscribers-mlir

Author: Nishant Patel (nbpatel)

Changes

This PR adds unrolling pattern for vector.step op to VectorUnroll transform.


Full diff: https://github.com/llvm/llvm-project/pull/157752.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.td (+1)
  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+4)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp (+49-1)
  • (modified) mlir/test/Dialect/Vector/vector-unroll-options.mlir (+23)
  • (modified) mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp (+6)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 77e26cca1607f..3d43e26c6be42 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -3018,6 +3018,7 @@ def Vector_ScanOp :
 
 def Vector_StepOp : Vector_Op<"step", [
     Pure,
+    DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
     DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>
   ]> {
   let summary = "A linear sequence of values from 0 to N";
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 85e485c28c74e..7ea13b5723ad8 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -7442,6 +7442,10 @@ void StepOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
   setResultRanges(getResult(), result);
 }
 
+std::optional<SmallVector<int64_t, 4>> StepOp::getShapeForUnroll() {
+  return llvm::to_vector<4>(llvm::cast<VectorType>(getType()).getShape());
+}
+
 //===----------------------------------------------------------------------===//
 // Vector Masking Utilities
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index e8ecb0c0be846..0671dd1c4bea2 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -809,6 +809,54 @@ struct UnrollBroadcastPattern : public OpRewritePattern<vector::BroadcastOp> {
   vector::UnrollVectorOptions options;
 };
 
+struct UnrollStepPattern : public OpRewritePattern<vector::StepOp> {
+  UnrollStepPattern(MLIRContext *context,
+                    const vector::UnrollVectorOptions &options,
+                    PatternBenefit benefit = 1)
+      : OpRewritePattern<vector::StepOp>(context, benefit), options(options) {}
+
+  LogicalResult matchAndRewrite(vector::StepOp stepOp,
+                                PatternRewriter &rewriter) const override {
+    auto targetShape = getTargetShape(options, stepOp);
+    if (!targetShape)
+      return failure();
+
+    VectorType vecType = stepOp.getType();
+    if (vecType.isScalable()) {
+      // Scalable vectors are not supported by this pattern.
+      return failure();
+    }
+    int64_t originalSize = vecType.getShape()[0];
+    Location loc = stepOp.getLoc();
+    SmallVector<int64_t> strides(1, 1);
+
+    Value result = arith::ConstantOp::create(rewriter, loc, vecType,
+                                             rewriter.getZeroAttr(vecType));
+
+    for (SmallVector<int64_t> offsets :
+         StaticTileOffsetRange({originalSize}, *targetShape)) {
+      int64_t tileOffset = offsets[0];
+      auto targetVecType =
+          VectorType::get(*targetShape, vecType.getElementType());
+      Value baseStep = rewriter.create<vector::StepOp>(loc, targetVecType);
+      Value offsetVal =
+          rewriter.create<arith::ConstantIndexOp>(loc, tileOffset);
+      Value bcastOffset =
+          rewriter.create<vector::BroadcastOp>(loc, targetVecType, offsetVal);
+      Value tileStep =
+          rewriter.create<arith::AddIOp>(loc, baseStep, bcastOffset);
+
+      result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
+          loc, tileStep, result, offsets, strides);
+    }
+    rewriter.replaceOp(stepOp, result);
+    return success();
+  }
+
+private:
+  vector::UnrollVectorOptions options;
+};
+
 } // namespace
 
 void mlir::vector::populateVectorUnrollPatterns(
@@ -818,6 +866,6 @@ void mlir::vector::populateVectorUnrollPatterns(
                UnrollContractionPattern, UnrollElementwisePattern,
                UnrollReductionPattern, UnrollMultiReductionPattern,
                UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern,
-               UnrollStorePattern, UnrollBroadcastPattern>(
+               UnrollStorePattern, UnrollBroadcastPattern, UnrollStepPattern>(
       patterns.getContext(), options, benefit);
 }
diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
index e129cd5c40b9c..777af995b4554 100644
--- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir
+++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
@@ -420,3 +420,26 @@ func.func @vector_store_2D(%mem: memref<4x4xf16>, %v: vector<4x4xf16>) {
   // CHECK: vector.store %[[V2]], %[[ARG0]][%[[C2]], %[[C0]]] : memref<4x4xf16>, vector<2x2xf16>
   // CHECK: %[[V3:.*]] = vector.extract_strided_slice %[[ARG1]] {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf16> to vector<2x2xf16>
   // CHECK: vector.store %[[V3]], %[[ARG0]][%[[C2]], %[[C2]]] : memref<4x4xf16>, vector<2x2xf16>
+
+
+func.func @vector_step() -> vector<32xindex> {
+    %0 = vector.step : vector<32xindex>
+    return %0 : vector<32xindex>
+}
+// CHECK-LABEL: func @vector_step
+// CHECK: %[[CST3:.*]] = arith.constant dense<24> : vector<8xindex>
+// CHECK: %[[CST2:.*]] = arith.constant dense<16> : vector<8xindex>
+// CHECK: %[[CST1:.*]] = arith.constant dense<8> : vector<8xindex>
+// CHECK: %[[CST0:.*]] = arith.constant dense<0> : vector<32xindex>
+// CHECK: %[[STEP0:.*]] = vector.step : vector<8xindex>
+// CHECK: %[[INS0:.*]] = vector.insert_strided_slice %[[STEP0]], %[[CST0]] {offsets = [0], strides = [1]} : vector<8xindex> into vector<32xindex>
+// CHECK: %[[STEP1:.*]] = vector.step : vector<8xindex>
+// CHECK: %[[ADD1:.*]] = arith.addi %[[STEP1]], %[[CST1]] : vector<8xindex>
+// CHECK: %[[INS1:.*]] = vector.insert_strided_slice %[[ADD1]], %[[INS0]] {offsets = [8], strides = [1]} : vector<8xindex> into vector<32xindex>
+// CHECK: %[[STEP2:.*]] = vector.step : vector<8xindex>
+// CHECK: %[[ADD2:.*]] = arith.addi %[[STEP2]], %[[CST2]] : vector<8xindex>
+// CHECK: %[[INS2:.*]] = vector.insert_strided_slice %[[ADD2]], %[[INS1]] {offsets = [16], strides = [1]} : vector<8xindex> into vector<32xindex>
+// CHECK: %[[STEP3:.*]] = vector.step : vector<8xindex>
+// CHECK: %[[ADD3:.*]] = arith.addi %[[STEP3]], %[[CST3]] : vector<8xindex>
+// CHECK: %[[INS3:.*]] = vector.insert_strided_slice %[[ADD3]], %[[INS2]] {offsets = [24], strides = [1]} : vector<8xindex> into vector<32xindex>
+// CHECK: return %[[INS3]] : vector<32xindex>
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index bb1598ee3efe5..1cd092cec2b81 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -172,6 +172,12 @@ struct TestVectorUnrollingPatterns
                       .setFilterConstraint([](Operation *op) {
                         return success(isa<vector::ReductionOp>(op));
                       }));
+    populateVectorUnrollPatterns(patterns,
+                                 UnrollVectorOptions()
+                                     .setNativeShape(ArrayRef<int64_t>{8})
+                                     .setFilterConstraint([](Operation *op) {
+                                       return success(isa<vector::StepOp>(op));
+                                     }));
     populateVectorUnrollPatterns(
         patterns, UnrollVectorOptions()
                       .setNativeShape(ArrayRef<int64_t>{1, 3, 4, 2})

@llvmbot
Copy link
Member

llvmbot commented Sep 9, 2025

@llvm/pr-subscribers-mlir-vector

Author: Nishant Patel (nbpatel)

Changes

This PR adds unrolling pattern for vector.step op to VectorUnroll transform.


Full diff: https://github.com/llvm/llvm-project/pull/157752.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.td (+1)
  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+4)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp (+49-1)
  • (modified) mlir/test/Dialect/Vector/vector-unroll-options.mlir (+23)
  • (modified) mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp (+6)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 77e26cca1607f..3d43e26c6be42 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -3018,6 +3018,7 @@ def Vector_ScanOp :
 
 def Vector_StepOp : Vector_Op<"step", [
     Pure,
+    DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
     DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>
   ]> {
   let summary = "A linear sequence of values from 0 to N";
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 85e485c28c74e..7ea13b5723ad8 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -7442,6 +7442,10 @@ void StepOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
   setResultRanges(getResult(), result);
 }
 
+std::optional<SmallVector<int64_t, 4>> StepOp::getShapeForUnroll() {
+  return llvm::to_vector<4>(llvm::cast<VectorType>(getType()).getShape());
+}
+
 //===----------------------------------------------------------------------===//
 // Vector Masking Utilities
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index e8ecb0c0be846..0671dd1c4bea2 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -809,6 +809,54 @@ struct UnrollBroadcastPattern : public OpRewritePattern<vector::BroadcastOp> {
   vector::UnrollVectorOptions options;
 };
 
+struct UnrollStepPattern : public OpRewritePattern<vector::StepOp> {
+  UnrollStepPattern(MLIRContext *context,
+                    const vector::UnrollVectorOptions &options,
+                    PatternBenefit benefit = 1)
+      : OpRewritePattern<vector::StepOp>(context, benefit), options(options) {}
+
+  LogicalResult matchAndRewrite(vector::StepOp stepOp,
+                                PatternRewriter &rewriter) const override {
+    auto targetShape = getTargetShape(options, stepOp);
+    if (!targetShape)
+      return failure();
+
+    VectorType vecType = stepOp.getType();
+    if (vecType.isScalable()) {
+      // Scalable vectors are not supported by this pattern.
+      return failure();
+    }
+    int64_t originalSize = vecType.getShape()[0];
+    Location loc = stepOp.getLoc();
+    SmallVector<int64_t> strides(1, 1);
+
+    Value result = arith::ConstantOp::create(rewriter, loc, vecType,
+                                             rewriter.getZeroAttr(vecType));
+
+    for (SmallVector<int64_t> offsets :
+         StaticTileOffsetRange({originalSize}, *targetShape)) {
+      int64_t tileOffset = offsets[0];
+      auto targetVecType =
+          VectorType::get(*targetShape, vecType.getElementType());
+      Value baseStep = rewriter.create<vector::StepOp>(loc, targetVecType);
+      Value offsetVal =
+          rewriter.create<arith::ConstantIndexOp>(loc, tileOffset);
+      Value bcastOffset =
+          rewriter.create<vector::BroadcastOp>(loc, targetVecType, offsetVal);
+      Value tileStep =
+          rewriter.create<arith::AddIOp>(loc, baseStep, bcastOffset);
+
+      result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
+          loc, tileStep, result, offsets, strides);
+    }
+    rewriter.replaceOp(stepOp, result);
+    return success();
+  }
+
+private:
+  vector::UnrollVectorOptions options;
+};
+
 } // namespace
 
 void mlir::vector::populateVectorUnrollPatterns(
@@ -818,6 +866,6 @@ void mlir::vector::populateVectorUnrollPatterns(
                UnrollContractionPattern, UnrollElementwisePattern,
                UnrollReductionPattern, UnrollMultiReductionPattern,
                UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern,
-               UnrollStorePattern, UnrollBroadcastPattern>(
+               UnrollStorePattern, UnrollBroadcastPattern, UnrollStepPattern>(
       patterns.getContext(), options, benefit);
 }
diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
index e129cd5c40b9c..777af995b4554 100644
--- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir
+++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
@@ -420,3 +420,26 @@ func.func @vector_store_2D(%mem: memref<4x4xf16>, %v: vector<4x4xf16>) {
   // CHECK: vector.store %[[V2]], %[[ARG0]][%[[C2]], %[[C0]]] : memref<4x4xf16>, vector<2x2xf16>
   // CHECK: %[[V3:.*]] = vector.extract_strided_slice %[[ARG1]] {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf16> to vector<2x2xf16>
   // CHECK: vector.store %[[V3]], %[[ARG0]][%[[C2]], %[[C2]]] : memref<4x4xf16>, vector<2x2xf16>
+
+
+func.func @vector_step() -> vector<32xindex> {
+    %0 = vector.step : vector<32xindex>
+    return %0 : vector<32xindex>
+}
+// CHECK-LABEL: func @vector_step
+// CHECK: %[[CST3:.*]] = arith.constant dense<24> : vector<8xindex>
+// CHECK: %[[CST2:.*]] = arith.constant dense<16> : vector<8xindex>
+// CHECK: %[[CST1:.*]] = arith.constant dense<8> : vector<8xindex>
+// CHECK: %[[CST0:.*]] = arith.constant dense<0> : vector<32xindex>
+// CHECK: %[[STEP0:.*]] = vector.step : vector<8xindex>
+// CHECK: %[[INS0:.*]] = vector.insert_strided_slice %[[STEP0]], %[[CST0]] {offsets = [0], strides = [1]} : vector<8xindex> into vector<32xindex>
+// CHECK: %[[STEP1:.*]] = vector.step : vector<8xindex>
+// CHECK: %[[ADD1:.*]] = arith.addi %[[STEP1]], %[[CST1]] : vector<8xindex>
+// CHECK: %[[INS1:.*]] = vector.insert_strided_slice %[[ADD1]], %[[INS0]] {offsets = [8], strides = [1]} : vector<8xindex> into vector<32xindex>
+// CHECK: %[[STEP2:.*]] = vector.step : vector<8xindex>
+// CHECK: %[[ADD2:.*]] = arith.addi %[[STEP2]], %[[CST2]] : vector<8xindex>
+// CHECK: %[[INS2:.*]] = vector.insert_strided_slice %[[ADD2]], %[[INS1]] {offsets = [16], strides = [1]} : vector<8xindex> into vector<32xindex>
+// CHECK: %[[STEP3:.*]] = vector.step : vector<8xindex>
+// CHECK: %[[ADD3:.*]] = arith.addi %[[STEP3]], %[[CST3]] : vector<8xindex>
+// CHECK: %[[INS3:.*]] = vector.insert_strided_slice %[[ADD3]], %[[INS2]] {offsets = [24], strides = [1]} : vector<8xindex> into vector<32xindex>
+// CHECK: return %[[INS3]] : vector<32xindex>
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index bb1598ee3efe5..1cd092cec2b81 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -172,6 +172,12 @@ struct TestVectorUnrollingPatterns
                       .setFilterConstraint([](Operation *op) {
                         return success(isa<vector::ReductionOp>(op));
                       }));
+    populateVectorUnrollPatterns(patterns,
+                                 UnrollVectorOptions()
+                                     .setNativeShape(ArrayRef<int64_t>{8})
+                                     .setFilterConstraint([](Operation *op) {
+                                       return success(isa<vector::StepOp>(op));
+                                     }));
     populateVectorUnrollPatterns(
         patterns, UnrollVectorOptions()
                       .setNativeShape(ArrayRef<int64_t>{1, 3, 4, 2})

@nbpatel
Copy link
Contributor Author

nbpatel commented Sep 12, 2025

pinging for review

@kuhar kuhar requested a review from amd-eochoalo September 12, 2025 15:46
Copy link
Contributor

@amd-eochoalo amd-eochoalo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the submission. I think it looks good, just some minor changes.

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, minor comments

@nbpatel
Copy link
Contributor Author

nbpatel commented Sep 16, 2025

Addressed all the feedback.

Copy link
Contributor

@amd-eochoalo amd-eochoalo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @nbpatel, all good on my side. Please address other reviewer's comments before merging.

@nbpatel nbpatel merged commit 0e5c32b into llvm:main Sep 16, 2025
9 checks passed
kimsh02 pushed a commit to kimsh02/llvm-project that referenced this pull request Sep 19, 2025
This PR adds unrolling pattern for vector.step op to VectorUnroll
transform.
@nbpatel nbpatel deleted the xegpu-step-blocking branch September 25, 2025 20:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants