Skip to content

Commit

Permalink
[WPD] Fix incorrect devirtualization after indirect call promotion
Browse files Browse the repository at this point in the history
Summary:
Add a dominance check to ensure that the possible devirtualizable
call is actually dominated by the type test/checked load intrinsic being
analyzed. With PGO, after indirect call promotion is performed during
the compile step, followed by inlining, we may have a type test in the
promoted and inlined sequence that allows an indirect call in that
sequence to be devirtualized. That indirect call (inserted by inlining
after promotion) will share the same vtable pointer as the fallback
indirect call that cannot be devirtualized.

Before this patch the code was incorrectly devirtualizing the fallback
indirect call.

See the new test and the example described there for more details.

Reviewers: pcc, vitalybuka

Subscribers: mehdi_amini, Prazek, eraman, steven_wu, dexonsmith, llvm-commits

Differential Revision: https://reviews.llvm.org/D52514

llvm-svn: 343226
  • Loading branch information
teresajohnson committed Sep 27, 2018
1 parent a9a5eee commit f24136f
Show file tree
Hide file tree
Showing 5 changed files with 233 additions and 48 deletions.
7 changes: 5 additions & 2 deletions llvm/include/llvm/Analysis/TypeMetadataUtils.h
Expand Up @@ -20,6 +20,8 @@

namespace llvm {

class DominatorTree;

/// The type of CFI jumptable needed for a function.
enum CfiFunctionLinkage {
CFL_Definition = 0,
Expand All @@ -39,15 +41,16 @@ struct DevirtCallSite {
/// call sites based on the call and return them in DevirtCalls.
void findDevirtualizableCallsForTypeTest(
SmallVectorImpl<DevirtCallSite> &DevirtCalls,
SmallVectorImpl<CallInst *> &Assumes, const CallInst *CI);
SmallVectorImpl<CallInst *> &Assumes, const CallInst *CI,
DominatorTree &DT);

/// Given a call to the intrinsic \@llvm.type.checked.load, find all
/// devirtualizable call sites based on the call and return them in DevirtCalls.
void findDevirtualizableCallsForTypeCheckedLoad(
SmallVectorImpl<DevirtCallSite> &DevirtCalls,
SmallVectorImpl<Instruction *> &LoadedPtrs,
SmallVectorImpl<Instruction *> &Preds, bool &HasNonCallUses,
const CallInst *CI);
const CallInst *CI, DominatorTree &DT);
}

