391 changes: 391 additions & 0 deletions flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,391 @@
//===- LowerWorkshare.cpp - special cases for bufferization -------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
// Lower omp workshare construct.
//===----------------------------------------------------------------------===//

#include <flang/Optimizer/Builder/FIRBuilder.h>
#include <flang/Optimizer/Dialect/FIROps.h>
#include <flang/Optimizer/Dialect/FIRType.h>
#include <flang/Optimizer/HLFIR/HLFIROps.h>
#include <flang/Optimizer/OpenMP/Passes.h>
#include <llvm/ADT/STLExtras.h>
#include <llvm/ADT/SmallVectorExtras.h>
#include <llvm/ADT/iterator_range.h>
#include <llvm/Support/ErrorHandling.h>
#include <mlir/Dialect/Arith/IR/Arith.h>
#include <mlir/Dialect/LLVMIR/LLVMTypes.h>
#include <mlir/Dialect/OpenMP/OpenMPDialect.h>
#include <mlir/Dialect/SCF/IR/SCF.h>
#include <mlir/IR/BuiltinOps.h>
#include <mlir/IR/IRMapping.h>
#include <mlir/IR/OpDefinition.h>
#include <mlir/IR/PatternMatch.h>
#include <mlir/IR/Visitors.h>
#include <mlir/Interfaces/SideEffectInterfaces.h>
#include <mlir/Support/LLVM.h>
#include <mlir/Transforms/GreedyPatternRewriteDriver.h>

#include <variant>

namespace flangomp {
#define GEN_PASS_DEF_LOWERWORKSHARE
#include "flang/Optimizer/OpenMP/Passes.h.inc"
} // namespace flangomp

#define DEBUG_TYPE "lower-workshare"

using namespace mlir;

namespace flangomp {
bool shouldUseWorkshareLowering(Operation *op) {
// TODO this is insufficient, as we could have
// omp.parallel {
// omp.workshare {
// omp.parallel {
// hlfir.elemental {}
//
// Then this hlfir.elemental shall _not_ use the lowering for workshare
//
// Standard says:
// For a parallel construct, the construct is a unit of work with respect to
// the workshare construct. The statements contained in the parallel
// construct are executed by a new thread team.
//
// TODO similarly for single, critical, etc. Need to think through the
// patterns and implement this function.
//
return op->getParentOfType<omp::WorkshareOp>();
}
} // namespace flangomp

