Skip to content

Commit

Permalink
[AA] Pass AAResults through AAQueryInfo
Browse files Browse the repository at this point in the history
Currently, AAResultBase (from which alias analysis providers inherit)
stores a reference back to the AAResults aggregation it is part of,
so it can perform recursive alias analysis queries via
getBestAAResults().

This patch removes the back-reference from AAResultBase to AAResults,
and instead passes the used aggregation through the AAQueryInfo.
This can be used to perform recursive AA queries using the full
aggregation.

Differential Revision: https://reviews.llvm.org/D94363
  • Loading branch information
nikic committed Oct 6, 2022
1 parent d1f13c5 commit c5bf452
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 156 deletions.
112 changes: 16 additions & 96 deletions llvm/include/llvm/Analysis/AliasAnalysis.h
Expand Up @@ -458,6 +458,8 @@ template <> struct DenseMapInfo<AACacheLoc> {
}
};

class AAResults;

/// This class stores info we want to provide to or retain within an alias
/// query. By default, the root query is stateless and starts with a freshly
/// constructed info object. Specific alias analyses can use this query info to
Expand All @@ -477,6 +479,11 @@ class AAQueryInfo {
/// Whether this is a definitive (non-assumption) result.
bool isDefinitive() const { return NumAssumptionUses < 0; }
};

// Alias analysis result aggregration using which this query is performed.
// Can be used to perform recursive queries.
AAResults &AAR;

using AliasCacheT = SmallDenseMap<LocPair, CacheEntry, 8>;
AliasCacheT AliasCache;

