Skip to content

Commit

Permalink
[OpenMP] Unified entry point for SPMD & generic kernels in the device…
Browse files Browse the repository at this point in the history
… RTL

In the spirit of TRegions [0], this patch provides a simpler and uniform
interface for a kernel to set up the device runtime. The OMPIRBuilder is
used for reuse in Flang. A custom state machine will be generated in the
follow up patch.

The "surplus" threads of the "master warp" will not exit early anymore
so we need to use non-aligned barriers. The new runtime will not have an
extra warp but also require these non-aligned barriers.

[0] https://link.springer.com/chapter/10.1007/978-3-030-28596-8_11

This was in parts extracted from D59319.

Reviewed By: ABataev, JonChesterfield

Differential Revision: https://reviews.llvm.org/D101976
  • Loading branch information
jdoerfert committed Jul 10, 2021
1 parent 5003ba2 commit 1d5711c
Show file tree
Hide file tree
Showing 13 changed files with 377 additions and 449 deletions.
353 changes: 24 additions & 329 deletions clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp

Large diffs are not rendered by default.

38 changes: 6 additions & 32 deletions clang/lib/CodeGen/CGOpenMPRuntimeGPU.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,19 +38,7 @@ class CGOpenMPRuntimeGPU : public CGOpenMPRuntime {
llvm::SmallVector<llvm::Function *, 16> Work;

struct EntryFunctionState {
llvm::BasicBlock *ExitBB = nullptr;
};

class WorkerFunctionState {
public:
llvm::Function *WorkerFn;
const CGFunctionInfo &CGFI;
SourceLocation Loc;

WorkerFunctionState(CodeGenModule &CGM, SourceLocation Loc);

private:
void createWorkerFunction(CodeGenModule &CGM);
};

ExecutionMode getExecutionMode() const;
Expand All @@ -60,20 +48,13 @@ class CGOpenMPRuntimeGPU : public CGOpenMPRuntime {
/// Get barrier to synchronize all threads in a block.
void syncCTAThreads(CodeGenFunction &CGF);

/// Emit the worker function for the current target region.
void emitWorkerFunction(WorkerFunctionState &WST);
/// Helper for target directive initialization.
void emitKernelInit(CodeGenFunction &CGF, EntryFunctionState &EST,
bool IsSPMD);

/// Helper for worker function. Emit body of worker loop.
void emitWorkerLoop(CodeGenFunction &CGF, WorkerFunctionState &WST);

/// Helper for non-SPMD target entry function. Guide the master and
/// worker threads to their respective locations.
void emitNonSPMDEntryHeader(CodeGenFunction &CGF, EntryFunctionState &EST,
WorkerFunctionState &WST);

/// Signal termination of OMP execution for non-SPMD target entry
/// function.
void emitNonSPMDEntryFooter(CodeGenFunction &CGF, EntryFunctionState &EST);
/// Helper for target directive finalization.
void emitKernelDeinit(CodeGenFunction &CGF, EntryFunctionState &EST,
bool IsSPMD);

/// Helper for generic variables globalization prolog.
void emitGenericVarsProlog(CodeGenFunction &CGF, SourceLocation Loc,
Expand All @@ -82,13 +63,6 @@ class CGOpenMPRuntimeGPU : public CGOpenMPRuntime {
/// Helper for generic variables globalization epilog.
void emitGenericVarsEpilog(CodeGenFunction &CGF, bool WithSPMDCheck = false);

/// Helper for SPMD mode target directive's entry function.
void emitSPMDEntryHeader(CodeGenFunction &CGF, EntryFunctionState &EST,
const OMPExecutableDirective &D);

/// Signal termination of SPMD mode execution.
void emitSPMDEntryFooter(CodeGenFunction &CGF, EntryFunctionState &EST);

//
// Base class overrides.
//
Expand Down
23 changes: 23 additions & 0 deletions llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -779,6 +779,29 @@ class OpenMPIRBuilder {
llvm::ConstantInt *Size,
const llvm::Twine &Name = Twine(""));

/// The `omp target` interface
///
/// For more information about the usage of this interface,
/// \see openmp/libomptarget/deviceRTLs/common/include/target.h
///
///{

/// Create a runtime call for kmpc_target_init
///
/// \param Loc The insert and source location description.
/// \param IsSPMD Flag to indicate if the kernel is an SPMD kernel or not.
/// \param RequiresFullRuntime Indicate if a full device runtime is necessary.
InsertPointTy createTargetInit(const LocationDescription &Loc, bool IsSPMD, bool RequiresFullRuntime);

/// Create a runtime call for kmpc_target_deinit
///
/// \param Loc The insert and source location description.
/// \param IsSPMD Flag to indicate if the kernel is an SPMD kernel or not.
/// \param RequiresFullRuntime Indicate if a full device runtime is necessary.
void createTargetDeinit(const LocationDescription &Loc, bool IsSPMD, bool RequiresFullRuntime);

///}

/// Declarations for LLVM-IR types (simple, array, function and structure) are
/// generated below. Their names are defined and used in OpenMPKinds.def. Here
/// we provide the declarations, the initializeTypes function will provide the
Expand Down
6 changes: 2 additions & 4 deletions llvm/include/llvm/Frontend/OpenMP/OMPKinds.def
Original file line number Diff line number Diff line change
Expand Up @@ -409,10 +409,8 @@ __OMP_RTL(__kmpc_task_allow_completion_event, false, VoidPtr, IdentPtr,
/* Int */ Int32, /* kmp_task_t */ VoidPtr)

/// OpenMP Device runtime functions
__OMP_RTL(__kmpc_kernel_init, false, Void, Int32, Int16)
__OMP_RTL(__kmpc_kernel_deinit, false, Void, Int16)
__OMP_RTL(__kmpc_spmd_kernel_init, false, Void, Int32, Int16)
__OMP_RTL(__kmpc_spmd_kernel_deinit_v2, false, Void, Int16)
__OMP_RTL(__kmpc_target_init, false, Int32, IdentPtr, Int1, Int1, Int1)
__OMP_RTL(__kmpc_target_deinit, false, Void, IdentPtr, Int1, Int1)
__OMP_RTL(__kmpc_kernel_prepare_parallel, false, Void, VoidPtr)
__OMP_RTL(__kmpc_parallel_51, false, Void, IdentPtr, Int32, Int32, Int32, Int32,
VoidPtr, VoidPtr, VoidPtrPtr, SizeTy)
Expand Down
65 changes: 65 additions & 0 deletions llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "llvm/IR/DebugInfo.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/MDBuilder.h"
#include "llvm/IR/Value.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Error.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
Expand Down Expand Up @@ -2191,6 +2192,70 @@ CallInst *OpenMPIRBuilder::createCachedThreadPrivate(
return Builder.CreateCall(Fn, Args);
}

OpenMPIRBuilder::InsertPointTy
OpenMPIRBuilder::createTargetInit(const LocationDescription &Loc, bool IsSPMD, bool RequiresFullRuntime) {
if (!updateToLocation(Loc))
return Loc.IP;

Constant *SrcLocStr = getOrCreateSrcLocStr(Loc);
Value *Ident = getOrCreateIdent(SrcLocStr);
ConstantInt *IsSPMDVal = ConstantInt::getBool(Int32->getContext(), IsSPMD);
ConstantInt *UseGenericStateMachine =
ConstantInt::getBool(Int32->getContext(), !IsSPMD);
ConstantInt *RequiresFullRuntimeVal = ConstantInt::getBool(Int32->getContext(), RequiresFullRuntime);

Function *Fn = getOrCreateRuntimeFunctionPtr(
omp::RuntimeFunction::OMPRTL___kmpc_target_init);

CallInst *ThreadKind =
Builder.CreateCall(Fn, {Ident, IsSPMDVal, UseGenericStateMachine, RequiresFullRuntimeVal});

Value *ExecUserCode = Builder.CreateICmpEQ(
ThreadKind, ConstantInt::get(ThreadKind->getType(), -1), "exec_user_code");

// ThreadKind = __kmpc_target_init(...)
// if (ThreadKind == -1)
// user_code
// else
// return;

auto *UI = Builder.CreateUnreachable();
BasicBlock *CheckBB = UI->getParent();
BasicBlock *UserCodeEntryBB = CheckBB->splitBasicBlock(UI, "user_code.entry");

BasicBlock *WorkerExitBB = BasicBlock::Create(
CheckBB->getContext(), "worker.exit", CheckBB->getParent());
Builder.SetInsertPoint(WorkerExitBB);
Builder.CreateRetVoid();

auto *CheckBBTI = CheckBB->getTerminator();
Builder.SetInsertPoint(CheckBBTI);
Builder.CreateCondBr(ExecUserCode, UI->getParent(), WorkerExitBB);

CheckBBTI->eraseFromParent();
UI->eraseFromParent();

// Continue in the "user_code" block, see diagram above and in
// openmp/libomptarget/deviceRTLs/common/include/target.h .
return InsertPointTy(UserCodeEntryBB, UserCodeEntryBB->getFirstInsertionPt());
}

void OpenMPIRBuilder::createTargetDeinit(const LocationDescription &Loc,
bool IsSPMD, bool RequiresFullRuntime) {
if (!updateToLocation(Loc))
return;

Constant *SrcLocStr = getOrCreateSrcLocStr(Loc);
Value *Ident = getOrCreateIdent(SrcLocStr);
ConstantInt *IsSPMDVal = ConstantInt::getBool(Int32->getContext(), IsSPMD);
ConstantInt *RequiresFullRuntimeVal = ConstantInt::getBool(Int32->getContext(), RequiresFullRuntime);

Function *Fn = getOrCreateRuntimeFunctionPtr(
omp::RuntimeFunction::OMPRTL___kmpc_target_deinit);

Builder.CreateCall(Fn, {Ident, IsSPMDVal, RequiresFullRuntimeVal});
}

std::string OpenMPIRBuilder::getNameWithSeparators(ArrayRef<StringRef> Parts,
StringRef FirstSeparator,
StringRef Separator) {
Expand Down
49 changes: 18 additions & 31 deletions llvm/lib/Transforms/IPO/OpenMPOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,6 @@
#include "llvm/Frontend/OpenMP/OMPConstants.h"
#include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/IntrinsicsAMDGPU.h"
#include "llvm/IR/IntrinsicsNVPTX.h"
#include "llvm/IR/PatternMatch.h"
#include "llvm/InitializePasses.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Transforms/IPO.h"
Expand All @@ -37,7 +34,6 @@
#include "llvm/Transforms/Utils/CallGraphUpdater.h"
#include "llvm/Transforms/Utils/CodeExtractor.h"

using namespace llvm::PatternMatch;
using namespace llvm;
using namespace omp;

Expand Down Expand Up @@ -2341,10 +2337,12 @@ ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) {
AllCallSitesKnown))
SingleThreadedBBs.erase(&F->getEntryBlock());

// Check if the edge into the successor block compares a thread-id function to
// a constant zero.
// TODO: Use AAValueSimplify to simplify and propogate constants.
// TODO: Check more than a single use for thread ID's.
auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];

// Check if the edge into the successor block compares the __kmpc_target_init
// result with -1. If we are in non-SPMD-mode that signals only the main
// thread will execute the edge.
auto IsInitialThreadOnly = [&](BranchInst *Edge, BasicBlock *SuccessorBB) {
if (!Edge || !Edge->isConditional())
return false;
Expand All @@ -2355,31 +2353,20 @@ ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) {
if (!Cmp || !Cmp->isTrueWhenEqual() || !Cmp->isEquality())
return false;

// Temporarily match the pattern generated by clang for teams regions.
// TODO: Remove this once the new runtime is in place.
ConstantInt *One, *NegOne;
CmpInst::Predicate Pred;
auto &&m_ThreadID = m_Intrinsic<Intrinsic::nvvm_read_ptx_sreg_tid_x>();
auto &&m_WarpSize = m_Intrinsic<Intrinsic::nvvm_read_ptx_sreg_warpsize>();
auto &&m_BlockSize = m_Intrinsic<Intrinsic::nvvm_read_ptx_sreg_ntid_x>();
if (match(Cmp, m_Cmp(Pred, m_ThreadID,
m_And(m_Sub(m_BlockSize, m_ConstantInt(One)),
m_Xor(m_Sub(m_WarpSize, m_ConstantInt(One)),
m_ConstantInt(NegOne))))))
if (One->isOne() && NegOne->isMinusOne() &&
Pred == CmpInst::Predicate::ICMP_EQ)
return true;

ConstantInt *C = dyn_cast<ConstantInt>(Cmp->getOperand(1));
if (!C || !C->isZero())
if (!C)
return false;

if (auto *II = dyn_cast<IntrinsicInst>(Cmp->getOperand(0)))
if (II->getIntrinsicID() == Intrinsic::nvvm_read_ptx_sreg_tid_x)
return true;
if (auto *II = dyn_cast<IntrinsicInst>(Cmp->getOperand(0)))
if (II->getIntrinsicID() == Intrinsic::amdgcn_workitem_id_x)
return true;
// Match: -1 == __kmpc_target_init (for non-SPMD kernels only!)
if (C->isAllOnesValue()) {
auto *CB = dyn_cast<CallBase>(Cmp->getOperand(0));
if (!CB || CB->getCalledFunction() != RFI.Declaration)
return false;
const int InitIsSPMDArgNo = 1;
auto *IsSPMDModeCI =
dyn_cast<ConstantInt>(CB->getOperand(InitIsSPMDArgNo));
return IsSPMDModeCI && IsSPMDModeCI->isZero();
}

return false;
};
Expand All @@ -2394,7 +2381,7 @@ ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) {
for (auto PredBB = pred_begin(BB), PredEndBB = pred_end(BB);
PredBB != PredEndBB; ++PredBB) {
if (!IsInitialThreadOnly(dyn_cast<BranchInst>((*PredBB)->getTerminator()),
BB))
BB))
IsInitialThread &= SingleThreadedBBs.contains(*PredBB);
}

