Skip to content

Commit

Permalink
[AbstractAttributor] Fold function calls to `__kmpc_is_spmd_exec_mode…
Browse files Browse the repository at this point in the history
…` if possible

In the device runtime there are many function calls to `__kmpc_is_spmd_exec_mode`
to query the execution mode of current kernels. In many cases, user programs
only contain target region executing in one mode. As a consequence, those runtime
function calls will only return one value. If we can get rid of these function
calls during compliation, it can potentially improve performance.

In this patch, we use `AAKernelInfo` to analyze kernel execution. Basically, for
each kernel (device) function `F`, we collect all kernel entries `K` that can
reach `F`. A new AA, `AAFoldRuntimeCall`, is created for each call site. In each
iteration, it will check all reaching kernel entries, and update the folded value
accordingly.

In the future we will support more function.

Reviewed By: jdoerfert

Differential Revision: https://reviews.llvm.org/D105787
  • Loading branch information
shiltian committed Jul 15, 2021
1 parent 4e3dc6b commit ca66229
Show file tree
Hide file tree
Showing 3 changed files with 439 additions and 2 deletions.
257 changes: 257 additions & 0 deletions llvm/lib/Transforms/IPO/OpenMPOpt.cpp
Expand Up @@ -497,6 +497,12 @@ struct KernelInfoState : AbstractState {
/// one we abort as the kernel is malformed.
CallBase *KernelDeinitCB = nullptr;

/// Flag to indicate if the associated function is a kernel entry.
bool IsKernelEntry = false;

/// State to track what kernel entries can reach the associated function.
BooleanStateWithPtrSetVector<Function, false> ReachingKernelEntries;

/// Abstract State interface
///{

Expand Down Expand Up @@ -537,6 +543,8 @@ struct KernelInfoState : AbstractState {
return false;
if (ReachedUnknownParallelRegions != RHS.ReachedUnknownParallelRegions)
return false;
if (ReachingKernelEntries != RHS.ReachingKernelEntries)
return false;
return true;
}

Expand Down Expand Up @@ -2729,6 +2737,10 @@ struct AAKernelInfoFunction : AAKernelInfo {
if (!OMPInfoCache.Kernels.count(Fn))
return;

// Add itself to the reaching kernel and set IsKernelEntry.
ReachingKernelEntries.insert(Fn);
IsKernelEntry = true;

OMPInformationCache::RuntimeFunctionInfo &InitRFI =
OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
OMPInformationCache::RuntimeFunctionInfo &DeinitRFI =
Expand Down Expand Up @@ -3213,6 +3225,9 @@ struct AAKernelInfoFunction : AAKernelInfo {
CheckRWInst, *this, UsedAssumedInformationInCheckRWInst))
SPMDCompatibilityTracker.indicatePessimisticFixpoint();

if (!IsKernelEntry)
updateReachingKernelEntries(A);

// Callback to check a call instruction.
auto CheckCallInst = [&](Instruction &I) {
auto &CB = cast<CallBase>(I);
Expand All @@ -3231,6 +3246,35 @@ struct AAKernelInfoFunction : AAKernelInfo {
return StateBefore == getState() ? ChangeStatus::UNCHANGED
: ChangeStatus::CHANGED;
}

private:
/// Update info regarding reaching kernels.
void updateReachingKernelEntries(Attributor &A) {
auto PredCallSite = [&](AbstractCallSite ACS) {
Function *Caller = ACS.getInstruction()->getFunction();

assert(Caller && "Caller is nullptr");

auto &CAA =
A.getOrCreateAAFor<AAKernelInfo>(IRPosition::function(*Caller));
if (CAA.ReachingKernelEntries.isValidState()) {
ReachingKernelEntries ^= CAA.ReachingKernelEntries;
return true;
}

// We lost track of the caller of the associated function, any kernel
// could reach now.
ReachingKernelEntries.indicatePessimisticFixpoint();

return true;
};

bool AllCallSitesKnown;
if (!A.checkForAllCallSites(PredCallSite, *this,
true /* RequireAllCallSites */,
AllCallSitesKnown))
ReachingKernelEntries.indicatePessimisticFixpoint();
}
};

/// The call site kernel info abstract attribute, basically, what can we say
Expand Down Expand Up @@ -3377,6 +3421,186 @@ struct AAKernelInfoCallSite : AAKernelInfo {
}
};

struct AAFoldRuntimeCall
: public StateWrapper<BooleanState, AbstractAttribute> {
using Base = StateWrapper<BooleanState, AbstractAttribute>;

AAFoldRuntimeCall(const IRPosition &IRP, Attributor &A) : Base(IRP) {}

/// Statistics are tracked as part of manifest for now.
void trackStatistics() const override {}

/// Create an abstract attribute biew for the position \p IRP.
static AAFoldRuntimeCall &createForPosition(const IRPosition &IRP,
Attributor &A);

/// See AbstractAttribute::getName()
const std::string getName() const override { return "AAFoldRuntimeCall"; }

/// See AbstractAttribute::getIdAddr()
const char *getIdAddr() const override { return &ID; }

/// This function should return true if the type of the \p AA is
/// AAFoldRuntimeCall
static bool classof(const AbstractAttribute *AA) {
return (AA->getIdAddr() == &ID);
}

static const char ID;
};

struct AAFoldRuntimeCallCallSiteReturned : AAFoldRuntimeCall {
AAFoldRuntimeCallCallSiteReturned(const IRPosition &IRP, Attributor &A)
: AAFoldRuntimeCall(IRP, A) {}

/// See AbstractAttribute::getAsStr()
const std::string getAsStr() const override {
if (!isValidState())
return "<invalid>";

std::string Str("simplified value: ");

if (!SimplifiedValue.hasValue())
return Str + std::string("none");

if (!SimplifiedValue.getValue())
return Str + std::string("nullptr");

if (ConstantInt *CI = dyn_cast<ConstantInt>(SimplifiedValue.getValue()))
return Str + std::to_string(CI->getSExtValue());

return Str + std::string("unknown");
}

void initialize(Attributor &A) override {
Function *Callee = getAssociatedFunction();

auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee);
assert(It != OMPInfoCache.RuntimeFunctionIDMap.end() &&
"Expected a known OpenMP runtime function");

RFKind = It->getSecond();

CallBase &CB = cast<CallBase>(getAssociatedValue());
A.registerSimplificationCallback(
IRPosition::callsite_returned(CB),
[&](const IRPosition &IRP, const AbstractAttribute *AA,
bool &UsedAssumedInformation) -> Optional<Value *> {
assert((isValidState() || (SimplifiedValue.hasValue() &&
SimplifiedValue.getValue() == nullptr)) &&
"Unexpected invalid state!");

if (!isAtFixpoint()) {
UsedAssumedInformation = true;
if (AA)
A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
}
return SimplifiedValue;
});
}

ChangeStatus updateImpl(Attributor &A) override {
ChangeStatus Changed = ChangeStatus::UNCHANGED;

switch (RFKind) {
case OMPRTL___kmpc_is_spmd_exec_mode:
Changed = Changed | foldIsSPMDExecMode(A);
break;
default:
llvm_unreachable("Unhandled OpenMP runtime function!");
}

return Changed;
}

ChangeStatus manifest(Attributor &A) override {
ChangeStatus Changed = ChangeStatus::UNCHANGED;

if (SimplifiedValue.hasValue() && SimplifiedValue.getValue()) {
Instruction &CB = *getCtxI();
A.changeValueAfterManifest(CB, **SimplifiedValue);
A.deleteAfterManifest(CB);
Changed = ChangeStatus::CHANGED;
}

return Changed;
}

ChangeStatus indicatePessimisticFixpoint() override {
SimplifiedValue = nullptr;
return AAFoldRuntimeCall::indicatePessimisticFixpoint();
}

private:
/// Fold __kmpc_is_spmd_exec_mode into a constant if possible.
ChangeStatus foldIsSPMDExecMode(Attributor &A) {
Optional<Value *> SimplifiedValueBefore = SimplifiedValue;

unsigned AssumedSPMDCount = 0, KnownSPMDCount = 0;
unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0;
auto &CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
*this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);

if (!CallerKernelInfoAA.ReachingKernelEntries.isValidState())
return indicatePessimisticFixpoint();

for (Kernel K : CallerKernelInfoAA.ReachingKernelEntries) {
auto &AA = A.getAAFor<AAKernelInfo>(*this, IRPosition::function(*K),
DepClassTy::REQUIRED);

if (!AA.isValidState()) {
SimplifiedValue = nullptr;
return indicatePessimisticFixpoint();
}

if (AA.SPMDCompatibilityTracker.isAssumed()) {
if (AA.SPMDCompatibilityTracker.isAtFixpoint())
++KnownSPMDCount;
else
++AssumedSPMDCount;
} else {
if (AA.SPMDCompatibilityTracker.isAtFixpoint())
++KnownNonSPMDCount;
else
++AssumedNonSPMDCount;
}
}

if (KnownSPMDCount && KnownNonSPMDCount)
return indicatePessimisticFixpoint();

if (AssumedSPMDCount && AssumedNonSPMDCount)
return indicatePessimisticFixpoint();

