Skip to content

Commit

Permalink
[mlir][linalg] linalg.tiled_loop peeling
Browse files Browse the repository at this point in the history
Differential Revision: https://reviews.llvm.org/D108270
  • Loading branch information
matthias-springer committed Sep 7, 2021
1 parent da3ef8b commit c57c4f8
Show file tree
Hide file tree
Showing 6 changed files with 455 additions and 15 deletions.
25 changes: 25 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
Expand Up @@ -1036,6 +1036,31 @@ class ConvOpVectorization : public OpRewritePattern<ConvOp> {
PatternRewriter &rewriter) const override;
};

/// Rewrite a TiledLoopOp with bounds/step that potentially do not divide evenly
/// into a TiledLoopOp where the step divides the iteration space evenly,
/// followed by another TiledLoopOp for the last (partial) iteration (if any).
/// This transformation is called "loop peeling".
///
/// This function peels the `idx`-th loop of the TiledLoopOp. To tile all loops
/// in the loop nest, this function must be called multiple times.
///
/// After loop peeling, this function tries to simplify/canonicalize affine.min
/// and affine.max ops in the body of the two TiledLoopOps. For more details,
/// refer to `mlir::scf::peelAndCanonicalizeForLoop`.
///
/// The return value indicates whether the loop was rewritten or not. Loops are
/// not rewritten if:
/// * Loop step size is 1 or
/// * Loop bounds and step size are static, and step already divides the
/// iteration space evenly.
///
/// Note: This function rewrites the given TiledLoopOp in-place and clones the
/// TileLoopOp operation for the last iteration. It replaces all uses of the
/// unpeeled TiledLoopOp with the results of the newly generated TiledLoopOp.
LogicalResult peelAndCanonicalizeTiledLoop(RewriterBase &rewriter,
TiledLoopOp loopOp, int64_t idx,
TiledLoopOp &result);

//===----------------------------------------------------------------------===//
// Support for staged pattern application.
//===----------------------------------------------------------------------===//
Expand Down
18 changes: 18 additions & 0 deletions mlir/include/mlir/Dialect/SCF/Transforms.h
Expand Up @@ -111,6 +111,24 @@ void naivelyFuseParallelOps(Region &region);
LogicalResult peelAndCanonicalizeForLoop(RewriterBase &rewriter, ForOp forOp,
scf::IfOp &ifOp);

/// Try to simplify a min/max operation `op` after loop peeling. This function
/// can simplify min/max operations such as (ub is the previous upper bound of
/// the unpeeled loop):
/// ```
/// #map = affine_map<(d0)[s0, s1] -> (s0, -d0 + s1)>
/// %r = affine.min #affine.min #map(%iv)[%step, %ub]
/// ```
/// and rewrites them into (in the case the peeled loop):
/// ```
/// %r = %step
/// ```
/// min/max operations inside the partial iteration are rewritten in a similar
/// way.
LogicalResult rewritePeeledMinMaxOp(RewriterBase &rewriter, Operation *op,
AffineMap map, ValueRange operands,
bool isMin, Value iv, Value ub, Value step,
bool insideLoop);

/// Tile a parallel loop of the form
/// scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3)
/// step (%arg4, %arg5)
Expand Down
114 changes: 114 additions & 0 deletions mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
Expand Up @@ -12,6 +12,7 @@
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/SCF/Transforms.h"
#include "mlir/Dialect/StandardOps/Utils/Utils.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
Expand Down Expand Up @@ -633,6 +634,119 @@ struct LowerTiledLoopsToSCF
};
} // namespace

