Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions llvm/include/llvm/Frontend/OpenMP/OMPKinds.def
Original file line number Diff line number Diff line change
Expand Up @@ -1205,6 +1205,38 @@ __OMP_PROC_BIND_KIND(unknown, 7)

///}


/// Callback information in OpenMP Runtime Functions
///
///{

#define Indices(...) ArrayRef<int>({__VA_ARGS__})

#ifndef OMP_RTL_CB_INFO
#define OMP_RTL_CB_INFO(Enum, Str, ArgNo, ArgIndices, IsVarArg)
#endif

#define __OMP_RTL_CB_INFO(Name, ArgNo, ArgIndices, IsVarArg) \
OMP_RTL_CB_INFO(OMPRTL_##Name, #Name, ArgNo, ArgIndices, IsVarArg)

__OMP_RTL_CB_INFO(__kmpc_distribute_static_loop_4, 1, Indices(-1, 2), false)
__OMP_RTL_CB_INFO(__kmpc_distribute_static_loop_4u, 1, Indices(-1, 2), false)
__OMP_RTL_CB_INFO(__kmpc_distribute_static_loop_8, 1, Indices(-1, 2), false)
__OMP_RTL_CB_INFO(__kmpc_distribute_static_loop_8u, 1, Indices(-1, 2), false)
__OMP_RTL_CB_INFO(__kmpc_distribute_for_static_loop_4, 1, Indices(-1, 2), false)
__OMP_RTL_CB_INFO(__kmpc_distribute_for_static_loop_4u, 1, Indices(-1, 2), false)
__OMP_RTL_CB_INFO(__kmpc_distribute_for_static_loop_8, 1, Indices(-1, 2), false)
__OMP_RTL_CB_INFO(__kmpc_distribute_for_static_loop_8u, 1, Indices(-1, 2), false)
__OMP_RTL_CB_INFO(__kmpc_for_static_loop_4, 1, Indices(-1, 2), false)
__OMP_RTL_CB_INFO(__kmpc_for_static_loop_4u, 1, Indices(-1, 2), false)
__OMP_RTL_CB_INFO(__kmpc_for_static_loop_8, 1, Indices(-1, 2), false)
__OMP_RTL_CB_INFO(__kmpc_for_static_loop_8u, 1, Indices(-1, 2), false)

#undef __OMP_RTL_CB_INFO
#undef OMP_RTL_CB_INFO

///}

/// OpenMP context related definitions:
/// - trait set selector
/// - trait selector
Expand Down
140 changes: 124 additions & 16 deletions llvm/lib/Transforms/IPO/OpenMPOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
#include "llvm/IR/IntrinsicsAMDGPU.h"
#include "llvm/IR/IntrinsicsNVPTX.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/MDBuilder.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
Expand Down Expand Up @@ -544,6 +545,24 @@ struct OMPInformationCache : public InformationCache {
collectUses(RFI, /*CollectStats*/ false);
}

void setCallbackMetadata(Function *F, unsigned ArgNo, ArrayRef<int> Indices,
bool IsVarArg) {
if (!F)
return;

LLVMContext &Ctx = F->getContext();
MDBuilder MDB(Ctx);

// Create the new callback encoding for this runtime function
MDNode *NewCallbackEncoding =
MDB.createCallbackEncoding(ArgNo, Indices, IsVarArg);

if (!F->getMetadata(LLVMContext::MD_callback))
// No existing metadata, create new with single entry
F->addMetadata(LLVMContext::MD_callback,
*MDNode::get(Ctx, {NewCallbackEncoding}));
}

// Helper function to recollect uses of all runtime functions.
void recollectUses() {
for (int Idx = 0; Idx < RFIs.size(); ++Idx)
Expand Down Expand Up @@ -627,8 +646,13 @@ struct OMPInformationCache : public InformationCache {
}); \
} \
}
#include "llvm/Frontend/OpenMP/OMPKinds.def"
#define OMP_RTL_CB_INFO(_Enum, _Name, _ArgNo, _ArgIndices, _IsVarArg) \
{ \
Function *F = M.getFunction(_Name); \
setCallbackMetadata(F, _ArgNo, _ArgIndices, _IsVarArg); \
}

