Skip to content

Commit

Permalink
[OMPIRBuilder] Support depend clause for task
Browse files Browse the repository at this point in the history
This patch adds support for the `depend` clause for the `task`
construct.

Reviewed By: jdoerfert

Differential Revision: https://reviews.llvm.org/D135695
  • Loading branch information
psoni2628 authored and Prabhdeep Singh Soni (A) committed Oct 19, 2022
1 parent 607be38 commit 6149589
Show file tree
Hide file tree
Showing 6 changed files with 191 additions and 35 deletions.
55 changes: 27 additions & 28 deletions clang/lib/CodeGen/CGOpenMPRuntime.cpp
Expand Up @@ -4377,39 +4377,26 @@ CGOpenMPRuntime::emitTaskInit(CodeGenFunction &CGF, SourceLocation Loc,
return Result;
}

namespace {
/// Dependence kind for RTL.
enum RTLDependenceKindTy {
DepIn = 0x01,
DepInOut = 0x3,
DepMutexInOutSet = 0x4,
DepInOutSet = 0x8,
DepOmpAllMem = 0x80,
};
/// Fields ids in kmp_depend_info record.
enum RTLDependInfoFieldsTy { BaseAddr, Len, Flags };
} // namespace

/// Translates internal dependency kind into the runtime kind.
static RTLDependenceKindTy translateDependencyKind(OpenMPDependClauseKind K) {
RTLDependenceKindTy DepKind;
switch (K) {
case OMPC_DEPEND_in:
DepKind = DepIn;
DepKind = RTLDependenceKindTy::DepIn;
break;
// Out and InOut dependencies must use the same code.
case OMPC_DEPEND_out:
case OMPC_DEPEND_inout:
DepKind = DepInOut;
DepKind = RTLDependenceKindTy::DepInOut;
break;
case OMPC_DEPEND_mutexinoutset:
DepKind = DepMutexInOutSet;
DepKind = RTLDependenceKindTy::DepMutexInOutSet;
break;
case OMPC_DEPEND_inoutset:
DepKind = DepInOutSet;
DepKind = RTLDependenceKindTy::DepInOutSet;
break;
case OMPC_DEPEND_outallmemory:
DepKind = DepOmpAllMem;
DepKind = RTLDependenceKindTy::DepOmpAllMem;
break;
case OMPC_DEPEND_source:
case OMPC_DEPEND_sink:
Expand Down Expand Up @@ -4457,7 +4444,9 @@ CGOpenMPRuntime::getDepobjElements(CodeGenFunction &CGF, LValue DepobjLVal,
DepObjAddr, KmpDependInfoTy, Base.getBaseInfo(), Base.getTBAAInfo());
// NumDeps = deps[i].base_addr;
LValue BaseAddrLVal = CGF.EmitLValueForField(
NumDepsBase, *std::next(KmpDependInfoRD->field_begin(), BaseAddr));
NumDepsBase,
*std::next(KmpDependInfoRD->field_begin(),
static_cast<unsigned int>(RTLDependInfoFields::BaseAddr)));
llvm::Value *NumDeps = CGF.EmitLoadOfScalar(BaseAddrLVal, Loc);
return std::make_pair(NumDeps, Base);
}
Expand Down Expand Up @@ -4503,18 +4492,24 @@ static void emitDependData(CodeGenFunction &CGF, QualType &KmpDependInfoTy,
}
// deps[i].base_addr = &<Dependencies[i].second>;
LValue BaseAddrLVal = CGF.EmitLValueForField(
Base, *std::next(KmpDependInfoRD->field_begin(), BaseAddr));
Base,
*std::next(KmpDependInfoRD->field_begin(),
static_cast<unsigned int>(RTLDependInfoFields::BaseAddr)));
CGF.EmitStoreOfScalar(Addr, BaseAddrLVal);
// deps[i].len = sizeof(<Dependencies[i].second>);
LValue LenLVal = CGF.EmitLValueForField(
Base, *std::next(KmpDependInfoRD->field_begin(), Len));
Base, *std::next(KmpDependInfoRD->field_begin(),
static_cast<unsigned int>(RTLDependInfoFields::Len)));
CGF.EmitStoreOfScalar(Size, LenLVal);
// deps[i].flags = <Dependencies[i].first>;
RTLDependenceKindTy DepKind = translateDependencyKind(Data.DepKind);
LValue FlagsLVal = CGF.EmitLValueForField(
Base, *std::next(KmpDependInfoRD->field_begin(), Flags));
CGF.EmitStoreOfScalar(llvm::ConstantInt::get(LLVMFlagsTy, DepKind),
FlagsLVal);
Base,
*std::next(KmpDependInfoRD->field_begin(),
static_cast<unsigned int>(RTLDependInfoFields::Flags)));
CGF.EmitStoreOfScalar(
llvm::ConstantInt::get(LLVMFlagsTy, static_cast<unsigned int>(DepKind)),
FlagsLVal);
if (unsigned *P = Pos.dyn_cast<unsigned *>()) {
++(*P);
} else {
Expand Down Expand Up @@ -4790,7 +4785,9 @@ Address CGOpenMPRuntime::emitDepobjDependClause(
LValue Base = CGF.MakeAddrLValue(DependenciesArray, KmpDependInfoTy);
// deps[i].base_addr = NumDependencies;
LValue BaseAddrLVal = CGF.EmitLValueForField(
Base, *std::next(KmpDependInfoRD->field_begin(), BaseAddr));
Base,
*std::next(KmpDependInfoRD->field_begin(),
static_cast<unsigned int>(RTLDependInfoFields::BaseAddr)));
CGF.EmitStoreOfScalar(NumDepsVal, BaseAddrLVal);
llvm::PointerUnion<unsigned *, LValue *> Pos;
unsigned Idx = 1;
Expand Down Expand Up @@ -4870,9 +4867,11 @@ void CGOpenMPRuntime::emitUpdateClause(CodeGenFunction &CGF, LValue DepobjLVal,
// deps[i].flags = NewDepKind;
RTLDependenceKindTy DepKind = translateDependencyKind(NewDepKind);
LValue FlagsLVal = CGF.EmitLValueForField(
Base, *std::next(KmpDependInfoRD->field_begin(), Flags));
CGF.EmitStoreOfScalar(llvm::ConstantInt::get(LLVMFlagsTy, DepKind),
FlagsLVal);
Base, *std::next(KmpDependInfoRD->field_begin(),
static_cast<unsigned int>(RTLDependInfoFields::Flags)));
CGF.EmitStoreOfScalar(
llvm::ConstantInt::get(LLVMFlagsTy, static_cast<unsigned int>(DepKind)),
FlagsLVal);

// Shift the address forward by one element.
Address ElementNext =
Expand Down
13 changes: 13 additions & 0 deletions llvm/include/llvm/Frontend/OpenMP/OMPConstants.h
Expand Up @@ -207,6 +207,19 @@ enum class OMPInteropType { Unknown, Target, TargetSync };
/// Atomic compare operations. Currently OpenMP only supports ==, >, and <.
enum class OMPAtomicCompareOp : unsigned { EQ, MIN, MAX };

/// Fields ids in kmp_depend_info record.
enum class RTLDependInfoFields { BaseAddr, Len, Flags };

/// Dependence kind for RTL.
enum class RTLDependenceKindTy {
DepUnknown = 0x0,
DepIn = 0x01,
DepInOut = 0x3,
DepMutexInOutSet = 0x4,
DepInOutSet = 0x8,
DepOmpAllMem = 0x80,
};

} // end namespace omp

} // end namespace llvm
Expand Down
14 changes: 13 additions & 1 deletion llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
Expand Up @@ -645,6 +645,17 @@ class OpenMPIRBuilder {
/// \param Loc The location where the taskyield directive was encountered.
void createTaskyield(const LocationDescription &Loc);

/// A struct to pack the relevant information for an OpenMP depend clause.
struct DependData {
omp::RTLDependenceKindTy DepKind = omp::RTLDependenceKindTy::DepUnknown;
Type *DepValueType;
Value *DepVal;
explicit DependData() = default;
DependData(omp::RTLDependenceKindTy DepKind, Type *DepValueType,
Value *DepVal)
: DepKind(DepKind), DepValueType(DepValueType), DepVal(DepVal) {}
};

/// Generator for `#omp task`
///
/// \param Loc The location where the task construct was encountered.
Expand All @@ -662,7 +673,8 @@ class OpenMPIRBuilder {
InsertPointTy createTask(const LocationDescription &Loc,
InsertPointTy AllocaIP, BodyGenCallbackTy BodyGenCB,
bool Tied = true, Value *Final = nullptr,
Value *IfCondition = nullptr);
Value *IfCondition = nullptr,
ArrayRef<DependData *> Dependencies = {});

/// Generator for the taskgroup construct
///
Expand Down
1 change: 1 addition & 0 deletions llvm/include/llvm/Frontend/OpenMP/OMPKinds.def
Expand Up @@ -92,6 +92,7 @@ __OMP_STRUCT_TYPE(OffloadEntry, __tgt_offload_entry, Int8Ptr, Int8Ptr, SizeTy,
__OMP_STRUCT_TYPE(KernelArgs, __tgt_kernel_arguments, Int32, Int32, VoidPtrPtr,
VoidPtrPtr, Int64Ptr, Int64Ptr, VoidPtrPtr, VoidPtrPtr, Int64)
__OMP_STRUCT_TYPE(AsyncInfo, __tgt_async_info, Int8Ptr)
__OMP_STRUCT_TYPE(DependInfo, kmp_dep_info, SizeTy, SizeTy, Int8)

#undef __OMP_STRUCT_TYPE
#undef OMP_STRUCT_TYPE
Expand Down
68 changes: 62 additions & 6 deletions llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
Expand Up @@ -1290,7 +1290,8 @@ void OpenMPIRBuilder::createTaskyield(const LocationDescription &Loc) {
OpenMPIRBuilder::InsertPointTy
OpenMPIRBuilder::createTask(const LocationDescription &Loc,
InsertPointTy AllocaIP, BodyGenCallbackTy BodyGenCB,
bool Tied, Value *Final, Value *IfCondition) {
bool Tied, Value *Final, Value *IfCondition,
ArrayRef<DependData *> Dependencies) {
if (!updateToLocation(Loc))
return InsertPointTy();

Expand Down Expand Up @@ -1322,8 +1323,8 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc,
OI.EntryBB = TaskAllocaBB;
OI.OuterAllocaBB = AllocaIP.getBlock();
OI.ExitBB = TaskExitBB;
OI.PostOutlineCB = [this, Ident, Tied, Final,
IfCondition](Function &OutlinedFn) {
OI.PostOutlineCB = [this, Ident, Tied, Final, IfCondition,
Dependencies](Function &OutlinedFn) {
// The input IR here looks like the following-
// ```
// func @current_fn() {
Expand Down Expand Up @@ -1433,6 +1434,49 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc,
TaskSize);
}

Value *DepArrayPtr = nullptr;
if (Dependencies.size()) {
InsertPointTy OldIP = Builder.saveIP();
Builder.SetInsertPoint(
&OldIP.getBlock()->getParent()->getEntryBlock().back());

Type *DepArrayTy = ArrayType::get(DependInfo, Dependencies.size());
Value *DepArray =
Builder.CreateAlloca(DepArrayTy, nullptr, ".dep.arr.addr");

unsigned P = 0;
for (DependData *Dep : Dependencies) {
Value *Base =
Builder.CreateConstInBoundsGEP2_64(DepArrayTy, DepArray, 0, P);
// Store the pointer to the variable
Value *Addr = Builder.CreateStructGEP(
DependInfo, Base,
static_cast<unsigned int>(RTLDependInfoFields::BaseAddr));
Value *DepValPtr =
Builder.CreatePtrToInt(Dep->DepVal, Builder.getInt64Ty());
Builder.CreateStore(DepValPtr, Addr);
// Store the size of the variable
Value *Size = Builder.CreateStructGEP(
DependInfo, Base,
static_cast<unsigned int>(RTLDependInfoFields::Len));
Builder.CreateStore(Builder.getInt64(M.getDataLayout().getTypeStoreSize(
Dep->DepValueType)),
Size);
// Store the dependency kind
Value *Flags = Builder.CreateStructGEP(
DependInfo, Base,
static_cast<unsigned int>(RTLDependInfoFields::Flags));
Builder.CreateStore(
ConstantInt::get(Builder.getInt8Ty(),
static_cast<unsigned int>(Dep->DepKind)),
Flags);
++P;
}

DepArrayPtr = Builder.CreateBitCast(DepArray, Builder.getInt8PtrTy());
Builder.restoreIP(OldIP);
}

// In the presence of the `if` clause, the following IR is generated:
// ...
// %data = call @__kmpc_omp_task_alloc(...)
Expand Down Expand Up @@ -1471,9 +1515,21 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc,
Builder.CreateCall(TaskCompleteFn, {Ident, ThreadID, NewTaskData});
Builder.SetInsertPoint(ThenTI);
}
// Emit the @__kmpc_omp_task runtime call to spawn the task
Function *TaskFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task);
Builder.CreateCall(TaskFn, {Ident, ThreadID, NewTaskData});

if (Dependencies.size()) {
Function *TaskFn =
getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_with_deps);
Builder.CreateCall(
TaskFn,
{Ident, ThreadID, NewTaskData, Builder.getInt32(Dependencies.size()),
DepArrayPtr, ConstantInt::get(Builder.getInt32Ty(), 0),
ConstantPointerNull::get(Type::getInt8PtrTy(M.getContext()))});

} else {
// Emit the @__kmpc_omp_task runtime call to spawn the task
Function *TaskFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task);
Builder.CreateCall(TaskFn, {Ident, ThreadID, NewTaskData});
}

