Skip to content

Conversation

@skc7
Copy link
Contributor

@skc7 skc7 commented Dec 11, 2025

PR adds support of openmp 6.1 feature num_threads with dims modifier.
llvmIR translation for num_threads with dims modifier is marked as NYI.

@skc7 skc7 force-pushed the users/skc7/dims/num_threads branch from 45e7dab to 33dcfd9 Compare December 11, 2025 12:57
@skc7 skc7 marked this pull request as ready for review December 11, 2025 13:16
@llvmbot
Copy link
Member

llvmbot commented Dec 11, 2025

@llvm/pr-subscribers-mlir

Author: Chaitanya (skc7)

Changes

PR adds support of openmp 6.1 feature num_threads with dims modifier.
llvmIR translation for num_threads with dims modifier is marked as NYI.


Full diff: https://github.com/llvm/llvm-project/pull/171767.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td (+42-3)
  • (modified) mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp (+2)
  • (modified) mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp (+70-7)
  • (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+11-1)
  • (modified) mlir/test/Dialect/OpenMP/invalid.mlir (+32-1)
  • (modified) mlir/test/Dialect/OpenMP/ops.mlir (+10-5)
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index e36dc7c246f01..09c1d4a8a5866 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -1069,16 +1069,55 @@ class OpenMP_NumThreadsClauseSkip<
   > : OpenMP_Clause<traits, arguments, assemblyFormat, description,
                     extraClassDeclaration> {
   let arguments = (ins
+    ConfinedAttr<OptionalAttr<I64Attr>, [IntPositive]>:$num_threads_num_dims,
+    Variadic<AnyInteger>:$num_threads_dims_values,
     Optional<IntLikeType>:$num_threads
   );
 
   let optAssemblyFormat = [{
-    `num_threads` `(` $num_threads `:` type($num_threads) `)`
+    `num_threads` `(` custom<NumThreadsClause>(
+      $num_threads_num_dims, $num_threads_dims_values, type($num_threads_dims_values),
+      $num_threads, type($num_threads)
+    ) `)`
   }];
 
   let description = [{
-    The optional `num_threads` parameter specifies the number of threads which
-    should be used to execute the parallel region.
+    num_threads clause specifies the desired number of threads in the team
+    space formed by the construct on which it appears.
+
+    With dims modifier:
+    - Uses `num_threads_num_dims` (dimension count) and `num_threads_dims_values` (upper bounds list)
+    - Specifies upper bounds for each dimension (all must have same type)
+    - Format: `num_threads(dims(N): upper_bound_0, ..., upper_bound_N-1 : type)`
+    - Example: `num_threads(dims(3): %ub0, %ub1, %ub2 : i32)`
+
+    Without dims modifier:
+    - Uses `num_threads`
+    - If lower bound not specified, it defaults to upper bound value
+    - Format: `num_threads(bounds : type)`
+    - Example: `num_threads(%ub : i32)`
+  }];
+
+  let extraClassDeclaration = [{
+    /// Returns true if the dims modifier is explicitly present
+    bool hasNumThreadsDimsModifier() {
+      return getNumThreadsNumDims().has_value() && getNumThreadsNumDims().value();
+    }
+
+    /// Returns the number of dimensions specified by dims modifier
+    unsigned getNumThreadsDimsCount() {
+      if (!hasNumThreadsDimsModifier())
+        return 1;
+      return static_cast<unsigned>(*getNumThreadsNumDims());
+    }
+
+    /// Returns the value for a specific dimension index
+    /// Index must be less than getNumThreadsDimsCount()
+    ::mlir::Value getNumThreadsDimsValue(unsigned index) {
+      assert(index < getNumThreadsDimsCount() &&
+             "Num threads dims index out of bounds");
+      return getNumThreadsDimsValues()[index];
+    }
   }];
 }
 
diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
index 6423d49859c97..ab7bded7835be 100644
--- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
+++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
@@ -448,6 +448,8 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
         /* allocate_vars = */ llvm::SmallVector<Value>{},
         /* allocator_vars = */ llvm::SmallVector<Value>{},
         /* if_expr = */ Value{},
+        /* num_threads_num_dims = */ nullptr,
+        /* num_threads_dims_values = */ llvm::SmallVector<Value>{},
         /* num_threads = */ numThreadsVar,
         /* private_vars = */ ValueRange(),
         /* private_syms = */ nullptr,
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index d4dbf5f5244df..a9ed0274cd21c 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -2533,6 +2533,8 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state,
                        ArrayRef<NamedAttribute> attributes) {
   ParallelOp::build(builder, state, /*allocate_vars=*/ValueRange(),
                     /*allocator_vars=*/ValueRange(), /*if_expr=*/nullptr,
+                    /*num_threads_dims=*/nullptr,
+                    /*num_threads_values=*/ValueRange(),
                     /*num_threads=*/nullptr, /*private_vars=*/ValueRange(),
                     /*private_syms=*/nullptr, /*private_needs_barrier=*/nullptr,
                     /*proc_bind_kind=*/nullptr,
@@ -2544,13 +2546,14 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state,
 void ParallelOp::build(OpBuilder &builder, OperationState &state,
                        const ParallelOperands &clauses) {
   MLIRContext *ctx = builder.getContext();
-  ParallelOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
-                    clauses.ifExpr, clauses.numThreads, clauses.privateVars,
-                    makeArrayAttr(ctx, clauses.privateSyms),
-                    clauses.privateNeedsBarrier, clauses.procBindKind,
-                    clauses.reductionMod, clauses.reductionVars,
-                    makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
-                    makeArrayAttr(ctx, clauses.reductionSyms));
+  ParallelOp::build(
+      builder, state, clauses.allocateVars, clauses.allocatorVars,
+      clauses.ifExpr, clauses.numThreadsNumDims, clauses.numThreadsDimsValues,
+      clauses.numThreads, clauses.privateVars,
+      makeArrayAttr(ctx, clauses.privateSyms), clauses.privateNeedsBarrier,
+      clauses.procBindKind, clauses.reductionMod, clauses.reductionVars,
+      makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
+      makeArrayAttr(ctx, clauses.reductionSyms));
 }
 
 template <typename OpType>
@@ -2596,14 +2599,39 @@ static LogicalResult verifyPrivateVarList(OpType &op) {
   return success();
 }
 
+// Helper: Verify num_threads clause
+LogicalResult
+verifyNumThreadsClause(Operation *op,
+                       std::optional<IntegerAttr> numThreadsNumDims,
+                       OperandRange numThreadsDimsValues, Value numThreads) {
+  bool hasDimsModifier =
+      numThreadsNumDims.has_value() && numThreadsNumDims.value();
+  if (hasDimsModifier && numThreads) {
+    return op->emitError("num_threads with dims modifier cannot be used "
+                         "together with number of threads");
+  }
+  if (failed(verifyDimsModifier(op, numThreadsNumDims, numThreadsDimsValues)))
+    return failure();
+  return success();
+}
+
 LogicalResult ParallelOp::verify() {
+  // verify num_threads clause restrictions
+  if (failed(verifyNumThreadsClause(
+          getOperation(), this->getNumThreadsNumDimsAttr(),
+          this->getNumThreadsDimsValues(), this->getNumThreads())))
+    return failure();
+
+  // verify allocate clause restrictions
   if (getAllocateVars().size() != getAllocatorVars().size())
     return emitError(
         "expected equal sizes for allocate and allocator variables");
 
+  // verify private variables restrictions
   if (failed(verifyPrivateVarList(*this)))
     return failure();
 
+  // verify reduction variables restrictions
   return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
                                 getReductionByref());
 }
@@ -4647,6 +4675,41 @@ static void printNumTeamsClause(OpAsmPrinter &p, Operation *op,
   }
 }
 
