Skip to content

Commit

Permalink
[CallPromotionUtils]Implement conditional indirect call promotion wit…
Browse files Browse the repository at this point in the history
…h vtable-based comparison (#81378)

* Given the code sequence
   ```
   bb:
     %vtable = load ptr, ptr %d, !prof !8
     %vfn = getelementptr inbounds ptr, ptr %vtable, i64 1
     %1 = load ptr, ptr %vfn
     %call = tail call i32 %1(ptr %d), !prof !9
  ```
   The transformation looks like

   ```
   bb:
    %vtable = load ptr, ptr %d, align 8
    %vfn = getelementptr inbounds i8, ptr %vtable, i64 8  <-- Inst 1
    %func-addr = load ptr, ptr %vfn, align 8  <-- Inst 2
    # compare loaded pointers with address point of vtables
%1 = icmp eq ptr %vtable, getelementptr inbounds (i8, ptr @_ZTV<VTable>,
i32 16)
br i1 %1, label %if.true.direct_targ, label %if.false.orig_indirect,
!prof !18

  if.true.direct_targ:                              ; preds = %bb
    %2 = tail call i32 @<direct-call>(ptr nonnull %d)
    br label %if.end.icp

  if.false.orig_indirect:                           ; preds = %bb
    %call = tail call i32 %func-addr(ptr nonnull %d)
    br label %if.end.icp

if.end.icp: ; preds = %if.false.orig_indirect, %if.true.direct_targ
%4 = phi i32 [ %call, %if.false.orig_indirect ], [ %2,
%if.true.direct_targ ]

   ```
It's intentional that `Inst 1` and `Inst2` remains in `bb` (not in
`if.false.orig_indirect`). A follow up patch will implement code to sink
them (something like how `instcombine` would
[sink](https://github.com/llvm/llvm-project/blob/2fcfc9754a16805b81e541dc8222a8b5cf17a121/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp#L4293)
instructions along with [debug
intrinsics](https://github.com/llvm/llvm-project/blob/2fcfc9754a16805b81e541dc8222a8b5cf17a121/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp#L4356-L4368)
if possible)

* The parent patch is #81181
  • Loading branch information
minglotus-6 committed May 19, 2024
1 parent 2f52bbe commit 5d3f296
Show file tree
Hide file tree
Showing 3 changed files with 143 additions and 10 deletions.
33 changes: 27 additions & 6 deletions llvm/include/llvm/Transforms/Utils/CallPromotionUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,12 @@
#define LLVM_TRANSFORMS_UTILS_CALLPROMOTIONUTILS_H

namespace llvm {
template <typename T> class ArrayRef;
class Constant;
class CallBase;
class CastInst;
class Function;
class Instruction;
class MDNode;
class Value;

Expand All @@ -41,7 +44,9 @@ bool isLegalToPromote(const CallBase &CB, Function *Callee,
CallBase &promoteCall(CallBase &CB, Function *Callee,
CastInst **RetBitCast = nullptr);

/// Promote the given indirect call site to conditionally call \p Callee.
/// Promote the given indirect call site to conditionally call \p Callee. The
/// promoted direct call instruction is predicated on `CB.getCalledOperand() ==
/// Callee`.
///
/// This function creates an if-then-else structure at the location of the call
/// site. The original call site is moved into the "else" block. A clone of the
Expand All @@ -51,6 +56,22 @@ CallBase &promoteCall(CallBase &CB, Function *Callee,
CallBase &promoteCallWithIfThenElse(CallBase &CB, Function *Callee,
MDNode *BranchWeights = nullptr);

/// This is similar to `promoteCallWithIfThenElse` except that the condition to
/// promote a virtual call is that \p VPtr is the same as any of \p
/// AddressPoints.
///
/// This function is expected to be used on virtual calls (a subset of indirect
/// calls). \p VPtr is the virtual table address stored in the objects, and
/// \p AddressPoints contains vtable address points. A vtable address point is
/// a location inside the vtable that's referenced by vpointer in C++ objects.
///
/// TODO: sink the address-calculation instructions of indirect callee to the
/// indirect call fallback after transformation.
CallBase &promoteCallWithVTableCmp(CallBase &CB, Instruction *VPtr,
Function *Callee,
ArrayRef<Constant *> AddressPoints,
MDNode *BranchWeights);

/// Try to promote (devirtualize) a virtual call on an Alloca. Return true on
/// success.
///
Expand All @@ -76,11 +97,11 @@ bool tryPromoteCall(CallBase &CB);

/// Predicate and clone the given call site.
///
/// This function creates an if-then-else structure at the location of the call
/// site. The "if" condition compares the call site's called value to the given
/// callee. The original call site is moved into the "else" block, and a clone
/// of the call site is placed in the "then" block. The cloned instruction is
/// returned.
/// This function creates an if-then-else structure at the location of the
/// call site. The "if" condition compares the call site's called value to
/// the given callee. The original call site is moved into the "else" block,
/// and a clone of the call site is placed in the "then" block. The cloned
/// instruction is returned.
CallBase &versionCallSite(CallBase &CB, Value *Callee, MDNode *BranchWeights);

} // end namespace llvm
Expand Down
32 changes: 28 additions & 4 deletions llvm/lib/Transforms/Utils/CallPromotionUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@
//===----------------------------------------------------------------------===//

#include "llvm/Transforms/Utils/CallPromotionUtils.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Analysis/Loads.h"
#include "llvm/Analysis/TypeMetadataUtils.h"
#include "llvm/IR/AttributeMask.h"
#include "llvm/IR/Constant.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Instructions.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
Expand Down Expand Up @@ -188,9 +190,9 @@ static void createRetBitCast(CallBase &CB, Type *RetTy, CastInst **RetBitCast) {
/// Predicate and clone the given call site.
///
/// This function creates an if-then-else structure at the location of the call
/// site. The "if" condition is specified by `Cond`. The original call site is
/// moved into the "else" block, and a clone of the call site is placed in the
/// "then" block. The cloned instruction is returned.
/// site. The "if" condition is specified by `Cond`.
/// The original call site is moved into the "else" block, and a clone of the
/// call site is placed in the "then" block. The cloned instruction is returned.
///
/// For example, the call instruction below:
///
Expand Down Expand Up @@ -518,7 +520,8 @@ CallBase &llvm::promoteCall(CallBase &CB, Function *Callee,
Type *FormalTy = CalleeType->getParamType(ArgNo);
Type *ActualTy = Arg->getType();
if (FormalTy != ActualTy) {
auto *Cast = CastInst::CreateBitOrPointerCast(Arg, FormalTy, "", CB.getIterator());
auto *Cast =
CastInst::CreateBitOrPointerCast(Arg, FormalTy, "", CB.getIterator());
CB.setArgOperand(ArgNo, Cast);

// Remove any incompatible attributes for the argument.
Expand Down Expand Up @@ -568,6 +571,27 @@ CallBase &llvm::promoteCallWithIfThenElse(CallBase &CB, Function *Callee,
return promoteCall(NewInst, Callee);
}

CallBase &llvm::promoteCallWithVTableCmp(CallBase &CB, Instruction *VPtr,
Function *Callee,
ArrayRef<Constant *> AddressPoints,
MDNode *BranchWeights) {
assert(!AddressPoints.empty() && "Caller should guarantee");
IRBuilder<> Builder(&CB);
SmallVector<Value *, 2> ICmps;
for (auto &AddressPoint : AddressPoints)
ICmps.push_back(Builder.CreateICmpEQ(VPtr, AddressPoint));

// TODO: Perform tree height reduction if the number of ICmps is high.
Value *Cond = Builder.CreateOr(ICmps);

// Version the indirect call site. If Cond is true, 'NewInst' will be
// executed, otherwise the original call site will be executed.
CallBase &NewInst = versionCallSiteWithCond(CB, Cond, BranchWeights);

// Promote 'NewInst' so that it directly calls the desired function.
return promoteCall(NewInst, Callee);
}

bool llvm::tryPromoteCall(CallBase &CB) {
assert(!CB.getCalledFunction());
Module *M = CB.getCaller()->getParent();
Expand Down
88 changes: 88 additions & 0 deletions llvm/unittests/Transforms/Utils/CallPromotionUtilsTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,12 @@

#include "llvm/Transforms/Utils/CallPromotionUtils.h"
#include "llvm/AsmParser/Parser.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/MDBuilder.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/NoFolder.h"
#include "llvm/Support/SourceMgr.h"
#include "gtest/gtest.h"

Expand All @@ -24,6 +27,21 @@ static std::unique_ptr<Module> parseIR(LLVMContext &C, const char *IR) {
return Mod;
}

// Returns a constant representing the vtable's address point specified by the
// offset.
static Constant *getVTableAddressPointOffset(GlobalVariable *VTable,
uint32_t AddressPointOffset) {
Module &M = *VTable->getParent();
LLVMContext &Context = M.getContext();
assert(AddressPointOffset <
M.getDataLayout().getTypeAllocSize(VTable->getValueType()) &&
"Out-of-bound access");

return ConstantExpr::getInBoundsGetElementPtr(
Type::getInt8Ty(Context), VTable,
llvm::ConstantInt::get(Type::getInt32Ty(Context), AddressPointOffset));
}

TEST(CallPromotionUtilsTest, TryPromoteCall) {
LLVMContext C;
std::unique_ptr<Module> M = parseIR(C,
Expand Down Expand Up @@ -368,3 +386,73 @@ declare %struct2 @_ZN4Impl3RunEv(%class.Impl* %this)
bool IsPromoted = tryPromoteCall(*CI);
EXPECT_FALSE(IsPromoted);
}

TEST(CallPromotionUtilsTest, promoteCallWithVTableCmp) {
LLVMContext C;
std::unique_ptr<Module> M = parseIR(C,
R"IR(
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
target triple = "x86_64-unknown-linux-gnu"
@_ZTV5Base1 = constant { [4 x ptr] } { [4 x ptr] [ptr null, ptr null, ptr @_ZN5Base15func0Ev, ptr @_ZN5Base15func1Ev] }, !type !0
@_ZTV8Derived1 = constant { [4 x ptr], [3 x ptr] } { [4 x ptr] [ptr inttoptr (i64 -8 to ptr), ptr null, ptr @_ZN5Base15func0Ev, ptr @_ZN5Base15func1Ev], [3 x ptr] [ptr null, ptr null, ptr @_ZN5Base25func2Ev] }, !type !0, !type !1, !type !2
@_ZTV8Derived2 = constant { [3 x ptr], [3 x ptr], [4 x ptr] } { [3 x ptr] [ptr null, ptr null, ptr @_ZN5Base35func3Ev], [3 x ptr] [ptr inttoptr (i64 -8 to ptr), ptr null, ptr @_ZN5Base25func2Ev], [4 x ptr] [ptr inttoptr (i64 -16 to ptr), ptr null, ptr @_ZN5Base15func0Ev, ptr @_ZN5Base15func1Ev] }, !type !3, !type !4, !type !5, !type !6
define i32 @testfunc(ptr %d) {
entry:
%vtable = load ptr, ptr %d, !prof !7
%vfn = getelementptr inbounds ptr, ptr %vtable, i64 1
%0 = load ptr, ptr %vfn
%call = tail call i32 %0(ptr %d), !prof !8
ret i32 %call
}
define i32 @_ZN5Base15func1Ev(ptr %this) {
entry:
ret i32 2
}
declare i32 @_ZN5Base25func2Ev(ptr)
declare i32 @_ZN5Base15func0Ev(ptr)
declare void @_ZN5Base35func3Ev(ptr)
!0 = !{i64 16, !"_ZTS5Base1"}
!1 = !{i64 48, !"_ZTS5Base2"}
!2 = !{i64 16, !"_ZTS8Derived1"}
!3 = !{i64 64, !"_ZTS5Base1"}
!4 = !{i64 40, !"_ZTS5Base2"}
!5 = !{i64 16, !"_ZTS5Base3"}
!6 = !{i64 16, !"_ZTS8Derived2"}
!7 = !{!"VP", i32 2, i64 1600, i64 -9064381665493407289, i64 800, i64 5035968517245772950, i64 500, i64 3215870116411581797, i64 300}
!8 = !{!"VP", i32 0, i64 1600, i64 6804820478065511155, i64 1600})IR");

Function *F = M->getFunction("testfunc");
CallInst *CI = dyn_cast<CallInst>(&*std::next(F->front().rbegin()));
ASSERT_TRUE(CI && CI->isIndirectCall());

// Create the constant and the branch weights
SmallVector<Constant *, 3> VTableAddressPoints;

for (auto &[VTableName, AddressPointOffset] : {std::pair{"_ZTV5Base1", 16},
{"_ZTV8Derived1", 16},
{"_ZTV8Derived2", 64}})
VTableAddressPoints.push_back(getVTableAddressPointOffset(
M->getGlobalVariable(VTableName), AddressPointOffset));

MDBuilder MDB(C);
MDNode *BranchWeights = MDB.createBranchWeights(1600, 0);

size_t OrigEntryBBSize = F->front().size();

LoadInst *VPtr = dyn_cast<LoadInst>(&*F->front().begin());

Function *Callee = M->getFunction("_ZN5Base15func1Ev");
// Tests that promoted direct call is returned.
CallBase &DirectCB = promoteCallWithVTableCmp(
*CI, VPtr, Callee, VTableAddressPoints, BranchWeights);
EXPECT_EQ(DirectCB.getCalledOperand(), Callee);

// Promotion inserts 3 icmp instructions and 2 or instructions, and removes
// 1 call instruction from the entry block.
EXPECT_EQ(F->front().size(), OrigEntryBBSize + 4);
}

0 comments on commit 5d3f296

Please sign in to comment.