Skip to content

Conversation

linuxlonelyeagle
Copy link
Member

Added the eliminate-function-parameter pass. During the IR transformation process, function parameters may become unused, and this pass is used to remove unused parameters in functions.

@llvmbot
Copy link
Member

llvmbot commented Sep 25, 2025

@llvm/pr-subscribers-mlir

Author: lonely eagle (linuxlonelyeagle)

Changes

Added the eliminate-function-parameter pass. During the IR transformation process, function parameters may become unused, and this pass is used to remove unused parameters in functions.


Full diff: https://github.com/llvm/llvm-project/pull/160654.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Func/Transforms/Passes.h (+1-1)
  • (modified) mlir/include/mlir/Dialect/Func/Transforms/Passes.td (+8)
  • (modified) mlir/lib/Dialect/Func/Transforms/CMakeLists.txt (+1)
  • (added) mlir/lib/Dialect/Func/Transforms/EliminateFunctionParameter.cpp (+88)
  • (added) mlir/test/Dialect/Func/eliminate-function-parameter.mlir (+32)
diff --git a/mlir/include/mlir/Dialect/Func/Transforms/Passes.h b/mlir/include/mlir/Dialect/Func/Transforms/Passes.h
index 6fe9cc4bb2986..1369a3627d0f2 100644
--- a/mlir/include/mlir/Dialect/Func/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Func/Transforms/Passes.h
@@ -22,7 +22,7 @@ class RewritePatternSet;
 
 namespace func {
 
-#define GEN_PASS_DECL_DUPLICATEFUNCTIONELIMINATIONPASS
+#define GEN_PASS_DECL
 #include "mlir/Dialect/Func/Transforms/Passes.h.inc"
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Func/Transforms/Passes.td b/mlir/include/mlir/Dialect/Func/Transforms/Passes.td
index 4163997515bb0..0e57bc3e0da91 100644
--- a/mlir/include/mlir/Dialect/Func/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Func/Transforms/Passes.td
@@ -21,4 +21,12 @@ def DuplicateFunctionEliminationPass : Pass<"duplicate-function-elimination",
   }];
 }
 
+def EliminateFunctionParameterPass : Pass<"eliminate-function-parameter",
+  "ModuleOp"> {
+  let summary = "Eliminate function parameter";
+  let description = [{
+    Eliminate function parameter is used to remove unnecessary parameters passed
+    to a function, then update the function call.
+  }];
+}
 #endif // MLIR_DIALECT_FUNC_TRANSFORMS_PASSES_TD