auto &Ctx = getAnchorValue().getContext();
if (KnownSPMDCount || AssumedSPMDCount) {
assert(KnownNonSPMDCount == 0 && AssumedNonSPMDCount == 0 &&
"Expected only SPMD kernels!");
// All reaching kernels are in SPMD mode. Update all function calls to
// __kmpc_is_spmd_exec_mode to 1.
SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), true);
} else {
assert(KnownSPMDCount == 0 && AssumedSPMDCount == 0 &&
"Expected only non-SPMD kernels!");
// All reaching kernels are in non-SPMD mode. Update all function
// calls to __kmpc_is_spmd_exec_mode to 0.
SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), false);
}

return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
: ChangeStatus::CHANGED;
}

/// An optional value the associated value is assumed to fold to. That is, we
/// assume the associated value (which is a call) can be replaced by this
/// simplified value.
Optional<Value *> SimplifiedValue;

/// The runtime function kind of the callee of the associated call site.
RuntimeFunction RFKind;
};

} // namespace

void OpenMPOpt::registerAAs(bool IsModulePass) {
Expand All @@ -3393,6 +3617,18 @@ void OpenMPOpt::registerAAs(bool IsModulePass) {
IRPosition::function(*Kernel), /* QueryingAA */ nullptr,
DepClassTy::NONE, /* ForceUpdate */ false,
/* UpdateAfterInit */ false);

auto &IsSPMDRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_is_spmd_exec_mode];
IsSPMDRFI.foreachUse(SCC, [&](Use &U, Function &) {
CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &IsSPMDRFI);
if (!CI)
return false;
A.getOrCreateAAFor<AAFoldRuntimeCall>(
IRPosition::callsite_returned(*CI), /* QueryingAA */ nullptr,
DepClassTy::NONE, /* ForceUpdate */ false,
/* UpdateAfterInit */ false);
return false;
});
}

// Create CallSite AA for all Getters.
Expand Down Expand Up @@ -3436,6 +3672,7 @@ const char AAICVTracker::ID = 0;
const char AAKernelInfo::ID = 0;
const char AAExecutionDomain::ID = 0;
const char AAHeapToShared::ID = 0;
const char AAFoldRuntimeCall::ID = 0;

AAICVTracker &AAICVTracker::createForPosition(const IRPosition &IRP,
Attributor &A) {
Expand Down Expand Up @@ -3527,6 +3764,26 @@ AAKernelInfo &AAKernelInfo::createForPosition(const IRPosition &IRP,
return *AA;
}

AAFoldRuntimeCall &AAFoldRuntimeCall::createForPosition(const IRPosition &IRP,
Attributor &A) {
AAFoldRuntimeCall *AA = nullptr;
switch (IRP.getPositionKind()) {
case IRPosition::IRP_INVALID:
case IRPosition::IRP_FLOAT:
case IRPosition::IRP_ARGUMENT:
case IRPosition::IRP_RETURNED:
case IRPosition::IRP_FUNCTION:
case IRPosition::IRP_CALL_SITE:
case IRPosition::IRP_CALL_SITE_ARGUMENT:
llvm_unreachable("KernelInfo can only be created for call site position!");
case IRPosition::IRP_CALL_SITE_RETURNED:
AA = new (A.Allocator) AAFoldRuntimeCallCallSiteReturned(IRP, A);
break;
}

return *AA;
}

PreservedAnalyses OpenMPOptPass::run(Module &M, ModuleAnalysisManager &AM) {
if (!containsOpenMP(M))
return PreservedAnalyses::all();
Expand Down
4 changes: 2 additions & 2 deletions llvm/test/Transforms/OpenMP/custom_state_machines.ll
Expand Up @@ -1713,8 +1713,8 @@ attributes #10 = { convergent nounwind readonly willreturn }
; CHECK: if.end:
; CHECK-NEXT: [[TMP1:%.*]] = load i32, i32* [[A_ADDR]], align 4
; CHECK-NEXT: [[SUB:%.*]] = sub nsw i32 [[TMP1]], 1
; CHECK-NEXT: call void @simple_state_machine_interprocedural_nested_recursive_after.internalized(i32 [[SUB]]) #[[ATTR8]]
; CHECK-NEXT: call void @simple_state_machine_interprocedural_nested_recursive_after_after.internalized() #[[ATTR8]]
; CHECK-NEXT: call void @simple_state_machine_interprocedural_nested_recursive_after.internalized(i32 [[SUB]]) #[[ATTR7]]
; CHECK-NEXT: call void @simple_state_machine_interprocedural_nested_recursive_after_after.internalized() #[[ATTR7]]
; CHECK-NEXT: br label [[RETURN]]
; CHECK: return:
; CHECK-NEXT: ret void
Expand Down

0 comments on commit ca66229

Please sign in to comment.