namespace {

struct SingleRegion {
Block::iterator begin, end;
};

static bool mustParallelizeOp(Operation *op) {
// TODO as in shouldUseWorkshareLowering we be careful not to pick up
// workshare_loop_wrapper in nested omp.parallel ops
//
// e.g.
//
// omp.parallel {
// omp.workshare {
// omp.parallel {
// omp.workshare {
// omp.workshare_loop_wrapper {}
return op
->walk(
[](omp::WorkshareLoopWrapperOp) { return WalkResult::interrupt(); })
.wasInterrupted();
}

static bool isSafeToParallelize(Operation *op) {
return isa<hlfir::DeclareOp>(op) || isa<fir::DeclareOp>(op) ||
isMemoryEffectFree(op);
}

static mlir::func::FuncOp createCopyFunc(mlir::Location loc, mlir::Type varType,
fir::FirOpBuilder builder) {
mlir::ModuleOp module = builder.getModule();
auto rt = cast<fir::ReferenceType>(varType);
mlir::Type eleTy = rt.getEleTy();
std::string copyFuncName =
fir::getTypeAsString(eleTy, builder.getKindMap(), "_workshare_copy");

if (auto decl = module.lookupSymbol<mlir::func::FuncOp>(copyFuncName))
return decl;
// create function
mlir::OpBuilder::InsertionGuard guard(builder);
mlir::OpBuilder modBuilder(module.getBodyRegion());
llvm::SmallVector<mlir::Type> argsTy = {varType, varType};
auto funcType = mlir::FunctionType::get(builder.getContext(), argsTy, {});
mlir::func::FuncOp funcOp =
modBuilder.create<mlir::func::FuncOp>(loc, copyFuncName, funcType);
funcOp.setVisibility(mlir::SymbolTable::Visibility::Private);
builder.createBlock(&funcOp.getRegion(), funcOp.getRegion().end(), argsTy,
{loc, loc});
builder.setInsertionPointToStart(&funcOp.getRegion().back());

Value loaded = builder.create<fir::LoadOp>(loc, funcOp.getArgument(0));
builder.create<fir::StoreOp>(loc, loaded, funcOp.getArgument(1));

builder.create<mlir::func::ReturnOp>(loc);
return funcOp;
}

static bool isUserOutsideSR(Operation *user, Operation *parentOp,
SingleRegion sr) {
while (user->getParentOp() != parentOp)
user = user->getParentOp();
return sr.begin->getBlock() != user->getBlock() ||
!(user->isBeforeInBlock(&*sr.end) && sr.begin->isBeforeInBlock(user));
}

static bool isTransitivelyUsedOutside(Value v, SingleRegion sr) {
Block *srBlock = sr.begin->getBlock();
Operation *parentOp = srBlock->getParentOp();

for (auto &use : v.getUses()) {
Operation *user = use.getOwner();
if (isUserOutsideSR(user, parentOp, sr))
return true;

// Results of nested users cannot be used outside of the SR
if (user->getBlock() != srBlock)
continue;

// A non-safe to parallelize operation will be handled separately
if (!isSafeToParallelize(user))
continue;

for (auto res : user->getResults())
if (isTransitivelyUsedOutside(res, sr))
return true;
}
return false;
}

static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
IRMapping &rootMapping, Location loc) {
OpBuilder rootBuilder(sourceRegion.getContext());
ModuleOp m = sourceRegion.getParentOfType<ModuleOp>();
OpBuilder copyFuncBuilder(m.getBodyRegion());
fir::FirOpBuilder firCopyFuncBuilder(copyFuncBuilder, m);

auto mapReloadedValue =
[&](Value v, OpBuilder allocaBuilder, OpBuilder singleBuilder,
OpBuilder parallelBuilder, IRMapping singleMapping) -> Value {
if (auto reloaded = rootMapping.lookupOrNull(v))
return nullptr;
Type ty = v.getType();
Value alloc = allocaBuilder.create<fir::AllocaOp>(loc, ty);
singleBuilder.create<fir::StoreOp>(loc, singleMapping.lookup(v), alloc);
Value reloaded = parallelBuilder.create<fir::LoadOp>(loc, ty, alloc);
rootMapping.map(v, reloaded);
return alloc;
};

auto moveToSingle = [&](SingleRegion sr, OpBuilder allocaBuilder,
OpBuilder singleBuilder,
OpBuilder parallelBuilder) -> SmallVector<Value> {
IRMapping singleMapping = rootMapping;
SmallVector<Value> copyPrivate;

for (Operation &op : llvm::make_range(sr.begin, sr.end)) {
if (isSafeToParallelize(&op)) {
singleBuilder.clone(op, singleMapping);
parallelBuilder.clone(op, rootMapping);
} else if (auto alloca = dyn_cast<fir::AllocaOp>(&op)) {
auto hoisted =
cast<fir::AllocaOp>(allocaBuilder.clone(*alloca, singleMapping));
rootMapping.map(&*alloca, &*hoisted);
rootMapping.map(alloca.getResult(), hoisted.getResult());
copyPrivate.push_back(hoisted);
} else {
singleBuilder.clone(op, singleMapping);
// Prepare reloaded values for results of operations that cannot be
// safely parallelized and which are used after the region `sr`
for (auto res : op.getResults()) {
if (isTransitivelyUsedOutside(res, sr)) {
auto alloc = mapReloadedValue(res, allocaBuilder, singleBuilder,
parallelBuilder, singleMapping);
if (alloc)
copyPrivate.push_back(alloc);
}
}
}
}
singleBuilder.create<omp::TerminatorOp>(loc);
return copyPrivate;
};

// TODO Need to handle these (clone them) in dominator tree order
for (Block &block : sourceRegion) {
rootBuilder.createBlock(
&targetRegion, {}, block.getArgumentTypes(),
llvm::map_to_vector(block.getArguments(),
[](BlockArgument arg) { return arg.getLoc(); }));
Operation *terminator = block.getTerminator();

SmallVector<std::variant<SingleRegion, Operation *>> regions;

auto it = block.begin();
auto getOneRegion = [&]() {
if (&*it == terminator)
return false;
if (mustParallelizeOp(&*it)) {
regions.push_back(&*it);
it++;
return true;
}
SingleRegion sr;
sr.begin = it;
while (&*it != terminator && !mustParallelizeOp(&*it))
it++;
sr.end = it;
assert(sr.begin != sr.end);
regions.push_back(sr);
return true;
};
while (getOneRegion())
;

for (auto [i, opOrSingle] : llvm::enumerate(regions)) {
bool isLast = i + 1 == regions.size();
if (std::holds_alternative<SingleRegion>(opOrSingle)) {
OpBuilder singleBuilder(sourceRegion.getContext());
Block *singleBlock = new Block();
singleBuilder.setInsertionPointToStart(singleBlock);

OpBuilder allocaBuilder(sourceRegion.getContext());
Block *allocaBlock = new Block();
allocaBuilder.setInsertionPointToStart(allocaBlock);

OpBuilder parallelBuilder(sourceRegion.getContext());
Block *parallelBlock = new Block();
parallelBuilder.setInsertionPointToStart(parallelBlock);

omp::SingleOperands singleOperands;
if (isLast)
singleOperands.nowait = rootBuilder.getUnitAttr();
singleOperands.copyprivateVars =
moveToSingle(std::get<SingleRegion>(opOrSingle), allocaBuilder,
singleBuilder, parallelBuilder);
for (auto var : singleOperands.copyprivateVars) {
Type ty;
if (auto firAlloca = var.getDefiningOp<fir::AllocaOp>()) {
ty = firAlloca.getAllocatedType();
} else {
ty = LLVM::LLVMPointerType::get(allocaBuilder.getContext());
}
mlir::func::FuncOp funcOp =
createCopyFunc(loc, var.getType(), firCopyFuncBuilder);
singleOperands.copyprivateSyms.push_back(SymbolRefAttr::get(funcOp));
}
omp::SingleOp singleOp =
rootBuilder.create<omp::SingleOp>(loc, singleOperands);
singleOp.getRegion().push_back(singleBlock);
rootBuilder.getInsertionBlock()->getOperations().splice(
rootBuilder.getInsertionPoint(), parallelBlock->getOperations());
targetRegion.front().getOperations().splice(
singleOp->getIterator(), allocaBlock->getOperations());
delete allocaBlock;
delete parallelBlock;
} else {
auto op = std::get<Operation *>(opOrSingle);
if (auto wslw = dyn_cast<omp::WorkshareLoopWrapperOp>(op)) {
omp::WsloopOperands wsloopOperands;
if (isLast)
wsloopOperands.nowait = rootBuilder.getUnitAttr();
auto wsloop =
rootBuilder.create<mlir::omp::WsloopOp>(loc, wsloopOperands);
auto clonedWslw = cast<omp::WorkshareLoopWrapperOp>(
rootBuilder.clone(*wslw, rootMapping));
wsloop.getRegion().takeBody(clonedWslw.getRegion());
clonedWslw->erase();
} else {
assert(mustParallelizeOp(op));
Operation *cloned = rootBuilder.cloneWithoutRegions(*op, rootMapping);
for (auto [region, clonedRegion] :
llvm::zip(op->getRegions(), cloned->getRegions()))
parallelizeRegion(region, clonedRegion, rootMapping, loc);
}
}
}

rootBuilder.clone(*block.getTerminator(), rootMapping);
}
}

