diff --git a/flang/include/flang/Lower/OpenMP.h b/flang/include/flang/Lower/OpenMP.h index 3b22a652d1fc1e..6e150ef4e8e82f 100644 --- a/flang/include/flang/Lower/OpenMP.h +++ b/flang/include/flang/Lower/OpenMP.h @@ -19,7 +19,6 @@ #include namespace mlir { -class Value; class Operation; class Location; namespace omp { @@ -30,7 +29,6 @@ enum class DeclareTargetCaptureClause : uint32_t; namespace fir { class FirOpBuilder; -class ConvertOp; } // namespace fir namespace Fortran { @@ -84,16 +82,6 @@ void genOpenMPSymbolProperties(AbstractConverter &converter, int64_t getCollapseValue(const Fortran::parser::OmpClauseList &clauseList); void genThreadprivateOp(AbstractConverter &, const pft::Variable &); void genDeclareTargetIntGlobal(AbstractConverter &, const pft::Variable &); -void genOpenMPReduction(AbstractConverter &, - Fortran::semantics::SemanticsContext &, - const Fortran::parser::OmpClauseList &clauseList); - -mlir::Operation *findReductionChain(mlir::Value, mlir::Value * = nullptr); -fir::ConvertOp getConvertFromReductionOp(mlir::Operation *, mlir::Value); -void updateReduction(mlir::Operation *, fir::FirOpBuilder &, mlir::Value, - mlir::Value, fir::ConvertOp * = nullptr); -void removeStoreOp(mlir::Operation *, mlir::Value); - bool isOpenMPTargetConstruct(const parser::OpenMPConstruct &); bool isOpenMPDeviceDeclareTarget(Fortran::lower::AbstractConverter &, Fortran::semantics::SemanticsContext &, diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp index 5defffd738b4e8..340921c867246c 100644 --- a/flang/lib/Lower/OpenMP/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP/OpenMP.cpp @@ -237,6 +237,213 @@ createAndSetPrivatizedLoopVar(Fortran::lower::AbstractConverter &converter, return storeOp; } +static mlir::Operation * +findReductionChain(mlir::Value loadVal, mlir::Value *reductionVal = nullptr) { + for (mlir::OpOperand &loadOperand : loadVal.getUses()) { + if (mlir::Operation *reductionOp = loadOperand.getOwner()) { + if (auto convertOp = mlir::dyn_cast(reductionOp)) { + for (mlir::OpOperand &convertOperand : convertOp.getRes().getUses()) { + if (mlir::Operation *reductionOp = convertOperand.getOwner()) + return reductionOp; + } + } + for (mlir::OpOperand &reductionOperand : reductionOp->getUses()) { + if (auto store = + mlir::dyn_cast(reductionOperand.getOwner())) { + if (store.getMemref() == *reductionVal) { + store.erase(); + return reductionOp; + } + } + if (auto assign = + mlir::dyn_cast(reductionOperand.getOwner())) { + if (assign.getLhs() == *reductionVal) { + assign.erase(); + return reductionOp; + } + } + } + } + } + return nullptr; +} + +// for a logical operator 'op' reduction X = X op Y +// This function returns the operation responsible for converting Y from +// fir.logical<4> to i1 +static fir::ConvertOp getConvertFromReductionOp(mlir::Operation *reductionOp, + mlir::Value loadVal) { + for (mlir::Value reductionOperand : reductionOp->getOperands()) { + if (auto convertOp = + mlir::dyn_cast(reductionOperand.getDefiningOp())) { + if (convertOp.getOperand() == loadVal) + continue; + return convertOp; + } + } + return nullptr; +} + +static void updateReduction(mlir::Operation *op, + fir::FirOpBuilder &firOpBuilder, + mlir::Value loadVal, mlir::Value reductionVal, + fir::ConvertOp *convertOp = nullptr) { + mlir::OpBuilder::InsertPoint insertPtDel = firOpBuilder.saveInsertionPoint(); + firOpBuilder.setInsertionPoint(op); + + mlir::Value reductionOp; + if (convertOp) + reductionOp = convertOp->getOperand(); + else if (op->getOperand(0) == loadVal) + reductionOp = op->getOperand(1); + else + reductionOp = op->getOperand(0); + + firOpBuilder.create(op->getLoc(), reductionOp, + reductionVal); + firOpBuilder.restoreInsertionPoint(insertPtDel); +} + +static void removeStoreOp(mlir::Operation *reductionOp, mlir::Value symVal) { + for (mlir::Operation *reductionOpUse : reductionOp->getUsers()) { + if (auto convertReduction = + mlir::dyn_cast(reductionOpUse)) { + for (mlir::Operation *convertReductionUse : + convertReduction.getRes().getUsers()) { + if (auto storeOp = mlir::dyn_cast(convertReductionUse)) { + if (storeOp.getMemref() == symVal) + storeOp.erase(); + } + if (auto assignOp = + mlir::dyn_cast(convertReductionUse)) { + if (assignOp.getLhs() == symVal) + assignOp.erase(); + } + } + } + } +} + +// Generate an OpenMP reduction operation. +// TODO: Currently assumes it is either an integer addition/multiplication +// reduction, or a logical and reduction. Generalize this for various reduction +// operation types. +// TODO: Generate the reduction operation during lowering instead of creating +// and removing operations since this is not a robust approach. Also, removing +// ops in the builder (instead of a rewriter) is probably not the best approach. +static void +genOpenMPReduction(Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, + const Fortran::parser::OmpClauseList &clauseList) { + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + + List clauses{makeClauses(clauseList, semaCtx)}; + + for (const Clause &clause : clauses) { + if (const auto &reductionClause = + std::get_if(&clause.u)) { + const auto &redOperatorList{ + std::get( + reductionClause->t)}; + assert(redOperatorList.size() == 1 && "Expecting single operator"); + const auto &redOperator = redOperatorList.front(); + const auto &objects{std::get(reductionClause->t)}; + if (const auto *reductionOp = + std::get_if(&redOperator.u)) { + const auto &intrinsicOp{ + std::get( + reductionOp->u)}; + + switch (intrinsicOp) { + case clause::DefinedOperator::IntrinsicOperator::Add: + case clause::DefinedOperator::IntrinsicOperator::Multiply: + case clause::DefinedOperator::IntrinsicOperator::AND: + case clause::DefinedOperator::IntrinsicOperator::EQV: + case clause::DefinedOperator::IntrinsicOperator::OR: + case clause::DefinedOperator::IntrinsicOperator::NEQV: + break; + default: + continue; + } + for (const Object &object : objects) { + if (const Fortran::semantics::Symbol *symbol = object.id()) { + mlir::Value reductionVal = converter.getSymbolAddress(*symbol); + if (auto declOp = reductionVal.getDefiningOp()) + reductionVal = declOp.getBase(); + mlir::Type reductionType = + reductionVal.getType().cast().getEleTy(); + if (!reductionType.isa()) { + if (!reductionType.isIntOrIndexOrFloat()) + continue; + } + for (mlir::OpOperand &reductionValUse : reductionVal.getUses()) { + if (auto loadOp = + mlir::dyn_cast(reductionValUse.getOwner())) { + mlir::Value loadVal = loadOp.getRes(); + if (reductionType.isa()) { + mlir::Operation *reductionOp = findReductionChain(loadVal); + fir::ConvertOp convertOp = + getConvertFromReductionOp(reductionOp, loadVal); + updateReduction(reductionOp, firOpBuilder, loadVal, + reductionVal, &convertOp); + removeStoreOp(reductionOp, reductionVal); + } else if (mlir::Operation *reductionOp = + findReductionChain(loadVal, &reductionVal)) { + updateReduction(reductionOp, firOpBuilder, loadVal, + reductionVal); + } + } + } + } + } + } else if (const auto *reductionIntrinsic = + std::get_if(&redOperator.u)) { + if (!ReductionProcessor::supportedIntrinsicProcReduction( + *reductionIntrinsic)) + continue; + ReductionProcessor::ReductionIdentifier redId = + ReductionProcessor::getReductionType(*reductionIntrinsic); + for (const Object &object : objects) { + if (const Fortran::semantics::Symbol *symbol = object.id()) { + mlir::Value reductionVal = converter.getSymbolAddress(*symbol); + if (auto declOp = reductionVal.getDefiningOp()) + reductionVal = declOp.getBase(); + for (const mlir::OpOperand &reductionValUse : + reductionVal.getUses()) { + if (auto loadOp = + mlir::dyn_cast(reductionValUse.getOwner())) { + mlir::Value loadVal = loadOp.getRes(); + // Max is lowered as a compare -> select. + // Match the pattern here. + mlir::Operation *reductionOp = + findReductionChain(loadVal, &reductionVal); + if (reductionOp == nullptr) + continue; + + if (redId == ReductionProcessor::ReductionIdentifier::MAX || + redId == ReductionProcessor::ReductionIdentifier::MIN) { + assert(mlir::isa(reductionOp) && + "Selection Op not found in reduction intrinsic"); + mlir::Operation *compareOp = + getCompareFromReductionOp(reductionOp, loadVal); + updateReduction(compareOp, firOpBuilder, loadVal, + reductionVal); + } + if (redId == ReductionProcessor::ReductionIdentifier::IOR || + redId == ReductionProcessor::ReductionIdentifier::IEOR || + redId == ReductionProcessor::ReductionIdentifier::IAND) { + updateReduction(reductionOp, firOpBuilder, loadVal, + reductionVal); + } + } + } + } + } + } + } + } +} + struct OpWithBodyGenInfo { /// A type for a code-gen callback function. This takes as argument the op for /// which the code is being generated and returns the arguments of the op's @@ -2339,216 +2546,6 @@ void Fortran::lower::genDeclareTargetIntGlobal( } } -// Generate an OpenMP reduction operation. -// TODO: Currently assumes it is either an integer addition/multiplication -// reduction, or a logical and reduction. Generalize this for various reduction -// operation types. -// TODO: Generate the reduction operation during lowering instead of creating -// and removing operations since this is not a robust approach. Also, removing -// ops in the builder (instead of a rewriter) is probably not the best approach. -void Fortran::lower::genOpenMPReduction( - Fortran::lower::AbstractConverter &converter, - Fortran::semantics::SemanticsContext &semaCtx, - const Fortran::parser::OmpClauseList &clauseList) { - fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - - List clauses{makeClauses(clauseList, semaCtx)}; - - for (const Clause &clause : clauses) { - if (const auto &reductionClause = - std::get_if(&clause.u)) { - const auto &redOperatorList{ - std::get( - reductionClause->t)}; - assert(redOperatorList.size() == 1 && "Expecting single operator"); - const auto &redOperator = redOperatorList.front(); - const auto &objects{std::get(reductionClause->t)}; - if (const auto *reductionOp = - std::get_if(&redOperator.u)) { - const auto &intrinsicOp{ - std::get( - reductionOp->u)}; - - switch (intrinsicOp) { - case clause::DefinedOperator::IntrinsicOperator::Add: - case clause::DefinedOperator::IntrinsicOperator::Multiply: - case clause::DefinedOperator::IntrinsicOperator::AND: - case clause::DefinedOperator::IntrinsicOperator::EQV: - case clause::DefinedOperator::IntrinsicOperator::OR: - case clause::DefinedOperator::IntrinsicOperator::NEQV: - break; - default: - continue; - } - for (const Object &object : objects) { - if (const Fortran::semantics::Symbol *symbol = object.id()) { - mlir::Value reductionVal = converter.getSymbolAddress(*symbol); - if (auto declOp = reductionVal.getDefiningOp()) - reductionVal = declOp.getBase(); - mlir::Type reductionType = - reductionVal.getType().cast().getEleTy(); - if (!reductionType.isa()) { - if (!reductionType.isIntOrIndexOrFloat()) - continue; - } - for (mlir::OpOperand &reductionValUse : reductionVal.getUses()) { - if (auto loadOp = - mlir::dyn_cast(reductionValUse.getOwner())) { - mlir::Value loadVal = loadOp.getRes(); - if (reductionType.isa()) { - mlir::Operation *reductionOp = findReductionChain(loadVal); - fir::ConvertOp convertOp = - getConvertFromReductionOp(reductionOp, loadVal); - updateReduction(reductionOp, firOpBuilder, loadVal, - reductionVal, &convertOp); - removeStoreOp(reductionOp, reductionVal); - } else if (mlir::Operation *reductionOp = - findReductionChain(loadVal, &reductionVal)) { - updateReduction(reductionOp, firOpBuilder, loadVal, - reductionVal); - } - } - } - } - } - } else if (const auto *reductionIntrinsic = - std::get_if(&redOperator.u)) { - if (!ReductionProcessor::supportedIntrinsicProcReduction( - *reductionIntrinsic)) - continue; - ReductionProcessor::ReductionIdentifier redId = - ReductionProcessor::getReductionType(*reductionIntrinsic); - for (const Object &object : objects) { - if (const Fortran::semantics::Symbol *symbol = object.id()) { - mlir::Value reductionVal = converter.getSymbolAddress(*symbol); - if (auto declOp = reductionVal.getDefiningOp()) - reductionVal = declOp.getBase(); - for (const mlir::OpOperand &reductionValUse : - reductionVal.getUses()) { - if (auto loadOp = - mlir::dyn_cast(reductionValUse.getOwner())) { - mlir::Value loadVal = loadOp.getRes(); - // Max is lowered as a compare -> select. - // Match the pattern here. - mlir::Operation *reductionOp = - findReductionChain(loadVal, &reductionVal); - if (reductionOp == nullptr) - continue; - - if (redId == ReductionProcessor::ReductionIdentifier::MAX || - redId == ReductionProcessor::ReductionIdentifier::MIN) { - assert(mlir::isa(reductionOp) && - "Selection Op not found in reduction intrinsic"); - mlir::Operation *compareOp = - getCompareFromReductionOp(reductionOp, loadVal); - updateReduction(compareOp, firOpBuilder, loadVal, - reductionVal); - } - if (redId == ReductionProcessor::ReductionIdentifier::IOR || - redId == ReductionProcessor::ReductionIdentifier::IEOR || - redId == ReductionProcessor::ReductionIdentifier::IAND) { - updateReduction(reductionOp, firOpBuilder, loadVal, - reductionVal); - } - } - } - } - } - } - } - } -} - -mlir::Operation *Fortran::lower::findReductionChain(mlir::Value loadVal, - mlir::Value *reductionVal) { - for (mlir::OpOperand &loadOperand : loadVal.getUses()) { - if (mlir::Operation *reductionOp = loadOperand.getOwner()) { - if (auto convertOp = mlir::dyn_cast(reductionOp)) { - for (mlir::OpOperand &convertOperand : convertOp.getRes().getUses()) { - if (mlir::Operation *reductionOp = convertOperand.getOwner()) - return reductionOp; - } - } - for (mlir::OpOperand &reductionOperand : reductionOp->getUses()) { - if (auto store = - mlir::dyn_cast(reductionOperand.getOwner())) { - if (store.getMemref() == *reductionVal) { - store.erase(); - return reductionOp; - } - } - if (auto assign = - mlir::dyn_cast(reductionOperand.getOwner())) { - if (assign.getLhs() == *reductionVal) { - assign.erase(); - return reductionOp; - } - } - } - } - } - return nullptr; -} - -// for a logical operator 'op' reduction X = X op Y -// This function returns the operation responsible for converting Y from -// fir.logical<4> to i1 -fir::ConvertOp -Fortran::lower::getConvertFromReductionOp(mlir::Operation *reductionOp, - mlir::Value loadVal) { - for (mlir::Value reductionOperand : reductionOp->getOperands()) { - if (auto convertOp = - mlir::dyn_cast(reductionOperand.getDefiningOp())) { - if (convertOp.getOperand() == loadVal) - continue; - return convertOp; - } - } - return nullptr; -} - -void Fortran::lower::updateReduction(mlir::Operation *op, - fir::FirOpBuilder &firOpBuilder, - mlir::Value loadVal, - mlir::Value reductionVal, - fir::ConvertOp *convertOp) { - mlir::OpBuilder::InsertPoint insertPtDel = firOpBuilder.saveInsertionPoint(); - firOpBuilder.setInsertionPoint(op); - - mlir::Value reductionOp; - if (convertOp) - reductionOp = convertOp->getOperand(); - else if (op->getOperand(0) == loadVal) - reductionOp = op->getOperand(1); - else - reductionOp = op->getOperand(0); - - firOpBuilder.create(op->getLoc(), reductionOp, - reductionVal); - firOpBuilder.restoreInsertionPoint(insertPtDel); -} - -void Fortran::lower::removeStoreOp(mlir::Operation *reductionOp, - mlir::Value symVal) { - for (mlir::Operation *reductionOpUse : reductionOp->getUsers()) { - if (auto convertReduction = - mlir::dyn_cast(reductionOpUse)) { - for (mlir::Operation *convertReductionUse : - convertReduction.getRes().getUsers()) { - if (auto storeOp = mlir::dyn_cast(convertReductionUse)) { - if (storeOp.getMemref() == symVal) - storeOp.erase(); - } - if (auto assignOp = - mlir::dyn_cast(convertReductionUse)) { - if (assignOp.getLhs() == symVal) - assignOp.erase(); - } - } - } - } -} - bool Fortran::lower::isOpenMPTargetConstruct( const Fortran::parser::OpenMPConstruct &omp) { llvm::omp::Directive dir = llvm::omp::Directive::OMPD_unknown;