Skip to content

Commit

Permalink
[mlir][Linalg] Drop filter-based splitReduction
Browse files Browse the repository at this point in the history
This transformation is available and tested via the transform dialect.

Differential Revision: https://reviews.llvm.org/D135767
  • Loading branch information
nicolasvasilache committed Oct 12, 2022
1 parent bbe4441 commit e0cea16
Show file tree
Hide file tree
Showing 4 changed files with 5 additions and 271 deletions.
9 changes: 0 additions & 9 deletions mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -1050,7 +1050,6 @@ using ControlSplitReductionFn =
void populateSplitReductionPattern(
RewritePatternSet &patterns,
const ControlSplitReductionFn &controlSplitReductionFn,
const LinalgTransformationFilter &f = LinalgTransformationFilter(),
bool useAlloc = false);

/// Apply transformation to split the single linalg op reduction into a parallel
Expand Down Expand Up @@ -1094,14 +1093,6 @@ void populateSplitReductionPattern(
/// linalg.yield %5 : f32
/// } -> tensor<f32>
/// ```
FailureOr<LinalgOp>
splitReduction(PatternRewriter &b, LinalgOp op,
const ControlSplitReductionFn &controlSplitReductionFn,
const LinalgTransformationFilter &f, bool useAlloc = false);

/// Filterless version of the above.
/// Returns both the new linalg ops as well as the fillOp needed to initialize
/// the temporary expanded tensor with the proper neutral element.
struct SplitReductionResult {
Operation *initOrAlloc;
FillOp fillOp;
Expand Down
34 changes: 5 additions & 29 deletions mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,26 +58,6 @@ static Attribute getNeutralElement(Operation *op) {
return Attribute();
}

FailureOr<LinalgOp> mlir::linalg::splitReduction(
PatternRewriter &b, LinalgOp op,
const ControlSplitReductionFn &controlSplitReductionFn,
const LinalgTransformationFilter &filter, bool useAlloc) {
if (failed(filter.checkAndNotify(b, op)) || !op.hasTensorSemantics() ||
op.getNumReductionLoops() != 1 || op.getNumOutputs() != 1 ||
!op.hasOnlyProjectedPermutations())
return b.notifyMatchFailure(op, "precondition not met");

FailureOr<SplitReductionResult> res =
splitReduction(b, op, controlSplitReductionFn, useAlloc);
if (failed(res))
return failure();

filter.replaceLinalgTransformationFilter(b, res->splitLinalgOp);
filter.replaceLinalgTransformationFilter(b, res->resultCombiningLinalgOp);

return res->splitLinalgOp;
}

FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
PatternRewriter &b, LinalgOp op,
const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc) {
Expand Down Expand Up @@ -481,30 +461,26 @@ struct LinalgSplitReduction : public OpInterfaceRewritePattern<LinalgOp> {
/// Construct a generic pattern applied to all LinalgOp that verify `filter`.
LinalgSplitReduction(MLIRContext *context,
ControlSplitReductionFn controlSplitReductionFn,
LinalgTransformationFilter f, bool useAlloc = false,
PatternBenefit benefit = 1)
bool useAlloc = false, PatternBenefit benefit = 1)
: OpInterfaceRewritePattern<LinalgOp>(context, benefit),
controlSplitReductionFn(std::move(controlSplitReductionFn)),
useAlloc(useAlloc), filter(std::move(f)) {}
useAlloc(useAlloc) {}

LogicalResult matchAndRewrite(LinalgOp op,
PatternRewriter &rewriter) const override {
return splitReduction(rewriter, op, controlSplitReductionFn, filter,
useAlloc);
return splitReduction(rewriter, op, controlSplitReductionFn, useAlloc);
}

private:
ControlSplitReductionFn controlSplitReductionFn;
bool useAlloc;
LinalgTransformationFilter filter;
};

} // namespace

void linalg::populateSplitReductionPattern(
RewritePatternSet &patterns,
const ControlSplitReductionFn &controlSplitReductionFn,
const LinalgTransformationFilter &f, bool useAlloc) {
const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc) {
patterns.add<LinalgSplitReduction>(patterns.getContext(),
controlSplitReductionFn, f, useAlloc);
controlSplitReductionFn, useAlloc);
}
193 changes: 0 additions & 193 deletions mlir/test/Dialect/Linalg/split_reduction.mlir

This file was deleted.

40 changes: 0 additions & 40 deletions mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,14 +84,6 @@ struct TestLinalgTransforms
llvm::cl::desc("Test rewrite of subtensor(tensor.pad) into "
"tensor.pad(subtensor)"),
llvm::cl::init(false)};
Option<bool> testSplitReduction{
*this, "test-split-reduction",
llvm::cl::desc("Test split reduction transformation"),
llvm::cl::init(false)};
Option<bool> testSplitReductionInnerParallel{
*this, "test-split-reduction-inner-parallel",
llvm::cl::desc("Test split reduction with inner parallel transformation"),
llvm::cl::init(false)};
ListOption<int64_t> peeledLoops{
*this, "peeled-loops",
llvm::cl::desc("Loops to be peeled when test-tile-pattern")};
Expand Down Expand Up @@ -176,34 +168,6 @@ static void applyExtractSliceOfPadTensorSwapPattern(func::FuncOp funcOp) {
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
}

static void applySplitReduction(func::FuncOp funcOp) {
RewritePatternSet patterns(funcOp.getContext());
linalg::populateSplitReductionPattern(
patterns,
[](LinalgOp op) {
unsigned insertDimIndex = op.getNumLoops() - 1;
return SplitReductionOptions{4, insertDimIndex, false};
},
LinalgTransformationFilter(
ArrayRef<StringAttr>{},
StringAttr::get(funcOp.getContext(), "SPLIT")));
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
}

static void applySplitReductionInnerParallel(func::FuncOp funcOp) {
RewritePatternSet patterns(funcOp.getContext());
linalg::populateSplitReductionPattern(
patterns,
[](LinalgOp op) {
unsigned insertDimIndex = op.getNumLoops() - 1;
return SplitReductionOptions{4, insertDimIndex, true};
},
LinalgTransformationFilter(
ArrayRef<StringAttr>{},
StringAttr::get(funcOp.getContext(), "SPLIT")));
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
}

static void applyBubbleUpExtractSliceOpPattern(func::FuncOp funcOp) {
RewritePatternSet patterns(funcOp.getContext());
populateBubbleUpExtractSliceOpPatterns(patterns);
Expand Down Expand Up @@ -237,10 +201,6 @@ void TestLinalgTransforms::runOnOperation() {
return applyGeneralizePadTensorPatterns(getOperation());
if (testSwapSubTensorPadTensor)
return applyExtractSliceOfPadTensorSwapPattern(getOperation());
if (testSplitReduction)
return applySplitReduction(getOperation());
if (testSplitReductionInnerParallel)
return applySplitReductionInnerParallel(getOperation());
if (testBubbleUpExtractSliceOpPattern)
return applyBubbleUpExtractSliceOpPattern(getOperation());
if (testSwapExtractSliceWithFill)
Expand Down

0 comments on commit e0cea16

Please sign in to comment.