Skip to content

Commit

Permalink
[OpenMP][Flang][MLIR] Lowering of OpenMP requires directive from pars…
Browse files Browse the repository at this point in the history
…e tree to MLIR

This patch implements the lowering of the OpenMP 'requires' directive
from Flang parse tree to MLIR attributes attached to the top-level
module.

Target-related 'requires' clauses are gathered and combined for each top-level
unit during semantics. Lastly, a single module-level `omp.requires` attribute
is attached to the MLIR module with that information at the end of the process.

The `atomic_default_mem_order` clause is not addressed by this patch, but
rather it will come as a separate patch and follow a different approach.

Depends on D147214, D150328, D150329 and D157983.

Differential Revision: https://reviews.llvm.org/D147218
  • Loading branch information
skatrak committed Sep 14, 2023
1 parent 094a63a commit 29aa749
Show file tree
Hide file tree
Showing 7 changed files with 254 additions and 34 deletions.
11 changes: 11 additions & 0 deletions flang/include/flang/Lower/OpenMP.h
Expand Up @@ -34,6 +34,10 @@ struct OmpEndLoopDirective;
struct OmpClauseList;
} // namespace parser

namespace semantics {
class Symbol;
} // namespace semantics

namespace lower {

class AbstractConverter;
Expand Down Expand Up @@ -62,6 +66,13 @@ fir::ConvertOp getConvertFromReductionOp(mlir::Operation *, mlir::Value);
void updateReduction(mlir::Operation *, fir::FirOpBuilder &, mlir::Value,
mlir::Value, fir::ConvertOp * = nullptr);
void removeStoreOp(mlir::Operation *, mlir::Value);

bool isOpenMPTargetConstruct(const parser::OpenMPConstruct &);
bool isOpenMPDeviceDeclareTarget(Fortran::lower::AbstractConverter &,
Fortran::lower::pft::Evaluation &,
const parser::OpenMPDeclarativeConstruct &);
void genOpenMPRequires(mlir::Operation *, const Fortran::semantics::Symbol *);

} // namespace lower
} // namespace Fortran

Expand Down
33 changes: 32 additions & 1 deletion flang/lib/Lower/Bridge.cpp
Expand Up @@ -50,6 +50,7 @@
#include "flang/Parser/parse-tree.h"
#include "flang/Runtime/iostat.h"
#include "flang/Semantics/runtime-type-info.h"
#include "flang/Semantics/symbol.h"
#include "flang/Semantics/tools.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/IR/PatternMatch.h"
Expand Down Expand Up @@ -294,20 +295,26 @@ class FirConverter : public Fortran::lower::AbstractConverter {
// that they are available before lowering any function that may use
// them.
bool hasMainProgram = false;
const Fortran::semantics::Symbol *globalOmpRequiresSymbol = nullptr;
for (Fortran::lower::pft::Program::Units &u : pft.getUnits()) {
std::visit(Fortran::common::visitors{
[&](Fortran::lower::pft::FunctionLikeUnit &f) {
if (f.isMainProgram())
hasMainProgram = true;
declareFunction(f);
if (!globalOmpRequiresSymbol)
globalOmpRequiresSymbol = f.getScope().symbol();
},
[&](Fortran::lower::pft::ModuleLikeUnit &m) {
lowerModuleDeclScope(m);
for (Fortran::lower::pft::FunctionLikeUnit &f :
m.nestedFunctions)
declareFunction(f);
},
[&](Fortran::lower::pft::BlockDataUnit &b) {},
[&](Fortran::lower::pft::BlockDataUnit &b) {
if (!globalOmpRequiresSymbol)
globalOmpRequiresSymbol = b.symTab.symbol();
},
[&](Fortran::lower::pft::CompilerDirectiveUnit &d) {},
},
u);
Expand Down Expand Up @@ -352,6 +359,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
});

finalizeOpenACCLowering();
finalizeOpenMPLowering(globalOmpRequiresSymbol);
}

/// Declare a function.
Expand Down Expand Up @@ -2347,10 +2355,19 @@ class FirConverter : public Fortran::lower::AbstractConverter {

localSymbols.popScope();
builder->restoreInsertionPoint(insertPt);

// Register if a target region was found
ompDeviceCodeFound =
ompDeviceCodeFound || Fortran::lower::isOpenMPTargetConstruct(omp);
}

void genFIR(const Fortran::parser::OpenMPDeclarativeConstruct &ompDecl) {
mlir::OpBuilder::InsertPoint insertPt = builder->saveInsertionPoint();
// Register if a declare target construct intended for a target device was
// found
ompDeviceCodeFound =
ompDeviceCodeFound ||
Fortran::lower::isOpenMPDeviceDeclareTarget(*this, getEval(), ompDecl);
genOpenMPDeclarativeConstruct(*this, getEval(), ompDecl);
for (Fortran::lower::pft::Evaluation &e : getEval().getNestedEvaluations())
genFIR(e);
Expand Down Expand Up @@ -4758,6 +4775,16 @@ class FirConverter : public Fortran::lower::AbstractConverter {
accRoutineInfos);
}

