-
Notifications
You must be signed in to change notification settings - Fork 15.5k
[flang][OpenMP] Implement COMBINER clause #172036
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
This adds parsing and lowering of the COMBINER clause. It utilizes the existing lowering code for combiner-expression to lower the COMBINER clause as well.
|
@llvm/pr-subscribers-flang-parser @llvm/pr-subscribers-flang-semantics Author: Krzysztof Parzyszek (kparzysz) ChangesThis adds parsing and lowering of the COMBINER clause. It utilizes the existing lowering code for combiner-expression to lower the COMBINER clause as well. Patch is 29.17 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/172036.diff 13 Files Affected:
diff --git a/flang/include/flang/Lower/OpenMP/Clauses.h b/flang/include/flang/Lower/OpenMP/Clauses.h
index 455eda2738e6d..5f03877624be7 100644
--- a/flang/include/flang/Lower/OpenMP/Clauses.h
+++ b/flang/include/flang/Lower/OpenMP/Clauses.h
@@ -104,6 +104,7 @@ struct hash<Fortran::lower::omp::IdTy> {
namespace Fortran::lower::omp {
using Object = tomp::ObjectT<IdTy, ExprTy>;
using ObjectList = tomp::ObjectListT<IdTy, ExprTy>;
+using StylizedInstance = tomp::type::StylizedInstanceT<IdTy, ExprTy>;
Object makeObject(const parser::OmpObject &object,
semantics::SemanticsContext &semaCtx);
@@ -173,8 +174,10 @@ std::optional<ResultTy> maybeApplyToV(FuncTy &&func, const ArgTy *arg) {
std::optional<Object> getBaseObject(const Object &object,
semantics::SemanticsContext &semaCtx);
+StylizedInstance makeStylizedInstance(const parser::OmpStylizedInstance &inp,
+ semantics::SemanticsContext &semaCtx);
+
namespace clause {
-using StylizedInstance = tomp::type::StylizedInstanceT<IdTy, ExprTy>;
using Range = tomp::type::RangeT<ExprTy>;
using Mapper = tomp::type::MapperT<IdTy, ExprTy>;
using Iterator = tomp::type::IteratorT<TypeTy, IdTy, ExprTy>;
@@ -208,6 +211,7 @@ using Bind = tomp::clause::BindT<TypeTy, IdTy, ExprTy>;
using Capture = tomp::clause::CaptureT<TypeTy, IdTy, ExprTy>;
using Collapse = tomp::clause::CollapseT<TypeTy, IdTy, ExprTy>;
using Collector = tomp::clause::CollectorT<TypeTy, IdTy, ExprTy>;
+using Combiner = tomp::clause::CombinerT<TypeTy, IdTy, ExprTy>;
using Compare = tomp::clause::CompareT<TypeTy, IdTy, ExprTy>;
using Contains = tomp::clause::ContainsT<TypeTy, IdTy, ExprTy>;
using Copyin = tomp::clause::CopyinT<TypeTy, IdTy, ExprTy>;
diff --git a/flang/include/flang/Parser/dump-parse-tree.h b/flang/include/flang/Parser/dump-parse-tree.h
index 252e156d2d459..dbc2b6541dd75 100644
--- a/flang/include/flang/Parser/dump-parse-tree.h
+++ b/flang/include/flang/Parser/dump-parse-tree.h
@@ -562,6 +562,7 @@ class ParseTreeDumper {
NODE(parser, OmpClauseList)
NODE(parser, OmpCloseModifier)
NODE_ENUM(OmpCloseModifier, Value)
+ NODE(parser, OmpCombinerClause)
NODE(parser, OmpCombinerExpression)
NODE(parser, OmpContainsClause)
NODE(parser, OmpContextSelectorSpecification)
diff --git a/flang/include/flang/Parser/openmp-utils.h b/flang/include/flang/Parser/openmp-utils.h
index 0fc7dbd29d6aa..bd200558e4c59 100644
--- a/flang/include/flang/Parser/openmp-utils.h
+++ b/flang/include/flang/Parser/openmp-utils.h
@@ -226,9 +226,9 @@ const BlockConstruct *GetFortranBlockConstruct(
const Block &GetInnermostExecPart(const Block &block);
bool IsStrictlyStructuredBlock(const Block &block);
-const OmpCombinerExpression *GetCombinerExpr(
- const OmpReductionSpecifier &rspec);
-const OmpInitializerExpression *GetInitializerExpr(const OmpClause &init);
+const OmpCombinerExpression *GetCombinerExpr(const OmpReductionSpecifier &x);
+const OmpCombinerExpression *GetCombinerExpr(const OmpClause &x);
+const OmpInitializerExpression *GetInitializerExpr(const OmpClause &x);
struct OmpAllocateInfo {
std::vector<const OmpAllocateDirective *> dirs;
diff --git a/flang/include/flang/Parser/parse-tree.h b/flang/include/flang/Parser/parse-tree.h
index 93743709f10d2..b00d25373f801 100644
--- a/flang/include/flang/Parser/parse-tree.h
+++ b/flang/include/flang/Parser/parse-tree.h
@@ -4395,6 +4395,14 @@ struct OmpCancellationConstructTypeClause {
std::tuple<OmpDirectiveName, std::optional<ScalarLogicalExpr>> t;
};
+// Ref: [6.0:262]
+//
+// combiner-clause -> // since 6.0
+// COMBINER(combiner-expr)
+struct OmpCombinerClause {
+ WRAPPER_CLASS_BOILERPLATE(OmpCombinerClause, OmpCombinerExpression);
+};
+
// Ref: [5.2:214]
//
// contains-clause ->
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
index 3c31b3a07f57f..b923e415231d6 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
@@ -390,10 +390,10 @@ bool ClauseProcessor::processInitializer(
mlir::Type type, mlir::Value ompOrig) {
lower::SymMapScope scope(symMap);
mlir::Value ompPrivVar;
- const clause::StylizedInstance &inst = clause->v.front();
+ const StylizedInstance &inst = clause->v.front();
for (const Object &object :
- std::get<clause::StylizedInstance::Variables>(inst.t)) {
+ std::get<StylizedInstance::Variables>(inst.t)) {
mlir::Value addr = builder.createTemporary(loc, ompOrig.getType());
fir::StoreOp::create(builder, loc, ompOrig, addr);
fir::FortranVariableFlagsEnum extraFlags = {};
@@ -412,7 +412,7 @@ bool ClauseProcessor::processInitializer(
// Lower the expression/function call
lower::StatementContext stmtCtx;
const semantics::SomeExpr &initExpr =
- std::get<clause::StylizedInstance::Instance>(inst.t);
+ std::get<StylizedInstance::Instance>(inst.t);
mlir::Value result = common::visit(
common::visitors{
[&](const evaluate::ProcedureRef &procRef) -> mlir::Value {
@@ -439,7 +439,9 @@ bool ClauseProcessor::processInitializer(
};
return true;
}
- return false;
+ TODO(converter.getCurrentLocation(),
+ "declare reduction without an initializer clause is not yet "
+ "supported");
}
bool ClauseProcessor::processMergeable(
diff --git a/flang/lib/Lower/OpenMP/Clauses.cpp b/flang/lib/Lower/OpenMP/Clauses.cpp
index 9ea4e8fcd6c0e..d53054f005dea 100644
--- a/flang/lib/Lower/OpenMP/Clauses.cpp
+++ b/flang/lib/Lower/OpenMP/Clauses.cpp
@@ -197,6 +197,24 @@ std::optional<Object> getBaseObject(const Object &object,
return std::nullopt;
}
+StylizedInstance makeStylizedInstance(const parser::OmpStylizedInstance &inp,
+ semantics::SemanticsContext &semaCtx) {
+ ObjectList variables;
+ llvm::transform(std::get<std::list<parser::OmpStylizedDeclaration>>(inp.t),
+ std::back_inserter(variables),
+ [&](const parser::OmpStylizedDeclaration &s) {
+ return makeObject(s.var, semaCtx);
+ });
+
+ SomeExpr instance = [&]() {
+ if (auto &&expr = semantics::omp::MakeEvaluateExpr(inp))
+ return std::move(*expr);
+ llvm_unreachable("Expecting expression instance");
+ }();
+
+ return StylizedInstance{{std::move(variables), std::move(instance)}};
+}
+
// Helper macros
#define MAKE_EMPTY_CLASS(cls, from_cls) \
cls make(const parser::OmpClause::from_cls &, \
@@ -551,6 +569,17 @@ Collapse make(const parser::OmpClause::Collapse &inp,
return Collapse{/*N=*/makeExpr(inp.v, semaCtx)};
}
+Combiner make(const parser::OmpClause::Combiner &inp,
+ semantics::SemanticsContext &semaCtx) {
+ const parser::OmpCombinerExpression &cexpr = inp.v.v;
+ Combiner combiner;
+
+ for (const parser::OmpStylizedInstance &sinst : cexpr.v)
+ combiner.v.push_back(makeStylizedInstance(sinst, semaCtx));
+
+ return combiner;
+}
+
// Compare: empty
Contains make(const parser::OmpClause::Contains &inp,
@@ -988,24 +1017,8 @@ Initializer make(const parser::OmpClause::Initializer &inp,
const parser::OmpInitializerExpression &iexpr = inp.v.v;
Initializer initializer;
- for (const parser::OmpStylizedInstance &sinst : iexpr.v) {
- ObjectList variables;
- llvm::transform(
- std::get<std::list<parser::OmpStylizedDeclaration>>(sinst.t),
- std::back_inserter(variables),
- [&](const parser::OmpStylizedDeclaration &s) {
- return makeObject(s.var, semaCtx);
- });
-
- SomeExpr instance = [&]() {
- if (auto &&expr = semantics::omp::MakeEvaluateExpr(sinst))
- return std::move(*expr);
- llvm_unreachable("Expecting expression instance");
- }();
-
- initializer.v.push_back(
- StylizedInstance{{std::move(variables), std::move(instance)}});
- }
+ for (const parser::OmpStylizedInstance &sinst : iexpr.v)
+ initializer.v.push_back(makeStylizedInstance(sinst, semaCtx));
return initializer;
}
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 38ab42076f559..7965119764e5d 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -3602,57 +3602,28 @@ genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
TODO(converter.getCurrentLocation(), "OmpDeclareVariantDirective");
}
-static ReductionProcessor::GenCombinerCBTy
-processReductionCombiner(lower::AbstractConverter &converter,
- lower::SymMap &symTable,
- semantics::SemanticsContext &semaCtx,
- const parser::OmpReductionSpecifier &specifier) {
+static ReductionProcessor::GenCombinerCBTy processReductionCombiner(
+ lower::AbstractConverter &converter, lower::SymMap &symTable,
+ semantics::SemanticsContext &semaCtx, const clause::Combiner &combiner) {
ReductionProcessor::GenCombinerCBTy genCombinerCB;
- const auto &combinerExpression =
- std::get<std::optional<parser::OmpCombinerExpression>>(specifier.t)
- .value();
- const parser::OmpStylizedInstance &combinerInstance =
- combinerExpression.v.front();
- const parser::OmpStylizedInstance::Instance &instance =
- std::get<parser::OmpStylizedInstance::Instance>(combinerInstance.t);
-
- std::optional<semantics::SomeExpr> evalExprOpt;
- if (const auto *as = std::get_if<parser::AssignmentStmt>(&instance.u)) {
- auto &expr = std::get<parser::Expr>(as->t);
- evalExprOpt = makeExpr(expr, semaCtx);
- } else if (const auto *call = std::get_if<parser::CallStmt>(&instance.u)) {
- if (call->typedCall) {
- const auto &procRef = *call->typedCall;
- evalExprOpt = semantics::SomeExpr{procRef};
- } else {
- TODO(converter.getCurrentLocation(),
- "CallStmt without typedCall is not yet supported");
- }
- } else {
- TODO(converter.getCurrentLocation(), "Unsupported combiner instance type");
- }
-
- assert(evalExprOpt.has_value() && "evalExpr must be initialized");
- semantics::SomeExpr evalExpr = *evalExprOpt;
+ const StylizedInstance &inst = combiner.v.front();
+ semantics::SomeExpr evalExpr = std::get<StylizedInstance::Instance>(inst.t);
genCombinerCB = [&, evalExpr](fir::FirOpBuilder &builder, mlir::Location loc,
mlir::Type type, mlir::Value lhs,
mlir::Value rhs, bool isByRef) {
lower::SymMapScope scope(symTable);
- const std::list<parser::OmpStylizedDeclaration> &declList =
- std::get<std::list<parser::OmpStylizedDeclaration>>(combinerInstance.t);
mlir::Value ompOutVar;
- for (const parser::OmpStylizedDeclaration &decl : declList) {
- auto &name = std::get<parser::ObjectName>(decl.var.t);
+ for (const Object &object : std::get<StylizedInstance::Variables>(inst.t)) {
mlir::Value addr = lhs;
mlir::Type type = lhs.getType();
- bool isRhs = name.ToString() == std::string("omp_in");
+ std::string name = object.sym()->name().ToString();
+ bool isRhs = name == "omp_in";
if (isRhs) {
addr = rhs;
type = rhs.getType();
}
- assert(name.symbol && "Reduction object name does not have a symbol");
if (!fir::conformsWithPassByRef(type)) {
addr = builder.createTemporary(loc, type);
fir::StoreOp::create(builder, loc, isRhs ? rhs : lhs, addr);
@@ -3660,13 +3631,13 @@ processReductionCombiner(lower::AbstractConverter &converter,
fir::FortranVariableFlagsEnum extraFlags = {};
fir::FortranVariableFlagsAttr attributes =
Fortran::lower::translateSymbolAttributes(builder.getContext(),
- *name.symbol, extraFlags);
+ *object.sym(), extraFlags);
auto declareOp =
- hlfir::DeclareOp::create(builder, loc, addr, name.ToString(), nullptr,
- {}, nullptr, nullptr, 0, attributes);
- if (name.ToString() == "omp_out")
+ hlfir::DeclareOp::create(builder, loc, addr, name, nullptr, {},
+ nullptr, nullptr, 0, attributes);
+ if (name == "omp_out")
ompOutVar = declareOp.getResult(0);
- symTable.addVariableDefinition(*name.symbol, declareOp);
+ symTable.addVariableDefinition(*object.sym(), declareOp);
}
lower::StatementContext stmtCtx;
@@ -3740,46 +3711,69 @@ getReductionType(lower::AbstractConverter &converter,
return reductionType;
}
-static void genOMP(
- lower::AbstractConverter &converter, lower::SymMap &symTable,
- semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
- const parser::OpenMPDeclareReductionConstruct &declareReductionConstruct) {
+// Represent the reduction combiner as a clause, return reference to it.
+// If there is a "combiner" clause already present, do nothing. Otherwise
+// manufacture a combiner clause from the combiner expression on the reduction
+// specifier and append it to the list of clauses.
+static const clause::Combiner &
+appendCombiner(const parser::OpenMPDeclareReductionConstruct &construct,
+ List<Clause> &clauses, semantics::SemanticsContext &semaCtx) {
+ for (const Clause &clause : clauses) {
+ if (clause.id == llvm::omp::Clause::OMPC_combiner)
+ return std::get<clause::Combiner>(clause.u);
+ }
+
+ using namespace parser::omp;
+ const parser::OmpDirectiveSpecification &dirSpec = construct.v;
+ auto *specifier = GetFirstArgument<parser::OmpReductionSpecifier>(dirSpec);
+ assert(specifier && "Expecting reduction specifier");
+ if (auto *expr = GetCombinerExpr(*specifier)) {
+ clause::Combiner combiner;
+ for (const parser::OmpStylizedInstance &sinst : expr->v)
+ combiner.v.push_back(makeStylizedInstance(sinst, semaCtx));
+ clauses.push_back(makeClause(llvm::omp::Clause::OMPC_combiner,
+ std::move(combiner), expr->source));
+ return std::get<clause::Combiner>(clauses.back().u);
+ }
+
+ llvm_unreachable("Expecting reduction combiner");
+}
+
+static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
+ semantics::SemanticsContext &semaCtx,
+ lower::pft::Evaluation &eval,
+ const parser::OpenMPDeclareReductionConstruct &construct) {
if (semaCtx.langOptions().OpenMPSimd)
return;
- const parser::OmpArgumentList &args{declareReductionConstruct.v.Arguments()};
- const parser::OmpArgument &arg{args.v.front()};
- const auto &specifier = std::get<parser::OmpReductionSpecifier>(arg.u);
-
+ const auto &specifier =
+ DEREF(parser::omp::GetFirstArgument<parser::OmpReductionSpecifier>(
+ construct.v));
if (std::get<parser::OmpTypeNameList>(specifier.t).v.size() > 1)
TODO(converter.getCurrentLocation(),
"multiple types in declare reduction is not yet supported");
mlir::Type reductionType = getReductionType(converter, specifier);
+ List<Clause> clauses = makeClauses(construct.v.Clauses(), semaCtx);
+ const clause::Combiner &combiner =
+ appendCombiner(construct, clauses, semaCtx);
+
ReductionProcessor::GenCombinerCBTy genCombinerCB =
- processReductionCombiner(converter, symTable, semaCtx, specifier);
- const parser::OmpClauseList &initializer =
- declareReductionConstruct.v.Clauses();
- if (initializer.v.size() > 0) {
- List<Clause> clauses = makeClauses(initializer, semaCtx);
- ReductionProcessor::GenInitValueCBTy genInitValueCB;
- ClauseProcessor cp(converter, semaCtx, clauses);
- cp.processInitializer(symTable, genInitValueCB);
- const auto &identifier =
- std::get<parser::OmpReductionIdentifier>(specifier.t);
- const auto &designator =
- std::get<parser::ProcedureDesignator>(identifier.u);
- const auto &reductionName = std::get<parser::Name>(designator.u);
- bool isByRef = ReductionProcessor::doReductionByRef(reductionType);
- ReductionProcessor::createDeclareReductionHelper<
- mlir::omp::DeclareReductionOp>(
- converter, reductionName.ToString(), reductionType,
- converter.getCurrentLocation(), isByRef, genCombinerCB, genInitValueCB);
- } else {
- TODO(converter.getCurrentLocation(),
- "declare reduction without an initializer clause is not yet "
- "supported");
- }
+ processReductionCombiner(converter, symTable, semaCtx, combiner);
+
+ ReductionProcessor::GenInitValueCBTy genInitValueCB;
+ ClauseProcessor cp(converter, semaCtx, clauses);
+ cp.processInitializer(symTable, genInitValueCB);
+
+ const auto &identifier =
+ std::get<parser::OmpReductionIdentifier>(specifier.t);
+ const auto &designator = std::get<parser::ProcedureDesignator>(identifier.u);
+ const auto &reductionName = std::get<parser::Name>(designator.u);
+ bool isByRef = ReductionProcessor::doReductionByRef(reductionType);
+ ReductionProcessor::createDeclareReductionHelper<
+ mlir::omp::DeclareReductionOp>(
+ converter, reductionName.ToString(), reductionType,
+ converter.getCurrentLocation(), isByRef, genCombinerCB, genInitValueCB);
}
static void
diff --git a/flang/lib/Parser/openmp-parsers.cpp b/flang/lib/Parser/openmp-parsers.cpp
index 24bdef9f88ed4..1f0d1be9adc00 100644
--- a/flang/lib/Parser/openmp-parsers.cpp
+++ b/flang/lib/Parser/openmp-parsers.cpp
@@ -488,27 +488,30 @@ static void InstantiateDeclareReduction(OmpDirectiveSpecification &spec) {
return;
}
- const OmpTypeNameList *typeNames{nullptr};
+ const OmpTypeNameList &typeNames{std::get<OmpTypeNameList>(rspec->t)};
if (auto *cexpr{
const_cast<OmpCombinerExpression *>(GetCombinerExpr(*rspec))}) {
- typeNames = &std::get<OmpTypeNameList>(rspec->t);
-
- InstantiateForTypes(*cexpr, *typeNames, OmpCombinerExpression::Variables());
+ InstantiateForTypes(*cexpr, typeNames, OmpCombinerExpression::Variables());
delete cexpr->state;
cexpr->state = nullptr;
- } else {
- // If there are no types, there is nothing else to do.
- return;
}
for (const OmpClause &clause : spec.Clauses().v) {
llvm::omp::Clause id{clause.Id()};
- if (id == llvm::omp::Clause::OMPC_initializer) {
+ if (id == llvm::omp::Clause::OMPC_combiner) {
+ if (auto *cexpr{
+ const_cast<OmpCombinerExpression *>(GetCombinerExpr(clause))}) {
+ InstantiateForTypes(
+ *cexpr, typeNames, OmpCombinerExpression::Variables());
+ delete cexpr->state;
+ cexpr->state = nullptr;
+ }
+ } else if (id == llvm::omp::Clause::OMPC_initializer) {
if (auto *iexpr{const_cast<OmpInitializerExpression *>(
GetInitializerExpr(clause))}) {
InstantiateForTypes(
- *iexpr, *typeNames, OmpInitializerExpression::Variables());
+ *iexpr, typeNames, OmpInitializerExpression::Variables());
delete iexpr->state;
iexpr->state = nullptr;
}
@@ -1316,6 +1319,8 @@ TYPE_PARSER(construct<OmpDetachClause>(Parser<OmpObject>{}))
TYPE_PARSER(construct<OmpHintClause>(scalarIntConstantExpr))
+TYPE_PARSER(construct<OmpCombinerClause>(Parser<OmpCombinerExpression>{}))
+
// init clause
TYPE_PARSER(construct<OmpInitClause>(
maybe(nonemptyList(Parser<OmpInitClause::Modifier>{}) / ":"),
@@ -1426,6 +1431,8 @@ TYPE_PARSER( //
"CAPTURE" >> construct<OmpClause>(construct<OmpClause::Capture>()) ||
"COLLAPSE" >> construct<OmpClause>(construct<OmpClause::Collapse>(
parenthesized(scalarIntConstantExpr))) ||
+ "COMBINER" >> construct<OmpClause>(construct<OmpClause::Combiner>(
+ parenthesized(Parser<OmpCombinerClause>{}))) ||
"COMPARE" >> construct<OmpClause>(construct<OmpClause::Compare>()) ||
"CONTAINS" >> construct<OmpClause>(construct<OmpClause::Contains>(
parenthesized(Parser<OmpContainsClause>{}))) ||
diff --git a/flang/lib/Parser/openmp-utils.cpp b/flang/lib/Parser/openmp-utils.cpp
index f96a5fca778e1..a9dbb55819b1e 100644
--- a/flang/lib/Parser/openmp-utils.cpp
+++ b/flang/lib/Parser/open...
[truncated]
|
|
@llvm/pr-subscribers-flang-fir-hlfir Author: Krzysztof Parzyszek (kparzysz) ChangesThis adds parsing and lowering of the COMBINER clause. It utilizes the existing lowering code for combiner-expression to lower the COMBINER clause as well. Patch is 29.17 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/172036.diff 13 Files Affected:
diff --git a/flang/include/flang/Lower/OpenMP/Clauses.h b/flang/include/flang/Lower/OpenMP/Clauses.h
index 455eda2738e6d..5f03877624be7 100644
--- a/flang/include/flang/Lower/OpenMP/Clauses.h
+++ b/flang/include/flang/Lower/OpenMP/Clauses.h
@@ -104,6 +104,7 @@ struct hash<Fortran::lower::omp::IdTy> {
namespace Fortran::lower::omp {
using Object = tomp::ObjectT<IdTy, ExprTy>;
using ObjectList = tomp::ObjectListT<IdTy, ExprTy>;
+using StylizedInstance = tomp::type::StylizedInstanceT<IdTy, ExprTy>;
Object makeObject(const parser::OmpObject &object,
semantics::SemanticsContext &semaCtx);
@@ -173,8 +174,10 @@ std::optional<ResultTy> maybeApplyToV(FuncTy &&func, const ArgTy *arg) {
std::optional<Object> getBaseObject(const Object &object,
semantics::SemanticsContext &semaCtx);
+StylizedInstance makeStylizedInstance(const parser::OmpStylizedInstance &inp,
+ semantics::SemanticsContext &semaCtx);
+
namespace clause {
-using StylizedInstance = tomp::type::StylizedInstanceT<IdTy, ExprTy>;
using Range = tomp::type::RangeT<ExprTy>;
using Mapper = tomp::type::MapperT<IdTy, ExprTy>;
using Iterator = tomp::type::IteratorT<TypeTy, IdTy, ExprTy>;
@@ -208,6 +211,7 @@ using Bind = tomp::clause::BindT<TypeTy, IdTy, ExprTy>;
using Capture = tomp::clause::CaptureT<TypeTy, IdTy, ExprTy>;
using Collapse = tomp::clause::CollapseT<TypeTy, IdTy, ExprTy>;
using Collector = tomp::clause::CollectorT<TypeTy, IdTy, ExprTy>;
+using Combiner = tomp::clause::CombinerT<TypeTy, IdTy, ExprTy>;
using Compare = tomp::clause::CompareT<TypeTy, IdTy, ExprTy>;
using Contains = tomp::clause::ContainsT<TypeTy, IdTy, ExprTy>;
using Copyin = tomp::clause::CopyinT<TypeTy, IdTy, ExprTy>;
diff --git a/flang/include/flang/Parser/dump-parse-tree.h b/flang/include/flang/Parser/dump-parse-tree.h
index 252e156d2d459..dbc2b6541dd75 100644
--- a/flang/include/flang/Parser/dump-parse-tree.h
+++ b/flang/include/flang/Parser/dump-parse-tree.h
@@ -562,6 +562,7 @@ class ParseTreeDumper {
NODE(parser, OmpClauseList)
NODE(parser, OmpCloseModifier)
NODE_ENUM(OmpCloseModifier, Value)
+ NODE(parser, OmpCombinerClause)
NODE(parser, OmpCombinerExpression)
NODE(parser, OmpContainsClause)
NODE(parser, OmpContextSelectorSpecification)
diff --git a/flang/include/flang/Parser/openmp-utils.h b/flang/include/flang/Parser/openmp-utils.h
index 0fc7dbd29d6aa..bd200558e4c59 100644
--- a/flang/include/flang/Parser/openmp-utils.h
+++ b/flang/include/flang/Parser/openmp-utils.h
@@ -226,9 +226,9 @@ const BlockConstruct *GetFortranBlockConstruct(
const Block &GetInnermostExecPart(const Block &block);
bool IsStrictlyStructuredBlock(const Block &block);
-const OmpCombinerExpression *GetCombinerExpr(
- const OmpReductionSpecifier &rspec);
-const OmpInitializerExpression *GetInitializerExpr(const OmpClause &init);
+const OmpCombinerExpression *GetCombinerExpr(const OmpReductionSpecifier &x);
+const OmpCombinerExpression *GetCombinerExpr(const OmpClause &x);
+const OmpInitializerExpression *GetInitializerExpr(const OmpClause &x);
struct OmpAllocateInfo {
std::vector<const OmpAllocateDirective *> dirs;
diff --git a/flang/include/flang/Parser/parse-tree.h b/flang/include/flang/Parser/parse-tree.h
index 93743709f10d2..b00d25373f801 100644
--- a/flang/include/flang/Parser/parse-tree.h
+++ b/flang/include/flang/Parser/parse-tree.h
@@ -4395,6 +4395,14 @@ struct OmpCancellationConstructTypeClause {
std::tuple<OmpDirectiveName, std::optional<ScalarLogicalExpr>> t;
};
+// Ref: [6.0:262]
+//
+// combiner-clause -> // since 6.0
+// COMBINER(combiner-expr)
+struct OmpCombinerClause {
+ WRAPPER_CLASS_BOILERPLATE(OmpCombinerClause, OmpCombinerExpression);
+};
+
// Ref: [5.2:214]
//
// contains-clause ->
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
index 3c31b3a07f57f..b923e415231d6 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
@@ -390,10 +390,10 @@ bool ClauseProcessor::processInitializer(
mlir::Type type, mlir::Value ompOrig) {
lower::SymMapScope scope(symMap);
mlir::Value ompPrivVar;
- const clause::StylizedInstance &inst = clause->v.front();
+ const StylizedInstance &inst = clause->v.front();
for (const Object &object :
- std::get<clause::StylizedInstance::Variables>(inst.t)) {
+ std::get<StylizedInstance::Variables>(inst.t)) {
mlir::Value addr = builder.createTemporary(loc, ompOrig.getType());
fir::StoreOp::create(builder, loc, ompOrig, addr);
fir::FortranVariableFlagsEnum extraFlags = {};
@@ -412,7 +412,7 @@ bool ClauseProcessor::processInitializer(
// Lower the expression/function call
lower::StatementContext stmtCtx;
const semantics::SomeExpr &initExpr =
- std::get<clause::StylizedInstance::Instance>(inst.t);
+ std::get<StylizedInstance::Instance>(inst.t);
mlir::Value result = common::visit(
common::visitors{
[&](const evaluate::ProcedureRef &procRef) -> mlir::Value {
@@ -439,7 +439,9 @@ bool ClauseProcessor::processInitializer(
};
return true;
}
- return false;
+ TODO(converter.getCurrentLocation(),
+ "declare reduction without an initializer clause is not yet "
+ "supported");
}
bool ClauseProcessor::processMergeable(
diff --git a/flang/lib/Lower/OpenMP/Clauses.cpp b/flang/lib/Lower/OpenMP/Clauses.cpp
index 9ea4e8fcd6c0e..d53054f005dea 100644
--- a/flang/lib/Lower/OpenMP/Clauses.cpp
+++ b/flang/lib/Lower/OpenMP/Clauses.cpp
@@ -197,6 +197,24 @@ std::optional<Object> getBaseObject(const Object &object,
return std::nullopt;
}
+StylizedInstance makeStylizedInstance(const parser::OmpStylizedInstance &inp,
+ semantics::SemanticsContext &semaCtx) {
+ ObjectList variables;
+ llvm::transform(std::get<std::list<parser::OmpStylizedDeclaration>>(inp.t),
+ std::back_inserter(variables),
+ [&](const parser::OmpStylizedDeclaration &s) {
+ return makeObject(s.var, semaCtx);
+ });
+
+ SomeExpr instance = [&]() {
+ if (auto &&expr = semantics::omp::MakeEvaluateExpr(inp))
+ return std::move(*expr);
+ llvm_unreachable("Expecting expression instance");
+ }();
+
+ return StylizedInstance{{std::move(variables), std::move(instance)}};
+}
+
// Helper macros
#define MAKE_EMPTY_CLASS(cls, from_cls) \
cls make(const parser::OmpClause::from_cls &, \
@@ -551,6 +569,17 @@ Collapse make(const parser::OmpClause::Collapse &inp,
return Collapse{/*N=*/makeExpr(inp.v, semaCtx)};
}
+Combiner make(const parser::OmpClause::Combiner &inp,
+ semantics::SemanticsContext &semaCtx) {
+ const parser::OmpCombinerExpression &cexpr = inp.v.v;
+ Combiner combiner;
+
+ for (const parser::OmpStylizedInstance &sinst : cexpr.v)
+ combiner.v.push_back(makeStylizedInstance(sinst, semaCtx));
+
+ return combiner;
+}
+
// Compare: empty
Contains make(const parser::OmpClause::Contains &inp,
@@ -988,24 +1017,8 @@ Initializer make(const parser::OmpClause::Initializer &inp,
const parser::OmpInitializerExpression &iexpr = inp.v.v;
Initializer initializer;
- for (const parser::OmpStylizedInstance &sinst : iexpr.v) {
- ObjectList variables;
- llvm::transform(
- std::get<std::list<parser::OmpStylizedDeclaration>>(sinst.t),
- std::back_inserter(variables),
- [&](const parser::OmpStylizedDeclaration &s) {
- return makeObject(s.var, semaCtx);
- });
-
- SomeExpr instance = [&]() {
- if (auto &&expr = semantics::omp::MakeEvaluateExpr(sinst))
- return std::move(*expr);
- llvm_unreachable("Expecting expression instance");
- }();
-
- initializer.v.push_back(
- StylizedInstance{{std::move(variables), std::move(instance)}});
- }
+ for (const parser::OmpStylizedInstance &sinst : iexpr.v)
+ initializer.v.push_back(makeStylizedInstance(sinst, semaCtx));
return initializer;
}
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 38ab42076f559..7965119764e5d 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -3602,57 +3602,28 @@ genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
TODO(converter.getCurrentLocation(), "OmpDeclareVariantDirective");
}
-static ReductionProcessor::GenCombinerCBTy
-processReductionCombiner(lower::AbstractConverter &converter,
- lower::SymMap &symTable,
- semantics::SemanticsContext &semaCtx,
- const parser::OmpReductionSpecifier &specifier) {
+static ReductionProcessor::GenCombinerCBTy processReductionCombiner(
+ lower::AbstractConverter &converter, lower::SymMap &symTable,
+ semantics::SemanticsContext &semaCtx, const clause::Combiner &combiner) {
ReductionProcessor::GenCombinerCBTy genCombinerCB;
- const auto &combinerExpression =
- std::get<std::optional<parser::OmpCombinerExpression>>(specifier.t)
- .value();
- const parser::OmpStylizedInstance &combinerInstance =
- combinerExpression.v.front();
- const parser::OmpStylizedInstance::Instance &instance =
- std::get<parser::OmpStylizedInstance::Instance>(combinerInstance.t);
-
- std::optional<semantics::SomeExpr> evalExprOpt;
- if (const auto *as = std::get_if<parser::AssignmentStmt>(&instance.u)) {
- auto &expr = std::get<parser::Expr>(as->t);
- evalExprOpt = makeExpr(expr, semaCtx);
- } else if (const auto *call = std::get_if<parser::CallStmt>(&instance.u)) {
- if (call->typedCall) {
- const auto &procRef = *call->typedCall;
- evalExprOpt = semantics::SomeExpr{procRef};
- } else {
- TODO(converter.getCurrentLocation(),
- "CallStmt without typedCall is not yet supported");
- }
- } else {
- TODO(converter.getCurrentLocation(), "Unsupported combiner instance type");
- }
-
- assert(evalExprOpt.has_value() && "evalExpr must be initialized");
- semantics::SomeExpr evalExpr = *evalExprOpt;
+ const StylizedInstance &inst = combiner.v.front();
+ semantics::SomeExpr evalExpr = std::get<StylizedInstance::Instance>(inst.t);
genCombinerCB = [&, evalExpr](fir::FirOpBuilder &builder, mlir::Location loc,
mlir::Type type, mlir::Value lhs,
mlir::Value rhs, bool isByRef) {
lower::SymMapScope scope(symTable);
- const std::list<parser::OmpStylizedDeclaration> &declList =
- std::get<std::list<parser::OmpStylizedDeclaration>>(combinerInstance.t);
mlir::Value ompOutVar;
- for (const parser::OmpStylizedDeclaration &decl : declList) {
- auto &name = std::get<parser::ObjectName>(decl.var.t);
+ for (const Object &object : std::get<StylizedInstance::Variables>(inst.t)) {
mlir::Value addr = lhs;
mlir::Type type = lhs.getType();
- bool isRhs = name.ToString() == std::string("omp_in");
+ std::string name = object.sym()->name().ToString();
+ bool isRhs = name == "omp_in";
if (isRhs) {
addr = rhs;
type = rhs.getType();
}
- assert(name.symbol && "Reduction object name does not have a symbol");
if (!fir::conformsWithPassByRef(type)) {
addr = builder.createTemporary(loc, type);
fir::StoreOp::create(builder, loc, isRhs ? rhs : lhs, addr);
@@ -3660,13 +3631,13 @@ processReductionCombiner(lower::AbstractConverter &converter,
fir::FortranVariableFlagsEnum extraFlags = {};
fir::FortranVariableFlagsAttr attributes =
Fortran::lower::translateSymbolAttributes(builder.getContext(),
- *name.symbol, extraFlags);
+ *object.sym(), extraFlags);
auto declareOp =
- hlfir::DeclareOp::create(builder, loc, addr, name.ToString(), nullptr,
- {}, nullptr, nullptr, 0, attributes);
- if (name.ToString() == "omp_out")
+ hlfir::DeclareOp::create(builder, loc, addr, name, nullptr, {},
+ nullptr, nullptr, 0, attributes);
+ if (name == "omp_out")
ompOutVar = declareOp.getResult(0);
- symTable.addVariableDefinition(*name.symbol, declareOp);
+ symTable.addVariableDefinition(*object.sym(), declareOp);
}
lower::StatementContext stmtCtx;
@@ -3740,46 +3711,69 @@ getReductionType(lower::AbstractConverter &converter,
return reductionType;
}
-static void genOMP(
- lower::AbstractConverter &converter, lower::SymMap &symTable,
- semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
- const parser::OpenMPDeclareReductionConstruct &declareReductionConstruct) {
+// Represent the reduction combiner as a clause, return reference to it.
+// If there is a "combiner" clause already present, do nothing. Otherwise
+// manufacture a combiner clause from the combiner expression on the reduction
+// specifier and append it to the list of clauses.
+static const clause::Combiner &
+appendCombiner(const parser::OpenMPDeclareReductionConstruct &construct,
+ List<Clause> &clauses, semantics::SemanticsContext &semaCtx) {
+ for (const Clause &clause : clauses) {
+ if (clause.id == llvm::omp::Clause::OMPC_combiner)
+ return std::get<clause::Combiner>(clause.u);
+ }
+
+ using namespace parser::omp;
+ const parser::OmpDirectiveSpecification &dirSpec = construct.v;
+ auto *specifier = GetFirstArgument<parser::OmpReductionSpecifier>(dirSpec);
+ assert(specifier && "Expecting reduction specifier");
+ if (auto *expr = GetCombinerExpr(*specifier)) {
+ clause::Combiner combiner;
+ for (const parser::OmpStylizedInstance &sinst : expr->v)
+ combiner.v.push_back(makeStylizedInstance(sinst, semaCtx));
+ clauses.push_back(makeClause(llvm::omp::Clause::OMPC_combiner,
+ std::move(combiner), expr->source));
+ return std::get<clause::Combiner>(clauses.back().u);
+ }
+
+ llvm_unreachable("Expecting reduction combiner");
+}
+
+static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
+ semantics::SemanticsContext &semaCtx,
+ lower::pft::Evaluation &eval,
+ const parser::OpenMPDeclareReductionConstruct &construct) {
if (semaCtx.langOptions().OpenMPSimd)
return;
- const parser::OmpArgumentList &args{declareReductionConstruct.v.Arguments()};
- const parser::OmpArgument &arg{args.v.front()};
- const auto &specifier = std::get<parser::OmpReductionSpecifier>(arg.u);
-
+ const auto &specifier =
+ DEREF(parser::omp::GetFirstArgument<parser::OmpReductionSpecifier>(
+ construct.v));
if (std::get<parser::OmpTypeNameList>(specifier.t).v.size() > 1)
TODO(converter.getCurrentLocation(),
"multiple types in declare reduction is not yet supported");
mlir::Type reductionType = getReductionType(converter, specifier);
+ List<Clause> clauses = makeClauses(construct.v.Clauses(), semaCtx);
+ const clause::Combiner &combiner =
+ appendCombiner(construct, clauses, semaCtx);
+
ReductionProcessor::GenCombinerCBTy genCombinerCB =
- processReductionCombiner(converter, symTable, semaCtx, specifier);
- const parser::OmpClauseList &initializer =
- declareReductionConstruct.v.Clauses();
- if (initializer.v.size() > 0) {
- List<Clause> clauses = makeClauses(initializer, semaCtx);
- ReductionProcessor::GenInitValueCBTy genInitValueCB;
- ClauseProcessor cp(converter, semaCtx, clauses);
- cp.processInitializer(symTable, genInitValueCB);
- const auto &identifier =
- std::get<parser::OmpReductionIdentifier>(specifier.t);
- const auto &designator =
- std::get<parser::ProcedureDesignator>(identifier.u);
- const auto &reductionName = std::get<parser::Name>(designator.u);
- bool isByRef = ReductionProcessor::doReductionByRef(reductionType);
- ReductionProcessor::createDeclareReductionHelper<
- mlir::omp::DeclareReductionOp>(
- converter, reductionName.ToString(), reductionType,
- converter.getCurrentLocation(), isByRef, genCombinerCB, genInitValueCB);
- } else {
- TODO(converter.getCurrentLocation(),
- "declare reduction without an initializer clause is not yet "
- "supported");
- }
+ processReductionCombiner(converter, symTable, semaCtx, combiner);
+
+ ReductionProcessor::GenInitValueCBTy genInitValueCB;
+ ClauseProcessor cp(converter, semaCtx, clauses);
+ cp.processInitializer(symTable, genInitValueCB);
+
+ const auto &identifier =
+ std::get<parser::OmpReductionIdentifier>(specifier.t);
+ const auto &designator = std::get<parser::ProcedureDesignator>(identifier.u);
+ const auto &reductionName = std::get<parser::Name>(designator.u);
+ bool isByRef = ReductionProcessor::doReductionByRef(reductionType);
+ ReductionProcessor::createDeclareReductionHelper<
+ mlir::omp::DeclareReductionOp>(
+ converter, reductionName.ToString(), reductionType,
+ converter.getCurrentLocation(), isByRef, genCombinerCB, genInitValueCB);
}
static void
diff --git a/flang/lib/Parser/openmp-parsers.cpp b/flang/lib/Parser/openmp-parsers.cpp
index 24bdef9f88ed4..1f0d1be9adc00 100644
--- a/flang/lib/Parser/openmp-parsers.cpp
+++ b/flang/lib/Parser/openmp-parsers.cpp
@@ -488,27 +488,30 @@ static void InstantiateDeclareReduction(OmpDirectiveSpecification &spec) {
return;
}
- const OmpTypeNameList *typeNames{nullptr};
+ const OmpTypeNameList &typeNames{std::get<OmpTypeNameList>(rspec->t)};
if (auto *cexpr{
const_cast<OmpCombinerExpression *>(GetCombinerExpr(*rspec))}) {
- typeNames = &std::get<OmpTypeNameList>(rspec->t);
-
- InstantiateForTypes(*cexpr, *typeNames, OmpCombinerExpression::Variables());
+ InstantiateForTypes(*cexpr, typeNames, OmpCombinerExpression::Variables());
delete cexpr->state;
cexpr->state = nullptr;
- } else {
- // If there are no types, there is nothing else to do.
- return;
}
for (const OmpClause &clause : spec.Clauses().v) {
llvm::omp::Clause id{clause.Id()};
- if (id == llvm::omp::Clause::OMPC_initializer) {
+ if (id == llvm::omp::Clause::OMPC_combiner) {
+ if (auto *cexpr{
+ const_cast<OmpCombinerExpression *>(GetCombinerExpr(clause))}) {
+ InstantiateForTypes(
+ *cexpr, typeNames, OmpCombinerExpression::Variables());
+ delete cexpr->state;
+ cexpr->state = nullptr;
+ }
+ } else if (id == llvm::omp::Clause::OMPC_initializer) {
if (auto *iexpr{const_cast<OmpInitializerExpression *>(
GetInitializerExpr(clause))}) {
InstantiateForTypes(
- *iexpr, *typeNames, OmpInitializerExpression::Variables());
+ *iexpr, typeNames, OmpInitializerExpression::Variables());
delete iexpr->state;
iexpr->state = nullptr;
}
@@ -1316,6 +1319,8 @@ TYPE_PARSER(construct<OmpDetachClause>(Parser<OmpObject>{}))
TYPE_PARSER(construct<OmpHintClause>(scalarIntConstantExpr))
+TYPE_PARSER(construct<OmpCombinerClause>(Parser<OmpCombinerExpression>{}))
+
// init clause
TYPE_PARSER(construct<OmpInitClause>(
maybe(nonemptyList(Parser<OmpInitClause::Modifier>{}) / ":"),
@@ -1426,6 +1431,8 @@ TYPE_PARSER( //
"CAPTURE" >> construct<OmpClause>(construct<OmpClause::Capture>()) ||
"COLLAPSE" >> construct<OmpClause>(construct<OmpClause::Collapse>(
parenthesized(scalarIntConstantExpr))) ||
+ "COMBINER" >> construct<OmpClause>(construct<OmpClause::Combiner>(
+ parenthesized(Parser<OmpCombinerClause>{}))) ||
"COMPARE" >> construct<OmpClause>(construct<OmpClause::Compare>()) ||
"CONTAINS" >> construct<OmpClause>(construct<OmpClause::Contains>(
parenthesized(Parser<OmpContainsClause>{}))) ||
diff --git a/flang/lib/Parser/openmp-utils.cpp b/flang/lib/Parser/openmp-utils.cpp
index f96a5fca778e1..a9dbb55819b1e 100644
--- a/flang/lib/Parser/openmp-utils.cpp
+++ b/flang/lib/Parser/open...
[truncated]
|
This adds parsing and lowering of the COMBINER clause. It utilizes the existing lowering code for combiner-expression to lower the COMBINER clause as well.