Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions docs/NotesOnThreading.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
This document is intended for ORT developers.

ORT allows the usage of either OpenMP or non-OpenMP (ORT) threads for execution. Threadpool management
is abstracted behind: (1) ThreadPool class in threadpool.h and (2) functions in thread_utils.h.
is abstracted behind: (1) ThreadPool class in [threadpool.h](https://github.com/microsoft/onnxruntime/blob/master/include/onnxruntime/core/platform/threadpool.h) and (2) functions in [thread_utils.h](https://github.com/microsoft/onnxruntime/blob/master/onnxruntime/core/util/thread_utils.h).

When developing an op, please use these abstractions to parallelize your code. These abstractions centralize 2 things.
When OpenMP is enabled, they resort to using OpenMP. When OpenMP is disabled they resort to sequential execution if the threadpool ptr is NULL or schedule the tasks on the threadpool otherwise.

Examples of these abstractions are: (threadpool.h has more documentation for these)
Examples of these abstractions are: ([threadpool.h](https://github.com/microsoft/onnxruntime/blob/master/include/onnxruntime/core/platform/threadpool.h) has more documentation for these)
* TryBatchParallelFor
* TryParallelFor
* TrySimpleParallelFor
* static version of NumThreads

**Please do not write #ifdef pragma omp in operator code**.
Expand Down
32 changes: 26 additions & 6 deletions include/onnxruntime/core/platform/threadpool.h
Original file line number Diff line number Diff line change
Expand Up @@ -190,11 +190,11 @@ class ThreadPool {

// Similar to ParallelFor above, but takes the specified scheduling strategy
// into account.
void
ParallelFor(std::ptrdiff_t total, const SchedulingParams& scheduling_params,
const std::function<void(std::ptrdiff_t, std::ptrdiff_t)>& fn);
void ParallelFor(std::ptrdiff_t total, const SchedulingParams& scheduling_params,
const std::function<void(std::ptrdiff_t, std::ptrdiff_t)>& fn);

static void TryParallelFor(concurrency::ThreadPool* tp, std::ptrdiff_t total, const SchedulingParams& scheduling_params,
static void TryParallelFor(concurrency::ThreadPool* tp, std::ptrdiff_t total,
const SchedulingParams& scheduling_params,
const std::function<void(std::ptrdiff_t first, std::ptrdiff_t last)>& fn) {
#ifdef _OPENMP
ORT_UNUSED_PARAMETER(scheduling_params);
Expand All @@ -216,7 +216,7 @@ class ThreadPool {
}
tp->ParallelFor(total, scheduling_params, fn);
#endif
} // namespace concurrency
}

// Prefer using this API to get the number of threads unless you know what you're doing.
// This API takes into account if openmp is enabled/disabled and if the thread pool ptr is nullptr.
Expand All @@ -236,7 +236,27 @@ class ThreadPool {

// Directly schedule the 'total' tasks to the underlying threadpool, without
// cutting them by halves
void SimpleParallelFor(std::ptrdiff_t total, std::function<void(std::ptrdiff_t)> fn);
void SimpleParallelFor(std::ptrdiff_t total, const std::function<void(std::ptrdiff_t)>& fn);

inline static void TrySimpleParallelFor(ThreadPool* tp, std::ptrdiff_t total,
const std::function<void(std::ptrdiff_t)>& fn) {
#ifdef _OPENMP
ORT_UNUSED_PARAMETER(tp);
#pragma omp parallel for
for (std::ptrdiff_t i = 0; i < total; ++i) {
fn(i);
}
#else
if (tp != nullptr) {
tp->SimpleParallelFor(total, fn);
} else {
for (std::ptrdiff_t i = 0; i < total; ++i) {
// In many cases, fn can be inlined here.
fn(i);
}
}
#endif
}

/**
* Tries to call the given function in parallel, with calls split into (num_batches) batches.
Expand Down
5 changes: 3 additions & 2 deletions onnxruntime/core/common/threadpool.cc
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,8 @@ ThreadPool::ThreadPool(Eigen::ThreadPoolInterface* user_threadpool, Eigen::Alloc
}

ThreadPool::~ThreadPool() = default;
void ThreadPool::SimpleParallelFor(std::ptrdiff_t total, std::function<void(std::ptrdiff_t)> fn) {

void ThreadPool::SimpleParallelFor(std::ptrdiff_t total, const std::function<void(std::ptrdiff_t)>& fn) {
if (total <= 0)
return;

Expand Down Expand Up @@ -320,4 +321,4 @@ Eigen::ThreadPoolInterface* ThreadPool::AsEigenThreadPool() const {
return underlying_threadpool_;
}
} // namespace concurrency
} // namespace onnxruntime
} // namespace onnxruntime
5 changes: 3 additions & 2 deletions onnxruntime/core/providers/cpu/math/top_k.cc
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ static void FindTopKElements(const Tensor* input, const TensorShape& input_shape
const int64_t num_blocks = input_shape[axis_parsed];
const int64_t block_slice = reduced_cols / k;

int64_t tp_threads = threadpool != nullptr ? threadpool->NumThreads() : 1;
int64_t tp_threads = concurrency::ThreadPool::NumThreads(threadpool);
int64_t num_threads = std::min(tp_threads, rows); // split on rows so can't have more threads than rows

// rough attempt to make sure there's enough work for each thread. if there's insufficient work the usage of
Expand Down Expand Up @@ -326,7 +326,8 @@ static void FindTopKElements(const Tensor* input, const TensorShape& input_shape
// we want to re-use the storage variables in each lambda as much as possible to minimize allocations
// on each iteration, so the lambda does multiple rows. e.g. the data_holder and indices_data vectors.
// the alternative would be to use TryBatchParallelFor with the lambda doing one row.
threadpool->SimpleParallelFor(num_threads, find_top_k);
// Use TrySimpleParallelFor so openmp is supported correctly
concurrency::ThreadPool::TrySimpleParallelFor(threadpool, num_threads, find_top_k);
}
}

Expand Down