+//===----------------------------------------------------------------------===//
+// Parser and printer for num_threads clause
+//===----------------------------------------------------------------------===//
+static ParseResult
+parseNumThreadsClause(OpAsmParser &parser, IntegerAttr &dimsAttr,
+                      SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
+                      SmallVectorImpl<Type> &types,
+                      std::optional<OpAsmParser::UnresolvedOperand> &bounds,
+                      Type &boundsType) {
+  if (succeeded(parseDimsModifierWithValues(parser, dimsAttr, values, types))) {
+    return success();
+  }
+
+  OpAsmParser::UnresolvedOperand boundsOperand;
+  if (parser.parseOperand(boundsOperand) || parser.parseColon() ||
+      parser.parseType(boundsType)) {
+    return failure();
+  }
+  bounds = boundsOperand;
+  return success();
+}
+
+static void printNumThreadsClause(OpAsmPrinter &p, Operation *op,
+                                  IntegerAttr dimsAttr, OperandRange values,
+                                  TypeRange types, Value bounds,
+                                  Type boundsType) {
+  if (!values.empty()) {
+    printDimsModifierWithValues(p, dimsAttr, values, types);
+  }
+  if (bounds) {
+    p.printOperand(bounds);
+    p << " : " << boundsType;
+  }
+}
+
 #define GET_ATTRDEF_CLASSES
 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
 
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 00f782e87d5af..2bfb9fb2211c4 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -2879,6 +2879,9 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
   if (auto ifVar = opInst.getIfExpr())
     ifCond = moduleTranslation.lookupValue(ifVar);
   llvm::Value *numThreads = nullptr;
+  // num_threads dims and values are not yet supported
+  assert(!opInst.hasNumThreadsDimsModifier() &&
+         "Lowering of num_threads with dims modifier is NYI.");
   if (auto numThreadsVar = opInst.getNumThreads())
     numThreads = moduleTranslation.lookupValue(numThreadsVar);
   auto pbKind = llvm::omp::OMP_PROC_BIND_default;