Expand All @@ -493,13 +500,13 @@ class AAQueryInfo {
/// assumption is disproven.
SmallVector<AAQueryInfo::LocPair, 4> AssumptionBasedResults;

AAQueryInfo(CaptureInfo *CI) : CI(CI) {}
AAQueryInfo(AAResults &AAR, CaptureInfo *CI) : AAR(AAR), CI(CI) {}

/// Create a new AAQueryInfo based on this one, but with the cache cleared.
/// This is used for recursive queries across phis, where cache results may
/// not be valid.
AAQueryInfo withEmptyCache() {
AAQueryInfo NewAAQI(CI);
AAQueryInfo NewAAQI(AAR, CI);
NewAAQI.Depth = Depth;
return NewAAQI;
}
Expand All @@ -510,7 +517,7 @@ class SimpleAAQueryInfo : public AAQueryInfo {
SimpleCaptureInfo CI;

public:
SimpleAAQueryInfo() : AAQueryInfo(&CI) {}
SimpleAAQueryInfo(AAResults &AAR) : AAQueryInfo(AAR, &CI) {}
};

class BatchAAResults;
Expand Down Expand Up @@ -784,7 +791,7 @@ class AAResults {
/// helpers above.
ModRefInfo getModRefInfo(const Instruction *I,
const Optional<MemoryLocation> &OptLoc) {
SimpleAAQueryInfo AAQIP;
SimpleAAQueryInfo AAQIP(*this);
return getModRefInfo(I, OptLoc, AAQIP);
}

Expand All @@ -809,7 +816,7 @@ class AAResults {
ModRefInfo callCapturesBefore(const Instruction *I,
const MemoryLocation &MemLoc,
DominatorTree *DT) {
SimpleAAQueryInfo AAQIP;
SimpleAAQueryInfo AAQIP(*this);
return callCapturesBefore(I, MemLoc, DT, AAQIP);
}

Expand Down Expand Up @@ -850,7 +857,6 @@ class AAResults {
return canInstructionRangeModRef(I1, I2, MemoryLocation(Ptr, Size), Mode);
}

private:
AliasResult alias(const MemoryLocation &LocA, const MemoryLocation &LocB,
AAQueryInfo &AAQI);
bool pointsToConstantMemory(const MemoryLocation &Loc, AAQueryInfo &AAQI,
Expand Down Expand Up @@ -886,6 +892,7 @@ class AAResults {
FunctionModRefBehavior getModRefBehavior(const CallBase *Call,
AAQueryInfo &AAQI);

private:
class Concept;

template <typename T> class Model;
Expand Down Expand Up @@ -913,8 +920,8 @@ class BatchAAResults {
SimpleCaptureInfo SimpleCI;

public:
BatchAAResults(AAResults &AAR) : AA(AAR), AAQI(&SimpleCI) {}
BatchAAResults(AAResults &AAR, CaptureInfo *CI) : AA(AAR), AAQI(CI) {}
BatchAAResults(AAResults &AAR) : AA(AAR), AAQI(AAR, &SimpleCI) {}
BatchAAResults(AAResults &AAR, CaptureInfo *CI) : AA(AAR), AAQI(AAR, CI) {}

AliasResult alias(const MemoryLocation &LocA, const MemoryLocation &LocB) {
return AA.alias(LocA, LocB, AAQI);
Expand Down Expand Up @@ -973,10 +980,6 @@ class AAResults::Concept {
public:
virtual ~Concept() = 0;

/// An update API used internally by the AAResults to provide
/// a handle back to the top level aggregation.
virtual void setAAResults(AAResults *NewAAR) = 0;

//===--------------------------------------------------------------------===//
/// \name Alias Queries
/// @{
Expand Down Expand Up @@ -1038,13 +1041,9 @@ template <typename AAResultT> class AAResults::Model final : public Concept {
AAResultT &Result;

public:
explicit Model(AAResultT &Result, AAResults &AAR) : Result(Result) {
Result.setAAResults(&AAR);
}
explicit Model(AAResultT &Result, AAResults &AAR) : Result(Result) {}
~Model() override = default;

void setAAResults(AAResults *NewAAR) override { Result.setAAResults(NewAAR); }

AliasResult alias(const MemoryLocation &LocA, const MemoryLocation &LocB,
AAQueryInfo &AAQI) override {
return Result.alias(LocA, LocB, AAQI);
Expand Down Expand Up @@ -1093,93 +1092,14 @@ template <typename AAResultT> class AAResults::Model final : public Concept {
/// use virtual anywhere, the CRTP base class does static dispatch to the
/// derived type passed into it.
template <typename DerivedT> class AAResultBase {
// Expose some parts of the interface only to the AAResults::Model
// for wrapping. Specifically, this allows the model to call our
// setAAResults method without exposing it as a fully public API.
friend class AAResults::Model<DerivedT>;

/// A pointer to the AAResults object that this AAResult is
/// aggregated within. May be null if not aggregated.
AAResults *AAR = nullptr;

/// Helper to dispatch calls back through the derived type.
DerivedT &derived() { return static_cast<DerivedT &>(*this); }

/// A setter for the AAResults pointer, which is used to satisfy the
/// AAResults::Model contract.
void setAAResults(AAResults *NewAAR) { AAR = NewAAR; }

protected:
/// This proxy class models a common pattern where we delegate to either the
/// top-level \c AAResults aggregation if one is registered, or to the
/// current result if none are registered.
class AAResultsProxy {
AAResults *AAR;
DerivedT &CurrentResult;

public:
AAResultsProxy(AAResults *AAR, DerivedT &CurrentResult)
: AAR(AAR), CurrentResult(CurrentResult) {}

AliasResult alias(const MemoryLocation &LocA, const MemoryLocation &LocB,
AAQueryInfo &AAQI) {
return AAR ? AAR->alias(LocA, LocB, AAQI)
: CurrentResult.alias(LocA, LocB, AAQI);
}

bool pointsToConstantMemory(const MemoryLocation &Loc, AAQueryInfo &AAQI,
bool OrLocal) {
return AAR ? AAR->pointsToConstantMemory(Loc, AAQI, OrLocal)
: CurrentResult.pointsToConstantMemory(Loc, AAQI, OrLocal);
}

ModRefInfo getArgModRefInfo(const CallBase *Call, unsigned ArgIdx) {
return AAR ? AAR->getArgModRefInfo(Call, ArgIdx)
: CurrentResult.getArgModRefInfo(Call, ArgIdx);
}

FunctionModRefBehavior getModRefBehavior(const CallBase *Call,
AAQueryInfo &AAQI) {
return AAR ? AAR->getModRefBehavior(Call, AAQI)
: CurrentResult.getModRefBehavior(Call, AAQI);
}

FunctionModRefBehavior getModRefBehavior(const Function *F) {
return AAR ? AAR->getModRefBehavior(F) : CurrentResult.getModRefBehavior(F);
}

ModRefInfo getModRefInfo(const CallBase *Call, const MemoryLocation &Loc,
AAQueryInfo &AAQI) {
return AAR ? AAR->getModRefInfo(Call, Loc, AAQI)
: CurrentResult.getModRefInfo(Call, Loc, AAQI);
}

ModRefInfo getModRefInfo(const CallBase *Call1, const CallBase *Call2,
AAQueryInfo &AAQI) {
return AAR ? AAR->getModRefInfo(Call1, Call2, AAQI)
: CurrentResult.getModRefInfo(Call1, Call2, AAQI);
}
};

explicit AAResultBase() = default;

// Provide all the copy and move constructors so that derived types aren't
// constrained.
AAResultBase(const AAResultBase &Arg) {}
AAResultBase(AAResultBase &&Arg) {}

/// Get a proxy for the best AA result set to query at this time.
///
/// When this result is part of a larger aggregation, this will proxy to that
/// aggregation. When this result is used in isolation, it will just delegate
/// back to the derived class's implementation.
///
/// Note that callers of this need to take considerable care to not cause
/// performance problems when they use this routine, in the case of a large
/// number of alias analyses being aggregated, it can be expensive to walk
/// back across the chain.
AAResultsProxy getBestAAResults() { return AAResultsProxy(AAR, derived()); }

public:
AliasResult alias(const MemoryLocation &LocA, const MemoryLocation &LocB,
AAQueryInfo &AAQI) {
Expand Down
47 changes: 18 additions & 29 deletions llvm/lib/Analysis/AliasAnalysis.cpp
Expand Up @@ -76,21 +76,9 @@ static const bool EnableAATrace = false;
#endif

AAResults::AAResults(AAResults &&Arg)
: TLI(Arg.TLI), AAs(std::move(Arg.AAs)), AADeps(std::move(Arg.AADeps)) {
for (auto &AA : AAs)
AA->setAAResults(this);
}
: TLI(Arg.TLI), AAs(std::move(Arg.AAs)), AADeps(std::move(Arg.AADeps)) {}

AAResults::~AAResults() {
// FIXME; It would be nice to at least clear out the pointers back to this
// aggregation here, but we end up with non-nesting lifetimes in the legacy
// pass manager that prevent this from working. In the legacy pass manager
// we'll end up with dangling references here in some cases.
#if 0
for (auto &AA : AAs)
AA->setAAResults(nullptr);
#endif
}
AAResults::~AAResults() {}

bool AAResults::invalidate(Function &F, const PreservedAnalyses &PA,
FunctionAnalysisManager::Invalidator &Inv) {
Expand Down Expand Up @@ -118,7 +106,7 @@ bool AAResults::invalidate(Function &F, const PreservedAnalyses &PA,

AliasResult AAResults::alias(const MemoryLocation &LocA,
const MemoryLocation &LocB) {
SimpleAAQueryInfo AAQIP;
SimpleAAQueryInfo AAQIP(*this);
return alias(LocA, LocB, AAQIP);
}

Expand Down Expand Up @@ -161,7 +149,7 @@ AliasResult AAResults::alias(const MemoryLocation &LocA,

bool AAResults::pointsToConstantMemory(const MemoryLocation &Loc,
bool OrLocal) {
SimpleAAQueryInfo AAQIP;
SimpleAAQueryInfo AAQIP(*this);
return pointsToConstantMemory(Loc, AAQIP, OrLocal);
}

Expand Down Expand Up @@ -189,7 +177,7 @@ ModRefInfo AAResults::getArgModRefInfo(const CallBase *Call, unsigned ArgIdx) {
}

ModRefInfo AAResults::getModRefInfo(Instruction *I, const CallBase *Call2) {
SimpleAAQueryInfo AAQIP;
SimpleAAQueryInfo AAQIP(*this);
return getModRefInfo(I, Call2, AAQIP);
}

Expand All @@ -216,7 +204,7 @@ ModRefInfo AAResults::getModRefInfo(Instruction *I, const CallBase *Call2,

ModRefInfo AAResults::getModRefInfo(const CallBase *Call,
const MemoryLocation &Loc) {
SimpleAAQueryInfo AAQIP;
SimpleAAQueryInfo AAQIP(*this);
return getModRefInfo(Call, Loc, AAQIP);
}

Expand Down Expand Up @@ -276,7 +264,7 @@ ModRefInfo AAResults::getModRefInfo(const CallBase *Call,

ModRefInfo AAResults::getModRefInfo(const CallBase *Call1,
const CallBase *Call2) {
SimpleAAQueryInfo AAQIP;
SimpleAAQueryInfo AAQIP(*this);
return getModRefInfo(Call1, Call2, AAQIP);
}

Expand Down Expand Up @@ -403,7 +391,7 @@ FunctionModRefBehavior AAResults::getModRefBehavior(const CallBase *Call,
}

FunctionModRefBehavior AAResults::getModRefBehavior(const CallBase *Call) {
SimpleAAQueryInfo AAQI;
SimpleAAQueryInfo AAQI(*this);
return getModRefBehavior(Call, AAQI);
}

Expand Down Expand Up @@ -484,7 +472,7 @@ raw_ostream &llvm::operator<<(raw_ostream &OS, FunctionModRefBehavior FMRB) {

ModRefInfo AAResults::getModRefInfo(const LoadInst *L,
const MemoryLocation &Loc) {
SimpleAAQueryInfo AAQIP;
SimpleAAQueryInfo AAQIP(*this);
return getModRefInfo(L, Loc, AAQIP);
}
ModRefInfo AAResults::getModRefInfo(const LoadInst *L,
Expand All @@ -507,7 +495,7 @@ ModRefInfo AAResults::getModRefInfo(const LoadInst *L,

ModRefInfo AAResults::getModRefInfo(const StoreInst *S,
const MemoryLocation &Loc) {
SimpleAAQueryInfo AAQIP;
SimpleAAQueryInfo AAQIP(*this);
return getModRefInfo(S, Loc, AAQIP);
}
ModRefInfo AAResults::getModRefInfo(const StoreInst *S,
Expand All @@ -534,8 +522,9 @@ ModRefInfo AAResults::getModRefInfo(const StoreInst *S,
return ModRefInfo::Mod;
}

ModRefInfo AAResults::getModRefInfo(const FenceInst *S, const MemoryLocation &Loc) {
SimpleAAQueryInfo AAQIP;
ModRefInfo AAResults::getModRefInfo(const FenceInst *S,
const MemoryLocation &Loc) {
SimpleAAQueryInfo AAQIP(*this);
return getModRefInfo(S, Loc, AAQIP);
}

Expand All @@ -551,7 +540,7 @@ ModRefInfo AAResults::getModRefInfo(const FenceInst *S,

ModRefInfo AAResults::getModRefInfo(const VAArgInst *V,
const MemoryLocation &Loc) {
SimpleAAQueryInfo AAQIP;
SimpleAAQueryInfo AAQIP(*this);
return getModRefInfo(V, Loc, AAQIP);
}

Expand All @@ -577,7 +566,7 @@ ModRefInfo AAResults::getModRefInfo(const VAArgInst *V,

ModRefInfo AAResults::getModRefInfo(const CatchPadInst *CatchPad,
const MemoryLocation &Loc) {
SimpleAAQueryInfo AAQIP;
SimpleAAQueryInfo AAQIP(*this);
return getModRefInfo(CatchPad, Loc, AAQIP);
}

Expand All @@ -597,7 +586,7 @@ ModRefInfo AAResults::getModRefInfo(const CatchPadInst *CatchPad,

ModRefInfo AAResults::getModRefInfo(const CatchReturnInst *CatchRet,
const MemoryLocation &Loc) {
SimpleAAQueryInfo AAQIP;
SimpleAAQueryInfo AAQIP(*this);
return getModRefInfo(CatchRet, Loc, AAQIP);
}

Expand All @@ -617,7 +606,7 @@ ModRefInfo AAResults::getModRefInfo(const CatchReturnInst *CatchRet,

ModRefInfo AAResults::getModRefInfo(const AtomicCmpXchgInst *CX,
const MemoryLocation &Loc) {
SimpleAAQueryInfo AAQIP;
SimpleAAQueryInfo AAQIP(*this);
return getModRefInfo(CX, Loc, AAQIP);
}

Expand All @@ -641,7 +630,7 @@ ModRefInfo AAResults::getModRefInfo(const AtomicCmpXchgInst *CX,

ModRefInfo AAResults::getModRefInfo(const AtomicRMWInst *RMW,
const MemoryLocation &Loc) {
SimpleAAQueryInfo AAQIP;
SimpleAAQueryInfo AAQIP(*this);
return getModRefInfo(RMW, Loc, AAQIP);
}

Expand Down

0 comments on commit c5bf452

Please sign in to comment.