Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
Original file line number Diff line number Diff line change
Expand Up @@ -1532,4 +1532,76 @@ class OpenMP_UseDevicePtrClauseSkip<

def OpenMP_UseDevicePtrClause : OpenMP_UseDevicePtrClauseSkip<>;

//===----------------------------------------------------------------------===//
// V6.1: `num_teams` clause with dims modifier
//===----------------------------------------------------------------------===//

class OpenMP_NumTeamsMultiDimClauseSkip<
bit traits = false, bit arguments = false, bit assemblyFormat = false,
bit description = false, bit extraClassDeclaration = false
> : OpenMP_Clause<traits, arguments, assemblyFormat, description,
extraClassDeclaration> {
let arguments = (ins
ConfinedAttr<OptionalAttr<I64Attr>, [IntPositive]>:$num_teams_dims,
Variadic<AnyInteger>:$num_teams_values
);

let optAssemblyFormat = [{
`num_teams_multi_dim` `(` custom<DimsModifier>($num_teams_dims,
$num_teams_values,
type($num_teams_values)) `)`
}];

let description = [{
The `num_teams_multi_dim` clause with dims modifier support specifies the limit on
the number of teams to be created in a multidimensional team space.

The dims modifier for the num_teams_multi_dim clause specifies the number of
dimensions for the league space (team space) that the clause arranges.
The dimensions argument in the dims modifier specifies the number of
dimensions and determines the length of the list argument. The list items
are specified in ascending order according to the ordinal number of the
dimensions (dimension 0, 1, 2, ..., N-1).

- If `dims` is not specified: The space is unidimensional (1D) with a single value
- If `dims(1)` is specified: The space is explicitly unidimensional (1D)
- If `dims(N)` where N > 1: The space is strictly multidimensional (N-D)

**Examples:**
- `num_teams_multi_dim(dims(3): %nt0, %nt1, %nt2 : i32, i32, i32)` creates a
3-dimensional team space with limits nt0, nt1, nt2 for dimensions 0, 1, 2.
- `num_teams_multi_dim(%nt : i32)` creates a unidimensional team space with limit nt.
}];

let extraClassDeclaration = [{
/// Returns true if the dims modifier is explicitly present
bool hasDimsModifier() {
return getNumTeamsDims().has_value();
}

/// Returns the number of dimensions specified by dims modifier
/// Returns 1 if dims modifier is not present (unidimensional by default)
unsigned getNumDimensions() {
if (!hasDimsModifier())
return 1;
return static_cast<unsigned>(*getNumTeamsDims());
}

/// Returns all dimension values as an operand range
::mlir::OperandRange getDimensionValues() {
return getNumTeamsValues();
}

/// Returns the value for a specific dimension index
/// Index must be less than getNumDimensions()
::mlir::Value getDimensionValue(unsigned index) {
assert(index < getDimensionValues().size() &&
"Dimension index out of bounds");
return getDimensionValues()[index];
}
}];
}

def OpenMP_NumTeamsMultiDimClause : OpenMP_NumTeamsMultiDimClauseSkip<>;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be rather called modifier instead of clause? The clause still is num_threads, but the modifier is dims.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Original design was to have a separate dims modifier(with dims and values args) class and then create num_teams and thread_limit clauses from it. But this leads to both clauses having the same argument names and when added to teams Op would create an issue.

So, now created just num_teams_multi_dim clause with arguments as num_teams_dims and num_teams_values.
Will remove the old num_teams clause and replace it with num_teams_multi_dim clause and move the name back to num_teams