@@ -5604,6 +5607,9 @@ extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
               llvm_unreachable("unsupported host_eval use");
           })
           .Case([&](omp::ParallelOp parallelOp) {
+            // num_threads dims and values are not yet supported
+            assert(!parallelOp.hasNumThreadsDimsModifier() &&
+                   "Lowering of num_threads with dims modifier is NYI.");
             if (parallelOp.getNumThreads() == blockArg)
               numThreads = hostEvalVar;
             else
@@ -5724,8 +5730,12 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
       threadLimit = teamsOp.getThreadLimit();
     }
 
-    if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp))
+    if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp)) {
+      // num_threads dims and values are not yet supported
+      assert(!parallelOp.hasNumThreadsDimsModifier() &&
+             "Lowering of num_threads with dims modifier is NYI.");
       numThreads = parallelOp.getNumThreads();
+    }
   }
 
   // Handle clauses impacting the number of teams.
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index dd367aba8da27..db0ddcb415d42 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -30,6 +30,37 @@ func.func @num_threads_once(%n : si32) {
 
 // -----
 
+func.func @num_threads_dims_no_values() {
+  // expected-error@+1 {{dims modifier requires values to be specified}}
+  "omp.parallel"() ({
+    omp.terminator
+  }) {operandSegmentSizes = array<i32: 0,0,0,0,0,0,0>, num_threads_num_dims = 2 : i64} : () -> ()
+  return
+}
+
+// -----
+
+func.func @num_threads_dims_mismatch(%n : i64) {
+  // expected-error@+1 {{dims(2) specified but 1 values provided}}
+  omp.parallel num_threads(dims(2): %n : i64) {
+    omp.terminator
+  }
+
+  return
+}
+
+// -----
+
+func.func @num_threads_dims_and_scalar(%n : i64, %m: i64) {
+  // expected-error@+1 {{num_threads with dims modifier cannot be used together with number of threads}}
+  "omp.parallel"(%n, %n, %m) ({
+    omp.terminator
+  }) {operandSegmentSizes = array<i32: 0,0,0,2,1,0,0>, num_threads_num_dims = 2 : i64} : (i64, i64, i64) -> ()
+  return
+}
+
+// -----
+
 func.func @nowait_not_allowed(%n : memref<i32>) {
   // expected-error@+1 {{expected '{' to begin a region}}
   omp.parallel nowait {}
@@ -2766,7 +2797,7 @@ func.func @undefined_privatizer(%arg0: index) {
 // -----
 func.func @undefined_privatizer(%arg0: !llvm.ptr) {
   // expected-error @below {{inconsistent number of private variables and privatizer op symbols, private vars: 1 vs. privatizer op symbols: 2}}
-  "omp.parallel"(%arg0) <{operandSegmentSizes = array<i32: 0, 0, 0, 0, 1, 0>, private_syms = [@x.privatizer, @y.privatizer]}> ({
+  "omp.parallel"(%arg0) <{operandSegmentSizes = array<i32: 0, 0, 0, 0, 0, 1, 0>, private_syms = [@x.privatizer, @y.privatizer]}> ({
     ^bb0(%arg2: !llvm.ptr):
       omp.terminator
     }) : (!llvm.ptr) -> ()
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 3633a4be1eb62..585c9483c08a9 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -73,7 +73,7 @@ func.func @omp_parallel(%data_var : memref<i32>, %if_cond : i1, %num_threads : i
   // CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>) num_threads(%{{.*}} : i32)
     "omp.parallel"(%data_var, %data_var, %num_threads) ({
       omp.terminator
-    }) {operandSegmentSizes = array<i32: 1,1,0,1,0,0>} : (memref<i32>, memref<i32>, i32) -> ()
+    }) {operandSegmentSizes = array<i32: 1,1,0,0,1,0,0>} : (memref<i32>, memref<i32>, i32) -> ()
 
   // CHECK: omp.barrier
     omp.barrier
@@ -82,22 +82,22 @@ func.func @omp_parallel(%data_var : memref<i32>, %if_cond : i1, %num_threads : i
   // CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>) if(%{{.*}})
     "omp.parallel"(%data_var, %data_var, %if_cond) ({
       omp.terminator
-    }) {operandSegmentSizes = array<i32: 1,1,1,0,0,0>} : (memref<i32>, memref<i32>, i1) -> ()
+    }) {operandSegmentSizes = array<i32: 1,1,1,0,0,0,0>} : (memref<i32>, memref<i32>, i1) -> ()
 
   // test without allocate
   // CHECK: omp.parallel if(%{{.*}}) num_threads(%{{.*}} : i32)
     "omp.parallel"(%if_cond, %num_threads) ({
       omp.terminator
-    }) {operandSegmentSizes = array<i32: 0,0,1,1,0,0>} : (i1, i32) -> ()
+    }) {operandSegmentSizes = array<i32: 0,0,1,0,1,0,0>} : (i1, i32) -> ()
 
     omp.terminator
-  }) {operandSegmentSizes = array<i32: 1,1,1,1,0,0>, proc_bind_kind = #omp<procbindkind spread>} : (memref<i32>, memref<i32>, i1, i32) -> ()
+  }) {operandSegmentSizes = array<i32: 1,1,1,0,1,0,0>, proc_bind_kind = #omp<procbindkind spread>} : (memref<i32>, memref<i32>, i1, i32) -> ()
 
   // test with multiple parameters for single variadic argument
   // CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>)
   "omp.parallel" (%data_var, %data_var) ({
     omp.terminator
-  }) {operandSegmentSizes = array<i32: 1,1,0,0,0,0>} : (memref<i32>, memref<i32>) -> ()
+  }) {operandSegmentSizes = array<i32: 1,1,0,0,0,0,0>} : (memref<i32>, memref<i32>) -> ()
 
   // CHECK: omp.parallel
   omp.parallel {
@@ -160,6 +160,11 @@ func.func @omp_parallel_pretty(%data_var : memref<i32>, %if_cond : i1, %num_thre
    omp.terminator
  }
 
+ // CHECK: omp.parallel num_threads(dims(2): %{{.*}}, %{{.*}} : i64)
+ omp.parallel num_threads(dims(2): %n_i64, %n_i64 : i64) {
+   omp.terminator
+ }
+
  // CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>)
  omp.parallel allocate(%data_var : memref<i32> -> %data_var : memref<i32>) {
    omp.terminator

@llvmbot
Copy link
Member

llvmbot commented Dec 11, 2025

@llvm/pr-subscribers-mlir-openmp

Author: Chaitanya (skc7)

Changes

PR adds support of openmp 6.1 feature num_threads with dims modifier.
llvmIR translation for num_threads with dims modifier is marked as NYI.


Full diff: https://github.com/llvm/llvm-project/pull/171767.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td (+42-3)
  • (modified) mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp (+2)
  • (modified) mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp (+70-7)
  • (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+11-1)
  • (modified) mlir/test/Dialect/OpenMP/invalid.mlir (+32-1)
  • (modified) mlir/test/Dialect/OpenMP/ops.mlir (+10-5)
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index e36dc7c246f01..09c1d4a8a5866 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -1069,16 +1069,55 @@ class OpenMP_NumThreadsClauseSkip<
   > : OpenMP_Clause<traits, arguments, assemblyFormat, description,
                     extraClassDeclaration> {
   let arguments = (ins
+    ConfinedAttr<OptionalAttr<I64Attr>, [IntPositive]>:$num_threads_num_dims,
+    Variadic<AnyInteger>:$num_threads_dims_values,
     Optional<IntLikeType>:$num_threads
   );
 
   let optAssemblyFormat = [{
-    `num_threads` `(` $num_threads `:` type($num_threads) `)`
+    `num_threads` `(` custom<NumThreadsClause>(
+      $num_threads_num_dims, $num_threads_dims_values, type($num_threads_dims_values),
+      $num_threads, type($num_threads)
+    ) `)`
   }];
 
   let description = [{
-    The optional `num_threads` parameter specifies the number of threads which
-    should be used to execute the parallel region.
+    num_threads clause specifies the desired number of threads in the team
+    space formed by the construct on which it appears.
+
+    With dims modifier:
+    - Uses `num_threads_num_dims` (dimension count) and `num_threads_dims_values` (upper bounds list)
+    - Specifies upper bounds for each dimension (all must have same type)
+    - Format: `num_threads(dims(N): upper_bound_0, ..., upper_bound_N-1 : type)`
+    - Example: `num_threads(dims(3): %ub0, %ub1, %ub2 : i32)`
+
+    Without dims modifier:
+    - Uses `num_threads`
+    - If lower bound not specified, it defaults to upper bound value
+    - Format: `num_threads(bounds : type)`
+    - Example: `num_threads(%ub : i32)`
+  }];
+
+  let extraClassDeclaration = [{
+    /// Returns true if the dims modifier is explicitly present
+    bool hasNumThreadsDimsModifier() {
+      return getNumThreadsNumDims().has_value() && getNumThreadsNumDims().value();
+    }
+
+    /// Returns the number of dimensions specified by dims modifier
+    unsigned getNumThreadsDimsCount() {
+      if (!hasNumThreadsDimsModifier())
+        return 1;
+      return static_cast<unsigned>(*getNumThreadsNumDims());
+    }
+
+    /// Returns the value for a specific dimension index
+    /// Index must be less than getNumThreadsDimsCount()
+    ::mlir::Value getNumThreadsDimsValue(unsigned index) {
+      assert(index < getNumThreadsDimsCount() &&
+             "Num threads dims index out of bounds");
+      return getNumThreadsDimsValues()[index];
+    }
   }];
 }
 
diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
index 6423d49859c97..ab7bded7835be 100644
--- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
+++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
@@ -448,6 +448,8 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
         /* allocate_vars = */ llvm::SmallVector<Value>{},
         /* allocator_vars = */ llvm::SmallVector<Value>{},
         /* if_expr = */ Value{},
+        /* num_threads_num_dims = */ nullptr,
+        /* num_threads_dims_values = */ llvm::SmallVector<Value>{},
         /* num_threads = */ numThreadsVar,
         /* private_vars = */ ValueRange(),
         /* private_syms = */ nullptr,
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index d4dbf5f5244df..a9ed0274cd21c 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -2533,6 +2533,8 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state,
                        ArrayRef<NamedAttribute> attributes) {
   ParallelOp::build(builder, state, /*allocate_vars=*/ValueRange(),
                     /*allocator_vars=*/ValueRange(), /*if_expr=*/nullptr,
+                    /*num_threads_dims=*/nullptr,
+                    /*num_threads_values=*/ValueRange(),
                     /*num_threads=*/nullptr, /*private_vars=*/ValueRange(),
                     /*private_syms=*/nullptr, /*private_needs_barrier=*/nullptr,
                     /*proc_bind_kind=*/nullptr,
@@ -2544,13 +2546,14 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state,
 void ParallelOp::build(OpBuilder &builder, OperationState &state,
                        const ParallelOperands &clauses) {
   MLIRContext *ctx = builder.getContext();
-  ParallelOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
-                    clauses.ifExpr, clauses.numThreads, clauses.privateVars,
-                    makeArrayAttr(ctx, clauses.privateSyms),
-                    clauses.privateNeedsBarrier, clauses.procBindKind,
-                    clauses.reductionMod, clauses.reductionVars,
-                    makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
-                    makeArrayAttr(ctx, clauses.reductionSyms));
+  ParallelOp::build(
+      builder, state, clauses.allocateVars, clauses.allocatorVars,
+      clauses.ifExpr, clauses.numThreadsNumDims, clauses.numThreadsDimsValues,
+      clauses.numThreads, clauses.privateVars,
+      makeArrayAttr(ctx, clauses.privateSyms), clauses.privateNeedsBarrier,
+      clauses.procBindKind, clauses.reductionMod, clauses.reductionVars,
+      makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
+      makeArrayAttr(ctx, clauses.reductionSyms));
 }
 
 template <typename OpType>
@@ -2596,14 +2599,39 @@ static LogicalResult verifyPrivateVarList(OpType &op) {
   return success();
 }
 
+// Helper: Verify num_threads clause
+LogicalResult
+verifyNumThreadsClause(Operation *op,
+                       std::optional<IntegerAttr> numThreadsNumDims,
+                       OperandRange numThreadsDimsValues, Value numThreads) {
+  bool hasDimsModifier =
+      numThreadsNumDims.has_value() && numThreadsNumDims.value();
+  if (hasDimsModifier && numThreads) {
+    return op->emitError("num_threads with dims modifier cannot be used "
+                         "together with number of threads");
+  }
+  if (failed(verifyDimsModifier(op, numThreadsNumDims, numThreadsDimsValues)))
+    return failure();
+  return success();
+}
+
 LogicalResult ParallelOp::verify() {
+  // verify num_threads clause restrictions
+  if (failed(verifyNumThreadsClause(
+          getOperation(), this->getNumThreadsNumDimsAttr(),
+          this->getNumThreadsDimsValues(), this->getNumThreads())))
+    return failure();
+
+  // verify allocate clause restrictions
   if (getAllocateVars().size() != getAllocatorVars().size())
     return emitError(
         "expected equal sizes for allocate and allocator variables");
 
+  // verify private variables restrictions
   if (failed(verifyPrivateVarList(*this)))
     return failure();
 
+  // verify reduction variables restrictions
   return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
                                 getReductionByref());
 }
@@ -4647,6 +4675,41 @@ static void printNumTeamsClause(OpAsmPrinter &p, Operation *op,
   }
 }
 