diff --git a/mlir/lib/Dialect/Func/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Func/Transforms/CMakeLists.txt
index 0bed59e109503..3553613543c86 100644
--- a/mlir/lib/Dialect/Func/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Func/Transforms/CMakeLists.txt
@@ -1,4 +1,5 @@
 add_mlir_dialect_library(MLIRFuncTransforms
+  EliminateFunctionParameter.cpp
   DuplicateFunctionElimination.cpp
   FuncConversions.cpp
 
diff --git a/mlir/lib/Dialect/Func/Transforms/EliminateFunctionParameter.cpp b/mlir/lib/Dialect/Func/Transforms/EliminateFunctionParameter.cpp
new file mode 100644
index 0000000000000..c5412e117878b
--- /dev/null
+++ b/mlir/lib/Dialect/Func/Transforms/EliminateFunctionParameter.cpp
@@ -0,0 +1,88 @@
+//===- EliminateFunctionParameter.cpp.cpp - Eliminate function Parameter --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Func/Transforms/Passes.h"
+
+namespace mlir {
+namespace func {
+#define GEN_PASS_DEF_ELIMINATEFUNCTIONPARAMETERPASS
+#include "mlir/Dialect/Func/Transforms/Passes.h.inc"
+} // namespace func
+
+/// This function eliminates unnecessary parameters within the function.
+static LogicalResult updateFunc(func::FuncOp funcOp, BitVector &arguemntNoUse) {
+  Block &entryBlock = funcOp.front();
+  bool change = false;
+  FunctionType origType = funcOp.getFunctionType();
+  llvm::ArrayRef<Type> origInputTypes = origType.getInputs();
+  SmallVector<Type, 4> newInputTypes;
+  for (auto iter : llvm::enumerate(funcOp.getArguments())) {
+    size_t position = iter.index();
+    if (!iter.value().use_empty()) {
+      newInputTypes.push_back(origInputTypes[position]);
+      continue;
+    }
+    arguemntNoUse.set(position);
+    entryBlock.eraseArgument(position);
+    change = true;
+  }
+
+  if (change) {
+    auto newFunctionType = FunctionType::get(funcOp.getContext(), newInputTypes,
+                                             origType.getResults());
+    funcOp.setFunctionType(newFunctionType);
+  }
+  return success(change);
+}
+
+/// After eliminating redundant parameters from the function, update the
+/// function calls.
+static LogicalResult updateCall(func::CallOp callOp,
+                                BitVector &argumentsNoUse) {
+  ValueRange origOperands = callOp.getOperands();
+  SmallVector<Value, 4> newOperands;
+  for (auto iter : llvm::enumerate(origOperands)) {
+    if (!argumentsNoUse[iter.index()])
+      newOperands.push_back(iter.value());
+  }
+  callOp->setOperands(newOperands);
+  return success();
+}
+
+namespace {
+struct EliminateFunctionParameterPass
+    : public func::impl::EliminateFunctionParameterPassBase<
+          EliminateFunctionParameterPass> {
+  using EliminateFunctionParameterPassBase<
+      EliminateFunctionParameterPass>::EliminateFunctionParameterPassBase;
+  void runOnOperation() override {
+    ModuleOp moduleOp = getOperation();
+    for (auto funcOp : moduleOp.getOps<func::FuncOp>()) {
+      size_t argumentSize = funcOp.getArguments().size();
+      if (!argumentSize)
+        continue;
+      BitVector argumentNoUse(argumentSize);
+      if (failed(updateFunc(funcOp, argumentNoUse)))
+        continue;
+
+      auto symbolOp = mlir::cast<SymbolOpInterface>(funcOp.getOperation());
+      auto users = symbolOp.getSymbolUses(moduleOp);
+
+      if (!users.has_value())
+        continue;
+      for (SymbolTable::SymbolUse user : *users) {
+        Operation *call = user.getUser();
+        (void)updateCall(mlir::cast<func::CallOp>(call), argumentNoUse);
+      }
+    }
+  }
+};
+
+} // namespace
+} // namespace mlir
diff --git a/mlir/test/Dialect/Func/eliminate-function-parameter.mlir b/mlir/test/Dialect/Func/eliminate-function-parameter.mlir
new file mode 100644
index 0000000000000..0bd35ec6bd1c7
--- /dev/null
+++ b/mlir/test/Dialect/Func/eliminate-function-parameter.mlir
@@ -0,0 +1,32 @@
+// RUN: mlir-opt %s --split-input-file --eliminate-function-parameter | \
+// RUN: FileCheck %s
+
+func.func @single_parameter(%arg: index) {
+  return
+}
+
+func.func @mutl_parameter(%arg0 : index, %arg1 : index) -> index {
+  return %arg0 : index
+}
+
+func.func @eliminate_parameter(%arg0: index, %arg1: index) -> index {
+  func.call @single_parameter(%arg0) : (index) -> ()
+  %ret = func.call @mutl_parameter(%arg0, %arg0) : (index, index) -> (index)
+  return %ret : index
+}
+
+// CHECK-LABEL: func @single_parameter() {
+//       CHECK:   return
+//       CHECK: }
+
+// CHECK-LABEL: func @mutl_parameter(
+//  CHECK-SAME:   %[[ARG0:.*]]: index) -> index {
+//       CHECK:   return %[[ARG0]] : index
+//       CHECK: }
+
+// CHECK-LABEL: func @eliminate_parameter(
+//  CHECK-SAME:   %[[ARG0:.*]]: index) -> index {
+//       CHECK:   call @single_parameter() : () -> ()
+//       CHECK:   %[[RET:.*]] = call @mutl_parameter(%[[ARG0]]) : (index) -> index
+//       CHECK:   return %[[RET]] : index
+//       CHECK: }

@llvmbot
Copy link
Member

llvmbot commented Sep 25, 2025

@llvm/pr-subscribers-mlir-func

Author: lonely eagle (linuxlonelyeagle)

Changes

Added the eliminate-function-parameter pass. During the IR transformation process, function parameters may become unused, and this pass is used to remove unused parameters in functions.


Full diff: https://github.com/llvm/llvm-project/pull/160654.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Func/Transforms/Passes.h (+1-1)
  • (modified) mlir/include/mlir/Dialect/Func/Transforms/Passes.td (+8)
  • (modified) mlir/lib/Dialect/Func/Transforms/CMakeLists.txt (+1)
  • (added) mlir/lib/Dialect/Func/Transforms/EliminateFunctionParameter.cpp (+88)
  • (added) mlir/test/Dialect/Func/eliminate-function-parameter.mlir (+32)
diff --git a/mlir/include/mlir/Dialect/Func/Transforms/Passes.h b/mlir/include/mlir/Dialect/Func/Transforms/Passes.h
index 6fe9cc4bb2986..1369a3627d0f2 100644
--- a/mlir/include/mlir/Dialect/Func/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Func/Transforms/Passes.h
@@ -22,7 +22,7 @@ class RewritePatternSet;
 
 namespace func {
 
-#define GEN_PASS_DECL_DUPLICATEFUNCTIONELIMINATIONPASS
+#define GEN_PASS_DECL
 #include "mlir/Dialect/Func/Transforms/Passes.h.inc"
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Func/Transforms/Passes.td b/mlir/include/mlir/Dialect/Func/Transforms/Passes.td
index 4163997515bb0..0e57bc3e0da91 100644
--- a/mlir/include/mlir/Dialect/Func/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Func/Transforms/Passes.td
@@ -21,4 +21,12 @@ def DuplicateFunctionEliminationPass : Pass<"duplicate-function-elimination",
   }];
 }
 
