Skip to content

Commit

Permalink
Mark stage buffers as consumed with stream callback (NVIDIA#712)
Browse files Browse the repository at this point in the history
This fixes an issue of marking CPU buffers as free
before they are consumed by async GPU kernels
of Mixed stage (simialr thing for Support stage).

Short circuit ReleaseIdx for invalid idxs

Introduce kInvalidIdx in queue policy

Check if indexes are valid before processing.
  • Loading branch information
klecki authored and haoxintong committed Jul 16, 2019
1 parent 90d2653 commit 796a9f6
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 50 deletions.
35 changes: 16 additions & 19 deletions dali/pipeline/executor/executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ void Executor<WorkspacePolicy, QueuePolicy>::RunCPU() {
TimeRange tr("[Executor] RunCPU");

auto support_idx = QueuePolicy::AcquireIdxs(OpType::SUPPORT);
if (exec_error_ || QueuePolicy::IsStopSignaled()) {
if (exec_error_ || QueuePolicy::IsStopSignaled() || !QueuePolicy::AreValid(support_idx)) {
QueuePolicy::ReleaseIdxs(OpType::SUPPORT, support_idx);
return;
}
Expand Down Expand Up @@ -295,22 +295,21 @@ void Executor<WorkspacePolicy, QueuePolicy>::RunCPU() {
QueuePolicy::ReleaseIdxs(OpType::SUPPORT, support_idx);

auto cpu_idx = QueuePolicy::AcquireIdxs(OpType::CPU);
if (exec_error_ || QueuePolicy::IsStopSignaled()) {
if (exec_error_ || QueuePolicy::IsStopSignaled() || !QueuePolicy::AreValid(cpu_idx)) {
QueuePolicy::ReleaseIdxs(OpType::CPU, cpu_idx);
return;
}
auto queue_idx = cpu_idx;

// Run the cpu-ops in the thread pool
for (int i = 0; i < batch_size_; ++i) {
thread_pool_.DoWorkWithID(std::bind(
[this, queue_idx] (int data_idx, int tid) {
[this, cpu_idx] (int data_idx, int tid) {
TimeRange tr("[Executor] RunCPU on " + to_string(data_idx));
SampleWorkspace ws;
for (int j = 0; j < graph_->NumOp(OpType::CPU); ++j) {
OpNode &op_node = graph_->Node(OpType::CPU, j);
OperatorBase &op = *op_node.op;
WorkspacePolicy::template GetWorkspace<OpType::CPU>(queue_idx, *graph_, op_node)
WorkspacePolicy::template GetWorkspace<OpType::CPU>(cpu_idx, *graph_, op_node)
.GetSample(&ws, data_idx, tid);
TimeRange tr("[Executor] Run CPU op " + op_node.instance_name
+ " on " + to_string(data_idx),
Expand Down Expand Up @@ -338,18 +337,17 @@ void Executor<WorkspacePolicy, QueuePolicy>::RunMixed() {
DeviceGuard g(device_id_);

auto mixed_idx = QueuePolicy::AcquireIdxs(OpType::MIXED);
if (exec_error_ || QueuePolicy::IsStopSignaled()) {
if (exec_error_ || QueuePolicy::IsStopSignaled() || !QueuePolicy::AreValid(mixed_idx)) {
QueuePolicy::ReleaseIdxs(OpType::MIXED, mixed_idx);
return;
}
auto queue_idx = mixed_idx;

try {
for (int i = 0; i < graph_->NumOp(OpType::MIXED); ++i) {
OpNode &op_node = graph_->Node(OpType::MIXED, i);
OperatorBase &op = *op_node.op;
typename WorkspacePolicy::template ws_t<OpType::MIXED> ws =
WorkspacePolicy::template GetWorkspace<OpType::MIXED>(queue_idx, *graph_, i);
WorkspacePolicy::template GetWorkspace<OpType::MIXED>(mixed_idx, *graph_, i);
TimeRange tr("[Executor] Run Mixed op " + op_node.instance_name,
TimeRange::kOrange);
op.Run(&ws);
Expand All @@ -366,19 +364,18 @@ void Executor<WorkspacePolicy, QueuePolicy>::RunMixed() {
}

// Pass the work to the gpu stage
QueuePolicy::ReleaseIdxs(OpType::MIXED, mixed_idx);
QueuePolicy::ReleaseIdxs(OpType::MIXED, mixed_idx, mixed_op_stream_);
}

template <typename WorkspacePolicy, typename QueuePolicy>
void Executor<WorkspacePolicy, QueuePolicy>::RunGPU() {
TimeRange tr("[Executor] RunGPU");

auto gpu_idx = QueuePolicy::AcquireIdxs(OpType::GPU);
if (exec_error_ || QueuePolicy::IsStopSignaled()) {
if (exec_error_ || QueuePolicy::IsStopSignaled() || !QueuePolicy::AreValid(gpu_idx)) {
QueuePolicy::ReleaseIdxs(OpType::GPU, gpu_idx);
return;
}
auto queue_idx = gpu_idx;
DeviceGuard g(device_id_);

// Enforce our assumed dependency between consecutive
Expand All @@ -396,7 +393,7 @@ void Executor<WorkspacePolicy, QueuePolicy>::RunGPU() {
OpNode &op_node = graph_->Node(OpType::GPU, i);
OperatorBase &op = *op_node.op;
typename WorkspacePolicy::template ws_t<OpType::GPU> ws =
WorkspacePolicy::template GetWorkspace<OpType::GPU>(queue_idx, *graph_, i);
WorkspacePolicy::template GetWorkspace<OpType::GPU>(gpu_idx, *graph_, i);
auto parent_events = ws.ParentEvents();

for (auto &event : parent_events) {
Expand All @@ -418,14 +415,14 @@ void Executor<WorkspacePolicy, QueuePolicy>::RunGPU() {
int src_idx = graph_->NodeIdx(src_id);

// Record events for each output requested by the user
cudaEvent_t event = gpu_output_events_[i].GetEvent(queue_idx[OpType::GPU]);
cudaEvent_t event = gpu_output_events_[i].GetEvent(gpu_idx[OpType::GPU]);
if (graph_->NodeType(src_id) == OpType::MIXED) {
typename WorkspacePolicy::template ws_t<OpType::MIXED> ws =
WorkspacePolicy::template GetWorkspace<OpType::MIXED>(queue_idx, *graph_, src_idx);
WorkspacePolicy::template GetWorkspace<OpType::MIXED>(gpu_idx, *graph_, src_idx);
CUDA_CALL(cudaEventRecord(event, ws.stream()));
} else if (graph_->NodeType(src_id) == OpType::GPU) {
typename WorkspacePolicy::template ws_t<OpType::GPU> ws =
WorkspacePolicy::template GetWorkspace<OpType::GPU>(queue_idx, *graph_, src_idx);
WorkspacePolicy::template GetWorkspace<OpType::GPU>(gpu_idx, *graph_, src_idx);
CUDA_CALL(cudaEventRecord(event, ws.stream()));
} else {
DALI_FAIL("Internal error. Output node is not gpu/mixed");
Expand All @@ -439,17 +436,17 @@ void Executor<WorkspacePolicy, QueuePolicy>::RunGPU() {
// Let the ReleaseIdx chain wake the output cv
}
// Update the ready queue to signal that all the work
// in the `queue_idx` set of output buffers has been
// in the `gpu_idx` set of output buffers has been
// issued. Notify any waiting threads.


// We do not release, but handle to used outputs
QueuePolicy::QueueOutputIdxs(gpu_idx);
QueuePolicy::QueueOutputIdxs(gpu_idx, gpu_op_stream_);

// Save the queue_idx so we can enforce the
// Save the gpu_idx so we can enforce the
// dependency between consecutive iterations
// of the gpu stage of the pipeline.
previous_gpu_queue_idx_ = queue_idx[OpType::GPU];
previous_gpu_queue_idx_ = gpu_idx[OpType::GPU];

// call any registered previously callback
if (callback_) {
Expand Down
104 changes: 73 additions & 31 deletions dali/pipeline/executor/queue_policy.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@
#ifndef DALI_PIPELINE_EXECUTOR_QUEUE_POLICY_H_
#define DALI_PIPELINE_EXECUTOR_QUEUE_POLICY_H_

#include <cuda_runtime_api.h>
#include <atomic>
#include <condition_variable>
#include <mutex>
#include <queue>
#include <vector>

#include "dali/pipeline/executor/queue_metadata.h"

Expand All @@ -34,6 +36,8 @@ namespace dali {
// QueueIdxs AcquireIdxs(OpType stage);
// // Finish stage and release the indexes. Not called by the last stage, as it "returns" outputs
// void ReleaseIdxs(OpType stage, QueueIdxs idxs);
// // Check if acquired indexes are valid
// bool AreValid(QueueIdxs idxs);
// // Called by the last stage - mark the Queue idxs as ready to be used as output
// void QueueOutputIdxs(QueueIdxs idxs);
// // Get the indexes of ready outputs and mark them as in_use by the user
Expand All @@ -49,6 +53,7 @@ namespace dali {

// Each stage requires ready buffers from previous stage and free buffers from current stage
struct UniformQueuePolicy {
static const int kInvalidIdx = -1;
static bool IsUniformPolicy() {
return true;
}
Expand All @@ -74,7 +79,7 @@ struct UniformQueuePolicy {
return !free_queue_.empty() || stage_work_stop_[static_cast<int>(stage)];
});
if (stage_work_stop_[static_cast<int>(stage)]) {
return QueueIdxs{-1}; // We return anything due to exec error
return QueueIdxs{kInvalidIdx}; // We return anything due to exec error
}
int queue_idx = free_queue_.front();
free_queue_.pop();
Expand All @@ -90,15 +95,23 @@ struct UniformQueuePolicy {
return QueueIdxs{queue_idx};
}

void ReleaseIdxs(OpType stage, QueueIdxs idxs) {
void ReleaseIdxs(OpType stage, QueueIdxs idxs, cudaStream_t = 0) {
if (idxs[stage] == kInvalidIdx) {
return;
}
if (HasNextStage(stage)) {
auto next_stage = NextStage(stage);
std::lock_guard<std::mutex> lock(stage_work_mutex_[static_cast<int>(next_stage)]);
stage_work_queue_[static_cast<int>(next_stage)].push(idxs[stage]);
}
}

void QueueOutputIdxs(QueueIdxs idxs) {
bool AreValid(QueueIdxs idxs) {
return idxs[OpType::SUPPORT] != kInvalidIdx && idxs[OpType::CPU] != kInvalidIdx &&
idxs[OpType::MIXED] != kInvalidIdx && idxs[OpType::GPU] != kInvalidIdx;
}

void QueueOutputIdxs(QueueIdxs idxs, cudaStream_t = 0) {
// We have to give up the elements to be occupied
{
std::lock_guard<std::mutex> lock(ready_mutex_);
Expand All @@ -115,7 +128,7 @@ struct UniformQueuePolicy {
return !ready_queue_.empty() || ready_stop_;
});
if (ready_stop_) {
return OutputIdxs{-1};
return OutputIdxs{kInvalidIdx};
}
int output_idx = ready_queue_.front();
ready_queue_.pop();
Expand Down Expand Up @@ -178,9 +191,21 @@ struct UniformQueuePolicy {
std::atomic<bool> ready_stop_ = {false};
};

struct SeparateQueuePolicy;

struct ReleaseCommand {
SeparateQueuePolicy *policy;
OpType stage;
int idx;
};

static void release_callback(cudaStream_t stream, cudaError_t status, void *userData);

// Ready buffers from previous stage imply that we can process corresponding buffers from current
// stage
struct SeparateQueuePolicy {
static const int kInvalidIdx = -1;

static bool IsUniformPolicy() {
return false;
}
Expand All @@ -192,6 +217,8 @@ struct SeparateQueuePolicy {
stage_free_[stage].push(i);
}
}
support_release_commands_.resize(stage_queue_depths[static_cast<int>(OpType::SUPPORT)]);
cpu_release_commands_.resize(stage_queue_depths[static_cast<int>(OpType::CPU)]);
}

QueueIdxs AcquireIdxs(OpType stage) {
Expand All @@ -207,7 +234,7 @@ struct SeparateQueuePolicy {
return !stage_ready_[previous_stage].empty() || stage_ready_stop_[previous_stage];
});
if (stage_ready_stop_[previous_stage]) {
return QueueIdxs{-1};
return QueueIdxs{kInvalidIdx};
}
// We fill the information about all the previous stages herew
result = stage_ready_[previous_stage].front();
Expand All @@ -221,7 +248,7 @@ struct SeparateQueuePolicy {
return !stage_free_[current_stage].empty() || stage_free_stop_[current_stage];
});
if (stage_free_stop_[current_stage]) {
return QueueIdxs{-1};
return QueueIdxs{kInvalidIdx};
}
// We add info about current stage
result[stage] = stage_free_[current_stage].front();
Expand All @@ -230,33 +257,42 @@ struct SeparateQueuePolicy {
return result;
}

void ReleaseIdxs(OpType stage, QueueIdxs idxs) {
void ReleaseIdxs(OpType stage, QueueIdxs idxs, cudaStream_t stage_stream = 0) {
if (idxs[stage] == kInvalidIdx) {
return;
}
int current_stage = static_cast<int>(stage);
// We have a special case for Support ops - they are set free by a GPU stage,
// during QueueOutputIdxs
if (stage != OpType::CPU) {
if (HasPreviousStage(stage)) {
ReleaseStageIdx(PreviousStage(stage), idxs);
}
// TODO(klecki) when we move to CUDA 10, we should move to cudaLaunchHostFunc
if (stage == OpType::MIXED) {
auto &command = cpu_release_commands_[idxs[OpType::CPU]];
command = ReleaseCommand{this, OpType::CPU, idxs[OpType::CPU]};
cudaStreamAddCallback(stage_stream, &release_callback, static_cast<void*>(&command), 0);
}
{
std::lock_guard<std::mutex> ready_current_lock(stage_ready_mutex_[current_stage]);
// stage_ready_[current_stage].push(idxs[stage]);
// Store the idxs up to the point of stage that we processed
stage_ready_[current_stage].push(idxs);
}
stage_ready_cv_[current_stage].notify_one();
}

void QueueOutputIdxs(QueueIdxs idxs) {
bool AreValid(QueueIdxs idxs) {
return idxs[OpType::SUPPORT] != kInvalidIdx && idxs[OpType::CPU] != kInvalidIdx &&
idxs[OpType::MIXED] != kInvalidIdx && idxs[OpType::GPU] != kInvalidIdx;
}


void QueueOutputIdxs(QueueIdxs idxs, cudaStream_t gpu_op_stream) {
{
std::lock_guard<std::mutex> ready_output_lock(ready_output_mutex_);
ready_output_queue_.push({idxs[OpType::MIXED], idxs[OpType::GPU]});
}
ready_output_cv_.notify_all();

// In case of GPU we release also the Support Op
ReleaseStageIdx(OpType::SUPPORT, idxs);
auto &command = support_release_commands_[idxs[OpType::SUPPORT]];
command = ReleaseCommand{this, OpType::SUPPORT, idxs[OpType::SUPPORT]};
cudaStreamAddCallback(gpu_op_stream, &release_callback, static_cast<void*>(&command), 0);
}

OutputIdxs UseOutputIdxs() {
Expand All @@ -267,7 +303,7 @@ struct SeparateQueuePolicy {
return !ready_output_queue_.empty() || ready_stop_;
});
if (ready_stop_) {
return OutputIdxs{-1, -1};
return OutputIdxs{kInvalidIdx, kInvalidIdx};
}
auto output_idx = ready_output_queue_.front();
ready_output_queue_.pop();
Expand All @@ -282,22 +318,12 @@ struct SeparateQueuePolicy {
// Mark the last in-use buffer as free and signal
// to waiting threads
if (!in_use_queue_.empty()) {
auto mixed_idx = static_cast<int>(OpType::MIXED);
auto gpu_idx = static_cast<int>(OpType::GPU);
// TODO(klecki): in_use_queue should be guarded, but we assume it is used only in synchronous
// python calls
auto processed = in_use_queue_.front();
in_use_queue_.pop();
{
std::lock_guard<std::mutex> lock(stage_free_mutex_[mixed_idx]);
stage_free_[mixed_idx].push(processed.mixed);
}
stage_free_cv_[mixed_idx].notify_one();
{
std::lock_guard<std::mutex> lock(stage_free_mutex_[gpu_idx]);
stage_free_[gpu_idx].push(processed.gpu);
}
stage_free_cv_[gpu_idx].notify_one();
ReleaseStageIdx(OpType::MIXED, processed.mixed);
ReleaseStageIdx(OpType::GPU, processed.gpu);
}
}

Expand Down Expand Up @@ -335,17 +361,23 @@ struct SeparateQueuePolicy {
}

private:
void ReleaseStageIdx(OpType stage, QueueIdxs idxs) {
friend void release_callback(cudaStream_t stream, cudaError_t status, void *userData);

void ReleaseStageIdx(OpType stage, int idx) {
auto released_stage = static_cast<int>(stage);
// We release the consumed buffer
{
std::lock_guard<std::mutex> free_lock(stage_free_mutex_[released_stage]);
stage_free_[released_stage].push(idxs[stage]);
stage_free_[released_stage].push(idx);
}
// We freed buffer, so we notfiy the released stage it can continue it's work
stage_free_cv_[released_stage].notify_one();
}

void ReleaseStageIdx(OpType stage, QueueIdxs idxs) {
ReleaseStageIdx(stage, idxs[stage]);
}

static const int kOpCount = static_cast<int>(OpType::COUNT);
// For syncing free and ready buffers between stages
std::array<std::mutex, kOpCount> stage_free_mutex_;
Expand Down Expand Up @@ -376,8 +408,18 @@ struct SeparateQueuePolicy {

std::queue<OutputIdxs> ready_output_queue_;
std::queue<OutputIdxs> in_use_queue_;
std::vector<ReleaseCommand> support_release_commands_;
std::vector<ReleaseCommand> cpu_release_commands_;
};

// void (CUDART_CB *cudaStreamCallback_t)(cudaStream_t stream, cudaError_t status, void *userData);
void release_callback(cudaStream_t stream, cudaError_t status, void *userData) {
auto command = static_cast<ReleaseCommand*>(userData);
command->policy->ReleaseStageIdx(command->stage, command->idx);
}



} // namespace dali

#endif // DALI_PIPELINE_EXECUTOR_QUEUE_POLICY_H_

0 comments on commit 796a9f6

Please sign in to comment.