#endif // OPENMP_CLAUSES
3 changes: 2 additions & 1 deletion mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,8 @@ def TeamsOp : OpenMP_Op<"teams", traits = [
AttrSizedOperandSegments, RecursiveMemoryEffects, OutlineableOpenMPOpInterface
], clauses = [
OpenMP_AllocateClause, OpenMP_IfClause, OpenMP_NumTeamsClause,
OpenMP_PrivateClause, OpenMP_ReductionClause, OpenMP_ThreadLimitClause
OpenMP_NumTeamsMultiDimClause, OpenMP_PrivateClause, OpenMP_ReductionClause,
OpenMP_ThreadLimitClause
], singleRegion = true> {
let summary = "teams construct";
let description = [{
Expand Down
64 changes: 64 additions & 0 deletions mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2621,6 +2621,7 @@ void TeamsOp::build(OpBuilder &builder, OperationState &state,
// TODO Store clauses in op: privateVars, privateSyms, privateNeedsBarrier
TeamsOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
clauses.ifExpr, clauses.numTeamsLower, clauses.numTeamsUpper,
clauses.numTeamsDims, clauses.numTeamsValues,
/*private_vars=*/{}, /*private_syms=*/nullptr,
/*private_needs_barrier=*/nullptr, clauses.reductionMod,
clauses.reductionVars,
Expand Down Expand Up @@ -4453,6 +4454,69 @@ LogicalResult WorkdistributeOp::verify() {
return success();
}

//===----------------------------------------------------------------------===//
// Parser and printer for Clauses with dims modifier
//===----------------------------------------------------------------------===//
// clause_name(dims(3): %v0, %v1, %v2 : i32, i32, i32)
// clause_name(%v : i32)
static ParseResult
parseDimsModifier(OpAsmParser &parser, IntegerAttr &dimsAttr,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
SmallVectorImpl<Type> &types) {
std::optional<int64_t> dims;
// Try to parse optional dims modifier: dims(N):
if (succeeded(parser.parseOptionalKeyword("dims"))) {
int64_t dimsValue;
if (parser.parseLParen() || parser.parseInteger(dimsValue) ||
parser.parseRParen() || parser.parseColon()) {
return failure();
}
dims = dimsValue;
}
// Parse the operand list
if (parser.parseOperandList(values))
return failure();
// Parse colon and types
if (parser.parseColon() || parser.parseTypeList(types))
return failure();

// Verify dims matches number of values if specified
if (dims.has_value() && values.size() != static_cast<size_t>(*dims)) {
return parser.emitError(parser.getCurrentLocation())
<< "dims(" << *dims << ") specified but " << values.size()
<< " values provided";
}

// If dims not specified but we have values, it's implicitly unidimensional
if (!dims.has_value() && values.size() != 1) {
return parser.emitError(parser.getCurrentLocation())
<< "expected 1 value without dims modifier, but got "
<< values.size() << " values";
}

// Convert to IntegerAttr
if (dims.has_value()) {
dimsAttr = parser.getBuilder().getI64IntegerAttr(*dims);
}
return success();
}

static void printDimsModifier(OpAsmPrinter &p, Operation *op,
IntegerAttr dimsAttr, OperandRange values,
TypeRange types) {
// Print dims modifier if present
if (dimsAttr) {
p << "dims(" << dimsAttr.getInt() << "): ";
}

// Print operands
p.printOperands(values);

// Print types
p << " : ";
llvm::interleaveComma(types, p);
}

#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"

Expand Down
4 changes: 2 additions & 2 deletions mlir/test/Dialect/OpenMP/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1438,7 +1438,7 @@ func.func @omp_teams_allocate(%data_var : memref<i32>) {
// expected-error @below {{expected equal sizes for allocate and allocator variables}}
"omp.teams" (%data_var) ({
omp.terminator
}) {operandSegmentSizes = array<i32: 1,0,0,0,0,0,0,0>} : (memref<i32>) -> ()
}) {operandSegmentSizes = array<i32: 1,0,0,0,0,0,0,0,0>} : (memref<i32>) -> ()
omp.terminator
}
return
Expand All @@ -1451,7 +1451,7 @@ func.func @omp_teams_num_teams1(%lb : i32) {
// expected-error @below {{expected num_teams upper bound to be defined if the lower bound is defined}}
"omp.teams" (%lb) ({
omp.terminator
}) {operandSegmentSizes = array<i32: 0,0,0,1,0,0,0,0>} : (i32) -> ()
}) {operandSegmentSizes = array<i32: 0,0,0,1,0,0,0,0,0>} : (i32) -> ()
omp.terminator
}
return
Expand Down
6 changes: 6 additions & 0 deletions mlir/test/Dialect/OpenMP/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1108,6 +1108,12 @@ func.func @omp_teams(%lb : i32, %ub : i32, %if_cond : i1, %num_threads : i32,
omp.terminator
}

// CHECK: omp.teams num_teams_multi_dim(dims(3): %{{.*}}, %{{.*}}, %{{.*}} : i32, i32, i32)
omp.teams num_teams_multi_dim(dims(3): %lb, %ub, %ub : i32, i32, i32) {
// CHECK: omp.terminator
omp.terminator
}

// Test if.
// CHECK: omp.teams if(%{{.+}})
omp.teams if(%if_cond) {
Expand Down