Skip to content

Commit

Permalink
[Flang][OpenMP] Improve support for if clause on combined constructs
Browse files Browse the repository at this point in the history
This patch adds support for matching multiple OpenMP `if` clauses to their
specified directive in a combined construct. It also enables this clause to be
attached by name to `simd` and `teams` directives, in addition to the others
that were already supported.

This patch on its own cannot yet be tested because there is currently no
lowering to MLIR support for any combined construct containing two or more
OpenMP directives that can have an `if` clause attached.

Depends on D155981.

Differential Revision: https://reviews.llvm.org/D156313
  • Loading branch information
skatrak committed Aug 4, 2023
1 parent 27a0a74 commit 65e80d6
Show file tree
Hide file tree
Showing 9 changed files with 885 additions and 49 deletions.
4 changes: 2 additions & 2 deletions flang/include/flang/Parser/parse-tree.h
Original file line number Diff line number Diff line change
Expand Up @@ -3462,8 +3462,8 @@ struct OmpDeviceTypeClause {
// 2.12 if-clause -> IF ([ directive-name-modifier :] scalar-logical-expr)
struct OmpIfClause {
TUPLE_CLASS_BOILERPLATE(OmpIfClause);
ENUM_CLASS(DirectiveNameModifier, Parallel, Target, TargetEnterData,
TargetExitData, TargetData, TargetUpdate, Taskloop, Task)
ENUM_CLASS(DirectiveNameModifier, Parallel, Simd, Target, TargetData,
TargetEnterData, TargetExitData, TargetUpdate, Task, Taskloop, Teams)
std::tuple<std::optional<DirectiveNameModifier>, ScalarLogicalExpr> t;
};

Expand Down
112 changes: 90 additions & 22 deletions flang/lib/Lower/OpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -518,8 +518,10 @@ class ClauseProcessor {
bool processCopyin() const;
bool processDepend(llvm::SmallVectorImpl<mlir::Attribute> &dependTypeOperands,
llvm::SmallVectorImpl<mlir::Value> &dependOperands) const;
bool processIf(Fortran::lower::StatementContext &stmtCtx,
mlir::Value &result) const;
bool
processIf(Fortran::lower::StatementContext &stmtCtx,
Fortran::parser::OmpIfClause::DirectiveNameModifier directiveName,
mlir::Value &result) const;
bool
processLink(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
bool processMap(llvm::SmallVectorImpl<mlir::Value> &mapOperands,
Expand Down Expand Up @@ -1049,11 +1051,19 @@ genDependKindAttr(fir::FirOpBuilder &firOpBuilder,
pbKind);
}

static mlir::Value
getIfClauseOperand(Fortran::lower::AbstractConverter &converter,
Fortran::lower::StatementContext &stmtCtx,
const Fortran::parser::OmpClause::If *ifClause,
mlir::Location clauseLocation) {
static mlir::Value getIfClauseOperand(
Fortran::lower::AbstractConverter &converter,
Fortran::lower::StatementContext &stmtCtx,
const Fortran::parser::OmpClause::If *ifClause,
Fortran::parser::OmpIfClause::DirectiveNameModifier directiveName,
mlir::Location clauseLocation) {
// Only consider the clause if it's intended for the given directive.
auto &directive = std::get<
std::optional<Fortran::parser::OmpIfClause::DirectiveNameModifier>>(
ifClause->v.t);
if (directive && directive.value() != directiveName)
return nullptr;

fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
auto &expr = std::get<Fortran::parser::ScalarLogicalExpr>(ifClause->v.t);
mlir::Value ifVal = fir::getBase(
Expand Down Expand Up @@ -1572,17 +1582,25 @@ bool ClauseProcessor::processDepend(
});
}

bool ClauseProcessor::processIf(Fortran::lower::StatementContext &stmtCtx,
mlir::Value &result) const {
return findRepeatableClause<ClauseTy::If>(
bool ClauseProcessor::processIf(
Fortran::lower::StatementContext &stmtCtx,
Fortran::parser::OmpIfClause::DirectiveNameModifier directiveName,
mlir::Value &result) const {
bool found = false;
findRepeatableClause<ClauseTy::If>(
[&](const ClauseTy::If *ifClause,
const Fortran::parser::CharBlock &source) {
mlir::Location clauseLocation = converter.genLocation(source);
// TODO Consider DirectiveNameModifier of the `ifClause` to only search
// for an applicable 'if' clause.
result =
getIfClauseOperand(converter, stmtCtx, ifClause, clauseLocation);
mlir::Value operand = getIfClauseOperand(converter, stmtCtx, ifClause,
directiveName, clauseLocation);
// Assume that, at most, a single 'if' clause will be applicable to the
// given directive.
if (operand) {
result = operand;
found = true;
}
});
return found;
}

bool ClauseProcessor::processLink(
Expand Down Expand Up @@ -2109,8 +2127,30 @@ static void createTargetOp(Fortran::lower::AbstractConverter &converter,
llvm::SmallVector<mlir::Location> useDeviceLocs;
llvm::SmallVector<const Fortran::semantics::Symbol *> useDeviceSymbols;

Fortran::parser::OmpIfClause::DirectiveNameModifier directiveName;
switch (directive) {
case llvm::omp::Directive::OMPD_target:
directiveName = Fortran::parser::OmpIfClause::DirectiveNameModifier::Target;
break;
case llvm::omp::Directive::OMPD_target_data:
directiveName =
Fortran::parser::OmpIfClause::DirectiveNameModifier::TargetData;
break;
case llvm::omp::Directive::OMPD_target_enter_data:
directiveName =
Fortran::parser::OmpIfClause::DirectiveNameModifier::TargetEnterData;
break;
case llvm::omp::Directive::OMPD_target_exit_data:
directiveName =
Fortran::parser::OmpIfClause::DirectiveNameModifier::TargetExitData;
break;
default:
TODO(currentLocation, "OMPD_target directive unknown");
break;
}

ClauseProcessor cp(converter, opClauseList);
cp.processIf(stmtCtx, ifClauseOperand);
cp.processIf(stmtCtx, directiveName, ifClauseOperand);
cp.processDevice(stmtCtx, deviceOperand);
cp.processThreadLimit(stmtCtx, threadLmtOperand);
cp.processNowait(nowaitAttr);
Expand Down Expand Up @@ -2157,8 +2197,6 @@ static void createTargetOp(Fortran::lower::AbstractConverter &converter,
firOpBuilder.create<mlir::omp::ExitDataOp>(currentLocation, ifClauseOperand,
deviceOperand, nowaitAttr,
mapOperands, mapTypesArrayAttr);
} else {
TODO(currentLocation, "OMPD_target directive unknown");
}
}

Expand Down Expand Up @@ -2256,7 +2294,9 @@ createCombinedParallelOp(Fortran::lower::AbstractConverter &converter,
// 1. default
// Note: rest of the clauses are handled when the inner operation is created
ClauseProcessor cp(converter, opClauseList);
cp.processIf(stmtCtx, ifClauseOperand);
cp.processIf(stmtCtx,
Fortran::parser::OmpIfClause::DirectiveNameModifier::Parallel,
ifClauseOperand);
cp.processNumThreads(stmtCtx, numThreadsClauseOperand);
cp.processProcBind(procBindKindAttr);

Expand Down Expand Up @@ -2315,7 +2355,9 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
cp.processCollapse(currentLocation, eval, lowerBound, upperBound, step, iv,
loopVarTypeSize);
cp.processScheduleChunk(stmtCtx, scheduleChunkClauseOperand);
cp.processIf(stmtCtx, ifClauseOperand);
cp.processIf(stmtCtx,
Fortran::parser::OmpIfClause::DirectiveNameModifier::Simd,
ifClauseOperand);
cp.processReduction(currentLocation, reductionVars, reductionDeclSymbols);
cp.processSimdlen(simdlenClauseOperand);
cp.processSafelen(safelenClauseOperand);
Expand Down Expand Up @@ -2416,10 +2458,38 @@ genOMP(Fortran::lower::AbstractConverter &converter,
llvm::SmallVector<mlir::Attribute> dependTypeOperands, reductionDeclSymbols;
mlir::UnitAttr nowaitAttr, untiedAttr, mergeableAttr;

// Use placeholder value to avoid uninitialized `directiveName` compiler
// errors. The 'if clause' obtained won't be used for these directives.
Fortran::parser::OmpIfClause::DirectiveNameModifier directiveName =
Fortran::parser::OmpIfClause::DirectiveNameModifier::Parallel;
switch (blockDirective.v) {
case llvm::omp::OMPD_parallel:
directiveName =
Fortran::parser::OmpIfClause::DirectiveNameModifier::Parallel;
break;
case llvm::omp::OMPD_task:
directiveName = Fortran::parser::OmpIfClause::DirectiveNameModifier::Task;
break;
// Target-related 'if' clauses handled by createTargetOp().
case llvm::omp::OMPD_target:
case llvm::omp::OMPD_target_data:
// These block directives do not accept an 'if' clause.
case llvm::omp::OMPD_master:
case llvm::omp::OMPD_single:
case llvm::omp::OMPD_ordered:
case llvm::omp::OMPD_taskgroup:
break;
default:
TODO(currentLocation,
"Unhandled block directive (" +
llvm::omp::getOpenMPDirectiveName(blockDirective.v) + ")");
break;
}

const auto &opClauseList =
std::get<Fortran::parser::OmpClauseList>(beginBlockDirective.t);
ClauseProcessor cp(converter, opClauseList);
cp.processIf(stmtCtx, ifClauseOperand);
cp.processIf(stmtCtx, directiveName, ifClauseOperand);
cp.processNumThreads(stmtCtx, numThreadsClauseOperand);
cp.processProcBind(procBindKindAttr);
cp.processAllocate(allocatorOperands, allocateOperands);
Expand Down Expand Up @@ -2524,8 +2594,6 @@ genOMP(Fortran::lower::AbstractConverter &converter,
} else if (blockDirective.v == llvm::omp::OMPD_target_data) {
createTargetOp(converter, opClauseList, blockDirective.v, currentLocation,
&eval);
} else {
TODO(currentLocation, "Unhandled block directive");
}
}

Expand Down
4 changes: 3 additions & 1 deletion flang/lib/Parser/openmp-parsers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ TYPE_PARSER(construct<OmpDeviceTypeClause>(
TYPE_PARSER(construct<OmpIfClause>(
maybe(
("PARALLEL" >> pure(OmpIfClause::DirectiveNameModifier::Parallel) ||
"SIMD" >> pure(OmpIfClause::DirectiveNameModifier::Simd) ||
"TARGET ENTER DATA" >>
pure(OmpIfClause::DirectiveNameModifier::TargetEnterData) ||
"TARGET EXIT DATA" >>
Expand All @@ -125,7 +126,8 @@ TYPE_PARSER(construct<OmpIfClause>(
pure(OmpIfClause::DirectiveNameModifier::TargetUpdate) ||
"TARGET" >> pure(OmpIfClause::DirectiveNameModifier::Target) ||
"TASK"_id >> pure(OmpIfClause::DirectiveNameModifier::Task) ||
"TASKLOOP" >> pure(OmpIfClause::DirectiveNameModifier::Taskloop)) /
"TASKLOOP" >> pure(OmpIfClause::DirectiveNameModifier::Taskloop) ||
"TEAMS" >> pure(OmpIfClause::DirectiveNameModifier::Teams)) /
":"),
scalarLogicalExpr))

Expand Down
22 changes: 19 additions & 3 deletions flang/lib/Semantics/check-omp-structure.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2395,19 +2395,35 @@ void OmpStructureChecker::Enter(const parser::OmpClause::Defaultmap &x) {
void OmpStructureChecker::Enter(const parser::OmpClause::If &x) {
CheckAllowed(llvm::omp::Clause::OMPC_if);
using dirNameModifier = parser::OmpIfClause::DirectiveNameModifier;
// TODO Check that, when multiple 'if' clauses are applied to a combined
// construct, at most one of them applies to each directive.
// Need to define set here because llvm::omp::teamSet does not include target
// teams combined constructs.
OmpDirectiveSet teamSet{llvm::omp::Directive::OMPD_target_teams,
llvm::omp::Directive::OMPD_target_teams_distribute,
llvm::omp::Directive::OMPD_target_teams_distribute_parallel_do,
llvm::omp::Directive::OMPD_target_teams_distribute_parallel_do_simd,
llvm::omp::Directive::OMPD_target_teams_distribute_simd,
llvm::omp::Directive::OMPD_teams,
llvm::omp::Directive::OMPD_teams_distribute,
llvm::omp::Directive::OMPD_teams_distribute_parallel_do,
llvm::omp::Directive::OMPD_teams_distribute_parallel_do_simd,
llvm::omp::Directive::OMPD_teams_distribute_simd};
static std::unordered_map<dirNameModifier, OmpDirectiveSet>
dirNameModifierMap{{dirNameModifier::Parallel, llvm::omp::parallelSet},
{dirNameModifier::Simd, llvm::omp::simdSet},
{dirNameModifier::Target, llvm::omp::targetSet},
{dirNameModifier::TargetData,
{llvm::omp::Directive::OMPD_target_data}},
{dirNameModifier::TargetEnterData,
{llvm::omp::Directive::OMPD_target_enter_data}},
{dirNameModifier::TargetExitData,
{llvm::omp::Directive::OMPD_target_exit_data}},
{dirNameModifier::TargetData,
{llvm::omp::Directive::OMPD_target_data}},
{dirNameModifier::TargetUpdate,
{llvm::omp::Directive::OMPD_target_update}},
{dirNameModifier::Task, {llvm::omp::Directive::OMPD_task}},
{dirNameModifier::Taskloop, llvm::omp::taskloopSet}};
{dirNameModifier::Taskloop, llvm::omp::taskloopSet},
{dirNameModifier::Teams, teamSet}};
if (const auto &directiveName{
std::get<std::optional<dirNameModifier>>(x.v.t)}) {
auto search{dirNameModifierMap.find(*directiveName)};
Expand Down
13 changes: 6 additions & 7 deletions flang/lib/Semantics/check-omp-structure.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,22 +66,21 @@ static OmpDirectiveSet targetSet{Directive::OMPD_target,
Directive::OMPD_target_parallel, Directive::OMPD_target_parallel_do,
Directive::OMPD_target_parallel_do_simd, Directive::OMPD_target_simd,
Directive::OMPD_target_teams, Directive::OMPD_target_teams_distribute,
Directive::OMPD_target_teams_distribute_parallel_do,
Directive::OMPD_target_teams_distribute_parallel_do_simd,
Directive::OMPD_target_teams_distribute_simd};
static OmpDirectiveSet simdSet{Directive::OMPD_distribute_parallel_do_simd,
Directive::OMPD_distribute_simd, Directive::OMPD_parallel_do_simd,
Directive::OMPD_do_simd, Directive::OMPD_simd,
Directive::OMPD_target_parallel_do_simd,
Directive::OMPD_distribute_simd, Directive::OMPD_do_simd,
Directive::OMPD_parallel_do_simd, Directive::OMPD_simd,
Directive::OMPD_target_parallel_do_simd, Directive::OMPD_target_simd,
Directive::OMPD_target_teams_distribute_parallel_do_simd,
Directive::OMPD_target_teams_distribute_simd, Directive::OMPD_target_simd,
Directive::OMPD_taskloop_simd,
Directive::OMPD_target_teams_distribute_simd, Directive::OMPD_taskloop_simd,
Directive::OMPD_teams_distribute_parallel_do_simd,
Directive::OMPD_teams_distribute_simd};
static OmpDirectiveSet teamSet{Directive::OMPD_teams,
Directive::OMPD_teams_distribute,
Directive::OMPD_teams_distribute_parallel_do,
Directive::OMPD_teams_distribute_parallel_do_simd,
Directive::OMPD_teams_distribute_parallel_for,
Directive::OMPD_teams_distribute_parallel_for_simd,
Directive::OMPD_teams_distribute_simd};
static OmpDirectiveSet taskGeneratingSet{
OmpDirectiveSet{Directive::OMPD_task} | taskloopSet};
Expand Down
61 changes: 61 additions & 0 deletions flang/test/Parser/OpenMP/if-clause-unparse.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
! RUN: %flang_fc1 -fdebug-unparse-no-sema -fopenmp %s | FileCheck %s
! Check Unparsing of OpenMP IF clause

program if_unparse
logical :: cond
integer :: i

! CHECK: !$OMP TARGET UPDATE
! CHECK-SAME: IF(cond)
!$omp target update if(cond)

! CHECK: !$OMP TARGET UPDATE
! CHECK-SAME: IF(TARGETUPDATE:cond)
!$omp target update if(target update: cond)

! CHECK: !$OMP TARGET UPDATE
! CHECK-SAME: IF(TARGETUPDATE:cond)
!$omp target update if(targetupdate: cond)

! CHECK: !$OMP TARGET ENTER DATA
! CHECK-SAME: IF(TARGETENTERDATA:cond)
!$omp target enter data map(to: i) if(target enter data: cond)

! CHECK: !$OMP TARGET EXIT DATA
! CHECK-SAME: IF(TARGETEXITDATA:cond)
!$omp target exit data map(from: i) if(target exit data: cond)

! CHECK: !$OMP TARGET DATA
! CHECK-SAME: IF(TARGETDATA:cond)
!$omp target data map(tofrom: i) if(target data: cond)
!$omp end target data

! CHECK: !$OMP TARGET
! CHECK-SAME: IF(TARGET:cond)
!$omp target if(target: cond)
!$omp end target

! CHECK: !$OMP TEAMS
! CHECK-SAME: IF(TEAMS:cond)
!$omp teams if(teams: cond)
!$omp end teams

! CHECK: !$OMP PARALLEL DO SIMD
! CHECK-SAME: IF(PARALLEL:i<10) IF(SIMD:.FALSE.)
!$omp parallel do simd if(parallel: i < 10) if(simd: .false.)
do i = 1, 10
end do
!$omp end parallel do simd

! CHECK: !$OMP TASK
! CHECK-SAME: IF(TASK:cond)
!$omp task if(task: cond)
!$omp end task

! CHECK: !$OMP TASKLOOP
! CHECK-SAME: IF(TASKLOOP:cond)
!$omp taskloop if(taskloop: cond)
do i = 1, 10
end do
!$omp end taskloop
end program if_unparse

0 comments on commit 65e80d6

Please sign in to comment.