Skip to content

Commit

Permalink
[CodeExtractor] Update function's assumption cache after extracting b…
Browse files Browse the repository at this point in the history
…locks from it

Summary: Assumption cache's self-updating mechanism does not correctly handle the case when blocks are extracted from the function by the CodeExtractor. As a result function's assumption cache may have stale references to the llvm.assume calls that were moved to the outlined function. This patch fixes this problem by removing extracted llvm.assume calls from the function’s assumption cache.

Reviewers: hfinkel, vsk, fhahn, davidxl, sanjoy

Reviewed By: hfinkel, vsk

Subscribers: llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D57215

llvm-svn: 353500
  • Loading branch information
sndmitriev committed Feb 8, 2019
1 parent df6770f commit 807960e
Show file tree
Hide file tree
Showing 8 changed files with 144 additions and 33 deletions.
8 changes: 8 additions & 0 deletions llvm/include/llvm/Analysis/AssumptionCache.h
Expand Up @@ -103,6 +103,10 @@ class AssumptionCache {
/// not already be in the cache.
void registerAssumption(CallInst *CI);

/// Remove an \@llvm.assume intrinsic from this function's cache if it has
/// been added to the cache earlier.
void unregisterAssumption(CallInst *CI);

/// Update the cache of values being affected by this assumption (i.e.
/// the values about which this assumption provides information).
void updateAffectedValues(CallInst *CI);
Expand Down Expand Up @@ -208,6 +212,10 @@ class AssumptionCacheTracker : public ImmutablePass {
/// existing cache will be returned.
AssumptionCache &getAssumptionCache(Function &F);

/// Return the cached assumptions for a function if it has already been
/// scanned. Otherwise return nullptr.
AssumptionCache *lookupAssumptionCache(Function &F);

AssumptionCacheTracker();
~AssumptionCacheTracker() override;

Expand Down
4 changes: 4 additions & 0 deletions llvm/include/llvm/Transforms/Utils/CodeExtractor.h
Expand Up @@ -26,6 +26,7 @@ class BasicBlock;
class BlockFrequency;
class BlockFrequencyInfo;
class BranchProbabilityInfo;
class AssumptionCache;
class CallInst;
class DominatorTree;
class Function;
Expand Down Expand Up @@ -56,6 +57,7 @@ class Value;
const bool AggregateArgs;
BlockFrequencyInfo *BFI;
BranchProbabilityInfo *BPI;
AssumptionCache *AC;

// If true, varargs functions can be extracted.
bool AllowVarArgs;
Expand Down Expand Up @@ -84,6 +86,7 @@ class Value;
CodeExtractor(ArrayRef<BasicBlock *> BBs, DominatorTree *DT = nullptr,
bool AggregateArgs = false, BlockFrequencyInfo *BFI = nullptr,
BranchProbabilityInfo *BPI = nullptr,
AssumptionCache *AC = nullptr,
bool AllowVarArgs = false, bool AllowAlloca = false,
std::string Suffix = "");

Expand All @@ -94,6 +97,7 @@ class Value;
CodeExtractor(DominatorTree &DT, Loop &L, bool AggregateArgs = false,
BlockFrequencyInfo *BFI = nullptr,
BranchProbabilityInfo *BPI = nullptr,
AssumptionCache *AC = nullptr,
std::string Suffix = "");

/// Perform the extraction, returning the new function.
Expand Down
28 changes: 26 additions & 2 deletions llvm/lib/Analysis/AssumptionCache.cpp
Expand Up @@ -53,11 +53,11 @@ AssumptionCache::getOrInsertAffectedValues(Value *V) {
return AVIP.first->second;
}

void AssumptionCache::updateAffectedValues(CallInst *CI) {
static void findAffectedValues(CallInst *CI,
SmallVectorImpl<Value *> &Affected) {
// Note: This code must be kept in-sync with the code in
// computeKnownBitsFromAssume in ValueTracking.

SmallVector<Value *, 16> Affected;
auto AddAffected = [&Affected](Value *V) {
if (isa<Argument>(V)) {
Affected.push_back(V);
Expand Down Expand Up @@ -108,6 +108,11 @@ void AssumptionCache::updateAffectedValues(CallInst *CI) {
AddAffectedFromEq(B);
}
}
}

void AssumptionCache::updateAffectedValues(CallInst *CI) {
SmallVector<Value *, 16> Affected;
findAffectedValues(CI, Affected);

for (auto &AV : Affected) {
auto &AVV = getOrInsertAffectedValues(AV);
Expand All @@ -116,6 +121,18 @@ void AssumptionCache::updateAffectedValues(CallInst *CI) {
}
}

void AssumptionCache::unregisterAssumption(CallInst *CI) {
SmallVector<Value *, 16> Affected;
findAffectedValues(CI, Affected);

for (auto &AV : Affected) {
auto AVI = AffectedValues.find_as(AV);
if (AVI != AffectedValues.end())
AffectedValues.erase(AVI);
}
remove_if(AssumeHandles, [CI](WeakTrackingVH &VH) { return CI == VH; });
}

void AssumptionCache::AffectedValueCallbackVH::deleted() {
auto AVI = AC->AffectedValues.find(getValPtr());
if (AVI != AC->AffectedValues.end())
Expand Down Expand Up @@ -240,6 +257,13 @@ AssumptionCache &AssumptionCacheTracker::getAssumptionCache(Function &F) {
return *IP.first->second;
}

AssumptionCache *AssumptionCacheTracker::lookupAssumptionCache(Function &F) {
auto I = AssumptionCaches.find_as(&F);
if (I != AssumptionCaches.end())
return I->second.get();
return nullptr;
}

void AssumptionCacheTracker::verifyAnalysis() const {
// FIXME: In the long term the verifier should not be controllable with a
// flag. We should either fix all passes to correctly update the assumption
Expand Down
33 changes: 21 additions & 12 deletions llvm/lib/Transforms/IPO/HotColdSplitting.cpp
Expand Up @@ -173,8 +173,9 @@ class HotColdSplitting {
HotColdSplitting(ProfileSummaryInfo *ProfSI,
function_ref<BlockFrequencyInfo *(Function &)> GBFI,
function_ref<TargetTransformInfo &(Function &)> GTTI,
std::function<OptimizationRemarkEmitter &(Function &)> *GORE)
: PSI(ProfSI), GetBFI(GBFI), GetTTI(GTTI), GetORE(GORE) {}
std::function<OptimizationRemarkEmitter &(Function &)> *GORE,
function_ref<AssumptionCache *(Function &)> LAC)
: PSI(ProfSI), GetBFI(GBFI), GetTTI(GTTI), GetORE(GORE), LookupAC(LAC) {}
bool run(Module &M);

private:
Expand All @@ -183,11 +184,13 @@ class HotColdSplitting {
bool outlineColdRegions(Function &F, bool HasProfileSummary);
Function *extractColdRegion(const BlockSequence &Region, DominatorTree &DT,
BlockFrequencyInfo *BFI, TargetTransformInfo &TTI,
OptimizationRemarkEmitter &ORE, unsigned Count);
OptimizationRemarkEmitter &ORE,
AssumptionCache *AC, unsigned Count);
ProfileSummaryInfo *PSI;
function_ref<BlockFrequencyInfo *(Function &)> GetBFI;
function_ref<TargetTransformInfo &(Function &)> GetTTI;
std::function<OptimizationRemarkEmitter &(Function &)> *GetORE;
function_ref<AssumptionCache *(Function &)> LookupAC;
};

class HotColdSplittingLegacyPass : public ModulePass {
Expand All @@ -198,10 +201,10 @@ class HotColdSplittingLegacyPass : public ModulePass {
}

void getAnalysisUsage(AnalysisUsage &AU) const override {
AU.addRequired<AssumptionCacheTracker>();
AU.addRequired<BlockFrequencyInfoWrapperPass>();
AU.addRequired<ProfileSummaryInfoWrapperPass>();
AU.addRequired<TargetTransformInfoWrapperPass>();
AU.addUsedIfAvailable<AssumptionCacheTracker>();
}

bool runOnModule(Module &M) override;
Expand Down Expand Up @@ -316,12 +319,13 @@ Function *HotColdSplitting::extractColdRegion(const BlockSequence &Region,
BlockFrequencyInfo *BFI,
TargetTransformInfo &TTI,
OptimizationRemarkEmitter &ORE,
AssumptionCache *AC,
unsigned Count) {
assert(!Region.empty());

// TODO: Pass BFI and BPI to update profile information.
CodeExtractor CE(Region, &DT, /* AggregateArgs */ false, /* BFI */ nullptr,
/* BPI */ nullptr, /* AllowVarArgs */ false,
/* BPI */ nullptr, AC, /* AllowVarArgs */ false,
/* AllowAlloca */ false,
/* Suffix */ "cold." + std::to_string(Count));

Expand Down Expand Up @@ -577,6 +581,7 @@ bool HotColdSplitting::outlineColdRegions(Function &F, bool HasProfileSummary) {

TargetTransformInfo &TTI = GetTTI(F);
OptimizationRemarkEmitter &ORE = (*GetORE)(F);
AssumptionCache *AC = LookupAC(F);

// Find all cold regions.
for (BasicBlock *BB : RPOT) {
Expand Down Expand Up @@ -638,8 +643,8 @@ bool HotColdSplitting::outlineColdRegions(Function &F, bool HasProfileSummary) {
BB->dump();
});

Function *Outlined =
extractColdRegion(SubRegion, *DT, BFI, TTI, ORE, OutlinedFunctionID);
Function *Outlined = extractColdRegion(SubRegion, *DT, BFI, TTI, ORE, AC,
OutlinedFunctionID);
if (Outlined) {
++OutlinedFunctionID;
Changed = true;
Expand Down Expand Up @@ -698,17 +703,21 @@ bool HotColdSplittingLegacyPass::runOnModule(Module &M) {
ORE.reset(new OptimizationRemarkEmitter(&F));
return *ORE.get();
};
auto LookupAC = [this](Function &F) -> AssumptionCache * {
if (auto *ACT = getAnalysisIfAvailable<AssumptionCacheTracker>())
return ACT->lookupAssumptionCache(F);
return nullptr;
};

return HotColdSplitting(PSI, GBFI, GTTI, &GetORE).run(M);
return HotColdSplitting(PSI, GBFI, GTTI, &GetORE, LookupAC).run(M);
}

PreservedAnalyses
HotColdSplittingPass::run(Module &M, ModuleAnalysisManager &AM) {
auto &FAM = AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();

std::function<AssumptionCache &(Function &)> GetAssumptionCache =
[&FAM](Function &F) -> AssumptionCache & {
return FAM.getResult<AssumptionAnalysis>(F);
auto LookupAC = [&FAM](Function &F) -> AssumptionCache * {
return FAM.getCachedResult<AssumptionAnalysis>(F);
};

auto GBFI = [&FAM](Function &F) {
Expand All @@ -729,7 +738,7 @@ HotColdSplittingPass::run(Module &M, ModuleAnalysisManager &AM) {

ProfileSummaryInfo *PSI = &AM.getResult<ProfileSummaryAnalysis>(M);

if (HotColdSplitting(PSI, GBFI, GTTI, &GetORE).run(M))
if (HotColdSplitting(PSI, GBFI, GTTI, &GetORE, LookupAC).run(M))
return PreservedAnalyses::none();
return PreservedAnalyses::all();
}
Expand Down
7 changes: 6 additions & 1 deletion llvm/lib/Transforms/IPO/LoopExtractor.cpp
Expand Up @@ -14,6 +14,7 @@
//===----------------------------------------------------------------------===//

#include "llvm/ADT/Statistic.h"
#include "llvm/Analysis/AssumptionCache.h"
#include "llvm/Analysis/LoopPass.h"
#include "llvm/IR/Dominators.h"
#include "llvm/IR/Instructions.h"
Expand Down Expand Up @@ -50,6 +51,7 @@ namespace {
AU.addRequiredID(LoopSimplifyID);
AU.addRequired<DominatorTreeWrapperPass>();
AU.addRequired<LoopInfoWrapperPass>();
AU.addUsedIfAvailable<AssumptionCacheTracker>();
}
};
}
Expand Down Expand Up @@ -138,7 +140,10 @@ bool LoopExtractor::runOnLoop(Loop *L, LPPassManager &LPM) {
if (ShouldExtractLoop) {
if (NumLoops == 0) return Changed;
--NumLoops;
CodeExtractor Extractor(DT, *L);
AssumptionCache *AC = nullptr;
if (auto *ACT = getAnalysisIfAvailable<AssumptionCacheTracker>())
AC = ACT->lookupAssumptionCache(*L->getHeader()->getParent());
CodeExtractor Extractor(DT, *L, false, nullptr, nullptr, AC);
if (Extractor.extractCodeRegion() != nullptr) {
Changed = true;
// After extraction, the loop is replaced by a function call, so
Expand Down
46 changes: 33 additions & 13 deletions llvm/lib/Transforms/IPO/PartialInlining.cpp
Expand Up @@ -199,10 +199,12 @@ struct PartialInlinerImpl {

PartialInlinerImpl(
std::function<AssumptionCache &(Function &)> *GetAC,
function_ref<AssumptionCache *(Function &)> LookupAC,
std::function<TargetTransformInfo &(Function &)> *GTTI,
Optional<function_ref<BlockFrequencyInfo &(Function &)>> GBFI,
ProfileSummaryInfo *ProfSI)
: GetAssumptionCache(GetAC), GetTTI(GTTI), GetBFI(GBFI), PSI(ProfSI) {}
: GetAssumptionCache(GetAC), LookupAssumptionCache(LookupAC),
GetTTI(GTTI), GetBFI(GBFI), PSI(ProfSI) {}

bool run(Module &M);
// Main part of the transformation that calls helper functions to find
Expand All @@ -222,9 +224,11 @@ struct PartialInlinerImpl {
// Two constructors, one for single region outlining, the other for
// multi-region outlining.
FunctionCloner(Function *F, FunctionOutliningInfo *OI,
OptimizationRemarkEmitter &ORE);
OptimizationRemarkEmitter &ORE,
function_ref<AssumptionCache *(Function &)> LookupAC);
FunctionCloner(Function *F, FunctionOutliningMultiRegionInfo *OMRI,
OptimizationRemarkEmitter &ORE);
OptimizationRemarkEmitter &ORE,
function_ref<AssumptionCache *(Function &)> LookupAC);
~FunctionCloner();

// Prepare for function outlining: making sure there is only
Expand Down Expand Up @@ -260,11 +264,13 @@ struct PartialInlinerImpl {
std::unique_ptr<FunctionOutliningMultiRegionInfo> ClonedOMRI = nullptr;
std::unique_ptr<BlockFrequencyInfo> ClonedFuncBFI = nullptr;
OptimizationRemarkEmitter &ORE;
function_ref<AssumptionCache *(Function &)> LookupAC;
};

private:
int NumPartialInlining = 0;
std::function<AssumptionCache &(Function &)> *GetAssumptionCache;
function_ref<AssumptionCache *(Function &)> LookupAssumptionCache;
std::function<TargetTransformInfo &(Function &)> *GetTTI;
Optional<function_ref<BlockFrequencyInfo &(Function &)>> GetBFI;
ProfileSummaryInfo *PSI;
Expand Down Expand Up @@ -365,12 +371,17 @@ struct PartialInlinerLegacyPass : public ModulePass {
return ACT->getAssumptionCache(F);
};

auto LookupAssumptionCache = [ACT](Function &F) -> AssumptionCache * {
return ACT->lookupAssumptionCache(F);
};

std::function<TargetTransformInfo &(Function &)> GetTTI =
[&TTIWP](Function &F) -> TargetTransformInfo & {
return TTIWP->getTTI(F);
};

return PartialInlinerImpl(&GetAssumptionCache, &GetTTI, NoneType::None, PSI)
return PartialInlinerImpl(&GetAssumptionCache, LookupAssumptionCache,
&GetTTI, NoneType::None, PSI)
.run(M);
}
};
Expand Down Expand Up @@ -948,8 +959,9 @@ void PartialInlinerImpl::computeCallsiteToProfCountMap(
}

PartialInlinerImpl::FunctionCloner::FunctionCloner(
Function *F, FunctionOutliningInfo *OI, OptimizationRemarkEmitter &ORE)
: OrigFunc(F), ORE(ORE) {
Function *F, FunctionOutliningInfo *OI, OptimizationRemarkEmitter &ORE,
function_ref<AssumptionCache *(Function &)> LookupAC)
: OrigFunc(F), ORE(ORE), LookupAC(LookupAC) {
ClonedOI = llvm::make_unique<FunctionOutliningInfo>();

// Clone the function, so that we can hack away on it.
Expand All @@ -972,8 +984,9 @@ PartialInlinerImpl::FunctionCloner::FunctionCloner(

PartialInlinerImpl::FunctionCloner::FunctionCloner(
Function *F, FunctionOutliningMultiRegionInfo *OI,
OptimizationRemarkEmitter &ORE)
: OrigFunc(F), ORE(ORE) {
OptimizationRemarkEmitter &ORE,
function_ref<AssumptionCache *(Function &)> LookupAC)
: OrigFunc(F), ORE(ORE), LookupAC(LookupAC) {
ClonedOMRI = llvm::make_unique<FunctionOutliningMultiRegionInfo>();

// Clone the function, so that we can hack away on it.
Expand Down Expand Up @@ -1111,7 +1124,9 @@ bool PartialInlinerImpl::FunctionCloner::doMultiRegionFunctionOutlining() {
int CurrentOutlinedRegionCost = ComputeRegionCost(RegionInfo.Region);

CodeExtractor CE(RegionInfo.Region, &DT, /*AggregateArgs*/ false,
ClonedFuncBFI.get(), &BPI, /* AllowVarargs */ false);
ClonedFuncBFI.get(), &BPI,
LookupAC(*RegionInfo.EntryBlock->getParent()),
/* AllowVarargs */ false);

CE.findInputsOutputs(Inputs, Outputs, Sinks);

Expand Down Expand Up @@ -1193,7 +1208,7 @@ PartialInlinerImpl::FunctionCloner::doSingleRegionFunctionOutlining() {
// Extract the body of the if.
Function *OutlinedFunc =
CodeExtractor(ToExtract, &DT, /*AggregateArgs*/ false,
ClonedFuncBFI.get(), &BPI,
ClonedFuncBFI.get(), &BPI, LookupAC(*ClonedFunc),
/* AllowVarargs */ true)
.extractCodeRegion();

Expand Down Expand Up @@ -1257,7 +1272,7 @@ std::pair<bool, Function *> PartialInlinerImpl::unswitchFunction(Function *F) {
std::unique_ptr<FunctionOutliningMultiRegionInfo> OMRI =
computeOutliningColdRegionsInfo(F, ORE);
if (OMRI) {
FunctionCloner Cloner(F, OMRI.get(), ORE);
FunctionCloner Cloner(F, OMRI.get(), ORE, LookupAssumptionCache);

#ifndef NDEBUG
if (TracePartialInlining) {
Expand Down Expand Up @@ -1290,7 +1305,7 @@ std::pair<bool, Function *> PartialInlinerImpl::unswitchFunction(Function *F) {
if (!OI)
return {false, nullptr};

FunctionCloner Cloner(F, OI.get(), ORE);
FunctionCloner Cloner(F, OI.get(), ORE, LookupAssumptionCache);
Cloner.NormalizeReturnBlock();

Function *OutlinedFunction = Cloner.doSingleRegionFunctionOutlining();
Expand Down Expand Up @@ -1484,6 +1499,10 @@ PreservedAnalyses PartialInlinerPass::run(Module &M,
return FAM.getResult<AssumptionAnalysis>(F);
};

auto LookupAssumptionCache = [&FAM](Function &F) -> AssumptionCache * {
return FAM.getCachedResult<AssumptionAnalysis>(F);
};

std::function<BlockFrequencyInfo &(Function &)> GetBFI =
[&FAM](Function &F) -> BlockFrequencyInfo & {
return FAM.getResult<BlockFrequencyAnalysis>(F);
Expand All @@ -1496,7 +1515,8 @@ PreservedAnalyses PartialInlinerPass::run(Module &M,

ProfileSummaryInfo *PSI = &AM.getResult<ProfileSummaryAnalysis>(M);

if (PartialInlinerImpl(&GetAssumptionCache, &GetTTI, {GetBFI}, PSI)
if (PartialInlinerImpl(&GetAssumptionCache, LookupAssumptionCache, &GetTTI,
{GetBFI}, PSI)
.run(M))
return PreservedAnalyses::none();
return PreservedAnalyses::all();
Expand Down

0 comments on commit 807960e

Please sign in to comment.