/// Lowers workshare to a sequence of single-thread regions and parallel loops
///
/// For example:
///
/// omp.workshare {
/// %a = fir.allocmem
/// omp.workshare_loop_wrapper {}
/// fir.call Assign %b %a
/// fir.freemem %a
/// }
///
/// becomes
///
/// omp.single {
/// %a = fir.allocmem
/// fir.store %a %tmp
/// }
/// %a_reloaded = fir.load %tmp
/// omp.workshare_loop_wrapper {}
/// omp.single {
/// fir.call Assign %b %a_reloaded
/// fir.freemem %a_reloaded
/// }
///
/// Note that we allocate temporary memory for values in omp.single's which need
/// to be accessed in all threads in the closest omp.parallel
void lowerWorkshare(mlir::omp::WorkshareOp wsOp) {
Location loc = wsOp->getLoc();
IRMapping rootMapping;

OpBuilder rootBuilder(wsOp);

// TODO We need something like an scf;execute here, but that is not registered
// so using fir.if for now but it looks like it does not support multiple
// blocks so it doesnt work for multi block case...
auto ifOp = rootBuilder.create<fir::IfOp>(
loc, rootBuilder.create<arith::ConstantIntOp>(loc, 1, 1), false);
ifOp.getThenRegion().front().erase();

parallelizeRegion(wsOp.getRegion(), ifOp.getThenRegion(), rootMapping, loc);

Operation *terminatorOp = ifOp.getThenRegion().back().getTerminator();
assert(isa<omp::TerminatorOp>(terminatorOp));
OpBuilder termBuilder(terminatorOp);

if (!wsOp.getNowait())
termBuilder.create<omp::BarrierOp>(loc);

termBuilder.create<fir::ResultOp>(loc, ValueRange());

terminatorOp->erase();
wsOp->erase();

return;
}

class LowerWorksharePass
: public flangomp::impl::LowerWorkshareBase<LowerWorksharePass> {
public:
void runOnOperation() override {
SmallPtrSet<Operation *, 8> parents;
getOperation()->walk([&](mlir::omp::WorkshareOp wsOp) {
Operation *isolatedParent =
wsOp->getParentWithTrait<OpTrait::IsIsolatedFromAbove>();
parents.insert(isolatedParent);

lowerWorkshare(wsOp);
});

// Do folding
for (Operation *isolatedParent : parents) {
RewritePatternSet patterns(&getContext());
GreedyRewriteConfig config;
// prevent the pattern driver form merging blocks
config.enableRegionSimplification =
mlir::GreedySimplifyRegionLevel::Disabled;
if (failed(applyPatternsAndFoldGreedily(isolatedParent,
std::move(patterns), config))) {
emitError(isolatedParent->getLoc(), "error in lower workshare\n");
signalPassFailure();
}
}
}
};
} // namespace
58 changes: 58 additions & 0 deletions flang/test/HLFIR/bufferize-workshare.fir
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
// RUN: fir-opt --bufferize-hlfir %s | FileCheck %s

// CHECK-LABEL: func.func @simple(
// CHECK-SAME: %[[VAL_0:.*]]: !fir.ref<!fir.array<42xi32>>) {
// CHECK: omp.parallel {
// CHECK: omp.workshare {
// CHECK: %[[VAL_1:.*]] = arith.constant 42 : index
// CHECK: %[[VAL_2:.*]] = arith.constant 1 : i32
// CHECK: %[[VAL_3:.*]] = fir.shape %[[VAL_1]] : (index) -> !fir.shape<1>
// CHECK: %[[VAL_4:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_3]]) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
// CHECK: %[[VAL_5:.*]] = fir.allocmem !fir.array<42xi32> {bindc_name = ".tmp.array", uniq_name = ""}
// CHECK: %[[VAL_6:.*]]:2 = hlfir.declare %[[VAL_5]](%[[VAL_3]]) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
// CHECK: %[[VAL_7:.*]] = arith.constant true
// CHECK: %[[VAL_8:.*]] = arith.constant 1 : index
// CHECK: "omp.workshare_loop_wrapper"() ({
// CHECK: omp.loop_nest (%[[VAL_9:.*]]) : index = (%[[VAL_8]]) to (%[[VAL_1]]) inclusive step (%[[VAL_8]]) {
// CHECK: %[[VAL_10:.*]] = hlfir.designate %[[VAL_4]]#0 (%[[VAL_9]]) : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
// CHECK: %[[VAL_11:.*]] = fir.load %[[VAL_10]] : !fir.ref<i32>
// CHECK: %[[VAL_12:.*]] = arith.subi %[[VAL_11]], %[[VAL_2]] : i32
// CHECK: %[[VAL_13:.*]] = hlfir.designate %[[VAL_6]]#0 (%[[VAL_9]]) : (!fir.heap<!fir.array<42xi32>>, index) -> !fir.ref<i32>
// CHECK: hlfir.assign %[[VAL_12]] to %[[VAL_13]] temporary_lhs : i32, !fir.ref<i32>
// CHECK: omp.yield
// CHECK: }
// CHECK: omp.terminator
// CHECK: }) : () -> ()
// CHECK: %[[VAL_14:.*]] = fir.undefined tuple<!fir.heap<!fir.array<42xi32>>, i1>
// CHECK: %[[VAL_15:.*]] = fir.insert_value %[[VAL_14]], %[[VAL_7]], [1 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, i1) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>
// CHECK: %[[VAL_16:.*]] = fir.insert_value %[[VAL_15]], %[[VAL_6]]#0, [0 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, !fir.heap<!fir.array<42xi32>>) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>
// CHECK: hlfir.assign %[[VAL_6]]#0 to %[[VAL_4]]#0 : !fir.heap<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>
// CHECK: fir.freemem %[[VAL_6]]#0 : !fir.heap<!fir.array<42xi32>>
// CHECK: omp.terminator
// CHECK: }
// CHECK: omp.terminator
// CHECK: }
// CHECK: return
// CHECK: }
func.func @simple(%arg: !fir.ref<!fir.array<42xi32>>) {
omp.parallel {
omp.workshare {
%c42 = arith.constant 42 : index
%c1_i32 = arith.constant 1 : i32
%shape = fir.shape %c42 : (index) -> !fir.shape<1>
%array:2 = hlfir.declare %arg(%shape) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
%elemental = hlfir.elemental %shape unordered : (!fir.shape<1>) -> !hlfir.expr<42xi32> {
^bb0(%i: index):
%ref = hlfir.designate %array#0 (%i) : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
%val = fir.load %ref : !fir.ref<i32>
%sub = arith.subi %val, %c1_i32 : i32
hlfir.yield_element %sub : i32
}
hlfir.assign %elemental to %array#0 : !hlfir.expr<42xi32>, !fir.ref<!fir.array<42xi32>>
hlfir.destroy %elemental : !hlfir.expr<42xi32>
omp.terminator
}
omp.terminator
}
return
}
191 changes: 191 additions & 0 deletions flang/test/Transforms/OpenMP/lower-workshare.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
// RUN: fir-opt --split-input-file --lower-workshare --allow-unregistered-dialect %s | FileCheck %s

// checks:
// nowait on final omp.single
func.func @wsfunc(%arg0: !fir.ref<!fir.array<42xi32>>) {
omp.parallel {
omp.workshare {
%c42 = arith.constant 42 : index
%c1_i32 = arith.constant 1 : i32
%0 = fir.shape %c42 : (index) -> !fir.shape<1>
%1:2 = hlfir.declare %arg0(%0) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
%2 = fir.allocmem !fir.array<42xi32> {bindc_name = ".tmp.array", uniq_name = ""}
%3:2 = hlfir.declare %2(%0) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
%true = arith.constant true
%c1 = arith.constant 1 : index
"omp.workshare_loop_wrapper"() ({
omp.loop_nest (%arg1) : index = (%c1) to (%c42) inclusive step (%c1) {
%7 = hlfir.designate %1#0 (%arg1) : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
%8 = fir.load %7 : !fir.ref<i32>
%9 = arith.subi %8, %c1_i32 : i32
%10 = hlfir.designate %3#0 (%arg1) : (!fir.heap<!fir.array<42xi32>>, index) -> !fir.ref<i32>
hlfir.assign %9 to %10 temporary_lhs : i32, !fir.ref<i32>
omp.yield
}
omp.terminator
}) : () -> ()
%4 = fir.undefined tuple<!fir.heap<!fir.array<42xi32>>, i1>
%5 = fir.insert_value %4, %true, [1 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, i1) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>
%6 = fir.insert_value %5, %3#0, [0 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, !fir.heap<!fir.array<42xi32>>) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>
hlfir.assign %3#0 to %1#0 : !fir.heap<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>
fir.freemem %3#0 : !fir.heap<!fir.array<42xi32>>
omp.terminator
}
omp.terminator
}
return
}

// -----

// checks:
// fir.alloca hoisted out and copyprivate'd
func.func @wsfunc(%arg0: !fir.ref<!fir.array<42xi32>>) {
omp.workshare {
%c1_i32 = arith.constant 1 : i32
%alloc = fir.alloca i32
fir.store %c1_i32 to %alloc : !fir.ref<i32>
%c42 = arith.constant 42 : index
%0 = fir.shape %c42 : (index) -> !fir.shape<1>
%1:2 = hlfir.declare %arg0(%0) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
%2 = fir.allocmem !fir.array<42xi32> {bindc_name = ".tmp.array", uniq_name = ""}
%3:2 = hlfir.declare %2(%0) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
%true = arith.constant true
%c1 = arith.constant 1 : index
"omp.workshare_loop_wrapper"() ({
omp.loop_nest (%arg1) : index = (%c1) to (%c42) inclusive step (%c1) {
%7 = hlfir.designate %1#0 (%arg1) : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
%8 = fir.load %7 : !fir.ref<i32>
%ld = fir.load %alloc : !fir.ref<i32>
%n8 = arith.subi %8, %ld : i32
%9 = arith.subi %n8, %c1_i32 : i32
%10 = hlfir.designate %3#0 (%arg1) : (!fir.heap<!fir.array<42xi32>>, index) -> !fir.ref<i32>
hlfir.assign %9 to %10 temporary_lhs : i32, !fir.ref<i32>
omp.yield
}
omp.terminator
}) : () -> ()
%4 = fir.undefined tuple<!fir.heap<!fir.array<42xi32>>, i1>
%5 = fir.insert_value %4, %true, [1 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, i1) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>
%6 = fir.insert_value %5, %3#0, [0 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, !fir.heap<!fir.array<42xi32>>) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>
"test.test1"(%alloc) : (!fir.ref<i32>) -> ()
hlfir.assign %3#0 to %1#0 : !fir.heap<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>
fir.freemem %3#0 : !fir.heap<!fir.array<42xi32>>
omp.terminator
}
return
}

// CHECK-LABEL: func.func private @_workshare_copy_heap_42xi32(
// CHECK-SAME: %[[VAL_0:.*]]: !fir.ref<!fir.heap<!fir.array<42xi32>>>,
// CHECK-SAME: %[[VAL_1:.*]]: !fir.ref<!fir.heap<!fir.array<42xi32>>>) {
// CHECK: %[[VAL_2:.*]] = fir.load %[[VAL_0]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
// CHECK: fir.store %[[VAL_2]] to %[[VAL_1]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
// CHECK: return
// CHECK: }

// CHECK-LABEL: func.func @wsfunc(
// CHECK-SAME: %[[VAL_0:.*]]: !fir.ref<!fir.array<42xi32>>) {
// CHECK: %[[VAL_1:.*]] = arith.constant 1 : index
// CHECK: %[[VAL_2:.*]] = arith.constant 1 : i32
// CHECK: %[[VAL_3:.*]] = arith.constant 42 : index
// CHECK: %[[VAL_4:.*]] = arith.constant true
// CHECK: omp.parallel {
// CHECK: fir.if %[[VAL_4]] {
// CHECK: %[[VAL_5:.*]] = fir.alloca !fir.heap<!fir.array<42xi32>>
// CHECK: omp.single copyprivate(%[[VAL_5]] -> @_workshare_copy_heap_42xi32 : !fir.ref<!fir.heap<!fir.array<42xi32>>>) {
// CHECK: %[[VAL_6:.*]] = fir.shape %[[VAL_3]] : (index) -> !fir.shape<1>
// CHECK: %[[VAL_7:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_6]]) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
// CHECK: %[[VAL_8:.*]] = fir.allocmem !fir.array<42xi32> {bindc_name = ".tmp.array", uniq_name = ""}
// CHECK: fir.store %[[VAL_8]] to %[[VAL_5]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
// CHECK: %[[VAL_9:.*]]:2 = hlfir.declare %[[VAL_8]](%[[VAL_6]]) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
// CHECK: omp.terminator
// CHECK: }
// CHECK: %[[VAL_10:.*]] = fir.shape %[[VAL_3]] : (index) -> !fir.shape<1>
// CHECK: %[[VAL_11:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_10]]) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
// CHECK: %[[VAL_12:.*]] = fir.load %[[VAL_5]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
// CHECK: %[[VAL_13:.*]]:2 = hlfir.declare %[[VAL_12]](%[[VAL_10]]) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
// CHECK: omp.wsloop {
// CHECK: omp.loop_nest (%[[VAL_14:.*]]) : index = (%[[VAL_1]]) to (%[[VAL_3]]) inclusive step (%[[VAL_1]]) {
// CHECK: %[[VAL_15:.*]] = hlfir.designate %[[VAL_11]]#0 (%[[VAL_14]]) : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
// CHECK: %[[VAL_16:.*]] = fir.load %[[VAL_15]] : !fir.ref<i32>
// CHECK: %[[VAL_17:.*]] = arith.subi %[[VAL_16]], %[[VAL_2]] : i32
// CHECK: %[[VAL_18:.*]] = hlfir.designate %[[VAL_13]]#0 (%[[VAL_14]]) : (!fir.heap<!fir.array<42xi32>>, index) -> !fir.ref<i32>
// CHECK: hlfir.assign %[[VAL_17]] to %[[VAL_18]] temporary_lhs : i32, !fir.ref<i32>
// CHECK: omp.yield
// CHECK: }
// CHECK: omp.terminator
// CHECK: }
// CHECK: omp.single nowait {
// CHECK: hlfir.assign %[[VAL_13]]#0 to %[[VAL_11]]#0 : !fir.heap<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>
// CHECK: fir.freemem %[[VAL_13]]#0 : !fir.heap<!fir.array<42xi32>>
// CHECK: omp.terminator
// CHECK: }
// CHECK: omp.barrier
// CHECK: }
// CHECK: omp.terminator
// CHECK: }
// CHECK: return
// CHECK: }

// CHECK-LABEL: func.func private @_workshare_copy_heap_42xi32(
// CHECK-SAME: %[[VAL_0:.*]]: !fir.ref<!fir.heap<!fir.array<42xi32>>>,
// CHECK-SAME: %[[VAL_1:.*]]: !fir.ref<!fir.heap<!fir.array<42xi32>>>) {
// CHECK: %[[VAL_2:.*]] = fir.load %[[VAL_0]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
// CHECK: fir.store %[[VAL_2]] to %[[VAL_1]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
// CHECK: return
// CHECK: }

// CHECK-LABEL: func.func private @_workshare_copy_i32(
// CHECK-SAME: %[[VAL_0:.*]]: !fir.ref<i32>,
// CHECK-SAME: %[[VAL_1:.*]]: !fir.ref<i32>) {
// CHECK: %[[VAL_2:.*]] = fir.load %[[VAL_0]] : !fir.ref<i32>
// CHECK: fir.store %[[VAL_2]] to %[[VAL_1]] : !fir.ref<i32>
// CHECK: return
// CHECK: }

