Skip to content

Commit

Permalink
[mlir][flang][openmp] Rework wsloop reduction operations (#80019)
Browse files Browse the repository at this point in the history
This patch reworks the way that wsloop reduction operations function to
better match the expected semantics from the OpenMP specification,
following the rework of parallel reductions.

The new semantics create a private reduction variable as a block
argument which should be used normally for all operations on that
variable in the region; this private variable is then combined with the
others into the shared variable. This way no special omp.reduction
operations are needed inside the region. These block arguments follow
the loop control block arguments.

---------

Co-authored-by: Kiran Chandramohan <kiran.chandramohan@arm.com>
  • Loading branch information
DavidTruby and kiranchandramohan committed Feb 13, 2024
1 parent a69ecb2 commit be9f8ff
Show file tree
Hide file tree
Showing 37 changed files with 2,477 additions and 1,997 deletions.
66 changes: 62 additions & 4 deletions flang/lib/Lower/OpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3352,6 +3352,57 @@ genLoopVars(mlir::Operation *op, Fortran::lower::AbstractConverter &converter,
return args;
}

static llvm::SmallVector<const Fortran::semantics::Symbol *>
genLoopAndReductionVars(mlir::Operation *op, Fortran::lower::AbstractConverter &converter,
mlir::Location &loc,
const llvm::SmallVector<const Fortran::semantics::Symbol *> &loopArgs,
const llvm::SmallVector<const Fortran::semantics::Symbol *> &reductionArgs,
llvm::SmallVector<mlir::Type> &reductionTypes) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();

llvm::SmallVector<mlir::Type> blockArgTypes;
llvm::SmallVector<mlir::Location> blockArgLocs;
blockArgTypes.reserve(loopArgs.size() + reductionArgs.size());
blockArgLocs.reserve(blockArgTypes.size());
mlir::Block *entryBlock;

if (loopArgs.size()) {
std::size_t loopVarTypeSize = 0;
for (const Fortran::semantics::Symbol *arg : loopArgs)
loopVarTypeSize = std::max(loopVarTypeSize, arg->GetUltimate().size());
mlir::Type loopVarType = getLoopVarType(converter, loopVarTypeSize);
std::fill_n(std::back_inserter(blockArgTypes), loopArgs.size(),
loopVarType);
std::fill_n(std::back_inserter(blockArgLocs), loopArgs.size(), loc);
}
if (reductionArgs.size()) {
llvm::copy(reductionTypes, std::back_inserter(blockArgTypes));
std::fill_n(std::back_inserter(blockArgLocs), reductionArgs.size(), loc);
}
entryBlock = firOpBuilder.createBlock(&op->getRegion(0), {}, blockArgTypes,
blockArgLocs);
// The argument is not currently in memory, so make a temporary for the
// argument, and store it there, then bind that location to the argument.
if (loopArgs.size()) {
mlir::Operation *storeOp = nullptr;
for (auto [argIndex, argSymbol] : llvm::enumerate(loopArgs)) {
mlir::Value indexVal =
fir::getBase(op->getRegion(0).front().getArgument(argIndex));
storeOp =
createAndSetPrivatizedLoopVar(converter, loc, indexVal, argSymbol);
}
firOpBuilder.setInsertionPointAfter(storeOp);
}
// Bind the reduction arguments to their block arguments
for (auto [arg, prv] : llvm::zip_equal(
reductionArgs,
llvm::drop_begin(entryBlock->getArguments(), loopArgs.size()))) {
converter.bindSymbol(*arg, prv);
}

return loopArgs;
}

