Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,7 @@ void OrderedAssignmentRewriter::pre(hlfir::RegionAssignOp regionAssignOp) {
// if the LHS is not).
mlir::Value shape = hlfir::genShape(loc, builder, lhsEntity);
elementalLoopNest = hlfir::genLoopNest(loc, builder, shape);
builder.setInsertionPointToStart(elementalLoopNest->innerLoop.getBody());
builder.setInsertionPointToStart(elementalLoopNest->body);
lhsEntity = hlfir::getElementAt(loc, builder, lhsEntity,
elementalLoopNest->oneBasedIndices);
rhsEntity = hlfir::getElementAt(loc, builder, rhsEntity,
Expand All @@ -484,7 +484,7 @@ void OrderedAssignmentRewriter::pre(hlfir::RegionAssignOp regionAssignOp) {
for (auto &cleanupConversion : argConversionCleanups)
cleanupConversion();
if (elementalLoopNest)
builder.setInsertionPointAfter(elementalLoopNest->outerLoop);
builder.setInsertionPointAfter(elementalLoopNest->outerOp);
} else {
// TODO: preserve allocatable assignment aspects for forall once
// they are conveyed in hlfir.region_assign.
Expand All @@ -493,7 +493,7 @@ void OrderedAssignmentRewriter::pre(hlfir::RegionAssignOp regionAssignOp) {
generateCleanupIfAny(loweredLhs.elementalCleanup);
if (loweredLhs.vectorSubscriptLoopNest)
builder.setInsertionPointAfter(
loweredLhs.vectorSubscriptLoopNest->outerLoop);
loweredLhs.vectorSubscriptLoopNest->outerOp);
generateCleanupIfAny(oldRhsYield);
generateCleanupIfAny(loweredLhs.nonElementalCleanup);
}
Expand All @@ -518,16 +518,16 @@ void OrderedAssignmentRewriter::pre(hlfir::WhereOp whereOp) {
hlfir::Entity savedMask{maybeSaved->first};
mlir::Value shape = hlfir::genShape(loc, builder, savedMask);
whereLoopNest = hlfir::genLoopNest(loc, builder, shape);
constructStack.push_back(whereLoopNest->outerLoop.getOperation());
builder.setInsertionPointToStart(whereLoopNest->innerLoop.getBody());
constructStack.push_back(whereLoopNest->outerOp);
builder.setInsertionPointToStart(whereLoopNest->body);
mlir::Value cdt = hlfir::getElementAt(loc, builder, savedMask,
whereLoopNest->oneBasedIndices);
generateMaskIfOp(cdt);
if (maybeSaved->second) {
// If this is the same run as the one that saved the value, the clean-up
// was left-over to be done now.
auto insertionPoint = builder.saveInsertionPoint();
builder.setInsertionPointAfter(whereLoopNest->outerLoop);
builder.setInsertionPointAfter(whereLoopNest->outerOp);
generateCleanupIfAny(maybeSaved->second);
builder.restoreInsertionPoint(insertionPoint);
}
Expand All @@ -539,8 +539,8 @@ void OrderedAssignmentRewriter::pre(hlfir::WhereOp whereOp) {
mask.generateNoneElementalPart(builder, mapper);
mlir::Value shape = mask.generateShape(builder, mapper);
whereLoopNest = hlfir::genLoopNest(loc, builder, shape);
constructStack.push_back(whereLoopNest->outerLoop.getOperation());
builder.setInsertionPointToStart(whereLoopNest->innerLoop.getBody());
constructStack.push_back(whereLoopNest->outerOp);
builder.setInsertionPointToStart(whereLoopNest->body);
mlir::Value cdt = generateMaskedEntity(mask);
generateMaskIfOp(cdt);
return;
Expand Down Expand Up @@ -754,7 +754,7 @@ OrderedAssignmentRewriter::generateYieldedLHS(
loweredLhs.vectorSubscriptLoopNest = hlfir::genLoopNest(
loc, builder, loweredLhs.vectorSubscriptShape.value());
builder.setInsertionPointToStart(
loweredLhs.vectorSubscriptLoopNest->innerLoop.getBody());
loweredLhs.vectorSubscriptLoopNest->body);
}
loweredLhs.lhs = temp->second.fetch(loc, builder);
return loweredLhs;
Expand All @@ -772,7 +772,7 @@ OrderedAssignmentRewriter::generateYieldedLHS(
hlfir::genLoopNest(loc, builder, *loweredLhs.vectorSubscriptShape,
!elementalAddrLhs.isOrdered());
builder.setInsertionPointToStart(
loweredLhs.vectorSubscriptLoopNest->innerLoop.getBody());
loweredLhs.vectorSubscriptLoopNest->body);
mapper.map(elementalAddrLhs.getIndices(),
loweredLhs.vectorSubscriptLoopNest->oneBasedIndices);
for (auto &op : elementalAddrLhs.getBody().front().without_terminator())
Expand All @@ -798,11 +798,11 @@ OrderedAssignmentRewriter::generateMaskedEntity(MaskedArrayExpr &maskedExpr) {
if (!maskedExpr.noneElementalPartWasGenerated) {
// Generate none elemental part before the where loops (but inside the
// current forall loops if any).
builder.setInsertionPoint(whereLoopNest->outerLoop);
builder.setInsertionPoint(whereLoopNest->outerOp);
maskedExpr.generateNoneElementalPart(builder, mapper);
}
// Generate the none elemental part cleanup after the where loops.
builder.setInsertionPointAfter(whereLoopNest->outerLoop);
builder.setInsertionPointAfter(whereLoopNest->outerOp);
maskedExpr.generateNoneElementalCleanupIfAny(builder, mapper);
// Generate the value of the current element for the masked expression
// at the current insertion point (inside the where loops, and any fir.if
Expand Down Expand Up @@ -1242,7 +1242,7 @@ void OrderedAssignmentRewriter::saveLeftHandSide(
LhsValueAndCleanUp loweredLhs = generateYieldedLHS(loc, region);
fir::factory::TemporaryStorage *temp = nullptr;
if (loweredLhs.vectorSubscriptLoopNest)
constructStack.push_back(loweredLhs.vectorSubscriptLoopNest->outerLoop);
constructStack.push_back(loweredLhs.vectorSubscriptLoopNest->outerOp);
if (loweredLhs.vectorSubscriptLoopNest && !rhsIsArray(regionAssignOp)) {
// Vector subscripted entity for which the shape must also be saved on top
// of the element addresses (e.g. the shape may change in each forall
Expand All @@ -1265,7 +1265,7 @@ void OrderedAssignmentRewriter::saveLeftHandSide(
// subscripted LHS.
auto &vectorTmp = temp->cast<fir::factory::AnyVectorSubscriptStack>();
auto insertionPoint = builder.saveInsertionPoint();
builder.setInsertionPoint(loweredLhs.vectorSubscriptLoopNest->outerLoop);
builder.setInsertionPoint(loweredLhs.vectorSubscriptLoopNest->outerOp);
vectorTmp.pushShape(loc, builder, shape);
builder.restoreInsertionPoint(insertionPoint);
} else {
Expand All @@ -1291,7 +1291,7 @@ void OrderedAssignmentRewriter::saveLeftHandSide(
if (loweredLhs.vectorSubscriptLoopNest) {
constructStack.pop_back();
builder.setInsertionPointAfter(
loweredLhs.vectorSubscriptLoopNest->outerLoop);
loweredLhs.vectorSubscriptLoopNest->outerOp);
}
}

Expand Down
16 changes: 10 additions & 6 deletions flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "flang/Optimizer/HLFIR/HLFIRDialect.h"
#include "flang/Optimizer/HLFIR/HLFIROps.h"
#include "flang/Optimizer/HLFIR/Passes.h"
#include "flang/Optimizer/OpenMP/Passes.h"
#include "flang/Optimizer/Transforms/Utils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Dominance.h"
Expand Down Expand Up @@ -482,8 +483,9 @@ llvm::LogicalResult ElementalAssignBufferization::matchAndRewrite(
// Generate a loop nest looping around the hlfir.elemental shape and clone
// hlfir.elemental region inside the inner loop
hlfir::LoopNest loopNest =
hlfir::genLoopNest(loc, builder, extents, !elemental.isOrdered());
builder.setInsertionPointToStart(loopNest.innerLoop.getBody());
hlfir::genLoopNest(loc, builder, extents, !elemental.isOrdered(),
flangomp::shouldUseWorkshareLowering(elemental));
builder.setInsertionPointToStart(loopNest.body);
auto yield = hlfir::inlineElementalOp(loc, builder, elemental,
loopNest.oneBasedIndices);
hlfir::Entity elementValue{yield.getElementValue()};
Expand Down Expand Up @@ -553,8 +555,9 @@ llvm::LogicalResult BroadcastAssignBufferization::matchAndRewrite(
llvm::SmallVector<mlir::Value> extents =
hlfir::getIndexExtents(loc, builder, shape);
hlfir::LoopNest loopNest =
hlfir::genLoopNest(loc, builder, extents, /*isUnordered=*/true);
builder.setInsertionPointToStart(loopNest.innerLoop.getBody());
hlfir::genLoopNest(loc, builder, extents, /*isUnordered=*/true,
flangomp::shouldUseWorkshareLowering(assign));
builder.setInsertionPointToStart(loopNest.body);
auto arrayElement =
hlfir::getElementAt(loc, builder, lhs, loopNest.oneBasedIndices);
builder.create<hlfir::AssignOp>(loc, rhs, arrayElement);
Expand Down Expand Up @@ -648,8 +651,9 @@ llvm::LogicalResult VariableAssignBufferization::matchAndRewrite(
llvm::SmallVector<mlir::Value> extents =
hlfir::getIndexExtents(loc, builder, shape);
hlfir::LoopNest loopNest =
hlfir::genLoopNest(loc, builder, extents, /*isUnordered=*/true);
builder.setInsertionPointToStart(loopNest.innerLoop.getBody());
hlfir::genLoopNest(loc, builder, extents, /*isUnordered=*/true,
flangomp::shouldUseWorkshareLowering(assign));
builder.setInsertionPointToStart(loopNest.body);
auto rhsArrayElement =
hlfir::getElementAt(loc, builder, rhs, loopNest.oneBasedIndices);
rhsArrayElement = hlfir::loadTrivialScalar(loc, builder, rhsArrayElement);
Expand Down
7 changes: 4 additions & 3 deletions flang/lib/Optimizer/OpenMP/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)

add_flang_library(FlangOpenMPTransforms
OMPFunctionFiltering.cpp
OMPMapInfoFinalization.cpp
OMPMarkDeclareTarget.cpp
FunctionFiltering.cpp
MapInfoFinalization.cpp
MarkDeclareTarget.cpp
LowerWorkshare.cpp

DEPENDS
FIRDialect
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
//===- OMPFunctionFiltering.cpp -------------------------------------------===//
//===- FunctionFiltering.cpp -------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
Expand All @@ -22,18 +22,17 @@
#include "llvm/ADT/SmallVector.h"

namespace flangomp {
#define GEN_PASS_DEF_OMPFUNCTIONFILTERING
#define GEN_PASS_DEF_FUNCTIONFILTERING
#include "flang/Optimizer/OpenMP/Passes.h.inc"
} // namespace flangomp

using namespace mlir;

namespace {
class OMPFunctionFilteringPass
: public flangomp::impl::OMPFunctionFilteringBase<
OMPFunctionFilteringPass> {
class FunctionFilteringPass
: public flangomp::impl::FunctionFilteringBase<FunctionFilteringPass> {
public:
OMPFunctionFilteringPass() = default;
FunctionFilteringPass() = default;

void runOnOperation() override {
MLIRContext *context = &getContext();
Expand Down
442 changes: 442 additions & 0 deletions flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
//===- OMPMapInfoFinalization.cpp -----------------------------------------===//
//===- MapInfoFinalization.cpp -----------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
Expand Down Expand Up @@ -41,14 +41,14 @@
#include <iterator>

namespace flangomp {
#define GEN_PASS_DEF_OMPMAPINFOFINALIZATIONPASS
#define GEN_PASS_DEF_MAPINFOFINALIZATIONPASS
#include "flang/Optimizer/OpenMP/Passes.h.inc"
} // namespace flangomp

namespace {
class OMPMapInfoFinalizationPass
: public flangomp::impl::OMPMapInfoFinalizationPassBase<
OMPMapInfoFinalizationPass> {
class MapInfoFinalizationPass
: public flangomp::impl::MapInfoFinalizationPassBase<
MapInfoFinalizationPass> {

void genDescriptorMemberMaps(mlir::omp::MapInfoOp op,
fir::FirOpBuilder &builder,
Expand Down Expand Up @@ -244,7 +244,7 @@ class OMPMapInfoFinalizationPass
// all users appropriately, making sure to only add a single member link
// per new generation for the original originating descriptor MapInfoOp.
assert(llvm::hasSingleElement(op->getUsers()) &&
"OMPMapInfoFinalization currently only supports single users "
"MapInfoFinalization currently only supports single users "
"of a MapInfoOp");

if (!op.getMembers().empty()) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
//===- OMPMarkDeclareTarget.cpp -------------------------------------------===//
//===- MarkDeclareTarget.cpp -------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
Expand All @@ -23,14 +23,13 @@
#include "llvm/ADT/SmallPtrSet.h"

namespace flangomp {
#define GEN_PASS_DEF_OMPMARKDECLARETARGETPASS
#define GEN_PASS_DEF_MARKDECLARETARGETPASS
#include "flang/Optimizer/OpenMP/Passes.h.inc"
} // namespace flangomp

namespace {
class OMPMarkDeclareTargetPass
: public flangomp::impl::OMPMarkDeclareTargetPassBase<
OMPMarkDeclareTargetPass> {
class MarkDeclareTargetPass
: public flangomp::impl::MarkDeclareTargetPassBase<MarkDeclareTargetPass> {

void markNestedFuncs(mlir::omp::DeclareTargetDeviceType parentDevTy,
mlir::omp::DeclareTargetCaptureClause parentCapClause,
Expand Down
1 change: 1 addition & 0 deletions flang/test/Fir/basic-program.fir
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ func.func @_QQmain() {
// PASSES-NEXT: LowerHLFIRIntrinsics
// PASSES-NEXT: BufferizeHLFIR
// PASSES-NEXT: ConvertHLFIRtoFIR
// PASSES-NEXT: LowerWorkshare
// PASSES-NEXT: CSE
// PASSES-NEXT: (S) 0 num-cse'd - Number of operations CSE'd
// PASSES-NEXT: (S) 0 num-dce'd - Number of operations DCE'd
Expand Down
58 changes: 58 additions & 0 deletions flang/test/HLFIR/bufferize-workshare.fir
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
// RUN: fir-opt --bufferize-hlfir %s | FileCheck %s

// CHECK-LABEL: func.func @simple(
// CHECK-SAME: %[[VAL_0:.*]]: !fir.ref<!fir.array<42xi32>>) {
// CHECK: omp.parallel {
// CHECK: omp.workshare {
// CHECK: %[[VAL_1:.*]] = arith.constant 42 : index
// CHECK: %[[VAL_2:.*]] = arith.constant 1 : i32
// CHECK: %[[VAL_3:.*]] = fir.shape %[[VAL_1]] : (index) -> !fir.shape<1>
// CHECK: %[[VAL_4:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_3]]) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
// CHECK: %[[VAL_5:.*]] = fir.allocmem !fir.array<42xi32> {bindc_name = ".tmp.array", uniq_name = ""}
// CHECK: %[[VAL_6:.*]]:2 = hlfir.declare %[[VAL_5]](%[[VAL_3]]) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
// CHECK: %[[VAL_7:.*]] = arith.constant true
// CHECK: %[[VAL_8:.*]] = arith.constant 1 : index
// CHECK: omp.workshare_loop_wrapper {
// CHECK: omp.loop_nest (%[[VAL_9:.*]]) : index = (%[[VAL_8]]) to (%[[VAL_1]]) inclusive step (%[[VAL_8]]) {
// CHECK: %[[VAL_10:.*]] = hlfir.designate %[[VAL_4]]#0 (%[[VAL_9]]) : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
// CHECK: %[[VAL_11:.*]] = fir.load %[[VAL_10]] : !fir.ref<i32>
// CHECK: %[[VAL_12:.*]] = arith.subi %[[VAL_11]], %[[VAL_2]] : i32
// CHECK: %[[VAL_13:.*]] = hlfir.designate %[[VAL_6]]#0 (%[[VAL_9]]) : (!fir.heap<!fir.array<42xi32>>, index) -> !fir.ref<i32>
// CHECK: hlfir.assign %[[VAL_12]] to %[[VAL_13]] temporary_lhs : i32, !fir.ref<i32>
// CHECK: omp.yield
// CHECK: }
// CHECK: omp.terminator
// CHECK: }
// CHECK: %[[VAL_14:.*]] = fir.undefined tuple<!fir.heap<!fir.array<42xi32>>, i1>
// CHECK: %[[VAL_15:.*]] = fir.insert_value %[[VAL_14]], %[[VAL_7]], [1 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, i1) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>
// CHECK: %[[VAL_16:.*]] = fir.insert_value %[[VAL_15]], %[[VAL_6]]#0, [0 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, !fir.heap<!fir.array<42xi32>>) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>
// CHECK: hlfir.assign %[[VAL_6]]#0 to %[[VAL_4]]#0 : !fir.heap<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>
// CHECK: fir.freemem %[[VAL_6]]#0 : !fir.heap<!fir.array<42xi32>>
// CHECK: omp.terminator
// CHECK: }
// CHECK: omp.terminator
// CHECK: }
// CHECK: return
// CHECK: }
func.func @simple(%arg: !fir.ref<!fir.array<42xi32>>) {
omp.parallel {
omp.workshare {
%c42 = arith.constant 42 : index
%c1_i32 = arith.constant 1 : i32
%shape = fir.shape %c42 : (index) -> !fir.shape<1>
%array:2 = hlfir.declare %arg(%shape) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
%elemental = hlfir.elemental %shape unordered : (!fir.shape<1>) -> !hlfir.expr<42xi32> {
^bb0(%i: index):
%ref = hlfir.designate %array#0 (%i) : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
%val = fir.load %ref : !fir.ref<i32>
%sub = arith.subi %val, %c1_i32 : i32
hlfir.yield_element %sub : i32
}
hlfir.assign %elemental to %array#0 : !hlfir.expr<42xi32>, !fir.ref<!fir.array<42xi32>>
hlfir.destroy %elemental : !hlfir.expr<42xi32>
omp.terminator
}
omp.terminator
}
return
}
6 changes: 3 additions & 3 deletions flang/test/Lower/OpenMP/workshare.f90
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ subroutine sb1(arr)
integer :: arr(:)
!CHECK: omp.parallel {
!$omp parallel
!CHECK: omp.single {
!CHECK: omp.workshare {
!$omp workshare
arr = 0
!$omp end workshare
Expand All @@ -20,7 +20,7 @@ subroutine sb2(arr)
integer :: arr(:)
!CHECK: omp.parallel {
!$omp parallel
!CHECK: omp.single nowait {
!CHECK: omp.workshare nowait {
!$omp workshare
arr = 0
!$omp end workshare nowait
Expand All @@ -33,7 +33,7 @@ subroutine sb2(arr)
subroutine sb3(arr)
integer :: arr(:)
!CHECK: omp.parallel {
!CHECK: omp.single {
!CHECK: omp.workshare {
!$omp parallel workshare
arr = 0
!$omp end parallel workshare
Expand Down
199 changes: 199 additions & 0 deletions flang/test/Transforms/OpenMP/lower-workshare.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
// RUN: fir-opt --split-input-file --lower-workshare --allow-unregistered-dialect %s | FileCheck %s

// checks:
// nowait on final omp.single
func.func @wsfunc(%arg0: !fir.ref<!fir.array<42xi32>>) {
omp.parallel {
omp.workshare {
%c42 = arith.constant 42 : index
%c1_i32 = arith.constant 1 : i32
%0 = fir.shape %c42 : (index) -> !fir.shape<1>
%1:2 = hlfir.declare %arg0(%0) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
%2 = fir.allocmem !fir.array<42xi32> {bindc_name = ".tmp.array", uniq_name = ""}
%3:2 = hlfir.declare %2(%0) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
%true = arith.constant true
%c1 = arith.constant 1 : index
omp.workshare_loop_wrapper {
omp.loop_nest (%arg1) : index = (%c1) to (%c42) inclusive step (%c1) {
%7 = hlfir.designate %1#0 (%arg1) : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
%8 = fir.load %7 : !fir.ref<i32>
%9 = arith.subi %8, %c1_i32 : i32
%10 = hlfir.designate %3#0 (%arg1) : (!fir.heap<!fir.array<42xi32>>, index) -> !fir.ref<i32>
hlfir.assign %9 to %10 temporary_lhs : i32, !fir.ref<i32>
omp.yield
}
omp.terminator
}
%4 = fir.undefined tuple<!fir.heap<!fir.array<42xi32>>, i1>
%5 = fir.insert_value %4, %true, [1 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, i1) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>
%6 = fir.insert_value %5, %3#0, [0 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, !fir.heap<!fir.array<42xi32>>) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>
hlfir.assign %3#0 to %1#0 : !fir.heap<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>
fir.freemem %3#0 : !fir.heap<!fir.array<42xi32>>
omp.terminator
}
omp.terminator
}
return
}

// -----

// checks:
// fir.alloca hoisted out and copyprivate'd
func.func @wsfunc(%arg0: !fir.ref<!fir.array<42xi32>>) {
omp.workshare {
%c1_i32 = arith.constant 1 : i32
%alloc = fir.alloca i32
fir.store %c1_i32 to %alloc : !fir.ref<i32>
%c42 = arith.constant 42 : index
%0 = fir.shape %c42 : (index) -> !fir.shape<1>
%1:2 = hlfir.declare %arg0(%0) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
%2 = fir.allocmem !fir.array<42xi32> {bindc_name = ".tmp.array", uniq_name = ""}
%3:2 = hlfir.declare %2(%0) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
%true = arith.constant true
%c1 = arith.constant 1 : index
omp.workshare_loop_wrapper {
omp.loop_nest (%arg1) : index = (%c1) to (%c42) inclusive step (%c1) {
%7 = hlfir.designate %1#0 (%arg1) : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
%8 = fir.load %7 : !fir.ref<i32>
%ld = fir.load %alloc : !fir.ref<i32>
%n8 = arith.subi %8, %ld : i32
%9 = arith.subi %n8, %c1_i32 : i32
%10 = hlfir.designate %3#0 (%arg1) : (!fir.heap<!fir.array<42xi32>>, index) -> !fir.ref<i32>
hlfir.assign %9 to %10 temporary_lhs : i32, !fir.ref<i32>
omp.yield
}
omp.terminator
}
%4 = fir.undefined tuple<!fir.heap<!fir.array<42xi32>>, i1>
%5 = fir.insert_value %4, %true, [1 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, i1) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>
%6 = fir.insert_value %5, %3#0, [0 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, !fir.heap<!fir.array<42xi32>>) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>
"test.test1"(%alloc) : (!fir.ref<i32>) -> ()
hlfir.assign %3#0 to %1#0 : !fir.heap<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>
fir.freemem %3#0 : !fir.heap<!fir.array<42xi32>>
omp.terminator
}
return
}

// CHECK-LABEL: func.func private @_workshare_copy_heap_42xi32(
// CHECK-SAME: %[[VAL_0:.*]]: !fir.ref<!fir.heap<!fir.array<42xi32>>>,
// CHECK-SAME: %[[VAL_1:.*]]: !fir.ref<!fir.heap<!fir.array<42xi32>>>) {
// CHECK: %[[VAL_2:.*]] = fir.load %[[VAL_0]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
// CHECK: fir.store %[[VAL_2]] to %[[VAL_1]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
// CHECK: return
// CHECK: }

// CHECK-LABEL: func.func @wsfunc(
// CHECK-SAME: %[[VAL_0:.*]]: !fir.ref<!fir.array<42xi32>>) {
// CHECK: omp.parallel {
// CHECK: %[[VAL_1:.*]] = fir.alloca !fir.heap<!fir.array<42xi32>>
// CHECK: omp.single copyprivate(%[[VAL_1]] -> @_workshare_copy_heap_42xi32 : !fir.ref<!fir.heap<!fir.array<42xi32>>>) {
// CHECK: %[[VAL_2:.*]] = arith.constant 42 : index
// CHECK: %[[VAL_3:.*]] = fir.shape %[[VAL_2]] : (index) -> !fir.shape<1>
// CHECK: %[[VAL_4:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_3]]) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
// CHECK: %[[VAL_5:.*]] = fir.allocmem !fir.array<42xi32> {bindc_name = ".tmp.array", uniq_name = ""}
// CHECK: fir.store %[[VAL_5]] to %[[VAL_1]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
// CHECK: %[[VAL_6:.*]]:2 = hlfir.declare %[[VAL_5]](%[[VAL_3]]) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
// CHECK: omp.terminator
// CHECK: }
// CHECK: %[[VAL_7:.*]] = arith.constant 42 : index
// CHECK: %[[VAL_8:.*]] = arith.constant 1 : i32
// CHECK: %[[VAL_9:.*]] = fir.shape %[[VAL_7]] : (index) -> !fir.shape<1>
// CHECK: %[[VAL_10:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_9]]) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
// CHECK: %[[VAL_11:.*]] = fir.load %[[VAL_1]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
// CHECK: %[[VAL_12:.*]]:2 = hlfir.declare %[[VAL_11]](%[[VAL_9]]) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
// CHECK: %[[VAL_13:.*]] = arith.constant true
// CHECK: %[[VAL_14:.*]] = arith.constant 1 : index
// CHECK: omp.wsloop {
// CHECK: omp.loop_nest (%[[VAL_15:.*]]) : index = (%[[VAL_14]]) to (%[[VAL_7]]) inclusive step (%[[VAL_14]]) {
// CHECK: %[[VAL_16:.*]] = hlfir.designate %[[VAL_10]]#0 (%[[VAL_15]]) : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
// CHECK: %[[VAL_17:.*]] = fir.load %[[VAL_16]] : !fir.ref<i32>
// CHECK: %[[VAL_18:.*]] = arith.subi %[[VAL_17]], %[[VAL_8]] : i32
// CHECK: %[[VAL_19:.*]] = hlfir.designate %[[VAL_12]]#0 (%[[VAL_15]]) : (!fir.heap<!fir.array<42xi32>>, index) -> !fir.ref<i32>
// CHECK: hlfir.assign %[[VAL_18]] to %[[VAL_19]] temporary_lhs : i32, !fir.ref<i32>
// CHECK: omp.yield
// CHECK: }
// CHECK: omp.terminator
// CHECK: }
// CHECK: omp.single nowait {
// CHECK: %[[VAL_20:.*]] = fir.undefined tuple<!fir.heap<!fir.array<42xi32>>, i1>
// CHECK: %[[VAL_21:.*]] = fir.insert_value %[[VAL_20]], %[[VAL_13]], [1 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, i1) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>
// CHECK: hlfir.assign %[[VAL_12]]#0 to %[[VAL_10]]#0 : !fir.heap<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>
// CHECK: fir.freemem %[[VAL_12]]#0 : !fir.heap<!fir.array<42xi32>>
// CHECK: omp.terminator
// CHECK: }
// CHECK: %[[VAL_22:.*]] = fir.undefined tuple<!fir.heap<!fir.array<42xi32>>, i1>
// CHECK: %[[VAL_23:.*]] = fir.insert_value %[[VAL_22]], %[[VAL_13]], [1 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, i1) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>
// CHECK: omp.barrier
// CHECK: omp.terminator
// CHECK: }
// CHECK: return
// CHECK: }

// CHECK-LABEL: func.func private @_workshare_copy_heap_42xi32(
// CHECK-SAME: %[[VAL_0:.*]]: !fir.ref<!fir.heap<!fir.array<42xi32>>>,
// CHECK-SAME: %[[VAL_1:.*]]: !fir.ref<!fir.heap<!fir.array<42xi32>>>) {
// CHECK: %[[VAL_2:.*]] = fir.load %[[VAL_0]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
// CHECK: fir.store %[[VAL_2]] to %[[VAL_1]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
// CHECK: return
// CHECK: }

// CHECK-LABEL: func.func private @_workshare_copy_i32(
// CHECK-SAME: %[[VAL_0:.*]]: !fir.ref<i32>,
// CHECK-SAME: %[[VAL_1:.*]]: !fir.ref<i32>) {
// CHECK: %[[VAL_2:.*]] = fir.load %[[VAL_0]] : !fir.ref<i32>
// CHECK: fir.store %[[VAL_2]] to %[[VAL_1]] : !fir.ref<i32>
// CHECK: return
// CHECK: }

// CHECK-LABEL: func.func @wsfunc(
// CHECK-SAME: %[[VAL_0:.*]]: !fir.ref<!fir.array<42xi32>>) {
// CHECK: %[[VAL_1:.*]] = fir.alloca i32
// CHECK: %[[VAL_2:.*]] = fir.alloca !fir.heap<!fir.array<42xi32>>
// CHECK: omp.single copyprivate(%[[VAL_1]] -> @_workshare_copy_i32 : !fir.ref<i32>, %[[VAL_2]] -> @_workshare_copy_heap_42xi32 : !fir.ref<!fir.heap<!fir.array<42xi32>>>) {
// CHECK: %[[VAL_3:.*]] = arith.constant 1 : i32
// CHECK: fir.store %[[VAL_3]] to %[[VAL_1]] : !fir.ref<i32>
// CHECK: %[[VAL_4:.*]] = arith.constant 42 : index
// CHECK: %[[VAL_5:.*]] = fir.shape %[[VAL_4]] : (index) -> !fir.shape<1>
// CHECK: %[[VAL_6:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_5]]) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
// CHECK: %[[VAL_7:.*]] = fir.allocmem !fir.array<42xi32> {bindc_name = ".tmp.array", uniq_name = ""}
// CHECK: fir.store %[[VAL_7]] to %[[VAL_2]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
// CHECK: %[[VAL_8:.*]]:2 = hlfir.declare %[[VAL_7]](%[[VAL_5]]) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
// CHECK: omp.terminator
// CHECK: }
// CHECK: %[[VAL_9:.*]] = arith.constant 1 : i32
// CHECK: %[[VAL_10:.*]] = arith.constant 42 : index
// CHECK: %[[VAL_11:.*]] = fir.shape %[[VAL_10]] : (index) -> !fir.shape<1>
// CHECK: %[[VAL_12:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_11]]) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
// CHECK: %[[VAL_13:.*]] = fir.load %[[VAL_2]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
// CHECK: %[[VAL_14:.*]]:2 = hlfir.declare %[[VAL_13]](%[[VAL_11]]) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
// CHECK: %[[VAL_15:.*]] = arith.constant true
// CHECK: %[[VAL_16:.*]] = arith.constant 1 : index
// CHECK: omp.wsloop {
// CHECK: omp.loop_nest (%[[VAL_17:.*]]) : index = (%[[VAL_16]]) to (%[[VAL_10]]) inclusive step (%[[VAL_16]]) {
// CHECK: %[[VAL_18:.*]] = hlfir.designate %[[VAL_12]]#0 (%[[VAL_17]]) : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
// CHECK: %[[VAL_19:.*]] = fir.load %[[VAL_18]] : !fir.ref<i32>
// CHECK: %[[VAL_20:.*]] = fir.load %[[VAL_1]] : !fir.ref<i32>
// CHECK: %[[VAL_21:.*]] = arith.subi %[[VAL_19]], %[[VAL_20]] : i32
// CHECK: %[[VAL_22:.*]] = arith.subi %[[VAL_21]], %[[VAL_9]] : i32
// CHECK: %[[VAL_23:.*]] = hlfir.designate %[[VAL_14]]#0 (%[[VAL_17]]) : (!fir.heap<!fir.array<42xi32>>, index) -> !fir.ref<i32>
// CHECK: hlfir.assign %[[VAL_22]] to %[[VAL_23]] temporary_lhs : i32, !fir.ref<i32>
// CHECK: omp.yield
// CHECK: }
// CHECK: omp.terminator
// CHECK: }
// CHECK: omp.single nowait {
// CHECK: %[[VAL_24:.*]] = fir.undefined tuple<!fir.heap<!fir.array<42xi32>>, i1>
// CHECK: %[[VAL_25:.*]] = fir.insert_value %[[VAL_24]], %[[VAL_15]], [1 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, i1) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>
// CHECK: "test.test1"(%[[VAL_1]]) : (!fir.ref<i32>) -> ()
// CHECK: hlfir.assign %[[VAL_14]]#0 to %[[VAL_12]]#0 : !fir.heap<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>
// CHECK: fir.freemem %[[VAL_14]]#0 : !fir.heap<!fir.array<42xi32>>
// CHECK: omp.terminator
// CHECK: }
// CHECK: %[[VAL_26:.*]] = fir.undefined tuple<!fir.heap<!fir.array<42xi32>>, i1>
// CHECK: %[[VAL_27:.*]] = fir.insert_value %[[VAL_26]], %[[VAL_15]], [1 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, i1) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>
// CHECK: omp.barrier
// CHECK: return
// CHECK: }

23 changes: 23 additions & 0 deletions flang/test/Transforms/OpenMP/lower-workshare2.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// RUN: fir-opt --split-input-file --lower-workshare --allow-unregistered-dialect %s | FileCheck %s

// Check that we correctly handle nowait

// CHECK-LABEL: func.func @nonowait
func.func @nonowait(%arg0: !fir.ref<!fir.array<42xi32>>) {
// CHECK: omp.barrier
omp.workshare {
omp.terminator
}
return
}

// -----

// CHECK-LABEL: func.func @nowait
func.func @nowait(%arg0: !fir.ref<!fir.array<42xi32>>) {
// CHECK-NOT: omp.barrier
omp.workshare nowait {
omp.terminator
}
return
}
74 changes: 74 additions & 0 deletions flang/test/Transforms/OpenMP/lower-workshare3.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
// RUN: fir-opt --split-input-file --lower-workshare --allow-unregistered-dialect %s | FileCheck %s


// Check if we store the correct values

func.func @wsfunc() {
omp.parallel {
// CHECK: fir.alloca
// CHECK: fir.alloca
// CHECK: fir.alloca
// CHECK: fir.alloca
// CHECK: fir.alloca
// CHECK-NOT: fir.alloca
omp.workshare {

%t1 = "test.test1"() : () -> i32
// CHECK: %[[T1:.*]] = "test.test1"
// CHECK: fir.store %[[T1]]
%t2 = "test.test2"() : () -> i32
// CHECK: %[[T2:.*]] = "test.test2"
// CHECK: fir.store %[[T2]]
%t3 = "test.test3"() : () -> i32
// CHECK: %[[T3:.*]] = "test.test3"
// CHECK-NOT: fir.store %[[T3]]
%t4 = "test.test4"() : () -> i32
// CHECK: %[[T4:.*]] = "test.test4"
// CHECK: fir.store %[[T4]]
%t5 = "test.test5"() : () -> i32
// CHECK: %[[T5:.*]] = "test.test5"
// CHECK: fir.store %[[T5]]
%t6 = "test.test6"() : () -> i32
// CHECK: %[[T6:.*]] = "test.test6"
// CHECK-NOT: fir.store %[[T6]]


"test.test1"(%t1) : (i32) -> ()
"test.test1"(%t2) : (i32) -> ()
"test.test1"(%t3) : (i32) -> ()

%true = arith.constant true
fir.if %true {
"test.test2"(%t3) : (i32) -> ()
}

%c1_i32 = arith.constant 1 : i32

%t5_pure_use = arith.addi %t5, %c1_i32 : i32

%t6_mem_effect_use = "test.test8"(%t6) : (i32) -> i32
// CHECK: %[[T6_USE:.*]] = "test.test8"
// CHECK: fir.store %[[T6_USE]]

%c42 = arith.constant 42 : index
%c1 = arith.constant 1 : index
omp.workshare_loop_wrapper {
omp.loop_nest (%arg1) : index = (%c1) to (%c42) inclusive step (%c1) {
"test.test10"(%t1) : (i32) -> ()
"test.test10"(%t5_pure_use) : (i32) -> ()
"test.test10"(%t6_mem_effect_use) : (i32) -> ()
omp.yield
}
omp.terminator
}

"test.test10"(%t2) : (i32) -> ()
fir.if %true {
"test.test10"(%t4) : (i32) -> ()
}
omp.terminator
}
omp.terminator
}
return
}
55 changes: 55 additions & 0 deletions flang/test/Transforms/OpenMP/lower-workshare4.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
// RUN: fir-opt --split-input-file --lower-workshare --allow-unregistered-dialect %s | FileCheck %s

// Check that we cleanup unused pure operations from either the parallel or
// single regions

func.func @wsfunc() {
%a = fir.alloca i32
omp.parallel {
omp.workshare {
%t1 = "test.test1"() : () -> i32

%c1 = arith.constant 1 : index
%c42 = arith.constant 42 : index

%c2 = arith.constant 2 : index
"test.test3"(%c2) : (index) -> ()

omp.workshare_loop_wrapper {
omp.loop_nest (%arg1) : index = (%c1) to (%c42) inclusive step (%c1) {
"test.test2"() : () -> ()
omp.yield
}
omp.terminator
}
omp.terminator
}
omp.terminator
}
return
}

// CHECK-LABEL: func.func @wsfunc() {
// CHECK: %[[VAL_0:.*]] = fir.alloca i32
// CHECK: omp.parallel {
// CHECK: omp.single {
// CHECK: %[[VAL_1:.*]] = "test.test1"() : () -> i32
// CHECK: %[[VAL_2:.*]] = arith.constant 2 : index
// CHECK: "test.test3"(%[[VAL_2]]) : (index) -> ()
// CHECK: omp.terminator
// CHECK: }
// CHECK: %[[VAL_3:.*]] = arith.constant 1 : index
// CHECK: %[[VAL_4:.*]] = arith.constant 42 : index
// CHECK: omp.wsloop nowait {
// CHECK: omp.loop_nest (%[[VAL_5:.*]]) : index = (%[[VAL_3]]) to (%[[VAL_4]]) inclusive step (%[[VAL_3]]) {
// CHECK: "test.test2"() : () -> ()
// CHECK: omp.yield
// CHECK: }
// CHECK: omp.terminator
// CHECK: }
// CHECK: omp.barrier
// CHECK: omp.terminator
// CHECK: }
// CHECK: return
// CHECK: }

42 changes: 42 additions & 0 deletions flang/test/Transforms/OpenMP/lower-workshare5.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// XFAIL: *
// RUN: fir-opt --split-input-file --lower-workshare --allow-unregistered-dialect %s | FileCheck %s

// TODO we can lower these but we have no guarantee that the parent of
// omp.workshare supports multi-block regions, thus we fail for now.

func.func @wsfunc() {
%a = fir.alloca i32
omp.parallel {
omp.workshare {
^bb1:
%c1 = arith.constant 1 : i32
cf.br ^bb3(%c1: i32)
^bb3(%arg1: i32):
"test.test2"(%arg1) : (i32) -> ()
omp.terminator
}
omp.terminator
}
return
}

// -----

func.func @wsfunc() {
%a = fir.alloca i32
omp.parallel {
omp.workshare {
^bb1:
%c1 = arith.constant 1 : i32
cf.br ^bb3(%c1: i32)
^bb2:
"test.test2"(%r) : (i32) -> ()
omp.terminator
^bb3(%arg1: i32):
%r = "test.test2"(%arg1) : (i32) -> i32
cf.br ^bb2
}
omp.terminator
}
return
}
51 changes: 51 additions & 0 deletions flang/test/Transforms/OpenMP/lower-workshare6.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// RUN: fir-opt --split-input-file --lower-workshare --allow-unregistered-dialect %s | FileCheck %s

// Checks that the omp.workshare_loop_wrapper binds to the correct omp.workshare

func.func @wsfunc() {
%c1 = arith.constant 1 : index
%c42 = arith.constant 42 : index
omp.parallel {
omp.workshare nowait {
omp.parallel {
omp.workshare nowait {
omp.workshare_loop_wrapper {
omp.loop_nest (%arg1) : index = (%c1) to (%c42) inclusive step (%c1) {
"test.test2"() : () -> ()
omp.yield
}
omp.terminator
}
omp.terminator
}
omp.terminator
}
omp.terminator
}
omp.terminator
}
return
}

// CHECK-LABEL: func.func @wsfunc() {
// CHECK: %[[VAL_0:.*]] = arith.constant 1 : index
// CHECK: %[[VAL_1:.*]] = arith.constant 42 : index
// CHECK: omp.parallel {
// CHECK: omp.single nowait {
// CHECK: omp.parallel {
// CHECK: omp.wsloop nowait {
// CHECK: omp.loop_nest (%[[VAL_2:.*]]) : index = (%[[VAL_0]]) to (%[[VAL_1]]) inclusive step (%[[VAL_0]]) {
// CHECK: "test.test2"() : () -> ()
// CHECK: omp.yield
// CHECK: }
// CHECK: omp.terminator
// CHECK: }
// CHECK: omp.terminator
// CHECK: }
// CHECK: omp.terminator
// CHECK: }
// CHECK: omp.terminator
// CHECK: }
// CHECK: return
// CHECK: }

140 changes: 140 additions & 0 deletions flang/test/Transforms/OpenMP/should-use-workshare-lowering.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
// RUN: fir-opt --bufferize-hlfir %s | FileCheck %s

// Checks that we correctly identify when to use the lowering to
// omp.workshare_loop_wrapper

// CHECK-LABEL: @should_parallelize_0
// CHECK: omp.workshare_loop_wrapper
func.func @should_parallelize_0(%arg: !fir.ref<!fir.array<42xi32>>, %idx : index) {
omp.workshare {
%c42 = arith.constant 42 : index
%c1_i32 = arith.constant 1 : i32
%shape = fir.shape %c42 : (index) -> !fir.shape<1>
%array:2 = hlfir.declare %arg(%shape) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
%elemental = hlfir.elemental %shape unordered : (!fir.shape<1>) -> !hlfir.expr<42xi32> {
^bb0(%i: index):
hlfir.yield_element %c1_i32 : i32
}
hlfir.assign %elemental to %array#0 : !hlfir.expr<42xi32>, !fir.ref<!fir.array<42xi32>>
hlfir.destroy %elemental : !hlfir.expr<42xi32>
omp.terminator
}
return
}

// CHECK-LABEL: @should_parallelize_1
// CHECK: omp.workshare_loop_wrapper
func.func @should_parallelize_1(%arg: !fir.ref<!fir.array<42xi32>>, %idx : index) {
omp.parallel {
omp.workshare {
%c42 = arith.constant 42 : index
%c1_i32 = arith.constant 1 : i32
%shape = fir.shape %c42 : (index) -> !fir.shape<1>
%array:2 = hlfir.declare %arg(%shape) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
%elemental = hlfir.elemental %shape unordered : (!fir.shape<1>) -> !hlfir.expr<42xi32> {
^bb0(%i: index):
hlfir.yield_element %c1_i32 : i32
}
hlfir.assign %elemental to %array#0 : !hlfir.expr<42xi32>, !fir.ref<!fir.array<42xi32>>
hlfir.destroy %elemental : !hlfir.expr<42xi32>
omp.terminator
}
omp.terminator
}
return
}


// CHECK-LABEL: @should_not_parallelize_0
// CHECK-NOT: omp.workshare_loop_wrapper
func.func @should_not_parallelize_0(%arg: !fir.ref<!fir.array<42xi32>>, %idx : index) {
omp.workshare {
omp.single {
%c42 = arith.constant 42 : index
%c1_i32 = arith.constant 1 : i32
%shape = fir.shape %c42 : (index) -> !fir.shape<1>
%array:2 = hlfir.declare %arg(%shape) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
%elemental = hlfir.elemental %shape unordered : (!fir.shape<1>) -> !hlfir.expr<42xi32> {
^bb0(%i: index):
hlfir.yield_element %c1_i32 : i32
}
hlfir.assign %elemental to %array#0 : !hlfir.expr<42xi32>, !fir.ref<!fir.array<42xi32>>
hlfir.destroy %elemental : !hlfir.expr<42xi32>
omp.terminator
}
omp.terminator
}
return
}

// CHECK-LABEL: @should_not_parallelize_1
// CHECK-NOT: omp.workshare_loop_wrapper
func.func @should_not_parallelize_1(%arg: !fir.ref<!fir.array<42xi32>>, %idx : index) {
omp.workshare {
omp.critical {
%c42 = arith.constant 42 : index
%c1_i32 = arith.constant 1 : i32
%shape = fir.shape %c42 : (index) -> !fir.shape<1>
%array:2 = hlfir.declare %arg(%shape) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
%elemental = hlfir.elemental %shape unordered : (!fir.shape<1>) -> !hlfir.expr<42xi32> {
^bb0(%i: index):
hlfir.yield_element %c1_i32 : i32
}
hlfir.assign %elemental to %array#0 : !hlfir.expr<42xi32>, !fir.ref<!fir.array<42xi32>>
hlfir.destroy %elemental : !hlfir.expr<42xi32>
omp.terminator
}
omp.terminator
}
return
}

// CHECK-LABEL: @should_not_parallelize_2
// CHECK-NOT: omp.workshare_loop_wrapper
func.func @should_not_parallelize_2(%arg: !fir.ref<!fir.array<42xi32>>, %idx : index) {
omp.workshare {
omp.parallel {
%c42 = arith.constant 42 : index
%c1_i32 = arith.constant 1 : i32
%shape = fir.shape %c42 : (index) -> !fir.shape<1>
%array:2 = hlfir.declare %arg(%shape) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
%elemental = hlfir.elemental %shape unordered : (!fir.shape<1>) -> !hlfir.expr<42xi32> {
^bb0(%i: index):
hlfir.yield_element %c1_i32 : i32
}
hlfir.assign %elemental to %array#0 : !hlfir.expr<42xi32>, !fir.ref<!fir.array<42xi32>>
hlfir.destroy %elemental : !hlfir.expr<42xi32>
omp.terminator
}
omp.terminator
}
return
}

// CHECK-LABEL: @should_not_parallelize_3
// CHECK-NOT: omp.workshare_loop_wrapper
func.func @should_not_parallelize_3(%arg: !fir.ref<!fir.array<42xi32>>, %idx : index) {
omp.workshare {
omp.parallel {
omp.workshare {
omp.parallel {
%c42 = arith.constant 42 : index
%c1_i32 = arith.constant 1 : i32
%shape = fir.shape %c42 : (index) -> !fir.shape<1>
%array:2 = hlfir.declare %arg(%shape) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
%elemental = hlfir.elemental %shape unordered : (!fir.shape<1>) -> !hlfir.expr<42xi32> {
^bb0(%i: index):
hlfir.yield_element %c1_i32 : i32
}
hlfir.assign %elemental to %array#0 : !hlfir.expr<42xi32>, !fir.ref<!fir.array<42xi32>>
hlfir.destroy %elemental : !hlfir.expr<42xi32>
omp.terminator
}
omp.terminator
}
omp.terminator
}
omp.terminator
}
return
}
5 changes: 4 additions & 1 deletion flang/tools/bbc/bbc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,8 @@ static llvm::LogicalResult convertFortranSourceToMLIR(

if (emitFIR && useHLFIR) {
// lower HLFIR to FIR
fir::createHLFIRToFIRPassPipeline(pm, llvm::OptimizationLevel::O2);
fir::createHLFIRToFIRPassPipeline(pm, enableOpenMP,
llvm::OptimizationLevel::O2);
if (mlir::failed(pm.run(mlirModule))) {
llvm::errs() << "FATAL: lowering from HLFIR to FIR failed";
return mlir::failure();
Expand All @@ -444,6 +445,8 @@ static llvm::LogicalResult convertFortranSourceToMLIR(

// Add O2 optimizer pass pipeline.
MLIRToLLVMPassPipelineConfig config(llvm::OptimizationLevel::O2);
if (enableOpenMP)
config.EnableOpenMP = true;
config.NSWOnLoopVarInc = setNSW;
fir::registerDefaultInlinerPass(config);
fir::createDefaultFIROptimizerPassPipeline(pm, config);
Expand Down
1 change: 1 addition & 0 deletions flang/tools/tco/tco.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ compileFIR(const mlir::PassPipelineCLParser &passPipeline) {
return mlir::failure();
} else {
MLIRToLLVMPassPipelineConfig config(llvm::OptimizationLevel::O2);
config.EnableOpenMP = true; // assume the input contains OpenMP
config.AliasAnalysis = true; // enabled when optimizing for speed
if (codeGenLLVM) {
// Run only CodeGen passes.
Expand Down
2 changes: 2 additions & 0 deletions mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,8 @@ using TeamsOperands =
detail::Clauses<AllocateClauseOps, IfClauseOps, NumTeamsClauseOps,
PrivateClauseOps, ReductionClauseOps, ThreadLimitClauseOps>;

using WorkshareOperands = detail::Clauses<NowaitClauseOps>;

using WsloopOperands =
detail::Clauses<AllocateClauseOps, LinearClauseOps, NowaitClauseOps,
OrderClauseOps, OrderedClauseOps, PrivateClauseOps,
Expand Down
43 changes: 43 additions & 0 deletions mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,49 @@ def SingleOp : OpenMP_Op<"single", traits = [
let hasVerifier = 1;
}

//===----------------------------------------------------------------------===//
// 2.8.3 Workshare Construct
//===----------------------------------------------------------------------===//

def WorkshareOp : OpenMP_Op<"workshare", traits = [
RecursiveMemoryEffects,
], clauses = [
OpenMP_NowaitClause,
], singleRegion = true> {
let summary = "workshare directive";
let description = [{
The workshare construct divides the execution of the enclosed structured
block into separate units of work, and causes the threads of the team to
share the work such that each unit is executed only once by one thread, in
the context of its implicit task

This operation is used for the intermediate representation of the workshare
block before the work gets divided between the threads. See the flang
LowerWorkshare pass for details.
}] # clausesDescription;

let builders = [
OpBuilder<(ins CArg<"const WorkshareOperands &">:$clauses)>
];
}

def WorkshareLoopWrapperOp : OpenMP_Op<"workshare_loop_wrapper", traits = [
DeclareOpInterfaceMethods<LoopWrapperInterface>,
RecursiveMemoryEffects, SingleBlock
], singleRegion = true> {
let summary = "contains loop nests to be parallelized by workshare";
let description = [{
This operation wraps a loop nest that is marked for dividing into units of
work by an encompassing omp.workshare operation.
}];

let builders = [
OpBuilder<(ins), [{ build($_builder, $_state, {}); }]>
];
let assemblyFormat = "$region attr-dict";
let hasVerifier = 1;
}

//===----------------------------------------------------------------------===//
// Loop Nest
//===----------------------------------------------------------------------===//
Expand Down
23 changes: 23 additions & 0 deletions mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1683,6 +1683,29 @@ LogicalResult SingleOp::verify() {
getCopyprivateSyms());
}

//===----------------------------------------------------------------------===//
// WorkshareOp
//===----------------------------------------------------------------------===//

void WorkshareOp::build(OpBuilder &builder, OperationState &state,
const WorkshareOperands &clauses) {
WorkshareOp::build(builder, state, clauses.nowait);
}

//===----------------------------------------------------------------------===//
// WorkshareLoopWrapperOp
//===----------------------------------------------------------------------===//

LogicalResult WorkshareLoopWrapperOp::verify() {
if (!isWrapper())
return emitOpError() << "must be a loop wrapper";
if (getNestedWrapper())
return emitError() << "nested wrappers not supported";
if (!(*this)->getParentOfType<WorkshareOp>())
return emitError() << "must be nested in an omp.workshare";
return success();
}

//===----------------------------------------------------------------------===//
// WsloopOp
//===----------------------------------------------------------------------===//
Expand Down
42 changes: 42 additions & 0 deletions mlir/test/Dialect/OpenMP/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2383,3 +2383,45 @@ func.func @masked_arg_count_mismatch(%arg0: i32, %arg1: i32) {
}) : (i32, i32) -> ()
return
}

// -----
func.func @nested_wrapper(%idx : index) {
omp.workshare {
// expected-error @below {{nested wrappers not supported}}
omp.workshare_loop_wrapper {
omp.simd {
omp.loop_nest (%iv) : index = (%idx) to (%idx) step (%idx) {
omp.yield
}
omp.terminator
}
omp.terminator
}
omp.terminator
}
return
}

// -----
func.func @not_wrapper() {
omp.workshare {
// expected-error @below {{must be a loop wrapper}}
omp.workshare_loop_wrapper {
omp.terminator
}
omp.terminator
}
return
}

// -----
func.func @missing_workshare(%idx : index) {
// expected-error @below {{must be nested in an omp.workshare}}
omp.workshare_loop_wrapper {
omp.loop_nest (%iv) : index = (%idx) to (%idx) step (%idx) {
omp.yield
}
omp.terminator
}
return
}
69 changes: 69 additions & 0 deletions mlir/test/Dialect/OpenMP/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2789,3 +2789,72 @@ func.func @omp_target_private(%map1: memref<?xi32>, %map2: memref<?xi32>, %priv_

return
}

// CHECK-LABEL: func @omp_workshare
func.func @omp_workshare() {
// CHECK: omp.workshare {
omp.workshare {
"test.payload"() : () -> ()
// CHECK: omp.terminator
omp.terminator
}
return
}

// CHECK-LABEL: func @omp_workshare_nowait
func.func @omp_workshare_nowait() {
// CHECK: omp.workshare nowait {
omp.workshare nowait {
"test.payload"() : () -> ()
// CHECK: omp.terminator
omp.terminator
}
return
}

// CHECK-LABEL: func @omp_workshare_multiple_blocks
func.func @omp_workshare_multiple_blocks() {
// CHECK: omp.workshare {
omp.workshare {
cf.br ^bb2
^bb2:
// CHECK: omp.terminator
omp.terminator
}
return
}

// CHECK-LABEL: func @omp_workshare_loop_wrapper
func.func @omp_workshare_loop_wrapper(%idx : index) {
// CHECK-NEXT: omp.workshare {
omp.workshare {
// CHECK-NEXT: omp.workshare_loop_wrapper
omp.workshare_loop_wrapper {
// CHECK-NEXT: omp.loop_nest
omp.loop_nest (%iv) : index = (%idx) to (%idx) step (%idx) {
omp.yield
}
omp.terminator
}
omp.terminator
}
return
}

// CHECK-LABEL: func @omp_workshare_loop_wrapper_attrs
func.func @omp_workshare_loop_wrapper_attrs(%idx : index) {
// CHECK-NEXT: omp.workshare {
omp.workshare {
// CHECK-NEXT: omp.workshare_loop_wrapper {
omp.workshare_loop_wrapper {
// CHECK-NEXT: omp.loop_nest
omp.loop_nest (%iv) : index = (%idx) to (%idx) step (%idx) {
omp.yield
}
omp.terminator
// CHECK: } {attr_in_dict}
} {attr_in_dict}
omp.terminator
}
return
}