diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index 2e15f4de4545d..64c7e5700c771 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -378,8 +378,15 @@ static LogicalResult checkImplementationStatus(Operation &op) { result = todo("num_teams with multi-dimensional values"); }; auto checkNumThreads = [&todo](auto op, LogicalResult &result) { - if (op.hasNumThreadsMultiDim()) - result = todo("num_threads with multi-dimensional values"); + if (op.getNumThreadsDimsCount() > 3) { + result = todo("num_threads with more than 3 dimensions"); + return; + } + + if (op.hasNumThreadsMultiDim() && + !op->template getParentOfType()) + result = todo( + "num_threads with multi-dimensional values outside target region"); }; auto checkThreadLimit = [&todo](auto op, LogicalResult &result) { @@ -6501,13 +6508,12 @@ createDeviceArgumentAccessor(MapInfoData &mapData, llvm::Argument &arg, /// /// Loop bounds and steps are only optionally populated, if output vectors are /// provided. -static void -extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads, - Value &numTeamsLower, Value &numTeamsUpper, - Value &threadLimit, - llvm::SmallVectorImpl *lowerBounds = nullptr, - llvm::SmallVectorImpl *upperBounds = nullptr, - llvm::SmallVectorImpl *steps = nullptr) { +static void extractHostEvalClauses( + omp::TargetOp targetOp, llvm::SmallVectorImpl &numThreadsVars, + Value &numTeamsLower, Value &numTeamsUpper, Value &threadLimit, + llvm::SmallVectorImpl *lowerBounds = nullptr, + llvm::SmallVectorImpl *upperBounds = nullptr, + llvm::SmallVectorImpl *steps = nullptr) { auto blockArgIface = llvm::cast(*targetOp); for (auto item : llvm::zip_equal(targetOp.getHostEvalVars(), blockArgIface.getHostEvalBlockArgs())) { @@ -6528,11 +6534,19 @@ extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads, llvm_unreachable("unsupported host_eval use"); }) .Case([&](omp::ParallelOp parallelOp) { - if (!parallelOp.getNumThreadsVars().empty() && - parallelOp.getNumThreads(0) == blockArg) - numThreads = hostEvalVar; - else + if (llvm::is_contained(parallelOp.getNumThreadsVars(), blockArg)) { + for (auto [i, threadsVar] : + llvm::enumerate(parallelOp.getNumThreadsVars())) { + if (threadsVar == blockArg) { + if (numThreadsVars.size() <= i) + numThreadsVars.resize(i + 1); + numThreadsVars[i] = hostEvalVar; + break; + } + } + } else { llvm_unreachable("unsupported host_eval use"); + } }) .Case([&](omp::LoopNestOp loopOp) { auto processBounds = @@ -6639,10 +6653,11 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp, bool isTargetDevice, bool isGPU) { // TODO: Handle constant 'if' clauses. - Value numThreads, numTeamsLower, numTeamsUpper, threadLimit; + Value numTeamsLower, numTeamsUpper, threadLimit; + llvm::SmallVector numThreadsVars; if (!isTargetDevice) { - extractHostEvalClauses(targetOp, numThreads, numTeamsLower, numTeamsUpper, - threadLimit); + extractHostEvalClauses(targetOp, numThreadsVars, numTeamsLower, + numTeamsUpper, threadLimit); } else { // In the target device, values for these clauses are not passed as // host_eval, but instead evaluated prior to entry to the region. This @@ -6657,8 +6672,10 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp, } if (auto parallelOp = castOrGetParentOfType(capturedOp)) { - if (!parallelOp.getNumThreadsVars().empty()) - numThreads = parallelOp.getNumThreads(0); + // Handle multi-dimensional num_threads + numThreadsVars.reserve(parallelOp.getNumThreadsVars().size()); + for (auto threadsVar : parallelOp.getNumThreadsVars()) + numThreadsVars.push_back(threadsVar); } } @@ -6706,10 +6723,12 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp, // Extract 'max_threads' clause from 'parallel' or set to 1 if it's SIMD. int32_t maxThreadsVal = -1; - if (castOrGetParentOfType(capturedOp)) - setMaxValueFromClause(numThreads, maxThreadsVal); - else if (castOrGetParentOfType(capturedOp, - /*immediateParent=*/true)) + if (castOrGetParentOfType(capturedOp)) { + // For multi-dimensional num_threads, only use the first dimension for now + if (!numThreadsVars.empty()) + setMaxValueFromClause(numThreadsVars[0], maxThreadsVal); + } else if (castOrGetParentOfType(capturedOp, + /*immediateParent=*/true)) maxThreadsVal = 1; // For max values, < 0 means unset, == 0 means set but unknown. Select the @@ -6773,10 +6792,11 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder, omp::LoopNestOp loopOp = castOrGetParentOfType(capturedOp); unsigned numLoops = loopOp ? loopOp.getNumLoops() : 0; - Value numThreads, numTeamsLower, numTeamsUpper, teamsThreadLimit; + Value numTeamsLower, numTeamsUpper, teamsThreadLimit; + llvm::SmallVector numThreadsVars; llvm::SmallVector lowerBounds(numLoops), upperBounds(numLoops), steps(numLoops); - extractHostEvalClauses(targetOp, numThreads, numTeamsLower, numTeamsUpper, + extractHostEvalClauses(targetOp, numThreadsVars, numTeamsLower, numTeamsUpper, teamsThreadLimit, &lowerBounds, &upperBounds, &steps); // TODO: Handle constant 'if' clauses. @@ -6801,8 +6821,9 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder, attrs.TeamsThreadLimit.front() = builder.CreateSExtOrTrunc( moduleTranslation.lookupValue(teamsThreadLimit), builder.getInt32Ty()); - if (numThreads) - attrs.MaxThreads = moduleTranslation.lookupValue(numThreads); + // Handle multi-dimensional num_threads (only first value for now) + if (!numThreadsVars.empty()) + attrs.MaxThreads = moduleTranslation.lookupValue(numThreadsVars[0]); if (omp::bitEnumContainsAny(targetOp.getKernelExecFlags(capturedOp), omp::TargetRegionFlags::trip_count)) { diff --git a/mlir/test/Target/LLVMIR/openmp-todo.mlir b/mlir/test/Target/LLVMIR/openmp-todo.mlir index e0872226531e6..1d85806bfaf55 100644 --- a/mlir/test/Target/LLVMIR/openmp-todo.mlir +++ b/mlir/test/Target/LLVMIR/openmp-todo.mlir @@ -457,8 +457,8 @@ llvm.func @teams_num_teams_multi_dim(%lb : i32, %ub : i32) { // ----- -llvm.func @parallel_num_threads_multi_dim(%lb : i32, %ub : i32) { - // expected-error@below {{not yet implemented: Unhandled clause num_threads with multi-dimensional values in omp.parallel operation}} +llvm.func @parallel_num_threads_multi_dim_standalone(%lb : i32, %ub : i32) { + // expected-error@below {{not yet implemented: Unhandled clause num_threads with multi-dimensional values outside target region in omp.parallel operation}} // expected-error@below {{LLVM Translation failed for operation: omp.parallel}} omp.parallel num_threads(%lb, %ub : i32, i32) { omp.terminator @@ -468,6 +468,17 @@ llvm.func @parallel_num_threads_multi_dim(%lb : i32, %ub : i32) { // ----- +llvm.func @parallel_num_threads_too_many_dims(%lb : i32, %ub : i32) { + // expected-error@below {{not yet implemented: Unhandled clause num_threads with more than 3 dimensions in omp.parallel operation}} + // expected-error@below {{LLVM Translation failed for operation: omp.parallel}} + omp.parallel num_threads(%lb, %ub, %lb, %ub : i32, i32, i32, i32) { + omp.terminator + } + llvm.return +} + +// ----- + llvm.func @teams_thread_limit_multi_dim(%lb : i32, %ub : i32) { // expected-error@below {{not yet implemented: Unhandled clause thread_limit with multi-dimensional values in omp.teams operation}} // expected-error@below {{LLVM Translation failed for operation: omp.teams}}