-
Notifications
You must be signed in to change notification settings - Fork 15.5k
[OpenMP][MLIR] Add num_threads clause with dims modifier support #171767
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: users/skc7/omp_teams_multidim
Are you sure you want to change the base?
Conversation
45e7dab to
33dcfd9
Compare
|
@llvm/pr-subscribers-mlir Author: Chaitanya (skc7) ChangesPR adds support of openmp 6.1 feature num_threads with dims modifier. Full diff: https://github.com/llvm/llvm-project/pull/171767.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index e36dc7c246f01..09c1d4a8a5866 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -1069,16 +1069,55 @@ class OpenMP_NumThreadsClauseSkip<
> : OpenMP_Clause<traits, arguments, assemblyFormat, description,
extraClassDeclaration> {
let arguments = (ins
+ ConfinedAttr<OptionalAttr<I64Attr>, [IntPositive]>:$num_threads_num_dims,
+ Variadic<AnyInteger>:$num_threads_dims_values,
Optional<IntLikeType>:$num_threads
);
let optAssemblyFormat = [{
- `num_threads` `(` $num_threads `:` type($num_threads) `)`
+ `num_threads` `(` custom<NumThreadsClause>(
+ $num_threads_num_dims, $num_threads_dims_values, type($num_threads_dims_values),
+ $num_threads, type($num_threads)
+ ) `)`
}];
let description = [{
- The optional `num_threads` parameter specifies the number of threads which
- should be used to execute the parallel region.
+ num_threads clause specifies the desired number of threads in the team
+ space formed by the construct on which it appears.
+
+ With dims modifier:
+ - Uses `num_threads_num_dims` (dimension count) and `num_threads_dims_values` (upper bounds list)
+ - Specifies upper bounds for each dimension (all must have same type)
+ - Format: `num_threads(dims(N): upper_bound_0, ..., upper_bound_N-1 : type)`
+ - Example: `num_threads(dims(3): %ub0, %ub1, %ub2 : i32)`
+
+ Without dims modifier:
+ - Uses `num_threads`
+ - If lower bound not specified, it defaults to upper bound value
+ - Format: `num_threads(bounds : type)`
+ - Example: `num_threads(%ub : i32)`
+ }];
+
+ let extraClassDeclaration = [{
+ /// Returns true if the dims modifier is explicitly present
+ bool hasNumThreadsDimsModifier() {
+ return getNumThreadsNumDims().has_value() && getNumThreadsNumDims().value();
+ }
+
+ /// Returns the number of dimensions specified by dims modifier
+ unsigned getNumThreadsDimsCount() {
+ if (!hasNumThreadsDimsModifier())
+ return 1;
+ return static_cast<unsigned>(*getNumThreadsNumDims());
+ }
+
+ /// Returns the value for a specific dimension index
+ /// Index must be less than getNumThreadsDimsCount()
+ ::mlir::Value getNumThreadsDimsValue(unsigned index) {
+ assert(index < getNumThreadsDimsCount() &&
+ "Num threads dims index out of bounds");
+ return getNumThreadsDimsValues()[index];
+ }
}];
}
diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
index 6423d49859c97..ab7bded7835be 100644
--- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
+++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
@@ -448,6 +448,8 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
/* allocate_vars = */ llvm::SmallVector<Value>{},
/* allocator_vars = */ llvm::SmallVector<Value>{},
/* if_expr = */ Value{},
+ /* num_threads_num_dims = */ nullptr,
+ /* num_threads_dims_values = */ llvm::SmallVector<Value>{},
/* num_threads = */ numThreadsVar,
/* private_vars = */ ValueRange(),
/* private_syms = */ nullptr,
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index d4dbf5f5244df..a9ed0274cd21c 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -2533,6 +2533,8 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state,
ArrayRef<NamedAttribute> attributes) {
ParallelOp::build(builder, state, /*allocate_vars=*/ValueRange(),
/*allocator_vars=*/ValueRange(), /*if_expr=*/nullptr,
+ /*num_threads_dims=*/nullptr,
+ /*num_threads_values=*/ValueRange(),
/*num_threads=*/nullptr, /*private_vars=*/ValueRange(),
/*private_syms=*/nullptr, /*private_needs_barrier=*/nullptr,
/*proc_bind_kind=*/nullptr,
@@ -2544,13 +2546,14 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state,
void ParallelOp::build(OpBuilder &builder, OperationState &state,
const ParallelOperands &clauses) {
MLIRContext *ctx = builder.getContext();
- ParallelOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
- clauses.ifExpr, clauses.numThreads, clauses.privateVars,
- makeArrayAttr(ctx, clauses.privateSyms),
- clauses.privateNeedsBarrier, clauses.procBindKind,
- clauses.reductionMod, clauses.reductionVars,
- makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
- makeArrayAttr(ctx, clauses.reductionSyms));
+ ParallelOp::build(
+ builder, state, clauses.allocateVars, clauses.allocatorVars,
+ clauses.ifExpr, clauses.numThreadsNumDims, clauses.numThreadsDimsValues,
+ clauses.numThreads, clauses.privateVars,
+ makeArrayAttr(ctx, clauses.privateSyms), clauses.privateNeedsBarrier,
+ clauses.procBindKind, clauses.reductionMod, clauses.reductionVars,
+ makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
+ makeArrayAttr(ctx, clauses.reductionSyms));
}
template <typename OpType>
@@ -2596,14 +2599,39 @@ static LogicalResult verifyPrivateVarList(OpType &op) {
return success();
}
+// Helper: Verify num_threads clause
+LogicalResult
+verifyNumThreadsClause(Operation *op,
+ std::optional<IntegerAttr> numThreadsNumDims,
+ OperandRange numThreadsDimsValues, Value numThreads) {
+ bool hasDimsModifier =
+ numThreadsNumDims.has_value() && numThreadsNumDims.value();
+ if (hasDimsModifier && numThreads) {
+ return op->emitError("num_threads with dims modifier cannot be used "
+ "together with number of threads");
+ }
+ if (failed(verifyDimsModifier(op, numThreadsNumDims, numThreadsDimsValues)))
+ return failure();
+ return success();
+}
+
LogicalResult ParallelOp::verify() {
+ // verify num_threads clause restrictions
+ if (failed(verifyNumThreadsClause(
+ getOperation(), this->getNumThreadsNumDimsAttr(),
+ this->getNumThreadsDimsValues(), this->getNumThreads())))
+ return failure();
+
+ // verify allocate clause restrictions
if (getAllocateVars().size() != getAllocatorVars().size())
return emitError(
"expected equal sizes for allocate and allocator variables");
+ // verify private variables restrictions
if (failed(verifyPrivateVarList(*this)))
return failure();
+ // verify reduction variables restrictions
return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
getReductionByref());
}
@@ -4647,6 +4675,41 @@ static void printNumTeamsClause(OpAsmPrinter &p, Operation *op,
}
}
+//===----------------------------------------------------------------------===//
+// Parser and printer for num_threads clause
+//===----------------------------------------------------------------------===//
+static ParseResult
+parseNumThreadsClause(OpAsmParser &parser, IntegerAttr &dimsAttr,
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
+ SmallVectorImpl<Type> &types,
+ std::optional<OpAsmParser::UnresolvedOperand> &bounds,
+ Type &boundsType) {
+ if (succeeded(parseDimsModifierWithValues(parser, dimsAttr, values, types))) {
+ return success();
+ }
+
+ OpAsmParser::UnresolvedOperand boundsOperand;
+ if (parser.parseOperand(boundsOperand) || parser.parseColon() ||
+ parser.parseType(boundsType)) {
+ return failure();
+ }
+ bounds = boundsOperand;
+ return success();
+}
+
+static void printNumThreadsClause(OpAsmPrinter &p, Operation *op,
+ IntegerAttr dimsAttr, OperandRange values,
+ TypeRange types, Value bounds,
+ Type boundsType) {
+ if (!values.empty()) {
+ printDimsModifierWithValues(p, dimsAttr, values, types);
+ }
+ if (bounds) {
+ p.printOperand(bounds);
+ p << " : " << boundsType;
+ }
+}
+
#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 00f782e87d5af..2bfb9fb2211c4 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -2879,6 +2879,9 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
if (auto ifVar = opInst.getIfExpr())
ifCond = moduleTranslation.lookupValue(ifVar);
llvm::Value *numThreads = nullptr;
+ // num_threads dims and values are not yet supported
+ assert(!opInst.hasNumThreadsDimsModifier() &&
+ "Lowering of num_threads with dims modifier is NYI.");
if (auto numThreadsVar = opInst.getNumThreads())
numThreads = moduleTranslation.lookupValue(numThreadsVar);
auto pbKind = llvm::omp::OMP_PROC_BIND_default;
@@ -5604,6 +5607,9 @@ extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
llvm_unreachable("unsupported host_eval use");
})
.Case([&](omp::ParallelOp parallelOp) {
+ // num_threads dims and values are not yet supported
+ assert(!parallelOp.hasNumThreadsDimsModifier() &&
+ "Lowering of num_threads with dims modifier is NYI.");
if (parallelOp.getNumThreads() == blockArg)
numThreads = hostEvalVar;
else
@@ -5724,8 +5730,12 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
threadLimit = teamsOp.getThreadLimit();
}
- if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp))
+ if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp)) {
+ // num_threads dims and values are not yet supported
+ assert(!parallelOp.hasNumThreadsDimsModifier() &&
+ "Lowering of num_threads with dims modifier is NYI.");
numThreads = parallelOp.getNumThreads();
+ }
}
// Handle clauses impacting the number of teams.
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index dd367aba8da27..db0ddcb415d42 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -30,6 +30,37 @@ func.func @num_threads_once(%n : si32) {
// -----
+func.func @num_threads_dims_no_values() {
+ // expected-error@+1 {{dims modifier requires values to be specified}}
+ "omp.parallel"() ({
+ omp.terminator
+ }) {operandSegmentSizes = array<i32: 0,0,0,0,0,0,0>, num_threads_num_dims = 2 : i64} : () -> ()
+ return
+}
+
+// -----
+
+func.func @num_threads_dims_mismatch(%n : i64) {
+ // expected-error@+1 {{dims(2) specified but 1 values provided}}
+ omp.parallel num_threads(dims(2): %n : i64) {
+ omp.terminator
+ }
+
+ return
+}
+
+// -----
+
+func.func @num_threads_dims_and_scalar(%n : i64, %m: i64) {
+ // expected-error@+1 {{num_threads with dims modifier cannot be used together with number of threads}}
+ "omp.parallel"(%n, %n, %m) ({
+ omp.terminator
+ }) {operandSegmentSizes = array<i32: 0,0,0,2,1,0,0>, num_threads_num_dims = 2 : i64} : (i64, i64, i64) -> ()
+ return
+}
+
+// -----
+
func.func @nowait_not_allowed(%n : memref<i32>) {
// expected-error@+1 {{expected '{' to begin a region}}
omp.parallel nowait {}
@@ -2766,7 +2797,7 @@ func.func @undefined_privatizer(%arg0: index) {
// -----
func.func @undefined_privatizer(%arg0: !llvm.ptr) {
// expected-error @below {{inconsistent number of private variables and privatizer op symbols, private vars: 1 vs. privatizer op symbols: 2}}
- "omp.parallel"(%arg0) <{operandSegmentSizes = array<i32: 0, 0, 0, 0, 1, 0>, private_syms = [@x.privatizer, @y.privatizer]}> ({
+ "omp.parallel"(%arg0) <{operandSegmentSizes = array<i32: 0, 0, 0, 0, 0, 1, 0>, private_syms = [@x.privatizer, @y.privatizer]}> ({
^bb0(%arg2: !llvm.ptr):
omp.terminator
}) : (!llvm.ptr) -> ()
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 3633a4be1eb62..585c9483c08a9 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -73,7 +73,7 @@ func.func @omp_parallel(%data_var : memref<i32>, %if_cond : i1, %num_threads : i
// CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>) num_threads(%{{.*}} : i32)
"omp.parallel"(%data_var, %data_var, %num_threads) ({
omp.terminator
- }) {operandSegmentSizes = array<i32: 1,1,0,1,0,0>} : (memref<i32>, memref<i32>, i32) -> ()
+ }) {operandSegmentSizes = array<i32: 1,1,0,0,1,0,0>} : (memref<i32>, memref<i32>, i32) -> ()
// CHECK: omp.barrier
omp.barrier
@@ -82,22 +82,22 @@ func.func @omp_parallel(%data_var : memref<i32>, %if_cond : i1, %num_threads : i
// CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>) if(%{{.*}})
"omp.parallel"(%data_var, %data_var, %if_cond) ({
omp.terminator
- }) {operandSegmentSizes = array<i32: 1,1,1,0,0,0>} : (memref<i32>, memref<i32>, i1) -> ()
+ }) {operandSegmentSizes = array<i32: 1,1,1,0,0,0,0>} : (memref<i32>, memref<i32>, i1) -> ()
// test without allocate
// CHECK: omp.parallel if(%{{.*}}) num_threads(%{{.*}} : i32)
"omp.parallel"(%if_cond, %num_threads) ({
omp.terminator
- }) {operandSegmentSizes = array<i32: 0,0,1,1,0,0>} : (i1, i32) -> ()
+ }) {operandSegmentSizes = array<i32: 0,0,1,0,1,0,0>} : (i1, i32) -> ()
omp.terminator
- }) {operandSegmentSizes = array<i32: 1,1,1,1,0,0>, proc_bind_kind = #omp<procbindkind spread>} : (memref<i32>, memref<i32>, i1, i32) -> ()
+ }) {operandSegmentSizes = array<i32: 1,1,1,0,1,0,0>, proc_bind_kind = #omp<procbindkind spread>} : (memref<i32>, memref<i32>, i1, i32) -> ()
// test with multiple parameters for single variadic argument
// CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>)
"omp.parallel" (%data_var, %data_var) ({
omp.terminator
- }) {operandSegmentSizes = array<i32: 1,1,0,0,0,0>} : (memref<i32>, memref<i32>) -> ()
+ }) {operandSegmentSizes = array<i32: 1,1,0,0,0,0,0>} : (memref<i32>, memref<i32>) -> ()
// CHECK: omp.parallel
omp.parallel {
@@ -160,6 +160,11 @@ func.func @omp_parallel_pretty(%data_var : memref<i32>, %if_cond : i1, %num_thre
omp.terminator
}
+ // CHECK: omp.parallel num_threads(dims(2): %{{.*}}, %{{.*}} : i64)
+ omp.parallel num_threads(dims(2): %n_i64, %n_i64 : i64) {
+ omp.terminator
+ }
+
// CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>)
omp.parallel allocate(%data_var : memref<i32> -> %data_var : memref<i32>) {
omp.terminator
|
|
@llvm/pr-subscribers-mlir-openmp Author: Chaitanya (skc7) ChangesPR adds support of openmp 6.1 feature num_threads with dims modifier. Full diff: https://github.com/llvm/llvm-project/pull/171767.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index e36dc7c246f01..09c1d4a8a5866 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -1069,16 +1069,55 @@ class OpenMP_NumThreadsClauseSkip<
> : OpenMP_Clause<traits, arguments, assemblyFormat, description,
extraClassDeclaration> {
let arguments = (ins
+ ConfinedAttr<OptionalAttr<I64Attr>, [IntPositive]>:$num_threads_num_dims,
+ Variadic<AnyInteger>:$num_threads_dims_values,
Optional<IntLikeType>:$num_threads
);
let optAssemblyFormat = [{
- `num_threads` `(` $num_threads `:` type($num_threads) `)`
+ `num_threads` `(` custom<NumThreadsClause>(
+ $num_threads_num_dims, $num_threads_dims_values, type($num_threads_dims_values),
+ $num_threads, type($num_threads)
+ ) `)`
}];
let description = [{
- The optional `num_threads` parameter specifies the number of threads which
- should be used to execute the parallel region.
+ num_threads clause specifies the desired number of threads in the team
+ space formed by the construct on which it appears.
+
+ With dims modifier:
+ - Uses `num_threads_num_dims` (dimension count) and `num_threads_dims_values` (upper bounds list)
+ - Specifies upper bounds for each dimension (all must have same type)
+ - Format: `num_threads(dims(N): upper_bound_0, ..., upper_bound_N-1 : type)`
+ - Example: `num_threads(dims(3): %ub0, %ub1, %ub2 : i32)`
+
+ Without dims modifier:
+ - Uses `num_threads`
+ - If lower bound not specified, it defaults to upper bound value
+ - Format: `num_threads(bounds : type)`
+ - Example: `num_threads(%ub : i32)`
+ }];
+
+ let extraClassDeclaration = [{
+ /// Returns true if the dims modifier is explicitly present
+ bool hasNumThreadsDimsModifier() {
+ return getNumThreadsNumDims().has_value() && getNumThreadsNumDims().value();
+ }
+
+ /// Returns the number of dimensions specified by dims modifier
+ unsigned getNumThreadsDimsCount() {
+ if (!hasNumThreadsDimsModifier())
+ return 1;
+ return static_cast<unsigned>(*getNumThreadsNumDims());
+ }
+
+ /// Returns the value for a specific dimension index
+ /// Index must be less than getNumThreadsDimsCount()
+ ::mlir::Value getNumThreadsDimsValue(unsigned index) {
+ assert(index < getNumThreadsDimsCount() &&
+ "Num threads dims index out of bounds");
+ return getNumThreadsDimsValues()[index];
+ }
}];
}
diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
index 6423d49859c97..ab7bded7835be 100644
--- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
+++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
@@ -448,6 +448,8 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
/* allocate_vars = */ llvm::SmallVector<Value>{},
/* allocator_vars = */ llvm::SmallVector<Value>{},
/* if_expr = */ Value{},
+ /* num_threads_num_dims = */ nullptr,
+ /* num_threads_dims_values = */ llvm::SmallVector<Value>{},
/* num_threads = */ numThreadsVar,
/* private_vars = */ ValueRange(),
/* private_syms = */ nullptr,
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index d4dbf5f5244df..a9ed0274cd21c 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -2533,6 +2533,8 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state,
ArrayRef<NamedAttribute> attributes) {
ParallelOp::build(builder, state, /*allocate_vars=*/ValueRange(),
/*allocator_vars=*/ValueRange(), /*if_expr=*/nullptr,
+ /*num_threads_dims=*/nullptr,
+ /*num_threads_values=*/ValueRange(),
/*num_threads=*/nullptr, /*private_vars=*/ValueRange(),
/*private_syms=*/nullptr, /*private_needs_barrier=*/nullptr,
/*proc_bind_kind=*/nullptr,
@@ -2544,13 +2546,14 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state,
void ParallelOp::build(OpBuilder &builder, OperationState &state,
const ParallelOperands &clauses) {
MLIRContext *ctx = builder.getContext();
- ParallelOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
- clauses.ifExpr, clauses.numThreads, clauses.privateVars,
- makeArrayAttr(ctx, clauses.privateSyms),
- clauses.privateNeedsBarrier, clauses.procBindKind,
- clauses.reductionMod, clauses.reductionVars,
- makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
- makeArrayAttr(ctx, clauses.reductionSyms));
+ ParallelOp::build(
+ builder, state, clauses.allocateVars, clauses.allocatorVars,
+ clauses.ifExpr, clauses.numThreadsNumDims, clauses.numThreadsDimsValues,
+ clauses.numThreads, clauses.privateVars,
+ makeArrayAttr(ctx, clauses.privateSyms), clauses.privateNeedsBarrier,
+ clauses.procBindKind, clauses.reductionMod, clauses.reductionVars,
+ makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
+ makeArrayAttr(ctx, clauses.reductionSyms));
}
template <typename OpType>
@@ -2596,14 +2599,39 @@ static LogicalResult verifyPrivateVarList(OpType &op) {
return success();
}
+// Helper: Verify num_threads clause
+LogicalResult
+verifyNumThreadsClause(Operation *op,
+ std::optional<IntegerAttr> numThreadsNumDims,
+ OperandRange numThreadsDimsValues, Value numThreads) {
+ bool hasDimsModifier =
+ numThreadsNumDims.has_value() && numThreadsNumDims.value();
+ if (hasDimsModifier && numThreads) {
+ return op->emitError("num_threads with dims modifier cannot be used "
+ "together with number of threads");
+ }
+ if (failed(verifyDimsModifier(op, numThreadsNumDims, numThreadsDimsValues)))
+ return failure();
+ return success();
+}
+
LogicalResult ParallelOp::verify() {
+ // verify num_threads clause restrictions
+ if (failed(verifyNumThreadsClause(
+ getOperation(), this->getNumThreadsNumDimsAttr(),
+ this->getNumThreadsDimsValues(), this->getNumThreads())))
+ return failure();
+
+ // verify allocate clause restrictions
if (getAllocateVars().size() != getAllocatorVars().size())
return emitError(
"expected equal sizes for allocate and allocator variables");
+ // verify private variables restrictions
if (failed(verifyPrivateVarList(*this)))
return failure();
+ // verify reduction variables restrictions
return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
getReductionByref());
}
@@ -4647,6 +4675,41 @@ static void printNumTeamsClause(OpAsmPrinter &p, Operation *op,
}
}
+//===----------------------------------------------------------------------===//
+// Parser and printer for num_threads clause
+//===----------------------------------------------------------------------===//
+static ParseResult
+parseNumThreadsClause(OpAsmParser &parser, IntegerAttr &dimsAttr,
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
+ SmallVectorImpl<Type> &types,
+ std::optional<OpAsmParser::UnresolvedOperand> &bounds,
+ Type &boundsType) {
+ if (succeeded(parseDimsModifierWithValues(parser, dimsAttr, values, types))) {
+ return success();
+ }
+
+ OpAsmParser::UnresolvedOperand boundsOperand;
+ if (parser.parseOperand(boundsOperand) || parser.parseColon() ||
+ parser.parseType(boundsType)) {
+ return failure();
+ }
+ bounds = boundsOperand;
+ return success();
+}
+
+static void printNumThreadsClause(OpAsmPrinter &p, Operation *op,
+ IntegerAttr dimsAttr, OperandRange values,
+ TypeRange types, Value bounds,
+ Type boundsType) {
+ if (!values.empty()) {
+ printDimsModifierWithValues(p, dimsAttr, values, types);
+ }
+ if (bounds) {
+ p.printOperand(bounds);
+ p << " : " << boundsType;
+ }
+}
+
#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 00f782e87d5af..2bfb9fb2211c4 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -2879,6 +2879,9 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
if (auto ifVar = opInst.getIfExpr())
ifCond = moduleTranslation.lookupValue(ifVar);
llvm::Value *numThreads = nullptr;
+ // num_threads dims and values are not yet supported
+ assert(!opInst.hasNumThreadsDimsModifier() &&
+ "Lowering of num_threads with dims modifier is NYI.");
if (auto numThreadsVar = opInst.getNumThreads())
numThreads = moduleTranslation.lookupValue(numThreadsVar);
auto pbKind = llvm::omp::OMP_PROC_BIND_default;
@@ -5604,6 +5607,9 @@ extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
llvm_unreachable("unsupported host_eval use");
})
.Case([&](omp::ParallelOp parallelOp) {
+ // num_threads dims and values are not yet supported
+ assert(!parallelOp.hasNumThreadsDimsModifier() &&
+ "Lowering of num_threads with dims modifier is NYI.");
if (parallelOp.getNumThreads() == blockArg)
numThreads = hostEvalVar;
else
@@ -5724,8 +5730,12 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
threadLimit = teamsOp.getThreadLimit();
}
- if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp))
+ if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp)) {
+ // num_threads dims and values are not yet supported
+ assert(!parallelOp.hasNumThreadsDimsModifier() &&
+ "Lowering of num_threads with dims modifier is NYI.");
numThreads = parallelOp.getNumThreads();
+ }
}
// Handle clauses impacting the number of teams.
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index dd367aba8da27..db0ddcb415d42 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -30,6 +30,37 @@ func.func @num_threads_once(%n : si32) {
// -----
+func.func @num_threads_dims_no_values() {
+ // expected-error@+1 {{dims modifier requires values to be specified}}
+ "omp.parallel"() ({
+ omp.terminator
+ }) {operandSegmentSizes = array<i32: 0,0,0,0,0,0,0>, num_threads_num_dims = 2 : i64} : () -> ()
+ return
+}
+
+// -----
+
+func.func @num_threads_dims_mismatch(%n : i64) {
+ // expected-error@+1 {{dims(2) specified but 1 values provided}}
+ omp.parallel num_threads(dims(2): %n : i64) {
+ omp.terminator
+ }
+
+ return
+}
+
+// -----
+
+func.func @num_threads_dims_and_scalar(%n : i64, %m: i64) {
+ // expected-error@+1 {{num_threads with dims modifier cannot be used together with number of threads}}
+ "omp.parallel"(%n, %n, %m) ({
+ omp.terminator
+ }) {operandSegmentSizes = array<i32: 0,0,0,2,1,0,0>, num_threads_num_dims = 2 : i64} : (i64, i64, i64) -> ()
+ return
+}
+
+// -----
+
func.func @nowait_not_allowed(%n : memref<i32>) {
// expected-error@+1 {{expected '{' to begin a region}}
omp.parallel nowait {}
@@ -2766,7 +2797,7 @@ func.func @undefined_privatizer(%arg0: index) {
// -----
func.func @undefined_privatizer(%arg0: !llvm.ptr) {
// expected-error @below {{inconsistent number of private variables and privatizer op symbols, private vars: 1 vs. privatizer op symbols: 2}}
- "omp.parallel"(%arg0) <{operandSegmentSizes = array<i32: 0, 0, 0, 0, 1, 0>, private_syms = [@x.privatizer, @y.privatizer]}> ({
+ "omp.parallel"(%arg0) <{operandSegmentSizes = array<i32: 0, 0, 0, 0, 0, 1, 0>, private_syms = [@x.privatizer, @y.privatizer]}> ({
^bb0(%arg2: !llvm.ptr):
omp.terminator
}) : (!llvm.ptr) -> ()
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 3633a4be1eb62..585c9483c08a9 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -73,7 +73,7 @@ func.func @omp_parallel(%data_var : memref<i32>, %if_cond : i1, %num_threads : i
// CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>) num_threads(%{{.*}} : i32)
"omp.parallel"(%data_var, %data_var, %num_threads) ({
omp.terminator
- }) {operandSegmentSizes = array<i32: 1,1,0,1,0,0>} : (memref<i32>, memref<i32>, i32) -> ()
+ }) {operandSegmentSizes = array<i32: 1,1,0,0,1,0,0>} : (memref<i32>, memref<i32>, i32) -> ()
// CHECK: omp.barrier
omp.barrier
@@ -82,22 +82,22 @@ func.func @omp_parallel(%data_var : memref<i32>, %if_cond : i1, %num_threads : i
// CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>) if(%{{.*}})
"omp.parallel"(%data_var, %data_var, %if_cond) ({
omp.terminator
- }) {operandSegmentSizes = array<i32: 1,1,1,0,0,0>} : (memref<i32>, memref<i32>, i1) -> ()
+ }) {operandSegmentSizes = array<i32: 1,1,1,0,0,0,0>} : (memref<i32>, memref<i32>, i1) -> ()
// test without allocate
// CHECK: omp.parallel if(%{{.*}}) num_threads(%{{.*}} : i32)
"omp.parallel"(%if_cond, %num_threads) ({
omp.terminator
- }) {operandSegmentSizes = array<i32: 0,0,1,1,0,0>} : (i1, i32) -> ()
+ }) {operandSegmentSizes = array<i32: 0,0,1,0,1,0,0>} : (i1, i32) -> ()
omp.terminator
- }) {operandSegmentSizes = array<i32: 1,1,1,1,0,0>, proc_bind_kind = #omp<procbindkind spread>} : (memref<i32>, memref<i32>, i1, i32) -> ()
+ }) {operandSegmentSizes = array<i32: 1,1,1,0,1,0,0>, proc_bind_kind = #omp<procbindkind spread>} : (memref<i32>, memref<i32>, i1, i32) -> ()
// test with multiple parameters for single variadic argument
// CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>)
"omp.parallel" (%data_var, %data_var) ({
omp.terminator
- }) {operandSegmentSizes = array<i32: 1,1,0,0,0,0>} : (memref<i32>, memref<i32>) -> ()
+ }) {operandSegmentSizes = array<i32: 1,1,0,0,0,0,0>} : (memref<i32>, memref<i32>) -> ()
// CHECK: omp.parallel
omp.parallel {
@@ -160,6 +160,11 @@ func.func @omp_parallel_pretty(%data_var : memref<i32>, %if_cond : i1, %num_thre
omp.terminator
}
+ // CHECK: omp.parallel num_threads(dims(2): %{{.*}}, %{{.*}} : i64)
+ omp.parallel num_threads(dims(2): %n_i64, %n_i64 : i64) {
+ omp.terminator
+ }
+
// CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>)
omp.parallel allocate(%data_var : memref<i32> -> %data_var : memref<i32>) {
omp.terminator
|
|
@llvm/pr-subscribers-mlir-llvm Author: Chaitanya (skc7) ChangesPR adds support of openmp 6.1 feature num_threads with dims modifier. Full diff: https://github.com/llvm/llvm-project/pull/171767.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index e36dc7c246f01..09c1d4a8a5866 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -1069,16 +1069,55 @@ class OpenMP_NumThreadsClauseSkip<
> : OpenMP_Clause<traits, arguments, assemblyFormat, description,
extraClassDeclaration> {
let arguments = (ins
+ ConfinedAttr<OptionalAttr<I64Attr>, [IntPositive]>:$num_threads_num_dims,
+ Variadic<AnyInteger>:$num_threads_dims_values,
Optional<IntLikeType>:$num_threads
);
let optAssemblyFormat = [{
- `num_threads` `(` $num_threads `:` type($num_threads) `)`
+ `num_threads` `(` custom<NumThreadsClause>(
+ $num_threads_num_dims, $num_threads_dims_values, type($num_threads_dims_values),
+ $num_threads, type($num_threads)
+ ) `)`
}];
let description = [{
- The optional `num_threads` parameter specifies the number of threads which
- should be used to execute the parallel region.
+ num_threads clause specifies the desired number of threads in the team
+ space formed by the construct on which it appears.
+
+ With dims modifier:
+ - Uses `num_threads_num_dims` (dimension count) and `num_threads_dims_values` (upper bounds list)
+ - Specifies upper bounds for each dimension (all must have same type)
+ - Format: `num_threads(dims(N): upper_bound_0, ..., upper_bound_N-1 : type)`
+ - Example: `num_threads(dims(3): %ub0, %ub1, %ub2 : i32)`
+
+ Without dims modifier:
+ - Uses `num_threads`
+ - If lower bound not specified, it defaults to upper bound value
+ - Format: `num_threads(bounds : type)`
+ - Example: `num_threads(%ub : i32)`
+ }];
+
+ let extraClassDeclaration = [{
+ /// Returns true if the dims modifier is explicitly present
+ bool hasNumThreadsDimsModifier() {
+ return getNumThreadsNumDims().has_value() && getNumThreadsNumDims().value();
+ }
+
+ /// Returns the number of dimensions specified by dims modifier
+ unsigned getNumThreadsDimsCount() {
+ if (!hasNumThreadsDimsModifier())
+ return 1;
+ return static_cast<unsigned>(*getNumThreadsNumDims());
+ }
+
+ /// Returns the value for a specific dimension index
+ /// Index must be less than getNumThreadsDimsCount()
+ ::mlir::Value getNumThreadsDimsValue(unsigned index) {
+ assert(index < getNumThreadsDimsCount() &&
+ "Num threads dims index out of bounds");
+ return getNumThreadsDimsValues()[index];
+ }
}];
}
diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
index 6423d49859c97..ab7bded7835be 100644
--- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
+++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
@@ -448,6 +448,8 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
/* allocate_vars = */ llvm::SmallVector<Value>{},
/* allocator_vars = */ llvm::SmallVector<Value>{},
/* if_expr = */ Value{},
+ /* num_threads_num_dims = */ nullptr,
+ /* num_threads_dims_values = */ llvm::SmallVector<Value>{},
/* num_threads = */ numThreadsVar,
/* private_vars = */ ValueRange(),
/* private_syms = */ nullptr,
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index d4dbf5f5244df..a9ed0274cd21c 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -2533,6 +2533,8 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state,
ArrayRef<NamedAttribute> attributes) {
ParallelOp::build(builder, state, /*allocate_vars=*/ValueRange(),
/*allocator_vars=*/ValueRange(), /*if_expr=*/nullptr,
+ /*num_threads_dims=*/nullptr,
+ /*num_threads_values=*/ValueRange(),
/*num_threads=*/nullptr, /*private_vars=*/ValueRange(),
/*private_syms=*/nullptr, /*private_needs_barrier=*/nullptr,
/*proc_bind_kind=*/nullptr,
@@ -2544,13 +2546,14 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state,
void ParallelOp::build(OpBuilder &builder, OperationState &state,
const ParallelOperands &clauses) {
MLIRContext *ctx = builder.getContext();
- ParallelOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
- clauses.ifExpr, clauses.numThreads, clauses.privateVars,
- makeArrayAttr(ctx, clauses.privateSyms),
- clauses.privateNeedsBarrier, clauses.procBindKind,
- clauses.reductionMod, clauses.reductionVars,
- makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
- makeArrayAttr(ctx, clauses.reductionSyms));
+ ParallelOp::build(
+ builder, state, clauses.allocateVars, clauses.allocatorVars,
+ clauses.ifExpr, clauses.numThreadsNumDims, clauses.numThreadsDimsValues,
+ clauses.numThreads, clauses.privateVars,
+ makeArrayAttr(ctx, clauses.privateSyms), clauses.privateNeedsBarrier,
+ clauses.procBindKind, clauses.reductionMod, clauses.reductionVars,
+ makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
+ makeArrayAttr(ctx, clauses.reductionSyms));
}
template <typename OpType>
@@ -2596,14 +2599,39 @@ static LogicalResult verifyPrivateVarList(OpType &op) {
return success();
}
+// Helper: Verify num_threads clause
+LogicalResult
+verifyNumThreadsClause(Operation *op,
+ std::optional<IntegerAttr> numThreadsNumDims,
+ OperandRange numThreadsDimsValues, Value numThreads) {
+ bool hasDimsModifier =
+ numThreadsNumDims.has_value() && numThreadsNumDims.value();
+ if (hasDimsModifier && numThreads) {
+ return op->emitError("num_threads with dims modifier cannot be used "
+ "together with number of threads");
+ }
+ if (failed(verifyDimsModifier(op, numThreadsNumDims, numThreadsDimsValues)))
+ return failure();
+ return success();
+}
+
LogicalResult ParallelOp::verify() {
+ // verify num_threads clause restrictions
+ if (failed(verifyNumThreadsClause(
+ getOperation(), this->getNumThreadsNumDimsAttr(),
+ this->getNumThreadsDimsValues(), this->getNumThreads())))
+ return failure();
+
+ // verify allocate clause restrictions
if (getAllocateVars().size() != getAllocatorVars().size())
return emitError(
"expected equal sizes for allocate and allocator variables");
+ // verify private variables restrictions
if (failed(verifyPrivateVarList(*this)))
return failure();
+ // verify reduction variables restrictions
return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
getReductionByref());
}
@@ -4647,6 +4675,41 @@ static void printNumTeamsClause(OpAsmPrinter &p, Operation *op,
}
}
+//===----------------------------------------------------------------------===//
+// Parser and printer for num_threads clause
+//===----------------------------------------------------------------------===//
+static ParseResult
+parseNumThreadsClause(OpAsmParser &parser, IntegerAttr &dimsAttr,
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
+ SmallVectorImpl<Type> &types,
+ std::optional<OpAsmParser::UnresolvedOperand> &bounds,
+ Type &boundsType) {
+ if (succeeded(parseDimsModifierWithValues(parser, dimsAttr, values, types))) {
+ return success();
+ }
+
+ OpAsmParser::UnresolvedOperand boundsOperand;
+ if (parser.parseOperand(boundsOperand) || parser.parseColon() ||
+ parser.parseType(boundsType)) {
+ return failure();
+ }
+ bounds = boundsOperand;
+ return success();
+}
+
+static void printNumThreadsClause(OpAsmPrinter &p, Operation *op,
+ IntegerAttr dimsAttr, OperandRange values,
+ TypeRange types, Value bounds,
+ Type boundsType) {
+ if (!values.empty()) {
+ printDimsModifierWithValues(p, dimsAttr, values, types);
+ }
+ if (bounds) {
+ p.printOperand(bounds);
+ p << " : " << boundsType;
+ }
+}
+
#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 00f782e87d5af..2bfb9fb2211c4 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -2879,6 +2879,9 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
if (auto ifVar = opInst.getIfExpr())
ifCond = moduleTranslation.lookupValue(ifVar);
llvm::Value *numThreads = nullptr;
+ // num_threads dims and values are not yet supported
+ assert(!opInst.hasNumThreadsDimsModifier() &&
+ "Lowering of num_threads with dims modifier is NYI.");
if (auto numThreadsVar = opInst.getNumThreads())
numThreads = moduleTranslation.lookupValue(numThreadsVar);
auto pbKind = llvm::omp::OMP_PROC_BIND_default;
@@ -5604,6 +5607,9 @@ extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
llvm_unreachable("unsupported host_eval use");
})
.Case([&](omp::ParallelOp parallelOp) {
+ // num_threads dims and values are not yet supported
+ assert(!parallelOp.hasNumThreadsDimsModifier() &&
+ "Lowering of num_threads with dims modifier is NYI.");
if (parallelOp.getNumThreads() == blockArg)
numThreads = hostEvalVar;
else
@@ -5724,8 +5730,12 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
threadLimit = teamsOp.getThreadLimit();
}
- if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp))
+ if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp)) {
+ // num_threads dims and values are not yet supported
+ assert(!parallelOp.hasNumThreadsDimsModifier() &&
+ "Lowering of num_threads with dims modifier is NYI.");
numThreads = parallelOp.getNumThreads();
+ }
}
// Handle clauses impacting the number of teams.
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index dd367aba8da27..db0ddcb415d42 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -30,6 +30,37 @@ func.func @num_threads_once(%n : si32) {
// -----
+func.func @num_threads_dims_no_values() {
+ // expected-error@+1 {{dims modifier requires values to be specified}}
+ "omp.parallel"() ({
+ omp.terminator
+ }) {operandSegmentSizes = array<i32: 0,0,0,0,0,0,0>, num_threads_num_dims = 2 : i64} : () -> ()
+ return
+}
+
+// -----
+
+func.func @num_threads_dims_mismatch(%n : i64) {
+ // expected-error@+1 {{dims(2) specified but 1 values provided}}
+ omp.parallel num_threads(dims(2): %n : i64) {
+ omp.terminator
+ }
+
+ return
+}
+
+// -----
+
+func.func @num_threads_dims_and_scalar(%n : i64, %m: i64) {
+ // expected-error@+1 {{num_threads with dims modifier cannot be used together with number of threads}}
+ "omp.parallel"(%n, %n, %m) ({
+ omp.terminator
+ }) {operandSegmentSizes = array<i32: 0,0,0,2,1,0,0>, num_threads_num_dims = 2 : i64} : (i64, i64, i64) -> ()
+ return
+}
+
+// -----
+
func.func @nowait_not_allowed(%n : memref<i32>) {
// expected-error@+1 {{expected '{' to begin a region}}
omp.parallel nowait {}
@@ -2766,7 +2797,7 @@ func.func @undefined_privatizer(%arg0: index) {
// -----
func.func @undefined_privatizer(%arg0: !llvm.ptr) {
// expected-error @below {{inconsistent number of private variables and privatizer op symbols, private vars: 1 vs. privatizer op symbols: 2}}
- "omp.parallel"(%arg0) <{operandSegmentSizes = array<i32: 0, 0, 0, 0, 1, 0>, private_syms = [@x.privatizer, @y.privatizer]}> ({
+ "omp.parallel"(%arg0) <{operandSegmentSizes = array<i32: 0, 0, 0, 0, 0, 1, 0>, private_syms = [@x.privatizer, @y.privatizer]}> ({
^bb0(%arg2: !llvm.ptr):
omp.terminator
}) : (!llvm.ptr) -> ()
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 3633a4be1eb62..585c9483c08a9 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -73,7 +73,7 @@ func.func @omp_parallel(%data_var : memref<i32>, %if_cond : i1, %num_threads : i
// CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>) num_threads(%{{.*}} : i32)
"omp.parallel"(%data_var, %data_var, %num_threads) ({
omp.terminator
- }) {operandSegmentSizes = array<i32: 1,1,0,1,0,0>} : (memref<i32>, memref<i32>, i32) -> ()
+ }) {operandSegmentSizes = array<i32: 1,1,0,0,1,0,0>} : (memref<i32>, memref<i32>, i32) -> ()
// CHECK: omp.barrier
omp.barrier
@@ -82,22 +82,22 @@ func.func @omp_parallel(%data_var : memref<i32>, %if_cond : i1, %num_threads : i
// CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>) if(%{{.*}})
"omp.parallel"(%data_var, %data_var, %if_cond) ({
omp.terminator
- }) {operandSegmentSizes = array<i32: 1,1,1,0,0,0>} : (memref<i32>, memref<i32>, i1) -> ()
+ }) {operandSegmentSizes = array<i32: 1,1,1,0,0,0,0>} : (memref<i32>, memref<i32>, i1) -> ()
// test without allocate
// CHECK: omp.parallel if(%{{.*}}) num_threads(%{{.*}} : i32)
"omp.parallel"(%if_cond, %num_threads) ({
omp.terminator
- }) {operandSegmentSizes = array<i32: 0,0,1,1,0,0>} : (i1, i32) -> ()
+ }) {operandSegmentSizes = array<i32: 0,0,1,0,1,0,0>} : (i1, i32) -> ()
omp.terminator
- }) {operandSegmentSizes = array<i32: 1,1,1,1,0,0>, proc_bind_kind = #omp<procbindkind spread>} : (memref<i32>, memref<i32>, i1, i32) -> ()
+ }) {operandSegmentSizes = array<i32: 1,1,1,0,1,0,0>, proc_bind_kind = #omp<procbindkind spread>} : (memref<i32>, memref<i32>, i1, i32) -> ()
// test with multiple parameters for single variadic argument
// CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>)
"omp.parallel" (%data_var, %data_var) ({
omp.terminator
- }) {operandSegmentSizes = array<i32: 1,1,0,0,0,0>} : (memref<i32>, memref<i32>) -> ()
+ }) {operandSegmentSizes = array<i32: 1,1,0,0,0,0,0>} : (memref<i32>, memref<i32>) -> ()
// CHECK: omp.parallel
omp.parallel {
@@ -160,6 +160,11 @@ func.func @omp_parallel_pretty(%data_var : memref<i32>, %if_cond : i1, %num_thre
omp.terminator
}
+ // CHECK: omp.parallel num_threads(dims(2): %{{.*}}, %{{.*}} : i64)
+ omp.parallel num_threads(dims(2): %n_i64, %n_i64 : i64) {
+ omp.terminator
+ }
+
// CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>)
omp.parallel allocate(%data_var : memref<i32> -> %data_var : memref<i32>) {
omp.terminator
|
|
@llvm/pr-subscribers-flang-openmp Author: Chaitanya (skc7) ChangesPR adds support of openmp 6.1 feature num_threads with dims modifier. Full diff: https://github.com/llvm/llvm-project/pull/171767.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index e36dc7c246f01..09c1d4a8a5866 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -1069,16 +1069,55 @@ class OpenMP_NumThreadsClauseSkip<
> : OpenMP_Clause<traits, arguments, assemblyFormat, description,
extraClassDeclaration> {
let arguments = (ins
+ ConfinedAttr<OptionalAttr<I64Attr>, [IntPositive]>:$num_threads_num_dims,
+ Variadic<AnyInteger>:$num_threads_dims_values,
Optional<IntLikeType>:$num_threads
);
let optAssemblyFormat = [{
- `num_threads` `(` $num_threads `:` type($num_threads) `)`
+ `num_threads` `(` custom<NumThreadsClause>(
+ $num_threads_num_dims, $num_threads_dims_values, type($num_threads_dims_values),
+ $num_threads, type($num_threads)
+ ) `)`
}];
let description = [{
- The optional `num_threads` parameter specifies the number of threads which
- should be used to execute the parallel region.
+ num_threads clause specifies the desired number of threads in the team
+ space formed by the construct on which it appears.
+
+ With dims modifier:
+ - Uses `num_threads_num_dims` (dimension count) and `num_threads_dims_values` (upper bounds list)
+ - Specifies upper bounds for each dimension (all must have same type)
+ - Format: `num_threads(dims(N): upper_bound_0, ..., upper_bound_N-1 : type)`
+ - Example: `num_threads(dims(3): %ub0, %ub1, %ub2 : i32)`
+
+ Without dims modifier:
+ - Uses `num_threads`
+ - If lower bound not specified, it defaults to upper bound value
+ - Format: `num_threads(bounds : type)`
+ - Example: `num_threads(%ub : i32)`
+ }];
+
+ let extraClassDeclaration = [{
+ /// Returns true if the dims modifier is explicitly present
+ bool hasNumThreadsDimsModifier() {
+ return getNumThreadsNumDims().has_value() && getNumThreadsNumDims().value();
+ }
+
+ /// Returns the number of dimensions specified by dims modifier
+ unsigned getNumThreadsDimsCount() {
+ if (!hasNumThreadsDimsModifier())
+ return 1;
+ return static_cast<unsigned>(*getNumThreadsNumDims());
+ }
+
+ /// Returns the value for a specific dimension index
+ /// Index must be less than getNumThreadsDimsCount()
+ ::mlir::Value getNumThreadsDimsValue(unsigned index) {
+ assert(index < getNumThreadsDimsCount() &&
+ "Num threads dims index out of bounds");
+ return getNumThreadsDimsValues()[index];
+ }
}];
}
diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
index 6423d49859c97..ab7bded7835be 100644
--- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
+++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
@@ -448,6 +448,8 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
/* allocate_vars = */ llvm::SmallVector<Value>{},
/* allocator_vars = */ llvm::SmallVector<Value>{},
/* if_expr = */ Value{},
+ /* num_threads_num_dims = */ nullptr,
+ /* num_threads_dims_values = */ llvm::SmallVector<Value>{},
/* num_threads = */ numThreadsVar,
/* private_vars = */ ValueRange(),
/* private_syms = */ nullptr,
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index d4dbf5f5244df..a9ed0274cd21c 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -2533,6 +2533,8 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state,
ArrayRef<NamedAttribute> attributes) {
ParallelOp::build(builder, state, /*allocate_vars=*/ValueRange(),
/*allocator_vars=*/ValueRange(), /*if_expr=*/nullptr,
+ /*num_threads_dims=*/nullptr,
+ /*num_threads_values=*/ValueRange(),
/*num_threads=*/nullptr, /*private_vars=*/ValueRange(),
/*private_syms=*/nullptr, /*private_needs_barrier=*/nullptr,
/*proc_bind_kind=*/nullptr,
@@ -2544,13 +2546,14 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state,
void ParallelOp::build(OpBuilder &builder, OperationState &state,
const ParallelOperands &clauses) {
MLIRContext *ctx = builder.getContext();
- ParallelOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
- clauses.ifExpr, clauses.numThreads, clauses.privateVars,
- makeArrayAttr(ctx, clauses.privateSyms),
- clauses.privateNeedsBarrier, clauses.procBindKind,
- clauses.reductionMod, clauses.reductionVars,
- makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
- makeArrayAttr(ctx, clauses.reductionSyms));
+ ParallelOp::build(
+ builder, state, clauses.allocateVars, clauses.allocatorVars,
+ clauses.ifExpr, clauses.numThreadsNumDims, clauses.numThreadsDimsValues,
+ clauses.numThreads, clauses.privateVars,
+ makeArrayAttr(ctx, clauses.privateSyms), clauses.privateNeedsBarrier,
+ clauses.procBindKind, clauses.reductionMod, clauses.reductionVars,
+ makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
+ makeArrayAttr(ctx, clauses.reductionSyms));
}
template <typename OpType>
@@ -2596,14 +2599,39 @@ static LogicalResult verifyPrivateVarList(OpType &op) {
return success();
}
+// Helper: Verify num_threads clause
+LogicalResult
+verifyNumThreadsClause(Operation *op,
+ std::optional<IntegerAttr> numThreadsNumDims,
+ OperandRange numThreadsDimsValues, Value numThreads) {
+ bool hasDimsModifier =
+ numThreadsNumDims.has_value() && numThreadsNumDims.value();
+ if (hasDimsModifier && numThreads) {
+ return op->emitError("num_threads with dims modifier cannot be used "
+ "together with number of threads");
+ }
+ if (failed(verifyDimsModifier(op, numThreadsNumDims, numThreadsDimsValues)))
+ return failure();
+ return success();
+}
+
LogicalResult ParallelOp::verify() {
+ // verify num_threads clause restrictions
+ if (failed(verifyNumThreadsClause(
+ getOperation(), this->getNumThreadsNumDimsAttr(),
+ this->getNumThreadsDimsValues(), this->getNumThreads())))
+ return failure();
+
+ // verify allocate clause restrictions
if (getAllocateVars().size() != getAllocatorVars().size())
return emitError(
"expected equal sizes for allocate and allocator variables");
+ // verify private variables restrictions
if (failed(verifyPrivateVarList(*this)))
return failure();
+ // verify reduction variables restrictions
return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
getReductionByref());
}
@@ -4647,6 +4675,41 @@ static void printNumTeamsClause(OpAsmPrinter &p, Operation *op,
}
}
+//===----------------------------------------------------------------------===//
+// Parser and printer for num_threads clause
+//===----------------------------------------------------------------------===//
+static ParseResult
+parseNumThreadsClause(OpAsmParser &parser, IntegerAttr &dimsAttr,
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
+ SmallVectorImpl<Type> &types,
+ std::optional<OpAsmParser::UnresolvedOperand> &bounds,
+ Type &boundsType) {
+ if (succeeded(parseDimsModifierWithValues(parser, dimsAttr, values, types))) {
+ return success();
+ }
+
+ OpAsmParser::UnresolvedOperand boundsOperand;
+ if (parser.parseOperand(boundsOperand) || parser.parseColon() ||
+ parser.parseType(boundsType)) {
+ return failure();
+ }
+ bounds = boundsOperand;
+ return success();
+}
+
+static void printNumThreadsClause(OpAsmPrinter &p, Operation *op,
+ IntegerAttr dimsAttr, OperandRange values,
+ TypeRange types, Value bounds,
+ Type boundsType) {
+ if (!values.empty()) {
+ printDimsModifierWithValues(p, dimsAttr, values, types);
+ }
+ if (bounds) {
+ p.printOperand(bounds);
+ p << " : " << boundsType;
+ }
+}
+
#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 00f782e87d5af..2bfb9fb2211c4 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -2879,6 +2879,9 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
if (auto ifVar = opInst.getIfExpr())
ifCond = moduleTranslation.lookupValue(ifVar);
llvm::Value *numThreads = nullptr;
+ // num_threads dims and values are not yet supported
+ assert(!opInst.hasNumThreadsDimsModifier() &&
+ "Lowering of num_threads with dims modifier is NYI.");
if (auto numThreadsVar = opInst.getNumThreads())
numThreads = moduleTranslation.lookupValue(numThreadsVar);
auto pbKind = llvm::omp::OMP_PROC_BIND_default;
@@ -5604,6 +5607,9 @@ extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
llvm_unreachable("unsupported host_eval use");
})
.Case([&](omp::ParallelOp parallelOp) {
+ // num_threads dims and values are not yet supported
+ assert(!parallelOp.hasNumThreadsDimsModifier() &&
+ "Lowering of num_threads with dims modifier is NYI.");
if (parallelOp.getNumThreads() == blockArg)
numThreads = hostEvalVar;
else
@@ -5724,8 +5730,12 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
threadLimit = teamsOp.getThreadLimit();
}
- if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp))
+ if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp)) {
+ // num_threads dims and values are not yet supported
+ assert(!parallelOp.hasNumThreadsDimsModifier() &&
+ "Lowering of num_threads with dims modifier is NYI.");
numThreads = parallelOp.getNumThreads();
+ }
}
// Handle clauses impacting the number of teams.
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index dd367aba8da27..db0ddcb415d42 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -30,6 +30,37 @@ func.func @num_threads_once(%n : si32) {
// -----
+func.func @num_threads_dims_no_values() {
+ // expected-error@+1 {{dims modifier requires values to be specified}}
+ "omp.parallel"() ({
+ omp.terminator
+ }) {operandSegmentSizes = array<i32: 0,0,0,0,0,0,0>, num_threads_num_dims = 2 : i64} : () -> ()
+ return
+}
+
+// -----
+
+func.func @num_threads_dims_mismatch(%n : i64) {
+ // expected-error@+1 {{dims(2) specified but 1 values provided}}
+ omp.parallel num_threads(dims(2): %n : i64) {
+ omp.terminator
+ }
+
+ return
+}
+
+// -----
+
+func.func @num_threads_dims_and_scalar(%n : i64, %m: i64) {
+ // expected-error@+1 {{num_threads with dims modifier cannot be used together with number of threads}}
+ "omp.parallel"(%n, %n, %m) ({
+ omp.terminator
+ }) {operandSegmentSizes = array<i32: 0,0,0,2,1,0,0>, num_threads_num_dims = 2 : i64} : (i64, i64, i64) -> ()
+ return
+}
+
+// -----
+
func.func @nowait_not_allowed(%n : memref<i32>) {
// expected-error@+1 {{expected '{' to begin a region}}
omp.parallel nowait {}
@@ -2766,7 +2797,7 @@ func.func @undefined_privatizer(%arg0: index) {
// -----
func.func @undefined_privatizer(%arg0: !llvm.ptr) {
// expected-error @below {{inconsistent number of private variables and privatizer op symbols, private vars: 1 vs. privatizer op symbols: 2}}
- "omp.parallel"(%arg0) <{operandSegmentSizes = array<i32: 0, 0, 0, 0, 1, 0>, private_syms = [@x.privatizer, @y.privatizer]}> ({
+ "omp.parallel"(%arg0) <{operandSegmentSizes = array<i32: 0, 0, 0, 0, 0, 1, 0>, private_syms = [@x.privatizer, @y.privatizer]}> ({
^bb0(%arg2: !llvm.ptr):
omp.terminator
}) : (!llvm.ptr) -> ()
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 3633a4be1eb62..585c9483c08a9 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -73,7 +73,7 @@ func.func @omp_parallel(%data_var : memref<i32>, %if_cond : i1, %num_threads : i
// CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>) num_threads(%{{.*}} : i32)
"omp.parallel"(%data_var, %data_var, %num_threads) ({
omp.terminator
- }) {operandSegmentSizes = array<i32: 1,1,0,1,0,0>} : (memref<i32>, memref<i32>, i32) -> ()
+ }) {operandSegmentSizes = array<i32: 1,1,0,0,1,0,0>} : (memref<i32>, memref<i32>, i32) -> ()
// CHECK: omp.barrier
omp.barrier
@@ -82,22 +82,22 @@ func.func @omp_parallel(%data_var : memref<i32>, %if_cond : i1, %num_threads : i
// CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>) if(%{{.*}})
"omp.parallel"(%data_var, %data_var, %if_cond) ({
omp.terminator
- }) {operandSegmentSizes = array<i32: 1,1,1,0,0,0>} : (memref<i32>, memref<i32>, i1) -> ()
+ }) {operandSegmentSizes = array<i32: 1,1,1,0,0,0,0>} : (memref<i32>, memref<i32>, i1) -> ()
// test without allocate
// CHECK: omp.parallel if(%{{.*}}) num_threads(%{{.*}} : i32)
"omp.parallel"(%if_cond, %num_threads) ({
omp.terminator
- }) {operandSegmentSizes = array<i32: 0,0,1,1,0,0>} : (i1, i32) -> ()
+ }) {operandSegmentSizes = array<i32: 0,0,1,0,1,0,0>} : (i1, i32) -> ()
omp.terminator
- }) {operandSegmentSizes = array<i32: 1,1,1,1,0,0>, proc_bind_kind = #omp<procbindkind spread>} : (memref<i32>, memref<i32>, i1, i32) -> ()
+ }) {operandSegmentSizes = array<i32: 1,1,1,0,1,0,0>, proc_bind_kind = #omp<procbindkind spread>} : (memref<i32>, memref<i32>, i1, i32) -> ()
// test with multiple parameters for single variadic argument
// CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>)
"omp.parallel" (%data_var, %data_var) ({
omp.terminator
- }) {operandSegmentSizes = array<i32: 1,1,0,0,0,0>} : (memref<i32>, memref<i32>) -> ()
+ }) {operandSegmentSizes = array<i32: 1,1,0,0,0,0,0>} : (memref<i32>, memref<i32>) -> ()
// CHECK: omp.parallel
omp.parallel {
@@ -160,6 +160,11 @@ func.func @omp_parallel_pretty(%data_var : memref<i32>, %if_cond : i1, %num_thre
omp.terminator
}
+ // CHECK: omp.parallel num_threads(dims(2): %{{.*}}, %{{.*}} : i64)
+ omp.parallel num_threads(dims(2): %n_i64, %n_i64 : i64) {
+ omp.terminator
+ }
+
// CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>)
omp.parallel allocate(%data_var : memref<i32> -> %data_var : memref<i32>) {
omp.terminator
|
PR adds support of openmp 6.1 feature num_threads with dims modifier.
llvmIR translation for num_threads with dims modifier is marked as NYI.