Skip to content

Commit

Permalink
[mlir] Add a pattern to fold single- and zero-iteration scf.forall ops.
Browse files Browse the repository at this point in the history
Differential Revision: https://reviews.llvm.org/D145368
  • Loading branch information
pifon2a committed Mar 21, 2023
1 parent 0165b73 commit 3a8f161
Show file tree
Hide file tree
Showing 7 changed files with 311 additions and 58 deletions.
8 changes: 8 additions & 0 deletions mlir/include/mlir/Dialect/SCF/IR/SCF.h
Expand Up @@ -62,6 +62,14 @@ ForallOp getForallOpThreadIndexOwner(Value val);
// TODO: Consider moving this functionality to RegionBranchOpInterface.
bool insideMutuallyExclusiveBranches(Operation *a, Operation *b);

/// Promotes the loop body of a scf::ForallOp to its containing block if the
/// loop was known to have a single iteration.
LogicalResult promoteIfSingleIteration(PatternRewriter &rewriter,
scf::ForallOp forallOp);

/// Promotes the loop body of a scf::ForallOp to its containing block.
void promote(PatternRewriter &rewriter, scf::ForallOp forallOp);

/// An owning vector of values, handy to return from functions.
using ValueVector = SmallVector<Value>;
using LoopVector = SmallVector<scf::ForOp>;
Expand Down
5 changes: 5 additions & 0 deletions mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
Expand Up @@ -128,6 +128,11 @@ SmallVector<int64_t>
getValuesSortedByKey(ArrayRef<Attribute> keys, ArrayRef<int64_t> values,
llvm::function_ref<bool(Attribute, Attribute)> compare);

/// Return the number of iterations for a loop with a lower bound `lb`, upper
/// bound `ub` and step `step`.
std::optional<int64_t> constantTripCount(OpFoldResult lb, OpFoldResult ub,
OpFoldResult step);

} // namespace mlir

#endif // MLIR_DIALECT_UTILS_STATICVALUEUTILS_H
212 changes: 165 additions & 47 deletions mlir/lib/Dialect/SCF/IR/SCF.cpp
Expand Up @@ -534,6 +534,61 @@ void ForOp::getSuccessorRegions(std::optional<unsigned> index,
regions.push_back(RegionSuccessor(getResults()));
}

/// Promotes the loop body of a forallOp to its containing block if it can be
/// determined that the loop has a single iteration.
LogicalResult mlir::scf::promoteIfSingleIteration(PatternRewriter &rewriter,
scf::ForallOp forallOp) {
for (auto [lb, ub, step] :
llvm::zip(forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
forallOp.getMixedStep())) {
auto tripCount = constantTripCount(lb, ub, step);
if (!tripCount.has_value() || *tripCount != 1)
return failure();
}

promote(rewriter, forallOp);
return success();
}

/// Promotes the loop body of a scf::ForallOp to its containing block.
void mlir::scf::promote(PatternRewriter &rewriter, scf::ForallOp forallOp) {
IRMapping mapping;
mapping.map(forallOp.getInductionVars(), forallOp.getLowerBound(rewriter));
mapping.map(forallOp.getOutputBlockArguments(), forallOp.getOutputs());
for (auto &bodyOp : forallOp.getBody()->without_terminator())
rewriter.clone(bodyOp, mapping);

SmallVector<Value> results;
results.reserve(forallOp.getResults().size());
scf::InParallelOp terminator = forallOp.getTerminator();
for (auto &yieldingOp : terminator.getYieldingOps()) {
auto parallelInsertSliceOp =
cast<tensor::ParallelInsertSliceOp>(yieldingOp);

Value dst = parallelInsertSliceOp.getDest();
Value src = parallelInsertSliceOp.getSource();

auto getMappedValues = [&](ValueRange values) {
return llvm::to_vector(llvm::map_range(
values, [&](Value value) { return mapping.lookupOrDefault(value); }));
};

Value srcVal = mapping.lookupOrDefault(src);
if (srcVal.getType().isa<TensorType>()) {
results.push_back(rewriter.create<tensor::InsertSliceOp>(
forallOp.getLoc(), dst.getType(), srcVal,
mapping.lookupOrDefault(dst),
getMappedValues(parallelInsertSliceOp.getOffsets()),
getMappedValues(parallelInsertSliceOp.getSizes()),
getMappedValues(parallelInsertSliceOp.getStrides()),
parallelInsertSliceOp.getStaticOffsets(),
parallelInsertSliceOp.getStaticSizes(),
parallelInsertSliceOp.getStaticStrides()));
}
}
rewriter.replaceOp(forallOp, results);
}

