Skip to content

Commit

Permalink
[Flang][OpenMP][Lower] Support lowering of teams directive to MLIR
Browse files Browse the repository at this point in the history
This patch adds support for translating `teams` OpenMP directives to MLIR, when
appearing as either loop or block constructs and as part of combined constructs
or on its own.

The current Fortran parser does not allow the specification of the optional
lower bound for the "num_teams" clause, so only the `num_teams_upper` MLIR
argument is set by this patch.

Depends on D156809

Differential Revision: https://reviews.llvm.org/D156884
  • Loading branch information
skatrak committed Aug 15, 2023
1 parent 8c177ae commit 211ed03
Show file tree
Hide file tree
Showing 4 changed files with 251 additions and 6 deletions.
59 changes: 55 additions & 4 deletions flang/lib/Lower/OpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,8 @@ class ClauseProcessor {
bool processHint(mlir::IntegerAttr &result) const;
bool processMergeable(mlir::UnitAttr &result) const;
bool processNowait(mlir::UnitAttr &result) const;
bool processNumTeams(Fortran::lower::StatementContext &stmtCtx,
mlir::Value &result) const;
bool processNumThreads(Fortran::lower::StatementContext &stmtCtx,
mlir::Value &result) const;
bool processOrdered(mlir::IntegerAttr &result) const;
Expand Down Expand Up @@ -1347,6 +1349,18 @@ bool ClauseProcessor::processNowait(mlir::UnitAttr &result) const {
return markClauseOccurrence<ClauseTy::Nowait>(result);
}

bool ClauseProcessor::processNumTeams(Fortran::lower::StatementContext &stmtCtx,
mlir::Value &result) const {
// TODO Get lower and upper bounds for num_teams when parser is updated to
// accept both.
if (auto *numTeamsClause = findUniqueClause<ClauseTy::NumTeams>()) {
result = fir::getBase(converter.genExprValue(
*Fortran::semantics::GetExpr(numTeamsClause->v), stmtCtx));
return true;
}
return false;
}

bool ClauseProcessor::processNumThreads(
Fortran::lower::StatementContext &stmtCtx, mlir::Value &result) const {
if (auto *numThreadsClause = findUniqueClause<ClauseTy::NumThreads>()) {
Expand Down Expand Up @@ -2359,6 +2373,40 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
mapOperands, mapTypesArrayAttr);
}

static mlir::omp::TeamsOp
genTeamsOp(Fortran::lower::AbstractConverter &converter,
Fortran::lower::pft::Evaluation &eval,
mlir::Location currentLocation,
const Fortran::parser::OmpClauseList &clauseList,
bool outerCombined = false) {
Fortran::lower::StatementContext stmtCtx;
mlir::Value numTeamsClauseOperand, ifClauseOperand, threadLimitClauseOperand;
llvm::SmallVector<mlir::Value> allocateOperands, allocatorOperands,
reductionVars;
llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;

ClauseProcessor cp(converter, clauseList);
cp.processIf(stmtCtx,
Fortran::parser::OmpIfClause::DirectiveNameModifier::Teams,
ifClauseOperand);
cp.processAllocate(allocatorOperands, allocateOperands);
cp.processDefault();
cp.processNumTeams(stmtCtx, numTeamsClauseOperand);
cp.processThreadLimit(stmtCtx, threadLimitClauseOperand);
if (cp.processReduction(currentLocation, reductionVars, reductionDeclSymbols))
TODO(currentLocation, "Reduction of TEAMS directive");

return genOpWithBody<mlir::omp::TeamsOp>(
converter, eval, currentLocation, outerCombined, &clauseList,
/*num_teams_lower=*/nullptr, numTeamsClauseOperand, ifClauseOperand,
threadLimitClauseOperand, allocateOperands, allocatorOperands,
reductionVars,
reductionDeclSymbols.empty()
? nullptr
: mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
reductionDeclSymbols));
}

