Skip to content

Commit

Permalink
[Attributor][NFCI] Use pointers to pass around AAs
Browse files Browse the repository at this point in the history
This will make it easier to create less trivial AAs in the future as we
can simply return `nullptr` rather than an AA with in invalid state.
  • Loading branch information
jdoerfert committed Jun 24, 2023
1 parent 52c799b commit e9fc399
Show file tree
Hide file tree
Showing 6 changed files with 620 additions and 503 deletions.
21 changes: 11 additions & 10 deletions llvm/include/llvm/Transforms/IPO/Attributor.h
Original file line number Diff line number Diff line change
Expand Up @@ -1545,7 +1545,7 @@ struct Attributor {
/// attribute is used for reasoning. To record the dependences explicitly use
/// the `Attributor::recordDependence` method.
template <typename AAType>
const AAType &getAAFor(const AbstractAttribute &QueryingAA,
const AAType *getAAFor(const AbstractAttribute &QueryingAA,
const IRPosition &IRP, DepClassTy DepClass) {
return getOrCreateAAFor<AAType>(IRP, &QueryingAA, DepClass,
/* ForceUpdate */ false);
Expand All @@ -1557,7 +1557,7 @@ struct Attributor {
/// possible/useful that were not happening before as the abstract attribute
/// was assumed dead.
template <typename AAType>
const AAType &getAndUpdateAAFor(const AbstractAttribute &QueryingAA,
const AAType *getAndUpdateAAFor(const AbstractAttribute &QueryingAA,
const IRPosition &IRP, DepClassTy DepClass) {
return getOrCreateAAFor<AAType>(IRP, &QueryingAA, DepClass,
/* ForceUpdate */ true);
Expand All @@ -1569,7 +1569,7 @@ struct Attributor {
/// function.
/// NOTE: ForceUpdate is ignored in any stage other than the update stage.
template <typename AAType>
const AAType &getOrCreateAAFor(IRPosition IRP,
const AAType *getOrCreateAAFor(IRPosition IRP,
const AbstractAttribute *QueryingAA,
DepClassTy DepClass, bool ForceUpdate = false,
bool UpdateAfterInit = true) {
Expand All @@ -1580,7 +1580,7 @@ struct Attributor {
/* AllowInvalidState */ true)) {
if (ForceUpdate && Phase == AttributorPhase::UPDATE)
updateAA(*AAPtr);
return *AAPtr;
return AAPtr;
}

// No matching attribute found, create one.
Expand All @@ -1594,7 +1594,7 @@ struct Attributor {
// If we are currenty seeding attributes, enforce seeding rules.
if (Phase == AttributorPhase::SEEDING && !shouldSeedAttribute(AA)) {
AA.getState().indicatePessimisticFixpoint();
return AA;
return &AA;
}

// For now we ignore naked and optnone functions.
Expand All @@ -1616,7 +1616,7 @@ struct Attributor {
// Allowed we will not perform updates at all.
if (Invalidate) {
AA.getState().indicatePessimisticFixpoint();
return AA;
return &AA;
}

{
Expand All @@ -1631,15 +1631,15 @@ struct Attributor {
if ((AnchorFn && !isRunOn(const_cast<Function *>(AnchorFn))) &&
!isRunOn(IRP.getAssociatedFunction())) {
AA.getState().indicatePessimisticFixpoint();
return AA;
return &AA;
}

// If this is queried in the manifest stage, we force the AA to indicate
// pessimistic fixpoint immediately.
if (Phase == AttributorPhase::MANIFEST ||
Phase == AttributorPhase::CLEANUP) {
AA.getState().indicatePessimisticFixpoint();
return AA;
return &AA;
}

// Allow seeded attributes to declare dependencies.
Expand All @@ -1656,10 +1656,11 @@ struct Attributor {
if (QueryingAA && AA.getState().isValidState())
recordDependence(AA, const_cast<AbstractAttribute &>(*QueryingAA),
DepClass);
return AA;
return &AA;
}

template <typename AAType>
const AAType &getOrCreateAAFor(const IRPosition &IRP) {
const AAType *getOrCreateAAFor(const IRPosition &IRP) {
return getOrCreateAAFor<AAType>(IRP, /* QueryingAA */ nullptr,
DepClassTy::NONE);
}
Expand Down
64 changes: 38 additions & 26 deletions llvm/lib/Target/AMDGPU/AMDGPUAttributor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -344,11 +344,13 @@ struct AAUniformWorkGroupSizeFunction : public AAUniformWorkGroupSize {
LLVM_DEBUG(dbgs() << "[AAUniformWorkGroupSize] Call " << Caller->getName()
<< "->" << getAssociatedFunction()->getName() << "\n");

const auto &CallerInfo = A.getAAFor<AAUniformWorkGroupSize>(
const auto *CallerInfo = A.getAAFor<AAUniformWorkGroupSize>(
*this, IRPosition::function(*Caller), DepClassTy::REQUIRED);
if (!CallerInfo)
return false;

Change = Change | clampStateAndIndicateChange(this->getState(),
CallerInfo.getState());
CallerInfo->getState());

return true;
};
Expand Down Expand Up @@ -433,9 +435,9 @@ struct AAAMDAttributesFunction : public AAAMDAttributes {
auto OrigAssumed = getAssumed();

// Check for Intrinsics and propagate attributes.
const AACallEdges &AAEdges = A.getAAFor<AACallEdges>(
const AACallEdges *AAEdges = A.getAAFor<AACallEdges>(
*this, this->getIRPosition(), DepClassTy::REQUIRED);
if (AAEdges.hasNonAsmUnknownCallee())
if (!AAEdges || AAEdges->hasNonAsmUnknownCallee())
return indicatePessimisticFixpoint();

bool IsNonEntryFunc = !AMDGPU::isEntryFunctionCC(F->getCallingConv());
Expand All @@ -446,12 +448,14 @@ struct AAAMDAttributesFunction : public AAAMDAttributes {
bool SupportsGetDoorbellID = InfoCache.supportsGetDoorbellID(*F);
unsigned COV = InfoCache.getCodeObjectVersion();

for (Function *Callee : AAEdges.getOptimisticEdges()) {
for (Function *Callee : AAEdges->getOptimisticEdges()) {
Intrinsic::ID IID = Callee->getIntrinsicID();
if (IID == Intrinsic::not_intrinsic) {
const AAAMDAttributes &AAAMD = A.getAAFor<AAAMDAttributes>(
*this, IRPosition::function(*Callee), DepClassTy::REQUIRED);
*this &= AAAMD;
const AAAMDAttributes *AAAMD = A.getAAFor<AAAMDAttributes>(
*this, IRPosition::function(*Callee), DepClassTy::REQUIRED);
if (!AAAMD)
return indicatePessimisticFixpoint();
*this &= *AAAMD;
continue;
}

Expand Down Expand Up @@ -641,10 +645,12 @@ struct AAAMDAttributesFunction : public AAAMDAttributes {
if (Call.getIntrinsicID() != Intrinsic::amdgcn_implicitarg_ptr)
return true;

const auto &PointerInfoAA = A.getAAFor<AAPointerInfo>(
const auto *PointerInfoAA = A.getAAFor<AAPointerInfo>(
*this, IRPosition::callsite_returned(Call), DepClassTy::REQUIRED);
if (!PointerInfoAA)
return false;

return PointerInfoAA.forallInterferingAccesses(
return PointerInfoAA->forallInterferingAccesses(
Range, [](const AAPointerInfo::Access &Acc, bool IsExact) {
return Acc.getRemoteInst()->isDroppable();
});
Expand Down Expand Up @@ -696,11 +702,13 @@ struct AAAMDSizeRangeAttribute
LLVM_DEBUG(dbgs() << '[' << getName() << "] Call " << Caller->getName()
<< "->" << getAssociatedFunction()->getName() << '\n');

const auto &CallerInfo = A.getAAFor<AttributeImpl>(
const auto *CallerInfo = A.getAAFor<AttributeImpl>(
*this, IRPosition::function(*Caller), DepClassTy::REQUIRED);
if (!CallerInfo)
return false;

Change |=
clampStateAndIndicateChange(this->getState(), CallerInfo.getState());
clampStateAndIndicateChange(this->getState(), CallerInfo->getState());

return true;
};
Expand Down Expand Up @@ -813,15 +821,17 @@ struct AAAMDWavesPerEU : public AAAMDSizeRangeAttribute {
Function *F = getAssociatedFunction();
auto &InfoCache = static_cast<AMDGPUInformationCache &>(A.getInfoCache());

const auto &AssumedGroupSize = A.getAAFor<AAAMDFlatWorkGroupSize>(
*this, IRPosition::function(*F), DepClassTy::REQUIRED);
unsigned Min, Max;
std::tie(Min, Max) = InfoCache.getWavesPerEU(
*F, {AssumedGroupSize.getAssumed().getLower().getZExtValue(),
AssumedGroupSize.getAssumed().getUpper().getZExtValue() - 1});
if (const auto *AssumedGroupSize = A.getAAFor<AAAMDFlatWorkGroupSize>(
*this, IRPosition::function(*F), DepClassTy::REQUIRED)) {

ConstantRange Range(APInt(32, Min), APInt(32, Max + 1));
intersectKnown(Range);
unsigned Min, Max;
std::tie(Min, Max) = InfoCache.getWavesPerEU(
*F, {AssumedGroupSize->getAssumed().getLower().getZExtValue(),
AssumedGroupSize->getAssumed().getUpper().getZExtValue() - 1});

ConstantRange Range(APInt(32, Min), APInt(32, Max + 1));
intersectKnown(Range);
}

if (AMDGPU::isEntryFunctionCC(F->getCallingConv()))
indicatePessimisticFixpoint();
Expand All @@ -837,18 +847,20 @@ struct AAAMDWavesPerEU : public AAAMDSizeRangeAttribute {
LLVM_DEBUG(dbgs() << '[' << getName() << "] Call " << Caller->getName()
<< "->" << Func->getName() << '\n');

const auto &CallerInfo = A.getAAFor<AAAMDWavesPerEU>(
const auto *CallerInfo = A.getAAFor<AAAMDWavesPerEU>(
*this, IRPosition::function(*Caller), DepClassTy::REQUIRED);
const auto &AssumedGroupSize = A.getAAFor<AAAMDFlatWorkGroupSize>(
const auto *AssumedGroupSize = A.getAAFor<AAAMDFlatWorkGroupSize>(
*this, IRPosition::function(*Func), DepClassTy::REQUIRED);
if (!CallerInfo || !AssumedGroupSize)
return false;

unsigned Min, Max;
std::tie(Min, Max) = InfoCache.getEffectiveWavesPerEU(
*Caller,
{CallerInfo.getAssumed().getLower().getZExtValue(),
CallerInfo.getAssumed().getUpper().getZExtValue() - 1},
{AssumedGroupSize.getAssumed().getLower().getZExtValue(),
AssumedGroupSize.getAssumed().getUpper().getZExtValue() - 1});
{CallerInfo->getAssumed().getLower().getZExtValue(),
CallerInfo->getAssumed().getUpper().getZExtValue() - 1},
{AssumedGroupSize->getAssumed().getLower().getZExtValue(),
AssumedGroupSize->getAssumed().getUpper().getZExtValue() - 1});
ConstantRange CallerRange(APInt(32, Min), APInt(32, Max + 1));
IntegerRangeState CallerRangeState(CallerRange);
Change |= clampStateAndIndicateChange(this->getState(), CallerRangeState);
Expand Down

0 comments on commit e9fc399

Please sign in to comment.