diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td index 8e43c4284d078..d26b3cfb7a86d 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td @@ -1532,4 +1532,76 @@ class OpenMP_UseDevicePtrClauseSkip< def OpenMP_UseDevicePtrClause : OpenMP_UseDevicePtrClauseSkip<>; +//===----------------------------------------------------------------------===// +// V6.1: `num_teams` clause with dims modifier +//===----------------------------------------------------------------------===// + +class OpenMP_NumTeamsMultiDimClauseSkip< + bit traits = false, bit arguments = false, bit assemblyFormat = false, + bit description = false, bit extraClassDeclaration = false + > : OpenMP_Clause { + let arguments = (ins + ConfinedAttr, [IntPositive]>:$num_teams_dims, + Variadic:$num_teams_values + ); + + let optAssemblyFormat = [{ + `num_teams_multi_dim` `(` custom($num_teams_dims, + $num_teams_values, + type($num_teams_values)) `)` + }]; + + let description = [{ + The `num_teams_multi_dim` clause with dims modifier support specifies the limit on + the number of teams to be created in a multidimensional team space. + + The dims modifier for the num_teams_multi_dim clause specifies the number of + dimensions for the league space (team space) that the clause arranges. + The dimensions argument in the dims modifier specifies the number of + dimensions and determines the length of the list argument. The list items + are specified in ascending order according to the ordinal number of the + dimensions (dimension 0, 1, 2, ..., N-1). + + - If `dims` is not specified: The space is unidimensional (1D) with a single value + - If `dims(1)` is specified: The space is explicitly unidimensional (1D) + - If `dims(N)` where N > 1: The space is strictly multidimensional (N-D) + + **Examples:** + - `num_teams_multi_dim(dims(3): %nt0, %nt1, %nt2 : i32, i32, i32)` creates a + 3-dimensional team space with limits nt0, nt1, nt2 for dimensions 0, 1, 2. + - `num_teams_multi_dim(%nt : i32)` creates a unidimensional team space with limit nt. + }]; + + let extraClassDeclaration = [{ + /// Returns true if the dims modifier is explicitly present + bool hasDimsModifier() { + return getNumTeamsDims().has_value(); + } + + /// Returns the number of dimensions specified by dims modifier + /// Returns 1 if dims modifier is not present (unidimensional by default) + unsigned getNumDimensions() { + if (!hasDimsModifier()) + return 1; + return static_cast(*getNumTeamsDims()); + } + + /// Returns all dimension values as an operand range + ::mlir::OperandRange getDimensionValues() { + return getNumTeamsValues(); + } + + /// Returns the value for a specific dimension index + /// Index must be less than getNumDimensions() + ::mlir::Value getDimensionValue(unsigned index) { + assert(index < getDimensionValues().size() && + "Dimension index out of bounds"); + return getDimensionValues()[index]; + } + }]; +} + +def OpenMP_NumTeamsMultiDimClause : OpenMP_NumTeamsMultiDimClauseSkip<>; + #endif // OPENMP_CLAUSES diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td index bbfe805eefe48..ea440cf924a95 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -241,7 +241,8 @@ def TeamsOp : OpenMP_Op<"teams", traits = [ AttrSizedOperandSegments, RecursiveMemoryEffects, OutlineableOpenMPOpInterface ], clauses = [ OpenMP_AllocateClause, OpenMP_IfClause, OpenMP_NumTeamsClause, - OpenMP_PrivateClause, OpenMP_ReductionClause, OpenMP_ThreadLimitClause + OpenMP_NumTeamsMultiDimClause, OpenMP_PrivateClause, OpenMP_ReductionClause, + OpenMP_ThreadLimitClause ], singleRegion = true> { let summary = "teams construct"; let description = [{ diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index 0d6b2870c625a..66b36caccada3 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -2621,6 +2621,7 @@ void TeamsOp::build(OpBuilder &builder, OperationState &state, // TODO Store clauses in op: privateVars, privateSyms, privateNeedsBarrier TeamsOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars, clauses.ifExpr, clauses.numTeamsLower, clauses.numTeamsUpper, + clauses.numTeamsDims, clauses.numTeamsValues, /*private_vars=*/{}, /*private_syms=*/nullptr, /*private_needs_barrier=*/nullptr, clauses.reductionMod, clauses.reductionVars, @@ -4453,6 +4454,69 @@ LogicalResult WorkdistributeOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// Parser and printer for Clauses with dims modifier +//===----------------------------------------------------------------------===// +// clause_name(dims(3): %v0, %v1, %v2 : i32, i32, i32) +// clause_name(%v : i32) +static ParseResult +parseDimsModifier(OpAsmParser &parser, IntegerAttr &dimsAttr, + SmallVectorImpl &values, + SmallVectorImpl &types) { + std::optional dims; + // Try to parse optional dims modifier: dims(N): + if (succeeded(parser.parseOptionalKeyword("dims"))) { + int64_t dimsValue; + if (parser.parseLParen() || parser.parseInteger(dimsValue) || + parser.parseRParen() || parser.parseColon()) { + return failure(); + } + dims = dimsValue; + } + // Parse the operand list + if (parser.parseOperandList(values)) + return failure(); + // Parse colon and types + if (parser.parseColon() || parser.parseTypeList(types)) + return failure(); + + // Verify dims matches number of values if specified + if (dims.has_value() && values.size() != static_cast(*dims)) { + return parser.emitError(parser.getCurrentLocation()) + << "dims(" << *dims << ") specified but " << values.size() + << " values provided"; + } + + // If dims not specified but we have values, it's implicitly unidimensional + if (!dims.has_value() && values.size() != 1) { + return parser.emitError(parser.getCurrentLocation()) + << "expected 1 value without dims modifier, but got " + << values.size() << " values"; + } + + // Convert to IntegerAttr + if (dims.has_value()) { + dimsAttr = parser.getBuilder().getI64IntegerAttr(*dims); + } + return success(); +} + +static void printDimsModifier(OpAsmPrinter &p, Operation *op, + IntegerAttr dimsAttr, OperandRange values, + TypeRange types) { + // Print dims modifier if present + if (dimsAttr) { + p << "dims(" << dimsAttr.getInt() << "): "; + } + + // Print operands + p.printOperands(values); + + // Print types + p << " : "; + llvm::interleaveComma(types, p); +} + #define GET_ATTRDEF_CLASSES #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc" diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir index af24d969064ab..62619f07d6573 100644 --- a/mlir/test/Dialect/OpenMP/invalid.mlir +++ b/mlir/test/Dialect/OpenMP/invalid.mlir @@ -1438,7 +1438,7 @@ func.func @omp_teams_allocate(%data_var : memref) { // expected-error @below {{expected equal sizes for allocate and allocator variables}} "omp.teams" (%data_var) ({ omp.terminator - }) {operandSegmentSizes = array} : (memref) -> () + }) {operandSegmentSizes = array} : (memref) -> () omp.terminator } return @@ -1451,7 +1451,7 @@ func.func @omp_teams_num_teams1(%lb : i32) { // expected-error @below {{expected num_teams upper bound to be defined if the lower bound is defined}} "omp.teams" (%lb) ({ omp.terminator - }) {operandSegmentSizes = array} : (i32) -> () + }) {operandSegmentSizes = array} : (i32) -> () omp.terminator } return diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir index ac29e20907b55..bb154fad12742 100644 --- a/mlir/test/Dialect/OpenMP/ops.mlir +++ b/mlir/test/Dialect/OpenMP/ops.mlir @@ -1108,6 +1108,12 @@ func.func @omp_teams(%lb : i32, %ub : i32, %if_cond : i1, %num_threads : i32, omp.terminator } + // CHECK: omp.teams num_teams_multi_dim(dims(3): %{{.*}}, %{{.*}}, %{{.*}} : i32, i32, i32) + omp.teams num_teams_multi_dim(dims(3): %lb, %ub, %ub : i32, i32, i32) { + // CHECK: omp.terminator + omp.terminator + } + // Test if. // CHECK: omp.teams if(%{{.+}}) omp.teams if(%if_cond) {