Skip to content

Commit

Permalink
[mlir][linalg] NFC: minor cleanups after moving pad to tensor dialect
Browse files Browse the repository at this point in the history
Reviewed By: mravishankar

Differential Revision: https://reviews.llvm.org/D120627
  • Loading branch information
antiagainst committed Mar 3, 2022
1 parent 5aeaabf commit 7d249df
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 15 deletions.
7 changes: 3 additions & 4 deletions mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
Expand Up @@ -103,10 +103,9 @@ void populateFoldReshapeOpsByLinearizationPatterns(RewritePatternSet &patterns);
void populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
RewritePatternSet &patterns);

/// Pattern to fuse a `linalg.pad_tensor` operation with the producer of its
/// source, if the producer is a `linalg` operation with all parallel iterator
/// types.
void populateFusePadTensorWithProducerLinalgOpPatterns(
/// Pattern to fuse a `tensor.pad` operation with the producer of its source,
/// if the producer is a `linalg` operation with all parallel iterator types.
void populateFuseTensorPadWithProducerLinalgOpPatterns(
RewritePatternSet &patterns);

/// Patterns to convert from one named op to another. These can be seen as
Expand Down
22 changes: 12 additions & 10 deletions mlir/lib/Dialect/Linalg/Transforms/PadOpInterchange.cpp
@@ -1,13 +1,14 @@
//===- PadOpInterchange.cpp - Interchange pad operation with Generic ops --===//
//===- PadOpInterchange.cpp - Interchange tensor.pad with linalg producer -===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements patterns that intechanges a generic op -> pad_tensor
// pattern into extract_slice -> generic_op.
// This file implements patterns that intechanges a linalg.generic -> tensor.pad
// op chain into a tensor.extract_slice -> linalg.generic -> tensor.insert_slice
// op chain.
//
//===----------------------------------------------------------------------===//

Expand All @@ -17,15 +18,14 @@
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

using namespace mlir;
using namespace mlir::linalg;

namespace {

/// A sequence of operations
///
/// ```mlir
/// %0 = linalg. ...
/// %1 = linalg.pad_tensor %0 ...
/// %1 = tensor.pad %0 ...
/// ```
///
/// can be replaced with
Expand All @@ -40,6 +40,7 @@ namespace {
/// if the `linalg.generic` has all parallel iterator types.
struct FusePadOp : OpRewritePattern<tensor::PadOp> {
using OpRewritePattern<tensor::PadOp>::OpRewritePattern;

LogicalResult matchAndRewrite(tensor::PadOp padOp,
PatternRewriter &rewriter) const override {
// Only works on padding op that sets the padded value to a constant.
Expand All @@ -50,7 +51,7 @@ struct FusePadOp : OpRewritePattern<tensor::PadOp> {
// This pattern could work for any Linalg op. For now restrict it to generic
// ops.
Value source = padOp.source();
auto linalgOp = source.getDefiningOp<GenericOp>();
auto linalgOp = source.getDefiningOp<linalg::GenericOp>();
if (!linalgOp) {
return rewriter.notifyMatchFailure(
padOp, "expected source to be linalg.generic op");
Expand All @@ -75,14 +76,14 @@ struct FusePadOp : OpRewritePattern<tensor::PadOp> {
// Create the tensor of same size as output of the pad op.
RankedTensorType padResultType = padOp.getResultType();
auto resultSizes = getAsOpFoldResult(resultShape[0]);
auto initTensor = rewriter.create<InitTensorOp>(
auto initTensor = rewriter.create<linalg::InitTensorOp>(
loc, resultSizes, padResultType.getElementType());

// Fill the tensor with the pad value.
// TODO: There is an option to fill only the boundaries. For now just
// filling the whole tensor.
auto fillTensor =
rewriter.create<FillOp>(loc, padValue, initTensor.getResult());
rewriter.create<linalg::FillOp>(loc, padValue, initTensor.getResult());

// Construct a slice of the fill result that is to be replaced with the
// result of the generic op. The low pad values are the offsets, the size of
Expand All @@ -107,7 +108,8 @@ struct FusePadOp : OpRewritePattern<tensor::PadOp> {
loc, fillTensor.getResult(0), offsets, sizes, strides);

// Clone the generic op.
auto clonedOp = cast<GenericOp>(rewriter.clone(*linalgOp.getOperation()));
auto clonedOp =
cast<linalg::GenericOp>(rewriter.clone(*linalgOp.getOperation()));
clonedOp.setOutputOperand(resultNumber, slice.getResult());

// Insert it back into the result of the fill.
Expand All @@ -119,7 +121,7 @@ struct FusePadOp : OpRewritePattern<tensor::PadOp> {
};
} // namespace

void mlir::linalg::populateFusePadTensorWithProducerLinalgOpPatterns(
void mlir::linalg::populateFuseTensorPadWithProducerLinalgOpPatterns(
RewritePatternSet &patterns) {
patterns.add<FusePadOp>(patterns.getContext());
}
2 changes: 1 addition & 1 deletion mlir/test/lib/Dialect/Linalg/TestPadFusion.cpp
Expand Up @@ -34,7 +34,7 @@ struct TestPadFusionPass
MLIRContext *context = &getContext();
FuncOp funcOp = getOperation();
RewritePatternSet patterns(context);
linalg::populateFusePadTensorWithProducerLinalgOpPatterns(patterns);
linalg::populateFuseTensorPadWithProducerLinalgOpPatterns(patterns);
if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(),
std::move(patterns))))
return signalPassFailure();
Expand Down

0 comments on commit 7d249df

Please sign in to comment.