Skip to content

Commit

Permalink
[OpenMPIRBuilder] Add ThreadLimit and NumTeams clauses to teams const…
Browse files Browse the repository at this point in the history
…ruct (#68364)

This patch adds support for `thread_limit` and bounds on `num_teams`
clause for the teams construct in OpenMP.

Added testcases for the same.
  • Loading branch information
shraiysh committed Oct 11, 2023
1 parent 315ab3c commit e41eaf4
Show file tree
Hide file tree
Showing 4 changed files with 243 additions and 2 deletions.
11 changes: 10 additions & 1 deletion llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -1917,8 +1917,17 @@ class OpenMPIRBuilder {
///
/// \param Loc The location where the teams construct was encountered.
/// \param BodyGenCB Callback that will generate the region code.
/// \param NumTeamsLower Lower bound on number of teams. If this is nullptr,
/// it is as if lower bound is specified as equal to upperbound. If
/// this is non-null, then upperbound must also be non-null.
/// \param NumTeamsUpper Upper bound on the number of teams.
/// \param ThreadLimit on the number of threads that may participate in a
/// contention group created by each team.
InsertPointTy createTeams(const LocationDescription &Loc,
BodyGenCallbackTy BodyGenCB);
BodyGenCallbackTy BodyGenCB,
Value *NumTeamsLower = nullptr,
Value *NumTeamsUpper = nullptr,
Value *ThreadLimit = nullptr);

/// Generate conditional branch and relevant BasicBlocks through which private
/// threads copy the 'copyin' variables from Master copy to threadprivate
Expand Down
1 change: 1 addition & 0 deletions llvm/include/llvm/Frontend/OpenMP/OMPKinds.def
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,7 @@ __OMP_RTL(__kmpc_cancellationpoint, false, Int32, IdentPtr, Int32, Int32)

__OMP_RTL(__kmpc_fork_teams, true, Void, IdentPtr, Int32, ParallelTaskPtr)
__OMP_RTL(__kmpc_push_num_teams, false, Void, IdentPtr, Int32, Int32, Int32)
__OMP_RTL(__kmpc_push_num_teams_51, false, Void, IdentPtr, Int32, Int32, Int32, Int32)
__OMP_RTL(__kmpc_set_thread_limit, false, Void, IdentPtr, Int32, Int32)

__OMP_RTL(__kmpc_copyprivate, false, Void, IdentPtr, Int32, SizeTy, VoidPtr,
Expand Down
23 changes: 22 additions & 1 deletion llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5733,7 +5733,8 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createAtomicCompare(

OpenMPIRBuilder::InsertPointTy
OpenMPIRBuilder::createTeams(const LocationDescription &Loc,
BodyGenCallbackTy BodyGenCB) {
BodyGenCallbackTy BodyGenCB, Value *NumTeamsLower,
Value *NumTeamsUpper, Value *ThreadLimit) {
if (!updateToLocation(Loc))
return InsertPointTy();

Expand Down Expand Up @@ -5771,6 +5772,26 @@ OpenMPIRBuilder::createTeams(const LocationDescription &Loc,
BasicBlock *AllocaBB =
splitBB(Builder, /*CreateBranch=*/true, "teams.alloca");

// Push num_teams
if (NumTeamsLower || NumTeamsUpper || ThreadLimit) {
assert((NumTeamsLower == nullptr || NumTeamsUpper != nullptr) &&
"if lowerbound is non-null, then upperbound must also be non-null "
"for bounds on num_teams");

if (NumTeamsUpper == nullptr)
NumTeamsUpper = Builder.getInt32(0);

if (NumTeamsLower == nullptr)
NumTeamsLower = NumTeamsUpper;

if (ThreadLimit == nullptr)
ThreadLimit = Builder.getInt32(0);

Value *ThreadNum = getOrCreateThreadID(Ident);
Builder.CreateCall(
getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_push_num_teams_51),
{Ident, ThreadNum, NumTeamsLower, NumTeamsUpper, ThreadLimit});
}
// Generate the body of teams.
InsertPointTy AllocaIP(AllocaBB, AllocaBB->begin());
InsertPointTy CodeGenIP(BodyBB, BodyBB->begin());
Expand Down
210 changes: 210 additions & 0 deletions llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4074,6 +4074,216 @@ TEST_F(OpenMPIRBuilderTest, CreateTeams) {
[](Instruction &inst) { return isa<ICmpInst>(&inst); }));
}

TEST_F(OpenMPIRBuilderTest, CreateTeamsWithThreadLimit) {
using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
OpenMPIRBuilder OMPBuilder(*M);
OMPBuilder.initialize();
F->setName("func");
IRBuilder<> &Builder = OMPBuilder.Builder;
Builder.SetInsertPoint(BB);

Function *FakeFunction =
Function::Create(FunctionType::get(Builder.getVoidTy(), false),
GlobalValue::ExternalLinkage, "fakeFunction", M.get());

auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
Builder.restoreIP(CodeGenIP);
Builder.CreateCall(FakeFunction, {});
};

// `F` has an argument - an integer, so we use that as the thread limit.
Builder.restoreIP(OMPBuilder.createTeams(/*=*/Builder, BodyGenCB,
/*NumTeamsLower=*/nullptr,
/*NumTeamsUpper=*/nullptr,
/*ThreadLimit=*/F->arg_begin()));

Builder.CreateRetVoid();
OMPBuilder.finalize();

ASSERT_FALSE(verifyModule(*M));

CallInst *PushNumTeamsCallInst =
findSingleCall(F, OMPRTL___kmpc_push_num_teams_51, OMPBuilder);
ASSERT_NE(PushNumTeamsCallInst, nullptr);

EXPECT_EQ(PushNumTeamsCallInst->getArgOperand(2), Builder.getInt32(0));
EXPECT_EQ(PushNumTeamsCallInst->getArgOperand(3), Builder.getInt32(0));
EXPECT_EQ(PushNumTeamsCallInst->getArgOperand(4), &*F->arg_begin());

// Verifying that the next instruction to execute is kmpc_fork_teams
BranchInst *BrInst =
dyn_cast<BranchInst>(PushNumTeamsCallInst->getNextNonDebugInstruction());
ASSERT_NE(BrInst, nullptr);
ASSERT_EQ(BrInst->getNumSuccessors(), 1U);
Instruction *NextInstruction =
BrInst->getSuccessor(0)->getFirstNonPHIOrDbgOrLifetime();
CallInst *ForkTeamsCI = dyn_cast_if_present<CallInst>(NextInstruction);
ASSERT_NE(ForkTeamsCI, nullptr);
EXPECT_EQ(ForkTeamsCI->getCalledFunction(),
OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_fork_teams));
}

