Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ static inline void ExecuteLambdaInParallel(TLambda lambda, int max, int step, do
}
#else
const int total_tasks = max / (step > 0 ? step : 1) + (max % step > 0 ? 1 : 0);
concurrency::ThreadPool::TryParallelFor(ttp, total_tasks, cost, [lambda, step](ptrdiff_t first, ptrdiff_t last) {
concurrency::ThreadPool::TryParallelFor(ttp, total_tasks, cost, [&lambda, step](ptrdiff_t first, ptrdiff_t last) {
for (int i = static_cast<int>(first), end = static_cast<int>(last); i < end; ++i) {
lambda(i * step);
}
Expand Down
23 changes: 14 additions & 9 deletions onnxruntime/core/providers/cpu/tensor/gather.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
//https://github.com/onnx/onnx/blob/master/docs/Operators.md#Gather
#include "core/providers/cpu/tensor/gather.h"
#include "core/common/common.h"
#include "core/platform/threadpool.h"

namespace onnxruntime {

Expand Down Expand Up @@ -57,11 +58,10 @@ template <typename Tin>
Status GatherCopyData(const Tensor* indices_tensor, const uint8_t* src_base, uint8_t* dst_base, bool is_string_type,
const size_t element_bytes, const int64_t block_size, const int64_t M,
const int64_t N, const int64_t data_batch_bytes, const int64_t gathered_batch_bytes,
const TensorShape& input_data_shape, const int64_t axis) {
const TensorShape& input_data_shape, const int64_t axis, concurrency::ThreadPool* tp) {
const Tin* indices_data = indices_tensor->template Data<Tin>();

// Check the indices first in case there's a out of bound index.
// We can't merge this code in the omp loop below as omp does not allow return in the loop
auto axis_dim_limit = input_data_shape[axis];

for (int64_t i = 0; i < N; ++i) {
Expand All @@ -73,10 +73,7 @@ Status GatherCopyData(const Tensor* indices_tensor, const uint8_t* src_base, uin
}
}

#ifdef _OPENMP
#pragma omp parallel for
#endif
for (int64_t index = 0; index < M * N; ++index) {
auto lambda = [&](int64_t index) {
int64_t batch = index / N;
int64_t i = index % N;

Expand All @@ -93,7 +90,13 @@ Status GatherCopyData(const Tensor* indices_tensor, const uint8_t* src_base, uin
} else {
memcpy(dst_base + dst_offset, src_base + src_offset, block_size);
}
}
};
concurrency::ThreadPool::TryParallelFor(tp, M * N, static_cast<double>(block_size),
[&lambda](ptrdiff_t first, ptrdiff_t last) {
for (int index = static_cast<int>(first), end = static_cast<int>(last); index < end; ++index) {
lambda(index);
}
});

return Status::OK();
}
Expand All @@ -117,13 +120,15 @@ Status Gather::Compute(OpKernelContext* context) const {
const auto* src_base = static_cast<const uint8_t*>(p.input_tensor->DataRaw());
auto* dst_base = static_cast<uint8_t*>(p.output_tensor->MutableDataRaw());

concurrency::ThreadPool* tp = context->GetOperatorThreadPool();

if (p.indices_tensor->IsDataType<int32_t>()) {
return GatherCopyData<int32_t>(p.indices_tensor, src_base, dst_base, is_string_type, element_bytes,
block_size, M, N, data_batch_bytes, gathered_batch_bytes, input_data_shape, p.axis);
block_size, M, N, data_batch_bytes, gathered_batch_bytes, input_data_shape, p.axis, tp);
}
if (p.indices_tensor->IsDataType<int64_t>()) {
return GatherCopyData<int64_t>(p.indices_tensor, src_base, dst_base, is_string_type, element_bytes,
block_size, M, N, data_batch_bytes, gathered_batch_bytes, input_data_shape, p.axis);
block_size, M, N, data_batch_bytes, gathered_batch_bytes, input_data_shape, p.axis, tp);
}

return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "Type for Tind not supported yet in Gather.");
Expand Down
61 changes: 34 additions & 27 deletions onnxruntime/core/providers/cpu/tensor/gather_nd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT License.

#include "gather_nd.h"
#include "core/platform/threadpool.h"

namespace onnxruntime {

Expand Down Expand Up @@ -43,7 +44,7 @@ ONNX_CPU_OPERATOR_KERNEL(
GatherND);

template <typename Tind>
Status GatherNDBase::PrepareForCompute(OpKernelContext* context, Prepare& p) const {
Status GatherNDBase::PrepareForCompute(OpKernelContext* context, Prepare& p, concurrency::ThreadPool* tp) const {
const auto* input_tensor = context->Input<Tensor>(0);
const auto* indices_tensor = context->Input<Tensor>(1);
ORT_ENFORCE(input_tensor != nullptr && indices_tensor != nullptr, "GatherNDBase PrepareForCompute: Input count mismatch");
Expand Down Expand Up @@ -72,9 +73,6 @@ Status GatherNDBase::PrepareForCompute(OpKernelContext* context, Prepare& p) con
auto* output_tensor = context->Output(0, TensorShape(std::move(shape)));

std::vector<int64_t> sizes_from_slice_dims(num_slice_dims);
#ifdef _OPENMP
#pragma omp parallel for
#endif
for (int64_t i = 0; i < num_slice_dims; ++i) {
sizes_from_slice_dims[i] = input_shape.SizeFromDimension(batch_dims_ + i + 1);
}
Expand All @@ -95,10 +93,7 @@ Status GatherNDBase::PrepareForCompute(OpKernelContext* context, Prepare& p) con
}

// Compute the element_offset
#ifdef _OPENMP
#pragma omp parallel for
#endif
for (int64_t slice_idx = 0; slice_idx < num_slices; ++slice_idx) {
auto lambda = [&](int64_t slice_idx) {
const size_t batch_idx = slice_idx / num_slices_per_batch;
const size_t input_base_offset = batch_idx * input_batch_stride;

Expand All @@ -118,46 +113,58 @@ Status GatherNDBase::PrepareForCompute(OpKernelContext* context, Prepare& p) con
}

p.slice_offsets[slice_idx] = input_base_offset + relative_slice_offset;
}
};
concurrency::ThreadPool::TryParallelFor(tp, num_slices, static_cast<double>(num_slice_dims),
[&lambda](ptrdiff_t first, ptrdiff_t last) {
for (int slice_idx = static_cast<int>(first), end = static_cast<int>(last); slice_idx < end; ++slice_idx) {
lambda(slice_idx);
}
});

return err_index == 0 ? Status::OK()
: ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "invalid index found, index = ", err_index);
}

template Status GatherNDBase::PrepareForCompute<int32_t>(OpKernelContext*, Prepare&) const;
template Status GatherNDBase::PrepareForCompute<int64_t>(OpKernelContext*, Prepare&) const;
template Status GatherNDBase::PrepareForCompute<int32_t>(OpKernelContext*, Prepare&, concurrency::ThreadPool*) const;
template Status GatherNDBase::PrepareForCompute<int64_t>(OpKernelContext*, Prepare&, concurrency::ThreadPool*) const;

Status GatherND::Compute(OpKernelContext* context) const {
Prepare p;
concurrency::ThreadPool* tp = context->GetOperatorThreadPool();
ORT_RETURN_IF_ERROR(context->Input<Tensor>(1)->IsDataType<int32_t>()
? PrepareForCompute<int32_t>(context, p)
: PrepareForCompute<int64_t>(context, p));
? PrepareForCompute<int32_t>(context, p, tp)
: PrepareForCompute<int64_t>(context, p, tp));

return nullptr == p.input_str_base ? GatherNumber(p) : GatherString(p);
return nullptr == p.input_str_base ? GatherNumber(p, tp) : GatherString(p, tp);
}

Status GatherND::GatherNumber(const Prepare& p) const {
#ifdef _OPENMP
#pragma omp parallel for
#endif
for (int64_t slice_idx = 0; slice_idx < static_cast<int64_t>(p.slice_offsets.size()); ++slice_idx) {
Status GatherND::GatherNumber(const Prepare& p, concurrency::ThreadPool* tp) const {
auto lambda = [&](int64_t slice_idx) {
memcpy(p.output_base + slice_idx * p.bytes_per_slice, p.input_base + p.slice_offsets[slice_idx] * p.element_bytes,
p.bytes_per_slice);
}

};
concurrency::ThreadPool::TryParallelFor(tp, p.slice_offsets.size(), static_cast<double>(p.bytes_per_slice),
[&lambda](ptrdiff_t first, ptrdiff_t last) {
for (int slice_idx = static_cast<int>(first), end = static_cast<int>(last); slice_idx < end; ++slice_idx) {
lambda(slice_idx);
}
});
return Status::OK();
}

Status GatherND::GatherString(const Prepare& p) const {
#ifdef _OPENMP
#pragma omp parallel for
#endif
for (int64_t slice_idx = 0; slice_idx < static_cast<int64_t>(p.slice_offsets.size()); ++slice_idx) {
Status GatherND::GatherString(const Prepare& p, concurrency::ThreadPool* tp) const {
auto lambda = [&](int64_t slice_idx) {
const int64_t slice_base_offset = slice_idx * p.element_count_per_slice;
for (int64_t j = 0; j < static_cast<int64_t>(p.element_count_per_slice); ++j) {
p.output_str_base[slice_base_offset + j] = p.input_str_base[p.slice_offsets[slice_idx] + j];
}
}
};
concurrency::ThreadPool::TryParallelFor(tp, p.slice_offsets.size(), static_cast<double>(p.element_count_per_slice),
[&lambda](ptrdiff_t first, ptrdiff_t last) {
for (int slice_idx = static_cast<int>(first), end = static_cast<int>(last); slice_idx < end; ++slice_idx) {
lambda(slice_idx);
}
});

return Status::OK();
}
Expand Down
11 changes: 6 additions & 5 deletions onnxruntime/core/providers/cpu/tensor/gather_nd.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@

#include "core/common/common.h"
#include "core/framework/op_kernel.h"
#include "core/platform/threadpool.h"

namespace onnxruntime {

namespace concurrency {
class ThreadPool;
}
class GatherNDBase {
protected:
struct Prepare {
Expand All @@ -32,7 +33,7 @@ class GatherNDBase {
}; // struct Prepare

template <typename Tind>
Status PrepareForCompute(OpKernelContext* context, Prepare& p) const;
Status PrepareForCompute(OpKernelContext* context, Prepare& p, concurrency::ThreadPool* tp) const;
int64_t batch_dims_;
}; // class GatherNDBase

Expand All @@ -44,8 +45,8 @@ class GatherND final : public OpKernel, protected GatherNDBase {
Status Compute(OpKernelContext* context) const override;

private:
Status GatherNumber(const Prepare& p) const;
Status GatherString(const Prepare& p) const;
Status GatherNumber(const Prepare& p, concurrency::ThreadPool* tp) const;
Status GatherString(const Prepare& p, concurrency::ThreadPool* tp) const;
};

} // namespace onnxruntime
Loading