+//===----------------------------------------------------------------------===//
+// Parser and printer for num_threads clause
+//===----------------------------------------------------------------------===//
+static ParseResult
+parseNumThreadsClause(OpAsmParser &parser, IntegerAttr &dimsAttr,
+                      SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
+                      SmallVectorImpl<Type> &types,
+                      std::optional<OpAsmParser::UnresolvedOperand> &bounds,
+                      Type &boundsType) {
+  if (succeeded(parseDimsModifierWithValues(parser, dimsAttr, values, types))) {
+    return success();
+  }
+
+  OpAsmParser::UnresolvedOperand boundsOperand;
+  if (parser.parseOperand(boundsOperand) || parser.parseColon() ||
+      parser.parseType(boundsType)) {
+    return failure();
+  }
+  bounds = boundsOperand;
+  return success();
+}
+
+static void printNumThreadsClause(OpAsmPrinter &p, Operation *op,
+                                  IntegerAttr dimsAttr, OperandRange values,
+                                  TypeRange types, Value bounds,
+                                  Type boundsType) {
+  if (!values.empty()) {
+    printDimsModifierWithValues(p, dimsAttr, values, types);
+  }
+  if (bounds) {
+    p.printOperand(bounds);
+    p << " : " << boundsType;
+  }
+}
+
 #define GET_ATTRDEF_CLASSES
 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
 
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 00f782e87d5af..2bfb9fb2211c4 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -2879,6 +2879,9 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
   if (auto ifVar = opInst.getIfExpr())
     ifCond = moduleTranslation.lookupValue(ifVar);
   llvm::Value *numThreads = nullptr;
+  // num_threads dims and values are not yet supported
+  assert(!opInst.hasNumThreadsDimsModifier() &&
+         "Lowering of num_threads with dims modifier is NYI.");
   if (auto numThreadsVar = opInst.getNumThreads())
     numThreads = moduleTranslation.lookupValue(numThreadsVar);
   auto pbKind = llvm::omp::OMP_PROC_BIND_default;
