Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MergeFunctions] Add support to run the pass over a set of function pointers #111045

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

Casperento
Copy link

This modification will enable the usage of MergeFunctions as a standalone library. Currently, MergeFunctions can only be applied to an entire module. By adopting this change, developers will gain the flexibility to reuse the MergeFunctions code within their own projects, choosing which functions to merge; hence, promoting code reusability. Notice that this modification will not break backward compatibility, because MergeFunctions will still work as a pass after the modification.

Summary of Changes:

  • Modified the MergeFunctionsPass to allow running the pass over a set of function pointers.
  • This behavior is optional and doesn't interfere with the existing functionality of running the pass on the entire Module.
  • Added unit tests to assert the correctness of the updated implementation, ensuring that function merging works as expected when run on both sets of pointers and full modules.

remove templates

unit tests added

format

updated data structures

format
Copy link

github-actions bot commented Oct 3, 2024

Thank you for submitting a Pull Request (PR) to the LLVM Project!

This PR will be automatically labeled and the relevant teams will be notified.

If you wish to, you can add reviewers by using the "Reviewers" section on this page.

If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using @ followed by their GitHub username.

If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers.

If you have further questions, they may be answered by the LLVM GitHub User Guide.

You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums.

@llvmbot
Copy link
Collaborator

llvmbot commented Oct 3, 2024

@llvm/pr-subscribers-llvm-transforms

Author: Rafael Eckstein (Casperento)

Changes

This modification will enable the usage of MergeFunctions as a standalone library. Currently, MergeFunctions can only be applied to an entire module. By adopting this change, developers will gain the flexibility to reuse the MergeFunctions code within their own projects, choosing which functions to merge; hence, promoting code reusability. Notice that this modification will not break backward compatibility, because MergeFunctions will still work as a pass after the modification.

Summary of Changes:

  • Modified the MergeFunctionsPass to allow running the pass over a set of function pointers.
  • This behavior is optional and doesn't interfere with the existing functionality of running the pass on the entire Module.
  • Added unit tests to assert the correctness of the updated implementation, ensuring that function merging works as expected when run on both sets of pointers and full modules.

Full diff: https://github.com/llvm/llvm-project/pull/111045.diff

5 Files Affected:

  • (modified) llvm/include/llvm/Transforms/IPO/MergeFunctions.h (+7)
  • (modified) llvm/lib/Transforms/IPO/MergeFunctions.cpp (+61-2)
  • (modified) llvm/unittests/Transforms/Utils/CMakeLists.txt (+1)
  • (added) llvm/unittests/Transforms/Utils/MergeFunctionsTest.cpp (+271)
  • (modified) llvm/utils/gn/secondary/llvm/unittests/Transforms/Utils/BUILD.gn (+1)
diff --git a/llvm/include/llvm/Transforms/IPO/MergeFunctions.h b/llvm/include/llvm/Transforms/IPO/MergeFunctions.h
index 822f0fd99188d0..1b3b1d22f11e28 100644
--- a/llvm/include/llvm/Transforms/IPO/MergeFunctions.h
+++ b/llvm/include/llvm/Transforms/IPO/MergeFunctions.h
@@ -15,7 +15,10 @@
 #ifndef LLVM_TRANSFORMS_IPO_MERGEFUNCTIONS_H
 #define LLVM_TRANSFORMS_IPO_MERGEFUNCTIONS_H
 
+#include "llvm/IR/Function.h"
 #include "llvm/IR/PassManager.h"
+#include <map>
+#include <set>
 
 namespace llvm {
 
@@ -25,6 +28,10 @@ class Module;
 class MergeFunctionsPass : public PassInfoMixin<MergeFunctionsPass> {
 public:
   PreservedAnalyses run(Module &M, ModuleAnalysisManager &AM);
+
+  static bool runOnModule(Module &M);
+  static std::pair<bool, std::map<Function *, Function *>>
+  runOnFunctions(std::set<Function *> &F);
 };
 
 } // end namespace llvm
diff --git a/llvm/lib/Transforms/IPO/MergeFunctions.cpp b/llvm/lib/Transforms/IPO/MergeFunctions.cpp
index b50a700e09038f..a434d7920b6ccf 100644
--- a/llvm/lib/Transforms/IPO/MergeFunctions.cpp
+++ b/llvm/lib/Transforms/IPO/MergeFunctions.cpp
@@ -123,6 +123,7 @@
 #include <algorithm>
 #include <cassert>
 #include <iterator>
