diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index c89fc59c91830..d00183a1e16a1 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -653,6 +653,9 @@ struct PadTilingInterfaceResult { // interpreted as the bounding box (dynamic) value to pad to. /// * Use "options.paddingValues" to set the padding value of the created // tensor::PadOp. +// +// The transformation assumes that the insertion point is set after the +// operation to pad. FailureOr rewriteAsPaddedOp(OpBuilder &, TilingInterface toPad, PadTilingInterfaceOptions options, diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index 794dda96d1dfa..8b89244486339 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -2464,6 +2464,8 @@ transform::PadTilingInterfaceOp::apply(transform::TransformRewriter &rewriter, .setPaddingSizes(getMixedPaddingSizes()) .setPadToMultipleOf(getPadToMultipleOf()); + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPointAfter(targetOp); auto maybePadOps = rewriteAsPaddedOp( rewriter, cast(targetOp.getOperation()), options); if (failed(maybePadOps)) { diff --git a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp index 3e787a2ad0ef5..52ab92f180575 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp @@ -288,10 +288,6 @@ FailureOr linalg::rewriteAsPaddedOp( return failure(); } - OpBuilder::InsertionGuard g(builder); - // Set IP after toPad because we also take the dims of toPad's output. - builder.setInsertionPointAfter(toPad); - // 1. Get the loopUpperBounds from the TilingInterface. SmallVector iterationDomain = toPad.getIterationDomain(builder);