Skip to content

Commit

Permalink
Add direct operator calls in debug mode (NVIDIA#3734)
Browse files Browse the repository at this point in the history
* Add base backend implementation of eager operators
* Add backend implementation of PipelineDebug managing backend operators
* Add OperatorManager util class for debug mode
* Replace minipipelines in debug mode by eager operators

Signed-off-by: ksztenderski <ksztenderski@nvidia.com>
  • Loading branch information
ksztenderski authored and cyyever committed May 13, 2022
1 parent b289133 commit f960d59
Show file tree
Hide file tree
Showing 10 changed files with 999 additions and 245 deletions.
10 changes: 1 addition & 9 deletions dali/pipeline/executor/executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
#include "dali/pipeline/graph/op_graph_verifier.h"
#include "dali/pipeline/operator/batch_size_provider.h"
#include "dali/pipeline/operator/common.h"
#include "dali/pipeline/util/batch_utils.h"
#include "dali/pipeline/util/event_pool.h"
#include "dali/pipeline/util/stream_pool.h"
#include "dali/pipeline/util/thread_pool.h"
Expand Down Expand Up @@ -361,15 +362,6 @@ class DLL_PUBLIC Executor : public ExecutorBase, public QueuePolicy {
WorkspacePolicy ws_policy_;

private:
template <typename InputRef>
static bool SetDefaultLayoutIfNeeded(InputRef &in, const OpSchema &schema, int in_idx) {
if (!in.GetLayout().empty()) return false;
auto default_layout = schema.GetInputLayout(in_idx, in.shape().sample_dim(), in.GetLayout());
if (default_layout.empty()) return false;
in.SetLayout(default_layout);
return true;
}

template <typename Workspace>
void RunHelper(OpNode &op_node, Workspace &ws);

Expand Down
239 changes: 239 additions & 0 deletions dali/pipeline/operator/eager_operator.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,239 @@
// Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef DALI_PIPELINE_OPERATOR_EAGER_OPERATOR_H_
#define DALI_PIPELINE_OPERATOR_EAGER_OPERATOR_H_

#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "dali/core/cuda_stream_pool.h"
#include "dali/pipeline/data/tensor_list.h"
#include "dali/pipeline/operator/op_spec.h"
#include "dali/pipeline/operator/operator.h"
#include "dali/pipeline/util/backend2workspace_map.h"
#include "dali/pipeline/util/batch_utils.h"
#include "dali/pipeline/util/thread_pool.h"
#include "dali/pipeline/workspace/workspace.h"

namespace dali {

template <typename Backend>
std::shared_ptr<TensorList<Backend>> AsTensorList(const std::shared_ptr<TensorList<Backend>> &in) {
return in;
}

template <typename Backend>
std::shared_ptr<TensorList<Backend>> AsTensorList(
const std::shared_ptr<TensorVector<Backend>> &in) {
if (in->IsContiguous()) {
// Filled contiguous TensorVector, we can return TensorList directly.
return in->AsTensorList(false);
}

auto tl = std::make_shared<TensorList<Backend>>();
tl->Copy(*in);
return tl;
}

/**
* @brief Direct operator providing eager execution of an operator in Run.
*/
template <typename Backend>
class DLL_PUBLIC EagerOperator {
public:
DLL_PUBLIC inline EagerOperator(const OpSpec &spec)
: batch_size_(spec.GetArgument<int>("max_batch_size")),
op_spec_(spec),
op_(InstantiateOperator(spec)) {
num_outputs_ = op_spec_.GetSchema().CalculateOutputs(op_spec_) +
op_spec_.GetSchema().CalculateAdditionalOutputs(op_spec_);
}

// Runs operator using shared thread pool and shared CUDA stream.
template <typename InBackend, typename OutBackend>
DLL_PUBLIC std::vector<std::shared_ptr<TensorList<OutBackend>>> Run(
const std::vector<std::shared_ptr<TensorList<InBackend>>> &inputs,
const std::unordered_map<std::string, std::shared_ptr<TensorList<CPUBackend>>> &kwargs) {
DALI_FAIL("Unsupported backends in EagerOperator.Run().");
}

// Runs operator using specified thread pool.
template <typename InBackend, typename OutBackend>
DLL_PUBLIC std::vector<std::shared_ptr<TensorList<OutBackend>>> Run(
const std::vector<std::shared_ptr<TensorList<InBackend>>> &inputs,
const std::unordered_map<std::string, std::shared_ptr<TensorList<CPUBackend>>> &kwargs,
ThreadPool *tp) {
DALI_FAIL("Unsupported backends in EagerOperator.Run() with thread pool.");
}

// Runs operator using specified CUDA stream.
template <typename InBackend, typename OutBackend>
DLL_PUBLIC std::vector<std::shared_ptr<TensorList<OutBackend>>> Run(
const std::vector<std::shared_ptr<TensorList<InBackend>>> &inputs,
const std::unordered_map<std::string, std::shared_ptr<TensorList<CPUBackend>>> &kwargs,
CUDAStreamLease &cuda_stream) {
DALI_FAIL("Unsupported backends in EagerOperator.Run() with CUDA stream");
}

// Update shared thread pool used for all direct operators.
DLL_PUBLIC inline static void UpdateThreadPool(int num_threads, int device_id,
bool set_affinity) {
shared_thread_pool = std::make_unique<ThreadPool>(num_threads, device_id, set_affinity);
}

// Update shared CUDA stream used for all direct operators.
DLL_PUBLIC inline static void UpdateCudaStream(int device_id) {
if (device_id != CPU_ONLY_DEVICE_ID) {
DeviceGuard g(device_id);
shared_cuda_stream = CUDAStreamPool::instance().Get(device_id);
}
}

private:
template <typename InBackend, typename OutBackend, typename WSInputType, typename WSOutputType>
std::vector<std::shared_ptr<TensorList<OutBackend>>> RunImpl(
const std::vector<std::shared_ptr<TensorList<InBackend>>> &inputs,
const std::unordered_map<std::string, std::shared_ptr<TensorList<CPUBackend>>> &kwargs);

int batch_size_;
size_t num_outputs_;
workspace_t<Backend> ws_;
OpSpec op_spec_;
std::unique_ptr<OperatorBase> op_;

static CUDAStreamLease shared_cuda_stream;
static std::unique_ptr<ThreadPool> shared_thread_pool;
};

template <>
template <>
std::vector<std::shared_ptr<TensorList<CPUBackend>>> EagerOperator<CPUBackend>::Run(
const std::vector<std::shared_ptr<TensorList<CPUBackend>>> &inputs,
const std::unordered_map<std::string, std::shared_ptr<TensorList<CPUBackend>>> &kwargs,
ThreadPool *thread_pool) {
ws_.Clear();
ws_.SetThreadPool(thread_pool);

return RunImpl<CPUBackend, CPUBackend, TensorVector<CPUBackend>, TensorVector<CPUBackend>>(
inputs, kwargs);
}

template <>
template <>
std::vector<std::shared_ptr<TensorList<GPUBackend>>> EagerOperator<GPUBackend>::Run(
const std::vector<std::shared_ptr<TensorList<GPUBackend>>> &inputs,
const std::unordered_map<std::string, std::shared_ptr<TensorList<CPUBackend>>> &kwargs,
CUDAStreamLease &cuda_stream) {
ws_.Clear();
ws_.set_stream(cuda_stream);
auto output = RunImpl<GPUBackend, GPUBackend, TensorList<GPUBackend>, TensorList<GPUBackend>>(
inputs, kwargs);
CUDA_CALL(cudaStreamSynchronize(cuda_stream));
return output;
}

template <>
template <>
std::vector<std::shared_ptr<TensorList<GPUBackend>>> EagerOperator<MixedBackend>::Run(
const std::vector<std::shared_ptr<TensorList<CPUBackend>>> &inputs,
const std::unordered_map<std::string, std::shared_ptr<TensorList<CPUBackend>>> &kwargs,
CUDAStreamLease &cuda_stream) {
ws_.Clear();
ws_.set_stream(cuda_stream);
auto output = RunImpl<CPUBackend, GPUBackend, TensorVector<CPUBackend>, TensorList<GPUBackend>>(
inputs, kwargs);
CUDA_CALL(cudaStreamSynchronize(cuda_stream));
return output;
}

template <>
template <>
std::vector<std::shared_ptr<TensorList<CPUBackend>>> EagerOperator<CPUBackend>::Run(
const std::vector<std::shared_ptr<TensorList<CPUBackend>>> &inputs,
const std::unordered_map<std::string, std::shared_ptr<TensorList<CPUBackend>>> &kwargs) {
return Run<CPUBackend, CPUBackend>(inputs, kwargs, shared_thread_pool.get());
}

template <>
template <>
std::vector<std::shared_ptr<TensorList<GPUBackend>>> EagerOperator<GPUBackend>::Run(
const std::vector<std::shared_ptr<TensorList<GPUBackend>>> &inputs,
const std::unordered_map<std::string, std::shared_ptr<TensorList<CPUBackend>>> &kwargs) {
return Run<GPUBackend, GPUBackend>(inputs, kwargs, shared_cuda_stream);
}

template <>
template <>
std::vector<std::shared_ptr<TensorList<GPUBackend>>> EagerOperator<MixedBackend>::Run(
const std::vector<std::shared_ptr<TensorList<CPUBackend>>> &inputs,
const std::unordered_map<std::string, std::shared_ptr<TensorList<CPUBackend>>> &kwargs) {
return Run<CPUBackend, GPUBackend>(inputs, kwargs, shared_cuda_stream);
}

template <typename Backend>
template <typename InBackend, typename OutBackend, typename WSInputType, typename WSOutputType>
std::vector<std::shared_ptr<TensorList<OutBackend>>> EagerOperator<Backend>::RunImpl(
const std::vector<std::shared_ptr<TensorList<InBackend>>> &inputs,
const std::unordered_map<std::string, std::shared_ptr<TensorList<CPUBackend>>> &kwargs) {
// Convert and add inputs to the workspace.
for (size_t in_idx = 0; in_idx < inputs.size(); ++in_idx) {
auto tensor_in = std::make_shared<WSInputType>();
tensor_in->ShareData(*inputs[in_idx]);
SetDefaultLayoutIfNeeded(*tensor_in, op_spec_.GetSchema(), in_idx);
ws_.AddInput(tensor_in);
}

for (auto &arg : kwargs) {
ws_.AddArgumentInput(arg.first, arg.second);
}

std::vector<OutputDesc> output_desc{};
std::vector<std::shared_ptr<TensorList<OutBackend>>> outputs{};

outputs.reserve(num_outputs_);

for (size_t i = 0; i < num_outputs_; ++i) {
ws_.AddOutput(std::make_shared<WSOutputType>(batch_size_));
}

ws_.SetBatchSizes(batch_size_);

// Setup outputs.
if (op_->Setup(output_desc, ws_) && op_->CanInferOutputs()) {
for (size_t i = 0; i < num_outputs_; ++i) {
ws_.template Output<OutBackend>(i).Resize(output_desc[i].shape, output_desc[i].type);
}
}

op_->Run(ws_);

for (size_t i = 0; i < num_outputs_; ++i) {
outputs.push_back(AsTensorList<OutBackend>(ws_.template OutputPtr<OutBackend>(i)));
}

return outputs;
}

template <typename Backend>
std::unique_ptr<ThreadPool> EagerOperator<Backend>::shared_thread_pool =
std::make_unique<ThreadPool>(1, CPU_ONLY_DEVICE_ID, false);

template <typename Backend>
CUDAStreamLease EagerOperator<Backend>::shared_cuda_stream{};

} // namespace dali

#endif // DALI_PIPELINE_OPERATOR_EAGER_OPERATOR_H_
21 changes: 1 addition & 20 deletions dali/pipeline/operator/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include "dali/pipeline/operator/op_schema.h"
#include "dali/pipeline/operator/op_spec.h"
#include "dali/pipeline/operator/operator_factory.h"
#include "dali/pipeline/util/batch_utils.h"
#include "dali/pipeline/util/backend2workspace_map.h"
#include "dali/pipeline/workspace/device_workspace.h"
#include "dali/pipeline/workspace/sample_workspace.h"
Expand Down Expand Up @@ -61,26 +62,6 @@ const std::string kSeed = "seed"; // NOLINT
const std::string kDtype = "dtype"; // NOLINT
} // namespace arg_names

/**
* @brief Verifies that the inputs in the workspace satisfy the layout
* constraints imposed by the schema.
*/
template <typename Workspace>
inline void CheckInputLayouts(const Workspace &ws, const OpSpec &spec) {
auto &schema = spec.GetSchema();
for (int i = 0; i < spec.NumRegularInput(); ++i) {
if (ws.template InputIsType<CPUBackend>(i)) {
auto &input = ws.template Input<CPUBackend>(i);
(void) schema.GetInputLayout(i, input.shape().sample_dim(), input.GetLayout());
} else if (ws.template InputIsType<GPUBackend>(i)) {
auto &input = ws.template Input<GPUBackend>(i);
(void) schema.GetInputLayout(i, input.shape().sample_dim(), input.GetLayout());
} else {
DALI_FAIL(make_string("Input ", i, " has an unknown backend"));
}
}
}

/**
* @brief Baseclass for the basic unit of computation in the pipeline.
*
Expand Down
Loading

0 comments on commit f960d59

Please sign in to comment.