@@ -5604,6 +5607,9 @@ extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
               llvm_unreachable("unsupported host_eval use");
           })
           .Case([&](omp::ParallelOp parallelOp) {
+            // num_threads dims and values are not yet supported
+            assert(!parallelOp.hasNumThreadsDimsModifier() &&
+                   "Lowering of num_threads with dims modifier is NYI.");
             if (parallelOp.getNumThreads() == blockArg)
               numThreads = hostEvalVar;
             else
@@ -5724,8 +5730,12 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
       threadLimit = teamsOp.getThreadLimit();
     }
 
-    if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp))
+    if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp)) {
+      // num_threads dims and values are not yet supported
+      assert(!parallelOp.hasNumThreadsDimsModifier() &&
+             "Lowering of num_threads with dims modifier is NYI.");
       numThreads = parallelOp.getNumThreads();
+    }
   }
 
   // Handle clauses impacting the number of teams.
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index dd367aba8da27..db0ddcb415d42 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -30,6 +30,37 @@ func.func @num_threads_once(%n : si32) {
 
 // -----
 
+func.func @num_threads_dims_no_values() {
+  // expected-error@+1 {{dims modifier requires values to be specified}}
+  "omp.parallel"() ({
+    omp.terminator
+  }) {operandSegmentSizes = array<i32: 0,0,0,0,0,0,0>, num_threads_num_dims = 2 : i64} : () -> ()
+  return
+}
+
+// -----
+
+func.func @num_threads_dims_mismatch(%n : i64) {
+  // expected-error@+1 {{dims(2) specified but 1 values provided}}
+  omp.parallel num_threads(dims(2): %n : i64) {
+    omp.terminator
+  }
+
+  return
+}
+
+// -----
+
+func.func @num_threads_dims_and_scalar(%n : i64, %m: i64) {
+  // expected-error@+1 {{num_threads with dims modifier cannot be used together with number of threads}}
+  "omp.parallel"(%n, %n, %m) ({
+    omp.terminator
+  }) {operandSegmentSizes = array<i32: 0,0,0,2,1,0,0>, num_threads_num_dims = 2 : i64} : (i64, i64, i64) -> ()
+  return
+}
+
+// -----
+
 func.func @nowait_not_allowed(%n : memref<i32>) {
   // expected-error@+1 {{expected '{' to begin a region}}
   omp.parallel nowait {}
@@ -2766,7 +2797,7 @@ func.func @undefined_privatizer(%arg0: index) {
 // -----
 func.func @undefined_privatizer(%arg0: !llvm.ptr) {
   // expected-error @below {{inconsistent number of private variables and privatizer op symbols, private vars: 1 vs. privatizer op symbols: 2}}
-  "omp.parallel"(%arg0) <{operandSegmentSizes = array<i32: 0, 0, 0, 0, 1, 0>, private_syms = [@x.privatizer, @y.privatizer]}> ({
+  "omp.parallel"(%arg0) <{operandSegmentSizes = array<i32: 0, 0, 0, 0, 0, 1, 0>, private_syms = [@x.privatizer, @y.privatizer]}> ({
     ^bb0(%arg2: !llvm.ptr):
       omp.terminator
     }) : (!llvm.ptr) -> ()
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 3633a4be1eb62..585c9483c08a9 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -73,7 +73,7 @@ func.func @omp_parallel(%data_var : memref<i32>, %if_cond : i1, %num_threads : i
   // CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>) num_threads(%{{.*}} : i32)
     "omp.parallel"(%data_var, %data_var, %num_threads) ({
       omp.terminator
-    }) {operandSegmentSizes = array<i32: 1,1,0,1,0,0>} : (memref<i32>, memref<i32>, i32) -> ()
+    }) {operandSegmentSizes = array<i32: 1,1,0,0,1,0,0>} : (memref<i32>, memref<i32>, i32) -> ()
 
   // CHECK: omp.barrier
     omp.barrier
@@ -82,22 +82,22 @@ func.func @omp_parallel(%data_var : memref<i32>, %if_cond : i1, %num_threads : i
   // CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>) if(%{{.*}})
     "omp.parallel"(%data_var, %data_var, %if_cond) ({
       omp.terminator
-    }) {operandSegmentSizes = array<i32: 1,1,1,0,0,0>} : (memref<i32>, memref<i32>, i1) -> ()
+    }) {operandSegmentSizes = array<i32: 1,1,1,0,0,0,0>} : (memref<i32>, memref<i32>, i1) -> ()
 
   // test without allocate
   // CHECK: omp.parallel if(%{{.*}}) num_threads(%{{.*}} : i32)
     "omp.parallel"(%if_cond, %num_threads) ({
       omp.terminator
-    }) {operandSegmentSizes = array<i32: 0,0,1,1,0,0>} : (i1, i32) -> ()
+    }) {operandSegmentSizes = array<i32: 0,0,1,0,1,0,0>} : (i1, i32) -> ()
 
     omp.terminator
-  }) {operandSegmentSizes = array<i32: 1,1,1,1,0,0>, proc_bind_kind = #omp<procbindkind spread>} : (memref<i32>, memref<i32>, i1, i32) -> ()
+  }) {operandSegmentSizes = array<i32: 1,1,1,0,1,0,0>, proc_bind_kind = #omp<procbindkind spread>} : (memref<i32>, memref<i32>, i1, i32) -> ()
 
   // test with multiple parameters for single variadic argument
   // CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>)
   "omp.parallel" (%data_var, %data_var) ({
     omp.terminator
-  }) {operandSegmentSizes = array<i32: 1,1,0,0,0,0>} : (memref<i32>, memref<i32>) -> ()
+  }) {operandSegmentSizes = array<i32: 1,1,0,0,0,0,0>} : (memref<i32>, memref<i32>) -> ()
 
   // CHECK: omp.parallel
   omp.parallel {
@@ -160,6 +160,11 @@ func.func @omp_parallel_pretty(%data_var : memref<i32>, %if_cond : i1, %num_thre
    omp.terminator
  }
 
+ // CHECK: omp.parallel num_threads(dims(2): %{{.*}}, %{{.*}} : i64)
+ omp.parallel num_threads(dims(2): %n_i64, %n_i64 : i64) {
+   omp.terminator
+ }
+
  // CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>)
  omp.parallel allocate(%data_var : memref<i32> -> %data_var : memref<i32>) {
    omp.terminator

@llvmbot
Copy link
Member

llvmbot commented Dec 11, 2025

@llvm/pr-subscribers-mlir-llvm

Author: Chaitanya (skc7)

Changes

PR adds support of openmp 6.1 feature num_threads with dims modifier.
llvmIR translation for num_threads with dims modifier is marked as NYI.


Full diff: https://github.com/llvm/llvm-project/pull/171767.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td (+42-3)
  • (modified) mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp (+2)
  • (modified) mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp (+70-7)
  • (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+11-1)
  • (modified) mlir/test/Dialect/OpenMP/invalid.mlir (+32-1)
  • (modified) mlir/test/Dialect/OpenMP/ops.mlir (+10-5)
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index e36dc7c246f01..09c1d4a8a5866 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -1069,16 +1069,55 @@ class OpenMP_NumThreadsClauseSkip<
   > : OpenMP_Clause<traits, arguments, assemblyFormat, description,
                     extraClassDeclaration> {
   let arguments = (ins
+    ConfinedAttr<OptionalAttr<I64Attr>, [IntPositive]>:$num_threads_num_dims,
+    Variadic<AnyInteger>:$num_threads_dims_values,
     Optional<IntLikeType>:$num_threads
   );
 
   let optAssemblyFormat = [{
-    `num_threads` `(` $num_threads `:` type($num_threads) `)`
+    `num_threads` `(` custom<NumThreadsClause>(
+      $num_threads_num_dims, $num_threads_dims_values, type($num_threads_dims_values),
+      $num_threads, type($num_threads)
+    ) `)`
   }];
 
   let description = [{
-    The optional `num_threads` parameter specifies the number of threads which
-    should be used to execute the parallel region.
+    num_threads clause specifies the desired number of threads in the team
+    space formed by the construct on which it appears.
+
+    With dims modifier:
+    - Uses `num_threads_num_dims` (dimension count) and `num_threads_dims_values` (upper bounds list)
+    - Specifies upper bounds for each dimension (all must have same type)
+    - Format: `num_threads(dims(N): upper_bound_0, ..., upper_bound_N-1 : type)`
+    - Example: `num_threads(dims(3): %ub0, %ub1, %ub2 : i32)`
+
+    Without dims modifier:
+    - Uses `num_threads`
+    - If lower bound not specified, it defaults to upper bound value
+    - Format: `num_threads(bounds : type)`
+    - Example: `num_threads(%ub : i32)`
+  }];
+
+  let extraClassDeclaration = [{
+    /// Returns true if the dims modifier is explicitly present
+    bool hasNumThreadsDimsModifier() {
+      return getNumThreadsNumDims().has_value() && getNumThreadsNumDims().value();
+    }
+
+    /// Returns the number of dimensions specified by dims modifier
+    unsigned getNumThreadsDimsCount() {
+      if (!hasNumThreadsDimsModifier())
+        return 1;
+      return static_cast<unsigned>(*getNumThreadsNumDims());
+    }
+
+    /// Returns the value for a specific dimension index
+    /// Index must be less than getNumThreadsDimsCount()
+    ::mlir::Value getNumThreadsDimsValue(unsigned index) {
+      assert(index < getNumThreadsDimsCount() &&
+             "Num threads dims index out of bounds");
+      return getNumThreadsDimsValues()[index];
+    }
   }];
 }
 
diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
index 6423d49859c97..ab7bded7835be 100644
--- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
+++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
@@ -448,6 +448,8 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
         /* allocate_vars = */ llvm::SmallVector<Value>{},
         /* allocator_vars = */ llvm::SmallVector<Value>{},
         /* if_expr = */ Value{},
+        /* num_threads_num_dims = */ nullptr,
+        /* num_threads_dims_values = */ llvm::SmallVector<Value>{},
         /* num_threads = */ numThreadsVar,
         /* private_vars = */ ValueRange(),
         /* private_syms = */ nullptr,
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index d4dbf5f5244df..a9ed0274cd21c 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -2533,6 +2533,8 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state,
                        ArrayRef<NamedAttribute> attributes) {
   ParallelOp::build(builder, state, /*allocate_vars=*/ValueRange(),
                     /*allocator_vars=*/ValueRange(), /*if_expr=*/nullptr,
+                    /*num_threads_dims=*/nullptr,
+                    /*num_threads_values=*/ValueRange(),
                     /*num_threads=*/nullptr, /*private_vars=*/ValueRange(),
                     /*private_syms=*/nullptr, /*private_needs_barrier=*/nullptr,
                     /*proc_bind_kind=*/nullptr,
@@ -2544,13 +2546,14 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state,
 void ParallelOp::build(OpBuilder &builder, OperationState &state,
                        const ParallelOperands &clauses) {
   MLIRContext *ctx = builder.getContext();
-  ParallelOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
-                    clauses.ifExpr, clauses.numThreads, clauses.privateVars,
-                    makeArrayAttr(ctx, clauses.privateSyms),
-                    clauses.privateNeedsBarrier, clauses.procBindKind,
-                    clauses.reductionMod, clauses.reductionVars,
-                    makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
-                    makeArrayAttr(ctx, clauses.reductionSyms));
+  ParallelOp::build(
+      builder, state, clauses.allocateVars, clauses.allocatorVars,
+      clauses.ifExpr, clauses.numThreadsNumDims, clauses.numThreadsDimsValues,
+      clauses.numThreads, clauses.privateVars,
+      makeArrayAttr(ctx, clauses.privateSyms), clauses.privateNeedsBarrier,
+      clauses.procBindKind, clauses.reductionMod, clauses.reductionVars,
+      makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
+      makeArrayAttr(ctx, clauses.reductionSyms));
 }
 
 template <typename OpType>
@@ -2596,14 +2599,39 @@ static LogicalResult verifyPrivateVarList(OpType &op) {
   return success();
 }
 
+// Helper: Verify num_threads clause
+LogicalResult
+verifyNumThreadsClause(Operation *op,
+                       std::optional<IntegerAttr> numThreadsNumDims,
+                       OperandRange numThreadsDimsValues, Value numThreads) {
+  bool hasDimsModifier =
+      numThreadsNumDims.has_value() && numThreadsNumDims.value();
+  if (hasDimsModifier && numThreads) {
+    return op->emitError("num_threads with dims modifier cannot be used "
+                         "together with number of threads");
+  }
+  if (failed(verifyDimsModifier(op, numThreadsNumDims, numThreadsDimsValues)))
+    return failure();
+  return success();
+}
+
 LogicalResult ParallelOp::verify() {
+  // verify num_threads clause restrictions
+  if (failed(verifyNumThreadsClause(
+          getOperation(), this->getNumThreadsNumDimsAttr(),
+          this->getNumThreadsDimsValues(), this->getNumThreads())))
+    return failure();
+
+  // verify allocate clause restrictions
   if (getAllocateVars().size() != getAllocatorVars().size())
     return emitError(
         "expected equal sizes for allocate and allocator variables");
 
+  // verify private variables restrictions
   if (failed(verifyPrivateVarList(*this)))
     return failure();
 
+  // verify reduction variables restrictions
   return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
                                 getReductionByref());
 }
@@ -4647,6 +4675,41 @@ static void printNumTeamsClause(OpAsmPrinter &p, Operation *op,
   }
 }
 