Expand Down
44 changes: 31 additions & 13 deletions llvm/test/Transforms/OpenMP/replace_globalization.ll
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,16 @@
target datalayout = "e-i64:64-i128:128-v16:16-v32:32-n16:32:64"
target triple = "nvptx64"

%struct.ident_t = type { i32, i32, i32, i32, i8* }

@S = external local_unnamed_addr global i8*
@0 = private unnamed_addr constant [113 x i8] c";llvm/test/Transforms/OpenMP/custom_state_machines_remarks.c;__omp_offloading_2a_d80d3d_test_fallback_l11;11;1;;\00", align 1
@1 = private unnamed_addr constant %struct.ident_t { i32 0, i32 2, i32 0, i32 0, i8* getelementptr inbounds ([113 x i8], [113 x i8]* @0, i32 0, i32 0) }, align 8

; CHECK-REMARKS: remark: replace_globalization.c:5:7: Replaced globalized variable with 16 bytes of shared memory
; CHECK-REMARKS: remark: replace_globalization.c:5:14: Replaced globalized variable with 4 bytes of shared memory
; CHECK-REMARKS-NOT: 6 bytes

; CHECK: [[SHARED_X:@.+]] = internal addrspace(3) global [16 x i8] undef
; CHECK: [[SHARED_Y:@.+]] = internal addrspace(3) global [4 x i8] undef

