Skip to content

Commit

Permalink
[OpenMP] Codegen support for 'target teams' on the host.
Browse files Browse the repository at this point in the history
This patch adds support for codegen of 'target teams' on the host.
This combined directive has two captured statements, one for the
'teams' region, and the other for the 'parallel'.

This target teams region is offloaded using the __tgt_target_teams()
call. The patch sets the number of teams as an argument to
this call.

Reviewers: ABataev
Differential Revision: https://reviews.llvm.org/D29084

llvm-svn: 293001
  • Loading branch information
arpith-jacob committed Jan 25, 2017
1 parent 2f3f985 commit 4dbf368
Show file tree
Hide file tree
Showing 9 changed files with 1,424 additions and 41 deletions.
5 changes: 4 additions & 1 deletion clang/lib/Basic/OpenMPKinds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -875,8 +875,11 @@ void clang::getOpenMPCaptureRegions(
case OMPD_parallel_sections:
CaptureRegions.push_back(OMPD_parallel);
break;
case OMPD_teams:
case OMPD_target_teams:
CaptureRegions.push_back(OMPD_target);
CaptureRegions.push_back(OMPD_teams);
break;
case OMPD_teams:
case OMPD_simd:
case OMPD_for:
case OMPD_for_simd:
Expand Down
75 changes: 53 additions & 22 deletions clang/lib/CodeGen/CGOpenMPRuntime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4911,18 +4911,28 @@ emitNumTeamsForTargetDirective(CGOpenMPRuntime &OMPRuntime,
"teams directive expected to be "
"emitted only for the host!");

auto &Bld = CGF.Builder;

// If the target directive is combined with a teams directive:
// Return the value in the num_teams clause, if any.
// Otherwise, return 0 to denote the runtime default.
if (isOpenMPTeamsDirective(D.getDirectiveKind())) {
if (const auto *NumTeamsClause = D.getSingleClause<OMPNumTeamsClause>()) {
CodeGenFunction::RunCleanupsScope NumTeamsScope(CGF);
auto NumTeams = CGF.EmitScalarExpr(NumTeamsClause->getNumTeams(),
/*IgnoreResultAssign*/ true);
return Bld.CreateIntCast(NumTeams, CGF.Int32Ty,
/*IsSigned=*/true);
}

// The default value is 0.
return Bld.getInt32(0);
}

// If the target directive is combined with a parallel directive but not a
// teams directive, start one team.
if (isOpenMPParallelDirective(D.getDirectiveKind()) &&
!isOpenMPTeamsDirective(D.getDirectiveKind()))
return CGF.Builder.getInt32(1);

// FIXME: For the moment we do not support combined directives with target and
// teams, so we do not expect to get any num_teams clause in the provided
// directive. Once we support that, this assertion can be replaced by the
// actual emission of the clause expression.
assert(D.getSingleClause<OMPNumTeamsClause>() == nullptr &&
"Not expecting clause in directive.");
if (isOpenMPParallelDirective(D.getDirectiveKind()))
return Bld.getInt32(1);

// If the current target region has a teams region enclosed, we need to get
// the number of teams to pass to the runtime function call. This is done
Expand All @@ -4940,13 +4950,13 @@ emitNumTeamsForTargetDirective(CGOpenMPRuntime &OMPRuntime,
CGOpenMPInnerExprInfo CGInfo(CGF, CS);
CodeGenFunction::CGCapturedStmtRAII CapInfoRAII(CGF, &CGInfo);
llvm::Value *NumTeams = CGF.EmitScalarExpr(NTE->getNumTeams());
return CGF.Builder.CreateIntCast(NumTeams, CGF.Int32Ty,
/*IsSigned=*/true);
return Bld.CreateIntCast(NumTeams, CGF.Int32Ty,
/*IsSigned=*/true);
}

// If we have an enclosed teams directive but no num_teams clause we use
// the default value 0.
return CGF.Builder.getInt32(0);
return Bld.getInt32(0);
}