+def EliminateFunctionParameterPass : Pass<"eliminate-function-parameter",
+  "ModuleOp"> {
+  let summary = "Eliminate function parameter";
+  let description = [{
+    Eliminate function parameter is used to remove unnecessary parameters passed
+    to a function, then update the function call.
+  }];
+}
 #endif // MLIR_DIALECT_FUNC_TRANSFORMS_PASSES_TD
diff --git a/mlir/lib/Dialect/Func/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Func/Transforms/CMakeLists.txt
index 0bed59e109503..3553613543c86 100644
--- a/mlir/lib/Dialect/Func/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Func/Transforms/CMakeLists.txt
@@ -1,4 +1,5 @@
 add_mlir_dialect_library(MLIRFuncTransforms
+  EliminateFunctionParameter.cpp
   DuplicateFunctionElimination.cpp
   FuncConversions.cpp
 
diff --git a/mlir/lib/Dialect/Func/Transforms/EliminateFunctionParameter.cpp b/mlir/lib/Dialect/Func/Transforms/EliminateFunctionParameter.cpp
new file mode 100644
index 0000000000000..c5412e117878b
--- /dev/null
+++ b/mlir/lib/Dialect/Func/Transforms/EliminateFunctionParameter.cpp
@@ -0,0 +1,88 @@
+//===- EliminateFunctionParameter.cpp.cpp - Eliminate function Parameter --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Func/Transforms/Passes.h"
+
+namespace mlir {
+namespace func {
+#define GEN_PASS_DEF_ELIMINATEFUNCTIONPARAMETERPASS
+#include "mlir/Dialect/Func/Transforms/Passes.h.inc"
+} // namespace func
+
+/// This function eliminates unnecessary parameters within the function.
+static LogicalResult updateFunc(func::FuncOp funcOp, BitVector &arguemntNoUse) {
+  Block &entryBlock = funcOp.front();
+  bool change = false;
+  FunctionType origType = funcOp.getFunctionType();
+  llvm::ArrayRef<Type> origInputTypes = origType.getInputs();
+  SmallVector<Type, 4> newInputTypes;
+  for (auto iter : llvm::enumerate(funcOp.getArguments())) {
+    size_t position = iter.index();
+    if (!iter.value().use_empty()) {
+      newInputTypes.push_back(origInputTypes[position]);
+      continue;
+    }
+    arguemntNoUse.set(position);
+    entryBlock.eraseArgument(position);
+    change = true;
+  }
+
+  if (change) {
+    auto newFunctionType = FunctionType::get(funcOp.getContext(), newInputTypes,
+                                             origType.getResults());
+    funcOp.setFunctionType(newFunctionType);
+  }
+  return success(change);
+}
+
+/// After eliminating redundant parameters from the function, update the
+/// function calls.
+static LogicalResult updateCall(func::CallOp callOp,
+                                BitVector &argumentsNoUse) {
+  ValueRange origOperands = callOp.getOperands();
+  SmallVector<Value, 4> newOperands;
+  for (auto iter : llvm::enumerate(origOperands)) {
+    if (!argumentsNoUse[iter.index()])
+      newOperands.push_back(iter.value());
+  }
+  callOp->setOperands(newOperands);
+  return success();
+}
+
+namespace {
+struct EliminateFunctionParameterPass
+    : public func::impl::EliminateFunctionParameterPassBase<
+          EliminateFunctionParameterPass> {
+  using EliminateFunctionParameterPassBase<
+      EliminateFunctionParameterPass>::EliminateFunctionParameterPassBase;
+  void runOnOperation() override {
+    ModuleOp moduleOp = getOperation();
+    for (auto funcOp : moduleOp.getOps<func::FuncOp>()) {
+      size_t argumentSize = funcOp.getArguments().size();
+      if (!argumentSize)
+        continue;
+      BitVector argumentNoUse(argumentSize);
+      if (failed(updateFunc(funcOp, argumentNoUse)))
+        continue;
+
+      auto symbolOp = mlir::cast<SymbolOpInterface>(funcOp.getOperation());
+      auto users = symbolOp.getSymbolUses(moduleOp);
+
+      if (!users.has_value())
+        continue;
+      for (SymbolTable::SymbolUse user : *users) {
+        Operation *call = user.getUser();
+        (void)updateCall(mlir::cast<func::CallOp>(call), argumentNoUse);
+      }
+    }
+  }
+};
+
+} // namespace
+} // namespace mlir
diff --git a/mlir/test/Dialect/Func/eliminate-function-parameter.mlir b/mlir/test/Dialect/Func/eliminate-function-parameter.mlir
new file mode 100644
index 0000000000000..0bd35ec6bd1c7
--- /dev/null
+++ b/mlir/test/Dialect/Func/eliminate-function-parameter.mlir
@@ -0,0 +1,32 @@
+// RUN: mlir-opt %s --split-input-file --eliminate-function-parameter | \
+// RUN: FileCheck %s
+
+func.func @single_parameter(%arg: index) {
+  return
+}
+
+func.func @mutl_parameter(%arg0 : index, %arg1 : index) -> index {
+  return %arg0 : index
+}
+
+func.func @eliminate_parameter(%arg0: index, %arg1: index) -> index {
+  func.call @single_parameter(%arg0) : (index) -> ()
+  %ret = func.call @mutl_parameter(%arg0, %arg0) : (index, index) -> (index)
+  return %ret : index
+}
+
+// CHECK-LABEL: func @single_parameter() {
+//       CHECK:   return
+//       CHECK: }
+
+// CHECK-LABEL: func @mutl_parameter(
+//  CHECK-SAME:   %[[ARG0:.*]]: index) -> index {
+//       CHECK:   return %[[ARG0]] : index
+//       CHECK: }
+
+// CHECK-LABEL: func @eliminate_parameter(
+//  CHECK-SAME:   %[[ARG0:.*]]: index) -> index {
+//       CHECK:   call @single_parameter() : () -> ()
+//       CHECK:   %[[RET:.*]] = call @mutl_parameter(%[[ARG0]]) : (index) -> index
+//       CHECK:   return %[[RET]] : index
+//       CHECK: }

@joker-eph
Copy link
Collaborator

Isn't "remove-dead-values" already doing this?

func.call @single_parameter(%arg0) : (index) -> ()
%ret = func.call @mutl_parameter(%arg0, %arg0) : (index, index) -> (index)
return %ret : index
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: all these functions are public, it is incorrect to change the signature of a public function: you can't know about call sites outside of the current module.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is something I hadn't considered.Thank you for the reminder. It has helped me better understand the meaning of “public function.”

@linuxlonelyeagle
Copy link
Member Author

Isn't "remove-dead-values" already doing this?

After reviewing the definition of remove-dead-values, it does indeed perform this operation.

@linuxlonelyeagle
Copy link
Member Author

remove-dead-values

It appears that remove-dead-values does not work for this.

func.func private @single_parameter(%arg: index) {
  return
}

func.func private @mutl_parameter(%arg0 : index, %arg1 : index) -> index {
  return %arg0 : index
}

func.func private @eliminate_parameter(%arg0: index, %arg1: index) -> index {
  func.call @single_parameter(%arg0) : (index) -> ()
  %ret = func.call @mutl_parameter(%arg0, %arg0) : (index, index) -> (index)
  return %ret : index
}

run pass

module {
  func.func private @single_parameter(%arg0: index) {
    return
  }
  func.func private @mutl_parameter(%arg0: index, %arg1: index) {
    return
  }
  func.func private @eliminate_parameter(%arg0: index, %arg1: index) {
    call @single_parameter(%arg0) : (index) -> ()
    call @mutl_parameter(%arg0, %arg0) : (index, index) -> ()
    return
  }
}

@linuxlonelyeagle
Copy link
Member Author

I've reopened this PR. Maybe we should discuss it.

@linuxlonelyeagle
Copy link
Member Author

#160755

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants