Skip to content

Commit

Permalink
[Attributor] AAFunctionReachability, Instruction reachability.
Browse files Browse the repository at this point in the history
This patch implement instruction reachability for AAFunctionReachability
attribute. It is used to tell if a certain instruction can reach a function
transitively.

NOTE: I created a new commit based of D106720 and set the author back to
      Kuter. Other metadata, etc. is wrong. I also addressed the
      remaining review comments and fixed the unit test.

Differential Revision: https://reviews.llvm.org/D106720
  • Loading branch information
kuterd authored and jdoerfert committed Feb 1, 2022
1 parent ac3ec22 commit b2d1ae0
Show file tree
Hide file tree
Showing 3 changed files with 256 additions and 63 deletions.
14 changes: 11 additions & 3 deletions llvm/include/llvm/Transforms/IPO/Attributor.h
Expand Up @@ -4616,17 +4616,25 @@ struct AAFunctionReachability
AAFunctionReachability(const IRPosition &IRP, Attributor &A) : Base(IRP) {}

/// If the function represented by this possition can reach \p Fn.
virtual bool canReach(Attributor &A, Function *Fn) const = 0;
virtual bool canReach(Attributor &A, const Function &Fn) const = 0;

/// Can \p CB reach \p Fn
virtual bool canReach(Attributor &A, CallBase &CB, Function *Fn) const = 0;
virtual bool canReach(Attributor &A, CallBase &CB,
const Function &Fn) const = 0;

/// Can \p Inst reach \p Fn
virtual bool instructionCanReach(Attributor &A, const Instruction &Inst,
const Function &Fn,
bool UseBackwards = true) const = 0;

/// Create an abstract attribute view for the position \p IRP.
static AAFunctionReachability &createForPosition(const IRPosition &IRP,
Attributor &A);

/// See AbstractAttribute::getName()
const std::string getName() const override { return "AAFunctionReachability"; }
const std::string getName() const override {
return "AAFunctionReachability";
}

/// See AbstractAttribute::getIdAddr()
const char *getIdAddr() const override { return &ID; }
Expand Down
223 changes: 183 additions & 40 deletions llvm/lib/Transforms/IPO/AttributorAttributes.cpp
Expand Up @@ -656,7 +656,7 @@ struct AACallSiteReturnedFromReturned : public BaseType {
if (!AssociatedFunction)
return S.indicatePessimisticFixpoint();

CallBase &CBContext = static_cast<CallBase &>(this->getAnchorValue());
CallBase &CBContext = cast<CallBase>(this->getAnchorValue());
if (IntroduceCallBaseContext)
LLVM_DEBUG(dbgs() << "[Attributor] Introducing call base context:"
<< CBContext << "\n");
Expand Down Expand Up @@ -2468,7 +2468,7 @@ struct AANoRecurseFunction final : AANoRecurseImpl {
const AAFunctionReachability &EdgeReachability =
A.getAAFor<AAFunctionReachability>(*this, getIRPosition(),
DepClassTy::REQUIRED);
if (EdgeReachability.canReach(A, getAnchorScope()))
if (EdgeReachability.canReach(A, *getAnchorScope()))
return indicatePessimisticFixpoint();
return ChangeStatus::UNCHANGED;
}
Expand Down Expand Up @@ -9482,7 +9482,7 @@ struct AACallEdgesCallSite : public AACallEdgesImpl {
}
};

CallBase *CB = static_cast<CallBase *>(getCtxI());
CallBase *CB = cast<CallBase>(getCtxI());

