Add CUDA Graph support for the CUDA plugin EP#28002
Conversation
There was a problem hiding this comment.
Pull request overview
This PR adds CUDA Graph capture/replay support to the CUDA plugin Execution Provider (EP) by extending the OrtEp C API and teaching ORT core to orchestrate warm-up, capture, and replay in a provider-agnostic way.
Changes:
- Extend the OrtEp C API and
IExecutionProviderinterface to support graph capture/replay capability querying, replay, and assignment-policy-driven validation. - Update
InferenceSessionto validate and manage graph capture generically (policy-driven validation, bounded warm-up recursion, replay fast-path). - Implement CUDA Graph capture/replay in the CUDA plugin EP (per-thread graph stream/context) and add Python test coverage + docs.
Reviewed changes
Copilot reviewed 27 out of 27 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
| onnxruntime/test/python/transformers/test_cuda_plugin_ep.py | Adds CUDA Graph capture/replay tests for the CUDA plugin EP. |
| onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.h | Adds graph capture/replay and provider-options overrides to PluginExecutionProvider. |
| onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc | Extracts provider options from session config; bridges graph capture/replay calls to plugin via OrtEp v1.26 callbacks. |
| onnxruntime/core/session/inference_session.h | Adds bounded warm-up recursion constant and a RunImpl helper with depth tracking. |
| onnxruntime/core/session/inference_session.cc | Makes graph capture validation provider-agnostic via policy; adds bounded recursive warm-up/capture behavior and replay fast-path. |
| onnxruntime/core/providers/webgpu/webgpu_execution_provider.h | Declares graph-capture assignment policy for WebGPU EP. |
| onnxruntime/core/providers/webgpu/ep/ep.h | Adds OrtEp graph capture/replay callback declarations for WebGPU wrapper EP. |
| onnxruntime/core/providers/webgpu/ep/ep.cc | Wires OrtEp graph capture/replay callbacks through to the underlying IExecutionProvider. |
| onnxruntime/core/providers/js/js_execution_provider.h | Declares graph-capture assignment policy for JS EP. |
| onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h | Declares graph-capture assignment policy for DML EP. |
| onnxruntime/core/providers/cuda/plugin/cuda_stream_plugin.h | Adds external-stream handle init support and stream ownership tracking. |
| onnxruntime/core/providers/cuda/plugin/cuda_stream_plugin.cc | Implements external-stream init; avoids destroying external streams; avoids synchronizing while stream is actively capturing. |
| onnxruntime/core/providers/cuda/plugin/cuda_plugin_utils.h | Adds a TLS “is capturing” flag to avoid illegal CUDA calls during capture. |
| onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h | Skips cudaSetDevice() during capture to avoid CUDA capture restrictions. |
| onnxruntime/core/providers/cuda/plugin/cuda_graph_plugin.h | New plugin-side CUDA graph manager API (capture/end/replay, per-annotation storage). |
| onnxruntime/core/providers/cuda/plugin/cuda_graph_plugin.cc | Implements plugin-side CUDA graph lifecycle and replay. |
| onnxruntime/core/providers/cuda/plugin/cuda_ep.h | Adds CUDA graph config knobs and per-thread graph context plumbing for the plugin EP. |
| onnxruntime/core/providers/cuda/plugin/cuda_ep.cc | Implements plugin EP graph capture/replay callbacks and per-thread graph stream/context management. |
| onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc | Parses enable_cuda_graph and warm-up run count from session config/provider options. |
| onnxruntime/core/providers/cuda/cuda_execution_provider.h | Adds graph-capture assignment policy override for in-tree CUDA EP. |
| include/onnxruntime/core/session/onnxruntime_ep_c_api.h | Adds OrtEp v1.26 graph capture/replay callbacks + node-assignment-policy enum and docs. |
| include/onnxruntime/core/session/onnxruntime_cxx_inline.h | Adds Env::CopyTensor() convenience wrapper over CopyTensors(). |
| include/onnxruntime/core/session/onnxruntime_cxx_api.h | Declares Env::CopyTensor() in the public C++ API. |
| include/onnxruntime/core/framework/execution_provider.h | Adds GetGraphCaptureNodeAssignmentPolicy() virtual with default strict policy. |
| docs/cuda_plugin_ep/QUICK_START.md | Updates plugin quick start/prereqs; removes outdated “no CUDA Graph” limitation. |
| docs/cuda_plugin_ep/cuda_plugin_ep_design.md | Updates design doc to reflect implemented CUDA Graph support and framework/API changes. |
| docs/cuda_plugin_ep/cuda_graph_for_cuda_plugin.md | New detailed design/implementation doc for CUDA Graph in the plugin EP. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 28 out of 28 changed files in this pull request and generated 4 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 30 out of 30 changed files in this pull request and generated 3 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 24 out of 24 changed files in this pull request and generated 1 comment.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 24 out of 24 changed files in this pull request and generated 1 comment.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
adrianlizarraga
left a comment
There was a problem hiding this comment.
Review: Add CUDA graph capture/replay support to the CUDA plugin EP.
Key changes:
- New CudaGraphManager — State machine for capture → instantiate → replay lifecycle, keyed by annotation ID for multi-graph support.
- Per-thread graph contexts — Each thread gets its own CUDA stream + graph manager via PerThreadContext, tracked with thread_local maps and weak_ptr cleanup.
- Thread-affine device stream pool — SessionState hands out per-thread stream sets so captured graphs replay on the same stream that captured them.
- OnRunStart/OnRunEnd callbacks — Begin/end capture, detect allocations during capture (warning), and replay on first captured run.
- Capture-safe kernel dispatch — Thread-local flag skips cudaSetDevice() and cudaStreamSynchronize() during capture (both prohibited by CUDA).
- Owned vs. external streams — CudaSyncStream gains owns_stream_ flag; external graph streams skip destroy and library handle creation.
- Sticky CUDA error cleanup — cudaGetLastError() calls added in mempool allocator to clear errors before they poison subsequent CUDA calls.
Minor Observations
- Thread safety in PerThreadContext — The thread_local + weak_ptr tracking pattern with mutex protection is correctly implemented. There's a theoretical race if an EP is destroyed while a thread simultaneously calls GetPerThreadContext(), but this is a caller-responsibility concern (EP destruction should happen after all inference calls complete), not a bug.
- cudaStreamIsCapturing failure path — In cuda_stream_plugin.cc, if cudaStreamIsCapturing fails, execution falls through to cudaStreamSynchronize. If the stream IS capturing and the query failed, this would error — but something is already wrong at that point, so this is acceptable.
- Device stream pool thread token design — The thread_local shared_ptr used purely for identity-based keying (with weak_ptr for pruning dead threads) is a sound pattern.
- CUDA graph capture/replay state machine — The lifecycle management (capture → replay → invalidation) looks correct with proper defensive checks.
The implementation follows ORT conventions for error handling, and the overall design is well-structured.
This reverts commit 58a87dc.
Description
This PR brings CUDA graph capture/replay to the CUDA plugin execution provider so plugin-based CUDA deployments can get the same reduced CPU launch overhead that the in-tree CUDA EP already supports. It also adds the ORT framework and plugin-C-API plumbing needed to let graph-enabled plugin EPs participate safely in warmup, capture, and replay, while preserving compatibility with older plugins through version-gated fallbacks.
Summary of Changes
CUDA plugin EP runtime and allocator integration
onnxruntime/core/providers/cuda/plugin/cuda_ep.cconnxruntime/core/providers/cuda/plugin/cuda_ep.honnxruntime/core/providers/cuda/plugin/cuda_graph_plugin.ccCudaGraphSet/CudaGraphManagerto own captured graphs and coordinate warmup, capture, and replay by annotation ID.onnxruntime/core/providers/cuda/plugin/cuda_graph_plugin.honnxruntime/core/providers/cuda/plugin/cuda_stream_plugin.cconnxruntime/core/providers/cuda/plugin/cuda_stream_plugin.honnxruntime/core/providers/cuda/plugin/cuda_ep_factory.ccenable_cuda_graphandmin_num_runs_before_cuda_graph_captureprovider/session options for the plugin EP.onnxruntime/core/providers/cuda/plugin/cuda_mempool_allocator_plugin.cconnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.honnxruntime/core/providers/cuda/plugin/cuda_plugin_utils.hORT framework and plugin API support for graph replay
include/onnxruntime/core/session/onnxruntime_ep_c_api.hOnRunStart/OnRunEnd.include/onnxruntime/core/framework/execution_provider.honnxruntime/core/session/inference_session.cconnxruntime/core/session/inference_session.honnxruntime/core/framework/session_state.cconnxruntime/core/framework/session_state.honnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cconnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.honnxruntime/core/providers/webgpu/webgpu_execution_provider.honnxruntime/core/providers/js/js_execution_provider.hTests and validation coverage
onnxruntime/test/python/transformers/test_cuda_plugin_ep.pyDocumentation
docs/cuda_plugin_ep/cuda_graph_for_cuda_plugin.mddocs/cuda_plugin_ep/cuda_plugin_ep_design.mddocs/cuda_plugin_ep/QUICK_START.mdTesting
onnxruntime_BUILD_CUDA_EP_AS_PLUGIN=ON, install the generated wheel, and deploy the CUDA plugin shared library as described indocs/cuda_plugin_ep/QUICK_START.md.python onnxruntime/test/python/transformers/test_cuda_plugin_ep.py.gpu_graph_idcaptures, and the second-device path when multiple GPUs are available.Motivation and Context
The CUDA plugin EP exists to decouple CUDA EP delivery from core ONNX Runtime releases, but that model only works well if important runtime optimizations are also available through the plugin path. CUDA graph replay is one of the highest-value CUDA execution optimizations because it eliminates repeated kernel-launch overhead after capture, especially for steady-state inference workloads.
Supporting that in the plugin EP required more than adding plugin-local capture code. ORT also needed a framework-level replay flow that works for plugin EPs, a plugin C API contract for graph support and node-assignment policy, and thread-affine stream reuse so captured graph resources and stream wrappers are not reused across unrelated threads. This PR packages those pieces together and documents the resulting behavior for future plugin EP work. It also depends on earlier plugin allocator work so warmup can stabilize allocations before capture begins.
Checklist