static void
createSimdLoop(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
Expand Down Expand Up @@ -3429,6 +3480,7 @@ static void createWsLoop(Fortran::lower::AbstractConverter &converter,
llvm::SmallVector<mlir::Value> linearVars, linearStepVars;
llvm::SmallVector<const Fortran::semantics::Symbol *> iv;
llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
llvm::SmallVector<const Fortran::semantics::Symbol *> reductionSymbols;
mlir::omp::ClauseOrderKindAttr orderClauseOperand;
mlir::omp::ClauseScheduleKindAttr scheduleValClauseOperand;
mlir::UnitAttr nowaitClauseOperand, scheduleSimdClauseOperand;
Expand All @@ -3440,7 +3492,8 @@ static void createWsLoop(Fortran::lower::AbstractConverter &converter,
cp.processCollapse(loc, eval, lowerBound, upperBound, step, iv,
loopVarTypeSize);
cp.processScheduleChunk(stmtCtx, scheduleChunkClauseOperand);
cp.processReduction(loc, reductionVars, reductionDeclSymbols);
cp.processReduction(loc, reductionVars, reductionDeclSymbols,
&reductionSymbols);
cp.processTODO<Fortran::parser::OmpClause::Linear,
Fortran::parser::OmpClause::Order>(loc, ompDirective);

Expand Down Expand Up @@ -3484,14 +3537,20 @@ static void createWsLoop(Fortran::lower::AbstractConverter &converter,
auto *nestedEval = getCollapsedLoopEval(
eval, Fortran::lower::getCollapseValue(beginClauseList));

llvm::SmallVector<mlir::Type> reductionTypes;
reductionTypes.reserve(reductionVars.size());
llvm::transform(reductionVars, std::back_inserter(reductionTypes),
[](mlir::Value v) { return v.getType(); });

auto ivCallback = [&](mlir::Operation *op) {
return genLoopVars(op, converter, loc, iv);
return genLoopAndReductionVars(op, converter, loc, iv, reductionSymbols, reductionTypes);
};

createBodyOfOp<mlir::omp::WsLoopOp>(
wsLoopOp, OpWithBodyGenInfo(converter, semaCtx, loc, *nestedEval)
.setClauses(&beginClauseList)
.setDataSharingProcessor(&dsp)
.setReductions(&reductionSymbols, &reductionTypes)
.setGenRegionEntryCb(ivCallback));
}

Expand Down Expand Up @@ -3594,12 +3653,11 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
// 2.9.3.1 SIMD construct
createSimdLoop(converter, semaCtx, eval, ompDirective, loopOpClauseList,
currentLocation);
genOpenMPReduction(converter, semaCtx, loopOpClauseList);
} else {
createWsLoop(converter, semaCtx, eval, ompDirective, loopOpClauseList,
endClauseList, currentLocation);
}

genOpenMPReduction(converter, semaCtx, loopOpClauseList);
}

static void
Expand Down
20 changes: 16 additions & 4 deletions flang/test/Fir/convert-to-llvm-openmp-and-fir.fir
Original file line number Diff line number Diff line change
Expand Up @@ -701,10 +701,17 @@ func.func @_QPsb() {
// CHECK-SAME: %[[ARRAY_REF:.*]]: !llvm.ptr
// CHECK: %[[RED_ACCUMULATOR:.*]] = llvm.alloca %2 x i32 {bindc_name = "x"} : (i64) -> !llvm.ptr
// CHECK: omp.parallel {
// CHECK: omp.wsloop reduction(@[[EQV_REDUCTION]] -> %[[RED_ACCUMULATOR]] : !llvm.ptr) for
// CHECK: omp.wsloop reduction(@[[EQV_REDUCTION]] %[[RED_ACCUMULATOR]] -> %[[PRV:.+]] : !llvm.ptr) for
// CHECK: %[[ARRAY_ELEM_REF:.*]] = llvm.getelementptr %[[ARRAY_REF]][0, %{{.*}}] : (!llvm.ptr, i64) -> !llvm.ptr
// CHECK: %[[ARRAY_ELEM:.*]] = llvm.load %[[ARRAY_ELEM_REF]] : !llvm.ptr -> i32
// CHECK: omp.reduction %[[ARRAY_ELEM]], %[[RED_ACCUMULATOR]] : i32, !llvm.ptr
// CHECK: %[[LPRV:.+]] = llvm.load %[[PRV]] : !llvm.ptr -> i32
// CHECK: %[[ZERO_1:.*]] = llvm.mlir.constant(0 : i64) : i32
// CHECK: %[[ARGVAL_1:.*]] = llvm.icmp "ne" %[[LPRV]], %[[ZERO_1]] : i32
// CHECK: %[[ZERO_2:.*]] = llvm.mlir.constant(0 : i64) : i32
// CHECK: %[[ARGVAL_2:.*]] = llvm.icmp "ne" %[[ARRAY_ELEM]], %[[ZERO_2]] : i32
// CHECK: %[[RES:.*]] = llvm.icmp "eq" %[[ARGVAL_2]], %[[ARGVAL_1]] : i1
// CHECK: %[[RES_EXT:.*]] = llvm.zext %[[RES]] : i1 to i32
// CHECK: llvm.store %[[RES_EXT]], %[[PRV]] : i32, !llvm.ptr
// CHECK: omp.yield
// CHECK: omp.terminator
// CHECK: llvm.return
Expand Down Expand Up @@ -733,15 +740,20 @@ func.func @_QPsimple_reduction(%arg0: !fir.ref<!fir.array<100x!fir.logical<4>>>
%c1_i32 = arith.constant 1 : i32
%c100_i32 = arith.constant 100 : i32
%c1_i32_0 = arith.constant 1 : i32
omp.wsloop reduction(@eqv_reduction -> %1 : !fir.ref<!fir.logical<4>>) for (%arg1) : i32 = (%c1_i32) to (%c100_i32) inclusive step (%c1_i32_0) {
omp.wsloop reduction(@eqv_reduction %1 -> %prv : !fir.ref<!fir.logical<4>>) for (%arg1) : i32 = (%c1_i32) to (%c100_i32) inclusive step (%c1_i32_0) {
fir.store %arg1 to %3 : !fir.ref<i32>
%4 = fir.load %3 : !fir.ref<i32>
%5 = fir.convert %4 : (i32) -> i64
%c1_i64 = arith.constant 1 : i64
%6 = arith.subi %5, %c1_i64 : i64
%7 = fir.coordinate_of %arg0, %6 : (!fir.ref<!fir.array<100x!fir.logical<4>>>, i64) -> !fir.ref<!fir.logical<4>>
%8 = fir.load %7 : !fir.ref<!fir.logical<4>>
omp.reduction %8, %1 : !fir.logical<4>, !fir.ref<!fir.logical<4>>
%lprv = fir.load %prv : !fir.ref<!fir.logical<4>>
%lprv1 = fir.convert %lprv : (!fir.logical<4>) -> i1
%9 = fir.convert %8 : (!fir.logical<4>) -> i1
%10 = arith.cmpi eq, %9, %lprv1 : i1
%11 = fir.convert %10 : (i1) -> !fir.logical<4>
fir.store %11 to %prv : !fir.ref<!fir.logical<4>>
omp.yield
}
omp.terminator
Expand Down

0 comments on commit be9f8ff

Please sign in to comment.