// No teams associated with the directive.
Expand Down Expand Up @@ -4986,9 +4996,20 @@ emitNumThreadsForTargetDirective(CGOpenMPRuntime &OMPRuntime,
//
// If this is not a teams directive return nullptr.

if (isOpenMPParallelDirective(D.getDirectiveKind())) {
if (isOpenMPTeamsDirective(D.getDirectiveKind()) ||
isOpenMPParallelDirective(D.getDirectiveKind())) {
llvm::Value *DefaultThreadLimitVal = Bld.getInt32(0);
llvm::Value *NumThreadsVal = nullptr;
llvm::Value *ThreadLimitVal = nullptr;

if (const auto *ThreadLimitClause =
D.getSingleClause<OMPThreadLimitClause>()) {
CodeGenFunction::RunCleanupsScope ThreadLimitScope(CGF);
auto ThreadLimit = CGF.EmitScalarExpr(ThreadLimitClause->getThreadLimit(),
/*IgnoreResultAssign*/ true);
ThreadLimitVal = Bld.CreateIntCast(ThreadLimit, CGF.Int32Ty,
/*IsSigned=*/true);
}

if (const auto *NumThreadsClause =
D.getSingleClause<OMPNumThreadsClause>()) {
Expand All @@ -5000,15 +5021,21 @@ emitNumThreadsForTargetDirective(CGOpenMPRuntime &OMPRuntime,
Bld.CreateIntCast(NumThreads, CGF.Int32Ty, /*IsSigned=*/true);
}

return NumThreadsVal ? NumThreadsVal : DefaultThreadLimitVal;
}
// Select the lesser of thread_limit and num_threads.
if (NumThreadsVal)
ThreadLimitVal = ThreadLimitVal
? Bld.CreateSelect(Bld.CreateICmpSLT(NumThreadsVal,
ThreadLimitVal),
NumThreadsVal, ThreadLimitVal)
: NumThreadsVal;

// FIXME: For the moment we do not support combined directives with target and
// teams, so we do not expect to get any thread_limit clause in the provided
// directive. Once we support that, this assertion can be replaced by the
// actual emission of the clause expression.
assert(D.getSingleClause<OMPThreadLimitClause>() == nullptr &&
"Not expecting clause in directive.");
// Set default value passed to the runtime if either teams or a target
// parallel type directive is found but no clause is specified.
if (!ThreadLimitVal)
ThreadLimitVal = DefaultThreadLimitVal;

return ThreadLimitVal;
}

// If the current target region has a teams region enclosed, we need to get
// the thread limit to pass to the runtime function call. This is done
Expand Down Expand Up @@ -6217,6 +6244,10 @@ void CGOpenMPRuntime::scanForTargetRegionsFunctions(const Stmt *S,
CodeGenFunction::EmitOMPTargetParallelDeviceFunction(
CGM, ParentName, cast<OMPTargetParallelDirective>(*S));
break;
case Stmt::OMPTargetTeamsDirectiveClass:
CodeGenFunction::EmitOMPTargetTeamsDeviceFunction(
CGM, ParentName, cast<OMPTargetTeamsDirective>(*S));
break;
default:
llvm_unreachable("Unknown target directive for OpenMP device codegen.");
}
Expand Down
1 change: 1 addition & 0 deletions clang/lib/CodeGen/CGOpenMPRuntimeNVPTX.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ getExecutionModeForDirective(CodeGenModule &CGM,
OpenMPDirectiveKind DirectiveKind = D.getDirectiveKind();
switch (DirectiveKind) {
case OMPD_target:
case OMPD_target_teams:
return CGOpenMPRuntimeNVPTX::ExecutionMode::Generic;
case OMPD_target_parallel:
return CGOpenMPRuntimeNVPTX::ExecutionMode::Spmd;
Expand Down
65 changes: 52 additions & 13 deletions clang/lib/CodeGen/CGStmtOpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,22 @@ class OMPParallelScope final : public OMPLexicalScope {
/*EmitPreInitStmt=*/EmitPreInitStmt(S)) {}
};

/// Lexical scope for OpenMP teams construct, that handles correct codegen
/// for captured expressions.
class OMPTeamsScope final : public OMPLexicalScope {
bool EmitPreInitStmt(const OMPExecutableDirective &S) {
OpenMPDirectiveKind Kind = S.getDirectiveKind();
return !isOpenMPTargetExecutionDirective(Kind) &&
isOpenMPTeamsDirective(Kind);
}

public:
OMPTeamsScope(CodeGenFunction &CGF, const OMPExecutableDirective &S)
: OMPLexicalScope(CGF, S,
/*AsInlined=*/false,
/*EmitPreInitStmt=*/EmitPreInitStmt(S)) {}
};

/// Private scope for OpenMP loop-based directives, that supports capturing
/// of used expression from loop statement.
class OMPLoopScope : public CodeGenFunction::RunCleanupsScope {
Expand Down Expand Up @@ -2018,15 +2034,6 @@ void CodeGenFunction::EmitOMPTeamsDistributeParallelForDirective(
});
}

void CodeGenFunction::EmitOMPTargetTeamsDirective(
const OMPTargetTeamsDirective &S) {
CGM.getOpenMPRuntime().emitInlinedDirective(
*this, OMPD_target_teams, [&S](CodeGenFunction &CGF, PrePostActionTy &) {
CGF.EmitStmt(
cast<CapturedStmt>(S.getAssociatedStmt())->getCapturedStmt());
});
}

void CodeGenFunction::EmitOMPTargetTeamsDistributeDirective(
const OMPTargetTeamsDistributeDirective &S) {
CGM.getOpenMPRuntime().emitInlinedDirective(
Expand Down Expand Up @@ -3519,9 +3526,8 @@ static void emitCommonOMPTeamsDirective(CodeGenFunction &CGF,
auto OutlinedFn = CGF.CGM.getOpenMPRuntime().emitTeamsOutlinedFunction(
S, *CS->getCapturedDecl()->param_begin(), InnermostKind, CodeGen);

const OMPTeamsDirective &TD = *dyn_cast<OMPTeamsDirective>(&S);
const OMPNumTeamsClause *NT = TD.getSingleClause<OMPNumTeamsClause>();
const OMPThreadLimitClause *TL = TD.getSingleClause<OMPThreadLimitClause>();
const OMPNumTeamsClause *NT = S.getSingleClause<OMPNumTeamsClause>();
const OMPThreadLimitClause *TL = S.getSingleClause<OMPThreadLimitClause>();
if (NT || TL) {
Expr *NumTeams = (NT) ? NT->getNumTeams() : nullptr;
Expr *ThreadLimit = (TL) ? TL->getThreadLimit() : nullptr;
Expand All @@ -3530,7 +3536,7 @@ static void emitCommonOMPTeamsDirective(CodeGenFunction &CGF,
S.getLocStart());
}

OMPLexicalScope Scope(CGF, S);
OMPTeamsScope Scope(CGF, S);
llvm::SmallVector<llvm::Value *, 16> CapturedVars;
CGF.GenerateOpenMPCapturedVars(*CS, CapturedVars);
CGF.CGM.getOpenMPRuntime().emitTeamsCall(CGF, S, S.getLocStart(), OutlinedFn,
Expand All @@ -3549,6 +3555,39 @@ void CodeGenFunction::EmitOMPTeamsDirective(const OMPTeamsDirective &S) {
emitCommonOMPTeamsDirective(*this, S, OMPD_teams, CodeGen);
}

static void emitTargetTeamsRegion(CodeGenFunction &CGF, PrePostActionTy &Action,
const OMPTargetTeamsDirective &S) {
auto *CS = S.getCapturedStmt(OMPD_teams);
Action.Enter(CGF);
auto &&CodeGen = [CS](CodeGenFunction &CGF, PrePostActionTy &) {
// TODO: Add support for clauses.
CGF.EmitStmt(CS->getCapturedStmt());
};
emitCommonOMPTeamsDirective(CGF, S, OMPD_teams, CodeGen);
}

void CodeGenFunction::EmitOMPTargetTeamsDeviceFunction(
CodeGenModule &CGM, StringRef ParentName,
const OMPTargetTeamsDirective &S) {
auto &&CodeGen = [&S](CodeGenFunction &CGF, PrePostActionTy &Action) {
emitTargetTeamsRegion(CGF, Action, S);
};
llvm::Function *Fn;
llvm::Constant *Addr;
// Emit target region as a standalone region.
CGM.getOpenMPRuntime().emitTargetOutlinedFunction(
S, ParentName, Fn, Addr, /*IsOffloadEntry=*/true, CodeGen);
assert(Fn && Addr && "Target device function emission failed.");
}

void CodeGenFunction::EmitOMPTargetTeamsDirective(
const OMPTargetTeamsDirective &S) {
auto &&CodeGen = [&S](CodeGenFunction &CGF, PrePostActionTy &Action) {
emitTargetTeamsRegion(CGF, Action, S);
};
emitCommonOMPTargetDirective(*this, S, CodeGen);
}

void CodeGenFunction::EmitOMPCancellationPointDirective(
const OMPCancellationPointDirective &S) {
CGM.getOpenMPRuntime().emitCancellationPointCall(*this, S.getLocStart(),
Expand Down
3 changes: 3 additions & 0 deletions clang/lib/CodeGen/CodeGenFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -2711,6 +2711,9 @@ class CodeGenFunction : public CodeGenTypeCache {
static void
EmitOMPTargetParallelDeviceFunction(CodeGenModule &CGM, StringRef ParentName,
const OMPTargetParallelDirective &S);
static void
EmitOMPTargetTeamsDeviceFunction(CodeGenModule &CGM, StringRef ParentName,
const OMPTargetTeamsDirective &S);
/// \brief Emit inner loop of the worksharing/simd construct.
///
/// \param S Directive, for which the inner loop must be emitted.
Expand Down
11 changes: 6 additions & 5 deletions clang/lib/Sema/SemaOpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1594,8 +1594,7 @@ void Sema::ActOnOpenMPRegionStart(OpenMPDirectiveKind DKind, Scope *CurScope) {
case OMPD_parallel_for:
case OMPD_parallel_for_simd:
case OMPD_parallel_sections:
case OMPD_teams:
case OMPD_target_teams: {
case OMPD_teams: {
QualType KmpInt32Ty = Context.getIntTypeForBitwidth(32, 1);
QualType KmpInt32PtrTy =
Context.getPointerType(KmpInt32Ty).withConst().withRestrict();
Expand All @@ -1608,6 +1607,7 @@ void Sema::ActOnOpenMPRegionStart(OpenMPDirectiveKind DKind, Scope *CurScope) {
Params);
break;
}
case OMPD_target_teams:
case OMPD_target_parallel: {
Sema::CapturedParamNameType ParamsTarget[] = {
std::make_pair(StringRef(), QualType()) // __context with shared vars
Expand All @@ -1618,14 +1618,15 @@ void Sema::ActOnOpenMPRegionStart(OpenMPDirectiveKind DKind, Scope *CurScope) {
QualType KmpInt32Ty = Context.getIntTypeForBitwidth(32, 1);
QualType KmpInt32PtrTy =
Context.getPointerType(KmpInt32Ty).withConst().withRestrict();
Sema::CapturedParamNameType ParamsParallel[] = {
Sema::CapturedParamNameType ParamsTeamsOrParallel[] = {
std::make_pair(".global_tid.", KmpInt32PtrTy),
std::make_pair(".bound_tid.", KmpInt32PtrTy),
std::make_pair(StringRef(), QualType()) // __context with shared vars
};
// Start a captured region for 'parallel'.
// Start a captured region for 'teams' or 'parallel'. Both regions have
// the same implicit parameters.
ActOnCapturedRegionStart(DSAStack->getConstructLoc(), CurScope, CR_OpenMP,
ParamsParallel);
ParamsTeamsOrParallel);
break;
}
case OMPD_simd:
Expand Down
Loading

0 comments on commit 4dbf368

Please sign in to comment.