diff --git a/offload/DeviceRTL/src/Parallelism.cpp b/offload/DeviceRTL/src/Parallelism.cpp index aa5e74029ec3e..0ea2f89337fee 100644 --- a/offload/DeviceRTL/src/Parallelism.cpp +++ b/offload/DeviceRTL/src/Parallelism.cpp @@ -103,11 +103,10 @@ uint32_t determineNumberOfThreads(int32_t NumThreadsClause, extern "C" { -[[clang::always_inline]] void -__kmpc_parallel_spmd(IdentTy *ident, int32_t num_threads, void *fn, void **args, - const int64_t nargs, int32_t nt_strict = false, - int32_t nt_severity = severity_fatal, - const char *nt_message = nullptr) { +[[clang::always_inline]] void __kmpc_parallel_spmd_impl( + IdentTy *ident, int32_t num_threads, void *fn, void **args, + const int64_t nargs, int32_t nt_strict = false, + int32_t nt_severity = severity_fatal, const char *nt_message = nullptr) { uint32_t TId = mapping::getThreadIdInBlock(); uint32_t NumThreads = determineNumberOfThreads(num_threads, nt_strict, nt_severity, nt_message); @@ -163,7 +162,22 @@ __kmpc_parallel_spmd(IdentTy *ident, int32_t num_threads, void *fn, void **args, return; } -[[clang::always_inline]] void __kmpc_parallel_51( +[[clang::always_inline]] void __kmpc_parallel_spmd(IdentTy *ident, + int32_t num_threads, + void *fn, void **args, + const int64_t nargs) { + return __kmpc_parallel_spmd_impl(ident, num_threads, fn, args, nargs); +} + +[[clang::always_inline]] void __kmpc_parallel_spmd_60( + IdentTy *ident, int32_t num_threads, void *fn, void **args, + const int64_t nargs, int32_t nt_strict = false, + int32_t nt_severity = severity_fatal, const char *nt_message = nullptr) { + return __kmpc_parallel_spmd_impl(ident, num_threads, fn, args, nargs, + nt_strict, nt_severity, nt_message); +} + +[[clang::always_inline]] void __kmpc_parallel_impl( IdentTy *ident, int32_t, int32_t if_expr, int32_t num_threads, int proc_bind, void *fn, void *wrapper_fn, void **args, int64_t nargs, int32_t nt_strict = false, int32_t nt_severity = severity_fatal, @@ -198,8 +212,11 @@ __kmpc_parallel_spmd(IdentTy *ident, int32_t num_threads, void *fn, void **args, // This was moved to its own routine so it could be called directly // in certain situations to avoid resource consumption of unused // logic in parallel_51. - __kmpc_parallel_spmd(ident, num_threads, fn, args, nargs, nt_strict, - nt_severity, nt_message); + if (nt_strict) + __kmpc_parallel_spmd(ident, num_threads, fn, args, nargs); + else + __kmpc_parallel_spmd_60(ident, num_threads, fn, args, nargs, nt_strict, + nt_severity, nt_message); return; } @@ -308,14 +325,22 @@ __kmpc_parallel_spmd(IdentTy *ident, int32_t num_threads, void *fn, void **args, __kmpc_end_sharing_variables(); } +[[clang::always_inline]] void +__kmpc_parallel_51(IdentTy *ident, int32_t id, int32_t if_expr, + int32_t num_threads, int proc_bind, void *fn, + void *wrapper_fn, void **args, int64_t nargs) { + return __kmpc_parallel_impl(ident, id, if_expr, num_threads, proc_bind, fn, + wrapper_fn, args, nargs); +} + [[clang::always_inline]] void __kmpc_parallel_60( IdentTy *ident, int32_t id, int32_t if_expr, int32_t num_threads, int proc_bind, void *fn, void *wrapper_fn, void **args, int64_t nargs, int32_t nt_strict = false, int32_t nt_severity = severity_fatal, const char *nt_message = nullptr) { - return __kmpc_parallel_51(ident, id, if_expr, num_threads, proc_bind, fn, - wrapper_fn, args, nargs, nt_strict, nt_severity, - nt_message); + return __kmpc_parallel_impl(ident, id, if_expr, num_threads, proc_bind, fn, + wrapper_fn, args, nargs, nt_strict, nt_severity, + nt_message); } [[clang::noinline]] bool __kmpc_kernel_parallel(ParallelRegionFnTy *WorkFn) {