Skip to content

Commit

Permalink
[Attributor] Introduce AAIndirectCallInfo
Browse files Browse the repository at this point in the history
AAIndirectCallInfo will collect information and specialize indirect call
sites. It is similar to our IndirectCallPromotion but runs as part of
the Attributor (so with assumed callee information). It also expands
more calls and let's the rest of the pipeline figure out what is UB, for
now. We use existing call promotion logic to improve the result,
otherwise we rely on the (implicit) function pointer cast.

This effectively "fixes" #60327 as it will undo the type punning early
enough for the inliner to work with the (now specialized, thus direct)
call.

Fixes: #60327
  • Loading branch information
jdoerfert committed Aug 18, 2023
1 parent 18b211c commit 9c08e76
Show file tree
Hide file tree
Showing 7 changed files with 434 additions and 70 deletions.
42 changes: 42 additions & 0 deletions llvm/include/llvm/Transforms/IPO/Attributor.h
Original file line number Diff line number Diff line change
Expand Up @@ -6109,6 +6109,48 @@ struct AAAddressSpace : public StateWrapper<BooleanState, AbstractAttribute> {
static const char ID;
};

/// An abstract interface for indirect call information interference.
struct AAIndirectCallInfo
: public StateWrapper<BooleanState, AbstractAttribute> {
AAIndirectCallInfo(const IRPosition &IRP, Attributor &A)
: StateWrapper<BooleanState, AbstractAttribute>(IRP) {}

/// The point is to derive callees, after all.
static bool requiresCalleeForCallBase() { return false; }

/// See AbstractAttribute::isValidIRPositionForInit
static bool isValidIRPositionForInit(Attributor &A, const IRPosition &IRP) {
if (IRP.getPositionKind() != IRPosition::IRP_CALL_SITE)
return false;
auto *CB = cast<CallBase>(IRP.getCtxI());
return CB->getOpcode() == Instruction::Call && CB->isIndirectCall() &&
!CB->isMustTailCall();
}

/// Create an abstract attribute view for the position \p IRP.
static AAIndirectCallInfo &createForPosition(const IRPosition &IRP,
Attributor &A);

/// Call \CB on each potential callee value and return true if all were known
/// and \p CB returned true on all of them. Otherwise, return false.
virtual bool foreachCallee(function_ref<bool(Function *)> CB) const = 0;

/// See AbstractAttribute::getName()
const std::string getName() const override { return "AAIndirectCallInfo"; }

/// See AbstractAttribute::getIdAddr()
const char *getIdAddr() const override { return &ID; }

/// This function should return true if the type of the \p AA is
/// AAIndirectCallInfo
static bool classof(const AbstractAttribute *AA) {
return (AA->getIdAddr() == &ID);
}

/// Unique ID (due to the unique address)
static const char ID;
};

raw_ostream &operator<<(raw_ostream &, const AAPointerInfo::Access &);

/// Run options, used by the pass manager.
Expand Down
4 changes: 3 additions & 1 deletion llvm/lib/Transforms/IPO/Attributor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3460,8 +3460,10 @@ void Attributor::identifyDefaultAbstractAttributes(Function &F) {
Function *Callee = dyn_cast_if_present<Function>(CB.getCalledOperand());
// TODO: Even if the callee is not known now we might be able to simplify
// the call/callee.
if (!Callee)
if (!Callee) {
getOrCreateAAFor<AAIndirectCallInfo>(CBFnPos);
return true;
}

// Every call site can track active assumptions.
getOrCreateAAFor<AAAssumptionInfo>(CBFnPos);
Expand Down
254 changes: 245 additions & 9 deletions llvm/lib/Transforms/IPO/AttributorAttributes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,14 @@
#include "llvm/Support/GraphWriter.h"
#include "llvm/Support/MathExtras.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/Transforms/Utils/CallPromotionUtils.h"
#include "llvm/Transforms/Utils/Local.h"
#include "llvm/Transforms/Utils/ValueMapper.h"
#include <cassert>
#include <numeric>
#include <optional>
#include <string>

using namespace llvm;

Expand Down Expand Up @@ -188,6 +191,7 @@ PIPE_OPERATOR(AAPointerInfo)
PIPE_OPERATOR(AAAssumptionInfo)
PIPE_OPERATOR(AAUnderlyingObjects)
PIPE_OPERATOR(AAAddressSpace)
PIPE_OPERATOR(AAIndirectCallInfo)

#undef PIPE_OPERATOR

Expand Down Expand Up @@ -10560,15 +10564,12 @@ struct AACallEdgesCallSite : public AACallEdgesImpl {
return Change;
}

// Process callee metadata if available.
if (auto *MD = getCtxI()->getMetadata(LLVMContext::MD_callees)) {
for (const auto &Op : MD->operands()) {
Function *Callee = mdconst::dyn_extract_or_null<Function>(Op);
if (Callee)
addCalledFunction(Callee, Change);
}
return Change;
}
if (CB->isIndirectCall())
if (auto *IndirectCallAA = A.getAAFor<AAIndirectCallInfo>(
*this, getIRPosition(), DepClassTy::OPTIONAL))
if (IndirectCallAA->foreachCallee(
[&](Function *Fn) { return VisitValue(*Fn, CB); }))
return Change;

// The most simple case.
ProcessCalledOperand(CB->getCalledOperand(), CB);
Expand Down Expand Up @@ -12051,6 +12052,224 @@ struct AAUnderlyingObjectsFunction final : AAUnderlyingObjectsImpl {
};
} // namespace

