Skip to content

Add CUDA Graph support for the CUDA plugin EP#28002

Merged
tianleiwu merged 13 commits intomainfrom
tlwu/20260407/cuda_plugin_cuda_graph
Apr 10, 2026
Merged

Add CUDA Graph support for the CUDA plugin EP#28002
tianleiwu merged 13 commits intomainfrom
tlwu/20260407/cuda_plugin_cuda_graph

Conversation

@tianleiwu
Copy link
Copy Markdown
Contributor

@tianleiwu tianleiwu commented Apr 7, 2026

Description

This PR brings CUDA graph capture/replay to the CUDA plugin execution provider so plugin-based CUDA deployments can get the same reduced CPU launch overhead that the in-tree CUDA EP already supports. It also adds the ORT framework and plugin-C-API plumbing needed to let graph-enabled plugin EPs participate safely in warmup, capture, and replay, while preserving compatibility with older plugins through version-gated fallbacks.

Summary of Changes

CUDA plugin EP runtime and allocator integration

File Change
onnxruntime/core/providers/cuda/plugin/cuda_ep.cc Implements plugin-side graph capture lifecycle callbacks, per-thread graph context management, graph replay, and stream selection for graph-enabled runs.
onnxruntime/core/providers/cuda/plugin/cuda_ep.h Adds CUDA graph configuration/state to the plugin EP, including per-thread graph context ownership.
onnxruntime/core/providers/cuda/plugin/cuda_graph_plugin.cc Adds CudaGraphSet/CudaGraphManager to own captured graphs and coordinate warmup, capture, and replay by annotation ID.
onnxruntime/core/providers/cuda/plugin/cuda_graph_plugin.h Declares the new graph manager types and graph-related constants.
onnxruntime/core/providers/cuda/plugin/cuda_stream_plugin.cc Adds external-stream wrapping so graph-enabled runs can reuse the thread’s graph stream without taking ownership of it.
onnxruntime/core/providers/cuda/plugin/cuda_stream_plugin.h Declares the external-stream initialization path and stream ownership tracking.
onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc Parses enable_cuda_graph and min_num_runs_before_cuda_graph_capture provider/session options for the plugin EP.
onnxruntime/core/providers/cuda/plugin/cuda_mempool_allocator_plugin.cc Updates allocator behavior needed for CUDA native mempool compatibility during graph capture/replay.
onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h Adjusts plugin kernel/device helpers used by the graph-enabled execution path.
onnxruntime/core/providers/cuda/plugin/cuda_plugin_utils.h Adds supporting helpers used by the plugin CUDA graph flow.

ORT framework and plugin API support for graph replay

File Change
include/onnxruntime/core/session/onnxruntime_ep_c_api.h Documents and extends the plugin EP contract for graph-enabled runs, including replay behavior relative to OnRunStart/OnRunEnd.
include/onnxruntime/core/framework/execution_provider.h Adds graph-capture node-assignment policy support to the execution provider interface.
onnxruntime/core/session/inference_session.cc Generalizes the session replay path and warmup/capture retry loop so ORT can drive graph replay for graph-capable EPs.
onnxruntime/core/session/inference_session.h Updates replay-related messaging and supporting declarations for the new run flow.
onnxruntime/core/framework/session_state.cc Makes device-stream collection reuse thread-affine so warmup/capture/replay reuse stays on the owning thread.
onnxruntime/core/framework/session_state.h Adds supporting state for the thread-affine stream collection pool.
onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc Bridges the new graph callbacks, hardens validation of plugin graph support, and exposes effective plugin provider options gathered from session config.
onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.h Stores provider options and declares the new accessor/graph bridge behavior.
onnxruntime/core/providers/webgpu/webgpu_execution_provider.h Aligns graph-capture policy support with the new execution-provider interface.
onnxruntime/core/providers/js/js_execution_provider.h Aligns graph-capture policy support with the new execution-provider interface.

Tests and validation coverage

File Change
onnxruntime/test/python/transformers/test_cuda_plugin_ep.py Adds end-to-end CUDA graph tests for warmup/capture/replay, replay after input updates, CUDA mempool mode, multiple graph annotation IDs, multi-GPU/device-id coverage, and a simple Add model.

Documentation

File Change
docs/cuda_plugin_ep/cuda_graph_for_cuda_plugin.md Adds a dedicated design/implementation document covering architecture, lifecycle, allocator interaction, concurrency, and verification guidance.
docs/cuda_plugin_ep/cuda_plugin_ep_design.md Updates the broader plugin EP design doc to reflect that CUDA graph support is implemented and documents the framework-level changes.
docs/cuda_plugin_ep/QUICK_START.md Updates quick-start/testing guidance and removes the outdated “no CUDA Graph support” limitation.

