diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td index caec229207ea6..12e228bcaeefa 100644 --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -143,6 +143,28 @@ def ApplyFoldPackUnpackIntoEmptyPatternsOp : Op]> { + let description = [{ + Collection of patterns to bubble up or down data layout ops across other + operations. + }]; + + let arguments = (ins DefaultValuedAttr:$poison_padding); + let assemblyFormat = "attr-dict"; +} + +def ApplyExtractSliceSinkingPatternsOp : Op]> { + let description = [{ + Patterns to sink extract slice across other operations. + }]; + + let assemblyFormat = "attr-dict"; +} + //===----------------------------------------------------------------------===// // BufferizeToAllocationOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index 309a4d989465d..ae65afac6b54c 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -272,6 +272,26 @@ void transform::ApplyFoldPackUnpackIntoEmptyPatternsOp::populatePatterns( linalg::populateFoldPackUnpackIntoTensorEmptyPatterns(patterns); } +void transform::ApplyDataLayoutPropagationPatternsOp::populatePatterns( + RewritePatternSet &patterns) { + linalg::ControlPropagationFn defaultControlFn = [](OpOperand *operand) { + return true; + }; + linalg::populateDataLayoutPropagationPatterns(patterns, defaultControlFn, + getPoisonPadding()); +} + +void transform::ApplyExtractSliceSinkingPatternsOp::populatePatterns( + RewritePatternSet &patterns) { + linalg::ControlPropagationFn defaultControlFn = + [](OpOperand *opOperand) -> bool { + Operation *producer = opOperand->get().getDefiningOp(); + Operation *consumer = opOperand->getOwner(); + return consumer->getBlock() == producer->getBlock(); + }; + linalg::populateExtractSliceSinkingPatterns(patterns, defaultControlFn); +} + //===----------------------------------------------------------------------===// // BufferizeToAllocationOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir index ec34a02096d5f..af6f70637d657 100644 --- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir +++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir @@ -1,4 +1,7 @@ -// RUN: mlir-opt %s -test-linalg-data-layout-propagation -split-input-file | FileCheck %s +// RUN: mlir-opt %s -split-input-file \ +// RUN: -transform-preload-library='transform-library-paths=%p/td/propagate-data-layout.mlir' \ +// RUN: -transform-interpreter=entry-point=propagate_data_layout \ +// RUN: | FileCheck %s #map0 = affine_map<(d0, d1) -> (d0, d1)> func.func @dynamic_elem_pack(%arg0: tensor, %dest: tensor) -> tensor diff --git a/mlir/test/Dialect/Linalg/td/propagate-data-layout.mlir b/mlir/test/Dialect/Linalg/td/propagate-data-layout.mlir new file mode 100644 index 0000000000000..5167cbe0f1f2e --- /dev/null +++ b/mlir/test/Dialect/Linalg/td/propagate-data-layout.mlir @@ -0,0 +1,12 @@ +module @transforms attributes { transform.with_named_sequence } { + transform.named_sequence @propagate_data_layout(%module: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %module : (!transform.any_op) -> !transform.any_op + + transform.apply_patterns to %func { + transform.apply_patterns.linalg.data_layout_propagation {poison_padding = true} + transform.apply_patterns.linalg.extract_slice_sinking + } : !transform.any_op + + transform.yield + } +} diff --git a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt index eb6f581252181..6549c1fec10a7 100644 --- a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt @@ -1,6 +1,5 @@ # Exclude tests from libMLIR.so add_mlir_library(MLIRLinalgTestPasses - TestDataLayoutPropagation.cpp TestLinalgDecomposeOps.cpp TestLinalgDropUnitDims.cpp TestLinalgElementwiseFusion.cpp diff --git a/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp b/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp deleted file mode 100644 index d45aaf788f9c2..0000000000000 --- a/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp +++ /dev/null @@ -1,57 +0,0 @@ -//===- TestDataLayoutPropagation.cpp --------------------------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include "mlir/Dialect/Affine/IR/AffineOps.h" -#include "mlir/Dialect/Linalg/Transforms/Transforms.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Pass/PassManager.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" - -using namespace mlir; - -namespace { -struct TestDataLayoutPropagationPass - : public PassWrapper> { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDataLayoutPropagationPass) - - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); - } - - StringRef getArgument() const final { - return "test-linalg-data-layout-propagation"; - } - StringRef getDescription() const final { - return "Test data layout propagation"; - } - - void runOnOperation() override { - MLIRContext *context = &getContext(); - RewritePatternSet patterns(context); - linalg::populateDataLayoutPropagationPatterns( - patterns, [](OpOperand *opOperand) { return true; }, - /*poisonPaddingOk=*/true); - linalg::ControlPropagationFn controlExtract = - [](OpOperand *opOperand) -> bool { - Operation *producer = opOperand->get().getDefiningOp(); - Operation *consumer = opOperand->getOwner(); - return consumer->getBlock() == producer->getBlock(); - }; - linalg::populateExtractSliceSinkingPatterns(patterns, controlExtract); - if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) - return signalPassFailure(); - } -}; -} // namespace - -namespace mlir { -namespace test { -void registerTestDataLayoutPropagation() { - PassRegistration(); -} -} // namespace test -} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp index a427132247e6d..7116441657c72 100644 --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -89,7 +89,6 @@ void registerTestComposeSubView(); void registerTestCompositePass(); void registerTestControlFlowSink(); void registerTestConvertToSPIRVPass(); -void registerTestDataLayoutPropagation(); void registerTestDataLayoutQuery(); void registerTestDeadCodeAnalysisPass(); void registerTestDecomposeCallGraphTypes(); @@ -239,7 +238,6 @@ static void registerTestPasses() { mlir::test::registerTestCompositePass(); mlir::test::registerTestControlFlowSink(); mlir::test::registerTestConvertToSPIRVPass(); - mlir::test::registerTestDataLayoutPropagation(); mlir::test::registerTestDataLayoutQuery(); mlir::test::registerTestDeadCodeAnalysisPass(); mlir::test::registerTestDecomposeCallGraphTypes();