[Plugin EP] Port graph capture/replay APIs#27958
Conversation
…ve the hardcoded list of EPs that support graph capture from inference_session.cc
There was a problem hiding this comment.
Pull request overview
This PR (draft) ports graph capture/replay support to the Plugin EP pathway by extending the OrtEp C API, wiring those callbacks through the plugin EP provider wrapper, and updating session initialization logic to validate/capture based on an EP-provided node-assignment policy.
Changes:
- Bump ORT version/API version to 1.26.0 /
ORT_API_VERSION=26and add newOrtEpgraph capture/replay callbacks plusOrtGraphCaptureNodeAssignmentPolicy. - Update
InferenceSession::Initialize()to select any EP with graph capture enabled, validate graph assignment via EP policy, and cache a single EP for replay. - Add/extend tests for plugin EP graph capture APIs and add an end-to-end autoep WebGPU plugin EP graph capture/replay test.
Reviewed changes
Copilot reviewed 18 out of 18 changed files in this pull request and generated 2 comments.
Show a summary per file
| File | Description |
|---|---|
| VERSION_NUMBER | Bumps runtime version to 1.26.0. |
| include/onnxruntime/core/session/onnxruntime_c_api.h | Bumps ORT_API_VERSION to 26. |
| onnxruntime/core/session/onnxruntime_c_api.cc | Updates version string static assert to 1.26.0. |
| include/onnxruntime/core/session/onnxruntime_ep_c_api.h | Adds graph capture/replay callbacks to OrtEp and introduces OrtGraphCaptureNodeAssignmentPolicy. |
| include/onnxruntime/core/framework/execution_provider.h | Adds IExecutionProvider::GetGraphCaptureNodeAssignmentPolicy() with a strict default. |
| onnxruntime/core/session/inference_session.cc | Generalizes graph-capture EP selection/validation and uses EP-specified node assignment policy. |
| onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.h | Exposes graph capture/replay APIs on PluginExecutionProvider. |
| onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc | Implements plugin-side forwarding for graph capture/replay and policy query with version gating. |
| onnxruntime/core/providers/webgpu/ep/ep.h | Declares plugin adapter entrypoints for graph capture/replay and assignment policy. |
| onnxruntime/core/providers/webgpu/ep/ep.cc | Wires WebGPU plugin EP adapter function pointers and forwards to EP impl. |
| onnxruntime/core/providers/webgpu/webgpu_execution_provider.h | Returns ALLOW_CPU_FOR_SHAPES policy for WebGPU EP. |
| onnxruntime/core/providers/js/js_execution_provider.h | Returns ALLOW_CPU_FOR_SHAPES policy for JS EP. |
| onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h | Returns ALLOW_CPU_FOR_SHAPES policy for DML EP wrapper. |
| onnxruntime/core/providers/cuda/cuda_execution_provider.h | Returns ALLOW_CPU_FOR_SHAPES policy for CUDA EP. |
| onnxruntime/core/providers/cuda/plugin/cuda_ep.h | Declares plugin CUDA EP adapter entrypoints for graph capture/replay and policy. |
| onnxruntime/core/providers/cuda/plugin/cuda_ep.cc | Wires plugin CUDA EP adapter function pointers (currently stubbed). |
| onnxruntime/test/framework/ep_plugin_provider_test.cc | Adds unit tests for plugin EP graph capture/replay function-pointer behavior and version gating. |
| onnxruntime/test/autoep/test_graph_capture.cc | Adds end-to-end test exercising WebGPU plugin EP graph capture + replay via public APIs. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
…ionProvider::IsGraphCaptureEnabled() as ORT previously never managed graph capture for this EP; Update use of webgpu name with constant
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 21 out of 21 changed files in this pull request and generated 1 comment.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 19 out of 19 changed files in this pull request and generated 2 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 20 out of 20 changed files in this pull request and generated 2 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 21 out of 21 changed files in this pull request and generated 1 comment.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 20 out of 20 changed files in this pull request and generated 1 comment.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
tianleiwu
left a comment
There was a problem hiding this comment.
Review Summary
Well-structured PR that generalizes the hardcoded EP-specific graph capture logic in InferenceSession into a policy-based, extensible design driven by IExecutionProvider virtual methods. The C API additions are thoroughly documented, version-gated, and defensively validated. The recursion-depth guard in Run() is a valuable safety net.
Highlights
AreAllComputeNodesAssignedToEpOrCpu()now requireshas_node_on_provider, which is a correctness improvement — avoids false positives where no nodes are on the target EP.PluginExecutionProvider::IsGraphCaptureEnabled()validation that checksIsGraphCaptured/ReplayGraphare non-null is a valuable defensive measure — without it, ORT would silently hang in the recursive warm-up loop.- NvTensorRTRTX change to return
falsefromIsGraphCaptureEnabled()is correct: it already returnedfalsefrom the publicIsGraphCaptured()and manages capture internally. - Excellent API documentation on the four new
OrtEpmembers — exactly the level of detail an out-of-tree EP author needs.
Minor Observations (not blocking)
CopyTensorAPI consistency: The new single-tensorCopyTensor(const OrtValue*, OrtValue*, OrtSyncStream*)takes raw pointers while the existingCopyTensorstakesconst std::vector<Value>&wrappers. Works correctly but inconsistent style.- CUDA plugin stubs: Since
IsGraphCaptureEnabledImplreturnsfalse, the other three stub registrations are dead code. Fine as scaffolding for upcoming work.
Overall: clean, well-tested, ready to merge. Two inline suggestions below.
This reverts commit 7afe4c2.
NOTE: This PR cannot be merged until the ORT version is updated to 1.26.0 in the main branch
Description
Ports graph capture/replay APIs (e.g., CUDA Graph) to the Plugin EP (
OrtEp) C API so that plugin-based execution providers can participate in ORT-managed graph capture and replay.What changed
New Plugin EP C API functions (
onnxruntime_ep_c_api.h):OrtEp::IsGraphCaptureEnabled— indicates whether the EP has graph capture enabled.OrtEp::IsGraphCaptured— indicates whether a graph has been captured for a given annotation ID.OrtEp::ReplayGraph— replays a previously captured graph.OrtEp::GetGraphCaptureNodeAssignmentPolicy— returns the node assignment validation policy for graph capture.All four are optional (NULL defaults to safe behavior) and version-gated (
ort_version_supported >= 26).If
IsGraphCaptureEnabledreturns true,IsGraphCapturedandReplayGraphmust also be implemented;otherwise
PluginExecutionProviderlogs a warning and disables graph capture for that EP.New
OrtGraphCaptureNodeAssignmentPolicyenum (onnxruntime_ep_c_api.h):Replaces the hardcoded EP-name checks in
InferenceSession::Initialize()with a policy-based approach:ALL_NODES_ON_EP— all nodes must be on the target EP (e.g., TensorRT).ALLOW_CPU_FOR_SHAPES— CPU nodes allowed for shape computation if no memcpy nodes exist (e.g., CUDA, WebGPU, DML).Refactored
InferenceSessiongraph capture selection (inference_session.cc):graph_support_ep_listand per-EPstrcmpchecks.IsGraphCaptureEnabled()+GetGraphCaptureNodeAssignmentPolicy()to select and validate the graph-capturing EP.AreAllComputeNodesAssignedToCudaOrJsOrDmlEpWebGpuEp()→ generalized toAreAllComputeNodesAssignedToEpOrCpu(), which also requires at least one node on the target EP.IExecutionProvider::GetGraphCaptureNodeAssignmentPolicy()added to the base class (defaults toALL_NODES_ON_EP).Bounded graph capture recursion (
inference_session.cc/h):Run()now delegates toRunImpl()with agraph_capture_depthparameter.kMaxGraphCaptureRunAttempts = 8, returning a clear error if the EP never reportsIsGraphCaptured() == true.EP implementations:
IExecutionProvider.IsGraphCaptureEnabled()now returnsfalsesince this EP manages graph capture internally (not via ORT).C++ wrapper (
onnxruntime_cxx_api.h/onnxruntime_cxx_inline.h):Ort::Env::CopyTensor()convenience overload for copying a single tensor (wrapsCopyTensorswithnum_tensors=1).Tests
ep_plugin_provider_test.cc: Unit tests for each newPluginExecutionProvidergraph capture method, including NULL function pointer defaults, version < 26 backward compatibility, and validation thatIsGraphCaptureEnabled()returns false whenIsGraphCapturedorReplayGraphare NULL.test_graph_capture.cc: End-to-end test for WebGPU plugin EP graph capture/replay using IO binding (warm-up + capture run, then replay with different inputs).Motivation and Context
Previously, graph capture support was limited to a hardcoded list of EPs (
kCudaExecutionProvider,kTensorrtExecutionProvider,kJsExecutionProvider,kWebGpuExecutionProvider,kDmlExecutionProvider) with EP-specific validation logic inInferenceSession. This made it impossible for plugin EPs to participate in ORT-managed graph capture/replay without modifying the core session code.This PR makes graph capture/replay extensible to any EP, including out-of-tree plugin EPs, by exposing it through the
OrtEpC API.