Testing

  • Build ONNX Runtime with onnxruntime_BUILD_CUDA_EP_AS_PLUGIN=ON, install the generated wheel, and deploy the CUDA plugin shared library as described in docs/cuda_plugin_ep/QUICK_START.md.
  • Run python onnxruntime/test/python/transformers/test_cuda_plugin_ep.py.
  • Pay particular attention to the new CUDA graph scenarios in that suite: warmup/capture/replay, replay after in-place input updates, CUDA mempool mode, multiple gpu_graph_id captures, and the second-device path when multiple GPUs are available.
  • Verify backward compatibility by confirming older plugins still load safely through the version-gated graph callback bridge, and that graph-disabled runs continue through the normal execution path.

Motivation and Context

The CUDA plugin EP exists to decouple CUDA EP delivery from core ONNX Runtime releases, but that model only works well if important runtime optimizations are also available through the plugin path. CUDA graph replay is one of the highest-value CUDA execution optimizations because it eliminates repeated kernel-launch overhead after capture, especially for steady-state inference workloads.

Supporting that in the plugin EP required more than adding plugin-local capture code. ORT also needed a framework-level replay flow that works for plugin EPs, a plugin C API contract for graph support and node-assignment policy, and thread-affine stream reuse so captured graph resources and stream wrappers are not reused across unrelated threads. This PR packages those pieces together and documents the resulting behavior for future plugin EP work. It also depends on earlier plugin allocator work so warmup can stabilize allocations before capture begins.

Checklist

  • Tests added/updated
  • Documentation updated (if applicable)
  • No breaking changes (or documented in description)

@tianleiwu tianleiwu marked this pull request as draft April 7, 2026 22:15
@tianleiwu tianleiwu marked this pull request as ready for review April 9, 2026 05:22
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds CUDA Graph capture/replay support to the CUDA plugin Execution Provider (EP) by extending the OrtEp C API and teaching ORT core to orchestrate warm-up, capture, and replay in a provider-agnostic way.

Changes:

  • Extend the OrtEp C API and IExecutionProvider interface to support graph capture/replay capability querying, replay, and assignment-policy-driven validation.
  • Update InferenceSession to validate and manage graph capture generically (policy-driven validation, bounded warm-up recursion, replay fast-path).
  • Implement CUDA Graph capture/replay in the CUDA plugin EP (per-thread graph stream/context) and add Python test coverage + docs.

Reviewed changes

