-
Notifications
You must be signed in to change notification settings - Fork 12k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[CallPromotionUtils]Implement conditional indirect call promotion with vtable-based comparison #81378
Conversation
Created using spr 1.3.4 [skip ci]
…hCond from versionCallSite
…h vtable-based comparison
@llvm/pr-subscribers-llvm-transforms Author: Mingming Liu (minglotus-6) Changes
Full diff: https://github.com/llvm/llvm-project/pull/81378.diff 3 Files Affected:
diff --git a/llvm/include/llvm/Transforms/Utils/CallPromotionUtils.h b/llvm/include/llvm/Transforms/Utils/CallPromotionUtils.h
index fcb384ec361339..32b252d132c04c 100644
--- a/llvm/include/llvm/Transforms/Utils/CallPromotionUtils.h
+++ b/llvm/include/llvm/Transforms/Utils/CallPromotionUtils.h
@@ -14,10 +14,16 @@
#ifndef LLVM_TRANSFORMS_UTILS_CALLPROMOTIONUTILS_H
#define LLVM_TRANSFORMS_UTILS_CALLPROMOTIONUTILS_H
+#include <cstdint>
+
namespace llvm {
+template <typename T> class ArrayRef;
+class Constant;
class CallBase;
class CastInst;
class Function;
+class GlobalVariable;
+class Instruction;
class MDNode;
class Value;
@@ -41,7 +47,9 @@ bool isLegalToPromote(const CallBase &CB, Function *Callee,
CallBase &promoteCall(CallBase &CB, Function *Callee,
CastInst **RetBitCast = nullptr);
-/// Promote the given indirect call site to conditionally call \p Callee.
+/// Promote the given indirect call site to conditionally call \p Callee. The
+/// promoted direct call instruction is predicated on `CB.getCalledOperand() ==
+/// Callee`.
///
/// This function creates an if-then-else structure at the location of the call
/// site. The original call site is moved into the "else" block. A clone of the
@@ -51,6 +59,31 @@ CallBase &promoteCall(CallBase &CB, Function *Callee,
CallBase &promoteCallWithIfThenElse(CallBase &CB, Function *Callee,
MDNode *BranchWeights = nullptr);
+/// This is similar to `promoteCallWithIfThenElse` except that the condition to
+/// promote a virtual call is that \p VPtr is the same as any of \p
+/// AddressPoints.
+///
+/// This function is expected to be used on virtual calls (a subset of indirect
+/// calls). \p VPtr is the virtual table address stored in the objects, and
+/// \p AddressPoints contains address points of vtables to be compared with.
+///
+/// It's the responsibility of caller to guarantee the transformation
+/// correctness (by specifying \p VPtr and \p AddressPoints properly).
+///
+/// This function doesn't sink the address-calculation instructions of indirect
+/// callee to the indirect call fallback. The subsequent passes (e.g.
+/// inst-combine) should sink them if possible and handle the sink of debug
+/// intrinsics together.
+CallBase &promoteCallWithVTableCmp(CallBase &CB, Instruction *VPtr,
+ Function *Callee,
+ ArrayRef<Constant *> AddressPoints,
+ MDNode *BranchWeights);
+
+/// Returns a constant representing the vtable's address point specified by the
+/// offset. Caller should ensure \p AddressPointOffset is valid.
+Constant *getVTableAddressPointOffset(GlobalVariable *VTable,
+ uint32_t AddressPointOffset);
+
/// Try to promote (devirtualize) a virtual call on an Alloca. Return true on
/// success.
///
@@ -74,13 +107,17 @@ CallBase &promoteCallWithIfThenElse(CallBase &CB, Function *Callee,
///
bool tryPromoteCall(CallBase &CB);
+/// Predicate and clone the given call site using the given condition.
+CallBase &versionCallSiteWithCond(CallBase &CB, Value *Cond,
+ MDNode *BranchWeights);
+
/// Predicate and clone the given call site.
///
-/// This function creates an if-then-else structure at the location of the call
-/// site. The "if" condition compares the call site's called value to the given
-/// callee. The original call site is moved into the "else" block, and a clone
-/// of the call site is placed in the "then" block. The cloned instruction is
-/// returned.
+/// This function creates an if-then-else structure at the location of the
+/// call site. The "if" condition compares the call site's called value to
+/// the given callee. The original call site is moved into the "else" block,
+/// and a clone of the call site is placed in the "then" block. The cloned
+/// instruction is returned.
CallBase &versionCallSite(CallBase &CB, Value *Callee, MDNode *BranchWeights);
} // end namespace llvm
diff --git a/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp b/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp
index d0cf0792eface0..ea855b9a4d8416 100644
--- a/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp
+++ b/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp
@@ -12,9 +12,11 @@
//===----------------------------------------------------------------------===//
#include "llvm/Transforms/Utils/CallPromotionUtils.h"
+#include "llvm/ADT/STLExtras.h"
#include "llvm/Analysis/Loads.h"
#include "llvm/Analysis/TypeMetadataUtils.h"
#include "llvm/IR/AttributeMask.h"
+#include "llvm/IR/Constant.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Instructions.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
@@ -185,6 +187,24 @@ static void createRetBitCast(CallBase &CB, Type *RetTy, CastInst **RetBitCast) {
U->replaceUsesOfWith(&CB, Cast);
}
+// Returns the or result of all icmp instructions.
+static Value *getOrResult(const SmallVector<Value *, 2> &ICmps,
+ IRBuilder<> &Builder) {
+ assert(!ICmps.empty() && "Must have at least one icmp instructions");
+ if (ICmps.size() == 1)
+ return ICmps[0];
+
+ SmallVector<Value *, 2> OrResults;
+ int i = 0, NumICmp = ICmps.size();
+ for (i = 0; i + 1 < NumICmp; i += 2)
+ OrResults.push_back(Builder.CreateOr(ICmps[i], ICmps[i + 1], "icmp-or"));
+
+ if (i < NumICmp)
+ OrResults.push_back(ICmps[i]);
+
+ return getOrResult(OrResults, Builder);
+}
+
/// Predicate and clone the given call site.
///
/// This function creates an if-then-else structure at the location of the call
@@ -276,8 +296,8 @@ static void createRetBitCast(CallBase &CB, Type *RetTy, CastInst **RetBitCast) {
/// ; The original call instruction stays in its original block.
/// %t0 = musttail call i32 %ptr()
/// ret %t0
-static CallBase &versionCallSiteWithCond(CallBase &CB, Value *Cond,
- MDNode *BranchWeights) {
+CallBase &llvm::versionCallSiteWithCond(CallBase &CB, Value *Cond,
+ MDNode *BranchWeights) {
IRBuilder<> Builder(&CB);
CallBase *OrigInst = &CB;
@@ -565,6 +585,46 @@ CallBase &llvm::promoteCallWithIfThenElse(CallBase &CB, Function *Callee,
return promoteCall(NewInst, Callee);
}
+Constant *llvm::getVTableAddressPointOffset(GlobalVariable *VTable,
+ uint32_t AddressPointOffset) {
+ Module &M = *VTable->getParent();
+ const DataLayout &DL = M.getDataLayout();
+ LLVMContext &Context = M.getContext();
+ Type *VTableType = VTable->getValueType();
+ assert(AddressPointOffset < DL.getTypeAllocSize(VTableType) &&
+ "Out-of-bound access");
+ APInt AddressPointOffsetAPInt(32, AddressPointOffset, false);
+ SmallVector<APInt> Indices =
+ DL.getGEPIndicesForOffset(VTableType, AddressPointOffsetAPInt);
+ SmallVector<llvm::Constant *> GEPIndices;
+ for (const auto &Index : Indices)
+ GEPIndices.push_back(llvm::ConstantInt::get(Type::getInt32Ty(Context),
+ Index.getZExtValue()));
+
+ return ConstantExpr::getInBoundsGetElementPtr(VTable->getValueType(), VTable,
+ GEPIndices);
+}
+
+CallBase &llvm::promoteCallWithVTableCmp(CallBase &CB, Instruction *VPtr,
+ Function *Callee,
+ ArrayRef<Constant *> AddressPoints,
+ MDNode *BranchWeights) {
+ assert(!AddressPoints.empty() && "Caller should guarantee");
+ IRBuilder<> Builder(&CB);
+ SmallVector<Value *, 2> ICmps;
+ for (auto &AddressPoint : AddressPoints)
+ ICmps.push_back(Builder.CreateICmpEQ(VPtr, AddressPoint));
+
+ Value *Cond = getOrResult(ICmps, Builder);
+
+ // Version the indirect call site. If Cond is true, 'NewInst' will be
+ // executed, otherwise the original call site will be executed.
+ CallBase &NewInst = versionCallSiteWithCond(CB, Cond, BranchWeights);
+
+ // Promote 'NewInst' so that it directly calls the desired function.
+ return promoteCall(NewInst, Callee);
+}
+
bool llvm::tryPromoteCall(CallBase &CB) {
assert(!CB.getCalledFunction());
Module *M = CB.getCaller()->getParent();
diff --git a/llvm/unittests/Transforms/Utils/CallPromotionUtilsTest.cpp b/llvm/unittests/Transforms/Utils/CallPromotionUtilsTest.cpp
index eff8e27d36d641..227156378369b5 100644
--- a/llvm/unittests/Transforms/Utils/CallPromotionUtilsTest.cpp
+++ b/llvm/unittests/Transforms/Utils/CallPromotionUtilsTest.cpp
@@ -8,9 +8,12 @@
#include "llvm/Transforms/Utils/CallPromotionUtils.h"
#include "llvm/AsmParser/Parser.h"
+#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/MDBuilder.h"
#include "llvm/IR/Module.h"
+#include "llvm/IR/NoFolder.h"
#include "llvm/Support/SourceMgr.h"
#include "gtest/gtest.h"
@@ -368,3 +371,119 @@ declare %struct2 @_ZN4Impl3RunEv(%class.Impl* %this)
bool IsPromoted = tryPromoteCall(*CI);
EXPECT_FALSE(IsPromoted);
}
+
+TEST(CallPromotionUtilsTest, getVTableAddressPointOffset) {
+ LLVMContext C;
+ std::unique_ptr<Module> M = parseIR(C,
+ R"IR(
+target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
+target triple = "x86_64-unknown-linux-gnu"
+
+@_ZTV8Derived2 = constant { [3 x ptr], [3 x ptr], [4 x ptr] } { [3 x ptr] [ptr null, ptr null, ptr @_ZN5Base35func3Ev], [3 x ptr] [ptr inttoptr (i64 -8 to ptr), ptr null, ptr @_ZN5Base25func2Ev], [4 x ptr] [ptr inttoptr (i64 -16 to ptr), ptr null, ptr @_ZN5Base15func0Ev, ptr @_ZN5Base15func1Ev] }
+
+declare i32 @_ZN5Base15func1Ev(ptr)
+declare i32 @_ZN5Base25func2Ev(ptr)
+declare i32 @_ZN5Base15func0Ev(ptr)
+declare void @_ZN5Base35func3Ev(ptr)
+)IR");
+ GlobalVariable *GV = M->getGlobalVariable("_ZTV8Derived2");
+
+ for (auto [AddressPointOffset, Index] :
+ {std::pair{16, 0}, {40, 1}, {64, 2}}) {
+ Constant *AddressPoint =
+ getVTableAddressPointOffset(GV, AddressPointOffset);
+
+ ConstantExpr *GEP = dyn_cast<ConstantExpr>(AddressPoint);
+ ASSERT_TRUE(GEP);
+ SmallVector<Constant *> Indices = {
+ llvm::ConstantInt::get(Type::getInt32Ty(C), 0U),
+ llvm::ConstantInt::get(Type::getInt32Ty(C), Index),
+ llvm::ConstantInt::get(Type::getInt32Ty(C), 2U)};
+ EXPECT_EQ(GEP, ConstantExpr::getInBoundsGetElementPtr(GV->getValueType(),
+ GV, Indices));
+ }
+}
+
+TEST(CallPromotionUtilsTest, promoteCallWithVTableCmp) {
+ LLVMContext C;
+ std::unique_ptr<Module> M = parseIR(C,
+ R"IR(
+target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
+target triple = "x86_64-unknown-linux-gnu"
+
+@_ZTV5Base1 = constant { [4 x ptr] } { [4 x ptr] [ptr null, ptr null, ptr @_ZN5Base15func0Ev, ptr @_ZN5Base15func1Ev] }, !type !0
+@_ZTV8Derived1 = constant { [4 x ptr], [3 x ptr] } { [4 x ptr] [ptr inttoptr (i64 -8 to ptr), ptr null, ptr @_ZN5Base15func0Ev, ptr @_ZN5Base15func1Ev], [3 x ptr] [ptr null, ptr null, ptr @_ZN5Base25func2Ev] }, !type !0, !type !1, !type !2
+@_ZTV8Derived2 = constant { [3 x ptr], [3 x ptr], [4 x ptr] } { [3 x ptr] [ptr null, ptr null, ptr @_ZN5Base35func3Ev], [3 x ptr] [ptr inttoptr (i64 -8 to ptr), ptr null, ptr @_ZN5Base25func2Ev], [4 x ptr] [ptr inttoptr (i64 -16 to ptr), ptr null, ptr @_ZN5Base15func0Ev, ptr @_ZN5Base15func1Ev] }, !type !3, !type !4, !type !5, !type !6
+
+define i32 @testfunc(ptr %d) {
+entry:
+ %vtable = load ptr, ptr %d, !prof !7
+ %vfn = getelementptr inbounds ptr, ptr %vtable, i64 1
+ %0 = load ptr, ptr %vfn
+ %call = tail call i32 %0(ptr %d), !prof !8
+ ret i32 %call
+}
+
+define i32 @_ZN5Base15func1Ev(ptr %this) {
+entry:
+ ret i32 2
+}
+
+declare i32 @_ZN5Base25func2Ev(ptr)
+declare i32 @_ZN5Base15func0Ev(ptr)
+declare void @_ZN5Base35func3Ev(ptr)
+
+!0 = !{i64 16, !"_ZTS5Base1"}
+!1 = !{i64 48, !"_ZTS5Base2"}
+!2 = !{i64 16, !"_ZTS8Derived1"}
+!3 = !{i64 64, !"_ZTS5Base1"}
+!4 = !{i64 40, !"_ZTS5Base2"}
+!5 = !{i64 16, !"_ZTS5Base3"}
+!6 = !{i64 16, !"_ZTS8Derived2"}
+!7 = !{!"VP", i32 2, i64 1600, i64 -9064381665493407289, i64 800, i64 5035968517245772950, i64 500, i64 3215870116411581797, i64 300}
+!8 = !{!"VP", i32 0, i64 1600, i64 6804820478065511155, i64 1600})IR");
+
+ Function *F = M->getFunction("testfunc");
+ ASSERT_TRUE(F);
+ CallInst *CI = dyn_cast<CallInst>(&*std::next(F->front().rbegin()));
+ ASSERT_TRUE(CI && CI->isIndirectCall());
+
+ LoadInst *FuncPtr = dyn_cast<LoadInst>(CI->getCalledOperand());
+ ASSERT_TRUE(FuncPtr);
+
+ GetElementPtrInst *GEP =
+ dyn_cast<GetElementPtrInst>(FuncPtr->getPointerOperand());
+ ASSERT_TRUE(GEP);
+
+ // Create the constant and the branch weights
+ SmallVector<Constant *, 3> VTableAddressPoints;
+
+ for (auto &[VTableName, AddressPointOffset] : {std::pair{"_ZTV5Base1", 16},
+ {"_ZTV8Derived1", 16},
+ {"_ZTV8Derived2", 64}})
+ VTableAddressPoints.push_back(getVTableAddressPointOffset(
+ M->getGlobalVariable(VTableName), AddressPointOffset));
+
+ MDBuilder MDB(C);
+ MDNode *BranchWeights = MDB.createBranchWeights(1600, 0);
+
+ size_t OrigEntryBBSize = F->front().size();
+
+ LoadInst *VPtr = dyn_cast<LoadInst>(&*F->front().begin());
+
+ Function *Callee = M->getFunction("_ZN5Base15func1Ev");
+ // Tests that promoted direct call is returned.
+ CallBase &DirectCB = promoteCallWithVTableCmp(
+ *CI, VPtr, Callee, VTableAddressPoints, BranchWeights);
+ EXPECT_EQ(DirectCB.getCalledOperand(), Callee);
+
+ // GEP and FuncPtr remains in the original block. `promoteCallWithVTableCmp`
+ // doesn't sink them to the basic block of indirect fallback.
+ BasicBlock *EntryBB = &F->front();
+ EXPECT_EQ(EntryBB, GEP->getParent());
+ EXPECT_EQ(EntryBB, FuncPtr->getParent());
+
+ // Promotion inserts 3 icmp instructions and 2 or instructions, and removes
+ // 1 call instruction from the entry block.
+ EXPECT_EQ(F->front().size(), OrigEntryBBSize + 4);
+}
|
Sorry for the spam! |
/// | ||
/// This function is expected to be used on virtual calls (a subset of indirect | ||
/// calls). \p VPtr is the virtual table address stored in the objects, and | ||
/// \p AddressPoints contains address points of vtables to be compared with. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Clarify address points: an address point of a vtable is the starting address of function pointer entries in the table.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done.
@@ -185,12 +187,30 @@ static void createRetBitCast(CallBase &CB, Type *RetTy, CastInst **RetBitCast) { | |||
U->replaceUsesOfWith(&CB, Cast); | |||
} | |||
|
|||
// Returns the or result of all icmp instructions. | |||
static Value *getOrResult(const SmallVector<Value *, 2> &ICmps, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: change the name to getORResult or just getORs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done.
*CI, VPtr, Callee, VTableAddressPoints, BranchWeights); | ||
EXPECT_EQ(DirectCB.getCalledOperand(), Callee); | ||
|
||
// GEP and FuncPtr remains in the original block. `promoteCallWithVTableCmp` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this required behavior? If yes, add a comment about it how later pass to sink them to the fallback branch.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this required behavior? If yes, add a comment about it how later pass to sink them to the fallback branch.
Not really. They exist to highlight these two instructions doesn't sink (yet), and partially explain the number of instructions in the entry basic block. I removed them since they doesn't matter much for the test.
"Out-of-bound access"); | ||
APInt AddressPointOffsetAPInt(32, AddressPointOffset, false); | ||
SmallVector<APInt> Indices = | ||
DL.getGEPIndicesForOffset(VTableType, AddressPointOffsetAPInt); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't do this. Emit the GEP with i8 element type instead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@nikic thanks for taking a look!
Before I make the change, do you mean getelementptr i8, ptr %p, i64 40
should be preferred over getelementptr [[i32 x 10], x 10], ptr %p, i64 0, i64 1, i64 0
as mentioned in https://groups.google.com/g/llvm-dev/c/U7D6z7ZnKy8?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, GEP i8 is canonical if the offset is constant.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done.
With i8 GEP, there isn't much value to have a unit test for getVTableAddressPointOffset
. So make this a static function in CallPromotionUtils.cpp
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I found CallPromotionUtils.cpp
doesn't use getVTableAddressPointOffset
either, so moved it to CallPromotionUtilsTest.cpp
if (i < NumICmp) | ||
OrResults.push_back(ICmps[i]); | ||
|
||
return getOrs(OrResults, Builder); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is emitting a tree reduction at this point important? If not, you can just use Builder.CreateOr(ICmps)
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For vtable-based comparison to be efficient (compared with function-based comparison), the number of icmp
s needs to be gated. I tend to make the threshold 2(i.e., at most two icmp
s) after tuning this on various workload. Thereby will use Builder.CreateOr(ICmps)
as suggested.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good. If later tree balancer does not do the expected for high threshold, it can be revisited.
@nikic any other comments? thanks! |
if (i < NumICmp) | ||
OrResults.push_back(ICmps[i]); | ||
|
||
return getOrs(OrResults, Builder); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good. If later tree balancer does not do the expected for high threshold, it can be revisited.
Given the code sequence
The transformation looks like
It's intentional that
Inst 1
andInst2
remains inbb
(not inif.false.orig_indirect
). A follow up patch will implement code to sink them (something like howinstcombine
would sink instructions along with debug intrinsics if possible)The parent patch is [NFC][CallPromotionUtils]Extract a helper function versionCallSiteWithCond from versionCallSite #81181