-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[mlir][func] Add eliminate-function-parameter pass #160654
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][func] Add eliminate-function-parameter pass #160654
Conversation
@llvm/pr-subscribers-mlir Author: lonely eagle (linuxlonelyeagle) ChangesAdded the eliminate-function-parameter pass. During the IR transformation process, function parameters may become unused, and this pass is used to remove unused parameters in functions. Full diff: https://github.com/llvm/llvm-project/pull/160654.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Func/Transforms/Passes.h b/mlir/include/mlir/Dialect/Func/Transforms/Passes.h
index 6fe9cc4bb2986..1369a3627d0f2 100644
--- a/mlir/include/mlir/Dialect/Func/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Func/Transforms/Passes.h
@@ -22,7 +22,7 @@ class RewritePatternSet;
namespace func {
-#define GEN_PASS_DECL_DUPLICATEFUNCTIONELIMINATIONPASS
+#define GEN_PASS_DECL
#include "mlir/Dialect/Func/Transforms/Passes.h.inc"
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Func/Transforms/Passes.td b/mlir/include/mlir/Dialect/Func/Transforms/Passes.td
index 4163997515bb0..0e57bc3e0da91 100644
--- a/mlir/include/mlir/Dialect/Func/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Func/Transforms/Passes.td
@@ -21,4 +21,12 @@ def DuplicateFunctionEliminationPass : Pass<"duplicate-function-elimination",
}];
}
+def EliminateFunctionParameterPass : Pass<"eliminate-function-parameter",
+ "ModuleOp"> {
+ let summary = "Eliminate function parameter";
+ let description = [{
+ Eliminate function parameter is used to remove unnecessary parameters passed
+ to a function, then update the function call.
+ }];
+}
#endif // MLIR_DIALECT_FUNC_TRANSFORMS_PASSES_TD
diff --git a/mlir/lib/Dialect/Func/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Func/Transforms/CMakeLists.txt
index 0bed59e109503..3553613543c86 100644
--- a/mlir/lib/Dialect/Func/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Func/Transforms/CMakeLists.txt
@@ -1,4 +1,5 @@
add_mlir_dialect_library(MLIRFuncTransforms
+ EliminateFunctionParameter.cpp
DuplicateFunctionElimination.cpp
FuncConversions.cpp
diff --git a/mlir/lib/Dialect/Func/Transforms/EliminateFunctionParameter.cpp b/mlir/lib/Dialect/Func/Transforms/EliminateFunctionParameter.cpp
new file mode 100644
index 0000000000000..c5412e117878b
--- /dev/null
+++ b/mlir/lib/Dialect/Func/Transforms/EliminateFunctionParameter.cpp
@@ -0,0 +1,88 @@
+//===- EliminateFunctionParameter.cpp.cpp - Eliminate function Parameter --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Func/Transforms/Passes.h"
+
+namespace mlir {
+namespace func {
+#define GEN_PASS_DEF_ELIMINATEFUNCTIONPARAMETERPASS
+#include "mlir/Dialect/Func/Transforms/Passes.h.inc"
+} // namespace func
+
+/// This function eliminates unnecessary parameters within the function.
+static LogicalResult updateFunc(func::FuncOp funcOp, BitVector &arguemntNoUse) {
+ Block &entryBlock = funcOp.front();
+ bool change = false;
+ FunctionType origType = funcOp.getFunctionType();
+ llvm::ArrayRef<Type> origInputTypes = origType.getInputs();
+ SmallVector<Type, 4> newInputTypes;
+ for (auto iter : llvm::enumerate(funcOp.getArguments())) {
+ size_t position = iter.index();
+ if (!iter.value().use_empty()) {
+ newInputTypes.push_back(origInputTypes[position]);
+ continue;
+ }
+ arguemntNoUse.set(position);
+ entryBlock.eraseArgument(position);
+ change = true;
+ }
+
+ if (change) {
+ auto newFunctionType = FunctionType::get(funcOp.getContext(), newInputTypes,
+ origType.getResults());
+ funcOp.setFunctionType(newFunctionType);
+ }
+ return success(change);
+}
+
+/// After eliminating redundant parameters from the function, update the
+/// function calls.
+static LogicalResult updateCall(func::CallOp callOp,
+ BitVector &argumentsNoUse) {
+ ValueRange origOperands = callOp.getOperands();
+ SmallVector<Value, 4> newOperands;
+ for (auto iter : llvm::enumerate(origOperands)) {
+ if (!argumentsNoUse[iter.index()])
+ newOperands.push_back(iter.value());
+ }
+ callOp->setOperands(newOperands);
+ return success();
+}
+
+namespace {
+struct EliminateFunctionParameterPass
+ : public func::impl::EliminateFunctionParameterPassBase<
+ EliminateFunctionParameterPass> {
+ using EliminateFunctionParameterPassBase<
+ EliminateFunctionParameterPass>::EliminateFunctionParameterPassBase;
+ void runOnOperation() override {
+ ModuleOp moduleOp = getOperation();
+ for (auto funcOp : moduleOp.getOps<func::FuncOp>()) {
+ size_t argumentSize = funcOp.getArguments().size();
+ if (!argumentSize)
+ continue;
+ BitVector argumentNoUse(argumentSize);
+ if (failed(updateFunc(funcOp, argumentNoUse)))
+ continue;
+
+ auto symbolOp = mlir::cast<SymbolOpInterface>(funcOp.getOperation());
+ auto users = symbolOp.getSymbolUses(moduleOp);
+
+ if (!users.has_value())
+ continue;
+ for (SymbolTable::SymbolUse user : *users) {
+ Operation *call = user.getUser();
+ (void)updateCall(mlir::cast<func::CallOp>(call), argumentNoUse);
+ }
+ }
+ }
+};
+
+} // namespace
+} // namespace mlir
diff --git a/mlir/test/Dialect/Func/eliminate-function-parameter.mlir b/mlir/test/Dialect/Func/eliminate-function-parameter.mlir
new file mode 100644
index 0000000000000..0bd35ec6bd1c7
--- /dev/null
+++ b/mlir/test/Dialect/Func/eliminate-function-parameter.mlir
@@ -0,0 +1,32 @@
+// RUN: mlir-opt %s --split-input-file --eliminate-function-parameter | \
+// RUN: FileCheck %s
+
+func.func @single_parameter(%arg: index) {
+ return
+}
+
+func.func @mutl_parameter(%arg0 : index, %arg1 : index) -> index {
+ return %arg0 : index
+}
+
+func.func @eliminate_parameter(%arg0: index, %arg1: index) -> index {
+ func.call @single_parameter(%arg0) : (index) -> ()
+ %ret = func.call @mutl_parameter(%arg0, %arg0) : (index, index) -> (index)
+ return %ret : index
+}
+
+// CHECK-LABEL: func @single_parameter() {
+// CHECK: return
+// CHECK: }
+
+// CHECK-LABEL: func @mutl_parameter(
+// CHECK-SAME: %[[ARG0:.*]]: index) -> index {
+// CHECK: return %[[ARG0]] : index
+// CHECK: }
+
+// CHECK-LABEL: func @eliminate_parameter(
+// CHECK-SAME: %[[ARG0:.*]]: index) -> index {
+// CHECK: call @single_parameter() : () -> ()
+// CHECK: %[[RET:.*]] = call @mutl_parameter(%[[ARG0]]) : (index) -> index
+// CHECK: return %[[RET]] : index
+// CHECK: }
|
@llvm/pr-subscribers-mlir-func Author: lonely eagle (linuxlonelyeagle) ChangesAdded the eliminate-function-parameter pass. During the IR transformation process, function parameters may become unused, and this pass is used to remove unused parameters in functions. Full diff: https://github.com/llvm/llvm-project/pull/160654.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Func/Transforms/Passes.h b/mlir/include/mlir/Dialect/Func/Transforms/Passes.h
index 6fe9cc4bb2986..1369a3627d0f2 100644
--- a/mlir/include/mlir/Dialect/Func/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Func/Transforms/Passes.h
@@ -22,7 +22,7 @@ class RewritePatternSet;
namespace func {
-#define GEN_PASS_DECL_DUPLICATEFUNCTIONELIMINATIONPASS
+#define GEN_PASS_DECL
#include "mlir/Dialect/Func/Transforms/Passes.h.inc"
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Func/Transforms/Passes.td b/mlir/include/mlir/Dialect/Func/Transforms/Passes.td
index 4163997515bb0..0e57bc3e0da91 100644
--- a/mlir/include/mlir/Dialect/Func/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Func/Transforms/Passes.td
@@ -21,4 +21,12 @@ def DuplicateFunctionEliminationPass : Pass<"duplicate-function-elimination",
}];
}
+def EliminateFunctionParameterPass : Pass<"eliminate-function-parameter",
+ "ModuleOp"> {
+ let summary = "Eliminate function parameter";
+ let description = [{
+ Eliminate function parameter is used to remove unnecessary parameters passed
+ to a function, then update the function call.
+ }];
+}
#endif // MLIR_DIALECT_FUNC_TRANSFORMS_PASSES_TD
diff --git a/mlir/lib/Dialect/Func/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Func/Transforms/CMakeLists.txt
index 0bed59e109503..3553613543c86 100644
--- a/mlir/lib/Dialect/Func/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Func/Transforms/CMakeLists.txt
@@ -1,4 +1,5 @@
add_mlir_dialect_library(MLIRFuncTransforms
+ EliminateFunctionParameter.cpp
DuplicateFunctionElimination.cpp
FuncConversions.cpp
diff --git a/mlir/lib/Dialect/Func/Transforms/EliminateFunctionParameter.cpp b/mlir/lib/Dialect/Func/Transforms/EliminateFunctionParameter.cpp
new file mode 100644
index 0000000000000..c5412e117878b
--- /dev/null
+++ b/mlir/lib/Dialect/Func/Transforms/EliminateFunctionParameter.cpp
@@ -0,0 +1,88 @@
+//===- EliminateFunctionParameter.cpp.cpp - Eliminate function Parameter --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Func/Transforms/Passes.h"
+
+namespace mlir {
+namespace func {
+#define GEN_PASS_DEF_ELIMINATEFUNCTIONPARAMETERPASS
+#include "mlir/Dialect/Func/Transforms/Passes.h.inc"
+} // namespace func
+
+/// This function eliminates unnecessary parameters within the function.
+static LogicalResult updateFunc(func::FuncOp funcOp, BitVector &arguemntNoUse) {
+ Block &entryBlock = funcOp.front();
+ bool change = false;
+ FunctionType origType = funcOp.getFunctionType();
+ llvm::ArrayRef<Type> origInputTypes = origType.getInputs();
+ SmallVector<Type, 4> newInputTypes;
+ for (auto iter : llvm::enumerate(funcOp.getArguments())) {
+ size_t position = iter.index();
+ if (!iter.value().use_empty()) {
+ newInputTypes.push_back(origInputTypes[position]);
+ continue;
+ }
+ arguemntNoUse.set(position);
+ entryBlock.eraseArgument(position);
+ change = true;
+ }
+
+ if (change) {
+ auto newFunctionType = FunctionType::get(funcOp.getContext(), newInputTypes,
+ origType.getResults());
+ funcOp.setFunctionType(newFunctionType);
+ }
+ return success(change);
+}
+
+/// After eliminating redundant parameters from the function, update the
+/// function calls.
+static LogicalResult updateCall(func::CallOp callOp,
+ BitVector &argumentsNoUse) {
+ ValueRange origOperands = callOp.getOperands();
+ SmallVector<Value, 4> newOperands;
+ for (auto iter : llvm::enumerate(origOperands)) {
+ if (!argumentsNoUse[iter.index()])
+ newOperands.push_back(iter.value());
+ }
+ callOp->setOperands(newOperands);
+ return success();
+}
+
+namespace {
+struct EliminateFunctionParameterPass
+ : public func::impl::EliminateFunctionParameterPassBase<
+ EliminateFunctionParameterPass> {
+ using EliminateFunctionParameterPassBase<
+ EliminateFunctionParameterPass>::EliminateFunctionParameterPassBase;
+ void runOnOperation() override {
+ ModuleOp moduleOp = getOperation();
+ for (auto funcOp : moduleOp.getOps<func::FuncOp>()) {
+ size_t argumentSize = funcOp.getArguments().size();
+ if (!argumentSize)
+ continue;
+ BitVector argumentNoUse(argumentSize);
+ if (failed(updateFunc(funcOp, argumentNoUse)))
+ continue;
+
+ auto symbolOp = mlir::cast<SymbolOpInterface>(funcOp.getOperation());
+ auto users = symbolOp.getSymbolUses(moduleOp);
+
+ if (!users.has_value())
+ continue;
+ for (SymbolTable::SymbolUse user : *users) {
+ Operation *call = user.getUser();
+ (void)updateCall(mlir::cast<func::CallOp>(call), argumentNoUse);
+ }
+ }
+ }
+};
+
+} // namespace
+} // namespace mlir
diff --git a/mlir/test/Dialect/Func/eliminate-function-parameter.mlir b/mlir/test/Dialect/Func/eliminate-function-parameter.mlir
new file mode 100644
index 0000000000000..0bd35ec6bd1c7
--- /dev/null
+++ b/mlir/test/Dialect/Func/eliminate-function-parameter.mlir
@@ -0,0 +1,32 @@
+// RUN: mlir-opt %s --split-input-file --eliminate-function-parameter | \
+// RUN: FileCheck %s
+
+func.func @single_parameter(%arg: index) {
+ return
+}
+
+func.func @mutl_parameter(%arg0 : index, %arg1 : index) -> index {
+ return %arg0 : index
+}
+
+func.func @eliminate_parameter(%arg0: index, %arg1: index) -> index {
+ func.call @single_parameter(%arg0) : (index) -> ()
+ %ret = func.call @mutl_parameter(%arg0, %arg0) : (index, index) -> (index)
+ return %ret : index
+}
+
+// CHECK-LABEL: func @single_parameter() {
+// CHECK: return
+// CHECK: }
+
+// CHECK-LABEL: func @mutl_parameter(
+// CHECK-SAME: %[[ARG0:.*]]: index) -> index {
+// CHECK: return %[[ARG0]] : index
+// CHECK: }
+
+// CHECK-LABEL: func @eliminate_parameter(
+// CHECK-SAME: %[[ARG0:.*]]: index) -> index {
+// CHECK: call @single_parameter() : () -> ()
+// CHECK: %[[RET:.*]] = call @mutl_parameter(%[[ARG0]]) : (index) -> index
+// CHECK: return %[[RET]] : index
+// CHECK: }
|
Isn't "remove-dead-values" already doing this? |
func.call @single_parameter(%arg0) : (index) -> () | ||
%ret = func.call @mutl_parameter(%arg0, %arg0) : (index, index) -> (index) | ||
return %ret : index | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: all these functions are public, it is incorrect to change the signature of a public function: you can't know about call sites outside of the current module.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is something I hadn't considered.Thank you for the reminder. It has helped me better understand the meaning of “public function.”
After reviewing the definition of |
It appears that remove-dead-values does not work for this.
run pass
|
I've reopened this PR. Maybe we should discuss it. |
Added the eliminate-function-parameter pass. During the IR transformation process, function parameters may become unused, and this pass is used to remove unused parameters in functions.