Skip to content

Commit

Permalink
[flang][OpenMP] Added lowering support for sections construct
Browse files Browse the repository at this point in the history
This patch adds lowering support (from PFT to FIR) for sections construct

Reviewed By: kiranchandramohan

Differential Revision: https://reviews.llvm.org/D122302
  • Loading branch information
NimishMishra committed Mar 24, 2022
1 parent 431c142 commit 88d5289
Show file tree
Hide file tree
Showing 2 changed files with 265 additions and 2 deletions.
87 changes: 85 additions & 2 deletions flang/lib/Lower/OpenMP.cpp
Expand Up @@ -96,6 +96,38 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
}
}

static void
genAllocateClause(Fortran::lower::AbstractConverter &converter,
const Fortran::parser::OmpAllocateClause &ompAllocateClause,
SmallVector<Value> &allocatorOperands,
SmallVector<Value> &allocateOperands) {
auto &firOpBuilder = converter.getFirOpBuilder();
auto currentLocation = converter.getCurrentLocation();
Fortran::lower::StatementContext stmtCtx;

mlir::Value allocatorOperand;
const Fortran::parser::OmpObjectList &ompObjectList =
std::get<Fortran::parser::OmpObjectList>(ompAllocateClause.t);
const auto &allocatorValue =
std::get<std::optional<Fortran::parser::OmpAllocateClause::Allocator>>(
ompAllocateClause.t);
// Check if allocate clause has allocator specified. If so, add it
// to list of allocators, otherwise, add default allocator to
// list of allocators.
if (allocatorValue) {
allocatorOperand = fir::getBase(converter.genExprValue(
*Fortran::semantics::GetExpr(allocatorValue->v), stmtCtx));
allocatorOperands.insert(allocatorOperands.end(), ompObjectList.v.size(),
allocatorOperand);
} else {
allocatorOperand = firOpBuilder.createIntegerConstant(
currentLocation, firOpBuilder.getI32Type(), 1);
allocatorOperands.insert(allocatorOperands.end(), ompObjectList.v.size(),
allocatorOperand);
}
genObjectList(ompObjectList, converter, allocateOperands);
}

