From 1e1d51eec53399ec5a02fc2406cfe238da1aabd6 Mon Sep 17 00:00:00 2001 From: Ian Wood Date: Thu, 11 Jul 2024 21:54:51 +0000 Subject: [PATCH 1/3] add flag to generalize batch matmul ops Signed-off-by: Ian Wood --- .../GeneralizeLinalgNamedOps.cpp | 15 +++++++++++++-- .../iree/compiler/GlobalOptimization/Passes.cpp | 12 ++++++++++-- .../src/iree/compiler/GlobalOptimization/Passes.h | 2 +- .../iree/compiler/GlobalOptimization/Passes.td | 4 ++++ .../test/generalize_named_ops.mlir | 15 ++++++++++++++- 5 files changed, 42 insertions(+), 6 deletions(-) diff --git a/compiler/src/iree/compiler/GlobalOptimization/GeneralizeLinalgNamedOps.cpp b/compiler/src/iree/compiler/GlobalOptimization/GeneralizeLinalgNamedOps.cpp index 5f30902c7e2f..b74c435f0b95 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/GeneralizeLinalgNamedOps.cpp +++ b/compiler/src/iree/compiler/GlobalOptimization/GeneralizeLinalgNamedOps.cpp @@ -14,6 +14,7 @@ #include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h" #include "iree/compiler/GlobalOptimization/PassDetail.h" #include "iree/compiler/GlobalOptimization/Passes.h" +#include "llvm/Support/Casting.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Pass/Pass.h" @@ -24,6 +25,10 @@ namespace { struct GeneralizeLinalgNamedOpsPass : public GeneralizeLinalgNamedOpsBase { + GeneralizeLinalgNamedOpsPass(bool generalizeLinalgMatmulOps) { + this->generalizeLinalgMatmulOps = generalizeLinalgMatmulOps; + } + void runOnOperation() override; }; } // namespace @@ -45,6 +50,11 @@ void GeneralizeLinalgNamedOpsPass::runOnOperation() { linalgOp.getOperation())) { namedOpCandidates.push_back(linalgOp); } + if (generalizeLinalgMatmulOps && + isa_and_nonnull( + linalgOp)) { + namedOpCandidates.push_back(linalgOp); + } }); IRRewriter rewriter(&getContext()); @@ -60,8 +70,9 @@ void GeneralizeLinalgNamedOpsPass::runOnOperation() { } std::unique_ptr> -createGeneralizeLinalgNamedOpsPass() { - return std::make_unique(); +createGeneralizeLinalgNamedOpsPass(bool generalizeLinalgMatmulOps) { + return std::make_unique( + generalizeLinalgMatmulOps); } } // namespace mlir::iree_compiler::GlobalOptimization diff --git a/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp b/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp index bab60f67bb91..9bbfa7e7baf2 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp +++ b/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp @@ -23,6 +23,10 @@ static llvm::cl::opt clEnableQuantizedMatmulReassociation( llvm::cl::desc( "Enables reassociation of quantized matmul ops (experimental)."), llvm::cl::init(false)); +static llvm::cl::opt + clGeneralizeLinalgMatmulOps("enable-generalize-linalg-matmul-ops", + llvm::cl::desc("Generalize linalg MatMul ops"), + llvm::cl::init(false)); static llvm::cl::opt clEnableFuseSiluHorizontalMatmul( "iree-global-opt-enable-fuse-silu-horizontal-matmul", llvm::cl::desc( @@ -122,7 +126,9 @@ void buildGlobalOptimizationPassPipeline( // dims as the unit dim folding pass updates indexing maps and is better // at working with generics. By this point we have already done any // specialized raising and the op names are no longer useful. - .addPass(createGeneralizeLinalgNamedOpsPass); + .addPass([&]() { + return createGeneralizeLinalgNamedOpsPass(clGeneralizeLinalgMatmulOps); + }); mainPassManager.addPass(IREE::Flow::createFoldUnitExtentDimsPass()); FunctionLikeNest(mainPassManager) @@ -178,7 +184,9 @@ void buildGlobalOptimizationPassPipeline( } // Generalize transposes and any other remaining named linalg ops that can // now be represented as generics. - FunctionLikeNest(mainPassManager).addPass(createGeneralizeLinalgNamedOpsPass); + FunctionLikeNest(mainPassManager).addPass([&]() { + return createGeneralizeLinalgNamedOpsPass(clGeneralizeLinalgMatmulOps); + }); // Hoist loop invariants (e.g. from scf loops) with zero-trip-check. FunctionLikeNest(mainPassManager) diff --git a/compiler/src/iree/compiler/GlobalOptimization/Passes.h b/compiler/src/iree/compiler/GlobalOptimization/Passes.h index 6ddbeaba1cb3..8156faa470b3 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/Passes.h +++ b/compiler/src/iree/compiler/GlobalOptimization/Passes.h @@ -86,7 +86,7 @@ createFuseSiluHorizontalMatmulPass(); /// Generalizes some named Linalg ops into `linalg.generic` operations since the /// compiler can handle that better. std::unique_ptr> -createGeneralizeLinalgNamedOpsPass(); +createGeneralizeLinalgNamedOpsPass(bool generalizeLinalgMatmulOps = false); /// Infers and inserts util.numeric.optional_narrow ops at points that may be /// beneficial. diff --git a/compiler/src/iree/compiler/GlobalOptimization/Passes.td b/compiler/src/iree/compiler/GlobalOptimization/Passes.td index 0f3bcd336229..c80a0612154f 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/Passes.td +++ b/compiler/src/iree/compiler/GlobalOptimization/Passes.td @@ -105,6 +105,10 @@ def GeneralizeLinalgNamedOps : InterfacePass<"iree-global-opt-generalize-linalg-named-ops", "mlir::FunctionOpInterface"> { let summary = "Convert some Linalg named ops into linalg.generics."; let constructor = "mlir::iree_compiler::GlobalOptimization::createGeneralizeLinalgNamedOpsPass()"; + let options = [ + Option<"generalizeLinalgMatmulOps", "enable-generalize-linalg-matmul-ops", "bool", + /*default=*/"false", "Generalize linalg batch MatMul ops">, + ]; } def InferNumericNarrowing : diff --git a/compiler/src/iree/compiler/GlobalOptimization/test/generalize_named_ops.mlir b/compiler/src/iree/compiler/GlobalOptimization/test/generalize_named_ops.mlir index 5111152b7b0d..0c371df3ee58 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/test/generalize_named_ops.mlir +++ b/compiler/src/iree/compiler/GlobalOptimization/test/generalize_named_ops.mlir @@ -1,4 +1,4 @@ -// RUN: iree-opt --pass-pipeline="builtin.module(util.func(iree-global-opt-generalize-linalg-named-ops))" --split-input-file %s | FileCheck %s +// RUN: iree-opt --pass-pipeline="builtin.module(util.func(iree-global-opt-generalize-linalg-named-ops{enable-generalize-linalg-matmul-ops=true}))" --split-input-file %s | FileCheck %s util.func public @generalize_op(%arg0 : tensor, %arg1 : tensor) -> tensor { %c0 = arith.constant 0 : index @@ -34,3 +34,16 @@ util.func public @no_generalize_op_within_dispatch(%arg0 : tensor, %arg // CHECK: %[[ADD:.+]] = linalg.add // CHECK: flow.return %[[ADD]] // CHECK: util.return %[[DISPATCH]] + +// ----- + +util.func public @generalize_matmul(%arg0: tensor<1x128x128xf32>, %arg1: tensor<1x128x128xf32>) -> tensor<1x128x128xf32> { + %0 = tensor.empty() : tensor<1x128x128xf32> + %1 = linalg.batch_matmul ins(%arg0, %arg1: tensor<1x128x128xf32>, tensor<1x128x128xf32>) outs(%0 : tensor<1x128x128xf32>) -> tensor<1x128x128xf32> + util.return %1 : tensor<1x128x128xf32> +} + +// CHECK-LABEL: util.func public @generalize_matmul +// CHECK-SAME: %[[ARG0:.+]]: tensor<1x128x128xf32>, %[[ARG1:.+]]: tensor<1x128x128xf32> +// CHECK: %[[GENERIC:.+]] = linalg.generic +// CHECK-SAME: %[[ARG0]], %[[ARG1]] From db0a3a2a63f4889f2bf62ad346bc18696ffe9d9e Mon Sep 17 00:00:00 2001 From: Ian Wood Date: Thu, 11 Jul 2024 23:27:55 +0000 Subject: [PATCH 2/3] Revert "add flag to generalize batch matmul ops" This reverts commit 1e1d51eec53399ec5a02fc2406cfe238da1aabd6. Signed-off-by: Ian Wood --- .../GeneralizeLinalgNamedOps.cpp | 15 ++------------- .../iree/compiler/GlobalOptimization/Passes.cpp | 12 ++---------- .../src/iree/compiler/GlobalOptimization/Passes.h | 2 +- .../iree/compiler/GlobalOptimization/Passes.td | 4 ---- .../test/generalize_named_ops.mlir | 15 +-------------- 5 files changed, 6 insertions(+), 42 deletions(-) diff --git a/compiler/src/iree/compiler/GlobalOptimization/GeneralizeLinalgNamedOps.cpp b/compiler/src/iree/compiler/GlobalOptimization/GeneralizeLinalgNamedOps.cpp index b74c435f0b95..5f30902c7e2f 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/GeneralizeLinalgNamedOps.cpp +++ b/compiler/src/iree/compiler/GlobalOptimization/GeneralizeLinalgNamedOps.cpp @@ -14,7 +14,6 @@ #include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h" #include "iree/compiler/GlobalOptimization/PassDetail.h" #include "iree/compiler/GlobalOptimization/Passes.h" -#include "llvm/Support/Casting.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Pass/Pass.h" @@ -25,10 +24,6 @@ namespace { struct GeneralizeLinalgNamedOpsPass : public GeneralizeLinalgNamedOpsBase { - GeneralizeLinalgNamedOpsPass(bool generalizeLinalgMatmulOps) { - this->generalizeLinalgMatmulOps = generalizeLinalgMatmulOps; - } - void runOnOperation() override; }; } // namespace @@ -50,11 +45,6 @@ void GeneralizeLinalgNamedOpsPass::runOnOperation() { linalgOp.getOperation())) { namedOpCandidates.push_back(linalgOp); } - if (generalizeLinalgMatmulOps && - isa_and_nonnull( - linalgOp)) { - namedOpCandidates.push_back(linalgOp); - } }); IRRewriter rewriter(&getContext()); @@ -70,9 +60,8 @@ void GeneralizeLinalgNamedOpsPass::runOnOperation() { } std::unique_ptr> -createGeneralizeLinalgNamedOpsPass(bool generalizeLinalgMatmulOps) { - return std::make_unique( - generalizeLinalgMatmulOps); +createGeneralizeLinalgNamedOpsPass() { + return std::make_unique(); } } // namespace mlir::iree_compiler::GlobalOptimization diff --git a/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp b/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp index 9bbfa7e7baf2..bab60f67bb91 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp +++ b/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp @@ -23,10 +23,6 @@ static llvm::cl::opt clEnableQuantizedMatmulReassociation( llvm::cl::desc( "Enables reassociation of quantized matmul ops (experimental)."), llvm::cl::init(false)); -static llvm::cl::opt - clGeneralizeLinalgMatmulOps("enable-generalize-linalg-matmul-ops", - llvm::cl::desc("Generalize linalg MatMul ops"), - llvm::cl::init(false)); static llvm::cl::opt clEnableFuseSiluHorizontalMatmul( "iree-global-opt-enable-fuse-silu-horizontal-matmul", llvm::cl::desc( @@ -126,9 +122,7 @@ void buildGlobalOptimizationPassPipeline( // dims as the unit dim folding pass updates indexing maps and is better // at working with generics. By this point we have already done any // specialized raising and the op names are no longer useful. - .addPass([&]() { - return createGeneralizeLinalgNamedOpsPass(clGeneralizeLinalgMatmulOps); - }); + .addPass(createGeneralizeLinalgNamedOpsPass); mainPassManager.addPass(IREE::Flow::createFoldUnitExtentDimsPass()); FunctionLikeNest(mainPassManager) @@ -184,9 +178,7 @@ void buildGlobalOptimizationPassPipeline( } // Generalize transposes and any other remaining named linalg ops that can // now be represented as generics. - FunctionLikeNest(mainPassManager).addPass([&]() { - return createGeneralizeLinalgNamedOpsPass(clGeneralizeLinalgMatmulOps); - }); + FunctionLikeNest(mainPassManager).addPass(createGeneralizeLinalgNamedOpsPass); // Hoist loop invariants (e.g. from scf loops) with zero-trip-check. FunctionLikeNest(mainPassManager) diff --git a/compiler/src/iree/compiler/GlobalOptimization/Passes.h b/compiler/src/iree/compiler/GlobalOptimization/Passes.h index 8156faa470b3..6ddbeaba1cb3 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/Passes.h +++ b/compiler/src/iree/compiler/GlobalOptimization/Passes.h @@ -86,7 +86,7 @@ createFuseSiluHorizontalMatmulPass(); /// Generalizes some named Linalg ops into `linalg.generic` operations since the /// compiler can handle that better. std::unique_ptr> -createGeneralizeLinalgNamedOpsPass(bool generalizeLinalgMatmulOps = false); +createGeneralizeLinalgNamedOpsPass(); /// Infers and inserts util.numeric.optional_narrow ops at points that may be /// beneficial. diff --git a/compiler/src/iree/compiler/GlobalOptimization/Passes.td b/compiler/src/iree/compiler/GlobalOptimization/Passes.td index c80a0612154f..0f3bcd336229 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/Passes.td +++ b/compiler/src/iree/compiler/GlobalOptimization/Passes.td @@ -105,10 +105,6 @@ def GeneralizeLinalgNamedOps : InterfacePass<"iree-global-opt-generalize-linalg-named-ops", "mlir::FunctionOpInterface"> { let summary = "Convert some Linalg named ops into linalg.generics."; let constructor = "mlir::iree_compiler::GlobalOptimization::createGeneralizeLinalgNamedOpsPass()"; - let options = [ - Option<"generalizeLinalgMatmulOps", "enable-generalize-linalg-matmul-ops", "bool", - /*default=*/"false", "Generalize linalg batch MatMul ops">, - ]; } def InferNumericNarrowing : diff --git a/compiler/src/iree/compiler/GlobalOptimization/test/generalize_named_ops.mlir b/compiler/src/iree/compiler/GlobalOptimization/test/generalize_named_ops.mlir index 0c371df3ee58..5111152b7b0d 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/test/generalize_named_ops.mlir +++ b/compiler/src/iree/compiler/GlobalOptimization/test/generalize_named_ops.mlir @@ -1,4 +1,4 @@ -// RUN: iree-opt --pass-pipeline="builtin.module(util.func(iree-global-opt-generalize-linalg-named-ops{enable-generalize-linalg-matmul-ops=true}))" --split-input-file %s | FileCheck %s +// RUN: iree-opt --pass-pipeline="builtin.module(util.func(iree-global-opt-generalize-linalg-named-ops))" --split-input-file %s | FileCheck %s util.func public @generalize_op(%arg0 : tensor, %arg1 : tensor) -> tensor { %c0 = arith.constant 0 : index @@ -34,16 +34,3 @@ util.func public @no_generalize_op_within_dispatch(%arg0 : tensor, %arg // CHECK: %[[ADD:.+]] = linalg.add // CHECK: flow.return %[[ADD]] // CHECK: util.return %[[DISPATCH]] - -// ----- - -util.func public @generalize_matmul(%arg0: tensor<1x128x128xf32>, %arg1: tensor<1x128x128xf32>) -> tensor<1x128x128xf32> { - %0 = tensor.empty() : tensor<1x128x128xf32> - %1 = linalg.batch_matmul ins(%arg0, %arg1: tensor<1x128x128xf32>, tensor<1x128x128xf32>) outs(%0 : tensor<1x128x128xf32>) -> tensor<1x128x128xf32> - util.return %1 : tensor<1x128x128xf32> -} - -// CHECK-LABEL: util.func public @generalize_matmul -// CHECK-SAME: %[[ARG0:.+]]: tensor<1x128x128xf32>, %[[ARG1:.+]]: tensor<1x128x128xf32> -// CHECK: %[[GENERIC:.+]] = linalg.generic -// CHECK-SAME: %[[ARG0]], %[[ARG1]] From 367b7c0ec50f9cac2cd410a4d8da0653984d0d44 Mon Sep 17 00:00:00 2001 From: Ian Wood Date: Fri, 12 Jul 2024 00:02:11 +0000 Subject: [PATCH 3/3] move to preprocessing and change cli flag name Signed-off-by: Ian Wood --- .../compiler/Preprocessing/Common/BUILD.bazel | 1 + .../Preprocessing/Common/CMakeLists.txt | 1 + .../Common/GeneralizeLinalgMatMul.cpp | 54 +++++++++++++++++++ .../compiler/Preprocessing/Common/Passes.td | 8 +++ .../Preprocessing/Common/test/BUILD.bazel | 1 + .../Preprocessing/Common/test/CMakeLists.txt | 1 + .../Common/test/generalize_linalg_matmul.mlir | 12 +++++ 7 files changed, 78 insertions(+) create mode 100644 compiler/src/iree/compiler/Preprocessing/Common/GeneralizeLinalgMatMul.cpp create mode 100644 compiler/src/iree/compiler/Preprocessing/Common/test/generalize_linalg_matmul.mlir diff --git a/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel b/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel index e004a550f728..1692c78bf800 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel +++ b/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel @@ -33,6 +33,7 @@ iree_compiler_cc_library( "ApplyPDLPatterns.cpp", "ConvertConv2DToImg2Col.cpp", "ConvertConvToChannelsLast.cpp", + "GeneralizeLinalgMatMul.cpp", "InterpreterPass.cpp", "MakeSingleDispatchForFunction.cpp", "PadLinalgOps.cpp", diff --git a/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt b/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt index c9c127ccca57..4613d4bb404b 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt +++ b/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt @@ -29,6 +29,7 @@ iree_cc_library( "ApplyPDLPatterns.cpp" "ConvertConv2DToImg2Col.cpp" "ConvertConvToChannelsLast.cpp" + "GeneralizeLinalgMatMul.cpp" "InterpreterPass.cpp" "MakeSingleDispatchForFunction.cpp" "PadLinalgOps.cpp" diff --git a/compiler/src/iree/compiler/Preprocessing/Common/GeneralizeLinalgMatMul.cpp b/compiler/src/iree/compiler/Preprocessing/Common/GeneralizeLinalgMatMul.cpp new file mode 100644 index 000000000000..a533339875e0 --- /dev/null +++ b/compiler/src/iree/compiler/Preprocessing/Common/GeneralizeLinalgMatMul.cpp @@ -0,0 +1,54 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Interfaces/FunctionInterfaces.h" +#include "mlir/Pass/Pass.h" + +namespace mlir::iree_compiler::Preprocessing { + +#define GEN_PASS_DEF_GENERALIZELINALGMATMULPASS +#include "iree/compiler/Preprocessing/Common/Passes.h.inc" // IWYU pragma: export + +namespace { + +struct GeneralizeLinalgMatMulPass + : public iree_compiler::Preprocessing::impl::GeneralizeLinalgMatMulPassBase< + GeneralizeLinalgMatMulPass> { + using iree_compiler::Preprocessing::impl::GeneralizeLinalgMatMulPassBase< + GeneralizeLinalgMatMulPass>::GeneralizeLinalgMatMulPassBase; + void runOnOperation() override { + auto funcOp = getOperation(); + SmallVector namedOpCandidates; + funcOp.walk([&](linalg::LinalgOp linalgOp) { + if (!IREE::Flow::isNonNullAndOutsideDispatch(linalgOp)) { + return; + } + if (isa_and_nonnull(linalgOp)) { + namedOpCandidates.push_back(linalgOp); + } + }); + + IRRewriter rewriter(&getContext()); + + for (auto linalgOp : namedOpCandidates) { + rewriter.setInsertionPoint(linalgOp); + FailureOr generalizedOp = + linalg::generalizeNamedOp(rewriter, linalgOp); + if (failed(generalizedOp)) { + linalgOp->emitOpError("failed to generalize operation"); + return signalPassFailure(); + } + } + } +}; +} // namespace +} // namespace mlir::iree_compiler::Preprocessing diff --git a/compiler/src/iree/compiler/Preprocessing/Common/Passes.td b/compiler/src/iree/compiler/Preprocessing/Common/Passes.td index e3316f09a653..e4921b81fe88 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/Passes.td +++ b/compiler/src/iree/compiler/Preprocessing/Common/Passes.td @@ -132,4 +132,12 @@ def TransposeMatmulPass : Pass<"iree-preprocessing-transpose-matmul-pass"> { ]; } +def GeneralizeLinalgMatMulPass : + InterfacePass<"iree-preprocessing-generalize-linalg-matmul-experimental", "mlir::FunctionOpInterface"> { + let summary = "Convert linalg matmul ops to linalg.generics."; + let dependentDialects = [ + "mlir::linalg::LinalgDialect", + ]; +} + #endif // IREE_PREPROCESSING_COMMON_PASSES diff --git a/compiler/src/iree/compiler/Preprocessing/Common/test/BUILD.bazel b/compiler/src/iree/compiler/Preprocessing/Common/test/BUILD.bazel index 3a5324f80696..54ebb1176caa 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Preprocessing/Common/test/BUILD.bazel @@ -18,6 +18,7 @@ iree_lit_test_suite( [ "conv2d_to_img2col.mlir", "conv_to_channels_last.mlir", + "generalize_linalg_matmul.mlir", "make_single_dispatch_for_function.mlir", "pad_linalg_ops.mlir", "pad_to_intrinsics_mfma.mlir", diff --git a/compiler/src/iree/compiler/Preprocessing/Common/test/CMakeLists.txt b/compiler/src/iree/compiler/Preprocessing/Common/test/CMakeLists.txt index a09c135dfe09..03c92b7423bc 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Preprocessing/Common/test/CMakeLists.txt @@ -16,6 +16,7 @@ iree_lit_test_suite( SRCS "conv2d_to_img2col.mlir" "conv_to_channels_last.mlir" + "generalize_linalg_matmul.mlir" "make_single_dispatch_for_function.mlir" "pad_linalg_ops.mlir" "pad_to_intrinsics_mfma.mlir" diff --git a/compiler/src/iree/compiler/Preprocessing/Common/test/generalize_linalg_matmul.mlir b/compiler/src/iree/compiler/Preprocessing/Common/test/generalize_linalg_matmul.mlir new file mode 100644 index 000000000000..bb1949c17ef1 --- /dev/null +++ b/compiler/src/iree/compiler/Preprocessing/Common/test/generalize_linalg_matmul.mlir @@ -0,0 +1,12 @@ +// RUN: iree-opt --pass-pipeline="builtin.module(util.func(iree-preprocessing-generalize-linalg-matmul-experimental))" --verify-each --split-input-file %s | FileCheck %s + +util.func public @generalize_matmul(%arg0: tensor<1x128x128xf32>, %arg1: tensor<1x128x128xf32>) -> tensor<1x128x128xf32> { + %0 = tensor.empty() : tensor<1x128x128xf32> + %1 = linalg.batch_matmul ins(%arg0, %arg1: tensor<1x128x128xf32>, tensor<1x128x128xf32>) outs(%0 : tensor<1x128x128xf32>) -> tensor<1x128x128xf32> + util.return %1 : tensor<1x128x128xf32> +} + +// CHECK-LABEL: util.func public @generalize_matmul +// CHECK-SAME: %[[ARG0:.+]]: tensor<1x128x128xf32>, %[[ARG1:.+]]: tensor<1x128x128xf32> +// CHECK: %[[GENERIC:.+]] = linalg.generic +// CHECK-SAME: %[[ARG0]], %[[ARG1]]