diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp index ec2c880437f33..51458b2a281d5 100644 --- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp +++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp @@ -1726,6 +1726,24 @@ bool ClauseProcessor::processUniform( }); } +bool ClauseProcessor::processInbranch( + mlir::omp::InbranchClauseOps &result) const { + if (findUniqueClause()) { + result.inbranch = converter.getFirOpBuilder().getUnitAttr(); + return true; + } + return false; +} + +bool ClauseProcessor::processNotinbranch( + mlir::omp::NotinbranchClauseOps &result) const { + if (findUniqueClause()) { + result.notinbranch = converter.getFirOpBuilder().getUnitAttr(); + return true; + } + return false; +} + } // namespace omp } // namespace lower } // namespace Fortran diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.h b/flang/lib/Lower/OpenMP/ClauseProcessor.h index 063da68fb5702..ba1764ce46821 100644 --- a/flang/lib/Lower/OpenMP/ClauseProcessor.h +++ b/flang/lib/Lower/OpenMP/ClauseProcessor.h @@ -86,6 +86,7 @@ class ClauseProcessor { mlir::omp::HasDeviceAddrClauseOps &result, llvm::SmallVectorImpl &hasDeviceSyms) const; bool processHint(mlir::omp::HintClauseOps &result) const; + bool processInbranch(mlir::omp::InbranchClauseOps &result) const; bool processInclusive(mlir::Location currentLocation, mlir::omp::InclusiveClauseOps &result) const; bool processInitializer( @@ -93,6 +94,7 @@ class ClauseProcessor { ReductionProcessor::GenInitValueCBTy &genInitValueCB) const; bool processMergeable(mlir::omp::MergeableClauseOps &result) const; bool processNogroup(mlir::omp::NogroupClauseOps &result) const; + bool processNotinbranch(mlir::omp::NotinbranchClauseOps &result) const; bool processNowait(mlir::omp::NowaitClauseOps &result) const; bool processNumTasks(lower::StatementContext &stmtCtx, mlir::omp::NumTasksClauseOps &result) const; diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp index eab04ab13f9fa..a3c25f20ba437 100644 --- a/flang/lib/Lower/OpenMP/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP/OpenMP.cpp @@ -3846,7 +3846,9 @@ genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable, mlir::omp::DeclareSimdOperands clauseOps; ClauseProcessor cp(converter, semaCtx, clauses); cp.processAligned(clauseOps); + cp.processInbranch(clauseOps); cp.processLinear(clauseOps); + cp.processNotinbranch(clauseOps); cp.processSimdlen(clauseOps); cp.processUniform(clauseOps); diff --git a/flang/test/Lower/OpenMP/declare-simd.f90 b/flang/test/Lower/OpenMP/declare-simd.f90 index 6ebb82d801708..b80bcc5105e7e 100644 --- a/flang/test/Lower/OpenMP/declare-simd.f90 +++ b/flang/test/Lower/OpenMP/declare-simd.f90 @@ -102,11 +102,35 @@ end subroutine declare_simd_uniform ! CHECK: omp.declare_simd uniform(%[[XDECL]]#0 : !fir.ref>>>, %[[YDECL]]#0 : !fir.ref>>>) ! CHECK: return +subroutine declare_simd_inbranch() +#ifdef OMP_60 +!$omp declare_simd inbranch +#else +!$omp declare simd inbranch +#endif +end subroutine declare_simd_inbranch + +! CHECK-LABEL: func.func @_QPdeclare_simd_inbranch() +! CHECK: omp.declare_simd inbranch{{$}} +! CHECK: return + +subroutine declare_simd_notinbranch() +#ifdef OMP_60 +!$omp declare_simd notinbranch +#else +!$omp declare simd notinbranch +#endif +end subroutine declare_simd_notinbranch + +! CHECK-LABEL: func.func @_QPdeclare_simd_notinbranch() +! CHECK: omp.declare_simd notinbranch{{$}} +! CHECK: return + subroutine declare_simd_combined(x, y, n, i) #ifdef OMP_60 -!$omp declare_simd aligned(x, y : 64) linear(i) simdlen(8) uniform(x, y) +!$omp declare_simd inbranch aligned(x, y : 64) linear(i) simdlen(8) uniform(x, y) #else -!$omp declare simd aligned(x, y : 64) linear(i) simdlen(8) uniform(x, y) +!$omp declare simd inbranch aligned(x, y : 64) linear(i) simdlen(8) uniform(x, y) #endif real(8), pointer, intent(inout) :: x(:) real(8), pointer, intent(in) :: y(:) @@ -132,6 +156,7 @@ end subroutine declare_simd_combined ! CHECK: omp.declare_simd ! CHECK-SAME: aligned(%[[X_DECL]]#0 : !fir.ref>>> -> 64 : i64, ! CHECK-SAME: %[[Y_DECL]]#0 : !fir.ref>>> -> 64 : i64) +! CHECK-SAME: inbranch ! CHECK-SAME: linear(%[[I_DECL]]#0 = %[[C1]] : !fir.ref) ! CHECK-SAME: simdlen(8) ! CHECK-SAME: uniform(%[[X_DECL]]#0 : !fir.ref>>>, diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td index 3972c9aca4b12..e77e31ca884d4 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td @@ -665,6 +665,28 @@ class OpenMP_IfClauseSkip< def OpenMP_IfClause : OpenMP_IfClauseSkip<>; +//===----------------------------------------------------------------------===// +// V5.2: [9.8.1.1]: `inbranch` clause +//===----------------------------------------------------------------------===// + +class OpenMP_InbranchClauseSkip< + bit traits = false, bit arguments = false, bit assemblyFormat = false, + bit description = false, bit extraClassDeclaration = false> + : OpenMP_Clause { + let arguments = (ins UnitAttr:$inbranch); + + let optAssemblyFormat = [{ + `inbranch` $inbranch + }]; + + let description = [{ + The `inbranch` clause indicates that the generated SIMD function variant + is intended for use in conditional branches. + }]; +} +def OpenMP_InbranchClause : OpenMP_InbranchClauseSkip<>; + //===----------------------------------------------------------------------===// // V5.2: [5.5.10] `in_reduction` clause //===----------------------------------------------------------------------===// @@ -913,6 +935,28 @@ class OpenMP_NontemporalClauseSkip< def OpenMP_NontemporalClause : OpenMP_NontemporalClauseSkip<>; +//===----------------------------------------------------------------------===// +// V5.2: [9.8.1.2]: `notinbranch` clause +//===----------------------------------------------------------------------===// + +class OpenMP_NotinbranchClauseSkip< + bit traits = false, bit arguments = false, bit assemblyFormat = false, + bit description = false, bit extraClassDeclaration = false> + : OpenMP_Clause { + let arguments = (ins UnitAttr:$notinbranch); + + let optAssemblyFormat = [{ + `notinbranch` $notinbranch + }]; + + let description = [{ + The `notinbranch` clause indicates that the generated SIMD function variant + is intended for use when not in conditional branches. + }]; +} +def OpenMP_NotinbranchClause : OpenMP_NotinbranchClauseSkip<>; + //===----------------------------------------------------------------------===// // V5.2: [15.6] `nowait` clause //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td index 8e48e50464532..dfec6609e1161 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -2247,7 +2247,8 @@ def WorkdistributeOp : OpenMP_Op<"workdistribute"> { def DeclareSimdOp : OpenMP_Op<"declare_simd", traits = [AttrSizedOperandSegments], - clauses = [OpenMP_AlignedClause, OpenMP_LinearClause, + clauses = [OpenMP_AlignedClause, OpenMP_InbranchClause, + OpenMP_LinearClause, OpenMP_NotinbranchClause, OpenMP_SimdlenClause, OpenMP_UniformClause]> { let summary = "declare simd directive"; let description = [{ diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index 9d548c76d37da..70753b0f2a69a 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -4476,6 +4476,9 @@ LogicalResult DeclareSimdOp::verify() { if (!func) return emitOpError() << "must be nested inside a function"; + if (getInbranch() && getNotinbranch()) + return emitOpError("cannot have both 'inbranch' and 'notinbranch'"); + return verifyAlignedClause(*this, getAlignments(), getAlignedVars()); } @@ -4483,10 +4486,10 @@ void DeclareSimdOp::build(OpBuilder &odsBuilder, OperationState &odsState, const DeclareSimdOperands &clauses) { MLIRContext *ctx = odsBuilder.getContext(); DeclareSimdOp::build(odsBuilder, odsState, clauses.alignedVars, - makeArrayAttr(ctx, clauses.alignments), + makeArrayAttr(ctx, clauses.alignments), clauses.inbranch, clauses.linearVars, clauses.linearStepVars, - clauses.linearVarTypes, clauses.simdlen, - clauses.uniformVars); + clauses.linearVarTypes, clauses.notinbranch, + clauses.simdlen, clauses.uniformVars); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir index 1350c5e2ee8d2..0d9d1f1663ef9 100644 --- a/mlir/test/Dialect/OpenMP/invalid.mlir +++ b/mlir/test/Dialect/OpenMP/invalid.mlir @@ -3143,3 +3143,10 @@ func.func @invalid_workdistribute() -> () { // ----- // expected-error @+1 {{'omp.declare_simd' op must be nested inside a function}} omp.declare_simd + +// ----- +func.func @omp_declare_simd_branch() -> () { + // expected-error @+1 {{'omp.declare_simd' op cannot have both 'inbranch' and 'notinbranch'}} + omp.declare_simd inbranch notinbranch + return +} diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir index 54534cca766d9..cc5e4fdcda5ba 100644 --- a/mlir/test/Dialect/OpenMP/ops.mlir +++ b/mlir/test/Dialect/OpenMP/ops.mlir @@ -3425,6 +3425,41 @@ func.func @omp_declare_simd_uniform(%a: f64, %b: f64, return } +// CHECK-LABEL: func.func @omp_declare_simd_inbranch +func.func @omp_declare_simd_inbranch() -> () { + // CHECK: omp.declare_simd inbranch + omp.declare_simd inbranch + return +} + +// CHECK-LABEL: func.func @omp_declare_simd_notinbranch +func.func @omp_declare_simd_notinbranch() -> () { + // CHECK: omp.declare_simd notinbranch + omp.declare_simd notinbranch + return +} + +// CHECK-LABEL: func.func @omp_declare_simd_multiple_clauses +func.func @omp_declare_simd_multiple_clauses(%a: f64, %b: f64, + %p0: memref, %p1: memref, + %iv: i32, %step: i32) -> () { + // CHECK: omp.declare_simd + // CHECK-SAME: aligned( + // CHECK-SAME: %{{.*}} : memref -> 32 : i64, + // CHECK-SAME: %{{.*}} : memref -> 128 : i64) + // CHECK-SAME: notinbranch + // CHECK-SAME: simdlen(8) + // CHECK-SAME: uniform( + // CHECK-SAME: %{{.*}} : memref, + // CHECK-SAME: %{{.*}} : memref) + omp.declare_simd simdlen(8) + aligned(%p0 : memref -> 32 : i64, + %p1 : memref -> 128 : i64) + uniform(%p0 : memref, %p1 : memref) + notinbranch + return +} + // CHECK-LABEL: func.func @omp_declare_simd_all_clauses func.func @omp_declare_simd_all_clauses(%a: f64, %b: f64, %p0: memref, %p1: memref, @@ -3433,6 +3468,7 @@ func.func @omp_declare_simd_all_clauses(%a: f64, %b: f64, // CHECK-SAME: aligned( // CHECK-SAME: %{{.*}} : memref -> 32 : i64, // CHECK-SAME: %{{.*}} : memref -> 128 : i64) + // CHECK-SAME: inbranch // CHECK-SAME: linear(%{{.*}} = %{{.*}} : i32) // CHECK-SAME: simdlen(8) // CHECK-SAME: uniform( @@ -3443,5 +3479,6 @@ func.func @omp_declare_simd_all_clauses(%a: f64, %b: f64, %p1 : memref -> 128 : i64) linear(%iv = %step : i32) uniform(%p0 : memref, %p1 : memref) + inbranch return }