diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index cbf0304d8c585..8e7ea21fb8f81 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -46,6 +46,9 @@ bool skipUnitDimReshape(const OpResult &producer, OpOperand &consumer); //===----------------------------------------------------------------------===// using LinalgLoops = SmallVector; +void populatePadTensorTilingPatterns(RewritePatternSet &patterns, + const LinalgTilingOptions &options); + /// Populate patterns for vectorizing low-D convolution ops. This is a step in /// progressive lowering for convolution ops, it assume high-D convolution ops /// were decomposed previously. diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp index 859f3f8521b13..2a052c037380c 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp @@ -100,6 +100,8 @@ struct LinalgStrategyTilePass filter); else tilingPattern.add(ctx, options, filter); + if (anchorOpName == linalg::PadTensorOp::getOperationName()) + populatePadTensorTilingPatterns(tilingPattern, options); (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern)); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp index 89ca83375c0f6..36bd434e845da 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -354,7 +354,9 @@ static LogicalResult tilePadTensorOp(RewriterBase &builder, PadTensorOp op, int64_t rank = op.getResultType().getRank(); SmallVector tileSizes = options.tileSizeComputationFunction(builder, op); - assert(static_cast(tileSizes.size()) == rank); + // Normalize untiled padding dimensions to 0. + Value zero = builder.create(loc, 0); + tileSizes.append(rank - tileSizes.size(), zero); // Compute lower and upper bounds of the loop nest. SmallVector ranges = op.getIterationDomain(builder); SmallVector lbs, dims, allDims, steps; @@ -490,6 +492,12 @@ static void insertTilingPatterns(RewritePatternSet &patterns, patterns.add(ctx, options); } +void mlir::linalg::populatePadTensorTilingPatterns( + RewritePatternSet &patterns, const LinalgTilingOptions &options) { + auto *ctx = patterns.getContext(); + patterns.add(ctx, options); +} + static void applyExtractSliceOfPadTensorSwapPattern(FuncOp funcOp) { MLIRContext *ctx = funcOp.getContext(); RewritePatternSet patterns(ctx); diff --git a/mlir/test/Dialect/Linalg/tile-pad-tensor-op.mlir b/mlir/test/Dialect/Linalg/tile-pad-tensor-op.mlir index 46fe369b66026..a83793544078c 100644 --- a/mlir/test/Dialect/Linalg/tile-pad-tensor-op.mlir +++ b/mlir/test/Dialect/Linalg/tile-pad-tensor-op.mlir @@ -2,6 +2,8 @@ // RUN: FileCheck %s -check-prefix=TILE2 // RUN: mlir-opt %s -linalg-tile="tile-sizes=0,3" -resolve-shaped-type-result-dims -cse -split-input-file | \ // RUN: FileCheck %s -check-prefix=TILE1 +// This test only checks that tiling does not crash. +// RUN: mlir-opt %s -linalg-tile="tile-sizes=2" -resolve-shaped-type-result-dims -cse -split-input-file // TILE2-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 + 8)> // TILE2-DAG: #[[MAP1:.*]] = affine_map<()[s0] -> (s0 + 7)>