diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp index ad4cffc707535..448eb1a14fa9f 100644 --- a/flang/lib/Lower/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP.cpp @@ -2607,6 +2607,7 @@ genSingleOp(Fortran::lower::AbstractConverter &converter, const Fortran::parser::OmpClauseList &beginClauseList, const Fortran::parser::OmpClauseList &endClauseList) { llvm::SmallVector allocateOperands, allocatorOperands; + llvm::SmallVector copyPrivateVars; mlir::UnitAttr nowaitAttr; ClauseProcessor cp(converter, beginClauseList); @@ -2620,7 +2621,8 @@ genSingleOp(Fortran::lower::AbstractConverter &converter, OpWithBodyGenInfo(converter, currentLocation, eval) .setGenNested(genNested) .setClauses(&beginClauseList), - allocateOperands, allocatorOperands, nowaitAttr); + allocateOperands, allocatorOperands, copyPrivateVars, + /*copyPrivateFuncs=*/nullptr, nowaitAttr); } static mlir::omp::TaskOp diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td index ca36350548577..c3985ac4dfe39 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -378,10 +378,16 @@ def SingleOp : OpenMP_Op<"single", [AttrSizedOperandSegments]> { master thread), in the context of its implicit task. The other threads in the team, which do not execute the block, wait at an implicit barrier at the end of the single construct unless a nowait clause is specified. + + If copyprivate variables and functions are specified, then each thread + variable is updated with the variable value of the thread that executed + the single region, using the specified copy functions. }]; let arguments = (ins Variadic:$allocate_vars, Variadic:$allocators_vars, + Variadic:$copyprivate_vars, + OptionalAttr:$copyprivate_funcs, UnitAttr:$nowait); let regions = (region AnyRegion:$region); @@ -393,6 +399,10 @@ def SingleOp : OpenMP_Op<"single", [AttrSizedOperandSegments]> { $allocators_vars, type($allocators_vars) ) `)` |`nowait` $nowait + |`copyprivate` `(` + custom( + $copyprivate_vars, type($copyprivate_vars), $copyprivate_funcs + ) `)` ) $region attr-dict }]; let hasVerifier = 1; diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index 381f17d080419..87d198a6ac350 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -505,6 +505,110 @@ static LogicalResult verifyReductionVarList(Operation *op, return success(); } +//===----------------------------------------------------------------------===// +// Parser, printer and verifier for CopyPrivateVarList +//===----------------------------------------------------------------------===// + +/// copyprivate-entry-list ::= copyprivate-entry +/// | copyprivate-entry-list `,` copyprivate-entry +/// copyprivate-entry ::= ssa-id `->` symbol-ref `:` type +static ParseResult parseCopyPrivateVarList( + OpAsmParser &parser, + SmallVectorImpl &operands, + SmallVectorImpl &types, ArrayAttr ©PrivateSymbols) { + SmallVector copyPrivateFuncsVec; + if (failed(parser.parseCommaSeparatedList([&]() { + if (parser.parseOperand(operands.emplace_back()) || + parser.parseArrow() || + parser.parseAttribute(copyPrivateFuncsVec.emplace_back()) || + parser.parseColonType(types.emplace_back())) + return failure(); + return success(); + }))) + return failure(); + SmallVector copyPrivateFuncs(copyPrivateFuncsVec.begin(), + copyPrivateFuncsVec.end()); + copyPrivateSymbols = ArrayAttr::get(parser.getContext(), copyPrivateFuncs); + return success(); +} + +/// Print CopyPrivate clause +static void printCopyPrivateVarList(OpAsmPrinter &p, Operation *op, + OperandRange copyPrivateVars, + TypeRange copyPrivateTypes, + std::optional copyPrivateFuncs) { + if (!copyPrivateFuncs.has_value()) + return; + llvm::interleaveComma( + llvm::zip(copyPrivateVars, *copyPrivateFuncs, copyPrivateTypes), p, + [&](const auto &args) { + p << std::get<0>(args) << " -> " << std::get<1>(args) << " : " + << std::get<2>(args); + }); +} + +/// Verifies CopyPrivate Clause +static LogicalResult +verifyCopyPrivateVarList(Operation *op, OperandRange copyPrivateVars, + std::optional copyPrivateFuncs) { + size_t copyPrivateFuncsSize = + copyPrivateFuncs.has_value() ? copyPrivateFuncs->size() : 0; + if (copyPrivateFuncsSize != copyPrivateVars.size()) + return op->emitOpError() << "inconsistent number of copyPrivate vars (= " + << copyPrivateVars.size() + << ") and functions (= " << copyPrivateFuncsSize + << "), both must be equal"; + if (!copyPrivateFuncs.has_value()) + return success(); + + for (auto copyPrivateVarAndFunc : + llvm::zip(copyPrivateVars, *copyPrivateFuncs)) { + auto symbolRef = + llvm::cast(std::get<1>(copyPrivateVarAndFunc)); + std::optional> + funcOp; + if (mlir::func::FuncOp mlirFuncOp = + SymbolTable::lookupNearestSymbolFrom(op, + symbolRef)) + funcOp = mlirFuncOp; + else if (mlir::LLVM::LLVMFuncOp llvmFuncOp = + SymbolTable::lookupNearestSymbolFrom( + op, symbolRef)) + funcOp = llvmFuncOp; + + auto getNumArguments = [&] { + return std::visit([](auto &f) { return f.getNumArguments(); }, *funcOp); + }; + + auto getArgumentType = [&](unsigned i) { + return std::visit([i](auto &f) { return f.getArgumentTypes()[i]; }, + *funcOp); + }; + + if (!funcOp) + return op->emitOpError() << "expected symbol reference " << symbolRef + << " to point to a copy function"; + + if (getNumArguments() != 2) + return op->emitOpError() + << "expected copy function " << symbolRef << " to have 2 operands"; + + Type argTy = getArgumentType(0); + if (argTy != getArgumentType(1)) + return op->emitOpError() << "expected copy function " << symbolRef + << " arguments to have the same type"; + + Type varType = std::get<0>(copyPrivateVarAndFunc).getType(); + if (argTy != varType) + return op->emitOpError() + << "expected copy function arguments' type (" << argTy + << ") to be the same as copyprivate variable's type (" << varType + << ")"; + } + + return success(); +} + //===----------------------------------------------------------------------===// // Parser, printer and verifier for DependVarList //===----------------------------------------------------------------------===// @@ -1072,7 +1176,8 @@ LogicalResult SingleOp::verify() { return emitError( "expected equal sizes for allocate and allocator variables"); - return success(); + return verifyCopyPrivateVarList(*this, getCopyprivateVars(), + getCopyprivateFuncs()); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir index 812b79e35595f..8f66af4b623fc 100644 --- a/mlir/test/Dialect/OpenMP/invalid.mlir +++ b/mlir/test/Dialect/OpenMP/invalid.mlir @@ -1284,7 +1284,63 @@ func.func @omp_single(%data_var : memref) -> () { // expected-error @below {{expected equal sizes for allocate and allocator variables}} "omp.single" (%data_var) ({ omp.barrier - }) {operandSegmentSizes = array} : (memref) -> () + }) {operandSegmentSizes = array} : (memref) -> () + return +} + +// ----- + +func.func @omp_single_copyprivate(%data_var : memref) -> () { + // expected-error @below {{inconsistent number of copyPrivate vars (= 1) and functions (= 0), both must be equal}} + "omp.single" (%data_var) ({ + omp.barrier + }) {operandSegmentSizes = array} : (memref) -> () + return +} + +// ----- + +func.func @omp_single_copyprivate(%data_var : memref) -> () { + // expected-error @below {{expected symbol reference @copy_func to point to a copy function}} + omp.single copyprivate(%data_var -> @copy_func : memref) { + omp.barrier + } + return +} + +// ----- + +func.func private @copy_func(memref) + +func.func @omp_single_copyprivate(%data_var : memref) -> () { + // expected-error @below {{expected copy function @copy_func to have 2 operands}} + omp.single copyprivate(%data_var -> @copy_func : memref) { + omp.barrier + } + return +} + +// ----- + +func.func private @copy_func(memref, memref) + +func.func @omp_single_copyprivate(%data_var : memref) -> () { + // expected-error @below {{expected copy function @copy_func arguments to have the same type}} + omp.single copyprivate(%data_var -> @copy_func : memref) { + omp.barrier + } + return +} + +// ----- + +func.func private @copy_func(memref, memref) + +func.func @omp_single_copyprivate(%data_var : memref) -> () { + // expected-error @below {{expected copy function arguments' type ('memref') to be the same as copyprivate variable's type ('memref')}} + omp.single copyprivate(%data_var -> @copy_func : memref) { + omp.barrier + } return } diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir index 65a704d18107b..172e358a4237d 100644 --- a/mlir/test/Dialect/OpenMP/ops.mlir +++ b/mlir/test/Dialect/OpenMP/ops.mlir @@ -1607,6 +1607,23 @@ func.func @omp_single_multiple_blocks() { return } +func.func private @copy_i32(memref, memref) + +// CHECK-LABEL: func @omp_single_copyprivate +func.func @omp_single_copyprivate(%data_var: memref) { + omp.parallel { + // CHECK: omp.single copyprivate(%{{.*}} -> @copy_i32 : memref) { + omp.single copyprivate(%data_var -> @copy_i32 : memref) { + "test.payload"() : () -> () + // CHECK: omp.terminator + omp.terminator + } + // CHECK: omp.terminator + omp.terminator + } + return +} + // CHECK-LABEL: @omp_task // CHECK-SAME: (%[[bool_var:.*]]: i1, %[[i64_var:.*]]: i64, %[[i32_var:.*]]: i32, %[[data_var:.*]]: memref) func.func @omp_task(%bool_var: i1, %i64_var: i64, %i32_var: i32, %data_var: memref) {