/// ------------------------ Indirect Call Info -------------------------------
namespace {
struct AAIndirectCallInfoCallSite : public AAIndirectCallInfo {
AAIndirectCallInfoCallSite(const IRPosition &IRP, Attributor &A)
: AAIndirectCallInfo(IRP, A) {}

/// See AbstractAttribute::initialize(...).
void initialize(Attributor &A) override {
auto *MD = getCtxI()->getMetadata(LLVMContext::MD_callees);
if (!MD)
return;
for (const auto &Op : MD->operands())
if (Function *Callee = mdconst::dyn_extract_or_null<Function>(Op))
PotentialCallees.insert(Callee);
}

ChangeStatus updateImpl(Attributor &A) override {
CallBase *CB = cast<CallBase>(getCtxI());
Value *FP = CB->getCalledOperand();

SmallSetVector<Function *, 4> AssumedCalleesNow;
bool AllCalleesKnownNow = AllCalleesKnown;

// Use simplification to find potential callees, if !callees was present,
// fallback to that set if necessary.
bool UsedAssumedInformation;
SmallVector<AA::ValueAndContext> Values;
if (!A.getAssumedSimplifiedValues(IRPosition::value(*FP), this, Values,
AA::ValueScope::AnyScope,
UsedAssumedInformation)) {
if (PotentialCallees.empty())
return indicatePessimisticFixpoint();
AssumedCalleesNow.set_union(PotentialCallees);
}

// Check simplification result, prune known UB callees, also restrict it to
// the !callees set, if present.
for (auto &VAC : Values) {
if (isa<UndefValue>(VAC.getValue()))
continue;
if (isa<ConstantPointerNull>(VAC.getValue()) &&
VAC.getValue()->getType()->getPointerAddressSpace() == 0)
continue;
// TODO: Check for known UB, e.g., poison + noundef.
if (auto *VACFn = dyn_cast<Function>(VAC.getValue())) {
if (PotentialCallees.empty() || PotentialCallees.count(VACFn))
AssumedCalleesNow.insert(VACFn);
continue;
}
if (!PotentialCallees.empty()) {
AssumedCalleesNow.set_union(PotentialCallees);
break;
}
AllCalleesKnownNow = false;
}

// If we can't specialize at all, give up now.
if (!AllCalleesKnownNow && AssumedCalleesNow.empty())
return indicatePessimisticFixpoint();

if (AssumedCalleesNow == AssumedCalles &&
AllCalleesKnown == AllCalleesKnownNow)
return ChangeStatus::UNCHANGED;

std::swap(AssumedCalles, AssumedCalleesNow);
AllCalleesKnown = AllCalleesKnownNow;
return ChangeStatus::CHANGED;
}

/// See AbstractAttribute::manifest(...).
ChangeStatus manifest(Attributor &A) override {

ChangeStatus Changed = ChangeStatus::UNCHANGED;
CallBase *CB = cast<CallBase>(getCtxI());
Value *FP = CB->getCalledOperand();

bool CBIsVoid = CB->getType()->isVoidTy();
Instruction *IP = CB;
FunctionType *CSFT = CB->getFunctionType();
SmallVector<Value *> CSArgs(CB->arg_begin(), CB->arg_end());

// If we know all callees and there are none, the call site is (effectively)
// dead (or UB).
if (AssumedCalles.empty()) {
assert(AllCalleesKnown &&
"Expected all callees to be known if there are none.");
A.changeToUnreachableAfterManifest(CB);
return ChangeStatus::CHANGED;
}

// Special handling for the single callee case.
if (AllCalleesKnown && AssumedCalles.size() == 1) {
auto *NewCallee = AssumedCalles.front();
if (isLegalToPromote(*CB, NewCallee)) {
promoteCall(*CB, NewCallee, nullptr);
return ChangeStatus::CHANGED;
}
Instruction *NewCall = CallInst::Create(FunctionCallee(CSFT, NewCallee),
CSArgs, CB->getName(), CB);
if (!CBIsVoid)
A.changeAfterManifest(IRPosition::callsite_returned(*CB), *NewCall);
A.deleteAfterManifest(*CB);
return ChangeStatus::CHANGED;
}

// For each potential value we create a conditional
//
// ```
// if (ptr == value) value(args);
// else ...
// ```
//
ICmpInst *LastCmp = nullptr;
SmallVector<std::pair<CallInst *, Instruction *>> NewCalls;
for (Function *NewCallee : AssumedCalles) {
LastCmp = new ICmpInst(IP, llvm::CmpInst::ICMP_EQ, FP, NewCallee);
Instruction *ThenTI =
SplitBlockAndInsertIfThen(LastCmp, IP, /* Unreachable */ false);
BasicBlock *CBBB = CB->getParent();
auto *SplitTI = cast<BranchInst>(LastCmp->getNextNode());
BasicBlock *ElseBB;
if (IP == CB) {
ElseBB = BasicBlock::Create(ThenTI->getContext(), "",
ThenTI->getFunction(), CBBB);
IP = BranchInst::Create(CBBB, ElseBB);
SplitTI->replaceUsesOfWith(CBBB, ElseBB);
} else {
ElseBB = IP->getParent();
ThenTI->replaceUsesOfWith(ElseBB, CBBB);
}
CastInst *RetBC = nullptr;
CallInst *NewCall = nullptr;
if (isLegalToPromote(*CB, NewCallee)) {
auto *CBClone = cast<CallBase>(CB->clone());
CBClone->insertBefore(ThenTI);
NewCall = &cast<CallInst>(promoteCall(*CBClone, NewCallee, &RetBC));
} else {
NewCall = CallInst::Create(FunctionCallee(CSFT, NewCallee), CSArgs,
CB->getName(), ThenTI);
}
NewCalls.push_back({NewCall, RetBC});
}

// Check if we need the fallback indirect call still.
if (AllCalleesKnown) {
LastCmp->replaceAllUsesWith(ConstantInt::getTrue(LastCmp->getContext()));
LastCmp->eraseFromParent();
new UnreachableInst(IP->getContext(), IP);
IP->eraseFromParent();
} else {
auto *CBClone = cast<CallInst>(CB->clone());
CBClone->setName(CB->getName());
CBClone->insertBefore(IP);
NewCalls.push_back({CBClone, nullptr});
}

// Check if we need a PHI to merge the results.
if (!CBIsVoid) {
auto *PHI = PHINode::Create(CB->getType(), NewCalls.size(),
CB->getName() + ".phi",
&*CB->getParent()->getFirstInsertionPt());
for (auto &It : NewCalls) {
CallBase *NewCall = It.first;
Instruction *CallRet = It.second ? It.second : It.first;
if (CallRet->getType() == CB->getType())
PHI->addIncoming(CallRet, CallRet->getParent());
else if (NewCall->getType()->isVoidTy())
PHI->addIncoming(PoisonValue::get(CB->getType()),
NewCall->getParent());
else
llvm_unreachable("Call return should match or be void!");
}
A.changeAfterManifest(IRPosition::callsite_returned(*CB), *PHI);
}

A.deleteAfterManifest(*CB);
Changed = ChangeStatus::CHANGED;

return Changed;
}

/// See AbstractAttribute::getAsStr().
const std::string getAsStr(Attributor *A) const override {
return std::string(AllCalleesKnown ? "eliminate" : "specialize") +
" indirect call site with " + std::to_string(AssumedCalles.size()) +
" functions";
}

void trackStatistics() const override {
if (AllCalleesKnown) {
STATS_DECLTRACK(
Eliminated, CallSites,
"Number of indirect call sites eliminated via specialization")
} else {
STATS_DECLTRACK(Specialized, CallSites,
"Number of indirect call sites specialized")
}
}

bool foreachCallee(function_ref<bool(Function *)> CB) const override {
return isValidState() && AllCalleesKnown && all_of(AssumedCalles, CB);
}

private:
/// If the !callee metadata was present, this set will contain all potential
/// callees (superset).
SmallSetVector<Function *, 4> PotentialCallees;

/// This set contains all currently assumed calllees, which might grow over
/// time.
SmallSetVector<Function *, 4> AssumedCalles;

/// Flag to indicate if all possible callees are in the AssumedCalles set or
/// if there could be others.
bool AllCalleesKnown = true;
};
} // namespace