StaleCI->eraseFromParent();

Expand Down
75 changes: 75 additions & 0 deletions llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
Expand Up @@ -5092,6 +5092,81 @@ TEST_F(OpenMPIRBuilderTest, CreateTaskUntied) {
EXPECT_FALSE(verifyModule(*M, &errs()));
}

TEST_F(OpenMPIRBuilderTest, CreateTaskDepend) {
using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
OpenMPIRBuilder OMPBuilder(*M);
OMPBuilder.initialize();
F->setName("func");
IRBuilder<> Builder(BB);
auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {};
BasicBlock *AllocaBB = Builder.GetInsertBlock();
BasicBlock *BodyBB = splitBB(Builder, /*CreateBranch=*/true, "alloca.split");
OpenMPIRBuilder::LocationDescription Loc(
InsertPointTy(BodyBB, BodyBB->getFirstInsertionPt()), DL);
AllocaInst *InDep = Builder.CreateAlloca(Type::getInt32Ty(M->getContext()));
OpenMPIRBuilder::DependData DDIn(RTLDependenceKindTy::DepIn,
Type::getInt32Ty(M->getContext()), InDep);
SmallVector<OpenMPIRBuilder::DependData *, 4> DDS;
DDS.push_back(&DDIn);
Builder.restoreIP(OMPBuilder.createTask(
Loc, InsertPointTy(AllocaBB, AllocaBB->getFirstInsertionPt()), BodyGenCB,
/*Tied=*/false, /*Final*/ nullptr, /*IfCondition*/ nullptr, DDS));
OMPBuilder.finalize();
Builder.CreateRetVoid();

// Check for the `NumDeps` argument
CallInst *TaskAllocCall = dyn_cast<CallInst>(
OMPBuilder
.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_with_deps)
->user_back());
ASSERT_NE(TaskAllocCall, nullptr);
ConstantInt *NumDeps = dyn_cast<ConstantInt>(TaskAllocCall->getArgOperand(3));
ASSERT_NE(NumDeps, nullptr);
EXPECT_EQ(NumDeps->getZExtValue(), 1U);

