Skip to content

Commit

Permalink
add flag to generalize batch matmul ops
Browse files Browse the repository at this point in the history
Signed-off-by: Ian Wood <ianwood2024@u.northwestern.edu>
  • Loading branch information
IanWood1 committed Jul 11, 2024
1 parent 534928d commit 1e1d51e
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
#include "iree/compiler/GlobalOptimization/PassDetail.h"
#include "iree/compiler/GlobalOptimization/Passes.h"
#include "llvm/Support/Casting.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Pass/Pass.h"
Expand All @@ -24,6 +25,10 @@ namespace {
struct GeneralizeLinalgNamedOpsPass
: public GeneralizeLinalgNamedOpsBase<GeneralizeLinalgNamedOpsPass> {

GeneralizeLinalgNamedOpsPass(bool generalizeLinalgMatmulOps) {
this->generalizeLinalgMatmulOps = generalizeLinalgMatmulOps;
}

void runOnOperation() override;
};
} // namespace
Expand All @@ -45,6 +50,11 @@ void GeneralizeLinalgNamedOpsPass::runOnOperation() {
linalgOp.getOperation())) {
namedOpCandidates.push_back(linalgOp);
}
if (generalizeLinalgMatmulOps &&
isa_and_nonnull<linalg::BatchMatmulOp, linalg::BatchMatmulTransposeBOp>(
linalgOp)) {
namedOpCandidates.push_back(linalgOp);
}
});

IRRewriter rewriter(&getContext());
Expand All @@ -60,8 +70,9 @@ void GeneralizeLinalgNamedOpsPass::runOnOperation() {
}

std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createGeneralizeLinalgNamedOpsPass() {
return std::make_unique<GeneralizeLinalgNamedOpsPass>();
createGeneralizeLinalgNamedOpsPass(bool generalizeLinalgMatmulOps) {
return std::make_unique<GeneralizeLinalgNamedOpsPass>(
generalizeLinalgMatmulOps);
}

} // namespace mlir::iree_compiler::GlobalOptimization
12 changes: 10 additions & 2 deletions compiler/src/iree/compiler/GlobalOptimization/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ static llvm::cl::opt<bool> clEnableQuantizedMatmulReassociation(
llvm::cl::desc(
"Enables reassociation of quantized matmul ops (experimental)."),
llvm::cl::init(false));
static llvm::cl::opt<bool>
clGeneralizeLinalgMatmulOps("enable-generalize-linalg-matmul-ops",
llvm::cl::desc("Generalize linalg MatMul ops"),
llvm::cl::init(false));
static llvm::cl::opt<bool> clEnableFuseSiluHorizontalMatmul(
"iree-global-opt-enable-fuse-silu-horizontal-matmul",
llvm::cl::desc(
Expand Down Expand Up @@ -122,7 +126,9 @@ void buildGlobalOptimizationPassPipeline(
// dims as the unit dim folding pass updates indexing maps and is better
// at working with generics. By this point we have already done any
// specialized raising and the op names are no longer useful.
.addPass(createGeneralizeLinalgNamedOpsPass);
.addPass([&]() {
return createGeneralizeLinalgNamedOpsPass(clGeneralizeLinalgMatmulOps);
});

mainPassManager.addPass(IREE::Flow::createFoldUnitExtentDimsPass());
FunctionLikeNest(mainPassManager)
Expand Down Expand Up @@ -178,7 +184,9 @@ void buildGlobalOptimizationPassPipeline(
}
// Generalize transposes and any other remaining named linalg ops that can
// now be represented as generics.
FunctionLikeNest(mainPassManager).addPass(createGeneralizeLinalgNamedOpsPass);
FunctionLikeNest(mainPassManager).addPass([&]() {
return createGeneralizeLinalgNamedOpsPass(clGeneralizeLinalgMatmulOps);
});

// Hoist loop invariants (e.g. from scf loops) with zero-trip-check.
FunctionLikeNest(mainPassManager)
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/iree/compiler/GlobalOptimization/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ createFuseSiluHorizontalMatmulPass();
/// Generalizes some named Linalg ops into `linalg.generic` operations since the
/// compiler can handle that better.
std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createGeneralizeLinalgNamedOpsPass();
createGeneralizeLinalgNamedOpsPass(bool generalizeLinalgMatmulOps = false);

/// Infers and inserts util.numeric.optional_narrow ops at points that may be
/// beneficial.
Expand Down
4 changes: 4 additions & 0 deletions compiler/src/iree/compiler/GlobalOptimization/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,10 @@ def GeneralizeLinalgNamedOps :
InterfacePass<"iree-global-opt-generalize-linalg-named-ops", "mlir::FunctionOpInterface"> {
let summary = "Convert some Linalg named ops into linalg.generics.";
let constructor = "mlir::iree_compiler::GlobalOptimization::createGeneralizeLinalgNamedOpsPass()";
let options = [
Option<"generalizeLinalgMatmulOps", "enable-generalize-linalg-matmul-ops", "bool",
/*default=*/"false", "Generalize linalg batch MatMul ops">,
];
}

def InferNumericNarrowing :
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: iree-opt --pass-pipeline="builtin.module(util.func(iree-global-opt-generalize-linalg-named-ops))" --split-input-file %s | FileCheck %s
// RUN: iree-opt --pass-pipeline="builtin.module(util.func(iree-global-opt-generalize-linalg-named-ops{enable-generalize-linalg-matmul-ops=true}))" --split-input-file %s | FileCheck %s

util.func public @generalize_op(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>) -> tensor<?x?xf32> {
%c0 = arith.constant 0 : index
Expand Down Expand Up @@ -34,3 +34,16 @@ util.func public @no_generalize_op_within_dispatch(%arg0 : tensor<?x?xf32>, %arg
// CHECK: %[[ADD:.+]] = linalg.add
// CHECK: flow.return %[[ADD]]
// CHECK: util.return %[[DISPATCH]]

// -----

util.func public @generalize_matmul(%arg0: tensor<1x128x128xf32>, %arg1: tensor<1x128x128xf32>) -> tensor<1x128x128xf32> {
%0 = tensor.empty() : tensor<1x128x128xf32>
%1 = linalg.batch_matmul ins(%arg0, %arg1: tensor<1x128x128xf32>, tensor<1x128x128xf32>) outs(%0 : tensor<1x128x128xf32>) -> tensor<1x128x128xf32>
util.return %1 : tensor<1x128x128xf32>
}

// CHECK-LABEL: util.func public @generalize_matmul
// CHECK-SAME: %[[ARG0:.+]]: tensor<1x128x128xf32>, %[[ARG1:.+]]: tensor<1x128x128xf32>
// CHECK: %[[GENERIC:.+]] = linalg.generic
// CHECK-SAME: %[[ARG0]], %[[ARG1]]

0 comments on commit 1e1d51e

Please sign in to comment.