From 09668215fae1dfd6e08200ccd4de0d37be57a964 Mon Sep 17 00:00:00 2001 From: linuxlonelyeagle <2020382038@qq.com> Date: Thu, 25 Sep 2025 06:09:25 +0000 Subject: [PATCH] add eliminate-function-parameter pass. --- .../mlir/Dialect/Func/Transforms/Passes.h | 2 +- .../mlir/Dialect/Func/Transforms/Passes.td | 8 ++ .../Dialect/Func/Transforms/CMakeLists.txt | 1 + .../Transforms/EliminateFunctionParameter.cpp | 88 +++++++++++++++++++ .../Func/eliminate-function-parameter.mlir | 32 +++++++ 5 files changed, 130 insertions(+), 1 deletion(-) create mode 100644 mlir/lib/Dialect/Func/Transforms/EliminateFunctionParameter.cpp create mode 100644 mlir/test/Dialect/Func/eliminate-function-parameter.mlir diff --git a/mlir/include/mlir/Dialect/Func/Transforms/Passes.h b/mlir/include/mlir/Dialect/Func/Transforms/Passes.h index 6fe9cc4bb2986..1369a3627d0f2 100644 --- a/mlir/include/mlir/Dialect/Func/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Func/Transforms/Passes.h @@ -22,7 +22,7 @@ class RewritePatternSet; namespace func { -#define GEN_PASS_DECL_DUPLICATEFUNCTIONELIMINATIONPASS +#define GEN_PASS_DECL #include "mlir/Dialect/Func/Transforms/Passes.h.inc" //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Func/Transforms/Passes.td b/mlir/include/mlir/Dialect/Func/Transforms/Passes.td index 4163997515bb0..0e57bc3e0da91 100644 --- a/mlir/include/mlir/Dialect/Func/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Func/Transforms/Passes.td @@ -21,4 +21,12 @@ def DuplicateFunctionEliminationPass : Pass<"duplicate-function-elimination", }]; } +def EliminateFunctionParameterPass : Pass<"eliminate-function-parameter", + "ModuleOp"> { + let summary = "Eliminate function parameter"; + let description = [{ + Eliminate function parameter is used to remove unnecessary parameters passed + to a function, then update the function call. + }]; +} #endif // MLIR_DIALECT_FUNC_TRANSFORMS_PASSES_TD diff --git a/mlir/lib/Dialect/Func/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Func/Transforms/CMakeLists.txt index 0bed59e109503..3553613543c86 100644 --- a/mlir/lib/Dialect/Func/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Func/Transforms/CMakeLists.txt @@ -1,4 +1,5 @@ add_mlir_dialect_library(MLIRFuncTransforms + EliminateFunctionParameter.cpp DuplicateFunctionElimination.cpp FuncConversions.cpp diff --git a/mlir/lib/Dialect/Func/Transforms/EliminateFunctionParameter.cpp b/mlir/lib/Dialect/Func/Transforms/EliminateFunctionParameter.cpp new file mode 100644 index 0000000000000..c5412e117878b --- /dev/null +++ b/mlir/lib/Dialect/Func/Transforms/EliminateFunctionParameter.cpp @@ -0,0 +1,88 @@ +//===- EliminateFunctionParameter.cpp.cpp - Eliminate function Parameter --===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Func/Transforms/Passes.h" + +namespace mlir { +namespace func { +#define GEN_PASS_DEF_ELIMINATEFUNCTIONPARAMETERPASS +#include "mlir/Dialect/Func/Transforms/Passes.h.inc" +} // namespace func + +/// This function eliminates unnecessary parameters within the function. +static LogicalResult updateFunc(func::FuncOp funcOp, BitVector &arguemntNoUse) { + Block &entryBlock = funcOp.front(); + bool change = false; + FunctionType origType = funcOp.getFunctionType(); + llvm::ArrayRef origInputTypes = origType.getInputs(); + SmallVector newInputTypes; + for (auto iter : llvm::enumerate(funcOp.getArguments())) { + size_t position = iter.index(); + if (!iter.value().use_empty()) { + newInputTypes.push_back(origInputTypes[position]); + continue; + } + arguemntNoUse.set(position); + entryBlock.eraseArgument(position); + change = true; + } + + if (change) { + auto newFunctionType = FunctionType::get(funcOp.getContext(), newInputTypes, + origType.getResults()); + funcOp.setFunctionType(newFunctionType); + } + return success(change); +} + +/// After eliminating redundant parameters from the function, update the +/// function calls. +static LogicalResult updateCall(func::CallOp callOp, + BitVector &argumentsNoUse) { + ValueRange origOperands = callOp.getOperands(); + SmallVector newOperands; + for (auto iter : llvm::enumerate(origOperands)) { + if (!argumentsNoUse[iter.index()]) + newOperands.push_back(iter.value()); + } + callOp->setOperands(newOperands); + return success(); +} + +namespace { +struct EliminateFunctionParameterPass + : public func::impl::EliminateFunctionParameterPassBase< + EliminateFunctionParameterPass> { + using EliminateFunctionParameterPassBase< + EliminateFunctionParameterPass>::EliminateFunctionParameterPassBase; + void runOnOperation() override { + ModuleOp moduleOp = getOperation(); + for (auto funcOp : moduleOp.getOps()) { + size_t argumentSize = funcOp.getArguments().size(); + if (!argumentSize) + continue; + BitVector argumentNoUse(argumentSize); + if (failed(updateFunc(funcOp, argumentNoUse))) + continue; + + auto symbolOp = mlir::cast(funcOp.getOperation()); + auto users = symbolOp.getSymbolUses(moduleOp); + + if (!users.has_value()) + continue; + for (SymbolTable::SymbolUse user : *users) { + Operation *call = user.getUser(); + (void)updateCall(mlir::cast(call), argumentNoUse); + } + } + } +}; + +} // namespace +} // namespace mlir diff --git a/mlir/test/Dialect/Func/eliminate-function-parameter.mlir b/mlir/test/Dialect/Func/eliminate-function-parameter.mlir new file mode 100644 index 0000000000000..0bd35ec6bd1c7 --- /dev/null +++ b/mlir/test/Dialect/Func/eliminate-function-parameter.mlir @@ -0,0 +1,32 @@ +// RUN: mlir-opt %s --split-input-file --eliminate-function-parameter | \ +// RUN: FileCheck %s + +func.func @single_parameter(%arg: index) { + return +} + +func.func @mutl_parameter(%arg0 : index, %arg1 : index) -> index { + return %arg0 : index +} + +func.func @eliminate_parameter(%arg0: index, %arg1: index) -> index { + func.call @single_parameter(%arg0) : (index) -> () + %ret = func.call @mutl_parameter(%arg0, %arg0) : (index, index) -> (index) + return %ret : index +} + +// CHECK-LABEL: func @single_parameter() { +// CHECK: return +// CHECK: } + +// CHECK-LABEL: func @mutl_parameter( +// CHECK-SAME: %[[ARG0:.*]]: index) -> index { +// CHECK: return %[[ARG0]] : index +// CHECK: } + +// CHECK-LABEL: func @eliminate_parameter( +// CHECK-SAME: %[[ARG0:.*]]: index) -> index { +// CHECK: call @single_parameter() : () -> () +// CHECK: %[[RET:.*]] = call @mutl_parameter(%[[ARG0]]) : (index) -> index +// CHECK: return %[[RET]] : index +// CHECK: }