/// ------------------------ Address Space ------------------------------------
namespace {
struct AAAddressSpaceImpl : public AAAddressSpace {
Expand Down Expand Up @@ -12259,6 +12478,7 @@ const char AAPointerInfo::ID = 0;
const char AAAssumptionInfo::ID = 0;
const char AAUnderlyingObjects::ID = 0;
const char AAAddressSpace::ID = 0;
const char AAIndirectCallInfo::ID = 0;

// Macro magic to create the static generator function for attributes that
// follow the naming scheme.
Expand Down Expand Up @@ -12305,6 +12525,18 @@ const char AAAddressSpace::ID = 0;
return *AA; \
}

#define CREATE_ABSTRACT_ATTRIBUTE_FOR_ONE_POSITION(POS, SUFFIX, CLASS) \
CLASS &CLASS::createForPosition(const IRPosition &IRP, Attributor &A) { \
CLASS *AA = nullptr; \
switch (IRP.getPositionKind()) { \
SWITCH_PK_CREATE(CLASS, IRP, POS, SUFFIX) \
default: \
llvm_unreachable("Cannot create " #CLASS " for position otherthan " #POS \
" position!"); \
} \
return *AA; \
}

#define CREATE_ALL_ABSTRACT_ATTRIBUTE_FOR_POSITION(CLASS) \
CLASS &CLASS::createForPosition(const IRPosition &IRP, Attributor &A) { \
CLASS *AA = nullptr; \
Expand Down Expand Up @@ -12383,6 +12615,9 @@ CREATE_ALL_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAIsDead)
CREATE_ALL_ABSTRACT_ATTRIBUTE_FOR_POSITION(AANoFree)
CREATE_ALL_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAUnderlyingObjects)

