-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[mlir][Vector] Fold vector.step compared to constant #161615
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -7524,6 +7524,103 @@ void StepOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, | |||||
setResultRanges(getResult(), result); | ||||||
} | ||||||
|
||||||
namespace { | ||||||
|
||||||
/// Fold `vector.step -> arith.cmpi` when the step value is compared to a | ||||||
/// constant large enough such that the result is the same at all indices. | ||||||
/// | ||||||
/// For example, rewrite the 'greater than' comparison below, | ||||||
/// | ||||||
/// %cst = arith.constant dense<7> : vector<3xindex> | ||||||
/// %stp = vector.step : vector<3xindex> | ||||||
/// %out = arith.cmpi ugt, %stp, %cst : vector<3xindex> | ||||||
/// | ||||||
/// as, | ||||||
/// | ||||||
/// %out = arith.constant dense<false> : vector<3xi1>. | ||||||
/// | ||||||
/// Above [0, 1, 2] > [7, 7, 7] => [false, false, false]. Because the result is | ||||||
/// false at ALL indices we fold. If the constant was 1, then | ||||||
/// [0, 1, 2] > [1, 1, 1] => [false, false, true] and we do fold, conservatively | ||||||
/// preferring the 'compact' vector.step representation. | ||||||
struct StepCompareFolder : public OpRewritePattern<StepOp> { | ||||||
using OpRewritePattern::OpRewritePattern; | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
|
||||||
LogicalResult matchAndRewrite(StepOp stepOp, | ||||||
PatternRewriter &rewriter) const override { | ||||||
const int64_t stepSize = stepOp.getResult().getType().getNumElements(); | ||||||
|
||||||
for (auto &use : stepOp.getResult().getUses()) { | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: spell out the type since it's not immediately obvious based on the RHS |
||||||
auto cmpiOp = dyn_cast<arith::CmpIOp>(use.getOwner()); | ||||||
if (!cmpiOp) | ||||||
continue; | ||||||
|
||||||
// arith.cmpi canonicalizer makes constants final operands. | ||||||
const unsigned stepOperandNumber = use.getOperandNumber(); | ||||||
if (stepOperandNumber != 0) | ||||||
continue; | ||||||
|
||||||
// Check that operand 1 is a constant. | ||||||
unsigned constOperandNumber = 1; | ||||||
Value otherOperand = cmpiOp.getOperand(constOperandNumber); | ||||||
auto maybeConstValue = getConstantIntValue(otherOperand); | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: spell out the type |
||||||
if (!maybeConstValue.has_value()) | ||||||
continue; | ||||||
|
||||||
int64_t constValue = maybeConstValue.value(); | ||||||
arith::CmpIPredicate pred = cmpiOp.getPredicate(); | ||||||
|
||||||
auto maybeSplat = [&]() -> std::optional<bool> { | ||||||
// Handle ult (unsigned less than) and uge (unsigned greater equal). | ||||||
if ((pred == arith::CmpIPredicate::ult || | ||||||
pred == arith::CmpIPredicate::uge) && | ||||||
stepSize <= constValue) | ||||||
return pred == arith::CmpIPredicate::ult; | ||||||
|
||||||
// Handle ule and ugt. | ||||||
if ((pred == arith::CmpIPredicate::ule || | ||||||
pred == arith::CmpIPredicate::ugt) && | ||||||
stepSize - 1 <= constValue) { | ||||||
return pred == arith::CmpIPredicate::ule; | ||||||
} | ||||||
|
||||||
// Handle eq and ne. | ||||||
if ((pred == arith::CmpIPredicate::eq || | ||||||
pred == arith::CmpIPredicate::ne) && | ||||||
stepSize <= constValue) | ||||||
return pred == arith::CmpIPredicate::ne; | ||||||
|
||||||
return std::nullopt; | ||||||
}(); | ||||||
|
||||||
if (!maybeSplat.has_value()) | ||||||
continue; | ||||||
|
||||||
rewriter.setInsertionPointAfter(cmpiOp); | ||||||
|
||||||
auto type = dyn_cast<VectorType>(cmpiOp.getResult().getType()); | ||||||
if (!type) | ||||||
continue; | ||||||
|
||||||
DenseElementsAttr boolAttr = | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
DenseElementsAttr::get(type, maybeSplat.value()); | ||||||
Value splat = mlir::arith::ConstantOp::create(rewriter, cmpiOp.getLoc(), | ||||||
type, boolAttr); | ||||||
|
||||||
rewriter.replaceOp(cmpiOp, splat); | ||||||
return success(); | ||||||
} | ||||||
|
||||||
return failure(); | ||||||
} | ||||||
}; | ||||||
} // namespace | ||||||
|
||||||
void StepOp::getCanonicalizationPatterns(RewritePatternSet &results, | ||||||
MLIRContext *context) { | ||||||
results.add<StepCompareFolder>(context); | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why isn't this a folder for |
||||||
} | ||||||
|
||||||
//===----------------------------------------------------------------------===// | ||||||
// Vector Masking Utilities | ||||||
//===----------------------------------------------------------------------===// | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit/optional: markdown renders nicely in some editors