TEST_F(OpenMPIRBuilderTest, CreateTeamsWithNumTeamsUpper) {
using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
OpenMPIRBuilder OMPBuilder(*M);
OMPBuilder.initialize();
F->setName("func");
IRBuilder<> &Builder = OMPBuilder.Builder;
Builder.SetInsertPoint(BB);

Function *FakeFunction =
Function::Create(FunctionType::get(Builder.getVoidTy(), false),
GlobalValue::ExternalLinkage, "fakeFunction", M.get());

auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
Builder.restoreIP(CodeGenIP);
Builder.CreateCall(FakeFunction, {});
};

// `F` already has an integer argument, so we use that as upper bound to
// `num_teams`
Builder.restoreIP(OMPBuilder.createTeams(Builder, BodyGenCB,
/*NumTeamsLower=*/nullptr,
/*NumTeamsUpper=*/F->arg_begin()));

Builder.CreateRetVoid();
OMPBuilder.finalize();

ASSERT_FALSE(verifyModule(*M));

CallInst *PushNumTeamsCallInst =
findSingleCall(F, OMPRTL___kmpc_push_num_teams_51, OMPBuilder);
ASSERT_NE(PushNumTeamsCallInst, nullptr);

EXPECT_EQ(PushNumTeamsCallInst->getArgOperand(2), &*F->arg_begin());
EXPECT_EQ(PushNumTeamsCallInst->getArgOperand(3), &*F->arg_begin());
EXPECT_EQ(PushNumTeamsCallInst->getArgOperand(4), Builder.getInt32(0));

// Verifying that the next instruction to execute is kmpc_fork_teams
BranchInst *BrInst =
dyn_cast<BranchInst>(PushNumTeamsCallInst->getNextNonDebugInstruction());
ASSERT_NE(BrInst, nullptr);
ASSERT_EQ(BrInst->getNumSuccessors(), 1U);
Instruction *NextInstruction =
BrInst->getSuccessor(0)->getFirstNonPHIOrDbgOrLifetime();
CallInst *ForkTeamsCI = dyn_cast_if_present<CallInst>(NextInstruction);
ASSERT_NE(ForkTeamsCI, nullptr);
EXPECT_EQ(ForkTeamsCI->getCalledFunction(),
OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_fork_teams));
}

TEST_F(OpenMPIRBuilderTest, CreateTeamsWithNumTeamsBoth) {
using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
OpenMPIRBuilder OMPBuilder(*M);
OMPBuilder.initialize();
F->setName("func");
IRBuilder<> &Builder = OMPBuilder.Builder;
Builder.SetInsertPoint(BB);

Function *FakeFunction =
Function::Create(FunctionType::get(Builder.getVoidTy(), false),
GlobalValue::ExternalLinkage, "fakeFunction", M.get());

Value *NumTeamsLower =
Builder.CreateAdd(F->arg_begin(), Builder.getInt32(5), "numTeamsLower");
Value *NumTeamsUpper =
Builder.CreateAdd(F->arg_begin(), Builder.getInt32(10), "numTeamsUpper");

auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
Builder.restoreIP(CodeGenIP);
Builder.CreateCall(FakeFunction, {});
};