/// Rewrite a TiledLoopOp with bounds/step that potentially do not divide evenly
/// into two TiledLoopOps: One where the step divides the iteration space
/// evenly, followed another one for the last (partial) iteration (if any). This
/// function only rewrites the `idx`-th loop of the loop nest represented by
/// the TiledLoopOp. To peel the entire loop nest, this function must be called
/// multiple times.
///
/// This function rewrites the given TiledLoopOp in-place and creates a new
/// TiledLoopOp for the last iteration. It replaces all uses of the original
/// TiledLoopOp with the results of the newly generated one.
///
/// The newly generated TiledLoopOp is returned via `result`. The boundary
/// at which the loop is split (new upper bound) is returned via `splitBound`.
/// The return value indicates whether the TiledLoopOp was rewritten or not.
static LogicalResult peelTiledLoop(RewriterBase &b, TiledLoopOp loopOp,
int64_t idx, TiledLoopOp &result,
Value &splitBound) {
Value lb = loopOp.lowerBound()[idx], ub = loopOp.upperBound()[idx],
step = loopOp.step()[idx];
auto ubInt = getConstantIntValue(ub);

auto loc = loopOp.getLoc();
AffineExpr exprLb, exprUb, exprStep;
bindSymbols(b.getContext(), exprLb, exprUb, exprStep);
// New upper bound: %ub - (%ub - %lb) mod %step
auto modMap = AffineMap::get(0, 3, {exprUb - ((exprUb - exprLb) % exprStep)});
SmallVector<Value> operands{lb, ub, step};
mlir::canonicalizeMapAndOperands(&modMap, &operands);
modMap = mlir::simplifyAffineMap(modMap);
RewriterBase::InsertionGuard guard(b);
b.setInsertionPoint(loopOp);
splitBound = b.createOrFold<AffineApplyOp>(loc, modMap, operands);
// No specialization necessary if step already divides upper bound evenly.
if (splitBound == ub || (ubInt && ubInt == getConstantIntValue(splitBound)))
return failure();

// Create remainder loop.
b.setInsertionPointAfter(loopOp);
auto remainderLoop = cast<TiledLoopOp>(b.clone(*loopOp.getOperation()));
loopOp.replaceAllUsesWith(remainderLoop->getResults());
// Outputs: Take tensors from main loop's results. Take memrefs from main
// loop's outputs.
SmallVector<Value> remainderOutputs;
for (unsigned o = 0, t = 0; o < loopOp.getNumOutputs(); ++o) {
remainderOutputs.push_back(loopOp.outputs()[o].getType().isa<MemRefType>()
? loopOp.outputs()[o]
: loopOp->getResult(t++));
}
remainderLoop.outputsMutable().assign(remainderOutputs);

// Set new loop bounds.
b.updateRootInPlace(loopOp, [&]() {
SmallVector<Value> ubs = loopOp.upperBound();
ubs[idx] = splitBound;
loopOp.upperBoundMutable().assign(ubs);
});
SmallVector<Value> lbs = remainderLoop.lowerBound();
lbs[idx] = splitBound;
remainderLoop.lowerBoundMutable().assign(lbs);

result = remainderLoop;
return success();
}

template <typename OpTy, bool IsMin>
static void
rewriteAffineOpAfterPeeling(RewriterBase &rewriter, TiledLoopOp mainLoop,
TiledLoopOp remainderLoop, Value mainIv,
Value remainderIv, Value ub, Value step) {
mainLoop.walk([&](OpTy affineOp) {
AffineMap map = affineOp.getAffineMap();
(void)scf::rewritePeeledMinMaxOp(rewriter, affineOp, map,
affineOp.operands(), IsMin, mainIv, ub,
step, /*insideLoop=*/true);
});
remainderLoop.walk([&](OpTy affineOp) {
AffineMap map = affineOp.getAffineMap();
(void)scf::rewritePeeledMinMaxOp(rewriter, affineOp, map,
affineOp.operands(), IsMin, remainderIv,
ub, step, /*insideLoop=*/false);
});
}

LogicalResult mlir::linalg::peelAndCanonicalizeTiledLoop(RewriterBase &rewriter,
TiledLoopOp loopOp,
int64_t idx,
TiledLoopOp &result) {
int64_t numLoops = loopOp.iterator_types().size();
if (idx < 0 || numLoops <= idx)
return failure();
// Only parallel iterator supported.
if (!isParallelIterator(loopOp.iterator_types()[idx]))
return failure();

Value ub = loopOp.upperBound()[idx];
TiledLoopOp remainderLoop;
Value splitBound;
if (failed(peelTiledLoop(rewriter, loopOp, idx, remainderLoop, splitBound)))
return failure();

// Rewrite affine.min and affine.max ops.
Value mainIv = loopOp.getInductionVars()[idx], step = loopOp.step()[idx],
remainderIv = remainderLoop.getInductionVars()[idx];

rewriteAffineOpAfterPeeling<AffineMinOp, /*IsMin=*/true>(
rewriter, loopOp, remainderLoop, mainIv, remainderIv, ub, step);
rewriteAffineOpAfterPeeling<AffineMaxOp, /*IsMin=*/false>(
rewriter, loopOp, remainderLoop, mainIv, remainderIv, ub, step);

result = remainderLoop;
return success();
}

void mlir::linalg::populateTiledLoopToSCFPattern(RewritePatternSet &patterns) {
patterns.add<TiledLoopToSCFPattern>(patterns.getContext());
}
Expand Down
32 changes: 17 additions & 15 deletions mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
Expand Up @@ -324,25 +324,25 @@ canonicalizeMinMaxOp(RewriterBase &rewriter, Operation *op, AffineMap map,
/// ```
/// %r = %step
/// ```
/// min/max operations inside the generated scf.if operation are rewritten in
/// a similar way.
/// min/max operations inside the partial iteration are rewritten in a similar
/// way.
///
/// This function builds up a set of constraints, capable of proving that:
/// * Inside the peeled loop: min(step, ub - iv) == step
/// * Inside the scf.if operation: min(step, ub - iv) == ub - iv
/// * Inside the partial iteration: min(step, ub - iv) == ub - iv
///
/// Returns `success` if the given operation was replaced by a new operation;
/// `failure` otherwise.
///
/// Note: `ub` is the previous upper bound of the loop (before peeling).
/// `insideLoop` must be true for min/max ops inside the loop and false for
/// affine.min ops inside the scf.for op. For an explanation of the other
/// affine.min ops inside the partial iteration. For an explanation of the other
/// parameters, see comment of `canonicalizeMinMaxOpInLoop`.
static LogicalResult rewritePeeledMinMaxOp(RewriterBase &rewriter,
Operation *op, AffineMap map,
ValueRange operands, bool isMin,
Value iv, Value ub, Value step,
bool insideLoop) {
LogicalResult mlir::scf::rewritePeeledMinMaxOp(RewriterBase &rewriter,
Operation *op, AffineMap map,
ValueRange operands, bool isMin,
Value iv, Value ub, Value step,
bool insideLoop) {
FlatAffineValueConstraints constraints;
constraints.appendDimId({iv, ub, step});
if (auto constUb = getConstantIntValue(ub))
Expand Down Expand Up @@ -374,14 +374,16 @@ static void
rewriteAffineOpAfterPeeling(RewriterBase &rewriter, ForOp forOp, scf::IfOp ifOp,
Value iv, Value splitBound, Value ub, Value step) {
forOp.walk([&](OpTy affineOp) {
(void)rewritePeeledMinMaxOp(rewriter, affineOp, affineOp.getAffineMap(),
affineOp.operands(), IsMin, iv, ub, step,
/*insideLoop=*/true);
AffineMap map = affineOp.getAffineMap();
(void)scf::rewritePeeledMinMaxOp(rewriter, affineOp, map,
affineOp.operands(), IsMin, iv, ub, step,
/*insideLoop=*/true);
});
ifOp.walk([&](OpTy affineOp) {
(void)rewritePeeledMinMaxOp(rewriter, affineOp, affineOp.getAffineMap(),
affineOp.operands(), IsMin, splitBound, ub,
step, /*insideLoop=*/false);
AffineMap map = affineOp.getAffineMap();
(void)scf::rewritePeeledMinMaxOp(rewriter, affineOp, map,
affineOp.operands(), IsMin, splitBound, ub,
step, /*insideLoop=*/false);
});
}

Expand Down

0 comments on commit c57c4f8

Please sign in to comment.