Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Global opt] add flag to generalize batch matmul ops #17877

Merged
merged 3 commits into from
Jul 12, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
#include "iree/compiler/GlobalOptimization/PassDetail.h"
#include "iree/compiler/GlobalOptimization/Passes.h"
#include "llvm/Support/Casting.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Pass/Pass.h"
Expand All @@ -24,6 +25,10 @@ namespace {
struct GeneralizeLinalgNamedOpsPass
: public GeneralizeLinalgNamedOpsBase<GeneralizeLinalgNamedOpsPass> {

GeneralizeLinalgNamedOpsPass(bool generalizeLinalgMatmulOps) {
this->generalizeLinalgMatmulOps = generalizeLinalgMatmulOps;
}

void runOnOperation() override;
};
} // namespace
Expand All @@ -45,6 +50,11 @@ void GeneralizeLinalgNamedOpsPass::runOnOperation() {
linalgOp.getOperation())) {
namedOpCandidates.push_back(linalgOp);
}
if (generalizeLinalgMatmulOps &&
isa_and_nonnull<linalg::BatchMatmulOp, linalg::BatchMatmulTransposeBOp>(
IanWood1 marked this conversation as resolved.
Show resolved Hide resolved
linalgOp)) {
namedOpCandidates.push_back(linalgOp);
}
});

IRRewriter rewriter(&getContext());
Expand All @@ -60,8 +70,9 @@ void GeneralizeLinalgNamedOpsPass::runOnOperation() {
}

std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createGeneralizeLinalgNamedOpsPass() {
return std::make_unique<GeneralizeLinalgNamedOpsPass>();
createGeneralizeLinalgNamedOpsPass(bool generalizeLinalgMatmulOps) {
return std::make_unique<GeneralizeLinalgNamedOpsPass>(
generalizeLinalgMatmulOps);
}

} // namespace mlir::iree_compiler::GlobalOptimization
12 changes: 10 additions & 2 deletions compiler/src/iree/compiler/GlobalOptimization/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ static llvm::cl::opt<bool> clEnableQuantizedMatmulReassociation(
llvm::cl::desc(
"Enables reassociation of quantized matmul ops (experimental)."),
llvm::cl::init(false));
static llvm::cl::opt<bool>
clGeneralizeLinalgMatmulOps("enable-generalize-linalg-matmul-ops",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs to be codified somewhere, but there is some thoughts about adding -experimental to flag names that are not really intended to be user facing (i.e. papering over backend issues that are WIP), per this discussion: #17788 (comment). Maybe something like this, but open to suggestions.

Suggested change
clGeneralizeLinalgMatmulOps("enable-generalize-linalg-matmul-ops",
clGeneralizeLinalgMatmulOps("iree-global-opt-generalize-linalg-matmul-ops-experimental",

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thats a good idea, added that to the name

llvm::cl::desc("Generalize linalg MatMul ops"),
llvm::cl::init(false));
static llvm::cl::opt<bool> clEnableFuseSiluHorizontalMatmul(
"iree-global-opt-enable-fuse-silu-horizontal-matmul",
llvm::cl::desc(
Expand Down Expand Up @@ -122,7 +126,9 @@ void buildGlobalOptimizationPassPipeline(
// dims as the unit dim folding pass updates indexing maps and is better
// at working with generics. By this point we have already done any
// specialized raising and the op names are no longer useful.
.addPass(createGeneralizeLinalgNamedOpsPass);
.addPass([&]() {
return createGeneralizeLinalgNamedOpsPass(clGeneralizeLinalgMatmulOps);
});

mainPassManager.addPass(IREE::Flow::createFoldUnitExtentDimsPass());
FunctionLikeNest(mainPassManager)
Expand Down Expand Up @@ -178,7 +184,9 @@ void buildGlobalOptimizationPassPipeline(
}
// Generalize transposes and any other remaining named linalg ops that can
// now be represented as generics.
FunctionLikeNest(mainPassManager).addPass(createGeneralizeLinalgNamedOpsPass);
FunctionLikeNest(mainPassManager).addPass([&]() {
return createGeneralizeLinalgNamedOpsPass(clGeneralizeLinalgMatmulOps);
});

// Hoist loop invariants (e.g. from scf loops) with zero-trip-check.
FunctionLikeNest(mainPassManager)
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/iree/compiler/GlobalOptimization/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ createFuseSiluHorizontalMatmulPass();
/// Generalizes some named Linalg ops into `linalg.generic` operations since the
/// compiler can handle that better.
std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createGeneralizeLinalgNamedOpsPass();
createGeneralizeLinalgNamedOpsPass(bool generalizeLinalgMatmulOps = false);

/// Infers and inserts util.numeric.optional_narrow ops at points that may be
/// beneficial.
Expand Down
4 changes: 4 additions & 0 deletions compiler/src/iree/compiler/GlobalOptimization/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,10 @@ def GeneralizeLinalgNamedOps :
InterfacePass<"iree-global-opt-generalize-linalg-named-ops", "mlir::FunctionOpInterface"> {
let summary = "Convert some Linalg named ops into linalg.generics.";
let constructor = "mlir::iree_compiler::GlobalOptimization::createGeneralizeLinalgNamedOpsPass()";
let options = [
Option<"generalizeLinalgMatmulOps", "enable-generalize-linalg-matmul-ops", "bool",
/*default=*/"false", "Generalize linalg batch MatMul ops">,
];
}

def InferNumericNarrowing :
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: iree-opt --pass-pipeline="builtin.module(util.func(iree-global-opt-generalize-linalg-named-ops))" --split-input-file %s | FileCheck %s
// RUN: iree-opt --pass-pipeline="builtin.module(util.func(iree-global-opt-generalize-linalg-named-ops{enable-generalize-linalg-matmul-ops=true}))" --split-input-file %s | FileCheck %s

IanWood1 marked this conversation as resolved.
Show resolved Hide resolved
util.func public @generalize_op(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>) -> tensor<?x?xf32> {
%c0 = arith.constant 0 : index
Expand Down Expand Up @@ -34,3 +34,16 @@ util.func public @no_generalize_op_within_dispatch(%arg0 : tensor<?x?xf32>, %arg
// CHECK: %[[ADD:.+]] = linalg.add
// CHECK: flow.return %[[ADD]]
// CHECK: util.return %[[DISPATCH]]

// -----

util.func public @generalize_matmul(%arg0: tensor<1x128x128xf32>, %arg1: tensor<1x128x128xf32>) -> tensor<1x128x128xf32> {
%0 = tensor.empty() : tensor<1x128x128xf32>
%1 = linalg.batch_matmul ins(%arg0, %arg1: tensor<1x128x128xf32>, tensor<1x128x128xf32>) outs(%0 : tensor<1x128x128xf32>) -> tensor<1x128x128xf32>
util.return %1 : tensor<1x128x128xf32>
}

// CHECK-LABEL: util.func public @generalize_matmul
// CHECK-SAME: %[[ARG0:.+]]: tensor<1x128x128xf32>, %[[ARG1:.+]]: tensor<1x128x128xf32>
// CHECK: %[[GENERIC:.+]] = linalg.generic
// CHECK-SAME: %[[ARG0]], %[[ARG1]]
Loading