// `F` already has an integer argument, so we use that as upper bound to
// `num_teams`
Builder.restoreIP(
OMPBuilder.createTeams(Builder, BodyGenCB, NumTeamsLower, NumTeamsUpper));

Builder.CreateRetVoid();
OMPBuilder.finalize();

ASSERT_FALSE(verifyModule(*M));

CallInst *PushNumTeamsCallInst =
findSingleCall(F, OMPRTL___kmpc_push_num_teams_51, OMPBuilder);
ASSERT_NE(PushNumTeamsCallInst, nullptr);

EXPECT_EQ(PushNumTeamsCallInst->getArgOperand(2), NumTeamsLower);
EXPECT_EQ(PushNumTeamsCallInst->getArgOperand(3), NumTeamsUpper);
EXPECT_EQ(PushNumTeamsCallInst->getArgOperand(4), Builder.getInt32(0));

// Verifying that the next instruction to execute is kmpc_fork_teams
BranchInst *BrInst =
dyn_cast<BranchInst>(PushNumTeamsCallInst->getNextNonDebugInstruction());
ASSERT_NE(BrInst, nullptr);
ASSERT_EQ(BrInst->getNumSuccessors(), 1U);
Instruction *NextInstruction =
BrInst->getSuccessor(0)->getFirstNonPHIOrDbgOrLifetime();
CallInst *ForkTeamsCI = dyn_cast_if_present<CallInst>(NextInstruction);
ASSERT_NE(ForkTeamsCI, nullptr);
EXPECT_EQ(ForkTeamsCI->getCalledFunction(),
OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_fork_teams));
}

TEST_F(OpenMPIRBuilderTest, CreateTeamsWithNumTeamsAndThreadLimit) {
using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
OpenMPIRBuilder OMPBuilder(*M);
OMPBuilder.initialize();
F->setName("func");
IRBuilder<> &Builder = OMPBuilder.Builder;
Builder.SetInsertPoint(BB);

BasicBlock *CodegenBB = splitBB(Builder, true);
Builder.SetInsertPoint(CodegenBB);

// Generate values for `num_teams` and `thread_limit` using the first argument
// of the testing function.
Value *NumTeamsLower =
Builder.CreateAdd(F->arg_begin(), Builder.getInt32(5), "numTeamsLower");
Value *NumTeamsUpper =
Builder.CreateAdd(F->arg_begin(), Builder.getInt32(10), "numTeamsUpper");
Value *ThreadLimit =
Builder.CreateAdd(F->arg_begin(), Builder.getInt32(20), "threadLimit");

Function *FakeFunction =
Function::Create(FunctionType::get(Builder.getVoidTy(), false),
GlobalValue::ExternalLinkage, "fakeFunction", M.get());

auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
Builder.restoreIP(CodeGenIP);
Builder.CreateCall(FakeFunction, {});
};

OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
Builder.restoreIP(OMPBuilder.createTeams(Builder, BodyGenCB, NumTeamsLower,
NumTeamsUpper, ThreadLimit));

Builder.CreateRetVoid();
OMPBuilder.finalize();

ASSERT_FALSE(verifyModule(*M));

CallInst *PushNumTeamsCallInst =
findSingleCall(F, OMPRTL___kmpc_push_num_teams_51, OMPBuilder);
ASSERT_NE(PushNumTeamsCallInst, nullptr);

EXPECT_EQ(PushNumTeamsCallInst->getArgOperand(2), NumTeamsLower);
EXPECT_EQ(PushNumTeamsCallInst->getArgOperand(3), NumTeamsUpper);
EXPECT_EQ(PushNumTeamsCallInst->getArgOperand(4), ThreadLimit);

// Verifying that the next instruction to execute is kmpc_fork_teams
BranchInst *BrInst =
dyn_cast<BranchInst>(PushNumTeamsCallInst->getNextNonDebugInstruction());
ASSERT_NE(BrInst, nullptr);
ASSERT_EQ(BrInst->getNumSuccessors(), 1U);
Instruction *NextInstruction =
BrInst->getSuccessor(0)->getFirstNonPHIOrDbgOrLifetime();
CallInst *ForkTeamsCI = dyn_cast_if_present<CallInst>(NextInstruction);
ASSERT_NE(ForkTeamsCI, nullptr);
EXPECT_EQ(ForkTeamsCI->getCalledFunction(),
OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_fork_teams));
}

/// Returns the single instruction of InstTy type in BB that uses the value V.
/// If there is more than one such instruction, returns null.
template <typename InstTy>
Expand Down

0 comments on commit e41eaf4

Please sign in to comment.