// Check for the `DepInfo` array argument
BitCastInst *DepArrayPtr =
dyn_cast<BitCastInst>(TaskAllocCall->getOperand(4));
ASSERT_NE(DepArrayPtr, nullptr);
AllocaInst *DepArray = dyn_cast<AllocaInst>(DepArrayPtr->getOperand(0));
ASSERT_NE(DepArray, nullptr);
Value::user_iterator DepArrayI = DepArray->user_begin();
EXPECT_EQ(*DepArrayI, DepArrayPtr);
++DepArrayI;
Value::user_iterator DepInfoI = DepArrayI->user_begin();
// Check for the `DependKind` flag in the `DepInfo` array
Value *Flag = findStoredValue<GetElementPtrInst>(*DepInfoI);
ASSERT_NE(Flag, nullptr);
ConstantInt *FlagInt = dyn_cast<ConstantInt>(Flag);
ASSERT_NE(FlagInt, nullptr);
EXPECT_EQ(FlagInt->getZExtValue(),
static_cast<unsigned int>(RTLDependenceKindTy::DepIn));
++DepInfoI;
// Check for the size in the `DepInfo` array
Value *Size = findStoredValue<GetElementPtrInst>(*DepInfoI);
ASSERT_NE(Size, nullptr);
ConstantInt *SizeInt = dyn_cast<ConstantInt>(Size);
ASSERT_NE(SizeInt, nullptr);
EXPECT_EQ(SizeInt->getZExtValue(), 4U);
++DepInfoI;
// Check for the variable address in the `DepInfo` array
Value *AddrStored = findStoredValue<GetElementPtrInst>(*DepInfoI);
ASSERT_NE(AddrStored, nullptr);
PtrToIntInst *AddrInt = dyn_cast<PtrToIntInst>(AddrStored);
ASSERT_NE(AddrInt, nullptr);
Value *Addr = AddrInt->getPointerOperand();
EXPECT_EQ(Addr, InDep);

ConstantInt *NumDepsNoAlias =
dyn_cast<ConstantInt>(TaskAllocCall->getArgOperand(5));
ASSERT_NE(NumDepsNoAlias, nullptr);
EXPECT_EQ(NumDepsNoAlias->getZExtValue(), 0U);
EXPECT_EQ(TaskAllocCall->getOperand(6),
ConstantPointerNull::get(Type::getInt8PtrTy(M->getContext())));

EXPECT_FALSE(verifyModule(*M, &errs()));
}

TEST_F(OpenMPIRBuilderTest, CreateTaskFinal) {
using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
OpenMPIRBuilder OMPBuilder(*M);
Expand Down

0 comments on commit 6149589

Please sign in to comment.