Skip to content

Commit a8f69be

Browse files
author
Tobias Gysi
committed
[mlir][linalg] Expose flag to control nofold attribute when padding.
Setting the nofold attribute enables packing an operand. At the moment, the attribute is set by default. The pack introduces a callback to control the flag. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D111718
1 parent 0b48b01 commit a8f69be

File tree

4 files changed

+44
-10
lines changed

4 files changed

+44
-10
lines changed

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -452,6 +452,10 @@ using TileSizeComputationFunction =
452452
using PaddingValueComputationFunction =
453453
std::function<FailureOr<Value>(OpBuilder &, OpOperand &)>;
454454

455+
/// Callback returning true if the pad tensor operation defining the given
456+
/// OpOperand shall be marked as nofold to enable packing.
457+
using PaddingNoFoldComputationFunction = std::function<bool(OpOperand &)>;
458+
455459
struct LinalgTilingOptions {
456460
/// Computation function that returns the tile sizes for each operation.
457461
/// Delayed construction of constant tile sizes should occur to interoperate
@@ -526,6 +530,18 @@ struct LinalgTilingOptions {
526530
return *this;
527531
}
528532

533+
/// Callback returning true if the pad tensor operation defining the given
534+
/// OpOperand shall be marked as nofold to enable packing. A padding operation
535+
/// is only marked nofold if `paddingNoFoldComputationFunction` is set and
536+
/// returns true. Otherwise, the nofold attribute is set to false.
537+
PaddingNoFoldComputationFunction paddingNoFoldComputationFunction = nullptr;
538+
539+
LinalgTilingOptions &
540+
setPaddingNoFoldComputationFunction(PaddingNoFoldComputationFunction fun) {
541+
paddingNoFoldComputationFunction = std::move(fun);
542+
return *this;
543+
}
544+
529545
/// Peel the specified loops.
530546
SmallVector<int64_t> peeledLoops;
531547

@@ -999,6 +1015,7 @@ struct PadTensorOpTransformationPattern : public OpRewritePattern<PadTensorOp> {
9991015
LogicalResult
10001016
rewriteAsPaddedOp(PatternRewriter &rewriter, LinalgOp opToPad,
10011017
const PaddingValueComputationFunction &paddingFunc,
1018+
const PaddingNoFoldComputationFunction &nofoldFunc,
10021019
LinalgOp &paddedOp);
10031020

10041021
using OptimizeCopyFn =

mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,8 @@ LinalgTilingOptions &mlir::linalg::LinalgTilingOptions::scalarizeDynamicDims() {
153153
/// padded to a static shape.
154154
static LogicalResult padOperandToSmallestStaticBoundingBox(
155155
PatternRewriter &rewriter, linalg::LinalgOp opToPad, OpOperand *opOperand,
156-
const PaddingValueComputationFunction &paddingFunc, Value &result) {
156+
const PaddingValueComputationFunction &paddingFunc,
157+
const PaddingNoFoldComputationFunction &nofoldFunc, Value &result) {
157158
// Can't pad scalars.
158159
if (opToPad.getShape(opOperand).empty())
159160
return success();
@@ -181,15 +182,17 @@ static LogicalResult padOperandToSmallestStaticBoundingBox(
181182
}
182183
auto staticTensorType = RankedTensorType::get(
183184
staticSizes, getElementTypeOrSelf(opOperand->get()));
185+
bool nofold = nofoldFunc ? nofoldFunc(*opOperand) : false;
184186
result = linalg::PadTensorOp::createPadHighOp(
185187
staticTensorType, opOperand->get(), paddingValue.getValue(),
186-
/*nofold=*/true, opToPad->getLoc(), rewriter);
188+
/*nofold=*/nofold, opToPad->getLoc(), rewriter);
187189
return success();
188190
}
189191

190192
LogicalResult
191193
linalg::rewriteAsPaddedOp(PatternRewriter &rewriter, LinalgOp opToPad,
192194
const PaddingValueComputationFunction &paddingFunc,
195+
const PaddingNoFoldComputationFunction &nofoldFunc,
193196
LinalgOp &paddedOp) {
194197
Location loc = opToPad->getLoc();
195198

@@ -208,7 +211,8 @@ linalg::rewriteAsPaddedOp(PatternRewriter &rewriter, LinalgOp opToPad,
208211
// If padding was requested but the shape cannot be bounded statically then
209212
// the pattern fails to apply.
210213
if (failed(padOperandToSmallestStaticBoundingBox(
211-
rewriter, opToPad, opOperand, paddingFunc, paddedOperand)))
214+
rewriter, opToPad, opOperand, paddingFunc, nofoldFunc,
215+
paddedOperand)))
212216
return failure();
213217
newOperands.push_back(paddedOperand ? paddedOperand : opOperand->get());
214218
}
@@ -341,9 +345,9 @@ LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewriteBase(
341345
// Try to pad on the fly by rewriting res->op as a padded op. If successful,
342346
// `res.op` is rewritten in static form with padded operands.
343347
LinalgOp paddedOp;
344-
if (succeeded(rewriteAsPaddedOp(rewriter, res->op,
345-
options.paddingValueComputationFunction,
346-
paddedOp))) {
348+
if (succeeded(rewriteAsPaddedOp(
349+
rewriter, res->op, options.paddingValueComputationFunction,
350+
options.paddingNoFoldComputationFunction, paddedOp))) {
347351
filter.replaceLinalgTransformationFilter(rewriter, paddedOp);
348352
res->op = paddedOp;
349353
result = *res;

mlir/test/Dialect/Linalg/tile-and-pad-tensors.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
// RUN: mlir-opt %s -test-linalg-transform-patterns="test-tile-pattern padded-operands=0,1,2 tile-sizes=2,3,4" -canonicalize | FileCheck %s
2-
// RUN: mlir-opt %s -test-linalg-transform-patterns="test-tile-pattern padded-operands=0,1 tile-sizes=2,3" -canonicalize | FileCheck %s -check-prefix=CHECK-1DIM-TILE
1+
// RUN: mlir-opt %s -test-linalg-transform-patterns="test-tile-pattern padded-operands=0,1,2 nofold-operands=0,1 tile-sizes=2,3,4" -canonicalize | FileCheck %s
2+
// RUN: mlir-opt %s -test-linalg-transform-patterns="test-tile-pattern padded-operands=0,1 nofold-operands=0,1 tile-sizes=2,3" -canonicalize | FileCheck %s -check-prefix=CHECK-1DIM-TILE
33

44
// CHECK-LABEL: func @matmul_tensors(
55
// CHECK-SAME: %[[TA:[0-9a-z]+]]: tensor<?x?xi8>
@@ -24,7 +24,7 @@ func @matmul_tensors(
2424
// CHECK: : tensor<?x?xi8> to tensor<2x4xi8>
2525
// CHECK: %[[pB:.*]] = linalg.pad_tensor %[[sTB]] nofold low[%[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}]
2626
// CHECK: : tensor<?x?xi8> to tensor<4x3xi8>
27-
// CHECK: %[[pC:.*]] = linalg.pad_tensor %[[sTC]] nofold low[%[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}]
27+
// CHECK: %[[pC:.*]] = linalg.pad_tensor %[[sTC]] low[%[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}]
2828
// CHECK: : tensor<?x?xi32> to tensor<2x3xi32>
2929
// CHECK: %[[pD:.*]] = linalg.matmul ins(%[[pA]], %[[pB]] : tensor<2x4xi8>, tensor<4x3xi8>)
3030
// CHECK-SAME: outs(%[[pC]] : tensor<2x3xi32>) -> tensor<2x3xi32>

mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,10 @@ struct TestLinalgTransforms
113113
*this, "padded-operands",
114114
llvm::cl::desc("Operands to pad when test-tile-pattern"),
115115
llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
116+
ListOption<int64_t> nofoldOperands{
117+
*this, "nofold-operands",
118+
llvm::cl::desc("Operands to set nofold when test-tile-pattern"),
119+
llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
116120
ListOption<int64_t> peeledLoops{
117121
*this, "peeled-loops",
118122
llvm::cl::desc("Loops to be peeled when test-tile-pattern"),
@@ -581,6 +585,7 @@ static Value getNeutralOfLinalgOp(OpBuilder &b, OpOperand &op) {
581585
static void applyTilePattern(FuncOp funcOp, std::string loopType,
582586
ArrayRef<int64_t> tileSizes,
583587
ArrayRef<int64_t> paddedOperands,
588+
ArrayRef<int64_t> nofoldOperands,
584589
ArrayRef<int64_t> peeledLoops,
585590
bool scalarizeDynamicDims) {
586591
MLIRContext *context = funcOp.getContext();
@@ -608,7 +613,13 @@ static void applyTilePattern(FuncOp funcOp, std::string loopType,
608613
return failure();
609614
return getNeutralOfLinalgOp(b, opOperand);
610615
};
616+
auto nofoldFunc = [&](OpOperand &opOperand) {
617+
if (llvm::count(nofoldOperands, opOperand.getOperandNumber()) != 0)
618+
return true;
619+
return false;
620+
};
611621
linalgTilingOptions.setPaddingValueComputationFunction(paddingFunc);
622+
linalgTilingOptions.setPaddingNoFoldComputationFunction(nofoldFunc);
612623
}
613624
tilingPattern.add<linalg::LinalgTilingPattern<linalg::MatmulOp>,
614625
linalg::LinalgTilingPattern<linalg::GenericOp>>(
@@ -743,9 +754,11 @@ void TestLinalgTransforms::runOnFunction() {
743754
skipPartial);
744755
if (testTilePattern)
745756
return applyTilePattern(getFunction(), loopType, tileSizes, paddedOperands,
746-
peeledLoops, /*scalarizeDynamicDims=*/false);
757+
nofoldOperands, peeledLoops,
758+
/*scalarizeDynamicDims=*/false);
747759
if (testTileScalarizeDynamicDims)
748760
return applyTilePattern(getFunction(), loopType, tileSizes, paddedOperands,
761+
nofoldOperands,
749762
/*peeledLoops=*/{}, /*scalarizeDynamicDims=*/true);
750763
if (testHoistPadding) {
751764
getFunction().walk([&](linalg::PadTensorOp padTensorOp) {

0 commit comments

Comments
 (0)