diff --git a/flang/include/flang/Lower/Support/ReductionProcessor.h b/flang/include/flang/Lower/Support/ReductionProcessor.h index 66f26b3b55630..bd0447360f089 100644 --- a/flang/include/flang/Lower/Support/ReductionProcessor.h +++ b/flang/include/flang/Lower/Support/ReductionProcessor.h @@ -40,6 +40,13 @@ namespace omp { class ReductionProcessor { public: + using GenInitValueCBTy = + std::function; + using GenCombinerCBTy = std::function; + // TODO: Move this enumeration to the OpenMP dialect enum ReductionIdentifier { ID, @@ -58,6 +65,9 @@ class ReductionProcessor { IEOR }; + static bool doReductionByRef(mlir::Type reductionType); + static bool doReductionByRef(mlir::Value reductionVar); + static ReductionIdentifier getReductionType(const omp::clause::ProcedureDesignator &pd); @@ -109,6 +119,14 @@ class ReductionProcessor { ReductionIdentifier redId, mlir::Type type, mlir::Value op1, mlir::Value op2); + /// Creates an OpenMP reduction declaration and inserts it into the provided + /// symbol table. The init and combiner regions are generated by the callback + /// functions genCombinerCB and genInitValueCB. + template + static DeclareRedType createDeclareReductionHelper( + AbstractConverter &converter, llvm::StringRef reductionOpName, + mlir::Type type, mlir::Location loc, bool isByRef, + GenCombinerCBTy genCombinerCB, GenInitValueCBTy genInitValueCB); /// Creates an OpenMP reduction declaration and inserts it into the provided /// symbol table. The declaration has a constant initializer with the neutral diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp index e018a2d937435..fadfb29b07a28 100644 --- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp +++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp @@ -13,6 +13,7 @@ #include "ClauseProcessor.h" #include "Utils.h" +#include "flang/Lower/ConvertCall.h" #include "flang/Lower/ConvertExprToHLFIR.h" #include "flang/Lower/OpenMP/Clauses.h" #include "flang/Lower/PFTBuilder.h" @@ -402,6 +403,65 @@ bool ClauseProcessor::processInclusive( return false; } +bool ClauseProcessor::processInitializer( + lower::SymMap &symMap, const parser::OmpClause::Initializer &inp, + ReductionProcessor::GenInitValueCBTy &genInitValueCB) const { + if (auto *clause = findUniqueClause()) { + genInitValueCB = [&, clause](fir::FirOpBuilder &builder, mlir::Location loc, + mlir::Type type, mlir::Value ompOrig) { + lower::SymMapScope scope(symMap); + const parser::OmpInitializerExpression &iexpr = inp.v.v; + const parser::OmpStylizedInstance &styleInstance = iexpr.v.front(); + const std::list &declList = + std::get>(styleInstance.t); + mlir::Value ompPrivVar; + for (const parser::OmpStylizedDeclaration &decl : declList) { + auto &name = std::get(decl.var.t); + assert(name.symbol && "Name does not have a symbol"); + mlir::Value addr = builder.createTemporary(loc, ompOrig.getType()); + fir::StoreOp::create(builder, loc, ompOrig, addr); + fir::FortranVariableFlagsEnum extraFlags = {}; + fir::FortranVariableFlagsAttr attributes = + Fortran::lower::translateSymbolAttributes(builder.getContext(), + *name.symbol, extraFlags); + auto declareOp = hlfir::DeclareOp::create( + builder, loc, addr, name.ToString(), nullptr, {}, nullptr, nullptr, + 0, attributes); + if (name.ToString() == "omp_priv") + ompPrivVar = declareOp.getResult(0); + symMap.addVariableDefinition(*name.symbol, declareOp); + } + // Lower the expression/function call + lower::StatementContext stmtCtx; + mlir::Value result = common::visit( + common::visitors{ + [&](const evaluate::ProcedureRef &procRef) -> mlir::Value { + convertCallToHLFIR(loc, converter, procRef, std::nullopt, + symMap, stmtCtx); + auto privVal = fir::LoadOp::create(builder, loc, ompPrivVar); + return privVal; + }, + [&](const auto &expr) -> mlir::Value { + mlir::Value exprResult = fir::getBase(convertExprToValue( + loc, converter, clause->v, symMap, stmtCtx)); + // Conversion can either give a value or a refrence to a value, + // we need to return the reduction type, so an optional load may + // be generated. + if (auto refType = llvm::dyn_cast( + exprResult.getType())) + if (ompPrivVar.getType() == refType) + exprResult = fir::LoadOp::create(builder, loc, exprResult); + return exprResult; + }}, + clause->v.u); + stmtCtx.finalizeAndPop(); + return result; + }; + return true; + } + return false; +} + bool ClauseProcessor::processMergeable( mlir::omp::MergeableClauseOps &result) const { return markClauseOccurrence(result.mergeable); diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.h b/flang/lib/Lower/OpenMP/ClauseProcessor.h index d524b4ddc8ac4..529b871330052 100644 --- a/flang/lib/Lower/OpenMP/ClauseProcessor.h +++ b/flang/lib/Lower/OpenMP/ClauseProcessor.h @@ -18,6 +18,7 @@ #include "flang/Lower/Bridge.h" #include "flang/Lower/DirectivesCommon.h" #include "flang/Lower/OpenMP/Clauses.h" +#include "flang/Lower/Support/ReductionProcessor.h" #include "flang/Optimizer/Builder/Todo.h" #include "flang/Parser/dump-parse-tree.h" #include "flang/Parser/parse-tree.h" @@ -88,6 +89,9 @@ class ClauseProcessor { bool processHint(mlir::omp::HintClauseOps &result) const; bool processInclusive(mlir::Location currentLocation, mlir::omp::InclusiveClauseOps &result) const; + bool processInitializer( + lower::SymMap &symMap, const parser::OmpClause::Initializer &inp, + ReductionProcessor::GenInitValueCBTy &genInitValueCB) const; bool processMergeable(mlir::omp::MergeableClauseOps &result) const; bool processNogroup(mlir::omp::NogroupClauseOps &result) const; bool processNowait(mlir::omp::NowaitClauseOps &result) const; diff --git a/flang/lib/Lower/OpenMP/Clauses.cpp b/flang/lib/Lower/OpenMP/Clauses.cpp index b1a3c3d3c5439..dc49a8118b0a5 100644 --- a/flang/lib/Lower/OpenMP/Clauses.cpp +++ b/flang/lib/Lower/OpenMP/Clauses.cpp @@ -981,7 +981,22 @@ Init make(const parser::OmpClause::Init &inp, Initializer make(const parser::OmpClause::Initializer &inp, semantics::SemanticsContext &semaCtx) { - llvm_unreachable("Empty: initializer"); + const parser::OmpInitializerExpression &iexpr = inp.v.v; + const parser::OmpStylizedInstance &styleInstance = iexpr.v.front(); + const parser::OmpStylizedInstance::Instance &instance = + std::get(styleInstance.t); + if (const auto *as = std::get_if(&instance.u)) { + auto &expr = std::get(as->t); + return Initializer{makeExpr(expr, semaCtx)}; + } else if (const auto *call = std::get_if(&instance.u)) { + if (call->typedCall) { + const auto &procRef = *call->typedCall; + semantics::SomeExpr evalProcRef{procRef}; + return Initializer{evalProcRef}; + } + } + + llvm_unreachable("Unexpected initializer"); } InReduction make(const parser::OmpClause::InReduction &inp, diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp index f822fe3c8dd71..bf1fd541f4273 100644 --- a/flang/lib/Lower/OpenMP/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP/OpenMP.cpp @@ -18,12 +18,15 @@ #include "Decomposer.h" #include "Utils.h" #include "flang/Common/idioms.h" +#include "flang/Evaluate/type.h" #include "flang/Lower/Bridge.h" #include "flang/Lower/ConvertExpr.h" +#include "flang/Lower/ConvertExprToHLFIR.h" #include "flang/Lower/ConvertVariable.h" #include "flang/Lower/DirectivesCommon.h" #include "flang/Lower/OpenMP/Clauses.h" #include "flang/Lower/StatementContext.h" +#include "flang/Lower/Support/ReductionProcessor.h" #include "flang/Lower/SymbolMap.h" #include "flang/Optimizer/Builder/BoxValue.h" #include "flang/Optimizer/Builder/FIRBuilder.h" @@ -2847,7 +2850,6 @@ genTeamsOp(lower::AbstractConverter &converter, lower::SymMap &symTable, // TODO: Add private syms and vars. args.reduction.syms = reductionSyms; args.reduction.vars = clauseOps.reductionVars; - return genOpWithBody( OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval, llvm::omp::Directive::OMPD_teams) @@ -3563,12 +3565,137 @@ genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable, TODO(converter.getCurrentLocation(), "OmpDeclareVariantDirective"); } +static ReductionProcessor::GenCombinerCBTy +processReductionCombiner(lower::AbstractConverter &converter, + lower::SymMap &symTable, + semantics::SemanticsContext &semaCtx, + const parser::OmpReductionSpecifier &specifier) { + ReductionProcessor::GenCombinerCBTy genCombinerCB; + const auto &combinerExpression = + std::get>(specifier.t) + .value(); + const parser::OmpStylizedInstance &combinerInstance = + combinerExpression.v.front(); + const parser::OmpStylizedInstance::Instance &instance = + std::get(combinerInstance.t); + + const auto *as = std::get_if(&instance.u); + if (!as) { + TODO(converter.getCurrentLocation(), + "A combiner that is a subroutine call is not yet supported"); + } + auto &expr = std::get(as->t); + genCombinerCB = [&](fir::FirOpBuilder &builder, mlir::Location loc, + mlir::Type type, mlir::Value lhs, mlir::Value rhs, + bool isByRef) { + const auto &evalExpr = makeExpr(expr, semaCtx); + lower::SymMapScope scope(symTable); + const std::list &declList = + std::get>(combinerInstance.t); + for (const parser::OmpStylizedDeclaration &decl : declList) { + auto &name = std::get(decl.var.t); + mlir::Value addr = lhs; + mlir::Type type = lhs.getType(); + bool isRhs = name.ToString() == std::string("omp_in"); + if (isRhs) { + addr = rhs; + type = rhs.getType(); + } + + assert(name.symbol && "Reduction object name does not have a symbol"); + if (!fir::conformsWithPassByRef(type)) { + addr = builder.createTemporary(loc, type); + fir::StoreOp::create(builder, loc, isRhs ? rhs : lhs, addr); + } + fir::FortranVariableFlagsEnum extraFlags = {}; + fir::FortranVariableFlagsAttr attributes = + Fortran::lower::translateSymbolAttributes(builder.getContext(), + *name.symbol, extraFlags); + auto declareOp = + hlfir::DeclareOp::create(builder, loc, addr, name.ToString(), nullptr, + {}, nullptr, nullptr, 0, attributes); + symTable.addVariableDefinition(*name.symbol, declareOp); + } + + lower::StatementContext stmtCtx; + mlir::Value result = fir::getBase( + convertExprToValue(loc, converter, evalExpr, symTable, stmtCtx)); + if (auto refType = llvm::dyn_cast(result.getType())) + if (lhs.getType() == refType.getElementType()) + result = fir::LoadOp::create(builder, loc, result); + stmtCtx.finalizeAndPop(); + if (isByRef) { + fir::StoreOp::create(builder, loc, result, lhs); + mlir::omp::YieldOp::create(builder, loc, lhs); + } else { + mlir::omp::YieldOp::create(builder, loc, result); + } + }; + return genCombinerCB; +} + +// Getting the type from a symbol compared to a DeclSpec is simpler since we do +// not need to consider derived vs intrinsic types. Semantics is guaranteed to +// generate these symbols. +static mlir::Type +getReductionType(lower::AbstractConverter &converter, + const parser::OmpReductionSpecifier &specifier) { + const auto &combinerExpression = + std::get>(specifier.t) + .value(); + const parser::OmpStylizedInstance &combinerInstance = + combinerExpression.v.front(); + const std::list &declList = + std::get>(combinerInstance.t); + const parser::OmpStylizedDeclaration &decl = declList.front(); + const auto &name = std::get(decl.var.t); + const auto &symbol = semantics::SymbolRef(*name.symbol); + mlir::Type reductionType = converter.genType(symbol); + return reductionType; +} + static void genOMP( lower::AbstractConverter &converter, lower::SymMap &symTable, semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval, const parser::OpenMPDeclareReductionConstruct &declareReductionConstruct) { - if (!semaCtx.langOptions().OpenMPSimd) - TODO(converter.getCurrentLocation(), "OpenMPDeclareReductionConstruct"); + if (semaCtx.langOptions().OpenMPSimd) + return; + + const parser::OmpArgumentList &args{declareReductionConstruct.v.Arguments()}; + const parser::OmpArgument &arg{args.v.front()}; + const auto &specifier = std::get(arg.u); + + if (std::get(specifier.t).v.size() > 1) + TODO(converter.getCurrentLocation(), + "multiple types in declare reduction is not yet supported"); + + mlir::Type reductionType = getReductionType(converter, specifier); + ReductionProcessor::GenCombinerCBTy genCombinerCB = + processReductionCombiner(converter, symTable, semaCtx, specifier); + const parser::OmpClauseList &initializer = + declareReductionConstruct.v.Clauses(); + if (initializer.v.size() > 0) { + List clauses = makeClauses(initializer, semaCtx); + ReductionProcessor::GenInitValueCBTy genInitValueCB; + ClauseProcessor cp(converter, semaCtx, clauses); + const parser::OmpClause::Initializer &iclause{ + std::get(initializer.v.front().u)}; + cp.processInitializer(symTable, iclause, genInitValueCB); + const auto &identifier = + std::get(specifier.t); + const auto &designator = + std::get(identifier.u); + const auto &reductionName = std::get(designator.u); + bool isByRef = ReductionProcessor::doReductionByRef(reductionType); + ReductionProcessor::createDeclareReductionHelper< + mlir::omp::DeclareReductionOp>( + converter, reductionName.ToString(), reductionType, + converter.getCurrentLocation(), isByRef, genCombinerCB, genInitValueCB); + } else { + TODO(converter.getCurrentLocation(), + "declare reduction without an initializer clause is not yet " + "supported"); + } } static void diff --git a/flang/lib/Lower/Support/ReductionProcessor.cpp b/flang/lib/Lower/Support/ReductionProcessor.cpp index 605a5b6b20b94..12c8b8f33c414 100644 --- a/flang/lib/Lower/Support/ReductionProcessor.cpp +++ b/flang/lib/Lower/Support/ReductionProcessor.cpp @@ -462,7 +462,7 @@ static void genCombiner(fir::FirOpBuilder &builder, mlir::Location loc, bool isByRef) { ty = fir::unwrapRefType(ty); - if (fir::isa_trivial(ty)) { + if (fir::isa_trivial(ty) || fir::isa_derived(ty)) { mlir::Value lhsLoaded = builder.loadIfRef(loc, lhs); mlir::Value rhsLoaded = builder.loadIfRef(loc, rhs); @@ -501,7 +501,7 @@ static mlir::Type unwrapSeqOrBoxedType(mlir::Type ty) { template static void createReductionAllocAndInitRegions( AbstractConverter &converter, mlir::Location loc, OpType &reductionDecl, - const ReductionProcessor::ReductionIdentifier redId, mlir::Type type, + ReductionProcessor::GenInitValueCBTy genInitValueCB, mlir::Type type, bool isByRef) { fir::FirOpBuilder &builder = converter.getFirOpBuilder(); auto yield = [&](mlir::Value ret) { genYield(builder, loc, ret); }; @@ -523,9 +523,8 @@ static void createReductionAllocAndInitRegions( mlir::Type ty = fir::unwrapRefType(type); builder.setInsertionPointToEnd(initBlock); - mlir::Value initValue = ReductionProcessor::getReductionInitValue( - loc, unwrapSeqOrBoxedType(ty), redId, builder); - + mlir::Value initValue = + genInitValueCB(builder, loc, ty, initBlock->getArgument(0)); if (isByRef) { populateByRefInitAndCleanupRegions( converter, loc, type, initValue, initBlock, @@ -536,7 +535,7 @@ static void createReductionAllocAndInitRegions( /*isDoConcurrent*/ std::is_same_v); } - if (fir::isa_trivial(ty)) { + if (fir::isa_trivial(ty) || fir::isa_derived(ty)) { if (isByRef) { // alloc region builder.setInsertionPointToEnd(allocBlock); @@ -556,18 +555,18 @@ static void createReductionAllocAndInitRegions( yield(boxAlloca); } -template -OpType ReductionProcessor::createDeclareReduction( +template +DeclareRedType ReductionProcessor::createDeclareReductionHelper( AbstractConverter &converter, llvm::StringRef reductionOpName, - const ReductionIdentifier redId, mlir::Type type, mlir::Location loc, - bool isByRef) { + mlir::Type type, mlir::Location loc, bool isByRef, + GenCombinerCBTy genCombinerCB, GenInitValueCBTy genInitValueCB) { fir::FirOpBuilder &builder = converter.getFirOpBuilder(); mlir::OpBuilder::InsertionGuard guard(builder); mlir::ModuleOp module = builder.getModule(); assert(!reductionOpName.empty()); - auto decl = module.lookupSymbol(reductionOpName); + auto decl = module.lookupSymbol(reductionOpName); if (decl) return decl; @@ -576,23 +575,54 @@ OpType ReductionProcessor::createDeclareReduction( if (!isByRef) type = valTy; - decl = OpType::create(modBuilder, loc, reductionOpName, type); - createReductionAllocAndInitRegions(converter, loc, decl, redId, type, + decl = DeclareRedType::create(modBuilder, loc, reductionOpName, type); + createReductionAllocAndInitRegions(converter, loc, decl, genInitValueCB, type, isByRef); - builder.createBlock(&decl.getReductionRegion(), decl.getReductionRegion().end(), {type, type}, {loc, loc}); - builder.setInsertionPointToEnd(&decl.getReductionRegion().back()); mlir::Value op1 = decl.getReductionRegion().front().getArgument(0); mlir::Value op2 = decl.getReductionRegion().front().getArgument(1); - genCombiner(builder, loc, redId, type, op1, op2, isByRef); - + genCombinerCB(builder, loc, type, op1, op2, isByRef); return decl; } -static bool doReductionByRef(mlir::Value reductionVar) { +template +OpType ReductionProcessor::createDeclareReduction( + AbstractConverter &converter, llvm::StringRef reductionOpName, + const ReductionIdentifier redId, mlir::Type type, mlir::Location loc, + bool isByRef) { + auto genInitValueCB = [&](fir::FirOpBuilder &builder, mlir::Location loc, + mlir::Type type, mlir::Value val) { + mlir::Type ty = fir::unwrapRefType(type); + mlir::Value initValue = ReductionProcessor::getReductionInitValue( + loc, unwrapSeqOrBoxedType(ty), redId, builder); + return initValue; + }; + auto genCombinerCB = [&](fir::FirOpBuilder &builder, mlir::Location loc, + mlir::Type type, mlir::Value op1, mlir::Value op2, + bool isByRef) { + genCombiner(builder, loc, redId, type, op1, op2, isByRef); + }; + + return createDeclareReductionHelper(converter, reductionOpName, type, + loc, isByRef, genCombinerCB, + genInitValueCB); +} + +bool ReductionProcessor::doReductionByRef(mlir::Type reductionType) { + if (forceByrefReduction) + return true; + + if (!fir::isa_trivial(fir::unwrapRefType(reductionType)) && + !fir::isa_derived(fir::unwrapRefType(reductionType))) + return true; + + return false; +} + +bool ReductionProcessor::doReductionByRef(mlir::Value reductionVar) { if (forceByrefReduction) return true; @@ -600,10 +630,7 @@ static bool doReductionByRef(mlir::Value reductionVar) { mlir::dyn_cast(reductionVar.getDefiningOp())) reductionVar = declare.getMemref(); - if (!fir::isa_trivial(fir::unwrapRefType(reductionVar.getType()))) - return true; - - return false; + return doReductionByRef(reductionVar.getType()); } template @@ -614,6 +641,8 @@ bool ReductionProcessor::processReductionArguments( llvm::SmallVectorImpl &reduceVarByRef, llvm::SmallVectorImpl &reductionDeclSymbols, const llvm::SmallVectorImpl &reductionSymbols) { + fir::FirOpBuilder &builder = converter.getFirOpBuilder(); + if constexpr (std::is_same_v) { // For OpenMP reduction clauses, check if the reduction operator is @@ -627,7 +656,13 @@ bool ReductionProcessor::processReductionArguments( std::get_if(&redOperator.u)) { if (!ReductionProcessor::supportedIntrinsicProcReduction( *reductionIntrinsic)) { - return false; + // If not an intrinsic is has to be a custom reduction op, and should + // be available in the module. + semantics::Symbol *sym = reductionIntrinsic->v.sym(); + mlir::ModuleOp module = builder.getModule(); + auto decl = module.lookupSymbol(getRealName(sym).ToString()); + if (!decl) + return false; } } else { return false; @@ -637,7 +672,6 @@ bool ReductionProcessor::processReductionArguments( // Reduction variable processing common to both intrinsic operators and // procedure designators - fir::FirOpBuilder &builder = converter.getFirOpBuilder(); mlir::OpBuilder::InsertPoint dcIP; constexpr bool isDoConcurrent = std::is_same_v; @@ -741,7 +775,13 @@ bool ReductionProcessor::processReductionArguments( &redOperator.u)) { if (!ReductionProcessor::supportedIntrinsicProcReduction( *reductionIntrinsic)) { - TODO(currentLocation, "Unsupported intrinsic proc reduction"); + // Custom reductions we can just add to the symbols without + // generating the declare reduction op. + semantics::Symbol *sym = reductionIntrinsic->v.sym(); + reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get( + builder.getContext(), sym->name().ToString())); + ++idx; + continue; } redId = getReductionType(*reductionIntrinsic); reductionName = diff --git a/flang/lib/Optimizer/OpenMP/MarkDeclareTarget.cpp b/flang/lib/Optimizer/OpenMP/MarkDeclareTarget.cpp index 0b0e6bd9ecf34..5fa77fb2080df 100644 --- a/flang/lib/Optimizer/OpenMP/MarkDeclareTarget.cpp +++ b/flang/lib/Optimizer/OpenMP/MarkDeclareTarget.cpp @@ -21,6 +21,7 @@ #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" #include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/TypeSwitch.h" namespace flangomp { #define GEN_PASS_DEF_MARKDECLARETARGETPASS @@ -31,9 +32,93 @@ namespace { class MarkDeclareTargetPass : public flangomp::impl::MarkDeclareTargetPassBase { - void markNestedFuncs(mlir::omp::DeclareTargetDeviceType parentDevTy, - mlir::omp::DeclareTargetCaptureClause parentCapClause, - bool parentAutomap, mlir::Operation *currOp, + struct ParentInfo { + mlir::omp::DeclareTargetDeviceType devTy; + mlir::omp::DeclareTargetCaptureClause capClause; + bool automap; + }; + + void processSymbolRef(mlir::SymbolRefAttr symRef, ParentInfo parentInfo, + llvm::SmallPtrSet visited) { + if (auto currFOp = + getOperation().lookupSymbol(symRef)) { + auto current = llvm::dyn_cast( + currFOp.getOperation()); + + if (current.isDeclareTarget()) { + auto currentDt = current.getDeclareTargetDeviceType(); + + // Found the same function twice, with different device_types, + // mark as Any as it belongs to both + if (currentDt != parentInfo.devTy && + currentDt != mlir::omp::DeclareTargetDeviceType::any) { + current.setDeclareTarget(mlir::omp::DeclareTargetDeviceType::any, + current.getDeclareTargetCaptureClause(), + current.getDeclareTargetAutomap()); + } + } else { + current.setDeclareTarget(parentInfo.devTy, parentInfo.capClause, + parentInfo.automap); + } + + markNestedFuncs(parentInfo, currFOp, visited); + } + } + + void processReductionRefs(std::optional symRefs, + ParentInfo parentInfo, + llvm::SmallPtrSet visited) { + if (!symRefs) + return; + + for (auto symRef : symRefs->getAsRange()) { + if (auto declareReductionOp = + getOperation().lookupSymbol( + symRef)) { + markNestedFuncs(parentInfo, declareReductionOp, visited); + } + } + } + + void + processReductionClauses(mlir::Operation *op, ParentInfo parentInfo, + llvm::SmallPtrSet visited) { + llvm::TypeSwitch(*op) + .Case([&](mlir::omp::LoopOp op) { + processReductionRefs(op.getReductionSyms(), parentInfo, visited); + }) + .Case([&](mlir::omp::ParallelOp op) { + processReductionRefs(op.getReductionSyms(), parentInfo, visited); + }) + .Case([&](mlir::omp::SectionsOp op) { + processReductionRefs(op.getReductionSyms(), parentInfo, visited); + }) + .Case([&](mlir::omp::SimdOp op) { + processReductionRefs(op.getReductionSyms(), parentInfo, visited); + }) + .Case([&](mlir::omp::TargetOp op) { + processReductionRefs(op.getInReductionSyms(), parentInfo, visited); + }) + .Case([&](mlir::omp::TaskgroupOp op) { + processReductionRefs(op.getTaskReductionSyms(), parentInfo, visited); + }) + .Case([&](mlir::omp::TaskloopOp op) { + processReductionRefs(op.getReductionSyms(), parentInfo, visited); + processReductionRefs(op.getInReductionSyms(), parentInfo, visited); + }) + .Case([&](mlir::omp::TaskOp op) { + processReductionRefs(op.getInReductionSyms(), parentInfo, visited); + }) + .Case([&](mlir::omp::TeamsOp op) { + processReductionRefs(op.getReductionSyms(), parentInfo, visited); + }) + .Case([&](mlir::omp::WsloopOp op) { + processReductionRefs(op.getReductionSyms(), parentInfo, visited); + }) + .Default([](mlir::Operation &) {}); + } + + void markNestedFuncs(ParentInfo parentInfo, mlir::Operation *currOp, llvm::SmallPtrSet visited) { if (visited.contains(currOp)) return; @@ -43,33 +128,10 @@ class MarkDeclareTargetPass if (auto callOp = llvm::dyn_cast(op)) { if (auto symRef = llvm::dyn_cast_if_present( callOp.getCallableForCallee())) { - if (auto currFOp = - getOperation().lookupSymbol(symRef)) { - auto current = llvm::dyn_cast( - currFOp.getOperation()); - - if (current.isDeclareTarget()) { - auto currentDt = current.getDeclareTargetDeviceType(); - - // Found the same function twice, with different device_types, - // mark as Any as it belongs to both - if (currentDt != parentDevTy && - currentDt != mlir::omp::DeclareTargetDeviceType::any) { - current.setDeclareTarget( - mlir::omp::DeclareTargetDeviceType::any, - current.getDeclareTargetCaptureClause(), - current.getDeclareTargetAutomap()); - } - } else { - current.setDeclareTarget(parentDevTy, parentCapClause, - parentAutomap); - } - - markNestedFuncs(parentDevTy, parentCapClause, parentAutomap, - currFOp, visited); - } + processSymbolRef(symRef, parentInfo, visited); } } + processReductionClauses(op, parentInfo, visited); }); } @@ -82,10 +144,10 @@ class MarkDeclareTargetPass functionOp.getOperation()); if (declareTargetOp.isDeclareTarget()) { llvm::SmallPtrSet visited; - markNestedFuncs(declareTargetOp.getDeclareTargetDeviceType(), - declareTargetOp.getDeclareTargetCaptureClause(), - declareTargetOp.getDeclareTargetAutomap(), functionOp, - visited); + ParentInfo parentInfo{declareTargetOp.getDeclareTargetDeviceType(), + declareTargetOp.getDeclareTargetCaptureClause(), + declareTargetOp.getDeclareTargetAutomap()}; + markNestedFuncs(parentInfo, functionOp, visited); } } @@ -96,12 +158,13 @@ class MarkDeclareTargetPass // the contents of the device clause getOperation()->walk([&](mlir::omp::TargetOp tarOp) { llvm::SmallPtrSet visited; - markNestedFuncs( - /*parentDevTy=*/mlir::omp::DeclareTargetDeviceType::nohost, - /*parentCapClause=*/mlir::omp::DeclareTargetCaptureClause::to, - /*parentAutomap=*/false, tarOp, visited); + ParentInfo parentInfo = { + /*devTy=*/mlir::omp::DeclareTargetDeviceType::nohost, + /*capClause=*/mlir::omp::DeclareTargetCaptureClause::to, + /*automap=*/false, + }; + markNestedFuncs(parentInfo, tarOp, visited); }); } }; - } // namespace diff --git a/flang/test/Lower/OpenMP/Todo/omp-declare-reduction-initsub.f90 b/flang/test/Lower/OpenMP/Todo/omp-declare-reduction-initsub.f90 deleted file mode 100644 index 30630465490b2..0000000000000 --- a/flang/test/Lower/OpenMP/Todo/omp-declare-reduction-initsub.f90 +++ /dev/null @@ -1,28 +0,0 @@ -! This test checks lowering of OpenMP declare reduction Directive, with initialization -! via a subroutine. This functionality is currently not implemented. - -! RUN: not flang -fc1 -emit-fir -fopenmp %s 2>&1 | FileCheck %s - -!CHECK: not yet implemented: OpenMPDeclareReductionConstruct -subroutine initme(x,n) - integer x,n - x=n -end subroutine initme - -function func(x, n, init) - integer func - integer x(n) - integer res - interface - subroutine initme(x,n) - integer x,n - end subroutine initme - end interface -!$omp declare reduction(red_add:integer(4):omp_out=omp_out+omp_in) initializer(initme(omp_priv,0)) - res=init -!$omp simd reduction(red_add:res) - do i=1,n - res=res+x(i) - enddo - func=res -end function func diff --git a/flang/test/Lower/OpenMP/Todo/omp-declare-reduction.f90 b/flang/test/Lower/OpenMP/Todo/omp-declare-reduction.f90 deleted file mode 100644 index db50c9ac8ee9d..0000000000000 --- a/flang/test/Lower/OpenMP/Todo/omp-declare-reduction.f90 +++ /dev/null @@ -1,10 +0,0 @@ -! This test checks lowering of OpenMP declare reduction Directive. - -! RUN: not flang -fc1 -emit-fir -fopenmp %s 2>&1 | FileCheck %s - -subroutine declare_red() - integer :: my_var - !CHECK: not yet implemented: OpenMPDeclareReductionConstruct - !$omp declare reduction (my_red : integer : omp_out = omp_in) initializer (omp_priv = 0) - my_var = 0 -end subroutine declare_red diff --git a/flang/test/Lower/OpenMP/declare-target-deferred-marking-reductions.f90 b/flang/test/Lower/OpenMP/declare-target-deferred-marking-reductions.f90 new file mode 100644 index 0000000000000..66697ef6bbe70 --- /dev/null +++ b/flang/test/Lower/OpenMP/declare-target-deferred-marking-reductions.f90 @@ -0,0 +1,37 @@ +!RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=52 %s -o - | FileCheck %s +!RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=52 -fopenmp-is-device %s -o - | FileCheck %s + +program main + use, intrinsic :: iso_c_binding + implicit none + interface + subroutine myinit(priv, orig) bind(c,name="myinit") + use, intrinsic :: iso_c_binding + implicit none + integer::priv, orig + end subroutine myinit + + function mycombine(lhs, rhs) bind(c,name="mycombine") + use, intrinsic :: iso_c_binding + implicit none + integer::lhs, rhs, mycombine + end function mycombine + end interface + !$omp declare reduction(myreduction:integer:omp_out = mycombine(omp_out, omp_in)) initializer(myinit(omp_priv, omp_orig)) + + integer :: i, s, a(10) + !$omp target + s = 0 + !$omp do reduction(myreduction:s) + do i = 1, 10 + s = mycombine(s, a(i)) + enddo + !$omp end do + !$omp end target + end program main + +!CHECK: func.func {{.*}} @myinit(!fir.ref, !fir.ref) +!CHECK-SAME: {{.*}}, omp.declare_target = #omp.declaretarget{{.*}} +!CHECK-LABEL: func.func {{.*}} @mycombine(!fir.ref, !fir.ref) +!CHECK-SAME: {{.*}}, omp.declare_target = #omp.declaretarget{{.*}} + diff --git a/flang/test/Lower/OpenMP/omp-declare-reduction-derivedtype.f90 b/flang/test/Lower/OpenMP/omp-declare-reduction-derivedtype.f90 new file mode 100644 index 0000000000000..36bb131e677a3 --- /dev/null +++ b/flang/test/Lower/OpenMP/omp-declare-reduction-derivedtype.f90 @@ -0,0 +1,112 @@ +! This test checks lowering of OpenMP declare reduction Directive, with initialization +! via a subroutine. This functionality is currently not implemented. + +!RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=52 %s -o - | FileCheck %s +module maxtype_mod + implicit none + + type maxtype + integer::sumval + integer::maxval + end type maxtype + +contains + + subroutine initme(x,n) + type(maxtype) :: x,n + x%sumval=0 + x%maxval=0 + end subroutine initme + + function mycombine(lhs, rhs) + type(maxtype) :: lhs, rhs + type(maxtype) :: mycombine + mycombine%sumval = lhs%sumval + rhs%sumval + mycombine%maxval = max(lhs%maxval, rhs%maxval) + end function mycombine + + function func(x, n, init) + type(maxtype) :: func + integer :: n, i + type(maxtype) :: x(n) + type(maxtype) :: init + type(maxtype) :: res +!$omp declare reduction(red_add_max:maxtype:omp_out=mycombine(omp_out,omp_in)) initializer(initme(omp_priv,omp_orig)) + res=init +!$omp simd reduction(red_add_max:res) + do i=1,n + res=mycombine(res,x(i)) + enddo + func=res + end function func + +end module maxtype_mod +!CHECK: omp.declare_reduction @red_add_max : [[MAXTYPE:.*]] init { +!CHECK: ^bb0(%[[OMP_ORIG_ARG_I:.*]]: [[MAXTYPE]]): +!CHECK: %[[OMP_PRIV:.*]] = fir.alloca [[MAXTYPE]] +!CHECK: %[[OMP_ORIG:.*]] = fir.alloca [[MAXTYPE]] +!CHECK: fir.store %[[OMP_ORIG_ARG_I]] to %[[OMP_ORIG]] : !fir.ref<[[MAXTYPE]]> +!CHECK: %[[OMP_ORIG_DECL:.*]]:2 = hlfir.declare %[[OMP_ORIG]] {uniq_name = "omp_orig"} : (!fir.ref<[[MAXTYPE]]>) -> (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>) +!CHECK: fir.store %[[OMP_ORIG_ARG_I]] to %[[OMP_PRIV]] : !fir.ref<[[MAXTYPE]]> +!CHECK: %[[OMP_PRIV_DECL:.*]]:2 = hlfir.declare %[[OMP_PRIV]] {uniq_name = "omp_priv"} : (!fir.ref<[[MAXTYPE]]>) -> (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>) +!CHECK: fir.call @_QMmaxtype_modPinitme(%[[OMP_PRIV_DECL]]#0, %[[OMP_ORIG_DECL]]#0) fastmath : (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>) -> () +!CHECK: %[[OMP_PRIV_VAL:.*]] = fir.load %[[OMP_PRIV_DECL]]#0 : !fir.ref<[[MAXTYPE]]> +!CHECK: omp.yield(%[[OMP_PRIV_VAL]] : [[MAXTYPE]]) +!CHECK: } combiner { +!CHECK: ^bb0(%[[LHS_ARG:.*]]: [[MAXTYPE]], %[[RHS_ARG:.*]]: [[MAXTYPE]]): +!CHECK: %[[RESULT:.*]] = fir.alloca [[MAXTYPE]] {bindc_name = ".result"} +!CHECK: %[[OMP_OUT:.*]] = fir.alloca [[MAXTYPE]] +!CHECK: %[[OMP_IN:.*]] = fir.alloca [[MAXTYPE]] +!CHECK: fir.store %[[RHS_ARG]] to %[[OMP_IN]] : !fir.ref<[[MAXTYPE]]> +!CHECK: %[[OMP_IN_DECL:.*]]:2 = hlfir.declare %[[OMP_IN]] {uniq_name = "omp_in"} : (!fir.ref<[[MAXTYPE]]>) -> (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>) +!CHECK: fir.store %[[LHS_ARG]] to %[[OMP_OUT]] : !fir.ref<[[MAXTYPE]]> +!CHECK: %[[OMP_OUT_DECL:.*]]:2 = hlfir.declare %[[OMP_OUT]] {uniq_name = "omp_out"} : (!fir.ref<[[MAXTYPE]]>) -> (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>) +!CHECK: %[[COMBINE_RESULT:.*]] = fir.call @_QMmaxtype_modPmycombine(%[[OMP_OUT_DECL]]#0, %[[OMP_IN_DECL]]#0) fastmath : (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>) -> [[MAXTYPE]] +!CHECK: fir.save_result %[[COMBINE_RESULT]] to %[[RESULT]] : [[MAXTYPE]], !fir.ref<[[MAXTYPE]]> +!CHECK: %[[TMPRESULT:.*]]:2 = hlfir.declare %[[RESULT]] {uniq_name = ".tmp.func_result"} : (!fir.ref<[[MAXTYPE]]>) -> (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>) +!CHECK: %false = arith.constant false +!CHECK: %[[EXPRRESULT:.*]] = hlfir.as_expr %[[TMPRESULT]]#0 move %false : (!fir.ref<[[MAXTYPE]]>, i1) -> !hlfir.expr<[[MAXTYPE]]> +!CHECK: %[[ASSOCIATE:.*]]:3 = hlfir.associate %[[EXPRRESULT]] {adapt.valuebyref} : (!hlfir.expr<[[MAXTYPE]]>) -> (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>, i1) +!CHECK: %[[RESULT_VAL:.*]] = fir.load %[[ASSOCIATE]]#0 : !fir.ref<[[MAXTYPE]]> +!CHECK: hlfir.end_associate %[[ASSOCIATE]]#1, %[[ASSOCIATE]]#2 : !fir.ref<[[MAXTYPE]]>, i1 +!CHECK: omp.yield(%[[RESULT_VAL]] : [[MAXTYPE]]) +!CHECK: } + +!CHECK: func.func @_QMmaxtype_modPinitme(%[[X_ARG:.*]]: !fir.ref<[[MAXTYPE]]> {fir.bindc_name = "x"}, %[[N_ARG:.*]]: !fir.ref<[[MAXTYPE]]> {fir.bindc_name = "n"}) { +!CHECK: %[[SCOPE:.*]] = fir.dummy_scope : !fir.dscope +!CHECK: %[[N_DECL:.*]]:2 = hlfir.declare %[[N_ARG]] dummy_scope %[[SCOPE]] arg 2 {uniq_name = "_QMmaxtype_modFinitmeEn"} : (!fir.ref<[[MAXTYPE]]>, !fir.dscope) -> (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>) +!CHECK: %[[X_DECL:.*]]:2 = hlfir.declare %[[X_ARG]] dummy_scope %[[SCOPE]] arg 1 {uniq_name = "_QMmaxtype_modFinitmeEx"} : (!fir.ref<[[MAXTYPE]]>, !fir.dscope) -> (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>) +!CHECK: %[[ZERO_0:.*]] = arith.constant 0 : i32 +!CHECK: %[[X_DESIGNATE_SUMVAL:.*]] = hlfir.designate %[[X_DECL]]#0{"sumval"} : (!fir.ref<[[MAXTYPE]]>) -> !fir.ref +!CHECK: hlfir.assign %[[ZERO_0]] to %[[X_DESIGNATE_SUMVAL]] : i32, !fir.ref +!CHECK: %[[ZERO_1:.*]] = arith.constant 0 : i32 +!CHECK: %[[X_DESIGNATE_MAXVAL:.*]] = hlfir.designate %[[X_DECL]]#0{"maxval"} : (!fir.ref<[[MAXTYPE]]>) -> !fir.ref +!CHECK: hlfir.assign %[[ZERO_1]] to %[[X_DESIGNATE_MAXVAL]] : i32, !fir.ref +!CHECK: return +!CHECK: } + + +!CHECK: func.func @_QMmaxtype_modPmycombine(%[[LHS:.*]]: !fir.ref<[[MAXTYPE]]> {fir.bindc_name = "lhs"}, %[[RHS:.*]]: !fir.ref<[[MAXTYPE]]> {fir.bindc_name = "rhs"}) -> [[MAXTYPE]] { +!CHECK: %[[SCOPE:.*]] = fir.dummy_scope : !fir.dscope +!CHECK: %[[LHS_DECL:.*]]:2 = hlfir.declare %[[LHS]] dummy_scope %[[SCOPE]] arg 1 {uniq_name = "_QMmaxtype_modFmycombineElhs"} : (!fir.ref<[[MAXTYPE]]>, !fir.dscope) -> (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>) +!CHECK: %[[RESULT_ALLOC:.*]] = fir.alloca [[MAXTYPE]] {bindc_name = "mycombine", uniq_name = "_QMmaxtype_modFmycombineEmycombine"} +!CHECK: %[[RESULT_DECL:.*]]:2 = hlfir.declare %[[RESULT_ALLOC]] {uniq_name = "_QMmaxtype_modFmycombineEmycombine"} : (!fir.ref<[[MAXTYPE]]>) -> (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>) +!CHECK: %[[RHS_DECL:.*]]:2 = hlfir.declare %[[RHS]] dummy_scope %[[SCOPE]] arg 2 {uniq_name = "_QMmaxtype_modFmycombineErhs"} : (!fir.ref<[[MAXTYPE]]>, !fir.dscope) -> (!fir.ref<[[MAXTYPE]]>, !fir.ref<[[MAXTYPE]]>) +!CHECK: %[[LHS_DESIGNATE_SUMVAL:.*]] = hlfir.designate %[[LHS_DECL]]#0{"sumval"} : (!fir.ref<[[MAXTYPE]]>) -> !fir.ref +!CHECK: %[[LHS_SUMVAL:.*]] = fir.load %[[LHS_DESIGNATE_SUMVAL]] : !fir.ref +!CHECK: %[[RHS_DESIGNATE_SUMVAL:.*]] = hlfir.designate %[[RHS_DECL]]#0{"sumval"} : (!fir.ref<[[MAXTYPE]]>) -> !fir.ref +!CHECK: %[[RHS_SUMVAL:.*]] = fir.load %[[RHS_DESIGNATE_SUMVAL]] : !fir.ref +!CHECK: %[[SUM:.*]] = arith.addi %[[LHS_SUMVAL]], %[[RHS_SUMVAL]] : i32 +!CHECK: %[[RESULT_DESIGNATE_SUMVAL:.*]] = hlfir.designate %[[RESULT_DECL]]#0{"sumval"} : (!fir.ref<[[MAXTYPE]]>) -> !fir.ref +!CHECK: hlfir.assign %[[SUM]] to %[[RESULT_DESIGNATE_SUMVAL]] : i32, !fir.ref +!CHECK: %[[LHS_DESIGNATE_MAXVAL:.*]] = hlfir.designate %[[LHS_DECL]]#0{"maxval"} : (!fir.ref<[[MAXTYPE]]>) -> !fir.ref +!CHECK: %[[LHS_MAXVAL:.*]] = fir.load %[[LHS_DESIGNATE_MAXVAL]] : !fir.ref +!CHECK: %[[RHS_DESIGNATE_MAXVAL:.*]] = hlfir.designate %[[RHS_DECL]]#0{"maxval"} : (!fir.ref<[[MAXTYPE]]>) -> !fir.ref +!CHECK: %[[RHS_MAXVAL:.*]] = fir.load %[[RHS_DESIGNATE_MAXVAL]] : !fir.ref +!CHECK: %[[CMP:.*]] = arith.cmpi sgt, %[[LHS_MAXVAL]], %[[RHS_MAXVAL]] : i32 +!CHECK: %[[MAX_VAL:.*]] = arith.select %[[CMP]], %[[LHS_MAXVAL]], %[[RHS_MAXVAL]] : i32 +!CHECK: %[[RESULT_DESIGNAGE_MAXVAL:.*]] = hlfir.designate %[[RESULT_DECL]]#0{"maxval"} : (!fir.ref<[[MAXTYPE]]>) -> !fir.ref +!CHECK: hlfir.assign %[[MAX_VAL]] to %[[RESULT_DESIGNAGE_MAXVAL]] : i32, !fir.ref +!CHECK: %[[RESULT:.*]] = fir.load %[[RESULT_DECL]]#0 : !fir.ref<[[MAXTYPE]]> +!CHECK: return %[[RESULT]] : [[MAXTYPE]] +!CHECK: } diff --git a/flang/test/Lower/OpenMP/omp-declare-reduction-initsub.f90 b/flang/test/Lower/OpenMP/omp-declare-reduction-initsub.f90 new file mode 100644 index 0000000000000..4aacc7cb2efba --- /dev/null +++ b/flang/test/Lower/OpenMP/omp-declare-reduction-initsub.f90 @@ -0,0 +1,59 @@ +! This test checks lowering of OpenMP declare reduction Directive, with initialization +! via a subroutine. This functionality is currently not implemented. + +!RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=52 %s -o - | FileCheck %s + +subroutine initme(x,n) + integer x,n + x=0 +end subroutine initme + +function func(x, n, init) + integer func + integer x(n) + integer res + interface + subroutine initme(x,n) + integer x,n + end subroutine initme + end interface +!CHECK: omp.declare_reduction @red_add : i32 init { +!CHECK: ^bb0(%[[OMP_ORIG_ARG_I:.*]]: i32): +!CHECK: %[[OMP_PRIV:.*]] = fir.alloca i32 +!CHECK: %[[OMP_ORIG:.*]] = fir.alloca i32 +!CHECK: fir.store %[[OMP_ORIG_ARG_I]] to %[[OMP_ORIG]] : !fir.ref +!CHECK: %[[OMP_ORIG_DECL:.*]]:2 = hlfir.declare %[[OMP_ORIG]] {uniq_name = "omp_orig"} : (!fir.ref) -> (!fir.ref, !fir.ref) +!CHECK: fir.store %[[OMP_ORIG_ARG_I]] to %[[OMP_PRIV]] : !fir.ref +!CHECK: %[[OMP_PRIV_DECL:.*]]:2 = hlfir.declare %[[OMP_PRIV]] {uniq_name = "omp_priv"} : (!fir.ref) -> (!fir.ref, !fir.ref) +!CHECK: fir.call @_QPinitme(%[[OMP_PRIV_DECL]]#0, %[[OMP_ORIG_DECL]]#0) fastmath : (!fir.ref, !fir.ref) -> () +!CHECK: %[[OMP_PRIV_VAL:.*]] = fir.load %[[OMP_PRIV_DECL]]#0 : !fir.ref +!CHECK: omp.yield(%[[OMP_PRIV_VAL]] : i32) +!CHECK: } combiner { +!CHECK: ^bb0(%[[LHS_ARG:.*]]: i32, %[[RHS_ARG:.*]]: i32): +!CHECK: %[[OMP_OUT:.*]] = fir.alloca i32 +!CHECK: %[[OMP_IN:.*]] = fir.alloca i32 +!CHECK: fir.store %[[RHS_ARG]] to %[[OMP_IN]] : !fir.ref +!CHECK: %[[OMP_IN_DECL:.*]]:2 = hlfir.declare %[[OMP_IN]] {uniq_name = "omp_in"} : (!fir.ref) -> (!fir.ref, !fir.ref) +!CHECK: fir.store %[[LHS_ARG]] to %[[OMP_OUT]] : !fir.ref +!CHECK: %[[OMP_OUT_DECL:.*]]:2 = hlfir.declare %[[OMP_OUT]] {uniq_name = "omp_out"} : (!fir.ref) -> (!fir.ref, !fir.ref) +!CHECK: %[[OMP_OUT_VAL:.*]] = fir.load %[[OMP_OUT_DECL]]#0 : !fir.ref +!CHECK: %[[OMP_IN_VAL:.*]] = fir.load %[[OMP_IN_DECL]]#0 : !fir.ref +!CHECK: %[[SUM:.*]] = arith.addi %[[OMP_OUT_VAL]], %[[OMP_IN_VAL]] : i32 +!CHECK: omp.yield(%[[SUM]] : i32) +!CHECK: } +!CHECK: func.func @_QPinitme(%[[X:.*]]: !fir.ref {fir.bindc_name = "x"}, %[[N:.*]]: !fir.ref {fir.bindc_name = "n"}) { +!CHECK: %[[SCOPE:.*]] = fir.dummy_scope : !fir.dscope +!CHECK: %[[N_DECL:.*]]:2 = hlfir.declare %[[N]] dummy_scope %[[SCOPE]] arg 2 {uniq_name = "_QFinitmeEn"} : (!fir.ref, !fir.dscope) -> (!fir.ref, !fir.ref) +!CHECK: %[[X_DECL:.*]]:2 = hlfir.declare %[[X]] dummy_scope %[[OMP_OUT]] arg 1 {uniq_name = "_QFinitmeEx"} : (!fir.ref, !fir.dscope) -> (!fir.ref, !fir.ref) +!CHECK: %[[CONST_0:.*]] = arith.constant 0 : i32 +!CHECK: hlfir.assign %[[CONST_0]] to %[[X_DECL]]#0 : i32, !fir.ref +!CHECK: return +!CHECK: } +!$omp declare reduction(red_add:integer(4):omp_out=omp_out+omp_in) initializer(initme(omp_priv,omp_orig)) + res=init +!$omp simd reduction(red_add:res) + do i=1,n + res=res+x(i) + enddo + func=res +end function func diff --git a/flang/test/Lower/OpenMP/omp-declare-reduction.f90 b/flang/test/Lower/OpenMP/omp-declare-reduction.f90 new file mode 100644 index 0000000000000..a41f6b214b9d8 --- /dev/null +++ b/flang/test/Lower/OpenMP/omp-declare-reduction.f90 @@ -0,0 +1,33 @@ +! This test checks lowering of OpenMP declare reduction Directive. + +!RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=52 %s -o - | FileCheck %s + +subroutine declare_red() + integer :: my_var +!CHECK: omp.declare_reduction @my_red : i32 init { +!CHECK: ^bb0(%[[OMP_ORIG_ARG_I:.*]]: i32): +!CHECK: %[[OMP_PRIV:.*]] = fir.alloca i32 +!CHECK: %[[OMP_ORIG:.*]] = fir.alloca i32 +!CHECK: fir.store %[[OMP_ORIG_ARG_I]] to %[[OMP_ORIG]] : !fir.ref +!CHECK: %[[OMP_ORIG_DECL:.*]]:2 = hlfir.declare %[[OMP_ORIG]] {uniq_name = "omp_orig"} : (!fir.ref) -> (!fir.ref, !fir.ref) +!CHECK: fir.store %[[OMP_ORIG_ARG_I]] to %[[OMP_PRIV]] : !fir.ref +!CHECK: %[[OMP_PRIV_DECL:.*]]:2 = hlfir.declare %[[OMP_PRIV]] {uniq_name = "omp_priv"} : (!fir.ref) -> (!fir.ref, !fir.ref) +!CHECK: %[[CONST_0:.*]] = arith.constant 0 : i32 +!CHECK: omp.yield(%[[CONST_0]] : i32) +!CHECK: } combiner { +!CHECK: ^bb0(%[[LHS_ARG:.*]]: i32, %[[RHS_ARG:.*]]: i32): +!CHECK: %[[OMP_OUT:.*]] = fir.alloca i32 +!CHECK: %[[OMP_IN:.*]] = fir.alloca i32 +!CHECK: fir.store %[[RHS_ARG]] to %[[OMP_IN]] : !fir.ref +!CHECK: %[[OMP_IN_DECL:.*]]:2 = hlfir.declare %[[OMP_IN]] {uniq_name = "omp_in"} : (!fir.ref) -> (!fir.ref, !fir.ref) +!CHECK: fir.store %[[LHS_ARG]] to %[[OMP_OUT]] : !fir.ref +!CHECK: %[[OMP_OUT_DECL:.*]]:2 = hlfir.declare %[[OMP_OUT]] {uniq_name = "omp_out"} : (!fir.ref) -> (!fir.ref, !fir.ref) +!CHECK: %[[OMP_OUT_VAL:.*]] = fir.load %[[OMP_OUT_DECL]]#0 : !fir.ref +!CHECK: %[[OMP_IN_VAL:.*]] = fir.load %[[OMP_IN_DECL]]#0 : !fir.ref +!CHECK: %[[SUM:.*]] = arith.addi %[[OMP_OUT_VAL]], %[[OMP_IN_VAL]] : i32 +!CHECK: omp.yield(%[[SUM]] : i32) +!CHECK: } + + !$omp declare reduction (my_red : integer : omp_out = omp_out + omp_in) initializer (omp_priv = 0) + my_var = 0 +end subroutine declare_red diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index 8edec990eaaba..0dcf0cf17f4d7 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -1170,6 +1170,7 @@ allocReductionVars(T loop, ArrayRef reductionArgs, template static void mapInitializationArgs(T loop, LLVM::ModuleTranslation &moduleTranslation, + llvm::IRBuilderBase &builder, SmallVectorImpl &reductionDecls, DenseMap &reductionVariableMap, unsigned i) { @@ -1180,8 +1181,17 @@ mapInitializationArgs(T loop, LLVM::ModuleTranslation &moduleTranslation, mlir::Value mlirSource = loop.getReductionVars()[i]; llvm::Value *llvmSource = moduleTranslation.lookupValue(mlirSource); - assert(llvmSource && "lookup reduction var"); - moduleTranslation.mapValue(reduction.getInitializerMoldArg(), llvmSource); + llvm::Value *origVal = llvmSource; + // If a non-pointer value is expected, load the value from the source pointer. + if (!isa( + reduction.getInitializerMoldArg().getType()) && + isa(mlirSource.getType())) { + origVal = + builder.CreateLoad(moduleTranslation.convertType( + reduction.getInitializerMoldArg().getType()), + llvmSource, "omp_orig"); + } + moduleTranslation.mapValue(reduction.getInitializerMoldArg(), origVal); if (entry.getNumArguments() > 1) { llvm::Value *allocation = @@ -1254,7 +1264,7 @@ initReductionVars(OP op, ArrayRef reductionArgs, SmallVector phis; // map block argument to initializer region - mapInitializationArgs(op, moduleTranslation, reductionDecls, + mapInitializationArgs(op, moduleTranslation, builder, reductionDecls, reductionVariableMap, i); // TODO In some cases (specially on the GPU), the init regions may diff --git a/offload/test/offloading/fortran/target-custom-reduction-derivedtype.f90 b/offload/test/offloading/fortran/target-custom-reduction-derivedtype.f90 new file mode 100644 index 0000000000000..cc390cf0881f3 --- /dev/null +++ b/offload/test/offloading/fortran/target-custom-reduction-derivedtype.f90 @@ -0,0 +1,88 @@ +! Basic offloading test with custom OpenMP reduction on derived type +! REQUIRES: flang, amdgpu +! +! RUN: %libomptarget-compile-fortran-generic +! RUN: env LIBOMPTARGET_INFO=16 %libomptarget-run-generic 2>&1 | %fcheck-generic +module maxtype_mod + implicit none + + type maxtype + integer::sumval + integer::maxval + end type maxtype + +contains + + subroutine initme(x,n) + type(maxtype) :: x,n + x%sumval=0 + x%maxval=0 + end subroutine initme + + function mycombine(lhs, rhs) + type(maxtype) :: lhs, rhs + type(maxtype) :: mycombine + mycombine%sumval = lhs%sumval + rhs%sumval + mycombine%maxval = max(lhs%maxval, rhs%maxval) + end function mycombine + +end module maxtype_mod + +program main + use maxtype_mod + implicit none + + integer :: n = 100 + integer :: i + integer :: error = 0 + type(maxtype) :: x(100) + type(maxtype) :: res + integer :: expected_sum, expected_max + +!$omp declare reduction(red_add_max:maxtype:omp_out=mycombine(omp_out,omp_in)) initializer(initme(omp_priv,omp_orig)) + + ! Initialize array with test data + do i = 1, n + x(i)%sumval = i + x(i)%maxval = i + end do + + ! Initialize reduction variable + res%sumval = 0 + res%maxval = 0 + + ! Perform reduction in target region + !$omp target parallel do map(to:x) reduction(red_add_max:res) + do i = 1, n + res = mycombine(res, x(i)) + end do + !$omp end target parallel do + + ! Compute expected values + expected_sum = 0 + expected_max = 0 + do i = 1, n + expected_sum = expected_sum + i + expected_max = max(expected_max, i) + end do + + ! Check results + if (res%sumval /= expected_sum) then + error = 1 + endif + + if (res%maxval /= expected_max) then + error = 1 + endif + + if (error == 0) then + print *,"PASSED" + else + print *,"FAILED" + endif + +end program main + +! CHECK: "PluginInterface" device {{[0-9]+}} info: Launching kernel {{.*}} +! CHECK: PASSED +