+//===----------------------------------------------------------------------===//
+// Parser and printer for num_threads clause
+//===----------------------------------------------------------------------===//
+static ParseResult
+parseNumThreadsClause(OpAsmParser &parser, IntegerAttr &dimsAttr,
+                      SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
+                      SmallVectorImpl<Type> &types,
+                      std::optional<OpAsmParser::UnresolvedOperand> &bounds,
+                      Type &boundsType) {
+  if (succeeded(parseDimsModifierWithValues(parser, dimsAttr, values, types))) {
+    return success();
+  }
+
+  OpAsmParser::UnresolvedOperand boundsOperand;
+  if (parser.parseOperand(boundsOperand) || parser.parseColon() ||
+      parser.parseType(boundsType)) {
+    return failure();
+  }
+  bounds = boundsOperand;
+  return success();
+}
+
+static void printNumThreadsClause(OpAsmPrinter &p, Operation *op,
+                                  IntegerAttr dimsAttr, OperandRange values,
+                                  TypeRange types, Value bounds,
+                                  Type boundsType) {
+  if (!values.empty()) {
+    printDimsModifierWithValues(p, dimsAttr, values, types);
+  }
+  if (bounds) {
+    p.printOperand(bounds);
+    p << " : " << boundsType;
+  }
+}
+
 #define GET_ATTRDEF_CLASSES
 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
 
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 00f782e87d5af..2bfb9fb2211c4 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -2879,6 +2879,9 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
   if (auto ifVar = opInst.getIfExpr())
     ifCond = moduleTranslation.lookupValue(ifVar);
   llvm::Value *numThreads = nullptr;
+  // num_threads dims and values are not yet supported
+  assert(!opInst.hasNumThreadsDimsModifier() &&
+         "Lowering of num_threads with dims modifier is NYI.");
   if (auto numThreadsVar = opInst.getNumThreads())
     numThreads = moduleTranslation.lookupValue(numThreadsVar);
   auto pbKind = llvm::omp::OMP_PROC_BIND_default;
