Skip to content

Commit

Permalink
[mlir][scf] Match any constants instead of arith.constant
Browse files Browse the repository at this point in the history
By matching `arith.constant` specifically, SCF canonicalizers/folders
are incompatible with other kinds of constants. Use the generic
matchers instead.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D135517
  • Loading branch information
Mogball committed Oct 13, 2022
1 parent 86e9181 commit 6005a1d
Showing 1 changed file with 25 additions and 28 deletions.
53 changes: 25 additions & 28 deletions mlir/lib/Dialect/SCF/IR/SCF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -299,9 +299,9 @@ void ForOp::build(OpBuilder &builder, OperationState &result, Value lb,
}

LogicalResult ForOp::verify() {
if (auto cst = getStep().getDefiningOp<arith::ConstantIndexOp>())
if (cst.value() <= 0)
return emitOpError("constant step operand must be positive");
IntegerAttr step;
if (matchPattern(getStep(), m_Constant(&step)) && step.getInt() <= 0)
return emitOpError("constant step operand must be positive");

auto opNumResults = getNumResults();
if (opNumResults == 0)
Expand Down Expand Up @@ -719,11 +719,10 @@ struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
/// Returns llvm::None when the difference between two AffineValueMap is
/// dynamic.
static Optional<int64_t> computeConstDiff(Value l, Value u) {
auto clb = l.getDefiningOp<arith::ConstantOp>();
auto cub = u.getDefiningOp<arith::ConstantOp>();
if (cub && clb) {
llvm::APInt lbValue = clb.getValue().cast<IntegerAttr>().getValue();
llvm::APInt ubValue = cub.getValue().cast<IntegerAttr>().getValue();
IntegerAttr clb, cub;
if (matchPattern(l, m_Constant(&clb)) && matchPattern(u, m_Constant(&cub))) {
llvm::APInt lbValue = clb.getValue();
llvm::APInt ubValue = cub.getValue();
return (ubValue - lbValue).getSExtValue();
}

Expand Down Expand Up @@ -763,13 +762,13 @@ struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
return success();
}

auto step = op.getStep().getDefiningOp<arith::ConstantOp>();
if (!step)
IntegerAttr step;
if (!matchPattern(op.getStep(), m_Constant(&step)))
return failure();

// If the loop is known to have 1 iteration, inline its body and remove the
// loop.
llvm::APInt stepValue = step.getValue().cast<IntegerAttr>().getValue();
llvm::APInt stepValue = step.getValue();
if (stepValue.sge(*diff)) {
SmallVector<Value, 4> blockArgs;
blockArgs.reserve(op.getNumIterOperands() + 1);
Expand Down Expand Up @@ -1674,11 +1673,11 @@ struct RemoveStaticCondition : public OpRewritePattern<IfOp> {

LogicalResult matchAndRewrite(IfOp op,
PatternRewriter &rewriter) const override {
auto constant = op.getCondition().getDefiningOp<arith::ConstantOp>();
if (!constant)
BoolAttr condition;
if (!matchPattern(op.getCondition(), m_Constant(&condition)))
return failure();

if (constant.getValue().cast<BoolAttr>().getValue())
if (condition.getValue())
replaceOpWithRegion(rewriter, op, op.getThenRegion());
else if (!op.getElseRegion().empty())
replaceOpWithRegion(rewriter, op, op.getElseRegion());
Expand Down Expand Up @@ -1777,7 +1776,7 @@ struct ConditionPropagation : public OpRewritePattern<IfOp> {
PatternRewriter &rewriter) const override {
// Early exit if the condition is constant since replacing a constant
// in the body with another constant isn't a simplification.
if (op.getCondition().getDefiningOp<arith::ConstantOp>())
if (matchPattern(op.getCondition(), m_Constant()))
return failure();

bool changed = false;
Expand Down Expand Up @@ -1881,25 +1880,23 @@ struct ReplaceIfYieldWithConditionOrValue : public OpRewritePattern<IfOp> {
continue;
}

auto trueYield = trueResult.getDefiningOp<arith::ConstantOp>();
if (!trueYield)
BoolAttr trueYield, falseYield;
if (!matchPattern(trueResult, m_Constant(&trueYield)) ||
!matchPattern(falseResult, m_Constant(&falseYield)))
continue;

if (!trueYield.getType().isInteger(1))
continue;

auto falseYield = falseResult.getDefiningOp<arith::ConstantOp>();
if (!falseYield)
continue;

bool trueVal = trueYield.getValue().cast<BoolAttr>().getValue();
bool falseVal = falseYield.getValue().cast<BoolAttr>().getValue();
bool trueVal = trueYield.getValue();
bool falseVal = falseYield.getValue();
if (!trueVal && falseVal) {
if (!opResult.use_empty()) {
Dialect *constDialect = trueResult.getDefiningOp()->getDialect();
Value notCond = rewriter.create<arith::XOrIOp>(
op.getLoc(), op.getCondition(),
rewriter.create<arith::ConstantOp>(
op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 1)));
constDialect
->materializeConstant(rewriter,
rewriter.getIntegerAttr(i1Ty, 1), i1Ty,
op.getLoc())
->getResult(0));
opResult.replaceAllUsesWith(notCond);
changed = true;
}
Expand Down

0 comments on commit 6005a1d

Please sign in to comment.