Skip to content

Conversation

newling
Copy link
Contributor

@newling newling commented Oct 2, 2025

This PR adds a canonicalizer to vector.step that folds vector.step iff the result of the fold is a splat value. An alternative would be to always constant fold it, but that might result in some very large/cumbersome constants.

I do wonder if vector.step might be better represented as some sort of attribute in the arith dialect, like %step = arith.constant iota<32> : vector<32xindex>.

Signed-off-by: James Newling <james.newling@gmail.com>
@llvmbot
Copy link
Member

llvmbot commented Oct 2, 2025

@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: James Newling (newling)

Changes

This PR adds a canonicalizer to vector.step that folds vector.step iff the result of the fold is a splat value. An alternative would be to always constant fold it, but that might result in some very large/cumbersome constants.

I do wonder if vector.step might be better represented as some sort of attribute in the arith dialect, like %step = arith.constant iota<32> : vector<32xindex>.


Full diff: https://github.com/llvm/llvm-project/pull/161615.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.td (+1)
  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+95)
  • (added) mlir/test/Dialect/Vector/canonicalize/vector-step.mlir (+379)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 252c0b72456df..dbb5d0f659159 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -3045,6 +3045,7 @@ def Vector_StepOp : Vector_Op<"step", [
   }];
   let results = (outs VectorOfRankAndType<[1], [Index]>:$result);
   let assemblyFormat = "attr-dict `:` type($result)";
+  let hasCanonicalizer = 1;
 }
 
 def Vector_YieldOp : Vector_Op<"yield", [
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index eb4686997c1b9..306be186308b0 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -7524,6 +7524,101 @@ void StepOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
   setResultRanges(getResult(), result);
 }
 
