Skip to content

Commit

Permalink
[mlir][linalg] Adapt the decompose patterns to use a filter (NFC).
Browse files Browse the repository at this point in the history
The revision updates the convolution decomposition patterns to take a linalg transformation filter. The transformation filter in a later revision allows use the patterns from CodegenStrategy.

Depends On D114690

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D114797
  • Loading branch information
gysit committed Nov 30, 2021
1 parent 1e82864 commit 98dbcff
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 36 deletions.
7 changes: 4 additions & 3 deletions mlir/include/mlir/Dialect/Linalg/Passes.h
Expand Up @@ -112,9 +112,10 @@ createLinalgStrategyGeneralizePass(StringRef opName = "",
linalg::LinalgTransformationFilter());

/// Create a LinalgStrategyDecomposePass.
// TODO: atm this is applied to all supported ops. If/when we need finer control
// this should be exposed with an opName + filter and a proper pattern.
std::unique_ptr<OperationPass<FuncOp>> createLinalgStrategyDecomposePass();
// TODO: if/when we need finer control add an `opName` parameter.
std::unique_ptr<OperationPass<FuncOp>>
createLinalgStrategyDecomposePass(linalg::LinalgTransformationFilter filter =
linalg::LinalgTransformationFilter());

/// Create a LinalgStrategyInterchangePass.
std::unique_ptr<OperationPass<FuncOp>>
Expand Down
3 changes: 1 addition & 2 deletions mlir/include/mlir/Dialect/Linalg/Passes.td
Expand Up @@ -286,8 +286,7 @@ def LinalgStrategyGeneralizePass
];
}

// TODO: atm this is applied to all supported ops. If/when we need finer control
// this should be exposed with an opName + filter and a proper pattern.
// TODO: if/when we need finer control add an anchorOp option.
def LinalgStrategyDecomposePass
: FunctionPass<"linalg-strategy-decompose-pass"> {
let summary = "Configurable pass to apply pattern-based generalization.";
Expand Down
16 changes: 10 additions & 6 deletions mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
Expand Up @@ -49,12 +49,6 @@ void populateConvVectorizationPatterns(
MLIRContext *context, SmallVectorImpl<RewritePatternSet> &patterns,
ArrayRef<int64_t> tileSizes);

/// Populates patterns to decompose high-D convolution ops into low-D ones. This
/// is a step in progressive lowering for convolution ops, afterwards we can
/// vectorize the low-D convolution ops.
void populateDecomposeConvolutionPatterns(RewritePatternSet &patterns,
PatternBenefit benefit = 1);

/// Populates patterns for vectorizing low-D convolution ops. This is a step in
/// progressive lowering for convolution ops, it assume high-D convolution ops
/// were decomposed previously.
Expand Down Expand Up @@ -1178,6 +1172,16 @@ void populateLinalgNamedOpsGeneralizationPatterns(
RewritePatternSet &patterns,
LinalgTransformationFilter filter = LinalgTransformationFilter());

/// Linalg decompose convolutions patterns

/// Populates patterns to decompose high-D convolution ops into low-D ones. This
/// is a step in progressive lowering for convolution ops, afterwards we can
/// vectorize the low-D convolution ops.
void populateDecomposeConvolutionPatterns(
RewritePatternSet &patterns,
LinalgTransformationFilter filter = LinalgTransformationFilter(),
PatternBenefit benefit = 1);

/// Linalg distribution patterns
//
/// Populates `patterns` with patterns to distribute linalg.tiled_loop.
Expand Down
15 changes: 10 additions & 5 deletions mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
Expand Up @@ -191,16 +191,21 @@ struct LinalgStrategyDecomposePass

LinalgStrategyDecomposePass() = default;

LinalgStrategyDecomposePass(LinalgTransformationFilter filter)
: filter(filter) {}

void runOnFunction() override {
auto funcOp = getFunction();
if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
return;
RewritePatternSet decompositionPattern(funcOp.getContext());
populateDecomposeConvolutionPatterns(decompositionPattern);
populateDecomposeConvolutionPatterns(decompositionPattern, filter);
if (failed(applyPatternsAndFoldGreedily(funcOp,
std::move(decompositionPattern))))
signalPassFailure();
}

LinalgTransformationFilter filter;
};

/// Configurable pass to apply pattern-based linalg generalization.
Expand Down Expand Up @@ -478,12 +483,12 @@ mlir::createLinalgStrategyGeneralizePass(StringRef opName,
LinalgTransformationFilter filter) {
return std::make_unique<LinalgStrategyGeneralizePass>(opName, filter);
}