Copilot reviewed 27 out of 27 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
onnxruntime/test/python/transformers/test_cuda_plugin_ep.py Adds CUDA Graph capture/replay tests for the CUDA plugin EP.
onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.h Adds graph capture/replay and provider-options overrides to PluginExecutionProvider.
onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc Extracts provider options from session config; bridges graph capture/replay calls to plugin via OrtEp v1.26 callbacks.
onnxruntime/core/session/inference_session.h Adds bounded warm-up recursion constant and a RunImpl helper with depth tracking.
onnxruntime/core/session/inference_session.cc Makes graph capture validation provider-agnostic via policy; adds bounded recursive warm-up/capture behavior and replay fast-path.
onnxruntime/core/providers/webgpu/webgpu_execution_provider.h Declares graph-capture assignment policy for WebGPU EP.
onnxruntime/core/providers/webgpu/ep/ep.h Adds OrtEp graph capture/replay callback declarations for WebGPU wrapper EP.
onnxruntime/core/providers/webgpu/ep/ep.cc Wires OrtEp graph capture/replay callbacks through to the underlying IExecutionProvider.
onnxruntime/core/providers/js/js_execution_provider.h Declares graph-capture assignment policy for JS EP.
onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h Declares graph-capture assignment policy for DML EP.
onnxruntime/core/providers/cuda/plugin/cuda_stream_plugin.h Adds external-stream handle init support and stream ownership tracking.
onnxruntime/core/providers/cuda/plugin/cuda_stream_plugin.cc Implements external-stream init; avoids destroying external streams; avoids synchronizing while stream is actively capturing.
onnxruntime/core/providers/cuda/plugin/cuda_plugin_utils.h Adds a TLS “is capturing” flag to avoid illegal CUDA calls during capture.
onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h Skips cudaSetDevice() during capture to avoid CUDA capture restrictions.
onnxruntime/core/providers/cuda/plugin/cuda_graph_plugin.h New plugin-side CUDA graph manager API (capture/end/replay, per-annotation storage).
onnxruntime/core/providers/cuda/plugin/cuda_graph_plugin.cc Implements plugin-side CUDA graph lifecycle and replay.
onnxruntime/core/providers/cuda/plugin/cuda_ep.h Adds CUDA graph config knobs and per-thread graph context plumbing for the plugin EP.
onnxruntime/core/providers/cuda/plugin/cuda_ep.cc Implements plugin EP graph capture/replay callbacks and per-thread graph stream/context management.
onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc Parses enable_cuda_graph and warm-up run count from session config/provider options.
onnxruntime/core/providers/cuda/cuda_execution_provider.h Adds graph-capture assignment policy override for in-tree CUDA EP.
include/onnxruntime/core/session/onnxruntime_ep_c_api.h Adds OrtEp v1.26 graph capture/replay callbacks + node-assignment-policy enum and docs.
include/onnxruntime/core/session/onnxruntime_cxx_inline.h Adds Env::CopyTensor() convenience wrapper over CopyTensors().
include/onnxruntime/core/session/onnxruntime_cxx_api.h Declares Env::CopyTensor() in the public C++ API.
include/onnxruntime/core/framework/execution_provider.h Adds GetGraphCaptureNodeAssignmentPolicy() virtual with default strict policy.
docs/cuda_plugin_ep/QUICK_START.md Updates plugin quick start/prereqs; removes outdated “no CUDA Graph” limitation.
docs/cuda_plugin_ep/cuda_plugin_ep_design.md Updates design doc to reflect implemented CUDA Graph support and framework/API changes.
docs/cuda_plugin_ep/cuda_graph_for_cuda_plugin.md New detailed design/implementation doc for CUDA Graph in the plugin EP.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 28 out of 28 changed files in this pull request and generated 4 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 30 out of 30 changed files in this pull request and generated 3 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 24 out of 24 changed files in this pull request and generated 1 comment.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@tianleiwu tianleiwu requested a review from Copilot April 10, 2026 01:41
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 24 out of 24 changed files in this pull request and generated 1 comment.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@tianleiwu tianleiwu changed the title Cuda graph for cuda plugin ep Add CUDA Graph support for the CUDA plugin EP Apr 10, 2026
Copy link
Copy Markdown
Contributor

@adrianlizarraga adrianlizarraga left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review: Add CUDA graph capture/replay support to the CUDA plugin EP.

Key changes:

  • New CudaGraphManager — State machine for capture → instantiate → replay lifecycle, keyed by annotation ID for multi-graph support.
  • Per-thread graph contexts — Each thread gets its own CUDA stream + graph manager via PerThreadContext, tracked with thread_local maps and weak_ptr cleanup.
  • Thread-affine device stream pool — SessionState hands out per-thread stream sets so captured graphs replay on the same stream that captured them.
  • OnRunStart/OnRunEnd callbacks — Begin/end capture, detect allocations during capture (warning), and replay on first captured run.
  • Capture-safe kernel dispatch — Thread-local flag skips cudaSetDevice() and cudaStreamSynchronize() during capture (both prohibited by CUDA).
  • Owned vs. external streams — CudaSyncStream gains owns_stream_ flag; external graph streams skip destroy and library handle creation.
  • Sticky CUDA error cleanup — cudaGetLastError() calls added in mempool allocator to clear errors before they poison subsequent CUDA calls.

Minor Observations

  1. Thread safety in PerThreadContext — The thread_local + weak_ptr tracking pattern with mutex protection is correctly implemented. There's a theoretical race if an EP is destroyed while a thread simultaneously calls GetPerThreadContext(), but this is a caller-responsibility concern (EP destruction should happen after all inference calls complete), not a bug.
  2. cudaStreamIsCapturing failure path — In cuda_stream_plugin.cc, if cudaStreamIsCapturing fails, execution falls through to cudaStreamSynchronize. If the stream IS capturing and the query failed, this would error — but something is already wrong at that point, so this is acceptable.
  3. Device stream pool thread token design — The thread_local shared_ptr used purely for identity-based keying (with weak_ptr for pruning dead threads) is a sound pattern.
  4. CUDA graph capture/replay state machine — The lifecycle management (capture → replay → invalidation) looks correct with proper defensive checks.

The implementation follows ORT conventions for error handling, and the overall design is well-structured.

@tianleiwu tianleiwu enabled auto-merge (squash) April 10, 2026 08:40
@tianleiwu tianleiwu merged commit 58a87dc into main Apr 10, 2026
104 of 108 checks passed
@tianleiwu tianleiwu deleted the tlwu/20260407/cuda_plugin_cuda_graph branch April 10, 2026 09:06
sanaa-hamel-microsoft added a commit that referenced this pull request Apr 10, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants