Skip to content

Commit

Permalink
cuda graph enhancement (#19636)
Browse files Browse the repository at this point in the history
### Description
<!-- Describe your changes. -->

1. add a config key in run_options to control cuda graph in runtime.
2. enhance cuda graph class to support mutiple graph saving and
retrieving in one ORT session
3. provide model modification/inference example on Phi2
4. benchmark shows an average of 13% latency reduction in token
generation.



limitation: TRT ep and ROCM ep hasn't applied this feature. we can
revisit this in the future.

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
  • Loading branch information
wangyems committed Mar 7, 2024
1 parent bff4f8b commit 72ce4de
Show file tree
Hide file tree
Showing 23 changed files with 766 additions and 177 deletions.
14 changes: 7 additions & 7 deletions include/onnxruntime/core/framework/execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -202,21 +202,21 @@ class IExecutionProvider {

/**
Indicate whether the graph capturing mode (e.g., cuda graph) is enabled for
the provider. Currently only CUDA execution provider supports it.
the provider.
*/
virtual bool IsGraphCaptureEnabled() const { return false; }

/**
Indicate whether the graph has been captured and instantiated. Currently
only CUDA execution provider supports it.
Indicate whether the graph has been captured and instantiated.
*/
virtual bool IsGraphCaptured() const { return false; }
virtual bool IsGraphCaptured(int /*graph_annotation_id*/) const { return false; }

/**
Run the instantiated graph. Currently only CUDA execution provider supports
it.
Run the instantiated graph.
*/
virtual common::Status ReplayGraph() { return Status::OK(); }
virtual common::Status ReplayGraph(int /*graph_annotation_id*/) {
return Status::OK();
}

/**
Called when session creation is complete
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,10 @@ static const char* const kOrtRunOptionsConfigQnnPerfModePostRun = "qnn.htp_perf_

// Set RPC control latency for QNN HTP backend
static const char* const kOrtRunOptionsConfigQnnRpcControlLatency = "qnn.rpc_control_latency";

// Set graph annotation id for CUDA EP. Use with enable_cuda_graph=true.
// The value should be an integer. If the value is not set, the default value is 0 and
// ORT session only captures one cuda graph before another capture is requested.
// If the value is set to -1, cuda graph capture/replay is disabled in that run.
// User are not expected to set the value to 0 as it is reserved for internal use.
static const char* const kOrtRunOptionsConfigCudaGraphAnnotation = "gpu_graph_id";
74 changes: 49 additions & 25 deletions onnxruntime/core/providers/cuda/cuda_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// Licensed under the MIT License.

#include "core/common/inlined_containers.h"
#include "core/common/parse_string.h"
#include "core/providers/shared_library/provider_api.h"
#include "core/platform/env_var_utils.h"
#include "core/providers/cuda/cuda_execution_provider.h"
Expand All @@ -11,6 +12,7 @@
#include "core/providers/cuda/cuda_fwd.h"
#include "core/providers/cuda/gpu_data_transfer.h"
#include "core/providers/cuda/cuda_profiler.h"
#include "core/session/onnxruntime_run_options_config_keys.h"

#ifndef USE_CUDA_MINIMAL
#ifndef DISABLE_CONTRIB_OPS
Expand Down Expand Up @@ -190,27 +192,46 @@ CUDAExecutionProvider::PerThreadContext::~PerThreadContext() {
#endif
}

bool CUDAExecutionProvider::PerThreadContext::IsGraphCaptureAllowed() const {
return regular_run_count_before_graph_capture_ >= min_num_runs_before_cuda_graph_capture_;
bool CUDAExecutionProvider::PerThreadContext::IsGraphCaptureAllowed(
CudaGraphAnnotation_t cuda_graph_annotation_id) const {
return regular_run_count_before_graph_capture_ >= min_num_runs_before_cuda_graph_capture_ &&
IsGraphCaptureAllowedOnRun(cuda_graph_annotation_id);
}

void CUDAExecutionProvider::PerThreadContext::CaptureBegin() {
cuda_graph_.Reset();
cuda_graph_.CaptureBegin();
bool CUDAExecutionProvider::PerThreadContext::IsGraphCaptureAllowedOnRun(
CudaGraphAnnotation_t cuda_graph_annotation_id) const {
return cuda_graph_.IsGraphCaptureAllowedOnRun(cuda_graph_annotation_id);
}

void CUDAExecutionProvider::PerThreadContext::CaptureEnd() {
cuda_graph_.CaptureEnd();
is_graph_captured_ = true;
CudaGraphAnnotation_t CUDAExecutionProvider::PerThreadContext::GetCudaGraphAnnotationId(
const onnxruntime::RunOptions& run_options) const {
auto graph_annotation_str =
run_options.GetConfigOptions().GetConfigEntry(kOrtRunOptionsConfigCudaGraphAnnotation);
// If graph annotation is not provided, fall back to the one cuda graph per session behavior
CudaGraphAnnotation_t cuda_graph_annotation_id = 0;
if (graph_annotation_str.has_value()) {
ORT_ENFORCE(TryParseStringWithClassicLocale<int>(*graph_annotation_str, cuda_graph_annotation_id),
"Failed to parse the cuda graph annotation id: ",
*graph_annotation_str);
}

return cuda_graph_annotation_id;
}

void CUDAExecutionProvider::PerThreadContext::CaptureBegin(CudaGraphAnnotation_t cuda_graph_annotation_id) {
cuda_graph_.CaptureBegin(cuda_graph_annotation_id);
}

void CUDAExecutionProvider::PerThreadContext::CaptureEnd(CudaGraphAnnotation_t cuda_graph_annotation_id) {
cuda_graph_.CaptureEnd(cuda_graph_annotation_id);
}

bool CUDAExecutionProvider::PerThreadContext::IsGraphCaptured() const {
return is_graph_captured_;
bool CUDAExecutionProvider::PerThreadContext::IsGraphCaptured(CudaGraphAnnotation_t graph_annotation_id) const {
return cuda_graph_.IsGraphCaptured(graph_annotation_id);
}

Status CUDAExecutionProvider::PerThreadContext::ReplayGraph() {
ORT_ENFORCE(IsGraphCaptured());
return cuda_graph_.Replay();
Status CUDAExecutionProvider::PerThreadContext::ReplayGraph(CudaGraphAnnotation_t graph_annotation_id) {
return cuda_graph_.Replay(graph_annotation_id);
}

void CUDAExecutionProvider::PerThreadContext::IncrementRegularRunCountBeforeGraphCapture() {
Expand Down Expand Up @@ -386,23 +407,26 @@ Status CUDAExecutionProvider::Sync() const {
return Status::OK();
}

Status CUDAExecutionProvider::OnRunStart(const onnxruntime::RunOptions& /*run_options*/) {
Status CUDAExecutionProvider::OnRunStart(const onnxruntime::RunOptions& run_options) {
// always set CUDA device when session::Run() in case it runs in a worker thread
CUDA_RETURN_IF_ERROR(cudaSetDevice(GetDeviceId()));
if (IsGraphCaptureEnabled() && GetPerThreadContext().IsGraphCaptureAllowed() && !GetPerThreadContext().IsGraphCaptured()) {
CudaGraphAnnotation_t cuda_graph_annotation_id = GetPerThreadContext().GetCudaGraphAnnotationId(run_options);
if (IsGraphCaptureEnabled() && !GetPerThreadContext().IsGraphCaptured(cuda_graph_annotation_id) &&
GetPerThreadContext().IsGraphCaptureAllowed(cuda_graph_annotation_id)) {
LOGS(*GetLogger(), INFO) << "Capturing the cuda graph for this model";
GetPerThreadContext().CaptureBegin();
GetPerThreadContext().CaptureBegin(cuda_graph_annotation_id);
}
return Status::OK();
}

Status CUDAExecutionProvider::OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& /*run_options*/) {
if (IsGraphCaptureEnabled() && !GetPerThreadContext().IsGraphCaptured()) {
if (GetPerThreadContext().IsGraphCaptureAllowed()) {
GetPerThreadContext().CaptureEnd();
Status CUDAExecutionProvider::OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& run_options) {
CudaGraphAnnotation_t cuda_graph_annotation_id = GetPerThreadContext().GetCudaGraphAnnotationId(run_options);
if (IsGraphCaptureEnabled() && !GetPerThreadContext().IsGraphCaptured(cuda_graph_annotation_id)) {
if (GetPerThreadContext().IsGraphCaptureAllowed(cuda_graph_annotation_id)) {
GetPerThreadContext().CaptureEnd(cuda_graph_annotation_id);
// CUDA work issued to a capturing stream doesn’t actually run on the GPU,
// so run the captured graph here to actually execute the work.
ORT_RETURN_IF_ERROR(GetPerThreadContext().ReplayGraph());
ORT_RETURN_IF_ERROR(GetPerThreadContext().ReplayGraph(cuda_graph_annotation_id));
} else {
GetPerThreadContext().IncrementRegularRunCountBeforeGraphCapture();
}
Expand Down Expand Up @@ -433,12 +457,12 @@ bool CUDAExecutionProvider::IsGraphCaptureEnabled() const {
return info_.enable_cuda_graph;
}

bool CUDAExecutionProvider::IsGraphCaptured() const {
return GetPerThreadContext().IsGraphCaptured();
bool CUDAExecutionProvider::IsGraphCaptured(int graph_annotation_id) const {
return GetPerThreadContext().IsGraphCaptured(graph_annotation_id);
}

Status CUDAExecutionProvider::ReplayGraph() {
return GetPerThreadContext().ReplayGraph();
Status CUDAExecutionProvider::ReplayGraph(int graph_annotation_id) {
return GetPerThreadContext().ReplayGraph(graph_annotation_id);
}

namespace cuda {
Expand Down
17 changes: 9 additions & 8 deletions onnxruntime/core/providers/cuda/cuda_execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,8 @@ class CUDAExecutionProvider : public IExecutionProvider {
std::unique_ptr<profiling::EpProfiler> GetProfiler() override;

bool IsGraphCaptureEnabled() const override;
bool IsGraphCaptured() const override;
Status ReplayGraph() override;
bool IsGraphCaptured(CudaGraphAnnotation_t graph_annotation_id) const override;
Status ReplayGraph(CudaGraphAnnotation_t graph_annotation_id) override;
void RegisterStreamHandlers(IStreamCommandHandleRegistry& stream_handle_registry, AllocatorMap& allocators) const override;
OrtDevice GetOrtDeviceByMemType(OrtMemType mem_type) const override;
std::vector<AllocatorPtr> CreatePreferredAllocators() override;
Expand Down Expand Up @@ -168,11 +168,13 @@ class CUDAExecutionProvider : public IExecutionProvider {
}
}

bool IsGraphCaptureAllowed() const;
void CaptureBegin();
void CaptureEnd();
bool IsGraphCaptured() const;
Status ReplayGraph();
bool IsGraphCaptureAllowed(CudaGraphAnnotation_t cuda_graph_annotation_id) const;
bool IsGraphCaptureAllowedOnRun(CudaGraphAnnotation_t cuda_graph_annotation_id) const;
void CaptureBegin(CudaGraphAnnotation_t cuda_graph_annotation_id);
void CaptureEnd(CudaGraphAnnotation_t cuda_graph_annotation_id);
bool IsGraphCaptured(CudaGraphAnnotation_t cuda_graph_annotation_id) const;
CudaGraphAnnotation_t GetCudaGraphAnnotationId(const onnxruntime::RunOptions& run_options) const;
Status ReplayGraph(CudaGraphAnnotation_t cuda_graph_annotation_id);
void IncrementRegularRunCountBeforeGraphCapture();

private:
Expand All @@ -192,7 +194,6 @@ class CUDAExecutionProvider : public IExecutionProvider {
// Cuda graph with multi threads will be supported in the future, so cuda_graph_
// is put under PerThreadContext.
CUDAGraph cuda_graph_;
bool is_graph_captured_ = false;
int regular_run_count_before_graph_capture_ = 0;

// There is chance that the second regular run allocates GPU memory for causes like:
Expand Down
89 changes: 62 additions & 27 deletions onnxruntime/core/providers/cuda/cuda_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,44 @@

namespace onnxruntime {

CUDAGraph::CUDAGraph(cudaStream_t stream) : stream_(stream) {
CudaGraphSet::~CudaGraphSet() {
Clear();
}

void CUDAGraph::SetStream(cudaStream_t stream) {
void CudaGraphSet::Clear() {
for (auto& it : cuda_graphs_) {
CUDA_CALL_THROW(cudaGraphExecDestroy(it.second));
}
cuda_graphs_.clear();
}

bool CudaGraphSet::Contains(CudaGraphAnnotation_t cuda_graph_annotation_id) const {
return cuda_graphs_.find(cuda_graph_annotation_id) != cuda_graphs_.end();
}

void CudaGraphSet::Put(CudaGraphAnnotation_t cuda_graph_annotation_id, cudaGraphExec_t graph_exec) {
ORT_ENFORCE(!Contains(cuda_graph_annotation_id));
cuda_graphs_.emplace(cuda_graph_annotation_id, graph_exec);
}

cudaGraphExec_t CudaGraphSet::Get(CudaGraphAnnotation_t cuda_graph_annotation_id) const {
ORT_ENFORCE(Contains(cuda_graph_annotation_id));
return cuda_graphs_.at(cuda_graph_annotation_id);
}

CUDAGraphManager::CUDAGraphManager(cudaStream_t stream) : stream_(stream) {
}

void CUDAGraphManager::SetStream(cudaStream_t stream) {
stream_ = stream;
}

void CUDAGraph::CaptureBegin() {
ORT_ENFORCE(!has_graph_exec_,
"This cuda graph has already captured a graph. "
"Create a new instance to capture a new graph.");
void CUDAGraphManager::CaptureBegin(CudaGraphAnnotation_t cuda_graph_annotation_id) {
ORT_ENFORCE(IsGraphCaptureAllowedOnRun(cuda_graph_annotation_id));

ORT_ENFORCE(!cuda_graph_set_.Contains(cuda_graph_annotation_id),
"Trying to capture a graph with annotation id ", cuda_graph_annotation_id,
" that already used. Please use a different annotation id.");

CUDA_CALL_THROW(cudaStreamSynchronize(stream_));
// For now cuda graph can only work with a single thread. In the future, we
Expand All @@ -29,40 +56,48 @@ void CUDAGraph::CaptureBegin() {
CUDA_CALL_THROW(cudaStreamBeginCapture(stream_, cudaStreamCaptureModeGlobal));
}

void CUDAGraph::CaptureEnd() {
CUDA_CALL_THROW(cudaStreamEndCapture(stream_, &graph_));
if (graph_ == NULL) {
void CUDAGraphManager::CaptureEnd(CudaGraphAnnotation_t cuda_graph_annotation_id) {
cudaGraph_t graph = NULL;
CUDA_CALL_THROW(cudaStreamEndCapture(stream_, &graph));
if (graph == NULL) {
ORT_THROW("CUDAGraph::CaptureEnd: graph_ is NULL");
}

has_graph_ = true;
CUDA_CALL_THROW(cudaGraphInstantiate(&graph_exec_, graph_, NULL, NULL, 0));
has_graph_exec_ = true;
CUDA_CALL_THROW(cudaGraphDestroy(graph_));
has_graph_ = false;
cudaGraphExec_t graph_exec = NULL;
CUDA_CALL_THROW(cudaGraphInstantiate(&graph_exec, graph, NULL, NULL, 0));
CUDA_CALL_THROW(cudaGraphDestroy(graph));

// Currently all the captured graphs will be tied to the session's lifecycle
// TODO(wy): Addd an interface to free captured graphs
cuda_graph_set_.Put(cuda_graph_annotation_id, graph_exec);
}

Status CUDAGraph::Replay() {
Status CUDAGraphManager::Replay(CudaGraphAnnotation_t cuda_graph_annotation_id) {
// Although this function is not thread safe, the lock is not needed here because
// CUDA EP maintains a separate cuda graph per thread
LOGS_DEFAULT(INFO) << "Replaying CUDA graph on stream " << stream_;
CUDA_RETURN_IF_ERROR(cudaGraphLaunch(graph_exec_, stream_));
LOGS_DEFAULT(INFO) << "Replaying CUDA graph on stream " << stream_ << " with cuda_graph_annotation_id "
<< cuda_graph_annotation_id;

cudaGraphExec_t graph_exec = cuda_graph_set_.Get(cuda_graph_annotation_id);
CUDA_RETURN_IF_ERROR(cudaGraphLaunch(graph_exec, stream_));

CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream_));
return Status::OK();
}

void CUDAGraph::Reset() {
if (has_graph_) {
CUDA_CALL_THROW(cudaGraphDestroy(graph_));
has_graph_ = false;
}
if (has_graph_exec_) {
CUDA_CALL_THROW(cudaGraphExecDestroy(graph_exec_));
has_graph_exec_ = false;
}
bool CUDAGraphManager::IsGraphCaptureAllowedOnRun(CudaGraphAnnotation_t cuda_graph_annotation_id) const {
return cuda_graph_annotation_id != kCudaGraphAnnotationSkip;
}

bool CUDAGraphManager::IsGraphCaptured(CudaGraphAnnotation_t cuda_graph_annotation_id) const {
return cuda_graph_set_.Contains(cuda_graph_annotation_id);
}

void CUDAGraphManager::Reset() {
cuda_graph_set_.Clear();
}

CUDAGraph::~CUDAGraph() {
CUDAGraphManager::~CUDAGraphManager() {
Reset();
}

Expand Down
48 changes: 35 additions & 13 deletions onnxruntime/core/providers/cuda/cuda_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,33 +3,55 @@

#pragma once

#include <unordered_map>

#include "core/common/common.h"
#include "core/platform/ort_mutex.h"
#include "core/providers/cuda/cuda_pch.h"

namespace onnxruntime {

using CaptureId_t = unsigned long long;
using CudaGraphAnnotation_t = int;
using CudaGraphSet_t = std::unordered_map<CudaGraphAnnotation_t, cudaGraphExec_t>;

constexpr CudaGraphAnnotation_t kCudaGraphAnnotationSkip = -1;
constexpr CudaGraphAnnotation_t kCudaGraphAnnotationDefault = 0;

struct CudaGraphSet {
CudaGraphSet(){};
~CudaGraphSet();

struct CUDAGraph {
CUDAGraph(){};
CUDAGraph(cudaStream_t stream);
~CUDAGraph();
void Clear();
bool Contains(CudaGraphAnnotation_t cuda_graph_annotation_id) const;
void Put(CudaGraphAnnotation_t cuda_graph_annotation_id, cudaGraphExec_t graph_exec);
cudaGraphExec_t Get(CudaGraphAnnotation_t cuda_graph_annotation_id) const;

private:
CudaGraphSet_t cuda_graphs_;
};

struct CUDAGraphManager {
CUDAGraphManager(){};
CUDAGraphManager(cudaStream_t stream);
~CUDAGraphManager();

void SetStream(cudaStream_t stream);
void CaptureBegin();
void CaptureEnd();
Status Replay();
void CaptureBegin(CudaGraphAnnotation_t cuda_graph_annotation_id);
void CaptureEnd(CudaGraphAnnotation_t cuda_graph_annotation_id);
Status Replay(CudaGraphAnnotation_t cuda_graph_annotation_id);

void Reset();

private:
cudaGraph_t graph_ = NULL;
cudaGraphExec_t graph_exec_ = NULL;
bool IsGraphCaptureAllowedOnRun(CudaGraphAnnotation_t cuda_graph_annotation_id) const;
bool IsGraphCaptured(CudaGraphAnnotation_t cuda_graph_annotation_id) const;

bool has_graph_ = false;
bool has_graph_exec_ = false;
private:
CudaGraphSet cuda_graph_set_;
CudaGraphAnnotation_t cuda_graph_annotation_id_ = kCudaGraphAnnotationDefault;

cudaStream_t stream_ = nullptr; // Does not own the stream
};

using CUDAGraph = CUDAGraphManager;

} // namespace onnxruntime

0 comments on commit 72ce4de

Please sign in to comment.