#include "llvm/Frontend/OpenMP/OMPKinds.def"
// Remove the `noinline` attribute from `__kmpc`, `ompx::` and `omp_`
// functions, except if `optnone` is present.
if (isOpenMPDevice(M)) {
Expand Down Expand Up @@ -4752,6 +4776,60 @@ struct AAKernelInfoFunction : AAKernelInfo {
bool AllSPMDStatesWereFixed = true;
auto CheckCallInst = [&](Instruction &I) {
auto &CB = cast<CallBase>(I);
auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
Function *Callee = CB.getCalledFunction();
const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee);
if (It != OMPInfoCache.RuntimeFunctionIDMap.end()) {
MDNode *CallbackMD = Callee->getMetadata(LLVMContext::MD_callback);
// If this runtime function has callbacks, we need to look at them
// to find potential parallel regions.
if (CallbackMD && CallbackMD->getNumOperands() > 0) {
// TODO: Handle multiple callbacks?
MDNode *OpMD = cast<MDNode>(CallbackMD->getOperand(0).get());
if (OpMD && OpMD->getNumOperands() > 0) {
auto *CBArgCM = cast<ConstantAsMetadata>(OpMD->getOperand(0));
const unsigned int ArgNo =
cast<ConstantInt>(CBArgCM->getValue())->getZExtValue();
auto *LoopRegion = dyn_cast<Function>(
CB.getArgOperand(ArgNo)->stripPointerCasts());
// Only analyze the callback if we have a concrete function
// definition. Declarations cannot be analyzed interprocedurally.
if (LoopRegion && !LoopRegion->isDeclaration()) {
LLVM_DEBUG(dbgs() << "[OpenMPOpt] Analyzing callback function: "
<< LoopRegion->getName() << "\n");
auto *FnAA = A.getAAFor<AAKernelInfo>(
*this, IRPosition::function(*LoopRegion),
DepClassTy::OPTIONAL);
if (FnAA) {
getState() ^= FnAA->getState();
AllSPMDStatesWereFixed &=
FnAA->SPMDCompatibilityTracker.isAtFixpoint();
AllParallelRegionStatesWereFixed &=
FnAA->ReachedKnownParallelRegions.isAtFixpoint();
AllParallelRegionStatesWereFixed &=
FnAA->ReachedUnknownParallelRegions.isAtFixpoint();
}
} else {
LLVM_DEBUG({
if (LoopRegion && LoopRegion->isDeclaration()) {
dbgs() << "[OpenMPOpt] Skipping callback analysis for "
"declaration-only function: "
<< LoopRegion->getName() << "\n";
}
});
}
#ifndef NDEBUG
// Verify our assumption: if it has callback metadata, we should
// typically be able to resolve the callback function.
if (!LoopRegion) {
LLVM_DEBUG(dbgs() << "[OpenMPOpt] Warning: Could not resolve "
"callback function for runtime call to "
<< Callee->getName() << "\n");
}
#endif
}
}
}
auto *CBAA = A.getAAFor<AAKernelInfo>(
*this, IRPosition::callsite_function(CB), DepClassTy::OPTIONAL);
if (!CBAA)
Expand Down Expand Up @@ -4923,14 +5001,28 @@ struct AAKernelInfoCallSite : AAKernelInfo {
// won't be any change so we indicate a fixpoint.
indicateOptimisticFixpoint();
}
// If the callee is known and can be used in IPO, we will update the
// state based on the callee state in updateImpl.
return;
}
if (NumCallees > 1) {
indicatePessimisticFixpoint();
return;
// If the callee is known and can be used in IPO, we will update the
// state based on the callee state in updateImpl.
return;
}
// Check if we have multiple possible callees. This usually indicates an
// indirect call where we don't know the target, requiring a pessimistic
// fixpoint. However, for callback functions, multiple edges are expected:
// one to the runtime function and edges through callback parameters. These
// are analyzable, so we exclude them from the pessimistic check.
if (NumCallees > 1 && !Callee->hasMetadata(LLVMContext::MD_callback)) {
LLVM_DEBUG(dbgs() << "[OpenMPOpt] Multiple callees found, forcing "
"pessimistic fixpoint\n");
indicatePessimisticFixpoint();
return;
}
LLVM_DEBUG({
if (NumCallees > 1 && Callee->hasMetadata(LLVMContext::MD_callback)) {
dbgs() << "[OpenMPOpt] Allowing multiple callees for callback "
"function: "
<< Callee->getName() << "\n";
}
});

RuntimeFunction RF = It->getSecond();
switch (RF) {
Expand All @@ -4948,6 +5040,7 @@ struct AAKernelInfoCallSite : AAKernelInfo {
case OMPRTL___kmpc_barrier:
case OMPRTL___kmpc_nvptx_parallel_reduce_nowait_v2:
case OMPRTL___kmpc_nvptx_teams_reduce_nowait_v2:
case OMPRTL___kmpc_reduction_get_fixed_buffer:
case OMPRTL___kmpc_error:
case OMPRTL___kmpc_flush:
case OMPRTL___kmpc_get_hardware_thread_id_in_block:
Expand Down Expand Up @@ -5043,8 +5136,8 @@ struct AAKernelInfoCallSite : AAKernelInfo {
// kernel from being SPMD-izable. We mark it as such because we need
// further changes in order to also consider the contents of the
// callbacks passed to them.
SPMDCompatibilityTracker.indicatePessimisticFixpoint();
SPMDCompatibilityTracker.insert(&CB);
// SPMDCompatibilityTracker.indicatePessimisticFixpoint();
// SPMDCompatibilityTracker.insert(&CB);
break;
default:
// Unknown OpenMP runtime calls cannot be executed in SPMD-mode,
Expand Down Expand Up @@ -5092,13 +5185,28 @@ struct AAKernelInfoCallSite : AAKernelInfo {
A.getAAFor<AAKernelInfo>(*this, FnPos, DepClassTy::REQUIRED);
if (!FnAA)
return indicatePessimisticFixpoint();
if (getState() == FnAA->getState())
return ChangeStatus::UNCHANGED;
getState() = FnAA->getState();
return ChangeStatus::CHANGED;
if (getState() == FnAA->getState())
return ChangeStatus::UNCHANGED;
getState() = FnAA->getState();
return ChangeStatus::CHANGED;
}
// Check if we have multiple possible callees. This usually indicates an
// indirect call where we don't know the target, requiring a pessimistic
// fixpoint. However, for callback functions, multiple edges are expected:
// one to the runtime function and edges through callback parameters. These
// are analyzable, so we exclude them from the pessimistic check.
if (NumCallees > 1 && !F->hasMetadata(LLVMContext::MD_callback)) {
LLVM_DEBUG(dbgs() << "[OpenMPOpt] Multiple callees in update, forcing "
"pessimistic fixpoint\n");
return indicatePessimisticFixpoint();
}
LLVM_DEBUG({
if (NumCallees > 1 && F->hasMetadata(LLVMContext::MD_callback)) {
dbgs() << "[OpenMPOpt] Allowing multiple callees for callback "
"function in update: "
<< F->getName() << "\n";
}
if (NumCallees > 1)
return indicatePessimisticFixpoint();
});

CallBase &CB = cast<CallBase>(getAssociatedValue());
if (It->getSecond() == OMPRTL___kmpc_parallel_51) {
Expand Down
90 changes: 90 additions & 0 deletions llvm/test/Transforms/OpenMP/callback_guards.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
; RUN: opt -passes=openmp-opt -S < %s | FileCheck %s

%struct.ident_t = type { i32, i32, i32, i32, ptr }
%struct.DynamicEnvironmentTy = type { i16 }
%struct.KernelEnvironmentTy = type { %struct.ConfigurationEnvironmentTy, ptr, ptr }
%struct.ConfigurationEnvironmentTy = type { i8, i8, i8, i32, i32, i32, i32, i32, i32 }

@0 = private unnamed_addr addrspace(1) constant [23 x i8] c";unknown;unknown;0;0;;\00", align 1
@1 = private unnamed_addr addrspace(1) constant %struct.ident_t { i32 0, i32 2, i32 0, i32 22, ptr addrspacecast (ptr addrspace(1) @0 to ptr) }, align 8
@__omp_offloading_10303_1849aab__QQmain_l22_exec_mode = weak protected addrspace(1) constant i8 1
@__omp_offloading_10303_1849aab__QQmain_l22_dynamic_environment = weak_odr protected addrspace(1) global %struct.DynamicEnvironmentTy zeroinitializer
@__omp_offloading_10303_1849aab__QQmain_l22_kernel_environment = weak_odr protected addrspace(1) constant %struct.KernelEnvironmentTy { %struct.ConfigurationEnvironmentTy { i8 1, i8 1, i8 1, i32 1, i32 256, i32 0, i32 0, i32 4, i32 1024 }, ptr addrspacecast (ptr addrspace(1) @1 to ptr), ptr addrspacecast (ptr addrspace(1) @__omp_offloading_10303_1849aab__QQmain_l22_dynamic_environment to ptr) }

; Function Attrs: nounwind
define internal void @parallel_func_..omp_par.3(ptr noalias noundef %tid.addr.ascast, ptr noalias noundef %zero.addr.ascast, ptr %0) #1 {
omp.par.entry:
ret void
}

; Function Attrs: mustprogress
define weak_odr protected amdgpu_kernel void @__omp_offloading_10303_1849aab__QQmain_l22(ptr %0, ptr %1, ptr %2) #4 {
entry:
%7 = call i32 @__kmpc_target_init(ptr addrspacecast (ptr addrspace(1) @__omp_offloading_10303_1849aab__QQmain_l22_kernel_environment to ptr), ptr %0)
%exec_user_code = icmp eq i32 %7, -1
br i1 %exec_user_code, label %user_code.entry, label %worker.exit

user_code.entry: ; preds = %entry
call void @__kmpc_distribute_static_loop_4u(ptr addrspacecast (ptr addrspace(1) @1 to ptr), ptr @__omp_offloading_10303_1849aab__QQmain_l22..omp_par, ptr %2, i32 100, i32 0, i8 0)
call void @__kmpc_target_deinit()
br label %worker.exit

worker.exit: ; preds = %entry
ret void
}


define internal void @__omp_offloading_10303_1849aab__QQmain_l22..omp_par(i32 %0, ptr %1) {
omp_loop.body:
%gep = getelementptr { ptr, ptr }, ptr %1, i32 0, i32 1
%p = load ptr, ptr %gep, align 8
%5 = add i32 %0, 1
store i32 %5, ptr %p, align 4
%omp_global_thread_num = call i32 @__kmpc_global_thread_num(ptr addrspacecast (ptr addrspace(1) @1 to ptr))
call void @__kmpc_parallel_51(ptr addrspacecast (ptr addrspace(1) @1 to ptr), i32 %omp_global_thread_num, i32 1, i32 -1, i32 -1, ptr @parallel_func_..omp_par.3, ptr @parallel_func_..omp_par.3.wrapper, ptr %1, i64 1)
%6 = load i32, ptr %p, align 4
%7 = add i32 %6, 1
store i32 %7, ptr %p, align 4
ret void
}

define internal void @parallel_func_..omp_par.3.wrapper(i16 noundef zeroext %0, i32 noundef %1) {
entry:
%addr = alloca i32, align 4, addrspace(5)
%addr.ascast = addrspacecast ptr addrspace(5) %addr to ptr
%zero = alloca i32, align 4, addrspace(5)
%zero.ascast = addrspacecast ptr addrspace(5) %zero to ptr
%global_args = alloca ptr, align 8, addrspace(5)
%global_args.ascast = addrspacecast ptr addrspace(5) %global_args to ptr
store i32 %1, ptr %addr.ascast, align 4
store i32 0, ptr %zero.ascast, align 4
call void @__kmpc_get_shared_variables(ptr %global_args.ascast)
%2 = load ptr, ptr %global_args.ascast, align 8
%3 = getelementptr inbounds ptr, ptr %2, i64 0
%structArg = load ptr, ptr %3, align 8
call void @parallel_func_..omp_par.3(ptr %addr.ascast, ptr %zero.ascast, ptr %structArg)
ret void
}


declare void @__kmpc_get_shared_variables(ptr)
declare i32 @__kmpc_target_init(ptr, ptr)
declare noalias ptr @__kmpc_alloc_shared(i64)
declare void @__kmpc_target_deinit()
declare i32 @__kmpc_global_thread_num(ptr)
declare void @__kmpc_parallel_51(ptr, i32, i32, i32, i32, ptr, ptr, ptr, i64)
declare void @__kmpc_distribute_static_loop_4u(ptr, ptr, ptr, i32, i32, i8)

attributes #1 = { nounwind "frame-pointer"="all" }
attributes #4 = { "kernel" }

!llvm.module.flags = !{!0, !1}

!0 = !{i32 7, !"openmp-device", i32 52}
!1 = !{i32 7, !"openmp", i32 52}

; CHECK: @__omp_offloading_{{.*}}_kernel_environment = {{.*}}%struct.KernelEnvironmentTy { %struct.ConfigurationEnvironmentTy { i8 0, i8 0, i8 3,
; CHECK: define internal void @__omp_offloading_10303_1849aab__QQmain_l22..omp_par(
; CHECK: region.guarded:
; CHECK: region.guarded{{[0-9]+}}:
; CHECK: ret void
Loading