/// Create a LinalgStrategyDecomposePass.
// TODO: atm this is applied to all supported ops. If/when we need finer control
// this should be exposed with an opName + filter and a proper pattern.
// TODO: if/when we need finer control add an `opName` parameter.
std::unique_ptr<OperationPass<FuncOp>>
mlir::createLinalgStrategyDecomposePass() {
return std::make_unique<LinalgStrategyDecomposePass>();
mlir::createLinalgStrategyDecomposePass(LinalgTransformationFilter filter) {
return std::make_unique<LinalgStrategyDecomposePass>(filter);
}

/// Create a LinalgStrategyInterchangePass.
Expand Down
62 changes: 42 additions & 20 deletions mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
Expand Up @@ -929,31 +929,36 @@ namespace {
/// convolution ops.
struct DownscaleSizeOneWindowed2DConvolution final
: public OpRewritePattern<Conv2DNhwcHwcfOp> {
using OpRewritePattern::OpRewritePattern;
DownscaleSizeOneWindowed2DConvolution(
MLIRContext *context,
LinalgTransformationFilter filter = LinalgTransformationFilter(),
PatternBenefit benefit = 1)
: OpRewritePattern<Conv2DNhwcHwcfOp>(context, benefit), filter(filter) {}

LogicalResult matchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp,
PatternRewriter &rewriter) const override {
auto linalgOp = cast<linalg::LinalgOp>(*convOp);
if (linalgOp.hasBufferSemantics())
if (failed(filter.checkAndNotify(rewriter, convOp)))
return failure();
if (convOp.hasBufferSemantics())
return failure(); // To be implemented

Value input = convOp.inputs().front();
Value filter = convOp.inputs().back();
Value kernel = convOp.inputs().back();
Value output = convOp.outputs().front();

auto inputType = input.getType().dyn_cast<RankedTensorType>();
auto filterType = filter.getType().dyn_cast<RankedTensorType>();
auto kernelType = kernel.getType().dyn_cast<RankedTensorType>();
auto outputType = output.getType().dyn_cast<RankedTensorType>();

auto filterShape = filterType.getShape();
auto kernelShape = kernelType.getShape();
auto outputShape = outputType.getShape();

// Only handle the case where at least one of the window dimensions is
// of size 1. Other cases can rely on tiling to reduce to such cases.
int64_t fhSize = filterShape[0], fwSize = filterShape[1];
int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
int64_t ohSize = outputShape[1], owSize = outputShape[2];
bool removeH = (fhSize == 1 && ohSize == 1);
bool removeW = (fwSize == 1 && owSize == 1);
bool removeH = (khSize == 1 && ohSize == 1);
bool removeW = (kwSize == 1 && owSize == 1);
if (!removeH && !removeW)
return failure();

Expand All @@ -962,17 +967,17 @@ struct DownscaleSizeOneWindowed2DConvolution final
using RTTBuilder = RankedTensorType::Builder;
RankedTensorType newInputType =
RTTBuilder(inputType).dropDim((removeH ? 1 : 2));
RankedTensorType newFilterType =
RTTBuilder(filterType).dropDim((removeH ? 0 : 1));
RankedTensorType newKernelType =
RTTBuilder(kernelType).dropDim((removeH ? 0 : 1));
RankedTensorType newOutputType =
RTTBuilder(outputType).dropDim(removeH ? 1 : 2);

// Rank-reduce operands.
Location loc = convOp.getLoc();
Value newInput = tensor::createCanonicalRankReducingExtractSliceOp(
rewriter, loc, input, newInputType);
Value newFilter = tensor::createCanonicalRankReducingExtractSliceOp(
rewriter, loc, filter, newFilterType);
Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp(
rewriter, loc, kernel, newKernelType);
Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp(
rewriter, loc, output, newOutputType);

Expand All @@ -988,28 +993,39 @@ struct DownscaleSizeOneWindowed2DConvolution final
auto dilationsAttr = rewriter.getI64VectorAttr(dilations);

auto conv1DOp = rewriter.create<linalg::Conv1DNwcWcfOp>(
loc, newOutputType, ValueRange{newInput, newFilter},
loc, newOutputType, ValueRange{newInput, newKernel},
ValueRange{newOutput}, stridesAttr, dilationsAttr);

// Insert back.
Value inserted = tensor::createCanonicalRankReducingInsertSliceOp(
rewriter, loc, conv1DOp.getResult(0), output);
rewriter.replaceOp(convOp, inserted);

filter.replaceLinalgTransformationFilter(rewriter, conv1DOp);
return success();
};

private:
/// LinalgTransformMarker handles special attribute manipulations.
LinalgTransformationFilter filter;
};

/// Rewrites 2-D depthwise convolution ops with size-1 (w, kw) or (h, kh)
/// dimensions into 1-D depthwise convolution ops.
struct DownscaleDepthwiseConv2DNhwcHwcOp final
: public OpRewritePattern<DepthwiseConv2DNhwcHwcOp> {
using OpRewritePattern::OpRewritePattern;
DownscaleDepthwiseConv2DNhwcHwcOp(
MLIRContext *context,
LinalgTransformationFilter filter = LinalgTransformationFilter(),
PatternBenefit benefit = 1)
: OpRewritePattern<DepthwiseConv2DNhwcHwcOp>(context, benefit),
filter(filter) {}

LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcOp convOp,
PatternRewriter &rewriter) const override {
auto linalgOp = cast<linalg::LinalgOp>(*convOp);
if (linalgOp.hasBufferSemantics())
if (failed(filter.checkAndNotify(rewriter, convOp)))
return failure();
if (convOp.hasBufferSemantics())
return failure(); // To be implemented

Value input = convOp.inputs().front();
Expand Down Expand Up @@ -1071,15 +1087,21 @@ struct DownscaleDepthwiseConv2DNhwcHwcOp final
rewriter, loc, conv1DOp.getResult(0), output);
rewriter.replaceOp(convOp, inserted);

filter.replaceLinalgTransformationFilter(rewriter, conv1DOp);
return success();
};

private:
/// LinalgTransformMarker handles special attribute manipulations.
LinalgTransformationFilter filter;
};

} // namespace

void linalg::populateDecomposeConvolutionPatterns(RewritePatternSet &patterns,
PatternBenefit benefit) {
void linalg::populateDecomposeConvolutionPatterns(
RewritePatternSet &patterns, LinalgTransformationFilter filter,
PatternBenefit benefit) {
patterns.add<DownscaleSizeOneWindowed2DConvolution,
DownscaleDepthwiseConv2DNhwcHwcOp>(patterns.getContext(),
DownscaleDepthwiseConv2DNhwcHwcOp>(patterns.getContext(), filter,
benefit);
}

0 comments on commit 98dbcff

Please sign in to comment.