diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp index 7519da844eebb..9397af8b8bd05 100644 --- a/flang/lib/Lower/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP.cpp @@ -2640,7 +2640,8 @@ genParallelOp(Fortran::lower::AbstractConverter &converter, ? nullptr : mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(), reductionDeclSymbols), - procBindKindAttr); + procBindKindAttr, /*private_vars=*/llvm::SmallVector{}, + /*privatizers=*/nullptr); } static mlir::omp::SectionOp diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td index 024f43f7e7e3b..f907e21e9b4de 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -276,7 +276,9 @@ def ParallelOp : OpenMP_Op<"parallel", [ Variadic:$allocators_vars, Variadic:$reduction_vars, OptionalAttr:$reductions, - OptionalAttr:$proc_bind_val); + OptionalAttr:$proc_bind_val, + Variadic:$private_vars, + OptionalAttr:$privatizers); let regions = (region AnyRegion:$region); @@ -297,7 +299,9 @@ def ParallelOp : OpenMP_Op<"parallel", [ $allocators_vars, type($allocators_vars) ) `)` | `proc_bind` `(` custom($proc_bind_val) `)` - ) custom($region, $reduction_vars, type($reduction_vars), $reductions) attr-dict + ) custom($region, $reduction_vars, type($reduction_vars), + $reductions, $private_vars, type($private_vars), + $privatizers) attr-dict }]; let hasVerifier = 1; } diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp index ea5f31ee8c6aa..464a647564ace 100644 --- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp +++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp @@ -450,7 +450,9 @@ struct ParallelOpLowering : public OpRewritePattern { /* allocators_vars = */ llvm::SmallVector{}, /* reduction_vars = */ llvm::SmallVector{}, /* reductions = */ ArrayAttr{}, - /* proc_bind_val = */ omp::ClauseProcBindKindAttr{}); + /* proc_bind_val = */ omp::ClauseProcBindKindAttr{}, + /* private_vars = */ ValueRange(), + /* privatizers = */ nullptr); { OpBuilder::InsertionGuard guard(rewriter); diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index 82e32da91aaee..c2b471ab96183 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -430,68 +430,102 @@ static void printScheduleClause(OpAsmPrinter &p, Operation *op, // Parser, printer and verifier for ReductionVarList //===----------------------------------------------------------------------===// -ParseResult -parseReductionClause(OpAsmParser &parser, Region ®ion, - SmallVectorImpl &operands, - SmallVectorImpl &types, ArrayAttr &reductionSymbols, - SmallVectorImpl &privates) { - if (failed(parser.parseOptionalKeyword("reduction"))) - return failure(); - +ParseResult parseClauseWithRegionArgs( + OpAsmParser &parser, Region ®ion, + SmallVectorImpl &operands, + SmallVectorImpl &types, ArrayAttr &symbols, + SmallVectorImpl ®ionPrivateArgs) { SmallVector reductionVec; + unsigned regionArgOffset = regionPrivateArgs.size(); if (failed( parser.parseCommaSeparatedList(OpAsmParser::Delimiter::Paren, [&]() { if (parser.parseAttribute(reductionVec.emplace_back()) || parser.parseOperand(operands.emplace_back()) || parser.parseArrow() || - parser.parseArgument(privates.emplace_back()) || + parser.parseArgument(regionPrivateArgs.emplace_back()) || parser.parseColonType(types.emplace_back())) return failure(); return success(); }))) return failure(); - for (auto [prv, type] : llvm::zip_equal(privates, types)) { + auto *argsBegin = regionPrivateArgs.begin(); + MutableArrayRef argsSubrange(argsBegin + regionArgOffset, + argsBegin + regionArgOffset + types.size()); + for (auto [prv, type] : llvm::zip_equal(argsSubrange, types)) { prv.type = type; } SmallVector reductions(reductionVec.begin(), reductionVec.end()); - reductionSymbols = ArrayAttr::get(parser.getContext(), reductions); + symbols = ArrayAttr::get(parser.getContext(), reductions); return success(); } -static void printReductionClause(OpAsmPrinter &p, Operation *op, - ValueRange reductionArgs, ValueRange operands, - TypeRange types, ArrayAttr reductionSymbols) { - p << "reduction("; +static void printClauseWithRegionArgs(OpAsmPrinter &p, Operation *op, + ValueRange argsSubrange, + StringRef clauseName, ValueRange operands, + TypeRange types, ArrayAttr symbols) { + p << clauseName << "("; llvm::interleaveComma( - llvm::zip_equal(reductionSymbols, operands, reductionArgs, types), p, - [&p](auto t) { + llvm::zip_equal(symbols, operands, argsSubrange, types), p, [&p](auto t) { auto [sym, op, arg, type] = t; p << sym << " " << op << " -> " << arg << " : " << type; }); p << ") "; } -static ParseResult -parseParallelRegion(OpAsmParser &parser, Region ®ion, - SmallVectorImpl &operands, - SmallVectorImpl &types, ArrayAttr &reductionSymbols) { +static ParseResult parseParallelRegion( + OpAsmParser &parser, Region ®ion, + SmallVectorImpl &reductionVarOperands, + SmallVectorImpl &reductionVarTypes, ArrayAttr &reductionSymbols, + llvm::SmallVectorImpl &privateVarOperands, + llvm::SmallVectorImpl &privateVarsTypes, + ArrayAttr &privatizerSymbols) { + llvm::SmallVector regionPrivateArgs; - llvm::SmallVector privates; - if (succeeded(parseReductionClause(parser, region, operands, types, - reductionSymbols, privates))) - return parser.parseRegion(region, privates); + if (succeeded(parser.parseOptionalKeyword("reduction"))) { + if (failed(parseClauseWithRegionArgs(parser, region, reductionVarOperands, + reductionVarTypes, reductionSymbols, + regionPrivateArgs))) + return failure(); + } - return parser.parseRegion(region); + if (succeeded(parser.parseOptionalKeyword("private"))) { + if (failed(parseClauseWithRegionArgs(parser, region, privateVarOperands, + privateVarsTypes, privatizerSymbols, + regionPrivateArgs))) + return failure(); + } + + return parser.parseRegion(region, regionPrivateArgs); } static void printParallelRegion(OpAsmPrinter &p, Operation *op, Region ®ion, - ValueRange operands, TypeRange types, - ArrayAttr reductionSymbols) { - if (reductionSymbols) - printReductionClause(p, op, region.front().getArguments(), operands, types, - reductionSymbols); + ValueRange reductionVarOperands, + TypeRange reductionVarTypes, + ArrayAttr reductionSymbols, + ValueRange privateVarOperands, + TypeRange privateVarTypes, + ArrayAttr privatizerSymbols) { + if (reductionSymbols) { + auto *argsBegin = region.front().getArguments().begin(); + MutableArrayRef argsSubrange(argsBegin, + argsBegin + reductionVarTypes.size()); + printClauseWithRegionArgs(p, op, argsSubrange, "reduction", + reductionVarOperands, reductionVarTypes, + reductionSymbols); + } + + if (privatizerSymbols) { + auto *argsBegin = region.front().getArguments().begin(); + MutableArrayRef argsSubrange(argsBegin + reductionVarOperands.size(), + argsBegin + reductionVarOperands.size() + + privateVarTypes.size()); + printClauseWithRegionArgs(p, op, argsSubrange, "private", + privateVarOperands, privateVarTypes, + privatizerSymbols); + } + p.printRegion(region, /*printEntryBlockArgs=*/false); } @@ -1174,14 +1208,64 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state, builder, state, /*if_expr_var=*/nullptr, /*num_threads_var=*/nullptr, /*allocate_vars=*/ValueRange(), /*allocators_vars=*/ValueRange(), /*reduction_vars=*/ValueRange(), /*reductions=*/nullptr, - /*proc_bind_val=*/nullptr); + /*proc_bind_val=*/nullptr, /*private_vars=*/ValueRange(), + /*privatizers=*/nullptr); state.addAttributes(attributes); } +template +static LogicalResult verifyPrivateVarList(OpType &op) { + auto privateVars = op.getPrivateVars(); + auto privatizers = op.getPrivatizersAttr(); + + if (privateVars.empty() && (privatizers == nullptr || privatizers.empty())) + return success(); + + auto numPrivateVars = privateVars.size(); + auto numPrivatizers = (privatizers == nullptr) ? 0 : privatizers.size(); + + if (numPrivateVars != numPrivatizers) + return op.emitError() << "inconsistent number of private variables and " + "privatizer op symbols, private vars: " + << numPrivateVars + << " vs. privatizer op symbols: " << numPrivatizers; + + for (auto privateVarInfo : llvm::zip_equal(privateVars, privatizers)) { + Type varType = std::get<0>(privateVarInfo).getType(); + SymbolRefAttr privatizerSym = + std::get<1>(privateVarInfo).template cast(); + PrivateClauseOp privatizerOp = + SymbolTable::lookupNearestSymbolFrom(op, + privatizerSym); + + if (privatizerOp == nullptr) + return op.emitError() << "failed to lookup privatizer op with symbol: '" + << privatizerSym << "'"; + + Type privatizerType = privatizerOp.getType(); + + if (varType != privatizerType) + return op.emitError() + << "type mismatch between a " + << (privatizerOp.getDataSharingType() == + DataSharingClauseType::Private + ? "private" + : "firstprivate") + << " variable and its privatizer op, var type: " << varType + << " vs. privatizer op type: " << privatizerType; + } + + return success(); +} + LogicalResult ParallelOp::verify() { if (getAllocateVars().size() != getAllocatorsVars().size()) return emitError( "expected equal sizes for allocate and allocator variables"); + + if (failed(verifyPrivateVarList(*this))) + return failure(); + return verifyReductionVarList(*this, getReductions(), getReductionVars()); } @@ -1279,9 +1363,10 @@ parseWsLoop(OpAsmParser &parser, Region ®ion, // Parse an optional reduction clause llvm::SmallVector privates; - bool hasReduction = succeeded( - parseReductionClause(parser, region, reductionOperands, reductionTypes, - reductionSymbols, privates)); + bool hasReduction = succeeded(parser.parseOptionalKeyword("reduction")) && + succeeded(parseClauseWithRegionArgs( + parser, region, reductionOperands, reductionTypes, + reductionSymbols, privates)); if (parser.parseKeyword("for")) return failure(); @@ -1328,8 +1413,9 @@ void printWsLoop(OpAsmPrinter &p, Operation *op, Region ®ion, if (reductionSymbols) { auto reductionArgs = region.front().getArguments().drop_front(loopVarTypes.size()); - printReductionClause(p, op, reductionArgs, reductionOperands, - reductionTypes, reductionSymbols); + printClauseWithRegionArgs(p, op, reductionArgs, "reduction", + reductionOperands, reductionTypes, + reductionSymbols); } p << " for "; diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir index 7965d6dc28420..d9261b89e24e3 100644 --- a/mlir/test/Dialect/OpenMP/invalid.mlir +++ b/mlir/test/Dialect/OpenMP/invalid.mlir @@ -1865,3 +1865,59 @@ omp.private {type = firstprivate} @x.privatizer : f32 alloc { ^bb0(%arg0: f32): omp.yield(%arg0 : f32) } + +// ----- + +func.func @private_type_mismatch(%arg0: index) { +// expected-error @below {{type mismatch between a private variable and its privatizer op, var type: 'index' vs. privatizer op type: '!llvm.ptr'}} + omp.parallel private(@var1.privatizer %arg0 -> %arg2 : index) { + omp.terminator + } + + return +} + +omp.private {type = private} @var1.privatizer : !llvm.ptr alloc { +^bb0(%arg0: !llvm.ptr): + omp.yield(%arg0 : !llvm.ptr) +} + +// ----- + +func.func @firstprivate_type_mismatch(%arg0: index) { + // expected-error @below {{type mismatch between a firstprivate variable and its privatizer op, var type: 'index' vs. privatizer op type: '!llvm.ptr'}} + omp.parallel private(@var1.privatizer %arg0 -> %arg2 : index) { + omp.terminator + } + + return +} + +omp.private {type = firstprivate} @var1.privatizer : !llvm.ptr alloc { +^bb0(%arg0: !llvm.ptr): + omp.yield(%arg0 : !llvm.ptr) +} copy { +^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr): + omp.yield(%arg0 : !llvm.ptr) +} + +// ----- + +func.func @undefined_privatizer(%arg0: index) { + // expected-error @below {{failed to lookup privatizer op with symbol: '@var1.privatizer'}} + omp.parallel private(@var1.privatizer %arg0 -> %arg2 : index) { + omp.terminator + } + + return +} + +// ----- +func.func @undefined_privatizer(%arg0: !llvm.ptr) { + // expected-error @below {{inconsistent number of private variables and privatizer op symbols, private vars: 1 vs. privatizer op symbols: 2}} + "omp.parallel"(%arg0) <{operandSegmentSizes = array, privatizers = [@x.privatizer, @y.privatizer]}> ({ + ^bb0(%arg2: !llvm.ptr): + omp.terminator + }) : (!llvm.ptr) -> () + return +} diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir index d92a554caf77c..211ff0ff9272e 100644 --- a/mlir/test/Dialect/OpenMP/ops.mlir +++ b/mlir/test/Dialect/OpenMP/ops.mlir @@ -59,7 +59,7 @@ func.func @omp_parallel(%data_var : memref, %if_cond : i1, %num_threads : i // CHECK: omp.parallel num_threads(%{{.*}} : i32) allocate(%{{.*}} : memref -> %{{.*}} : memref) "omp.parallel"(%num_threads, %data_var, %data_var) ({ omp.terminator - }) {operandSegmentSizes = array} : (i32, memref, memref) -> () + }) {operandSegmentSizes = array} : (i32, memref, memref) -> () // CHECK: omp.barrier omp.barrier @@ -68,22 +68,22 @@ func.func @omp_parallel(%data_var : memref, %if_cond : i1, %num_threads : i // CHECK: omp.parallel if(%{{.*}}) allocate(%{{.*}} : memref -> %{{.*}} : memref) "omp.parallel"(%if_cond, %data_var, %data_var) ({ omp.terminator - }) {operandSegmentSizes = array} : (i1, memref, memref) -> () + }) {operandSegmentSizes = array} : (i1, memref, memref) -> () // test without allocate // CHECK: omp.parallel if(%{{.*}}) num_threads(%{{.*}} : i32) "omp.parallel"(%if_cond, %num_threads) ({ omp.terminator - }) {operandSegmentSizes = array} : (i1, i32) -> () + }) {operandSegmentSizes = array} : (i1, i32) -> () omp.terminator - }) {operandSegmentSizes = array, proc_bind_val = #omp} : (i1, i32, memref, memref) -> () + }) {operandSegmentSizes = array, proc_bind_val = #omp} : (i1, i32, memref, memref) -> () // test with multiple parameters for single variadic argument // CHECK: omp.parallel allocate(%{{.*}} : memref -> %{{.*}} : memref) "omp.parallel" (%data_var, %data_var) ({ omp.terminator - }) {operandSegmentSizes = array} : (memref, memref) -> () + }) {operandSegmentSizes = array} : (memref, memref) -> () return } @@ -2231,3 +2231,63 @@ func.func @omp_target_enter_update_exit_data_depend(%a: memref, %b: memre omp.target_exit_data map_entries(%map_c : memref) depend(taskdependin -> %c : memref) return } + +// CHECK-LABEL: parallel_op_privatizers +// CHECK-SAME: (%[[ARG0:[^[:space:]]+]]: !llvm.ptr, %[[ARG1:[^[:space:]]+]]: !llvm.ptr) +func.func @parallel_op_privatizers(%arg0: !llvm.ptr, %arg1: !llvm.ptr) { + // CHECK: omp.parallel private( + // CHECK-SAME: @x.privatizer %[[ARG0]] -> %[[ARG0_PRIV:[^[:space:]]+]] : !llvm.ptr, + // CHECK-SAME: @y.privatizer %[[ARG1]] -> %[[ARG1_PRIV:[^[:space:]]+]] : !llvm.ptr) + omp.parallel private(@x.privatizer %arg0 -> %arg2 : !llvm.ptr, @y.privatizer %arg1 -> %arg3 : !llvm.ptr) { + // CHECK: llvm.load %[[ARG0_PRIV]] + %0 = llvm.load %arg2 : !llvm.ptr -> i32 + // CHECK: llvm.load %[[ARG1_PRIV]] + %1 = llvm.load %arg3 : !llvm.ptr -> i32 + omp.terminator + } + return +} + +// CHECK-LABEL: omp.private {type = private} @x.privatizer : !llvm.ptr alloc { +omp.private {type = private} @x.privatizer : !llvm.ptr alloc { +// CHECK: ^bb0(%{{.*}}: {{.*}}): +^bb0(%arg0: !llvm.ptr): + omp.yield(%arg0 : !llvm.ptr) +} + +// CHECK-LABEL: omp.private {type = firstprivate} @y.privatizer : !llvm.ptr alloc { +omp.private {type = firstprivate} @y.privatizer : !llvm.ptr alloc { +// CHECK: ^bb0(%{{.*}}: {{.*}}): +^bb0(%arg0: !llvm.ptr): + omp.yield(%arg0 : !llvm.ptr) +// CHECK: } copy { +} copy { +// CHECK: ^bb0(%{{.*}}: {{.*}}, %{{.*}}: {{.*}}): +^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr): + omp.yield(%arg0 : !llvm.ptr) +} + +// CHECK-LABEL: parallel_op_reduction_and_private +func.func @parallel_op_reduction_and_private(%priv_var: !llvm.ptr, %priv_var2: !llvm.ptr, %reduc_var: !llvm.ptr, %reduc_var2: !llvm.ptr) { + // CHECK: omp.parallel + // CHECK-SAME: reduction( + // CHECK-SAME: @add_f32 %[[REDUC_VAR:[^[:space:]]+]] -> %[[REDUC_ARG:[^[:space:]]+]] : !llvm.ptr, + // CHECK-SAME: @add_f32 %[[REDUC_VAR2:[^[:space:]]+]] -> %[[REDUC_ARG2:[^[:space:]]+]] : !llvm.ptr) + // + // CHECK-SAME: private( + // CHECK-SAME: @x.privatizer %[[PRIV_VAR:[^[:space:]]+]] -> %[[PRIV_ARG:[^[:space:]]+]] : !llvm.ptr, + // CHECK-SAME: @y.privatizer %[[PRIV_VAR2:[^[:space:]]+]] -> %[[PRIV_ARG2:[^[:space:]]+]] : !llvm.ptr) + omp.parallel reduction(@add_f32 %reduc_var -> %reduc_arg : !llvm.ptr, @add_f32 %reduc_var2 -> %reduc_arg2 : !llvm.ptr) + private(@x.privatizer %priv_var -> %priv_arg : !llvm.ptr, @y.privatizer %priv_var2 -> %priv_arg2 : !llvm.ptr) { + // CHECK: llvm.load %[[PRIV_ARG]] + %0 = llvm.load %priv_arg : !llvm.ptr -> f32 + // CHECK: llvm.load %[[PRIV_ARG2]] + %1 = llvm.load %priv_arg2 : !llvm.ptr -> f32 + // CHECK: llvm.load %[[REDUC_ARG]] + %2 = llvm.load %reduc_arg : !llvm.ptr -> f32 + // CHECK: llvm.load %[[REDUC_ARG2]] + %3 = llvm.load %reduc_arg2 : !llvm.ptr -> f32 + omp.terminator + } + return +} diff --git a/mlir/test/Dialect/OpenMP/roundtrip.mlir b/mlir/test/Dialect/OpenMP/roundtrip.mlir deleted file mode 100644 index 2553442638ee8..0000000000000 --- a/mlir/test/Dialect/OpenMP/roundtrip.mlir +++ /dev/null @@ -1,21 +0,0 @@ -// RUN: mlir-opt -verify-diagnostics %s | mlir-opt | FileCheck %s - -// CHECK: omp.private {type = private} @x.privatizer : !llvm.ptr alloc { -omp.private {type = private} @x.privatizer : !llvm.ptr alloc { -// CHECK: ^bb0(%arg0: {{.*}}): -^bb0(%arg0: !llvm.ptr): - omp.yield(%arg0 : !llvm.ptr) -} - -// CHECK: omp.private {type = firstprivate} @y.privatizer : !llvm.ptr alloc { -omp.private {type = firstprivate} @y.privatizer : !llvm.ptr alloc { -// CHECK: ^bb0(%arg0: {{.*}}): -^bb0(%arg0: !llvm.ptr): - omp.yield(%arg0 : !llvm.ptr) -// CHECK: } copy { -} copy { -// CHECK: ^bb0(%arg0: {{.*}}, %arg1: {{.*}}): -^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr): - omp.yield(%arg0 : !llvm.ptr) -} -