diff --git a/flang/lib/Lower/CMakeLists.txt b/flang/lib/Lower/CMakeLists.txt index b13d415e02f1d..5577a60f1daea 100644 --- a/flang/lib/Lower/CMakeLists.txt +++ b/flang/lib/Lower/CMakeLists.txt @@ -24,7 +24,11 @@ add_flang_library(FortranLower LoweringOptions.cpp Mangler.cpp OpenACC.cpp - OpenMP.cpp + OpenMP/ClauseProcessor.cpp + OpenMP/DataSharingProcessor.cpp + OpenMP/OpenMP.cpp + OpenMP/ReductionProcessor.cpp + OpenMP/Utils.cpp PFTBuilder.cpp Runtime.cpp SymbolMap.cpp diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp new file mode 100644 index 0000000000000..4e3951492fb65 --- /dev/null +++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp @@ -0,0 +1,880 @@ +//===-- ClauseProcessor.cpp -------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ +// +//===----------------------------------------------------------------------===// + +#include "ClauseProcessor.h" + +#include "flang/Lower/PFTBuilder.h" +#include "flang/Parser/tools.h" +#include "flang/Semantics/tools.h" + +namespace Fortran { +namespace lower { +namespace omp { + +/// Check for unsupported map operand types. +static void checkMapType(mlir::Location location, mlir::Type type) { + if (auto refType = type.dyn_cast()) + type = refType.getElementType(); + if (auto boxType = type.dyn_cast_or_null()) + if (!boxType.getElementType().isa()) + TODO(location, "OMPD_target_data MapOperand BoxType"); +} + +static mlir::omp::ScheduleModifier +translateScheduleModifier(const Fortran::parser::OmpScheduleModifierType &m) { + switch (m.v) { + case Fortran::parser::OmpScheduleModifierType::ModType::Monotonic: + return mlir::omp::ScheduleModifier::monotonic; + case Fortran::parser::OmpScheduleModifierType::ModType::Nonmonotonic: + return mlir::omp::ScheduleModifier::nonmonotonic; + case Fortran::parser::OmpScheduleModifierType::ModType::Simd: + return mlir::omp::ScheduleModifier::simd; + } + return mlir::omp::ScheduleModifier::none; +} + +static mlir::omp::ScheduleModifier +getScheduleModifier(const Fortran::parser::OmpScheduleClause &x) { + const auto &modifier = + std::get>(x.t); + // The input may have the modifier any order, so we look for one that isn't + // SIMD. If modifier is not set at all, fall down to the bottom and return + // "none". + if (modifier) { + const auto &modType1 = + std::get(modifier->t); + if (modType1.v.v == + Fortran::parser::OmpScheduleModifierType::ModType::Simd) { + const auto &modType2 = std::get< + std::optional>( + modifier->t); + if (modType2 && + modType2->v.v != + Fortran::parser::OmpScheduleModifierType::ModType::Simd) + return translateScheduleModifier(modType2->v); + + return mlir::omp::ScheduleModifier::none; + } + + return translateScheduleModifier(modType1.v); + } + return mlir::omp::ScheduleModifier::none; +} + +static mlir::omp::ScheduleModifier +getSimdModifier(const Fortran::parser::OmpScheduleClause &x) { + const auto &modifier = + std::get>(x.t); + // Either of the two possible modifiers in the input can be the SIMD modifier, + // so look in either one, and return simd if we find one. Not found = return + // "none". + if (modifier) { + const auto &modType1 = + std::get(modifier->t); + if (modType1.v.v == Fortran::parser::OmpScheduleModifierType::ModType::Simd) + return mlir::omp::ScheduleModifier::simd; + + const auto &modType2 = std::get< + std::optional>( + modifier->t); + if (modType2 && modType2->v.v == + Fortran::parser::OmpScheduleModifierType::ModType::Simd) + return mlir::omp::ScheduleModifier::simd; + } + return mlir::omp::ScheduleModifier::none; +} + +static void +genAllocateClause(Fortran::lower::AbstractConverter &converter, + const Fortran::parser::OmpAllocateClause &ompAllocateClause, + llvm::SmallVectorImpl &allocatorOperands, + llvm::SmallVectorImpl &allocateOperands) { + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + mlir::Location currentLocation = converter.getCurrentLocation(); + Fortran::lower::StatementContext stmtCtx; + + mlir::Value allocatorOperand; + const Fortran::parser::OmpObjectList &ompObjectList = + std::get(ompAllocateClause.t); + const auto &allocateModifier = std::get< + std::optional>( + ompAllocateClause.t); + + // If the allocate modifier is present, check if we only use the allocator + // submodifier. ALIGN in this context is unimplemented + const bool onlyAllocator = + allocateModifier && + std::holds_alternative< + Fortran::parser::OmpAllocateClause::AllocateModifier::Allocator>( + allocateModifier->u); + + if (allocateModifier && !onlyAllocator) { + TODO(currentLocation, "OmpAllocateClause ALIGN modifier"); + } + + // Check if allocate clause has allocator specified. If so, add it + // to list of allocators, otherwise, add default allocator to + // list of allocators. + if (onlyAllocator) { + const auto &allocatorValue = std::get< + Fortran::parser::OmpAllocateClause::AllocateModifier::Allocator>( + allocateModifier->u); + allocatorOperand = fir::getBase(converter.genExprValue( + *Fortran::semantics::GetExpr(allocatorValue.v), stmtCtx)); + allocatorOperands.insert(allocatorOperands.end(), ompObjectList.v.size(), + allocatorOperand); + } else { + allocatorOperand = firOpBuilder.createIntegerConstant( + currentLocation, firOpBuilder.getI32Type(), 1); + allocatorOperands.insert(allocatorOperands.end(), ompObjectList.v.size(), + allocatorOperand); + } + genObjectList(ompObjectList, converter, allocateOperands); +} + +static mlir::omp::ClauseProcBindKindAttr genProcBindKindAttr( + fir::FirOpBuilder &firOpBuilder, + const Fortran::parser::OmpClause::ProcBind *procBindClause) { + mlir::omp::ClauseProcBindKind procBindKind; + switch (procBindClause->v.v) { + case Fortran::parser::OmpProcBindClause::Type::Master: + procBindKind = mlir::omp::ClauseProcBindKind::Master; + break; + case Fortran::parser::OmpProcBindClause::Type::Close: + procBindKind = mlir::omp::ClauseProcBindKind::Close; + break; + case Fortran::parser::OmpProcBindClause::Type::Spread: + procBindKind = mlir::omp::ClauseProcBindKind::Spread; + break; + case Fortran::parser::OmpProcBindClause::Type::Primary: + procBindKind = mlir::omp::ClauseProcBindKind::Primary; + break; + } + return mlir::omp::ClauseProcBindKindAttr::get(firOpBuilder.getContext(), + procBindKind); +} + +static mlir::omp::ClauseTaskDependAttr +genDependKindAttr(fir::FirOpBuilder &firOpBuilder, + const Fortran::parser::OmpClause::Depend *dependClause) { + mlir::omp::ClauseTaskDepend pbKind; + switch ( + std::get( + std::get(dependClause->v.u) + .t) + .v) { + case Fortran::parser::OmpDependenceType::Type::In: + pbKind = mlir::omp::ClauseTaskDepend::taskdependin; + break; + case Fortran::parser::OmpDependenceType::Type::Out: + pbKind = mlir::omp::ClauseTaskDepend::taskdependout; + break; + case Fortran::parser::OmpDependenceType::Type::Inout: + pbKind = mlir::omp::ClauseTaskDepend::taskdependinout; + break; + default: + llvm_unreachable("unknown parser task dependence type"); + break; + } + return mlir::omp::ClauseTaskDependAttr::get(firOpBuilder.getContext(), + pbKind); +} + +static mlir::Value getIfClauseOperand( + Fortran::lower::AbstractConverter &converter, + const Fortran::parser::OmpClause::If *ifClause, + Fortran::parser::OmpIfClause::DirectiveNameModifier directiveName, + mlir::Location clauseLocation) { + // Only consider the clause if it's intended for the given directive. + auto &directive = std::get< + std::optional>( + ifClause->v.t); + if (directive && directive.value() != directiveName) + return nullptr; + + Fortran::lower::StatementContext stmtCtx; + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + auto &expr = std::get(ifClause->v.t); + mlir::Value ifVal = fir::getBase( + converter.genExprValue(*Fortran::semantics::GetExpr(expr), stmtCtx)); + return firOpBuilder.createConvert(clauseLocation, firOpBuilder.getI1Type(), + ifVal); +} + +static void +addUseDeviceClause(Fortran::lower::AbstractConverter &converter, + const Fortran::parser::OmpObjectList &useDeviceClause, + llvm::SmallVectorImpl &operands, + llvm::SmallVectorImpl &useDeviceTypes, + llvm::SmallVectorImpl &useDeviceLocs, + llvm::SmallVectorImpl + &useDeviceSymbols) { + genObjectList(useDeviceClause, converter, operands); + for (mlir::Value &operand : operands) { + checkMapType(operand.getLoc(), operand.getType()); + useDeviceTypes.push_back(operand.getType()); + useDeviceLocs.push_back(operand.getLoc()); + } + for (const Fortran::parser::OmpObject &ompObject : useDeviceClause.v) { + Fortran::semantics::Symbol *sym = getOmpObjectSymbol(ompObject); + useDeviceSymbols.push_back(sym); + } +} + +//===----------------------------------------------------------------------===// +// ClauseProcessor unique clauses +//===----------------------------------------------------------------------===// + +bool ClauseProcessor::processCollapse( + mlir::Location currentLocation, Fortran::lower::pft::Evaluation &eval, + llvm::SmallVectorImpl &lowerBound, + llvm::SmallVectorImpl &upperBound, + llvm::SmallVectorImpl &step, + llvm::SmallVectorImpl &iv, + std::size_t &loopVarTypeSize) const { + bool found = false; + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + + // Collect the loops to collapse. + Fortran::lower::pft::Evaluation *doConstructEval = + &eval.getFirstNestedEvaluation(); + if (doConstructEval->getIf() + ->IsDoConcurrent()) { + TODO(currentLocation, "Do Concurrent in Worksharing loop construct"); + } + + std::int64_t collapseValue = 1l; + if (auto *collapseClause = findUniqueClause()) { + const auto *expr = Fortran::semantics::GetExpr(collapseClause->v); + collapseValue = Fortran::evaluate::ToInt64(*expr).value(); + found = true; + } + + loopVarTypeSize = 0; + do { + Fortran::lower::pft::Evaluation *doLoop = + &doConstructEval->getFirstNestedEvaluation(); + auto *doStmt = doLoop->getIf(); + assert(doStmt && "Expected do loop to be in the nested evaluation"); + const auto &loopControl = + std::get>(doStmt->t); + const Fortran::parser::LoopControl::Bounds *bounds = + std::get_if(&loopControl->u); + assert(bounds && "Expected bounds for worksharing do loop"); + Fortran::lower::StatementContext stmtCtx; + lowerBound.push_back(fir::getBase(converter.genExprValue( + *Fortran::semantics::GetExpr(bounds->lower), stmtCtx))); + upperBound.push_back(fir::getBase(converter.genExprValue( + *Fortran::semantics::GetExpr(bounds->upper), stmtCtx))); + if (bounds->step) { + step.push_back(fir::getBase(converter.genExprValue( + *Fortran::semantics::GetExpr(bounds->step), stmtCtx))); + } else { // If `step` is not present, assume it as `1`. + step.push_back(firOpBuilder.createIntegerConstant( + currentLocation, firOpBuilder.getIntegerType(32), 1)); + } + iv.push_back(bounds->name.thing.symbol); + loopVarTypeSize = std::max(loopVarTypeSize, + bounds->name.thing.symbol->GetUltimate().size()); + collapseValue--; + doConstructEval = + &*std::next(doConstructEval->getNestedEvaluations().begin()); + } while (collapseValue > 0); + + return found; +} + +bool ClauseProcessor::processDefault() const { + if (auto *defaultClause = findUniqueClause()) { + // Private, Firstprivate, Shared, None + switch (defaultClause->v.v) { + case Fortran::parser::OmpDefaultClause::Type::Shared: + case Fortran::parser::OmpDefaultClause::Type::None: + // Default clause with shared or none do not require any handling since + // Shared is the default behavior in the IR and None is only required + // for semantic checks. + break; + case Fortran::parser::OmpDefaultClause::Type::Private: + // TODO Support default(private) + break; + case Fortran::parser::OmpDefaultClause::Type::Firstprivate: + // TODO Support default(firstprivate) + break; + } + return true; + } + return false; +} + +bool ClauseProcessor::processDevice(Fortran::lower::StatementContext &stmtCtx, + mlir::Value &result) const { + const Fortran::parser::CharBlock *source = nullptr; + if (auto *deviceClause = findUniqueClause(&source)) { + mlir::Location clauseLocation = converter.genLocation(*source); + if (auto deviceModifier = std::get< + std::optional>( + deviceClause->v.t)) { + if (deviceModifier == + Fortran::parser::OmpDeviceClause::DeviceModifier::Ancestor) { + TODO(clauseLocation, "OMPD_target Device Modifier Ancestor"); + } + } + if (const auto *deviceExpr = Fortran::semantics::GetExpr( + std::get(deviceClause->v.t))) { + result = fir::getBase(converter.genExprValue(*deviceExpr, stmtCtx)); + } + return true; + } + return false; +} + +bool ClauseProcessor::processDeviceType( + mlir::omp::DeclareTargetDeviceType &result) const { + if (auto *deviceTypeClause = findUniqueClause()) { + // Case: declare target ... device_type(any | host | nohost) + switch (deviceTypeClause->v.v) { + case Fortran::parser::OmpDeviceTypeClause::Type::Nohost: + result = mlir::omp::DeclareTargetDeviceType::nohost; + break; + case Fortran::parser::OmpDeviceTypeClause::Type::Host: + result = mlir::omp::DeclareTargetDeviceType::host; + break; + case Fortran::parser::OmpDeviceTypeClause::Type::Any: + result = mlir::omp::DeclareTargetDeviceType::any; + break; + } + return true; + } + return false; +} + +bool ClauseProcessor::processFinal(Fortran::lower::StatementContext &stmtCtx, + mlir::Value &result) const { + const Fortran::parser::CharBlock *source = nullptr; + if (auto *finalClause = findUniqueClause(&source)) { + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + mlir::Location clauseLocation = converter.genLocation(*source); + + mlir::Value finalVal = fir::getBase(converter.genExprValue( + *Fortran::semantics::GetExpr(finalClause->v), stmtCtx)); + result = firOpBuilder.createConvert(clauseLocation, + firOpBuilder.getI1Type(), finalVal); + return true; + } + return false; +} + +bool ClauseProcessor::processHint(mlir::IntegerAttr &result) const { + if (auto *hintClause = findUniqueClause()) { + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + const auto *expr = Fortran::semantics::GetExpr(hintClause->v); + int64_t hintValue = *Fortran::evaluate::ToInt64(*expr); + result = firOpBuilder.getI64IntegerAttr(hintValue); + return true; + } + return false; +} + +bool ClauseProcessor::processMergeable(mlir::UnitAttr &result) const { + return markClauseOccurrence(result); +} + +bool ClauseProcessor::processNowait(mlir::UnitAttr &result) const { + return markClauseOccurrence(result); +} + +bool ClauseProcessor::processNumTeams(Fortran::lower::StatementContext &stmtCtx, + mlir::Value &result) const { + // TODO Get lower and upper bounds for num_teams when parser is updated to + // accept both. + if (auto *numTeamsClause = findUniqueClause()) { + result = fir::getBase(converter.genExprValue( + *Fortran::semantics::GetExpr(numTeamsClause->v), stmtCtx)); + return true; + } + return false; +} + +bool ClauseProcessor::processNumThreads( + Fortran::lower::StatementContext &stmtCtx, mlir::Value &result) const { + if (auto *numThreadsClause = findUniqueClause()) { + // OMPIRBuilder expects `NUM_THREADS` clause as a `Value`. + result = fir::getBase(converter.genExprValue( + *Fortran::semantics::GetExpr(numThreadsClause->v), stmtCtx)); + return true; + } + return false; +} + +bool ClauseProcessor::processOrdered(mlir::IntegerAttr &result) const { + if (auto *orderedClause = findUniqueClause()) { + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + int64_t orderedClauseValue = 0l; + if (orderedClause->v.has_value()) { + const auto *expr = Fortran::semantics::GetExpr(orderedClause->v); + orderedClauseValue = *Fortran::evaluate::ToInt64(*expr); + } + result = firOpBuilder.getI64IntegerAttr(orderedClauseValue); + return true; + } + return false; +} + +bool ClauseProcessor::processPriority(Fortran::lower::StatementContext &stmtCtx, + mlir::Value &result) const { + if (auto *priorityClause = findUniqueClause()) { + result = fir::getBase(converter.genExprValue( + *Fortran::semantics::GetExpr(priorityClause->v), stmtCtx)); + return true; + } + return false; +} + +bool ClauseProcessor::processProcBind( + mlir::omp::ClauseProcBindKindAttr &result) const { + if (auto *procBindClause = findUniqueClause()) { + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + result = genProcBindKindAttr(firOpBuilder, procBindClause); + return true; + } + return false; +} + +bool ClauseProcessor::processSafelen(mlir::IntegerAttr &result) const { + if (auto *safelenClause = findUniqueClause()) { + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + const auto *expr = Fortran::semantics::GetExpr(safelenClause->v); + const std::optional safelenVal = + Fortran::evaluate::ToInt64(*expr); + result = firOpBuilder.getI64IntegerAttr(*safelenVal); + return true; + } + return false; +} + +bool ClauseProcessor::processSchedule( + mlir::omp::ClauseScheduleKindAttr &valAttr, + mlir::omp::ScheduleModifierAttr &modifierAttr, + mlir::UnitAttr &simdModifierAttr) const { + if (auto *scheduleClause = findUniqueClause()) { + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + mlir::MLIRContext *context = firOpBuilder.getContext(); + const Fortran::parser::OmpScheduleClause &scheduleType = scheduleClause->v; + const auto &scheduleClauseKind = + std::get( + scheduleType.t); + + mlir::omp::ClauseScheduleKind scheduleKind; + switch (scheduleClauseKind) { + case Fortran::parser::OmpScheduleClause::ScheduleType::Static: + scheduleKind = mlir::omp::ClauseScheduleKind::Static; + break; + case Fortran::parser::OmpScheduleClause::ScheduleType::Dynamic: + scheduleKind = mlir::omp::ClauseScheduleKind::Dynamic; + break; + case Fortran::parser::OmpScheduleClause::ScheduleType::Guided: + scheduleKind = mlir::omp::ClauseScheduleKind::Guided; + break; + case Fortran::parser::OmpScheduleClause::ScheduleType::Auto: + scheduleKind = mlir::omp::ClauseScheduleKind::Auto; + break; + case Fortran::parser::OmpScheduleClause::ScheduleType::Runtime: + scheduleKind = mlir::omp::ClauseScheduleKind::Runtime; + break; + } + + mlir::omp::ScheduleModifier scheduleModifier = + getScheduleModifier(scheduleClause->v); + + if (scheduleModifier != mlir::omp::ScheduleModifier::none) + modifierAttr = + mlir::omp::ScheduleModifierAttr::get(context, scheduleModifier); + + if (getSimdModifier(scheduleClause->v) != mlir::omp::ScheduleModifier::none) + simdModifierAttr = firOpBuilder.getUnitAttr(); + + valAttr = mlir::omp::ClauseScheduleKindAttr::get(context, scheduleKind); + return true; + } + return false; +} + +bool ClauseProcessor::processScheduleChunk( + Fortran::lower::StatementContext &stmtCtx, mlir::Value &result) const { + if (auto *scheduleClause = findUniqueClause()) { + if (const auto &chunkExpr = + std::get>( + scheduleClause->v.t)) { + if (const auto *expr = Fortran::semantics::GetExpr(*chunkExpr)) { + result = fir::getBase(converter.genExprValue(*expr, stmtCtx)); + } + } + return true; + } + return false; +} + +bool ClauseProcessor::processSimdlen(mlir::IntegerAttr &result) const { + if (auto *simdlenClause = findUniqueClause()) { + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + const auto *expr = Fortran::semantics::GetExpr(simdlenClause->v); + const std::optional simdlenVal = + Fortran::evaluate::ToInt64(*expr); + result = firOpBuilder.getI64IntegerAttr(*simdlenVal); + return true; + } + return false; +} + +bool ClauseProcessor::processThreadLimit( + Fortran::lower::StatementContext &stmtCtx, mlir::Value &result) const { + if (auto *threadLmtClause = findUniqueClause()) { + result = fir::getBase(converter.genExprValue( + *Fortran::semantics::GetExpr(threadLmtClause->v), stmtCtx)); + return true; + } + return false; +} + +bool ClauseProcessor::processUntied(mlir::UnitAttr &result) const { + return markClauseOccurrence(result); +} + +//===----------------------------------------------------------------------===// +// ClauseProcessor repeatable clauses +//===----------------------------------------------------------------------===// + +bool ClauseProcessor::processAllocate( + llvm::SmallVectorImpl &allocatorOperands, + llvm::SmallVectorImpl &allocateOperands) const { + return findRepeatableClause( + [&](const ClauseTy::Allocate *allocateClause, + const Fortran::parser::CharBlock &) { + genAllocateClause(converter, allocateClause->v, allocatorOperands, + allocateOperands); + }); +} + +bool ClauseProcessor::processCopyin() const { + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + mlir::OpBuilder::InsertPoint insPt = firOpBuilder.saveInsertionPoint(); + firOpBuilder.setInsertionPointToStart(firOpBuilder.getAllocaBlock()); + auto checkAndCopyHostAssociateVar = + [&](Fortran::semantics::Symbol *sym, + mlir::OpBuilder::InsertPoint *copyAssignIP = nullptr) { + assert(sym->has() && + "No host-association found"); + if (converter.isPresentShallowLookup(*sym)) + converter.copyHostAssociateVar(*sym, copyAssignIP); + }; + bool hasCopyin = findRepeatableClause( + [&](const ClauseTy::Copyin *copyinClause, + const Fortran::parser::CharBlock &) { + const Fortran::parser::OmpObjectList &ompObjectList = copyinClause->v; + for (const Fortran::parser::OmpObject &ompObject : ompObjectList.v) { + Fortran::semantics::Symbol *sym = getOmpObjectSymbol(ompObject); + if (const auto *commonDetails = + sym->detailsIf()) { + for (const auto &mem : commonDetails->objects()) + checkAndCopyHostAssociateVar(&*mem, &insPt); + break; + } + if (Fortran::semantics::IsAllocatableOrObjectPointer( + &sym->GetUltimate())) + TODO(converter.getCurrentLocation(), + "pointer or allocatable variables in Copyin clause"); + assert(sym->has() && + "No host-association found"); + checkAndCopyHostAssociateVar(sym); + } + }); + + // [OMP 5.0, 2.19.6.1] The copy is done after the team is formed and prior to + // the execution of the associated structured block. Emit implicit barrier to + // synchronize threads and avoid data races on propagation master's thread + // values of threadprivate variables to local instances of that variables of + // all other implicit threads. + if (hasCopyin) + firOpBuilder.create(converter.getCurrentLocation()); + firOpBuilder.restoreInsertionPoint(insPt); + return hasCopyin; +} + +bool ClauseProcessor::processDepend( + llvm::SmallVectorImpl &dependTypeOperands, + llvm::SmallVectorImpl &dependOperands) const { + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + + return findRepeatableClause( + [&](const ClauseTy::Depend *dependClause, + const Fortran::parser::CharBlock &) { + const std::list &depVal = + std::get>( + std::get( + dependClause->v.u) + .t); + mlir::omp::ClauseTaskDependAttr dependTypeOperand = + genDependKindAttr(firOpBuilder, dependClause); + dependTypeOperands.insert(dependTypeOperands.end(), depVal.size(), + dependTypeOperand); + for (const Fortran::parser::Designator &ompObject : depVal) { + Fortran::semantics::Symbol *sym = nullptr; + std::visit( + Fortran::common::visitors{ + [&](const Fortran::parser::DataRef &designator) { + if (const Fortran::parser::Name *name = + std::get_if(&designator.u)) { + sym = name->symbol; + } else if (std::get_if>( + &designator.u)) { + TODO(converter.getCurrentLocation(), + "array sections not supported for task depend"); + } + }, + [&](const Fortran::parser::Substring &designator) { + TODO(converter.getCurrentLocation(), + "substring not supported for task depend"); + }}, + (ompObject).u); + const mlir::Value variable = converter.getSymbolAddress(*sym); + dependOperands.push_back(variable); + } + }); +} + +bool ClauseProcessor::processIf( + Fortran::parser::OmpIfClause::DirectiveNameModifier directiveName, + mlir::Value &result) const { + bool found = false; + findRepeatableClause( + [&](const ClauseTy::If *ifClause, + const Fortran::parser::CharBlock &source) { + mlir::Location clauseLocation = converter.genLocation(source); + mlir::Value operand = getIfClauseOperand(converter, ifClause, + directiveName, clauseLocation); + // Assume that, at most, a single 'if' clause will be applicable to the + // given directive. + if (operand) { + result = operand; + found = true; + } + }); + return found; +} + +bool ClauseProcessor::processLink( + llvm::SmallVectorImpl &result) const { + return findRepeatableClause( + [&](const ClauseTy::Link *linkClause, + const Fortran::parser::CharBlock &) { + // Case: declare target link(var1, var2)... + gatherFuncAndVarSyms( + linkClause->v, mlir::omp::DeclareTargetCaptureClause::link, result); + }); +} + +mlir::omp::MapInfoOp +createMapInfoOp(fir::FirOpBuilder &builder, mlir::Location loc, + mlir::Value baseAddr, mlir::Value varPtrPtr, std::string name, + mlir::SmallVector bounds, + mlir::SmallVector members, uint64_t mapType, + mlir::omp::VariableCaptureKind mapCaptureType, mlir::Type retTy, + bool isVal) { + if (auto boxTy = baseAddr.getType().dyn_cast()) { + baseAddr = builder.create(loc, baseAddr); + retTy = baseAddr.getType(); + } + + mlir::TypeAttr varType = mlir::TypeAttr::get( + llvm::cast(retTy).getElementType()); + + mlir::omp::MapInfoOp op = builder.create( + loc, retTy, baseAddr, varType, varPtrPtr, members, bounds, + builder.getIntegerAttr(builder.getIntegerType(64, false), mapType), + builder.getAttr(mapCaptureType), + builder.getStringAttr(name)); + + return op; +} + +bool ClauseProcessor::processMap( + mlir::Location currentLocation, const llvm::omp::Directive &directive, + Fortran::lower::StatementContext &stmtCtx, + llvm::SmallVectorImpl &mapOperands, + llvm::SmallVectorImpl *mapSymTypes, + llvm::SmallVectorImpl *mapSymLocs, + llvm::SmallVectorImpl *mapSymbols) + const { + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + return findRepeatableClause( + [&](const ClauseTy::Map *mapClause, + const Fortran::parser::CharBlock &source) { + mlir::Location clauseLocation = converter.genLocation(source); + const auto &oMapType = + std::get>( + mapClause->v.t); + llvm::omp::OpenMPOffloadMappingFlags mapTypeBits = + llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE; + // If the map type is specified, then process it else Tofrom is the + // default. + if (oMapType) { + const Fortran::parser::OmpMapType::Type &mapType = + std::get(oMapType->t); + switch (mapType) { + case Fortran::parser::OmpMapType::Type::To: + mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO; + break; + case Fortran::parser::OmpMapType::Type::From: + mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM; + break; + case Fortran::parser::OmpMapType::Type::Tofrom: + mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO | + llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM; + break; + case Fortran::parser::OmpMapType::Type::Alloc: + case Fortran::parser::OmpMapType::Type::Release: + // alloc and release is the default map_type for the Target Data + // Ops, i.e. if no bits for map_type is supplied then alloc/release + // is implicitly assumed based on the target directive. Default + // value for Target Data and Enter Data is alloc and for Exit Data + // it is release. + break; + case Fortran::parser::OmpMapType::Type::Delete: + mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE; + } + + if (std::get>( + oMapType->t)) + mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS; + } else { + mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO | + llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM; + } + + for (const Fortran::parser::OmpObject &ompObject : + std::get(mapClause->v.t).v) { + llvm::SmallVector bounds; + std::stringstream asFortran; + + Fortran::lower::AddrAndBoundsInfo info = + Fortran::lower::gatherDataOperandAddrAndBounds< + Fortran::parser::OmpObject, mlir::omp::DataBoundsOp, + mlir::omp::DataBoundsType>( + converter, firOpBuilder, semaCtx, stmtCtx, ompObject, + clauseLocation, asFortran, bounds, treatIndexAsSection); + + auto origSymbol = + converter.getSymbolAddress(*getOmpObjectSymbol(ompObject)); + mlir::Value symAddr = info.addr; + if (origSymbol && fir::isTypeWithDescriptor(origSymbol.getType())) + symAddr = origSymbol; + + // Explicit map captures are captured ByRef by default, + // optimisation passes may alter this to ByCopy or other capture + // types to optimise + mlir::Value mapOp = createMapInfoOp( + firOpBuilder, clauseLocation, symAddr, mlir::Value{}, + asFortran.str(), bounds, {}, + static_cast< + std::underlying_type_t>( + mapTypeBits), + mlir::omp::VariableCaptureKind::ByRef, symAddr.getType()); + + mapOperands.push_back(mapOp); + if (mapSymTypes) + mapSymTypes->push_back(symAddr.getType()); + if (mapSymLocs) + mapSymLocs->push_back(symAddr.getLoc()); + + if (mapSymbols) + mapSymbols->push_back(getOmpObjectSymbol(ompObject)); + } + }); +} + +bool ClauseProcessor::processReduction( + mlir::Location currentLocation, + llvm::SmallVectorImpl &reductionVars, + llvm::SmallVectorImpl &reductionDeclSymbols, + llvm::SmallVectorImpl *reductionSymbols) + const { + return findRepeatableClause( + [&](const ClauseTy::Reduction *reductionClause, + const Fortran::parser::CharBlock &) { + ReductionProcessor rp; + rp.addReductionDecl(currentLocation, converter, reductionClause->v, + reductionVars, reductionDeclSymbols, + reductionSymbols); + }); +} + +bool ClauseProcessor::processSectionsReduction( + mlir::Location currentLocation) const { + return findRepeatableClause( + [&](const ClauseTy::Reduction *, const Fortran::parser::CharBlock &) { + TODO(currentLocation, "OMPC_Reduction"); + }); +} + +bool ClauseProcessor::processTo( + llvm::SmallVectorImpl &result) const { + return findRepeatableClause( + [&](const ClauseTy::To *toClause, const Fortran::parser::CharBlock &) { + // Case: declare target to(func, var1, var2)... + gatherFuncAndVarSyms(toClause->v, + mlir::omp::DeclareTargetCaptureClause::to, result); + }); +} + +bool ClauseProcessor::processEnter( + llvm::SmallVectorImpl &result) const { + return findRepeatableClause( + [&](const ClauseTy::Enter *enterClause, + const Fortran::parser::CharBlock &) { + // Case: declare target enter(func, var1, var2)... + gatherFuncAndVarSyms(enterClause->v, + mlir::omp::DeclareTargetCaptureClause::enter, + result); + }); +} + +bool ClauseProcessor::processUseDeviceAddr( + llvm::SmallVectorImpl &operands, + llvm::SmallVectorImpl &useDeviceTypes, + llvm::SmallVectorImpl &useDeviceLocs, + llvm::SmallVectorImpl &useDeviceSymbols) + const { + return findRepeatableClause( + [&](const ClauseTy::UseDeviceAddr *devAddrClause, + const Fortran::parser::CharBlock &) { + addUseDeviceClause(converter, devAddrClause->v, operands, + useDeviceTypes, useDeviceLocs, useDeviceSymbols); + }); +} + +bool ClauseProcessor::processUseDevicePtr( + llvm::SmallVectorImpl &operands, + llvm::SmallVectorImpl &useDeviceTypes, + llvm::SmallVectorImpl &useDeviceLocs, + llvm::SmallVectorImpl &useDeviceSymbols) + const { + return findRepeatableClause( + [&](const ClauseTy::UseDevicePtr *devPtrClause, + const Fortran::parser::CharBlock &) { + addUseDeviceClause(converter, devPtrClause->v, operands, useDeviceTypes, + useDeviceLocs, useDeviceSymbols); + }); +} +} // namespace omp +} // namespace lower +} // namespace Fortran diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.h b/flang/lib/Lower/OpenMP/ClauseProcessor.h new file mode 100644 index 0000000000000..312255112605e --- /dev/null +++ b/flang/lib/Lower/OpenMP/ClauseProcessor.h @@ -0,0 +1,305 @@ +//===-- Lower/OpenMP/ClauseProcessor.h --------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ +// +//===----------------------------------------------------------------------===// +#ifndef FORTRAN_LOWER_CLAUASEPROCESSOR_H +#define FORTRAN_LOWER_CLAUASEPROCESSOR_H + +#include "DirectivesCommon.h" +#include "ReductionProcessor.h" +#include "Utils.h" +#include "flang/Lower/AbstractConverter.h" +#include "flang/Lower/Bridge.h" +#include "flang/Optimizer/Builder/Todo.h" +#include "flang/Parser/dump-parse-tree.h" +#include "flang/Parser/parse-tree.h" +#include "mlir/Dialect/OpenMP/OpenMPDialect.h" + +namespace fir { +class FirOpBuilder; +} // namespace fir + +namespace Fortran { +namespace lower { +namespace omp { + +/// Class that handles the processing of OpenMP clauses. +/// +/// Its `process()` methods perform MLIR code generation for their +/// corresponding clause if it is present in the clause list. Otherwise, they +/// will return `false` to signal that the clause was not found. +/// +/// The intended use is of this class is to move clause processing outside of +/// construct processing, since the same clauses can appear attached to +/// different constructs and constructs can be combined, so that code +/// duplication is minimized. +/// +/// Each construct-lowering function only calls the `process()` +/// methods that relate to clauses that can impact the lowering of that +/// construct. +class ClauseProcessor { + using ClauseTy = Fortran::parser::OmpClause; + +public: + ClauseProcessor(Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, + const Fortran::parser::OmpClauseList &clauses) + : converter(converter), semaCtx(semaCtx), clauses(clauses) {} + + // 'Unique' clauses: They can appear at most once in the clause list. + bool + processCollapse(mlir::Location currentLocation, + Fortran::lower::pft::Evaluation &eval, + llvm::SmallVectorImpl &lowerBound, + llvm::SmallVectorImpl &upperBound, + llvm::SmallVectorImpl &step, + llvm::SmallVectorImpl &iv, + std::size_t &loopVarTypeSize) const; + bool processDefault() const; + bool processDevice(Fortran::lower::StatementContext &stmtCtx, + mlir::Value &result) const; + bool processDeviceType(mlir::omp::DeclareTargetDeviceType &result) const; + bool processFinal(Fortran::lower::StatementContext &stmtCtx, + mlir::Value &result) const; + bool processHint(mlir::IntegerAttr &result) const; + bool processMergeable(mlir::UnitAttr &result) const; + bool processNowait(mlir::UnitAttr &result) const; + bool processNumTeams(Fortran::lower::StatementContext &stmtCtx, + mlir::Value &result) const; + bool processNumThreads(Fortran::lower::StatementContext &stmtCtx, + mlir::Value &result) const; + bool processOrdered(mlir::IntegerAttr &result) const; + bool processPriority(Fortran::lower::StatementContext &stmtCtx, + mlir::Value &result) const; + bool processProcBind(mlir::omp::ClauseProcBindKindAttr &result) const; + bool processSafelen(mlir::IntegerAttr &result) const; + bool processSchedule(mlir::omp::ClauseScheduleKindAttr &valAttr, + mlir::omp::ScheduleModifierAttr &modifierAttr, + mlir::UnitAttr &simdModifierAttr) const; + bool processScheduleChunk(Fortran::lower::StatementContext &stmtCtx, + mlir::Value &result) const; + bool processSimdlen(mlir::IntegerAttr &result) const; + bool processThreadLimit(Fortran::lower::StatementContext &stmtCtx, + mlir::Value &result) const; + bool processUntied(mlir::UnitAttr &result) const; + + // 'Repeatable' clauses: They can appear multiple times in the clause list. + bool + processAllocate(llvm::SmallVectorImpl &allocatorOperands, + llvm::SmallVectorImpl &allocateOperands) const; + bool processCopyin() const; + bool processDepend(llvm::SmallVectorImpl &dependTypeOperands, + llvm::SmallVectorImpl &dependOperands) const; + bool + processEnter(llvm::SmallVectorImpl &result) const; + bool + processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier directiveName, + mlir::Value &result) const; + bool + processLink(llvm::SmallVectorImpl &result) const; + + // This method is used to process a map clause. + // The optional parameters - mapSymTypes, mapSymLocs & mapSymbols are used to + // store the original type, location and Fortran symbol for the map operands. + // They may be used later on to create the block_arguments for some of the + // target directives that require it. + bool processMap(mlir::Location currentLocation, + const llvm::omp::Directive &directive, + Fortran::lower::StatementContext &stmtCtx, + llvm::SmallVectorImpl &mapOperands, + llvm::SmallVectorImpl *mapSymTypes = nullptr, + llvm::SmallVectorImpl *mapSymLocs = nullptr, + llvm::SmallVectorImpl + *mapSymbols = nullptr) const; + bool + processReduction(mlir::Location currentLocation, + llvm::SmallVectorImpl &reductionVars, + llvm::SmallVectorImpl &reductionDeclSymbols, + llvm::SmallVectorImpl + *reductionSymbols = nullptr) const; + bool processSectionsReduction(mlir::Location currentLocation) const; + bool processTo(llvm::SmallVectorImpl &result) const; + bool + processUseDeviceAddr(llvm::SmallVectorImpl &operands, + llvm::SmallVectorImpl &useDeviceTypes, + llvm::SmallVectorImpl &useDeviceLocs, + llvm::SmallVectorImpl + &useDeviceSymbols) const; + bool + processUseDevicePtr(llvm::SmallVectorImpl &operands, + llvm::SmallVectorImpl &useDeviceTypes, + llvm::SmallVectorImpl &useDeviceLocs, + llvm::SmallVectorImpl + &useDeviceSymbols) const; + + template + bool processMotionClauses(Fortran::lower::StatementContext &stmtCtx, + llvm::SmallVectorImpl &mapOperands); + + // Call this method for these clauses that should be supported but are not + // implemented yet. It triggers a compilation error if any of the given + // clauses is found. + template + void processTODO(mlir::Location currentLocation, + llvm::omp::Directive directive) const; + +private: + using ClauseIterator = std::list::const_iterator; + + /// Utility to find a clause within a range in the clause list. + template + static ClauseIterator findClause(ClauseIterator begin, ClauseIterator end); + + /// Return the first instance of the given clause found in the clause list or + /// `nullptr` if not present. If more than one instance is expected, use + /// `findRepeatableClause` instead. + template + const T * + findUniqueClause(const Fortran::parser::CharBlock **source = nullptr) const; + + /// Call `callbackFn` for each occurrence of the given clause. Return `true` + /// if at least one instance was found. + template + bool findRepeatableClause( + std::function + callbackFn) const; + + /// Set the `result` to a new `mlir::UnitAttr` if the clause is present. + template + bool markClauseOccurrence(mlir::UnitAttr &result) const; + + Fortran::lower::AbstractConverter &converter; + Fortran::semantics::SemanticsContext &semaCtx; + const Fortran::parser::OmpClauseList &clauses; +}; + +template +bool ClauseProcessor::processMotionClauses( + Fortran::lower::StatementContext &stmtCtx, + llvm::SmallVectorImpl &mapOperands) { + return findRepeatableClause( + [&](const T *motionClause, const Fortran::parser::CharBlock &source) { + mlir::Location clauseLocation = converter.genLocation(source); + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + + static_assert(std::is_same_v || + std::is_same_v); + + // TODO Support motion modifiers: present, mapper, iterator. + constexpr llvm::omp::OpenMPOffloadMappingFlags mapTypeBits = + std::is_same_v + ? llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO + : llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM; + + for (const Fortran::parser::OmpObject &ompObject : motionClause->v.v) { + llvm::SmallVector bounds; + std::stringstream asFortran; + Fortran::lower::AddrAndBoundsInfo info = + Fortran::lower::gatherDataOperandAddrAndBounds< + Fortran::parser::OmpObject, mlir::omp::DataBoundsOp, + mlir::omp::DataBoundsType>( + converter, firOpBuilder, semaCtx, stmtCtx, ompObject, + clauseLocation, asFortran, bounds, treatIndexAsSection); + + auto origSymbol = + converter.getSymbolAddress(*getOmpObjectSymbol(ompObject)); + mlir::Value symAddr = info.addr; + if (origSymbol && fir::isTypeWithDescriptor(origSymbol.getType())) + symAddr = origSymbol; + + // Explicit map captures are captured ByRef by default, + // optimisation passes may alter this to ByCopy or other capture + // types to optimise + mlir::Value mapOp = createMapInfoOp( + firOpBuilder, clauseLocation, symAddr, mlir::Value{}, + asFortran.str(), bounds, {}, + static_cast< + std::underlying_type_t>( + mapTypeBits), + mlir::omp::VariableCaptureKind::ByRef, symAddr.getType()); + + mapOperands.push_back(mapOp); + } + }); +} + +template +void ClauseProcessor::processTODO(mlir::Location currentLocation, + llvm::omp::Directive directive) const { + auto checkUnhandledClause = [&](const auto *x) { + if (!x) + return; + TODO(currentLocation, + "Unhandled clause " + + llvm::StringRef(Fortran::parser::ParseTreeDumper::GetNodeName(*x)) + .upper() + + " in " + llvm::omp::getOpenMPDirectiveName(directive).upper() + + " construct"); + }; + + for (ClauseIterator it = clauses.v.begin(); it != clauses.v.end(); ++it) + (checkUnhandledClause(std::get_if(&it->u)), ...); +} + +template +ClauseProcessor::ClauseIterator +ClauseProcessor::findClause(ClauseIterator begin, ClauseIterator end) { + for (ClauseIterator it = begin; it != end; ++it) { + if (std::get_if(&it->u)) + return it; + } + + return end; +} + +template +const T *ClauseProcessor::findUniqueClause( + const Fortran::parser::CharBlock **source) const { + ClauseIterator it = findClause(clauses.v.begin(), clauses.v.end()); + if (it != clauses.v.end()) { + if (source) + *source = &it->source; + return &std::get(it->u); + } + return nullptr; +} + +template +bool ClauseProcessor::findRepeatableClause( + std::function + callbackFn) const { + bool found = false; + ClauseIterator nextIt, endIt = clauses.v.end(); + for (ClauseIterator it = clauses.v.begin(); it != endIt; it = nextIt) { + nextIt = findClause(it, endIt); + + if (nextIt != endIt) { + callbackFn(&std::get(nextIt->u), nextIt->source); + found = true; + ++nextIt; + } + } + return found; +} + +template +bool ClauseProcessor::markClauseOccurrence(mlir::UnitAttr &result) const { + if (findUniqueClause()) { + result = converter.getFirOpBuilder().getUnitAttr(); + return true; + } + return false; +} + +} // namespace omp +} // namespace lower +} // namespace Fortran + +#endif // FORTRAN_LOWER_CLAUASEPROCESSOR_H diff --git a/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp b/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp new file mode 100644 index 0000000000000..136bda0b582ee --- /dev/null +++ b/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp @@ -0,0 +1,350 @@ +//===-- DataSharingProcessor.cpp --------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ +// +//===----------------------------------------------------------------------===// + +#include "DataSharingProcessor.h" + +#include "Utils.h" +#include "flang/Lower/PFTBuilder.h" +#include "flang/Optimizer/Builder/Todo.h" +#include "flang/Semantics/tools.h" +#include "mlir/Dialect/OpenMP/OpenMPDialect.h" + +namespace Fortran { +namespace lower { +namespace omp { + +void DataSharingProcessor::processStep1() { + collectSymbolsForPrivatization(); + collectDefaultSymbols(); + privatize(); + defaultPrivatize(); + insertBarrier(); +} + +void DataSharingProcessor::processStep2(mlir::Operation *op, bool isLoop) { + insPt = firOpBuilder.saveInsertionPoint(); + copyLastPrivatize(op); + firOpBuilder.restoreInsertionPoint(insPt); + + if (isLoop) { + // push deallocs out of the loop + firOpBuilder.setInsertionPointAfter(op); + insertDeallocs(); + } else { + // insert dummy instruction to mark the insertion position + mlir::Value undefMarker = firOpBuilder.create( + op->getLoc(), firOpBuilder.getIndexType()); + insertDeallocs(); + firOpBuilder.setInsertionPointAfter(undefMarker.getDefiningOp()); + } +} + +void DataSharingProcessor::insertDeallocs() { + for (const Fortran::semantics::Symbol *sym : privatizedSymbols) + if (Fortran::semantics::IsAllocatable(sym->GetUltimate())) { + converter.createHostAssociateVarCloneDealloc(*sym); + } +} + +void DataSharingProcessor::cloneSymbol(const Fortran::semantics::Symbol *sym) { + // Privatization for symbols which are pre-determined (like loop index + // variables) happen separately, for everything else privatize here. + if (sym->test(Fortran::semantics::Symbol::Flag::OmpPreDetermined)) + return; + bool success = converter.createHostAssociateVarClone(*sym); + (void)success; + assert(success && "Privatization failed due to existing binding"); +} + +void DataSharingProcessor::copyFirstPrivateSymbol( + const Fortran::semantics::Symbol *sym) { + if (sym->test(Fortran::semantics::Symbol::Flag::OmpFirstPrivate)) + converter.copyHostAssociateVar(*sym); +} + +void DataSharingProcessor::copyLastPrivateSymbol( + const Fortran::semantics::Symbol *sym, + [[maybe_unused]] mlir::OpBuilder::InsertPoint *lastPrivIP) { + if (sym->test(Fortran::semantics::Symbol::Flag::OmpLastPrivate)) + converter.copyHostAssociateVar(*sym, lastPrivIP); +} + +void DataSharingProcessor::collectOmpObjectListSymbol( + const Fortran::parser::OmpObjectList &ompObjectList, + llvm::SetVector &symbolSet) { + for (const Fortran::parser::OmpObject &ompObject : ompObjectList.v) { + Fortran::semantics::Symbol *sym = getOmpObjectSymbol(ompObject); + symbolSet.insert(sym); + } +} + +void DataSharingProcessor::collectSymbolsForPrivatization() { + bool hasCollapse = false; + for (const Fortran::parser::OmpClause &clause : opClauseList.v) { + if (const auto &privateClause = + std::get_if(&clause.u)) { + collectOmpObjectListSymbol(privateClause->v, privatizedSymbols); + } else if (const auto &firstPrivateClause = + std::get_if( + &clause.u)) { + collectOmpObjectListSymbol(firstPrivateClause->v, privatizedSymbols); + } else if (const auto &lastPrivateClause = + std::get_if( + &clause.u)) { + collectOmpObjectListSymbol(lastPrivateClause->v, privatizedSymbols); + hasLastPrivateOp = true; + } else if (std::get_if(&clause.u)) { + hasCollapse = true; + } + } + + if (hasCollapse && hasLastPrivateOp) + TODO(converter.getCurrentLocation(), "Collapse clause with lastprivate"); +} + +bool DataSharingProcessor::needBarrier() { + for (const Fortran::semantics::Symbol *sym : privatizedSymbols) { + if (sym->test(Fortran::semantics::Symbol::Flag::OmpFirstPrivate) && + sym->test(Fortran::semantics::Symbol::Flag::OmpLastPrivate)) + return true; + } + return false; +} + +void DataSharingProcessor::insertBarrier() { + // Emit implicit barrier to synchronize threads and avoid data races on + // initialization of firstprivate variables and post-update of lastprivate + // variables. + // FIXME: Emit barrier for lastprivate clause when 'sections' directive has + // 'nowait' clause. Otherwise, emit barrier when 'sections' directive has + // both firstprivate and lastprivate clause. + // Emit implicit barrier for linear clause. Maybe on somewhere else. + if (needBarrier()) + firOpBuilder.create(converter.getCurrentLocation()); +} + +void DataSharingProcessor::insertLastPrivateCompare(mlir::Operation *op) { + bool cmpCreated = false; + mlir::OpBuilder::InsertPoint localInsPt = firOpBuilder.saveInsertionPoint(); + for (const Fortran::parser::OmpClause &clause : opClauseList.v) { + if (std::get_if(&clause.u)) { + // TODO: Add lastprivate support for simd construct + if (mlir::isa(op)) { + if (&eval == &eval.parentConstruct->getLastNestedEvaluation()) { + // For `omp.sections`, lastprivatized variables occur in + // lexically final `omp.section` operation. The following FIR + // shall be generated for the same: + // + // omp.sections lastprivate(...) { + // omp.section {...} + // omp.section {...} + // omp.section { + // fir.allocate for `private`/`firstprivate` + // + // fir.if %true { + // ^%lpv_update_blk + // } + // } + // } + // + // To keep code consistency while handling privatization + // through this control flow, add a `fir.if` operation + // that always evaluates to true, in order to create + // a dedicated sub-region in `omp.section` where + // lastprivate FIR can reside. Later canonicalizations + // will optimize away this operation. + if (!eval.lowerAsUnstructured()) { + auto ifOp = firOpBuilder.create( + op->getLoc(), + firOpBuilder.createIntegerConstant( + op->getLoc(), firOpBuilder.getIntegerType(1), 0x1), + /*else*/ false); + firOpBuilder.setInsertionPointToStart( + &ifOp.getThenRegion().front()); + + const Fortran::parser::OpenMPConstruct *parentOmpConstruct = + eval.parentConstruct->getIf(); + assert(parentOmpConstruct && + "Expected a valid enclosing OpenMP construct"); + const Fortran::parser::OpenMPSectionsConstruct *sectionsConstruct = + std::get_if( + &parentOmpConstruct->u); + assert(sectionsConstruct && + "Expected an enclosing omp.sections construct"); + const Fortran::parser::OmpClauseList §ionsEndClauseList = + std::get( + std::get( + sectionsConstruct->t) + .t); + for (const Fortran::parser::OmpClause &otherClause : + sectionsEndClauseList.v) + if (std::get_if( + &otherClause.u)) + // Emit implicit barrier to synchronize threads and avoid data + // races on post-update of lastprivate variables when `nowait` + // clause is present. + firOpBuilder.create( + converter.getCurrentLocation()); + firOpBuilder.setInsertionPointToStart( + &ifOp.getThenRegion().front()); + lastPrivIP = firOpBuilder.saveInsertionPoint(); + firOpBuilder.setInsertionPoint(ifOp); + insPt = firOpBuilder.saveInsertionPoint(); + } else { + // Lastprivate operation is inserted at the end + // of the lexically last section in the sections + // construct + mlir::OpBuilder::InsertPoint unstructuredSectionsIP = + firOpBuilder.saveInsertionPoint(); + mlir::Operation *lastOper = op->getRegion(0).back().getTerminator(); + firOpBuilder.setInsertionPoint(lastOper); + lastPrivIP = firOpBuilder.saveInsertionPoint(); + firOpBuilder.restoreInsertionPoint(unstructuredSectionsIP); + } + } + } else if (mlir::isa(op)) { + // Update the original variable just before exiting the worksharing + // loop. Conversion as follows: + // + // omp.wsloop { + // omp.wsloop { ... + // ... store + // store ===> %v = arith.addi %iv, %step + // omp.yield %cmp = %step < 0 ? %v < %ub : %v > %ub + // } fir.if %cmp { + // fir.store %v to %loopIV + // ^%lpv_update_blk: + // } + // omp.yield + // } + // + + // Only generate the compare once in presence of multiple LastPrivate + // clauses. + if (cmpCreated) + continue; + cmpCreated = true; + + mlir::Location loc = op->getLoc(); + mlir::Operation *lastOper = op->getRegion(0).back().getTerminator(); + firOpBuilder.setInsertionPoint(lastOper); + + mlir::Value iv = op->getRegion(0).front().getArguments()[0]; + mlir::Value ub = + mlir::dyn_cast(op).getUpperBound()[0]; + mlir::Value step = mlir::dyn_cast(op).getStep()[0]; + + // v = iv + step + // cmp = step < 0 ? v < ub : v > ub + mlir::Value v = firOpBuilder.create(loc, iv, step); + mlir::Value zero = + firOpBuilder.createIntegerConstant(loc, step.getType(), 0); + mlir::Value negativeStep = firOpBuilder.create( + loc, mlir::arith::CmpIPredicate::slt, step, zero); + mlir::Value vLT = firOpBuilder.create( + loc, mlir::arith::CmpIPredicate::slt, v, ub); + mlir::Value vGT = firOpBuilder.create( + loc, mlir::arith::CmpIPredicate::sgt, v, ub); + mlir::Value cmpOp = firOpBuilder.create( + loc, negativeStep, vLT, vGT); + + auto ifOp = firOpBuilder.create(loc, cmpOp, /*else*/ false); + firOpBuilder.setInsertionPointToStart(&ifOp.getThenRegion().front()); + assert(loopIV && "loopIV was not set"); + firOpBuilder.create(op->getLoc(), v, loopIV); + lastPrivIP = firOpBuilder.saveInsertionPoint(); + } else { + TODO(converter.getCurrentLocation(), + "lastprivate clause in constructs other than " + "simd/worksharing-loop"); + } + } + } + firOpBuilder.restoreInsertionPoint(localInsPt); +} + +void DataSharingProcessor::collectSymbols( + Fortran::semantics::Symbol::Flag flag) { + converter.collectSymbolSet(eval, defaultSymbols, flag, + /*collectSymbols=*/true, + /*collectHostAssociatedSymbols=*/true); + for (Fortran::lower::pft::Evaluation &e : eval.getNestedEvaluations()) { + if (e.hasNestedEvaluations()) + converter.collectSymbolSet(e, symbolsInNestedRegions, flag, + /*collectSymbols=*/true, + /*collectHostAssociatedSymbols=*/false); + else + converter.collectSymbolSet(e, symbolsInParentRegions, flag, + /*collectSymbols=*/false, + /*collectHostAssociatedSymbols=*/true); + } +} + +void DataSharingProcessor::collectDefaultSymbols() { + for (const Fortran::parser::OmpClause &clause : opClauseList.v) { + if (const auto &defaultClause = + std::get_if(&clause.u)) { + if (defaultClause->v.v == + Fortran::parser::OmpDefaultClause::Type::Private) + collectSymbols(Fortran::semantics::Symbol::Flag::OmpPrivate); + else if (defaultClause->v.v == + Fortran::parser::OmpDefaultClause::Type::Firstprivate) + collectSymbols(Fortran::semantics::Symbol::Flag::OmpFirstPrivate); + } + } +} + +void DataSharingProcessor::privatize() { + for (const Fortran::semantics::Symbol *sym : privatizedSymbols) { + if (const auto *commonDet = + sym->detailsIf()) { + for (const auto &mem : commonDet->objects()) { + cloneSymbol(&*mem); + copyFirstPrivateSymbol(&*mem); + } + } else { + cloneSymbol(sym); + copyFirstPrivateSymbol(sym); + } + } +} + +void DataSharingProcessor::copyLastPrivatize(mlir::Operation *op) { + insertLastPrivateCompare(op); + for (const Fortran::semantics::Symbol *sym : privatizedSymbols) + if (const auto *commonDet = + sym->detailsIf()) { + for (const auto &mem : commonDet->objects()) { + copyLastPrivateSymbol(&*mem, &lastPrivIP); + } + } else { + copyLastPrivateSymbol(sym, &lastPrivIP); + } +} + +void DataSharingProcessor::defaultPrivatize() { + for (const Fortran::semantics::Symbol *sym : defaultSymbols) { + if (!Fortran::semantics::IsProcedure(*sym) && + !sym->GetUltimate().has() && + !sym->GetUltimate().has() && + !symbolsInNestedRegions.contains(sym) && + !symbolsInParentRegions.contains(sym) && + !privatizedSymbols.contains(sym)) { + cloneSymbol(sym); + copyFirstPrivateSymbol(sym); + } + } +} + +} // namespace omp +} // namespace lower +} // namespace Fortran diff --git a/flang/lib/Lower/OpenMP/DataSharingProcessor.h b/flang/lib/Lower/OpenMP/DataSharingProcessor.h new file mode 100644 index 0000000000000..10c0a30c09c39 --- /dev/null +++ b/flang/lib/Lower/OpenMP/DataSharingProcessor.h @@ -0,0 +1,89 @@ +//===-- Lower/OpenMP/DataSharingProcessor.h ---------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ +// +//===----------------------------------------------------------------------===// +#ifndef FORTRAN_LOWER_DATASHARINGPROCESSOR_H +#define FORTRAN_LOWER_DATASHARINGPROCESSOR_H + +#include "flang/Lower/AbstractConverter.h" +#include "flang/Lower/OpenMP.h" +#include "flang/Optimizer/Builder/FIRBuilder.h" +#include "flang/Parser/parse-tree.h" +#include "flang/Semantics/symbol.h" + +namespace Fortran { +namespace lower { +namespace omp { + +class DataSharingProcessor { + bool hasLastPrivateOp; + mlir::OpBuilder::InsertPoint lastPrivIP; + mlir::OpBuilder::InsertPoint insPt; + mlir::Value loopIV; + // Symbols in private, firstprivate, and/or lastprivate clauses. + llvm::SetVector privatizedSymbols; + llvm::SetVector defaultSymbols; + llvm::SetVector symbolsInNestedRegions; + llvm::SetVector symbolsInParentRegions; + Fortran::lower::AbstractConverter &converter; + fir::FirOpBuilder &firOpBuilder; + const Fortran::parser::OmpClauseList &opClauseList; + Fortran::lower::pft::Evaluation &eval; + + bool needBarrier(); + void collectSymbols(Fortran::semantics::Symbol::Flag flag); + void collectOmpObjectListSymbol( + const Fortran::parser::OmpObjectList &ompObjectList, + llvm::SetVector &symbolSet); + void collectSymbolsForPrivatization(); + void insertBarrier(); + void collectDefaultSymbols(); + void privatize(); + void defaultPrivatize(); + void copyLastPrivatize(mlir::Operation *op); + void insertLastPrivateCompare(mlir::Operation *op); + void cloneSymbol(const Fortran::semantics::Symbol *sym); + void copyFirstPrivateSymbol(const Fortran::semantics::Symbol *sym); + void copyLastPrivateSymbol(const Fortran::semantics::Symbol *sym, + mlir::OpBuilder::InsertPoint *lastPrivIP); + void insertDeallocs(); + +public: + DataSharingProcessor(Fortran::lower::AbstractConverter &converter, + const Fortran::parser::OmpClauseList &opClauseList, + Fortran::lower::pft::Evaluation &eval) + : hasLastPrivateOp(false), converter(converter), + firOpBuilder(converter.getFirOpBuilder()), opClauseList(opClauseList), + eval(eval) {} + // Privatisation is split into two steps. + // Step1 performs cloning of all privatisation clauses and copying for + // firstprivates. Step1 is performed at the place where process/processStep1 + // is called. This is usually inside the Operation corresponding to the OpenMP + // construct, for looping constructs this is just before the Operation. The + // split into two steps was performed basically to be able to call + // privatisation for looping constructs before the operation is created since + // the bounds of the MLIR OpenMP operation can be privatised. + // Step2 performs the copying for lastprivates and requires knowledge of the + // MLIR operation to insert the last private update. Step2 adds + // dealocation code as well. + void processStep1(); + void processStep2(mlir::Operation *op, bool isLoop); + + void setLoopIV(mlir::Value iv) { + assert(!loopIV && "Loop iteration variable already set"); + loopIV = iv; + } +}; + +} // namespace omp +} // namespace lower +} // namespace Fortran + +#endif // FORTRAN_LOWER_DATASHARINGPROCESSOR_H diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp similarity index 55% rename from flang/lib/Lower/OpenMP.cpp rename to flang/lib/Lower/OpenMP/OpenMP.cpp index 9397af8b8bd05..3aefad6cf0ec1 100644 --- a/flang/lib/Lower/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP/OpenMP.cpp @@ -11,109 +11,36 @@ //===----------------------------------------------------------------------===// #include "flang/Lower/OpenMP.h" + +#include "ClauseProcessor.h" +#include "DataSharingProcessor.h" #include "DirectivesCommon.h" +#include "ReductionProcessor.h" #include "flang/Common/idioms.h" #include "flang/Lower/Bridge.h" #include "flang/Lower/ConvertExpr.h" #include "flang/Lower/ConvertVariable.h" -#include "flang/Lower/PFTBuilder.h" #include "flang/Lower/StatementContext.h" #include "flang/Lower/SymbolMap.h" #include "flang/Optimizer/Builder/BoxValue.h" #include "flang/Optimizer/Builder/FIRBuilder.h" #include "flang/Optimizer/Builder/Todo.h" #include "flang/Optimizer/HLFIR/HLFIROps.h" -#include "flang/Parser/dump-parse-tree.h" #include "flang/Parser/parse-tree.h" #include "flang/Semantics/openmp-directive-sets.h" #include "flang/Semantics/tools.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/OpenMP/OpenMPDialect.h" -#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Transforms/RegionUtils.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Frontend/OpenMP/OMPConstants.h" -#include "llvm/Support/CommandLine.h" - -static llvm::cl::opt treatIndexAsSection( - "openmp-treat-index-as-section", - llvm::cl::desc("In the OpenMP data clauses treat `a(N)` as `a(N:N)`."), - llvm::cl::init(true)); -using DeclareTargetCapturePair = - std::pair; +using namespace Fortran::lower::omp; //===----------------------------------------------------------------------===// -// Common helper functions +// Code generation helper functions //===----------------------------------------------------------------------===// -static Fortran::semantics::Symbol * -getOmpObjectSymbol(const Fortran::parser::OmpObject &ompObject) { - Fortran::semantics::Symbol *sym = nullptr; - std::visit( - Fortran::common::visitors{ - [&](const Fortran::parser::Designator &designator) { - if (auto *arrayEle = - Fortran::parser::Unwrap( - designator)) { - sym = GetFirstName(arrayEle->base).symbol; - } else if (auto *structComp = Fortran::parser::Unwrap< - Fortran::parser::StructureComponent>(designator)) { - sym = structComp->component.symbol; - } else if (const Fortran::parser::Name *name = - Fortran::semantics::getDesignatorNameIfDataRef( - designator)) { - sym = name->symbol; - } - }, - [&](const Fortran::parser::Name &name) { sym = name.symbol; }}, - ompObject.u); - return sym; -} - -static void genObjectList(const Fortran::parser::OmpObjectList &objectList, - Fortran::lower::AbstractConverter &converter, - llvm::SmallVectorImpl &operands) { - auto addOperands = [&](Fortran::lower::SymbolRef sym) { - const mlir::Value variable = converter.getSymbolAddress(sym); - if (variable) { - operands.push_back(variable); - } else { - if (const auto *details = - sym->detailsIf()) { - operands.push_back(converter.getSymbolAddress(details->symbol())); - converter.copySymbolBinding(details->symbol(), sym); - } - } - }; - for (const Fortran::parser::OmpObject &ompObject : objectList.v) { - Fortran::semantics::Symbol *sym = getOmpObjectSymbol(ompObject); - addOperands(*sym); - } -} - -static void gatherFuncAndVarSyms( - const Fortran::parser::OmpObjectList &objList, - mlir::omp::DeclareTargetCaptureClause clause, - llvm::SmallVectorImpl &symbolAndClause) { - for (const Fortran::parser::OmpObject &ompObject : objList.v) { - Fortran::common::visit( - Fortran::common::visitors{ - [&](const Fortran::parser::Designator &designator) { - if (const Fortran::parser::Name *name = - Fortran::semantics::getDesignatorNameIfDataRef( - designator)) { - symbolAndClause.emplace_back(clause, *name->symbol); - } - }, - [&](const Fortran::parser::Name &name) { - symbolAndClause.emplace_back(clause, *name.symbol); - }}, - ompObject.u); - } -} - static Fortran::lower::pft::Evaluation * getCollapsedLoopEval(Fortran::lower::pft::Evaluation &eval, int collapseValue) { // Return the Evaluation of the innermost collapsed loop, or the current one @@ -142,1961 +69,6 @@ static void genNestedEvaluations(Fortran::lower::AbstractConverter &converter, converter.genEval(e); } -//===----------------------------------------------------------------------===// -// DataSharingProcessor -//===----------------------------------------------------------------------===// - -class DataSharingProcessor { - bool hasLastPrivateOp; - mlir::OpBuilder::InsertPoint lastPrivIP; - mlir::OpBuilder::InsertPoint insPt; - mlir::Value loopIV; - // Symbols in private, firstprivate, and/or lastprivate clauses. - llvm::SetVector privatizedSymbols; - llvm::SetVector defaultSymbols; - llvm::SetVector symbolsInNestedRegions; - llvm::SetVector symbolsInParentRegions; - Fortran::lower::AbstractConverter &converter; - fir::FirOpBuilder &firOpBuilder; - const Fortran::parser::OmpClauseList &opClauseList; - Fortran::lower::pft::Evaluation &eval; - - bool needBarrier(); - void collectSymbols(Fortran::semantics::Symbol::Flag flag); - void collectOmpObjectListSymbol( - const Fortran::parser::OmpObjectList &ompObjectList, - llvm::SetVector &symbolSet); - void collectSymbolsForPrivatization(); - void insertBarrier(); - void collectDefaultSymbols(); - void privatize(); - void defaultPrivatize(); - void copyLastPrivatize(mlir::Operation *op); - void insertLastPrivateCompare(mlir::Operation *op); - void cloneSymbol(const Fortran::semantics::Symbol *sym); - void copyFirstPrivateSymbol(const Fortran::semantics::Symbol *sym); - void copyLastPrivateSymbol(const Fortran::semantics::Symbol *sym, - mlir::OpBuilder::InsertPoint *lastPrivIP); - void insertDeallocs(); - -public: - DataSharingProcessor(Fortran::lower::AbstractConverter &converter, - const Fortran::parser::OmpClauseList &opClauseList, - Fortran::lower::pft::Evaluation &eval) - : hasLastPrivateOp(false), converter(converter), - firOpBuilder(converter.getFirOpBuilder()), opClauseList(opClauseList), - eval(eval) {} - // Privatisation is split into two steps. - // Step1 performs cloning of all privatisation clauses and copying for - // firstprivates. Step1 is performed at the place where process/processStep1 - // is called. This is usually inside the Operation corresponding to the OpenMP - // construct, for looping constructs this is just before the Operation. The - // split into two steps was performed basically to be able to call - // privatisation for looping constructs before the operation is created since - // the bounds of the MLIR OpenMP operation can be privatised. - // Step2 performs the copying for lastprivates and requires knowledge of the - // MLIR operation to insert the last private update. Step2 adds - // dealocation code as well. - void processStep1(); - void processStep2(mlir::Operation *op, bool isLoop); - - void setLoopIV(mlir::Value iv) { - assert(!loopIV && "Loop iteration variable already set"); - loopIV = iv; - } -}; - -void DataSharingProcessor::processStep1() { - collectSymbolsForPrivatization(); - collectDefaultSymbols(); - privatize(); - defaultPrivatize(); - insertBarrier(); -} - -void DataSharingProcessor::processStep2(mlir::Operation *op, bool isLoop) { - insPt = firOpBuilder.saveInsertionPoint(); - copyLastPrivatize(op); - firOpBuilder.restoreInsertionPoint(insPt); - - if (isLoop) { - // push deallocs out of the loop - firOpBuilder.setInsertionPointAfter(op); - insertDeallocs(); - } else { - // insert dummy instruction to mark the insertion position - mlir::Value undefMarker = firOpBuilder.create( - op->getLoc(), firOpBuilder.getIndexType()); - insertDeallocs(); - firOpBuilder.setInsertionPointAfter(undefMarker.getDefiningOp()); - } -} - -void DataSharingProcessor::insertDeallocs() { - for (const Fortran::semantics::Symbol *sym : privatizedSymbols) - if (Fortran::semantics::IsAllocatable(sym->GetUltimate())) { - converter.createHostAssociateVarCloneDealloc(*sym); - } -} - -void DataSharingProcessor::cloneSymbol(const Fortran::semantics::Symbol *sym) { - // Privatization for symbols which are pre-determined (like loop index - // variables) happen separately, for everything else privatize here. - if (sym->test(Fortran::semantics::Symbol::Flag::OmpPreDetermined)) - return; - bool success = converter.createHostAssociateVarClone(*sym); - (void)success; - assert(success && "Privatization failed due to existing binding"); -} - -void DataSharingProcessor::copyFirstPrivateSymbol( - const Fortran::semantics::Symbol *sym) { - if (sym->test(Fortran::semantics::Symbol::Flag::OmpFirstPrivate)) - converter.copyHostAssociateVar(*sym); -} - -void DataSharingProcessor::copyLastPrivateSymbol( - const Fortran::semantics::Symbol *sym, - [[maybe_unused]] mlir::OpBuilder::InsertPoint *lastPrivIP) { - if (sym->test(Fortran::semantics::Symbol::Flag::OmpLastPrivate)) - converter.copyHostAssociateVar(*sym, lastPrivIP); -} - -void DataSharingProcessor::collectOmpObjectListSymbol( - const Fortran::parser::OmpObjectList &ompObjectList, - llvm::SetVector &symbolSet) { - for (const Fortran::parser::OmpObject &ompObject : ompObjectList.v) { - Fortran::semantics::Symbol *sym = getOmpObjectSymbol(ompObject); - symbolSet.insert(sym); - } -} - -void DataSharingProcessor::collectSymbolsForPrivatization() { - bool hasCollapse = false; - for (const Fortran::parser::OmpClause &clause : opClauseList.v) { - if (const auto &privateClause = - std::get_if(&clause.u)) { - collectOmpObjectListSymbol(privateClause->v, privatizedSymbols); - } else if (const auto &firstPrivateClause = - std::get_if( - &clause.u)) { - collectOmpObjectListSymbol(firstPrivateClause->v, privatizedSymbols); - } else if (const auto &lastPrivateClause = - std::get_if( - &clause.u)) { - collectOmpObjectListSymbol(lastPrivateClause->v, privatizedSymbols); - hasLastPrivateOp = true; - } else if (std::get_if(&clause.u)) { - hasCollapse = true; - } - } - - if (hasCollapse && hasLastPrivateOp) - TODO(converter.getCurrentLocation(), "Collapse clause with lastprivate"); -} - -bool DataSharingProcessor::needBarrier() { - for (const Fortran::semantics::Symbol *sym : privatizedSymbols) { - if (sym->test(Fortran::semantics::Symbol::Flag::OmpFirstPrivate) && - sym->test(Fortran::semantics::Symbol::Flag::OmpLastPrivate)) - return true; - } - return false; -} - -void DataSharingProcessor::insertBarrier() { - // Emit implicit barrier to synchronize threads and avoid data races on - // initialization of firstprivate variables and post-update of lastprivate - // variables. - // FIXME: Emit barrier for lastprivate clause when 'sections' directive has - // 'nowait' clause. Otherwise, emit barrier when 'sections' directive has - // both firstprivate and lastprivate clause. - // Emit implicit barrier for linear clause. Maybe on somewhere else. - if (needBarrier()) - firOpBuilder.create(converter.getCurrentLocation()); -} - -void DataSharingProcessor::insertLastPrivateCompare(mlir::Operation *op) { - bool cmpCreated = false; - mlir::OpBuilder::InsertPoint localInsPt = firOpBuilder.saveInsertionPoint(); - for (const Fortran::parser::OmpClause &clause : opClauseList.v) { - if (std::get_if(&clause.u)) { - // TODO: Add lastprivate support for simd construct - if (mlir::isa(op)) { - if (&eval == &eval.parentConstruct->getLastNestedEvaluation()) { - // For `omp.sections`, lastprivatized variables occur in - // lexically final `omp.section` operation. The following FIR - // shall be generated for the same: - // - // omp.sections lastprivate(...) { - // omp.section {...} - // omp.section {...} - // omp.section { - // fir.allocate for `private`/`firstprivate` - // - // fir.if %true { - // ^%lpv_update_blk - // } - // } - // } - // - // To keep code consistency while handling privatization - // through this control flow, add a `fir.if` operation - // that always evaluates to true, in order to create - // a dedicated sub-region in `omp.section` where - // lastprivate FIR can reside. Later canonicalizations - // will optimize away this operation. - if (!eval.lowerAsUnstructured()) { - auto ifOp = firOpBuilder.create( - op->getLoc(), - firOpBuilder.createIntegerConstant( - op->getLoc(), firOpBuilder.getIntegerType(1), 0x1), - /*else*/ false); - firOpBuilder.setInsertionPointToStart( - &ifOp.getThenRegion().front()); - - const Fortran::parser::OpenMPConstruct *parentOmpConstruct = - eval.parentConstruct->getIf(); - assert(parentOmpConstruct && - "Expected a valid enclosing OpenMP construct"); - const Fortran::parser::OpenMPSectionsConstruct *sectionsConstruct = - std::get_if( - &parentOmpConstruct->u); - assert(sectionsConstruct && - "Expected an enclosing omp.sections construct"); - const Fortran::parser::OmpClauseList §ionsEndClauseList = - std::get( - std::get( - sectionsConstruct->t) - .t); - for (const Fortran::parser::OmpClause &otherClause : - sectionsEndClauseList.v) - if (std::get_if( - &otherClause.u)) - // Emit implicit barrier to synchronize threads and avoid data - // races on post-update of lastprivate variables when `nowait` - // clause is present. - firOpBuilder.create( - converter.getCurrentLocation()); - firOpBuilder.setInsertionPointToStart( - &ifOp.getThenRegion().front()); - lastPrivIP = firOpBuilder.saveInsertionPoint(); - firOpBuilder.setInsertionPoint(ifOp); - insPt = firOpBuilder.saveInsertionPoint(); - } else { - // Lastprivate operation is inserted at the end - // of the lexically last section in the sections - // construct - mlir::OpBuilder::InsertPoint unstructuredSectionsIP = - firOpBuilder.saveInsertionPoint(); - mlir::Operation *lastOper = op->getRegion(0).back().getTerminator(); - firOpBuilder.setInsertionPoint(lastOper); - lastPrivIP = firOpBuilder.saveInsertionPoint(); - firOpBuilder.restoreInsertionPoint(unstructuredSectionsIP); - } - } - } else if (mlir::isa(op)) { - // Update the original variable just before exiting the worksharing - // loop. Conversion as follows: - // - // omp.wsloop { - // omp.wsloop { ... - // ... store - // store ===> %v = arith.addi %iv, %step - // omp.yield %cmp = %step < 0 ? %v < %ub : %v > %ub - // } fir.if %cmp { - // fir.store %v to %loopIV - // ^%lpv_update_blk: - // } - // omp.yield - // } - // - - // Only generate the compare once in presence of multiple LastPrivate - // clauses. - if (cmpCreated) - continue; - cmpCreated = true; - - mlir::Location loc = op->getLoc(); - mlir::Operation *lastOper = op->getRegion(0).back().getTerminator(); - firOpBuilder.setInsertionPoint(lastOper); - - mlir::Value iv = op->getRegion(0).front().getArguments()[0]; - mlir::Value ub = - mlir::dyn_cast(op).getUpperBound()[0]; - mlir::Value step = mlir::dyn_cast(op).getStep()[0]; - - // v = iv + step - // cmp = step < 0 ? v < ub : v > ub - mlir::Value v = firOpBuilder.create(loc, iv, step); - mlir::Value zero = - firOpBuilder.createIntegerConstant(loc, step.getType(), 0); - mlir::Value negativeStep = firOpBuilder.create( - loc, mlir::arith::CmpIPredicate::slt, step, zero); - mlir::Value vLT = firOpBuilder.create( - loc, mlir::arith::CmpIPredicate::slt, v, ub); - mlir::Value vGT = firOpBuilder.create( - loc, mlir::arith::CmpIPredicate::sgt, v, ub); - mlir::Value cmpOp = firOpBuilder.create( - loc, negativeStep, vLT, vGT); - - auto ifOp = firOpBuilder.create(loc, cmpOp, /*else*/ false); - firOpBuilder.setInsertionPointToStart(&ifOp.getThenRegion().front()); - assert(loopIV && "loopIV was not set"); - firOpBuilder.create(op->getLoc(), v, loopIV); - lastPrivIP = firOpBuilder.saveInsertionPoint(); - } else { - TODO(converter.getCurrentLocation(), - "lastprivate clause in constructs other than " - "simd/worksharing-loop"); - } - } - } - firOpBuilder.restoreInsertionPoint(localInsPt); -} - -void DataSharingProcessor::collectSymbols( - Fortran::semantics::Symbol::Flag flag) { - converter.collectSymbolSet(eval, defaultSymbols, flag, - /*collectSymbols=*/true, - /*collectHostAssociatedSymbols=*/true); - for (Fortran::lower::pft::Evaluation &e : eval.getNestedEvaluations()) { - if (e.hasNestedEvaluations()) - converter.collectSymbolSet(e, symbolsInNestedRegions, flag, - /*collectSymbols=*/true, - /*collectHostAssociatedSymbols=*/false); - else - converter.collectSymbolSet(e, symbolsInParentRegions, flag, - /*collectSymbols=*/false, - /*collectHostAssociatedSymbols=*/true); - } -} - -void DataSharingProcessor::collectDefaultSymbols() { - for (const Fortran::parser::OmpClause &clause : opClauseList.v) { - if (const auto &defaultClause = - std::get_if(&clause.u)) { - if (defaultClause->v.v == - Fortran::parser::OmpDefaultClause::Type::Private) - collectSymbols(Fortran::semantics::Symbol::Flag::OmpPrivate); - else if (defaultClause->v.v == - Fortran::parser::OmpDefaultClause::Type::Firstprivate) - collectSymbols(Fortran::semantics::Symbol::Flag::OmpFirstPrivate); - } - } -} - -void DataSharingProcessor::privatize() { - for (const Fortran::semantics::Symbol *sym : privatizedSymbols) { - if (const auto *commonDet = - sym->detailsIf()) { - for (const auto &mem : commonDet->objects()) { - cloneSymbol(&*mem); - copyFirstPrivateSymbol(&*mem); - } - } else { - cloneSymbol(sym); - copyFirstPrivateSymbol(sym); - } - } -} - -void DataSharingProcessor::copyLastPrivatize(mlir::Operation *op) { - insertLastPrivateCompare(op); - for (const Fortran::semantics::Symbol *sym : privatizedSymbols) - if (const auto *commonDet = - sym->detailsIf()) { - for (const auto &mem : commonDet->objects()) { - copyLastPrivateSymbol(&*mem, &lastPrivIP); - } - } else { - copyLastPrivateSymbol(sym, &lastPrivIP); - } -} - -void DataSharingProcessor::defaultPrivatize() { - for (const Fortran::semantics::Symbol *sym : defaultSymbols) { - if (!Fortran::semantics::IsProcedure(*sym) && - !sym->GetUltimate().has() && - !sym->GetUltimate().has() && - !symbolsInNestedRegions.contains(sym) && - !symbolsInParentRegions.contains(sym) && - !privatizedSymbols.contains(sym)) { - cloneSymbol(sym); - copyFirstPrivateSymbol(sym); - } - } -} - -//===----------------------------------------------------------------------===// -// ClauseProcessor -//===----------------------------------------------------------------------===// - -/// Class that handles the processing of OpenMP clauses. -/// -/// Its `process()` methods perform MLIR code generation for their -/// corresponding clause if it is present in the clause list. Otherwise, they -/// will return `false` to signal that the clause was not found. -/// -/// The intended use is of this class is to move clause processing outside of -/// construct processing, since the same clauses can appear attached to -/// different constructs and constructs can be combined, so that code -/// duplication is minimized. -/// -/// Each construct-lowering function only calls the `process()` -/// methods that relate to clauses that can impact the lowering of that -/// construct. -class ClauseProcessor { - using ClauseTy = Fortran::parser::OmpClause; - -public: - ClauseProcessor(Fortran::lower::AbstractConverter &converter, - Fortran::semantics::SemanticsContext &semaCtx, - const Fortran::parser::OmpClauseList &clauses) - : converter(converter), semaCtx(semaCtx), clauses(clauses) {} - - // 'Unique' clauses: They can appear at most once in the clause list. - bool - processCollapse(mlir::Location currentLocation, - Fortran::lower::pft::Evaluation &eval, - llvm::SmallVectorImpl &lowerBound, - llvm::SmallVectorImpl &upperBound, - llvm::SmallVectorImpl &step, - llvm::SmallVectorImpl &iv, - std::size_t &loopVarTypeSize) const; - bool processDefault() const; - bool processDevice(Fortran::lower::StatementContext &stmtCtx, - mlir::Value &result) const; - bool processDeviceType(mlir::omp::DeclareTargetDeviceType &result) const; - bool processFinal(Fortran::lower::StatementContext &stmtCtx, - mlir::Value &result) const; - bool processHint(mlir::IntegerAttr &result) const; - bool processMergeable(mlir::UnitAttr &result) const; - bool processNowait(mlir::UnitAttr &result) const; - bool processNumTeams(Fortran::lower::StatementContext &stmtCtx, - mlir::Value &result) const; - bool processNumThreads(Fortran::lower::StatementContext &stmtCtx, - mlir::Value &result) const; - bool processOrdered(mlir::IntegerAttr &result) const; - bool processPriority(Fortran::lower::StatementContext &stmtCtx, - mlir::Value &result) const; - bool processProcBind(mlir::omp::ClauseProcBindKindAttr &result) const; - bool processSafelen(mlir::IntegerAttr &result) const; - bool processSchedule(mlir::omp::ClauseScheduleKindAttr &valAttr, - mlir::omp::ScheduleModifierAttr &modifierAttr, - mlir::UnitAttr &simdModifierAttr) const; - bool processScheduleChunk(Fortran::lower::StatementContext &stmtCtx, - mlir::Value &result) const; - bool processSimdlen(mlir::IntegerAttr &result) const; - bool processThreadLimit(Fortran::lower::StatementContext &stmtCtx, - mlir::Value &result) const; - bool processUntied(mlir::UnitAttr &result) const; - - // 'Repeatable' clauses: They can appear multiple times in the clause list. - bool - processAllocate(llvm::SmallVectorImpl &allocatorOperands, - llvm::SmallVectorImpl &allocateOperands) const; - bool processCopyin() const; - bool processDepend(llvm::SmallVectorImpl &dependTypeOperands, - llvm::SmallVectorImpl &dependOperands) const; - bool - processEnter(llvm::SmallVectorImpl &result) const; - bool - processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier directiveName, - mlir::Value &result) const; - bool - processLink(llvm::SmallVectorImpl &result) const; - - // This method is used to process a map clause. - // The optional parameters - mapSymTypes, mapSymLocs & mapSymbols are used to - // store the original type, location and Fortran symbol for the map operands. - // They may be used later on to create the block_arguments for some of the - // target directives that require it. - bool processMap(mlir::Location currentLocation, - const llvm::omp::Directive &directive, - Fortran::lower::StatementContext &stmtCtx, - llvm::SmallVectorImpl &mapOperands, - llvm::SmallVectorImpl *mapSymTypes = nullptr, - llvm::SmallVectorImpl *mapSymLocs = nullptr, - llvm::SmallVectorImpl - *mapSymbols = nullptr) const; - bool - processReduction(mlir::Location currentLocation, - llvm::SmallVectorImpl &reductionVars, - llvm::SmallVectorImpl &reductionDeclSymbols, - llvm::SmallVectorImpl - *reductionSymbols = nullptr) const; - bool processSectionsReduction(mlir::Location currentLocation) const; - bool processTo(llvm::SmallVectorImpl &result) const; - bool - processUseDeviceAddr(llvm::SmallVectorImpl &operands, - llvm::SmallVectorImpl &useDeviceTypes, - llvm::SmallVectorImpl &useDeviceLocs, - llvm::SmallVectorImpl - &useDeviceSymbols) const; - bool - processUseDevicePtr(llvm::SmallVectorImpl &operands, - llvm::SmallVectorImpl &useDeviceTypes, - llvm::SmallVectorImpl &useDeviceLocs, - llvm::SmallVectorImpl - &useDeviceSymbols) const; - - template - bool processMotionClauses(Fortran::lower::StatementContext &stmtCtx, - llvm::SmallVectorImpl &mapOperands); - - // Call this method for these clauses that should be supported but are not - // implemented yet. It triggers a compilation error if any of the given - // clauses is found. - template - void processTODO(mlir::Location currentLocation, - llvm::omp::Directive directive) const; - -private: - using ClauseIterator = std::list::const_iterator; - - /// Utility to find a clause within a range in the clause list. - template - static ClauseIterator findClause(ClauseIterator begin, ClauseIterator end) { - for (ClauseIterator it = begin; it != end; ++it) { - if (std::get_if(&it->u)) - return it; - } - - return end; - } - - /// Return the first instance of the given clause found in the clause list or - /// `nullptr` if not present. If more than one instance is expected, use - /// `findRepeatableClause` instead. - template - const T * - findUniqueClause(const Fortran::parser::CharBlock **source = nullptr) const { - ClauseIterator it = findClause(clauses.v.begin(), clauses.v.end()); - if (it != clauses.v.end()) { - if (source) - *source = &it->source; - return &std::get(it->u); - } - return nullptr; - } - - /// Call `callbackFn` for each occurrence of the given clause. Return `true` - /// if at least one instance was found. - template - bool findRepeatableClause( - std::function - callbackFn) const { - bool found = false; - ClauseIterator nextIt, endIt = clauses.v.end(); - for (ClauseIterator it = clauses.v.begin(); it != endIt; it = nextIt) { - nextIt = findClause(it, endIt); - - if (nextIt != endIt) { - callbackFn(&std::get(nextIt->u), nextIt->source); - found = true; - ++nextIt; - } - } - return found; - } - - /// Set the `result` to a new `mlir::UnitAttr` if the clause is present. - template - bool markClauseOccurrence(mlir::UnitAttr &result) const { - if (findUniqueClause()) { - result = converter.getFirOpBuilder().getUnitAttr(); - return true; - } - return false; - } - - Fortran::lower::AbstractConverter &converter; - Fortran::semantics::SemanticsContext &semaCtx; - const Fortran::parser::OmpClauseList &clauses; -}; - -//===----------------------------------------------------------------------===// -// ClauseProcessor helper functions -//===----------------------------------------------------------------------===// - -/// Check for unsupported map operand types. -static void checkMapType(mlir::Location location, mlir::Type type) { - if (auto refType = type.dyn_cast()) - type = refType.getElementType(); - if (auto boxType = type.dyn_cast_or_null()) - if (!boxType.getElementType().isa()) - TODO(location, "OMPD_target_data MapOperand BoxType"); -} - -class ReductionProcessor { -public: - // TODO: Move this enumeration to the OpenMP dialect - enum ReductionIdentifier { - ID, - USER_DEF_OP, - ADD, - SUBTRACT, - MULTIPLY, - AND, - OR, - EQV, - NEQV, - MAX, - MIN, - IAND, - IOR, - IEOR - }; - static ReductionIdentifier - getReductionType(const Fortran::parser::ProcedureDesignator &pd) { - auto redType = llvm::StringSwitch>( - getRealName(pd).ToString()) - .Case("max", ReductionIdentifier::MAX) - .Case("min", ReductionIdentifier::MIN) - .Case("iand", ReductionIdentifier::IAND) - .Case("ior", ReductionIdentifier::IOR) - .Case("ieor", ReductionIdentifier::IEOR) - .Default(std::nullopt); - assert(redType && "Invalid Reduction"); - return *redType; - } - - static ReductionIdentifier getReductionType( - Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp) { - switch (intrinsicOp) { - case Fortran::parser::DefinedOperator::IntrinsicOperator::Add: - return ReductionIdentifier::ADD; - case Fortran::parser::DefinedOperator::IntrinsicOperator::Subtract: - return ReductionIdentifier::SUBTRACT; - case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply: - return ReductionIdentifier::MULTIPLY; - case Fortran::parser::DefinedOperator::IntrinsicOperator::AND: - return ReductionIdentifier::AND; - case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV: - return ReductionIdentifier::EQV; - case Fortran::parser::DefinedOperator::IntrinsicOperator::OR: - return ReductionIdentifier::OR; - case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV: - return ReductionIdentifier::NEQV; - default: - llvm_unreachable("unexpected intrinsic operator in reduction"); - } - } - - static bool supportedIntrinsicProcReduction( - const Fortran::parser::ProcedureDesignator &pd) { - const auto *name{Fortran::parser::Unwrap(pd)}; - assert(name && "Invalid Reduction Intrinsic."); - if (!name->symbol->GetUltimate().attrs().test( - Fortran::semantics::Attr::INTRINSIC)) - return false; - auto redType = llvm::StringSwitch(getRealName(name).ToString()) - .Case("max", true) - .Case("min", true) - .Case("iand", true) - .Case("ior", true) - .Case("ieor", true) - .Default(false); - return redType; - } - - static const Fortran::semantics::SourceName - getRealName(const Fortran::parser::Name *name) { - return name->symbol->GetUltimate().name(); - } - - static const Fortran::semantics::SourceName - getRealName(const Fortran::parser::ProcedureDesignator &pd) { - const auto *name{Fortran::parser::Unwrap(pd)}; - assert(name && "Invalid Reduction Intrinsic."); - return getRealName(name); - } - - static std::string getReductionName(llvm::StringRef name, mlir::Type ty) { - return (llvm::Twine(name) + - (ty.isIntOrIndex() ? llvm::Twine("_i_") : llvm::Twine("_f_")) + - llvm::Twine(ty.getIntOrFloatBitWidth())) - .str(); - } - - static std::string getReductionName( - Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp, - mlir::Type ty) { - std::string reductionName; - - switch (intrinsicOp) { - case Fortran::parser::DefinedOperator::IntrinsicOperator::Add: - reductionName = "add_reduction"; - break; - case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply: - reductionName = "multiply_reduction"; - break; - case Fortran::parser::DefinedOperator::IntrinsicOperator::AND: - return "and_reduction"; - case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV: - return "eqv_reduction"; - case Fortran::parser::DefinedOperator::IntrinsicOperator::OR: - return "or_reduction"; - case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV: - return "neqv_reduction"; - default: - reductionName = "other_reduction"; - break; - } - - return getReductionName(reductionName, ty); - } - - /// This function returns the identity value of the operator \p - /// reductionOpName. For example: - /// 0 + x = x, - /// 1 * x = x - static int getOperationIdentity(ReductionIdentifier redId, - mlir::Location loc) { - switch (redId) { - case ReductionIdentifier::ADD: - case ReductionIdentifier::OR: - case ReductionIdentifier::NEQV: - return 0; - case ReductionIdentifier::MULTIPLY: - case ReductionIdentifier::AND: - case ReductionIdentifier::EQV: - return 1; - default: - TODO(loc, "Reduction of some intrinsic operators is not supported"); - } - } - - static mlir::Value getReductionInitValue(mlir::Location loc, mlir::Type type, - ReductionIdentifier redId, - fir::FirOpBuilder &builder) { - assert((fir::isa_integer(type) || fir::isa_real(type) || - type.isa()) && - "only integer, logical and real types are currently supported"); - switch (redId) { - case ReductionIdentifier::MAX: { - if (auto ty = type.dyn_cast()) { - const llvm::fltSemantics &sem = ty.getFloatSemantics(); - return builder.createRealConstant( - loc, type, llvm::APFloat::getLargest(sem, /*Negative=*/true)); - } - unsigned bits = type.getIntOrFloatBitWidth(); - int64_t minInt = llvm::APInt::getSignedMinValue(bits).getSExtValue(); - return builder.createIntegerConstant(loc, type, minInt); - } - case ReductionIdentifier::MIN: { - if (auto ty = type.dyn_cast()) { - const llvm::fltSemantics &sem = ty.getFloatSemantics(); - return builder.createRealConstant( - loc, type, llvm::APFloat::getLargest(sem, /*Negative=*/false)); - } - unsigned bits = type.getIntOrFloatBitWidth(); - int64_t maxInt = llvm::APInt::getSignedMaxValue(bits).getSExtValue(); - return builder.createIntegerConstant(loc, type, maxInt); - } - case ReductionIdentifier::IOR: { - unsigned bits = type.getIntOrFloatBitWidth(); - int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue(); - return builder.createIntegerConstant(loc, type, zeroInt); - } - case ReductionIdentifier::IEOR: { - unsigned bits = type.getIntOrFloatBitWidth(); - int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue(); - return builder.createIntegerConstant(loc, type, zeroInt); - } - case ReductionIdentifier::IAND: { - unsigned bits = type.getIntOrFloatBitWidth(); - int64_t allOnInt = llvm::APInt::getAllOnes(bits).getSExtValue(); - return builder.createIntegerConstant(loc, type, allOnInt); - } - case ReductionIdentifier::ADD: - case ReductionIdentifier::MULTIPLY: - case ReductionIdentifier::AND: - case ReductionIdentifier::OR: - case ReductionIdentifier::EQV: - case ReductionIdentifier::NEQV: - if (type.isa()) - return builder.create( - loc, type, - builder.getFloatAttr(type, - (double)getOperationIdentity(redId, loc))); - - if (type.isa()) { - mlir::Value intConst = builder.create( - loc, builder.getI1Type(), - builder.getIntegerAttr(builder.getI1Type(), - getOperationIdentity(redId, loc))); - return builder.createConvert(loc, type, intConst); - } - - return builder.create( - loc, type, - builder.getIntegerAttr(type, getOperationIdentity(redId, loc))); - case ReductionIdentifier::ID: - case ReductionIdentifier::USER_DEF_OP: - case ReductionIdentifier::SUBTRACT: - TODO(loc, "Reduction of some identifier types is not supported"); - } - llvm_unreachable("Unhandled Reduction identifier : getReductionInitValue"); - } - - template - static mlir::Value getReductionOperation(fir::FirOpBuilder &builder, - mlir::Type type, mlir::Location loc, - mlir::Value op1, mlir::Value op2) { - assert(type.isIntOrIndexOrFloat() && - "only integer and float types are currently supported"); - if (type.isIntOrIndex()) - return builder.create(loc, op1, op2); - return builder.create(loc, op1, op2); - } - - static mlir::Value createScalarCombiner(fir::FirOpBuilder &builder, - mlir::Location loc, - ReductionIdentifier redId, - mlir::Type type, mlir::Value op1, - mlir::Value op2) { - mlir::Value reductionOp; - switch (redId) { - case ReductionIdentifier::MAX: - reductionOp = - getReductionOperation( - builder, type, loc, op1, op2); - break; - case ReductionIdentifier::MIN: - reductionOp = - getReductionOperation( - builder, type, loc, op1, op2); - break; - case ReductionIdentifier::IOR: - assert((type.isIntOrIndex()) && "only integer is expected"); - reductionOp = builder.create(loc, op1, op2); - break; - case ReductionIdentifier::IEOR: - assert((type.isIntOrIndex()) && "only integer is expected"); - reductionOp = builder.create(loc, op1, op2); - break; - case ReductionIdentifier::IAND: - assert((type.isIntOrIndex()) && "only integer is expected"); - reductionOp = builder.create(loc, op1, op2); - break; - case ReductionIdentifier::ADD: - reductionOp = - getReductionOperation( - builder, type, loc, op1, op2); - break; - case ReductionIdentifier::MULTIPLY: - reductionOp = - getReductionOperation( - builder, type, loc, op1, op2); - break; - case ReductionIdentifier::AND: { - mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1); - mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2); - - mlir::Value andiOp = - builder.create(loc, op1I1, op2I1); - - reductionOp = builder.createConvert(loc, type, andiOp); - break; - } - case ReductionIdentifier::OR: { - mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1); - mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2); - - mlir::Value oriOp = builder.create(loc, op1I1, op2I1); - - reductionOp = builder.createConvert(loc, type, oriOp); - break; - } - case ReductionIdentifier::EQV: { - mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1); - mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2); - - mlir::Value cmpiOp = builder.create( - loc, mlir::arith::CmpIPredicate::eq, op1I1, op2I1); - - reductionOp = builder.createConvert(loc, type, cmpiOp); - break; - } - case ReductionIdentifier::NEQV: { - mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1); - mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2); - - mlir::Value cmpiOp = builder.create( - loc, mlir::arith::CmpIPredicate::ne, op1I1, op2I1); - - reductionOp = builder.createConvert(loc, type, cmpiOp); - break; - } - default: - TODO(loc, "Reduction of some intrinsic operators is not supported"); - } - - return reductionOp; - } - - /// Creates an OpenMP reduction declaration and inserts it into the provided - /// symbol table. The declaration has a constant initializer with the neutral - /// value `initValue`, and the reduction combiner carried over from `reduce`. - /// TODO: Generalize this for non-integer types, add atomic region. - static mlir::omp::ReductionDeclareOp createReductionDecl( - fir::FirOpBuilder &builder, llvm::StringRef reductionOpName, - const ReductionIdentifier redId, mlir::Type type, mlir::Location loc) { - mlir::OpBuilder::InsertionGuard guard(builder); - mlir::ModuleOp module = builder.getModule(); - - auto decl = - module.lookupSymbol(reductionOpName); - if (decl) - return decl; - - mlir::OpBuilder modBuilder(module.getBodyRegion()); - - decl = modBuilder.create( - loc, reductionOpName, type); - builder.createBlock(&decl.getInitializerRegion(), - decl.getInitializerRegion().end(), {type}, {loc}); - builder.setInsertionPointToEnd(&decl.getInitializerRegion().back()); - mlir::Value init = getReductionInitValue(loc, type, redId, builder); - builder.create(loc, init); - - builder.createBlock(&decl.getReductionRegion(), - decl.getReductionRegion().end(), {type, type}, - {loc, loc}); - - builder.setInsertionPointToEnd(&decl.getReductionRegion().back()); - mlir::Value op1 = decl.getReductionRegion().front().getArgument(0); - mlir::Value op2 = decl.getReductionRegion().front().getArgument(1); - - mlir::Value reductionOp = - createScalarCombiner(builder, loc, redId, type, op1, op2); - builder.create(loc, reductionOp); - - return decl; - } - - /// Creates a reduction declaration and associates it with an OpenMP block - /// directive. - static void - addReductionDecl(mlir::Location currentLocation, - Fortran::lower::AbstractConverter &converter, - const Fortran::parser::OmpReductionClause &reduction, - llvm::SmallVectorImpl &reductionVars, - llvm::SmallVectorImpl &reductionDeclSymbols, - llvm::SmallVectorImpl - *reductionSymbols = nullptr) { - fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - mlir::omp::ReductionDeclareOp decl; - const auto &redOperator{ - std::get(reduction.t)}; - const auto &objectList{ - std::get(reduction.t)}; - if (const auto &redDefinedOp = - std::get_if(&redOperator.u)) { - const auto &intrinsicOp{ - std::get( - redDefinedOp->u)}; - ReductionIdentifier redId = getReductionType(intrinsicOp); - switch (redId) { - case ReductionIdentifier::ADD: - case ReductionIdentifier::MULTIPLY: - case ReductionIdentifier::AND: - case ReductionIdentifier::EQV: - case ReductionIdentifier::OR: - case ReductionIdentifier::NEQV: - break; - default: - TODO(currentLocation, - "Reduction of some intrinsic operators is not supported"); - break; - } - for (const Fortran::parser::OmpObject &ompObject : objectList.v) { - if (const auto *name{ - Fortran::parser::Unwrap(ompObject)}) { - if (const Fortran::semantics::Symbol * symbol{name->symbol}) { - if (reductionSymbols) - reductionSymbols->push_back(symbol); - mlir::Value symVal = converter.getSymbolAddress(*symbol); - if (auto declOp = symVal.getDefiningOp()) - symVal = declOp.getBase(); - mlir::Type redType = - symVal.getType().cast().getEleTy(); - reductionVars.push_back(symVal); - if (redType.isa()) - decl = createReductionDecl( - firOpBuilder, - getReductionName(intrinsicOp, firOpBuilder.getI1Type()), - redId, redType, currentLocation); - else if (redType.isIntOrIndexOrFloat()) { - decl = createReductionDecl(firOpBuilder, - getReductionName(intrinsicOp, redType), - redId, redType, currentLocation); - } else { - TODO(currentLocation, "Reduction of some types is not supported"); - } - reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get( - firOpBuilder.getContext(), decl.getSymName())); - } - } - } - } else if (const auto *reductionIntrinsic = - std::get_if( - &redOperator.u)) { - if (ReductionProcessor::supportedIntrinsicProcReduction( - *reductionIntrinsic)) { - ReductionProcessor::ReductionIdentifier redId = - ReductionProcessor::getReductionType(*reductionIntrinsic); - for (const Fortran::parser::OmpObject &ompObject : objectList.v) { - if (const auto *name{ - Fortran::parser::Unwrap(ompObject)}) { - if (const Fortran::semantics::Symbol * symbol{name->symbol}) { - if (reductionSymbols) - reductionSymbols->push_back(symbol); - mlir::Value symVal = converter.getSymbolAddress(*symbol); - if (auto declOp = symVal.getDefiningOp()) - symVal = declOp.getBase(); - mlir::Type redType = - symVal.getType().cast().getEleTy(); - reductionVars.push_back(symVal); - assert(redType.isIntOrIndexOrFloat() && - "Unsupported reduction type"); - decl = createReductionDecl( - firOpBuilder, - getReductionName(getRealName(*reductionIntrinsic).ToString(), - redType), - redId, redType, currentLocation); - reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get( - firOpBuilder.getContext(), decl.getSymName())); - } - } - } - } - } - } -}; - -static mlir::omp::ScheduleModifier -translateScheduleModifier(const Fortran::parser::OmpScheduleModifierType &m) { - switch (m.v) { - case Fortran::parser::OmpScheduleModifierType::ModType::Monotonic: - return mlir::omp::ScheduleModifier::monotonic; - case Fortran::parser::OmpScheduleModifierType::ModType::Nonmonotonic: - return mlir::omp::ScheduleModifier::nonmonotonic; - case Fortran::parser::OmpScheduleModifierType::ModType::Simd: - return mlir::omp::ScheduleModifier::simd; - } - return mlir::omp::ScheduleModifier::none; -} - -static mlir::omp::ScheduleModifier -getScheduleModifier(const Fortran::parser::OmpScheduleClause &x) { - const auto &modifier = - std::get>(x.t); - // The input may have the modifier any order, so we look for one that isn't - // SIMD. If modifier is not set at all, fall down to the bottom and return - // "none". - if (modifier) { - const auto &modType1 = - std::get(modifier->t); - if (modType1.v.v == - Fortran::parser::OmpScheduleModifierType::ModType::Simd) { - const auto &modType2 = std::get< - std::optional>( - modifier->t); - if (modType2 && - modType2->v.v != - Fortran::parser::OmpScheduleModifierType::ModType::Simd) - return translateScheduleModifier(modType2->v); - - return mlir::omp::ScheduleModifier::none; - } - - return translateScheduleModifier(modType1.v); - } - return mlir::omp::ScheduleModifier::none; -} - -static mlir::omp::ScheduleModifier -getSimdModifier(const Fortran::parser::OmpScheduleClause &x) { - const auto &modifier = - std::get>(x.t); - // Either of the two possible modifiers in the input can be the SIMD modifier, - // so look in either one, and return simd if we find one. Not found = return - // "none". - if (modifier) { - const auto &modType1 = - std::get(modifier->t); - if (modType1.v.v == Fortran::parser::OmpScheduleModifierType::ModType::Simd) - return mlir::omp::ScheduleModifier::simd; - - const auto &modType2 = std::get< - std::optional>( - modifier->t); - if (modType2 && modType2->v.v == - Fortran::parser::OmpScheduleModifierType::ModType::Simd) - return mlir::omp::ScheduleModifier::simd; - } - return mlir::omp::ScheduleModifier::none; -} - -static void -genAllocateClause(Fortran::lower::AbstractConverter &converter, - const Fortran::parser::OmpAllocateClause &ompAllocateClause, - llvm::SmallVectorImpl &allocatorOperands, - llvm::SmallVectorImpl &allocateOperands) { - fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - mlir::Location currentLocation = converter.getCurrentLocation(); - Fortran::lower::StatementContext stmtCtx; - - mlir::Value allocatorOperand; - const Fortran::parser::OmpObjectList &ompObjectList = - std::get(ompAllocateClause.t); - const auto &allocateModifier = std::get< - std::optional>( - ompAllocateClause.t); - - // If the allocate modifier is present, check if we only use the allocator - // submodifier. ALIGN in this context is unimplemented - const bool onlyAllocator = - allocateModifier && - std::holds_alternative< - Fortran::parser::OmpAllocateClause::AllocateModifier::Allocator>( - allocateModifier->u); - - if (allocateModifier && !onlyAllocator) { - TODO(currentLocation, "OmpAllocateClause ALIGN modifier"); - } - - // Check if allocate clause has allocator specified. If so, add it - // to list of allocators, otherwise, add default allocator to - // list of allocators. - if (onlyAllocator) { - const auto &allocatorValue = std::get< - Fortran::parser::OmpAllocateClause::AllocateModifier::Allocator>( - allocateModifier->u); - allocatorOperand = fir::getBase(converter.genExprValue( - *Fortran::semantics::GetExpr(allocatorValue.v), stmtCtx)); - allocatorOperands.insert(allocatorOperands.end(), ompObjectList.v.size(), - allocatorOperand); - } else { - allocatorOperand = firOpBuilder.createIntegerConstant( - currentLocation, firOpBuilder.getI32Type(), 1); - allocatorOperands.insert(allocatorOperands.end(), ompObjectList.v.size(), - allocatorOperand); - } - genObjectList(ompObjectList, converter, allocateOperands); -} - -static mlir::omp::ClauseProcBindKindAttr genProcBindKindAttr( - fir::FirOpBuilder &firOpBuilder, - const Fortran::parser::OmpClause::ProcBind *procBindClause) { - mlir::omp::ClauseProcBindKind procBindKind; - switch (procBindClause->v.v) { - case Fortran::parser::OmpProcBindClause::Type::Master: - procBindKind = mlir::omp::ClauseProcBindKind::Master; - break; - case Fortran::parser::OmpProcBindClause::Type::Close: - procBindKind = mlir::omp::ClauseProcBindKind::Close; - break; - case Fortran::parser::OmpProcBindClause::Type::Spread: - procBindKind = mlir::omp::ClauseProcBindKind::Spread; - break; - case Fortran::parser::OmpProcBindClause::Type::Primary: - procBindKind = mlir::omp::ClauseProcBindKind::Primary; - break; - } - return mlir::omp::ClauseProcBindKindAttr::get(firOpBuilder.getContext(), - procBindKind); -} - -static mlir::omp::ClauseTaskDependAttr -genDependKindAttr(fir::FirOpBuilder &firOpBuilder, - const Fortran::parser::OmpClause::Depend *dependClause) { - mlir::omp::ClauseTaskDepend pbKind; - switch ( - std::get( - std::get(dependClause->v.u) - .t) - .v) { - case Fortran::parser::OmpDependenceType::Type::In: - pbKind = mlir::omp::ClauseTaskDepend::taskdependin; - break; - case Fortran::parser::OmpDependenceType::Type::Out: - pbKind = mlir::omp::ClauseTaskDepend::taskdependout; - break; - case Fortran::parser::OmpDependenceType::Type::Inout: - pbKind = mlir::omp::ClauseTaskDepend::taskdependinout; - break; - default: - llvm_unreachable("unknown parser task dependence type"); - break; - } - return mlir::omp::ClauseTaskDependAttr::get(firOpBuilder.getContext(), - pbKind); -} - -static mlir::Value getIfClauseOperand( - Fortran::lower::AbstractConverter &converter, - const Fortran::parser::OmpClause::If *ifClause, - Fortran::parser::OmpIfClause::DirectiveNameModifier directiveName, - mlir::Location clauseLocation) { - // Only consider the clause if it's intended for the given directive. - auto &directive = std::get< - std::optional>( - ifClause->v.t); - if (directive && directive.value() != directiveName) - return nullptr; - - Fortran::lower::StatementContext stmtCtx; - fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - auto &expr = std::get(ifClause->v.t); - mlir::Value ifVal = fir::getBase( - converter.genExprValue(*Fortran::semantics::GetExpr(expr), stmtCtx)); - return firOpBuilder.createConvert(clauseLocation, firOpBuilder.getI1Type(), - ifVal); -} - -static void -addUseDeviceClause(Fortran::lower::AbstractConverter &converter, - const Fortran::parser::OmpObjectList &useDeviceClause, - llvm::SmallVectorImpl &operands, - llvm::SmallVectorImpl &useDeviceTypes, - llvm::SmallVectorImpl &useDeviceLocs, - llvm::SmallVectorImpl - &useDeviceSymbols) { - genObjectList(useDeviceClause, converter, operands); - for (mlir::Value &operand : operands) { - checkMapType(operand.getLoc(), operand.getType()); - useDeviceTypes.push_back(operand.getType()); - useDeviceLocs.push_back(operand.getLoc()); - } - for (const Fortran::parser::OmpObject &ompObject : useDeviceClause.v) { - Fortran::semantics::Symbol *sym = getOmpObjectSymbol(ompObject); - useDeviceSymbols.push_back(sym); - } -} - -//===----------------------------------------------------------------------===// -// ClauseProcessor unique clauses -//===----------------------------------------------------------------------===// - -bool ClauseProcessor::processCollapse( - mlir::Location currentLocation, Fortran::lower::pft::Evaluation &eval, - llvm::SmallVectorImpl &lowerBound, - llvm::SmallVectorImpl &upperBound, - llvm::SmallVectorImpl &step, - llvm::SmallVectorImpl &iv, - std::size_t &loopVarTypeSize) const { - bool found = false; - fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - - // Collect the loops to collapse. - Fortran::lower::pft::Evaluation *doConstructEval = - &eval.getFirstNestedEvaluation(); - if (doConstructEval->getIf() - ->IsDoConcurrent()) { - TODO(currentLocation, "Do Concurrent in Worksharing loop construct"); - } - - std::int64_t collapseValue = 1l; - if (auto *collapseClause = findUniqueClause()) { - const auto *expr = Fortran::semantics::GetExpr(collapseClause->v); - collapseValue = Fortran::evaluate::ToInt64(*expr).value(); - found = true; - } - - loopVarTypeSize = 0; - do { - Fortran::lower::pft::Evaluation *doLoop = - &doConstructEval->getFirstNestedEvaluation(); - auto *doStmt = doLoop->getIf(); - assert(doStmt && "Expected do loop to be in the nested evaluation"); - const auto &loopControl = - std::get>(doStmt->t); - const Fortran::parser::LoopControl::Bounds *bounds = - std::get_if(&loopControl->u); - assert(bounds && "Expected bounds for worksharing do loop"); - Fortran::lower::StatementContext stmtCtx; - lowerBound.push_back(fir::getBase(converter.genExprValue( - *Fortran::semantics::GetExpr(bounds->lower), stmtCtx))); - upperBound.push_back(fir::getBase(converter.genExprValue( - *Fortran::semantics::GetExpr(bounds->upper), stmtCtx))); - if (bounds->step) { - step.push_back(fir::getBase(converter.genExprValue( - *Fortran::semantics::GetExpr(bounds->step), stmtCtx))); - } else { // If `step` is not present, assume it as `1`. - step.push_back(firOpBuilder.createIntegerConstant( - currentLocation, firOpBuilder.getIntegerType(32), 1)); - } - iv.push_back(bounds->name.thing.symbol); - loopVarTypeSize = std::max(loopVarTypeSize, - bounds->name.thing.symbol->GetUltimate().size()); - collapseValue--; - doConstructEval = - &*std::next(doConstructEval->getNestedEvaluations().begin()); - } while (collapseValue > 0); - - return found; -} - -bool ClauseProcessor::processDefault() const { - if (auto *defaultClause = findUniqueClause()) { - // Private, Firstprivate, Shared, None - switch (defaultClause->v.v) { - case Fortran::parser::OmpDefaultClause::Type::Shared: - case Fortran::parser::OmpDefaultClause::Type::None: - // Default clause with shared or none do not require any handling since - // Shared is the default behavior in the IR and None is only required - // for semantic checks. - break; - case Fortran::parser::OmpDefaultClause::Type::Private: - // TODO Support default(private) - break; - case Fortran::parser::OmpDefaultClause::Type::Firstprivate: - // TODO Support default(firstprivate) - break; - } - return true; - } - return false; -} - -bool ClauseProcessor::processDevice(Fortran::lower::StatementContext &stmtCtx, - mlir::Value &result) const { - const Fortran::parser::CharBlock *source = nullptr; - if (auto *deviceClause = findUniqueClause(&source)) { - mlir::Location clauseLocation = converter.genLocation(*source); - if (auto deviceModifier = std::get< - std::optional>( - deviceClause->v.t)) { - if (deviceModifier == - Fortran::parser::OmpDeviceClause::DeviceModifier::Ancestor) { - TODO(clauseLocation, "OMPD_target Device Modifier Ancestor"); - } - } - if (const auto *deviceExpr = Fortran::semantics::GetExpr( - std::get(deviceClause->v.t))) { - result = fir::getBase(converter.genExprValue(*deviceExpr, stmtCtx)); - } - return true; - } - return false; -} - -bool ClauseProcessor::processDeviceType( - mlir::omp::DeclareTargetDeviceType &result) const { - if (auto *deviceTypeClause = findUniqueClause()) { - // Case: declare target ... device_type(any | host | nohost) - switch (deviceTypeClause->v.v) { - case Fortran::parser::OmpDeviceTypeClause::Type::Nohost: - result = mlir::omp::DeclareTargetDeviceType::nohost; - break; - case Fortran::parser::OmpDeviceTypeClause::Type::Host: - result = mlir::omp::DeclareTargetDeviceType::host; - break; - case Fortran::parser::OmpDeviceTypeClause::Type::Any: - result = mlir::omp::DeclareTargetDeviceType::any; - break; - } - return true; - } - return false; -} - -bool ClauseProcessor::processFinal(Fortran::lower::StatementContext &stmtCtx, - mlir::Value &result) const { - const Fortran::parser::CharBlock *source = nullptr; - if (auto *finalClause = findUniqueClause(&source)) { - fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - mlir::Location clauseLocation = converter.genLocation(*source); - - mlir::Value finalVal = fir::getBase(converter.genExprValue( - *Fortran::semantics::GetExpr(finalClause->v), stmtCtx)); - result = firOpBuilder.createConvert(clauseLocation, - firOpBuilder.getI1Type(), finalVal); - return true; - } - return false; -} - -bool ClauseProcessor::processHint(mlir::IntegerAttr &result) const { - if (auto *hintClause = findUniqueClause()) { - fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - const auto *expr = Fortran::semantics::GetExpr(hintClause->v); - int64_t hintValue = *Fortran::evaluate::ToInt64(*expr); - result = firOpBuilder.getI64IntegerAttr(hintValue); - return true; - } - return false; -} - -bool ClauseProcessor::processMergeable(mlir::UnitAttr &result) const { - return markClauseOccurrence(result); -} - -bool ClauseProcessor::processNowait(mlir::UnitAttr &result) const { - return markClauseOccurrence(result); -} - -bool ClauseProcessor::processNumTeams(Fortran::lower::StatementContext &stmtCtx, - mlir::Value &result) const { - // TODO Get lower and upper bounds for num_teams when parser is updated to - // accept both. - if (auto *numTeamsClause = findUniqueClause()) { - result = fir::getBase(converter.genExprValue( - *Fortran::semantics::GetExpr(numTeamsClause->v), stmtCtx)); - return true; - } - return false; -} - -bool ClauseProcessor::processNumThreads( - Fortran::lower::StatementContext &stmtCtx, mlir::Value &result) const { - if (auto *numThreadsClause = findUniqueClause()) { - // OMPIRBuilder expects `NUM_THREADS` clause as a `Value`. - result = fir::getBase(converter.genExprValue( - *Fortran::semantics::GetExpr(numThreadsClause->v), stmtCtx)); - return true; - } - return false; -} - -bool ClauseProcessor::processOrdered(mlir::IntegerAttr &result) const { - if (auto *orderedClause = findUniqueClause()) { - fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - int64_t orderedClauseValue = 0l; - if (orderedClause->v.has_value()) { - const auto *expr = Fortran::semantics::GetExpr(orderedClause->v); - orderedClauseValue = *Fortran::evaluate::ToInt64(*expr); - } - result = firOpBuilder.getI64IntegerAttr(orderedClauseValue); - return true; - } - return false; -} - -bool ClauseProcessor::processPriority(Fortran::lower::StatementContext &stmtCtx, - mlir::Value &result) const { - if (auto *priorityClause = findUniqueClause()) { - result = fir::getBase(converter.genExprValue( - *Fortran::semantics::GetExpr(priorityClause->v), stmtCtx)); - return true; - } - return false; -} - -bool ClauseProcessor::processProcBind( - mlir::omp::ClauseProcBindKindAttr &result) const { - if (auto *procBindClause = findUniqueClause()) { - fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - result = genProcBindKindAttr(firOpBuilder, procBindClause); - return true; - } - return false; -} - -bool ClauseProcessor::processSafelen(mlir::IntegerAttr &result) const { - if (auto *safelenClause = findUniqueClause()) { - fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - const auto *expr = Fortran::semantics::GetExpr(safelenClause->v); - const std::optional safelenVal = - Fortran::evaluate::ToInt64(*expr); - result = firOpBuilder.getI64IntegerAttr(*safelenVal); - return true; - } - return false; -} - -bool ClauseProcessor::processSchedule( - mlir::omp::ClauseScheduleKindAttr &valAttr, - mlir::omp::ScheduleModifierAttr &modifierAttr, - mlir::UnitAttr &simdModifierAttr) const { - if (auto *scheduleClause = findUniqueClause()) { - fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - mlir::MLIRContext *context = firOpBuilder.getContext(); - const Fortran::parser::OmpScheduleClause &scheduleType = scheduleClause->v; - const auto &scheduleClauseKind = - std::get( - scheduleType.t); - - mlir::omp::ClauseScheduleKind scheduleKind; - switch (scheduleClauseKind) { - case Fortran::parser::OmpScheduleClause::ScheduleType::Static: - scheduleKind = mlir::omp::ClauseScheduleKind::Static; - break; - case Fortran::parser::OmpScheduleClause::ScheduleType::Dynamic: - scheduleKind = mlir::omp::ClauseScheduleKind::Dynamic; - break; - case Fortran::parser::OmpScheduleClause::ScheduleType::Guided: - scheduleKind = mlir::omp::ClauseScheduleKind::Guided; - break; - case Fortran::parser::OmpScheduleClause::ScheduleType::Auto: - scheduleKind = mlir::omp::ClauseScheduleKind::Auto; - break; - case Fortran::parser::OmpScheduleClause::ScheduleType::Runtime: - scheduleKind = mlir::omp::ClauseScheduleKind::Runtime; - break; - } - - mlir::omp::ScheduleModifier scheduleModifier = - getScheduleModifier(scheduleClause->v); - - if (scheduleModifier != mlir::omp::ScheduleModifier::none) - modifierAttr = - mlir::omp::ScheduleModifierAttr::get(context, scheduleModifier); - - if (getSimdModifier(scheduleClause->v) != mlir::omp::ScheduleModifier::none) - simdModifierAttr = firOpBuilder.getUnitAttr(); - - valAttr = mlir::omp::ClauseScheduleKindAttr::get(context, scheduleKind); - return true; - } - return false; -} - -bool ClauseProcessor::processScheduleChunk( - Fortran::lower::StatementContext &stmtCtx, mlir::Value &result) const { - if (auto *scheduleClause = findUniqueClause()) { - if (const auto &chunkExpr = - std::get>( - scheduleClause->v.t)) { - if (const auto *expr = Fortran::semantics::GetExpr(*chunkExpr)) { - result = fir::getBase(converter.genExprValue(*expr, stmtCtx)); - } - } - return true; - } - return false; -} - -bool ClauseProcessor::processSimdlen(mlir::IntegerAttr &result) const { - if (auto *simdlenClause = findUniqueClause()) { - fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - const auto *expr = Fortran::semantics::GetExpr(simdlenClause->v); - const std::optional simdlenVal = - Fortran::evaluate::ToInt64(*expr); - result = firOpBuilder.getI64IntegerAttr(*simdlenVal); - return true; - } - return false; -} - -bool ClauseProcessor::processThreadLimit( - Fortran::lower::StatementContext &stmtCtx, mlir::Value &result) const { - if (auto *threadLmtClause = findUniqueClause()) { - result = fir::getBase(converter.genExprValue( - *Fortran::semantics::GetExpr(threadLmtClause->v), stmtCtx)); - return true; - } - return false; -} - -bool ClauseProcessor::processUntied(mlir::UnitAttr &result) const { - return markClauseOccurrence(result); -} - -//===----------------------------------------------------------------------===// -// ClauseProcessor repeatable clauses -//===----------------------------------------------------------------------===// - -bool ClauseProcessor::processAllocate( - llvm::SmallVectorImpl &allocatorOperands, - llvm::SmallVectorImpl &allocateOperands) const { - return findRepeatableClause( - [&](const ClauseTy::Allocate *allocateClause, - const Fortran::parser::CharBlock &) { - genAllocateClause(converter, allocateClause->v, allocatorOperands, - allocateOperands); - }); -} - -bool ClauseProcessor::processCopyin() const { - fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - mlir::OpBuilder::InsertPoint insPt = firOpBuilder.saveInsertionPoint(); - firOpBuilder.setInsertionPointToStart(firOpBuilder.getAllocaBlock()); - auto checkAndCopyHostAssociateVar = - [&](Fortran::semantics::Symbol *sym, - mlir::OpBuilder::InsertPoint *copyAssignIP = nullptr) { - assert(sym->has() && - "No host-association found"); - if (converter.isPresentShallowLookup(*sym)) - converter.copyHostAssociateVar(*sym, copyAssignIP); - }; - bool hasCopyin = findRepeatableClause( - [&](const ClauseTy::Copyin *copyinClause, - const Fortran::parser::CharBlock &) { - const Fortran::parser::OmpObjectList &ompObjectList = copyinClause->v; - for (const Fortran::parser::OmpObject &ompObject : ompObjectList.v) { - Fortran::semantics::Symbol *sym = getOmpObjectSymbol(ompObject); - if (const auto *commonDetails = - sym->detailsIf()) { - for (const auto &mem : commonDetails->objects()) - checkAndCopyHostAssociateVar(&*mem, &insPt); - break; - } - if (Fortran::semantics::IsAllocatableOrObjectPointer( - &sym->GetUltimate())) - TODO(converter.getCurrentLocation(), - "pointer or allocatable variables in Copyin clause"); - assert(sym->has() && - "No host-association found"); - checkAndCopyHostAssociateVar(sym); - } - }); - - // [OMP 5.0, 2.19.6.1] The copy is done after the team is formed and prior to - // the execution of the associated structured block. Emit implicit barrier to - // synchronize threads and avoid data races on propagation master's thread - // values of threadprivate variables to local instances of that variables of - // all other implicit threads. - if (hasCopyin) - firOpBuilder.create(converter.getCurrentLocation()); - firOpBuilder.restoreInsertionPoint(insPt); - return hasCopyin; -} - -bool ClauseProcessor::processDepend( - llvm::SmallVectorImpl &dependTypeOperands, - llvm::SmallVectorImpl &dependOperands) const { - fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - - return findRepeatableClause( - [&](const ClauseTy::Depend *dependClause, - const Fortran::parser::CharBlock &) { - const std::list &depVal = - std::get>( - std::get( - dependClause->v.u) - .t); - mlir::omp::ClauseTaskDependAttr dependTypeOperand = - genDependKindAttr(firOpBuilder, dependClause); - dependTypeOperands.insert(dependTypeOperands.end(), depVal.size(), - dependTypeOperand); - for (const Fortran::parser::Designator &ompObject : depVal) { - Fortran::semantics::Symbol *sym = nullptr; - std::visit( - Fortran::common::visitors{ - [&](const Fortran::parser::DataRef &designator) { - if (const Fortran::parser::Name *name = - std::get_if(&designator.u)) { - sym = name->symbol; - } else if (std::get_if>( - &designator.u)) { - TODO(converter.getCurrentLocation(), - "array sections not supported for task depend"); - } - }, - [&](const Fortran::parser::Substring &designator) { - TODO(converter.getCurrentLocation(), - "substring not supported for task depend"); - }}, - (ompObject).u); - const mlir::Value variable = converter.getSymbolAddress(*sym); - dependOperands.push_back(variable); - } - }); -} - -bool ClauseProcessor::processIf( - Fortran::parser::OmpIfClause::DirectiveNameModifier directiveName, - mlir::Value &result) const { - bool found = false; - findRepeatableClause( - [&](const ClauseTy::If *ifClause, - const Fortran::parser::CharBlock &source) { - mlir::Location clauseLocation = converter.genLocation(source); - mlir::Value operand = getIfClauseOperand(converter, ifClause, - directiveName, clauseLocation); - // Assume that, at most, a single 'if' clause will be applicable to the - // given directive. - if (operand) { - result = operand; - found = true; - } - }); - return found; -} - -bool ClauseProcessor::processLink( - llvm::SmallVectorImpl &result) const { - return findRepeatableClause( - [&](const ClauseTy::Link *linkClause, - const Fortran::parser::CharBlock &) { - // Case: declare target link(var1, var2)... - gatherFuncAndVarSyms( - linkClause->v, mlir::omp::DeclareTargetCaptureClause::link, result); - }); -} - -static mlir::omp::MapInfoOp -createMapInfoOp(fir::FirOpBuilder &builder, mlir::Location loc, - mlir::Value baseAddr, mlir::Value varPtrPtr, std::string name, - mlir::SmallVector bounds, - mlir::SmallVector members, uint64_t mapType, - mlir::omp::VariableCaptureKind mapCaptureType, mlir::Type retTy, - bool isVal = false) { - if (auto boxTy = baseAddr.getType().dyn_cast()) { - baseAddr = builder.create(loc, baseAddr); - retTy = baseAddr.getType(); - } - - mlir::TypeAttr varType = mlir::TypeAttr::get( - llvm::cast(retTy).getElementType()); - - mlir::omp::MapInfoOp op = builder.create( - loc, retTy, baseAddr, varType, varPtrPtr, members, bounds, - builder.getIntegerAttr(builder.getIntegerType(64, false), mapType), - builder.getAttr(mapCaptureType), - builder.getStringAttr(name)); - - return op; -} - -bool ClauseProcessor::processMap( - mlir::Location currentLocation, const llvm::omp::Directive &directive, - Fortran::lower::StatementContext &stmtCtx, - llvm::SmallVectorImpl &mapOperands, - llvm::SmallVectorImpl *mapSymTypes, - llvm::SmallVectorImpl *mapSymLocs, - llvm::SmallVectorImpl *mapSymbols) - const { - fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - return findRepeatableClause( - [&](const ClauseTy::Map *mapClause, - const Fortran::parser::CharBlock &source) { - mlir::Location clauseLocation = converter.genLocation(source); - const auto &oMapType = - std::get>( - mapClause->v.t); - llvm::omp::OpenMPOffloadMappingFlags mapTypeBits = - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE; - // If the map type is specified, then process it else Tofrom is the - // default. - if (oMapType) { - const Fortran::parser::OmpMapType::Type &mapType = - std::get(oMapType->t); - switch (mapType) { - case Fortran::parser::OmpMapType::Type::To: - mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO; - break; - case Fortran::parser::OmpMapType::Type::From: - mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM; - break; - case Fortran::parser::OmpMapType::Type::Tofrom: - mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO | - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM; - break; - case Fortran::parser::OmpMapType::Type::Alloc: - case Fortran::parser::OmpMapType::Type::Release: - // alloc and release is the default map_type for the Target Data - // Ops, i.e. if no bits for map_type is supplied then alloc/release - // is implicitly assumed based on the target directive. Default - // value for Target Data and Enter Data is alloc and for Exit Data - // it is release. - break; - case Fortran::parser::OmpMapType::Type::Delete: - mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE; - } - - if (std::get>( - oMapType->t)) - mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS; - } else { - mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO | - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM; - } - - for (const Fortran::parser::OmpObject &ompObject : - std::get(mapClause->v.t).v) { - llvm::SmallVector bounds; - std::stringstream asFortran; - - Fortran::lower::AddrAndBoundsInfo info = - Fortran::lower::gatherDataOperandAddrAndBounds< - Fortran::parser::OmpObject, mlir::omp::DataBoundsOp, - mlir::omp::DataBoundsType>( - converter, firOpBuilder, semaCtx, stmtCtx, ompObject, - clauseLocation, asFortran, bounds, treatIndexAsSection); - - auto origSymbol = - converter.getSymbolAddress(*getOmpObjectSymbol(ompObject)); - mlir::Value symAddr = info.addr; - if (origSymbol && fir::isTypeWithDescriptor(origSymbol.getType())) - symAddr = origSymbol; - - // Explicit map captures are captured ByRef by default, - // optimisation passes may alter this to ByCopy or other capture - // types to optimise - mlir::Value mapOp = createMapInfoOp( - firOpBuilder, clauseLocation, symAddr, mlir::Value{}, - asFortran.str(), bounds, {}, - static_cast< - std::underlying_type_t>( - mapTypeBits), - mlir::omp::VariableCaptureKind::ByRef, symAddr.getType()); - - mapOperands.push_back(mapOp); - if (mapSymTypes) - mapSymTypes->push_back(symAddr.getType()); - if (mapSymLocs) - mapSymLocs->push_back(symAddr.getLoc()); - - if (mapSymbols) - mapSymbols->push_back(getOmpObjectSymbol(ompObject)); - } - }); -} - -bool ClauseProcessor::processReduction( - mlir::Location currentLocation, - llvm::SmallVectorImpl &reductionVars, - llvm::SmallVectorImpl &reductionDeclSymbols, - llvm::SmallVectorImpl *reductionSymbols) - const { - return findRepeatableClause( - [&](const ClauseTy::Reduction *reductionClause, - const Fortran::parser::CharBlock &) { - ReductionProcessor rp; - rp.addReductionDecl(currentLocation, converter, reductionClause->v, - reductionVars, reductionDeclSymbols, - reductionSymbols); - }); -} - -bool ClauseProcessor::processSectionsReduction( - mlir::Location currentLocation) const { - return findRepeatableClause( - [&](const ClauseTy::Reduction *, const Fortran::parser::CharBlock &) { - TODO(currentLocation, "OMPC_Reduction"); - }); -} - -bool ClauseProcessor::processTo( - llvm::SmallVectorImpl &result) const { - return findRepeatableClause( - [&](const ClauseTy::To *toClause, const Fortran::parser::CharBlock &) { - // Case: declare target to(func, var1, var2)... - gatherFuncAndVarSyms(toClause->v, - mlir::omp::DeclareTargetCaptureClause::to, result); - }); -} - -bool ClauseProcessor::processEnter( - llvm::SmallVectorImpl &result) const { - return findRepeatableClause( - [&](const ClauseTy::Enter *enterClause, - const Fortran::parser::CharBlock &) { - // Case: declare target enter(func, var1, var2)... - gatherFuncAndVarSyms(enterClause->v, - mlir::omp::DeclareTargetCaptureClause::enter, - result); - }); -} - -bool ClauseProcessor::processUseDeviceAddr( - llvm::SmallVectorImpl &operands, - llvm::SmallVectorImpl &useDeviceTypes, - llvm::SmallVectorImpl &useDeviceLocs, - llvm::SmallVectorImpl &useDeviceSymbols) - const { - return findRepeatableClause( - [&](const ClauseTy::UseDeviceAddr *devAddrClause, - const Fortran::parser::CharBlock &) { - addUseDeviceClause(converter, devAddrClause->v, operands, - useDeviceTypes, useDeviceLocs, useDeviceSymbols); - }); -} - -bool ClauseProcessor::processUseDevicePtr( - llvm::SmallVectorImpl &operands, - llvm::SmallVectorImpl &useDeviceTypes, - llvm::SmallVectorImpl &useDeviceLocs, - llvm::SmallVectorImpl &useDeviceSymbols) - const { - return findRepeatableClause( - [&](const ClauseTy::UseDevicePtr *devPtrClause, - const Fortran::parser::CharBlock &) { - addUseDeviceClause(converter, devPtrClause->v, operands, useDeviceTypes, - useDeviceLocs, useDeviceSymbols); - }); -} - -template -bool ClauseProcessor::processMotionClauses( - Fortran::lower::StatementContext &stmtCtx, - llvm::SmallVectorImpl &mapOperands) { - return findRepeatableClause( - [&](const T *motionClause, const Fortran::parser::CharBlock &source) { - mlir::Location clauseLocation = converter.genLocation(source); - fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - - static_assert(std::is_same_v || - std::is_same_v); - - // TODO Support motion modifiers: present, mapper, iterator. - constexpr llvm::omp::OpenMPOffloadMappingFlags mapTypeBits = - std::is_same_v - ? llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO - : llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM; - - for (const Fortran::parser::OmpObject &ompObject : motionClause->v.v) { - llvm::SmallVector bounds; - std::stringstream asFortran; - Fortran::lower::AddrAndBoundsInfo info = - Fortran::lower::gatherDataOperandAddrAndBounds< - Fortran::parser::OmpObject, mlir::omp::DataBoundsOp, - mlir::omp::DataBoundsType>( - converter, firOpBuilder, semaCtx, stmtCtx, ompObject, - clauseLocation, asFortran, bounds, treatIndexAsSection); - - auto origSymbol = - converter.getSymbolAddress(*getOmpObjectSymbol(ompObject)); - mlir::Value symAddr = info.addr; - if (origSymbol && fir::isTypeWithDescriptor(origSymbol.getType())) - symAddr = origSymbol; - - // Explicit map captures are captured ByRef by default, - // optimisation passes may alter this to ByCopy or other capture - // types to optimise - mlir::Value mapOp = createMapInfoOp( - firOpBuilder, clauseLocation, symAddr, mlir::Value{}, - asFortran.str(), bounds, {}, - static_cast< - std::underlying_type_t>( - mapTypeBits), - mlir::omp::VariableCaptureKind::ByRef, symAddr.getType()); - - mapOperands.push_back(mapOp); - } - }); -} - -template -void ClauseProcessor::processTODO(mlir::Location currentLocation, - llvm::omp::Directive directive) const { - auto checkUnhandledClause = [&](const auto *x) { - if (!x) - return; - TODO(currentLocation, - "Unhandled clause " + - llvm::StringRef(Fortran::parser::ParseTreeDumper::GetNodeName(*x)) - .upper() + - " in " + llvm::omp::getOpenMPDirectiveName(directive).upper() + - " construct"); - }; - - for (ClauseIterator it = clauses.v.begin(); it != clauses.v.end(); ++it) - (checkUnhandledClause(std::get_if(&it->u)), ...); -} - -//===----------------------------------------------------------------------===// -// Code generation helper functions -//===----------------------------------------------------------------------===// - static fir::GlobalOp globalInitialization( Fortran::lower::AbstractConverter &converter, fir::FirOpBuilder &firOpBuilder, const Fortran::semantics::Symbol &sym, diff --git a/flang/lib/Lower/OpenMP/ReductionProcessor.cpp b/flang/lib/Lower/OpenMP/ReductionProcessor.cpp new file mode 100644 index 0000000000000..a8b98f3f56724 --- /dev/null +++ b/flang/lib/Lower/OpenMP/ReductionProcessor.cpp @@ -0,0 +1,431 @@ +//===-- ReductionProcessor.cpp ----------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ +// +//===----------------------------------------------------------------------===// + +#include "ReductionProcessor.h" + +#include "flang/Lower/AbstractConverter.h" +#include "flang/Optimizer/Builder/Todo.h" +#include "flang/Optimizer/HLFIR/HLFIROps.h" +#include "flang/Parser/tools.h" +#include "mlir/Dialect/OpenMP/OpenMPDialect.h" + +namespace Fortran { +namespace lower { +namespace omp { + +ReductionProcessor::ReductionIdentifier ReductionProcessor::getReductionType( + const Fortran::parser::ProcedureDesignator &pd) { + auto redType = llvm::StringSwitch>( + ReductionProcessor::getRealName(pd).ToString()) + .Case("max", ReductionIdentifier::MAX) + .Case("min", ReductionIdentifier::MIN) + .Case("iand", ReductionIdentifier::IAND) + .Case("ior", ReductionIdentifier::IOR) + .Case("ieor", ReductionIdentifier::IEOR) + .Default(std::nullopt); + assert(redType && "Invalid Reduction"); + return *redType; +} + +ReductionProcessor::ReductionIdentifier ReductionProcessor::getReductionType( + Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp) { + switch (intrinsicOp) { + case Fortran::parser::DefinedOperator::IntrinsicOperator::Add: + return ReductionIdentifier::ADD; + case Fortran::parser::DefinedOperator::IntrinsicOperator::Subtract: + return ReductionIdentifier::SUBTRACT; + case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply: + return ReductionIdentifier::MULTIPLY; + case Fortran::parser::DefinedOperator::IntrinsicOperator::AND: + return ReductionIdentifier::AND; + case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV: + return ReductionIdentifier::EQV; + case Fortran::parser::DefinedOperator::IntrinsicOperator::OR: + return ReductionIdentifier::OR; + case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV: + return ReductionIdentifier::NEQV; + default: + llvm_unreachable("unexpected intrinsic operator in reduction"); + } +} + +bool ReductionProcessor::supportedIntrinsicProcReduction( + const Fortran::parser::ProcedureDesignator &pd) { + const auto *name{Fortran::parser::Unwrap(pd)}; + assert(name && "Invalid Reduction Intrinsic."); + if (!name->symbol->GetUltimate().attrs().test( + Fortran::semantics::Attr::INTRINSIC)) + return false; + auto redType = llvm::StringSwitch(getRealName(name).ToString()) + .Case("max", true) + .Case("min", true) + .Case("iand", true) + .Case("ior", true) + .Case("ieor", true) + .Default(false); + return redType; +} + +std::string ReductionProcessor::getReductionName(llvm::StringRef name, + mlir::Type ty) { + return (llvm::Twine(name) + + (ty.isIntOrIndex() ? llvm::Twine("_i_") : llvm::Twine("_f_")) + + llvm::Twine(ty.getIntOrFloatBitWidth())) + .str(); +} + +std::string ReductionProcessor::getReductionName( + Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp, + mlir::Type ty) { + std::string reductionName; + + switch (intrinsicOp) { + case Fortran::parser::DefinedOperator::IntrinsicOperator::Add: + reductionName = "add_reduction"; + break; + case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply: + reductionName = "multiply_reduction"; + break; + case Fortran::parser::DefinedOperator::IntrinsicOperator::AND: + return "and_reduction"; + case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV: + return "eqv_reduction"; + case Fortran::parser::DefinedOperator::IntrinsicOperator::OR: + return "or_reduction"; + case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV: + return "neqv_reduction"; + default: + reductionName = "other_reduction"; + break; + } + + return getReductionName(reductionName, ty); +} + +mlir::Value +ReductionProcessor::getReductionInitValue(mlir::Location loc, mlir::Type type, + ReductionIdentifier redId, + fir::FirOpBuilder &builder) { + assert((fir::isa_integer(type) || fir::isa_real(type) || + type.isa()) && + "only integer, logical and real types are currently supported"); + switch (redId) { + case ReductionIdentifier::MAX: { + if (auto ty = type.dyn_cast()) { + const llvm::fltSemantics &sem = ty.getFloatSemantics(); + return builder.createRealConstant( + loc, type, llvm::APFloat::getLargest(sem, /*Negative=*/true)); + } + unsigned bits = type.getIntOrFloatBitWidth(); + int64_t minInt = llvm::APInt::getSignedMinValue(bits).getSExtValue(); + return builder.createIntegerConstant(loc, type, minInt); + } + case ReductionIdentifier::MIN: { + if (auto ty = type.dyn_cast()) { + const llvm::fltSemantics &sem = ty.getFloatSemantics(); + return builder.createRealConstant( + loc, type, llvm::APFloat::getLargest(sem, /*Negative=*/false)); + } + unsigned bits = type.getIntOrFloatBitWidth(); + int64_t maxInt = llvm::APInt::getSignedMaxValue(bits).getSExtValue(); + return builder.createIntegerConstant(loc, type, maxInt); + } + case ReductionIdentifier::IOR: { + unsigned bits = type.getIntOrFloatBitWidth(); + int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue(); + return builder.createIntegerConstant(loc, type, zeroInt); + } + case ReductionIdentifier::IEOR: { + unsigned bits = type.getIntOrFloatBitWidth(); + int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue(); + return builder.createIntegerConstant(loc, type, zeroInt); + } + case ReductionIdentifier::IAND: { + unsigned bits = type.getIntOrFloatBitWidth(); + int64_t allOnInt = llvm::APInt::getAllOnes(bits).getSExtValue(); + return builder.createIntegerConstant(loc, type, allOnInt); + } + case ReductionIdentifier::ADD: + case ReductionIdentifier::MULTIPLY: + case ReductionIdentifier::AND: + case ReductionIdentifier::OR: + case ReductionIdentifier::EQV: + case ReductionIdentifier::NEQV: + if (type.isa()) + return builder.create( + loc, type, + builder.getFloatAttr(type, (double)getOperationIdentity(redId, loc))); + + if (type.isa()) { + mlir::Value intConst = builder.create( + loc, builder.getI1Type(), + builder.getIntegerAttr(builder.getI1Type(), + getOperationIdentity(redId, loc))); + return builder.createConvert(loc, type, intConst); + } + + return builder.create( + loc, type, + builder.getIntegerAttr(type, getOperationIdentity(redId, loc))); + case ReductionIdentifier::ID: + case ReductionIdentifier::USER_DEF_OP: + case ReductionIdentifier::SUBTRACT: + TODO(loc, "Reduction of some identifier types is not supported"); + } + llvm_unreachable("Unhandled Reduction identifier : getReductionInitValue"); +} + +mlir::Value ReductionProcessor::createScalarCombiner( + fir::FirOpBuilder &builder, mlir::Location loc, ReductionIdentifier redId, + mlir::Type type, mlir::Value op1, mlir::Value op2) { + mlir::Value reductionOp; + switch (redId) { + case ReductionIdentifier::MAX: + reductionOp = + getReductionOperation( + builder, type, loc, op1, op2); + break; + case ReductionIdentifier::MIN: + reductionOp = + getReductionOperation( + builder, type, loc, op1, op2); + break; + case ReductionIdentifier::IOR: + assert((type.isIntOrIndex()) && "only integer is expected"); + reductionOp = builder.create(loc, op1, op2); + break; + case ReductionIdentifier::IEOR: + assert((type.isIntOrIndex()) && "only integer is expected"); + reductionOp = builder.create(loc, op1, op2); + break; + case ReductionIdentifier::IAND: + assert((type.isIntOrIndex()) && "only integer is expected"); + reductionOp = builder.create(loc, op1, op2); + break; + case ReductionIdentifier::ADD: + reductionOp = + getReductionOperation( + builder, type, loc, op1, op2); + break; + case ReductionIdentifier::MULTIPLY: + reductionOp = + getReductionOperation( + builder, type, loc, op1, op2); + break; + case ReductionIdentifier::AND: { + mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1); + mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2); + + mlir::Value andiOp = builder.create(loc, op1I1, op2I1); + + reductionOp = builder.createConvert(loc, type, andiOp); + break; + } + case ReductionIdentifier::OR: { + mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1); + mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2); + + mlir::Value oriOp = builder.create(loc, op1I1, op2I1); + + reductionOp = builder.createConvert(loc, type, oriOp); + break; + } + case ReductionIdentifier::EQV: { + mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1); + mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2); + + mlir::Value cmpiOp = builder.create( + loc, mlir::arith::CmpIPredicate::eq, op1I1, op2I1); + + reductionOp = builder.createConvert(loc, type, cmpiOp); + break; + } + case ReductionIdentifier::NEQV: { + mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1); + mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2); + + mlir::Value cmpiOp = builder.create( + loc, mlir::arith::CmpIPredicate::ne, op1I1, op2I1); + + reductionOp = builder.createConvert(loc, type, cmpiOp); + break; + } + default: + TODO(loc, "Reduction of some intrinsic operators is not supported"); + } + + return reductionOp; +} + +mlir::omp::ReductionDeclareOp ReductionProcessor::createReductionDecl( + fir::FirOpBuilder &builder, llvm::StringRef reductionOpName, + const ReductionIdentifier redId, mlir::Type type, mlir::Location loc) { + mlir::OpBuilder::InsertionGuard guard(builder); + mlir::ModuleOp module = builder.getModule(); + + auto decl = + module.lookupSymbol(reductionOpName); + if (decl) + return decl; + + mlir::OpBuilder modBuilder(module.getBodyRegion()); + + decl = modBuilder.create(loc, reductionOpName, + type); + builder.createBlock(&decl.getInitializerRegion(), + decl.getInitializerRegion().end(), {type}, {loc}); + builder.setInsertionPointToEnd(&decl.getInitializerRegion().back()); + mlir::Value init = getReductionInitValue(loc, type, redId, builder); + builder.create(loc, init); + + builder.createBlock(&decl.getReductionRegion(), + decl.getReductionRegion().end(), {type, type}, + {loc, loc}); + + builder.setInsertionPointToEnd(&decl.getReductionRegion().back()); + mlir::Value op1 = decl.getReductionRegion().front().getArgument(0); + mlir::Value op2 = decl.getReductionRegion().front().getArgument(1); + + mlir::Value reductionOp = + createScalarCombiner(builder, loc, redId, type, op1, op2); + builder.create(loc, reductionOp); + + return decl; +} + +void ReductionProcessor::addReductionDecl( + mlir::Location currentLocation, + Fortran::lower::AbstractConverter &converter, + const Fortran::parser::OmpReductionClause &reduction, + llvm::SmallVectorImpl &reductionVars, + llvm::SmallVectorImpl &reductionDeclSymbols, + llvm::SmallVectorImpl + *reductionSymbols) { + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + mlir::omp::ReductionDeclareOp decl; + const auto &redOperator{ + std::get(reduction.t)}; + const auto &objectList{std::get(reduction.t)}; + if (const auto &redDefinedOp = + std::get_if(&redOperator.u)) { + const auto &intrinsicOp{ + std::get( + redDefinedOp->u)}; + ReductionIdentifier redId = getReductionType(intrinsicOp); + switch (redId) { + case ReductionIdentifier::ADD: + case ReductionIdentifier::MULTIPLY: + case ReductionIdentifier::AND: + case ReductionIdentifier::EQV: + case ReductionIdentifier::OR: + case ReductionIdentifier::NEQV: + break; + default: + TODO(currentLocation, + "Reduction of some intrinsic operators is not supported"); + break; + } + for (const Fortran::parser::OmpObject &ompObject : objectList.v) { + if (const auto *name{ + Fortran::parser::Unwrap(ompObject)}) { + if (const Fortran::semantics::Symbol * symbol{name->symbol}) { + if (reductionSymbols) + reductionSymbols->push_back(symbol); + mlir::Value symVal = converter.getSymbolAddress(*symbol); + if (auto declOp = symVal.getDefiningOp()) + symVal = declOp.getBase(); + mlir::Type redType = + symVal.getType().cast().getEleTy(); + reductionVars.push_back(symVal); + if (redType.isa()) + decl = createReductionDecl( + firOpBuilder, + getReductionName(intrinsicOp, firOpBuilder.getI1Type()), redId, + redType, currentLocation); + else if (redType.isIntOrIndexOrFloat()) { + decl = createReductionDecl(firOpBuilder, + getReductionName(intrinsicOp, redType), + redId, redType, currentLocation); + } else { + TODO(currentLocation, "Reduction of some types is not supported"); + } + reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get( + firOpBuilder.getContext(), decl.getSymName())); + } + } + } + } else if (const auto *reductionIntrinsic = + std::get_if( + &redOperator.u)) { + if (ReductionProcessor::supportedIntrinsicProcReduction( + *reductionIntrinsic)) { + ReductionProcessor::ReductionIdentifier redId = + ReductionProcessor::getReductionType(*reductionIntrinsic); + for (const Fortran::parser::OmpObject &ompObject : objectList.v) { + if (const auto *name{ + Fortran::parser::Unwrap(ompObject)}) { + if (const Fortran::semantics::Symbol * symbol{name->symbol}) { + if (reductionSymbols) + reductionSymbols->push_back(symbol); + mlir::Value symVal = converter.getSymbolAddress(*symbol); + if (auto declOp = symVal.getDefiningOp()) + symVal = declOp.getBase(); + mlir::Type redType = + symVal.getType().cast().getEleTy(); + reductionVars.push_back(symVal); + assert(redType.isIntOrIndexOrFloat() && + "Unsupported reduction type"); + decl = createReductionDecl( + firOpBuilder, + getReductionName(getRealName(*reductionIntrinsic).ToString(), + redType), + redId, redType, currentLocation); + reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get( + firOpBuilder.getContext(), decl.getSymName())); + } + } + } + } + } +} + +const Fortran::semantics::SourceName +ReductionProcessor::getRealName(const Fortran::parser::Name *name) { + return name->symbol->GetUltimate().name(); +} + +const Fortran::semantics::SourceName ReductionProcessor::getRealName( + const Fortran::parser::ProcedureDesignator &pd) { + const auto *name{Fortran::parser::Unwrap(pd)}; + assert(name && "Invalid Reduction Intrinsic."); + return getRealName(name); +} + +int ReductionProcessor::getOperationIdentity(ReductionIdentifier redId, + mlir::Location loc) { + switch (redId) { + case ReductionIdentifier::ADD: + case ReductionIdentifier::OR: + case ReductionIdentifier::NEQV: + return 0; + case ReductionIdentifier::MULTIPLY: + case ReductionIdentifier::AND: + case ReductionIdentifier::EQV: + return 1; + default: + TODO(loc, "Reduction of some intrinsic operators is not supported"); + } +} + +} // namespace omp +} // namespace lower +} // namespace Fortran diff --git a/flang/lib/Lower/OpenMP/ReductionProcessor.h b/flang/lib/Lower/OpenMP/ReductionProcessor.h new file mode 100644 index 0000000000000..00770fe81d1ef --- /dev/null +++ b/flang/lib/Lower/OpenMP/ReductionProcessor.h @@ -0,0 +1,138 @@ +//===-- Lower/OpenMP/ReductionProcessor.h -----------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ +// +//===----------------------------------------------------------------------===// + +#ifndef FORTRAN_LOWER_REDUCTIONPROCESSOR_H +#define FORTRAN_LOWER_REDUCTIONPROCESSOR_H + +#include "flang/Optimizer/Builder/FIRBuilder.h" +#include "flang/Parser/parse-tree.h" +#include "flang/Semantics/symbol.h" +#include "flang/Semantics/type.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/Types.h" + +namespace mlir { +namespace omp { +class ReductionDeclareOp; +} // namespace omp +} // namespace mlir + +namespace Fortran { +namespace lower { +class AbstractConverter; +} // namespace lower +} // namespace Fortran + +namespace Fortran { +namespace lower { +namespace omp { + +class ReductionProcessor { +public: + // TODO: Move this enumeration to the OpenMP dialect + enum ReductionIdentifier { + ID, + USER_DEF_OP, + ADD, + SUBTRACT, + MULTIPLY, + AND, + OR, + EQV, + NEQV, + MAX, + MIN, + IAND, + IOR, + IEOR + }; + + static ReductionIdentifier + getReductionType(const Fortran::parser::ProcedureDesignator &pd); + + static ReductionIdentifier getReductionType( + Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp); + + static bool supportedIntrinsicProcReduction( + const Fortran::parser::ProcedureDesignator &pd); + + static const Fortran::semantics::SourceName + getRealName(const Fortran::parser::Name *name); + + static const Fortran::semantics::SourceName + getRealName(const Fortran::parser::ProcedureDesignator &pd); + + static std::string getReductionName(llvm::StringRef name, mlir::Type ty); + + static std::string getReductionName( + Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp, + mlir::Type ty); + + /// This function returns the identity value of the operator \p + /// reductionOpName. For example: + /// 0 + x = x, + /// 1 * x = x + static int getOperationIdentity(ReductionIdentifier redId, + mlir::Location loc); + + static mlir::Value getReductionInitValue(mlir::Location loc, mlir::Type type, + ReductionIdentifier redId, + fir::FirOpBuilder &builder); + + template + static mlir::Value getReductionOperation(fir::FirOpBuilder &builder, + mlir::Type type, mlir::Location loc, + mlir::Value op1, mlir::Value op2); + + static mlir::Value createScalarCombiner(fir::FirOpBuilder &builder, + mlir::Location loc, + ReductionIdentifier redId, + mlir::Type type, mlir::Value op1, + mlir::Value op2); + + /// Creates an OpenMP reduction declaration and inserts it into the provided + /// symbol table. The declaration has a constant initializer with the neutral + /// value `initValue`, and the reduction combiner carried over from `reduce`. + /// TODO: Generalize this for non-integer types, add atomic region. + static mlir::omp::ReductionDeclareOp createReductionDecl( + fir::FirOpBuilder &builder, llvm::StringRef reductionOpName, + const ReductionIdentifier redId, mlir::Type type, mlir::Location loc); + + /// Creates a reduction declaration and associates it with an OpenMP block + /// directive. + static void + addReductionDecl(mlir::Location currentLocation, + Fortran::lower::AbstractConverter &converter, + const Fortran::parser::OmpReductionClause &reduction, + llvm::SmallVectorImpl &reductionVars, + llvm::SmallVectorImpl &reductionDeclSymbols, + llvm::SmallVectorImpl + *reductionSymbols = nullptr); +}; + +template +mlir::Value +ReductionProcessor::getReductionOperation(fir::FirOpBuilder &builder, + mlir::Type type, mlir::Location loc, + mlir::Value op1, mlir::Value op2) { + assert(type.isIntOrIndexOrFloat() && + "only integer and float types are currently supported"); + if (type.isIntOrIndex()) + return builder.create(loc, op1, op2); + return builder.create(loc, op1, op2); +} + +} // namespace omp +} // namespace lower +} // namespace Fortran + +#endif // FORTRAN_LOWER_REDUCTIONPROCESSOR_H diff --git a/flang/lib/Lower/OpenMP/Utils.cpp b/flang/lib/Lower/OpenMP/Utils.cpp new file mode 100644 index 0000000000000..31b15257d1868 --- /dev/null +++ b/flang/lib/Lower/OpenMP/Utils.cpp @@ -0,0 +1,99 @@ +//===-- Utils..cpp ----------------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ +// +//===----------------------------------------------------------------------===// + +#include "Utils.h" + +#include +#include +#include +#include +#include +#include + +llvm::cl::opt treatIndexAsSection( + "openmp-treat-index-as-section", + llvm::cl::desc("In the OpenMP data clauses treat `a(N)` as `a(N:N)`."), + llvm::cl::init(true)); + +namespace Fortran { +namespace lower { +namespace omp { + +void genObjectList(const Fortran::parser::OmpObjectList &objectList, + Fortran::lower::AbstractConverter &converter, + llvm::SmallVectorImpl &operands) { + auto addOperands = [&](Fortran::lower::SymbolRef sym) { + const mlir::Value variable = converter.getSymbolAddress(sym); + if (variable) { + operands.push_back(variable); + } else { + if (const auto *details = + sym->detailsIf()) { + operands.push_back(converter.getSymbolAddress(details->symbol())); + converter.copySymbolBinding(details->symbol(), sym); + } + } + }; + for (const Fortran::parser::OmpObject &ompObject : objectList.v) { + Fortran::semantics::Symbol *sym = getOmpObjectSymbol(ompObject); + addOperands(*sym); + } +} + +void gatherFuncAndVarSyms( + const Fortran::parser::OmpObjectList &objList, + mlir::omp::DeclareTargetCaptureClause clause, + llvm::SmallVectorImpl &symbolAndClause) { + for (const Fortran::parser::OmpObject &ompObject : objList.v) { + Fortran::common::visit( + Fortran::common::visitors{ + [&](const Fortran::parser::Designator &designator) { + if (const Fortran::parser::Name *name = + Fortran::semantics::getDesignatorNameIfDataRef( + designator)) { + symbolAndClause.emplace_back(clause, *name->symbol); + } + }, + [&](const Fortran::parser::Name &name) { + symbolAndClause.emplace_back(clause, *name.symbol); + }}, + ompObject.u); + } +} + +Fortran::semantics::Symbol * +getOmpObjectSymbol(const Fortran::parser::OmpObject &ompObject) { + Fortran::semantics::Symbol *sym = nullptr; + std::visit( + Fortran::common::visitors{ + [&](const Fortran::parser::Designator &designator) { + if (auto *arrayEle = + Fortran::parser::Unwrap( + designator)) { + sym = GetFirstName(arrayEle->base).symbol; + } else if (auto *structComp = Fortran::parser::Unwrap< + Fortran::parser::StructureComponent>(designator)) { + sym = structComp->component.symbol; + } else if (const Fortran::parser::Name *name = + Fortran::semantics::getDesignatorNameIfDataRef( + designator)) { + sym = name->symbol; + } + }, + [&](const Fortran::parser::Name &name) { sym = name.symbol; }}, + ompObject.u); + return sym; +} + +} // namespace omp +} // namespace lower +} // namespace Fortran diff --git a/flang/lib/Lower/OpenMP/Utils.h b/flang/lib/Lower/OpenMP/Utils.h new file mode 100644 index 0000000000000..c346f891f0797 --- /dev/null +++ b/flang/lib/Lower/OpenMP/Utils.h @@ -0,0 +1,68 @@ +//===-- Lower/OpenMP/Utils.h ------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef FORTRAN_LOWER_OPENMPUTILS_H +#define FORTRAN_LOWER_OPENMPUTILS_H + +#include "mlir/Dialect/OpenMP/OpenMPDialect.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/Value.h" +#include "llvm/Support/CommandLine.h" + +extern llvm::cl::opt treatIndexAsSection; + +namespace fir { +class FirOpBuilder; +} // namespace fir + +namespace Fortran { + +namespace semantics { +class Symbol; +} // namespace semantics + +namespace parser { +struct OmpObject; +struct OmpObjectList; +} // namespace parser + +namespace lower { + +class AbstractConverter; + +namespace omp { + +using DeclareTargetCapturePair = + std::pair; + +mlir::omp::MapInfoOp +createMapInfoOp(fir::FirOpBuilder &builder, mlir::Location loc, + mlir::Value baseAddr, mlir::Value varPtrPtr, std::string name, + mlir::SmallVector bounds, + mlir::SmallVector members, uint64_t mapType, + mlir::omp::VariableCaptureKind mapCaptureType, mlir::Type retTy, + bool isVal = false); + +void gatherFuncAndVarSyms( + const Fortran::parser::OmpObjectList &objList, + mlir::omp::DeclareTargetCaptureClause clause, + llvm::SmallVectorImpl &symbolAndClause); + +Fortran::semantics::Symbol * +getOmpObjectSymbol(const Fortran::parser::OmpObject &ompObject); + +void genObjectList(const Fortran::parser::OmpObjectList &objectList, + Fortran::lower::AbstractConverter &converter, + llvm::SmallVectorImpl &operands); + +} // namespace omp +} // namespace lower +} // namespace Fortran + +#endif // FORTRAN_LOWER_OPENMPUTILS_H