Skip to content
18 changes: 18 additions & 0 deletions flang/include/flang/Lower/Support/ReductionProcessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,13 @@ namespace omp {

class ReductionProcessor {
public:
using GenInitValueCBTy =
std::function<mlir::Value(fir::FirOpBuilder &builder, mlir::Location loc,
mlir::Type type, mlir::Value ompOrig)>;
using GenCombinerCBTy = std::function<void(
fir::FirOpBuilder &builder, mlir::Location loc, mlir::Type type,
mlir::Value op1, mlir::Value op2, bool isByRef)>;

// TODO: Move this enumeration to the OpenMP dialect
enum ReductionIdentifier {
ID,
Expand All @@ -58,6 +65,9 @@ class ReductionProcessor {
IEOR
};

static bool doReductionByRef(mlir::Type reductionType);
static bool doReductionByRef(mlir::Value reductionVar);

static ReductionIdentifier
getReductionType(const omp::clause::ProcedureDesignator &pd);

Expand Down Expand Up @@ -109,6 +119,14 @@ class ReductionProcessor {
ReductionIdentifier redId,
mlir::Type type, mlir::Value op1,
mlir::Value op2);
/// Creates an OpenMP reduction declaration and inserts it into the provided
/// symbol table. The init and combiner regions are generated by the callback
/// functions genCombinerCB and genInitValueCB.
template <typename DeclareRedType>
static DeclareRedType createDeclareReductionHelper(
AbstractConverter &converter, llvm::StringRef reductionOpName,
mlir::Type type, mlir::Location loc, bool isByRef,
GenCombinerCBTy genCombinerCB, GenInitValueCBTy genInitValueCB);

/// Creates an OpenMP reduction declaration and inserts it into the provided
/// symbol table. The declaration has a constant initializer with the neutral
Expand Down
60 changes: 60 additions & 0 deletions flang/lib/Lower/OpenMP/ClauseProcessor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "ClauseProcessor.h"
#include "Utils.h"

#include "flang/Lower/ConvertCall.h"
#include "flang/Lower/ConvertExprToHLFIR.h"
#include "flang/Lower/OpenMP/Clauses.h"
#include "flang/Lower/PFTBuilder.h"
Expand Down Expand Up @@ -402,6 +403,65 @@ bool ClauseProcessor::processInclusive(
return false;
}

bool ClauseProcessor::processInitializer(
lower::SymMap &symMap, const parser::OmpClause::Initializer &inp,
ReductionProcessor::GenInitValueCBTy &genInitValueCB) const {
if (auto *clause = findUniqueClause<omp::clause::Initializer>()) {
genInitValueCB = [&, clause](fir::FirOpBuilder &builder, mlir::Location loc,
mlir::Type type, mlir::Value ompOrig) {
lower::SymMapScope scope(symMap);
const parser::OmpInitializerExpression &iexpr = inp.v.v;
const parser::OmpStylizedInstance &styleInstance = iexpr.v.front();
const std::list<parser::OmpStylizedDeclaration> &declList =
std::get<std::list<parser::OmpStylizedDeclaration>>(styleInstance.t);
mlir::Value ompPrivVar;
for (const parser::OmpStylizedDeclaration &decl : declList) {
auto &name = std::get<parser::ObjectName>(decl.var.t);
assert(name.symbol && "Name does not have a symbol");
mlir::Value addr = builder.createTemporary(loc, ompOrig.getType());
fir::StoreOp::create(builder, loc, ompOrig, addr);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the guarantee that ompOrig is a primitive type?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't have to be primitive. It can be a derived type, as long as all the data can be contained in one unit. I would like to restrict the current PR to these simpler types. I can add a check that the reduction type to not include boxed types/references and put a TODO if that acceptable?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay yeah that works for me. Thanks!

fir::FortranVariableFlagsEnum extraFlags = {};
fir::FortranVariableFlagsAttr attributes =
Fortran::lower::translateSymbolAttributes(builder.getContext(),
*name.symbol, extraFlags);
auto declareOp = hlfir::DeclareOp::create(
builder, loc, addr, name.ToString(), nullptr, {}, nullptr, nullptr,
0, attributes);
if (name.ToString() == "omp_priv")
ompPrivVar = declareOp.getResult(0);
symMap.addVariableDefinition(*name.symbol, declareOp);
}
// Lower the expression/function call
lower::StatementContext stmtCtx;
mlir::Value result = common::visit(
common::visitors{
[&](const evaluate::ProcedureRef &procRef) -> mlir::Value {
convertCallToHLFIR(loc, converter, procRef, std::nullopt,
symMap, stmtCtx);
auto privVal = fir::LoadOp::create(builder, loc, ompPrivVar);
return privVal;
},
[&](const auto &expr) -> mlir::Value {
mlir::Value exprResult = fir::getBase(convertExprToValue(
loc, converter, clause->v, symMap, stmtCtx));
// Conversion can either give a value or a refrence to a value,
// we need to return the reduction type, so an optional load may
// be generated.
if (auto refType = llvm::dyn_cast<fir::ReferenceType>(
exprResult.getType()))
if (ompPrivVar.getType() == refType)
exprResult = fir::LoadOp::create(builder, loc, exprResult);
return exprResult;
}},
clause->v.u);
stmtCtx.finalizeAndPop();
return result;
};
return true;
}
return false;
}

bool ClauseProcessor::processMergeable(
mlir::omp::MergeableClauseOps &result) const {
return markClauseOccurrence<omp::clause::Mergeable>(result.mergeable);
Expand Down
4 changes: 4 additions & 0 deletions flang/lib/Lower/OpenMP/ClauseProcessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "flang/Lower/Bridge.h"
#include "flang/Lower/DirectivesCommon.h"
#include "flang/Lower/OpenMP/Clauses.h"
#include "flang/Lower/Support/ReductionProcessor.h"
#include "flang/Optimizer/Builder/Todo.h"
#include "flang/Parser/dump-parse-tree.h"
#include "flang/Parser/parse-tree.h"
Expand Down Expand Up @@ -88,6 +89,9 @@ class ClauseProcessor {
bool processHint(mlir::omp::HintClauseOps &result) const;
bool processInclusive(mlir::Location currentLocation,
mlir::omp::InclusiveClauseOps &result) const;
bool processInitializer(
lower::SymMap &symMap, const parser::OmpClause::Initializer &inp,
ReductionProcessor::GenInitValueCBTy &genInitValueCB) const;
bool processMergeable(mlir::omp::MergeableClauseOps &result) const;
bool processNogroup(mlir::omp::NogroupClauseOps &result) const;
bool processNowait(mlir::omp::NowaitClauseOps &result) const;
Expand Down
17 changes: 16 additions & 1 deletion flang/lib/Lower/OpenMP/Clauses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -981,7 +981,22 @@ Init make(const parser::OmpClause::Init &inp,

Initializer make(const parser::OmpClause::Initializer &inp,
semantics::SemanticsContext &semaCtx) {
llvm_unreachable("Empty: initializer");
const parser::OmpInitializerExpression &iexpr = inp.v.v;
const parser::OmpStylizedInstance &styleInstance = iexpr.v.front();
const parser::OmpStylizedInstance::Instance &instance =
std::get<parser::OmpStylizedInstance::Instance>(styleInstance.t);
if (const auto *as = std::get_if<parser::AssignmentStmt>(&instance.u)) {
auto &expr = std::get<parser::Expr>(as->t);
return Initializer{makeExpr(expr, semaCtx)};
} else if (const auto *call = std::get_if<parser::CallStmt>(&instance.u)) {
if (call->typedCall) {
const auto &procRef = *call->typedCall;
semantics::SomeExpr evalProcRef{procRef};
return Initializer{evalProcRef};
}
}

llvm_unreachable("Unexpected initializer");
}

InReduction make(const parser::OmpClause::InReduction &inp,
Expand Down
152 changes: 149 additions & 3 deletions flang/lib/Lower/OpenMP/OpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,15 @@
#include "Decomposer.h"
#include "Utils.h"
#include "flang/Common/idioms.h"
#include "flang/Evaluate/type.h"
#include "flang/Lower/Bridge.h"
#include "flang/Lower/ConvertExpr.h"
#include "flang/Lower/ConvertExprToHLFIR.h"
#include "flang/Lower/ConvertVariable.h"
#include "flang/Lower/DirectivesCommon.h"
#include "flang/Lower/OpenMP/Clauses.h"
#include "flang/Lower/StatementContext.h"
#include "flang/Lower/Support/ReductionProcessor.h"
#include "flang/Lower/SymbolMap.h"
#include "flang/Optimizer/Builder/BoxValue.h"
#include "flang/Optimizer/Builder/FIRBuilder.h"
Expand Down Expand Up @@ -2847,7 +2850,6 @@ genTeamsOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
// TODO: Add private syms and vars.
args.reduction.syms = reductionSyms;
args.reduction.vars = clauseOps.reductionVars;

return genOpWithBody<mlir::omp::TeamsOp>(
OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval,
llvm::omp::Directive::OMPD_teams)
Expand Down Expand Up @@ -3563,12 +3565,156 @@ genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
TODO(converter.getCurrentLocation(), "OmpDeclareVariantDirective");
}

static ReductionProcessor::GenCombinerCBTy
processReductionCombiner(lower::AbstractConverter &converter,
lower::SymMap &symTable,
semantics::SemanticsContext &semaCtx,
const parser::OmpReductionSpecifier &specifier) {
ReductionProcessor::GenCombinerCBTy genCombinerCB;
const auto &combinerExpression =
std::get<std::optional<parser::OmpCombinerExpression>>(specifier.t)
.value();
const parser::OmpStylizedInstance &combinerInstance =
combinerExpression.v.front();
const parser::OmpStylizedInstance::Instance &instance =
std::get<parser::OmpStylizedInstance::Instance>(combinerInstance.t);

const auto *as = std::get_if<parser::AssignmentStmt>(&instance.u);
if (!as) {
TODO(converter.getCurrentLocation(),
"A combiner that is a subroutine call is not yet supported");
}
auto &expr = std::get<parser::Expr>(as->t);
genCombinerCB = [&](fir::FirOpBuilder &builder, mlir::Location loc,
mlir::Type type, mlir::Value lhs, mlir::Value rhs,
bool isByRef) {
const auto &evalExpr = makeExpr(expr, semaCtx);
lower::SymMapScope scope(symTable);
const std::list<parser::OmpStylizedDeclaration> &declList =
std::get<std::list<parser::OmpStylizedDeclaration>>(combinerInstance.t);
for (const parser::OmpStylizedDeclaration &decl : declList) {
auto &name = std::get<parser::ObjectName>(decl.var.t);
mlir::Value addr = lhs;
mlir::Type type = lhs.getType();
bool isRhs = name.ToString() == std::string("omp_in");
if (isRhs) {
addr = rhs;
type = rhs.getType();
}

assert(name.symbol && "Reduction object name does not have a symbol");
if (!fir::conformsWithPassByRef(type)) {
addr = builder.createTemporary(loc, type);
fir::StoreOp::create(builder, loc, isRhs ? rhs : lhs, addr);
}
fir::FortranVariableFlagsEnum extraFlags = {};
fir::FortranVariableFlagsAttr attributes =
Fortran::lower::translateSymbolAttributes(builder.getContext(),
*name.symbol, extraFlags);
auto declareOp =
hlfir::DeclareOp::create(builder, loc, addr, name.ToString(), nullptr,
{}, nullptr, nullptr, 0, attributes);
symTable.addVariableDefinition(*name.symbol, declareOp);
}

lower::StatementContext stmtCtx;
mlir::Value result = fir::getBase(
convertExprToValue(loc, converter, evalExpr, symTable, stmtCtx));
if (auto refType = llvm::dyn_cast<fir::ReferenceType>(result.getType()))
if (lhs.getType() == refType.getElementType())
result = fir::LoadOp::create(builder, loc, result);
stmtCtx.finalizeAndPop();
if (isByRef) {
fir::StoreOp::create(builder, loc, result, lhs);
mlir::omp::YieldOp::create(builder, loc, lhs);
} else {
mlir::omp::YieldOp::create(builder, loc, result);
}
};
return genCombinerCB;
}

// Checks that the reduction type is either a trivial type or a derived type of
// trivial types.
static bool isSimpleReductionType(mlir::Type reductionType) {
if (fir::isa_trivial(reductionType))
return true;
if (auto recordTy = mlir::dyn_cast<fir::RecordType>(reductionType)) {
for (auto [_, fieldType] : recordTy.getTypeList()) {
if (!fir::isa_trivial(fieldType))
return false;
}
}
return true;
}

// Getting the type from a symbol compared to a DeclSpec is simpler since we do
// not need to consider derived vs intrinsic types. Semantics is guaranteed to
// generate these symbols.
static mlir::Type
getReductionType(lower::AbstractConverter &converter,
const parser::OmpReductionSpecifier &specifier) {
const auto &combinerExpression =
std::get<std::optional<parser::OmpCombinerExpression>>(specifier.t)
.value();
const parser::OmpStylizedInstance &combinerInstance =
combinerExpression.v.front();
const std::list<parser::OmpStylizedDeclaration> &declList =
std::get<std::list<parser::OmpStylizedDeclaration>>(combinerInstance.t);
const parser::OmpStylizedDeclaration &decl = declList.front();
const auto &name = std::get<parser::ObjectName>(decl.var.t);
const auto &symbol = semantics::SymbolRef(*name.symbol);
mlir::Type reductionType = converter.genType(symbol);

if (!isSimpleReductionType(reductionType))
TODO(converter.getCurrentLocation(),
"declare reduction currently only supports trival types or derived "
"types containing trivial types");
return reductionType;
}

static void genOMP(
lower::AbstractConverter &converter, lower::SymMap &symTable,
semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
const parser::OpenMPDeclareReductionConstruct &declareReductionConstruct) {
if (!semaCtx.langOptions().OpenMPSimd)
TODO(converter.getCurrentLocation(), "OpenMPDeclareReductionConstruct");
if (semaCtx.langOptions().OpenMPSimd)
return;

const parser::OmpArgumentList &args{declareReductionConstruct.v.Arguments()};
const parser::OmpArgument &arg{args.v.front()};
const auto &specifier = std::get<parser::OmpReductionSpecifier>(arg.u);

if (std::get<parser::OmpTypeNameList>(specifier.t).v.size() > 1)
TODO(converter.getCurrentLocation(),
"multiple types in declare reduction is not yet supported");

mlir::Type reductionType = getReductionType(converter, specifier);
ReductionProcessor::GenCombinerCBTy genCombinerCB =
processReductionCombiner(converter, symTable, semaCtx, specifier);
const parser::OmpClauseList &initializer =
declareReductionConstruct.v.Clauses();
if (initializer.v.size() > 0) {
List<Clause> clauses = makeClauses(initializer, semaCtx);
ReductionProcessor::GenInitValueCBTy genInitValueCB;
ClauseProcessor cp(converter, semaCtx, clauses);
const parser::OmpClause::Initializer &iclause{
std::get<parser::OmpClause::Initializer>(initializer.v.front().u)};
cp.processInitializer(symTable, iclause, genInitValueCB);
const auto &identifier =
std::get<parser::OmpReductionIdentifier>(specifier.t);
const auto &designator =
std::get<parser::ProcedureDesignator>(identifier.u);
const auto &reductionName = std::get<parser::Name>(designator.u);
bool isByRef = ReductionProcessor::doReductionByRef(reductionType);
ReductionProcessor::createDeclareReductionHelper<
mlir::omp::DeclareReductionOp>(
converter, reductionName.ToString(), reductionType,
converter.getCurrentLocation(), isByRef, genCombinerCB, genInitValueCB);
} else {
TODO(converter.getCurrentLocation(),
"declare reduction without an initializer clause is not yet "
"supported");
}
}

static void
Expand Down
Loading