// CHECK-LABEL: func.func @wsfunc(
// CHECK-SAME: %[[VAL_0:.*]]: !fir.ref<!fir.array<42xi32>>) {
// CHECK: %[[VAL_1:.*]] = arith.constant 1 : index
// CHECK: %[[VAL_2:.*]] = arith.constant 42 : index
// CHECK: %[[VAL_3:.*]] = arith.constant 1 : i32
// CHECK: %[[VAL_4:.*]] = arith.constant true
// CHECK: fir.if %[[VAL_4]] {
// CHECK: %[[VAL_5:.*]] = fir.alloca i32
// CHECK: %[[VAL_6:.*]] = fir.alloca !fir.heap<!fir.array<42xi32>>
// CHECK: omp.single copyprivate(%[[VAL_5]] -> @_workshare_copy_i32 : !fir.ref<i32>, %[[VAL_6]] -> @_workshare_copy_heap_42xi32 : !fir.ref<!fir.heap<!fir.array<42xi32>>>) {
// CHECK: fir.store %[[VAL_3]] to %[[VAL_5]] : !fir.ref<i32>
// CHECK: %[[VAL_7:.*]] = fir.shape %[[VAL_2]] : (index) -> !fir.shape<1>
// CHECK: %[[VAL_8:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_7]]) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
// CHECK: %[[VAL_9:.*]] = fir.allocmem !fir.array<42xi32> {bindc_name = ".tmp.array", uniq_name = ""}
// CHECK: fir.store %[[VAL_9]] to %[[VAL_6]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
// CHECK: %[[VAL_10:.*]]:2 = hlfir.declare %[[VAL_9]](%[[VAL_7]]) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
// CHECK: omp.terminator
// CHECK: }
// CHECK: %[[VAL_11:.*]] = fir.shape %[[VAL_2]] : (index) -> !fir.shape<1>
// CHECK: %[[VAL_12:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_11]]) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
// CHECK: %[[VAL_13:.*]] = fir.load %[[VAL_6]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
// CHECK: %[[VAL_14:.*]]:2 = hlfir.declare %[[VAL_13]](%[[VAL_11]]) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
// CHECK: omp.wsloop {
// CHECK: omp.loop_nest (%[[VAL_15:.*]]) : index = (%[[VAL_1]]) to (%[[VAL_2]]) inclusive step (%[[VAL_1]]) {
// CHECK: %[[VAL_16:.*]] = hlfir.designate %[[VAL_12]]#0 (%[[VAL_15]]) : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
// CHECK: %[[VAL_17:.*]] = fir.load %[[VAL_16]] : !fir.ref<i32>
// CHECK: %[[VAL_18:.*]] = fir.load %[[VAL_5]] : !fir.ref<i32>
// CHECK: %[[VAL_19:.*]] = arith.subi %[[VAL_17]], %[[VAL_18]] : i32
// CHECK: %[[VAL_20:.*]] = arith.subi %[[VAL_19]], %[[VAL_3]] : i32
// CHECK: %[[VAL_21:.*]] = hlfir.designate %[[VAL_14]]#0 (%[[VAL_15]]) : (!fir.heap<!fir.array<42xi32>>, index) -> !fir.ref<i32>
// CHECK: hlfir.assign %[[VAL_20]] to %[[VAL_21]] temporary_lhs : i32, !fir.ref<i32>
// CHECK: omp.yield
// CHECK: }
// CHECK: omp.terminator
// CHECK: }
// CHECK: omp.single nowait {
// CHECK: "test.test1"(%[[VAL_5]]) : (!fir.ref<i32>) -> ()
// CHECK: hlfir.assign %[[VAL_14]]#0 to %[[VAL_12]]#0 : !fir.heap<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>
// CHECK: fir.freemem %[[VAL_14]]#0 : !fir.heap<!fir.array<42xi32>>
// CHECK: omp.terminator
// CHECK: }
// CHECK: omp.barrier
// CHECK: }
// CHECK: return
// CHECK: }
21 changes: 21 additions & 0 deletions flang/test/Transforms/OpenMP/lower-workshare2.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// RUN: fir-opt --split-input-file --lower-workshare --allow-unregistered-dialect %s | FileCheck %s

// CHECK-LABEL: func.func @nonowait
func.func @nonowait(%arg0: !fir.ref<!fir.array<42xi32>>) {
// CHECK: omp.barrier
omp.workshare {
omp.terminator
}
return
}

// -----