#endif
23 changes: 12 additions & 11 deletions llvm/lib/Analysis/ModuleSummaryAnalysis.cpp
Expand Up @@ -147,7 +147,8 @@ static void addIntrinsicToSummary(
SetVector<FunctionSummary::VFuncId> &TypeTestAssumeVCalls,
SetVector<FunctionSummary::VFuncId> &TypeCheckedLoadVCalls,
SetVector<FunctionSummary::ConstVCall> &TypeTestAssumeConstVCalls,
SetVector<FunctionSummary::ConstVCall> &TypeCheckedLoadConstVCalls) {
SetVector<FunctionSummary::ConstVCall> &TypeCheckedLoadConstVCalls,
DominatorTree &DT) {
switch (CI->getCalledFunction()->getIntrinsicID()) {
case Intrinsic::type_test: {
auto *TypeMDVal = cast<MetadataAsValue>(CI->getArgOperand(1));
Expand All @@ -172,7 +173,7 @@ static void addIntrinsicToSummary(

SmallVector<DevirtCallSite, 4> DevirtCalls;
SmallVector<CallInst *, 4> Assumes;
findDevirtualizableCallsForTypeTest(DevirtCalls, Assumes, CI);
findDevirtualizableCallsForTypeTest(DevirtCalls, Assumes, CI, DT);
for (auto &Call : DevirtCalls)
addVCallToSet(Call, Guid, TypeTestAssumeVCalls,
TypeTestAssumeConstVCalls);
Expand All @@ -192,7 +193,7 @@ static void addIntrinsicToSummary(
SmallVector<Instruction *, 4> Preds;
bool HasNonCallUses = false;
findDevirtualizableCallsForTypeCheckedLoad(DevirtCalls, LoadedPtrs, Preds,
HasNonCallUses, CI);
HasNonCallUses, CI, DT);
// Any non-call uses of the result of llvm.type.checked.load will
// prevent us from optimizing away the llvm.type.test.
if (HasNonCallUses)
Expand All @@ -208,11 +209,10 @@ static void addIntrinsicToSummary(
}
}

static void
computeFunctionSummary(ModuleSummaryIndex &Index, const Module &M,
const Function &F, BlockFrequencyInfo *BFI,
ProfileSummaryInfo *PSI, bool HasLocalsInUsedOrAsm,
DenseSet<GlobalValue::GUID> &CantBePromoted) {
static void computeFunctionSummary(
ModuleSummaryIndex &Index, const Module &M, const Function &F,
BlockFrequencyInfo *BFI, ProfileSummaryInfo *PSI, DominatorTree &DT,
bool HasLocalsInUsedOrAsm, DenseSet<GlobalValue::GUID> &CantBePromoted) {
// Summary not currently supported for anonymous functions, they should
// have been named.
assert(F.hasName());
Expand Down Expand Up @@ -273,7 +273,7 @@ computeFunctionSummary(ModuleSummaryIndex &Index, const Module &M,
if (CI && CalledFunction->isIntrinsic()) {
addIntrinsicToSummary(
CI, TypeTests, TypeTestAssumeVCalls, TypeCheckedLoadVCalls,
TypeTestAssumeConstVCalls, TypeCheckedLoadConstVCalls);
TypeTestAssumeConstVCalls, TypeCheckedLoadConstVCalls, DT);
continue;
}
// We should have named any anonymous globals
Expand Down Expand Up @@ -488,18 +488,19 @@ ModuleSummaryIndex llvm::buildModuleSummaryIndex(
if (F.isDeclaration())
continue;

DominatorTree DT(const_cast<Function &>(F));
BlockFrequencyInfo *BFI = nullptr;
std::unique_ptr<BlockFrequencyInfo> BFIPtr;
if (GetBFICallback)
BFI = GetBFICallback(F);
else if (F.hasProfileData()) {
LoopInfo LI{DominatorTree(const_cast<Function &>(F))};
LoopInfo LI{DT};
BranchProbabilityInfo BPI{F, LI};
BFIPtr = llvm::make_unique<BlockFrequencyInfo>(F, BPI, LI);
BFI = BFIPtr.get();
}

computeFunctionSummary(Index, M, F, BFI, PSI,
computeFunctionSummary(Index, M, F, BFI, PSI, DT,
!LocalsUsed.empty() || HasLocalInlineAsmSymbol,
CantBePromoted);
}
Expand Down
42 changes: 27 additions & 15 deletions llvm/lib/Analysis/TypeMetadataUtils.cpp
Expand Up @@ -14,6 +14,7 @@

#include "llvm/Analysis/TypeMetadataUtils.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/Dominators.h"
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/Module.h"

Expand All @@ -22,11 +23,21 @@ using namespace llvm;
// Search for virtual calls that call FPtr and add them to DevirtCalls.
static void
findCallsAtConstantOffset(SmallVectorImpl<DevirtCallSite> &DevirtCalls,
bool *HasNonCallUses, Value *FPtr, uint64_t Offset) {
bool *HasNonCallUses, Value *FPtr, uint64_t Offset,
const CallInst *CI, DominatorTree &DT) {
for (const Use &U : FPtr->uses()) {
Value *User = U.getUser();
Instruction *User = cast<Instruction>(U.getUser());
// Ignore this instruction if it is not dominated by the type intrinsic
// being analyzed. Otherwise we may transform a call sharing the same
// vtable pointer incorrectly. Specifically, this situation can arise
// after indirect call promotion and inlining, where we may have uses
// of the vtable pointer guarded by a function pointer check, and a fallback
// indirect call.
if (!DT.dominates(CI, User))
continue;
if (isa<BitCastInst>(User)) {
findCallsAtConstantOffset(DevirtCalls, HasNonCallUses, User, Offset);
findCallsAtConstantOffset(DevirtCalls, HasNonCallUses, User, Offset, CI,
DT);
} else if (auto CI = dyn_cast<CallInst>(User)) {
DevirtCalls.push_back({Offset, CI});
} else if (auto II = dyn_cast<InvokeInst>(User)) {
Expand All @@ -38,31 +49,32 @@ findCallsAtConstantOffset(SmallVectorImpl<DevirtCallSite> &DevirtCalls,
}

// Search for virtual calls that load from VPtr and add them to DevirtCalls.
static void
findLoadCallsAtConstantOffset(const Module *M,
SmallVectorImpl<DevirtCallSite> &DevirtCalls,
Value *VPtr, int64_t Offset) {
static void findLoadCallsAtConstantOffset(
const Module *M, SmallVectorImpl<DevirtCallSite> &DevirtCalls, Value *VPtr,
int64_t Offset, const CallInst *CI, DominatorTree &DT) {
for (const Use &U : VPtr->uses()) {
Value *User = U.getUser();
if (isa<BitCastInst>(User)) {
findLoadCallsAtConstantOffset(M, DevirtCalls, User, Offset);
findLoadCallsAtConstantOffset(M, DevirtCalls, User, Offset, CI, DT);
} else if (isa<LoadInst>(User)) {
findCallsAtConstantOffset(DevirtCalls, nullptr, User, Offset);
findCallsAtConstantOffset(DevirtCalls, nullptr, User, Offset, CI, DT);
} else if (auto GEP = dyn_cast<GetElementPtrInst>(User)) {
// Take into account the GEP offset.
if (VPtr == GEP->getPointerOperand() && GEP->hasAllConstantIndices()) {
SmallVector<Value *, 8> Indices(GEP->op_begin() + 1, GEP->op_end());
int64_t GEPOffset = M->getDataLayout().getIndexedOffsetInType(
GEP->getSourceElementType(), Indices);
findLoadCallsAtConstantOffset(M, DevirtCalls, User, Offset + GEPOffset);
findLoadCallsAtConstantOffset(M, DevirtCalls, User, Offset + GEPOffset,
CI, DT);
}
}
}
}

void llvm::findDevirtualizableCallsForTypeTest(
SmallVectorImpl<DevirtCallSite> &DevirtCalls,
SmallVectorImpl<CallInst *> &Assumes, const CallInst *CI) {
SmallVectorImpl<CallInst *> &Assumes, const CallInst *CI,
DominatorTree &DT) {
assert(CI->getCalledFunction()->getIntrinsicID() == Intrinsic::type_test);

const Module *M = CI->getParent()->getParent()->getParent();
Expand All @@ -79,15 +91,15 @@ void llvm::findDevirtualizableCallsForTypeTest(
// If we found any, search for virtual calls based on %p and add them to
// DevirtCalls.
if (!Assumes.empty())
findLoadCallsAtConstantOffset(M, DevirtCalls,
CI->getArgOperand(0)->stripPointerCasts(), 0);
findLoadCallsAtConstantOffset(
M, DevirtCalls, CI->getArgOperand(0)->stripPointerCasts(), 0, CI, DT);
}

void llvm::findDevirtualizableCallsForTypeCheckedLoad(
SmallVectorImpl<DevirtCallSite> &DevirtCalls,
SmallVectorImpl<Instruction *> &LoadedPtrs,
SmallVectorImpl<Instruction *> &Preds, bool &HasNonCallUses,
const CallInst *CI) {
const CallInst *CI, DominatorTree &DT) {
assert(CI->getCalledFunction()->getIntrinsicID() ==
Intrinsic::type_checked_load);

Expand All @@ -114,5 +126,5 @@ void llvm::findDevirtualizableCallsForTypeCheckedLoad(

for (Value *LoadedPtr : LoadedPtrs)
findCallsAtConstantOffset(DevirtCalls, &HasNonCallUses, LoadedPtr,
Offset->getZExtValue());
Offset->getZExtValue(), CI, DT);
}
61 changes: 41 additions & 20 deletions llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp
Expand Up @@ -58,6 +58,7 @@
#include "llvm/IR/DataLayout.h"
#include "llvm/IR/DebugLoc.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/Dominators.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/GlobalAlias.h"
#include "llvm/IR/GlobalVariable.h"
Expand Down Expand Up @@ -406,6 +407,7 @@ void VTableSlotInfo::addCallSite(Value *VTable, CallSite CS,
struct DevirtModule {
Module &M;
function_ref<AAResults &(Function &)> AARGetter;
function_ref<DominatorTree &(Function &)> LookupDomTree;

ModuleSummaryIndex *ExportSummary;
const ModuleSummaryIndex *ImportSummary;
Expand Down Expand Up @@ -433,10 +435,12 @@ struct DevirtModule {

DevirtModule(Module &M, function_ref<AAResults &(Function &)> AARGetter,
function_ref<OptimizationRemarkEmitter &(Function *)> OREGetter,
function_ref<DominatorTree &(Function &)> LookupDomTree,
ModuleSummaryIndex *ExportSummary,
const ModuleSummaryIndex *ImportSummary)
: M(M), AARGetter(AARGetter), ExportSummary(ExportSummary),
ImportSummary(ImportSummary), Int8Ty(Type::getInt8Ty(M.getContext())),
: M(M), AARGetter(AARGetter), LookupDomTree(LookupDomTree),
ExportSummary(ExportSummary), ImportSummary(ImportSummary),
Int8Ty(Type::getInt8Ty(M.getContext())),
Int8PtrTy(Type::getInt8PtrTy(M.getContext())),
Int32Ty(Type::getInt32Ty(M.getContext())),
Int64Ty(Type::getInt64Ty(M.getContext())),
Expand Down Expand Up @@ -533,9 +537,10 @@ struct DevirtModule {

// Lower the module using the action and summary passed as command line
// arguments. For testing purposes only.
static bool runForTesting(
Module &M, function_ref<AAResults &(Function &)> AARGetter,
function_ref<OptimizationRemarkEmitter &(Function *)> OREGetter);
static bool
runForTesting(Module &M, function_ref<AAResults &(Function &)> AARGetter,
function_ref<OptimizationRemarkEmitter &(Function *)> OREGetter,
function_ref<DominatorTree &(Function &)> LookupDomTree);
};

struct WholeProgramDevirt : public ModulePass {
Expand Down Expand Up @@ -572,17 +577,23 @@ struct WholeProgramDevirt : public ModulePass {
return *ORE;
};

auto LookupDomTree = [this](Function &F) -> DominatorTree & {
return this->getAnalysis<DominatorTreeWrapperPass>(F).getDomTree();
};

if (UseCommandLine)
return DevirtModule::runForTesting(M, LegacyAARGetter(*this), OREGetter);
return DevirtModule::runForTesting(M, LegacyAARGetter(*this), OREGetter,
LookupDomTree);

return DevirtModule(M, LegacyAARGetter(*this), OREGetter, ExportSummary,
ImportSummary)
return DevirtModule(M, LegacyAARGetter(*this), OREGetter, LookupDomTree,
ExportSummary, ImportSummary)
.run();
}

void getAnalysisUsage(AnalysisUsage &AU) const override {
AU.addRequired<AssumptionCacheTracker>();
AU.addRequired<TargetLibraryInfoWrapperPass>();
AU.addRequired<DominatorTreeWrapperPass>();
}
};

Expand All @@ -592,6 +603,7 @@ INITIALIZE_PASS_BEGIN(WholeProgramDevirt, "wholeprogramdevirt",
"Whole program devirtualization", false, false)
INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
INITIALIZE_PASS_END(WholeProgramDevirt, "wholeprogramdevirt",
"Whole program devirtualization", false, false)
char WholeProgramDevirt::ID = 0;
Expand All @@ -611,15 +623,20 @@ PreservedAnalyses WholeProgramDevirtPass::run(Module &M,
auto OREGetter = [&](Function *F) -> OptimizationRemarkEmitter & {
return FAM.getResult<OptimizationRemarkEmitterAnalysis>(*F);
};
if (!DevirtModule(M, AARGetter, OREGetter, ExportSummary, ImportSummary)
auto LookupDomTree = [&FAM](Function &F) -> DominatorTree & {
return FAM.getResult<DominatorTreeAnalysis>(F);
};
if (!DevirtModule(M, AARGetter, OREGetter, LookupDomTree, ExportSummary,
ImportSummary)
.run())
return PreservedAnalyses::all();
return PreservedAnalyses::none();
}

bool DevirtModule::runForTesting(
Module &M, function_ref<AAResults &(Function &)> AARGetter,
function_ref<OptimizationRemarkEmitter &(Function *)> OREGetter) {
function_ref<OptimizationRemarkEmitter &(Function *)> OREGetter,
function_ref<DominatorTree &(Function &)> LookupDomTree) {
ModuleSummaryIndex Summary(/*HaveGVs=*/false);

// Handle the command-line summary arguments. This code is for testing
Expand All @@ -637,7 +654,7 @@ bool DevirtModule::runForTesting(

bool Changed =
DevirtModule(
M, AARGetter, OREGetter,
M, AARGetter, OREGetter, LookupDomTree,
ClSummaryAction == PassSummaryAction::Export ? &Summary : nullptr,
ClSummaryAction == PassSummaryAction::Import ? &Summary : nullptr)
.run();
Expand Down Expand Up @@ -1342,7 +1359,7 @@ void DevirtModule::scanTypeTestUsers(Function *TypeTestFunc,
// points to a member of the type identifier %md. Group calls by (type ID,
// offset) pair (effectively the identity of the virtual function) and store
// to CallSlots.
DenseSet<Value *> SeenPtrs;
DenseSet<CallSite> SeenCallSites;
for (auto I = TypeTestFunc->use_begin(), E = TypeTestFunc->use_end();
I != E;) {
auto CI = dyn_cast<CallInst>(I->getUser());
Expand All @@ -1353,19 +1370,22 @@ void DevirtModule::scanTypeTestUsers(Function *TypeTestFunc,
// Search for virtual calls based on %p and add them to DevirtCalls.
SmallVector<DevirtCallSite, 1> DevirtCalls;
SmallVector<CallInst *, 1> Assumes;
findDevirtualizableCallsForTypeTest(DevirtCalls, Assumes, CI);
auto &DT = LookupDomTree(*CI->getFunction());
findDevirtualizableCallsForTypeTest(DevirtCalls, Assumes, CI, DT);

// If we found any, add them to CallSlots. Only do this if we haven't seen
// the vtable pointer before, as it may have been CSE'd with pointers from
// other call sites, and we don't want to process call sites multiple times.
// If we found any, add them to CallSlots.
if (!Assumes.empty()) {
Metadata *TypeId =
cast<MetadataAsValue>(CI->getArgOperand(1))->getMetadata();
Value *Ptr = CI->getArgOperand(0)->stripPointerCasts();
if (SeenPtrs.insert(Ptr).second) {
for (DevirtCallSite Call : DevirtCalls) {
for (DevirtCallSite Call : DevirtCalls) {
// Only add this CallSite if we haven't seen it before. The vtable
// pointer may have been CSE'd with pointers from other call sites,
// and we don't want to process call sites multiple times. We can't
// just skip the vtable Ptr if it has been seen before, however, since
// it may be shared by type tests that dominate different calls.
if (SeenCallSites.insert(Call.CS).second)
CallSlots[{TypeId, Call.Offset}].addCallSite(Ptr, Call.CS, nullptr);
}
}
}

Expand Down Expand Up @@ -1399,8 +1419,9 @@ void DevirtModule::scanTypeCheckedLoadUsers(Function *TypeCheckedLoadFunc) {
SmallVector<Instruction *, 1> LoadedPtrs;
SmallVector<Instruction *, 1> Preds;
bool HasNonCallUses = false;
auto &DT = LookupDomTree(*CI->getFunction());
findDevirtualizableCallsForTypeCheckedLoad(DevirtCalls, LoadedPtrs, Preds,
HasNonCallUses, CI);
HasNonCallUses, CI, DT);

// Start by generating "pessimistic" code that explicitly loads the function
// pointer from the vtable and performs the type check. If possible, we will
Expand Down

0 comments on commit f24136f

Please sign in to comment.