+#include <map>
 #include <set>
 #include <utility>
 #include <vector>
@@ -198,6 +199,8 @@ class MergeFunctions {
   }
 
   bool runOnModule(Module &M);
+  bool runOnFunctions(std::set<Function *> &F);
+  std::map<Function *, Function *> &getDelToNewMap();
 
 private:
   // The function comparison operator is provided here so that FunctionNodes do
@@ -298,17 +301,31 @@ class MergeFunctions {
   // dangling iterators into FnTree. The invariant that preserves this is that
   // there is exactly one mapping F -> FN for each FunctionNode FN in FnTree.
   DenseMap<AssertingVH<Function>, FnTreeType::iterator> FNodesInTree;
+
+  /// Deleted-New functions mapping
+  std::map<Function *, Function *> DelToNewMap;
 };
 } // end anonymous namespace
 
 PreservedAnalyses MergeFunctionsPass::run(Module &M,
                                           ModuleAnalysisManager &AM) {
-  MergeFunctions MF;
-  if (!MF.runOnModule(M))
+  if (!MergeFunctionsPass::runOnModule(M))
     return PreservedAnalyses::all();
   return PreservedAnalyses::none();
 }
 
+bool MergeFunctionsPass::runOnModule(Module &M) {
+  MergeFunctions MF;
+  return MF.runOnModule(M);
+}
+
+std::pair<bool, std::map<Function *, Function *>>
+MergeFunctionsPass::runOnFunctions(std::set<Function *> &F) {
+  MergeFunctions MF;
+  bool MergeResult = MF.runOnFunctions(F);
+  return {MergeResult, MF.getDelToNewMap()};
+}
+
 #ifndef NDEBUG
 bool MergeFunctions::doFunctionalCheck(std::vector<WeakTrackingVH> &Worklist) {
   if (const unsigned Max = NumFunctionsForVerificationCheck) {
@@ -468,6 +485,47 @@ bool MergeFunctions::runOnModule(Module &M) {
   return Changed;
 }
 
+bool MergeFunctions::runOnFunctions(std::set<Function *> &F) {
+  bool Changed = false;
+  std::vector<std::pair<IRHash, Function *>> HashedFuncs;
+  for (Function *Func : F) {
+    if (isEligibleForMerging(*Func)) {
+      HashedFuncs.push_back({StructuralHash(*Func), Func});
+    }
+  }
+  llvm::stable_sort(HashedFuncs, less_first());
+  auto S = HashedFuncs.begin();
+  for (auto I = HashedFuncs.begin(), IE = HashedFuncs.end(); I != IE; ++I) {
+    if ((I != S && std::prev(I)->first == I->first) ||
+        (std::next(I) != IE && std::next(I)->first == I->first)) {
+      Deferred.push_back(WeakTrackingVH(I->second));
+    }
+  }
+  do {
+    std::vector<WeakTrackingVH> Worklist;
+    Deferred.swap(Worklist);
+    LLVM_DEBUG(dbgs() << "size of function: " << F.size() << '\n');
+    LLVM_DEBUG(dbgs() << "size of worklist: " << Worklist.size() << '\n');
+    for (WeakTrackingVH &I : Worklist) {
+      if (!I)
+        continue;
+      Function *F = cast<Function>(I);
+      if (!F->isDeclaration() && !F->hasAvailableExternallyLinkage()) {
+        Changed |= insert(F);
+      }
+    }
+    LLVM_DEBUG(dbgs() << "size of FnTree: " << FnTree.size() << '\n');
+  } while (!Deferred.empty());
+  FnTree.clear();
+  FNodesInTree.clear();
+  GlobalNumbers.clear();
+  return Changed;
+}
+
+std::map<Function *, Function *> &MergeFunctions::getDelToNewMap() {
+  return this->DelToNewMap;
+}
+
 // Replace direct callers of Old with New.
 void MergeFunctions::replaceDirectCallers(Function *Old, Function *New) {
   for (Use &U : llvm::make_early_inc_range(Old->uses())) {
@@ -1004,6 +1062,7 @@ bool MergeFunctions::insert(Function *NewFunction) {
 
   Function *DeleteF = NewFunction;
   mergeTwoFunctions(OldF.getFunc(), DeleteF);
+  this->DelToNewMap.emplace(DeleteF, OldF.getFunc());
   return true;
 }
 
diff --git a/llvm/unittests/Transforms/Utils/CMakeLists.txt b/llvm/unittests/Transforms/Utils/CMakeLists.txt
index 5c7ec28709c169..7effa5d8e7d6d2 100644
--- a/llvm/unittests/Transforms/Utils/CMakeLists.txt
+++ b/llvm/unittests/Transforms/Utils/CMakeLists.txt
@@ -26,6 +26,7 @@ add_llvm_unittest(UtilsTests
   LoopUtilsTest.cpp
   MemTransferLowering.cpp
   ModuleUtilsTest.cpp
+  MergeFunctionsTest.cpp
   ScalarEvolutionExpanderTest.cpp
   SizeOptsTest.cpp
   SSAUpdaterBulkTest.cpp
diff --git a/llvm/unittests/Transforms/Utils/MergeFunctionsTest.cpp b/llvm/unittests/Transforms/Utils/MergeFunctionsTest.cpp
new file mode 100644
index 00000000000000..696c5391ef4f68
--- /dev/null
+++ b/llvm/unittests/Transforms/Utils/MergeFunctionsTest.cpp
@@ -0,0 +1,271 @@
+//===- MergeFunctionsTest.cpp - Unit tests for
+//MergeFunctionsPass-----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Transforms/IPO/MergeFunctions.h"
+#include "llvm/AsmParser/Parser.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/Module.h"
+#include "llvm/Support/SourceMgr.h"
+#include "gtest/gtest.h"
+#include <memory>
+
+using namespace llvm;
+
+namespace {
+
+TEST(MergeFunctions, TrueOutputModuleTest) {
+  LLVMContext Ctx;
+  SMDiagnostic Err;
+  std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
+        @.str = private unnamed_addr constant [10 x i8] c"On f: %d\0A\00", align 1
+        @.str.1 = private unnamed_addr constant [13 x i8] c"On main: %d\0A\00", align 1
+
+        define dso_local i32 @f(i32 noundef %arg) #0 {
+            entry:
+                %add109 = call i32 @_slice_add10(i32 %arg)
+                %call = call i32 (ptr, ...) @printf(ptr noundef @.str, i32 noundef %add109)
+                ret i32 %add109
+        }
+
+        declare i32 @printf(ptr noundef, ...) #1
+
+        define dso_local i32 @main(i32 noundef %argc, ptr noundef %argv) #0 {
+            entry:
+                %add99 = call i32 @_slice_add10(i32 %argc)
+                %call = call i32 @f(i32 noundef 2)
+                %sub = sub nsw i32 %call, 6
+                %call10 = call i32 (ptr, ...) @printf(ptr noundef @.str.1, i32 noundef %add99)
+                ret i32 %add99
+        }
+
+        define internal i32 @_slice_add10(i32 %arg) #2 {
+            sliceclone_entry:
+                %0 = mul nsw i32 %arg, %arg
+                %1 = mul nsw i32 %0, 2
+                %2 = mul nsw i32 %1, 2
+                %3 = mul nsw i32 %2, 2
+                %4 = add nsw i32 %3, 2
+                ret i32 %4
+        }
+
+        define internal i32 @_slice_add10_alt(i32 %arg) #2 {
+            sliceclone_entry:
+                %0 = mul nsw i32 %arg, %arg
+                %1 = mul nsw i32 %0, 2
+                %2 = mul nsw i32 %1, 2
+                %3 = mul nsw i32 %2, 2
+                %4 = add nsw i32 %3, 2
+                ret i32 %4
+        }
+
+        attributes #0 = { noinline nounwind uwtable "frame-pointer"="all" "min-legal-vector-width"="0" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+cmov,+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" }
+        attributes #1 = { "frame-pointer"="all" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+cmov,+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" }
+        attributes #2 = { nounwind willreturn }
+    )invalid",
+                                                Err, Ctx));
+
+  // Expects true after merging _slice_add10 and _slice_add10_alt
+  EXPECT_TRUE(MergeFunctionsPass::runOnModule(*M));
+}
+
+TEST(MergeFunctions, TrueOutputFunctionsTest) {
+  LLVMContext Ctx;
+  SMDiagnostic Err;
+  std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
+        @.str = private unnamed_addr constant [10 x i8] c"On f: %d\0A\00", align 1
+        @.str.1 = private unnamed_addr constant [13 x i8] c"On main: %d\0A\00", align 1
+
+        define dso_local i32 @f(i32 noundef %arg) #0 {
+            entry:
+                %add109 = call i32 @_slice_add10(i32 %arg)
+                %call = call i32 (ptr, ...) @printf(ptr noundef @.str, i32 noundef %add109)
+                ret i32 %add109
+        }
+
+        declare i32 @printf(ptr noundef, ...) #1
+
+        define dso_local i32 @main(i32 noundef %argc, ptr noundef %argv) #0 {
+            entry:
+                %add99 = call i32 @_slice_add10(i32 %argc)
+                %call = call i32 @f(i32 noundef 2)
+                %sub = sub nsw i32 %call, 6
+                %call10 = call i32 (ptr, ...) @printf(ptr noundef @.str.1, i32 noundef %add99)
+                ret i32 %add99
+        }
+
+        define internal i32 @_slice_add10(i32 %arg) #2 {
+            sliceclone_entry:
+                %0 = mul nsw i32 %arg, %arg
+                %1 = mul nsw i32 %0, 2
+                %2 = mul nsw i32 %1, 2
+                %3 = mul nsw i32 %2, 2
+                %4 = add nsw i32 %3, 2
+                ret i32 %4
+        }
+
+        define internal i32 @_slice_add10_alt(i32 %arg) #2 {
+            sliceclone_entry:
+                %0 = mul nsw i32 %arg, %arg
+                %1 = mul nsw i32 %0, 2
+                %2 = mul nsw i32 %1, 2
+                %3 = mul nsw i32 %2, 2
+                %4 = add nsw i32 %3, 2
+                ret i32 %4
+        }
+
+        attributes #0 = { noinline nounwind uwtable "frame-pointer"="all" "min-legal-vector-width"="0" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+cmov,+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" }
+        attributes #1 = { "frame-pointer"="all" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+cmov,+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" }
+        attributes #2 = { nounwind willreturn }
+    )invalid",
+                                                Err, Ctx));
+
+  std::set<Function *> FunctionsSet;
+  for (Function &F : *M)
+    FunctionsSet.insert(&F);
+
+  std::pair<bool, std::map<Function *, Function *>> MergeResult =
+      MergeFunctionsPass::runOnFunctions(FunctionsSet);
+
+  // Expects true after merging _slice_add10 and _slice_add10_alt
+  EXPECT_TRUE(MergeResult.first);
+
+  // Expects that both functions (_slice_add10 and _slice_add10_alt)
+  // be mapped to the same new function
+  EXPECT_TRUE(MergeResult.second.size() > 0);
+  std::map<Function *, Function *> DelToNew = MergeResult.second;
+  Function *NewFunction = M->getFunction("_slice_add10");
+  for (auto P : DelToNew)
+    if (P.second)
+      EXPECT_EQ(P.second, NewFunction);
+}
+
+TEST(MergeFunctions, FalseOutputModuleTest) {
+  LLVMContext Ctx;
+  SMDiagnostic Err;
+  std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
+        @.str = private unnamed_addr constant [10 x i8] c"On f: %d\0A\00", align 1
+        @.str.1 = private unnamed_addr constant [13 x i8] c"On main: %d\0A\00", align 1
+
+        define dso_local i32 @f(i32 noundef %arg) #0 {
+            entry:
+                %add109 = call i32 @_slice_add10(i32 %arg)
+                %call = call i32 (ptr, ...) @printf(ptr noundef @.str, i32 noundef %add109)
+                ret i32 %add109
+        }
+
+        declare i32 @printf(ptr noundef, ...) #1
+
+        define dso_local i32 @main(i32 noundef %argc, ptr noundef %argv) #0 {
+            entry:
+                %add99 = call i32 @_slice_add10(i32 %argc)
+                %call = call i32 @f(i32 noundef 2)
+                %sub = sub nsw i32 %call, 6
+                %call10 = call i32 (ptr, ...) @printf(ptr noundef @.str.1, i32 noundef %add99)
+                ret i32 %add99
+        }
+
+        define internal i32 @_slice_add10(i32 %arg) #2 {
+            sliceclone_entry:
+                %0 = mul nsw i32 %arg, %arg
+                %1 = mul nsw i32 %0, 2
+                %2 = mul nsw i32 %1, 2
+                %3 = mul nsw i32 %2, 2
+                %4 = add nsw i32 %3, 2
+                ret i32 %4
+        }
+
+        define internal i32 @_slice_add10_alt(i32 %arg) #2 {
+            sliceclone_entry:
+                %0 = mul nsw i32 %arg, %arg
+                %1 = mul nsw i32 %0, 2
+                %2 = mul nsw i32 %1, 2
+                %3 = mul nsw i32 %2, 2
+                %4 = add nsw i32 %3, 2
+                ret i32 %0
+        }
+
+        attributes #0 = { noinline nounwind uwtable "frame-pointer"="all" "min-legal-vector-width"="0" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+cmov,+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" }
+        attributes #1 = { "frame-pointer"="all" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+cmov,+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" }
+        attributes #2 = { nounwind willreturn }
+    )invalid",
+                                                Err, Ctx));
+
+  // Expects false after trying to merge _slice_add10 and _slice_add10_alt
+  EXPECT_FALSE(MergeFunctionsPass::runOnModule(*M));
+}
+
+TEST(MergeFunctions, FalseOutputFunctionsTest) {
+  LLVMContext Ctx;
+  SMDiagnostic Err;
+  std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
+        @.str = private unnamed_addr constant [10 x i8] c"On f: %d\0A\00", align 1
+        @.str.1 = private unnamed_addr constant [13 x i8] c"On main: %d\0A\00", align 1
+
+        define dso_local i32 @f(i32 noundef %arg) #0 {
+            entry:
+                %add109 = call i32 @_slice_add10(i32 %arg)
+                %call = call i32 (ptr, ...) @printf(ptr noundef @.str, i32 noundef %add109)
+                ret i32 %add109
+        }
+
+        declare i32 @printf(ptr noundef, ...) #1
+
+        define dso_local i32 @main(i32 noundef %argc, ptr noundef %argv) #0 {
+            entry:
+                %add99 = call i32 @_slice_add10(i32 %argc)
+                %call = call i32 @f(i32 noundef 2)
+                %sub = sub nsw i32 %call, 6
+                %call10 = call i32 (ptr, ...) @printf(ptr noundef @.str.1, i32 noundef %add99)
+                ret i32 %add99
+        }
+
+        define internal i32 @_slice_add10(i32 %arg) #2 {
+            sliceclone_entry:
+                %0 = mul nsw i32 %arg, %arg
+                %1 = mul nsw i32 %0, 2
+                %2 = mul nsw i32 %1, 2
+                %3 = mul nsw i32 %2, 2
+                %4 = add nsw i32 %3, 2
+                ret i32 %4
+        }
+
+        define internal i32 @_slice_add10_alt(i32 %arg) #2 {
+            sliceclone_entry:
+                %0 = mul nsw i32 %arg, %arg
+                %1 = mul nsw i32 %0, 2
+                %2 = mul nsw i32 %1, 2
+                %3 = mul nsw i32 %2, 2
+                %4 = add nsw i32 %3, 2
+                ret i32 %0
+        }
+
+        attributes #0 = { noinline nounwind uwtable "frame-pointer"="all" "min-legal-vector-width"="0" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+cmov,+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" }
+        attributes #1 = { "frame-pointer"="all" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+cmov,+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" }
+        attributes #2 = { nounwind willreturn }
+    )invalid",
+                                                Err, Ctx));
+
+  std::set<Function *> FunctionsSet;
+  for (Function &F : *M)
+    FunctionsSet.insert(&F);
+
+  std::pair<bool, std::map<Function *, Function *>> MergeResult =
+      MergeFunctionsPass::runOnFunctions(FunctionsSet);
+
+  for (auto P : MergeResult.second)
+    std::cout << P.first << " " << P.second << "\n";
+
+  // Expects false after trying to merge _slice_add10 and _slice_add10_alt
+  EXPECT_FALSE(MergeResult.first);
+
+  // Expects empty map
+  EXPECT_EQ(MergeResult.second.size(), 0u);
+}
+
+} // namespace
\ No newline at end of file
diff --git a/llvm/utils/gn/secondary/llvm/unittests/Transforms/Utils/BUILD.gn b/llvm/utils/gn/secondary/llvm/unittests/Transforms/Utils/BUILD.gn
index 380ed71a2bc010..fcea55c91f083c 100644
--- a/llvm/utils/gn/secondary/llvm/unittests/Transforms/Utils/BUILD.gn
+++ b/llvm/utils/gn/secondary/llvm/unittests/Transforms/Utils/BUILD.gn
@@ -27,6 +27,7 @@ unittest("UtilsTests") {
     "LoopUtilsTest.cpp",
     "MemTransferLowering.cpp",
     "ModuleUtilsTest.cpp",
+    "MergeFunctionsTest.cpp",
     "ProfDataUtilTest.cpp",
     "SSAUpdaterBulkTest.cpp",
     "ScalarEvolutionExpanderTest.cpp",

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants