Skip to content

Commit

Permalink
[flang][openacc] Use the array section for assumed shape array reduct…
Browse files Browse the repository at this point in the history
…ion (#68147)

Use the bounds information in the reduction recipe for assumed shape
arrays.
  • Loading branch information
clementval committed Oct 3, 2023
1 parent 25a6b89 commit 438a6ef
Show file tree
Hide file tree
Showing 4 changed files with 157 additions and 35 deletions.
3 changes: 3 additions & 0 deletions flang/include/flang/Optimizer/Dialect/FIRType.h
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,9 @@ bool isUnlimitedPolymorphicType(mlir::Type ty);
/// Return true iff `ty` is the type of an assumed type.
bool isAssumedType(mlir::Type ty);

/// Return true iff `ty` is the type of an assumed shape array.
bool isAssumedShape(mlir::Type ty);

/// Return true iff `boxTy` wraps a record type or an unlimited polymorphic
/// entity. Polymorphic entities with intrinsic type spec do not have addendum
inline bool boxHasAddendum(fir::BaseBoxType boxTy) {
Expand Down
122 changes: 97 additions & 25 deletions flang/lib/Lower/OpenACC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -890,6 +890,52 @@ static mlir::Value genScalarCombiner(fir::FirOpBuilder &builder,
TODO(loc, "reduction operator");
}

static fir::ShapeOp
genShapeFromBounds(mlir::Location loc, fir::FirOpBuilder &builder,
const llvm::SmallVector<mlir::Value> &args) {
assert(args.size() % 3 == 0 && "Triplets must be a multiple of 3");
llvm::SmallVector<mlir::Value> extents;
mlir::Type idxTy = builder.getIndexType();
mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
mlir::Value zero = builder.createIntegerConstant(loc, idxTy, 0);
for (unsigned i = 0; i < args.size(); i += 3) {
mlir::Value s1 =
builder.create<mlir::arith::SubIOp>(loc, args[i + 1], args[0]);
mlir::Value s2 = builder.create<mlir::arith::AddIOp>(loc, s1, one);
mlir::Value s3 = builder.create<mlir::arith::DivSIOp>(loc, s2, args[i + 2]);
mlir::Value cmp = builder.create<mlir::arith::CmpIOp>(
loc, mlir::arith::CmpIPredicate::sgt, s3, zero);
mlir::Value ext = builder.create<mlir::arith::SelectOp>(loc, cmp, s3, zero);
extents.push_back(ext);
}
return builder.create<fir::ShapeOp>(loc, extents);
}

static llvm::SmallVector<mlir::Value>
genConstantBounds(fir::FirOpBuilder &builder, mlir::Location loc,
mlir::acc::DataBoundsOp &dataBound) {
mlir::Type idxTy = builder.getIndexType();
mlir::Value lb, ub, step;
if (dataBound.getLowerbound() &&
fir::getIntIfConstant(dataBound.getLowerbound()) &&
dataBound.getUpperbound() &&
fir::getIntIfConstant(dataBound.getUpperbound())) {
lb = builder.createIntegerConstant(
loc, idxTy, *fir::getIntIfConstant(dataBound.getLowerbound()));
ub = builder.createIntegerConstant(
loc, idxTy, *fir::getIntIfConstant(dataBound.getUpperbound()));
step = builder.createIntegerConstant(loc, idxTy, 1);
} else if (dataBound.getExtent()) {
lb = builder.createIntegerConstant(loc, idxTy, 0);
ub = builder.createIntegerConstant(
loc, idxTy, *fir::getIntIfConstant(dataBound.getExtent()) - 1);
step = builder.createIntegerConstant(loc, idxTy, 1);
} else {
llvm::report_fatal_error("Expect constant lb/ub or extent");
}
return {lb, ub, step};
}

static void genCombiner(fir::FirOpBuilder &builder, mlir::Location loc,
mlir::acc::ReductionOperator op, mlir::Type ty,
mlir::Value value1, mlir::Value value2,
Expand All @@ -907,30 +953,14 @@ static void genCombiner(fir::FirOpBuilder &builder, mlir::Location loc,
if (allConstantBound) {
// Use the constant bound directly in the combiner region so they do not
// need to be passed as block argument.
mlir::Type idxTy = builder.getIndexType();
for (auto bound : llvm::reverse(bounds)) {
auto dataBound =
mlir::dyn_cast<mlir::acc::DataBoundsOp>(bound.getDefiningOp());
mlir::Value lb, ub, step;
if (dataBound.getLowerbound() &&
fir::getIntIfConstant(dataBound.getLowerbound()) &&
dataBound.getUpperbound() &&
fir::getIntIfConstant(dataBound.getUpperbound())) {
lb = builder.createIntegerConstant(
loc, idxTy, *fir::getIntIfConstant(dataBound.getLowerbound()));
ub = builder.createIntegerConstant(
loc, idxTy, *fir::getIntIfConstant(dataBound.getUpperbound()));
step = builder.createIntegerConstant(loc, idxTy, 1);
} else if (dataBound.getExtent()) {
lb = builder.createIntegerConstant(loc, idxTy, 0);
ub = builder.createIntegerConstant(
loc, idxTy, *fir::getIntIfConstant(dataBound.getExtent()) - 1);
step = builder.createIntegerConstant(loc, idxTy, 1);
} else {
llvm::report_fatal_error("Expect constant lb/ub or extent");
}
auto loop = builder.create<fir::DoLoopOp>(loc, lb, ub, step,
/*unordered=*/false);
llvm::SmallVector<mlir::Value> values =
genConstantBounds(builder, loc, dataBound);
auto loop =
builder.create<fir::DoLoopOp>(loc, values[0], values[1], values[2],
/*unordered=*/false);
builder.setInsertionPointToStart(loop.getBody());
loops.push_back(loop);
ivs.push_back(loop.getInductionVar());
Expand Down Expand Up @@ -962,13 +992,55 @@ static void genCombiner(fir::FirOpBuilder &builder, mlir::Location loc,
builder.create<fir::StoreOp>(loc, res, addr1);
builder.setInsertionPointAfter(loops[0]);
} else if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(ty)) {
llvm::SmallVector<mlir::Value> tripletArgs;
fir::SequenceType seqTy =
mlir::dyn_cast_or_null<fir::SequenceType>(boxTy.getEleTy());
if (!seqTy)
TODO(loc, "Unsupported boxed type in OpenACC reduction");
hlfir::Entity left = hlfir::Entity{value1};
hlfir::Entity right = hlfir::Entity{value2};
auto shape = hlfir::genShape(loc, builder, left);

if (allConstantBound) {
for (auto bound : llvm::reverse(bounds)) {
auto dataBound =
mlir::cast<mlir::acc::DataBoundsOp>(bound.getDefiningOp());
tripletArgs.append(genConstantBounds(builder, loc, dataBound));
}
} else {
assert(((recipe.getCombinerRegion().getArguments().size() - 2) / 3 ==
seqTy.getDimension()) &&
"Expect 3 block arguments per dimension");
for (auto arg : recipe.getCombinerRegion().getArguments().drop_front(2))
tripletArgs.push_back(arg);
}
auto shape = genShapeFromBounds(loc, builder, tripletArgs);

hlfir::DesignateOp::Subscripts triplets;
for (unsigned i = 2; i < recipe.getCombinerRegion().getArguments().size();
i += 3)
triplets.emplace_back(hlfir::DesignateOp::Triplet{
recipe.getCombinerRegion().getArgument(i),
recipe.getCombinerRegion().getArgument(i + 1),
recipe.getCombinerRegion().getArgument(i + 2)});

llvm::SmallVector<mlir::Value> lenParamsLeft;
auto leftEntity = hlfir::Entity{value1};
hlfir::genLengthParameters(loc, builder, leftEntity, lenParamsLeft);
auto leftDesignate = builder.create<hlfir::DesignateOp>(
loc, value1.getType(), leftEntity, /*component=*/"",
/*componentShape=*/mlir::Value{}, triplets,
/*substring=*/mlir::ValueRange{}, /*complexPartAttr=*/std::nullopt,
shape, lenParamsLeft);
auto left = hlfir::Entity{leftDesignate.getResult()};

llvm::SmallVector<mlir::Value> lenParamsRight;
auto rightEntity = hlfir::Entity{value2};
hlfir::genLengthParameters(loc, builder, rightEntity, lenParamsRight);
auto rightDesignate = builder.create<hlfir::DesignateOp>(
loc, value2.getType(), rightEntity, /*component=*/"",
/*componentShape=*/mlir::Value{}, triplets,
/*substring=*/mlir::ValueRange{}, /*complexPartAttr=*/std::nullopt,
shape, lenParamsRight);
auto right = hlfir::Entity{rightDesignate.getResult()};

llvm::SmallVector<mlir::Value, 1> typeParams;
auto genKernel = [&builder, &loc, op, seqTy, &left, &right](
mlir::Location l, fir::FirOpBuilder &b,
Expand Down Expand Up @@ -1079,7 +1151,7 @@ genReductions(const Fortran::parser::AccObjectListWithReduction &objectList,
std::string recipeName = fir::getTypeAsString(
ty, converter.getKindMap(),
("reduction_" + stringifyReductionOperator(mlirOp)).str() + suffix);
if (!areAllBoundConstant(bounds))
if (!areAllBoundConstant(bounds) || fir::isAssumedShape(baseAddr.getType()))
ty = baseAddr.getType();
mlir::acc::ReductionRecipeOp recipe =
Fortran::lower::createOrGetReductionRecipe(
Expand Down
7 changes: 7 additions & 0 deletions flang/lib/Optimizer/Dialect/FIRType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,13 @@ bool isAssumedType(mlir::Type ty) {
return false;
}

bool isAssumedShape(mlir::Type ty) {
if (auto boxTy = mlir::dyn_cast<fir::BoxType>(ty))
if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(boxTy.getEleTy()))
return seqTy.hasDynamicExtents();
return false;
}

bool isPolymorphicType(mlir::Type ty) {
if (auto refTy = fir::dyn_cast_ptrEleTy(ty))
ty = refTy;
Expand Down
60 changes: 50 additions & 10 deletions flang/test/Lower/OpenACC/acc-reduction.f90
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,32 @@
! RUN: bbc -fopenacc -emit-fir %s -o - | FileCheck %s --check-prefixes=CHECK,FIR
! RUN: bbc -fopenacc -emit-hlfir %s -o - | FileCheck %s --check-prefixes=CHECK,HLFIR

! CHECK-LABEL: acc.reduction.recipe @"reduction_add_section_lb1.ub3_ref_?xi32" : !fir.box<!fir.array<?xi32>> reduction_operator <add> init {
! CHECK: ^bb0(%[[ARG0:.*]]: !fir.box<!fir.array<?xi32>>):
! HLFIR: %[[BOX_DIMS:.*]]:3 = fir.box_dims %[[ARG0]], %c0{{.*}} : (!fir.box<!fir.array<?xi32>>, index) -> (index, index, index)
! HLFIR: %[[SHAPE:.*]] = fir.shape %[[BOX_DIMS]]#1 : (index) -> !fir.shape<1>
! HLFIR: %[[TEMP:.*]] = fir.allocmem !fir.array<?xi32>, %0#1 {bindc_name = ".tmp", uniq_name = ""}
! HLFIR: %[[DECLARE:.*]]:2 = hlfir.declare %[[TEMP]](%[[SHAPE]]) {uniq_name = ".tmp"} : (!fir.heap<!fir.array<?xi32>>, !fir.shape<1>) -> (!fir.box<!fir.array<?xi32>>, !fir.heap<!fir.array<?xi32>>)
! HLFIR: hlfir.assign %c0{{.*}} to %[[DECLARE]]#0 : i32, !fir.box<!fir.array<?xi32>>
! HLFIR: acc.yield %[[DECLARE]]#0 : !fir.box<!fir.array<?xi32>>
! CHECK: } combiner {
! CHECK: ^bb0(%[[ARG0:.*]]: !fir.box<!fir.array<?xi32>>, %[[ARG1:.*]]: !fir.box<!fir.array<?xi32>>):
! HLFIR: %[[SHAPE:.*]] = fir.shape %{{.*}} : (index) -> !fir.shape<1>
! HLFIR: %[[DES1:.*]] = hlfir.designate %[[ARG0]] shape %[[SHAPE]] : (!fir.box<!fir.array<?xi32>>, !fir.shape<1>) -> !fir.box<!fir.array<?xi32>>
! HLFIR: %[[DES2:.*]] = hlfir.designate %[[ARG1]] shape %[[SHAPE]] : (!fir.box<!fir.array<?xi32>>, !fir.shape<1>) -> !fir.box<!fir.array<?xi32>>
! HLFIR: %[[ELEMENTAL:.*]] = hlfir.elemental %[[SHAPE]] unordered : (!fir.shape<1>) -> !hlfir.expr<?xi32> {
! HLFIR: ^bb0(%[[IV:.*]]: index):
! HLFIR: %[[DES_V1:.*]] = hlfir.designate %[[DES1]] (%[[IV]]) : (!fir.box<!fir.array<?xi32>>, index) -> !fir.ref<i32>
! HLFIR: %[[DES_V2:.*]] = hlfir.designate %[[DES2]] (%[[IV]]) : (!fir.box<!fir.array<?xi32>>, index) -> !fir.ref<i32>
! HLFIR: %[[LOAD_V1:.*]] = fir.load %[[DES_V1]] : !fir.ref<i32>
! HLFIR: %[[LOAD_V2:.*]] = fir.load %[[DES_V2]] : !fir.ref<i32>
! HLFIR: %[[COMBINED:.*]] = arith.addi %[[LOAD_V1]], %[[LOAD_V2]] : i32
! HLFIR: hlfir.yield_element %[[COMBINED]] : i32
! HLFIR: }
! HLFIR: hlfir.assign %[[ELEMENTAL]] to %[[ARG0]] : !hlfir.expr<?xi32>, !fir.box<!fir.array<?xi32>>
! HLFIR: acc.yield %[[ARG0]] : !fir.box<!fir.array<?xi32>>
! HLFIR: }

! CHECK-LABEL: acc.reduction.recipe @"reduction_max_ref_?xf32" : !fir.box<!fir.array<?xf32>> reduction_operator <max> init {
! CHECK: ^bb0(%[[ARG0:.*]]: !fir.box<!fir.array<?xf32>>):
! CHECK: %[[INIT_VALUE:.*]] = arith.constant -1.401300e-45 : f32
Expand All @@ -15,12 +41,12 @@
! HLFIR: acc.yield %[[DECLARE]]#0 : !fir.box<!fir.array<?xf32>>
! CHECK: } combiner {
! CHECK: ^bb0(%[[ARG0:.*]]: !fir.box<!fir.array<?xf32>>, %[[ARG1:.*]]: !fir.box<!fir.array<?xf32>>
! HLFIR: %[[BOX_DIMS:.*]]:3 = fir.box_dims %[[ARG0]], %{{.*}} : (!fir.box<!fir.array<?xf32>>, index) -> (index, index, index)
! HLFIR: %[[SHAPE:.*]] = fir.shape %[[BOX_DIMS]]#1 : (index) -> !fir.shape<1>
! HLFIR: %[[ELEMENTAL:.*]] = hlfir.elemental %[[SHAPE]] unordered : (!fir.shape<1>) -> !hlfir.expr<?xf32> {
! HLFIR: %[[LEFT:.*]] = hlfir.designate %[[ARG0]] (%{{.*}}:%{{.*}}:%{{.*}}) shape %{{.*}} : (!fir.box<!fir.array<?xf32>>, index, index, index, !fir.shape<1>) -> !fir.box<!fir.array<?xf32>>
! HLFIR: %[[RIGHT:.*]] = hlfir.designate %[[ARG1]] (%{{.*}}:%{{.*}}:%{{.*}}) shape %{{.*}} : (!fir.box<!fir.array<?xf32>>, index, index, index, !fir.shape<1>) -> !fir.box<!fir.array<?xf32>>
! HLFIR: %[[ELEMENTAL:.*]] = hlfir.elemental %{{.*}} unordered : (!fir.shape<1>) -> !hlfir.expr<?xf32> {
! HLFIR: ^bb0(%{{.*}}: index):
! HLFIR: %[[DES_V1:.*]] = hlfir.designate %[[ARG0]] (%{{.*}}) : (!fir.box<!fir.array<?xf32>>, index) -> !fir.ref<f32>
! HLFIR: %[[DES_V2:.*]] = hlfir.designate %[[ARG1]] (%{{.*}}) : (!fir.box<!fir.array<?xf32>>, index) -> !fir.ref<f32>
! HLFIR: %[[DES_V1:.*]] = hlfir.designate %[[LEFT]] (%{{.*}}) : (!fir.box<!fir.array<?xf32>>, index) -> !fir.ref<f32>
! HLFIR: %[[DES_V2:.*]] = hlfir.designate %[[RIGHT]] (%{{.*}}) : (!fir.box<!fir.array<?xf32>>, index) -> !fir.ref<f32>
! HLFIR: %[[LOAD_V1:.*]] = fir.load %[[DES_V1]] : !fir.ref<f32>
! HLFIR: %[[LOAD_V2:.*]] = fir.load %[[DES_V2]] : !fir.ref<f32>
! HLFIR: %[[CMPF:.*]] = arith.cmpf ogt, %[[LOAD_V1]], %[[LOAD_V2]] : f32
Expand All @@ -43,12 +69,12 @@
! HLFIR: acc.yield %[[DECLARE]]#0 : !fir.box<!fir.array<?xi32>>
! CHECK: } combiner {
! CHECK: ^bb0(%[[V1:.*]]: !fir.box<!fir.array<?xi32>>, %[[V2:.*]]: !fir.box<!fir.array<?xi32>>
! HLFIR: %[[BOX_DIMS]]:3 = fir.box_dims %[[V1]], %{{.*}} : (!fir.box<!fir.array<?xi32>>, index) -> (index, index, index)
! HLFIR: %[[SHAPE:.*]] = fir.shape %[[BOX_DIMS]]#1 : (index) -> !fir.shape<1>
! HLFIR: %[[ELEMENTAL:.*]] = hlfir.elemental %[[SHAPE]] unordered : (!fir.shape<1>) -> !hlfir.expr<?xi32> {
! HLFIR: %[[LEFT:.*]] = hlfir.designate %[[ARG0]] (%{{.*}}:%{{.*}}:%{{.*}}) shape %{{.*}} : (!fir.box<!fir.array<?xi32>>, index, index, index, !fir.shape<1>) -> !fir.box<!fir.array<?xi32>>
! HLFIR: %[[RIGHT:.*]] = hlfir.designate %[[ARG1]] (%{{.*}}:%{{.*}}:%{{.*}}) shape %{{.*}} : (!fir.box<!fir.array<?xi32>>, index, index, index, !fir.shape<1>) -> !fir.box<!fir.array<?xi32>>
! HLFIR: %[[ELEMENTAL:.*]] = hlfir.elemental %{{.*}} unordered : (!fir.shape<1>) -> !hlfir.expr<?xi32> {
! HLFIR: ^bb0(%{{.*}}: index):
! HLFIR: %[[DES_V1:.*]] = hlfir.designate %[[V1]] (%{{.*}}) : (!fir.box<!fir.array<?xi32>>, index) -> !fir.ref<i32>
! HLFIR: %[[DES_V2:.*]] = hlfir.designate %[[V2]] (%{{.*}}) : (!fir.box<!fir.array<?xi32>>, index) -> !fir.ref<i32>
! HLFIR: %[[DES_V1:.*]] = hlfir.designate %[[LEFT]] (%{{.*}}) : (!fir.box<!fir.array<?xi32>>, index) -> !fir.ref<i32>
! HLFIR: %[[DES_V2:.*]] = hlfir.designate %[[RIGHT]] (%{{.*}}) : (!fir.box<!fir.array<?xi32>>, index) -> !fir.ref<i32>
! HLFIR: %[[LOAD_V1:.*]] = fir.load %[[DES_V1]] : !fir.ref<i32>
! HLFIR: %[[LOAD_V2:.*]] = fir.load %[[DES_V2]] : !fir.ref<i32>
! HLFIR: %[[COMBINED:.*]] = arith.addi %[[LOAD_V1]], %[[LOAD_V2]] : i32
Expand Down Expand Up @@ -1084,3 +1110,17 @@ subroutine acc_reduction_add_dynamic_extent_max(a)
! HLFIR: %[[DECLARG0:.*]]:2 = hlfir.declare %[[ARG0]]
! HLFIR: %[[RED:.*]] = acc.reduction varPtr(%{{.*}} : !fir.ref<!fir.array<?xf32>>) bounds(%{{.*}}) -> !fir.ref<!fir.array<?xf32>> {name = "a"}
! HLFIR: acc.parallel reduction(@"reduction_max_ref_?xf32" -> %[[RED]] : !fir.ref<!fir.array<?xf32>>) {

subroutine acc_reduction_add_dynamic_extent_add_with_section(a)
integer :: a(:)
!$acc parallel reduction(+:a(2:4))
!$acc end parallel
end subroutine

! CHECK-LABEL: func.func @_QPacc_reduction_add_dynamic_extent_add_with_section(
! CHECK-SAME: %[[ARG0:.*]]: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "a"})
! HLFIR: %[[DECL:.*]]:2 = hlfir.declare %[[ARG0]] {uniq_name = "_QFacc_reduction_add_dynamic_extent_add_with_sectionEa"} : (!fir.box<!fir.array<?xi32>>) -> (!fir.box<!fir.array<?xi32>>, !fir.box<!fir.array<?xi32>>)
! HLFIR: %[[BOUND:.*]] = acc.bounds lowerbound(%c1{{.*}} : index) upperbound(%c3{{.*}} : index) stride(%{{.*}}#2 : index) startIdx(%{{.*}} : index) {strideInBytes = true}
! HLFIR: %[[BOX_ADDR:.*]] = fir.box_addr %[[DECL]]#1 : (!fir.box<!fir.array<?xi32>>) -> !fir.ref<!fir.array<?xi32>>
! HLFIR: %[[RED:.*]] = acc.reduction varPtr(%[[BOX_ADDR]] : !fir.ref<!fir.array<?xi32>>) bounds(%[[BOUND]]) -> !fir.ref<!fir.array<?xi32>> {name = "a(2:4)"}
! HLFIR: acc.parallel reduction(@"reduction_add_section_lb1.ub3_ref_?xi32" -> %[[RED]] : !fir.ref<!fir.array<?xi32>>)

0 comments on commit 438a6ef

Please sign in to comment.