diff --git a/flang/lib/Semantics/check-omp-structure.cpp b/flang/lib/Semantics/check-omp-structure.cpp index cf56efb632d0f2..83acaba367ffd2 100644 --- a/flang/lib/Semantics/check-omp-structure.cpp +++ b/flang/lib/Semantics/check-omp-structure.cpp @@ -1806,6 +1806,7 @@ CHECK_SIMPLE_CLAUSE(MemoryOrder, OMPC_memory_order) CHECK_SIMPLE_CLAUSE(Bind, OMPC_bind) CHECK_SIMPLE_CLAUSE(Align, OMPC_align) CHECK_SIMPLE_CLAUSE(Compare, OMPC_compare) +CHECK_SIMPLE_CLAUSE(CancellationConstructType, OMPC_cancellation_construct_type) CHECK_REQ_SCALAR_INT_CLAUSE(Grainsize, OMPC_grainsize) CHECK_REQ_SCALAR_INT_CLAUSE(NumTasks, OMPC_num_tasks) diff --git a/llvm/include/llvm/Frontend/OpenMP/OMP.td b/llvm/include/llvm/Frontend/OpenMP/OMP.td index e5a1dd3931247e..9575f6df6caeab 100644 --- a/llvm/include/llvm/Frontend/OpenMP/OMP.td +++ b/llvm/include/llvm/Frontend/OpenMP/OMP.td @@ -163,6 +163,25 @@ def OMPC_MemoryOrder : Clause<"memory_order"> { ]; } +def OMP_CANCELLATION_CONSTRUCT_Parallel : ClauseVal<"parallel", 1, 1> {} +def OMP_CANCELLATION_CONSTRUCT_Loop : ClauseVal<"loop", 2, 1> {} +def OMP_CANCELLATION_CONSTRUCT_Sections : ClauseVal<"sections", 3, 1> {} +def OMP_CANCELLATION_CONSTRUCT_Taskgroup : ClauseVal<"taskgroup", 4, 1> {} +def OMP_CANCELLATION_CONSTRUCT_None : ClauseVal<"none", 5, 0> { + let isDefault = 1; +} + +def OMPC_CancellationConstructType : Clause<"cancellation_construct_type"> { + let enumClauseValue = "CancellationConstructType"; + let allowedClauseValues = [ + OMP_CANCELLATION_CONSTRUCT_Parallel, + OMP_CANCELLATION_CONSTRUCT_Loop, + OMP_CANCELLATION_CONSTRUCT_Sections, + OMP_CANCELLATION_CONSTRUCT_Taskgroup, + OMP_CANCELLATION_CONSTRUCT_None + ]; +} + def OMPC_Ordered : Clause<"ordered"> { let clangClass = "OMPOrderedClause"; let flangClass = "ScalarIntConstantExpr"; diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td index bc5a7ff89783f7..e69891ec9abf31 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -998,6 +998,40 @@ def ThreadprivateOp : OpenMP_Op<"threadprivate"> { }]; } +//===----------------------------------------------------------------------===// +// 2.18.1 Cancel Construct +//===----------------------------------------------------------------------===// +def CancelOp : OpenMP_Op<"cancel"> { + let summary = "cancel directive"; + let description = [{ + The cancel construct activates cancellation of the innermost enclosing + region of the type specified. + }]; + let arguments = (ins CancellationConstructTypeAttr:$cancellation_construct_type_val, + Optional:$if_expr); + let assemblyFormat = [{ `cancellation_construct_type` `(` + custom($cancellation_construct_type_val) `)` + ( `if` `(` $if_expr^ `)` )? attr-dict}]; + let hasVerifier = 1; +} + +//===----------------------------------------------------------------------===// +// 2.18.2 Cancellation Point Construct +//===----------------------------------------------------------------------===// +def CancellationPointOp : OpenMP_Op<"cancellationpoint"> { + let summary = "cancellation point directive"; + let description = [{ + The cancellation point construct introduces a user-defined cancellation + point at which implicit or explicit tasks check if cancellation of the + innermost enclosing region of the type specified has been activated. + }]; + let arguments = (ins CancellationConstructTypeAttr:$cancellation_construct_type_val); + let assemblyFormat = [{ `cancellation_construct_type` `(` + custom($cancellation_construct_type_val) `)` + attr-dict}]; + let hasVerifier = 1; +} + //===----------------------------------------------------------------------===// // 2.19.5.7 declare reduction Directive //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index fa2becae7e6374..5540eec8f3005b 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -950,6 +950,80 @@ LogicalResult AtomicCaptureOp::verifyRegions() { return success(); } +//===----------------------------------------------------------------------===// +// Verifier for CancelOp +//===----------------------------------------------------------------------===// + +LogicalResult CancelOp::verify() { + ClauseCancellationConstructType cct = cancellation_construct_type_val(); + Operation *parentOp = (*this)->getParentOp(); + + if (!parentOp) { + return emitOpError() << "must be used within a region supporting " + "cancel directive"; + } + + if ((cct == ClauseCancellationConstructType::Parallel) && + !isa(parentOp)) { + return emitOpError() << "cancel parallel must appear " + << "inside a parallel region"; + } else if (cct == ClauseCancellationConstructType::Loop) { + if (!isa(parentOp)) { + return emitOpError() << "cancel loop must appear " + << "inside a worksharing-loop region"; + } else { + if (cast(parentOp).nowaitAttr()) { + return emitError() << "A worksharing construct that is canceled " + << "must not have a nowait clause"; + } else if (cast(parentOp).ordered_valAttr()) { + return emitError() << "A worksharing construct that is canceled " + << "must not have an ordered clause"; + } + } + } else if (cct == ClauseCancellationConstructType::Sections) { + if (!(isa(parentOp) || isa(parentOp))) { + return emitOpError() << "cancel sections must appear " + << "inside a sections region"; + } + if (parentOp->getParentOp() && isa(parentOp->getParentOp()) && + cast(parentOp->getParentOp()).nowaitAttr()) { + return emitError() << "A sections construct that is canceled " + << "must not have a nowait clause"; + } + } + // TODO : Add more when we support taskgroup. + return success(); +} +//===----------------------------------------------------------------------===// +// Verifier for CancelOp +//===----------------------------------------------------------------------===// + +LogicalResult CancellationPointOp::verify() { + ClauseCancellationConstructType cct = cancellation_construct_type_val(); + Operation *parentOp = (*this)->getParentOp(); + + if (!parentOp) { + return emitOpError() << "must be used within a region supporting " + "cancellation point directive"; + } + + if ((cct == ClauseCancellationConstructType::Parallel) && + !(isa(parentOp))) { + return emitOpError() << "cancellation point parallel must appear " + << "inside a parallel region"; + } else if ((cct == ClauseCancellationConstructType::Loop) && + !isa(parentOp)) { + return emitOpError() << "cancellation point loop must appear " + << "inside a worksharing-loop region"; + } else if ((cct == ClauseCancellationConstructType::Sections) && + !(isa(parentOp) || isa(parentOp))) { + return emitOpError() << "cancellation point sections must appear " + << "inside a sections region"; + } + // TODO : Add more when we support taskgroup. + return success(); +} + #define GET_ATTRDEF_CLASSES #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc" diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir index e88c7c4b705972..61e44b04593bd0 100644 --- a/mlir/test/Dialect/OpenMP/invalid.mlir +++ b/mlir/test/Dialect/OpenMP/invalid.mlir @@ -1143,3 +1143,116 @@ func.func @omp_task(%mem: memref<1xf32>) { } return } + +// ----- + +func @omp_cancel() { + omp.sections { + // expected-error @below {{cancel parallel must appear inside a parallel region}} + omp.cancel cancellation_construct_type(parallel) + // CHECK: omp.terminator + omp.terminator + } + return +} + +// ----- + +func @omp_cancel1() { + omp.parallel { + // expected-error @below {{cancel sections must appear inside a sections region}} + omp.cancel cancellation_construct_type(sections) + // CHECK: omp.terminator + omp.terminator + } + return +} + +// ----- + +func @omp_cancel2() { + omp.sections { + // expected-error @below {{cancel loop must appear inside a worksharing-loop region}} + omp.cancel cancellation_construct_type(loop) + // CHECK: omp.terminator + omp.terminator + } + return +} + +// ----- + +func @omp_cancel3(%arg1 : i32, %arg2 : i32, %arg3 : i32) -> () { + omp.wsloop nowait + for (%0) : i32 = (%arg1) to (%arg2) step (%arg3) { + // expected-error @below {{A worksharing construct that is canceled must not have a nowait clause}} + omp.cancel cancellation_construct_type(loop) + // CHECK: omp.terminator + omp.terminator + } + return +} + +// ----- + +func @omp_cancel4(%arg1 : i32, %arg2 : i32, %arg3 : i32) -> () { + omp.wsloop ordered(1) + for (%0) : i32 = (%arg1) to (%arg2) step (%arg3) { + // expected-error @below {{A worksharing construct that is canceled must not have an ordered clause}} + omp.cancel cancellation_construct_type(loop) + // CHECK: omp.terminator + omp.terminator + } + return +} + +// ----- + +func @omp_cancel5() -> () { + omp.sections nowait { + omp.section { + // expected-error @below {{A sections construct that is canceled must not have a nowait clause}} + omp.cancel cancellation_construct_type(sections) + omp.terminator + } + // CHECK: omp.terminator + omp.terminator + } + return +} + +// ----- + +func @omp_cancellationpoint() { + omp.sections { + // expected-error @below {{cancellation point parallel must appear inside a parallel region}} + omp.cancellationpoint cancellation_construct_type(parallel) + // CHECK: omp.terminator + omp.terminator + } + return +} + +// ----- + +func @omp_cancellationpoint1() { + omp.parallel { + // expected-error @below {{cancellation point sections must appear inside a sections region}} + omp.cancellationpoint cancellation_construct_type(sections) + // CHECK: omp.terminator + omp.terminator + } + return +} + +// ----- + +func @omp_cancellationpoint2() { + omp.sections { + // expected-error @below {{cancellation point loop must appear inside a worksharing-loop region}} + omp.cancellationpoint cancellation_construct_type(loop) + // CHECK: omp.terminator + omp.terminator + } + return +} diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir index 15ec6f796b8809..9dd76c4dfc2ac0 100644 --- a/mlir/test/Dialect/OpenMP/ops.mlir +++ b/mlir/test/Dialect/OpenMP/ops.mlir @@ -1276,3 +1276,77 @@ func.func @omp_threadprivate() { } llvm.mlir.global internal @_QFsubEx() : i32 + +func @omp_cancel_parallel(%if_cond : i1) -> () { + // Test with optional operand; if_expr. + omp.parallel { + // CHECK: omp.cancel cancellation_construct_type(parallel) if(%{{.*}}) + omp.cancel cancellation_construct_type(parallel) if(%if_cond) + // CHECK: omp.terminator + omp.terminator + } + return +} + +func @omp_cancel_wsloop(%lb : index, %ub : index, %step : index) { + omp.wsloop + for (%iv) : index = (%lb) to (%ub) step (%step) { + // CHECK: omp.cancel cancellation_construct_type(loop) + omp.cancel cancellation_construct_type(loop) + // CHECK: omp.terminator + omp.terminator + } + return +} + +func @omp_cancel_sections() -> () { + omp.sections { + omp.section { + // CHECK: omp.cancel cancellation_construct_type(sections) + omp.cancel cancellation_construct_type(sections) + omp.terminator + } + // CHECK: omp.terminator + omp.terminator + } + return +} + +func @omp_cancellationpoint_parallel() -> () { + omp.parallel { + // CHECK: omp.cancellationpoint cancellation_construct_type(parallel) + omp.cancellationpoint cancellation_construct_type(parallel) + // CHECK: omp.cancel cancellation_construct_type(parallel) + omp.cancel cancellation_construct_type(parallel) + omp.terminator + } + return +} + +func @omp_cancellationpoint_wsloop(%lb : index, %ub : index, %step : index) { + omp.wsloop + for (%iv) : index = (%lb) to (%ub) step (%step) { + // CHECK: omp.cancellationpoint cancellation_construct_type(loop) + omp.cancellationpoint cancellation_construct_type(loop) + // CHECK: omp.cancel cancellation_construct_type(loop) + omp.cancel cancellation_construct_type(loop) + // CHECK: omp.terminator + omp.terminator + } + return +} + +func @omp_cancellationpoint_sections() -> () { + omp.sections { + omp.section { + // CHECK: omp.cancellationpoint cancellation_construct_type(sections) + omp.cancellationpoint cancellation_construct_type(sections) + // CHECK: omp.cancel cancellation_construct_type(sections) + omp.cancel cancellation_construct_type(sections) + omp.terminator + } + // CHECK: omp.terminator + omp.terminator + } + return +}