LoopNest mlir::scf::buildLoopNest(
OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
ValueRange steps, ValueRange iterArgs,
Expand Down Expand Up @@ -1452,16 +1507,99 @@ class ForallOpControlOperandsFolder : public OpRewritePattern<ForallOp> {
dispatchIndexOpFoldResults(mixedStep, dynamicStep, staticStep);
op.getDynamicStepMutable().assign(dynamicStep);
op.setStaticStep(staticStep);

op->setAttr(ForallOp::getOperandSegmentSizeAttr(),
rewriter.getDenseI32ArrayAttr(
{static_cast<int32_t>(dynamicLowerBound.size()),
static_cast<int32_t>(dynamicUpperBound.size()),
static_cast<int32_t>(dynamicStep.size()),
static_cast<int32_t>(op.getNumResults())}));
});
return success();
}
};

struct ForallOpSingleOrZeroIterationDimsFolder
: public OpRewritePattern<ForallOp> {
using OpRewritePattern<ForallOp>::OpRewritePattern;

LogicalResult matchAndRewrite(ForallOp op,
PatternRewriter &rewriter) const override {
// Do not fold dimensions if they are mapped to processing units.
if (op.getMapping().has_value())
return failure();
Location loc = op.getLoc();

// Compute new loop bounds that omit all single-iteration loop dimensions.
SmallVector<OpFoldResult> newMixedLowerBounds, newMixedUpperBounds,
newMixedSteps;
IRMapping mapping;
for (auto [lb, ub, step, iv] :
llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),
op.getMixedStep(), op.getInductionVars())) {
auto numIterations = constantTripCount(lb, ub, step);
if (numIterations.has_value()) {
// Remove the loop if it performs zero iterations.
if (*numIterations == 0) {
rewriter.replaceOp(op, op.getOutputs());
return success();
}
// Replace the loop induction variable by the lower bound if the loop
// performs a single iteration. Otherwise, copy the loop bounds.
if (*numIterations == 1) {
mapping.map(iv, getValueOrCreateConstantIndexOp(rewriter, loc, lb));
continue;
}
}
newMixedLowerBounds.push_back(lb);
newMixedUpperBounds.push_back(ub);
newMixedSteps.push_back(step);
}
// Exit if none of the loop dimensions perform a single iteration.
if (newMixedLowerBounds.size() == static_cast<unsigned>(op.getRank())) {
return rewriter.notifyMatchFailure(
op, "no dimensions have 0 or 1 iterations");
}

// All of the loop dimensions perform a single iteration. Inline loop body.
if (newMixedLowerBounds.empty()) {
promote(rewriter, op);
return success();
}

// Replace the loop by a lower-dimensional loop.
ForallOp newOp;
newOp = rewriter.create<ForallOp>(loc, newMixedLowerBounds,
newMixedUpperBounds, newMixedSteps,
op.getOutputs(), std::nullopt, nullptr);
newOp.getBodyRegion().getBlocks().clear();
// The new loop needs to keep all attributes from the old one, except for
// "operand_segment_sizes" and static loop bound attributes which capture
// the outdated information of the old iteration domain.
SmallVector<StringAttr> elidedAttrs{newOp.getOperandSegmentSizesAttrName(),
newOp.getStaticLowerBoundAttrName(),
newOp.getStaticUpperBoundAttrName(),
newOp.getStaticStepAttrName()};
for (const auto &namedAttr : op->getAttrs()) {
if (llvm::is_contained(elidedAttrs, namedAttr.getName()))
continue;
rewriter.updateRootInPlace(newOp, [&]() {
newOp->setAttr(namedAttr.getName(), namedAttr.getValue());
});
}
rewriter.cloneRegionBefore(op.getRegion(), newOp.getRegion(),
newOp.getRegion().begin(), mapping);
rewriter.replaceOp(op, newOp.getResults());
return success();
}
};

} // namespace

void ForallOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<DimOfForallOp, ForallOpControlOperandsFolder>(context);
results.add<DimOfForallOp, ForallOpControlOperandsFolder,
ForallOpSingleOrZeroIterationDimsFolder>(context);
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -2615,41 +2753,37 @@ ParallelOp mlir::scf::getParallelForInductionVarOwner(Value val) {

namespace {
// Collapse loop dimensions that perform a single iteration.
struct CollapseSingleIterationLoops : public OpRewritePattern<ParallelOp> {
struct ParallelOpSingleOrZeroIterationDimsFolder
: public OpRewritePattern<ParallelOp> {
using OpRewritePattern<ParallelOp>::OpRewritePattern;

LogicalResult matchAndRewrite(ParallelOp op,
PatternRewriter &rewriter) const override {
IRMapping mapping;
Location loc = op.getLoc();

// Compute new loop bounds that omit all single-iteration loop dimensions.
SmallVector<Value, 2> newLowerBounds;
SmallVector<Value, 2> newUpperBounds;
SmallVector<Value, 2> newSteps;
newLowerBounds.reserve(op.getLowerBound().size());
newUpperBounds.reserve(op.getUpperBound().size());
newSteps.reserve(op.getStep().size());
for (auto [lowerBound, upperBound, step, iv] :
SmallVector<Value> newLowerBounds, newUpperBounds, newSteps;
IRMapping mapping;
for (auto [lb, ub, step, iv] :
llvm::zip(op.getLowerBound(), op.getUpperBound(), op.getStep(),
op.getInductionVars())) {
// Collect the statically known loop bounds.
auto lowerBoundConstant =
dyn_cast_or_null<arith::ConstantIndexOp>(lowerBound.getDefiningOp());
auto upperBoundConstant =
dyn_cast_or_null<arith::ConstantIndexOp>(upperBound.getDefiningOp());
auto stepConstant =
dyn_cast_or_null<arith::ConstantIndexOp>(step.getDefiningOp());
// Replace the loop induction variable by the lower bound if the loop
// performs a single iteration. Otherwise, copy the loop bounds.
if (lowerBoundConstant && upperBoundConstant && stepConstant &&
(upperBoundConstant.value() - lowerBoundConstant.value()) > 0 &&
(upperBoundConstant.value() - lowerBoundConstant.value()) <=
stepConstant.value()) {
mapping.map(iv, lowerBound);
} else {
newLowerBounds.push_back(lowerBound);
newUpperBounds.push_back(upperBound);
newSteps.push_back(step);
auto numIterations = constantTripCount(lb, ub, step);
if (numIterations.has_value()) {
// Remove the loop if it performs zero iterations.
if (*numIterations == 0) {
rewriter.replaceOp(op, op.getInitVals());
return success();
}
// Replace the loop induction variable by the lower bound if the loop
// performs a single iteration. Otherwise, copy the loop bounds.
if (*numIterations == 1) {
mapping.map(iv, getValueOrCreateConstantIndexOp(rewriter, loc, lb));
continue;
}
}
newLowerBounds.push_back(lb);
newUpperBounds.push_back(ub);
newSteps.push_back(step);
}
// Exit if none of the loop dimensions perform a single iteration.
if (newLowerBounds.size() == op.getLowerBound().size())
Expand Down Expand Up @@ -2694,23 +2828,6 @@ struct CollapseSingleIterationLoops : public OpRewritePattern<ParallelOp> {
}
};

/// Removes parallel loops in which at least one lower/upper bound pair consists
/// of the same values - such loops have an empty iteration domain.
struct RemoveEmptyParallelLoops : public OpRewritePattern<ParallelOp> {
using OpRewritePattern<ParallelOp>::OpRewritePattern;

LogicalResult matchAndRewrite(ParallelOp op,
PatternRewriter &rewriter) const override {
for (auto dim : llvm::zip(op.getLowerBound(), op.getUpperBound())) {
if (std::get<0>(dim) == std::get<1>(dim)) {
rewriter.replaceOp(op, op.getInitVals());
return success();
}
}
return failure();
}
};

struct MergeNestedParallelLoops : public OpRewritePattern<ParallelOp> {
using OpRewritePattern<ParallelOp>::OpRewritePattern;

Expand Down Expand Up @@ -2773,8 +2890,9 @@ struct MergeNestedParallelLoops : public OpRewritePattern<ParallelOp> {

void ParallelOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<CollapseSingleIterationLoops, RemoveEmptyParallelLoops,
MergeNestedParallelLoops>(context);
results
.add<ParallelOpSingleOrZeroIterationDimsFolder, MergeNestedParallelLoops>(
context);
}

//===----------------------------------------------------------------------===//
Expand Down
14 changes: 4 additions & 10 deletions mlir/lib/Dialect/SCF/Utils/Utils.cpp
Expand Up @@ -381,18 +381,12 @@ static void replaceIterArgsAndYieldResults(scf::ForOp forOp) {
/// Promotes the loop body of a forOp to its containing block if the forOp
/// it can be determined that the loop has a single iteration.
LogicalResult mlir::promoteIfSingleIteration(scf::ForOp forOp) {
auto lbCstOp = forOp.getLowerBound().getDefiningOp<arith::ConstantIndexOp>();
auto ubCstOp = forOp.getUpperBound().getDefiningOp<arith::ConstantIndexOp>();
auto stepCstOp = forOp.getStep().getDefiningOp<arith::ConstantIndexOp>();
if (!lbCstOp || !ubCstOp || !stepCstOp || lbCstOp.value() < 0 ||
ubCstOp.value() < 0 || stepCstOp.value() < 0)
return failure();
int64_t tripCount =
mlir::ceilDiv(ubCstOp.value() - lbCstOp.value(), stepCstOp.value());
if (tripCount != 1)
std::optional<int64_t> tripCount = constantTripCount(
forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep());
if (!tripCount.has_value() || tripCount != 1)
return failure();
auto iv = forOp.getInductionVar();
iv.replaceAllUsesWith(lbCstOp);
iv.replaceAllUsesWith(forOp.getLowerBound());

replaceIterArgsAndYieldResults(forOp);

Expand Down
21 changes: 21 additions & 0 deletions mlir/lib/Dialect/Utils/StaticValueUtils.cpp
Expand Up @@ -10,6 +10,7 @@
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/MathExtras.h"
#include "llvm/ADT/APSInt.h"

namespace mlir {
Expand Down Expand Up @@ -228,4 +229,24 @@ getValuesSortedByKey(ArrayRef<Attribute> keys, ArrayRef<int64_t> values,
return getValuesSortedByKeyImpl(keys, values, compare);
}

/// Return the number of iterations for a loop with a lower bound `lb`, upper
/// bound `ub` and step `step`.
std::optional<int64_t> constantTripCount(OpFoldResult lb, OpFoldResult ub,
OpFoldResult step) {
if (lb == ub)
return 0;

std::optional<int64_t> lbConstant = getConstantIntValue(lb);
if (!lbConstant)
return std::nullopt;
std::optional<int64_t> ubConstant = getConstantIntValue(ub);
if (!ubConstant)
return std::nullopt;
std::optional<int64_t> stepConstant = getConstantIntValue(step);
if (!stepConstant)
return std::nullopt;

return mlir::ceilDiv(*ubConstant - *lbConstant, *stepConstant);
}

} // namespace mlir

0 comments on commit 3a8f161

Please sign in to comment.