if (CB->isInlineAsm()) {
setHasUnknownCallee(false, Change);
Expand Down Expand Up @@ -9521,7 +9521,7 @@ struct AACallEdgesFunction : public AACallEdgesImpl {
ChangeStatus Change = ChangeStatus::UNCHANGED;

auto ProcessCallInst = [&](Instruction &Inst) {
CallBase &CB = static_cast<CallBase &>(Inst);
CallBase &CB = cast<CallBase>(Inst);

auto &CBEdges = A.getAAFor<AACallEdges>(
*this, IRPosition::callsite_function(CB), DepClassTy::REQUIRED);
Expand Down Expand Up @@ -9552,11 +9552,39 @@ struct AACallEdgesFunction : public AACallEdgesImpl {
struct AAFunctionReachabilityFunction : public AAFunctionReachability {
private:
struct QuerySet {
void markReachable(Function *Fn) {
Reachable.insert(Fn);
Unreachable.erase(Fn);
void markReachable(const Function &Fn) {
Reachable.insert(&Fn);
Unreachable.erase(&Fn);
}

/// If there is no information about the function None is returned.
Optional<bool> isCachedReachable(const Function &Fn) {
// Assume that we can reach the function.
// TODO: Be more specific with the unknown callee.
if (CanReachUnknownCallee)
return true;

if (Reachable.count(&Fn))
return true;

if (Unreachable.count(&Fn))
return false;

return llvm::None;
}

/// Set of functions that we know for sure is reachable.
DenseSet<const Function *> Reachable;

/// Set of functions that are unreachable, but might become reachable.
DenseSet<const Function *> Unreachable;

/// If we can reach a function with a call to a unknown function we assume
/// that we can reach any function.
bool CanReachUnknownCallee = false;
};

struct QueryResolver : public QuerySet {
ChangeStatus update(Attributor &A, const AAFunctionReachability &AA,
ArrayRef<const AACallEdges *> AAEdgesList) {
ChangeStatus Change = ChangeStatus::UNCHANGED;
Expand All @@ -9570,31 +9598,25 @@ struct AAFunctionReachabilityFunction : public AAFunctionReachability {
}
}

for (Function *Fn : make_early_inc_range(Unreachable)) {
if (checkIfReachable(A, AA, AAEdgesList, Fn)) {
for (const Function *Fn : make_early_inc_range(Unreachable)) {
if (checkIfReachable(A, AA, AAEdgesList, *Fn)) {
Change = ChangeStatus::CHANGED;
markReachable(Fn);
markReachable(*Fn);
}
}
return Change;
}

bool isReachable(Attributor &A, const AAFunctionReachability &AA,
ArrayRef<const AACallEdges *> AAEdgesList, Function *Fn) {
// Assume that we can reach the function.
// TODO: Be more specific with the unknown callee.
if (CanReachUnknownCallee)
return true;

if (Reachable.count(Fn))
return true;

if (Unreachable.count(Fn))
return false;
ArrayRef<const AACallEdges *> AAEdgesList,
const Function &Fn) {
Optional<bool> Cached = isCachedReachable(Fn);
if (Cached.hasValue())
return Cached.getValue();

// We need to assume that this function can't reach Fn to prevent
// an infinite loop if this function is recursive.
Unreachable.insert(Fn);
Unreachable.insert(&Fn);

bool Result = checkIfReachable(A, AA, AAEdgesList, Fn);
if (Result)
Expand All @@ -9604,13 +9626,13 @@ struct AAFunctionReachabilityFunction : public AAFunctionReachability {

bool checkIfReachable(Attributor &A, const AAFunctionReachability &AA,
ArrayRef<const AACallEdges *> AAEdgesList,
Function *Fn) const {
const Function &Fn) const {

// Handle the most trivial case first.
for (auto *AAEdges : AAEdgesList) {
const SetVector<Function *> &Edges = AAEdges->getOptimisticEdges();

if (Edges.count(Fn))
if (Edges.count(const_cast<Function *>(&Fn)))
return true;
}

Expand All @@ -9631,28 +9653,80 @@ struct AAFunctionReachabilityFunction : public AAFunctionReachability {
}

// The result is false for now, set dependencies and leave.
for (auto Dep : Deps)
A.recordDependence(AA, *Dep, DepClassTy::REQUIRED);
for (auto *Dep : Deps)
A.recordDependence(*Dep, AA, DepClassTy::REQUIRED);

return false;
}
};

/// Set of functions that we know for sure is reachable.
DenseSet<Function *> Reachable;
/// Get call edges that can be reached by this instruction.
bool getReachableCallEdges(Attributor &A, const AAReachability &Reachability,
const Instruction &Inst,
SmallVector<const AACallEdges *> &Result) const {
// Determine call like instructions that we can reach from the inst.
auto CheckCallBase = [&](Instruction &CBInst) {
if (!Reachability.isAssumedReachable(A, Inst, CBInst))
return true;

/// Set of functions that are unreachable, but might become reachable.
DenseSet<Function *> Unreachable;
const auto &CB = cast<CallBase>(CBInst);
const AACallEdges &AAEdges = A.getAAFor<AACallEdges>(
*this, IRPosition::callsite_function(CB), DepClassTy::REQUIRED);

/// If we can reach a function with a call to a unknown function we assume
/// that we can reach any function.
bool CanReachUnknownCallee = false;
};
Result.push_back(&AAEdges);
return true;
};

bool UsedAssumedInformation = false;
return A.checkForAllCallLikeInstructions(CheckCallBase, *this,
UsedAssumedInformation);
}

ChangeStatus checkReachableBackwards(Attributor &A, QuerySet &Set) {
ChangeStatus Change = ChangeStatus::UNCHANGED;

// For all remaining instruction queries, check
// callers. A call inside that function might satisfy the query.
auto CheckCallSite = [&](AbstractCallSite CallSite) {
CallBase *CB = CallSite.getInstruction();
if (!CB)
return false;

if (isa<InvokeInst>(CB))
return false;

Instruction *Inst = CB->getNextNonDebugInstruction();
const AAFunctionReachability &AA = A.getAAFor<AAFunctionReachability>(
*this, IRPosition::function(*Inst->getFunction()),
DepClassTy::REQUIRED);
for (const Function *Fn : make_early_inc_range(Set.Unreachable)) {
if (AA.instructionCanReach(A, *Inst, *Fn, /* UseBackwards */ false)) {
Set.markReachable(*Fn);
Change = ChangeStatus::CHANGED;
}
}
return true;
};

bool NoUnknownCall = true;
if (A.checkForAllCallSites(CheckCallSite, *this, true, NoUnknownCall))
return Change;

// If we don't know all callsites we have to assume that we can reach fn.
for (auto &QSet : InstQueriesBackwards) {
if (!QSet.second.CanReachUnknownCallee)
Change = ChangeStatus::CHANGED;
QSet.second.CanReachUnknownCallee = true;
}

return Change;
}

public:
AAFunctionReachabilityFunction(const IRPosition &IRP, Attributor &A)
: AAFunctionReachability(IRP, A) {}

bool canReach(Attributor &A, Function *Fn) const override {
bool canReach(Attributor &A, const Function &Fn) const override {
const AACallEdges &AAEdges =
A.getAAFor<AACallEdges>(*this, getIRPosition(), DepClassTy::REQUIRED);

Expand All @@ -9668,7 +9742,8 @@ struct AAFunctionReachabilityFunction : public AAFunctionReachability {
}

/// Can \p CB reach \p Fn
bool canReach(Attributor &A, CallBase &CB, Function *Fn) const override {
bool canReach(Attributor &A, CallBase &CB,
const Function &Fn) const override {
const AACallEdges &AAEdges = A.getAAFor<AACallEdges>(
*this, IRPosition::callsite_function(CB), DepClassTy::REQUIRED);

Expand All @@ -9677,13 +9752,52 @@ struct AAFunctionReachabilityFunction : public AAFunctionReachability {
// a const_cast.
// This is a hack for us to be able to cache queries.
auto *NonConstThis = const_cast<AAFunctionReachabilityFunction *>(this);
QuerySet &CBQuery = NonConstThis->CBQueries[&CB];
QueryResolver &CBQuery = NonConstThis->CBQueries[&CB];

bool Result = CBQuery.isReachable(A, *this, {&AAEdges}, Fn);

return Result;
}

bool instructionCanReach(Attributor &A, const Instruction &Inst,
const Function &Fn,
bool UseBackwards) const override {
const auto &Reachability = &A.getAAFor<AAReachability>(
*this, IRPosition::function(*getAssociatedFunction()),
DepClassTy::REQUIRED);

SmallVector<const AACallEdges *> CallEdges;
bool AllKnown = getReachableCallEdges(A, *Reachability, Inst, CallEdges);
// Attributor returns attributes as const, so this function has to be
// const for users of this attribute to use it without having to do
// a const_cast.
// This is a hack for us to be able to cache queries.
auto *NonConstThis = const_cast<AAFunctionReachabilityFunction *>(this);
QueryResolver &InstQSet = NonConstThis->InstQueries[&Inst];
if (!AllKnown)
InstQSet.CanReachUnknownCallee = true;

bool ForwardsResult = InstQSet.isReachable(A, *this, CallEdges, Fn);
if (ForwardsResult)
return true;
// We are done.
if (!UseBackwards)
return false;

QuerySet &InstBackwardsQSet = NonConstThis->InstQueriesBackwards[&Inst];

Optional<bool> BackwardsCached = InstBackwardsQSet.isCachedReachable(Fn);
if (BackwardsCached.hasValue())
return BackwardsCached.getValue();

// Assume unreachable, to prevent problems.
InstBackwardsQSet.Unreachable.insert(&Fn);

// Check backwards reachability.
NonConstThis->checkReachableBackwards(A, InstBackwardsQSet);
return InstBackwardsQSet.isCachedReachable(Fn).getValue();
}

/// See AbstractAttribute::updateImpl(...).
ChangeStatus updateImpl(Attributor &A) override {
const AACallEdges &AAEdges =
Expand All @@ -9692,14 +9806,37 @@ struct AAFunctionReachabilityFunction : public AAFunctionReachability {

Change |= WholeFunction.update(A, *this, {&AAEdges});

for (auto CBPair : CBQueries) {
for (auto &CBPair : CBQueries) {
const AACallEdges &AAEdges = A.getAAFor<AACallEdges>(
*this, IRPosition::callsite_function(*CBPair.first),
DepClassTy::REQUIRED);

Change |= CBPair.second.update(A, *this, {&AAEdges});
}

// Update the Instruction queries.
const AAReachability *Reachability;
if (!InstQueries.empty()) {
Reachability = &A.getAAFor<AAReachability>(
*this, IRPosition::function(*getAssociatedFunction()),
DepClassTy::REQUIRED);
}

// Check for local callbases first.
for (auto &InstPair : InstQueries) {
SmallVector<const AACallEdges *> CallEdges;
bool AllKnown =
getReachableCallEdges(A, *Reachability, *InstPair.first, CallEdges);
// Update will return change if we this effects any queries.
if (!AllKnown)
InstPair.second.CanReachUnknownCallee = true;
Change |= InstPair.second.update(A, *this, CallEdges);
}

// Update backwards queries.
for (auto &QueryPair : InstQueriesBackwards)
Change |= checkReachableBackwards(A, QueryPair.second);

return Change;
}

Expand All @@ -9720,11 +9857,17 @@ struct AAFunctionReachabilityFunction : public AAFunctionReachability {
}

/// Used to answer if a the whole function can reacha a specific function.
QuerySet WholeFunction;
QueryResolver WholeFunction;

/// Used to answer if a call base inside this function can reach a specific
/// function.
DenseMap<CallBase *, QuerySet> CBQueries;
DenseMap<const CallBase *, QueryResolver> CBQueries;

/// This is for instruction queries than scan "forward".
DenseMap<const Instruction *, QueryResolver> InstQueries;

/// This is for instruction queries than scan "backward".
DenseMap<const Instruction *, QuerySet> InstQueriesBackwards;
};

/// ---------------------- Assumption Propagation ------------------------------
Expand Down

0 comments on commit b2d1ae0

Please sign in to comment.