+namespace {
+
+/// Fold `vector.step -> arith.cmpi` when the step value is compared to a
+/// constant large enough such that the result is the same at all indices.
+///
+/// For example, rewrite the 'greater than' comparison below,
+///
+/// %cst = arith.constant dense<7> : vector<3xindex>
+/// %stp = vector.step : vector<3xindex>
+/// %out = arith.cmpi ugt, %stp, %cst : vector<3xindex>
+///
+/// as,
+///
+/// %out = arith.constant dense<false> : vector<3xi1>.
+///
+/// Above [0, 1, 2] > [7, 7, 7] => [false, false, false]. Because the result is
+/// false at ALL indices we fold. If the constant was 1, then
+/// [0, 1, 2] > [1, 1, 1] => [false, false, true] and we do fold, conservatively
+/// preferring the 'compact' vector.step representation.
+struct StepCompareFolder : public OpRewritePattern<StepOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(StepOp stepOp,
+                                PatternRewriter &rewriter) const override {
+    const int64_t stepSize = stepOp.getResult().getType().getNumElements();
+
+    for (auto &use : stepOp.getResult().getUses()) {
+      if (auto cmpiOp = dyn_cast<arith::CmpIOp>(use.getOwner())) {
+        const unsigned stepOperandNumber = use.getOperandNumber();
+
+        // arith.cmpi canonicalizer makes constants final operands.
+        if (stepOperandNumber != 0)
+          continue;
+
+        // Check that operand 1 is a constant.
+        unsigned constOperandNumber = 1;
+        Value otherOperand = cmpiOp.getOperand(constOperandNumber);
+        auto maybeConstValue = getConstantIntValue(otherOperand);
+        if (!maybeConstValue.has_value())
+          continue;
+
+        int64_t constValue = maybeConstValue.value();
+        arith::CmpIPredicate pred = cmpiOp.getPredicate();
+
+        auto maybeSplat = [&]() -> std::optional<bool> {
+          // Handle ult (unsigned less than) and uge (unsigned greater equal).
+          if ((pred == arith::CmpIPredicate::ult ||
+               pred == arith::CmpIPredicate::uge) &&
+              stepSize <= constValue)
+            return pred == arith::CmpIPredicate::ult;
+
+          // Handle ule and ugt.
+          if ((pred == arith::CmpIPredicate::ule ||
+               pred == arith::CmpIPredicate::ugt) &&
+              stepSize <= constValue + 1)
+            return pred == arith::CmpIPredicate::ule;
+
+          // Handle eq and ne.
+          if ((pred == arith::CmpIPredicate::eq ||
+               pred == arith::CmpIPredicate::ne) &&
+              stepSize <= constValue)
+            return pred == arith::CmpIPredicate::ne;
+
+          return std::optional<bool>();
+        }();
+
+        if (!maybeSplat.has_value())
+          continue;
+
+        rewriter.setInsertionPointAfter(cmpiOp);
+
+        auto type = dyn_cast<VectorType>(cmpiOp.getResult().getType());
+        if (!type)
+          continue;
+
+        DenseElementsAttr boolAttr =
+            DenseElementsAttr::get(type, maybeSplat.value());
+        Value splat = mlir::arith::ConstantOp::create(rewriter, cmpiOp.getLoc(),
+                                                      type, boolAttr);
+
+        rewriter.replaceOp(cmpiOp, splat);
+        return success();
+      }
+    }
+
+    return failure();
+  }
+};
+} // namespace
+
+void StepOp::getCanonicalizationPatterns(RewritePatternSet &results,
+                                         MLIRContext *context) {
+  results.add<StepCompareFolder>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // Vector Masking Utilities
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-step.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-step.mlir
new file mode 100644
index 0000000000000..effeb3d9c093a
--- /dev/null
+++ b/mlir/test/Dialect/Vector/canonicalize/vector-step.mlir
@@ -0,0 +1,379 @@
+// RUN: mlir-opt %s -canonicalize="test-convergence" -split-input-file -allow-unregistered-dialect | FileCheck %s
+
+///===----------------------------------------------===//
+///  Tests of `StepCompareFolder`
+///===----------------------------------------------===//
+
+
+///===------------------------------------===//
+///  Tests of `ugt` (unsigned greater than)
+///===------------------------------------===//
+
+// CHECK-LABEL: @check_ugt_constant_3_lhs
+//       CHECK: %[[CST:.*]] = arith.constant dense<true> : vector<3xi1>
+//       CHECK: return %[[CST]] : vector<3xi1>
+func.func @check_ugt_constant_3_lhs() -> vector<3xi1> {
+  %cst = arith.constant dense<3> : vector<3xindex>
+  %0 = vector.step : vector<3xindex>
+  // 3 > [0, 1, 2] => true
+  %1 = arith.cmpi ugt, %cst, %0 : vector<3xindex>
+  return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @check_ugt_constant_2_lhs
+//       CHECK: %[[CMP:.*]] = arith.cmpi
+//       CHECK: return %[[CMP]]
+func.func @check_ugt_constant_2_lhs() -> vector<3xi1> {
+  %cst = arith.constant dense<2> : vector<3xindex>
+  %0 = vector.step : vector<3xindex>
+  // 2 > [0, 1, 2] => not constant
+  %1 = arith.cmpi ugt, %cst, %0 : vector<3xindex>
+  return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @check_ugt_constant_1_lhs
+//       CHECK: %[[CMP:.*]] = arith.cmpi
+//       CHECK: return %[[CMP]]
+func.func @check_ugt_constant_1_lhs() -> vector<3xi1> {
+  %cst = arith.constant dense<1> : vector<3xindex>
+  %0 = vector.step : vector<3xindex>
+  // 1 > [0, 1, 2] => not constant
+  %1 = arith.cmpi ugt, %cst, %0 : vector<3xindex>
+  return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @check_ugt_constant_3_rhs
+//       CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<3xi1>
+//       CHECK: return %[[CST]] : vector<3xi1>
+func.func @check_ugt_constant_3_rhs() -> vector<3xi1> {
+  %cst = arith.constant dense<3> : vector<3xindex>
+  %0 = vector.step : vector<3xindex>
+  // [0, 1, 2] > 3 => false
+  %1 = arith.cmpi ugt, %0, %cst : vector<3xindex>
+  return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @check_ugt_constant_2_rhs
+//       CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<3xi1>
+//       CHECK: return %[[CST]] : vector<3xi1>
+func.func @check_ugt_constant_2_rhs() -> vector<3xi1> {
+  %cst = arith.constant dense<2> : vector<3xindex>
+  %0 = vector.step : vector<3xindex>
+  // [0, 1, 2] > 2 => false
+  %1 = arith.cmpi ugt, %0, %cst : vector<3xindex>
+  return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @check_ugt_constant_1_rhs
+//       CHECK: %[[CMP:.*]] = arith.cmpi
+//       CHECK: return %[[CMP]]
+func.func @check_ugt_constant_1_rhs() -> vector<3xi1> {
+  %cst = arith.constant dense<1> : vector<3xindex>
+  %0 = vector.step : vector<3xindex>
+  // [0, 1, 2] > 1 => not constant
+  %1 = arith.cmpi ugt, %0, %cst: vector<3xindex>
+  return %1 : vector<3xi1>
+}
+
+// -----
+
+///===------------------------------------===//
+///  Tests of `uge` (unsigned greater than or equal)
+///===------------------------------------===//
+
+// CHECK-LABEL: @check_uge_constant_3_lhs
+//       CHECK: %[[CST:.*]] = arith.constant dense<true> : vector<3xi1>
+//       CHECK: return %[[CST]] : vector<3xi1>
+func.func @check_uge_constant_3_lhs() -> vector<3xi1> {
+  %cst = arith.constant dense<3> : vector<3xindex>
+  %0 = vector.step : vector<3xindex>
+  // 3 >= [0, 1, 2] => true
+  %1 = arith.cmpi uge, %cst, %0 : vector<3xindex>
+  return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @check_uge_constant_2_lhs
+//       CHECK: %[[CST:.*]] = arith.constant dense<true> : vector<3xi1>
+//       CHECK: return %[[CST]] : vector<3xi1>
+func.func @check_uge_constant_2_lhs() -> vector<3xi1> {
+  %cst = arith.constant dense<2> : vector<3xindex>
+  %0 = vector.step : vector<3xindex>
+  // 2 >= [0, 1, 2] => true
+  %1 = arith.cmpi uge, %cst, %0 : vector<3xindex>
+  return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @check_uge_constant_1_lhs
+//       CHECK: %[[CMP:.*]] = arith.cmpi
+//       CHECK: return %[[CMP]]
+func.func @check_uge_constant_1_lhs() -> vector<3xi1> {
+  %cst = arith.constant dense<1> : vector<3xindex>
+  %0 = vector.step : vector<3xindex>
+  // 1 >= [0, 1, 2] => not constant
+  %1 = arith.cmpi uge, %cst, %0 : vector<3xindex>
+  return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @check_uge_constant_3_rhs
+//       CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<3xi1>
+//       CHECK: return %[[CST]] : vector<3xi1>
+func.func @check_uge_constant_3_rhs() -> vector<3xi1> {
+  %cst = arith.constant dense<3> : vector<3xindex>
+  %0 = vector.step : vector<3xindex>
+  // [0, 1, 2] >= 3 => false
+  %1 = arith.cmpi uge, %0, %cst : vector<3xindex>
+  return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @check_uge_constant_2_rhs
+//       CHECK: %[[CMP:.*]] = arith.cmpi
+//       CHECK: return %[[CMP]]
+func.func @check_uge_constant_2_rhs() -> vector<3xi1> {
+  %cst = arith.constant dense<2> : vector<3xindex>
+  %0 = vector.step : vector<3xindex>
+  // [0, 1, 2] >= 2 => not constant
+  %1 = arith.cmpi uge, %0, %cst : vector<3xindex>
+  return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @check_uge_constant_1_rhs
+//       CHECK: %[[CMP:.*]] = arith.cmpi
+//       CHECK: return %[[CMP]]
+func.func @check_uge_constant_1_rhs() -> vector<3xi1> {
+  %cst = arith.constant dense<1> : vector<3xindex>
+  %0 = vector.step : vector<3xindex>
+  // [0, 1, 2] >= 1 => not constant
+  %1 = arith.cmpi uge, %0, %cst: vector<3xindex>
+  return %1 : vector<3xi1>
+}
+
+// -----
+
+
+
+///===------------------------------------===//
+///  Tests of `ult` (unsigned less than)
+///===------------------------------------===//
+
+// CHECK-LABEL: @check_ult_constant_3_lhs
+//       CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<3xi1>
+//       CHECK: return %[[CST]] : vector<3xi1>
+func.func @check_ult_constant_3_lhs() -> vector<3xi1> {
+  %cst = arith.constant dense<3> : vector<3xindex>
+  %0 = vector.step : vector<3xindex>
+  %1 = arith.cmpi ult, %cst, %0 : vector<3xindex>
+  return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @check_ult_constant_2_lhs
+//       CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<3xi1>
+//       CHECK: return %[[CST]] : vector<3xi1>
+func.func @check_ult_constant_2_lhs() -> vector<3xi1> {
+  %cst = arith.constant dense<2> : vector<3xindex>
+  %0 = vector.step : vector<3xindex>
+  %1 = arith.cmpi ult, %cst, %0 : vector<3xindex>
+  return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @check_ult_constant_1_lhs
+//       CHECK: %[[CMP:.*]] = arith.cmpi
+//       CHECK: return %[[CMP]]
+func.func @check_ult_constant_1_lhs() -> vector<3xi1> {
+  %cst = arith.constant dense<1> : vector<3xindex>
+  %0 = vector.step : vector<3xindex>
+  %1 = arith.cmpi ult, %cst, %0 : vector<3xindex>
+  return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @check_ult_constant_3_rhs
+//       CHECK: %[[CST:.*]] = arith.constant dense<true> : vector<3xi1>
+//       CHECK: return %[[CST]] : vector<3xi1>
+func.func @check_ult_constant_3_rhs() -> vector<3xi1> {
+  %cst = arith.constant dense<3> : vector<3xindex>
+  %0 = vector.step : vector<3xindex>
+  %1 = arith.cmpi ult, %0, %cst : vector<3xindex>
+  return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @check_ult_constant_2_rhs
+//       CHECK: %[[CMP:.*]] = arith.cmpi
+//       CHECK: return %[[CMP]]
+func.func @check_ult_constant_2_rhs() -> vector<3xi1> {
+  %cst = arith.constant dense<2> : vector<3xindex>
+  %0 = vector.step : vector<3xindex>
+  %1 = arith.cmpi ult, %0, %cst : vector<3xindex>
+  return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @check_ult_constant_1_rhs
+//       CHECK: %[[CMP:.*]] = arith.cmpi
+//       CHECK: return %[[CMP]]
+func.func @check_ult_constant_1_rhs() -> vector<3xi1> {
+  %cst = arith.constant dense<1> : vector<3xindex>
+  %0 = vector.step : vector<3xindex>
+  %1 = arith.cmpi ult, %0, %cst: vector<3xindex>
+  return %1 : vector<3xi1>
+}
+
+// -----
+
+///===------------------------------------===//
+///  Tests of `ule` (unsigned less than or equal)
+///===------------------------------------===//
+
+// CHECK-LABEL: @check_ule_constant_3_lhs
+//       CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<3xi1>
+//       CHECK: return %[[CST]] : vector<3xi1>
+func.func @check_ule_constant_3_lhs() -> vector<3xi1> {
+  %cst = arith.constant dense<3> : vector<3xindex>
+  %0 = vector.step : vector<3xindex>
+  %1 = arith.cmpi ule, %cst, %0 : vector<3xindex>
+  return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @check_ule_constant_2_lhs
+//       CHECK: %[[CMP:.*]] = arith.cmpi
+//       CHECK: return %[[CMP]]
+func.func @check_ule_constant_2_lhs() -> vector<3xi1> {
+  %cst = arith.constant dense<2> : vector<3xindex>
+  %0 = vector.step : vector<3xindex>
+  %1 = arith.cmpi ule, %cst, %0 : vector<3xindex>
+  return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @check_ule_constant_1_lhs
+//       CHECK: %[[CMP:.*]] = arith.cmpi
+//       CHECK: return %[[CMP]]
+func.func @check_ule_constant_1_lhs() -> vector<3xi1> {
+  %cst = arith.constant dense<1> : vector<3xindex>
+  %0 = vector.step : vector<3xindex>
+  %1 = arith.cmpi ule, %cst, %0 : vector<3xindex>
+  return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @check_ule_constant_3_rhs
+//       CHECK: %[[CST:.*]] = arith.constant dense<true> : vector<3xi1>
+//       CHECK: return %[[CST]] : vector<3xi1>
+func.func @check_ule_constant_3_rhs() -> vector<3xi1> {
+  %cst = arith.constant dense<3> : vector<3xindex>
+  %0 = vector.step : vector<3xindex>
+  %1 = arith.cmpi ule, %0, %cst : vector<3xindex>
+  return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @check_ule_constant_2_rhs
+//       CHECK: %[[CST:.*]] = arith.constant dense<true> : vector<3xi1>
+//       CHECK: return %[[CST]] : vector<3xi1>
+func.func @check_ule_constant_2_rhs() -> vector<3xi1> {
+  %cst = arith.constant dense<2> : vector<3xindex>
+  %0 = vector.step : vector<3xindex>
+  %1 = arith.cmpi ule, %0, %cst : vector<3xindex>
+  return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @check_ule_constant_1_rhs
+//       CHECK: %[[CMP:.*]] = arith.cmpi
+//       CHECK: return %[[CMP]]
+func.func @check_ule_constant_1_rhs() -> vector<3xi1> {
+  %cst = arith.constant dense<1> : vector<3xindex>
+  %0 = vector.step : vector<3xindex>
+  %1 = arith.cmpi ule, %0, %cst: vector<3xindex>
+  return %1 : vector<3xi1>
+}
+
+// -----
+
+///===------------------------------------===//
+///  Tests of `eq` (equal)
+///===------------------------------------===//
+
+// CHECK-LABEL: @check_eq_constant_3
+//       CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<3xi1>
+//       CHECK: return %[[CST]] : vector<3xi1>
+func.func @check_eq_constant_3() -> vector<3xi1> {
+  %cst = arith.constant dense<3> : vector<3xindex>
+  %0 = vector.step : vector<3xindex>
+  %1 = arith.cmpi eq, %0, %cst: vector<3xindex>
+  return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @check_eq_constant_2
+//       CHECK: %[[CMP:.*]] = arith.cmpi
+//       CHECK: return %[[CMP]]
+func.func @check_eq_constant_2() -> vector<3xi1> {
+  %cst = arith.constant dense<2> : vector<3xindex>
+  %0 = vector.step : vector<3xindex>
+  %1 = arith.cmpi eq, %0, %cst: vector<3xindex>
+  return %1 : vector<3xi1>
+}
+
+// -----
+
+///===------------------------------------===//
+///  Tests of `ne` (not equal)
+///===------------------------------------===//
+
+// CHECK-LABEL: @check_ne_constant_3
+//       CHECK: %[[CST:.*]] = arith.constant dense<true> : vector<3xi1>
+//       CHECK: return %[[CST]] : vector<3xi1>
+func.func @check_ne_constant_3() -> vector<3xi1> {
+  %cst = arith.constant dense<3> : vector<3xindex>
+  %0 = vector.step : vector<3xindex>
+  %1 = arith.cmpi ne, %0, %cst: vector<3xindex>
+  return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @check_ne_constant_2
+//       CHECK: %[[CMP:.*]] = arith.cmpi
+//       CHECK: return %[[CMP]]
+func.func @check_ne_constant_2() -> vector<3xi1> {
+  %cst = arith.constant dense<2> : vector<3xindex>
+  %0 = vector.step : vector<3xindex>
+  %1 = arith.cmpi ne, %0, %cst: vector<3xindex>
+  return %1 : vector<3xi1>
+}
+

Signed-off-by: James Newling <james.newling@gmail.com>
Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An alternative would be to always constant fold it, but that might result in some very large/cumbersome constants.

FYI, we have some max vector sizes defined for insert/extract folders for this exact reason. I think it would work here too.

Signed-off-by: James Newling <james.newling@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants