Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,6 @@ def AffineForOp : Affine_Op<"for",
Speculation::Speculatability getSpeculatability();
}];

let hasCanonicalizer = 1;
let hasCustomAssemblyFormat = 1;
let hasFolder = 1;
let hasRegionVerifier = 1;
Expand Down
171 changes: 80 additions & 91 deletions mlir/lib/Dialect/Affine/IR/AffineOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2460,6 +2460,65 @@ static LogicalResult foldLoopBounds(AffineForOp forOp) {
return success(folded);
}

/// Returns constant trip count in trivial cases.
static std::optional<uint64_t> getTrivialConstantTripCount(AffineForOp forOp) {
int64_t step = forOp.getStepAsInt();
if (!forOp.hasConstantBounds() || step <= 0)
return std::nullopt;
int64_t lb = forOp.getConstantLowerBound();
int64_t ub = forOp.getConstantUpperBound();
return ub - lb <= 0 ? 0 : (ub - lb + step - 1) / step;
}

/// Fold the empty loop.
static SmallVector<OpFoldResult> AffineForEmptyLoopFolder(AffineForOp forOp) {
if (!llvm::hasSingleElement(*forOp.getBody()))
return {};
if (forOp.getNumResults() == 0)
return {};
std::optional<uint64_t> tripCount = getTrivialConstantTripCount(forOp);
if (tripCount == 0) {
// The initial values of the iteration arguments would be the op's
// results.
return forOp.getInits();
}
SmallVector<Value, 4> replacements;
auto yieldOp = cast<AffineYieldOp>(forOp.getBody()->getTerminator());
auto iterArgs = forOp.getRegionIterArgs();
bool hasValDefinedOutsideLoop = false;
bool iterArgsNotInOrder = false;
for (unsigned i = 0, e = yieldOp->getNumOperands(); i < e; ++i) {
Value val = yieldOp.getOperand(i);
BlockArgument *iterArgIt = llvm::find(iterArgs, val);
// TODO: It should be possible to perform a replacement by computing the
// last value of the IV based on the bounds and the step.
if (val == forOp.getInductionVar())
return {};
if (iterArgIt == iterArgs.end()) {
// `val` is defined outside of the loop.
assert(forOp.isDefinedOutsideOfLoop(val) &&
"must be defined outside of the loop");
hasValDefinedOutsideLoop = true;
replacements.push_back(val);
} else {
unsigned pos = std::distance(iterArgs.begin(), iterArgIt);
if (pos != i)
iterArgsNotInOrder = true;
replacements.push_back(forOp.getInits()[pos]);
}
}
// Bail out when the trip count is unknown and the loop returns any value
// defined outside of the loop or any iterArg out of order.
if (!tripCount.has_value() &&
(hasValDefinedOutsideLoop || iterArgsNotInOrder))
return {};
// Bail out when the loop iterates more than once and it returns any iterArg
// out of order.
if (tripCount.has_value() && tripCount.value() >= 2 && iterArgsNotInOrder)
return {};
return llvm::to_vector_of<OpFoldResult>(replacements);
}

/// Canonicalize the bounds of the given loop.
static LogicalResult canonicalizeLoopBounds(AffineForOp forOp) {
SmallVector<Value, 4> lbOperands(forOp.getLowerBoundOperands());
Expand Down Expand Up @@ -2491,79 +2550,30 @@ static LogicalResult canonicalizeLoopBounds(AffineForOp forOp) {
return success();
}

namespace {
/// Returns constant trip count in trivial cases.
static std::optional<uint64_t> getTrivialConstantTripCount(AffineForOp forOp) {
int64_t step = forOp.getStepAsInt();
if (!forOp.hasConstantBounds() || step <= 0)
return std::nullopt;
int64_t lb = forOp.getConstantLowerBound();
int64_t ub = forOp.getConstantUpperBound();
return ub - lb <= 0 ? 0 : (ub - lb + step - 1) / step;
/// Returns true if the affine.for has zero iterations in trivial cases.
static bool hasTrivialZeroTripCount(AffineForOp op) {
return getTrivialConstantTripCount(op) == 0;
}

/// This is a pattern to fold trivially empty loop bodies.
/// TODO: This should be moved into the folding hook.
struct AffineForEmptyLoopFolder : public OpRewritePattern<AffineForOp> {
using OpRewritePattern<AffineForOp>::OpRewritePattern;

LogicalResult matchAndRewrite(AffineForOp forOp,
PatternRewriter &rewriter) const override {
// Check that the body only contains a yield.
if (!llvm::hasSingleElement(*forOp.getBody()))
return failure();
if (forOp.getNumResults() == 0)
return success();
std::optional<uint64_t> tripCount = getTrivialConstantTripCount(forOp);
if (tripCount == 0) {
// The initial values of the iteration arguments would be the op's
// results.
rewriter.replaceOp(forOp, forOp.getInits());
return success();
}
SmallVector<Value, 4> replacements;
auto yieldOp = cast<AffineYieldOp>(forOp.getBody()->getTerminator());
auto iterArgs = forOp.getRegionIterArgs();
bool hasValDefinedOutsideLoop = false;
bool iterArgsNotInOrder = false;
for (unsigned i = 0, e = yieldOp->getNumOperands(); i < e; ++i) {
Value val = yieldOp.getOperand(i);
auto *iterArgIt = llvm::find(iterArgs, val);
// TODO: It should be possible to perform a replacement by computing the
// last value of the IV based on the bounds and the step.
if (val == forOp.getInductionVar())
return failure();
if (iterArgIt == iterArgs.end()) {
// `val` is defined outside of the loop.
assert(forOp.isDefinedOutsideOfLoop(val) &&
"must be defined outside of the loop");
hasValDefinedOutsideLoop = true;
replacements.push_back(val);
} else {
unsigned pos = std::distance(iterArgs.begin(), iterArgIt);
if (pos != i)
iterArgsNotInOrder = true;
replacements.push_back(forOp.getInits()[pos]);
}
}
// Bail out when the trip count is unknown and the loop returns any value
// defined outside of the loop or any iterArg out of order.
if (!tripCount.has_value() &&
(hasValDefinedOutsideLoop || iterArgsNotInOrder))
return failure();
// Bail out when the loop iterates more than once and it returns any iterArg
// out of order.
if (tripCount.has_value() && tripCount.value() >= 2 && iterArgsNotInOrder)
return failure();
rewriter.replaceOp(forOp, replacements);
return success();
LogicalResult AffineForOp::fold(FoldAdaptor adaptor,
SmallVectorImpl<OpFoldResult> &results) {
bool folded = succeeded(foldLoopBounds(*this));
folded |= succeeded(canonicalizeLoopBounds(*this));
if (hasTrivialZeroTripCount(*this) && getNumResults() != 0) {
// The initial values of the loop-carried variables (iter_args) are the
// results of the op. But this must be avoided for an affine.for op that
// does not return any results. Since ops that do not return results cannot
// be folded away, we would enter an infinite loop of folds on the same
// affine.for op.
results.assign(getInits().begin(), getInits().end());
folded = true;
}
};
} // namespace

void AffineForOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<AffineForEmptyLoopFolder>(context);
SmallVector<OpFoldResult> foldResults = AffineForEmptyLoopFolder(*this);
if (!foldResults.empty()) {
results.assign(foldResults);
folded = true;
}
return success(folded);
}

OperandRange AffineForOp::getEntrySuccessorOperands(RegionBranchPoint point) {
Expand Down Expand Up @@ -2606,27 +2616,6 @@ void AffineForOp::getSuccessorRegions(
regions.push_back(RegionSuccessor(getResults()));
}

/// Returns true if the affine.for has zero iterations in trivial cases.
static bool hasTrivialZeroTripCount(AffineForOp op) {
return getTrivialConstantTripCount(op) == 0;
}

LogicalResult AffineForOp::fold(FoldAdaptor adaptor,
SmallVectorImpl<OpFoldResult> &results) {
bool folded = succeeded(foldLoopBounds(*this));
folded |= succeeded(canonicalizeLoopBounds(*this));
if (hasTrivialZeroTripCount(*this) && getNumResults() != 0) {
// The initial values of the loop-carried variables (iter_args) are the
// results of the op. But this must be avoided for an affine.for op that
// does not return any results. Since ops that do not return results cannot
// be folded away, we would enter an infinite loop of folds on the same
// affine.for op.
results.assign(getInits().begin(), getInits().end());
folded = true;
}
return success(folded);
}

AffineBound AffineForOp::getLowerBound() {
return AffineBound(*this, getLowerBoundOperands(), getLowerBoundMap());
}
Expand Down