static void
genOMP(Fortran::lower::AbstractConverter &converter,
Fortran::lower::pft::Evaluation &eval,
Expand Down Expand Up @@ -262,6 +294,57 @@ genOMP(Fortran::lower::AbstractConverter &converter,
createBodyOfOp<omp::CriticalOp>(criticalOp, firOpBuilder, currentLocation);
}

static void
genOMP(Fortran::lower::AbstractConverter &converter,
Fortran::lower::pft::Evaluation &eval,
const Fortran::parser::OpenMPSectionConstruct &sectionConstruct) {

auto &firOpBuilder = converter.getFirOpBuilder();
auto currentLocation = converter.getCurrentLocation();
mlir::omp::SectionOp sectionOp =
firOpBuilder.create<mlir::omp::SectionOp>(currentLocation);
createBodyOfOp<omp::SectionOp>(sectionOp, firOpBuilder, currentLocation);
}

// TODO: Add support for reduction
static void
genOMP(Fortran::lower::AbstractConverter &converter,
Fortran::lower::pft::Evaluation &eval,
const Fortran::parser::OpenMPSectionsConstruct &sectionsConstruct) {
auto &firOpBuilder = converter.getFirOpBuilder();
auto currentLocation = converter.getCurrentLocation();
SmallVector<Value> reductionVars, allocateOperands, allocatorOperands;
mlir::UnitAttr noWaitClauseOperand;
const auto &sectionsClauseList = std::get<Fortran::parser::OmpClauseList>(
std::get<Fortran::parser::OmpBeginSectionsDirective>(sectionsConstruct.t)
.t);
for (const Fortran::parser::OmpClause &clause : sectionsClauseList.v) {
if (std::get_if<Fortran::parser::OmpClause::Reduction>(&clause.u)) {
TODO(currentLocation, "OMPC_Reduction");
} else if (const auto &allocateClause =
std::get_if<Fortran::parser::OmpClause::Allocate>(
&clause.u)) {
genAllocateClause(converter, allocateClause->v, allocatorOperands,
allocateOperands);
}
}
const auto &endSectionsClauseList =
std::get<Fortran::parser::OmpEndSectionsDirective>(sectionsConstruct.t);
const auto &clauseList =
std::get<Fortran::parser::OmpClauseList>(endSectionsClauseList.t);
for (const auto &clause : clauseList.v) {
if (std::get_if<Fortran::parser::OmpClause::Nowait>(&clause.u)) {
noWaitClauseOperand = firOpBuilder.getUnitAttr();
}
}

mlir::omp::SectionsOp sectionsOp = firOpBuilder.create<mlir::omp::SectionsOp>(
currentLocation, reductionVars, /*reductions = */ nullptr,
allocateOperands, allocatorOperands, noWaitClauseOperand);

createBodyOfOp<omp::SectionsOp>(sectionsOp, firOpBuilder, currentLocation);
}

void Fortran::lower::genOpenMPConstruct(
Fortran::lower::AbstractConverter &converter,
Fortran::lower::pft::Evaluation &eval,
Expand All @@ -275,10 +358,10 @@ void Fortran::lower::genOpenMPConstruct(
},
[&](const Fortran::parser::OpenMPSectionsConstruct
&sectionsConstruct) {
TODO(converter.getCurrentLocation(), "OpenMPSectionsConstruct");
genOMP(converter, eval, sectionsConstruct);
},
[&](const Fortran::parser::OpenMPSectionConstruct &sectionConstruct) {
TODO(converter.getCurrentLocation(), "OpenMPSectionConstruct");
genOMP(converter, eval, sectionConstruct);
},
[&](const Fortran::parser::OpenMPLoopConstruct &loopConstruct) {
TODO(converter.getCurrentLocation(), "OpenMPLoopConstruct");
Expand Down
180 changes: 180 additions & 0 deletions flang/test/Lower/OpenMP/sections.f90
@@ -0,0 +1,180 @@
! This test checks the lowering of OpenMP sections construct with several clauses present

! RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | FileCheck %s --check-prefix="FIRDialect"
! RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | fir-opt --fir-to-llvm-ir | FileCheck %s --check-prefix="LLVMDialect"
! TODO before (%flang_fc1 -emit-fir -fopenmp %s -o - | fir-opt --fir-to-llvm-ir | tco | FileCheck %s --check-prefix="LLVMIR"):
! ensure allocate clause lowering

!FIRDialect: func @_QQmain() {
!FIRDialect: %[[COUNT:.*]] = fir.address_of(@_QFEcount) : !fir.ref<i32>
!FIRDialect: %[[DOUBLE_COUNT:.*]] = fir.address_of(@_QFEdouble_count) : !fir.ref<i32>
!FIRDialect: %[[ETA:.*]] = fir.alloca f32 {bindc_name = "eta", uniq_name = "_QFEeta"}
!FIRDialect: %[[CONST_1:.*]] = arith.constant 1 : i32
!FIRDialect: omp.sections allocate(%[[CONST_1]] : i32 -> %0 : !fir.ref<i32>) {
!FIRDialect: omp.section {
!FIRDialect: {{.*}} = arith.constant 5 : i32
!FIRDialect: fir.store {{.*}} to {{.*}} : !fir.ref<i32>
!FIRDialect: {{.*}} = fir.load %[[COUNT]] : !fir.ref<i32>
!FIRDialect: {{.*}} = fir.load %[[DOUBLE_COUNT]] : !fir.ref<i32>
!FIRDialect: {{.*}} = arith.muli {{.*}}, {{.*}} : i32
!FIRDialect: {{.*}} = fir.convert {{.*}} : (i32) -> f32
!FIRDialect: fir.store {{.*}} to %[[ETA]] : !fir.ref<f32>
!FIRDialect: omp.terminator
!FIRDialect: }
!FIRDialect: omp.section {
!FIRDialect: {{.*}} = fir.load %[[DOUBLE_COUNT]] : !fir.ref<i32>
!FIRDialect: {{.*}} = arith.constant 1 : i32
!FIRDialect: {{.*}} = arith.addi {{.*}} : i32
!FIRDialect: fir.store {{.*}} to %[[DOUBLE_COUNT]] : !fir.ref<i32>
!FIRDialect: omp.terminator
!FIRDialect: }
!FIRDialect: omp.section {
!FIRDialect: {{.*}} = fir.load %[[ETA]] : !fir.ref<f32>
!FIRDialect: {{.*}} = arith.constant 7.000000e+00 : f32
!FIRDialect: {{.*}} = arith.subf {{.*}} : f32
!FIRDialect: fir.store {{.*}} to %[[ETA]] : !fir.ref<f32>
!FIRDialect: {{.*}} = fir.load %[[COUNT]] : !fir.ref<i32>
!FIRDialect: {{.*}} = fir.convert {{.*}} : (i32) -> f32
!FIRDialect: {{.*}} = fir.load %[[ETA]] : !fir.ref<f32>
!FIRDialect: {{.*}} = arith.mulf {{.*}}, {{.*}} : f32
!FIRDialect: {{.*}} = fir.convert {{.*}} : (f32) -> i32
!FIRDialect: fir.store {{.*}} to %[[COUNT]] : !fir.ref<i32>
!FIRDialect: {{.*}} = fir.load %[[COUNT]] : !fir.ref<i32>
!FIRDialect: {{.*}} = fir.convert {{.*}} : (i32) -> f32
!FIRDialect: {{.*}} = fir.load %[[ETA]] : !fir.ref<f32>
!FIRDialect: {{.*}} = arith.subf {{.*}}, {{.*}} : f32
!FIRDialect: {{.*}} = fir.convert {{.*}} : (f32) -> i32
!FIRDialect: fir.store {{.*}} to %[[DOUBLE_COUNT]] : !fir.ref<i32>
!FIRDialect: omp.terminator
!FIRDialect: }
!FIRDialect: omp.terminator
!FIRDialect: }
!FIRDialect: omp.sections nowait {
!FIRDialect: omp.terminator
!FIRDialect: }
!FIRDialect: return
!FIRDialect: }

!LLVMDialect: llvm.func @_QQmain() {
!LLVMDialect: %[[COUNT:.*]] = llvm.mlir.addressof @_QFEcount : !llvm.ptr<i32>
!LLVMDialect: {{.*}} = builtin.unrealized_conversion_cast %[[COUNT]] : !llvm.ptr<i32> to !fir.ref<i32>
!LLVMDialect: %[[DOUBLE_COUNT:.*]] = llvm.mlir.addressof @_QFEdouble_count : !llvm.ptr<i32>
!LLVMDialect: %[[ALLOCATOR:.*]] = llvm.mlir.constant(1 : i64) : i64
!LLVMDialect: %[[ETA:.*]] = llvm.alloca %[[ALLOCATOR]] x f32 {bindc_name = "eta", in_type = f32, operand_segment_sizes = dense<0> : vector<2xi32>, uniq_name = "_QFEeta"} : (i64) -> !llvm.ptr<f32>
!LLVMDialect: %[[CONSTANT:.*]] = llvm.mlir.constant(1 : i32) : i32
!LLVMDialect: omp.sections allocate(%[[CONSTANT]] : i32 -> %1 : !fir.ref<i32>) {
!LLVMDialect: omp.section {
!LLVMDialect: {{.*}} = llvm.mlir.constant(5 : i32) : i32
!LLVMDialect: llvm.store {{.*}}, %[[COUNT]] : !llvm.ptr<i32>
!LLVMDialect: {{.*}} = llvm.load %[[COUNT]] : !llvm.ptr<i32>
!LLVMDialect: {{.*}} = llvm.load %[[DOUBLE_COUNT]] : !llvm.ptr<i32>
!LLVMDialect: {{.*}} = llvm.mul {{.*}}, {{.*}} : i32
!LLVMDialect: {{.*}} = llvm.sitofp {{.*}} : i32 to f32
!LLVMDialect: llvm.store {{.*}}, %[[ETA]] : !llvm.ptr<f32>
!LLVMDialect: omp.terminator
!LLVMDialect: }
!LLVMDialect: omp.section {
!LLVMDialect: {{.*}} = llvm.load %[[DOUBLE_COUNT]] : !llvm.ptr<i32>
!LLVMDialect: {{.*}} = llvm.mlir.constant(1 : i32) : i32
!LLVMDialect: {{.*}} = llvm.add {{.*}}, {{.*}} : i32
!LLVMDialect: llvm.store {{.*}}, %[[DOUBLE_COUNT]] : !llvm.ptr<i32>
!LLVMDialect: omp.terminator
!LLVMDialect: }
!LLVMDialect: omp.section {
!LLVMDialect: {{.*}} = llvm.load %[[ETA]] : !llvm.ptr<f32>
!LLVMDialect: {{.*}} = llvm.mlir.constant(7.000000e+00 : f32) : f32
!LLVMDialect: {{.*}} = llvm.fsub {{.*}}, {{.*}} : f32
!LLVMDialect: llvm.store {{.*}}, %[[ETA]] : !llvm.ptr<f32>
!LLVMDialect: {{.*}} = llvm.load %[[COUNT]] : !llvm.ptr<i32>
!LLVMDialect: {{.*}} = llvm.sitofp {{.*}} : i32 to f32
!LLVMDialect: {{.*}} = llvm.load %[[ETA]] : !llvm.ptr<f32>
!LLVMDialect: {{.*}} = llvm.fmul {{.*}}, {{.*}} : f32
!LLVMDialect: {{.*}} = llvm.fptosi {{.*}} : f32 to i32
!LLVMDialect: llvm.store {{.*}}, %[[COUNT]] : !llvm.ptr<i32>
!LLVMDialect: {{.*}} = llvm.load %[[COUNT]] : !llvm.ptr<i32>
!LLVMDialect: {{.*}} = llvm.sitofp {{.*}} : i32 to f32
!LLVMDialect: {{.*}} = llvm.load %[[ETA]] : !llvm.ptr<f32>
!LLVMDialect: {{.*}} = llvm.fsub {{.*}}, {{.*}} : f32
!LLVMDialect: {{.*}} = llvm.fptosi {{.*}} : f32 to i32
!LLVMDialect: llvm.store {{.*}}, %[[DOUBLE_COUNT]] : !llvm.ptr<i32>
!LLVMDialect: omp.terminator
!LLVMDialect: }
!LLVMDialect: omp.terminator
!LLVMDialect: }
!LLVMDialect: omp.sections nowait {
!LLVMDialect: omp.section {
!LLVMDialect: omp.terminator
!LLVMDialect: }
!LLVMDialect: omp.terminator
!LLVMDialect: }
!LLVMDialect: llvm.return
!LLVMDialect: }

program sample
use omp_lib
integer :: count = 0, double_count = 1
!$omp sections private (eta, double_count) allocate(omp_high_bw_mem_alloc: count)
!$omp section
count = 1 + 4
eta = count * double_count
!$omp section
double_count = double_count + 1
!$omp section
eta = eta - 7
count = count * eta
double_count = count - eta
!$omp end sections

!$omp sections
!$omp end sections nowait
end program sample

!FIRDialect: func @_QPfirstprivate(%[[ARG:.*]]: !fir.ref<f32> {fir.bindc_name = "alpha"}) {
!FIRDialect: omp.sections {
!FIRDialect: omp.section {
!FIRDialect: omp.terminator
!FIRDialect: }
!FIRDialect: omp.terminator
!FIRDialect: }
!FIRDialect: omp.sections {
!FIRDialect: omp.section {
!FIRDialect: %[[PRIVATE_VAR:.*]] = fir.load %[[ARG]] : !fir.ref<f32>
!FIRDialect: %[[CONSTANT:.*]] = arith.constant 5.000000e+00 : f32
!FIRDialect: %[[PRIVATE_VAR_2:.*]] = arith.mulf %[[PRIVATE_VAR]], %[[CONSTANT]] : f32
!FIRDialect: fir.store %[[PRIVATE_VAR_2]] to %[[ARG]] : !fir.ref<f32>
!FIRDialect: omp.terminator
!FIRDialect: }
!FIRDialect: omp.terminator
!FIRDialect: }
!FIRDialect: return
!FIRDialect: }

!LLVMDialect: llvm.func @_QPfirstprivate(%[[ARG:.*]]: !llvm.ptr<f32> {fir.bindc_name = "alpha"}) {
!LLVMDialect: omp.sections {
!LLVMDialect: omp.section {
!LLVMDialect: omp.terminator
!LLVMDialect: }
!LLVMDialect: omp.terminator
!LLVMDialect: }
!LLVMDialect: omp.sections {
!LLVMDialect: omp.section {
!LLVMDialect: {{.*}} = llvm.load %[[ARG]] : !llvm.ptr<f32>
!LLVMDialect: {{.*}} = llvm.mlir.constant(5.000000e+00 : f32) : f32
!LLVMDialect: {{.*}} = llvm.fmul {{.*}}, {{.*}} : f32
!LLVMDialect: llvm.store {{.*}}, %[[ARG]] : !llvm.ptr<f32>
!LLVMDialect: omp.terminator
!LLVMDialect: }
!LLVMDialect: omp.terminator
!LLVMDialect: }
!LLVMDialect: llvm.return
!LLVMDialect: }

subroutine firstprivate(alpha)
real :: alpha
!$omp sections firstprivate(alpha)
!$omp end sections

!$omp sections
alpha = alpha * 5
!$omp end sections
end subroutine

0 comments on commit 88d5289

Please sign in to comment.