Expand Up
@@ -120,10 +120,6 @@ struct TestLinalgTransforms
*this , " tile-sizes" ,
llvm::cl::desc (" Linalg tile sizes for test-tile-pattern" ),
llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
ListOption<unsigned > testTiledLoopPeeling{
*this , " test-tiled-loop-peeling" ,
llvm::cl::desc (" Test peeling of linalg.tiled_loop ops" ),
llvm::cl::OneOrMore, llvm::cl::MiscFlags::CommaSeparated};
Option<bool > skipPartial{
*this , " skip-partial" ,
llvm::cl::desc (" Skip loops inside partial iterations during peeling" ),
Expand Down
Expand Up
@@ -605,8 +601,7 @@ static void applyTilePattern(FuncOp funcOp, const std::string &loopType,
llvm::StringSwitch<LinalgTilingLoopType>(loopType)
.Case (" for" , LinalgTilingLoopType::Loops)
.Case (" affine" , LinalgTilingLoopType::AffineLoops)
.Case (" parallel" , LinalgTilingLoopType::ParallelLoops)
.Case (" tiled_loop" , LinalgTilingLoopType::TiledLoops);
.Case (" parallel" , LinalgTilingLoopType::ParallelLoops);
auto linalgTilingOptions = linalg::LinalgTilingOptions ()
.setPeeledLoops (peeledLoops)
.setLoopType (type);
Expand All
@@ -626,76 +621,6 @@ static void applyTilePattern(FuncOp funcOp, const std::string &loopType,
static constexpr char kPeeledLoopsLabel [] = " __peeled_loops__" ;
static constexpr char kPartialIterationLabel [] = " __partial_iteration__" ;
namespace {
// / Peel TiledLoopOps, i.e., split them into two loops: One loop where the
// / `idx`-th loop contains only "full" iterations and a second loop for the
// / remaining partial iteration (if any).
struct TiledLoopPeelingPattern : public OpRewritePattern <TiledLoopOp> {
TiledLoopPeelingPattern (MLIRContext *ctx, int64_t idx, bool skipPartial)
: OpRewritePattern<TiledLoopOp>(ctx), idx(idx), skipPartial(skipPartial) {
}
LogicalResult matchAndRewrite (TiledLoopOp loopOp,
PatternRewriter &rewriter) const override {
SmallVector<int64_t > peeledLoops;
if (loopOp->hasAttr (kPeeledLoopsLabel )) {
auto attr = loopOp->getAttr (kPeeledLoopsLabel ).cast <ArrayAttr>();
peeledLoops =
llvm::to_vector<4 >(llvm::map_range (attr, [](Attribute attr) {
return attr.cast <IntegerAttr>().getInt ();
}));
// Check if the loop was already peeled.
if (llvm::find (peeledLoops, idx) != peeledLoops.end ())
return failure ();
}
if (skipPartial && loopOp->hasAttr (kPartialIterationLabel ))
// No peeling of loop nests with a partial iteration.
return failure ();
if (static_cast <int64_t >(loopOp.iterator_types ().size ()) <= idx)
return failure ();
// Peel loop and canonicalize.
TiledLoopOp result;
if (failed (linalg::peelAndCanonicalizeTiledLoop (rewriter, loopOp, idx,
result)))
return failure ();
// Apply label, so that the same loop is not rewritten a second time.
peeledLoops.push_back (idx);
rewriter.updateRootInPlace (loopOp, [&]() {
loopOp->setAttr (kPeeledLoopsLabel , rewriter.getI64ArrayAttr (peeledLoops));
});
result->setAttr (kPeeledLoopsLabel , rewriter.getI64ArrayAttr (peeledLoops));
result->setAttr (kPartialIterationLabel , rewriter.getUnitAttr ());
return success ();
}
// / Index of loop to peel.
int64_t idx;
// / If set to true, do not peel TiledLoopOps with a partial iteration.
bool skipPartial;
};
} // namespace
static void applyTiledLoopPeelingPattern (FuncOp funcOp,
ArrayRef<unsigned > loops,
bool skipPartial) {
MLIRContext *ctx = funcOp.getContext ();
RewritePatternSet patterns (ctx);
for (unsigned idx : loops)
patterns.add <TiledLoopPeelingPattern>(ctx, idx, skipPartial);
(void )applyPatternsAndFoldGreedily (funcOp, std::move (patterns));
// Drop the markers.
funcOp.walk ([](TiledLoopOp op) {
op->removeAttr (kPeeledLoopsLabel );
op->removeAttr (kPartialIterationLabel );
});
}
// / Apply transformations specified as patterns.
void TestLinalgTransforms::runOnOperation () {
auto lambda = [&](void *) {
Expand Down
Expand Up
@@ -739,9 +664,6 @@ void TestLinalgTransforms::runOnOperation() {
return applyGeneralizePadTensorPatterns (getOperation ());
if (testSwapSubTensorPadTensor)
return applyExtractSliceOfPadTensorSwapPattern (getOperation ());
if (testTiledLoopPeeling.hasValue ())
return applyTiledLoopPeelingPattern (getOperation (), testTiledLoopPeeling,
skipPartial);
if (testTilePattern)
return applyTilePattern (getOperation (), loopType, tileSizes, peeledLoops,
/* scalarizeDynamicDims=*/ false );
Expand Down