@@ -5604,6 +5607,9 @@ extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
               llvm_unreachable("unsupported host_eval use");
           })
           .Case([&](omp::ParallelOp parallelOp) {
+            // num_threads dims and values are not yet supported
+            assert(!parallelOp.hasNumThreadsDimsModifier() &&
+                   "Lowering of num_threads with dims modifier is NYI.");
             if (parallelOp.getNumThreads() == blockArg)
               numThreads = hostEvalVar;
             else
@@ -5724,8 +5730,12 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
       threadLimit = teamsOp.getThreadLimit();
     }
 
-    if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp))
+    if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp)) {
+      // num_threads dims and values are not yet supported
+      assert(!parallelOp.hasNumThreadsDimsModifier() &&
+             "Lowering of num_threads with dims modifier is NYI.");
       numThreads = parallelOp.getNumThreads();
+    }
   }
 
   // Handle clauses impacting the number of teams.
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index dd367aba8da27..db0ddcb415d42 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -30,6 +30,37 @@ func.func @num_threads_once(%n : si32) {
 
 // -----
 
+func.func @num_threads_dims_no_values() {
+  // expected-error@+1 {{dims modifier requires values to be specified}}
+  "omp.parallel"() ({
+    omp.terminator
+  }) {operandSegmentSizes = array<i32: 0,0,0,0,0,0,0>, num_threads_num_dims = 2 : i64} : () -> ()
+  return
+}
+
+// -----
+
+func.func @num_threads_dims_mismatch(%n : i64) {
+  // expected-error@+1 {{dims(2) specified but 1 values provided}}
+  omp.parallel num_threads(dims(2): %n : i64) {
+    omp.terminator
+  }
+
+  return
+}
+
+// -----
+
+func.func @num_threads_dims_and_scalar(%n : i64, %m: i64) {
+  // expected-error@+1 {{num_threads with dims modifier cannot be used together with number of threads}}
+  "omp.parallel"(%n, %n, %m) ({
+    omp.terminator
+  }) {operandSegmentSizes = array<i32: 0,0,0,2,1,0,0>, num_threads_num_dims = 2 : i64} : (i64, i64, i64) -> ()
+  return
+}
+
+// -----
+
 func.func @nowait_not_allowed(%n : memref<i32>) {
   // expected-error@+1 {{expected '{' to begin a region}}
   omp.parallel nowait {}
@@ -2766,7 +2797,7 @@ func.func @undefined_privatizer(%arg0: index) {
 // -----
 func.func @undefined_privatizer(%arg0: !llvm.ptr) {
   // expected-error @below {{inconsistent number of private variables and privatizer op symbols, private vars: 1 vs. privatizer op symbols: 2}}
-  "omp.parallel"(%arg0) <{operandSegmentSizes = array<i32: 0, 0, 0, 0, 1, 0>, private_syms = [@x.privatizer, @y.privatizer]}> ({
+  "omp.parallel"(%arg0) <{operandSegmentSizes = array<i32: 0, 0, 0, 0, 0, 1, 0>, private_syms = [@x.privatizer, @y.privatizer]}> ({
     ^bb0(%arg2: !llvm.ptr):
       omp.terminator
     }) : (!llvm.ptr) -> ()
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 3633a4be1eb62..585c9483c08a9 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -73,7 +73,7 @@ func.func @omp_parallel(%data_var : memref<i32>, %if_cond : i1, %num_threads : i
   // CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>) num_threads(%{{.*}} : i32)
     "omp.parallel"(%data_var, %data_var, %num_threads) ({
       omp.terminator
-    }) {operandSegmentSizes = array<i32: 1,1,0,1,0,0>} : (memref<i32>, memref<i32>, i32) -> ()
+    }) {operandSegmentSizes = array<i32: 1,1,0,0,1,0,0>} : (memref<i32>, memref<i32>, i32) -> ()
 
   // CHECK: omp.barrier
     omp.barrier
@@ -82,22 +82,22 @@ func.func @omp_parallel(%data_var : memref<i32>, %if_cond : i1, %num_threads : i
   // CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>) if(%{{.*}})
     "omp.parallel"(%data_var, %data_var, %if_cond) ({
       omp.terminator
-    }) {operandSegmentSizes = array<i32: 1,1,1,0,0,0>} : (memref<i32>, memref<i32>, i1) -> ()
+    }) {operandSegmentSizes = array<i32: 1,1,1,0,0,0,0>} : (memref<i32>, memref<i32>, i1) -> ()
 
   // test without allocate
   // CHECK: omp.parallel if(%{{.*}}) num_threads(%{{.*}} : i32)
     "omp.parallel"(%if_cond, %num_threads) ({
       omp.terminator
-    }) {operandSegmentSizes = array<i32: 0,0,1,1,0,0>} : (i1, i32) -> ()
+    }) {operandSegmentSizes = array<i32: 0,0,1,0,1,0,0>} : (i1, i32) -> ()
 
     omp.terminator
-  }) {operandSegmentSizes = array<i32: 1,1,1,1,0,0>, proc_bind_kind = #omp<procbindkind spread>} : (memref<i32>, memref<i32>, i1, i32) -> ()
+  }) {operandSegmentSizes = array<i32: 1,1,1,0,1,0,0>, proc_bind_kind = #omp<procbindkind spread>} : (memref<i32>, memref<i32>, i1, i32) -> ()
 
   // test with multiple parameters for single variadic argument
   // CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>)
   "omp.parallel" (%data_var, %data_var) ({
     omp.terminator
-  }) {operandSegmentSizes = array<i32: 1,1,0,0,0,0>} : (memref<i32>, memref<i32>) -> ()
+  }) {operandSegmentSizes = array<i32: 1,1,0,0,0,0,0>} : (memref<i32>, memref<i32>) -> ()
 
   // CHECK: omp.parallel
   omp.parallel {
@@ -160,6 +160,11 @@ func.func @omp_parallel_pretty(%data_var : memref<i32>, %if_cond : i1, %num_thre
    omp.terminator
  }
 
+ // CHECK: omp.parallel num_threads(dims(2): %{{.*}}, %{{.*}} : i64)
+ omp.parallel num_threads(dims(2): %n_i64, %n_i64 : i64) {
+   omp.terminator
+ }
+
  // CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>)
  omp.parallel allocate(%data_var : memref<i32> -> %data_var : memref<i32>) {
    omp.terminator

@llvmbot
Copy link
Member

llvmbot commented Dec 11, 2025

@llvm/pr-subscribers-flang-openmp

Author: Chaitanya (skc7)

Changes

PR adds support of openmp 6.1 feature num_threads with dims modifier.
llvmIR translation for num_threads with dims modifier is marked as NYI.


Full diff: https://github.com/llvm/llvm-project/pull/171767.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td (+42-3)
  • (modified) mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp (+2)
  • (modified) mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp (+70-7)
  • (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+11-1)
  • (modified) mlir/test/Dialect/OpenMP/invalid.mlir (+32-1)
  • (modified) mlir/test/Dialect/OpenMP/ops.mlir (+10-5)
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index e36dc7c246f01..09c1d4a8a5866 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -1069,16 +1069,55 @@ class OpenMP_NumThreadsClauseSkip<
   > : OpenMP_Clause<traits, arguments, assemblyFormat, description,
                     extraClassDeclaration> {
   let arguments = (ins
+    ConfinedAttr<OptionalAttr<I64Attr>, [IntPositive]>:$num_threads_num_dims,
+    Variadic<AnyInteger>:$num_threads_dims_values,
     Optional<IntLikeType>:$num_threads
   );
 
   let optAssemblyFormat = [{
-    `num_threads` `(` $num_threads `:` type($num_threads) `)`
+    `num_threads` `(` custom<NumThreadsClause>(
+      $num_threads_num_dims, $num_threads_dims_values, type($num_threads_dims_values),
+      $num_threads, type($num_threads)
+    ) `)`
   }];
 
   let description = [{
-    The optional `num_threads` parameter specifies the number of threads which
-    should be used to execute the parallel region.
+    num_threads clause specifies the desired number of threads in the team
+    space formed by the construct on which it appears.
+
+    With dims modifier:
+    - Uses `num_threads_num_dims` (dimension count) and `num_threads_dims_values` (upper bounds list)
+    - Specifies upper bounds for each dimension (all must have same type)
+    - Format: `num_threads(dims(N): upper_bound_0, ..., upper_bound_N-1 : type)`
+    - Example: `num_threads(dims(3): %ub0, %ub1, %ub2 : i32)`
+
+    Without dims modifier:
+    - Uses `num_threads`
+    - If lower bound not specified, it defaults to upper bound value
+    - Format: `num_threads(bounds : type)`
+    - Example: `num_threads(%ub : i32)`
+  }];
+
+  let extraClassDeclaration = [{
+    /// Returns true if the dims modifier is explicitly present
+    bool hasNumThreadsDimsModifier() {
+      return getNumThreadsNumDims().has_value() && getNumThreadsNumDims().value();
+    }
+
+    /// Returns the number of dimensions specified by dims modifier
+    unsigned getNumThreadsDimsCount() {
+      if (!hasNumThreadsDimsModifier())
+        return 1;
+      return static_cast<unsigned>(*getNumThreadsNumDims());
+    }
+
+    /// Returns the value for a specific dimension index
+    /// Index must be less than getNumThreadsDimsCount()
+    ::mlir::Value getNumThreadsDimsValue(unsigned index) {
+      assert(index < getNumThreadsDimsCount() &&
+             "Num threads dims index out of bounds");
+      return getNumThreadsDimsValues()[index];
+    }
   }];
 }
 
diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
index 6423d49859c97..ab7bded7835be 100644
--- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
+++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
@@ -448,6 +448,8 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
         /* allocate_vars = */ llvm::SmallVector<Value>{},
         /* allocator_vars = */ llvm::SmallVector<Value>{},
         /* if_expr = */ Value{},
+        /* num_threads_num_dims = */ nullptr,
+        /* num_threads_dims_values = */ llvm::SmallVector<Value>{},
         /* num_threads = */ numThreadsVar,
         /* private_vars = */ ValueRange(),
         /* private_syms = */ nullptr,
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index d4dbf5f5244df..a9ed0274cd21c 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -2533,6 +2533,8 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state,
                        ArrayRef<NamedAttribute> attributes) {
   ParallelOp::build(builder, state, /*allocate_vars=*/ValueRange(),
                     /*allocator_vars=*/ValueRange(), /*if_expr=*/nullptr,
+                    /*num_threads_dims=*/nullptr,
+                    /*num_threads_values=*/ValueRange(),
                     /*num_threads=*/nullptr, /*private_vars=*/ValueRange(),
                     /*private_syms=*/nullptr, /*private_needs_barrier=*/nullptr,
                     /*proc_bind_kind=*/nullptr,
@@ -2544,13 +2546,14 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state,
 void ParallelOp::build(OpBuilder &builder, OperationState &state,
                        const ParallelOperands &clauses) {
   MLIRContext *ctx = builder.getContext();
-  ParallelOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
-                    clauses.ifExpr, clauses.numThreads, clauses.privateVars,
-                    makeArrayAttr(ctx, clauses.privateSyms),
-                    clauses.privateNeedsBarrier, clauses.procBindKind,
-                    clauses.reductionMod, clauses.reductionVars,
-                    makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
-                    makeArrayAttr(ctx, clauses.reductionSyms));
+  ParallelOp::build(
+      builder, state, clauses.allocateVars, clauses.allocatorVars,
+      clauses.ifExpr, clauses.numThreadsNumDims, clauses.numThreadsDimsValues,
+      clauses.numThreads, clauses.privateVars,
+      makeArrayAttr(ctx, clauses.privateSyms), clauses.privateNeedsBarrier,
+      clauses.procBindKind, clauses.reductionMod, clauses.reductionVars,
+      makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
+      makeArrayAttr(ctx, clauses.reductionSyms));
 }
 
 template <typename OpType>
@@ -2596,14 +2599,39 @@ static LogicalResult verifyPrivateVarList(OpType &op) {
   return success();
 }
 
+// Helper: Verify num_threads clause
+LogicalResult
+verifyNumThreadsClause(Operation *op,
+                       std::optional<IntegerAttr> numThreadsNumDims,
+                       OperandRange numThreadsDimsValues, Value numThreads) {
+  bool hasDimsModifier =
+      numThreadsNumDims.has_value() && numThreadsNumDims.value();
+  if (hasDimsModifier && numThreads) {
+    return op->emitError("num_threads with dims modifier cannot be used "
+                         "together with number of threads");
+  }
+  if (failed(verifyDimsModifier(op, numThreadsNumDims, numThreadsDimsValues)))
+    return failure();
+  return success();
+}
+
 LogicalResult ParallelOp::verify() {
+  // verify num_threads clause restrictions
+  if (failed(verifyNumThreadsClause(
+          getOperation(), this->getNumThreadsNumDimsAttr(),
+          this->getNumThreadsDimsValues(), this->getNumThreads())))
+    return failure();
+
+  // verify allocate clause restrictions
   if (getAllocateVars().size() != getAllocatorVars().size())
     return emitError(
         "expected equal sizes for allocate and allocator variables");
 
+  // verify private variables restrictions
   if (failed(verifyPrivateVarList(*this)))
     return failure();
 
+  // verify reduction variables restrictions
   return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
                                 getReductionByref());
 }
@@ -4647,6 +4675,41 @@ static void printNumTeamsClause(OpAsmPrinter &p, Operation *op,
   }
 }
 
