Skip to content

Commit

Permalink
[mlir][linalg] Retire Linalg's Vectorization Pattern
Browse files Browse the repository at this point in the history
This revision retires the LinalgCodegenStrategy vectorization pattern. Please see the context: https://discourse.llvm.org/t/psa-retire-linalg-filter-based-patterns/63785.
This revision improves the transform dialect's VectorizeOp in different ways below:
- Adds LinalgDialect as a dependent dialect. When `transform.structured.vectorize` vectorizes `tensor.pad`, it generates `linalg.init_tensor`. In this case, linalg dialect must be registered.
- Inserts CopyVectorizationPattern in order to vectorize `memref.copy`.
- Creates two attributes: `disable_multi_reduction_to_contract_patterns` and `disable_transfer_permutation_map_lowering_patterns`. They are limiting the power of vectorization and are currently intended for testing purposes.

It also removes some of the "CHECK: vector.transfer_write" in the vectorization.mlir test. They are redundant writes, at the end of the code there is a rewrite to the same place. Transform dialect no longer generates them.

Depends on D133684 that retires the LinalgCodegenStrategy vectorization pass.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D133699
  • Loading branch information
grypp committed Sep 15, 2022
1 parent 51e0946 commit 5279e11
Show file tree
Hide file tree
Showing 6 changed files with 499 additions and 77 deletions.
Expand Up @@ -767,6 +767,10 @@ def VectorizeOp : Op<Transform_Dialect, "structured.vectorize",
Note that this transformation is invalidating the handles to any payload IR
operation that is contained inside the vectorization target.

`disable_multi_reduction_to_contract_patterns` and
`disable_transfer_permutation_map_lowering_patterns` limits the power of
vectorization. They are currently intended for testing purposes.

#### Return modes:

This operation produces `definiteFailure` if vectorization fails for any
Expand All @@ -776,7 +780,9 @@ def VectorizeOp : Op<Transform_Dialect, "structured.vectorize",
}];

let arguments = (ins PDL_Operation:$target,
DefaultValuedAttr<BoolAttr, "false">:$vectorize_padding);
DefaultValuedAttr<BoolAttr, "false">:$vectorize_padding,
DefaultValuedAttr<BoolAttr, "false">:$disable_multi_reduction_to_contract_patterns,
DefaultValuedAttr<BoolAttr, "false">:$disable_transfer_permutation_map_lowering_patterns);
let results = (outs PDL_Operation:$transformed);

