Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CallPromotionUtil] See through function alias when devirtualizing a virtual call on an alloca. #80736

Merged
merged 3 commits into from
Feb 6, 2024

Conversation

minglotus-6
Copy link
Contributor

@minglotus-6 minglotus-6 commented Feb 5, 2024

  • Extract utility function from DevirtModule::tryFindVirtualCallTargets, which sees through an alias to a function. Call this utility function in the WPD callsite.
  • For type profiling work, this helper function will be used by indirect-call-promotion pass to find the function pointer at a specified vtable offset (an example in this line)

…tVTableOffset that finds functions through alias. Use it in CallPromotionUtils which didn't promote aliasee previously

- The utility function is extracted from the implementation of
  'DevirtModule::tryFindVirtualCallTargets'.
@llvmbot
Copy link
Collaborator

llvmbot commented Feb 5, 2024

@llvm/pr-subscribers-llvm-transforms

@llvm/pr-subscribers-llvm-analysis

Author: Mingming Liu (minglotus-6)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/80736.diff

4 Files Affected:

  • (modified) llvm/include/llvm/Analysis/TypeMetadataUtils.h (+5)
  • (modified) llvm/lib/Analysis/TypeMetadataUtils.cpp (+17)
  • (modified) llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp (+4-6)
  • (modified) llvm/lib/Transforms/Utils/CallPromotionUtils.cpp (+4-7)
diff --git a/llvm/include/llvm/Analysis/TypeMetadataUtils.h b/llvm/include/llvm/Analysis/TypeMetadataUtils.h
index dab67aad1ab0e..9f8c364b49375 100644
--- a/llvm/include/llvm/Analysis/TypeMetadataUtils.h
+++ b/llvm/include/llvm/Analysis/TypeMetadataUtils.h
@@ -14,6 +14,7 @@
 #ifndef LLVM_ANALYSIS_TYPEMETADATAUTILS_H
 #define LLVM_ANALYSIS_TYPEMETADATAUTILS_H
 