//===----------------------------------------------------------------------===//
// genOMP() Code generation helper functions
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -2483,7 +2531,8 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
if ((llvm::omp::allTeamsSet & llvm::omp::loopConstructSet)
.test(ompDirective)) {
validDirective = true;
TODO(currentLocation, "Teams construct");
genTeamsOp(converter, eval, currentLocation, loopOpClauseList,
/*outerCombined=*/true);
}
if (llvm::omp::allDistributeSet.test(ompDirective)) {
validDirective = true;
Expand Down Expand Up @@ -2628,7 +2677,8 @@ genOMP(Fortran::lower::AbstractConverter &converter,
!std::get_if<Fortran::parser::OmpClause::Map>(&clause.u) &&
!std::get_if<Fortran::parser::OmpClause::UseDevicePtr>(&clause.u) &&
!std::get_if<Fortran::parser::OmpClause::UseDeviceAddr>(&clause.u) &&
!std::get_if<Fortran::parser::OmpClause::ThreadLimit>(&clause.u)) {
!std::get_if<Fortran::parser::OmpClause::ThreadLimit>(&clause.u) &&
!std::get_if<Fortran::parser::OmpClause::NumTeams>(&clause.u)) {
TODO(clauseLocation, "OpenMP Block construct clause");
}
}
Expand Down Expand Up @@ -2667,7 +2717,8 @@ genOMP(Fortran::lower::AbstractConverter &converter,
genTaskGroupOp(converter, eval, currentLocation, beginClauseList);
break;
case llvm::omp::Directive::OMPD_teams:
TODO(currentLocation, "Teams construct");
genTeamsOp(converter, eval, currentLocation, beginClauseList,
/*outerCombined=*/false);
break;
case llvm::omp::Directive::OMPD_workshare:
TODO(currentLocation, "Workshare construct");
Expand All @@ -2683,7 +2734,7 @@ genOMP(Fortran::lower::AbstractConverter &converter,
}
if ((llvm::omp::allTeamsSet & llvm::omp::blockConstructSet)
.test(directive.v)) {
TODO(currentLocation, "Teams construct");
genTeamsOp(converter, eval, currentLocation, beginClauseList);
combinedDirective = true;
}
if ((llvm::omp::allParallelSet & llvm::omp::blockConstructSet)
Expand Down
12 changes: 12 additions & 0 deletions flang/test/Lower/OpenMP/Todo/reduction-teams.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
! RUN: %not_todo_cmd bbc -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s
! RUN: %not_todo_cmd %flang_fc1 -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s

! CHECK: not yet implemented: Reduction of TEAMS directive
subroutine reduction_teams()
integer :: i
i = 0

!$omp teams reduction(+:i)
i = i + 1
!$omp end teams
end subroutine reduction_teams
72 changes: 70 additions & 2 deletions flang/test/Lower/OpenMP/if-clause.f90
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,13 @@ program main
! - PARALLEL SECTIONS
! - PARALLEL WORKSHARE
! - TARGET PARALLEL
! - TARGET TEAMS
! - TARGET TEAMS DISTRIBUTE
! - TARGET TEAMS DISTRIBUTE PARALLEL DO
! - TARGET TEAMS DISTRIBUTE PARALLEL DO SIMD
! - TARGET TEAMS DISTRIBUTE SIMD
! - TARGET UPDATE
! - TASKLOOP
! - TASKLOOP SIMD
! - TEAMS
! - TEAMS DISTRIBUTE
! - TEAMS DISTRIBUTE PARALLEL DO
! - TEAMS DISTRIBUTE PARALLEL DO SIMD
Expand Down Expand Up @@ -416,6 +414,54 @@ program main
end do
!$omp end target simd

! ----------------------------------------------------------------------------
! TARGET TEAMS
! ----------------------------------------------------------------------------

! CHECK: omp.target
! CHECK-NOT: if({{.*}})
! CHECK-SAME: {
! CHECK: omp.teams
! CHECK-NOT: if({{.*}})
! CHECK-SAME: {
!$omp target teams
i = 1
!$omp end target teams

! CHECK: omp.target
! CHECK-SAME: if({{.*}})
! CHECK: omp.teams
! CHECK-SAME: if({{.*}})
!$omp target teams if(.true.)
i = 1
!$omp end target teams

! CHECK: omp.target
! CHECK-SAME: if({{.*}})
! CHECK: omp.teams
! CHECK-SAME: if({{.*}})
!$omp target teams if(target: .true.) if(teams: .false.)
i = 1
!$omp end target teams

! CHECK: omp.target
! CHECK-SAME: if({{.*}})
! CHECK: omp.teams
! CHECK-NOT: if({{.*}})
! CHECK-SAME: {
!$omp target teams if(target: .true.)
i = 1
!$omp end target teams

! CHECK: omp.target
! CHECK-NOT: if({{.*}})
! CHECK-SAME: {
! CHECK: omp.teams
! CHECK-SAME: if({{.*}})
!$omp target teams if(teams: .true.)
i = 1
!$omp end target teams

! ----------------------------------------------------------------------------
! TASK
! ----------------------------------------------------------------------------
Expand All @@ -434,4 +480,26 @@ program main
! CHECK-SAME: if({{.*}})
!$omp task if(task: .true.)
!$omp end task

! ----------------------------------------------------------------------------
! TEAMS
! ----------------------------------------------------------------------------
! CHECK: omp.teams
! CHECK-NOT: if({{.*}})
! CHECK-SAME: {
!$omp teams
i = 1
!$omp end teams

! CHECK: omp.teams
! CHECK-SAME: if({{.*}})
!$omp teams if(.true.)
i = 1
!$omp end teams

! CHECK: omp.teams
! CHECK-SAME: if({{.*}})
!$omp teams if(teams: .true.)
i = 1
!$omp end teams
end program main
114 changes: 114 additions & 0 deletions flang/test/Lower/OpenMP/teams.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
! RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | FileCheck %s

! CHECK-LABEL: func @_QPteams_simple
subroutine teams_simple()
! CHECK: omp.teams
!$omp teams
! CHECK: fir.call
call f1()
! CHECK: omp.terminator
!$omp end teams
end subroutine teams_simple

!===============================================================================
! `num_teams` clause
!===============================================================================

! CHECK-LABEL: func @_QPteams_numteams
subroutine teams_numteams(num_teams)
integer, intent(inout) :: num_teams

! CHECK: omp.teams
! CHECK-SAME: num_teams( to %{{.*}}: i32)
!$omp teams num_teams(4)
! CHECK: fir.call
call f1()
! CHECK: omp.terminator
!$omp end teams

! CHECK: omp.teams
! CHECK-SAME: num_teams( to %{{.*}}: i32)
!$omp teams num_teams(num_teams)
! CHECK: fir.call
call f2()
! CHECK: omp.terminator
!$omp end teams

end subroutine teams_numteams

!===============================================================================
! `if` clause
!===============================================================================

! CHECK-LABEL: func @_QPteams_if
subroutine teams_if(alpha)
integer, intent(in) :: alpha
logical :: condition

! CHECK: omp.teams
! CHECK-SAME: if(%{{.*}})
!$omp teams if(.false.)
! CHECK: fir.call
call f1()
! CHECK: omp.terminator
!$omp end teams

! CHECK: omp.teams
! CHECK-SAME: if(%{{.*}})
!$omp teams if(alpha .le. 0)
! CHECK: fir.call
call f2()
! CHECK: omp.terminator
!$omp end teams

! CHECK: omp.teams
! CHECK-SAME: if(%{{.*}})
!$omp teams if(condition)
! CHECK: fir.call
call f3()
! CHECK: omp.terminator
!$omp end teams
end subroutine teams_if

!===============================================================================
! `thread_limit` clause
!===============================================================================

! CHECK-LABEL: func @_QPteams_threadlimit
subroutine teams_threadlimit(thread_limit)
integer, intent(inout) :: thread_limit

! CHECK: omp.teams
! CHECK-SAME: thread_limit(%{{.*}}: i32)
!$omp teams thread_limit(4)
! CHECK: fir.call
call f1()
! CHECK: omp.terminator
!$omp end teams

! CHECK: omp.teams
! CHECK-SAME: thread_limit(%{{.*}}: i32)
!$omp teams thread_limit(thread_limit)
! CHECK: fir.call
call f2()
! CHECK: omp.terminator
!$omp end teams

end subroutine teams_threadlimit

!===============================================================================
! `allocate` clause
!===============================================================================

! CHECK-LABEL: func @_QPteams_allocate
subroutine teams_allocate()
use omp_lib
integer :: x
! CHECK: omp.teams
! CHECK-SAME: allocate(%{{.+}} : i32 -> %{{.+}} : !fir.ref<i32>)
!$omp teams allocate(omp_high_bw_mem_alloc: x) private(x)
! CHECK: arith.addi
x = x + 12
! CHECK: omp.terminator
!$omp end teams
end subroutine teams_allocate

0 comments on commit 211ed03

Please sign in to comment.