let assemblyFormat = "$target attr-dict";
Expand Down
37 changes: 0 additions & 37 deletions mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
Expand Up @@ -926,31 +926,6 @@ struct LinalgPeelingPattern : public OpInterfaceRewritePattern<LinalgOp> {
/// Empty for now, used for SFINAE purposes only.
struct LinalgVectorizationOptions {};

/// `filter` controls LinalgTransformMarker matching and update when specified.
/// See `vectorizeLinalgOp` for more details.
struct LinalgVectorizationPattern : public OpInterfaceRewritePattern<LinalgOp> {
/// Construct a generic pattern applied to all LinalgOp that verify `filter`.
LinalgVectorizationPattern(
MLIRContext *context,
LinalgTransformationFilter f = LinalgTransformationFilter(),
LinalgVectorizationOptions options = LinalgVectorizationOptions(),
PatternBenefit benefit = 1);

/// Construct a pattern specifically applied to `opName`.
LinalgVectorizationPattern(
StringRef opName, MLIRContext *context,
LinalgVectorizationOptions options = LinalgVectorizationOptions(),
LinalgTransformationFilter f = LinalgTransformationFilter(),
PatternBenefit benefit = 1);

LogicalResult matchAndRewrite(LinalgOp linalgOp,
PatternRewriter &rewriter) const override;

private:
/// LinalgTransformMarker handles special attribute manipulations.
LinalgTransformationFilter filter;
};

/// `filter` controls LinalgTransformMarker matching and update when specified.
/// See `vectorizeLinalgOp` for more details.
struct CopyVectorizationPattern : public OpRewritePattern<memref::CopyOp> {
Expand Down Expand Up @@ -1335,18 +1310,6 @@ class VectorizationPatterns<> {
const LinalgTransformationFilter &f) {}
};

template <typename OpTy, typename... OpTypes>
class VectorizationPatterns<OpTy, OpTypes...> {
public:
static void insert(RewritePatternSet &patterns,
const LinalgVectorizationOptions &options,
const LinalgTransformationFilter &f) {
patterns.add<LinalgVectorizationPattern>(OpTy::getOperationName(),
patterns.getContext(), options, f);
VectorizationPatterns<OpTypes...>::insert(patterns, options, f);
}
};

template <typename... OpTypes>
class TilingPatterns;

Expand Down
31 changes: 27 additions & 4 deletions mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
Expand Up @@ -1166,6 +1166,22 @@ LogicalResult TileToForeachThreadOp::verify() {
// VectorizeOp
//===----------------------------------------------------------------------===//

namespace {
/// This is an helper only to call vectorize via a pattern inside of
/// VectorizeOp::applyToOne.
struct VectorizationPattern : public RewritePattern {
explicit VectorizationPattern(MLIRContext *context)
: RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
if (!linalgOp)
return failure();
return vectorize(rewriter, linalgOp);
}
};
} // namespace

DiagnosedSilenceableFailure
transform::VectorizeOp::applyToOne(Operation *target,
SmallVectorImpl<Operation *> &results,
Expand All @@ -1178,15 +1194,22 @@ transform::VectorizeOp::applyToOne(Operation *target,

MLIRContext *ctx = getContext();
RewritePatternSet patterns(ctx);
patterns.add<LinalgVectorizationPattern>(ctx);
patterns.add<VectorizationPattern>(ctx);

if (!getDisableTransferPermutationMapLoweringPatterns())
vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);

if (!getDisableMultiReductionToContractPatterns())
vector::populateVectorReductionToContractPatterns(patterns);

vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
vector::populateVectorReductionToContractPatterns(patterns);
patterns.add<linalg::LinalgCopyVTRForwardingPattern,
linalg::LinalgCopyVTWForwardingPattern>(ctx,
/*benefit=*/2);
vector::TransferReadOp::getCanonicalizationPatterns(patterns, ctx);
vector::TransferWriteOp::getCanonicalizationPatterns(patterns, ctx);

patterns.add<CopyVectorizationPattern>(ctx);

if (getVectorizePadding())
linalg::populatePadOpVectorizationPatterns(patterns);

Expand All @@ -1212,7 +1235,7 @@ class LinalgTransformDialectExtension

void init() {
declareDependentDialect<pdl::PDLDialect>();

declareDependentDialect<LinalgDialect>();
declareGeneratedDialect<AffineDialect>();
declareGeneratedDialect<arith::ArithmeticDialect>();
declareGeneratedDialect<scf::SCFDialect>();
Expand Down
19 changes: 0 additions & 19 deletions mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
Expand Up @@ -590,25 +590,6 @@ LogicalResult mlir::linalg::LinalgPeelingPattern::matchAndRewrite(
return success();
}

mlir::linalg::LinalgVectorizationPattern::LinalgVectorizationPattern(
MLIRContext *context, LinalgTransformationFilter f,
LinalgVectorizationOptions options, PatternBenefit benefit)
: OpInterfaceRewritePattern<LinalgOp>(context, benefit),
filter(std::move(f)) {}

mlir::linalg::LinalgVectorizationPattern::LinalgVectorizationPattern(
StringRef opName, MLIRContext *context, LinalgVectorizationOptions options,
LinalgTransformationFilter f, PatternBenefit benefit)
: OpInterfaceRewritePattern<LinalgOp>(context, benefit),
filter(f.addOpNameFilter(opName)) {}

LogicalResult mlir::linalg::LinalgVectorizationPattern::matchAndRewrite(
LinalgOp linalgOp, PatternRewriter &rewriter) const {
if (failed(filter.checkAndNotify(rewriter, linalgOp)))
return failure();
return vectorize(rewriter, linalgOp);
}

LogicalResult mlir::linalg::CopyVectorizationPattern::matchAndRewrite(
memref::CopyOp copyOp, PatternRewriter &rewriter) const {
return vectorizeCopy(rewriter, copyOp);
Expand Down

0 comments on commit 5279e11

Please sign in to comment.