/// Performing OpenMP lowering actions that were deferred to the end of
/// lowering.
void finalizeOpenMPLowering(
const Fortran::semantics::Symbol *globalOmpRequiresSymbol) {
// Set the module attribute related to OpenMP requires directives
if (ompDeviceCodeFound)
Fortran::lower::genOpenMPRequires(getModuleOp().getOperation(),
globalOmpRequiresSymbol);
}

//===--------------------------------------------------------------------===//

Fortran::lower::LoweringBridge &bridge;
Expand Down Expand Up @@ -4804,6 +4831,10 @@ class FirConverter : public Fortran::lower::AbstractConverter {

/// Deferred OpenACC routine attachment.
Fortran::lower::AccRoutineInfoMappingList accRoutineInfos;

/// Whether an OpenMP target region or declare target function/subroutine
/// intended for device offloading has been detected
bool ompDeviceCodeFound = false;
};

} // namespace
Expand Down
174 changes: 141 additions & 33 deletions flang/lib/Lower/OpenMP.cpp
Expand Up @@ -78,9 +78,7 @@ static void genObjectList(const Fortran::parser::OmpObjectList &objectList,
static void gatherFuncAndVarSyms(
const Fortran::parser::OmpObjectList &objList,
mlir::omp::DeclareTargetCaptureClause clause,
llvm::SmallVectorImpl<std::pair<mlir::omp::DeclareTargetCaptureClause,
Fortran::semantics::Symbol>>
&symbolAndClause) {
llvm::SmallVectorImpl<DeclareTargetCapturePair> &symbolAndClause) {
for (const Fortran::parser::OmpObject &ompObject : objList.v) {
Fortran::common::visit(
Fortran::common::visitors{
Expand Down Expand Up @@ -2453,6 +2451,71 @@ genTeamsOp(Fortran::lower::AbstractConverter &converter,
reductionDeclSymbols));
}

/// Extract the list of function and variable symbols affected by the given
/// 'declare target' directive and return the intended device type for them.
static mlir::omp::DeclareTargetDeviceType getDeclareTargetInfo(
Fortran::lower::AbstractConverter &converter,
Fortran::lower::pft::Evaluation &eval,
const Fortran::parser::OpenMPDeclareTargetConstruct &declareTargetConstruct,
llvm::SmallVectorImpl<DeclareTargetCapturePair> &symbolAndClause) {

// The default capture type
mlir::omp::DeclareTargetDeviceType deviceType =
mlir::omp::DeclareTargetDeviceType::any;
const auto &spec = std::get<Fortran::parser::OmpDeclareTargetSpecifier>(
declareTargetConstruct.t);

if (const auto *objectList{
Fortran::parser::Unwrap<Fortran::parser::OmpObjectList>(spec.u)}) {
// Case: declare target(func, var1, var2)
gatherFuncAndVarSyms(*objectList, mlir::omp::DeclareTargetCaptureClause::to,
symbolAndClause);
} else if (const auto *clauseList{
Fortran::parser::Unwrap<Fortran::parser::OmpClauseList>(
spec.u)}) {
if (clauseList->v.empty()) {
// Case: declare target, implicit capture of function
symbolAndClause.emplace_back(
mlir::omp::DeclareTargetCaptureClause::to,
eval.getOwningProcedure()->getSubprogramSymbol());
}

ClauseProcessor cp(converter, *clauseList);
cp.processTo(symbolAndClause);
cp.processLink(symbolAndClause);
cp.processDeviceType(deviceType);
cp.processTODO<Fortran::parser::OmpClause::Indirect>(
converter.getCurrentLocation(),
llvm::omp::Directive::OMPD_declare_target);
}

return deviceType;
}

static std::optional<mlir::omp::DeclareTargetDeviceType>
getDeclareTargetFunctionDevice(
Fortran::lower::AbstractConverter &converter,
Fortran::lower::pft::Evaluation &eval,
const Fortran::parser::OpenMPDeclareTargetConstruct
&declareTargetConstruct) {
llvm::SmallVector<DeclareTargetCapturePair, 0> symbolAndClause;
mlir::omp::DeclareTargetDeviceType deviceType = getDeclareTargetInfo(
converter, eval, declareTargetConstruct, symbolAndClause);

// Return the device type only if at least one of the targets for the
// directive is a function or subroutine
mlir::ModuleOp mod = converter.getFirOpBuilder().getModule();
for (const DeclareTargetCapturePair &symClause : symbolAndClause) {
mlir::Operation *op = mod.lookupSymbol(
converter.mangleName(std::get<Fortran::semantics::Symbol>(symClause)));

if (mlir::isa<mlir::func::FuncOp>(op))
return deviceType;
}

return std::nullopt;
}

//===----------------------------------------------------------------------===//
// genOMP() Code generation helper functions
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -2973,35 +3036,8 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
&declareTargetConstruct) {
llvm::SmallVector<DeclareTargetCapturePair, 0> symbolAndClause;
mlir::ModuleOp mod = converter.getFirOpBuilder().getModule();

// The default capture type
mlir::omp::DeclareTargetDeviceType deviceType =
mlir::omp::DeclareTargetDeviceType::any;
const auto &spec = std::get<Fortran::parser::OmpDeclareTargetSpecifier>(
declareTargetConstruct.t);
if (const auto *objectList{
Fortran::parser::Unwrap<Fortran::parser::OmpObjectList>(spec.u)}) {
// Case: declare target(func, var1, var2)
gatherFuncAndVarSyms(*objectList, mlir::omp::DeclareTargetCaptureClause::to,
symbolAndClause);
} else if (const auto *clauseList{
Fortran::parser::Unwrap<Fortran::parser::OmpClauseList>(
spec.u)}) {
if (clauseList->v.empty()) {
// Case: declare target, implicit capture of function
symbolAndClause.emplace_back(
mlir::omp::DeclareTargetCaptureClause::to,
eval.getOwningProcedure()->getSubprogramSymbol());
}

ClauseProcessor cp(converter, *clauseList);
cp.processTo(symbolAndClause);
cp.processLink(symbolAndClause);
cp.processDeviceType(deviceType);
cp.processTODO<Fortran::parser::OmpClause::Indirect>(
converter.getCurrentLocation(),
llvm::omp::Directive::OMPD_declare_target);
}
mlir::omp::DeclareTargetDeviceType deviceType = getDeclareTargetInfo(
converter, eval, declareTargetConstruct, symbolAndClause);

for (const DeclareTargetCapturePair &symClause : symbolAndClause) {
mlir::Operation *op = mod.lookupSymbol(
Expand Down Expand Up @@ -3130,7 +3166,10 @@ void Fortran::lower::genOpenMPDeclarativeConstruct(
},
[&](const Fortran::parser::OpenMPRequiresConstruct
&requiresConstruct) {
TODO(converter.getCurrentLocation(), "OpenMPRequiresConstruct");
// Requires directives are gathered and processed in semantics and
// then combined in the lowering bridge before triggering codegen
// just once. Hence, there is no need to lower each individual
// occurrence here.
},
[&](const Fortran::parser::OpenMPThreadprivate &threadprivate) {
// The directive is lowered when instantiating the variable to
Expand Down Expand Up @@ -3444,3 +3483,72 @@ void Fortran::lower::removeStoreOp(mlir::Operation *reductionOp,
}
}
}

bool Fortran::lower::isOpenMPTargetConstruct(
const Fortran::parser::OpenMPConstruct &omp) {
llvm::omp::Directive dir = llvm::omp::Directive::OMPD_unknown;
if (const auto *block =
std::get_if<Fortran::parser::OpenMPBlockConstruct>(&omp.u)) {
const auto &begin =
std::get<Fortran::parser::OmpBeginBlockDirective>(block->t);
dir = std::get<Fortran::parser::OmpBlockDirective>(begin.t).v;
} else if (const auto *loop =
std::get_if<Fortran::parser::OpenMPLoopConstruct>(&omp.u)) {
const auto &begin =
std::get<Fortran::parser::OmpBeginLoopDirective>(loop->t);
dir = std::get<Fortran::parser::OmpLoopDirective>(begin.t).v;
}
return llvm::omp::allTargetSet.test(dir);
}

bool Fortran::lower::isOpenMPDeviceDeclareTarget(
Fortran::lower::AbstractConverter &converter,
Fortran::lower::pft::Evaluation &eval,
const Fortran::parser::OpenMPDeclarativeConstruct &ompDecl) {
return std::visit(
Fortran::common::visitors{
[&](const Fortran::parser::OpenMPDeclareTargetConstruct &ompReq) {
mlir::omp::DeclareTargetDeviceType targetType =
getDeclareTargetFunctionDevice(converter, eval, ompReq)
.value_or(mlir::omp::DeclareTargetDeviceType::host);
return targetType != mlir::omp::DeclareTargetDeviceType::host;
},
[&](const auto &) { return false; },
},
ompDecl.u);
}

void Fortran::lower::genOpenMPRequires(
mlir::Operation *mod, const Fortran::semantics::Symbol *symbol) {
using MlirRequires = mlir::omp::ClauseRequires;
using SemaRequires = Fortran::semantics::WithOmpDeclarative::RequiresFlag;

if (auto offloadMod =
llvm::dyn_cast<mlir::omp::OffloadModuleInterface>(mod)) {
Fortran::semantics::WithOmpDeclarative::RequiresFlags semaFlags;
if (symbol) {
Fortran::common::visit(
[&](const auto &details) {
if constexpr (std::is_base_of_v<
Fortran::semantics::WithOmpDeclarative,
std::decay_t<decltype(details)>>) {
if (details.has_ompRequires())
semaFlags = *details.ompRequires();
}
},
symbol->details());
}

MlirRequires mlirFlags = MlirRequires::none;
if (semaFlags.test(SemaRequires::ReverseOffload))
mlirFlags = mlirFlags | MlirRequires::reverse_offload;
if (semaFlags.test(SemaRequires::UnifiedAddress))
mlirFlags = mlirFlags | MlirRequires::unified_address;
if (semaFlags.test(SemaRequires::UnifiedSharedMemory))
mlirFlags = mlirFlags | MlirRequires::unified_shared_memory;
if (semaFlags.test(SemaRequires::DynamicAllocators))
mlirFlags = mlirFlags | MlirRequires::dynamic_allocators;

offloadMod.setRequires(mlirFlags);
}
}
23 changes: 23 additions & 0 deletions flang/test/Lower/OpenMP/Todo/requires-unnamed-common.f90
@@ -0,0 +1,23 @@
! This test checks the lowering of REQUIRES inside of an unnamed BLOCK DATA.
! The symbol of the `symTab` scope of the `BlockDataUnit` PFT node is null in
! this case, resulting in the inability to store the REQUIRES flags gathered in
! it.

! RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | FileCheck %s
! RUN: %flang_fc1 -emit-fir -fopenmp -fopenmp-is-target-device %s -o - | FileCheck %s
! RUN: bbc -fopenmp -emit-fir %s -o - | FileCheck %s
! RUN: bbc -fopenmp -fopenmp-is-target-device -emit-fir %s -o - | FileCheck %s
! XFAIL: *

!CHECK: module attributes {
!CHECK-SAME: omp.requires = #omp<clause_requires unified_shared_memory>
block data
!$omp requires unified_shared_memory
integer :: x
common /block/ x
data x / 10 /
end

subroutine f
!$omp declare target
end subroutine f
19 changes: 19 additions & 0 deletions flang/test/Lower/OpenMP/requires-common.f90
@@ -0,0 +1,19 @@
! RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | FileCheck %s
! RUN: %flang_fc1 -emit-fir -fopenmp -fopenmp-is-target-device %s -o - | FileCheck %s
! RUN: bbc -fopenmp -emit-fir %s -o - | FileCheck %s
! RUN: bbc -fopenmp -fopenmp-is-target-device -emit-fir %s -o - | FileCheck %s

! This test checks the lowering of requires into MLIR

!CHECK: module attributes {
!CHECK-SAME: omp.requires = #omp<clause_requires unified_shared_memory>
block data init
!$omp requires unified_shared_memory
integer :: x
common /block/ x
data x / 10 /
end

subroutine f
!$omp declare target
end subroutine f
14 changes: 14 additions & 0 deletions flang/test/Lower/OpenMP/requires-notarget.f90
@@ -0,0 +1,14 @@
! RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | FileCheck %s
! RUN: %flang_fc1 -emit-fir -fopenmp -fopenmp-is-target-device %s -o - | FileCheck %s
! RUN: bbc -fopenmp -emit-fir %s -o - | FileCheck %s
! RUN: bbc -fopenmp -fopenmp-is-target-device -emit-fir %s -o - | FileCheck %s

! This test checks that requires lowering into MLIR skips creating the
! omp.requires attribute with target-related clauses if there are no device
! functions in the compilation unit

!CHECK: module attributes {
!CHECK-NOT: omp.requires
program requires
!$omp requires unified_shared_memory reverse_offload atomic_default_mem_order(seq_cst)
end program requires
14 changes: 14 additions & 0 deletions flang/test/Lower/OpenMP/requires.f90
@@ -0,0 +1,14 @@
! RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | FileCheck %s
! RUN: %flang_fc1 -emit-fir -fopenmp -fopenmp-is-target-device %s -o - | FileCheck %s
! RUN: bbc -fopenmp -emit-fir %s -o - | FileCheck %s
! RUN: bbc -fopenmp -fopenmp-is-target-device -emit-fir %s -o - | FileCheck %s

! This test checks the lowering of requires into MLIR

!CHECK: module attributes {
!CHECK-SAME: omp.requires = #omp<clause_requires reverse_offload|unified_shared_memory>
program requires
!$omp requires unified_shared_memory reverse_offload atomic_default_mem_order(seq_cst)
!$omp target
!$omp end target
end program requires

0 comments on commit 29aa749

Please sign in to comment.