+#include "llvm/IR/GlobalVariable.h"
 #include <cstdint>
 
 namespace llvm {
@@ -77,6 +78,10 @@ void findDevirtualizableCallsForTypeCheckedLoad(
 Constant *getPointerAtOffset(Constant *I, uint64_t Offset, Module &M,
                              Constant *TopLevelGlobal = nullptr);
 
+// Given a vtable, returns the function pointer specified by Offset.
+std::pair<Function *, Constant *>
+getFunctionAtVTableOffset(GlobalVariable *GV, uint64_t Offset, Module &M);
+
 /// Finds the same "relative pointer" pattern as described above, where the
 /// target is `F`, and replaces the entire pattern with a constant zero.
 void replaceRelativePointerUsersWithZero(Function *F);
diff --git a/llvm/lib/Analysis/TypeMetadataUtils.cpp b/llvm/lib/Analysis/TypeMetadataUtils.cpp
index bbaee06ed8a55..0e6859f622a7d 100644
--- a/llvm/lib/Analysis/TypeMetadataUtils.cpp
+++ b/llvm/lib/Analysis/TypeMetadataUtils.cpp
@@ -201,6 +201,23 @@ Constant *llvm::getPointerAtOffset(Constant *I, uint64_t Offset, Module &M,
   return nullptr;
 }
 
+std::pair<Function *, Constant *>
+llvm::getFunctionAtVTableOffset(GlobalVariable *GV, uint64_t Offset,
+                                Module &M) {
+  Constant *Ptr = getPointerAtOffset(GV->getInitializer(), Offset, M, GV);
+  if (!Ptr)
+    return std::pair<Function *, Constant *>(nullptr, nullptr);
+
+  auto C = Ptr->stripPointerCasts();
+  // Make sure this is a function or alias to a function.
+  auto Fn = dyn_cast<Function>(C);
+  auto A = dyn_cast<GlobalAlias>(C);
+  if (!Fn && A)
+    Fn = dyn_cast<Function>(A->getAliasee());
+
+  return std::pair<Function *, Constant *>(Fn, C);
+}
+
 void llvm::replaceRelativePointerUsersWithZero(Function *F) {
   for (auto *U : F->users()) {
     auto *PtrExpr = dyn_cast<ConstantExpr>(U);
diff --git a/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp b/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp
index 01aba47cdbfff..154ff876a53cd 100644
--- a/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp
+++ b/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp
@@ -1071,12 +1071,10 @@ bool DevirtModule::tryFindVirtualCallTargets(
     if (!Ptr)
       return false;
 
-    auto C = Ptr->stripPointerCasts();
-    // Make sure this is a function or alias to a function.
-    auto Fn = dyn_cast<Function>(C);
-    auto A = dyn_cast<GlobalAlias>(C);
-    if (!Fn && A)
-      Fn = dyn_cast<Function>(A->getAliasee());
+    Function *Fn = nullptr;
+    Constant *C = nullptr;
+    std::tie(Fn, C) =
+        getFunctionAtVTableOffset(TM.Bits->GV, TM.Offset + ByteOffset, M);
 
     if (!Fn)
       return false;
diff --git a/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp b/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp
index e42cdab64446e..4e84927f1cfc9 100644
--- a/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp
+++ b/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp
@@ -597,16 +597,13 @@ bool llvm::tryPromoteCall(CallBase &CB) {
     // Not in the form of a global constant variable with an initializer.
     return false;
 
-  Constant *VTableGVInitializer = GV->getInitializer();
   APInt VTableGVOffset = VTableOffsetGVBase + VTableOffset;
   if (!(VTableGVOffset.getActiveBits() <= 64))
     return false; // Out of range.
-  Constant *Ptr = getPointerAtOffset(VTableGVInitializer,
-                                     VTableGVOffset.getZExtValue(),
-                                     *M);
-  if (!Ptr)
-    return false; // No constant (function) pointer found.
-  Function *DirectCallee = dyn_cast<Function>(Ptr->stripPointerCasts());
+
+  Function *DirectCallee = nullptr;
+  std::tie(DirectCallee, std::ignore) =
+      getFunctionAtVTableOffset(GV, VTableGVOffset.getZExtValue(), *M);
   if (!DirectCallee)
     return false; // No function pointer found.
 

Copy link
Contributor

@teresajohnson teresajohnson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggest making the PR title shorter and moving some of the details there currently to the description. Also, I think this can be marked NFC?

@@ -14,6 +14,7 @@
#ifndef LLVM_ANALYSIS_TYPEMETADATAUTILS_H
#define LLVM_ANALYSIS_TYPEMETADATAUTILS_H

#include "llvm/IR/GlobalVariable.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you instead just use a forward declaration for GlobalVariable, like we do for other types below?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done (along with #include <utility>).

TIL that forward declaration is almost always better,thanks!

@@ -1071,12 +1071,10 @@ bool DevirtModule::tryFindVirtualCallTargets(
if (!Ptr)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this earlier call to getPointerAtOffset be removed now?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah good catch. done.

@@ -77,6 +78,10 @@ void findDevirtualizableCallsForTypeCheckedLoad(
Constant *getPointerAtOffset(Constant *I, uint64_t Offset, Module &M,
Constant *TopLevelGlobal = nullptr);

// Given a vtable, returns the function pointer specified by Offset.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Document the pair contents.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done. Also updated function impl to return a pair of nullptr when the not-null pointer cannot be casted to function and is not an alias to a function.

@minglotus-6 minglotus-6 changed the title [TypeMetadataUtil][CallPromtionUtil]Add utility function getFunctionAtVTableOffset that finds functions through alias. Use it in CallPromotionUtils which didn't promote aliasee previously [TypeMetadataUtil] Add utility function getFunctionAtVTableOffset that finds functions through alias. Feb 5, 2024
@minglotus-6 minglotus-6 changed the title [TypeMetadataUtil] Add utility function getFunctionAtVTableOffset that finds functions through alias. [CallPromotionUtil] Devirtualize a virtual call on an alloca if the pointer at specified offset of the vtable is an alias to a function. Feb 5, 2024
@minglotus-6 minglotus-6 changed the title [CallPromotionUtil] Devirtualize a virtual call on an alloca if the pointer at specified offset of the vtable is an alias to a function. [CallPromotionUtil] See through function alias when devirtualizing a virtual call on an alloca. Feb 5, 2024
Copy link

github-actions bot commented Feb 5, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@minglotus-6
Copy link
Contributor Author

Suggest making the PR title shorter and moving some of the details there currently to the description. Also, I think this can be marked NFC?

Makes sense to make PR title shorter and move implementation details. It's not strictly an NFC (not stated in a simple way in prior title), rephrased the title.

@minglotus-6 minglotus-6 merged commit 8ea858b into llvm:main Feb 6, 2024
4 checks passed
@minglotus-6 minglotus-6 deleted the pointeroffset branch February 6, 2024 17:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants