Skip to content

Commit

Permalink
[mlir] Add peeling xform to Codegen Strategy
Browse files Browse the repository at this point in the history
This patch adds the knobs to use peeling in the codegen strategy
infrastructure.

Reviewed By: springerm

Differential Revision: https://reviews.llvm.org/D126842
  • Loading branch information
dcaballe committed Jun 3, 2022
1 parent 5ac2615 commit 9a79b1b
Show file tree
Hide file tree
Showing 6 changed files with 182 additions and 0 deletions.
7 changes: 7 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/Passes.h
Expand Up @@ -127,6 +127,13 @@ createLinalgStrategyInterchangePass(
const linalg::LinalgTransformationFilter &filter =
linalg::LinalgTransformationFilter());

/// Create a LinalgStrategyPeelPass.
std::unique_ptr<OperationPass<func::FuncOp>> createLinalgStrategyPeelPass(
StringRef opName = "",
linalg::LinalgPeelOptions opt = linalg::LinalgPeelOptions(),
const linalg::LinalgTransformationFilter &filter =
linalg::LinalgTransformationFilter());

/// Create a LinalgStrategyVectorizePass.
std::unique_ptr<OperationPass<func::FuncOp>> createLinalgStrategyVectorizePass(
StringRef opName = "",
Expand Down
16 changes: 16 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/Passes.td
Expand Up @@ -272,6 +272,22 @@ def LinalgStrategyInterchangePass
];
}

def LinalgStrategyPeelPass
: Pass<"linalg-strategy-peel-pass", "func::FuncOp"> {
let summary = "Configurable pass to apply pattern-based linalg peeling.";
let constructor = "mlir::createLinalgStrategyPeelPass()";
let dependentDialects = [
"linalg::LinalgDialect",
"scf::SCFDialect"
];
let options = [
Option<"anchorFuncName", "anchor-func", "std::string", /*default=*/"",
"Which func op is the anchor to latch on.">,
Option<"anchorOpName", "anchor-op", "std::string", /*default=*/"",
"Which linalg op within the func is the anchor to latch on.">,
];
}

def LinalgStrategyVectorizePass
: Pass<"linalg-strategy-vectorize-pass", "func::FuncOp"> {
let summary = "Configurable pass to apply pattern-based linalg vectorization.";
Expand Down
34 changes: 34 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h
Expand Up @@ -141,6 +141,26 @@ struct Decompose : public Transformation {
}
};

/// Represent one application of createLinalgStrategyPeelPass.
struct Peel : public Transformation {
explicit Peel(linalg::LinalgPeelOptions options,
LinalgTransformationFilter::FilterFunction f = nullptr)
: Transformation(std::move(f)), opName(), options(options) {}

Peel(StringRef name, linalg::LinalgPeelOptions options,
LinalgTransformationFilter::FilterFunction f = nullptr)
: Transformation(std::move(f)), opName(name), options(options) {}

void addToPassPipeline(OpPassManager &pm,
LinalgTransformationFilter m) const override {
pm.addPass(createLinalgStrategyPeelPass(opName, options, m));
}

private:
std::string opName;
linalg::LinalgPeelOptions options;
};

/// Represent one application of createLinalgStrategyVectorizePass.
struct Vectorize : public Transformation {
explicit Vectorize(linalg::LinalgVectorizationOptions options,
Expand Down Expand Up @@ -288,6 +308,20 @@ struct CodegenStrategy {
decomposeIf(bool b, LinalgTransformationFilter::FilterFunction f = nullptr) {
return b ? decompose(std::move(f)) : *this;
}
/// Append a pattern to peel 'LinalgOpType'.
CodegenStrategy &
peel(StringRef opName, const LinalgPeelOptions &options,
const LinalgTransformationFilter::FilterFunction &f = nullptr) {
transformationSequence.emplace_back(
std::make_unique<Peel>(opName, options, f));
return *this;
}
/// Conditionally append a pattern to peel 'LinalgOpType'.
CodegenStrategy &
peelIf(bool b, StringRef opName, const LinalgPeelOptions &options,
LinalgTransformationFilter::FilterFunction f = nullptr) {
return b ? peel(opName, options, std::move(f)) : *this;
}
/// Append a pattern to rewrite `LinalgOpType` as a vector operation.
CodegenStrategy &
vectorize(StringRef opName,
Expand Down
46 changes: 46 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
Expand Up @@ -129,6 +129,9 @@ struct TiledLinalgOp {
FailureOr<TiledLinalgOp> tileLinalgOp(RewriterBase &b, LinalgOp op,
const LinalgTilingOptions &options);

/// Peel and canonicalize 'loops'.
void peelLoops(RewriterBase &rewriter, ArrayRef<scf::ForOp> loops);

/// Peel the loops of a TiledLinalgOp.
void peelTiledLinalgOp(RewriterBase &rewriter, TiledLinalgOp &res,
ArrayRef<int64_t> peeledLoops,
Expand Down Expand Up @@ -965,6 +968,49 @@ struct LinalgPromotionPattern : public LinalgBasePromotionPattern {
: LinalgBasePromotionPattern(opName, context, options, f, benefit) {}
};

///
/// Linalg peeling patterns.
///

/// Compute the loops to peel and return them in a SmallVector. Loops will be
/// peeled in order of appearance in the SmallVector. This order will impact the
/// output IR. If an inner-to-outer order is provided, the peeled iterations of
/// the outer loops will also contain the peeled inner loops. If an
/// outer-to-inner order is provided, the peeled iterations of the outer loops
/// will not contain any peeled inner loops.
using LoopsToPeelComputationFunction = std::function<void(
OpBuilder &, Operation *, SmallVectorImpl<scf::ForOp> &)>;

struct LinalgPeelOptions {
LoopsToPeelComputationFunction loopsToPeelComputationFunction = nullptr;
};

/// `filter` controls LinalgTransformMarker matching and update when specified.
struct LinalgPeelingPattern : public OpInterfaceRewritePattern<LinalgOp> {
/// Construct a generic pattern applied to all LinalgOp that verify `filter`.
LinalgPeelingPattern(
MLIRContext *context,
LinalgTransformationFilter f = LinalgTransformationFilter(),
LinalgPeelOptions options = LinalgPeelOptions(),
PatternBenefit benefit = 1);

/// Construct a pattern specifically applied to `opName`.
LinalgPeelingPattern(
StringRef opName, MLIRContext *context,
LinalgPeelOptions options = LinalgPeelOptions(),
LinalgTransformationFilter f = LinalgTransformationFilter(),
PatternBenefit benefit = 1);

LogicalResult matchAndRewrite(LinalgOp linalgOp,
PatternRewriter &rewriter) const override;

private:
/// LinalgTransformMarker handles special attribute manipulations.
const LinalgTransformationFilter filter;
/// Peeling options.
const LinalgPeelOptions options;
};

///
/// Linalg vectorization patterns.
///
Expand Down
41 changes: 41 additions & 0 deletions mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
Expand Up @@ -262,6 +262,40 @@ struct LinalgStrategyPromotePass
LinalgTransformationFilter filter;
};

/// Configurable pass to apply pattern-based linalg peeling.
struct LinalgStrategyPeelPass
: public LinalgStrategyPeelPassBase<LinalgStrategyPeelPass> {

LinalgStrategyPeelPass() = default;

LinalgStrategyPeelPass(StringRef opName, LinalgPeelOptions opt,
LinalgTransformationFilter filt)
: options(opt), filter(std::move(filt)) {
this->anchorOpName.setValue(opName.str());
}

void runOnOperation() override {
auto funcOp = getOperation();
if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
return;

RewritePatternSet peelingPatterns(funcOp.getContext());
if (!anchorOpName.empty()) {
peelingPatterns.add<LinalgPeelingPattern>(
anchorOpName, funcOp.getContext(), options, filter);
} else {
peelingPatterns.add<LinalgPeelingPattern>(funcOp.getContext(), filter,
options);
}
if (failed(
applyPatternsAndFoldGreedily(funcOp, std::move(peelingPatterns))))
return signalPassFailure();
}

LinalgPeelOptions options;
LinalgTransformationFilter filter;
};

/// Configurable pass to apply pattern-based linalg vectorization.
struct LinalgStrategyVectorizePass
: public LinalgStrategyVectorizePassBase<LinalgStrategyVectorizePass> {
Expand Down Expand Up @@ -506,6 +540,13 @@ mlir::createLinalgStrategyInterchangePass(
filter);
}

/// Create a LinalgStrategyPeelPass.
std::unique_ptr<OperationPass<func::FuncOp>>
mlir::createLinalgStrategyPeelPass(StringRef opName, LinalgPeelOptions opt,
const LinalgTransformationFilter &filter) {
return std::make_unique<LinalgStrategyPeelPass>(opName, opt, filter);
}

/// Create a LinalgStrategyVectorizePass.
std::unique_ptr<OperationPass<func::FuncOp>>
mlir::createLinalgStrategyVectorizePass(
Expand Down
38 changes: 38 additions & 0 deletions mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
Expand Up @@ -323,6 +323,15 @@ static SmallVector<Value, 4> peelLoop(RewriterBase &rewriter, Operation *op) {
.Default([&](Operation *op) { return op->getResults(); });
}

/// Peel and canonicalize 'loops'.
void mlir::linalg::peelLoops(RewriterBase &rewriter,
ArrayRef<scf::ForOp> loops) {
for (auto loopOp : loops) {
SmallVector<Value, 4> loopResults;
loopResults = peelLoop(rewriter, loopOp);
}
}

/// Peel loops after tiling.
void mlir::linalg::peelTiledLinalgOp(RewriterBase &rewriter, TiledLinalgOp &res,
ArrayRef<int64_t> peeledLoops,
Expand Down Expand Up @@ -716,6 +725,35 @@ LogicalResult mlir::linalg::LinalgBasePromotionPattern::matchAndRewrite(
return success();
}

mlir::linalg::LinalgPeelingPattern::LinalgPeelingPattern(
MLIRContext *context, LinalgTransformationFilter f,
LinalgPeelOptions options, PatternBenefit benefit)
: OpInterfaceRewritePattern<LinalgOp>(context, benefit),
filter(std::move(f)), options(std::move(options)) {}

mlir::linalg::LinalgPeelingPattern::LinalgPeelingPattern(
StringRef opName, MLIRContext *context, LinalgPeelOptions options,
LinalgTransformationFilter f, PatternBenefit benefit)
: OpInterfaceRewritePattern<LinalgOp>(context, benefit),
filter(f.addOpNameFilter(opName)), options(std::move(options)) {}

LogicalResult mlir::linalg::LinalgPeelingPattern::matchAndRewrite(
LinalgOp linalgOp, PatternRewriter &rewriter) const {
if (failed(filter.checkAndNotify(rewriter, linalgOp)))
return failure();

// Increase marker counter even if peeling doesn't happen for this op.
filter.replaceLinalgTransformationFilter(rewriter, linalgOp);

if (!options.loopsToPeelComputationFunction)
return failure();

SmallVector<scf::ForOp, 4> loopsToPeel;
options.loopsToPeelComputationFunction(rewriter, linalgOp, loopsToPeel);
peelLoops(rewriter, loopsToPeel);
return success();
}

mlir::linalg::LinalgVectorizationPattern::LinalgVectorizationPattern(
MLIRContext *context, LinalgTransformationFilter f,
LinalgVectorizationOptions options, PatternBenefit benefit)
Expand Down

0 comments on commit 9a79b1b

Please sign in to comment.