CREATE_ABSTRACT_ATTRIBUTE_FOR_ONE_POSITION(IRP_CALL_SITE, CallSite,
AAIndirectCallInfo)

CREATE_FUNCTION_ONLY_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAHeapToStack)
CREATE_FUNCTION_ONLY_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAUndefinedBehavior)
CREATE_FUNCTION_ONLY_ABSTRACT_ATTRIBUTE_FOR_POSITION(AANonConvergent)
Expand All @@ -12396,5 +12631,6 @@ CREATE_NON_RET_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAMemoryBehavior)
#undef CREATE_NON_RET_ABSTRACT_ATTRIBUTE_FOR_POSITION
#undef CREATE_VALUE_ABSTRACT_ATTRIBUTE_FOR_POSITION
#undef CREATE_ALL_ABSTRACT_ATTRIBUTE_FOR_POSITION
#undef CREATE_ABSTRACT_ATTRIBUTE_FOR_ONE_POSITION
#undef SWITCH_PK_CREATE
#undef SWITCH_PK_INV
5 changes: 5 additions & 0 deletions llvm/lib/Transforms/IPO/OpenMPOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5419,6 +5419,11 @@ void OpenMPOpt::registerAAsForFunction(Attributor &A, const Function &F) {
UsedAssumedInformation, AA::Interprocedural);
continue;
}
if (auto *CI = dyn_cast<CallBase>(&I)) {
if (CI->isIndirectCall())
A.getOrCreateAAFor<AAIndirectCallInfo>(
IRPosition::callsite_function(*CI));
}
if (auto *SI = dyn_cast<StoreInst>(&I)) {
A.getOrCreateAAFor<AAIsDead>(IRPosition::value(*SI));
continue;
Expand Down

2 comments on commit 9c08e76

@thurstond
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's been a buildbot breakage at https://lab.llvm.org/buildbot/#/builders/5/builds/35950
involving

  LLVM :: Transforms/Attributor/liveness.ll
  LLVM :: Transforms/Attributor/misc.ll
  LLVM :: Transforms/Attributor/value-simplify.ll

This patch, being in Attributor, seems to be the most relevant out of the patches introduced in that build.
(Bisection experiment underway)

@jdoerfert
Copy link
Member Author

@jdoerfert jdoerfert commented on 9c08e76 Aug 19, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@thurstond Found it. Fix in 2 min. Thx

Please sign in to comment.