+//===----------------------------------------------------------------------===//
+// Parser and printer for num_threads clause
+//===----------------------------------------------------------------------===//
+static ParseResult
+parseNumThreadsClause(OpAsmParser &parser, IntegerAttr &dimsAttr,
+                      SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
+                      SmallVectorImpl<Type> &types,
+                      std::optional<OpAsmParser::UnresolvedOperand> &bounds,
+                      Type &boundsType) {
+  if (succeeded(parseDimsModifierWithValues(parser, dimsAttr, values, types))) {
+    return success();
+  }
+
+  OpAsmParser::UnresolvedOperand boundsOperand;
+  if (parser.parseOperand(boundsOperand) || parser.parseColon() ||
+      parser.parseType(boundsType)) {
+    return failure();
+  }
+  bounds = boundsOperand;
+  return success();
+}
+
+static void printNumThreadsClause(OpAsmPrinter &p, Operation *op,
+                                  IntegerAttr dimsAttr, OperandRange values,
+                                  TypeRange types, Value bounds,
+                                  Type boundsType) {
+  if (!values.empty()) {
+    printDimsModifierWithValues(p, dimsAttr, values, types);
+  }
+  if (bounds) {
+    p.printOperand(bounds);
+    p << " : " << boundsType;
+  }
+}
+
 #define GET_ATTRDEF_CLASSES
 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
 
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 00f782e87d5af..2bfb9fb2211c4 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -2879,6 +2879,9 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
   if (auto ifVar = opInst.getIfExpr())
     ifCond = moduleTranslation.lookupValue(ifVar);
   llvm::Value *numThreads = nullptr;
+  // num_threads dims and values are not yet supported
+  assert(!opInst.hasNumThreadsDimsModifier() &&
+         "Lowering of num_threads with dims modifier is NYI.");
   if (auto numThreadsVar = opInst.getNumThreads())
     numThreads = moduleTranslation.lookupValue(numThreadsVar);
   auto pbKind = llvm::omp::OMP_PROC_BIND_default;
@@ -5604,6 +5607,9 @@ extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
               llvm_unreachable("unsupported host_eval use");
           })
           .Case([&](omp::ParallelOp parallelOp) {
+            // num_threads dims and values are not yet supported
+            assert(!parallelOp.hasNumThreadsDimsModifier() &&
+                   "Lowering of num_threads with dims modifier is NYI.");
             if (parallelOp.getNumThreads() == blockArg)
               numThreads = hostEvalVar;
             else
@@ -5724,8 +5730,12 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
       threadLimit = teamsOp.getThreadLimit();
     }
 
-    if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp))
+    if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp)) {
+      // num_threads dims and values are not yet supported
+      assert(!parallelOp.hasNumThreadsDimsModifier() &&
+             "Lowering of num_threads with dims modifier is NYI.");
       numThreads = parallelOp.getNumThreads();
+    }
   }
 
   // Handle clauses impacting the number of teams.
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index dd367aba8da27..db0ddcb415d42 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -30,6 +30,37 @@ func.func @num_threads_once(%n : si32) {
 
 // -----
 
+func.func @num_threads_dims_no_values() {
+  // expected-error@+1 {{dims modifier requires values to be specified}}
+  "omp.parallel"() ({
+    omp.terminator
+  }) {operandSegmentSizes = array<i32: 0,0,0,0,0,0,0>, num_threads_num_dims = 2 : i64} : () -> ()
+  return
+}
+
+// -----
+
+func.func @num_threads_dims_mismatch(%n : i64) {
+  // expected-error@+1 {{dims(2) specified but 1 values provided}}
+  omp.parallel num_threads(dims(2): %n : i64) {
+    omp.terminator
+  }
+
+  return
+}
+
+// -----
+
+func.func @num_threads_dims_and_scalar(%n : i64, %m: i64) {
+  // expected-error@+1 {{num_threads with dims modifier cannot be used together with number of threads}}
+  "omp.parallel"(%n, %n, %m) ({
+    omp.terminator
+  }) {operandSegmentSizes = array<i32: 0,0,0,2,1,0,0>, num_threads_num_dims = 2 : i64} : (i64, i64, i64) -> ()
+  return
+}
+
+// -----
+
 func.func @nowait_not_allowed(%n : memref<i32>) {
   // expected-error@+1 {{expected '{' to begin a region}}
   omp.parallel nowait {}
@@ -2766,7 +2797,7 @@ func.func @undefined_privatizer(%arg0: index) {
 // -----
 func.func @undefined_privatizer(%arg0: !llvm.ptr) {
   // expected-error @below {{inconsistent number of private variables and privatizer op symbols, private vars: 1 vs. privatizer op symbols: 2}}
-  "omp.parallel"(%arg0) <{operandSegmentSizes = array<i32: 0, 0, 0, 0, 1, 0>, private_syms = [@x.privatizer, @y.privatizer]}> ({
+  "omp.parallel"(%arg0) <{operandSegmentSizes = array<i32: 0, 0, 0, 0, 0, 1, 0>, private_syms = [@x.privatizer, @y.privatizer]}> ({
     ^bb0(%arg2: !llvm.ptr):
       omp.terminator
     }) : (!llvm.ptr) -> ()
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 3633a4be1eb62..585c9483c08a9 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -73,7 +73,7 @@ func.func @omp_parallel(%data_var : memref<i32>, %if_cond : i1, %num_threads : i
   // CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>) num_threads(%{{.*}} : i32)
     "omp.parallel"(%data_var, %data_var, %num_threads) ({
       omp.terminator
-    }) {operandSegmentSizes = array<i32: 1,1,0,1,0,0>} : (memref<i32>, memref<i32>, i32) -> ()
+    }) {operandSegmentSizes = array<i32: 1,1,0,0,1,0,0>} : (memref<i32>, memref<i32>, i32) -> ()
 
   // CHECK: omp.barrier
     omp.barrier
@@ -82,22 +82,22 @@ func.func @omp_parallel(%data_var : memref<i32>, %if_cond : i1, %num_threads : i
   // CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>) if(%{{.*}})
     "omp.parallel"(%data_var, %data_var, %if_cond) ({
       omp.terminator
-    }) {operandSegmentSizes = array<i32: 1,1,1,0,0,0>} : (memref<i32>, memref<i32>, i1) -> ()
+    }) {operandSegmentSizes = array<i32: 1,1,1,0,0,0,0>} : (memref<i32>, memref<i32>, i1) -> ()
 
   // test without allocate
   // CHECK: omp.parallel if(%{{.*}}) num_threads(%{{.*}} : i32)
     "omp.parallel"(%if_cond, %num_threads) ({
       omp.terminator
-    }) {operandSegmentSizes = array<i32: 0,0,1,1,0,0>} : (i1, i32) -> ()
+    }) {operandSegmentSizes = array<i32: 0,0,1,0,1,0,0>} : (i1, i32) -> ()
 
     omp.terminator
-  }) {operandSegmentSizes = array<i32: 1,1,1,1,0,0>, proc_bind_kind = #omp<procbindkind spread>} : (memref<i32>, memref<i32>, i1, i32) -> ()
+  }) {operandSegmentSizes = array<i32: 1,1,1,0,1,0,0>, proc_bind_kind = #omp<procbindkind spread>} : (memref<i32>, memref<i32>, i1, i32) -> ()
 
   // test with multiple parameters for single variadic argument
   // CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>)
   "omp.parallel" (%data_var, %data_var) ({
     omp.terminator
-  }) {operandSegmentSizes = array<i32: 1,1,0,0,0,0>} : (memref<i32>, memref<i32>) -> ()
+  }) {operandSegmentSizes = array<i32: 1,1,0,0,0,0,0>} : (memref<i32>, memref<i32>) -> ()
 
   // CHECK: omp.parallel
   omp.parallel {
@@ -160,6 +160,11 @@ func.func @omp_parallel_pretty(%data_var : memref<i32>, %if_cond : i1, %num_thre
    omp.terminator
  }
 
+ // CHECK: omp.parallel num_threads(dims(2): %{{.*}}, %{{.*}} : i64)
+ omp.parallel num_threads(dims(2): %n_i64, %n_i64 : i64) {
+   omp.terminator
+ }
+
  // CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>)
  omp.parallel allocate(%data_var : memref<i32> -> %data_var : memref<i32>) {
    omp.terminator

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants