diff --git a/flang/include/flang/Semantics/openmp-utils.h b/flang/include/flang/Semantics/openmp-utils.h index 8e671c1d71bc4..fa0528bea5114 100644 --- a/flang/include/flang/Semantics/openmp-utils.h +++ b/flang/include/flang/Semantics/openmp-utils.h @@ -112,10 +112,6 @@ MaybeExpr MakeEvaluateExpr(const parser::OmpStylizedInstance &inp); bool IsLoopTransforming(llvm::omp::Directive dir); bool IsFullUnroll(const parser::OpenMPLoopConstruct &x); -std::optional GetNumGeneratedNestsFrom( - const parser::ExecutionPartConstruct &epc, - std::optional nestedCount); - struct LoopSequence { LoopSequence( const parser::ExecutionPartConstruct &root, bool allowAllLoops = false); diff --git a/flang/lib/Semantics/openmp-utils.cpp b/flang/lib/Semantics/openmp-utils.cpp index f1c53f302caf7..8b0214531ee25 100644 --- a/flang/lib/Semantics/openmp-utils.cpp +++ b/flang/lib/Semantics/openmp-utils.cpp @@ -720,14 +720,58 @@ bool IsTransformableLoop(const parser::ExecutionPartConstruct &epc) { return false; } -std::optional GetNumGeneratedNestsFrom( - const parser::ExecutionPartConstruct &epc, - std::optional nestedCount) { - if (parser::Unwrap(epc)) { +LoopSequence::LoopSequence( + const parser::ExecutionPartConstruct &root, bool allowAllLoops) + : allowAllLoops_(allowAllLoops) { + entry_ = createConstructEntry(root); + assert(entry_ && "Expecting loop like code"); + + createChildrenFromRange(entry_->location); + length_ = calculateLength(); +} + +LoopSequence::LoopSequence(std::unique_ptr entry, bool allowAllLoops) + : allowAllLoops_(allowAllLoops), entry_(std::move(entry)) { + createChildrenFromRange(entry_->location); + length_ = calculateLength(); +} + +std::unique_ptr LoopSequence::createConstructEntry( + const parser::ExecutionPartConstruct &code) { + if (auto *loop{parser::Unwrap(code)}) { + if (allowAllLoops_ || IsTransformableLoop(*loop)) { + auto &body{std::get(loop->t)}; + return std::make_unique(body, &code); + } + } else if (auto *omp{parser::Unwrap(code)}) { + if (IsTransformableLoop(*omp)) { + auto &body{std::get(omp->t)}; + return std::make_unique(body, &code); + } + } + + return nullptr; +} + +void LoopSequence::createChildrenFromRange( + ExecutionPartIterator::IteratorType begin, + ExecutionPartIterator::IteratorType end) { + for (auto &code : BlockRange(begin, end, BlockRange::Step::Over)) { + if (auto entry{createConstructEntry(code)}) { + children_.push_back(LoopSequence(std::move(entry), allowAllLoops_)); + } + } +} + +std::optional LoopSequence::calculateLength() const { + if (!entry_->owner) { + return sumOfChildrenLengths(); + } + if (parser::Unwrap(entry_->owner)) { return 1; } - auto &omp{DEREF(parser::Unwrap(epc))}; + auto &omp{DEREF(parser::Unwrap(*entry_->owner))}; const parser::OmpDirectiveSpecification &beginSpec{omp.BeginDir()}; llvm::omp::Directive dir{beginSpec.DirId()}; if (!IsLoopTransforming(dir)) { @@ -739,6 +783,8 @@ std::optional GetNumGeneratedNestsFrom( return std::nullopt; } + auto nestedCount{sumOfChildrenLengths()}; + if (dir == llvm::omp::Directive::OMPD_fuse) { // If there are no loops nested inside of FUSE, then the construct is // invalid. This case will be diagnosed when analyzing the body of the FUSE @@ -778,59 +824,6 @@ std::optional GetNumGeneratedNestsFrom( return 1; } -LoopSequence::LoopSequence( - const parser::ExecutionPartConstruct &root, bool allowAllLoops) - : allowAllLoops_(allowAllLoops) { - entry_ = createConstructEntry(root); - assert(entry_ && "Expecting loop like code"); - - createChildrenFromRange(entry_->location); - length_ = calculateLength(); -} - -LoopSequence::LoopSequence(std::unique_ptr entry, bool allowAllLoops) - : allowAllLoops_(allowAllLoops), entry_(std::move(entry)) { - createChildrenFromRange(entry_->location); - length_ = calculateLength(); -} - -std::unique_ptr LoopSequence::createConstructEntry( - const parser::ExecutionPartConstruct &code) { - if (auto *loop{parser::Unwrap(code)}) { - if (allowAllLoops_ || IsTransformableLoop(*loop)) { - auto &body{std::get(loop->t)}; - return std::make_unique(body, &code); - } - } else if (auto *omp{parser::Unwrap(code)}) { - if (IsTransformableLoop(*omp)) { - auto &body{std::get(omp->t)}; - return std::make_unique(body, &code); - } - } - - return nullptr; -} - -void LoopSequence::createChildrenFromRange( - ExecutionPartIterator::IteratorType begin, - ExecutionPartIterator::IteratorType end) { - for (auto &code : BlockRange(begin, end, BlockRange::Step::Over)) { - if (auto entry{createConstructEntry(code)}) { - children_.push_back(LoopSequence(std::move(entry), allowAllLoops_)); - } - } -} - -std::optional LoopSequence::calculateLength() const { - if (!entry_->owner) { - return sumOfChildrenLengths(); - } - if (parser::Unwrap(entry_->owner)) { - return 1; - } - return GetNumGeneratedNestsFrom(*entry_->owner, sumOfChildrenLengths()); -} - std::optional LoopSequence::sumOfChildrenLengths() const { int64_t sum{0}; for (auto &seq : children_) {