diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index 252c0b72456df..dbb5d0f659159 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -3045,6 +3045,7 @@ def Vector_StepOp : Vector_Op<"step", [ }]; let results = (outs VectorOfRankAndType<[1], [Index]>:$result); let assemblyFormat = "attr-dict `:` type($result)"; + let hasCanonicalizer = 1; } def Vector_YieldOp : Vector_Op<"yield", [ diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index eb4686997c1b9..2a7dff2a99e88 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -7524,6 +7524,103 @@ void StepOp::inferResultRanges(ArrayRef argRanges, setResultRanges(getResult(), result); } +namespace { + +/// Fold `vector.step -> arith.cmpi` when the step value is compared to a +/// constant large enough such that the result is the same at all indices. +/// +/// For example, rewrite the 'greater than' comparison below, +/// +/// %cst = arith.constant dense<7> : vector<3xindex> +/// %stp = vector.step : vector<3xindex> +/// %out = arith.cmpi ugt, %stp, %cst : vector<3xindex> +/// +/// as, +/// +/// %out = arith.constant dense : vector<3xi1>. +/// +/// Above [0, 1, 2] > [7, 7, 7] => [false, false, false]. Because the result is +/// false at ALL indices we fold. If the constant was 1, then +/// [0, 1, 2] > [1, 1, 1] => [false, false, true] and we do fold, conservatively +/// preferring the 'compact' vector.step representation. +struct StepCompareFolder : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(StepOp stepOp, + PatternRewriter &rewriter) const override { + const int64_t stepSize = stepOp.getResult().getType().getNumElements(); + + for (auto &use : stepOp.getResult().getUses()) { + auto cmpiOp = dyn_cast(use.getOwner()); + if (!cmpiOp) + continue; + + // arith.cmpi canonicalizer makes constants final operands. + const unsigned stepOperandNumber = use.getOperandNumber(); + if (stepOperandNumber != 0) + continue; + + // Check that operand 1 is a constant. + unsigned constOperandNumber = 1; + Value otherOperand = cmpiOp.getOperand(constOperandNumber); + auto maybeConstValue = getConstantIntValue(otherOperand); + if (!maybeConstValue.has_value()) + continue; + + int64_t constValue = maybeConstValue.value(); + arith::CmpIPredicate pred = cmpiOp.getPredicate(); + + auto maybeSplat = [&]() -> std::optional { + // Handle ult (unsigned less than) and uge (unsigned greater equal). + if ((pred == arith::CmpIPredicate::ult || + pred == arith::CmpIPredicate::uge) && + stepSize <= constValue) + return pred == arith::CmpIPredicate::ult; + + // Handle ule and ugt. + if ((pred == arith::CmpIPredicate::ule || + pred == arith::CmpIPredicate::ugt) && + stepSize - 1 <= constValue) { + return pred == arith::CmpIPredicate::ule; + } + + // Handle eq and ne. + if ((pred == arith::CmpIPredicate::eq || + pred == arith::CmpIPredicate::ne) && + stepSize <= constValue) + return pred == arith::CmpIPredicate::ne; + + return std::nullopt; + }(); + + if (!maybeSplat.has_value()) + continue; + + rewriter.setInsertionPointAfter(cmpiOp); + + auto type = dyn_cast(cmpiOp.getResult().getType()); + if (!type) + continue; + + DenseElementsAttr boolAttr = + DenseElementsAttr::get(type, maybeSplat.value()); + Value splat = mlir::arith::ConstantOp::create(rewriter, cmpiOp.getLoc(), + type, boolAttr); + + rewriter.replaceOp(cmpiOp, splat); + return success(); + } + + return failure(); + } +}; +} // namespace + +void StepOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} + //===----------------------------------------------------------------------===// // Vector Masking Utilities //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-step.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-step.mlir new file mode 100644 index 0000000000000..eb997438d2d51 --- /dev/null +++ b/mlir/test/Dialect/Vector/canonicalize/vector-step.mlir @@ -0,0 +1,393 @@ +// RUN: mlir-opt %s -canonicalize="test-convergence" -split-input-file | FileCheck %s + +///===----------------------------------------------===// +/// Tests of `StepCompareFolder` +///===----------------------------------------------===// + + +///===------------------------------------===// +/// Tests of `ugt` (unsigned greater than) +///===------------------------------------===// + +// CHECK-LABEL: @check_ugt_constant_3_lhs +// CHECK: %[[CST:.*]] = arith.constant dense : vector<3xi1> +// CHECK: return %[[CST]] : vector<3xi1> +func.func @check_ugt_constant_3_lhs() -> vector<3xi1> { + %cst = arith.constant dense<3> : vector<3xindex> + %0 = vector.step : vector<3xindex> + // 3 > [0, 1, 2] => true + %1 = arith.cmpi ugt, %cst, %0 : vector<3xindex> + return %1 : vector<3xi1> +} + +// ----- + +// CHECK-LABEL: @check_ugt_constant_2_lhs +// CHECK: %[[CMP:.*]] = arith.cmpi +// CHECK: return %[[CMP]] +func.func @check_ugt_constant_2_lhs() -> vector<3xi1> { + %cst = arith.constant dense<2> : vector<3xindex> + %0 = vector.step : vector<3xindex> + // 2 > [0, 1, 2] => not constant + %1 = arith.cmpi ugt, %cst, %0 : vector<3xindex> + return %1 : vector<3xi1> +} + +// ----- + +// CHECK-LABEL: @check_ugt_constant_1_lhs +// CHECK: %[[CMP:.*]] = arith.cmpi +// CHECK: return %[[CMP]] +func.func @check_ugt_constant_1_lhs() -> vector<3xi1> { + %cst = arith.constant dense<1> : vector<3xindex> + %0 = vector.step : vector<3xindex> + // 1 > [0, 1, 2] => not constant + %1 = arith.cmpi ugt, %cst, %0 : vector<3xindex> + return %1 : vector<3xi1> +} + +// ----- + +// CHECK-LABEL: @check_ugt_constant_3_rhs +// CHECK: %[[CST:.*]] = arith.constant dense : vector<3xi1> +// CHECK: return %[[CST]] : vector<3xi1> +func.func @check_ugt_constant_3_rhs() -> vector<3xi1> { + %cst = arith.constant dense<3> : vector<3xindex> + %0 = vector.step : vector<3xindex> + // [0, 1, 2] > 3 => false + %1 = arith.cmpi ugt, %0, %cst : vector<3xindex> + return %1 : vector<3xi1> +} + +// ----- + +// CHECK-LABEL: @check_ugt_constant_max_rhs +// CHECK: %[[CST:.*]] = arith.constant dense : vector<3xi1> +// CHECK: return %[[CST]] : vector<3xi1> +func.func @check_ugt_constant_max_rhs() -> vector<3xi1> { + // The largest i64 possible: + %cst = arith.constant dense<0x7fffffffffffffff> : vector<3xindex> + %0 = vector.step : vector<3xindex> + %1 = arith.cmpi ugt, %0, %cst: vector<3xindex> + return %1 : vector<3xi1> +} + + +// ----- + +// CHECK-LABEL: @check_ugt_constant_2_rhs +// CHECK: %[[CST:.*]] = arith.constant dense : vector<3xi1> +// CHECK: return %[[CST]] : vector<3xi1> +func.func @check_ugt_constant_2_rhs() -> vector<3xi1> { + %cst = arith.constant dense<2> : vector<3xindex> + %0 = vector.step : vector<3xindex> + // [0, 1, 2] > 2 => false + %1 = arith.cmpi ugt, %0, %cst : vector<3xindex> + return %1 : vector<3xi1> +} + +// ----- + +// CHECK-LABEL: @check_ugt_constant_1_rhs +// CHECK: %[[CMP:.*]] = arith.cmpi +// CHECK: return %[[CMP]] +func.func @check_ugt_constant_1_rhs() -> vector<3xi1> { + %cst = arith.constant dense<1> : vector<3xindex> + %0 = vector.step : vector<3xindex> + // [0, 1, 2] > 1 => not constant + %1 = arith.cmpi ugt, %0, %cst: vector<3xindex> + return %1 : vector<3xi1> +} + +// ----- + +///===------------------------------------===// +/// Tests of `uge` (unsigned greater than or equal) +///===------------------------------------===// + +// CHECK-LABEL: @check_uge_constant_3_lhs +// CHECK: %[[CST:.*]] = arith.constant dense : vector<3xi1> +// CHECK: return %[[CST]] : vector<3xi1> +func.func @check_uge_constant_3_lhs() -> vector<3xi1> { + %cst = arith.constant dense<3> : vector<3xindex> + %0 = vector.step : vector<3xindex> + // 3 >= [0, 1, 2] => true + %1 = arith.cmpi uge, %cst, %0 : vector<3xindex> + return %1 : vector<3xi1> +} + +// ----- + +// CHECK-LABEL: @check_uge_constant_2_lhs +// CHECK: %[[CST:.*]] = arith.constant dense : vector<3xi1> +// CHECK: return %[[CST]] : vector<3xi1> +func.func @check_uge_constant_2_lhs() -> vector<3xi1> { + %cst = arith.constant dense<2> : vector<3xindex> + %0 = vector.step : vector<3xindex> + // 2 >= [0, 1, 2] => true + %1 = arith.cmpi uge, %cst, %0 : vector<3xindex> + return %1 : vector<3xi1> +} + +// ----- + +// CHECK-LABEL: @check_uge_constant_1_lhs +// CHECK: %[[CMP:.*]] = arith.cmpi +// CHECK: return %[[CMP]] +func.func @check_uge_constant_1_lhs() -> vector<3xi1> { + %cst = arith.constant dense<1> : vector<3xindex> + %0 = vector.step : vector<3xindex> + // 1 >= [0, 1, 2] => not constant + %1 = arith.cmpi uge, %cst, %0 : vector<3xindex> + return %1 : vector<3xi1> +} + +// ----- + +// CHECK-LABEL: @check_uge_constant_3_rhs +// CHECK: %[[CST:.*]] = arith.constant dense : vector<3xi1> +// CHECK: return %[[CST]] : vector<3xi1> +func.func @check_uge_constant_3_rhs() -> vector<3xi1> { + %cst = arith.constant dense<3> : vector<3xindex> + %0 = vector.step : vector<3xindex> + // [0, 1, 2] >= 3 => false + %1 = arith.cmpi uge, %0, %cst : vector<3xindex> + return %1 : vector<3xi1> +} + +// ----- + +// CHECK-LABEL: @check_uge_constant_2_rhs +// CHECK: %[[CMP:.*]] = arith.cmpi +// CHECK: return %[[CMP]] +func.func @check_uge_constant_2_rhs() -> vector<3xi1> { + %cst = arith.constant dense<2> : vector<3xindex> + %0 = vector.step : vector<3xindex> + // [0, 1, 2] >= 2 => not constant + %1 = arith.cmpi uge, %0, %cst : vector<3xindex> + return %1 : vector<3xi1> +} + +// ----- + +// CHECK-LABEL: @check_uge_constant_1_rhs +// CHECK: %[[CMP:.*]] = arith.cmpi +// CHECK: return %[[CMP]] +func.func @check_uge_constant_1_rhs() -> vector<3xi1> { + %cst = arith.constant dense<1> : vector<3xindex> + %0 = vector.step : vector<3xindex> + // [0, 1, 2] >= 1 => not constant + %1 = arith.cmpi uge, %0, %cst: vector<3xindex> + return %1 : vector<3xi1> +} + +// ----- + + + +///===------------------------------------===// +/// Tests of `ult` (unsigned less than) +///===------------------------------------===// + +// CHECK-LABEL: @check_ult_constant_3_lhs +// CHECK: %[[CST:.*]] = arith.constant dense : vector<3xi1> +// CHECK: return %[[CST]] : vector<3xi1> +func.func @check_ult_constant_3_lhs() -> vector<3xi1> { + %cst = arith.constant dense<3> : vector<3xindex> + %0 = vector.step : vector<3xindex> + %1 = arith.cmpi ult, %cst, %0 : vector<3xindex> + return %1 : vector<3xi1> +} + +// ----- + +// CHECK-LABEL: @check_ult_constant_2_lhs +// CHECK: %[[CST:.*]] = arith.constant dense : vector<3xi1> +// CHECK: return %[[CST]] : vector<3xi1> +func.func @check_ult_constant_2_lhs() -> vector<3xi1> { + %cst = arith.constant dense<2> : vector<3xindex> + %0 = vector.step : vector<3xindex> + %1 = arith.cmpi ult, %cst, %0 : vector<3xindex> + return %1 : vector<3xi1> +} + +// ----- + +// CHECK-LABEL: @check_ult_constant_1_lhs +// CHECK: %[[CMP:.*]] = arith.cmpi +// CHECK: return %[[CMP]] +func.func @check_ult_constant_1_lhs() -> vector<3xi1> { + %cst = arith.constant dense<1> : vector<3xindex> + %0 = vector.step : vector<3xindex> + %1 = arith.cmpi ult, %cst, %0 : vector<3xindex> + return %1 : vector<3xi1> +} + +// ----- + +// CHECK-LABEL: @check_ult_constant_3_rhs +// CHECK: %[[CST:.*]] = arith.constant dense : vector<3xi1> +// CHECK: return %[[CST]] : vector<3xi1> +func.func @check_ult_constant_3_rhs() -> vector<3xi1> { + %cst = arith.constant dense<3> : vector<3xindex> + %0 = vector.step : vector<3xindex> + %1 = arith.cmpi ult, %0, %cst : vector<3xindex> + return %1 : vector<3xi1> +} + +// ----- + +// CHECK-LABEL: @check_ult_constant_2_rhs +// CHECK: %[[CMP:.*]] = arith.cmpi +// CHECK: return %[[CMP]] +func.func @check_ult_constant_2_rhs() -> vector<3xi1> { + %cst = arith.constant dense<2> : vector<3xindex> + %0 = vector.step : vector<3xindex> + %1 = arith.cmpi ult, %0, %cst : vector<3xindex> + return %1 : vector<3xi1> +} + +// ----- + +// CHECK-LABEL: @check_ult_constant_1_rhs +// CHECK: %[[CMP:.*]] = arith.cmpi +// CHECK: return %[[CMP]] +func.func @check_ult_constant_1_rhs() -> vector<3xi1> { + %cst = arith.constant dense<1> : vector<3xindex> + %0 = vector.step : vector<3xindex> + %1 = arith.cmpi ult, %0, %cst: vector<3xindex> + return %1 : vector<3xi1> +} + +// ----- + +///===------------------------------------===// +/// Tests of `ule` (unsigned less than or equal) +///===------------------------------------===// + +// CHECK-LABEL: @check_ule_constant_3_lhs +// CHECK: %[[CST:.*]] = arith.constant dense : vector<3xi1> +// CHECK: return %[[CST]] : vector<3xi1> +func.func @check_ule_constant_3_lhs() -> vector<3xi1> { + %cst = arith.constant dense<3> : vector<3xindex> + %0 = vector.step : vector<3xindex> + %1 = arith.cmpi ule, %cst, %0 : vector<3xindex> + return %1 : vector<3xi1> +} + +// ----- + +// CHECK-LABEL: @check_ule_constant_2_lhs +// CHECK: %[[CMP:.*]] = arith.cmpi +// CHECK: return %[[CMP]] +func.func @check_ule_constant_2_lhs() -> vector<3xi1> { + %cst = arith.constant dense<2> : vector<3xindex> + %0 = vector.step : vector<3xindex> + %1 = arith.cmpi ule, %cst, %0 : vector<3xindex> + return %1 : vector<3xi1> +} + +// ----- + +// CHECK-LABEL: @check_ule_constant_1_lhs +// CHECK: %[[CMP:.*]] = arith.cmpi +// CHECK: return %[[CMP]] +func.func @check_ule_constant_1_lhs() -> vector<3xi1> { + %cst = arith.constant dense<1> : vector<3xindex> + %0 = vector.step : vector<3xindex> + %1 = arith.cmpi ule, %cst, %0 : vector<3xindex> + return %1 : vector<3xi1> +} + +// ----- + +// CHECK-LABEL: @check_ule_constant_3_rhs +// CHECK: %[[CST:.*]] = arith.constant dense : vector<3xi1> +// CHECK: return %[[CST]] : vector<3xi1> +func.func @check_ule_constant_3_rhs() -> vector<3xi1> { + %cst = arith.constant dense<3> : vector<3xindex> + %0 = vector.step : vector<3xindex> + %1 = arith.cmpi ule, %0, %cst : vector<3xindex> + return %1 : vector<3xi1> +} + +// ----- + +// CHECK-LABEL: @check_ule_constant_2_rhs +// CHECK: %[[CST:.*]] = arith.constant dense : vector<3xi1> +// CHECK: return %[[CST]] : vector<3xi1> +func.func @check_ule_constant_2_rhs() -> vector<3xi1> { + %cst = arith.constant dense<2> : vector<3xindex> + %0 = vector.step : vector<3xindex> + %1 = arith.cmpi ule, %0, %cst : vector<3xindex> + return %1 : vector<3xi1> +} + +// ----- + +// CHECK-LABEL: @check_ule_constant_1_rhs +// CHECK: %[[CMP:.*]] = arith.cmpi +// CHECK: return %[[CMP]] +func.func @check_ule_constant_1_rhs() -> vector<3xi1> { + %cst = arith.constant dense<1> : vector<3xindex> + %0 = vector.step : vector<3xindex> + %1 = arith.cmpi ule, %0, %cst: vector<3xindex> + return %1 : vector<3xi1> +} + +// ----- + +///===------------------------------------===// +/// Tests of `eq` (equal) +///===------------------------------------===// + +// CHECK-LABEL: @check_eq_constant_3 +// CHECK: %[[CST:.*]] = arith.constant dense : vector<3xi1> +// CHECK: return %[[CST]] : vector<3xi1> +func.func @check_eq_constant_3() -> vector<3xi1> { + %cst = arith.constant dense<3> : vector<3xindex> + %0 = vector.step : vector<3xindex> + %1 = arith.cmpi eq, %0, %cst: vector<3xindex> + return %1 : vector<3xi1> +} + +// ----- + +// CHECK-LABEL: @check_eq_constant_2 +// CHECK: %[[CMP:.*]] = arith.cmpi +// CHECK: return %[[CMP]] +func.func @check_eq_constant_2() -> vector<3xi1> { + %cst = arith.constant dense<2> : vector<3xindex> + %0 = vector.step : vector<3xindex> + %1 = arith.cmpi eq, %0, %cst: vector<3xindex> + return %1 : vector<3xi1> +} + +// ----- + +///===------------------------------------===// +/// Tests of `ne` (not equal) +///===------------------------------------===// + +// CHECK-LABEL: @check_ne_constant_3 +// CHECK: %[[CST:.*]] = arith.constant dense : vector<3xi1> +// CHECK: return %[[CST]] : vector<3xi1> +func.func @check_ne_constant_3() -> vector<3xi1> { + %cst = arith.constant dense<3> : vector<3xindex> + %0 = vector.step : vector<3xindex> + %1 = arith.cmpi ne, %0, %cst: vector<3xindex> + return %1 : vector<3xi1> +} + +// ----- + +// CHECK-LABEL: @check_ne_constant_2 +// CHECK: %[[CMP:.*]] = arith.cmpi +// CHECK: return %[[CMP]] +func.func @check_ne_constant_2() -> vector<3xi1> { + %cst = arith.constant dense<2> : vector<3xindex> + %0 = vector.step : vector<3xindex> + %1 = arith.cmpi ne, %0, %cst: vector<3xindex> + return %1 : vector<3xi1> +} +