Skip to content

Commit

Permalink
[flang][OpenMP] Decompose compound constructs, do recursive lowering (#…
Browse files Browse the repository at this point in the history
…90098)

A compound construct with a list of clauses is broken up into individual
leaf/composite constructs. Each such construct has the list of clauses
that apply to it based on the OpenMP spec.

Each lowering function (i.e. a function that generates MLIR ops) is now
responsible for generating its body as described below.

Functions that receive AST nodes extract the construct, and the clauses
from the node. They then create a work queue consisting of individual
constructs, and invoke a common dispatch function to process (lower) the
queue.

The dispatch function examines the current position in the queue, and
invokes the appropriate lowering function. Each lowering function
receives the queue as well, and once it needs to generate its body, it
either invokes the dispatch function on the rest of the queue (if any),
or processes nested evaluations if the work queue is at the end.
  • Loading branch information
kparzysz committed May 13, 2024
1 parent f059058 commit ca1bd59
Show file tree
Hide file tree
Showing 16 changed files with 3,226 additions and 452 deletions.
1 change: 1 addition & 0 deletions flang/lib/Lower/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ add_flang_library(FortranLower
OpenMP/ClauseProcessor.cpp
OpenMP/Clauses.cpp
OpenMP/DataSharingProcessor.cpp
OpenMP/Decomposer.cpp
OpenMP/OpenMP.cpp
OpenMP/ReductionProcessor.cpp
OpenMP/Utils.cpp
Expand Down
23 changes: 23 additions & 0 deletions flang/lib/Lower/OpenMP/Clauses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1227,4 +1227,27 @@ List<Clause> makeClauses(const parser::OmpClauseList &clauses,
return makeClause(s, semaCtx);
});
}

bool transferLocations(const List<Clause> &from, List<Clause> &to) {
bool allDone = true;

for (Clause &clause : to) {
if (!clause.source.empty())
continue;
auto found =
llvm::find_if(from, [&](const Clause &c) { return c.id == clause.id; });
// This is not completely accurate, but should be good enough for now.
// It can be improved in the future if necessary, but in cases of
// synthesized clauses getting accurate location may be impossible.
if (found != from.end()) {
clause.source = found->source;
} else {
// Found a clause that won't have "source".
allDone = false;
}
}

return allDone;
}

} // namespace Fortran::lower::omp
13 changes: 11 additions & 2 deletions flang/lib/Lower/OpenMP/Clauses.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,15 @@

namespace Fortran::lower::omp {
using namespace Fortran;
using SomeType = evaluate::SomeType;
using SomeExpr = semantics::SomeExpr;
using MaybeExpr = semantics::MaybeExpr;

using TypeTy = SomeType;
// evaluate::SomeType doesn't provide == operation. It's not really used in
// flang's clauses so far, so a trivial implementation is sufficient.
struct TypeTy : public evaluate::SomeType {
bool operator==(const TypeTy &t) const { return true; }
};

using IdTy = semantics::Symbol *;
using ExprTy = SomeExpr;

Expand Down Expand Up @@ -222,6 +226,8 @@ using When = tomp::clause::WhenT<TypeTy, IdTy, ExprTy>;
using Write = tomp::clause::WriteT<TypeTy, IdTy, ExprTy>;
} // namespace clause

using tomp::type::operator==;

struct CancellationConstructType {
using EmptyTrait = std::true_type;
};
Expand All @@ -244,6 +250,7 @@ using ClauseBase = tomp::ClauseT<TypeTy, IdTy, ExprTy,
MemoryOrder, Threadprivate>;

struct Clause : public ClauseBase {
// "source" will be ignored by tomp::type::operator==.
parser::CharBlock source;
};

Expand All @@ -258,6 +265,8 @@ Clause makeClause(const Fortran::parser::OmpClause &cls,

List<Clause> makeClauses(const parser::OmpClauseList &clauses,
semantics::SemanticsContext &semaCtx);

bool transferLocations(const List<Clause> &from, List<Clause> &to);
} // namespace Fortran::lower::omp

#endif // FORTRAN_LOWER_OPENMP_CLAUSES_H
126 changes: 126 additions & 0 deletions flang/lib/Lower/OpenMP/Decomposer.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
//===-- Decomposer.cpp -- Compound directive decomposition ----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
//
//===----------------------------------------------------------------------===//

#include "Decomposer.h"

#include "Clauses.h"
#include "Utils.h"
#include "flang/Lower/PFTBuilder.h"
#include "flang/Semantics/semantics.h"
#include "flang/Tools/CrossToolHelpers.h"
#include "mlir/IR/BuiltinOps.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Frontend/OpenMP/ClauseT.h"
#include "llvm/Frontend/OpenMP/ConstructCompositionT.h"
#include "llvm/Frontend/OpenMP/ConstructDecompositionT.h"
#include "llvm/Frontend/OpenMP/OMP.h"
#include "llvm/Support/raw_ostream.h"

#include <optional>
#include <utility>
#include <variant>

using namespace Fortran;

