diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp index fb24c8d1fe3eb..ae0d8bd37228d 100644 --- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp +++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp @@ -162,14 +162,13 @@ getIfClauseOperand(Fortran::lower::AbstractConverter &converter, ifVal); } -static void -addUseDeviceClause(Fortran::lower::AbstractConverter &converter, - const omp::ObjectList &objects, - llvm::SmallVectorImpl &operands, - llvm::SmallVectorImpl &useDeviceTypes, - llvm::SmallVectorImpl &useDeviceLocs, - llvm::SmallVectorImpl - &useDeviceSymbols) { +static void addUseDeviceClause( + Fortran::lower::AbstractConverter &converter, + const omp::ObjectList &objects, + llvm::SmallVectorImpl &operands, + llvm::SmallVectorImpl &useDeviceTypes, + llvm::SmallVectorImpl &useDeviceLocs, + llvm::SmallVectorImpl &useDeviceSyms) { genObjectList(objects, converter, operands); for (mlir::Value &operand : operands) { checkMapType(operand.getLoc(), operand.getType()); @@ -177,25 +176,24 @@ addUseDeviceClause(Fortran::lower::AbstractConverter &converter, useDeviceLocs.push_back(operand.getLoc()); } for (const omp::Object &object : objects) - useDeviceSymbols.push_back(object.id()); + useDeviceSyms.push_back(object.id()); } static void convertLoopBounds(Fortran::lower::AbstractConverter &converter, mlir::Location loc, - llvm::SmallVectorImpl &lowerBound, - llvm::SmallVectorImpl &upperBound, - llvm::SmallVectorImpl &step, + mlir::omp::CollapseClauseOps &result, std::size_t loopVarTypeSize) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); // The types of lower bound, upper bound, and step are converted into the // type of the loop variable if necessary. mlir::Type loopVarType = getLoopVarType(converter, loopVarTypeSize); - for (unsigned it = 0; it < (unsigned)lowerBound.size(); it++) { - lowerBound[it] = - firOpBuilder.createConvert(loc, loopVarType, lowerBound[it]); - upperBound[it] = - firOpBuilder.createConvert(loc, loopVarType, upperBound[it]); - step[it] = firOpBuilder.createConvert(loc, loopVarType, step[it]); + for (unsigned it = 0; it < (unsigned)result.loopLBVar.size(); it++) { + result.loopLBVar[it] = + firOpBuilder.createConvert(loc, loopVarType, result.loopLBVar[it]); + result.loopUBVar[it] = + firOpBuilder.createConvert(loc, loopVarType, result.loopUBVar[it]); + result.loopStepVar[it] = + firOpBuilder.createConvert(loc, loopVarType, result.loopStepVar[it]); } } @@ -205,9 +203,7 @@ static void convertLoopBounds(Fortran::lower::AbstractConverter &converter, bool ClauseProcessor::processCollapse( mlir::Location currentLocation, Fortran::lower::pft::Evaluation &eval, - llvm::SmallVectorImpl &lowerBound, - llvm::SmallVectorImpl &upperBound, - llvm::SmallVectorImpl &step, + mlir::omp::CollapseClauseOps &result, llvm::SmallVectorImpl &iv) const { bool found = false; fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); @@ -238,15 +234,15 @@ bool ClauseProcessor::processCollapse( std::get_if(&loopControl->u); assert(bounds && "Expected bounds for worksharing do loop"); Fortran::lower::StatementContext stmtCtx; - lowerBound.push_back(fir::getBase(converter.genExprValue( + result.loopLBVar.push_back(fir::getBase(converter.genExprValue( *Fortran::semantics::GetExpr(bounds->lower), stmtCtx))); - upperBound.push_back(fir::getBase(converter.genExprValue( + result.loopUBVar.push_back(fir::getBase(converter.genExprValue( *Fortran::semantics::GetExpr(bounds->upper), stmtCtx))); if (bounds->step) { - step.push_back(fir::getBase(converter.genExprValue( + result.loopStepVar.push_back(fir::getBase(converter.genExprValue( *Fortran::semantics::GetExpr(bounds->step), stmtCtx))); } else { // If `step` is not present, assume it as `1`. - step.push_back(firOpBuilder.createIntegerConstant( + result.loopStepVar.push_back(firOpBuilder.createIntegerConstant( currentLocation, firOpBuilder.getIntegerType(32), 1)); } iv.push_back(bounds->name.thing.symbol); @@ -257,8 +253,7 @@ bool ClauseProcessor::processCollapse( &*std::next(doConstructEval->getNestedEvaluations().begin()); } while (collapseValue > 0); - convertLoopBounds(converter, currentLocation, lowerBound, upperBound, step, - loopVarTypeSize); + convertLoopBounds(converter, currentLocation, result, loopVarTypeSize); return found; } @@ -286,7 +281,7 @@ bool ClauseProcessor::processDefault() const { } bool ClauseProcessor::processDevice(Fortran::lower::StatementContext &stmtCtx, - mlir::Value &result) const { + mlir::omp::DeviceClauseOps &result) const { const Fortran::parser::CharBlock *source = nullptr; if (auto *clause = findUniqueClause(&source)) { mlir::Location clauseLocation = converter.genLocation(*source); @@ -298,25 +293,26 @@ bool ClauseProcessor::processDevice(Fortran::lower::StatementContext &stmtCtx, } } const auto &deviceExpr = std::get(clause->t); - result = fir::getBase(converter.genExprValue(deviceExpr, stmtCtx)); + result.deviceVar = + fir::getBase(converter.genExprValue(deviceExpr, stmtCtx)); return true; } return false; } bool ClauseProcessor::processDeviceType( - mlir::omp::DeclareTargetDeviceType &result) const { + mlir::omp::DeviceTypeClauseOps &result) const { if (auto *clause = findUniqueClause()) { // Case: declare target ... device_type(any | host | nohost) switch (clause->v) { case omp::clause::DeviceType::DeviceTypeDescription::Nohost: - result = mlir::omp::DeclareTargetDeviceType::nohost; + result.deviceType = mlir::omp::DeclareTargetDeviceType::nohost; break; case omp::clause::DeviceType::DeviceTypeDescription::Host: - result = mlir::omp::DeclareTargetDeviceType::host; + result.deviceType = mlir::omp::DeclareTargetDeviceType::host; break; case omp::clause::DeviceType::DeviceTypeDescription::Any: - result = mlir::omp::DeclareTargetDeviceType::any; + result.deviceType = mlir::omp::DeclareTargetDeviceType::any; break; } return true; @@ -325,7 +321,7 @@ bool ClauseProcessor::processDeviceType( } bool ClauseProcessor::processFinal(Fortran::lower::StatementContext &stmtCtx, - mlir::Value &result) const { + mlir::omp::FinalClauseOps &result) const { const Fortran::parser::CharBlock *source = nullptr; if (auto *clause = findUniqueClause(&source)) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); @@ -333,100 +329,108 @@ bool ClauseProcessor::processFinal(Fortran::lower::StatementContext &stmtCtx, mlir::Value finalVal = fir::getBase(converter.genExprValue(clause->v, stmtCtx)); - result = firOpBuilder.createConvert(clauseLocation, - firOpBuilder.getI1Type(), finalVal); + result.finalVar = firOpBuilder.createConvert( + clauseLocation, firOpBuilder.getI1Type(), finalVal); return true; } return false; } -bool ClauseProcessor::processHint(mlir::IntegerAttr &result) const { +bool ClauseProcessor::processHint(mlir::omp::HintClauseOps &result) const { if (auto *clause = findUniqueClause()) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); int64_t hintValue = *Fortran::evaluate::ToInt64(clause->v); - result = firOpBuilder.getI64IntegerAttr(hintValue); + result.hintAttr = firOpBuilder.getI64IntegerAttr(hintValue); return true; } return false; } -bool ClauseProcessor::processMergeable(mlir::UnitAttr &result) const { - return markClauseOccurrence(result); +bool ClauseProcessor::processMergeable( + mlir::omp::MergeableClauseOps &result) const { + return markClauseOccurrence(result.mergeableAttr); } -bool ClauseProcessor::processNowait(mlir::UnitAttr &result) const { - return markClauseOccurrence(result); +bool ClauseProcessor::processNowait(mlir::omp::NowaitClauseOps &result) const { + return markClauseOccurrence(result.nowaitAttr); } -bool ClauseProcessor::processNumTeams(Fortran::lower::StatementContext &stmtCtx, - mlir::Value &result) const { +bool ClauseProcessor::processNumTeams( + Fortran::lower::StatementContext &stmtCtx, + mlir::omp::NumTeamsClauseOps &result) const { // TODO Get lower and upper bounds for num_teams when parser is updated to // accept both. if (auto *clause = findUniqueClause()) { // auto lowerBound = std::get>(clause->t); auto &upperBound = std::get(clause->t); - result = fir::getBase(converter.genExprValue(upperBound, stmtCtx)); + result.numTeamsUpperVar = + fir::getBase(converter.genExprValue(upperBound, stmtCtx)); return true; } return false; } bool ClauseProcessor::processNumThreads( - Fortran::lower::StatementContext &stmtCtx, mlir::Value &result) const { + Fortran::lower::StatementContext &stmtCtx, + mlir::omp::NumThreadsClauseOps &result) const { if (auto *clause = findUniqueClause()) { // OMPIRBuilder expects `NUM_THREADS` clause as a `Value`. - result = fir::getBase(converter.genExprValue(clause->v, stmtCtx)); + result.numThreadsVar = + fir::getBase(converter.genExprValue(clause->v, stmtCtx)); return true; } return false; } -bool ClauseProcessor::processOrdered(mlir::IntegerAttr &result) const { +bool ClauseProcessor::processOrdered( + mlir::omp::OrderedClauseOps &result) const { if (auto *clause = findUniqueClause()) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); int64_t orderedClauseValue = 0l; if (clause->v.has_value()) orderedClauseValue = *Fortran::evaluate::ToInt64(*clause->v); - result = firOpBuilder.getI64IntegerAttr(orderedClauseValue); + result.orderedAttr = firOpBuilder.getI64IntegerAttr(orderedClauseValue); return true; } return false; } -bool ClauseProcessor::processPriority(Fortran::lower::StatementContext &stmtCtx, - mlir::Value &result) const { +bool ClauseProcessor::processPriority( + Fortran::lower::StatementContext &stmtCtx, + mlir::omp::PriorityClauseOps &result) const { if (auto *clause = findUniqueClause()) { - result = fir::getBase(converter.genExprValue(clause->v, stmtCtx)); + result.priorityVar = + fir::getBase(converter.genExprValue(clause->v, stmtCtx)); return true; } return false; } bool ClauseProcessor::processProcBind( - mlir::omp::ClauseProcBindKindAttr &result) const { + mlir::omp::ProcBindClauseOps &result) const { if (auto *clause = findUniqueClause()) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - result = genProcBindKindAttr(firOpBuilder, *clause); + result.procBindKindAttr = genProcBindKindAttr(firOpBuilder, *clause); return true; } return false; } -bool ClauseProcessor::processSafelen(mlir::IntegerAttr &result) const { +bool ClauseProcessor::processSafelen( + mlir::omp::SafelenClauseOps &result) const { if (auto *clause = findUniqueClause()) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); const std::optional safelenVal = Fortran::evaluate::ToInt64(clause->v); - result = firOpBuilder.getI64IntegerAttr(*safelenVal); + result.safelenAttr = firOpBuilder.getI64IntegerAttr(*safelenVal); return true; } return false; } bool ClauseProcessor::processSchedule( - mlir::omp::ClauseScheduleKindAttr &valAttr, - mlir::omp::ScheduleModifierAttr &modifierAttr, - mlir::UnitAttr &simdModifierAttr) const { + Fortran::lower::StatementContext &stmtCtx, + mlir::omp::ScheduleClauseOps &result) const { if (auto *clause = findUniqueClause()) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); mlir::MLIRContext *context = firOpBuilder.getContext(); @@ -451,53 +455,51 @@ bool ClauseProcessor::processSchedule( break; } - mlir::omp::ScheduleModifier scheduleModifier = getScheduleModifier(*clause); + result.scheduleValAttr = + mlir::omp::ClauseScheduleKindAttr::get(context, scheduleKind); + mlir::omp::ScheduleModifier scheduleModifier = getScheduleModifier(*clause); if (scheduleModifier != mlir::omp::ScheduleModifier::none) - modifierAttr = + result.scheduleModAttr = mlir::omp::ScheduleModifierAttr::get(context, scheduleModifier); if (getSimdModifier(*clause) != mlir::omp::ScheduleModifier::none) - simdModifierAttr = firOpBuilder.getUnitAttr(); + result.scheduleSimdAttr = firOpBuilder.getUnitAttr(); - valAttr = mlir::omp::ClauseScheduleKindAttr::get(context, scheduleKind); - return true; - } - return false; -} - -bool ClauseProcessor::processScheduleChunk( - Fortran::lower::StatementContext &stmtCtx, mlir::Value &result) const { - if (auto *clause = findUniqueClause()) { if (const auto &chunkExpr = std::get(clause->t)) - result = fir::getBase(converter.genExprValue(*chunkExpr, stmtCtx)); + result.scheduleChunkVar = + fir::getBase(converter.genExprValue(*chunkExpr, stmtCtx)); + return true; } return false; } -bool ClauseProcessor::processSimdlen(mlir::IntegerAttr &result) const { +bool ClauseProcessor::processSimdlen( + mlir::omp::SimdlenClauseOps &result) const { if (auto *clause = findUniqueClause()) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); const std::optional simdlenVal = Fortran::evaluate::ToInt64(clause->v); - result = firOpBuilder.getI64IntegerAttr(*simdlenVal); + result.simdlenAttr = firOpBuilder.getI64IntegerAttr(*simdlenVal); return true; } return false; } bool ClauseProcessor::processThreadLimit( - Fortran::lower::StatementContext &stmtCtx, mlir::Value &result) const { + Fortran::lower::StatementContext &stmtCtx, + mlir::omp::ThreadLimitClauseOps &result) const { if (auto *clause = findUniqueClause()) { - result = fir::getBase(converter.genExprValue(clause->v, stmtCtx)); + result.threadLimitVar = + fir::getBase(converter.genExprValue(clause->v, stmtCtx)); return true; } return false; } -bool ClauseProcessor::processUntied(mlir::UnitAttr &result) const { - return markClauseOccurrence(result); +bool ClauseProcessor::processUntied(mlir::omp::UntiedClauseOps &result) const { + return markClauseOccurrence(result.untiedAttr); } //===----------------------------------------------------------------------===// @@ -505,13 +507,12 @@ bool ClauseProcessor::processUntied(mlir::UnitAttr &result) const { //===----------------------------------------------------------------------===// bool ClauseProcessor::processAllocate( - llvm::SmallVectorImpl &allocatorOperands, - llvm::SmallVectorImpl &allocateOperands) const { + mlir::omp::AllocateClauseOps &result) const { return findRepeatableClause( [&](const omp::clause::Allocate &clause, const Fortran::parser::CharBlock &) { - genAllocateClause(converter, clause, allocatorOperands, - allocateOperands); + genAllocateClause(converter, clause, result.allocatorVars, + result.allocateVars); }); } @@ -660,10 +661,9 @@ createCopyFunc(mlir::Location loc, Fortran::lower::AbstractConverter &converter, return funcOp; } -bool ClauseProcessor::processCopyPrivate( +bool ClauseProcessor::processCopyprivate( mlir::Location currentLocation, - llvm::SmallVectorImpl ©PrivateVars, - llvm::SmallVectorImpl ©PrivateFuncs) const { + mlir::omp::CopyprivateClauseOps &result) const { auto addCopyPrivateVar = [&](Fortran::semantics::Symbol *sym) { mlir::Value symVal = converter.getSymbolAddress(*sym); auto declOp = symVal.getDefiningOp(); @@ -690,10 +690,10 @@ bool ClauseProcessor::processCopyPrivate( cpVar = alloca; } - copyPrivateVars.push_back(cpVar); + result.copyprivateVars.push_back(cpVar); mlir::func::FuncOp funcOp = createCopyFunc(currentLocation, converter, cpVar.getType(), attrs); - copyPrivateFuncs.push_back(mlir::SymbolRefAttr::get(funcOp)); + result.copyprivateFuncs.push_back(mlir::SymbolRefAttr::get(funcOp)); }; bool hasCopyPrivate = findRepeatableClause( @@ -714,9 +714,7 @@ bool ClauseProcessor::processCopyPrivate( return hasCopyPrivate; } -bool ClauseProcessor::processDepend( - llvm::SmallVectorImpl &dependTypeOperands, - llvm::SmallVectorImpl &dependOperands) const { +bool ClauseProcessor::processDepend(mlir::omp::DependClauseOps &result) const { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); return findRepeatableClause( @@ -731,7 +729,7 @@ bool ClauseProcessor::processDepend( mlir::omp::ClauseTaskDependAttr dependTypeOperand = genDependKindAttr(firOpBuilder, kind); - dependTypeOperands.append(objects.size(), dependTypeOperand); + result.dependTypeAttrs.append(objects.size(), dependTypeOperand); for (const omp::Object &object : objects) { assert(object.ref() && "Expecting designator"); @@ -746,13 +744,13 @@ bool ClauseProcessor::processDepend( Fortran::semantics::Symbol *sym = object.id(); const mlir::Value variable = converter.getSymbolAddress(*sym); - dependOperands.push_back(variable); + result.dependVars.push_back(variable); } }); } bool ClauseProcessor::processHasDeviceAddr( - llvm::SmallVectorImpl &operands, + mlir::omp::HasDeviceAddrClauseOps &result, llvm::SmallVectorImpl &isDeviceTypes, llvm::SmallVectorImpl &isDeviceLocs, llvm::SmallVectorImpl &isDeviceSymbols) @@ -760,14 +758,14 @@ bool ClauseProcessor::processHasDeviceAddr( return findRepeatableClause( [&](const omp::clause::HasDeviceAddr &devAddrClause, const Fortran::parser::CharBlock &) { - addUseDeviceClause(converter, devAddrClause.v, operands, isDeviceTypes, - isDeviceLocs, isDeviceSymbols); + addUseDeviceClause(converter, devAddrClause.v, result.hasDeviceAddrVars, + isDeviceTypes, isDeviceLocs, isDeviceSymbols); }); } bool ClauseProcessor::processIf( omp::clause::If::DirectiveNameModifier directiveName, - mlir::Value &result) const { + mlir::omp::IfClauseOps &result) const { bool found = false; findRepeatableClause( [&](const omp::clause::If &clause, @@ -778,7 +776,7 @@ bool ClauseProcessor::processIf( // Assume that, at most, a single 'if' clause will be applicable to the // given directive. if (operand) { - result = operand; + result.ifVar = operand; found = true; } }); @@ -786,7 +784,7 @@ bool ClauseProcessor::processIf( } bool ClauseProcessor::processIsDevicePtr( - llvm::SmallVectorImpl &operands, + mlir::omp::IsDevicePtrClauseOps &result, llvm::SmallVectorImpl &isDeviceTypes, llvm::SmallVectorImpl &isDeviceLocs, llvm::SmallVectorImpl &isDeviceSymbols) @@ -794,8 +792,8 @@ bool ClauseProcessor::processIsDevicePtr( return findRepeatableClause( [&](const omp::clause::IsDevicePtr &devPtrClause, const Fortran::parser::CharBlock &) { - addUseDeviceClause(converter, devPtrClause.v, operands, isDeviceTypes, - isDeviceLocs, isDeviceSymbols); + addUseDeviceClause(converter, devPtrClause.v, result.isDevicePtrVars, + isDeviceTypes, isDeviceLocs, isDeviceSymbols); }); } @@ -835,12 +833,10 @@ createMapInfoOp(fir::FirOpBuilder &builder, mlir::Location loc, bool ClauseProcessor::processMap( mlir::Location currentLocation, const llvm::omp::Directive &directive, - Fortran::lower::StatementContext &stmtCtx, - llvm::SmallVectorImpl &mapOperands, - llvm::SmallVectorImpl *mapSymTypes, + Fortran::lower::StatementContext &stmtCtx, mlir::omp::MapClauseOps &result, + llvm::SmallVectorImpl *mapSyms, llvm::SmallVectorImpl *mapSymLocs, - llvm::SmallVectorImpl *mapSymbols) - const { + llvm::SmallVectorImpl *mapSymTypes) const { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); return findRepeatableClause( [&](const omp::clause::Map &clause, @@ -915,25 +911,23 @@ bool ClauseProcessor::processMap( mapTypeBits), mlir::omp::VariableCaptureKind::ByRef, symAddr.getType()); - mapOperands.push_back(mapOp); - if (mapSymTypes) - mapSymTypes->push_back(symAddr.getType()); + result.mapVars.push_back(mapOp); + + if (mapSyms) + mapSyms->push_back(object.id()); if (mapSymLocs) mapSymLocs->push_back(symAddr.getLoc()); - - if (mapSymbols) - mapSymbols->push_back(object.id()); + if (mapSymTypes) + mapSymTypes->push_back(symAddr.getType()); } }); } bool ClauseProcessor::processReduction( - mlir::Location currentLocation, - llvm::SmallVectorImpl &outReductionVars, - llvm::SmallVectorImpl &outReductionTypes, - llvm::SmallVectorImpl &outReductionDeclSymbols, - llvm::SmallVectorImpl - *outReductionSymbols) const { + mlir::Location currentLocation, mlir::omp::ReductionClauseOps &result, + llvm::SmallVectorImpl *outReductionTypes, + llvm::SmallVectorImpl *outReductionSyms) + const { return findRepeatableClause( [&](const omp::clause::Reduction &clause, const Fortran::parser::CharBlock &) { @@ -943,30 +937,31 @@ bool ClauseProcessor::processReduction( // whether to do the reduction byref. llvm::SmallVector reductionVars; llvm::SmallVector reductionDeclSymbols; - llvm::SmallVector reductionSymbols; + llvm::SmallVector reductionSyms; ReductionProcessor rp; rp.addDeclareReduction(currentLocation, converter, clause, reductionVars, reductionDeclSymbols, - outReductionSymbols ? &reductionSymbols - : nullptr); + outReductionSyms ? &reductionSyms : nullptr); // Copy local lists into the output. - llvm::copy(reductionVars, std::back_inserter(outReductionVars)); + llvm::copy(reductionVars, std::back_inserter(result.reductionVars)); llvm::copy(reductionDeclSymbols, - std::back_inserter(outReductionDeclSymbols)); - if (outReductionSymbols) - llvm::copy(reductionSymbols, - std::back_inserter(*outReductionSymbols)); - - outReductionTypes.reserve(outReductionTypes.size() + - reductionVars.size()); - llvm::transform(reductionVars, std::back_inserter(outReductionTypes), - [](mlir::Value v) { return v.getType(); }); + std::back_inserter(result.reductionDeclSymbols)); + + if (outReductionTypes) { + outReductionTypes->reserve(outReductionTypes->size() + + reductionVars.size()); + llvm::transform(reductionVars, std::back_inserter(*outReductionTypes), + [](mlir::Value v) { return v.getType(); }); + } + + if (outReductionSyms) + llvm::copy(reductionSyms, std::back_inserter(*outReductionSyms)); }); } bool ClauseProcessor::processSectionsReduction( - mlir::Location currentLocation) const { + mlir::Location currentLocation, mlir::omp::ReductionClauseOps &) const { return findRepeatableClause( [&](const omp::clause::Reduction &, const Fortran::parser::CharBlock &) { TODO(currentLocation, "OMPC_Reduction"); @@ -995,30 +990,30 @@ bool ClauseProcessor::processEnter( } bool ClauseProcessor::processUseDeviceAddr( - llvm::SmallVectorImpl &operands, + mlir::omp::UseDeviceClauseOps &result, llvm::SmallVectorImpl &useDeviceTypes, llvm::SmallVectorImpl &useDeviceLocs, - llvm::SmallVectorImpl &useDeviceSymbols) + llvm::SmallVectorImpl &useDeviceSyms) const { return findRepeatableClause( [&](const omp::clause::UseDeviceAddr &clause, const Fortran::parser::CharBlock &) { - addUseDeviceClause(converter, clause.v, operands, useDeviceTypes, - useDeviceLocs, useDeviceSymbols); + addUseDeviceClause(converter, clause.v, result.useDeviceAddrVars, + useDeviceTypes, useDeviceLocs, useDeviceSyms); }); } bool ClauseProcessor::processUseDevicePtr( - llvm::SmallVectorImpl &operands, + mlir::omp::UseDeviceClauseOps &result, llvm::SmallVectorImpl &useDeviceTypes, llvm::SmallVectorImpl &useDeviceLocs, - llvm::SmallVectorImpl &useDeviceSymbols) + llvm::SmallVectorImpl &useDeviceSyms) const { return findRepeatableClause( [&](const omp::clause::UseDevicePtr &clause, const Fortran::parser::CharBlock &) { - addUseDeviceClause(converter, clause.v, operands, useDeviceTypes, - useDeviceLocs, useDeviceSymbols); + addUseDeviceClause(converter, clause.v, result.useDevicePtrVars, + useDeviceTypes, useDeviceLocs, useDeviceSyms); }); } diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.h b/flang/lib/Lower/OpenMP/ClauseProcessor.h index df8f4f5310fcb..aa2c14b61e756 100644 --- a/flang/lib/Lower/OpenMP/ClauseProcessor.h +++ b/flang/lib/Lower/OpenMP/ClauseProcessor.h @@ -37,7 +37,7 @@ namespace omp { /// corresponding clause if it is present in the clause list. Otherwise, they /// will return `false` to signal that the clause was not found. /// -/// The intended use is of this class is to move clause processing outside of +/// The intended use of this class is to move clause processing outside of /// construct processing, since the same clauses can appear attached to /// different constructs and constructs can be combined, so that code /// duplication is minimized. @@ -56,61 +56,51 @@ class ClauseProcessor { // 'Unique' clauses: They can appear at most once in the clause list. bool processCollapse( mlir::Location currentLocation, Fortran::lower::pft::Evaluation &eval, - llvm::SmallVectorImpl &lowerBound, - llvm::SmallVectorImpl &upperBound, - llvm::SmallVectorImpl &step, + mlir::omp::CollapseClauseOps &result, llvm::SmallVectorImpl &iv) const; bool processDefault() const; bool processDevice(Fortran::lower::StatementContext &stmtCtx, - mlir::Value &result) const; - bool processDeviceType(mlir::omp::DeclareTargetDeviceType &result) const; + mlir::omp::DeviceClauseOps &result) const; + bool processDeviceType(mlir::omp::DeviceTypeClauseOps &result) const; bool processFinal(Fortran::lower::StatementContext &stmtCtx, - mlir::Value &result) const; + mlir::omp::FinalClauseOps &result) const; bool - processHasDeviceAddr(llvm::SmallVectorImpl &operands, + processHasDeviceAddr(mlir::omp::HasDeviceAddrClauseOps &result, llvm::SmallVectorImpl &isDeviceTypes, llvm::SmallVectorImpl &isDeviceLocs, llvm::SmallVectorImpl &isDeviceSymbols) const; - bool processHint(mlir::IntegerAttr &result) const; - bool processMergeable(mlir::UnitAttr &result) const; - bool processNowait(mlir::UnitAttr &result) const; + bool processHint(mlir::omp::HintClauseOps &result) const; + bool processMergeable(mlir::omp::MergeableClauseOps &result) const; + bool processNowait(mlir::omp::NowaitClauseOps &result) const; bool processNumTeams(Fortran::lower::StatementContext &stmtCtx, - mlir::Value &result) const; + mlir::omp::NumTeamsClauseOps &result) const; bool processNumThreads(Fortran::lower::StatementContext &stmtCtx, - mlir::Value &result) const; - bool processOrdered(mlir::IntegerAttr &result) const; + mlir::omp::NumThreadsClauseOps &result) const; + bool processOrdered(mlir::omp::OrderedClauseOps &result) const; bool processPriority(Fortran::lower::StatementContext &stmtCtx, - mlir::Value &result) const; - bool processProcBind(mlir::omp::ClauseProcBindKindAttr &result) const; - bool processSafelen(mlir::IntegerAttr &result) const; - bool processSchedule(mlir::omp::ClauseScheduleKindAttr &valAttr, - mlir::omp::ScheduleModifierAttr &modifierAttr, - mlir::UnitAttr &simdModifierAttr) const; - bool processScheduleChunk(Fortran::lower::StatementContext &stmtCtx, - mlir::Value &result) const; - bool processSimdlen(mlir::IntegerAttr &result) const; + mlir::omp::PriorityClauseOps &result) const; + bool processProcBind(mlir::omp::ProcBindClauseOps &result) const; + bool processSafelen(mlir::omp::SafelenClauseOps &result) const; + bool processSchedule(Fortran::lower::StatementContext &stmtCtx, + mlir::omp::ScheduleClauseOps &result) const; + bool processSimdlen(mlir::omp::SimdlenClauseOps &result) const; bool processThreadLimit(Fortran::lower::StatementContext &stmtCtx, - mlir::Value &result) const; - bool processUntied(mlir::UnitAttr &result) const; + mlir::omp::ThreadLimitClauseOps &result) const; + bool processUntied(mlir::omp::UntiedClauseOps &result) const; // 'Repeatable' clauses: They can appear multiple times in the clause list. - bool - processAllocate(llvm::SmallVectorImpl &allocatorOperands, - llvm::SmallVectorImpl &allocateOperands) const; + bool processAllocate(mlir::omp::AllocateClauseOps &result) const; bool processCopyin() const; - bool processCopyPrivate( - mlir::Location currentLocation, - llvm::SmallVectorImpl ©PrivateVars, - llvm::SmallVectorImpl ©PrivateFuncs) const; - bool processDepend(llvm::SmallVectorImpl &dependTypeOperands, - llvm::SmallVectorImpl &dependOperands) const; + bool processCopyprivate(mlir::Location currentLocation, + mlir::omp::CopyprivateClauseOps &result) const; + bool processDepend(mlir::omp::DependClauseOps &result) const; bool processEnter(llvm::SmallVectorImpl &result) const; bool processIf(omp::clause::If::DirectiveNameModifier directiveName, - mlir::Value &result) const; + mlir::omp::IfClauseOps &result) const; bool - processIsDevicePtr(llvm::SmallVectorImpl &operands, + processIsDevicePtr(mlir::omp::IsDevicePtrClauseOps &result, llvm::SmallVectorImpl &isDeviceTypes, llvm::SmallVectorImpl &isDeviceLocs, llvm::SmallVectorImpl @@ -119,43 +109,42 @@ class ClauseProcessor { processLink(llvm::SmallVectorImpl &result) const; // This method is used to process a map clause. - // The optional parameters - mapSymTypes, mapSymLocs & mapSymbols are used to + // The optional parameters - mapSymTypes, mapSymLocs & mapSyms are used to // store the original type, location and Fortran symbol for the map operands. // They may be used later on to create the block_arguments for some of the // target directives that require it. - bool processMap(mlir::Location currentLocation, - const llvm::omp::Directive &directive, - Fortran::lower::StatementContext &stmtCtx, - llvm::SmallVectorImpl &mapOperands, - llvm::SmallVectorImpl *mapSymTypes = nullptr, - llvm::SmallVectorImpl *mapSymLocs = nullptr, - llvm::SmallVectorImpl - *mapSymbols = nullptr) const; - bool - processReduction(mlir::Location currentLocation, - llvm::SmallVectorImpl &reductionVars, - llvm::SmallVectorImpl &reductionTypes, - llvm::SmallVectorImpl &reductionDeclSymbols, - llvm::SmallVectorImpl - *reductionSymbols = nullptr) const; - bool processSectionsReduction(mlir::Location currentLocation) const; + bool processMap( + mlir::Location currentLocation, const llvm::omp::Directive &directive, + Fortran::lower::StatementContext &stmtCtx, + mlir::omp::MapClauseOps &result, + llvm::SmallVectorImpl *mapSyms = + nullptr, + llvm::SmallVectorImpl *mapSymLocs = nullptr, + llvm::SmallVectorImpl *mapSymTypes = nullptr) const; + bool processReduction( + mlir::Location currentLocation, mlir::omp::ReductionClauseOps &result, + llvm::SmallVectorImpl *reductionTypes = nullptr, + llvm::SmallVectorImpl *reductionSyms = + nullptr) const; + bool processSectionsReduction(mlir::Location currentLocation, + mlir::omp::ReductionClauseOps &result) const; bool processTo(llvm::SmallVectorImpl &result) const; bool - processUseDeviceAddr(llvm::SmallVectorImpl &operands, + processUseDeviceAddr(mlir::omp::UseDeviceClauseOps &result, llvm::SmallVectorImpl &useDeviceTypes, llvm::SmallVectorImpl &useDeviceLocs, llvm::SmallVectorImpl - &useDeviceSymbols) const; + &useDeviceSyms) const; bool - processUseDevicePtr(llvm::SmallVectorImpl &operands, + processUseDevicePtr(mlir::omp::UseDeviceClauseOps &result, llvm::SmallVectorImpl &useDeviceTypes, llvm::SmallVectorImpl &useDeviceLocs, llvm::SmallVectorImpl - &useDeviceSymbols) const; + &useDeviceSyms) const; template bool processMotionClauses(Fortran::lower::StatementContext &stmtCtx, - llvm::SmallVectorImpl &mapOperands); + mlir::omp::MapClauseOps &result); // Call this method for these clauses that should be supported but are not // implemented yet. It triggers a compilation error if any of the given @@ -197,7 +186,7 @@ class ClauseProcessor { template bool ClauseProcessor::processMotionClauses( Fortran::lower::StatementContext &stmtCtx, - llvm::SmallVectorImpl &mapOperands) { + mlir::omp::MapClauseOps &result) { return findRepeatableClause( [&](const T &clause, const Fortran::parser::CharBlock &source) { mlir::Location clauseLocation = converter.genLocation(source); @@ -239,7 +228,7 @@ bool ClauseProcessor::processMotionClauses( mapTypeBits), mlir::omp::VariableCaptureKind::ByRef, symAddr.getType()); - mapOperands.push_back(mapOp); + result.mapVars.push_back(mapOp); } }); } diff --git a/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp b/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp index e114ab9f4548a..5a42e6a6aa417 100644 --- a/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp +++ b/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp @@ -23,11 +23,13 @@ namespace Fortran { namespace lower { namespace omp { -void DataSharingProcessor::processStep1() { +void DataSharingProcessor::processStep1( + mlir::omp::PrivateClauseOps *clauseOps, + llvm::SmallVectorImpl *privateSyms) { collectSymbolsForPrivatization(); collectDefaultSymbols(); - privatize(); - defaultPrivatize(); + privatize(clauseOps, privateSyms); + defaultPrivatize(clauseOps, privateSyms); insertBarrier(); } @@ -299,14 +301,16 @@ void DataSharingProcessor::collectDefaultSymbols() { } } -void DataSharingProcessor::privatize() { +void DataSharingProcessor::privatize( + mlir::omp::PrivateClauseOps *clauseOps, + llvm::SmallVectorImpl *privateSyms) { for (const Fortran::semantics::Symbol *sym : privatizedSymbols) { if (const auto *commonDet = sym->detailsIf()) { for (const auto &mem : commonDet->objects()) - doPrivatize(&*mem); + doPrivatize(&*mem, clauseOps, privateSyms); } else - doPrivatize(sym); + doPrivatize(sym, clauseOps, privateSyms); } } @@ -323,7 +327,9 @@ void DataSharingProcessor::copyLastPrivatize(mlir::Operation *op) { } } -void DataSharingProcessor::defaultPrivatize() { +void DataSharingProcessor::defaultPrivatize( + mlir::omp::PrivateClauseOps *clauseOps, + llvm::SmallVectorImpl *privateSyms) { for (const Fortran::semantics::Symbol *sym : defaultSymbols) { if (!Fortran::semantics::IsProcedure(*sym) && !sym->GetUltimate().has() && @@ -331,11 +337,14 @@ void DataSharingProcessor::defaultPrivatize() { !symbolsInNestedRegions.contains(sym) && !symbolsInParentRegions.contains(sym) && !privatizedSymbols.contains(sym)) - doPrivatize(sym); + doPrivatize(sym, clauseOps, privateSyms); } } -void DataSharingProcessor::doPrivatize(const Fortran::semantics::Symbol *sym) { +void DataSharingProcessor::doPrivatize( + const Fortran::semantics::Symbol *sym, + mlir::omp::PrivateClauseOps *clauseOps, + llvm::SmallVectorImpl *privateSyms) { if (!useDelayedPrivatization) { cloneSymbol(sym); copyFirstPrivateSymbol(sym); @@ -419,10 +428,13 @@ void DataSharingProcessor::doPrivatize(const Fortran::semantics::Symbol *sym) { return result; }(); - delayedPrivatizationInfo.privatizers.push_back( - mlir::SymbolRefAttr::get(privatizerOp)); - delayedPrivatizationInfo.originalAddresses.push_back(hsb.getAddr()); - delayedPrivatizationInfo.symbols.push_back(sym); + if (clauseOps) { + clauseOps->privatizers.push_back(mlir::SymbolRefAttr::get(privatizerOp)); + clauseOps->privateVars.push_back(hsb.getAddr()); + } + + if (privateSyms) + privateSyms->push_back(sym); } } // namespace omp diff --git a/flang/lib/Lower/OpenMP/DataSharingProcessor.h b/flang/lib/Lower/OpenMP/DataSharingProcessor.h index 1cbc825fd5e11..c11ee299c5d08 100644 --- a/flang/lib/Lower/OpenMP/DataSharingProcessor.h +++ b/flang/lib/Lower/OpenMP/DataSharingProcessor.h @@ -19,28 +19,17 @@ #include "flang/Parser/parse-tree.h" #include "flang/Semantics/symbol.h" +namespace mlir { +namespace omp { +struct PrivateClauseOps; +} // namespace omp +} // namespace mlir + namespace Fortran { namespace lower { namespace omp { class DataSharingProcessor { -public: - /// Collects all the information needed for delayed privatization. This can be - /// used by ops with data-sharing clauses to properly generate their regions - /// (e.g. add region arguments) and map the original SSA values to their - /// corresponding OMP region operands. - struct DelayedPrivatizationInfo { - // The list of symbols referring to delayed privatizer ops (i.e. - // `omp.private` ops). - llvm::SmallVector privatizers; - // SSA values that correspond to "original" values being privatized. - // "Original" here means the SSA value outside the OpenMP region from which - // a clone is created inside the region. - llvm::SmallVector originalAddresses; - // Fortran symbols corresponding to the above SSA values. - llvm::SmallVector symbols; - }; - private: bool hasLastPrivateOp; mlir::OpBuilder::InsertPoint lastPrivIP; @@ -57,7 +46,6 @@ class DataSharingProcessor { Fortran::lower::pft::Evaluation &eval; bool useDelayedPrivatization; Fortran::lower::SymMap *symTable; - DelayedPrivatizationInfo delayedPrivatizationInfo; bool needBarrier(); void collectSymbols(Fortran::semantics::Symbol::Flag flag); @@ -67,9 +55,16 @@ class DataSharingProcessor { void collectSymbolsForPrivatization(); void insertBarrier(); void collectDefaultSymbols(); - void privatize(); - void defaultPrivatize(); - void doPrivatize(const Fortran::semantics::Symbol *sym); + void privatize( + mlir::omp::PrivateClauseOps *clauseOps, + llvm::SmallVectorImpl *privateSyms); + void defaultPrivatize( + mlir::omp::PrivateClauseOps *clauseOps, + llvm::SmallVectorImpl *privateSyms); + void doPrivatize( + const Fortran::semantics::Symbol *sym, + mlir::omp::PrivateClauseOps *clauseOps, + llvm::SmallVectorImpl *privateSyms); void copyLastPrivatize(mlir::Operation *op); void insertLastPrivateCompare(mlir::Operation *op); void cloneSymbol(const Fortran::semantics::Symbol *sym); @@ -103,17 +98,15 @@ class DataSharingProcessor { // Step2 performs the copying for lastprivates and requires knowledge of the // MLIR operation to insert the last private update. Step2 adds // dealocation code as well. - void processStep1(); + void processStep1(mlir::omp::PrivateClauseOps *clauseOps = nullptr, + llvm::SmallVectorImpl + *privateSyms = nullptr); void processStep2(mlir::Operation *op, bool isLoop); void setLoopIV(mlir::Value iv) { assert(!loopIV && "Loop iteration variable already set"); loopIV = iv; } - - const DelayedPrivatizationInfo &getDelayedPrivatizationInfo() const { - return delayedPrivatizationInfo; - } }; } // namespace omp diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp index 50ad889052ab0..3dcfe0fd775dc 100644 --- a/flang/lib/Lower/OpenMP/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP/OpenMP.cpp @@ -730,19 +730,25 @@ genMasterOp(Fortran::lower::AbstractConverter &converter, mlir::Location currentLocation) { return genOpWithBody( OpWithBodyGenInfo(converter, semaCtx, currentLocation, eval) - .setGenNested(genNested), - /*resultTypes=*/mlir::TypeRange()); + .setGenNested(genNested)); } static mlir::omp::OrderedRegionOp genOrderedRegionOp(Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, bool genNested, - mlir::Location currentLocation) { + mlir::Location currentLocation, + const Fortran::parser::OmpClauseList &clauseList) { + mlir::omp::OrderedRegionClauseOps clauseOps; + + ClauseProcessor cp(converter, semaCtx, clauseList); + cp.processTODO(currentLocation, + llvm::omp::Directive::OMPD_ordered); + return genOpWithBody( OpWithBodyGenInfo(converter, semaCtx, currentLocation, eval) .setGenNested(genNested), - /*simd=*/false); + clauseOps); } static mlir::omp::ParallelOp @@ -753,77 +759,62 @@ genParallelOp(Fortran::lower::AbstractConverter &converter, mlir::Location currentLocation, const Fortran::parser::OmpClauseList &clauseList, bool outerCombined = false) { + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); Fortran::lower::StatementContext stmtCtx; - mlir::Value ifClauseOperand, numThreadsClauseOperand; - mlir::omp::ClauseProcBindKindAttr procBindKindAttr; - llvm::SmallVector allocateOperands, allocatorOperands, - reductionVars; + mlir::omp::ParallelClauseOps clauseOps; + llvm::SmallVector privateSyms; llvm::SmallVector reductionTypes; - llvm::SmallVector reductionDeclSymbols; - llvm::SmallVector reductionSymbols; + llvm::SmallVector reductionSyms; ClauseProcessor cp(converter, semaCtx, clauseList); - cp.processIf(llvm::omp::Directive::OMPD_parallel, ifClauseOperand); - cp.processNumThreads(stmtCtx, numThreadsClauseOperand); - cp.processProcBind(procBindKindAttr); + cp.processIf(llvm::omp::Directive::OMPD_parallel, clauseOps); + cp.processNumThreads(stmtCtx, clauseOps); + cp.processProcBind(clauseOps); cp.processDefault(); - cp.processAllocate(allocatorOperands, allocateOperands); + cp.processAllocate(clauseOps); + if (!outerCombined) - cp.processReduction(currentLocation, reductionVars, reductionTypes, - reductionDeclSymbols, &reductionSymbols); + cp.processReduction(currentLocation, clauseOps, &reductionTypes, + &reductionSyms); + + if (ReductionProcessor::doReductionByRef(clauseOps.reductionVars)) + clauseOps.reductionByRefAttr = firOpBuilder.getUnitAttr(); auto reductionCallback = [&](mlir::Operation *op) { - llvm::SmallVector locs(reductionVars.size(), + llvm::SmallVector locs(clauseOps.reductionVars.size(), currentLocation); - auto *block = converter.getFirOpBuilder().createBlock(&op->getRegion(0), {}, - reductionTypes, locs); + auto *block = + firOpBuilder.createBlock(&op->getRegion(0), {}, reductionTypes, locs); for (auto [arg, prv] : - llvm::zip_equal(reductionSymbols, block->getArguments())) { + llvm::zip_equal(reductionSyms, block->getArguments())) { converter.bindSymbol(*arg, prv); } - return reductionSymbols; + return reductionSyms; }; - mlir::UnitAttr byrefAttr; - if (ReductionProcessor::doReductionByRef(reductionVars)) - byrefAttr = converter.getFirOpBuilder().getUnitAttr(); - OpWithBodyGenInfo genInfo = OpWithBodyGenInfo(converter, semaCtx, currentLocation, eval) .setGenNested(genNested) .setOuterCombined(outerCombined) .setClauses(&clauseList) - .setReductions(&reductionSymbols, &reductionTypes) + .setReductions(&reductionSyms, &reductionTypes) .setGenRegionEntryCb(reductionCallback); - if (!enableDelayedPrivatization) { - return genOpWithBody( - genInfo, - /*resultTypes=*/mlir::TypeRange(), ifClauseOperand, - numThreadsClauseOperand, allocateOperands, allocatorOperands, - reductionVars, - reductionDeclSymbols.empty() - ? nullptr - : mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(), - reductionDeclSymbols), - procBindKindAttr, /*private_vars=*/llvm::SmallVector{}, - /*privatizers=*/nullptr, byrefAttr); - } + if (!enableDelayedPrivatization) + return genOpWithBody(genInfo, clauseOps); bool privatize = !outerCombined; DataSharingProcessor dsp(converter, semaCtx, clauseList, eval, /*useDelayedPrivatization=*/true, &symTable); if (privatize) - dsp.processStep1(); - - const auto &delayedPrivatizationInfo = dsp.getDelayedPrivatizationInfo(); + dsp.processStep1(&clauseOps, &privateSyms); auto genRegionEntryCB = [&](mlir::Operation *op) { auto parallelOp = llvm::cast(op); - llvm::SmallVector reductionLocs(reductionVars.size(), - currentLocation); + llvm::SmallVector reductionLocs( + clauseOps.reductionVars.size(), currentLocation); mlir::OperandRange privateVars = parallelOp.getPrivateVars(); mlir::Region ®ion = parallelOp.getRegion(); @@ -838,12 +829,12 @@ genParallelOp(Fortran::lower::AbstractConverter &converter, llvm::transform(privateVars, std::back_inserter(privateVarLocs), [](mlir::Value v) { return v.getLoc(); }); - converter.getFirOpBuilder().createBlock(®ion, /*insertPt=*/{}, - privateVarTypes, privateVarLocs); + firOpBuilder.createBlock(®ion, /*insertPt=*/{}, privateVarTypes, + privateVarLocs); llvm::SmallVector allSymbols = - reductionSymbols; - allSymbols.append(delayedPrivatizationInfo.symbols); + reductionSyms; + allSymbols.append(privateSyms); for (auto [arg, prv] : llvm::zip_equal(allSymbols, region.getArguments())) { converter.bindSymbol(*arg, prv); } @@ -853,26 +844,7 @@ genParallelOp(Fortran::lower::AbstractConverter &converter, // TODO Merge with the reduction CB. genInfo.setGenRegionEntryCb(genRegionEntryCB).setDataSharingProcessor(&dsp); - - llvm::SmallVector privatizers( - delayedPrivatizationInfo.privatizers.begin(), - delayedPrivatizationInfo.privatizers.end()); - - return genOpWithBody( - genInfo, - /*resultTypes=*/mlir::TypeRange(), ifClauseOperand, - numThreadsClauseOperand, allocateOperands, allocatorOperands, - reductionVars, - reductionDeclSymbols.empty() - ? nullptr - : mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(), - reductionDeclSymbols), - procBindKindAttr, delayedPrivatizationInfo.originalAddresses, - delayedPrivatizationInfo.privatizers.empty() - ? nullptr - : mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(), - privatizers), - byrefAttr); + return genOpWithBody(genInfo, clauseOps); } static mlir::omp::SectionOp @@ -896,28 +868,21 @@ genSingleOp(Fortran::lower::AbstractConverter &converter, mlir::Location currentLocation, const Fortran::parser::OmpClauseList &beginClauseList, const Fortran::parser::OmpClauseList &endClauseList) { - llvm::SmallVector allocateOperands, allocatorOperands; - llvm::SmallVector copyPrivateVars; - llvm::SmallVector copyPrivateFuncs; - mlir::UnitAttr nowaitAttr; + mlir::omp::SingleClauseOps clauseOps; ClauseProcessor cp(converter, semaCtx, beginClauseList); - cp.processAllocate(allocatorOperands, allocateOperands); + cp.processAllocate(clauseOps); + // TODO Support delayed privatization. ClauseProcessor ecp(converter, semaCtx, endClauseList); - ecp.processNowait(nowaitAttr); - ecp.processCopyPrivate(currentLocation, copyPrivateVars, copyPrivateFuncs); + ecp.processNowait(clauseOps); + ecp.processCopyprivate(currentLocation, clauseOps); return genOpWithBody( OpWithBodyGenInfo(converter, semaCtx, currentLocation, eval) .setGenNested(genNested) .setClauses(&beginClauseList), - allocateOperands, allocatorOperands, copyPrivateVars, - copyPrivateFuncs.empty() - ? nullptr - : mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(), - copyPrivateFuncs), - nowaitAttr); + clauseOps); } static mlir::omp::TaskOp @@ -927,21 +892,19 @@ genTaskOp(Fortran::lower::AbstractConverter &converter, mlir::Location currentLocation, const Fortran::parser::OmpClauseList &clauseList) { Fortran::lower::StatementContext stmtCtx; - mlir::Value ifClauseOperand, finalClauseOperand, priorityClauseOperand; - mlir::UnitAttr untiedAttr, mergeableAttr; - llvm::SmallVector dependTypeOperands; - llvm::SmallVector allocateOperands, allocatorOperands, - dependOperands; + mlir::omp::TaskClauseOps clauseOps; ClauseProcessor cp(converter, semaCtx, clauseList); - cp.processIf(llvm::omp::Directive::OMPD_task, ifClauseOperand); - cp.processAllocate(allocatorOperands, allocateOperands); + cp.processIf(llvm::omp::Directive::OMPD_task, clauseOps); + cp.processAllocate(clauseOps); cp.processDefault(); - cp.processFinal(stmtCtx, finalClauseOperand); - cp.processUntied(untiedAttr); - cp.processMergeable(mergeableAttr); - cp.processPriority(stmtCtx, priorityClauseOperand); - cp.processDepend(dependTypeOperands, dependOperands); + cp.processFinal(stmtCtx, clauseOps); + cp.processUntied(clauseOps); + cp.processMergeable(clauseOps); + cp.processPriority(stmtCtx, clauseOps); + cp.processDepend(clauseOps); + // TODO Support delayed privatization. + cp.processTODO( currentLocation, llvm::omp::Directive::OMPD_task); @@ -949,14 +912,7 @@ genTaskOp(Fortran::lower::AbstractConverter &converter, OpWithBodyGenInfo(converter, semaCtx, currentLocation, eval) .setGenNested(genNested) .setClauses(&clauseList), - ifClauseOperand, finalClauseOperand, untiedAttr, mergeableAttr, - /*in_reduction_vars=*/mlir::ValueRange(), - /*in_reductions=*/nullptr, priorityClauseOperand, - dependTypeOperands.empty() - ? nullptr - : mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(), - dependTypeOperands), - dependOperands, allocateOperands, allocatorOperands); + clauseOps); } static mlir::omp::TaskgroupOp @@ -965,17 +921,18 @@ genTaskgroupOp(Fortran::lower::AbstractConverter &converter, Fortran::lower::pft::Evaluation &eval, bool genNested, mlir::Location currentLocation, const Fortran::parser::OmpClauseList &clauseList) { - llvm::SmallVector allocateOperands, allocatorOperands; + mlir::omp::TaskgroupClauseOps clauseOps; + ClauseProcessor cp(converter, semaCtx, clauseList); - cp.processAllocate(allocatorOperands, allocateOperands); + cp.processAllocate(clauseOps); cp.processTODO(currentLocation, llvm::omp::Directive::OMPD_taskgroup); + return genOpWithBody( OpWithBodyGenInfo(converter, semaCtx, currentLocation, eval) .setGenNested(genNested) .setClauses(&clauseList), - /*task_reduction_vars=*/mlir::ValueRange(), - /*task_reductions=*/nullptr, allocateOperands, allocatorOperands); + clauseOps); } // This helper function implements the functionality of "promoting" @@ -996,8 +953,7 @@ genTaskgroupOp(Fortran::lower::AbstractConverter &converter, // clause. Support for such list items in a use_device_ptr clause // is deprecated." static void promoteNonCPtrUseDevicePtrArgsToUseDeviceAddr( - llvm::SmallVectorImpl &devicePtrOperands, - llvm::SmallVectorImpl &deviceAddrOperands, + mlir::omp::UseDeviceClauseOps &clauseOps, llvm::SmallVectorImpl &useDeviceTypes, llvm::SmallVectorImpl &useDeviceLocs, llvm::SmallVectorImpl @@ -1010,9 +966,10 @@ static void promoteNonCPtrUseDevicePtrArgsToUseDeviceAddr( // Iterate over our use_device_ptr list and shift all non-cptr arguments into // use_device_addr. - for (auto *it = devicePtrOperands.begin(); it != devicePtrOperands.end();) { + for (auto *it = clauseOps.useDevicePtrVars.begin(); + it != clauseOps.useDevicePtrVars.end();) { if (!fir::isa_builtin_cptr_type(fir::unwrapRefType(it->getType()))) { - deviceAddrOperands.push_back(*it); + clauseOps.useDeviceAddrVars.push_back(*it); // We have to shuffle the symbols around as well, to maintain // the correct Input -> BlockArg for use_device_ptr/use_device_addr. // NOTE: However, as map's do not seem to be included currently @@ -1020,11 +977,11 @@ static void promoteNonCPtrUseDevicePtrArgsToUseDeviceAddr( // future alterations. I believe the reason they are not currently // is that the BlockArg assign/lowering needs to be extended // to a greater set of types. - auto idx = std::distance(devicePtrOperands.begin(), it); + auto idx = std::distance(clauseOps.useDevicePtrVars.begin(), it); moveElementToBack(idx, useDeviceTypes); moveElementToBack(idx, useDeviceLocs); moveElementToBack(idx, useDeviceSymbols); - it = devicePtrOperands.erase(it); + it = clauseOps.useDevicePtrVars.erase(it); continue; } ++it; @@ -1038,20 +995,19 @@ genTargetDataOp(Fortran::lower::AbstractConverter &converter, mlir::Location currentLocation, const Fortran::parser::OmpClauseList &clauseList) { Fortran::lower::StatementContext stmtCtx; - mlir::Value ifClauseOperand, deviceOperand; - llvm::SmallVector mapOperands, devicePtrOperands, - deviceAddrOperands; + mlir::omp::TargetDataClauseOps clauseOps; llvm::SmallVector useDeviceTypes; llvm::SmallVector useDeviceLocs; - llvm::SmallVector useDeviceSymbols; + llvm::SmallVector useDeviceSyms; ClauseProcessor cp(converter, semaCtx, clauseList); - cp.processIf(llvm::omp::Directive::OMPD_target_data, ifClauseOperand); - cp.processDevice(stmtCtx, deviceOperand); - cp.processUseDevicePtr(devicePtrOperands, useDeviceTypes, useDeviceLocs, - useDeviceSymbols); - cp.processUseDeviceAddr(deviceAddrOperands, useDeviceTypes, useDeviceLocs, - useDeviceSymbols); + cp.processIf(llvm::omp::Directive::OMPD_target_data, clauseOps); + cp.processDevice(stmtCtx, clauseOps); + cp.processUseDevicePtr(clauseOps, useDeviceTypes, useDeviceLocs, + useDeviceSyms); + cp.processUseDeviceAddr(clauseOps, useDeviceTypes, useDeviceLocs, + useDeviceSyms); + // This function implements the deprecated functionality of use_device_ptr // that allows users to provide non-CPTR arguments to it with the caveat // that the compiler will treat them as use_device_addr. A lot of legacy @@ -1063,17 +1019,16 @@ genTargetDataOp(Fortran::lower::AbstractConverter &converter, // ordering. // TODO: Perhaps create a user provideable compiler option that will // re-introduce a hard-error rather than a warning in these cases. - promoteNonCPtrUseDevicePtrArgsToUseDeviceAddr( - devicePtrOperands, deviceAddrOperands, useDeviceTypes, useDeviceLocs, - useDeviceSymbols); + promoteNonCPtrUseDevicePtrArgsToUseDeviceAddr(clauseOps, useDeviceTypes, + useDeviceLocs, useDeviceSyms); cp.processMap(currentLocation, llvm::omp::Directive::OMPD_target_data, - stmtCtx, mapOperands); + stmtCtx, clauseOps); auto dataOp = converter.getFirOpBuilder().create( - currentLocation, ifClauseOperand, deviceOperand, devicePtrOperands, - deviceAddrOperands, mapOperands); + currentLocation, clauseOps); + genBodyOfTargetDataOp(converter, semaCtx, eval, genNested, dataOp, - useDeviceTypes, useDeviceLocs, useDeviceSymbols, + useDeviceTypes, useDeviceLocs, useDeviceSyms, currentLocation); return dataOp; } @@ -1086,10 +1041,7 @@ static OpTy genTargetEnterExitDataUpdateOp( const Fortran::parser::OmpClauseList &clauseList) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); Fortran::lower::StatementContext stmtCtx; - mlir::Value ifClauseOperand, deviceOperand; - mlir::UnitAttr nowaitAttr; - llvm::SmallVector mapOperands, dependOperands; - llvm::SmallVector dependTypeOperands; + mlir::omp::TargetEnterExitUpdateDataClauseOps clauseOps; // GCC 9.3.0 emits a (probably) bogus warning about an unused variable. [[maybe_unused]] llvm::omp::Directive directive; @@ -1104,25 +1056,19 @@ static OpTy genTargetEnterExitDataUpdateOp( } ClauseProcessor cp(converter, semaCtx, clauseList); - cp.processIf(directive, ifClauseOperand); - cp.processDevice(stmtCtx, deviceOperand); - cp.processDepend(dependTypeOperands, dependOperands); - cp.processNowait(nowaitAttr); + cp.processIf(directive, clauseOps); + cp.processDevice(stmtCtx, clauseOps); + cp.processDepend(clauseOps); + cp.processNowait(clauseOps); if constexpr (std::is_same_v) { - cp.processMotionClauses(stmtCtx, mapOperands); - cp.processMotionClauses(stmtCtx, mapOperands); + cp.processMotionClauses(stmtCtx, clauseOps); + cp.processMotionClauses(stmtCtx, clauseOps); } else { - cp.processMap(currentLocation, directive, stmtCtx, mapOperands); + cp.processMap(currentLocation, directive, stmtCtx, clauseOps); } - return firOpBuilder.create( - currentLocation, ifClauseOperand, deviceOperand, - dependTypeOperands.empty() - ? nullptr - : mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(), - dependTypeOperands), - dependOperands, nowaitAttr, mapOperands); + return firOpBuilder.create(currentLocation, clauseOps); } // This functions creates a block for the body of the targetOp's region. It adds @@ -1132,9 +1078,9 @@ genBodyOfTargetOp(Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, bool genNested, mlir::omp::TargetOp &targetOp, - llvm::ArrayRef mapSymTypes, + llvm::ArrayRef mapSyms, llvm::ArrayRef mapSymLocs, - llvm::ArrayRef mapSymbols, + llvm::ArrayRef mapSymTypes, const mlir::Location ¤tLocation) { assert(mapSymTypes.size() == mapSymLocs.size()); @@ -1163,7 +1109,7 @@ genBodyOfTargetOp(Fortran::lower::AbstractConverter &converter, }; // Bind the symbols to their corresponding block arguments. - for (auto [argIndex, argSymbol] : llvm::enumerate(mapSymbols)) { + for (auto [argIndex, argSymbol] : llvm::enumerate(mapSyms)) { const mlir::BlockArgument &arg = region.getArgument(argIndex); // Avoid capture of a reference to a structured binding. const Fortran::semantics::Symbol *sym = argSymbol; @@ -1287,31 +1233,25 @@ genTargetOp(Fortran::lower::AbstractConverter &converter, const Fortran::parser::OmpClauseList &clauseList, llvm::omp::Directive directive, bool outerCombined = false) { Fortran::lower::StatementContext stmtCtx; - mlir::Value ifClauseOperand, deviceOperand, threadLimitOperand; - mlir::UnitAttr nowaitAttr; - llvm::SmallVector dependTypeOperands; - llvm::SmallVector mapOperands, dependOperands; - llvm::SmallVector mapSymTypes; - llvm::SmallVector mapSymLocs; - llvm::SmallVector mapSymbols; - llvm::SmallVector devicePtrOperands, deviceAddrOperands; - llvm::SmallVector devicePtrTypes, deviceAddrTypes; - llvm::SmallVector devicePtrLocs, deviceAddrLocs; - llvm::SmallVector devicePtrSymbols, - deviceAddrSymbols; + mlir::omp::TargetClauseOps clauseOps; + llvm::SmallVector mapTypes, devicePtrTypes, deviceAddrTypes; + llvm::SmallVector mapLocs, devicePtrLocs, deviceAddrLocs; + llvm::SmallVector mapSyms, devicePtrSyms, + deviceAddrSyms; ClauseProcessor cp(converter, semaCtx, clauseList); - cp.processIf(llvm::omp::Directive::OMPD_target, ifClauseOperand); - cp.processDevice(stmtCtx, deviceOperand); - cp.processThreadLimit(stmtCtx, threadLimitOperand); - cp.processDepend(dependTypeOperands, dependOperands); - cp.processNowait(nowaitAttr); - cp.processMap(currentLocation, directive, stmtCtx, mapOperands, &mapSymTypes, - &mapSymLocs, &mapSymbols); - cp.processIsDevicePtr(devicePtrOperands, devicePtrTypes, devicePtrLocs, - devicePtrSymbols); - cp.processHasDeviceAddr(deviceAddrOperands, deviceAddrTypes, deviceAddrLocs, - deviceAddrSymbols); + cp.processIf(llvm::omp::Directive::OMPD_target, clauseOps); + cp.processDevice(stmtCtx, clauseOps); + cp.processThreadLimit(stmtCtx, clauseOps); + cp.processDepend(clauseOps); + cp.processNowait(clauseOps); + cp.processMap(currentLocation, directive, stmtCtx, clauseOps, &mapSyms, + &mapLocs, &mapTypes); + cp.processIsDevicePtr(clauseOps, devicePtrTypes, devicePtrLocs, + devicePtrSyms); + cp.processHasDeviceAddr(clauseOps, deviceAddrTypes, deviceAddrLocs, + deviceAddrSyms); + // TODO Support delayed privatization. cp.processTODO( - currentLocation, ifClauseOperand, deviceOperand, threadLimitOperand, - dependTypeOperands.empty() - ? nullptr - : mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(), - dependTypeOperands), - dependOperands, nowaitAttr, devicePtrOperands, deviceAddrOperands, - mapOperands); + currentLocation, clauseOps); - genBodyOfTargetOp(converter, semaCtx, eval, genNested, targetOp, mapSymTypes, - mapSymLocs, mapSymbols, currentLocation); + genBodyOfTargetOp(converter, semaCtx, eval, genNested, targetOp, mapSyms, + mapLocs, mapTypes, currentLocation); return targetOp; } @@ -1426,17 +1360,16 @@ genTeamsOp(Fortran::lower::AbstractConverter &converter, const Fortran::parser::OmpClauseList &clauseList, bool outerCombined = false) { Fortran::lower::StatementContext stmtCtx; - mlir::Value numTeamsClauseOperand, ifClauseOperand, threadLimitClauseOperand; - llvm::SmallVector allocateOperands, allocatorOperands, - reductionVars; - llvm::SmallVector reductionDeclSymbols; + mlir::omp::TeamsClauseOps clauseOps; ClauseProcessor cp(converter, semaCtx, clauseList); - cp.processIf(llvm::omp::Directive::OMPD_teams, ifClauseOperand); - cp.processAllocate(allocatorOperands, allocateOperands); + cp.processIf(llvm::omp::Directive::OMPD_teams, clauseOps); + cp.processAllocate(clauseOps); cp.processDefault(); - cp.processNumTeams(stmtCtx, numTeamsClauseOperand); - cp.processThreadLimit(stmtCtx, threadLimitClauseOperand); + cp.processNumTeams(stmtCtx, clauseOps); + cp.processThreadLimit(stmtCtx, clauseOps); + // TODO Support delayed privatization. + cp.processTODO(currentLocation, llvm::omp::Directive::OMPD_teams); @@ -1445,30 +1378,20 @@ genTeamsOp(Fortran::lower::AbstractConverter &converter, .setGenNested(genNested) .setOuterCombined(outerCombined) .setClauses(&clauseList), - /*num_teams_lower=*/nullptr, numTeamsClauseOperand, ifClauseOperand, - threadLimitClauseOperand, allocateOperands, allocatorOperands, - reductionVars, - reductionDeclSymbols.empty() - ? nullptr - : mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(), - reductionDeclSymbols)); + clauseOps); } /// Extract the list of function and variable symbols affected by the given /// 'declare target' directive and return the intended device type for them. -static mlir::omp::DeclareTargetDeviceType getDeclareTargetInfo( +static void getDeclareTargetInfo( Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, const Fortran::parser::OpenMPDeclareTargetConstruct &declareTargetConstruct, + mlir::omp::DeclareTargetClauseOps &clauseOps, llvm::SmallVectorImpl &symbolAndClause) { - - // The default capture type - mlir::omp::DeclareTargetDeviceType deviceType = - mlir::omp::DeclareTargetDeviceType::any; const auto &spec = std::get( declareTargetConstruct.t); - if (const auto *objectList{ Fortran::parser::Unwrap(spec.u)}) { ObjectList objects{makeObjects(*objectList, semaCtx)}; @@ -1489,12 +1412,10 @@ static mlir::omp::DeclareTargetDeviceType getDeclareTargetInfo( cp.processTo(symbolAndClause); cp.processEnter(symbolAndClause); cp.processLink(symbolAndClause); - cp.processDeviceType(deviceType); + cp.processDeviceType(clauseOps); cp.processTODO(converter.getCurrentLocation(), llvm::omp::Directive::OMPD_declare_target); } - - return deviceType; } static void collectDeferredDeclareTargets( @@ -1504,9 +1425,10 @@ static void collectDeferredDeclareTargets( const Fortran::parser::OpenMPDeclareTargetConstruct &declareTargetConstruct, llvm::SmallVectorImpl &deferredDeclareTarget) { + mlir::omp::DeclareTargetClauseOps clauseOps; llvm::SmallVector symbolAndClause; - mlir::omp::DeclareTargetDeviceType devType = getDeclareTargetInfo( - converter, semaCtx, eval, declareTargetConstruct, symbolAndClause); + getDeclareTargetInfo(converter, semaCtx, eval, declareTargetConstruct, + clauseOps, symbolAndClause); // Return the device type only if at least one of the targets for the // directive is a function or subroutine mlir::ModuleOp mod = converter.getFirOpBuilder().getModule(); @@ -1516,8 +1438,9 @@ static void collectDeferredDeclareTargets( std::get(symClause))); if (!op) { - deferredDeclareTarget.push_back( - {std::get<0>(symClause), devType, std::get<1>(symClause)}); + deferredDeclareTarget.push_back({std::get<0>(symClause), + clauseOps.deviceType, + std::get<1>(symClause)}); } } } @@ -1529,9 +1452,10 @@ getDeclareTargetFunctionDevice( Fortran::lower::pft::Evaluation &eval, const Fortran::parser::OpenMPDeclareTargetConstruct &declareTargetConstruct) { + mlir::omp::DeclareTargetClauseOps clauseOps; llvm::SmallVector symbolAndClause; - mlir::omp::DeclareTargetDeviceType deviceType = getDeclareTargetInfo( - converter, semaCtx, eval, declareTargetConstruct, symbolAndClause); + getDeclareTargetInfo(converter, semaCtx, eval, declareTargetConstruct, + clauseOps, symbolAndClause); // Return the device type only if at least one of the targets for the // directive is a function or subroutine @@ -1541,7 +1465,7 @@ getDeclareTargetFunctionDevice( std::get(symClause))); if (mlir::isa_and_nonnull(op)) - return deviceType; + return clauseOps.deviceType; } return std::nullopt; @@ -1571,12 +1495,14 @@ genOmpSimpleStandalone(Fortran::lower::AbstractConverter &converter, case llvm::omp::Directive::OMPD_barrier: firOpBuilder.create(currentLocation); break; - case llvm::omp::Directive::OMPD_taskwait: - ClauseProcessor(converter, semaCtx, opClauseList) - .processTODO( - currentLocation, llvm::omp::Directive::OMPD_taskwait); - firOpBuilder.create(currentLocation); + case llvm::omp::Directive::OMPD_taskwait: { + mlir::omp::TaskwaitClauseOps clauseOps; + ClauseProcessor cp(converter, semaCtx, opClauseList); + cp.processTODO( + currentLocation, llvm::omp::Directive::OMPD_taskwait); + firOpBuilder.create(currentLocation, clauseOps); break; + } case llvm::omp::Directive::OMPD_taskyield: firOpBuilder.create(currentLocation); break; @@ -1711,32 +1637,21 @@ createSimdLoop(Fortran::lower::AbstractConverter &converter, dsp.processStep1(); Fortran::lower::StatementContext stmtCtx; - mlir::Value scheduleChunkClauseOperand, ifClauseOperand; - llvm::SmallVector lowerBound, upperBound, step, reductionVars; - llvm::SmallVector alignedVars, nontemporalVars; + mlir::omp::SimdLoopClauseOps clauseOps; llvm::SmallVector iv; - llvm::SmallVector reductionTypes; - llvm::SmallVector reductionDeclSymbols; - mlir::omp::ClauseOrderKindAttr orderClauseOperand; - mlir::IntegerAttr simdlenClauseOperand, safelenClauseOperand; ClauseProcessor cp(converter, semaCtx, loopOpClauseList); - cp.processCollapse(loc, eval, lowerBound, upperBound, step, iv); - cp.processScheduleChunk(stmtCtx, scheduleChunkClauseOperand); - cp.processReduction(loc, reductionVars, reductionTypes, reductionDeclSymbols); - cp.processIf(llvm::omp::Directive::OMPD_simd, ifClauseOperand); - cp.processSimdlen(simdlenClauseOperand); - cp.processSafelen(safelenClauseOperand); + cp.processCollapse(loc, eval, clauseOps, iv); + cp.processReduction(loc, clauseOps); + cp.processIf(llvm::omp::Directive::OMPD_simd, clauseOps); + cp.processSimdlen(clauseOps); + cp.processSafelen(clauseOps); + clauseOps.loopInclusiveAttr = firOpBuilder.getUnitAttr(); + // TODO Support delayed privatization. + cp.processTODO(loc, ompDirective); - mlir::TypeRange resultType; - auto simdLoopOp = firOpBuilder.create( - loc, resultType, lowerBound, upperBound, step, alignedVars, - /*alignment_values=*/nullptr, ifClauseOperand, nontemporalVars, - orderClauseOperand, simdlenClauseOperand, safelenClauseOperand, - /*inclusive=*/firOpBuilder.getUnitAttr()); - auto *nestedEval = getCollapsedLoopEval( eval, Fortran::lower::getCollapseValue(loopOpClauseList)); @@ -1744,11 +1659,12 @@ createSimdLoop(Fortran::lower::AbstractConverter &converter, return genLoopVars(op, converter, loc, iv); }; - createBodyOfOp( - simdLoopOp, OpWithBodyGenInfo(converter, semaCtx, loc, *nestedEval) - .setClauses(&loopOpClauseList) - .setDataSharingProcessor(&dsp) - .setGenRegionEntryCb(ivCallback)); + genOpWithBody( + OpWithBodyGenInfo(converter, semaCtx, loc, *nestedEval) + .setClauses(&loopOpClauseList) + .setDataSharingProcessor(&dsp) + .setGenRegionEntryCb(ivCallback), + clauseOps); } static void createWsloop(Fortran::lower::AbstractConverter &converter, @@ -1763,77 +1679,50 @@ static void createWsloop(Fortran::lower::AbstractConverter &converter, dsp.processStep1(); Fortran::lower::StatementContext stmtCtx; - mlir::Value scheduleChunkClauseOperand; - llvm::SmallVector lowerBound, upperBound, step, reductionVars; - llvm::SmallVector linearVars, linearStepVars; + mlir::omp::WsloopClauseOps clauseOps; llvm::SmallVector iv; llvm::SmallVector reductionTypes; - llvm::SmallVector reductionDeclSymbols; - llvm::SmallVector reductionSymbols; - mlir::omp::ClauseOrderKindAttr orderClauseOperand; - mlir::omp::ClauseScheduleKindAttr scheduleValClauseOperand; - mlir::UnitAttr nowaitClauseOperand, byrefOperand, scheduleSimdClauseOperand; - mlir::IntegerAttr orderedClauseOperand; - mlir::omp::ScheduleModifierAttr scheduleModClauseOperand; + llvm::SmallVector reductionSyms; ClauseProcessor cp(converter, semaCtx, beginClauseList); - cp.processCollapse(loc, eval, lowerBound, upperBound, step, iv); - cp.processScheduleChunk(stmtCtx, scheduleChunkClauseOperand); - cp.processReduction(loc, reductionVars, reductionTypes, reductionDeclSymbols, - &reductionSymbols); - cp.processTODO(loc, ompDirective); - - if (ReductionProcessor::doReductionByRef(reductionVars)) - byrefOperand = firOpBuilder.getUnitAttr(); - - auto wsLoopOp = firOpBuilder.create( - loc, lowerBound, upperBound, step, linearVars, linearStepVars, - reductionVars, - reductionDeclSymbols.empty() - ? nullptr - : mlir::ArrayAttr::get(firOpBuilder.getContext(), - reductionDeclSymbols), - scheduleValClauseOperand, scheduleChunkClauseOperand, - /*schedule_modifiers=*/nullptr, - /*simd_modifier=*/nullptr, nowaitClauseOperand, byrefOperand, - orderedClauseOperand, orderClauseOperand, - /*inclusive=*/firOpBuilder.getUnitAttr()); - - // Handle attribute based clauses. - if (cp.processOrdered(orderedClauseOperand)) - wsLoopOp.setOrderedValAttr(orderedClauseOperand); - - if (cp.processSchedule(scheduleValClauseOperand, scheduleModClauseOperand, - scheduleSimdClauseOperand)) { - wsLoopOp.setScheduleValAttr(scheduleValClauseOperand); - wsLoopOp.setScheduleModifierAttr(scheduleModClauseOperand); - wsLoopOp.setSimdModifierAttr(scheduleSimdClauseOperand); - } + cp.processCollapse(loc, eval, clauseOps, iv); + cp.processSchedule(stmtCtx, clauseOps); + cp.processReduction(loc, clauseOps, &reductionTypes, &reductionSyms); + cp.processOrdered(clauseOps); + clauseOps.loopInclusiveAttr = firOpBuilder.getUnitAttr(); + // TODO Support delayed privatization. + + if (ReductionProcessor::doReductionByRef(clauseOps.reductionVars)) + clauseOps.reductionByRefAttr = firOpBuilder.getUnitAttr(); + + cp.processTODO(loc, + ompDirective); + // In FORTRAN `nowait` clause occur at the end of `omp do` directive. // i.e // !$omp do // <...> // !$omp end do nowait if (endClauseList) { - if (ClauseProcessor(converter, semaCtx, *endClauseList) - .processNowait(nowaitClauseOperand)) - wsLoopOp.setNowaitAttr(nowaitClauseOperand); + ClauseProcessor ecp(converter, semaCtx, *endClauseList); + ecp.processNowait(clauseOps); } auto *nestedEval = getCollapsedLoopEval( eval, Fortran::lower::getCollapseValue(beginClauseList)); auto ivCallback = [&](mlir::Operation *op) { - return genLoopAndReductionVars(op, converter, loc, iv, reductionSymbols, + return genLoopAndReductionVars(op, converter, loc, iv, reductionSyms, reductionTypes); }; - createBodyOfOp( - wsLoopOp, OpWithBodyGenInfo(converter, semaCtx, loc, *nestedEval) - .setClauses(&beginClauseList) - .setDataSharingProcessor(&dsp) - .setReductions(&reductionSymbols, &reductionTypes) - .setGenRegionEntryCb(ivCallback)); + genOpWithBody( + OpWithBodyGenInfo(converter, semaCtx, loc, *nestedEval) + .setClauses(&beginClauseList) + .setDataSharingProcessor(&dsp) + .setReductions(&reductionSyms, &reductionTypes) + .setGenRegionEntryCb(ivCallback), + clauseOps); } static void createSimdWsloop( @@ -1921,10 +1810,11 @@ static void genOMP(Fortran::lower::AbstractConverter &converter, Fortran::lower::pft::Evaluation &eval, const Fortran::parser::OpenMPDeclareTargetConstruct &declareTargetConstruct) { + mlir::omp::DeclareTargetClauseOps clauseOps; llvm::SmallVector symbolAndClause; mlir::ModuleOp mod = converter.getFirOpBuilder().getModule(); - mlir::omp::DeclareTargetDeviceType deviceType = getDeclareTargetInfo( - converter, semaCtx, eval, declareTargetConstruct, symbolAndClause); + getDeclareTargetInfo(converter, semaCtx, eval, declareTargetConstruct, + clauseOps, symbolAndClause); for (const DeclareTargetCapturePair &symClause : symbolAndClause) { mlir::Operation *op = mod.lookupSymbol(converter.mangleName( @@ -1938,7 +1828,8 @@ static void genOMP(Fortran::lower::AbstractConverter &converter, markDeclareTarget( op, converter, - std::get(symClause), deviceType); + std::get(symClause), + clauseOps.deviceType); } } @@ -2072,7 +1963,8 @@ genOMP(Fortran::lower::AbstractConverter &converter, !std::get_if(&clause.u) && !std::get_if(&clause.u) && !std::get_if(&clause.u) && - !std::get_if(&clause.u)) { + !std::get_if(&clause.u) && + !std::get_if(&clause.u)) { TODO(clauseLocation, "OpenMP Block construct clause"); } } @@ -2092,7 +1984,7 @@ genOMP(Fortran::lower::AbstractConverter &converter, break; case llvm::omp::Directive::OMPD_ordered: genOrderedRegionOp(converter, semaCtx, eval, /*genNested=*/true, - currentLocation); + currentLocation, beginClauseList); break; case llvm::omp::Directive::OMPD_parallel: genParallelOp(converter, symTable, semaCtx, eval, /*genNested=*/true, @@ -2183,7 +2075,6 @@ genOMP(Fortran::lower::AbstractConverter &converter, const Fortran::parser::OpenMPCriticalConstruct &criticalConstruct) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); mlir::Location currentLocation = converter.getCurrentLocation(); - mlir::IntegerAttr hintClauseOp; std::string name; const Fortran::parser::OmpCriticalDirective &cd = std::get(criticalConstruct.t); @@ -2192,21 +2083,28 @@ genOMP(Fortran::lower::AbstractConverter &converter, std::get>(cd.t).value().ToString(); } - const auto &clauseList = std::get(cd.t); - ClauseProcessor(converter, semaCtx, clauseList).processHint(hintClauseOp); - mlir::omp::CriticalOp criticalOp = [&]() { if (name.empty()) { return firOpBuilder.create( currentLocation, mlir::FlatSymbolRefAttr()); } + mlir::ModuleOp module = firOpBuilder.getModule(); mlir::OpBuilder modBuilder(module.getBodyRegion()); auto global = module.lookupSymbol(name); - if (!global) - global = modBuilder.create( - currentLocation, - mlir::StringAttr::get(firOpBuilder.getContext(), name), hintClauseOp); + if (!global) { + mlir::omp::CriticalClauseOps clauseOps; + const auto &clauseList = std::get(cd.t); + + ClauseProcessor cp(converter, semaCtx, clauseList); + cp.processHint(clauseOps); + clauseOps.nameAttr = + mlir::StringAttr::get(firOpBuilder.getContext(), name); + + global = modBuilder.create(currentLocation, + clauseOps); + } + return firOpBuilder.create( currentLocation, mlir::FlatSymbolRefAttr::get(firOpBuilder.getContext(), global.getSymName())); @@ -2323,8 +2221,7 @@ genOMP(Fortran::lower::AbstractConverter &converter, Fortran::lower::pft::Evaluation &eval, const Fortran::parser::OpenMPSectionsConstruct §ionsConstruct) { mlir::Location currentLocation = converter.getCurrentLocation(); - llvm::SmallVector allocateOperands, allocatorOperands; - mlir::UnitAttr nowaitClauseOperand; + mlir::omp::SectionsClauseOps clauseOps; const auto &beginSectionsDirective = std::get(sectionsConstruct.t); const auto §ionsClauseList = @@ -2333,8 +2230,9 @@ genOMP(Fortran::lower::AbstractConverter &converter, // Process clauses before optional omp.parallel, so that new variables are // allocated outside of the parallel region ClauseProcessor cp(converter, semaCtx, sectionsClauseList); - cp.processSectionsReduction(currentLocation); - cp.processAllocate(allocatorOperands, allocateOperands); + cp.processSectionsReduction(currentLocation, clauseOps); + cp.processAllocate(clauseOps); + // TODO Support delayed privatization. llvm::omp::Directive dir = std::get(beginSectionsDirective.t) @@ -2351,16 +2249,14 @@ genOMP(Fortran::lower::AbstractConverter &converter, const auto &endSectionsClauseList = std::get(endSectionsDirective.t); ClauseProcessor(converter, semaCtx, endSectionsClauseList) - .processNowait(nowaitClauseOperand); + .processNowait(clauseOps); } // SECTIONS construct genOpWithBody( OpWithBodyGenInfo(converter, semaCtx, currentLocation, eval) .setGenNested(false), - /*reduction_vars=*/mlir::ValueRange(), - /*reductions=*/nullptr, allocateOperands, allocatorOperands, - nowaitClauseOperand); + clauseOps); const auto §ionBlocks = std::get(sectionsConstruct.t); diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h index 4ce7e47da046b..304a9740d91ed 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h @@ -81,7 +81,7 @@ struct GrainsizeClauseOps { Value grainsizeVar; }; -struct HasDeviceAddrOps { +struct HasDeviceAddrClauseOps { llvm::SmallVector hasDeviceAddrVars; }; struct HintClauseOps { @@ -97,7 +97,7 @@ struct InReductionClauseOps { llvm::SmallVector inReductionDeclSymbols; }; -struct IsDevicePtrOps { +struct IsDevicePtrClauseOps { llvm::SmallVector isDevicePtrVars; }; @@ -234,6 +234,8 @@ using DistributeClauseOps = detail::Clauses; +using LoopNestClauseOps = detail::Clauses; + // TODO `filter` clause. using MaskedClauseOps = detail::Clauses<>; @@ -261,8 +263,8 @@ using SingleClauseOps = detail::Clauses; using TargetDataClauseOps = detail::Clauses:$step, UnitAttr:$inclusive); + let builders = [ + OpBuilder<(ins CArg<"const LoopNestClauseOps &">:$clauses)> + ]; + let regions = (region AnyRegion:$region); let extraClassDeclaration = [{ diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index 4ab78bf331f80..35fb174046a3a 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -1920,6 +1920,12 @@ void LoopNestOp::print(OpAsmPrinter &p) { p.printRegion(region, /*printEntryBlockArgs=*/false); } +void LoopNestOp::build(OpBuilder &builder, OperationState &state, + const LoopNestClauseOps &clauses) { + LoopNestOp::build(builder, state, clauses.loopLBVar, clauses.loopUBVar, + clauses.loopStepVar, clauses.loopInclusiveAttr); +} + LogicalResult LoopNestOp::verify() { if (getLowerBound().size() != getIVs().size()) return emitOpError() << "number of range arguments and IVs do not match";