// CHECK-LABEL: func.func @nowait
func.func @nowait(%arg0: !fir.ref<!fir.array<42xi32>>) {
// CHECK-NOT: omp.barrier
omp.workshare nowait {
omp.terminator
}
return
}
74 changes: 74 additions & 0 deletions flang/test/Transforms/OpenMP/lower-workshare3.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
// RUN: fir-opt --split-input-file --lower-workshare --allow-unregistered-dialect %s | FileCheck %s


// tests if the correct values are stored

func.func @wsfunc(%arg0: !fir.ref<!fir.array<42xi32>>) {
omp.parallel {
// CHECK: fir.alloca
// CHECK: fir.alloca
// CHECK: fir.alloca
// CHECK: fir.alloca
// CHECK: fir.alloca
// CHECK-NOT: fir.alloca
omp.workshare {

%t1 = "test.test1"() : () -> i32
// CHECK: %[[T1:.*]] = "test.test1"
// CHECK: fir.store %[[T1]]
%t2 = "test.test2"() : () -> i32
// CHECK: %[[T2:.*]] = "test.test2"
// CHECK: fir.store %[[T2]]
%t3 = "test.test3"() : () -> i32
// CHECK: %[[T3:.*]] = "test.test3"
// CHECK-NOT: fir.store %[[T3]]
%t4 = "test.test4"() : () -> i32
// CHECK: %[[T4:.*]] = "test.test4"
// CHECK: fir.store %[[T4]]
%t5 = "test.test5"() : () -> i32
// CHECK: %[[T5:.*]] = "test.test5"
// CHECK: fir.store %[[T5]]
%t6 = "test.test6"() : () -> i32
// CHECK: %[[T6:.*]] = "test.test6"
// CHECK-NOT: fir.store %[[T6]]


"test.test1"(%t1) : (i32) -> ()
"test.test1"(%t2) : (i32) -> ()
"test.test1"(%t3) : (i32) -> ()

%true = arith.constant true
fir.if %true {
"test.test2"(%t3) : (i32) -> ()
}

%c1_i32 = arith.constant 1 : i32

%t5_pure_use = arith.addi %t5, %c1_i32 : i32

%t6_mem_effect_use = "test.test8"(%t6) : (i32) -> i32
// CHECK: %[[T6_USE:.*]] = "test.test8"
// CHECK: fir.store %[[T6_USE]]

%c42 = arith.constant 42 : index
%c1 = arith.constant 1 : index
"omp.workshare_loop_wrapper"() ({
omp.loop_nest (%arg1) : index = (%c1) to (%c42) inclusive step (%c1) {
"test.test10"(%t1) : (i32) -> ()
"test.test10"(%t5_pure_use) : (i32) -> ()
"test.test10"(%t6_mem_effect_use) : (i32) -> ()
omp.yield
}
omp.terminator
}) : () -> ()

"test.test10"(%t2) : (i32) -> ()
fir.if %true {
"test.test10"(%t4) : (i32) -> ()
}
omp.terminator
}
omp.terminator
}
return
}
1 change: 1 addition & 0 deletions flang/tools/bbc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ FIRTransforms
FIRBuilder
HLFIRDialect
HLFIRTransforms
FlangOpenMPTransforms
${dialect_libs}
${extension_libs}
MLIRAffineToStandard
Expand Down
1 change: 1 addition & 0 deletions flang/tools/fir-opt/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ target_link_libraries(fir-opt PRIVATE
FIRCodeGen
HLFIRDialect
HLFIRTransforms
FlangOpenMPTransforms
FIRAnalysis
${test_libs}
${dialect_libs}
Expand Down
2 changes: 2 additions & 0 deletions flang/tools/fir-opt/fir-opt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "mlir/Tools/mlir-opt/MlirOptMain.h"
#include "flang/Optimizer/CodeGen/CodeGen.h"
#include "flang/Optimizer/HLFIR/Passes.h"
#include "flang/Optimizer/OpenMP/Passes.h"
#include "flang/Optimizer/Support/InitFIR.h"
#include "flang/Optimizer/Transforms/Passes.h"

Expand All @@ -34,6 +35,7 @@ int main(int argc, char **argv) {
fir::registerOptCodeGenPasses();
fir::registerOptTransformPasses();
hlfir::registerHLFIRPasses();
flangomp::registerFlangOpenMPPasses();
#ifdef FLANG_INCLUDE_TESTS
fir::test::registerTestFIRAliasAnalysisPass();
mlir::registerSideEffectTestPasses();
Expand Down
1 change: 1 addition & 0 deletions flang/tools/tco/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ target_link_libraries(tco PRIVATE
FIRBuilder
HLFIRDialect
HLFIRTransforms
FlangOpenMPTransforms
${dialect_libs}
${extension_libs}
MLIRIR
Expand Down