namespace {
using namespace Fortran::lower::omp;

struct ConstructDecomposition {
ConstructDecomposition(mlir::ModuleOp modOp,
semantics::SemanticsContext &semaCtx,
lower::pft::Evaluation &ev,
llvm::omp::Directive compound,
const List<Clause> &clauses)
: semaCtx(semaCtx), mod(modOp), eval(ev) {
tomp::ConstructDecompositionT decompose(getOpenMPVersionAttribute(modOp),
*this, compound,
llvm::ArrayRef(clauses));
output = std::move(decompose.output);
}

// Given an object, return its base object if one exists.
std::optional<Object> getBaseObject(const Object &object) {
return lower::omp::getBaseObject(object, semaCtx);
}

// Return the iteration variable of the associated loop if any.
std::optional<Object> getLoopIterVar() {
if (semantics::Symbol *symbol = getIterationVariableSymbol(eval))
return Object{symbol, /*designator=*/{}};
return std::nullopt;
}

semantics::SemanticsContext &semaCtx;
mlir::ModuleOp mod;
lower::pft::Evaluation &eval;
List<UnitConstruct> output;
};
} // namespace

static UnitConstruct mergeConstructs(uint32_t version,
llvm::ArrayRef<UnitConstruct> units) {
tomp::ConstructCompositionT compose(version, units);
return compose.merged;
}

namespace Fortran::lower::omp {
LLVM_DUMP_METHOD llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
const UnitConstruct &uc) {
os << llvm::omp::getOpenMPDirectiveName(uc.id);
for (auto [index, clause] : llvm::enumerate(uc.clauses)) {
os << (index == 0 ? '\t' : ' ');
os << llvm::omp::getOpenMPClauseName(clause.id);
}
return os;
}

ConstructQueue buildConstructQueue(
mlir::ModuleOp modOp, Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::pft::Evaluation &eval, const parser::CharBlock &source,
llvm::omp::Directive compound, const List<Clause> &clauses) {

List<UnitConstruct> constructs;

ConstructDecomposition decompose(modOp, semaCtx, eval, compound, clauses);
assert(!decompose.output.empty() && "Construct decomposition failed");

llvm::SmallVector<llvm::omp::Directive> loweringUnits;
std::ignore =
llvm::omp::getLeafOrCompositeConstructs(compound, loweringUnits);
uint32_t version = getOpenMPVersionAttribute(modOp);

int leafIndex = 0;
for (llvm::omp::Directive dir_id : loweringUnits) {
llvm::ArrayRef<llvm::omp::Directive> leafsOrSelf =
llvm::omp::getLeafConstructsOrSelf(dir_id);
size_t numLeafs = leafsOrSelf.size();

llvm::ArrayRef<UnitConstruct> toMerge{&decompose.output[leafIndex],
numLeafs};
auto &uc = constructs.emplace_back(mergeConstructs(version, toMerge));

if (!transferLocations(clauses, uc.clauses)) {
// If some clauses are left without source information, use the
// directive's source.
for (auto &clause : uc.clauses) {
if (clause.source.empty())
clause.source = source;
}
}
leafIndex += numLeafs;
}

return constructs;
}
} // namespace Fortran::lower::omp
51 changes: 51 additions & 0 deletions flang/lib/Lower/OpenMP/Decomposer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
//===-- Decomposer.h -- Compound directive decomposition ------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef FORTRAN_LOWER_OPENMP_DECOMPOSER_H
#define FORTRAN_LOWER_OPENMP_DECOMPOSER_H

#include "Clauses.h"
#include "mlir/IR/BuiltinOps.h"
#include "llvm/Frontend/OpenMP/ConstructCompositionT.h"
#include "llvm/Frontend/OpenMP/ConstructDecompositionT.h"
#include "llvm/Frontend/OpenMP/OMP.h"
#include "llvm/Support/Compiler.h"

namespace llvm {
class raw_ostream;
}

namespace Fortran {
namespace semantics {
class SemanticsContext;
}
namespace lower::pft {
struct Evaluation;
}
} // namespace Fortran

namespace Fortran::lower::omp {
using UnitConstruct = tomp::DirectiveWithClauses<lower::omp::Clause>;
using ConstructQueue = List<UnitConstruct>;

LLVM_DUMP_METHOD llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
const UnitConstruct &uc);

// Given a potentially compound construct with a list of clauses that
// apply to it, break it up into individual sub-constructs each with
// the subset of applicable clauses (plus implicit clauses, if any).
// From that create a work queue where each work item corresponds to
// the sub-construct with its clauses.
ConstructQueue buildConstructQueue(mlir::ModuleOp modOp,
semantics::SemanticsContext &semaCtx,
lower::pft::Evaluation &eval,
const parser::CharBlock &source,
llvm::omp::Directive compound,
const List<Clause> &clauses);
} // namespace Fortran::lower::omp

#endif // FORTRAN_LOWER_OPENMP_DECOMPOSER_H
Loading

0 comments on commit ca1bd59

Please sign in to comment.