Skip to content

Commit

Permalink
[flang][OpenMP] simplify getReductionName (#85666)
Browse files Browse the repository at this point in the history
Re-use fir::getTypeAsString instead of creating something new here. This
spells integer names like i32 instead of i_32 so there is a lot of test
churn.
  • Loading branch information
tblah committed Mar 20, 2024
1 parent 576d81b commit 3deaa77
Show file tree
Hide file tree
Showing 38 changed files with 163 additions and 213 deletions.
75 changes: 22 additions & 53 deletions flang/lib/Lower/OpenMP/ReductionProcessor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,10 @@ bool ReductionProcessor::supportedIntrinsicProcReduction(
return redType;
}

std::string ReductionProcessor::getReductionName(llvm::StringRef name,
mlir::Type ty, bool isByRef) {
std::string
ReductionProcessor::getReductionName(llvm::StringRef name,
const fir::KindMapping &kindMap,
mlir::Type ty, bool isByRef) {
ty = fir::unwrapRefType(ty);

// extra string to distinguish reduction functions for variables passed by
Expand All @@ -91,47 +93,12 @@ std::string ReductionProcessor::getReductionName(llvm::StringRef name,
if (isByRef)
byrefAddition = "_byref";

if (fir::isa_trivial(ty))
return (llvm::Twine(name) +
(ty.isIntOrIndex() ? llvm::Twine("_i_") : llvm::Twine("_f_")) +
llvm::Twine(ty.getIntOrFloatBitWidth()) + byrefAddition)
.str();

// creates a name like reduction_i_64_box_ux4x3
if (auto boxTy = mlir::dyn_cast_or_null<fir::BaseBoxType>(ty)) {
// TODO: support for allocatable boxes:
// !fir.box<!fir.heap<!fir.array<...>>>
fir::SequenceType seqTy = fir::unwrapRefType(boxTy.getEleTy())
.dyn_cast_or_null<fir::SequenceType>();
if (!seqTy)
return {};

std::string prefix = getReductionName(
name, fir::unwrapSeqOrBoxedSeqType(ty), /*isByRef=*/false);
if (prefix.empty())
return {};
std::stringstream tyStr;
tyStr << prefix << "_box_";
bool first = true;
for (std::int64_t extent : seqTy.getShape()) {
if (first)
first = false;
else
tyStr << "x";
if (extent == seqTy.getUnknownExtent())
tyStr << 'u'; // I'm not sure that '?' is safe in symbol names
else
tyStr << extent;
}
return (tyStr.str() + byrefAddition).str();
}

return {};
return fir::getTypeAsString(ty, kindMap, (name + byrefAddition).str());
}

std::string ReductionProcessor::getReductionName(
omp::clause::DefinedOperator::IntrinsicOperator intrinsicOp, mlir::Type ty,
bool isByRef) {
omp::clause::DefinedOperator::IntrinsicOperator intrinsicOp,
const fir::KindMapping &kindMap, mlir::Type ty, bool isByRef) {
std::string reductionName;

switch (intrinsicOp) {
Expand All @@ -154,17 +121,17 @@ std::string ReductionProcessor::getReductionName(
break;
}

return getReductionName(reductionName, ty, isByRef);
return getReductionName(reductionName, kindMap, ty, isByRef);
}

mlir::Value
ReductionProcessor::getReductionInitValue(mlir::Location loc, mlir::Type type,
ReductionIdentifier redId,
fir::FirOpBuilder &builder) {
type = fir::unwrapRefType(type);
assert((fir::isa_integer(type) || fir::isa_real(type) ||
type.isa<fir::LogicalType>()) &&
"only integer, logical and real types are currently supported");
if (!fir::isa_integer(type) && !fir::isa_real(type) &&
!mlir::isa<fir::LogicalType>(type))
TODO(loc, "Reduction of some types is not supported");
switch (redId) {
case ReductionIdentifier::MAX: {
if (auto ty = type.dyn_cast<mlir::FloatType>()) {
Expand Down Expand Up @@ -463,8 +430,7 @@ mlir::omp::DeclareReductionOp ReductionProcessor::createDeclareReduction(
mlir::OpBuilder::InsertionGuard guard(builder);
mlir::ModuleOp module = builder.getModule();

if (reductionOpName.empty())
TODO(loc, "Reduction of some types is not supported");
assert(!reductionOpName.empty());

auto decl =
module.lookupSymbol<mlir::omp::DeclareReductionOp>(reductionOpName);
Expand Down Expand Up @@ -601,15 +567,18 @@ void ReductionProcessor::addDeclareReduction(

for (mlir::Value symVal : reductionVars) {
auto redType = mlir::cast<fir::ReferenceType>(symVal.getType());
const auto &kindMap = firOpBuilder.getKindMap();
if (redType.getEleTy().isa<fir::LogicalType>())
decl = createDeclareReduction(
firOpBuilder,
getReductionName(intrinsicOp, firOpBuilder.getI1Type(), isByRef),
redId, redType, currentLocation, isByRef);
decl = createDeclareReduction(firOpBuilder,
getReductionName(intrinsicOp, kindMap,
firOpBuilder.getI1Type(),
isByRef),
redId, redType, currentLocation, isByRef);
else
decl = createDeclareReduction(
firOpBuilder, getReductionName(intrinsicOp, redType, isByRef),
redId, redType, currentLocation, isByRef);
firOpBuilder,
getReductionName(intrinsicOp, kindMap, redType, isByRef), redId,
redType, currentLocation, isByRef);
reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get(
firOpBuilder.getContext(), decl.getSymName()));
}
Expand All @@ -631,7 +600,7 @@ void ReductionProcessor::addDeclareReduction(
decl = createDeclareReduction(
firOpBuilder,
getReductionName(getRealName(*reductionIntrinsic).ToString(),
redType, isByRef),
firOpBuilder.getKindMap(), redType, isByRef),
redId, redType, currentLocation, isByRef);
reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get(
firOpBuilder.getContext(), decl.getSymName()));
Expand Down
8 changes: 5 additions & 3 deletions flang/lib/Lower/OpenMP/ReductionProcessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,14 @@ class ReductionProcessor {
static bool
doReductionByRef(const llvm::SmallVectorImpl<mlir::Value> &reductionVars);

static std::string getReductionName(llvm::StringRef name, mlir::Type ty,
bool isByRef);
static std::string getReductionName(llvm::StringRef name,
const fir::KindMapping &kindMap,
mlir::Type ty, bool isByRef);

static std::string
getReductionName(omp::clause::DefinedOperator::IntrinsicOperator intrinsicOp,
mlir::Type ty, bool isByRef);
const fir::KindMapping &kindMap, mlir::Type ty,
bool isByRef);

/// This function returns the identity value of the operator \p
/// reductionOpName. For example:
Expand Down
22 changes: 11 additions & 11 deletions flang/test/Lower/OpenMP/FIR/wsloop-reduction-add-byref.f90
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
! RUN: %flang_fc1 -emit-fir -flang-deprecated-no-hlfir -fopenmp -mmlir --force-byref-reduction %s -o - | FileCheck %s
! NOTE: Assertions have been autogenerated by utils/generate-test-checks.py

! CHECK-LABEL: omp.declare_reduction @add_reduction_f_64_byref : !fir.ref<f64>
! CHECK-LABEL: omp.declare_reduction @add_reduction_byref_f64 : !fir.ref<f64>
! CHECK-SAME: init {
! CHECK: ^bb0(%[[VAL_0:.*]]: !fir.ref<f64>):
! CHECK: %[[C0_1:.*]] = arith.constant 0.000000e+00 : f64
Expand All @@ -19,7 +19,7 @@
! CHECK: omp.yield(%[[ARG0]] : !fir.ref<f64>)
! CHECK: }

! CHECK-LABEL: omp.declare_reduction @add_reduction_i_64_byref : !fir.ref<i64>
! CHECK-LABEL: omp.declare_reduction @add_reduction_byref_i64 : !fir.ref<i64>
! CHECK-SAME: init {
! CHECK: ^bb0(%[[VAL_0:.*]]: !fir.ref<i64>):
! CHECK: %[[C0_1:.*]] = arith.constant 0 : i64
Expand All @@ -36,7 +36,7 @@
! CHECK: omp.yield(%[[ARG0]] : !fir.ref<i64>)
! CHECK: }

! CHECK-LABEL: omp.declare_reduction @add_reduction_f_32_byref : !fir.ref<f32>
! CHECK-LABEL: omp.declare_reduction @add_reduction_byref_f32 : !fir.ref<f32>
! CHECK-SAME: init {
! CHECK: ^bb0(%[[VAL_0:.*]]: !fir.ref<f32>):
! CHECK: %[[C0_1:.*]] = arith.constant 0.000000e+00 : f32
Expand All @@ -53,7 +53,7 @@
! CHECK: omp.yield(%[[ARG0]] : !fir.ref<f32>)
! CHECK: }

! CHECK-LABEL: omp.declare_reduction @add_reduction_i_32_byref : !fir.ref<i32>
! CHECK-LABEL: omp.declare_reduction @add_reduction_byref_i32 : !fir.ref<i32>
! CHECK-SAME: init {
! CHECK: ^bb0(%[[VAL_0:.*]]: !fir.ref<i32>):
! CHECK: %[[C0_1:.*]] = arith.constant 0 : i32
Expand All @@ -80,7 +80,7 @@
! CHECK: %[[VAL_4:.*]] = arith.constant 1 : i32
! CHECK: %[[VAL_5:.*]] = arith.constant 100 : i32
! CHECK: %[[VAL_6:.*]] = arith.constant 1 : i32
! CHECK: omp.wsloop byref reduction(@add_reduction_i_32_byref %[[VAL_1]] -> %[[VAL_7:.*]] : !fir.ref<i32>) for (%[[VAL_8:.*]]) : i32 = (%[[VAL_4]]) to (%[[VAL_5]]) inclusive step (%[[VAL_6]]) {
! CHECK: omp.wsloop byref reduction(@add_reduction_byref_i32 %[[VAL_1]] -> %[[VAL_7:.*]] : !fir.ref<i32>) for (%[[VAL_8:.*]]) : i32 = (%[[VAL_4]]) to (%[[VAL_5]]) inclusive step (%[[VAL_6]]) {
! CHECK: fir.store %[[VAL_8]] to %[[VAL_3]] : !fir.ref<i32>
! CHECK: %[[VAL_9:.*]] = fir.load %[[VAL_7]] : !fir.ref<i32>
! CHECK: %[[VAL_10:.*]] = fir.load %[[VAL_3]] : !fir.ref<i32>
Expand Down Expand Up @@ -116,7 +116,7 @@ subroutine simple_int_reduction
! CHECK: %[[VAL_4:.*]] = arith.constant 1 : i32
! CHECK: %[[VAL_5:.*]] = arith.constant 100 : i32
! CHECK: %[[VAL_6:.*]] = arith.constant 1 : i32
! CHECK: omp.wsloop byref reduction(@add_reduction_f_32_byref %[[VAL_1]] -> %[[VAL_7:.*]] : !fir.ref<f32>) for (%[[VAL_8:.*]]) : i32 = (%[[VAL_4]]) to (%[[VAL_5]]) inclusive step (%[[VAL_6]]) {
! CHECK: omp.wsloop byref reduction(@add_reduction_byref_f32 %[[VAL_1]] -> %[[VAL_7:.*]] : !fir.ref<f32>) for (%[[VAL_8:.*]]) : i32 = (%[[VAL_4]]) to (%[[VAL_5]]) inclusive step (%[[VAL_6]]) {
! CHECK: fir.store %[[VAL_8]] to %[[VAL_3]] : !fir.ref<i32>
! CHECK: %[[VAL_9:.*]] = fir.load %[[VAL_7]] : !fir.ref<f32>
! CHECK: %[[VAL_10:.*]] = fir.load %[[VAL_3]] : !fir.ref<i32>
Expand Down Expand Up @@ -152,7 +152,7 @@ subroutine simple_real_reduction
! CHECK: %[[VAL_4:.*]] = arith.constant 1 : i32
! CHECK: %[[VAL_5:.*]] = arith.constant 100 : i32
! CHECK: %[[VAL_6:.*]] = arith.constant 1 : i32
! CHECK: omp.wsloop byref reduction(@add_reduction_i_32_byref %[[VAL_1]] -> %[[VAL_7:.*]] : !fir.ref<i32>) for (%[[VAL_8:.*]]) : i32 = (%[[VAL_4]]) to (%[[VAL_5]]) inclusive step (%[[VAL_6]]) {
! CHECK: omp.wsloop byref reduction(@add_reduction_byref_i32 %[[VAL_1]] -> %[[VAL_7:.*]] : !fir.ref<i32>) for (%[[VAL_8:.*]]) : i32 = (%[[VAL_4]]) to (%[[VAL_5]]) inclusive step (%[[VAL_6]]) {
! CHECK: fir.store %[[VAL_8]] to %[[VAL_3]] : !fir.ref<i32>
! CHECK: %[[VAL_9:.*]] = fir.load %[[VAL_3]] : !fir.ref<i32>
! CHECK: %[[VAL_10:.*]] = fir.load %[[VAL_7]] : !fir.ref<i32>
Expand Down Expand Up @@ -187,7 +187,7 @@ subroutine simple_int_reduction_switch_order
! CHECK: %[[VAL_4:.*]] = arith.constant 1 : i32
! CHECK: %[[VAL_5:.*]] = arith.constant 100 : i32
! CHECK: %[[VAL_6:.*]] = arith.constant 1 : i32
! CHECK: omp.wsloop byref reduction(@add_reduction_f_32_byref %[[VAL_1]] -> %[[VAL_7:.*]] : !fir.ref<f32>) for (%[[VAL_8:.*]]) : i32 = (%[[VAL_4]]) to (%[[VAL_5]]) inclusive step (%[[VAL_6]]) {
! CHECK: omp.wsloop byref reduction(@add_reduction_byref_f32 %[[VAL_1]] -> %[[VAL_7:.*]] : !fir.ref<f32>) for (%[[VAL_8:.*]]) : i32 = (%[[VAL_4]]) to (%[[VAL_5]]) inclusive step (%[[VAL_6]]) {
! CHECK: fir.store %[[VAL_8]] to %[[VAL_3]] : !fir.ref<i32>
! CHECK: %[[VAL_9:.*]] = fir.load %[[VAL_3]] : !fir.ref<i32>
! CHECK: %[[VAL_10:.*]] = fir.convert %[[VAL_9]] : (i32) -> f32
Expand Down Expand Up @@ -229,7 +229,7 @@ subroutine simple_real_reduction_switch_order
! CHECK: %[[VAL_8:.*]] = arith.constant 1 : i32
! CHECK: %[[VAL_9:.*]] = arith.constant 100 : i32
! CHECK: %[[VAL_10:.*]] = arith.constant 1 : i32
! CHECK: omp.wsloop byref reduction(@add_reduction_i_32_byref %[[VAL_1]] -> %[[VAL_11:.*]] : !fir.ref<i32>, @add_reduction_i_32_byref %[[VAL_2]] -> %[[VAL_12:.*]] : !fir.ref<i32>, @add_reduction_i_32_byref %[[VAL_3]] -> %[[VAL_13:.*]] : !fir.ref<i32>) for (%[[VAL_14:.*]]) : i32 = (%[[VAL_8]]) to (%[[VAL_9]]) inclusive step (%[[VAL_10]]) {
! CHECK: omp.wsloop byref reduction(@add_reduction_byref_i32 %[[VAL_1]] -> %[[VAL_11:.*]] : !fir.ref<i32>, @add_reduction_byref_i32 %[[VAL_2]] -> %[[VAL_12:.*]] : !fir.ref<i32>, @add_reduction_byref_i32 %[[VAL_3]] -> %[[VAL_13:.*]] : !fir.ref<i32>) for (%[[VAL_14:.*]]) : i32 = (%[[VAL_8]]) to (%[[VAL_9]]) inclusive step (%[[VAL_10]]) {
! CHECK: fir.store %[[VAL_14]] to %[[VAL_7]] : !fir.ref<i32>
! CHECK: %[[VAL_15:.*]] = fir.load %[[VAL_11]] : !fir.ref<i32>
! CHECK: %[[VAL_16:.*]] = fir.load %[[VAL_7]] : !fir.ref<i32>
Expand Down Expand Up @@ -282,7 +282,7 @@ subroutine multiple_int_reductions_same_type
! CHECK: %[[VAL_8:.*]] = arith.constant 1 : i32
! CHECK: %[[VAL_9:.*]] = arith.constant 100 : i32
! CHECK: %[[VAL_10:.*]] = arith.constant 1 : i32
! CHECK: omp.wsloop byref reduction(@add_reduction_f_32_byref %[[VAL_1]] -> %[[VAL_11:.*]] : !fir.ref<f32>, @add_reduction_f_32_byref %[[VAL_2]] -> %[[VAL_12:.*]] : !fir.ref<f32>, @add_reduction_f_32_byref %[[VAL_3]] -> %[[VAL_13:.*]] : !fir.ref<f32>) for (%[[VAL_14:.*]]) : i32 = (%[[VAL_8]]) to (%[[VAL_9]]) inclusive step (%[[VAL_10]]) {
! CHECK: omp.wsloop byref reduction(@add_reduction_byref_f32 %[[VAL_1]] -> %[[VAL_11:.*]] : !fir.ref<f32>, @add_reduction_byref_f32 %[[VAL_2]] -> %[[VAL_12:.*]] : !fir.ref<f32>, @add_reduction_byref_f32 %[[VAL_3]] -> %[[VAL_13:.*]] : !fir.ref<f32>) for (%[[VAL_14:.*]]) : i32 = (%[[VAL_8]]) to (%[[VAL_9]]) inclusive step (%[[VAL_10]]) {
! CHECK: fir.store %[[VAL_14]] to %[[VAL_7]] : !fir.ref<i32>
! CHECK: %[[VAL_15:.*]] = fir.load %[[VAL_11]] : !fir.ref<f32>
! CHECK: %[[VAL_16:.*]] = fir.load %[[VAL_7]] : !fir.ref<i32>
Expand Down Expand Up @@ -341,7 +341,7 @@ subroutine multiple_real_reductions_same_type
! CHECK: %[[VAL_10:.*]] = arith.constant 1 : i32
! CHECK: %[[VAL_11:.*]] = arith.constant 100 : i32
! CHECK: %[[VAL_12:.*]] = arith.constant 1 : i32
! CHECK: omp.wsloop byref reduction(@add_reduction_i_32_byref %[[VAL_2]] -> %[[VAL_13:.*]] : !fir.ref<i32>, @add_reduction_i_64_byref %[[VAL_3]] -> %[[VAL_14:.*]] : !fir.ref<i64>, @add_reduction_f_32_byref %[[VAL_4]] -> %[[VAL_15:.*]] : !fir.ref<f32>, @add_reduction_f_64_byref %[[VAL_1]] -> %[[VAL_16:.*]] : !fir.ref<f64>) for (%[[VAL_17:.*]]) : i32 = (%[[VAL_10]]) to (%[[VAL_11]]) inclusive step (%[[VAL_12]]) {
! CHECK: omp.wsloop byref reduction(@add_reduction_byref_i32 %[[VAL_2]] -> %[[VAL_13:.*]] : !fir.ref<i32>, @add_reduction_byref_i64 %[[VAL_3]] -> %[[VAL_14:.*]] : !fir.ref<i64>, @add_reduction_byref_f32 %[[VAL_4]] -> %[[VAL_15:.*]] : !fir.ref<f32>, @add_reduction_byref_f64 %[[VAL_1]] -> %[[VAL_16:.*]] : !fir.ref<f64>) for (%[[VAL_17:.*]]) : i32 = (%[[VAL_10]]) to (%[[VAL_11]]) inclusive step (%[[VAL_12]]) {
! CHECK: fir.store %[[VAL_17]] to %[[VAL_9]] : !fir.ref<i32>
! CHECK: %[[VAL_18:.*]] = fir.load %[[VAL_13]] : !fir.ref<i32>
! CHECK: %[[VAL_19:.*]] = fir.load %[[VAL_9]] : !fir.ref<i32>
Expand Down

0 comments on commit 3deaa77

Please sign in to comment.