Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -3045,6 +3045,7 @@ def Vector_StepOp : Vector_Op<"step", [
}];
let results = (outs VectorOfRankAndType<[1], [Index]>:$result);
let assemblyFormat = "attr-dict `:` type($result)";
let hasCanonicalizer = 1;
}

def Vector_YieldOp : Vector_Op<"yield", [
Expand Down
97 changes: 97 additions & 0 deletions mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7524,6 +7524,103 @@ void StepOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
setResultRanges(getResult(), result);
}

namespace {

/// Fold `vector.step -> arith.cmpi` when the step value is compared to a
/// constant large enough such that the result is the same at all indices.
///
/// For example, rewrite the 'greater than' comparison below,
///
/// %cst = arith.constant dense<7> : vector<3xindex>
/// %stp = vector.step : vector<3xindex>
/// %out = arith.cmpi ugt, %stp, %cst : vector<3xindex>
///
/// as,
///
/// %out = arith.constant dense<false> : vector<3xi1>.
///
/// Above [0, 1, 2] > [7, 7, 7] => [false, false, false]. Because the result is
/// false at ALL indices we fold. If the constant was 1, then
/// [0, 1, 2] > [1, 1, 1] => [false, false, true] and we do fold, conservatively
/// preferring the 'compact' vector.step representation.
Comment on lines +7533 to +7545
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit/optional: markdown renders nicely in some editors

Suggested change
///
/// %cst = arith.constant dense<7> : vector<3xindex>
/// %stp = vector.step : vector<3xindex>
/// %out = arith.cmpi ugt, %stp, %cst : vector<3xindex>
///
/// as,
///
/// %out = arith.constant dense<false> : vector<3xi1>.
///
/// Above [0, 1, 2] > [7, 7, 7] => [false, false, false]. Because the result is
/// false at ALL indices we fold. If the constant was 1, then
/// [0, 1, 2] > [1, 1, 1] => [false, false, true] and we do fold, conservatively
/// preferring the 'compact' vector.step representation.
/// ```mlir
/// %cst = arith.constant dense<7> : vector<3xindex>
/// %stp = vector.step : vector<3xindex>
/// %out = arith.cmpi ugt, %stp, %cst : vector<3xindex>
/// ```
///
/// as,
/// ```mlir
/// %out = arith.constant dense<false> : vector<3xi1>.
/// ```
///
/// Above `[0, 1, 2] > [7, 7, 7] => [false, false, false]`. Because the result is
/// false at ALL indices we fold. If the constant was 1, then
/// `[0, 1, 2] > [1, 1, 1] => [false, false, true]` and we do fold, conservatively
/// preferring the 'compact' vector.step representation.

struct StepCompareFolder : public OpRewritePattern<StepOp> {
using OpRewritePattern::OpRewritePattern;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
using OpRewritePattern::OpRewritePattern;
using Base::Base;


LogicalResult matchAndRewrite(StepOp stepOp,
PatternRewriter &rewriter) const override {
const int64_t stepSize = stepOp.getResult().getType().getNumElements();

for (auto &use : stepOp.getResult().getUses()) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: spell out the type since it's not immediately obvious based on the RHS

auto cmpiOp = dyn_cast<arith::CmpIOp>(use.getOwner());
if (!cmpiOp)
continue;

// arith.cmpi canonicalizer makes constants final operands.
const unsigned stepOperandNumber = use.getOperandNumber();
if (stepOperandNumber != 0)
continue;

// Check that operand 1 is a constant.
unsigned constOperandNumber = 1;
Value otherOperand = cmpiOp.getOperand(constOperandNumber);
auto maybeConstValue = getConstantIntValue(otherOperand);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: spell out the type

if (!maybeConstValue.has_value())
continue;

int64_t constValue = maybeConstValue.value();
arith::CmpIPredicate pred = cmpiOp.getPredicate();

auto maybeSplat = [&]() -> std::optional<bool> {
// Handle ult (unsigned less than) and uge (unsigned greater equal).
if ((pred == arith::CmpIPredicate::ult ||
pred == arith::CmpIPredicate::uge) &&
stepSize <= constValue)
return pred == arith::CmpIPredicate::ult;

// Handle ule and ugt.
if ((pred == arith::CmpIPredicate::ule ||
pred == arith::CmpIPredicate::ugt) &&
stepSize - 1 <= constValue) {
return pred == arith::CmpIPredicate::ule;
}

// Handle eq and ne.
if ((pred == arith::CmpIPredicate::eq ||
pred == arith::CmpIPredicate::ne) &&
stepSize <= constValue)
return pred == arith::CmpIPredicate::ne;

return std::nullopt;
}();

if (!maybeSplat.has_value())
continue;

rewriter.setInsertionPointAfter(cmpiOp);

auto type = dyn_cast<VectorType>(cmpiOp.getResult().getType());
if (!type)
continue;

DenseElementsAttr boolAttr =
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
DenseElementsAttr boolAttr =
auto boolAttr =

DenseElementsAttr::get(type, maybeSplat.value());
Value splat = mlir::arith::ConstantOp::create(rewriter, cmpiOp.getLoc(),
type, boolAttr);

rewriter.replaceOp(cmpiOp, splat);
return success();
}

return failure();
}
};
} // namespace

void StepOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<StepCompareFolder>(context);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why isn't this a folder for cmpi instead?

}

//===----------------------------------------------------------------------===//
// Vector Masking Utilities
//===----------------------------------------------------------------------===//
Expand Down
Loading