From 4555ad22c26a2b2a7e5f335d5d7e6eefb814e35f Mon Sep 17 00:00:00 2001 From: James Newling Date: Wed, 1 Oct 2025 18:07:36 -0700 Subject: [PATCH 1/3] add folder Signed-off-by: James Newling --- .../mlir/Dialect/Vector/IR/VectorOps.td | 1 + mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 114 ++++++ .../Vector/canonicalize/vector-step.mlir | 379 ++++++++++++++++++ 3 files changed, 494 insertions(+) create mode 100644 mlir/test/Dialect/Vector/canonicalize/vector-step.mlir diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index 252c0b72456df..dbb5d0f659159 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -3045,6 +3045,7 @@ def Vector_StepOp : Vector_Op<"step", [ }]; let results = (outs VectorOfRankAndType<[1], [Index]>:$result); let assemblyFormat = "attr-dict `:` type($result)"; + let hasCanonicalizer = 1; } def Vector_YieldOp : Vector_Op<"yield", [ diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index eb4686997c1b9..6a18df72c5335 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -7524,6 +7524,120 @@ void StepOp::inferResultRanges(ArrayRef argRanges, setResultRanges(getResult(), result); } +namespace { + +/// Constant fold vector.step when it is compared to constant with arith.cmpi +/// and the result is the same at all indices. For example, rewrite: +/// +/// %cst = arith.constant dense<7> : vector<3xindex> +/// %0 = vector.step : vector<3xindex> +/// %1 = arith.cmpi ugt, %0, %cst : vector<3xindex> +/// +/// as +/// +/// %out = arith.constant dense : vector<3xi1> +/// +/// Above [0, 1, 2] > [7, 7, 7] => [false, false, false]. Because the result is +/// false at ALL indices we fold to the constant. false. If the constant was 1, +/// then [0, 1, 2] > [1, 1, 1] => [false, false, true] and we do not constant +/// fold, preferring the more 'compact' vector.step representation. +struct StepCompareFolder : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(StepOp stepOp, + PatternRewriter &rewriter) const override { + + int64_t stepSize = stepOp.getResult().getType().getNumElements(); + + for (auto &use : stepOp.getResult().getUses()) { + if (auto cmpiOp = dyn_cast(use.getOwner())) { + unsigned stepOperandNumber = use.getOperandNumber(); + + // arith.cmpi has a canonicalizer to put constants on operand 1. Let it + // run first. + if (stepOperandNumber != 0) { + continue; + } + + // Check that operand 1 is a constant. + unsigned otherOperandNumber = 1; + Value otherOperand = cmpiOp.getOperand(otherOperandNumber); + auto maybeConstValue = getConstantIntValue(otherOperand); + if (!maybeConstValue.has_value()) + continue; + int64_t constValue = maybeConstValue.value(); + + arith::CmpIPredicate pred = cmpiOp.getPredicate(); + + auto maybeSplat = [&]() -> std::optional { + // Handle ult (unsigned less than) and uge (unsigned greater equal). + // Examples where stepSize = constValue = 3, for the 4 + // cases of [ult, uge] x [stepOperandNumber = 0, 1]: + // + // pred stepOperandNumber + // ==== ================= + // ult 0 [0, 1, 2] < 3 ==> true. + // ult 1 3 < [0, 1, 2] ==> false. + // uge 0 [0, 1, 2] >= 3 ==> true. + // uge 1 3 >= [0, 1, 2] ==> false. + // + // If constValue is any smaller, the comparison is not constant. + if (pred == arith::CmpIPredicate::ult || + pred == arith::CmpIPredicate::uge) { + if (stepSize <= constValue) { + return pred == arith::CmpIPredicate::ult; + } + } + + // Handle ule and ugt. + // + // pred stepOperandNumber + // ==== ================= + // ule 0 [0, 1, 2] <= 2 ==> true + // (stepSize = 3, constValue = 2). + if (pred == arith::CmpIPredicate::ule || + pred == arith::CmpIPredicate::ugt) { + if (stepSize <= constValue + 1) { + return pred == arith::CmpIPredicate::ule; + } + } + + // Handle eq and ne + if (pred == arith::CmpIPredicate::eq || + pred == arith::CmpIPredicate::ne) { + if (stepSize <= constValue) { + return pred == arith::CmpIPredicate::ne; + } + } + + return std::optional(); + }(); + + if (!maybeSplat.has_value()) + continue; + + rewriter.setInsertionPointAfter(cmpiOp); + auto boolConst = mlir::arith::ConstantOp::create( + rewriter, cmpiOp.getLoc(), + rewriter.getBoolAttr(maybeSplat.value())); + auto splat = vector::BroadcastOp::create( + rewriter, cmpiOp.getLoc(), cmpiOp.getResult().getType(), boolConst); + + rewriter.replaceOp(cmpiOp, splat.getResult()); + return success(); + } + } + + return failure(); + } +}; +} // namespace + +void StepOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} + //===----------------------------------------------------------------------===// // Vector Masking Utilities //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-step.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-step.mlir new file mode 100644 index 0000000000000..e213c78f5ea42 --- /dev/null +++ b/mlir/test/Dialect/Vector/canonicalize/vector-step.mlir @@ -0,0 +1,379 @@ +// RUN: mlir-opt %s -canonicalize="test-convergence" -split-input-file -allow-unregistered-dialect | FileCheck %s + +///===----------------------------------------------===// +/// Tests of `StepCompareFolder` +///===----------------------------------------------===// + + +///===--------------===// +/// Tests of `ugt` (unsigned greater than) +///===--------------===// + +// CHECK-LABEL: @check_ugt_constant_3_lhs +// CHECK: %[[CST:.*]] = arith.constant dense : vector<3xi1> +// CHECK: return %[[CST]] : vector<3xi1> +func.func @check_ugt_constant_3_lhs() -> vector<3xi1> { + %cst = arith.constant dense<3> : vector<3xindex> + %0 = vector.step : vector<3xindex> + // 3 > [0, 1, 2] => true + %1 = arith.cmpi ugt, %cst, %0 : vector<3xindex> + return %1 : vector<3xi1> +} + +// ----- + +// CHECK-LABEL: @check_ugt_constant_2_lhs +// CHECK: %[[CMP:.*]] = arith.cmpi +// CHECK: return %[[CMP]] +func.func @check_ugt_constant_2_lhs() -> vector<3xi1> { + %cst = arith.constant dense<2> : vector<3xindex> + %0 = vector.step : vector<3xindex> + // 2 > [0, 1, 2] => not constant + %1 = arith.cmpi ugt, %cst, %0 : vector<3xindex> + return %1 : vector<3xi1> +} + +// ----- + +// CHECK-LABEL: @check_ugt_constant_1_lhs +// CHECK: %[[CMP:.*]] = arith.cmpi +// CHECK: return %[[CMP]] +func.func @check_ugt_constant_1_lhs() -> vector<3xi1> { + %cst = arith.constant dense<1> : vector<3xindex> + %0 = vector.step : vector<3xindex> + // 1 > [0, 1, 2] => not constant + %1 = arith.cmpi ugt, %cst, %0 : vector<3xindex> + return %1 : vector<3xi1> +} + +// ----- + +// CHECK-LABEL: @check_ugt_constant_3_rhs +// CHECK: %[[CST:.*]] = arith.constant dense : vector<3xi1> +// CHECK: return %[[CST]] : vector<3xi1> +func.func @check_ugt_constant_3_rhs() -> vector<3xi1> { + %cst = arith.constant dense<3> : vector<3xindex> + %0 = vector.step : vector<3xindex> + // [0, 1, 2] > 3 => false + %1 = arith.cmpi ugt, %0, %cst : vector<3xindex> + return %1 : vector<3xi1> +} + +// ----- + +// CHECK-LABEL: @check_ugt_constant_2_rhs +// CHECK: %[[CST:.*]] = arith.constant dense : vector<3xi1> +// CHECK: return %[[CST]] : vector<3xi1> +func.func @check_ugt_constant_2_rhs() -> vector<3xi1> { + %cst = arith.constant dense<2> : vector<3xindex> + %0 = vector.step : vector<3xindex> + // [0, 1, 2] > 2 => false + %1 = arith.cmpi ugt, %0, %cst : vector<3xindex> + return %1 : vector<3xi1> +} + +// ----- + +// CHECK-LABEL: @check_ugt_constant_1_rhs +// CHECK: %[[CMP:.*]] = arith.cmpi +// CHECK: return %[[CMP]] +func.func @check_ugt_constant_1_rhs() -> vector<3xi1> { + %cst = arith.constant dense<1> : vector<3xindex> + %0 = vector.step : vector<3xindex> + // [0, 1, 2] > 1 => not constant + %1 = arith.cmpi ugt, %0, %cst: vector<3xindex> + return %1 : vector<3xi1> +} + +// ----- + +///===--------------===// +/// Tests of `uge` (unsigned greater than or equal) +///===--------------===// + +// CHECK-LABEL: @check_uge_constant_3_lhs +// CHECK: %[[CST:.*]] = arith.constant dense : vector<3xi1> +// CHECK: return %[[CST]] : vector<3xi1> +func.func @check_uge_constant_3_lhs() -> vector<3xi1> { + %cst = arith.constant dense<3> : vector<3xindex> + %0 = vector.step : vector<3xindex> + // 3 >= [0, 1, 2] => true + %1 = arith.cmpi uge, %cst, %0 : vector<3xindex> + return %1 : vector<3xi1> +} + +// ----- + +// CHECK-LABEL: @check_uge_constant_2_lhs +// CHECK: %[[CST:.*]] = arith.constant dense : vector<3xi1> +// CHECK: return %[[CST]] : vector<3xi1> +func.func @check_uge_constant_2_lhs() -> vector<3xi1> { + %cst = arith.constant dense<2> : vector<3xindex> + %0 = vector.step : vector<3xindex> + // 2 >= [0, 1, 2] => true + %1 = arith.cmpi uge, %cst, %0 : vector<3xindex> + return %1 : vector<3xi1> +} + +// ----- + +// CHECK-LABEL: @check_uge_constant_1_lhs +// CHECK: %[[CMP:.*]] = arith.cmpi +// CHECK: return %[[CMP]] +func.func @check_uge_constant_1_lhs() -> vector<3xi1> { + %cst = arith.constant dense<1> : vector<3xindex> + %0 = vector.step : vector<3xindex> + // 1 >= [0, 1, 2] => not constant + %1 = arith.cmpi uge, %cst, %0 : vector<3xindex> + return %1 : vector<3xi1> +} + +// ----- + +// CHECK-LABEL: @check_uge_constant_3_rhs +// CHECK: %[[CST:.*]] = arith.constant dense : vector<3xi1> +// CHECK: return %[[CST]] : vector<3xi1> +func.func @check_uge_constant_3_rhs() -> vector<3xi1> { + %cst = arith.constant dense<3> : vector<3xindex> + %0 = vector.step : vector<3xindex> + // [0, 1, 2] >= 3 => false + %1 = arith.cmpi uge, %0, %cst : vector<3xindex> + return %1 : vector<3xi1> +} + +// ----- + +// CHECK-LABEL: @check_uge_constant_2_rhs +// CHECK: %[[CMP:.*]] = arith.cmpi +// CHECK: return %[[CMP]] +func.func @check_uge_constant_2_rhs() -> vector<3xi1> { + %cst = arith.constant dense<2> : vector<3xindex> + %0 = vector.step : vector<3xindex> + // [0, 1, 2] >= 2 => not constant + %1 = arith.cmpi uge, %0, %cst : vector<3xindex> + return %1 : vector<3xi1> +} + +// ----- + +// CHECK-LABEL: @check_uge_constant_1_rhs +// CHECK: %[[CMP:.*]] = arith.cmpi +// CHECK: return %[[CMP]] +func.func @check_uge_constant_1_rhs() -> vector<3xi1> { + %cst = arith.constant dense<1> : vector<3xindex> + %0 = vector.step : vector<3xindex> + // [0, 1, 2] >= 1 => not constant + %1 = arith.cmpi uge, %0, %cst: vector<3xindex> + return %1 : vector<3xi1> +} + +// ----- + + + +///===--------------===// +/// Tests of `ult` (unsigned less than) +///===--------------===// + +// CHECK-LABEL: @check_ult_constant_3_lhs +// CHECK: %[[CST:.*]] = arith.constant dense : vector<3xi1> +// CHECK: return %[[CST]] : vector<3xi1> +func.func @check_ult_constant_3_lhs() -> vector<3xi1> { + %cst = arith.constant dense<3> : vector<3xindex> + %0 = vector.step : vector<3xindex> + %1 = arith.cmpi ult, %cst, %0 : vector<3xindex> + return %1 : vector<3xi1> +} + +// ----- + +// CHECK-LABEL: @check_ult_constant_2_lhs +// CHECK: %[[CST:.*]] = arith.constant dense : vector<3xi1> +// CHECK: return %[[CST]] : vector<3xi1> +func.func @check_ult_constant_2_lhs() -> vector<3xi1> { + %cst = arith.constant dense<2> : vector<3xindex> + %0 = vector.step : vector<3xindex> + %1 = arith.cmpi ult, %cst, %0 : vector<3xindex> + return %1 : vector<3xi1> +} + +// ----- + +// CHECK-LABEL: @check_ult_constant_1_lhs +// CHECK: %[[CMP:.*]] = arith.cmpi +// CHECK: return %[[CMP]] +func.func @check_ult_constant_1_lhs() -> vector<3xi1> { + %cst = arith.constant dense<1> : vector<3xindex> + %0 = vector.step : vector<3xindex> + %1 = arith.cmpi ult, %cst, %0 : vector<3xindex> + return %1 : vector<3xi1> +} + +// ----- + +// CHECK-LABEL: @check_ult_constant_3_rhs +// CHECK: %[[CST:.*]] = arith.constant dense : vector<3xi1> +// CHECK: return %[[CST]] : vector<3xi1> +func.func @check_ult_constant_3_rhs() -> vector<3xi1> { + %cst = arith.constant dense<3> : vector<3xindex> + %0 = vector.step : vector<3xindex> + %1 = arith.cmpi ult, %0, %cst : vector<3xindex> + return %1 : vector<3xi1> +} + +// ----- + +// CHECK-LABEL: @check_ult_constant_2_rhs +// CHECK: %[[CMP:.*]] = arith.cmpi +// CHECK: return %[[CMP]] +func.func @check_ult_constant_2_rhs() -> vector<3xi1> { + %cst = arith.constant dense<2> : vector<3xindex> + %0 = vector.step : vector<3xindex> + %1 = arith.cmpi ult, %0, %cst : vector<3xindex> + return %1 : vector<3xi1> +} + +// ----- + +// CHECK-LABEL: @check_ult_constant_1_rhs +// CHECK: %[[CMP:.*]] = arith.cmpi +// CHECK: return %[[CMP]] +func.func @check_ult_constant_1_rhs() -> vector<3xi1> { + %cst = arith.constant dense<1> : vector<3xindex> + %0 = vector.step : vector<3xindex> + %1 = arith.cmpi ult, %0, %cst: vector<3xindex> + return %1 : vector<3xi1> +} + +// ----- + +///===--------------===// +/// Tests of `ule` (unsigned less than or equal) +///===--------------===// + +// CHECK-LABEL: @check_ule_constant_3_lhs +// CHECK: %[[CST:.*]] = arith.constant dense : vector<3xi1> +// CHECK: return %[[CST]] : vector<3xi1> +func.func @check_ule_constant_3_lhs() -> vector<3xi1> { + %cst = arith.constant dense<3> : vector<3xindex> + %0 = vector.step : vector<3xindex> + %1 = arith.cmpi ule, %cst, %0 : vector<3xindex> + return %1 : vector<3xi1> +} + +// ----- + +// CHECK-LABEL: @check_ule_constant_2_lhs +// CHECK: %[[CMP:.*]] = arith.cmpi +// CHECK: return %[[CMP]] +func.func @check_ule_constant_2_lhs() -> vector<3xi1> { + %cst = arith.constant dense<2> : vector<3xindex> + %0 = vector.step : vector<3xindex> + %1 = arith.cmpi ule, %cst, %0 : vector<3xindex> + return %1 : vector<3xi1> +} + +// ----- + +// CHECK-LABEL: @check_ule_constant_1_lhs +// CHECK: %[[CMP:.*]] = arith.cmpi +// CHECK: return %[[CMP]] +func.func @check_ule_constant_1_lhs() -> vector<3xi1> { + %cst = arith.constant dense<1> : vector<3xindex> + %0 = vector.step : vector<3xindex> + %1 = arith.cmpi ule, %cst, %0 : vector<3xindex> + return %1 : vector<3xi1> +} + +// ----- + +// CHECK-LABEL: @check_ule_constant_3_rhs +// CHECK: %[[CST:.*]] = arith.constant dense : vector<3xi1> +// CHECK: return %[[CST]] : vector<3xi1> +func.func @check_ule_constant_3_rhs() -> vector<3xi1> { + %cst = arith.constant dense<3> : vector<3xindex> + %0 = vector.step : vector<3xindex> + %1 = arith.cmpi ule, %0, %cst : vector<3xindex> + return %1 : vector<3xi1> +} + +// ----- + +// CHECK-LABEL: @check_ule_constant_2_rhs +// CHECK: %[[CST:.*]] = arith.constant dense : vector<3xi1> +// CHECK: return %[[CST]] : vector<3xi1> +func.func @check_ule_constant_2_rhs() -> vector<3xi1> { + %cst = arith.constant dense<2> : vector<3xindex> + %0 = vector.step : vector<3xindex> + %1 = arith.cmpi ule, %0, %cst : vector<3xindex> + return %1 : vector<3xi1> +} + +// ----- + +// CHECK-LABEL: @check_ule_constant_1_rhs +// CHECK: %[[CMP:.*]] = arith.cmpi +// CHECK: return %[[CMP]] +func.func @check_ule_constant_1_rhs() -> vector<3xi1> { + %cst = arith.constant dense<1> : vector<3xindex> + %0 = vector.step : vector<3xindex> + %1 = arith.cmpi ule, %0, %cst: vector<3xindex> + return %1 : vector<3xi1> +} + +// ----- + +///===--------------===// +/// Tests of `eq` (equal) +///===--------------===// + +// CHECK-LABEL: @check_eq_constant_3 +// CHECK: %[[CST:.*]] = arith.constant dense : vector<3xi1> +// CHECK: return %[[CST]] : vector<3xi1> +func.func @check_eq_constant_3() -> vector<3xi1> { + %cst = arith.constant dense<3> : vector<3xindex> + %0 = vector.step : vector<3xindex> + %1 = arith.cmpi eq, %0, %cst: vector<3xindex> + return %1 : vector<3xi1> +} + +// ----- + +// CHECK-LABEL: @check_eq_constant_2 +// CHECK: %[[CMP:.*]] = arith.cmpi +// CHECK: return %[[CMP]] +func.func @check_eq_constant_2() -> vector<3xi1> { + %cst = arith.constant dense<2> : vector<3xindex> + %0 = vector.step : vector<3xindex> + %1 = arith.cmpi eq, %0, %cst: vector<3xindex> + return %1 : vector<3xi1> +} + +// ----- + +///===--------------===// +/// Tests of `ne` (not equal) +///===--------------===// + +// CHECK-LABEL: @check_ne_constant_3 +// CHECK: %[[CST:.*]] = arith.constant dense : vector<3xi1> +// CHECK: return %[[CST]] : vector<3xi1> +func.func @check_ne_constant_3() -> vector<3xi1> { + %cst = arith.constant dense<3> : vector<3xindex> + %0 = vector.step : vector<3xindex> + %1 = arith.cmpi ne, %0, %cst: vector<3xindex> + return %1 : vector<3xi1> +} + +// ----- + +// CHECK-LABEL: @check_ne_constant_2 +// CHECK: %[[CMP:.*]] = arith.cmpi +// CHECK: return %[[CMP]] +func.func @check_ne_constant_2() -> vector<3xi1> { + %cst = arith.constant dense<2> : vector<3xindex> + %0 = vector.step : vector<3xindex> + %1 = arith.cmpi ne, %0, %cst: vector<3xindex> + return %1 : vector<3xi1> +} + From 7232237305d5374c802995d42fe5e6126852911e Mon Sep 17 00:00:00 2001 From: James Newling Date: Thu, 2 Oct 2025 09:07:52 -0700 Subject: [PATCH 2/3] cosmetics Signed-off-by: James Newling --- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 101 +++++++----------- .../Vector/canonicalize/vector-step.mlir | 24 ++--- 2 files changed, 53 insertions(+), 72 deletions(-) diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 6a18df72c5335..306be186308b0 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -7526,89 +7526,66 @@ void StepOp::inferResultRanges(ArrayRef argRanges, namespace { -/// Constant fold vector.step when it is compared to constant with arith.cmpi -/// and the result is the same at all indices. For example, rewrite: +/// Fold `vector.step -> arith.cmpi` when the step value is compared to a +/// constant large enough such that the result is the same at all indices. +/// +/// For example, rewrite the 'greater than' comparison below, /// /// %cst = arith.constant dense<7> : vector<3xindex> -/// %0 = vector.step : vector<3xindex> -/// %1 = arith.cmpi ugt, %0, %cst : vector<3xindex> +/// %stp = vector.step : vector<3xindex> +/// %out = arith.cmpi ugt, %stp, %cst : vector<3xindex> /// -/// as +/// as, /// -/// %out = arith.constant dense : vector<3xi1> +/// %out = arith.constant dense : vector<3xi1>. /// /// Above [0, 1, 2] > [7, 7, 7] => [false, false, false]. Because the result is -/// false at ALL indices we fold to the constant. false. If the constant was 1, -/// then [0, 1, 2] > [1, 1, 1] => [false, false, true] and we do not constant -/// fold, preferring the more 'compact' vector.step representation. +/// false at ALL indices we fold. If the constant was 1, then +/// [0, 1, 2] > [1, 1, 1] => [false, false, true] and we do fold, conservatively +/// preferring the 'compact' vector.step representation. struct StepCompareFolder : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(StepOp stepOp, PatternRewriter &rewriter) const override { - - int64_t stepSize = stepOp.getResult().getType().getNumElements(); + const int64_t stepSize = stepOp.getResult().getType().getNumElements(); for (auto &use : stepOp.getResult().getUses()) { if (auto cmpiOp = dyn_cast(use.getOwner())) { - unsigned stepOperandNumber = use.getOperandNumber(); + const unsigned stepOperandNumber = use.getOperandNumber(); - // arith.cmpi has a canonicalizer to put constants on operand 1. Let it - // run first. - if (stepOperandNumber != 0) { + // arith.cmpi canonicalizer makes constants final operands. + if (stepOperandNumber != 0) continue; - } // Check that operand 1 is a constant. - unsigned otherOperandNumber = 1; - Value otherOperand = cmpiOp.getOperand(otherOperandNumber); + unsigned constOperandNumber = 1; + Value otherOperand = cmpiOp.getOperand(constOperandNumber); auto maybeConstValue = getConstantIntValue(otherOperand); if (!maybeConstValue.has_value()) continue; - int64_t constValue = maybeConstValue.value(); + int64_t constValue = maybeConstValue.value(); arith::CmpIPredicate pred = cmpiOp.getPredicate(); auto maybeSplat = [&]() -> std::optional { // Handle ult (unsigned less than) and uge (unsigned greater equal). - // Examples where stepSize = constValue = 3, for the 4 - // cases of [ult, uge] x [stepOperandNumber = 0, 1]: - // - // pred stepOperandNumber - // ==== ================= - // ult 0 [0, 1, 2] < 3 ==> true. - // ult 1 3 < [0, 1, 2] ==> false. - // uge 0 [0, 1, 2] >= 3 ==> true. - // uge 1 3 >= [0, 1, 2] ==> false. - // - // If constValue is any smaller, the comparison is not constant. - if (pred == arith::CmpIPredicate::ult || - pred == arith::CmpIPredicate::uge) { - if (stepSize <= constValue) { - return pred == arith::CmpIPredicate::ult; - } - } + if ((pred == arith::CmpIPredicate::ult || + pred == arith::CmpIPredicate::uge) && + stepSize <= constValue) + return pred == arith::CmpIPredicate::ult; // Handle ule and ugt. - // - // pred stepOperandNumber - // ==== ================= - // ule 0 [0, 1, 2] <= 2 ==> true - // (stepSize = 3, constValue = 2). - if (pred == arith::CmpIPredicate::ule || - pred == arith::CmpIPredicate::ugt) { - if (stepSize <= constValue + 1) { - return pred == arith::CmpIPredicate::ule; - } - } + if ((pred == arith::CmpIPredicate::ule || + pred == arith::CmpIPredicate::ugt) && + stepSize <= constValue + 1) + return pred == arith::CmpIPredicate::ule; - // Handle eq and ne - if (pred == arith::CmpIPredicate::eq || - pred == arith::CmpIPredicate::ne) { - if (stepSize <= constValue) { - return pred == arith::CmpIPredicate::ne; - } - } + // Handle eq and ne. + if ((pred == arith::CmpIPredicate::eq || + pred == arith::CmpIPredicate::ne) && + stepSize <= constValue) + return pred == arith::CmpIPredicate::ne; return std::optional(); }(); @@ -7617,13 +7594,17 @@ struct StepCompareFolder : public OpRewritePattern { continue; rewriter.setInsertionPointAfter(cmpiOp); - auto boolConst = mlir::arith::ConstantOp::create( - rewriter, cmpiOp.getLoc(), - rewriter.getBoolAttr(maybeSplat.value())); - auto splat = vector::BroadcastOp::create( - rewriter, cmpiOp.getLoc(), cmpiOp.getResult().getType(), boolConst); - rewriter.replaceOp(cmpiOp, splat.getResult()); + auto type = dyn_cast(cmpiOp.getResult().getType()); + if (!type) + continue; + + DenseElementsAttr boolAttr = + DenseElementsAttr::get(type, maybeSplat.value()); + Value splat = mlir::arith::ConstantOp::create(rewriter, cmpiOp.getLoc(), + type, boolAttr); + + rewriter.replaceOp(cmpiOp, splat); return success(); } } diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-step.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-step.mlir index e213c78f5ea42..effeb3d9c093a 100644 --- a/mlir/test/Dialect/Vector/canonicalize/vector-step.mlir +++ b/mlir/test/Dialect/Vector/canonicalize/vector-step.mlir @@ -5,9 +5,9 @@ ///===----------------------------------------------===// -///===--------------===// +///===------------------------------------===// /// Tests of `ugt` (unsigned greater than) -///===--------------===// +///===------------------------------------===// // CHECK-LABEL: @check_ugt_constant_3_lhs // CHECK: %[[CST:.*]] = arith.constant dense : vector<3xi1> @@ -87,9 +87,9 @@ func.func @check_ugt_constant_1_rhs() -> vector<3xi1> { // ----- -///===--------------===// +///===------------------------------------===// /// Tests of `uge` (unsigned greater than or equal) -///===--------------===// +///===------------------------------------===// // CHECK-LABEL: @check_uge_constant_3_lhs // CHECK: %[[CST:.*]] = arith.constant dense : vector<3xi1> @@ -171,9 +171,9 @@ func.func @check_uge_constant_1_rhs() -> vector<3xi1> { -///===--------------===// +///===------------------------------------===// /// Tests of `ult` (unsigned less than) -///===--------------===// +///===------------------------------------===// // CHECK-LABEL: @check_ult_constant_3_lhs // CHECK: %[[CST:.*]] = arith.constant dense : vector<3xi1> @@ -247,9 +247,9 @@ func.func @check_ult_constant_1_rhs() -> vector<3xi1> { // ----- -///===--------------===// +///===------------------------------------===// /// Tests of `ule` (unsigned less than or equal) -///===--------------===// +///===------------------------------------===// // CHECK-LABEL: @check_ule_constant_3_lhs // CHECK: %[[CST:.*]] = arith.constant dense : vector<3xi1> @@ -323,9 +323,9 @@ func.func @check_ule_constant_1_rhs() -> vector<3xi1> { // ----- -///===--------------===// +///===------------------------------------===// /// Tests of `eq` (equal) -///===--------------===// +///===------------------------------------===// // CHECK-LABEL: @check_eq_constant_3 // CHECK: %[[CST:.*]] = arith.constant dense : vector<3xi1> @@ -351,9 +351,9 @@ func.func @check_eq_constant_2() -> vector<3xi1> { // ----- -///===--------------===// +///===------------------------------------===// /// Tests of `ne` (not equal) -///===--------------===// +///===------------------------------------===// // CHECK-LABEL: @check_ne_constant_3 // CHECK: %[[CST:.*]] = arith.constant dense : vector<3xi1> From 934606f8cad1e2fa1fc3dec4a460a178db0d169f Mon Sep 17 00:00:00 2001 From: James Newling Date: Thu, 2 Oct 2025 10:42:28 -0700 Subject: [PATCH 3/3] address Jakub's comments Signed-off-by: James Newling --- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 94 ++++++++++--------- .../Vector/canonicalize/vector-step.mlir | 16 +++- 2 files changed, 63 insertions(+), 47 deletions(-) diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 306be186308b0..2a7dff2a99e88 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -7551,62 +7551,64 @@ struct StepCompareFolder : public OpRewritePattern { const int64_t stepSize = stepOp.getResult().getType().getNumElements(); for (auto &use : stepOp.getResult().getUses()) { - if (auto cmpiOp = dyn_cast(use.getOwner())) { - const unsigned stepOperandNumber = use.getOperandNumber(); - - // arith.cmpi canonicalizer makes constants final operands. - if (stepOperandNumber != 0) - continue; - - // Check that operand 1 is a constant. - unsigned constOperandNumber = 1; - Value otherOperand = cmpiOp.getOperand(constOperandNumber); - auto maybeConstValue = getConstantIntValue(otherOperand); - if (!maybeConstValue.has_value()) - continue; + auto cmpiOp = dyn_cast(use.getOwner()); + if (!cmpiOp) + continue; - int64_t constValue = maybeConstValue.value(); - arith::CmpIPredicate pred = cmpiOp.getPredicate(); + // arith.cmpi canonicalizer makes constants final operands. + const unsigned stepOperandNumber = use.getOperandNumber(); + if (stepOperandNumber != 0) + continue; - auto maybeSplat = [&]() -> std::optional { - // Handle ult (unsigned less than) and uge (unsigned greater equal). - if ((pred == arith::CmpIPredicate::ult || - pred == arith::CmpIPredicate::uge) && - stepSize <= constValue) - return pred == arith::CmpIPredicate::ult; + // Check that operand 1 is a constant. + unsigned constOperandNumber = 1; + Value otherOperand = cmpiOp.getOperand(constOperandNumber); + auto maybeConstValue = getConstantIntValue(otherOperand); + if (!maybeConstValue.has_value()) + continue; - // Handle ule and ugt. - if ((pred == arith::CmpIPredicate::ule || - pred == arith::CmpIPredicate::ugt) && - stepSize <= constValue + 1) - return pred == arith::CmpIPredicate::ule; + int64_t constValue = maybeConstValue.value(); + arith::CmpIPredicate pred = cmpiOp.getPredicate(); + + auto maybeSplat = [&]() -> std::optional { + // Handle ult (unsigned less than) and uge (unsigned greater equal). + if ((pred == arith::CmpIPredicate::ult || + pred == arith::CmpIPredicate::uge) && + stepSize <= constValue) + return pred == arith::CmpIPredicate::ult; + + // Handle ule and ugt. + if ((pred == arith::CmpIPredicate::ule || + pred == arith::CmpIPredicate::ugt) && + stepSize - 1 <= constValue) { + return pred == arith::CmpIPredicate::ule; + } - // Handle eq and ne. - if ((pred == arith::CmpIPredicate::eq || - pred == arith::CmpIPredicate::ne) && - stepSize <= constValue) - return pred == arith::CmpIPredicate::ne; + // Handle eq and ne. + if ((pred == arith::CmpIPredicate::eq || + pred == arith::CmpIPredicate::ne) && + stepSize <= constValue) + return pred == arith::CmpIPredicate::ne; - return std::optional(); - }(); + return std::nullopt; + }(); - if (!maybeSplat.has_value()) - continue; + if (!maybeSplat.has_value()) + continue; - rewriter.setInsertionPointAfter(cmpiOp); + rewriter.setInsertionPointAfter(cmpiOp); - auto type = dyn_cast(cmpiOp.getResult().getType()); - if (!type) - continue; + auto type = dyn_cast(cmpiOp.getResult().getType()); + if (!type) + continue; - DenseElementsAttr boolAttr = - DenseElementsAttr::get(type, maybeSplat.value()); - Value splat = mlir::arith::ConstantOp::create(rewriter, cmpiOp.getLoc(), - type, boolAttr); + DenseElementsAttr boolAttr = + DenseElementsAttr::get(type, maybeSplat.value()); + Value splat = mlir::arith::ConstantOp::create(rewriter, cmpiOp.getLoc(), + type, boolAttr); - rewriter.replaceOp(cmpiOp, splat); - return success(); - } + rewriter.replaceOp(cmpiOp, splat); + return success(); } return failure(); diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-step.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-step.mlir index effeb3d9c093a..eb997438d2d51 100644 --- a/mlir/test/Dialect/Vector/canonicalize/vector-step.mlir +++ b/mlir/test/Dialect/Vector/canonicalize/vector-step.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -canonicalize="test-convergence" -split-input-file -allow-unregistered-dialect | FileCheck %s +// RUN: mlir-opt %s -canonicalize="test-convergence" -split-input-file | FileCheck %s ///===----------------------------------------------===// /// Tests of `StepCompareFolder` @@ -59,6 +59,20 @@ func.func @check_ugt_constant_3_rhs() -> vector<3xi1> { return %1 : vector<3xi1> } +// ----- + +// CHECK-LABEL: @check_ugt_constant_max_rhs +// CHECK: %[[CST:.*]] = arith.constant dense : vector<3xi1> +// CHECK: return %[[CST]] : vector<3xi1> +func.func @check_ugt_constant_max_rhs() -> vector<3xi1> { + // The largest i64 possible: + %cst = arith.constant dense<0x7fffffffffffffff> : vector<3xindex> + %0 = vector.step : vector<3xindex> + %1 = arith.cmpi ugt, %0, %cst: vector<3xindex> + return %1 : vector<3xi1> +} + + // ----- // CHECK-LABEL: @check_ugt_constant_2_rhs