Expand All @@ -25,14 +31,15 @@ entry:
define void @bar() {
call void @baz()
call void @qux()
call void @negative_qux_spmd()
ret void
}

; CHECK: call void @use.internalized(i8* nofree writeonly addrspacecast (i8 addrspace(3)* getelementptr inbounds ([16 x i8], [16 x i8] addrspace(3)* [[SHARED_X]], i32 0, i32 0) to i8*))
define internal void @baz() {
entry:
%tid = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
%cmp = icmp eq i32 %tid, 0
%call = call i32 @__kmpc_target_init(%struct.ident_t* nonnull @1, i1 false, i1 false, i1 true)
%cmp = icmp eq i32 %call, -1
br i1 %cmp, label %master, label %exit
master:
%x = call i8* @__kmpc_alloc_shared(i64 16), !dbg !11
Expand All @@ -48,20 +55,30 @@ exit:
; CHECK: call void @use.internalized(i8* nofree writeonly addrspacecast (i8 addrspace(3)* getelementptr inbounds ([4 x i8], [4 x i8] addrspace(3)* [[SHARED_Y]], i32 0, i32 0) to i8*))
define internal void @qux() {
entry:
%tid = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
%ntid = call i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
%warpsize = call i32 @llvm.nvvm.read.ptx.sreg.warpsize()
%0 = sub nuw i32 %warpsize, 1
%1 = sub nuw i32 %ntid, 1
%2 = xor i32 %0, -1
%master_tid = and i32 %1, %2
%3 = icmp eq i32 %tid, %master_tid
br i1 %3, label %master, label %exit
%call = call i32 @__kmpc_target_init(%struct.ident_t* nonnull @1, i1 false, i1 true, i1 true)
%0 = icmp eq i32 %call, -1
br i1 %0, label %master, label %exit
master:
%y = call i8* @__kmpc_alloc_shared(i64 4), !dbg !12
%y_on_stack = bitcast i8* %y to [4 x i32]*
%4 = bitcast [4 x i32]* %y_on_stack to i8*
call void @use(i8* %4)
%1 = bitcast [4 x i32]* %y_on_stack to i8*
call void @use(i8* %1)
call void @__kmpc_free_shared(i8* %y)
br label %exit
exit:
ret void
}

define internal void @negative_qux_spmd() {
entry:
%call = call i32 @__kmpc_target_init(%struct.ident_t* nonnull @1, i1 true, i1 true, i1 true)
%0 = icmp eq i32 %call, -1
br i1 %0, label %master, label %exit
master:
%y = call i8* @__kmpc_alloc_shared(i64 6), !dbg !12
%y_on_stack = bitcast i8* %y to [6 x i32]*
%1 = bitcast [6 x i32]* %y_on_stack to i8*
call void @use(i8* %1)
call void @__kmpc_free_shared(i8* %y)
br label %exit
exit:
Expand All @@ -85,6 +102,7 @@ declare i32 @llvm.nvvm.read.ptx.sreg.ntid.x()

declare i32 @llvm.nvvm.read.ptx.sreg.warpsize()

declare i32 @__kmpc_target_init(%struct.ident_t*, i1, i1, i1)

!llvm.dbg.cu = !{!0}
!llvm.module.flags = !{!3, !4, !5, !6}
Expand Down
Loading

0 comments on commit 1d5711c

Please sign in to comment.