From 1851b73f71032018b00d59251a01abef9db85762 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Tue, 18 Nov 2025 20:27:45 -0800 Subject: [PATCH 01/17] [build] fix packaging pipeline for arm64/linux (#26592) ### Description ### Motivation and Context --- .../azure-pipelines/templates/c-api-linux-cpu.yml | 12 ++++++------ ...api-artifacts-package-and-publish-steps-posix.yml | 4 ++-- .../templates/mac-cpu-packaging-pipeline.yml | 3 ++- 3 files changed, 10 insertions(+), 9 deletions(-) diff --git a/tools/ci_build/github/azure-pipelines/templates/c-api-linux-cpu.yml b/tools/ci_build/github/azure-pipelines/templates/c-api-linux-cpu.yml index f1599b6843fb5..3f34e7ae37538 100644 --- a/tools/ci_build/github/azure-pipelines/templates/c-api-linux-cpu.yml +++ b/tools/ci_build/github/azure-pipelines/templates/c-api-linux-cpu.yml @@ -46,12 +46,12 @@ jobs: clean: true submodules: none - - task: UsePythonVersion@0 - displayName: Use Python 3.12 - inputs: - versionSpec: 3.12 - ${{ if eq(parameters.OnnxruntimeArch, 'aarch64') }}: - architecture: arm64 + - ${{ if eq(parameters.OnnxruntimeArch, 'x64') }}: + # Only need to install Python on x64 agents as Python is pre-installed on arm64 agents + - task: UsePythonVersion@0 + displayName: Use Python 3.12 + inputs: + versionSpec: 3.12 - template: set-version-number-variables-step.yml - ${{ if eq(parameters.OnnxruntimeArch, 'x64') }}: diff --git a/tools/ci_build/github/azure-pipelines/templates/java-api-artifacts-package-and-publish-steps-posix.yml b/tools/ci_build/github/azure-pipelines/templates/java-api-artifacts-package-and-publish-steps-posix.yml index 166b03f6b55e1..749b6093cf9d3 100644 --- a/tools/ci_build/github/azure-pipelines/templates/java-api-artifacts-package-and-publish-steps-posix.yml +++ b/tools/ci_build/github/azure-pipelines/templates/java-api-artifacts-package-and-publish-steps-posix.yml @@ -28,12 +28,13 @@ parameters: displayName: Architecture type: string #default: 'linux-x64' - + steps: - task: PythonScript@0 inputs: scriptSource: 'filePath' scriptPath: 'tools/ci_build/linux_java_copy_strip_binary.py' + pythonInterpreter: 'python3' arguments: >- --binary-dir $(Build.BinariesDirectory) --build-config ${{parameters.buildConfig}} @@ -47,4 +48,3 @@ steps: inputs: targetPath: '$(Build.BinariesDirectory)/${{parameters.artifactName}}' artifactName: 'drop-${{parameters.artifactName}}' - diff --git a/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packaging-pipeline.yml index 907563cb77242..0bc0a94fdd6e3 100644 --- a/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packaging-pipeline.yml @@ -53,6 +53,7 @@ stages: - task: UsePythonVersion@0 inputs: versionSpec: '3.13' + architecture: arm64 addToPath: true - script: | @@ -77,7 +78,7 @@ stages: cd temp find $(Build.ArtifactStagingDirectory) -name '*.zip' -exec unzip {} \; rm -rf $(Build.ArtifactStagingDirectory)/*; - find . -type d -name 'onnxruntime-osx-*' -exec tar -czf {}.tgz {} \; + find . -type d -name 'onnxruntime-osx-*' -exec tar -czf {}.tgz {} \; ls -l mv *.tgz $(Build.ArtifactStagingDirectory) displayName: 'Unzip Signed Files and Repackage to TGZ' From e6023b0c6067470b60789478680c501128cf2a8b Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Wed, 19 Nov 2025 13:56:37 -0800 Subject: [PATCH 02/17] Create a CUDA based memory arena instead of Cuda Allocator wrapped into BFCArena (#26535) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Description This change allows users to better control GPU memory in shared environments with multiple tenants or multiple inference sessions per process. Cuda based memory pool features native allocations on streams, allows trimming the memory on Shrink if enabled and releases memory back to the system based on the user specified parameters. In my limited testing latencies were comparable with running on BFCArena, although your milage and requirements may vary. CudaMemoryPoolArena is enabled via OrtArenaCfg with introducing 3 new parameters: - `use_cuda_mempool` set to 1 to enable - `cuda_mempool_release_threshold` amount of memory to keep cached - `cuda_mempool_bytes_to_keep_on_shrink` the amount of memory to keep on Shrink when being trimmed, allocated memory is not affected. ### Motivation and Context Better GPU memory control in multitenant environments. There are some new options for `onnxruntime_perf_test` introduced in this PR so they may assist clients to figure out the best settings for they case: - `--enable_cuda_mempool 209715200;1048576` with first parameter being `cuda_mempool_release_threshold`. The second `cuda_mempool_bytes_to_keep_on_shrink` can be zero if shrink is not enabled. - `--shrink_arena_between_runs gpu:0` measure perf and memory consumption with shrink. This new allocator strictly speaking does not need `Shrink()` since cuda mempool may release memory on the go according to `cuda_mempool_release_threshold`. Here is some performance numbers gathered when running HF_Bart model. If the CudaMempool release threshold is set too low, latency increases because the system ends up constantly allocating and releasing memory. But as we raise the threshold and allow more memory to stay allocated, latency improves—and we end up using only about half as much memory between runs compared to BFCArena. Running default setup with BFCArena > onnxruntime_perf_test -s -e cuda -I -S 10 -m times -r 100 "hf_Bart_torchscript.onnx" Average inference time cost total: 66.493545 ms P99 Latency: 0.0805385 s Total memory allocated: 1,409,286,144 200 MB release threshold > onnxruntime_perf_test -s -e cuda --enable_cuda_mempool 209715200;0 -I -S 10 -m times -r 100 hf_Bart_torchscript.onnx Average inference time cost total: 77.367473 ms P99 Latency: 0.0931895 s 0.5Gb release threshold > onnxruntime_perf_test -s -e cuda --enable_cuda_mempool 536870912;0 -I -S 10 -m times -r 100 hf_Bart_torchscript.onnx Average inference time cost total: 75.112840 ms P99 Latency: 0.0910992 s 1Gb release threshold > onnxruntime_perf_test -s -e cuda --enable_cuda_mempool 1073741824;0 -I -S 10 -m times -r 100 hf_Bart_torchscript.onnx Average inference time cost total: 66.533892 ms P99 Latency: 0.0761336 s Enabling shrink show we’re retaining only half the memory compared to BFCArena in between inference runs. >CudaMempoolArena::Shrink: pool current_in_use: 709,603,688 reserved size after trim : 738,197,504 bytes. --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- .../onnxruntime/core/framework/allocator.h | 19 ++ onnxruntime/core/framework/allocator.cc | 22 ++ onnxruntime/core/framework/allocator_utils.cc | 14 +- onnxruntime/core/framework/bfc_arena.cc | 28 +- onnxruntime/core/framework/bfc_arena.h | 48 ++- .../framework/device_stream_collection.cc | 16 +- .../providers/cuda/cuda_execution_provider.cc | 86 +++-- .../providers/cuda/cuda_execution_provider.h | 13 +- .../core/providers/cuda/cuda_mempool_arena.cc | 272 ++++++++++++++++ .../core/providers/cuda/cuda_mempool_arena.h | 177 ++++++++++ .../providers/cuda/cuda_provider_factory.cc | 8 +- onnxruntime/core/session/inference_session.cc | 2 +- onnxruntime/core/session/onnxruntime_c_api.cc | 6 + onnxruntime/test/framework/bfc_arena_test.cc | 4 +- .../test/framework/inference_session_test.cc | 23 +- .../test/framework/session_state_test.cc | 2 +- .../test/perftest/command_args_parser.cc | 37 ++- onnxruntime/test/perftest/ort_test_session.cc | 38 ++- onnxruntime/test/perftest/ort_test_session.h | 6 +- .../test/perftest/test_configuration.h | 7 + .../providers/cuda/cuda_mempool_arena_test.cc | 306 ++++++++++++++++++ 21 files changed, 1024 insertions(+), 110 deletions(-) create mode 100644 onnxruntime/core/providers/cuda/cuda_mempool_arena.cc create mode 100644 onnxruntime/core/providers/cuda/cuda_mempool_arena.h create mode 100644 onnxruntime/test/providers/cuda/cuda_mempool_arena_test.cc diff --git a/include/onnxruntime/core/framework/allocator.h b/include/onnxruntime/core/framework/allocator.h index 24cc460a17fa9..983be1f9efd5c 100644 --- a/include/onnxruntime/core/framework/allocator.h +++ b/include/onnxruntime/core/framework/allocator.h @@ -38,6 +38,13 @@ struct OrtArenaCfg { int max_dead_bytes_per_chunk; // use -1 to allow ORT to choose the default int initial_growth_chunk_size_bytes; // use -1 to allow ORT to choose the default int64_t max_power_of_two_extend_bytes; // use -1 to allow ORT to choose the default + // Use CudaMemPool based arena if available (starting with cuda 11.2) + int use_cuda_mempool = -1; + // Amount of reserved memory in bytes to hold onto before trying + // to release memory back to the OS. + uint64_t cuda_mempool_release_threshold = 0; + // Bytes to keep on shrink for CudaMemPool, 0 is to attempt to release all, allocated space not affected. + size_t cuda_mempool_bytes_to_keep_on_shrink = 0; bool IsValid() { return arena_extend_strategy >= -1 && arena_extend_strategy <= 1 && @@ -55,6 +62,9 @@ struct OrtArenaCfg { static constexpr const char* InitialGrowthChunkSizeBytes = "arena.initial_growth_chunk_size_bytes"; static constexpr const char* MaxPowerOfTwoExtendBytes = "arena.max_power_of_two_extend_bytes"; static constexpr const char* MaxMem = "arena.max_mem"; + static constexpr const char* UseCudaMemPool = "arena.use_cuda_mempool"; + static constexpr const char* CudaMempoolReleaseThreshold = "arena.cuda_mempool_release_threshold"; + static constexpr const char* CudaMempoolBytesToKeepOnShrink = "arena.cuda_mempool_bytes_to_keep_on_shrink"; }; static onnxruntime::common::Status FromKeyValuePairs(const OrtKeyValuePairs& kvps, OrtArenaCfg& cfg); @@ -348,4 +358,13 @@ void AllocatorDefaultFree(void* p); void* AllocatorDefaultAllocAligned(size_t size, size_t alignment); void AllocatorDefaultFreeAligned(void* p, size_t alignment); +class IArena : public IAllocator { + public: + using IAllocator::IAllocator; + virtual Status Shrink() = 0; + // Only implemented when IsStreamAware() returns true + virtual void ReleaseStreamBuffers(Stream* /*stream*/) {} + static IArena* SafeArenaCast(IAllocator* allocator); +}; + } // namespace onnxruntime diff --git a/onnxruntime/core/framework/allocator.cc b/onnxruntime/core/framework/allocator.cc index 91b5b811a3529..a656abb098911 100644 --- a/onnxruntime/core/framework/allocator.cc +++ b/onnxruntime/core/framework/allocator.cc @@ -58,6 +58,18 @@ Status OrtArenaCfg::FromKeyValuePairs(const OrtKeyValuePairs& kvps, OrtArenaCfg& ORT_RETURN_IF_ERROR(from_string(it->first, it->second, cfg.max_mem)); } + if (auto it = kvps_entries.find(ConfigKeyNames::UseCudaMemPool); it != kvps_entries.end()) { + ORT_RETURN_IF_ERROR(from_string(it->first, it->second, cfg.use_cuda_mempool)); + } + + if (auto it = kvps_entries.find(ConfigKeyNames::CudaMempoolReleaseThreshold); it != kvps_entries.end()) { + ORT_RETURN_IF_ERROR(from_string(it->first, it->second, cfg.cuda_mempool_release_threshold)); + } + + if (auto it = kvps_entries.find(ConfigKeyNames::CudaMempoolBytesToKeepOnShrink); it != kvps_entries.end()) { + ORT_RETURN_IF_ERROR(from_string(it->first, it->second, cfg.cuda_mempool_bytes_to_keep_on_shrink)); + } + if (!cfg.IsValid()) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid arena configuration. Please check the values provided."); @@ -177,6 +189,16 @@ void* AllocateBufferWithOptions(IAllocator& alloc, size_t size, bool use_reserve return alloc.Alloc(size); } + +IArena* IArena::SafeArenaCast(IAllocator* allocator) { +#if !defined(ORT_NO_RTTI) + auto* result = dynamic_cast(allocator); + return result; +#else + return static_cast(allocator); +#endif +} + } // namespace onnxruntime std::ostream& operator<<(std::ostream& out, const OrtMemoryInfo& info) { return (out << info.ToString()); } diff --git a/onnxruntime/core/framework/allocator_utils.cc b/onnxruntime/core/framework/allocator_utils.cc index 8c4e74c4b1cc7..ee9cf5bb39ca0 100644 --- a/onnxruntime/core/framework/allocator_utils.cc +++ b/onnxruntime/core/framework/allocator_utils.cc @@ -52,14 +52,14 @@ AllocatorPtr CreateAllocator(const AllocatorCreationInfo& info) { if (info.use_stream_aware_arena) { #ifdef ORT_ENABLE_STREAM return AllocatorPtr( - std::make_unique(std::move(device_allocator), - max_mem, - arena_extend_str, - initial_chunk_size_bytes, - max_dead_bytes_per_chunk, - initial_growth_chunk_size_bytes)); + std::make_unique(std::move(device_allocator), + max_mem, + arena_extend_str, + initial_chunk_size_bytes, + max_dead_bytes_per_chunk, + initial_growth_chunk_size_bytes)); #else - ORT_THROW("StreamAwareArena should be transparent to minimal build."); + ORT_THROW("StreamAwareBFCArena should be transparent to minimal build."); #endif } else { return AllocatorPtr( diff --git a/onnxruntime/core/framework/bfc_arena.cc b/onnxruntime/core/framework/bfc_arena.cc index 3a5af42d03cdd..cfe155986eff2 100644 --- a/onnxruntime/core/framework/bfc_arena.cc +++ b/onnxruntime/core/framework/bfc_arena.cc @@ -13,11 +13,10 @@ BFCArena::BFCArena(std::unique_ptr resource_allocator, int max_dead_bytes_per_chunk, int initial_growth_chunk_size_bytes, int64_t max_power_of_two_extend_bytes) - : IAllocator(OrtMemoryInfo(resource_allocator->Info().name.c_str(), - OrtAllocatorType::OrtArenaAllocator, - resource_allocator->Info().device, - resource_allocator->Info().mem_type)), - arena_type_(ArenaType::BaseArena), + : IArena(OrtMemoryInfo(resource_allocator->Info().name.c_str(), + OrtAllocatorType::OrtArenaAllocator, + resource_allocator->Info().device, + resource_allocator->Info().mem_type)), device_allocator_(std::move(resource_allocator)), free_chunks_list_(kInvalidChunkHandle), next_allocation_id_(1), @@ -827,13 +826,13 @@ void BFCArena::ResetChunkOnTargetStream(Stream* target_stream, bool coalesce_fla } } -StreamAwareArena::StreamAwareArena(std::unique_ptr resource_allocator, - size_t total_memory, - ArenaExtendStrategy arena_extend_strategy, - int initial_chunk_size_bytes, - int max_dead_bytes_per_chunk, - int initial_growth_chunk_size_bytes, - int64_t max_power_of_two_extend_bytes) +StreamAwareBFCArena::StreamAwareBFCArena(std::unique_ptr resource_allocator, + size_t total_memory, + ArenaExtendStrategy arena_extend_strategy, + int initial_chunk_size_bytes, + int max_dead_bytes_per_chunk, + int initial_growth_chunk_size_bytes, + int64_t max_power_of_two_extend_bytes) : BFCArena(std::move(resource_allocator), total_memory, arena_extend_strategy, @@ -841,14 +840,13 @@ StreamAwareArena::StreamAwareArena(std::unique_ptr resource_allocato max_dead_bytes_per_chunk, initial_growth_chunk_size_bytes, max_power_of_two_extend_bytes) { - arena_type_ = ArenaType::StreamAwareArena; } -void* StreamAwareArena::AllocOnStream(size_t size, Stream* current_stream) { +void* StreamAwareBFCArena::AllocOnStream(size_t size, Stream* current_stream) { return AllocateRawInternal(size, false, current_stream); } -void StreamAwareArena::ReleaseStreamBuffers(Stream* stream) { +void StreamAwareBFCArena::ReleaseStreamBuffers(Stream* stream) { // since chunks on target stream will be reset to nullptr, trigger coalesce to see whether we can get bigger chunk. ResetChunkOnTargetStream(stream, true); } diff --git a/onnxruntime/core/framework/bfc_arena.h b/onnxruntime/core/framework/bfc_arena.h index f3c0544124241..e3494853f7064 100644 --- a/onnxruntime/core/framework/bfc_arena.h +++ b/onnxruntime/core/framework/bfc_arena.h @@ -43,7 +43,7 @@ namespace onnxruntime { #endif #endif -class StreamAwareArena; +class StreamAwareBFCArena; // A memory allocator that implements a 'best-fit with coalescing' // algorithm. This is essentially a very simple version of Doug Lea's // malloc (dlmalloc). @@ -52,7 +52,7 @@ class StreamAwareArena; // coalescing. One assumption we make is that the process using this // allocator owns pretty much all of the memory, and that nearly // all requests to allocate memory go through this interface. -class BFCArena : public IAllocator { +class BFCArena : public IArena { public: static const ArenaExtendStrategy DEFAULT_ARENA_EXTEND_STRATEGY = ArenaExtendStrategy::kNextPowerOfTwo; static const int DEFAULT_INITIAL_CHUNK_SIZE_BYTES = 1 * 1024 * 1024; @@ -61,11 +61,6 @@ class BFCArena : public IAllocator { static const int64_t DEFAULT_MAX_POWER_OF_TWO_EXTEND_BYTES = 1024 * 1024 * 1024; // 1GB static const size_t DEFAULT_MAX_MEM = std::numeric_limits::max(); - enum ArenaType { - BaseArena, - StreamAwareArena, - }; - BFCArena(std::unique_ptr resource_allocator, size_t total_memory, ArenaExtendStrategy arena_extend_strategy = DEFAULT_ARENA_EXTEND_STRATEGY, @@ -84,14 +79,6 @@ class BFCArena : public IAllocator { // If p is NULL, no operation is performed. void Free(void* p) override; - // Frees all allocation regions in which no chunk is in use. - // Does not free any reserved chunks. - // Resets the size that the arena will grow by in the next allocation to - // `initial_growth_chunk_size_bytes_` but ultimately all - // future allocation sizes are determined by the arena growth strategy - // and the allocation request. - Status Shrink(); - void* Reserve(size_t size) override; void GetStats(AllocatorStats* stats) override; @@ -100,7 +87,13 @@ class BFCArena : public IAllocator { size_t AllocatedSize(const void* ptr); - ArenaType GetArenaType() const { return arena_type_; } + // Frees all allocation regions in which no chunk is in use. + // Does not free any reserved chunks. + // Resets the size that the arena will grow by in the next allocation to + // `initial_growth_chunk_size_bytes_` but ultimately all + // future allocation sizes are determined by the arena growth strategy + // and the allocation request. + Status Shrink() override; protected: void* AllocateRawInternal(size_t num_bytes, @@ -112,7 +105,6 @@ class BFCArena : public IAllocator { // perform coalesce if coalesce_flag is true void ResetChunkOnTargetStream(Stream* target_stream, bool coalesce_flag); #endif - ArenaType arena_type_; private: void DeallocateRawInternal(void* ptr); @@ -510,26 +502,22 @@ class BFCArena : public IAllocator { }; #ifdef ORT_ENABLE_STREAM -class StreamAwareArena : public BFCArena { +class StreamAwareBFCArena : public BFCArena { public: - StreamAwareArena(std::unique_ptr resource_allocator, - size_t total_memory, - ArenaExtendStrategy arena_extend_strategy = DEFAULT_ARENA_EXTEND_STRATEGY, - int initial_chunk_size_bytes = DEFAULT_INITIAL_CHUNK_SIZE_BYTES, - int max_dead_bytes_per_chunk = DEFAULT_MAX_DEAD_BYTES_PER_CHUNK, - int initial_growth_chunk_size_bytes = DEFAULT_INITIAL_GROWTH_CHUNK_SIZE_BYTES, - int64_t max_power_of_two_extend_bytes = DEFAULT_MAX_POWER_OF_TWO_EXTEND_BYTES); + StreamAwareBFCArena(std::unique_ptr resource_allocator, + size_t total_memory, + ArenaExtendStrategy arena_extend_strategy = DEFAULT_ARENA_EXTEND_STRATEGY, + int initial_chunk_size_bytes = DEFAULT_INITIAL_CHUNK_SIZE_BYTES, + int max_dead_bytes_per_chunk = DEFAULT_MAX_DEAD_BYTES_PER_CHUNK, + int initial_growth_chunk_size_bytes = DEFAULT_INITIAL_GROWTH_CHUNK_SIZE_BYTES, + int64_t max_power_of_two_extend_bytes = DEFAULT_MAX_POWER_OF_TWO_EXTEND_BYTES); bool IsStreamAware() const override { return true; } // Standard alloc behavior. Returns valid pointer if size > 0 and memory was available. Otherwise returns nullptr. void* AllocOnStream(size_t size, Stream* current_stream_id) override; - void ReleaseStreamBuffers(Stream* stream); - - static StreamAwareArena* FromBFCArena(BFCArena& arena) { - return arena.GetArenaType() == ArenaType::StreamAwareArena ? reinterpret_cast(&arena) : nullptr; - } + void ReleaseStreamBuffers(Stream* stream) override; }; #endif #ifdef __GNUC__ diff --git a/onnxruntime/core/framework/device_stream_collection.cc b/onnxruntime/core/framework/device_stream_collection.cc index 8d15e03c2e5ce..a32973ddb8c9e 100644 --- a/onnxruntime/core/framework/device_stream_collection.cc +++ b/onnxruntime/core/framework/device_stream_collection.cc @@ -21,7 +21,8 @@ struct DummyStream : Stream { class DeviceStreamCollectionImpl { public: - DeviceStreamCollectionImpl(size_t num_streams, const AllocatorMap& allocators, bool is_main_graph) : num_streams_(num_streams), allocators_(allocators), is_main_graph_(is_main_graph) { + DeviceStreamCollectionImpl(size_t num_streams, const AllocatorMap& allocators, bool is_main_graph) + : num_streams_(num_streams), allocators_(allocators), is_main_graph_(is_main_graph) { device_streams_.resize(num_streams, nullptr); owned_streams_.reserve(num_streams); root_stream_ = std::make_unique(nullptr, root_stream_device_); @@ -32,13 +33,16 @@ class DeviceStreamCollectionImpl { void ReleaseSingleStreamBuffers(Stream* stream) { if (!stream) return; - for (auto it : allocators_) { + for (const auto& it : allocators_) { if (it.second->Info().device == stream->GetDevice() && it.second->Info().alloc_type == OrtArenaAllocator) { - auto* arena_alloc = static_cast(it.second.get()); - auto* stream_aware_alloc = StreamAwareArena::FromBFCArena(*arena_alloc); - if (stream_aware_alloc) { - stream_aware_alloc->ReleaseStreamBuffers(stream); + if (it.second->IsStreamAware()) { + // Previously we only had one StreamAwareBFCArena. We need to guard + // against multiple allocators now. + auto* arena_alloc = IArena::SafeArenaCast(it.second.get()); + if (arena_alloc) { + arena_alloc->ReleaseStreamBuffers(stream); + } } } } diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index 3816cc1f8f6b9..eff0801a00460 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -14,6 +14,7 @@ #include "core/providers/cuda/cuda_fwd.h" #include "core/providers/cuda/gpu_data_transfer.h" #include "core/providers/cuda/cuda_profiler.h" +#include "core/providers/cuda/cuda_mempool_arena.h" #include "core/session/onnxruntime_run_options_config_keys.h" #ifndef USE_CUDA_MINIMAL @@ -134,11 +135,10 @@ ONNX_OPERATOR_KERNEL_EX( } // namespace cuda -AllocatorPtr CUDAExecutionProvider::CreateCudaAllocator(OrtDevice::DeviceId device_id, - size_t gpu_mem_limit, - ArenaExtendStrategy arena_extend_strategy, - CUDAExecutionProviderExternalAllocatorInfo external_allocator_info, - const OrtArenaCfg* default_memory_arena_cfg) { +AllocatorPtr CUDAExecutionProvider::CreateCudaAllocator(const CUDAAllocatorParams& cuda_allocator_params) { + ORT_ENFORCE(cuda_allocator_params.external_alloc_info != nullptr, + "CUDAAllocatorParams.external_alloc_info is nullptr."); + const auto& external_allocator_info = *(cuda_allocator_params.external_alloc_info); if (external_allocator_info.UseExternalAllocator()) { AllocatorCreationInfo default_memory_info( [external_allocator_info](OrtDevice::DeviceId id) { @@ -147,24 +147,59 @@ AllocatorPtr CUDAExecutionProvider::CreateCudaAllocator(OrtDevice::DeviceId devi external_allocator_info.free, external_allocator_info.empty_cache); }, - device_id, + cuda_allocator_params.device_id, false); return CreateAllocator(default_memory_info); } else { - AllocatorCreationInfo default_memory_info( - [](OrtDevice::DeviceId id) { - return std::make_unique(id, CUDA); - }, - device_id, - true, - {default_memory_arena_cfg ? *default_memory_arena_cfg - : OrtArenaCfg(gpu_mem_limit, static_cast(arena_extend_strategy), -1, -1, -1, -1L)}, - // make it stream aware - true); - - // CUDA malloc/free is expensive so always use an arena - return CreateAllocator(default_memory_info); + const auto* arena_cfg = cuda_allocator_params.arena_cfg; + const bool cuda_mempool_requested = arena_cfg != nullptr && arena_cfg->use_cuda_mempool == 1; + bool use_cuda_mempool = cuda_mempool_requested && cuda::CudaMempoolArena::IsCudaVersionSupported(); + + if (cuda_mempool_requested && !use_cuda_mempool) { + LOGS_DEFAULT(WARNING) + << "CUDA memory pool requested but not supported on this device/driver." + << "Falling back to default BFCArena with CUDA allocator."; + } + + if (use_cuda_mempool) { + const bool cuda_graph_enabled = cuda_allocator_params.provider_info != nullptr && + cuda_allocator_params.provider_info->enable_cuda_graph; + + if (cuda_graph_enabled) { + LOGS_DEFAULT(WARNING) + << "CUDA Mempool Arena allocator is not compatible with requested CUDA Graph Capture" + << "Falling back to default BFCArena with CUDA allocator."; + use_cuda_mempool = false; + } + } + + if (use_cuda_mempool) { + auto device = OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NVIDIA, + cuda_allocator_params.device_id); + auto mem_info = OrtMemoryInfo("CUDAMemPoolArena", OrtAllocatorType::OrtArenaAllocator, device, OrtMemTypeDefault); + + auto mempool_allocator = std::make_shared(mem_info, + arena_cfg->cuda_mempool_release_threshold, + arena_cfg->cuda_mempool_bytes_to_keep_on_shrink, + cuda_allocator_params.logger); + + return mempool_allocator; + } else { + AllocatorCreationInfo default_memory_info( + [](OrtDevice::DeviceId id) { + return std::make_unique(id, CUDA); + }, + cuda_allocator_params.device_id, + true, + {arena_cfg ? *arena_cfg + : OrtArenaCfg(cuda_allocator_params.cuda_mem_threshold, + static_cast(cuda_allocator_params.arena_extend_strategy), -1, -1, -1, -1L)}, + // make it stream aware + true); + // CUDA malloc/free is expensive so always use an arena + return CreateAllocator(default_memory_info); + } } } @@ -3044,9 +3079,18 @@ std::vector CUDAExecutionProvider::CreatePreferredAllocators() { return std::make_unique(device_id, CUDA_PINNED); }, info_.device_id); + + CUDAExecutionProvider::CUDAAllocatorParams params{}; + params.device_id = info_.device_id; + params.cuda_mem_threshold = info_.gpu_mem_limit; + params.arena_extend_strategy = info_.arena_extend_strategy; + params.provider_info = &info_; + params.external_alloc_info = &info_.external_allocator_info; + params.arena_cfg = info_.default_memory_arena_cfg; + params.logger = GetLogger(); + return std::vector{ - CreateCudaAllocator(info_.device_id, info_.gpu_mem_limit, info_.arena_extend_strategy, - info_.external_allocator_info, info_.default_memory_arena_cfg), + CreateCudaAllocator(params), CreateAllocator(pinned_memory_info), }; } diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.h b/onnxruntime/core/providers/cuda/cuda_execution_provider.h index 57fde8146d929..751bbb90f8619 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.h +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.h @@ -103,8 +103,17 @@ class CUDAExecutionProvider : public IExecutionProvider { return CUDAExecutionProviderInfo::ToProviderOptions(info_); } - static AllocatorPtr CreateCudaAllocator(OrtDevice::DeviceId device_id, size_t cuda_mem_limit, ArenaExtendStrategy arena_extend_strategy, - CUDAExecutionProviderExternalAllocatorInfo external_alloc_info, const OrtArenaCfg* arena_cfg); + struct CUDAAllocatorParams { + OrtDevice::DeviceId device_id = 0; + size_t cuda_mem_threshold = std::numeric_limits::max(); + ArenaExtendStrategy arena_extend_strategy = ArenaExtendStrategy::kNextPowerOfTwo; + const CUDAExecutionProviderInfo* provider_info = nullptr; + const CUDAExecutionProviderExternalAllocatorInfo* external_alloc_info = nullptr; + const OrtArenaCfg* arena_cfg = nullptr; + const logging::Logger* logger = nullptr; + }; + + static AllocatorPtr CreateCudaAllocator(const CUDAAllocatorParams& cuda_allocator_params); ITuningContext* GetTuningContext() const override; diff --git a/onnxruntime/core/providers/cuda/cuda_mempool_arena.cc b/onnxruntime/core/providers/cuda/cuda_mempool_arena.cc new file mode 100644 index 0000000000000..802867ec0d89b --- /dev/null +++ b/onnxruntime/core/providers/cuda/cuda_mempool_arena.cc @@ -0,0 +1,272 @@ +// Copyright (c) Microsoft. +// Licensed under the MIT License. + +#include "cuda_mempool_arena.h" + +#include + +#include "core/providers/cuda/shared_inc/cuda_call.h" // ORT CudaCall helpers +#include "core/providers/shared_library/provider_api.h" + +namespace onnxruntime { +namespace cuda { + +// ======== CudaMempoolArena ======== + +CudaMempoolArena::CudaMempoolArena(const OrtMemoryInfo& memory_info, + uint64_t pool_release_threshold, + size_t bytes_to_keep_on_shrink, + const logging::Logger* logger) + : IArena(memory_info), + pool_release_threshold_(pool_release_threshold), + bytes_to_keep_on_shrink_(bytes_to_keep_on_shrink), + logger_(logger) { + if (logger_ == nullptr) { + logger_ = &::onnxruntime::logging::LoggingManager::DefaultLogger(); + } + + // Create a process-local device memory pool for device_id_. + // 'cudaMemAllocationTypeDevice' (for cudaMemPoolProps.allocType) not clear when it is available + + cudaMemPoolProps props{}; + // Pinned is not the same as pinned allocator, cudaMemLocationTypeDevice actually does not exist + // even though is present in some internet docs. + props.allocType = cudaMemAllocationTypePinned; + props.handleTypes = cudaMemHandleTypeNone; // local to process + props.location.type = cudaMemLocationTypeDevice; // Device memory + props.location.id = this->Info().device.Id(); + + CUDA_CALL_THROW(cudaMemPoolCreate(&pool_, &props)); + + if (pool_release_threshold_ != 0) { + CUDA_CALL_THROW(cudaMemPoolSetAttribute(pool_, cudaMemPoolAttrReleaseThreshold, + &pool_release_threshold_)); + } + + LOGS(*logger_, INFO) << "CudaMempoolArena created on device " << this->Info().device.Id() + << " with pool_release_threshold=" << pool_release_threshold_ + << " bytes_to_keep_on_shrink=" << bytes_to_keep_on_shrink_ << "."; + + // Intentionally DO NOT call cudaDeviceSetMemPool(device_id_, pool_); + // All allocations explicitly target this pool via cudaMallocFromPoolAsync. +} + +CudaMempoolArena::~CudaMempoolArena() { + // 1) Best-effort: enqueue frees for any remaining allocations on their recorded streams. + // No locking by design: destruction implies no concurrent access. + for (auto& kv : alloc_map_) { + void* p = kv.first; + const cudaStream_t s = kv.second.stream; + ORT_IGNORE_RETURN_VALUE(cudaFreeAsync(p, s)); // ignore errors in destructor + } + + // 2) Synchronize all streams we know about (those that ever held allocations). + SyncAllKnownStreams_NoThrow(); + + // Now it is safe to drop our bookkeeping. + alloc_map_.clear(); + stream_map_.clear(); + + // 3) Safety barrier: ensure any frees enqueued on destroyed/unknown streams are completed. + ORT_IGNORE_RETURN_VALUE(cudaDeviceSynchronize()); // ignore errors in destructor + + // 4) Trim to zero and destroy the pool. + if (pool_) { + ORT_IGNORE_RETURN_VALUE(cudaMemPoolTrimTo(pool_, 0)); // best-effort + ORT_IGNORE_RETURN_VALUE(cudaMemPoolDestroy(pool_)); + pool_ = nullptr; + } +} + +void* CudaMempoolArena::Alloc(size_t size) { + if (size == 0) return nullptr; + + void* p = nullptr; + constexpr const cudaStream_t kDefaultStream = static_cast(0); + cudaError_t err = cudaMallocFromPoolAsync(&p, size, pool_, kDefaultStream); + if (err != cudaSuccess) { + ORT_THROW("CudaMempoolArena::Alloc: cudaMallocFromPoolAsync failed: ", + cudaGetErrorString(err), " (", static_cast(err), "), size=", size); + } + + LOGS(*logger_, VERBOSE) << "CudaMempoolArena::Alloc: allocated " + << size << " bytes at " << p << " on default stream."; + + // In case the default stream is busy. + ORT_IGNORE_RETURN_VALUE(cudaStreamSynchronize(kDefaultStream)); + + { + std::lock_guard lock(mutex_); + AllocationRecord rec{size, kDefaultStream}; + alloc_map_.emplace(p, rec); + stream_map_[kDefaultStream].insert(p); + + total_allocated_ += size; + in_use_bytes_ += size; + max_bytes_in_use_ = std::max(max_bytes_in_use_, in_use_bytes_); + max_alloc_size_ = std::max(max_alloc_size_, size); + ++num_allocs_; + } + + return p; +} + +void* CudaMempoolArena::AllocOnStream(size_t size, Stream* stream) { + if (size == 0) return nullptr; + + void* p = nullptr; + const cudaStream_t s = ResolveCudaStream(stream); + + cudaError_t err = cudaMallocFromPoolAsync(&p, size, pool_, s); + if (err != cudaSuccess) { + ORT_THROW("CudaMempoolArena::AllocOnStream: cudaMallocFromPoolAsync failed on stream=", + reinterpret_cast(s), ": ", + cudaGetErrorString(err), " (", static_cast(err), "), size=", size); + } + + LOGS(*logger_, VERBOSE) << "CudaMempoolArena::AllocOnStream: allocated " + << size << " bytes at " << p << " on stream " + << reinterpret_cast(s) << "."; + + { + std::lock_guard lock(mutex_); + AllocationRecord rec{size, s}; + alloc_map_.emplace(p, rec); + stream_map_[s].insert(p); + + total_allocated_ += size; + in_use_bytes_ += size; + max_bytes_in_use_ = std::max(max_bytes_in_use_, in_use_bytes_); + max_alloc_size_ = std::max(max_alloc_size_, size); + ++num_allocs_; + } + + return p; +} + +void CudaMempoolArena::Free(void* p) { + if (!p) return; + + cudaStream_t s = static_cast(0); + size_t sz = 0; + + { + std::lock_guard lock(mutex_); + auto it = alloc_map_.find(p); + if (it == alloc_map_.end()) { + // Not owned by this allocator; ignore per ORT convention. + LOGS(*logger_, WARNING) << "CudaMempoolArena::Free: pointer " + << p << " not found in allocation map; ignoring."; + return; + } + + s = it->second.stream; + sz = it->second.bytes; + + alloc_map_.erase(it); + + auto sit = stream_map_.find(s); + if (sit != stream_map_.end()) { + sit->second.erase(p); + if (sit->second.empty()) { + stream_map_.erase(sit); + } + } + + in_use_bytes_ = (sz <= in_use_bytes_) ? (in_use_bytes_ - sz) : 0; + } + + // Ordered free on the stream that allocated p + CUDA_CALL_THROW(cudaFreeAsync(p, s)); +} + +Status CudaMempoolArena::Shrink() { + // Trim the pool; live allocations are not affected. + ORT_RETURN_IF_ERROR(CUDA_CALL(cudaMemPoolTrimTo(pool_, bytes_to_keep_on_shrink_))); + + size_t current_in_use = 0; + ORT_IGNORE_RETURN_VALUE(CUDA_CALL(cudaMemPoolGetAttribute(pool_, cudaMemPoolAttrUsedMemCurrent, + ¤t_in_use))); + + // Query current reserved size. cudaMemPoolAttrReservedMemCurrent + size_t reserved_size = 0; + if (CUDA_CALL(cudaMemPoolGetAttribute(pool_, cudaMemPoolAttrReservedMemCurrent, + &reserved_size)) + .IsOK()) { + LOGS(*logger_, INFO) << "CudaMempoolArena::Shrink: pool current_in_use: " << current_in_use + << " reserved size after trim : " << reserved_size << " bytes."; + } else { + LOGS(*logger_, INFO) << "CudaMempoolArena pool has been shrunk; unable to query reserved size."; + } + + // Right-size maps under lock. + std::lock_guard lock(mutex_); + MaybeRehashLocked(); + ++num_arena_shrinkages_; + return Status::OK(); +} + +void CudaMempoolArena::GetStats(AllocatorStats* stats) { + if (!stats) return; + std::lock_guard lock(mutex_); + stats->num_allocs = num_allocs_; + stats->total_allocated_bytes = total_allocated_; + stats->bytes_in_use = in_use_bytes_; + stats->max_bytes_in_use = max_bytes_in_use_; + stats->num_arena_shrinkages = num_arena_shrinkages_; +} + +cudaStream_t CudaMempoolArena::ResolveCudaStream(Stream* stream) noexcept { + if (!stream) return static_cast(0); + return static_cast(stream->GetHandle()); +} + +void CudaMempoolArena::MaybeRehashLocked() { + const size_t alloc_sz = alloc_map_.size(); + const size_t stream_sz = stream_map_.size(); + if (alloc_sz > 0) alloc_map_.reserve(alloc_sz); + if (stream_sz > 0) stream_map_.reserve(stream_sz); +} + +void CudaMempoolArena::SyncAllKnownStreams_NoThrow() { + for (const auto& kv : stream_map_) { + const cudaStream_t s = kv.first; + ORT_IGNORE_RETURN_VALUE(cudaStreamSynchronize(s)); // ignore errors; device-wide sync follows + } +} + +bool CudaMempoolArena::IsCudaVersionSupported() noexcept { + int ort_cuda_rt_version = 0; + cudaError_t cuda_status = cudaRuntimeGetVersion(&ort_cuda_rt_version); + if (cuda_status != cudaSuccess) { + return false; + } + + if (ort_cuda_rt_version < 11020) { + return false; + } + + int ort_cuda_driver_version = 0; + cuda_status = cudaDriverGetVersion(&ort_cuda_driver_version); + if (cuda_status != cudaSuccess) { + return false; + } + + if (ort_cuda_driver_version < 11020) { + return false; + } + + // Check if the driver version supports the runtime version + if (ort_cuda_rt_version >= 12000 && ort_cuda_driver_version < 12000) { + return false; + } + + if (ort_cuda_rt_version >= 13000 && ort_cuda_driver_version < 13000) { + return false; + } + + return true; +} + +} // namespace cuda +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/cuda_mempool_arena.h b/onnxruntime/core/providers/cuda/cuda_mempool_arena.h new file mode 100644 index 0000000000000..750cbaf93b6d4 --- /dev/null +++ b/onnxruntime/core/providers/cuda/cuda_mempool_arena.h @@ -0,0 +1,177 @@ +// Copyright (c) Microsoft. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include + +#include "core/common/common.h" // ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE, ORT_THROW/ENFORCE +#include "core/common/inlined_containers.h" // InlinedHashMap, InlinedHashSet, InlinedVector +#include "core/providers/cuda/cuda_stream_handle.h" // ORT Stream -> cudaStream_t +#include "core/providers/shared_library/provider_api.h" + +namespace onnxruntime { +namespace logging { +class Logger; +} +namespace cuda { +/** + * @brief Stream-aware CUDA allocator implemented on top of a private `cudaMemPool_t`. + * The purpose of this arena is to assist with memory allocations in environments where + * a single process is hosting more than one cuda session. This arena hosts cuda memory pool + * which has some tunable parameters to control its memory usage and de-allocates memory back to + * the device according to the specified params. This is opposite to the BFCArena which only + * attempts to free memory on Shrink() if configured at the end of the run. + * + * ### Behavior + * - Creates a **process-local** CUDA mempool for a specific device (from `OrtMemoryInfo`). + * - All allocations use **`cudaMallocFromPoolAsync()`** on either the legacy default stream (0) or a + * caller-provided stream. The allocation stream is recorded for ordered free. + * - `Free()` enqueue **`cudaFreeAsync()`** on the recorded stream to + * respect CUDA's stream-ordered semantics. + * - `Shrink()` trims the pool with **`cudaMemPoolTrimTo(bytes_to_keep)`** and right-sizes the book-keeping maps + * under lock. + * + * ### Tuning + * - `pool_release_threshold`: if non-zero, sets `cudaMemPoolAttrReleaseThreshold`. **Recommended: 1 MB.**, but + * must be experimentally determined based on workload for optimal memory consumption vs performance. + * `cudaMemPoolAttrReservedMemCurrent`. **Recommended: 10 MB.** + * - `bytes_to_keep_on_shrink`: target size for `cudaMemPoolTrimTo()` on `Shrink()`. This is only relevant + * if Shrink() is enabled. It usually costs performance, and strictly speaking is not necessary for cuda mempools + * since they release memory on at synchronous points according to `pool_release_threshold`. + * + * ### Thread-safety + * - All updates to internal maps and statistics are guarded by an internal `std::mutex`. + * + * @note The allocator **does not** set the device default mempool and **does not** switch the current device. + */ +class CudaMempoolArena final : public IArena { + public: + /** + * @brief Construct a `CudaMempoolArena` with a private CUDA mempool. + * + * @param memory_info `OrtMemoryInfo` whose device id selects the CUDA device. + * @param pool_release_threshold Optional release threshold (bytes) for `cudaMemPoolAttrReleaseThreshold`. + * If 0, the attribute is not set. **Recommended value: 1 MB.** + * @param bytes_to_keep_on_shrink Target size (bytes) for `cudaMemPoolTrimTo()` on `Shrink()`. + * @param logger Cuda EP Logger + * + * The created pool is process-local and is **not** set as the device default pool. + */ + CudaMempoolArena(const OrtMemoryInfo& memory_info, + uint64_t pool_release_threshold, + size_t bytes_to_keep_on_shrink, + const logging::Logger* logger); + + /** + * @brief Destructor: + * 1) Enqueues cudaFreeAsync() for any outstanding allocations. + * 2) Synchronizes all known streams (best-effort; ignores invalid handles). + * 3) Calls cudaDeviceSynchronize() as a final barrier to ensure queued frees complete. + * 4) Trims pool to zero and destroys it. + */ + ~CudaMempoolArena() override; + + // -------- IAllocator overrides -------- + + /** + * @brief Allocate @p size bytes using the legacy default CUDA stream (0). + * @return device pointer or nullptr when size == 0 + * @throws on allocation failure + */ + void* Alloc(size_t size) override; + + /** + * @brief Allocate @p size bytes on the given ORT stream (uses `cudaMallocFromPoolAsync`). + * @return device pointer or nullptr when size == 0 + * @throws on allocation failure + */ + void* AllocOnStream(size_t size, Stream* stream) override; + + /** + * @brief Enqueue an ordered async free on the stream that allocated @p p. + * No-op if @p p is null or not owned by this allocator. + */ + void Free(void* p) override; + + /** + * @brief Reserve @p size bytes; implemented in terms of `Alloc(size)`. + * This is done so all the memory is gone including initializers when + * the session is torn down. + * @return device pointer or nullptr when size == 0 + * @throws on allocation failure + */ + void* Reserve(size_t size) override { return Alloc(size); } + + /// @brief This allocator is stream-aware. + bool IsStreamAware() const override { return true; } + + /// @brief Populate basic allocation statistics. + void GetStats(AllocatorStats* stats) override; + + // -------- IArena overrides -------- + + /** + * @brief Enqueue `cudaFreeAsync()` for all allocations made on @p stream. + * we intentionally do not implement this method. The call to this method + * will yank memory from under live OrtValues such as allocated for output + * bound and the resulting output OrtValue will not be valid. + * Then when the OrtValues attempt to release memory those entries are not found + * in the map: CudaMempoolArena::Free: pointer 0000000203800400 not found in allocation map; ignoring + * The reason this works with BFCArena is because it does not really release memory. + */ + // void ReleaseStreamBuffers(Stream* stream) override; + + /** + * @brief Trim the pool to `bytes_to_keep_on_shrink_` (configured at construction) using `cudaMemPoolTrimTo()`. + * Memory still allocated is not affected. Shrink() may affect your performance and contrary to BFCArena + * This allocator does not need Shrink. Cuda mempool is capable of releasing memory automatically + * according to pool_release_threshold_ set at construction. + * Also rehashes internal maps under lock to keep them reasonably sized. + */ + Status Shrink() override; + + static bool IsCudaVersionSupported() noexcept; + + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(CudaMempoolArena); + + private: + /// Convert ORT `Stream*` to native `cudaStream_t`; null means legacy default (0). + static cudaStream_t ResolveCudaStream(Stream* stream) noexcept; + + /// Rehash internal maps under lock; invoked only by `Shrink()`. + void MaybeRehashLocked(); + + /// Best-effort synchronization of all streams in stream_map_. Non-throwing; ignores errors. + void SyncAllKnownStreams_NoThrow(); + + struct AllocationRecord { + size_t bytes; + cudaStream_t stream; // stream on which allocation/free are ordered + }; + + // ---- Pool/context configuration (immutable) ---- + uint64_t pool_release_threshold_; + size_t bytes_to_keep_on_shrink_; + size_t initial_pool_size_bytes_; + const logging::Logger* logger_; + cudaMemPool_t pool_{nullptr}; + + // ---- Bookkeeping (guarded by mutex_) ---- + std::mutex mutex_; + InlinedHashMap alloc_map_; // ptr -> record + InlinedHashMap> stream_map_; // stream -> ptrs + + // ---- Stats (guarded by mutex_) ---- + size_t total_allocated_ = 0; + size_t in_use_bytes_ = 0; + size_t max_bytes_in_use_ = 0; + size_t num_allocs_ = 0; + size_t num_arena_shrinkages_ = 0; + size_t max_alloc_size_ = 0; +}; + +} // namespace cuda +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/cuda_provider_factory.cc b/onnxruntime/core/providers/cuda/cuda_provider_factory.cc index 3b361f155831b..70afba320576b 100644 --- a/onnxruntime/core/providers/cuda/cuda_provider_factory.cc +++ b/onnxruntime/core/providers/cuda/cuda_provider_factory.cc @@ -181,7 +181,13 @@ struct ProviderInfo_CUDA_Impl final : ProviderInfo_CUDA { } std::shared_ptr CreateCudaAllocator(int16_t device_id, size_t gpu_mem_limit, onnxruntime::ArenaExtendStrategy arena_extend_strategy, onnxruntime::CUDAExecutionProviderExternalAllocatorInfo& external_allocator_info, const OrtArenaCfg* default_memory_arena_cfg) override { - return CUDAExecutionProvider::CreateCudaAllocator(device_id, gpu_mem_limit, arena_extend_strategy, external_allocator_info, default_memory_arena_cfg); + CUDAExecutionProvider::CUDAAllocatorParams params{}; + params.device_id = device_id; + params.cuda_mem_threshold = gpu_mem_limit; + params.arena_extend_strategy = arena_extend_strategy; + params.external_alloc_info = &external_allocator_info; + params.arena_cfg = default_memory_arena_cfg; + return CUDAExecutionProvider::CreateCudaAllocator(params); } } g_info; diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 14f0892687ad1..4d4dea9cb444c 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -3642,7 +3642,7 @@ common::Status InferenceSession::ValidateAndParseShrinkArenaString(const std::st void InferenceSession::ShrinkMemoryArenas(gsl::span arenas_to_shrink) { for (auto& alloc : arenas_to_shrink) { - auto status = static_cast(alloc.get())->Shrink(); + auto status = static_cast(alloc.get())->Shrink(); if (!status.IsOK()) { LOGS(*session_logger_, WARNING) << "Unable to shrink arena: " << alloc->Info().ToString() diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 546b11ae580d5..4891ece8bcda3 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -2293,6 +2293,12 @@ ORT_API_STATUS_IMPL(OrtApis::CreateArenaCfgV2, _In_reads_(num_keys) const char* cfg->initial_growth_chunk_size_bytes = static_cast(arena_config_values[i]); } else if (strcmp(arena_config_keys[i], "max_power_of_two_extend_bytes") == 0) { cfg->max_power_of_two_extend_bytes = static_cast(arena_config_values[i]); + } else if (strcmp(arena_config_keys[i], "use_cuda_mempool") == 0) { + cfg->use_cuda_mempool = static_cast(arena_config_values[i]); + } else if (strcmp(arena_config_keys[i], "cuda_mempool_release_threshold") == 0) { + cfg->cuda_mempool_release_threshold = static_cast(arena_config_values[i]); + } else if (strcmp(arena_config_keys[i], "cuda_mempool_bytes_to_keep_on_shrink") == 0) { + cfg->cuda_mempool_bytes_to_keep_on_shrink = static_cast(arena_config_values[i]); } else { std::ostringstream oss; oss << "Invalid key found: " << arena_config_keys[i]; diff --git a/onnxruntime/test/framework/bfc_arena_test.cc b/onnxruntime/test/framework/bfc_arena_test.cc index 9ded9d2bfeac0..5a50998af584f 100644 --- a/onnxruntime/test/framework/bfc_arena_test.cc +++ b/onnxruntime/test/framework/bfc_arena_test.cc @@ -339,7 +339,7 @@ struct StreamMock : public Stream { #ifdef ORT_ENABLE_STREAM TEST(StreamAwareArenaTest, TwoStreamAllocation) { - StreamAwareArena a(std::unique_ptr(new CPUAllocator()), 1 << 30); + StreamAwareBFCArena a(std::unique_ptr(new CPUAllocator()), 1 << 30); CheckStats(&a, 0, 0, 0, 0); OrtDevice tmp; @@ -451,7 +451,7 @@ TEST(BFCArenaTest, TestExtendStrategy) { 0, true, config}; auto allocator = CreateAllocator(device_info); size_t block_size = 1 << 20; // 1MB - BFCArena& a = *static_cast(allocator.get()); + auto& a = *allocator; a.Alloc(block_size); AllocatorStats stats; a.GetStats(&stats); diff --git a/onnxruntime/test/framework/inference_session_test.cc b/onnxruntime/test/framework/inference_session_test.cc index 8f6ed6f55c11a..aca345fccdc01 100644 --- a/onnxruntime/test/framework/inference_session_test.cc +++ b/onnxruntime/test/framework/inference_session_test.cc @@ -388,10 +388,7 @@ void RunModelWithBindingMatMul(InferenceSession& session_object, std::vector expected_values_mul_y_2 = {174, 216, 258, 102, 128, 154, 30, 40, 50}; // Now run - st = session_object.Run(run_options, *io_binding.get()); - - std::cout << "Run returned status: " << st.ErrorMessage() << std::endl; - ASSERT_TRUE(st.IsOK()); + ASSERT_STATUS_OK(session_object.Run(run_options, *io_binding)); if ((is_preallocate_output_vec && (allocation_provider == kCudaExecutionProvider || allocation_provider == kRocmExecutionProvider || allocation_provider == kWebGpuExecutionProvider)) || (output_device && output_device->Type() == OrtDevice::GPU)) { @@ -402,21 +399,19 @@ void RunModelWithBindingMatMul(InferenceSession& session_object, auto& rtensor = outputs.front().Get(); auto element_type = rtensor.DataType(); auto& shape = rtensor.Shape(); - std::unique_ptr cpu_tensor = std::make_unique(element_type, shape, cpu_alloc); + Tensor cpu_tensor(element_type, shape, cpu_alloc); #ifdef USE_CUDA - st = GetProviderInfo_CUDA().CreateGPUDataTransfer()->CopyTensor(rtensor, *cpu_tensor.get()); + st = gpu_provider->GetDataTransfer()->CopyTensor(rtensor, cpu_tensor); #endif #ifdef USE_ROCM - st = GetProviderInfo_ROCM().CreateGPUDataTransfer()->CopyTensor(rtensor, *cpu_tensor.get()); + st = GetProviderInfo_ROCM().CreateGPUDataTransfer()->CopyTensor(rtensor, cpu_tensor); #endif #ifdef USE_WEBGPU - st = gpu_provider->GetDataTransfer()->CopyTensor(rtensor, *cpu_tensor.get()); + st = gpu_provider->GetDataTransfer()->CopyTensor(rtensor, cpu_tensor); #endif ASSERT_TRUE(st.IsOK()); OrtValue ml_value; - ml_value.Init(cpu_tensor.release(), - DataTypeImpl::GetType(), - DataTypeImpl::GetType()->GetDeleteFunc()); + Tensor::InitOrtValue(std::move(cpu_tensor), ml_value); VerifyOutputs({ml_value}, expected_output_dims, expected_values_mul_y); #endif } else { @@ -2230,7 +2225,7 @@ TEST(InferenceSessionTests, TestArenaShrinkageAfterRun) { auto cuda_alloc = session_object.GetAllocator(mem_info); AllocatorStats alloc_stats; - static_cast(cuda_alloc.get())->GetStats(&alloc_stats); + cuda_alloc->GetStats(&alloc_stats); #ifdef ENABLE_TRAINING // In training builds, initializers are allocated using the Reserve() call which // will not cause an arena extension @@ -2250,7 +2245,7 @@ TEST(InferenceSessionTests, TestArenaShrinkageAfterRun) { RunOptions run_options_1; RunModel(session_object, run_options_1); - static_cast(cuda_alloc.get())->GetStats(&alloc_stats); + cuda_alloc->GetStats(&alloc_stats); // The arena would have made 2 more extensions as part of servicing memory requests within Run() // 1) - To take the solitary feed to cuda memory @@ -2274,7 +2269,7 @@ TEST(InferenceSessionTests, TestArenaShrinkageAfterRun) { "gpu:0")); RunModel(session_object, run_options_2); - static_cast(cuda_alloc.get())->GetStats(&alloc_stats); + cuda_alloc->GetStats(&alloc_stats); // The arena would have made no extensions in this Run() as the freed memory after the first Run() // will be re-used diff --git a/onnxruntime/test/framework/session_state_test.cc b/onnxruntime/test/framework/session_state_test.cc index 9bdc0898c81c1..ed2b98e5280b5 100644 --- a/onnxruntime/test/framework/session_state_test.cc +++ b/onnxruntime/test/framework/session_state_test.cc @@ -405,7 +405,7 @@ TEST(SessionStateTest, TestInitializerMemoryAllocatedUsingNonArenaMemory) { // One reserve call should have been made (for allocating memory for the sole initializer in the model) ASSERT_EQ(1, alloc_stats.num_reserves); - // This counter comes from Reserve(). The actual call for arena based allocator went to StreamAwareArena instance + // This counter comes from Reserve(). The actual call for arena based allocator went to StreamAwareBFCArena instance ASSERT_EQ(1, alloc_stats.num_allocs); } } diff --git a/onnxruntime/test/perftest/command_args_parser.cc b/onnxruntime/test/perftest/command_args_parser.cc index 8960898f036fc..2c9377d48f0c4 100644 --- a/onnxruntime/test/perftest/command_args_parser.cc +++ b/onnxruntime/test/perftest/command_args_parser.cc @@ -15,6 +15,8 @@ #include #include #include +#include +#include #include "test_configuration.h" #include "strings_helper.h" @@ -24,6 +26,7 @@ #include "absl/flags/usage.h" #include "absl/flags/usage_config.h" #include "absl/flags/reflection.h" +#include "absl/strings/str_split.h" static const onnxruntime::perftest::PerformanceTestConfig& DefaultPerformanceTestConfig() { static onnxruntime::perftest::PerformanceTestConfig default_config{}; @@ -149,8 +152,10 @@ ABSL_FLAG(std::string, C, "", "Refer to onnxruntime_session_options_config_keys.h for valid keys and values. \n" "[Example] -C \"session.disable_cpu_ep_fallback|1 ep.context_enable|1\" \n"); ABSL_FLAG(std::string, R, "", "Allows user to register custom op by .so or .dll file."); -ABSL_FLAG(bool, A, DefaultPerformanceTestConfig().run_config.enable_cpu_mem_arena, "Disables memory arena."); -ABSL_FLAG(bool, M, DefaultPerformanceTestConfig().run_config.enable_memory_pattern, "Disables memory pattern."); +ABSL_FLAG(bool, A, !DefaultPerformanceTestConfig().run_config.enable_cpu_mem_arena, "Disables memory arena."); +ABSL_FLAG(std::string, shrink_arena_between_runs, "", "When arena is enabled call Shrink for specified devices 'cpu:0;gpu:0'"); +ABSL_FLAG(std::string, enable_cuda_mempool, "", "When cuda is enabled use CudaMempoolArena with params 'pool_release_threshold;bytes_to_keep_on_shrink'"); +ABSL_FLAG(bool, M, !DefaultPerformanceTestConfig().run_config.enable_memory_pattern, "Disables memory pattern."); ABSL_FLAG(bool, s, DefaultPerformanceTestConfig().run_config.f_dump_statistics, "Shows statistics result, like P75, P90. If no result_file provided this defaults to on."); ABSL_FLAG(bool, v, DefaultPerformanceTestConfig().run_config.f_verbose, "Shows verbose information."); ABSL_FLAG(bool, I, DefaultPerformanceTestConfig().run_config.generate_model_input_binding, "Generates tensor input binding. Free dimensions are treated as 1 unless overridden using -f."); @@ -261,10 +266,34 @@ bool CommandLineParser::ParseArguments(PerformanceTestConfig& test_config, int a } // -M - test_config.run_config.enable_memory_pattern = absl::GetFlag(FLAGS_M); + test_config.run_config.enable_memory_pattern = !absl::GetFlag(FLAGS_M); // -A - test_config.run_config.enable_cpu_mem_arena = absl::GetFlag(FLAGS_A); + test_config.run_config.enable_cpu_mem_arena = !absl::GetFlag(FLAGS_A); + + // --shrink_arena_between_runs + if (test_config.run_config.enable_cpu_mem_arena) { + auto shrink_spec = absl::GetFlag(FLAGS_shrink_arena_between_runs); + test_config.run_config.run_config_entries.emplace( + kOrtRunOptionsConfigEnableMemoryArenaShrinkage, + std::move(shrink_spec)); + } + + // --enable_cuda_mempool + { + auto cuda_mempool_spec = absl::GetFlag(FLAGS_enable_cuda_mempool); + if (!cuda_mempool_spec.empty()) { + // Split the string with ';' separator in two parts + std::vector parts = absl::StrSplit(cuda_mempool_spec, ';'); + if (parts.size() == 2U) { + test_config.run_config.cuda_mempool_arena_config = { + std::move(parts[0]), std::move(parts[1])}; + } else { + std::cerr << "Invalid format for --enable_cuda_mempool. " + << "Expected format : " << std::endl; + } + } + } // -e { diff --git a/onnxruntime/test/perftest/ort_test_session.cc b/onnxruntime/test/perftest/ort_test_session.cc index 0f2da07c69d85..cb40a9beafeee 100644 --- a/onnxruntime/test/perftest/ort_test_session.cc +++ b/onnxruntime/test/perftest/ort_test_session.cc @@ -10,9 +10,11 @@ #include #include #include +#include #include #include "core/session/onnxruntime_session_options_config_keys.h" #include "core/session/onnxruntime_run_options_config_keys.h" +#include "core/providers/cuda/cuda_provider_options.h" #include "core/providers/tensorrt/tensorrt_provider_options.h" #include "core/providers/dnnl/dnnl_provider_options.h" #include @@ -45,12 +47,15 @@ RunTiming OnnxRuntimeTestSession::Run() { auto& input = test_inputs_.at(id); auto start = std::chrono::high_resolution_clock::now(); Ort::RunOptions run_options; + for (const auto& kv : run_config_entries_) { + run_options.AddConfigEntry(kv.first.c_str(), kv.second.c_str()); + } + RunTiming timing; if (CUDA == device_memory_name_) { run_options.AddConfigEntry(kOrtRunOptionsConfigDisableSynchronizeExecutionProviders, "1"); Ort::IoBinding io_binding(session_); - const OrtMemoryInfo* mem_info; - Ort::ThrowOnError(Ort::GetApi().AllocatorGetInfo(allocator_, &mem_info)); + auto mem_info = allocator_.GetInfo(); for (size_t i = 0; i < input_names_.size(); ++i) { io_binding.BindInput(input_names_[i], input[i]); @@ -76,7 +81,11 @@ RunTiming OnnxRuntimeTestSession::Run() { OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device& rd, const PerformanceTestConfig& performance_test_config, const TestModelInfo& m) - : rand_engine_(rd()), input_names_(m.GetInputCount()), input_names_str_(m.GetInputCount()), input_length_(m.GetInputCount()) { + : rand_engine_(rd()), + input_names_(m.GetInputCount()), + input_names_str_(m.GetInputCount()), + input_length_(m.GetInputCount()), + run_config_entries_(performance_test_config.run_config.run_config_entries) { Ort::SessionOptions session_options; // Add EP devices if any (created by plugin EP) @@ -218,8 +227,8 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device #endif } else if (provider_name_ == onnxruntime::kCudaExecutionProvider) { #ifdef USE_CUDA - Ort::CUDAProviderOptions cuda_options; + Ort::CUDAProviderOptions cuda_options; const char* config_val = nullptr; switch (performance_test_config.run_config.cudnn_conv_algo) { case 0: @@ -249,6 +258,24 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device } cuda_options.Update(provider_options); + if (performance_test_config.run_config.cuda_mempool_arena_config) { + // Enable and configure cuda_mempool arena + const size_t release_threshold = + static_cast(std::atoll(performance_test_config.run_config.cuda_mempool_arena_config->release_threshold.c_str())); + const size_t bytes_to_keep_on_shrink = + static_cast(std::atoll(performance_test_config.run_config.cuda_mempool_arena_config->bytes_to_keep.c_str())); + // Create a map of properties for the arena configuration + std::unordered_map arena_config_map = { + {"use_cuda_mempool", 1U}, + {"cuda_mempool_bytes_to_keep_on_shrink", bytes_to_keep_on_shrink}, + {"cuda_mempool_release_threshold", release_threshold}, + }; + // Must be kept alive while session is alive + Ort::ArenaCfg cuda_arena_cfg(arena_config_map); + cuda_mempool_arena_cfg_ = std::move(cuda_arena_cfg); + (*cuda_options).default_memory_arena_cfg = cuda_mempool_arena_cfg_; + } + session_options.AppendExecutionProvider_CUDA_V2(*cuda_options); #else ORT_THROW("CUDA is not supported in this build\n"); @@ -1014,7 +1041,8 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); memory_info = Ort::MemoryInfo(device_memory_name_.data(), OrtArenaAllocator, 0, OrtMemTypeCPUOutput); } custom_allocator_ = Ort::Allocator(session_, memory_info); - allocator_ = custom_allocator_; + // Switch to custom + allocator_ = Ort::UnownedAllocator(custom_allocator_); // free dimensions are treated as 1 if not overridden transform_fcn = [](int64_t input) { return (input == -1) ? -input : input; }; diff --git a/onnxruntime/test/perftest/ort_test_session.h b/onnxruntime/test/perftest/ort_test_session.h index ada467824ca18..743db63b7b43c 100644 --- a/onnxruntime/test/perftest/ort_test_session.h +++ b/onnxruntime/test/perftest/ort_test_session.h @@ -42,9 +42,11 @@ class OnnxRuntimeTestSession : public TestSession { Ort::Session session_{nullptr}; std::mt19937 rand_engine_; std::uniform_int_distribution dist_; - OrtAllocator* allocator_ = Ort::AllocatorWithDefaultOptions(); + Ort::AllocatorWithDefaultOptions default_allocator_; // Note: custom_allocator_, if used, must outlive the `Ort::Value`s allocated with it in test_inputs_ and outputs_. + // and must be declared before them to ensure it is destroyed after them. Ort::Allocator custom_allocator_{nullptr}; + Ort::UnownedAllocator allocator_{default_allocator_}; std::vector> test_inputs_; std::vector outputs_; std::vector output_names_; @@ -56,9 +58,11 @@ class OnnxRuntimeTestSession : public TestSession { const int input_length_; std::string provider_name_; std::string device_memory_name_; // Device memory type name to use from the list in allocator.h + const std::unordered_map& run_config_entries_; #if defined(USE_CUDA) || defined(USE_TENSORRT) || defined(USE_NV) cudaStream_t stream_; // Device stream if required by IO bindings #endif + Ort::ArenaCfg cuda_mempool_arena_cfg_{nullptr}; }; } // namespace perftest diff --git a/onnxruntime/test/perftest/test_configuration.h b/onnxruntime/test/perftest/test_configuration.h index c982a8daadc9d..1d8ad77096ef3 100644 --- a/onnxruntime/test/perftest/test_configuration.h +++ b/onnxruntime/test/perftest/test_configuration.h @@ -5,6 +5,7 @@ #include #include +#include #include #include @@ -63,6 +64,7 @@ struct RunConfig { bool set_denormal_as_zero{false}; std::basic_string ep_runtime_config_string; std::unordered_map session_config_entries; + std::unordered_map run_config_entries; std::map free_dim_name_overrides; std::map free_dim_denotation_overrides; std::string intra_op_thread_affinities; @@ -75,6 +77,11 @@ struct RunConfig { bool compile_ep_context{false}; std::basic_string compile_model_path; bool compile_binary_embed{false}; + struct CudaMempoolArenaConfig { + std::string release_threshold; + std::string bytes_to_keep; + }; + std::optional cuda_mempool_arena_config; }; struct PerformanceTestConfig { diff --git a/onnxruntime/test/providers/cuda/cuda_mempool_arena_test.cc b/onnxruntime/test/providers/cuda/cuda_mempool_arena_test.cc new file mode 100644 index 0000000000000..70c7a5b2bcdcb --- /dev/null +++ b/onnxruntime/test/providers/cuda/cuda_mempool_arena_test.cc @@ -0,0 +1,306 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifdef USE_CUDA + +#include +#include + +#include + +#include "core/common/inlined_containers.h" // InlinedVector +#include "core/framework/allocator.h" // OrtMemoryInfo, IAllocator, AllocatorStats, onnxruntime::CUDA +#include "core/framework/execution_provider.h" +#include "core/framework/stream_handles.h" // onnxruntime::Stream (interface) +#include "core/providers/cuda/cuda_provider_options.h" +#include "core/providers/cuda/cuda_provider_factory.h" +#include "core/providers/cuda/cuda_provider_factory_creator.h" +#include "test/util/include/asserts.h" + +namespace onnxruntime { +namespace test { + +// --------- Helpers --------- + +static bool IsCudaMemPoolSupported() { + int ort_cuda_rt_version = 0; + cudaError_t cuda_status = cudaRuntimeGetVersion(&ort_cuda_rt_version); + if (cuda_status != cudaSuccess) { + return false; + } + + if (ort_cuda_rt_version < 11020) { + return false; + } + + int ort_cuda_driver_version = 0; + cuda_status = cudaDriverGetVersion(&ort_cuda_driver_version); + if (cuda_status != cudaSuccess) { + return false; + } + + if (ort_cuda_driver_version < 11020) { + return false; + } + + // Check if the driver version supports the runtime version + if (ort_cuda_rt_version >= 12000 && ort_cuda_driver_version < 12000) { + return false; + } + + if (ort_cuda_rt_version >= 13000 && ort_cuda_driver_version < 13000) { + return false; + } + + // Creating a cuda mempool in some pipelines fails with + // CUDA failure 801: operation not supported ; GPU=0 ; hostname=af14bbb1c000000 ; + // Even though CUDA version may be 12.8 possibly due to the driver. + cudaMemPoolProps props{}; + // Pinned is not the same as pinned allocator, cudaMemLocationTypeDevice actually does not exist + // even though is present in some internet docs. + props.allocType = cudaMemAllocationTypePinned; + props.handleTypes = cudaMemHandleTypeNone; // local to process + props.location.type = cudaMemLocationTypeDevice; // Device memory + props.location.id = 0; // test device 0 + cudaMemPool_t pool; + auto cuda_error = cudaMemPoolCreate(&pool, &props); + if (cuda_error != cudaSuccess) { + return false; + } + cuda_error = cudaMemPoolDestroy(pool); + + return true; +} + +static ::cudaStream_t NewCudaStream() { + ::cudaStream_t s{}; + const cudaError_t st = ::cudaStreamCreate(&s); + EXPECT_EQ(st, cudaSuccess); + return s; +} + +static void DestroyCudaStream(::cudaStream_t s) { + if (s) (void)::cudaStreamDestroy(s); +} + +static void TouchDevice(void* p, size_t bytes, ::cudaStream_t s, unsigned char value = 0xAB) { + ASSERT_NE(p, nullptr); + ASSERT_EQ(::cudaSuccess, ::cudaMemsetAsync(p, static_cast(value), bytes, s)); +} + +// --------- Test parameters --------- + +struct MPArenaParams { + uint64_t release_threshold = 1ull << 20; // 1 MB (recommended in allocator docs) + size_t bytes_to_keep = 4ull << 20; // 4 MB (small trim target for tests) +}; + +OrtArenaCfg CreateArenaCfgFromParams(const MPArenaParams& params) { + OrtArenaCfg cfg; + cfg.initial_chunk_size_bytes = 0; // Make BFCArena for CUDAPinned not to allocate anything here + cfg.use_cuda_mempool = 1; // Key switch + cfg.cuda_mempool_release_threshold = params.release_threshold; + cfg.cuda_mempool_bytes_to_keep_on_shrink = params.bytes_to_keep; + return cfg; +} + +std::unique_ptr CreateCudaExecutionProvider(OrtArenaCfg& arena_cfg) { + OrtCUDAProviderOptionsV2 cuda_options; + cuda_options.device_id = 0; // single-device tests + cuda_options.default_memory_arena_cfg = &arena_cfg; + cuda_options.do_copy_in_default_stream = true; + cuda_options.use_tf32 = false; + if (auto factory = CudaProviderFactoryCreator::Create(&cuda_options)) + return factory->CreateProvider(); + return nullptr; +} + +AllocatorPtr GetCudaMempoolArena(IExecutionProvider& cuda_ep) { + auto allocators = cuda_ep.CreatePreferredAllocators(); + EXPECT_EQ(allocators.size(), 2u); + const auto& mem_info = allocators[0]->Info(); + EXPECT_EQ("CUDAMemPoolArena", mem_info.name); + return allocators[0]; +} + +// --------- Minimal test Stream adapter --------- +// +// Adapts a cudaStream_t to ORT's Stream interface. +// If your Stream interface has additional pure virtuals on the work branch, +// add trivial overrides here (returning defaults / no-ops) so tests compile. +class TestCudaStream final : public onnxruntime::Stream { + public: + TestCudaStream(::cudaStream_t s, const OrtDevice& device) : Stream(s, device) {} + + ~TestCudaStream() { + DestroyCudaStream(static_cast<::cudaStream_t>(GetHandle())); + } + + void* GetHandle() const { + // ORT expects GetHandle() to return the native handle (cast to void*). + return Stream::GetHandle(); + } +}; + +// --------- Test fixture --------- + +class CudaMempoolArenaTest : public ::testing::Test { + protected: + void SetUp() override { + if (!IsCudaMemPoolSupported()) { + GTEST_SKIP() << "CUDA memory pools not supported on this device/driver."; + } + + const auto& logger = onnxruntime::logging::LoggingManager::DefaultLogger(); + orig_severity_ = logger.GetSeverity(); + orig_verbosity_ = logger.VLOGMaxLevel(); + logging::LoggingManager::SetDefaultLoggerSeverity(logging::Severity::kVERBOSE); + logging::LoggingManager::SetDefaultLoggerVerbosity(0); + cuda_ep_ = CreateCudaExecutionProvider(arena_cfg_); + cuda_ep_->SetLogger(&logger); + arena_ = GetCudaMempoolArena(*cuda_ep_); + mem_info_ = arena_->Info(); + } + + void TearDown() override { + arena_.reset(); + cuda_ep_.reset(); + ::cudaDeviceSynchronize(); + logging::LoggingManager::SetDefaultLoggerSeverity(orig_severity_); + logging::LoggingManager::SetDefaultLoggerVerbosity(orig_verbosity_); + } + + logging::Severity orig_severity_; + int orig_verbosity_; + OrtArenaCfg arena_cfg_ = CreateArenaCfgFromParams(MPArenaParams()); + std::unique_ptr cuda_ep_; + AllocatorPtr arena_; + OrtMemoryInfo mem_info_; +}; + +// --------- Tests --------- + +TEST_F(CudaMempoolArenaTest, AllocAndFree_OnDefaultStream) { + const size_t kBytes = 1 << 20; // 1 MB + void* p = arena_->Alloc(kBytes); + ASSERT_NE(p, nullptr); + + // default (legacy) stream 0 + ASSERT_EQ(::cudaSuccess, ::cudaMemsetAsync(p, 0xCD, kBytes, /*stream=*/0)); + arena_->Free(p); + + ASSERT_EQ(::cudaSuccess, ::cudaDeviceSynchronize()); + + onnxruntime::AllocatorStats stats{}; + arena_->GetStats(&stats); + EXPECT_GE(stats.num_allocs, 1u); +} + +TEST_F(CudaMempoolArenaTest, AllocOnTwoStreams_OrderedFree) { + const size_t kBytes = 2 << 20; // 2 MB + + ::cudaStream_t s0 = NewCudaStream(); + ::cudaStream_t s1 = NewCudaStream(); + { + TestCudaStream ort_s0(s0, mem_info_.device); + TestCudaStream ort_s1(s1, mem_info_.device); + + void* p0 = arena_->AllocOnStream(kBytes, &ort_s0); + void* p1 = arena_->AllocOnStream(kBytes, &ort_s1); + ASSERT_NE(p0, nullptr); + ASSERT_NE(p1, nullptr); + + TouchDevice(p0, kBytes, s0, 0x11); + TouchDevice(p1, kBytes, s1, 0x22); + + // Enqueue ordered frees (no sync needed here). + arena_->Free(p0); + arena_->Free(p1); + + // Ensure queued frees completed on each stream. + ASSERT_EQ(::cudaSuccess, ::cudaStreamSynchronize(s0)); + ASSERT_EQ(::cudaSuccess, ::cudaStreamSynchronize(s1)); + + // Destroy streams here + } + + ASSERT_EQ(::cudaSuccess, ::cudaGetLastError()); +} + +TEST_F(CudaMempoolArenaTest, Shrink_TrimsPool_And_AllowsFurtherUse) { + const size_t kBytes = 2 << 20; + + InlinedVector ptrs; + for (size_t i = 0; i < ptrs.capacity(); ++i) { + void* p = arena_->Alloc(kBytes); + ASSERT_NE(p, nullptr); + ASSERT_EQ(::cudaSuccess, ::cudaMemsetAsync(p, 0xEF, kBytes, /*stream=*/0)); + ptrs.push_back(p); + } + ASSERT_EQ(::cudaSuccess, ::cudaDeviceSynchronize()); + + for (void* p : ptrs) { + arena_->Free(p); + } + ASSERT_EQ(::cudaSuccess, ::cudaDeviceSynchronize()); + + // Trim and sanity-check future allocations still work. + auto* arena_cast = IArena::SafeArenaCast(arena_.get()); + ASSERT_STATUS_OK(arena_cast->Shrink()); + + void* p_check = arena_->Alloc(kBytes); + ASSERT_NE(p_check, nullptr); + arena_->Free(p_check); + ASSERT_EQ(::cudaSuccess, ::cudaDeviceSynchronize()); +} + +TEST_F(CudaMempoolArenaTest, Reserve_DelegatesToAlloc) { + const size_t kBytes = 512 * 1024; + void* p = arena_->Reserve(kBytes); + ASSERT_NE(p, nullptr); + arena_->Free(p); + ASSERT_EQ(::cudaSuccess, ::cudaDeviceSynchronize()); +} + +// Validates allocator dtor guarantees completion of queued frees even when +// streams are destroyed prior to allocator destruction. +TEST_F(CudaMempoolArenaTest, Destructor_CompletesQueuedFrees_EvenIfStreamDestroyed) { + const size_t kBytes = 1 << 20; + ::cudaStream_t s = NewCudaStream(); + + { + auto cuda_prov = CreateCudaExecutionProvider(arena_cfg_); + cuda_prov->SetLogger(&onnxruntime::logging::LoggingManager::DefaultLogger()); + auto alloc = GetCudaMempoolArena(*cuda_ep_); + { + TestCudaStream ort_s(s, mem_info_.device); + + InlinedVector ptrs; + for (size_t i = 0; i < ptrs.capacity(); ++i) { + void* p = alloc->AllocOnStream(kBytes, &ort_s); + ASSERT_NE(p, nullptr); + TouchDevice(p, kBytes, s); + ptrs.push_back(p); + } + + for (void* p : ptrs) { + alloc->Free(p); + } + + // Destroy the stream *before* the frees have a chance to run. + } + + // arena goes out of scope here; its destructor must: + // - sync known streams (best-effort), + // - device-wide synchronize as a safety net, + // - then trim and destroy the pool. + } + + ASSERT_EQ(::cudaSuccess, ::cudaGetLastError()); + ASSERT_EQ(::cudaSuccess, ::cudaDeviceSynchronize()); +} + +} // namespace test +} // namespace onnxruntime + +#endif // USE_CUDA From 81a04ca45d5c30a4cbe9ebd06c6f37d80d9caf91 Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Thu, 20 Nov 2025 08:59:16 +0800 Subject: [PATCH 03/17] [webgpu] Fix the wrong fallback in Attention (#26608) Attention input handling updates: * Corrected the input indices for `past` from `input[5]` to `input[4]` in the fallback logic, ensuring the code reflects the actual input order. With this change, the Attention ops in phi-4-mm-vision.onnx can go to the gpu instead of cpu. --- .../core/providers/webgpu/webgpu_execution_provider.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc index 3df194217933e..e0b84fef51f1f 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc @@ -878,9 +878,9 @@ std::vector> WebGpuExecutionProvider::GetCapa const auto& inputs = node.InputDefs(); const auto& outputs = node.OutputDefs(); - // Current implementation does not support mask_index(input[3]), past(input[5]) and past_seq_len(input[6]) + // Current implementation does not support mask_index(input[3]), past(input[4]) and past_seq_len(input[6]) FALLBACK_TO_CPU_IF_EXIST_INPUT(3); - FALLBACK_TO_CPU_IF_EXIST_INPUT(5); + FALLBACK_TO_CPU_IF_EXIST_INPUT(4); FALLBACK_TO_CPU_IF_EXIST_INPUT(6); // Current implementation does not support present(output[1]) From 607d5e4de96caad7b44f9f492d0bb7ec06f07d7e Mon Sep 17 00:00:00 2001 From: Jiawei Shao Date: Fri, 21 Nov 2025 00:09:53 +0800 Subject: [PATCH 04/17] [WebGPU] Implement Split-K on Conv|MatMul (#26461) ### Description This patch implements the `Split-K` optimization on `Conv|MatMul`. With `Split-K` we can re-arrange the computation into multiple workgroups when `K` is large to increase the parallelism on the platforms that `Split-K` is confirmed to be useful. 1. Support `Split-K` in `MakeMatMulPackedVec4Source()` to split a workgroup with large K into smaller ones. In this patch we only support `Split-K` with `batch_size == 1` and `vec4` on `Conv|MatMul`. 2. Support `Split-K` in `MatMulWriteFnSource()` (add the partial result to output with atomic built-in functions) 3. Implement `SplitKConfig` to decide whether `Split-K` should be used or not, and all the related thresholds. 4. Implement `MatMulFillBiasBeforeSplitKProgram` to initialize the output with `bias` or 0 when `Split-K` is used. ### Motivation and Context In current implementation, when `K` or `dim_inner` is large, in each invocation we always do the computation one by one in a very large loop, which may not make full use of all EUs on a GPU. With `Split-K` we can split such large amount of computation (`K`) into multiple workgroups with less computation (`kSplitK`, smaller than K), which can greatly improve the parallelism. With this patch we can get about 15% performance improvement on `efficientnet-lite-f16-demo` and 9% improvement on `mobilenetv2-12-f16-demo` on Lunar Lake and Meteor Lake. --- .../core/providers/webgpu/compute_context.cc | 4 + .../core/providers/webgpu/compute_context.h | 7 + .../core/providers/webgpu/math/gemm_utils.cc | 131 ++++++++++++++++-- .../core/providers/webgpu/math/gemm_utils.h | 4 +- .../core/providers/webgpu/math/matmul.cc | 96 +++++++++++-- .../core/providers/webgpu/math/matmul.h | 11 +- .../providers/webgpu/math/matmul_packed.cc | 58 +++++++- .../providers/webgpu/math/matmul_packed.h | 36 ++++- onnxruntime/core/providers/webgpu/nn/conv.cc | 3 +- .../core/providers/webgpu/nn/fuse_utils.cc | 1 + .../core/providers/webgpu/nn/fuse_utils.h | 8 +- .../core/providers/webgpu/shader_helper.cc | 2 +- .../core/providers/webgpu/webgpu_context.cc | 7 + .../core/providers/webgpu/webgpu_context.h | 11 ++ .../core/providers/webgpu/webgpu_utils.cc | 59 ++++++++ .../core/providers/webgpu/webgpu_utils.h | 21 +++ .../test/providers/cpu/nn/conv_fp16_test.cc | 122 +++++++++++++++- .../test/providers/cpu/nn/conv_op_test.cc | 100 ++++++++++++- 18 files changed, 633 insertions(+), 48 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/compute_context.cc b/onnxruntime/core/providers/webgpu/compute_context.cc index d8e58f0d0a170..ebe71c6ccfacd 100644 --- a/onnxruntime/core/providers/webgpu/compute_context.cc +++ b/onnxruntime/core/providers/webgpu/compute_context.cc @@ -20,5 +20,9 @@ const webgpu::BufferManager& ComputeContext::BufferManagerAccessor::Get(const Co return context.ep_.BufferManager(); } +const SplitKConfig& ComputeContext::GetSplitKConfig() { + return webgpu_context_.GetSplitKConfig(); +} + } // namespace webgpu } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/compute_context.h b/onnxruntime/core/providers/webgpu/compute_context.h index 01cae1e337439..ed16f2f0a1345 100644 --- a/onnxruntime/core/providers/webgpu/compute_context.h +++ b/onnxruntime/core/providers/webgpu/compute_context.h @@ -152,6 +152,13 @@ class ComputeContext final { return webgpu_context_.Run(*this, program); } + // + // Get Split-K configuration. + // + // `split_k_config_` won't be initialized until the first call to this method. + // + const SplitKConfig& GetSplitKConfig(); + private: WebGpuContext& webgpu_context_; OpKernelContext& kernel_context_; diff --git a/onnxruntime/core/providers/webgpu/math/gemm_utils.cc b/onnxruntime/core/providers/webgpu/math/gemm_utils.cc index 7dd3b50c656f4..7cbc7f6a4a821 100644 --- a/onnxruntime/core/providers/webgpu/math/gemm_utils.cc +++ b/onnxruntime/core/providers/webgpu/math/gemm_utils.cc @@ -13,7 +13,7 @@ namespace webgpu { // which are used in the MatMulWriteFnSource function. namespace { -void HanldeMaybeHaveBiasForGEMM(ShaderHelper& shader, +void HandleMaybeHaveBiasForGEMM(ShaderHelper& shader, const ShaderVariableHelper& output, bool has_bias, int c_components, @@ -53,6 +53,70 @@ void HandleMaybeBiasForMatMul(ShaderHelper& shader, << output.SetByIndices("coords", "value") << "\n"; } +void HandleMatMulWithSplitK( + ShaderHelper& shader, + ProgramVariableDataType output_variable_type) { + shader.AdditionalImplementation() << " let coords = vec3(u32(batch), u32(row), u32(colIn));\n"; + + // With Split-K, the final output will be the sum of the sub-outputs from multiple workgroups, + // so we must add them with atomic built-in functions. Because currently WebGPU doesn't support + // atomic built-in functions on `f32` or `f16`, we implement the `atomicAdd` on `f32` and `f16` + // with `atomicLoad` and `atomicCompareExchangeWeak`: + // 1. Get `old_output_i32` from `output[offset]` with `atomicLoad`. + // 2. Convert `old_output_i32` into `f32` (`old_output_f32`) or `vec2h` (`old_output_vec2h`). + // 3. Add incoming `value` into `old_output_f32` or `old_output_vec2h`. + // 4. Convert the result of step 3 into `i32` values. + // 5. Try assigning the result of step 4 into `output[offset]` with `atomicCompareExchangeWeak` + // and `old_output_i32`. The assignment will fail if at this time `output[offset]` is not + // equal to `old_output_i32` (it is updated in another invocation). If the assignment fails + // we have to go to step 1 and repeat all the above steps. + switch (output_variable_type) { + case ProgramVariableDataType::Float32x4: { + shader.AdditionalImplementation() << R"( + let offset0 = i2o_output(coords) * 4u; + for (var i = 0u; i < 4u; i++) { + let offset = offset0 + i; + while (true) { + let old_output_i32 = atomicLoad(&output[offset]); + let old_output_f32 = bitcast(old_output_i32); + let new_output_f32 = old_output_f32 + value[i]; + let new_output_i32 = bitcast(new_output_f32); + let output_compare_exchange = atomicCompareExchangeWeak(&output[offset], old_output_i32, new_output_i32); + if (output_compare_exchange.old_value == old_output_i32) { + break; + } + } + } +)"; + break; + } + case ProgramVariableDataType::Float16x4: { + shader.AdditionalImplementation() << R"( + let offset0 = i2o_output(coords) * 2u; + var vec2h_values : array; + vec2h_values[0] = value.xy; + vec2h_values[1] = value.zw; + for (var i = 0u; i < 2u; i++) { + let offset = offset0 + i; + while (true) { + let old_output_i32 = atomicLoad(&output[offset]); + let old_output_vec2h = bitcast(old_output_i32); + let new_output_vec2h = old_output_vec2h + vec2h_values[i]; + let new_output_i32 = bitcast(new_output_vec2h); + let output_compare_exchange = atomicCompareExchangeWeak(&output[offset], old_output_i32, new_output_i32); + if (output_compare_exchange.old_value == old_output_i32) { + break; + } + } + } +)"; + break; + } + default: + break; + } +} + } // namespace void MatMulReadFnSource(ShaderHelper& shader, @@ -125,7 +189,9 @@ void MatMulWriteFnSource(ShaderHelper& shader, int output_components, bool c_is_scalar, std::string activation_snippet, - bool is_channels_last) { + bool is_channels_last, + bool use_split_k, + ProgramVariableDataType output_variable_type) { shader.AdditionalImplementation() << "fn mm_write(batch: i32, row: i32, colIn: i32, valueIn: output_value_t) { \n"; @@ -134,8 +200,17 @@ void MatMulWriteFnSource(ShaderHelper& shader, shader.AdditionalImplementation() << "if(row < i32(uniforms.dim_a_outer) && col < i32(uniforms.dim_b_outer)) { \n" << " var value = valueIn; \n"; - if (is_gemm) { - HanldeMaybeHaveBiasForGEMM(shader, output, has_bias, c_components, output_components, c_is_scalar); + if (use_split_k) { + // Set output when MatMul is performed with Split-K. + // When Split-K is used in MatMul, the bias will be handled in `MatMulFillBiasOrZeroBeforeSplitKProgram` + // instead of here, so `has_bias` and `is_channels_last` is not used for Split-K. Note that we + // still need to handle `has_bias` (and `is_channels_last` in the future) in + // `MatMulFillBiasOrZeroBeforeSplitKProgram`. + ORT_ENFORCE(!has_bias, "Bias is not supported in MatMulProgram when Split-K is enabled."); + ORT_ENFORCE(is_channels_last, "Only channels-last is supported in MatMulProgram when Split-K is enabled."); + HandleMatMulWithSplitK(shader, output_variable_type); + } else if (is_gemm) { + HandleMaybeHaveBiasForGEMM(shader, output, has_bias, c_components, output_components, c_is_scalar); } else { HandleMaybeBiasForMatMul(shader, output, has_bias, activation_snippet, is_channels_last); } @@ -159,9 +234,6 @@ Status MakeMatMulPackedVec4Source(ShaderHelper& shader, uint32_t tile_inner, bool split_k, uint32_t split_dim_inner) { - ORT_UNUSED_PARAMETER(split_k); - ORT_UNUSED_PARAMETER(split_dim_inner); - const std::string type_string = MakeScalarOrVectorType(4 /*components */, data_type); std::string write_data_to_sub_a_vec4_snippet = @@ -208,14 +280,51 @@ Status MakeMatMulPackedVec4Source(ShaderHelper& shader, << " let tileCol = i32(local_id.x);\n" << " let globalRow = i32(global_id.y) * rowPerThread;\n" << " let globalCol = i32(global_id.x);\n" - << " let batch = i32(global_id.z);\n" - << (nullptr != batch_dims ? " let batchIndices = " + batch_dims->OffsetToIndices("u32(batch)") + ";\n" : "") << " let globalRowStart = i32(workgroup_id.y) * " << tile_a_outer << ";\n" << " let globalColStart = i32(workgroup_id.x) * " << tile_b_outer << ";\n" - << " let num_tiles = (uniforms.dim_inner - 1) / tileInner + 1;\n" - << " var kStart = 0;\n" << " var acc: array, rowPerThread>;\n"; + if (split_k) { + // With Split-K, the original "workgroup" (with dispatch_z == 1 in API side) is split into + // multiple ones, and in the current workgroup we only compute `kSplitK` elements starting from + // `kSplitK * i32(global_id.z)`. + // + // For example: considering computing Y = (X * W + B) in one workgroup. + // Let kSplitK = 2, B = [d1, d2] + // Let X = [[a1 a1 b1 b1 c1 c1] = [ A1 B1 C1 ], W = [[a2 a2] = [ A2 + // [a1 a1 b1 b1 c1 c1]] [a2 a2] B2 + // [b2 b2] C2 ] + // [b2 b2] + // [c2 c2] + // [c2 c2]] + // + // With Split-K: + // 1. Initialize output Y with B in `MatMulFillBiasOrZeroBeforeSplitKProgram`: Y = [[d1, d2] + // [d1, d2]] + // 2. Split the original 1 workgroup into 3 workgroups (now `dispatch_z = 3` in API side) + // Workgroup1: compute (A1 * A2) Workgroup2: compute (B1 * B2) + // Workgroup3: compute (C1 * C2) + // In each workgroup: + // - `num_tiles` is computed with `kSplitK`, and `kStart` is computed with `global_id.z` + // - When the computation in each workgroup is completed, add the result to Y with several + // atomic built-in functions in `HandleMatMulWithSplitK()`. + shader.MainFunctionBody() + << "const kSplitK = " << split_dim_inner << ";\n" + << " let num_tiles = (kSplitK - 1) / tileInner + 1;\n" + << " var kStart = kSplitK * i32(global_id.z);\n" + + // When Split-K is used, `batch` should always be 0 and `global_id.z` is used to indicate + // the index of split-k instead of batch. + << " let batch = 0;\n" + << " let batchIndices = 0u;\n"; + } else { + shader.MainFunctionBody() + << " let num_tiles = (uniforms.dim_inner - 1) / tileInner + 1;\n" + << " var kStart = 0;\n" + << " let batch = i32(global_id.z);\n" + << (nullptr != batch_dims ? " let batchIndices = " + batch_dims->OffsetToIndices("u32(batch)") + ";\n" : ""); + } + // Loop over shared dimension. shader.MainFunctionBody() << " let tileRowB = localRow * " << row_per_thread_b << ";\n"; diff --git a/onnxruntime/core/providers/webgpu/math/gemm_utils.h b/onnxruntime/core/providers/webgpu/math/gemm_utils.h index ed4cf997d2f00..7075debeb9952 100644 --- a/onnxruntime/core/providers/webgpu/math/gemm_utils.h +++ b/onnxruntime/core/providers/webgpu/math/gemm_utils.h @@ -24,7 +24,9 @@ void MatMulWriteFnSource(ShaderHelper& shader, int output_components, bool c_is_scalar, std::string activation_snippet = "", - bool is_channels_last = false); + bool is_channels_last = false, + bool use_split_k = false, + ProgramVariableDataType output_variable_type = ProgramVariableDataType::Float32x4); // The two following functions are used to generate shader code for vec4 and scalar. // It is used in GEMM, Matmul, and Conv. diff --git a/onnxruntime/core/providers/webgpu/math/matmul.cc b/onnxruntime/core/providers/webgpu/math/matmul.cc index cf4b9d3fae2d2..55c2c5773cc1f 100644 --- a/onnxruntime/core/providers/webgpu/math/matmul.cc +++ b/onnxruntime/core/providers/webgpu/math/matmul.cc @@ -161,14 +161,14 @@ Status MatMul::ComputeInternal(ComputeContext& context) const { const auto* bias = context.Input(2); inputs.push_back(bias); } - auto program = CreateMatMulProgram(Activation(), inputs, output_tensor, false); - return context.RunProgram(program); + return ComputeMatMul(&context, Activation(), inputs, output_tensor, false); } -MatMulProgram CreateMatMulProgram(const Activation& activation, std::vector& inputs, Tensor* output_tensor, bool is_channels_last, - const TensorShape& input_a_reshape, - const TensorShape& input_b_reshape) { +Status ComputeMatMul(ComputeContext* context, + const Activation& activation, std::vector& inputs, Tensor* output_tensor, bool is_channels_last, + const TensorShape& input_a_reshape, + const TensorShape& input_b_reshape) { const auto* a = inputs[0]; const auto* b = inputs[1]; bool has_bias = inputs.size() > 2; @@ -226,31 +226,97 @@ MatMulProgram CreateMatMulProgram(const Activation& activation, std::vector((dim_a_outer + MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Y * elements_per_thread[1] - 1) / (MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Y * elements_per_thread[1])); - const uint32_t dispatch_z = narrow((static_cast(batch_size) + MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Z * elements_per_thread[2] - 1) / - (MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Z * elements_per_thread[2])); + uint32_t dispatch_z = narrow((static_cast(batch_size) + MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Z * elements_per_thread[2] - 1) / + (MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Z * elements_per_thread[2])); const int components = is_vec4 ? 4 : 1; const TensorShape a_shape_temp = CreateMatMulIntermediateShape(outer_dims_a, dim_a_outer, dim_inner, components); const TensorShape b_shape_temp = CreateMatMulIntermediateShape(outer_dims_b, dim_inner, dim_b_outer, components); const TensorShape output_shape_temp = TensorShape({batch_size, dim_a_outer, dim_b_outer / components}); - MatMulProgram program{activation, has_bias, is_vec4, elements_per_thread, is_channels_last}; - program - .CacheHint(activation.ToString(), absl::StrJoin(elements_per_thread, "-"), std::to_string(is_vec4), components, is_channels_last) + ProgramOutput output(output_tensor, ProgramTensorMetadataDependency::Rank, output_shape_temp, components); + const Tensor* bias = has_bias ? inputs[2] : nullptr; + bool use_bias_in_matmul = has_bias; + uint32_t split_dim_inner = 1; + + const SplitKConfig& split_k_config = context->GetSplitKConfig(); + const bool need_split_k = split_k_config.UseSplitK(is_vec4, activation.activation_kind_, batch_size, is_channels_last, dim_a_outer, dim_b_outer, dim_inner); + if (need_split_k) { + ORT_ENFORCE(batch_size == 1, "Split-K MatMul only supports batch_size == 1."); + ORT_ENFORCE(is_vec4, "Split-K MatMul only supports bias in vec4 format."); + ORT_ENFORCE(is_channels_last, "Split-K MatMul only supports channels-last format."); + + // Initialize `output_tensor` with 0 or bias before MatMulProgram with Split-K enabled. + const auto fill_bias_program = CreateMatMulFillBiasOrZeroBeforeSplitKProgram(bias, output_tensor, output_shape_temp); + ORT_RETURN_IF_ERROR(context->RunProgram(fill_bias_program)); + + // `bias` has been handled in the execution of `fill_bias_program` so we don't need to set + // `bias` again in `MatMulProgram`. + use_bias_in_matmul = false; + + // With Split-K, `dim_inner` will be split into multiple parts and `dispatch_z` will be the + // number of splits along `dim_inner`. + // TODO: avoid using `global_id.xxx` or `workgroup_id.xxx` in `MatMulProgram` when we normalize + // the dispatch size with `ProgramManager::NormalizeDispatchGroupSize()` for `MatMulProgram`. + split_dim_inner = split_k_config.GetSplitDimInner(); + dispatch_z = (dim_inner + split_dim_inner - 1) / split_dim_inner; + + // The output should be declared in atomic types in `MatMulProgram` for the use of atomic + // built-in functions. + output.is_atomic = true; + } + + MatMulProgram matmul_program{activation, use_bias_in_matmul, is_vec4, elements_per_thread, is_channels_last, split_dim_inner}; + matmul_program + .CacheHint(activation.ToString(), absl::StrJoin(elements_per_thread, "-"), std::to_string(is_vec4), components, is_channels_last, split_dim_inner) .AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, a_shape_temp, components}, {b, ProgramTensorMetadataDependency::TypeAndRank, b_shape_temp, components}}) - .AddOutputs({{output_tensor, ProgramTensorMetadataDependency::Rank, output_shape_temp, components}}) .AddUniformVariables({{dim_a_outer}, {dim_b_outer}, {dim_inner}}) .AddIndices(outer_dims) .SetDispatchGroupSize(dispatch_x, dispatch_y, dispatch_z) - .SetWorkgroupSize(MatMul::MATMUL_PACKED_WORKGROUP_SIZE_X, MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Y, MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Z); + .SetWorkgroupSize(MatMul::MATMUL_PACKED_WORKGROUP_SIZE_X, MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Y, MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Z) + .AddOutput(std::move(output)); - if (has_bias) { + if (use_bias_in_matmul) { auto bias_components = is_channels_last ? components : 1; - const auto* bias = inputs[2]; TensorShape reduced_bias_shape = ReduceShapeByComponents(bias->Shape(), bias_components); - program.AddInput({bias, ProgramTensorMetadataDependency::Rank, reduced_bias_shape, bias_components}); + matmul_program.AddInput({bias, ProgramTensorMetadataDependency::Rank, reduced_bias_shape, bias_components}); + } + + return context->RunProgram(matmul_program); +} + +MatMulFillBiasOrZeroBeforeSplitKProgram CreateMatMulFillBiasOrZeroBeforeSplitKProgram( + const Tensor* bias, + Tensor* output, + const TensorShape& output_shape_vec4) { + const bool has_bias = bias != nullptr; + + // Currently we only support bias in vec4 and channels last format for Split-K MatMul. + constexpr uint32_t bias_components = 4; + MatMulFillBiasOrZeroBeforeSplitKProgram program(has_bias); + + const uint32_t dim_a_outer = narrow(output_shape_vec4[output_shape_vec4.NumDimensions() - 2]); + const uint32_t dim_b_outer_vec4 = narrow(output_shape_vec4[output_shape_vec4.NumDimensions() - 1]); + + // Fill one value (currently only vec4) per invocation. Now we use default workgroup size (64) for + // this program. + const uint32_t total_outputs_vec4 = dim_a_outer * dim_b_outer_vec4; + const uint32_t dispatch_x = (total_outputs_vec4 + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE; + + // To reuse `MatMulWriteFnSource()` we need to set `dim_a_outer` and `dim_b_outer` in scalar + // instead of vec4, while use `output_shape_vec4` directly as the output shape. + const uint32_t dim_b_outer = narrow(dim_b_outer_vec4 * bias_components); + program.CacheHint(has_bias) + .AddOutput({output, ProgramTensorMetadataDependency::TypeAndRank, output_shape_vec4, static_cast(bias_components)}) + .AddUniformVariables({{dim_a_outer}, {dim_b_outer}}) + .SetDispatchGroupSize(dispatch_x); + + if (has_bias) { + const TensorShape reduced_bias_shape = ReduceShapeByComponents(bias->Shape(), bias_components); + program.AddInput({bias, ProgramTensorMetadataDependency::TypeAndRank, reduced_bias_shape, static_cast(bias_components)}); } + return program; } diff --git a/onnxruntime/core/providers/webgpu/math/matmul.h b/onnxruntime/core/providers/webgpu/math/matmul.h index 8ab8c3a6ba2d0..0b65827be7f17 100644 --- a/onnxruntime/core/providers/webgpu/math/matmul.h +++ b/onnxruntime/core/providers/webgpu/math/matmul.h @@ -14,9 +14,14 @@ namespace onnxruntime { namespace webgpu { -MatMulProgram CreateMatMulProgram(const Activation& activation, std::vector& inputs, Tensor* output, bool is_channels_last, - const TensorShape& input_a_reshape = TensorShape(), - const TensorShape& input_b_reshape = TensorShape()); +Status ComputeMatMul(ComputeContext* context, const Activation& activation, std::vector& inputs, Tensor* output, bool is_channels_last, + const TensorShape& input_a_reshape = TensorShape(), + const TensorShape& input_b_reshape = TensorShape()); + +MatMulFillBiasOrZeroBeforeSplitKProgram CreateMatMulFillBiasOrZeroBeforeSplitKProgram( + const Tensor* bias, + Tensor* output, + const TensorShape& output_shape_vec4); class MatMul final : public WebGpuKernel { public: diff --git a/onnxruntime/core/providers/webgpu/math/matmul_packed.cc b/onnxruntime/core/providers/webgpu/math/matmul_packed.cc index 585f8f1e011c4..4daabe8246aa7 100644 --- a/onnxruntime/core/providers/webgpu/math/matmul_packed.cc +++ b/onnxruntime/core/providers/webgpu/math/matmul_packed.cc @@ -14,25 +14,77 @@ namespace webgpu { Status MatMulProgram::GenerateShaderCode(ShaderHelper& shader) const { const auto& a = shader.AddInput("a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); const auto& b = shader.AddInput("b", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); - const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + + const bool need_split_k = NeedSplitK(); + ShaderUsage output_usage = ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias; + if (need_split_k) { + // When Split-K is enabled, we will declare output as `atomic` to call atomic built-in + // functions on it, so we need below information to correctly compute the index on the output. + output_usage |= ShaderUsage::UseIndicesToOffset | ShaderUsage::UseShapeAndStride; + } + const auto& output = shader.AddOutput("output", output_usage); + const auto& batch_dims = shader.AddIndices("batch_dims", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); if (has_bias_) { shader.AddInput("bias", ShaderUsage::UseUniform); } std::string apply_activation = GetActivationSnippet(activation_, "output_value_t", "output_element_t"); + ProgramVariableDataType output_var_type = this->Outputs()[0].var_type; // declare the read and write functions MatMulReadFnSource(shader, a, b, &batch_dims, /*transA = */ false, /*transB = */ false, is_vec4_); - MatMulWriteFnSource(shader, output, has_bias_, /* is_gemm = */ false, 1, is_vec4_ ? 4 : 1, false, apply_activation, is_channels_last_); + MatMulWriteFnSource(shader, output, has_bias_, /* is_gemm = */ false, 1, is_vec4_ ? 4 : 1, false, apply_activation, is_channels_last_, need_split_k, output_var_type); std::string data_type = "a_element_t"; // generate the main function if (is_vec4_) { - ORT_RETURN_IF_ERROR(MakeMatMulPackedVec4Source(shader, elements_per_thread_, WorkgroupSizeX(), WorkgroupSizeY(), data_type, &batch_dims)); + ORT_RETURN_IF_ERROR(MakeMatMulPackedVec4Source( + shader, elements_per_thread_, WorkgroupSizeX(), WorkgroupSizeY(), data_type, &batch_dims, + /*transA = */ false, /*transB = */ false, /*alpha = */ 1.f, /*need_handle_matmul = */ true, + /*output_components = */ 4, /*tile_inner = */ 32, need_split_k, split_dim_inner_)); } else { ORT_RETURN_IF_ERROR(MakeMatMulPackedSource(shader, elements_per_thread_, WorkgroupSizeX(), WorkgroupSizeY(), data_type, &batch_dims)); } return Status::OK(); } +bool MatMulProgram::NeedSplitK() const { + return split_dim_inner_ > 1; +} + +Status MatMulFillBiasOrZeroBeforeSplitKProgram::GenerateShaderCode(ShaderHelper& shader) const { + const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + + if (has_bias_) { + shader.AddInput("bias", ShaderUsage::UseUniform); + } + + // Handle bias with `MatMulWriteFnSource()`. + // Here `use_split_k` is false because we just initialize `output` with bias. + // `use_split_k` is true only when we do the actual MatMul with Split-K. + // Currently we only support bias in vec4 and channels last format for Split-K MatMul. + MatMulWriteFnSource( + shader, output, has_bias_, /*is_gemm*/ false, /*c_components*/ 4, /*output_components*/ 4, /*c_is_scalar*/ false, + /*activation_snippet*/ "", /*is_channels_last*/ true, /*use_split_k*/ false); + + shader.MainFunctionBody() << R"( + let output_components = 4; + let output_id = i32(global_idx); + + let dim_a_outer = i32(uniforms.dim_a_outer); + let dim_b_outer = i32(uniforms.dim_b_outer) / output_components; + if (output_id >= dim_a_outer * dim_b_outer) { + return; + } + + let output_row = output_id / dim_b_outer; + let output_col = output_id % dim_b_outer; + let output_batch = 0; + let output_value = output_value_t(); + mm_write(output_batch, output_row, output_col, output_value); +)"; + + return Status::OK(); +} + } // namespace webgpu } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/math/matmul_packed.h b/onnxruntime/core/providers/webgpu/math/matmul_packed.h index 767fdd8802e5b..143ba61c99e13 100644 --- a/onnxruntime/core/providers/webgpu/math/matmul_packed.h +++ b/onnxruntime/core/providers/webgpu/math/matmul_packed.h @@ -13,24 +13,48 @@ namespace onnxruntime { namespace webgpu { class MatMulProgram final : public Program { public: - MatMulProgram(const Activation& activation, bool bias, bool is_vec4, const gsl::span& elements_per_thread, bool is_channels_last = false) : Program{"MatMul"}, - activation_(activation), - has_bias_{bias}, - is_vec4_{is_vec4}, - elements_per_thread_(elements_per_thread.begin(), elements_per_thread.end()), - is_channels_last_(is_channels_last) {} + MatMulProgram(const Activation& activation, bool bias, bool is_vec4, const gsl::span& elements_per_thread, bool is_channels_last = false, uint32_t split_dim_inner = 1) : Program{"MatMul"}, + activation_(activation), + has_bias_{bias}, + is_vec4_{is_vec4}, + elements_per_thread_(elements_per_thread.begin(), elements_per_thread.end()), + is_channels_last_(is_channels_last), + split_dim_inner_(split_dim_inner) {} Status GenerateShaderCode(ShaderHelper& sh) const override; WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"dim_a_outer", ProgramUniformVariableDataType::Uint32}, {"dim_b_outer", ProgramUniformVariableDataType::Uint32}, {"dim_inner", ProgramUniformVariableDataType::Uint32}); + bool NeedSplitK() const; + private: const Activation activation_; const bool has_bias_; const bool is_vec4_; const InlinedVector elements_per_thread_; bool is_channels_last_ = false; + uint32_t split_dim_inner_ = 1; +}; + +// The program to initialize the output with 0 or bias before doing MatMul with Split-K. In Split-K, +// we set the output values with `atomicLoad` and `atomicCompareExchangeWeak` instead of a direct +// assignment (see the function `HandleMatMulWithSplitK()` in `gemm_utils.cc`), so we must initialize +// the output with 0 or bias first to make sure `atomicLoad` won't return garbage data. +class MatMulFillBiasOrZeroBeforeSplitKProgram final : public Program { + public: + explicit MatMulFillBiasOrZeroBeforeSplitKProgram(bool has_bias) + : Program{"MatMul_Fill_Bias_Or_Zero_Before_Split_K"}, + has_bias_(has_bias) { + } + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"dim_a_outer", ProgramUniformVariableDataType::Uint32}, + {"dim_b_outer", ProgramUniformVariableDataType::Uint32}); + + private: + bool has_bias_ = false; }; } // namespace webgpu diff --git a/onnxruntime/core/providers/webgpu/nn/conv.cc b/onnxruntime/core/providers/webgpu/nn/conv.cc index a2777979ae983..77fa46cb87518 100644 --- a/onnxruntime/core/providers/webgpu/nn/conv.cc +++ b/onnxruntime/core/providers/webgpu/nn/conv.cc @@ -200,8 +200,7 @@ Status Conv::ComputeInternal(ComputeContext& context .AddUniformVariables({{output_size}, {static_cast(matmul_output_shape[1])}, {static_cast(matmul_output_shape[2])}, {static_cast(K)}}); return context.RunProgram(program); } else { - MatMulProgram program = CreateMatMulProgram(activation_, matmul_inputs, output, is_channels_last, matmul_input_reshapes[0], matmul_input_reshapes[1]); - return context.RunProgram(program); + return ComputeMatMul(&context, activation_, matmul_inputs, output, is_channels_last, matmul_input_reshapes[0], matmul_input_reshapes[1]); } } // Transpose weights diff --git a/onnxruntime/core/providers/webgpu/nn/fuse_utils.cc b/onnxruntime/core/providers/webgpu/nn/fuse_utils.cc index 9e934e9eb5db7..aa0caab39a88e 100644 --- a/onnxruntime/core/providers/webgpu/nn/fuse_utils.cc +++ b/onnxruntime/core/providers/webgpu/nn/fuse_utils.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "core/providers/webgpu/nn/fuse_utils.h" +#include "core/framework/op_kernel_info.h" #include namespace onnxruntime { namespace webgpu { diff --git a/onnxruntime/core/providers/webgpu/nn/fuse_utils.h b/onnxruntime/core/providers/webgpu/nn/fuse_utils.h index f5d2585bb9b45..fad7d3d145bc6 100644 --- a/onnxruntime/core/providers/webgpu/nn/fuse_utils.h +++ b/onnxruntime/core/providers/webgpu/nn/fuse_utils.h @@ -1,11 +1,17 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. #include -#include "core/providers/webgpu/webgpu_kernel.h" +#include + +#include "core/common/status.h" #pragma once namespace onnxruntime { + +class OpKernelInfo; + namespace webgpu { + enum class ActivationKind { None, Relu, diff --git a/onnxruntime/core/providers/webgpu/shader_helper.cc b/onnxruntime/core/providers/webgpu/shader_helper.cc index 5447966b91aa7..b08649cbd5d5b 100644 --- a/onnxruntime/core/providers/webgpu/shader_helper.cc +++ b/onnxruntime/core/providers/webgpu/shader_helper.cc @@ -472,7 +472,7 @@ Status ShaderHelper::GenerateSourceCode(std::string& code, std::vector& sha } ss << ": array<"; if (is_atomic) { - if (output->type_ == ProgramVariableDataType::Float32) { + if (output->type_ == ProgramVariableDataType::Float32 || output->type_ == ProgramVariableDataType::Float16x4 || output->type_ == ProgramVariableDataType::Float32x4) { ss << "atomic"; // emulate float atomic via i32 } else if (output->type_ == ProgramVariableDataType::Uint32) { ss << "atomic"; diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.cc b/onnxruntime/core/providers/webgpu/webgpu_context.cc index 9af9cd455b5a4..28decb076951e 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_context.cc @@ -910,6 +910,13 @@ void WebGpuContext::ReleaseGraphResources(std::vector WebGpuContextFactory::contexts_; std::mutex WebGpuContextFactory::mutex_; std::once_flag WebGpuContextFactory::init_default_flag_; diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.h b/onnxruntime/core/providers/webgpu/webgpu_context.h index 1ead7b3a005bb..bd7dae75f2e2d 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.h +++ b/onnxruntime/core/providers/webgpu/webgpu_context.h @@ -5,12 +5,14 @@ #include #include +#include #include "core/providers/webgpu/webgpu_external_header.h" #include "core/common/common.h" #include "core/providers/webgpu/buffer_manager.h" #include "core/providers/webgpu/program_manager.h" +#include "core/providers/webgpu/webgpu_utils.h" #if defined(ENABLE_PIX_FOR_WEBGPU_EP) #include "core/providers/webgpu/webgpu_pix_frame_generator.h" @@ -171,6 +173,13 @@ class WebGpuContext final { Status Run(ComputeContext& context, const ProgramBase& program); void OnRunEnd(); + // + // Get Split-K configuration. + // + // `split_k_config_` won't be initialized until the first call to this method. + // + const SplitKConfig& GetSplitKConfig(); + private: enum class TimestampQueryType { None = 0, @@ -268,6 +277,8 @@ class WebGpuContext final { uint32_t num_pending_dispatches_ = 0; const uint32_t max_num_pending_dispatches_ = 16; + std::optional split_k_config_; + // profiling TimestampQueryType query_type_; wgpu::QuerySet query_set_; diff --git a/onnxruntime/core/providers/webgpu/webgpu_utils.cc b/onnxruntime/core/providers/webgpu/webgpu_utils.cc index 53b96dfe7a346..568d29a96cb88 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_utils.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_utils.cc @@ -21,5 +21,64 @@ TensorShape ReduceShapeByComponents(const TensorShape& shape, int64_t components return TensorShape(shape_vector); } +SplitKConfig SplitKConfig::GetSplitKConfig(const wgpu::AdapterInfo& adapter_info) { + SplitKConfig config = {}; + + if (adapter_info.vendor == std::string_view{"intel"}) { + if (adapter_info.architecture == std::string_view{"xe-2lpg"} || + adapter_info.architecture == std::string_view{"xe-2hpg"} || + adapter_info.architecture == std::string_view{"xe-lpg"} || + adapter_info.architecture == std::string_view{"gen-12hp"}) { + config.enable_split_k_ = true; + + // Below thresholds are only verified on the above Intel GPUs without any regressions. The + // proper value of `max_dim_a_outer_multiplies_dim_b_outer_divides_dim_inner_` may be + // reduced when we support a larger `dim_inner` because larger `dim_inner` will bring more + // atomic calls for each output value. + config.split_dim_inner_ = 256; + config.min_dim_inner_with_split_k_ = config.split_dim_inner_ * 2; + config.max_dim_inner_with_split_k_ = config.split_dim_inner_ * 9; + config.max_dim_a_outer_multiplies_dim_b_outer_divides_dim_inner_ = 35.0f; + } + } + return config; +} + +bool SplitKConfig::UseSplitK( + bool is_vec4, + ActivationKind activation_kind, + uint64_t batch_size, + bool is_channels_last, + uint32_t dim_a_outer, + uint32_t dim_b_outer, + uint32_t dim_inner) const { + if (!enable_split_k_) { + return false; + } + + bool use_split_k = true; + + // TODO: support the cases below. + use_split_k &= activation_kind == ActivationKind::None; + use_split_k &= is_vec4; + use_split_k &= batch_size == 1; + // Now `is_channels_last` is only supported because we only generate vec4 shaders in + // `MatMulFillBiasOrZeroBeforeSplitKProgram`. + use_split_k &= is_channels_last; + + // Split-K works best when `dim_inner` is relatively large compared with `dim_a_outer` and + // `dim_b_outer`. Currently we use the factor between `(dim_a_outer * dim_b_outer)` and + // `dim_inner)` as the metric to decide whether to use Split-K or not. + use_split_k &= (dim_inner >= min_dim_inner_with_split_k_); + use_split_k &= (dim_inner <= max_dim_inner_with_split_k_); + use_split_k &= ((dim_a_outer * dim_b_outer * 1.0f / dim_inner) <= max_dim_a_outer_multiplies_dim_b_outer_divides_dim_inner_); + + return use_split_k; +} + +uint32_t SplitKConfig::GetSplitDimInner() const { + return split_dim_inner_; +} + } // namespace webgpu } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_utils.h b/onnxruntime/core/providers/webgpu/webgpu_utils.h index 86eb57f99f3b3..d45b9bf4dd119 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_utils.h +++ b/onnxruntime/core/providers/webgpu/webgpu_utils.h @@ -7,6 +7,8 @@ #include "core/common/common.h" #include "core/framework/tensor.h" #include "core/framework/tensor_shape.h" +#include "core/providers/webgpu/webgpu_external_header.h" +#include "core/providers/webgpu/nn/fuse_utils.h" namespace onnxruntime { namespace webgpu { @@ -89,5 +91,24 @@ inline Tensor CreateTensorView(const Tensor& tensor, MLDataType new_data_type, c return {new_data_type, new_shape, const_cast(tensor.DataRaw()), tensor.Location()}; } +class SplitKConfig { + public: + static SplitKConfig GetSplitKConfig(const wgpu::AdapterInfo& adapter_info); + + bool UseSplitK( + bool is_vec4, ActivationKind activation_kind, uint64_t batch_size, + bool is_channels_last, uint32_t dim_a_outer, + uint32_t dim_b_outer, uint32_t dim_inner) const; + + uint32_t GetSplitDimInner() const; + + private: + bool enable_split_k_ = false; + uint32_t split_dim_inner_ = 0; + uint32_t min_dim_inner_with_split_k_ = 0; + uint32_t max_dim_inner_with_split_k_ = 0; + float max_dim_a_outer_multiplies_dim_b_outer_divides_dim_inner_ = 0.0f; +}; + } // namespace webgpu } // namespace onnxruntime diff --git a/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc b/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc index 8382258bf39b4..0847c15ba7cc6 100644 --- a/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc +++ b/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc @@ -3,9 +3,10 @@ #include "core/mlas/inc/mlas.h" -#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) || defined(USE_COREML) || defined(USE_XNNPACK) +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) || defined(USE_COREML) || defined(USE_XNNPACK) || defined(USE_WEBGPU) #include "gtest/gtest.h" +#include "test/common/random_generator.h" #include "test/providers/provider_test_utils.h" #include "default_providers.h" @@ -39,12 +40,13 @@ please add the EP to the excluded_providers list. void TestConvFp16Op(const ConvOpAndTestAttributes& attributes, const vector>& inputs, const vector>& input_shapes, - const std::initializer_list& expected_output, + const vector& expected_output, const vector& expected_output_shape, bool weight_is_initializer = false, OpTester::ExpectResult expect_result = OpTester::ExpectResult::kExpectSuccess, const std::string& err_str = "", - int opset = 11) { + int opset = 11, + float rel_error = 0.002f) { std::unique_ptr tester; if (!attributes.activation.empty()) { tester = std::make_unique("NhwcFusedConv", 1, onnxruntime::kMSDomain); @@ -84,7 +86,7 @@ void TestConvFp16Op(const ConvOpAndTestAttributes& attributes, if (inputs.size() >= 4) tester->AddInput(szNames[3], input_shapes[3], inputs[3]); - tester->AddOutput("Y", expected_output_shape, expected_output, /*no sort*/ false, 0.002f, 0.0f); + tester->AddOutput("Y", expected_output_shape, expected_output, /*no sort*/ false, rel_error, 0.0f); std::unordered_set excluded_providers(attributes.excluded_providers); // Disable TensorRT because weight as input is not supported @@ -424,6 +426,118 @@ TEST(ConvFp16Test, Conv2D_2) { TestConvFp16Op(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, true); } +TEST(ConvFp16Test, Conv2D_MatMul_SplitK_No_Bias) { + ConvOpAndTestAttributes attrs = { + "", // auto_pad + vector{1, 1}, // dilations + 1, // group + vector{1, 1}, // kernel_shape + vector{0, 0, 0, 0}, // pads + vector{1, 1}, // strides + {} // excluded EPs + }; + + // Define the matrix shapes to test a matmul-like convolution + constexpr int64_t M = 16; + constexpr int64_t K = 768; + constexpr int64_t N = 64; + + vector X_shape = {1, K, M, 1}; + vector W_shape = {N, K, 1, 1}; + vector Y_shape = {1, N, M, 1}; + + RandomValueGenerator random{1234}; + vector X_float32(random.Gaussian(AsSpan(X_shape), 0.0f, 0.025f)); + vector W_float32(random.Gaussian(AsSpan(W_shape), 0.0f, 0.025f)); + + vector X = FloatsToMLFloat16s(X_float32); + vector W = FloatsToMLFloat16s(W_float32); + + // Calculate expected output values + vector expected_vals_float32; + expected_vals_float32.resize(M * N); + for (int m = 0; m < M; ++m) { + for (int n = 0; n < N; ++n) { + float sum{}; + for (int k = 0; k < K; ++k) { + int x_index = k * M + m; + int w_index = n * K + k; + sum += X[x_index].ToFloat() * W[w_index].ToFloat(); + } + int y_index = n * M + m; + expected_vals_float32[y_index] = sum; + } + } + vector expected_vals = FloatsToMLFloat16s(expected_vals_float32); + + TestConvFp16Op(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, false, + OpTester::ExpectResult::kExpectSuccess, "", 11); + + // NNAPI/CoreML EP requires weight to be an initializer + TestConvFp16Op(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, true, + OpTester::ExpectResult::kExpectSuccess, "", 11); +} + +TEST(ConvFp16Test, Conv2D_MatMul_SplitK_With_Bias) { + ConvOpAndTestAttributes attrs = { + "", // auto_pad + vector{1, 1}, // dilations + 1, // group + vector{1, 1}, // kernel_shape + vector{0, 0, 0, 0}, // pads + vector{1, 1}, // strides + {} // excluded EPs + }; + + // Define the matrix shapes to test a matmul-like convolution + constexpr int64_t M = 16; + constexpr int64_t K = 768; + constexpr int64_t N = 64; + + vector X_shape = {1, K, M, 1}; + vector W_shape = {N, K, 1, 1}; + vector Y_shape = {1, N, M, 1}; + vector B_shape = {N}; + + RandomValueGenerator random{1234}; + vector X_float32(random.Gaussian(AsSpan(X_shape), 0.0f, 0.025f)); + vector W_float32(random.Gaussian(AsSpan(W_shape), 0.0f, 0.025f)); + vector B_float32(random.Gaussian(AsSpan(B_shape), 0.0f, 0.25f)); + + vector X = FloatsToMLFloat16s(X_float32); + vector W = FloatsToMLFloat16s(W_float32); + vector B = FloatsToMLFloat16s(B_float32); + + // Calculate expected output values + vector expected_vals_float32; + expected_vals_float32.resize(M * N); + for (int m = 0; m < M; ++m) { + for (int n = 0; n < N; ++n) { + float sum{}; + for (int k = 0; k < K; ++k) { + int x_index = k * M + m; + int w_index = n * K + k; + sum += X[x_index].ToFloat() * W[w_index].ToFloat(); + } + sum += B[n].ToFloat(); + int y_index = n * M + m; + expected_vals_float32[y_index] = sum; + } + } + vector expected_vals = FloatsToMLFloat16s(expected_vals_float32); + + // Using a higher relative error threshold for the Linux arm64 bots + constexpr float rel_error = 0.02f; + TestConvFp16Op( + attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape, false, + OpTester::ExpectResult::kExpectSuccess, "", 11, rel_error); + + // NNAPI/CoreML EP requires weight to be an initializer + TestConvFp16Op( + attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape, true, + OpTester::ExpectResult::kExpectSuccess, "", 11, rel_error); +} + TEST(ConvFp16Test, Conv2D_Bias_1) { ConvOpAndTestAttributes attrs = { "", // auto_pad diff --git a/onnxruntime/test/providers/cpu/nn/conv_op_test.cc b/onnxruntime/test/providers/cpu/nn/conv_op_test.cc index 7c84aefa1c01f..4efbb8cfd5c19 100644 --- a/onnxruntime/test/providers/cpu/nn/conv_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/conv_op_test.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "core/graph/constants.h" #include "gtest/gtest.h" +#include "test/common/random_generator.h" #include "test/providers/provider_test_utils.h" using namespace std; @@ -23,7 +24,7 @@ struct ConvOpAndTestAttributes { void TestConvOp(const ConvOpAndTestAttributes& attributes, const vector>& inputs, const vector>& input_shapes, - const std::initializer_list& expected_output, + const vector& expected_output, const vector& expected_output_shape, bool weight_is_initializer = false, optional epsilon = optional(), @@ -535,6 +536,103 @@ TEST(ConvTest, Conv2D_AutoPad2) { TestConvOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, true); } +TEST(ConvTest, Conv2D_MatMul_SplitK_No_Bias) { + ConvOpAndTestAttributes attrs = { + "", // auto_pad + vector{1, 1}, // dilations + 1, // group + vector{1, 1}, // kernel_shape + vector{0, 0, 0, 0}, // pads + vector{1, 1}, // strides + {} // excluded EPs + }; + + // Define the matrix shapes to test a matmul-like convolution + constexpr int64_t M = 16; + constexpr int64_t K = 768; + constexpr int64_t N = 64; + + vector X_shape = {1, K, M, 1}; + vector W_shape = {N, K, 1, 1}; + vector Y_shape = {1, N, M, 1}; + + // Fill X and W + RandomValueGenerator random{1234}; + vector X(random.Gaussian(AsSpan(X_shape), 0.0f, 0.025f)); + vector W(random.Gaussian(AsSpan(W_shape), 0.0f, 0.025f)); + + // Calculate expected output values + vector expected_vals; + expected_vals.resize(M * N); + for (int m = 0; m < M; ++m) { + for (int n = 0; n < N; ++n) { + float sum = 0.0f; + for (int k = 0; k < K; ++k) { + int x_index = k * M + m; + int w_index = n * K + k; + sum += X[x_index] * W[w_index]; + } + int y_index = n * M + m; + expected_vals[y_index] = sum; + } + } + + TestConvOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); + + // NNAPI/CoreML EP requires weight to be an initializer + TestConvOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, true); +} + +TEST(ConvTest, Conv2D_MatMul_SplitK_With_Bias) { + ConvOpAndTestAttributes attrs = { + "", // auto_pad + vector{1, 1}, // dilations + 1, // group + vector{1, 1}, // kernel_shape + vector{0, 0, 0, 0}, // pads + vector{1, 1}, // strides + {} // excluded EPs + }; + + // Define the matrix shapes to test a matmul-like convolution + constexpr int64_t M = 16; + constexpr int64_t K = 768; + constexpr int64_t N = 64; + + vector X_shape = {1, K, M, 1}; + vector W_shape = {N, K, 1, 1}; + vector Y_shape = {1, N, M, 1}; + vector B_shape = {N}; + + // Fill X, W and B + RandomValueGenerator random{1234}; + vector X(random.Gaussian(AsSpan(X_shape), 0.0f, 0.025f)); + vector W(random.Gaussian(AsSpan(W_shape), 0.0f, 0.025f)); + vector B(random.Gaussian(AsSpan(B_shape), 0.0f, 0.25f)); + + // Calculate expected output values + vector expected_vals; + expected_vals.resize(M * N); + for (int m = 0; m < M; ++m) { + for (int n = 0; n < N; ++n) { + float sum = 0.0f; + for (int k = 0; k < K; ++k) { + int x_index = k * M + m; + int w_index = n * K + k; + sum += X[x_index] * W[w_index]; + } + sum += B[n]; + int y_index = n * M + m; + expected_vals[y_index] = sum; + } + } + + TestConvOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape); + + // NNAPI/CoreML EP requires weight to be an initializer + TestConvOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape, true); +} + // Conv10 TEST(ConvTest, Conv3D_1) { ConvOpAndTestAttributes attrs = { From 4dbb05f45f3de62e8a2d54613ff915233c74234c Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Thu, 20 Nov 2025 14:08:30 -0800 Subject: [PATCH 05/17] [core] allow using initializer allocator for prepack (#26617) ### Description This PR makes ORT to prefer initializer allocator when calling `OpKernel::PrePack`. If an EP does not register an initializer allocator (currently only WebGPU does this), the behavior is kept unchanged. ### Motivation and Context Helps to improve the memory usage when doing prepack. --- onnxruntime/core/framework/session_state.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/framework/session_state.cc b/onnxruntime/core/framework/session_state.cc index 01ba492eb166e..8fb3dc63aa4d1 100644 --- a/onnxruntime/core/framework/session_state.cc +++ b/onnxruntime/core/framework/session_state.cc @@ -588,7 +588,7 @@ Status SessionState::PrepackConstantInitializedTensors( // within this session. Or if the weight is not present on disk, // we store the newly minted pre-packed data. - AllocatorPtr session_cpu_alloc = GetAllocator(kernel->Info().GetDevice(OrtMemType::OrtMemTypeDefault)); + AllocatorPtr session_initializer_alloc = GetInitializerAllocator(kernel->Info().GetDevice(OrtMemType::OrtMemTypeDefault)); PrePackedWeights weights_to_be_filled_in; // The reason we invoke PrePack() before looking into the container for any pre-packed weight // cached by another instance of the same op_type (for the same constant initializer) is because @@ -596,7 +596,7 @@ Status SessionState::PrepackConstantInitializedTensors( // pre-packed weight with the pre-packed weight generated by this instance of the same op_type because // other static properties of the node like node attributes could play a role in the pre-packed // weights' contents. - ORT_RETURN_IF_ERROR(kernel->PrePack(const_initialized_tensor, input_idx, session_cpu_alloc, + ORT_RETURN_IF_ERROR(kernel->PrePack(const_initialized_tensor, input_idx, session_initializer_alloc, is_packed, &weights_to_be_filled_in)); From 1d2a4341e7df6a85f75715c1954e957cb3f38065 Mon Sep 17 00:00:00 2001 From: Ranjit Ranjan <165394499+ranjitshs@users.noreply.github.com> Date: Fri, 21 Nov 2025 06:10:43 +0530 Subject: [PATCH 06/17] [AIX]Blocking the call of dladdr under _AIX (#26513) ### Description In AIX, dladdr() is not supported so blocking the call of dladdr API under _AIX. we don't have support of cpuinfo pkg also which generates a warning at runtime. This PR is to fox the issues mentioned above. ### Motivation and Context 1. Fix for below compilation error ``` /home/buildusr/jenkins/workspace/onnxruntime-openxl/onnxruntime/onnxruntime/core/platform/posix/env.cc:562:9: error: unknown type name 'Dl_info' 562 | if (Dl_info dl_info{}; ``` 2. Fix for below warning during test application executions. `2025-11-06 07:23:44.176700000 [W:onnxruntime:Default, cpuid_info.cc:95 LogEarlyWarning] Unknown CPU vendor. cpuinfo_vendor value: 0` --- cmake/onnxruntime_unittests.cmake | 10 ++++++++-- onnxruntime/core/common/cpuid_info_vendor.cc | 5 ++++- onnxruntime/core/platform/posix/env.cc | 4 +++- onnxruntime/test/shared_lib/test_runtime_path.cc | 4 ++++ 4 files changed, 19 insertions(+), 4 deletions(-) diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 754669fffbf8d..4913d38939792 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -1528,8 +1528,14 @@ endif() onnxruntime_add_shared_library(onnxruntime_runtime_path_test_shared_library ${onnxruntime_runtime_path_test_shared_library_src}) - target_link_libraries(onnxruntime_runtime_path_test_shared_library PRIVATE - onnxruntime_common cpuinfo ${CMAKE_DL_LIBS}) + if (CMAKE_SYSTEM_NAME MATCHES "AIX") + target_link_libraries(onnxruntime_runtime_path_test_shared_library PRIVATE + onnxruntime_common ${CMAKE_DL_LIBS}) + set_target_properties(onnxruntime_runtime_path_test_shared_library PROPERTIES AIX_SHARED_LIBRARY_ARCHIVE OFF) + else() + target_link_libraries(onnxruntime_runtime_path_test_shared_library PRIVATE + onnxruntime_common cpuinfo ${CMAKE_DL_LIBS}) + endif() target_include_directories(onnxruntime_runtime_path_test_shared_library PRIVATE ${ONNXRUNTIME_ROOT}) if(UNIX) diff --git a/onnxruntime/core/common/cpuid_info_vendor.cc b/onnxruntime/core/common/cpuid_info_vendor.cc index d4d940eedfe28..8675f129da770 100644 --- a/onnxruntime/core/common/cpuid_info_vendor.cc +++ b/onnxruntime/core/common/cpuid_info_vendor.cc @@ -198,7 +198,7 @@ constexpr std::array kCpuVendorInfos{ CpuVendorInfo{cpuinfo_vendor_nvidia, "Nvidia", 0x10DE}, CpuVendorInfo{cpuinfo_vendor_apple, "Apple", 0x106B}, CpuVendorInfo{cpuinfo_vendor_arm, "ARM", 0x13B5}, - + CpuVendorInfo{cpuinfo_vendor_ibm, "IBM", 0x1014}, // TODO add more as needed }; @@ -228,6 +228,9 @@ void CPUIDInfo::VendorInfoInit() { } } #endif // defined(CPUINFO_SUPPORTED) +#if defined(_AIX) + result = cpuinfo_vendor_ibm; +#endif return result; }(); diff --git a/onnxruntime/core/platform/posix/env.cc b/onnxruntime/core/platform/posix/env.cc index 34b6b2de64a92..aeddef0c5188f 100644 --- a/onnxruntime/core/platform/posix/env.cc +++ b/onnxruntime/core/platform/posix/env.cc @@ -556,6 +556,8 @@ class PosixEnv : public Env { } PathString GetRuntimePath() const override { +// In AIX, dladdr is not supported. +#if !defined(_AIX) // Use dladdr() to look up the file that contains an address from this binary. const void* const address_from_this_binary = reinterpret_cast(Env::Default); @@ -568,7 +570,7 @@ class PosixEnv : public Env { runtime_path.remove_filename(); return runtime_path; } - +#endif return PathString{}; } diff --git a/onnxruntime/test/shared_lib/test_runtime_path.cc b/onnxruntime/test/shared_lib/test_runtime_path.cc index 621d006a8659a..f004c96041ef6 100644 --- a/onnxruntime/test/shared_lib/test_runtime_path.cc +++ b/onnxruntime/test/shared_lib/test_runtime_path.cc @@ -21,7 +21,11 @@ bool IsDirectorySeparator(PathChar c) { } } // namespace +#if !defined(_AIX) TEST(GetRuntimePathFromSharedLibraryTest, Basic) { +#else +TEST(GetRuntimePathFromSharedLibraryTest, DISABLED_Basic) { +#endif const auto* runtime_path_cstr = OrtTestGetSharedLibraryRuntimePath(); ASSERT_NE(runtime_path_cstr, nullptr); From ee0ffd5851510d59d8c267fb433edb9767fab5ac Mon Sep 17 00:00:00 2001 From: Wanming Lin Date: Fri, 21 Nov 2025 09:53:47 +0800 Subject: [PATCH 07/17] [WebNN] Update unit tests list (#26566) Only list no fallback and pass tests for WebNN, as a pre-requisite for enabling WebNN CI tests in future. --- js/web/test/suite-test-list.jsonc | 506 +++++++++++++++--------------- 1 file changed, 258 insertions(+), 248 deletions(-) diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index f4ccff1d7770d..9f2f0afa61604 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -1533,7 +1533,7 @@ // // "test_adam_multiple", // // "test_adam", "test_add_bcast", - // "test_add_uint8", + "test_add_uint8", "test_add", "test_and_bcast3v1d", "test_and_bcast3v2d", @@ -1543,37 +1543,38 @@ "test_and2d", "test_and3d", "test_and4d", - "test_argmax_default_axis_example_select_last_index", + // tests "test_arg*_select_last_index" are excluded because WebNN spec does not support select_last_index attribute. + // "test_argmax_default_axis_example_select_last_index", "test_argmax_default_axis_example", - "test_argmax_default_axis_random_select_last_index", + // "test_argmax_default_axis_random_select_last_index", "test_argmax_default_axis_random", - "test_argmax_keepdims_example_select_last_index", + // "test_argmax_keepdims_example_select_last_index", "test_argmax_keepdims_example", - "test_argmax_keepdims_random_select_last_index", + // "test_argmax_keepdims_random_select_last_index", "test_argmax_keepdims_random", - "test_argmax_negative_axis_keepdims_example_select_last_index", + // "test_argmax_negative_axis_keepdims_example_select_last_index", "test_argmax_negative_axis_keepdims_example", - "test_argmax_negative_axis_keepdims_random_select_last_index", + // "test_argmax_negative_axis_keepdims_random_select_last_index", "test_argmax_negative_axis_keepdims_random", - "test_argmax_no_keepdims_example_select_last_index", + // "test_argmax_no_keepdims_example_select_last_index", "test_argmax_no_keepdims_example", - "test_argmax_no_keepdims_random_select_last_index", + // "test_argmax_no_keepdims_random_select_last_index", "test_argmax_no_keepdims_random", - "test_argmin_default_axis_example_select_last_index", + // "test_argmin_default_axis_example_select_last_index", "test_argmin_default_axis_example", - "test_argmin_default_axis_random_select_last_index", + // "test_argmin_default_axis_random_select_last_index", "test_argmin_default_axis_random", - "test_argmin_keepdims_example_select_last_index", + // "test_argmin_keepdims_example_select_last_index", "test_argmin_keepdims_example", - "test_argmin_keepdims_random_select_last_index", + // "test_argmin_keepdims_random_select_last_index", "test_argmin_keepdims_random", - "test_argmin_negative_axis_keepdims_example_select_last_index", + // "test_argmin_negative_axis_keepdims_example_select_last_index", "test_argmin_negative_axis_keepdims_example", - "test_argmin_negative_axis_keepdims_random_select_last_index", + // "test_argmin_negative_axis_keepdims_random_select_last_index", "test_argmin_negative_axis_keepdims_random", - "test_argmin_no_keepdims_example_select_last_index", + // "test_argmin_no_keepdims_example_select_last_index", "test_argmin_no_keepdims_example", - "test_argmin_no_keepdims_random_select_last_index", + // "test_argmin_no_keepdims_random_select_last_index", "test_argmin_no_keepdims_random", // "test_asin_example", // "test_asin", @@ -1587,21 +1588,21 @@ // "test_averagepool_2d_ceil", "test_averagepool_2d_default", "test_averagepool_2d_pads_count_include_pad", - "test_averagepool_2d_pads", + // "test_averagepool_2d_pads", // unsupported by TFLite backend. "test_averagepool_2d_precomputed_pads_count_include_pad", "test_averagepool_2d_precomputed_pads", "test_averagepool_2d_precomputed_same_upper", "test_averagepool_2d_precomputed_strides", - "test_averagepool_2d_same_lower", + // "test_averagepool_2d_same_lower", // unsupported by TFLite backend. "test_averagepool_2d_same_upper", "test_averagepool_2d_strides", // "test_averagepool_3d_default", "test_basic_conv_with_padding", "test_basic_conv_without_padding", "test_basic_convinteger", - "test_batchnorm_epsilon_training_mode", + // "test_batchnorm_epsilon_training_mode", // unsupported training_mode by WebNN. "test_batchnorm_epsilon", - "test_batchnorm_example_training_mode", + // "test_batchnorm_example_training_mode", // unsupported training_mode by WebNN. "test_batchnorm_example", // // "test_bernoulli_double_expanded", // // "test_bernoulli_double", @@ -1622,10 +1623,10 @@ // // "test_blackmanwindow_symmetric", // // "test_blackmanwindow", // // "test_cast_BFLOAT16_to_FLOAT", - "test_cast_DOUBLE_to_FLOAT", + // "test_cast_DOUBLE_to_FLOAT", // "test_cast_DOUBLE_to_FLOAT16", // // "test_cast_FLOAT_to_BFLOAT16", - "test_cast_FLOAT_to_DOUBLE", + // "test_cast_FLOAT_to_DOUBLE", // // "test_cast_FLOAT_to_FLOAT16", // // "test_cast_FLOAT_to_STRING", // "test_cast_FLOAT16_to_DOUBLE", @@ -1657,15 +1658,16 @@ // "test_celu", "test_clip_default_inbounds", "test_clip_default_int8_inbounds", - "test_clip_default_int8_max", - "test_clip_default_int8_min", - "test_clip_default_max", - "test_clip_default_min", - "test_clip_example", - "test_clip_inbounds", - "test_clip_outbounds", - "test_clip_splitbounds", - "test_clip", + // "test_clip_default_int8_max", + // "test_clip_default_int8_min", + // tests "test_clip*" on opset > 10 are excluded because max and min are non-constant inputs. + "opset{7,8,9,10}/test_clip_default_max", + "opset{7,8,9,10}/test_clip_default_min", + "opset{7,8,9,10}/test_clip_example", + "opset{7,8,9,10}/test_clip_inbounds", + "opset{7,8,9,10}/test_clip_outbounds", + "opset{7,8,9,10}/test_clip_splitbounds", + "opset{7,8,9,10}/test_clip", // // "test_compress_0", // // "test_compress_1", // // "test_compress_default_axis", @@ -1690,32 +1692,33 @@ "test_convinteger_without_padding", "test_convtranspose_1d", // // "test_convtranspose_3d", - // "test_convtranspose_autopad_same", - "test_convtranspose_dilations", + "!(opset14)/test_convtranspose_autopad_same", + // "test_convtranspose_dilations", // unsupported by TFLite backend. "test_convtranspose_kernel_shape", "opset{9,17}/test_convtranspose_output_shape", "test_convtranspose_pad", - "test_convtranspose_pads", + // "test_convtranspose_pads", // unsupported by TFLite backend. "test_convtranspose_with_kernel", "test_convtranspose", "test_cos_example", "test_cos", // "test_cosh_example", // "test_cosh", - "test_cumsum_1d_exclusive", - "test_cumsum_1d_reverse_exclusive", - "test_cumsum_1d_reverse", - "test_cumsum_1d", - "test_cumsum_2d_axis_0", - "test_cumsum_2d_axis_1", - "test_cumsum_2d_negative_axis", + // tests "test_cumsum*" are excluded because they use float64. + // "test_cumsum_1d_exclusive", + // "test_cumsum_1d_reverse_exclusive", + // "test_cumsum_1d_reverse", + // "test_cumsum_1d", + // "test_cumsum_2d_axis_0", + // "test_cumsum_2d_axis_1", + // "test_cumsum_2d_negative_axis", // "test_depthtospace_crd_mode_example", // "test_depthtospace_crd_mode", // "test_depthtospace_dcr_mode", // "test_depthtospace_example", // "test_depthtospace", - // // "test_dequantizelinear_axis", - // // "test_dequantizelinear", + "test_dequantizelinear_axis", + "test_dequantizelinear", // // "test_det_2d", // // "test_det_nd", // // "test_dft_axis", @@ -1723,27 +1726,27 @@ // // "test_dft", "test_div_bcast", "test_div_example", - // "test_div_uint8", + "test_div_uint8", "test_div", - // // "test_dropout_default_mask_ratio", - // // "test_dropout_default_mask", - // // "test_dropout_default_old", - // // "test_dropout_default_ratio", - // // "test_dropout_default", - // // "test_dropout_random_old", - // // "test_dropout_random", + "test_dropout_default_mask_ratio", + "test_dropout_default_mask", + "test_dropout_default_old", + "test_dropout_default_ratio", + "test_dropout_default", + "test_dropout_random_old", + "test_dropout_random", // // "test_dynamic_slice_default_axes", // // "test_dynamic_slice_end_out_of_bounds", // // "test_dynamic_slice_neg", // // "test_dynamic_slice_start_out_of_bounds", // // "test_dynamic_slice", - // // "test_dynamicquantizelinear_expanded", - // // "test_dynamicquantizelinear_max_adjusted_expanded", - // // "test_dynamicquantizelinear_max_adjusted", - // // "test_dynamicquantizelinear_min_adjusted_expanded", - // // "test_dynamicquantizelinear_min_adjusted", - // // "test_dynamicquantizelinear", - // "test_edge_pad", + "test_dynamicquantizelinear_expanded", + "test_dynamicquantizelinear_max_adjusted_expanded", + "test_dynamicquantizelinear_max_adjusted", + "test_dynamicquantizelinear_min_adjusted_expanded", + "test_dynamicquantizelinear_min_adjusted", + "test_dynamicquantizelinear", + // "opset{7,8,9,10}/test_edge_pad", // The edge padding model is unsupported by TFLite backend. // "test_einsum_batch_diagonal", // "test_einsum_batch_matmul", // "test_einsum_inner_prod", @@ -1754,7 +1757,7 @@ "test_elu", "test_equal_bcast", "test_equal", - // "test_erf", + "test_erf", "test_exp_example", "test_exp", // "test_expand_dim_changed", @@ -1777,11 +1780,11 @@ "test_gather_1", "test_gather_2d_indices", "test_gather_negative_indices", - "test_gather_elements_0", - "test_gather_elements_1", - "test_gather_elements_negative_indices", + // "test_gather_elements_0", // TFLite backend only supports constant indices. + // "test_gather_elements_1", // TFLite backend only supports constant indices. + // "test_gather_elements_negative_indices", // TFLite backend only supports constant indices. "test_gathernd_example_float32", - "test_gathernd_example_int32_batch_dim1", + // "test_gathernd_example_int32_batch_dim1", "test_gathernd_example_int32", "test_gemm_all_attributes", "test_gemm_alpha", @@ -1789,7 +1792,7 @@ "test_gemm_broadcast", "test_gemm_default_matrix_bias", "test_gemm_default_no_bias", - // "test_gemm_default_scalar_bias", + "test_gemm_default_scalar_bias", "test_gemm_default_single_elem_vector_bias", "test_gemm_default_vector_bias", "test_gemm_default_zero_bias", @@ -1845,48 +1848,49 @@ // "test_if_opt", "test_instancenorm_epsilon", "test_instancenorm_example", - // "test_isinf_negative", - // "test_isinf_positive", - // "test_isinf", - // "test_isnan", + "test_isinf_negative", + "test_isinf_positive", + "test_isinf", + "test_isnan", + // tests "test_layernorm*" are excluded because they produce 3 outputs. // "test_layer_normalization_2d_axis_negative_1_expanded", - "test_layer_normalization_2d_axis_negative_1", + // "test_layer_normalization_2d_axis_negative_1", // "test_layer_normalization_2d_axis_negative_2_expanded", - "test_layer_normalization_2d_axis_negative_2", + // "test_layer_normalization_2d_axis_negative_2", // "test_layer_normalization_2d_axis0_expanded", - "test_layer_normalization_2d_axis0", + // "test_layer_normalization_2d_axis0", // "test_layer_normalization_2d_axis1_expanded", - "test_layer_normalization_2d_axis1", + // "test_layer_normalization_2d_axis1", // "test_layer_normalization_3d_axis_negative_1_epsilon_expanded", - "test_layer_normalization_3d_axis_negative_1_epsilon", + // "test_layer_normalization_3d_axis_negative_1_epsilon", // "test_layer_normalization_3d_axis_negative_2_epsilon_expanded", - "test_layer_normalization_3d_axis_negative_2_epsilon", + // "test_layer_normalization_3d_axis_negative_2_epsilon", // "test_layer_normalization_3d_axis_negative_3_epsilon_expanded", - "test_layer_normalization_3d_axis_negative_3_epsilon", + // "test_layer_normalization_3d_axis_negative_3_epsilon", // "test_layer_normalization_3d_axis0_epsilon_expanded", - "test_layer_normalization_3d_axis0_epsilon", + // "test_layer_normalization_3d_axis0_epsilon", // "test_layer_normalization_3d_axis1_epsilon_expanded", - "test_layer_normalization_3d_axis1_epsilon", + // "test_layer_normalization_3d_axis1_epsilon", // "test_layer_normalization_3d_axis2_epsilon_expanded", - "test_layer_normalization_3d_axis2_epsilon", + // "test_layer_normalization_3d_axis2_epsilon", // "test_layer_normalization_4d_axis_negative_1_expanded", - "test_layer_normalization_4d_axis_negative_1", + // "test_layer_normalization_4d_axis_negative_1", // "test_layer_normalization_4d_axis_negative_2_expanded", - "test_layer_normalization_4d_axis_negative_2", + // "test_layer_normalization_4d_axis_negative_2", // "test_layer_normalization_4d_axis_negative_3_expanded", - "test_layer_normalization_4d_axis_negative_3", + // "test_layer_normalization_4d_axis_negative_3", // "test_layer_normalization_4d_axis_negative_4_expanded", - "test_layer_normalization_4d_axis_negative_4", + // "test_layer_normalization_4d_axis_negative_4", // "test_layer_normalization_4d_axis0_expanded", - "test_layer_normalization_4d_axis0", + // "test_layer_normalization_4d_axis0", // "test_layer_normalization_4d_axis1_expanded", - "test_layer_normalization_4d_axis1", + // "test_layer_normalization_4d_axis1", // "test_layer_normalization_4d_axis2_expanded", - "test_layer_normalization_4d_axis2", + // "test_layer_normalization_4d_axis2", // "test_layer_normalization_4d_axis3_expanded", - "test_layer_normalization_4d_axis3", + // "test_layer_normalization_4d_axis3", // "test_layer_normalization_default_axis_expanded", - "test_layer_normalization_default_axis", + // "test_layer_normalization_default_axis", "test_leakyrelu_default", "test_leakyrelu_example", "test_leakyrelu", @@ -1912,42 +1916,42 @@ // // "test_logsoftmax_large_number", // // "test_logsoftmax_negative_axis_expanded", // // "test_logsoftmax_negative_axis", - // "test_lrn_default", - // "test_lrn", + "test_lrn_default", + "test_lrn", // // "test_lstm_batchwise", "test_lstm_defaults", "test_lstm_with_initial_bias", - "test_lstm_with_peepholes", + // "test_lstm_with_peepholes", "test_matmul_2d", "test_matmul_3d", "test_matmul_4d", - // // "test_matmulinteger", + "test_matmulinteger", "test_max_example", // "test_max_float16", "test_max_float32", - "test_max_float64", + // "test_max_float64", // "test_max_int16", - // "test_max_int32", - // "test_max_int64", - // "test_max_int8", + "test_max_int32", + "test_max_int64", + "test_max_int8", "test_max_one_input", "test_max_two_inputs", // "test_max_uint16", - // "test_max_uint32", + "test_max_uint32", // "test_max_uint64", - // "test_max_uint8", + "test_max_uint8", // "test_maxpool_1d_default", // "test_maxpool_2d_ceil", "test_maxpool_2d_default", - "test_maxpool_2d_dilations", - "test_maxpool_2d_pads", + // "test_maxpool_2d_dilations", // unsupported by TFLite backend. + // "test_maxpool_2d_pads", // unsupported by TFLite backend. "test_maxpool_2d_precomputed_pads", "test_maxpool_2d_precomputed_same_upper", "test_maxpool_2d_precomputed_strides", - "test_maxpool_2d_same_lower", + // "test_maxpool_2d_same_lower", // unsupported by TFLite backend. "test_maxpool_2d_same_upper", "test_maxpool_2d_strides", - // "test_maxpool_2d_uint8", + "test_maxpool_2d_uint8", // "test_maxpool_3d_default", // "test_maxpool_with_argmax_2d_precomputed_pads", // "test_maxpool_with_argmax_2d_precomputed_strides", @@ -1960,17 +1964,17 @@ "test_min_example", // "test_min_float16", "test_min_float32", - "test_min_float64", + // "test_min_float64", // "test_min_int16", - // "test_min_int32", - // "test_min_int64", - // "test_min_int8", + "test_min_int32", + "test_min_int64", + "test_min_int8", "test_min_one_input", "test_min_two_inputs", // "test_min_uint16", - // "test_min_uint32", + "test_min_uint32", // "test_min_uint64", - // "test_min_uint8", + "test_min_uint8", // "test_mod_bcast", // "test_mod_broadcast", // "test_mod_float_mixed_sign_example", @@ -1992,9 +1996,9 @@ // // "test_momentum", "test_mul_bcast", "test_mul_example", - // "test_mul_uint8", + "test_mul_uint8", "test_mul", - // "test_mvn_expanded", + "test_mvn_expanded", // "test_mvn", "test_neg_example", "test_neg", @@ -2110,17 +2114,17 @@ // "test_pow_types_float32_uint64", // "test_pow_types_int", // "test_pow_types_int32_float32", - // "test_pow_types_int32_int32", + "test_pow_types_int32_int32", // "test_pow_types_int64_float32", - // "test_pow_types_int64_int64", + "test_pow_types_int64_int64", "test_pow", "test_prelu_broadcast", "test_prelu_example", // // "test_qlinearconv", // // "test_qlinearmatmul_2D", // // "test_qlinearmatmul_3D", - // // "test_quantizelinear_axis", - // // "test_quantizelinear", + "test_quantizelinear_axis", + "test_quantizelinear", // "test_range_float_type_positive_delta_expanded", // "test_range_float_type_positive_delta", // "test_range_int32_type_negative_delta_expanded", @@ -2129,23 +2133,24 @@ "test_reciprocal", "test_reduce_l1_default_axes_keepdims_example", "test_reduce_l1_default_axes_keepdims_random", - "test_reduce_l1_do_not_keepdims_example", - "test_reduce_l1_do_not_keepdims_random", - "test_reduce_l1_keep_dims_example", - "test_reduce_l1_keep_dims_random", - "test_reduce_l1_negative_axes_keep_dims_example", - "test_reduce_l1_negative_axes_keep_dims_random", + // tests "test_reduce_*" on opset > 13 are excluded because the axes is non-constant input. + "opset{7,8,9,10,11,12,13}/test_reduce_l1_do_not_keepdims_example", + "opset{7,8,9,10,11,12,13}/test_reduce_l1_do_not_keepdims_random", + "opset{7,8,9,10,11,12,13}/test_reduce_l1_keep_dims_example", + "opset{7,8,9,10,11,12,13}/test_reduce_l1_keep_dims_random", + "opset{7,8,9,10,11,12,13}/test_reduce_l1_negative_axes_keep_dims_example", + "opset{7,8,9,10,11,12,13}/test_reduce_l1_negative_axes_keep_dims_random", "test_reduce_l2_default_axes_keepdims_example", "test_reduce_l2_default_axes_keepdims_random", - "test_reduce_l2_do_not_keepdims_example", - "test_reduce_l2_do_not_keepdims_random", - "test_reduce_l2_keep_dims_example", - "test_reduce_l2_keep_dims_random", - "test_reduce_l2_negative_axes_keep_dims_example", - "test_reduce_l2_negative_axes_keep_dims_random", - "test_reduce_log_sum_asc_axes", + "opset{7,8,9,10,11,12,13}/test_reduce_l2_do_not_keepdims_example", + "opset{7,8,9,10,11,12,13}/test_reduce_l2_do_not_keepdims_random", + "opset{7,8,9,10,11,12,13}/test_reduce_l2_keep_dims_example", + "opset{7,8,9,10,11,12,13}/test_reduce_l2_keep_dims_random", + "opset{7,8,9,10,11,12,13}/test_reduce_l2_negative_axes_keep_dims_example", + "opset{7,8,9,10,11,12,13}/test_reduce_l2_negative_axes_keep_dims_random", + "opset{7,8,9,10,11,12,13}/test_reduce_log_sum_asc_axes", "test_reduce_log_sum_default", - "test_reduce_log_sum_desc_axes", + "opset{7,8,9,10,11,12,13}/test_reduce_log_sum_desc_axes", // tests "test_reduce_log_sum_exp_*" on opset17/opset18 are excluded because they use float64. "opset{7,8,9}/test_reduce_log_sum_exp_default_axes_keepdims_example", "opset{7,8,9}/test_reduce_log_sum_exp_default_axes_keepdims_random", @@ -2155,116 +2160,118 @@ "opset{7,8,9}/test_reduce_log_sum_exp_keepdims_random", "opset11/test_reduce_log_sum_exp_negative_axes_keepdims_example", "opset11/test_reduce_log_sum_exp_negative_axes_keepdims_random", - "test_reduce_log_sum_negative_axes", + "opset{7,8,9,10,11,12,13}/test_reduce_log_sum_negative_axes", "test_reduce_log_sum", "test_reduce_max_default_axes_keepdim_example", "test_reduce_max_default_axes_keepdims_random", - "test_reduce_max_do_not_keepdims_example", - "test_reduce_max_do_not_keepdims_random", - "test_reduce_max_keepdims_example", - "test_reduce_max_keepdims_random", - "test_reduce_max_negative_axes_keepdims_example", - "test_reduce_max_negative_axes_keepdims_random", + "opset{7,8,9,10,11,12,13}/test_reduce_max_do_not_keepdims_example", + "opset{7,8,9,10,11,12,13}/test_reduce_max_do_not_keepdims_random", + "opset{7,8,9,10,11,12,13}/test_reduce_max_keepdims_example", + "opset{7,8,9,10,11,12,13}/test_reduce_max_keepdims_random", + "opset{7,8,9,10,11,12,13}/test_reduce_max_negative_axes_keepdims_example", + "opset{7,8,9,10,11,12,13}/test_reduce_max_negative_axes_keepdims_random", "test_reduce_mean_default_axes_keepdims_example", "test_reduce_mean_default_axes_keepdims_random", - "test_reduce_mean_do_not_keepdims_example", - "test_reduce_mean_do_not_keepdims_random", - "test_reduce_mean_keepdims_example", - "test_reduce_mean_keepdims_random", - "test_reduce_mean_negative_axes_keepdims_example", - "test_reduce_mean_negative_axes_keepdims_random", + "opset{7,8,9,10,11,12,13}/test_reduce_mean_do_not_keepdims_example", + "opset{7,8,9,10,11,12,13}/test_reduce_mean_do_not_keepdims_random", + "opset{7,8,9,10,11,12,13}/test_reduce_mean_keepdims_example", + "opset{7,8,9,10,11,12,13}/test_reduce_mean_keepdims_random", + "opset{7,8,9,10,11,12,13}/test_reduce_mean_negative_axes_keepdims_example", + "opset{7,8,9,10,11,12,13}/test_reduce_mean_negative_axes_keepdims_random", "test_reduce_min_default_axes_keepdims_example", "test_reduce_min_default_axes_keepdims_random", - "test_reduce_min_do_not_keepdims_example", - "test_reduce_min_do_not_keepdims_random", - "test_reduce_min_keepdims_example", - "test_reduce_min_keepdims_random", - "test_reduce_min_negative_axes_keepdims_example", - "test_reduce_min_negative_axes_keepdims_random", + "opset{7,8,9,10,11,12,13}/test_reduce_min_do_not_keepdims_example", + "opset{7,8,9,10,11,12,13}/test_reduce_min_do_not_keepdims_random", + "opset{7,8,9,10,11,12,13}/test_reduce_min_keepdims_example", + "opset{7,8,9,10,11,12,13}/test_reduce_min_keepdims_random", + "opset{7,8,9,10,11,12,13}/test_reduce_min_negative_axes_keepdims_example", + "opset{7,8,9,10,11,12,13}/test_reduce_min_negative_axes_keepdims_random", "test_reduce_prod_default_axes_keepdims_example", "test_reduce_prod_default_axes_keepdims_random", - "test_reduce_prod_do_not_keepdims_example", - "test_reduce_prod_do_not_keepdims_random", - "test_reduce_prod_keepdims_example", - "test_reduce_prod_keepdims_random", - "test_reduce_prod_negative_axes_keepdims_example", - "test_reduce_prod_negative_axes_keepdims_random", + "opset{7,8,9,10,11,12,13}/test_reduce_prod_do_not_keepdims_example", + "opset{7,8,9,10,11,12,13}/test_reduce_prod_do_not_keepdims_random", + "opset{7,8,9,10,11,12,13}/test_reduce_prod_keepdims_example", + "opset{7,8,9,10,11,12,13}/test_reduce_prod_keepdims_random", + "opset{7,8,9,10,11,12,13}/test_reduce_prod_negative_axes_keepdims_example", + "opset{7,8,9,10,11,12,13}/test_reduce_prod_negative_axes_keepdims_random", "test_reduce_sum_default_axes_keepdims_example", "test_reduce_sum_default_axes_keepdims_random", - "test_reduce_sum_do_not_keepdims_example", - "test_reduce_sum_do_not_keepdims_random", + "opset{7,8,9,10,11}/test_reduce_sum_do_not_keepdims_example", + "opset{7,8,9,10,11}/test_reduce_sum_do_not_keepdims_random", "test_reduce_sum_empty_axes_input_noop_example", "test_reduce_sum_empty_axes_input_noop_random", - "test_reduce_sum_keepdims_example", - "test_reduce_sum_keepdims_random", - "test_reduce_sum_negative_axes_keepdims_example", - "test_reduce_sum_negative_axes_keepdims_random", + "opset{7,8,9,10,11}/test_reduce_sum_keepdims_example", + "opset{7,8,9,10,11}/test_reduce_sum_keepdims_random", + "opset{7,8,9,10,11}/test_reduce_sum_negative_axes_keepdims_example", + "opset{7,8,9,10,11}/test_reduce_sum_negative_axes_keepdims_random", "test_reduce_sum_square_default_axes_keepdims_example", "test_reduce_sum_square_default_axes_keepdims_random", - "test_reduce_sum_square_do_not_keepdims_example", - "test_reduce_sum_square_do_not_keepdims_random", - "test_reduce_sum_square_keepdims_example", - "test_reduce_sum_square_keepdims_random", - "test_reduce_sum_square_negative_axes_keepdims_example", - "test_reduce_sum_square_negative_axes_keepdims_random", - // "test_reflect_pad", + "opset{7,8,9,10,11,12,13}/test_reduce_sum_square_do_not_keepdims_example", + "opset{7,8,9,10,11,12,13}/test_reduce_sum_square_do_not_keepdims_random", + "opset{7,8,9,10,11,12,13}/test_reduce_sum_square_keepdims_example", + "opset{7,8,9,10,11,12,13}/test_reduce_sum_square_keepdims_random", + "opset{7,8,9,10,11,12,13}/test_reduce_sum_square_negative_axes_keepdims_example", + "opset{7,8,9,10,11,12,13}/test_reduce_sum_square_negative_axes_keepdims_random", + "opset{7,8,9,10}/test_reflect_pad", "test_relu", - "test_reshape_allowzero_reordered", - "test_reshape_extended_dims", - "test_reshape_negative_dim", - "test_reshape_negative_extended_dims", - "test_reshape_one_dim", - "test_reshape_reduced_dims", - "test_reshape_reordered_all_dims", - "test_reshape_reordered_dims", - "test_reshape_reordered_last_dims", - "test_reshape_zero_and_negative_dim", - "test_reshape_zero_dim", - "test_resize_downsample_linear", - "test_resize_downsample_nearest", - "test_resize_downsample_scales_cubic_A_n0p5_exclude_outside", + // tests "test_reshape*" are excluded because the shape is non-constant input. + // "test_reshape_allowzero_reordered", + // "test_reshape_extended_dims", + // "test_reshape_negative_dim", + // "test_reshape_negative_extended_dims", + // "test_reshape_one_dim", + // "test_reshape_reduced_dims", + // "test_reshape_reordered_all_dims", + // "test_reshape_reordered_dims", + // "test_reshape_reordered_last_dims", + // "test_reshape_zero_and_negative_dim", + // "test_reshape_zero_dim", + // tests "test_resize*" are excluded because scales and sizes are non-constant inputs. + // "test_resize_downsample_linear", + // "test_resize_downsample_nearest", + // "test_resize_downsample_scales_cubic_A_n0p5_exclude_outside", // "test_resize_downsample_scales_cubic_align_corners", - "test_resize_downsample_scales_cubic", + // "test_resize_downsample_scales_cubic", // "test_resize_downsample_scales_linear_align_corners", - "test_resize_downsample_scales_linear", - "test_resize_downsample_scales_nearest", - "test_resize_downsample_sizes_cubic", - "test_resize_downsample_sizes_linear_pytorch_half_pixel", - "test_resize_downsample_sizes_nearest_tf_half_pixel_for_nn", - "test_resize_downsample_sizes_nearest", - "test_resize_nearest", - "test_resize_tf_crop_and_resize", - "test_resize_upsample_linear", - "test_resize_upsample_nearest", - "test_resize_upsample_scales_cubic_A_n0p5_exclude_outside", - "test_resize_upsample_scales_cubic_align_corners", - "test_resize_upsample_scales_cubic_asymmetric", - "test_resize_upsample_scales_cubic", - "test_resize_upsample_scales_linear_align_corners", - "test_resize_upsample_scales_linear", - "test_resize_upsample_scales_nearest", - "test_resize_upsample_sizes_cubic", - "opset{12,13,17,18}/test_resize_upsample_sizes_nearest_ceil_half_pixel", - "opset{12,13,17,18}/test_resize_upsample_sizes_nearest_floor_align_corners", - "opset{12,13,17,18}/test_resize_upsample_sizes_nearest_round_prefer_ceil_asymmetric", - "test_resize_upsample_sizes_nearest", + // "test_resize_downsample_scales_linear", + // "test_resize_downsample_scales_nearest", + // "test_resize_downsample_sizes_cubic", + // "test_resize_downsample_sizes_linear_pytorch_half_pixel", + // "test_resize_downsample_sizes_nearest_tf_half_pixel_for_nn", + // "test_resize_downsample_sizes_nearest", + // "test_resize_nearest", + // "test_resize_tf_crop_and_resize", + // "test_resize_upsample_linear", + // "test_resize_upsample_nearest", + // "test_resize_upsample_scales_cubic_A_n0p5_exclude_outside", + // "test_resize_upsample_scales_cubic_align_corners", + // "test_resize_upsample_scales_cubic_asymmetric", + // "test_resize_upsample_scales_cubic", + // "test_resize_upsample_scales_linear_align_corners", + // "test_resize_upsample_scales_linear", + // "test_resize_upsample_scales_nearest", + // "test_resize_upsample_sizes_cubic", + // "opset{12,13,17,18}/test_resize_upsample_sizes_nearest_ceil_half_pixel", + // "opset{12,13,17,18}/test_resize_upsample_sizes_nearest_floor_align_corners", + // "opset{12,13,17,18}/test_resize_upsample_sizes_nearest_round_prefer_ceil_asymmetric", + // "test_resize_upsample_sizes_nearest", // // "test_reversesequence_batch", // // "test_reversesequence_time", // // "test_rnn_seq_length", // // "test_roialign_aligned_false", // // "test_roialign_aligned_true", // // "test_roialign", - // // "test_round", + "test_round", // // "test_scan_sum", // // "test_scan9_sum", - "test_scatter_elements_with_axis", - "test_scatter_elements_with_duplicate_indices", - "test_scatter_elements_with_negative_indices", - "test_scatter_elements_without_axis", + // "test_scatter_elements_with_axis", // TFLite backend does not support non-constant indices. + // "test_scatter_elements_with_duplicate_indices", // WebNN only supports reduction type 'none'. + // "test_scatter_elements_with_negative_indices", // TFLite backend does not support non-constant indices. + // "test_scatter_elements_without_axis", // TFLite backend does not support non-constant indices. // // "test_scatter_with_axis", // // "test_scatter_without_axis", - "test_scatternd_add", - "test_scatternd_multiply", + // "test_scatternd_add", // WebNN only supports reduction type 'none'. + // "test_scatternd_multiply", // WebNN only supports reduction type 'none'. "test_scatternd", // // "test_sce_mean_3d_expanded", // // "test_sce_mean_3d_log_prob_expanded", @@ -2365,14 +2372,15 @@ // "test_sinh", // // "test_size_example", // // "test_size", - "test_slice_default_axes", - "test_slice_default_steps", - "test_slice_end_out_of_bounds", - "test_slice_neg_steps", - "test_slice_neg", - "test_slice_negative_axes", - "test_slice_start_out_of_bounds", - "test_slice", + // tests "test_slice_*" on opset > 9 are excluded because starts, ends, axes and steps are non-constant inputs. + "opset{7,8,9}/test_slice_default_axes", + // "test_slice_default_steps", + "opset{7,8,9}/test_slice_end_out_of_bounds", + // "test_slice_neg_steps", + "opset{7,8,9}/test_slice_neg", + // "test_slice_negative_axes", + "opset{7,8,9}/test_slice_start_out_of_bounds", + "opset{7,8,9}/test_slice", "test_softmax_axis_0_expanded", "test_softmax_axis_0", "test_softmax_axis_1_expanded", @@ -2455,23 +2463,24 @@ "test_softmax_large_number", "test_softmax_negative_axis_expanded", "test_softmax_negative_axis", - // // "test_softplus_example", - // // "test_softplus", - // // "test_softsign_example", - // // "test_softsign", + "test_softplus_example", + "test_softplus", + "test_softsign_example", + "test_softsign", // "test_spacetodepth_example", // "test_spacetodepth", - "test_split_equal_parts_1d", - "test_split_equal_parts_2d", + // tests "test_split_*" on opset > 10 are excluded because the split input is non-constant input. + "opset{7,8,9,10}/test_split_equal_parts_1d", + "opset{7,8,9,10}/test_split_equal_parts_2d", "test_split_equal_parts_default_axis", - "test_split_variable_parts_1d", - "test_split_variable_parts_2d", - "test_split_variable_parts_default_axis", - "test_split_zero_size_splits", + "opset{7,8,9,10}/test_split_variable_parts_1d", + "opset{7,8,9,10}/test_split_variable_parts_2d", + "opset{7,8,9,10}/test_split_variable_parts_default_axis", + // "test_split_zero_size_splits", "test_sqrt_example", "test_sqrt", - "test_squeeze_negative_axes", - "test_squeeze", + // "test_squeeze_negative_axes", + "opset{7,8,9,10}/test_squeeze", // // "test_stft_with_window", // // "test_stft", // // "test_strnormalizer_export_monday_casesensintive_lower", @@ -2482,7 +2491,7 @@ // // "test_strnormalizer_nostopwords_nochangecase", "test_sub_bcast", "test_sub_example", - // "test_sub_uint8", + "test_sub_uint8", "test_sub", // "test_sum_example", // "test_sum_one_input", @@ -2501,8 +2510,9 @@ // "test_thresholdedrelu_default", // "test_thresholdedrelu_example", // "test_thresholdedrelu", - "test_tile_precomputed", - "test_tile", + // tests "test_tile*" are excluded because the repeats is non-constant input. + // "test_tile_precomputed", + // "test_tile", // // "test_top_k_negative_axis", // // "test_top_k_smallest", // // "test_top_k", @@ -2520,41 +2530,41 @@ "test_transpose_all_permutations_5", "test_transpose_default", // "test_tril_neg", - // "test_tril_one_row_neg", + "test_tril_one_row_neg", // "test_tril_out_neg", // "test_tril_out_pos", // "test_tril_pos", // "test_tril_square_neg", - // "test_tril_square", + "test_tril_square", // "test_tril_zero", - // "test_tril", + "test_tril", // "test_triu_neg", // "test_triu_one_row", // "test_triu_out_neg_out", // "test_triu_out_pos", // "test_triu_pos", // "test_triu_square_neg", - // "test_triu_square", + "test_triu_square", // "test_triu_zero", - // "test_triu", + "test_triu", // // "test_unique_not_sorted_without_axis", // // "test_unique_sorted_with_axis_3d", // // "test_unique_sorted_with_axis", // // "test_unique_sorted_with_negative_axis", // // "test_unique_sorted_without_axis", - "test_unsqueeze_axis_0", - "test_unsqueeze_axis_1", - "test_unsqueeze_axis_2", - "test_unsqueeze_axis_3", - "test_unsqueeze_negative_axes", - "test_unsqueeze_three_axes", - "test_unsqueeze_two_axes", - "test_unsqueeze_unsorted_axes", - "test_unsqueeze", + // "test_unsqueeze_axis_0", + // "test_unsqueeze_axis_1", + // "test_unsqueeze_axis_2", + // "test_unsqueeze_axis_3", + // "test_unsqueeze_negative_axes", + // "test_unsqueeze_three_axes", + // "test_unsqueeze_two_axes", + // "test_unsqueeze_unsorted_axes", + "opset{7,8,9,10}/test_unsqueeze", // "test_wrap_pad" // "test_upsample_nearest", "test_where_example", - // "test_where_long_example", + "test_where_long_example", "test_xor_bcast3v1d", "test_xor_bcast3v2d", "test_xor_bcast4v2d", From ff0715d3f6b8cec9d274085314e8c808307b975b Mon Sep 17 00:00:00 2001 From: adrastogi Date: Thu, 20 Nov 2025 19:38:33 -0800 Subject: [PATCH 08/17] Adding candidate metadata key for tracking EP's OS driver version (#26616) ### Description This change adds a well-known key name (`os_driver_version`) corresponding to the OS driver version associated with an EP. We will eventually flesh this out to enable retrieving it from the `OrtEpDevice` if it's been populated, but for starters we reserve the name. ### Motivation and Context We have a scenario in WebNN where the browser would like to get the driver version associated with a given EP (this is to enable policy against the driver, e.g. for maintaining a blocklist if a particular driver version has a scenario-blocking bug in it). Having a mechanism to retrieve the driver version via ORT would help with implementing this feature. --------- Co-authored-by: Aditya Rastogi --- .../core/session/onnxruntime_ep_device_ep_metadata_keys.h | 3 +++ 1 file changed, 3 insertions(+) diff --git a/include/onnxruntime/core/session/onnxruntime_ep_device_ep_metadata_keys.h b/include/onnxruntime/core/session/onnxruntime_ep_device_ep_metadata_keys.h index f6afd84dabc5e..5ea4261840299 100644 --- a/include/onnxruntime/core/session/onnxruntime_ep_device_ep_metadata_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_ep_device_ep_metadata_keys.h @@ -9,6 +9,9 @@ // Key for the execution provider version string. This should be available for all plugin EPs. static const char* const kOrtEpDevice_EpMetadataKey_Version = "version"; +// Key for the execution provider OS driver version. +static const char* const kOrtEpDevice_EpMetadataKey_OSDriverVersion = "os_driver_version"; + // Prefix for execution provider compatibility information stored in model metadata. // Used when generating EP context models to store compatibility strings for each EP. // Full key format: "ep_compatibility_info." From bdf8dc2c05fa46f31d4d9860b35b8c869ee52c17 Mon Sep 17 00:00:00 2001 From: Peishen Yan Date: Fri, 21 Nov 2025 13:20:58 +0800 Subject: [PATCH 09/17] [WebNN EP] Support local attention feature for GQA (#26565) ### Description Support the `local_window_size` attribute in **GroupQueryAttention** Operator, which is designed for sliding window attention and may influence the attention mask pattern. For local window size not equal to -1, new attention mask pattern will be created as follows for applying sliding window. ``` condition_1 (old attn_mask) ---> CumSum (axis=3, exclusive=true, reversed=true) | | | Lesser <--- local_window_size | | LogicalAnd <----------------- condition_2 | new attn_mask ``` ### Motivation and Context --- .../webnn/builders/impl/gqa_op_builder.cc | 60 +++++++++++++++---- 1 file changed, 48 insertions(+), 12 deletions(-) diff --git a/onnxruntime/core/providers/webnn/builders/impl/gqa_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gqa_op_builder.cc index 0b927075402fe..a29fbdb91e79f 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gqa_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gqa_op_builder.cc @@ -107,6 +107,7 @@ Status GroupQueryAttentionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_b ORT_RETURN_IF_NOT(GetShape(*input_defs[3], input_past_k_shape, logger), "Cannot get past_key shape"); NodeAttrHelper helper(node); + const int32_t local_window_size = helper.Get("local_window_size", -1); const uint32_t kv_num_heads = helper.Get("kv_num_heads", 0); const uint32_t num_heads = helper.Get("num_heads", 0); @@ -290,18 +291,17 @@ Status GroupQueryAttentionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_b | | +-------------------------------> Lesser <---------------------Transpose (1,0) | - 1 ---> Where <--- finfo_min (minimum value of FP32) + 1 ---> Where (attn_mask) <--- finfo_min (minimum value of FP32) | attention_bias */ - const std::vector mask_shape_ones_shape(batch_size * num_heads * qkv_sequence_length * past_sequence_length, - 1); - std::string mask_shape_ones_shape_name = "webnn_GQA_left_constant_of_scatter_indices_" + std::to_string(batch_size) + - "_" + std::to_string(num_heads) + "_" + std::to_string(qkv_sequence_length) + - "_" + std::to_string(past_sequence_length); - emscripten::val mask_shape_ones_shape_constant = model_builder.CreateOrGetConstant( - ONNX_NAMESPACE::TensorProto_DataType_INT32, mask_shape_ones_shape_name, mask_shape_ones_shape, - std::vector({batch_size, num_heads, qkv_sequence_length, past_sequence_length})); + emscripten::val value_int_one_constant = + model_builder.CreateOrGetConstant(ONNX_NAMESPACE::TensorProto_DataType_INT32, 1, {1}); + + std::vector mask_shape_ones_shape = {batch_size, num_heads, qkv_sequence_length, past_sequence_length}; + common_options.set("label", node.Name() + "_/GQA/GQA_mask_shape_ones/expand"); + emscripten::val mask_shape_ones_shape_constant = model_builder.GetBuilder().call( + "expand", value_int_one_constant, emscripten::val::array(mask_shape_ones_shape), common_options); emscripten::val cumsum_options = emscripten::val::object(); cumsum_options.set("label", node.Name() + "_range_of_mask_shape"); @@ -315,7 +315,7 @@ Status GroupQueryAttentionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_b std::iota(pre_neq_right_data_range.begin(), pre_neq_right_data_range.end(), 1); std::string pre_neq_right_data_range_name = - "webnn_GQA_left_constant_of_scatter_indices_" + std::to_string(qkv_sequence_length); + "webnn_GQA_pre_neq_right_data_range_" + std::to_string(qkv_sequence_length); emscripten::val pre_neq_right_data_range_constant = model_builder.CreateOrGetConstant( ONNX_NAMESPACE::TensorProto_DataType_INT32, pre_neq_right_data_range_name, pre_neq_right_data_range, std::vector({qkv_sequence_length})); @@ -333,10 +333,46 @@ Status GroupQueryAttentionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_b emscripten::val neq_right = model_builder.GetBuilder().call("transpose", expanded_neq_right, transpose_options); - common_options.set("label", node.Name() + "_/GQA/attn_mask/condition"); - emscripten::val condition = + common_options.set("label", node.Name() + "_/GQA/attn_mask/condition_1"); + emscripten::val condition_1 = model_builder.GetBuilder().call("lesser", neq_left, neq_right, common_options); + emscripten::val condition = condition_1; + // For local window size not equal to -1, new attention mask pattern for applying sliding window + /* + condition_1 (old attn_mask) ---> CumSum (axis=3, exclusive=true, reversed=true) + | | + | Lesser <--- local_window_size + | | + LogicalAnd <----------------- condition_2 + | + new attn_mask + */ + if (local_window_size != -1) { + // Cast condition + common_options.set("label", node.Name() + "_/GQA/attn_mask/condition_2/cast"); + emscripten::val casted_condition_1 = + model_builder.GetBuilder().call("cast", condition_1, emscripten::val("int32"), common_options); + + cumsum_options = emscripten::val::object(); + cumsum_options.set("label", node.Name() + "_/GQA/attn_mask/condition_2/cumsum"); + cumsum_options.set("exclusive", true); + cumsum_options.set("reversed", true); + emscripten::val neq_left_2 = model_builder.GetBuilder().call( + "cumulativeSum", casted_condition_1, gsl::narrow(3), cumsum_options); + + emscripten::val local_window_size_constant = + model_builder.CreateOrGetConstant(ONNX_NAMESPACE::TensorProto_DataType_INT32, local_window_size, {1}); + + common_options.set("label", node.Name() + "_/GQA/attn_mask/condition_2"); + emscripten::val condition_2 = + model_builder.GetBuilder().call("lesser", neq_left_2, local_window_size_constant, common_options); + + common_options.set("label", node.Name() + "_/GQA/attn_mask/condition/and"); + condition = model_builder.GetBuilder().call( + "logicalAnd", condition_1, condition_2, common_options); + } + emscripten::val value_one_constant = model_builder.CreateOrGetConstant(ONNX_NAMESPACE::TensorProto_DataType_FLOAT, 1, {1}); From 96926a073329f45b532ecf7f6595c767d1b071a7 Mon Sep 17 00:00:00 2001 From: Xiaofei Han Date: Sat, 22 Nov 2025 00:05:39 +0800 Subject: [PATCH 10/17] [webgpu] Fused CopyKVCache and SplitPackedQKVWithRotaryEmbedding as SplitPackedQKVWithRotaryEmbeddingAndCopyKV (#26563) ### Description Create a ultimated fused path called SplitPackedQKVWithRotaryEmbeddingAndCopyKV which fused SplitPackedQKVWithRotaryEmbedding and CopyKVCache. When use flash attention and static kv cache is enabled, run it. We did the following things: - Support components for existed SplitPackedQKVWithRotaryEmbedding - Fused it and copykvcache as new SplitPackedQKVWithRotaryEmbeddingAndCopyKV ### Motivation and Context On NV5080, the token generation speed improve ~4%. | generation tps | Before | After | |--------|--------|-------| | NV5080 | 135 | **141** | | Intel | 15.3 | 15.4 | | Mac | 71.2 | 71.8 | --- .../webgpu/bert/flash_attention.cc | 163 +++++++++++++++--- .../contrib_ops/webgpu/bert/flash_attention.h | 42 ++++- .../webgpu/bert/group_query_attention.cc | 80 ++++++--- ...ed_qkv_with_rotary_embedding.wgsl.template | 6 +- ..._rotary_embedding_and_copykv.wgsl.template | 111 ++++++++++++ 5 files changed, 349 insertions(+), 53 deletions(-) create mode 100644 onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding_and_copykv.wgsl.template diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc index 00f60142df159..606dbfde15c2c 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc @@ -16,6 +16,32 @@ namespace onnxruntime { namespace contrib { namespace webgpu { +Status SplitPackedQKVWithRotaryEmbeddingAndCopyKVProgram::GenerateShaderCode(ShaderHelper& sh) const { + const auto& packed_qkv = sh.AddInput("packed_qkv", ShaderUsage::UseUniform); + const auto& seqlens = sh.AddInput("seqlens", ShaderUsage::UseUniform); + const auto& cos_cache = sh.AddInput("cos_cache", ShaderUsage::UseUniform); + const auto& sin_cache = sh.AddInput("sin_cache", ShaderUsage::UseUniform); + + const auto& query = sh.AddOutput("query", ShaderUsage::UseUniform); + const auto& present_key = sh.AddOutput("present_key", ShaderUsage::UseUniform); + const auto& present_value = sh.AddOutput("present_value", ShaderUsage::UseUniform); + + if (prepare_indirect_dispatch_) { + sh.AddOutput("indirect_buffer", ShaderUsage::None); + } + + return WGSL_TEMPLATE_APPLY(sh, "bert/split_packed_qkv_with_rotary_embedding_and_copykv.wgsl.template", + WGSL_TEMPLATE_PARAMETER(interleaved, interleaved_), + WGSL_TEMPLATE_PARAMETER(prepare_indirect_dispatch, prepare_indirect_dispatch_), + WGSL_TEMPLATE_VARIABLE(cos_cache, cos_cache), + WGSL_TEMPLATE_VARIABLE(packed_qkv, packed_qkv), + WGSL_TEMPLATE_VARIABLE(present_key, present_key), + WGSL_TEMPLATE_VARIABLE(present_value, present_value), + WGSL_TEMPLATE_VARIABLE(query, query), + WGSL_TEMPLATE_VARIABLE(seqlens, seqlens), + WGSL_TEMPLATE_VARIABLE(sin_cache, sin_cache)); +} + Status CopyKVCacheProgram::GenerateShaderCode(ShaderHelper& shader) const { // Expectations are // qkv have same number of heads and hidden dimension (head size). @@ -351,17 +377,54 @@ Status ComputeFlashAttentionDecodeVxReduce(onnxruntime::webgpu::ComputeContext& Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias, Tensor* output, const Tensor* past_key, Tensor* present_key, const Tensor* past_value, Tensor* present_value, - const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k) { + const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k, + const Tensor* cos_cache, const Tensor* sin_cache) { + constexpr uint32_t tile_size = 64; + // Extract present_sequence_length directly from present_key tensor shape: // (batch_size, num_heads, total_sequence_length/max_sequence_length, head_size) const uint32_t present_sequence_length = static_cast(present_key->Shape()[2]); const bool use_seqlen_k = seqlen_k != nullptr && context.IsGraphCaptureEnabled(); + // Declare query_output at function scope to ensure it persists throughout the function + Tensor query_output; + + // Create indirect dispatch buffer if using indirect dispatch + Tensor* indirect_buffer_ptr = nullptr; + Tensor indirect_buffer; + + // Prepare indirect dispatch buffer for decode path with static KV cache + const bool use_indirect_dispatch = parameters.sequence_length_ == 1 && + parameters.past_present_share_buffer_ && + seqlen_k != nullptr && + context.IsGraphCaptureEnabled(); + if (use_indirect_dispatch) { + const TensorShape indirect_buffer_shape{3}; // 3 uint32 values for dispatch dimensions + indirect_buffer = context.CreateGPUTensor(DataTypeImpl::GetType(), indirect_buffer_shape); + indirect_buffer_ptr = &indirect_buffer; + } + + const bool do_rotary = (cos_cache != nullptr && sin_cache != nullptr); + + if (do_rotary) { + ORT_ENFORCE(parameters.is_packed_qkv_, "Fused SplitPackedQKVWithRotaryEmbeddingAndCopyKV requires packed QKV input."); + ORT_ENFORCE(parameters.past_present_share_buffer_, "Fused SplitPackedQKVWithRotaryEmbeddingAndCopyKV requires static KV cache."); + + // Q points to the packed QKV tensor in this case, create query output tensor + query_output = context.CreateGPUTensor(Q->DataType(), TensorShape({parameters.batch_size_, parameters.sequence_length_, parameters.hidden_size_})); + + ORT_RETURN_IF_ERROR(RunSplitPackedQKVWithRotaryEmbeddingAndCopyKV(context, parameters, + Q, seqlen_k, + cos_cache, sin_cache, + &query_output, present_key, present_value, + indirect_buffer_ptr, tile_size)); + Q = &query_output; + } else { + ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, tile_size, use_indirect_dispatch ? seqlen_k : nullptr, indirect_buffer_ptr)); + } + if (parameters.sequence_length_ > 1) { - const uint32_t tile_size = 64; - // For encode path, use the original CopyKVCache without indirect dispatch preparation - ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, tile_size, use_seqlen_k ? seqlen_k : nullptr, nullptr)); bool has_attention_bias = attention_bias != nullptr; bool is_qualcomm = context.AdapterInfo().vendor == std::string_view{"qualcomm"}; bool is_nvidia = context.AdapterInfo().vendor == std::string_view{"nvidia"}; @@ -406,29 +469,9 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co parameters.sequence_length_, present_sequence_length}); const TensorShape qk_shape(qk_dims); Tensor qk = context.CreateGPUTensor(Q->DataType(), qk_shape); - constexpr uint32_t tile_size = 64; const uint32_t num_total_seq_length_tile = (parameters.total_sequence_length_ + tile_size - 1) / tile_size; const uint32_t num_present_sequence_length_tile = (present_sequence_length + tile_size - 1) / tile_size; - // Determine if we should use indirect dispatch - const bool use_indirect_dispatch = parameters.past_present_share_buffer_ && - seqlen_k != nullptr && - context.IsGraphCaptureEnabled(); - - // Create indirect dispatch buffer if using indirect dispatch - Tensor* indirect_buffer_ptr = nullptr; - Tensor indirect_buffer; - if (use_indirect_dispatch) { - const TensorShape indirect_buffer_shape{3}; // 3 uint32 values for dispatch dimensions - indirect_buffer = context.CreateGPUTensor(DataTypeImpl::GetType(), indirect_buffer_shape); - indirect_buffer_ptr = &indirect_buffer; - // Use the fused CopyKVCache that also prepares the indirect dispatch buffer - ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, tile_size, seqlen_k, indirect_buffer_ptr)); - } else { - // Use the original CopyKVCache without indirect dispatch preparation - ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, tile_size, seqlen_k, nullptr)); - } - // The metadata is used to store the max and sum of each tile. const TensorShapeVector metadata_dims({parameters.batch_size_, parameters.num_heads_, num_present_sequence_length_tile, 2}); @@ -467,6 +510,78 @@ bool CanApplyFlashAttention(const Tensor* bias, const Tensor* present_key, const ((context.AdapterInfo().vendor == std::string_view{"qualcomm"} && parameters.head_size_ % 8 == 0) || parameters.head_size_ % 4 == 0); } +Status RunSplitPackedQKVWithRotaryEmbeddingAndCopyKV(onnxruntime::webgpu::ComputeContext& context, + const WebgpuAttentionParameters& params, + const Tensor* packedQKV, + const Tensor* seqlen_k, + const Tensor* cos_cache, + const Tensor* sin_cache, + Tensor* query, + Tensor* present_key, + Tensor* present_value, + Tensor* indirect_buffer, + uint32_t tile_size) { + const auto half_rotary_embedding_dim = gsl::narrow_cast(cos_cache->Shape()[1]); + const auto head_size = params.head_size_; + + int components = 1; + // Currently we only support vectorization when RoPE is not interleaved + if (!params.rotary_interleaved_) { + if ((params.head_size_ % 4 == 0) && (half_rotary_embedding_dim % 4 == 0)) { + components = 4; + } else if ((params.head_size_ % 2 == 0) && (half_rotary_embedding_dim % 2 == 0)) { + components = 2; + } + } + // Adjust dimensions for vectorization + const auto half_rotary_embedding_dim_vec = half_rotary_embedding_dim / components; + const auto head_size_vec = head_size / components; + + // Dispatch: batch_size * sequence_length * num_heads * (half_rotary_dim + need_copy_dim) + // work_per_head = half_rotary_dim + (head_size - 2 * half_rotary_dim) + // = head_size - half_rotary_dim + const auto work_per_head = head_size_vec - half_rotary_embedding_dim_vec; + auto dispatch_size = static_cast(params.batch_size_ * params.sequence_length_ * params.num_heads_ * work_per_head); + + // Extract present_sequence_length from present_key tensor shape + const uint32_t present_sequence_length = gsl::narrow_cast(present_key->Shape()[2]); + + const bool prepare_indirect_dispatch = (indirect_buffer != nullptr); + + SplitPackedQKVWithRotaryEmbeddingAndCopyKVProgram program(params.rotary_interleaved_, prepare_indirect_dispatch); + program + .CacheHint(params.rotary_interleaved_, prepare_indirect_dispatch) + .AddInput({packedQKV, ProgramTensorMetadataDependency::TypeAndRank, components}) + .AddInputs({ + {seqlen_k, ProgramTensorMetadataDependency::TypeAndRank}, + {cos_cache, ProgramTensorMetadataDependency::Rank, components}, + {sin_cache, ProgramTensorMetadataDependency::Rank, components}, + }); + program.AddOutputs({{query, ProgramTensorMetadataDependency::None, components}, + {present_key, ProgramTensorMetadataDependency::None, components}, + {present_value, ProgramTensorMetadataDependency::None, components}}); + + if (prepare_indirect_dispatch) { + program.AddOutput({indirect_buffer, ProgramTensorMetadataDependency::None}); + } + + program.AddUniformVariables({ + {static_cast(params.sequence_length_)}, + {static_cast(params.hidden_size_ / components)}, + {static_cast(params.kv_hidden_size_ / components)}, + {static_cast(params.num_heads_)}, + {static_cast(params.kv_num_heads_)}, + {head_size_vec}, + {half_rotary_embedding_dim_vec}, + {present_sequence_length}, + {tile_size}, + {static_cast(dispatch_size)}, + }); + + program.SetDispatchGroupSize((dispatch_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE); + return context.RunProgram(program); +} + } // namespace webgpu } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h index 9599c10533351..a936a91695921 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h @@ -15,6 +15,32 @@ namespace webgpu { using namespace onnxruntime::webgpu; +class SplitPackedQKVWithRotaryEmbeddingAndCopyKVProgram final : public Program { + public: + SplitPackedQKVWithRotaryEmbeddingAndCopyKVProgram(bool interleaved, bool prepare_indirect_dispatch) + : Program{"SplitPackedQKVWithRotaryEmbeddingAndCopyKV"}, + interleaved_(interleaved), + prepare_indirect_dispatch_(prepare_indirect_dispatch) {} + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( + {"sequence_length", ProgramUniformVariableDataType::Uint32}, + {"hidden_size", ProgramUniformVariableDataType::Uint32}, + {"kv_hidden_size", ProgramUniformVariableDataType::Uint32}, + {"num_heads", ProgramUniformVariableDataType::Uint32}, + {"kv_num_heads", ProgramUniformVariableDataType::Uint32}, + {"head_size", ProgramUniformVariableDataType::Uint32}, + {"half_rotary_dim", ProgramUniformVariableDataType::Uint32}, + {"present_sequence_length", ProgramUniformVariableDataType::Uint32}, + {"tile_size", ProgramUniformVariableDataType::Uint32}, + {"dispatch_size", ProgramUniformVariableDataType::Uint32}); + + private: + const bool interleaved_; + const bool prepare_indirect_dispatch_; +}; + class CopyKVCacheProgram final : public Program { public: CopyKVCacheProgram(const std::string& kernel_name, bool has_past, bool kv_BNSH, @@ -145,10 +171,24 @@ class FlashAttentionDecodeVxReduceProgram final : public ProgramShape().Size(); program - .AddInput({packedQKV, ProgramTensorMetadataDependency::Rank}) - .AddOutputs({{query, ProgramTensorMetadataDependency::Rank}, {key, ProgramTensorMetadataDependency::Rank}, {val, ProgramTensorMetadataDependency::Rank}}) + .AddInput({packedQKV, ProgramTensorMetadataDependency::TypeAndRank}) + .AddOutputs({{query, ProgramTensorMetadataDependency::None}, {key, ProgramTensorMetadataDependency::None}, {val, ProgramTensorMetadataDependency::None}}) .AddUniformVariables({ {static_cast(params.hidden_size_)}, {static_cast(params.kv_hidden_size_)}, @@ -90,32 +90,46 @@ Status RunSplitPackedQKVWithRotaryEmbedding(onnxruntime::webgpu::ComputeContext& const auto half_rotary_embedding_dim = gsl::narrow_cast(cos_cache->Shape()[1]); const auto head_size = params.head_size_; + int components = 1; + // Currently we only support vectorization when RoPE is not interleaved + if (!params.rotary_interleaved_) { + if ((params.head_size_ % 4 == 0) && (half_rotary_embedding_dim % 4 == 0)) { + components = 4; + } else if ((params.head_size_ % 2 == 0) && (half_rotary_embedding_dim % 2 == 0)) { + components = 2; + } + } + + // Adjust dimensions for vectorization + const auto half_rotary_embedding_dim_vec = half_rotary_embedding_dim / components; + const auto head_size_vec = head_size / components; + // Dispatch: batch_size * sequence_length * num_heads * (half_rotary_dim + need_copy_dim) // work_per_head = half_rotary_dim + (head_size - 2 * half_rotary_dim) // = head_size - half_rotary_dim - const auto work_per_head = head_size - half_rotary_embedding_dim; - auto dispatch_size = static_cast(params.batch_size_ * params.sequence_length_ * params.num_heads_ * work_per_head); + const auto work_per_head_vec = head_size_vec - half_rotary_embedding_dim_vec; + auto dispatch_size = static_cast(params.batch_size_ * params.sequence_length_ * params.num_heads_ * work_per_head_vec); SplitPackedQKVWithRotaryEmbeddingProgram program(params.rotary_interleaved_); program .CacheHint(params.rotary_interleaved_) - .AddInput({packedQKV, ProgramTensorMetadataDependency::Rank}) + .AddInput({packedQKV, ProgramTensorMetadataDependency::TypeAndRank, components}) .AddInputs({ - {seqlen_k, ProgramTensorMetadataDependency::Rank}, - {cos_cache, ProgramTensorMetadataDependency::Rank}, - {sin_cache, ProgramTensorMetadataDependency::Rank}, + {seqlen_k, ProgramTensorMetadataDependency::TypeAndRank}, + {cos_cache, ProgramTensorMetadataDependency::Rank, components}, + {sin_cache, ProgramTensorMetadataDependency::Rank, components}, }) - .AddOutputs({{query, ProgramTensorMetadataDependency::Rank}, - {key, ProgramTensorMetadataDependency::Rank}, - {val, ProgramTensorMetadataDependency::Rank}}) + .AddOutputs({{query, ProgramTensorMetadataDependency::None, components}, + {key, ProgramTensorMetadataDependency::None, components}, + {val, ProgramTensorMetadataDependency::None, components}}) .AddUniformVariables({ {static_cast(params.sequence_length_)}, - {static_cast(params.hidden_size_)}, - {static_cast(params.kv_hidden_size_)}, + {static_cast(params.hidden_size_ / components)}, + {static_cast(params.kv_hidden_size_ / components)}, {static_cast(params.num_heads_)}, {static_cast(params.kv_num_heads_)}, - {static_cast(head_size)}, - {half_rotary_embedding_dim}, + {head_size_vec}, + {half_rotary_embedding_dim_vec}, {static_cast(dispatch_size)}, }) .SetDispatchGroupSize((dispatch_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE); @@ -177,15 +191,15 @@ Status RunFusedQKRotaryEmbedding(onnxruntime::webgpu::ComputeContext& context, program .CacheHint(params.rotary_interleaved_) .AddInputs({ - {query_in, ProgramTensorMetadataDependency::Rank}, + {query_in, ProgramTensorMetadataDependency::TypeAndRank}, {key_in, ProgramTensorMetadataDependency::Rank}, - {seqlen_k, ProgramTensorMetadataDependency::Rank}, + {seqlen_k, ProgramTensorMetadataDependency::TypeAndRank}, {cos_cache, ProgramTensorMetadataDependency::Rank}, {sin_cache, ProgramTensorMetadataDependency::Rank}, }) .AddOutputs({ - {query_out, ProgramTensorMetadataDependency::Rank}, - {key_out, ProgramTensorMetadataDependency::Rank}, + {query_out, ProgramTensorMetadataDependency::None}, + {key_out, ProgramTensorMetadataDependency::None}, }) .SetDispatchGroupSize((q_domain_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) .AddUniformVariables({ @@ -265,7 +279,26 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& Tensor qRotary; Tensor kRotary; + + // Use a sliding window if the total sequence exceeds the window's length. + bool use_sliding_window = (local_window_size_ != -1 && local_window_size_ < parameters.total_sequence_length_); + bool will_use_flash_attention = false; + if (head_sink == nullptr && !use_smooth_softmax_ && !use_sliding_window) { + // Create a temporary parameters copy with is_packed_qkv_ set to false to check if flash attention can be applied after unpacking + WebgpuAttentionParameters temp_params = parameters; + temp_params.is_packed_qkv_ = false; + will_use_flash_attention = CanApplyFlashAttention(attention_bias, present_key, present_value, temp_params, context); + } + if (parameters.is_packed_qkv_ && do_rotary_) { + // Use the ultimate fused operation when FlashAttention and static KV cache is enabled. + if (will_use_flash_attention && parameters.past_present_share_buffer_) { + // Directly call ApplyFlashAttention with fused split/rotary/copyKV enabled + // query points to packed QKV, K and V are nullptr since they're not needed + return ApplyFlashAttention(query, nullptr, nullptr, attention_bias, output, past_key, present_key, past_value, + present_value, parameters, context, seqlen_k, cos_cache, sin_cache); + } + // Fused: splitQKV + rotary QK qSplit = context.CreateGPUTensor(query->DataType(), TensorShape({parameters.batch_size_, parameters.sequence_length_, parameters.hidden_size_})); kSplit = context.CreateGPUTensor(query->DataType(), TensorShape({parameters.batch_size_, parameters.sequence_length_, parameters.kv_hidden_size_})); vSplit = context.CreateGPUTensor(query->DataType(), TensorShape({parameters.batch_size_, parameters.sequence_length_, parameters.kv_hidden_size_})); @@ -279,8 +312,8 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& key = &kSplit; value = &vSplit; } else { - // Original separate path if (parameters.is_packed_qkv_) { + // splitQKV qSplit = context.CreateGPUTensor(query->DataType(), TensorShape({parameters.batch_size_, parameters.sequence_length_, parameters.hidden_size_})); kSplit = context.CreateGPUTensor(query->DataType(), TensorShape({parameters.batch_size_, parameters.sequence_length_, parameters.kv_hidden_size_})); vSplit = context.CreateGPUTensor(query->DataType(), TensorShape({parameters.batch_size_, parameters.sequence_length_, parameters.kv_hidden_size_})); @@ -292,6 +325,7 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& value = &vSplit; } if (do_rotary_) { + // rotary QK qRotary = context.CreateGPUTensor(query->DataType(), query->Shape()); kRotary = context.CreateGPUTensor(key->DataType(), key->Shape()); ORT_RETURN_IF_ERROR(RunFusedQKRotaryEmbedding(context, parameters, @@ -304,11 +338,7 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& } } - // Use a sliding window if the total sequence exceeds the window's length. - bool use_sliding_window = (local_window_size_ != -1 && local_window_size_ < parameters.total_sequence_length_); - if (head_sink == nullptr && !use_smooth_softmax_ && - !use_sliding_window && - CanApplyFlashAttention(attention_bias, present_key, present_value, parameters, context)) { + if (will_use_flash_attention) { return ApplyFlashAttention(query, key, value, attention_bias, output, past_key, present_key, past_value, present_value, parameters, context, seqlen_k); } diff --git a/onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding.wgsl.template index b64448611079f..777be41ffb456 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding.wgsl.template @@ -36,11 +36,11 @@ $MAIN { // Calculate actual indices in the head for i and j #if interleaved - let idx_i = in_head_idx; - let idx_j = in_head_idx + 1u; + let idx_i = in_head_idx + in_head_idx; + let idx_j = idx_i + 1u; #else let idx_i = in_head_idx; - let idx_j = in_head_idx + uniforms.half_rotary_dim; + let idx_j = idx_i + uniforms.half_rotary_dim; #endif // Process Q pair diff --git a/onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding_and_copykv.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding_and_copykv.wgsl.template new file mode 100644 index 0000000000000..d6cb654afa756 --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding_and_copykv.wgsl.template @@ -0,0 +1,111 @@ +#param interleaved +#param prepare_indirect_dispatch + +#use guardAgainstOutOfBoundsWorkgroupSizes +#use .setByIndices .getByIndices .getByOffset + +$MAIN { + guardAgainstOutOfBoundsWorkgroupSizes(uniforms.dispatch_size); + + // Dispatch: batch * seq * num_heads * (half_rotary_dim + need_copy_dim) + // work_per_head = half_rotary_dim + (head_size - 2 * half_rotary_dim) + let work_per_head = uniforms.head_size - uniforms.half_rotary_dim; + let total_work = uniforms.num_heads * work_per_head; + + let batch_idx = global_idx / (uniforms.sequence_length * total_work); + let remainder1 = global_idx % (uniforms.sequence_length * total_work); + let seq_idx = remainder1 / total_work; + let remainder2 = remainder1 % total_work; + let head_idx = remainder2 / work_per_head; + let in_head_idx = remainder2 % work_per_head; + + // Calculate base offset in packed_qkv for this token + // Layout per token: [Q(hidden_size), K(kv_hidden_size), V(kv_hidden_size)] + let token_size = uniforms.hidden_size + 2u * uniforms.kv_hidden_size; + let base_offset = batch_idx * uniforms.sequence_length * token_size + seq_idx * token_size; + + // Calculate position_id (needed for rotary embedding) + let seqlen_i = seqlens.getByOffset(batch_idx); + let seqlen = u32(seqlen_i); + let total_seqlen = seqlen + 1u; + + let past_seqlen = total_seqlen - uniforms.sequence_length; + // `position_id` is used to get cos/sin cache and also as the time step index in present_key/present_value + let position_id = past_seqlen + seq_idx; + +#if prepare_indirect_dispatch + // Prepare indirect dispatch buffer for thread 0 + if (global_idx == 0u) { + let num_total_seq_length_tile = (total_seqlen + uniforms.tile_size - 1u) / uniforms.tile_size; + indirect_buffer[0] = num_total_seq_length_tile; + indirect_buffer[1] = uniforms.num_heads; + indirect_buffer[2] = 1u; + } +#endif + + if (in_head_idx < uniforms.half_rotary_dim) { + // Process a rotary pair (i, j) + let cos_v = cos_cache.getByIndices(vec2(position_id, in_head_idx)); + let sin_v = sin_cache.getByIndices(vec2(position_id, in_head_idx)); + + // Calculate actual indices in the head for i and j +#if interleaved + let idx_i = in_head_idx + in_head_idx; + let idx_j = idx_i + 1u; +#else + let idx_i = in_head_idx; + let idx_j = idx_i + uniforms.half_rotary_dim; +#endif + + // Process Q pair + let q_base = base_offset + head_idx * uniforms.head_size; + let q_i_offset = q_base + idx_i; + let q_j_offset = q_base + idx_j; + let q_i = packed_qkv.getByOffset(q_i_offset); + let q_j = packed_qkv.getByOffset(q_j_offset); + let q_re = q_i * cos_v - q_j * sin_v; + let q_im = q_i * sin_v + q_j * cos_v; + query.setByIndices(vec3(batch_idx, seq_idx, head_idx * uniforms.head_size + idx_i), q_re); + query.setByIndices(vec3(batch_idx, seq_idx, head_idx * uniforms.head_size + idx_j), q_im); + + // Process K and V pairs if within kv_num_heads + if (head_idx < uniforms.kv_num_heads) { + let k_base = base_offset + uniforms.hidden_size + head_idx * uniforms.head_size; + let k_i_offset = k_base + idx_i; + let k_j_offset = k_base + idx_j; + let k_i = packed_qkv.getByOffset(k_i_offset); + let k_j = packed_qkv.getByOffset(k_j_offset); + let k_re = k_i * cos_v - k_j * sin_v; + let k_im = k_i * sin_v + k_j * cos_v; + // Write K directly to present_key cache + present_key.setByIndices(vec4(batch_idx, head_idx, position_id, idx_i), k_re); + present_key.setByIndices(vec4(batch_idx, head_idx, position_id, idx_j), k_im); + + // V doesn't need rotary, just copy the pair to present_value cache + let v_base = base_offset + uniforms.hidden_size + uniforms.kv_hidden_size + head_idx * uniforms.head_size; + let v_i = packed_qkv.getByOffset(v_base + idx_i); + let v_j = packed_qkv.getByOffset(v_base + idx_j); + present_value.setByIndices(vec4(batch_idx, head_idx, position_id, idx_i), v_i); + present_value.setByIndices(vec4(batch_idx, head_idx, position_id, idx_j), v_j); + } + } else { + // Process non-rotary elements (direct copy) + let actual_idx = uniforms.half_rotary_dim + in_head_idx; + + // Copy Q + let q_offset = base_offset + head_idx * uniforms.head_size + actual_idx; + let q_data = packed_qkv.getByOffset(q_offset); + query.setByIndices(vec3(batch_idx, seq_idx, head_idx * uniforms.head_size + actual_idx), q_data); + + // Copy K and V if within kv_num_heads directly to present cache + if (head_idx < uniforms.kv_num_heads) { + let k_offset = base_offset + uniforms.hidden_size + head_idx * uniforms.head_size + actual_idx; + let k_data = packed_qkv.getByOffset(k_offset); + present_key.setByIndices(vec4(batch_idx, head_idx, position_id, actual_idx), k_data); + + let v_offset = base_offset + uniforms.hidden_size + uniforms.kv_hidden_size + head_idx * uniforms.head_size + actual_idx; + let v_data = packed_qkv.getByOffset(v_offset); + present_value.setByIndices(vec4(batch_idx, head_idx, position_id, actual_idx), v_data); + } + } +} // MAIN From 4665804592d1ae341094b93c0717b1c1c6656559 Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Fri, 21 Nov 2025 09:08:15 -0800 Subject: [PATCH 11/17] Udpate MS Wil dependency for FETCH_CONTENT to the latest (#26623) ### Description Since this is a MS component, I thinkg vcpkg is already updated ### Motivation and Context The older URL is now failing. --- cmake/deps.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmake/deps.txt b/cmake/deps.txt index e1870bf2df0cf..078a66a4c4d85 100644 --- a/cmake/deps.txt +++ b/cmake/deps.txt @@ -31,7 +31,7 @@ googletest;https://github.com/google/googletest/archive/refs/tags/v1.17.0.zip;f6 googlexnnpack;https://github.com/google/XNNPACK/archive/3cf85e705098622d59056dcb8f5f963ea7bb0a00.zip;6f6bbba627241f89463ca845febaf063982b34fe json;https://github.com/nlohmann/json/archive/refs/tags/v3.11.3.zip;5e88795165cc8590138d1f47ce94ee567b85b4d6 microsoft_gsl;https://github.com/microsoft/GSL/archive/refs/tags/v4.0.0.zip;cf368104cd22a87b4dd0c80228919bb2df3e2a14 -microsoft_wil;https://github.com/microsoft/wil/archive/refs/tags/v1.0.230629.1.zip;e4a542a323c070376f7c2d1973d0f7ddbc1d2fa5 +microsoft_wil;https://github.com/microsoft/wil/archive/refs/tags/v1.0.250325.1.zip;826c8bd47c2258ec61b8b218e031e5b33d27f761 mimalloc;https://github.com/microsoft/mimalloc/archive/refs/tags/v2.1.1.zip;d5ee7d34223d0567892db5179849939c8769dc41 mp11;https://github.com/boostorg/mp11/archive/refs/tags/boost-1.82.0.zip;9bc9e01dffb64d9e0773b2e44d2f22c51aace063 onnx;https://github.com/onnx/onnx/archive/refs/tags/v1.19.1.zip;c5215b5697dcdfd71799f001b8c4054a6bba6b09 From 977efe4788b2ee24371523b5fa14dd02efcd4942 Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Sat, 22 Nov 2025 04:02:47 +0800 Subject: [PATCH 12/17] [webgpu] Throw errors for graph catpure when not implemented (#26604) ### Description ### Motivation and Context --- onnxruntime/contrib_ops/webgpu/bert/attention.cc | 3 +++ 1 file changed, 3 insertions(+) diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.cc b/onnxruntime/contrib_ops/webgpu/bert/attention.cc index 9b9d755498366..a5ab63d74df24 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.cc @@ -522,6 +522,9 @@ Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const T const Tensor* past_key, const Tensor* past_value, Tensor* output, Tensor* present_key, Tensor* present_value, Tensor* output_qk, WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* head_sink, const Tensor* seqlen_k, int local_window_size) { + if (context.IsGraphCaptureEnabled()) { + ORT_NOT_IMPLEMENTED("Graph capture not implemented for non flash attention path"); + } const int output_count = std::min({context.OutputCount(), 1 + (past_key != nullptr ? 1 : 0) + (past_value != nullptr ? 1 : 0)}); const int past_sequence_length = output_count > 1 ? parameters.past_sequence_length_ : 0; const int total_sequence_length = From 4870d455a536cebd1e058ef00c6ad59ccbd016a3 Mon Sep 17 00:00:00 2001 From: r stroh Date: Mon, 24 Nov 2025 05:35:55 +0200 Subject: [PATCH 13/17] Add int8 support to ConvInteger (#26585) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Description This change extends the `ConvInteger` implementation to match the [ONNX operator spec](https://onnx.ai/onnx/operators/onnx__ConvInteger.html), which allows both `int8` and `uint8` for the input tensors: - The ONNX `ConvInteger` schema defines: - `T1`: `tensor(int8)` or `tensor(uint8)` - `T2`: `tensor(int8)` or `tensor(uint8)` - `T3`: `tensor(int32)` - Previously, only the `uint8` × `uint8` combination was supported. - This PR adds support for all 8-bit combinations: - `uint8` × `uint8` (existing behavior) - `uint8` × `int8` - `int8` × `uint8` - `int8` × `int8` ### Motivation and Context Fixes https://github.com/microsoft/onnxruntime/issues/24183 Fixes https://github.com/microsoft/onnxruntime/issues/15888 Fixes https://github.com/microsoft/onnxruntime/issues/12558 Fixes https://github.com/microsoft/onnxruntime/issues/3130 Fixes https://github.com/microsoft/onnxruntime/issues/12362 The ONNX ConvInteger operator schema allows both int8 and uint8 element types for its inputs, but the current implementation only supports uint8 × uint8. This leads to a gap where valid ONNX models using ConvInteger with int8 tensors cannot be executed. This PR closes that gap by: Aligning the implementation with the official ConvInteger type constraints. Enabling models that use int8 (or mixed int8/uint8) for X and W to run without needing operator rewrites or additional custom kernels. Keeping existing uint8 behavior unchanged, so the change is backwards compatible for current users. ### Implementation details 1. Templated core implementation (ComputeInner) The core logic of ConvInteger::Compute is moved into a templated helper: ```text class ConvInteger : public OpKernel { public: ... private: template Status ComputeInner(OpKernelContext* context) const }; ``` XT is the element type of X (uint8_t or int8_t). WT is the element type of W (uint8_t or int8_t). 2. Zero-point handling Zero points are still treated as per-tensor scalar values, with the same validation, The values are read via `DataRaw()` and stored as 8-bit scalars, preserving the previous behavior. Interpretation of these raw bytes as signed or unsigned is delegated to the GEMM implementation via explicit signedness flags (see below). 3. Im2col templated on XT The Im2col call now uses the runtime input type XT. 4. Quantized GEMM with signedness flags: ```text gemm_shape.AIsSigned = W->IsDataType(); gemm_shape.BIsSigned = X->IsDataType(); ``` AIsSigned and BIsSigned are derived from the runtime types of W and X. Data for A and B is passed as raw bytes, the GEMM implementation uses the signedness flags to interpret them correctly (In a manner similar to the implementation in `MatMulInteger`). 5. Runtime dispatch in Compute() The public Compute method becomes a thin dispatcher that selects the appropriate ComputeInner instantiation based on the actual input types. In addition, a small set of unit tests is added on top of the existing ConvInteger tests to cover the new type combinations, including cases where the first input tensor contains negative values (for the int8 × int8 path). --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- docs/OperatorKernels.md | 2 +- .../cpu/quantization/conv_integer.cc | 115 ++-- onnxruntime/core/util/math_cpu.cc | 1 + .../providers/cpu/nn/conv_integer_test.cc | 589 ++++++++++++++++++ 4 files changed, 669 insertions(+), 38 deletions(-) diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 92093ec5464f7..d7be243323c17 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -88,7 +88,7 @@ Do not modify directly.* |Conv|*in* X:**T**
*in* W:**T**
*in* B:**T**
*out* Y:**T**|22+|**T** = tensor(float)| |||[11, 21]|**T** = tensor(float)| |||[1, 10]|**T** = tensor(float)| -|ConvInteger|*in* x:**T1**
*in* w:**T2**
*in* x_zero_point:**T1**
*in* w_zero_point:**T2**
*out* y:**T3**|10+|**T1** = tensor(uint8)
**T2** = tensor(uint8)
**T3** = tensor(int32)| +|ConvInteger|*in* x:**T1**
*in* w:**T2**
*in* x_zero_point:**T1**
*in* w_zero_point:**T2**
*out* y:**T3**|10+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(int32)| |ConvTranspose|*in* X:**T**
*in* W:**T**
*in* B:**T**
*out* Y:**T**|22+|**T** = tensor(float)| |||[11, 21]|**T** = tensor(float)| |||[1, 10]|**T** = tensor(float)| diff --git a/onnxruntime/core/providers/cpu/quantization/conv_integer.cc b/onnxruntime/core/providers/cpu/quantization/conv_integer.cc index f3c6b18f8e753..dc2cec1852fed 100644 --- a/onnxruntime/core/providers/cpu/quantization/conv_integer.cc +++ b/onnxruntime/core/providers/cpu/quantization/conv_integer.cc @@ -28,8 +28,10 @@ ONNX_OPERATOR_KERNEL_EX( 10, kCpuExecutionProvider, KernelDefBuilder() - .TypeConstraint("T1", DataTypeImpl::GetTensorType()) - .TypeConstraint("T2", DataTypeImpl::GetTensorType()) + .TypeConstraint("T1", {DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}) + .TypeConstraint("T2", {DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}) .TypeConstraint("T3", DataTypeImpl::GetTensorType()), ConvInteger); @@ -43,12 +45,12 @@ Status ConvInteger::Compute(OpKernelContext* context) const { if (num_inputs >= 3 && input_defs[2]->Exists()) { const auto* X_Zero_Point = context->Input(2); ORT_ENFORCE(IsScalarOr1ElementVector(X_Zero_Point), "Must be a scalar or 1D tensor or size 1."); - input_offset = *(X_Zero_Point->Data()); + input_offset = *static_cast(X_Zero_Point->DataRaw()); } if (num_inputs >= 4 && input_defs[3]->Exists()) { const auto* W_Zero_Point = context->Input(3); ORT_ENFORCE(IsScalarOr1ElementVector(W_Zero_Point), "Non per-tensor quantization is not supported now."); - filter_offset = *(W_Zero_Point->Data()); + filter_offset = *static_cast(W_Zero_Point->DataRaw()); } const int64_t N = X->Shape()[0]; @@ -110,45 +112,82 @@ Status ConvInteger::Compute(OpKernelContext* context) const { concurrency::ThreadPool* thread_pool = context->GetOperatorThreadPool(); - const auto* Xdata = X->Data(); - const auto* Wdata = W->Data(); + const auto* Xdata = static_cast(X->DataRaw()); + const auto* Wdata = static_cast(W->DataRaw()); + bool X_is_signed = X->IsDataType(); auto* Ydata = Y->MutableData(); for (int image_id = 0; image_id < N; ++image_id) { for (int group_id = 0; group_id < conv_attrs_.group; ++group_id) { if (col_buffer_data != nullptr) { if (kernel_rank == 2) { - math::Im2col()( - Xdata, - C / conv_attrs_.group, - input_shape[0], - input_shape[1], - kernel_shape[0], - kernel_shape[1], - dilations[0], - dilations[1], - pads[0], - pads[1], - pads[2], - pads[3], - strides[0], - strides[1], - col_buffer_data, - input_offset); + if (X_is_signed) { + math::Im2col()( + reinterpret_cast(Xdata), + C / conv_attrs_.group, + input_shape[0], + input_shape[1], + kernel_shape[0], + kernel_shape[1], + dilations[0], + dilations[1], + pads[0], + pads[1], + pads[2], + pads[3], + strides[0], + strides[1], + reinterpret_cast(col_buffer_data), + static_cast(input_offset)); + } else { + math::Im2col()( + Xdata, + C / conv_attrs_.group, + input_shape[0], + input_shape[1], + kernel_shape[0], + kernel_shape[1], + dilations[0], + dilations[1], + pads[0], + pads[1], + pads[2], + pads[3], + strides[0], + strides[1], + col_buffer_data, + input_offset); + } } else { - math::Im2col()( - Xdata, - input_shape.GetDims().data(), - output_shape.GetDims().data(), - kernel_dim, - kernel_shape.data(), - strides.data(), - dilations.data(), - pads.data(), - static_cast(kernel_rank), - col_buffer_data, - false, - input_offset); + if (X_is_signed) { + math::Im2col()( + reinterpret_cast(Xdata), + input_shape.GetDims().data(), + output_shape.GetDims().data(), + kernel_dim, + kernel_shape.data(), + strides.data(), + dilations.data(), + pads.data(), + static_cast(kernel_rank), + reinterpret_cast(col_buffer_data), + false, + static_cast(input_offset)); + } else { + math::Im2col()( + Xdata, + input_shape.GetDims().data(), + output_shape.GetDims().data(), + kernel_dim, + kernel_shape.data(), + strides.data(), + dilations.data(), + pads.data(), + static_cast(kernel_rank), + col_buffer_data, + false, + input_offset); + } } } @@ -156,12 +195,14 @@ Status ConvInteger::Compute(OpKernelContext* context) const { gemm_shape.M = static_cast(M / conv_attrs_.group); gemm_shape.N = static_cast(output_image_size); gemm_shape.K = static_cast(kernel_dim); + gemm_shape.AIsSigned = W->IsDataType(); + gemm_shape.BIsSigned = X_is_signed; MLAS_GEMM_QUANT_DATA_PARAMS gemm_params; gemm_params.A = Wdata + group_id * W_offset; gemm_params.lda = static_cast(kernel_dim); gemm_params.ZeroPointA = filter_offset; - gemm_params.B = (col_buffer_data == nullptr) ? Xdata : col_buffer_data, + gemm_params.B = (col_buffer_data == nullptr) ? Xdata : col_buffer_data; gemm_params.ldb = static_cast(output_image_size); gemm_params.ZeroPointB = &input_offset; gemm_params.C = Ydata; diff --git a/onnxruntime/core/util/math_cpu.cc b/onnxruntime/core/util/math_cpu.cc index dcb4a495c23c5..045dc98a3501e 100644 --- a/onnxruntime/core/util/math_cpu.cc +++ b/onnxruntime/core/util/math_cpu.cc @@ -527,6 +527,7 @@ void Im2col::operator()( template struct Im2col; template struct Im2col; +template struct Im2col; template void Im2col::operator()( diff --git a/onnxruntime/test/providers/cpu/nn/conv_integer_test.cc b/onnxruntime/test/providers/cpu/nn/conv_integer_test.cc index c98d9e28b2f46..8155ac41318f6 100644 --- a/onnxruntime/test/providers/cpu/nn/conv_integer_test.cc +++ b/onnxruntime/test/providers/cpu/nn/conv_integer_test.cc @@ -254,6 +254,595 @@ TEST(ConvIntegerTest, WithStride3_2D_u8u8) { test.Run(); } +TEST(ConvIntegerTest, WithoutPadding_2D_u8s8) { + OpTester test("ConvInteger", 10); + std::vector x_dims{1, 1, 3, 3}; + test.AddInput("x", x_dims, + {2, 3, 4, + 5, 6, 7, + 8, 9, 10}); + std::vector w_dims{1, 1, 2, 2}; + test.AddInput("w", w_dims, + {-9, -9, + -9, -9}); + test.AddInput("x_zero_point", {}, {1}); + test.AddInput("w_zero_point", {}, {-10}); + std::vector y_dims{1, 1, 2, 2}; + test.AddOutput("y", y_dims, + {12, 16, + 24, 28}); + test.Run(); +} + +TEST(ConvIntegerTest, WithPadding_2D_u8s8) { + OpTester test("ConvInteger", 10); + std::vector x_dims{1, 1, 3, 3}; + test.AddInput("x", x_dims, + {2, 3, 4, + 5, 6, 7, + 8, 9, 10}); + std::vector w_dims{1, 1, 2, 2}; + test.AddInput("w", w_dims, + {-9, -9, + -9, -9}); + test.AddInput("x_zero_point", {}, {1}); + test.AddInput("w_zero_point", {}, {-10}); + test.AddAttribute>("pads", {1, 1, 1, 1}); + std::vector y_dims{1, 1, 4, 4}; + test.AddOutput("y", y_dims, + {1, 3, 5, 3, + 5, 12, 16, 9, + 11, 24, 28, 15, + 7, 15, 17, 9}); + test.Run(); +} + +TEST(ConvIntegerTest, WithGroup_2D_u8s8) { + OpTester test("ConvInteger", 10); + std::vector x_dims{1, 3, 3, 3}; + test.AddInput("x", x_dims, + {2, 3, 4, + 5, 6, 7, + 8, 9, 10, + 11, 12, 13, + 14, 15, 16, + 17, 18, 19, + 20, 21, 22, + 23, 24, 25, + 26, 27, 28}); + std::vector w_dims{3, 1, 2, 2}; + test.AddInput("w", w_dims, + {-9, -8, + -8, -9, + -7, -6, + -6, -7, + -5, -4, + -4, -5}); + test.AddInput("x_zero_point", {}, {1}); + test.AddInput("w_zero_point", {}, {-10}); + test.AddAttribute>("pads", {1, 1, 1, 1}); + test.AddAttribute("group", static_cast(3)); + std::vector y_dims{1, 3, 4, 4}; + test.AddOutput("y", y_dims, + {1, 4, 7, 6, + 6, 18, 24, 15, + 15, 36, 42, 24, + 14, 23, 26, 9, + 30, 73, 80, 48, + 79, 168, 182, 96, + 100, 210, 224, 117, + 64, 116, 123, 54, + 95, 214, 225, 126, + 224, 462, 484, 249, + 257, 528, 550, 282, + 150, 281, 292, 135}); + test.Run(); +} + +TEST(ConvIntegerTest, WithPadding_3D_u8s8) { + // TODO: Unskip when fixed #41968513 + if (DefaultDmlExecutionProvider().get() != nullptr) { + GTEST_SKIP() << "Skipping because of the following error: AbiCustomRegistry.cpp(507): The parameter is incorrect."; + } + + OpTester test("ConvInteger", 10); + std::vector x_dims{1, 1, 3, 3, 3}; + test.AddInput("x", x_dims, + {2, 3, 4, + 5, 6, 7, + 8, 9, 10, + 11, 12, 13, + 14, 15, 16, + 17, 18, 19, + 20, 21, 22, + 23, 24, 25, + 26, 27, 28}); + std::vector w_dims{1, 1, 2, 2, 2}; + test.AddInput("w", w_dims, + {-9, -9, + -9, -9, + -9, -9, + -9, -9}); + test.AddInput("x_zero_point", {}, {1}); + test.AddInput("w_zero_point", {}, {-10}); + test.AddAttribute>("pads", {1, 1, 1, 1, 1, 1}); + std::vector y_dims{1, 1, 4, 4, 4}; + test.AddOutput("y", y_dims, + {1, 3, 5, 3, + 5, 12, 16, 9, + 11, 24, 28, 15, + 7, 15, 17, 9, + 11, 24, 28, 15, + 28, 60, 68, 36, + 40, 84, 92, 48, + 23, 48, 52, 27, + 29, 60, 64, 33, + 64, 132, 140, 72, + 76, 156, 164, 84, + 41, 84, 88, 45, + 19, 39, 41, 21, + 41, 84, 88, 45, + 47, 96, 100, 51, + 25, 51, 53, 27}); + test.Run(); +} + +TEST(ConvIntegerTest, WithStride2_2D_u8s8) { + OpTester test("ConvInteger", 10); + std::vector x_dims{1, 1, 7, 7}; + test.AddInput("x", x_dims, + {10, 11, 12, 13, 14, 15, 16, + 20, 21, 22, 23, 24, 25, 26, + 30, 31, 32, 33, 34, 35, 36, + 40, 41, 42, 43, 44, 45, 46, + 50, 51, 52, 53, 54, 55, 56, + 60, 61, 62, 63, 64, 65, 66, + 70, 71, 72, 73, 74, 75, 76}); + std::vector w_dims{1, 1, 3, 3}; + test.AddInput("w", w_dims, + {-9, -8, -9, + -8, -7, -8, + -9, -8, -9}); + test.AddInput("x_zero_point", {}, {10}); + test.AddInput("w_zero_point", {}, {-10}); + test.AddAttribute>("pads", {1, 1, 1, 1}); + test.AddAttribute>("strides", {2, 2}); + std::vector y_dims{1, 1, 4, 4}; + test.AddOutput("y", y_dims, + {33, 62, 84, 75, + 224, 330, 360, 282, + 444, 630, 660, 502, + 453, 642, 664, 495}); + // Exercise the (stride_w = 2) path inside Math::Im2col. + test.Run(); +} + +TEST(ConvIntegerTest, WithStride3_2D_u8s8) { + OpTester test("ConvInteger", 10); + std::vector x_dims{1, 1, 7, 7}; + test.AddInput("x", x_dims, + {10, 11, 12, 13, 14, 15, 16, + 20, 21, 22, 23, 24, 25, 26, + 30, 31, 32, 33, 34, 35, 36, + 40, 41, 42, 43, 44, 45, 46, + 50, 51, 52, 53, 54, 55, 56, + 60, 61, 62, 63, 64, 65, 66, + 70, 71, 72, 73, 74, 75, 76}); + std::vector w_dims{1, 1, 3, 3}; + test.AddInput("w", w_dims, + {-9, -8, -9, + -8, -7, -8, + -9, -8, -9}); + test.AddInput("x_zero_point", {}, {10}); + test.AddInput("w_zero_point", {}, {-10}); + test.AddAttribute>("pads", {2, 2, 1, 1}); + test.AddAttribute>("strides", {3, 3}); + std::vector y_dims{1, 1, 3, 3}; + test.AddOutput("y", y_dims, + {0, 8, 20, + 80, 330, 375, + 200, 780, 825}); + // Exercise the (stride_w > 2) path inside Math::Im2col. + test.Run(); +} + +TEST(ConvIntegerTest, WithoutPadding_2D_s8s8) { + OpTester test("ConvInteger", 10); + + std::vector x_dims{1, 1, 3, 3}; + test.AddInput("x", x_dims, + {-1, 2, -3, + 4, -5, 6, + -7, 8, -9}); + + std::vector w_dims{1, 1, 2, 2}; + test.AddInput("w", w_dims, + {1, -2, + 3, -4}); + + test.AddInput("x_zero_point", {}, {0}); + test.AddInput("w_zero_point", {}, {0}); + + std::vector y_dims{1, 1, 2, 2}; + test.AddOutput("y", y_dims, + {27, -31, + -39, 43}); + + test.Run(); +} + +TEST(ConvIntegerTest, WithPadding_2D_s8s8) { + OpTester test("ConvInteger", 10); + + std::vector x_dims{1, 1, 3, 3}; + test.AddInput("x", x_dims, + {-1, 2, -3, + 4, -5, 6, + -7, 8, -9}); + + std::vector w_dims{1, 1, 2, 2}; + test.AddInput("w", w_dims, + {1, -2, + 3, -4}); + + test.AddInput("x_zero_point", {}, {0}); + test.AddInput("w_zero_point", {}, {0}); + + test.AddAttribute>("pads", {1, 1, 1, 1}); + + std::vector y_dims{1, 1, 4, 4}; + test.AddOutput("y", y_dims, + {4, -11, 18, -9, + -14, 27, -31, 15, + 20, -39, 43, -21, + 14, -23, 26, -9}); + + test.Run(); +} + +TEST(ConvIntegerTest, WithGroup_2D_s8s8) { + OpTester test("ConvInteger", 10); + + std::vector x_dims{1, 3, 3, 3}; + test.AddInput("x", x_dims, + {2, 3, 4, + 5, 6, 7, + 8, 9, 10, + 11, 12, 13, + 14, 15, 16, + 17, 18, 19, + 20, 21, 22, + 23, 24, 25, + 26, 27, 28}); + + std::vector w_dims{3, 1, 2, 2}; + test.AddInput("w", w_dims, + {-9, -8, + -8, -9, + -7, -6, + -6, -7, + -5, -4, + -4, -5}); + + test.AddInput("x_zero_point", {}, {1}); + test.AddInput("w_zero_point", {}, {-10}); + + test.AddAttribute>("pads", {1, 1, 1, 1}); + test.AddAttribute("group", static_cast(3)); + + std::vector y_dims{1, 3, 4, 4}; + test.AddOutput("y", y_dims, + {1, 4, 7, 6, + 6, 18, 24, 15, + 15, 36, 42, 24, + 14, 23, 26, 9, + 30, 73, 80, 48, + 79, 168, 182, 96, + 100, 210, 224, 117, + 64, 116, 123, 54, + 95, 214, 225, 126, + 224, 462, 484, 249, + 257, 528, 550, 282, + 150, 281, 292, 135}); + + test.Run(); +} + +TEST(ConvIntegerTest, WithPadding_3D_s8s8) { + // TODO: Unskip when fixed #41968513 + if (DefaultDmlExecutionProvider().get() != nullptr) { + GTEST_SKIP() << "Skipping because of the following error: AbiCustomRegistry.cpp(507): The parameter is incorrect."; + } + + OpTester test("ConvInteger", 10); + + std::vector x_dims{1, 1, 3, 3, 3}; + test.AddInput("x", x_dims, + {2, 3, 4, + 5, 6, 7, + 8, 9, 10, + 11, 12, 13, + 14, 15, 16, + 17, 18, 19, + 20, 21, 22, + 23, 24, 25, + 26, 27, 28}); + + std::vector w_dims{1, 1, 2, 2, 2}; + test.AddInput("w", w_dims, + {-9, -9, + -9, -9, + -9, -9, + -9, -9}); + + test.AddInput("x_zero_point", {}, {1}); + test.AddInput("w_zero_point", {}, {-10}); + + test.AddAttribute>("pads", {1, 1, 1, 1, 1, 1}); + + std::vector y_dims{1, 1, 4, 4, 4}; + test.AddOutput("y", y_dims, + {1, 3, 5, 3, + 5, 12, 16, 9, + 11, 24, 28, 15, + 7, 15, 17, 9, + 11, 24, 28, 15, + 28, 60, 68, 36, + 40, 84, 92, 48, + 23, 48, 52, 27, + 29, 60, 64, 33, + 64, 132, 140, 72, + 76, 156, 164, 84, + 41, 84, 88, 45, + 19, 39, 41, 21, + 41, 84, 88, 45, + 47, 96, 100, 51, + 25, 51, 53, 27}); + + test.Run(); +} + +TEST(ConvIntegerTest, WithStride2_2D_s8s8) { + OpTester test("ConvInteger", 10); + + std::vector x_dims{1, 1, 7, 7}; + test.AddInput("x", x_dims, + {10, 11, 12, 13, 14, 15, 16, + 20, 21, 22, 23, 24, 25, 26, + 30, 31, 32, 33, 34, 35, 36, + 40, 41, 42, 43, 44, 45, 46, + 50, 51, 52, 53, 54, 55, 56, + 60, 61, 62, 63, 64, 65, 66, + 70, 71, 72, 73, 74, 75, 76}); + + std::vector w_dims{1, 1, 3, 3}; + test.AddInput("w", w_dims, + {-9, -8, -9, + -8, -7, -8, + -9, -8, -9}); + + test.AddInput("x_zero_point", {}, {10}); + test.AddInput("w_zero_point", {}, {-10}); + + test.AddAttribute>("pads", {1, 1, 1, 1}); + test.AddAttribute>("strides", {2, 2}); + + std::vector y_dims{1, 1, 4, 4}; + test.AddOutput("y", y_dims, + {33, 62, 84, 75, + 224, 330, 360, 282, + 444, 630, 660, 502, + 453, 642, 664, 495}); + // Exercise the (stride_w = 2) path inside Math::Im2col. + + test.Run(); +} + +TEST(ConvIntegerTest, WithStride3_2D_s8s8) { + OpTester test("ConvInteger", 10); + + std::vector x_dims{1, 1, 7, 7}; + test.AddInput("x", x_dims, + {10, 11, 12, 13, 14, 15, 16, + 20, 21, 22, 23, 24, 25, 26, + 30, 31, 32, 33, 34, 35, 36, + 40, 41, 42, 43, 44, 45, 46, + 50, 51, 52, 53, 54, 55, 56, + 60, 61, 62, 63, 64, 65, 66, + 70, 71, 72, 73, 74, 75, 76}); + + std::vector w_dims{1, 1, 3, 3}; + test.AddInput("w", w_dims, + {-9, -8, -9, + -8, -7, -8, + -9, -8, -9}); + + test.AddInput("x_zero_point", {}, {10}); + test.AddInput("w_zero_point", {}, {-10}); + + test.AddAttribute>("pads", {2, 2, 1, 1}); + test.AddAttribute>("strides", {3, 3}); + + std::vector y_dims{1, 1, 3, 3}; + test.AddOutput("y", y_dims, + {0, 8, 20, + 80, 330, 375, + 200, 780, 825}); + // Exercise the (stride_w > 2) path inside Math::Im2col. + + test.Run(); +} + +TEST(ConvIntegerTest, WithoutPadding_2D_s8u8) { + OpTester test("ConvInteger", 10); + + std::vector x_dims{1, 1, 3, 3}; + test.AddInput("x", x_dims, + {-1, 2, -3, + 4, -5, 6, + -7, 8, -9}); + + std::vector w_dims{1, 1, 2, 2}; + test.AddInput("w", w_dims, + {1, 2, + 3, 4}); + + test.AddInput("x_zero_point", {}, {0}); + test.AddInput("w_zero_point", {}, {0}); + + std::vector y_dims{1, 1, 2, 2}; + test.AddOutput("y", y_dims, + {-5, 5, + 5, -5}); + + test.Run(); +} + +TEST(ConvIntegerTest, WithPadding_2D_s8u8) { + OpTester test("ConvInteger", 10); + + std::vector x_dims{1, 1, 3, 3}; + test.AddInput("x", x_dims, + {-1, 2, -3, + 4, -5, 6, + -7, 8, -9}); + + std::vector w_dims{1, 1, 2, 2}; + test.AddInput("w", w_dims, + {1, 2, + 3, 4}); + + test.AddInput("x_zero_point", {}, {0}); + test.AddInput("w_zero_point", {}, {0}); + test.AddAttribute>("pads", {1, 1, 1, 1}); + + std::vector y_dims{1, 1, 4, 4}; + test.AddOutput("y", y_dims, + {-4, 5, -6, -9, + 14, -5, 5, 15, + -20, 5, -5, -21, + -14, 9, -10, -9}); + + test.Run(); +} + +TEST(ConvIntegerTest, WithGroup_2D_s8u8) { + OpTester test("ConvInteger", 10); + + std::vector x_dims{1, 3, 3, 3}; + test.AddInput("x", x_dims, + {2, 3, 4, + 5, 6, 7, + 8, 9, 10, + 11, 12, 13, + 14, 15, 16, + 17, 18, 19, + 20, 21, 22, + 23, 24, 25, + 26, 27, 28}); + + std::vector w_dims{3, 1, 2, 2}; + test.AddInput("w", w_dims, + {11, 12, + 12, 11, + 13, 14, + 14, 13, + 15, 16, + 16, 15}); + + test.AddInput("x_zero_point", {}, {1}); + test.AddInput("w_zero_point", {}, {10}); + + test.AddAttribute>("pads", {1, 1, 1, 1}); + test.AddAttribute("group", static_cast(3)); + + std::vector y_dims{1, 3, 4, 4}; + test.AddOutput("y", y_dims, + {1, 4, 7, 6, + 6, 18, 24, 15, + 15, 36, 42, 24, + 14, 23, 26, 9, + 30, 73, 80, 48, + 79, 168, 182, 96, + 100, 210, 224, 117, + 64, 116, 123, 54, + 95, 214, 225, 126, + 224, 462, 484, 249, + 257, 528, 550, 282, + 150, 281, 292, 135}); + + test.Run(); +} + +TEST(ConvIntegerTest, WithStride2_2D_s8u8) { + OpTester test("ConvInteger", 10); + + std::vector x_dims{1, 1, 7, 7}; + test.AddInput("x", x_dims, + {10, 11, 12, 13, 14, 15, 16, + 20, 21, 22, 23, 24, 25, 26, + 30, 31, 32, 33, 34, 35, 36, + 40, 41, 42, 43, 44, 45, 46, + 50, 51, 52, 53, 54, 55, 56, + 60, 61, 62, 63, 64, 65, 66, + 70, 71, 72, 73, 74, 75, 76}); + + std::vector w_dims{1, 1, 3, 3}; + test.AddInput("w", w_dims, + {11, 12, 11, + 12, 13, 12, + 11, 12, 11}); + + test.AddInput("x_zero_point", {}, {10}); + test.AddInput("w_zero_point", {}, {10}); + + test.AddAttribute>("pads", {1, 1, 1, 1}); + test.AddAttribute>("strides", {2, 2}); + + std::vector y_dims{1, 1, 4, 4}; + test.AddOutput("y", y_dims, + {33, 62, 84, 75, + 224, 330, 360, 282, + 444, 630, 660, 502, + 453, 642, 664, 495}); + + test.Run(); +} + +TEST(ConvIntegerTest, WithStride3_2D_s8u8) { + OpTester test("ConvInteger", 10); + + std::vector x_dims{1, 1, 7, 7}; + test.AddInput("x", x_dims, + {10, 11, 12, 13, 14, 15, 16, + 20, 21, 22, 23, 24, 25, 26, + 30, 31, 32, 33, 34, 35, 36, + 40, 41, 42, 43, 44, 45, 46, + 50, 51, 52, 53, 54, 55, 56, + 60, 61, 62, 63, 64, 65, 66, + 70, 71, 72, 73, 74, 75, 76}); + + std::vector w_dims{1, 1, 3, 3}; + test.AddInput("w", w_dims, + {11, 12, 11, + 12, 13, 12, + 11, 12, 11}); + + test.AddInput("x_zero_point", {}, {10}); + test.AddInput("w_zero_point", {}, {10}); + + test.AddAttribute>("pads", {2, 2, 1, 1}); + test.AddAttribute>("strides", {3, 3}); + + std::vector y_dims{1, 1, 3, 3}; + test.AddOutput("y", y_dims, + {0, 8, 20, + 80, 330, 375, + 200, 780, 825}); + + test.Run(); +} + TEST(ConvIntegerTest, NoXZeroPoint) { OpTester test("ConvInteger", 10); std::vector x_dims{1, 1, 3, 3}; From 5834bfe8782ccce88558753343216b89a3ec84d7 Mon Sep 17 00:00:00 2001 From: zpye <986790855@qq.com> Date: Tue, 25 Nov 2025 01:36:36 +0800 Subject: [PATCH 14/17] Add API to access config entries from KernelInfo (#26589) ## Description This PR adds a new API function `KernelInfo_GetConfigEntries` that allows custom operators to access all configuration entries from the `OrtKernelInfo` object during kernel construction. ## Motivation and Context Custom operators may need to access session configuration options to adjust their behavior. Previously, there was no way to retrieve all config entries from `KernelInfo`. This PR provides a convenient method to get all configuration key-value pairs that were set on the `OrtSessionOptions`. ## Changes ### API Additions - **C API**: Added `KernelInfo_GetConfigEntries` function to `OrtApi` (Version 1.24) - Takes an `OrtKernelInfo*` as input - Returns all config entries as `OrtKeyValuePairs*` - Properly documented with usage examples - **C++ API**: Added `GetConfigEntries()` method to `KernelInfoImpl` template class - Returns `KeyValuePairs` object - Follows existing C++ wrapper patterns ### Implementation - Implemented in `onnxruntime/core/session/custom_ops.cc` - Iterates through `config_options_map` from `OpKernelInfo` - Creates and populates `OrtKeyValuePairs` with all configuration entries ### Testing - Updated `shape_inference_test.cc` with test case - Verifies config entries can be retrieved in custom kernel constructor - Tests both existing and non-existing config keys ## Files Changed - `include/onnxruntime/core/session/onnxruntime_c_api.h` - API declaration - `include/onnxruntime/core/session/onnxruntime_cxx_api.h` - C++ wrapper declaration - `include/onnxruntime/core/session/onnxruntime_cxx_inline.h` - C++ wrapper implementation - `onnxruntime/core/session/custom_ops.cc` - Core implementation - `onnxruntime/core/session/onnxruntime_c_api.cc` - API registration - `onnxruntime/core/session/ort_apis.h` - API header declaration - `onnxruntime/test/framework/shape_inference_test.cc` - Test coverage ## API Version This change is part of ORT API Version 1.24. ## Breaking Changes None. This is a backward-compatible addition to the API. --- .../core/session/onnxruntime_c_api.h | 17 +++++++++++++++++ .../core/session/onnxruntime_cxx_api.h | 2 ++ .../core/session/onnxruntime_cxx_inline.h | 7 +++++++ onnxruntime/core/session/custom_ops.cc | 15 +++++++++++++++ onnxruntime/core/session/onnxruntime_c_api.cc | 1 + onnxruntime/core/session/ort_apis.h | 3 +++ .../test/framework/shape_inference_test.cc | 15 ++++++++++++++- 7 files changed, 59 insertions(+), 1 deletion(-) diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 02915f2f1882e..d1b652229e4b6 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -6591,6 +6591,23 @@ struct OrtApi { * \since Version 1.24 */ ORT_API_T(bool, TensorTypeAndShape_HasShape, _In_ const OrtTensorTypeAndShapeInfo* info); + + /** \brief Get all config entries from ::OrtKernelInfo. + * + * Gets all configuration entries from the ::OrtKernelInfo object as key-value pairs. + * Config entries are set on the ::OrtSessionOptions and are accessible in custom operator kernels. + * + * Used in the CreateKernel callback of an OrtCustomOp to access all session configuration entries + * during kernel construction. + * + * \param[in] info An instance of ::OrtKernelInfo. + * \param[out] out A pointer to a newly created OrtKeyValuePairs instance containing all config entries. + * Note: the user should call OrtApi::ReleaseKeyValuePairs. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * \since Version 1.24 + */ + ORT_API2_STATUS(KernelInfo_GetConfigEntries, _In_ const OrtKernelInfo* info, _Outptr_ OrtKeyValuePairs** out); }; /* diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index d3a8856455c49..22708bbf06a3d 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -2768,6 +2768,8 @@ struct KernelInfoImpl : Base { std::string GetNodeName() const; Logger GetLogger() const; + + KeyValuePairs GetConfigEntries() const; }; } // namespace detail diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index 8ee057f51eb20..5144418db2b58 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -2822,6 +2822,13 @@ inline Logger KernelInfoImpl::GetLogger() const { return Logger{out}; } +template +inline KeyValuePairs KernelInfoImpl::GetConfigEntries() const { + OrtKeyValuePairs* out = nullptr; + Ort::ThrowOnError(GetApi().KernelInfo_GetConfigEntries(this->p_, &out)); + return KeyValuePairs{out}; +} + inline void attr_utils::GetAttr(const OrtKernelInfo* p, const char* name, float& out) { Ort::ThrowOnError(GetApi().KernelInfoGetAttribute_float(p, name, &out)); } diff --git a/onnxruntime/core/session/custom_ops.cc b/onnxruntime/core/session/custom_ops.cc index 9bc6c8d0a96a1..6c6c589ffcad4 100644 --- a/onnxruntime/core/session/custom_ops.cc +++ b/onnxruntime/core/session/custom_ops.cc @@ -755,6 +755,21 @@ ORT_API_STATUS_IMPL(OrtApis::KernelInfoGetAllocator, _In_ const OrtKernelInfo* i }); } +ORT_API_STATUS_IMPL(OrtApis::KernelInfo_GetConfigEntries, _In_ const OrtKernelInfo* info, _Outptr_ OrtKeyValuePairs** out) { + return ExecuteIfCustomOpsApiEnabled([&]() -> OrtStatusPtr { + const auto* op_info = reinterpret_cast(info); + const auto& config_options_map = op_info->GetConfigOptions().GetConfigOptionsMap(); + + auto kvps = std::make_unique(); + for (const auto& kv : config_options_map) { + kvps->Add(kv.first.c_str(), kv.second.c_str()); + } + + *out = kvps.release(); + return nullptr; + }); +} + ORT_API_STATUS_IMPL(OrtApis::KernelContext_GetScratchBuffer, _In_ const OrtKernelContext* context, _In_ const OrtMemoryInfo* mem_info, _In_ size_t count_or_bytes, _Outptr_ void** out) { if (count_or_bytes == 0) { *out = nullptr; diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 4891ece8bcda3..394f69bb15b19 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -4237,6 +4237,7 @@ static constexpr OrtApi ort_api_1_to_24 = { // End of Version 23 - DO NOT MODIFY ABOVE (see above text for more information) &OrtApis::TensorTypeAndShape_HasShape, + &OrtApis::KernelInfo_GetConfigEntries, }; // OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase. diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index f016bb3215330..c0e4d32ac0167 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -751,4 +751,7 @@ ORT_API_STATUS_IMPL(CopyTensors, _In_ const OrtEnv* env, _In_reads_(num_tensors) OrtValue* const* dst_tensors, _In_opt_ OrtSyncStream* stream, _In_ size_t num_tensors); + +ORT_API_STATUS_IMPL(KernelInfo_GetConfigEntries, _In_ const OrtKernelInfo* info, _Outptr_ OrtKeyValuePairs** out); + } // namespace OrtApis diff --git a/onnxruntime/test/framework/shape_inference_test.cc b/onnxruntime/test/framework/shape_inference_test.cc index 2d5c3a43ee8ed..37c3825101ba4 100644 --- a/onnxruntime/test/framework/shape_inference_test.cc +++ b/onnxruntime/test/framework/shape_inference_test.cc @@ -78,7 +78,18 @@ TEST_F(ShapeInferenceTest, BasicTest) { namespace { struct MyCustomKernelWithOptionalInput { - MyCustomKernelWithOptionalInput(const OrtKernelInfo* /*info*/) { + MyCustomKernelWithOptionalInput(const OrtKernelInfo* info) { + Ort::ConstKernelInfo k_info(info); + + Ort::KeyValuePairs kvp = k_info.GetConfigEntries(); + + EXPECT_NE(nullptr, kvp.GetValue("session.inter_op.allow_spinning")); + EXPECT_STREQ("0", kvp.GetValue("session.inter_op.allow_spinning")); + + EXPECT_NE(nullptr, kvp.GetValue("session.intra_op.allow_spinning")); + EXPECT_STREQ("0", kvp.GetValue("session.intra_op.allow_spinning")); + + EXPECT_EQ(nullptr, kvp.GetValue("__not__exist__")); } OrtStatusPtr ComputeV2(OrtKernelContext* /* context */) const { @@ -143,6 +154,8 @@ TEST(ShapeInferenceCustomOpTest, custom_op_optional_input_inference_test) { SessionOptions sess_opts; sess_opts.inter_op_param.thread_pool_size = 1; sess_opts.intra_op_param.thread_pool_size = 1; + ASSERT_STATUS_OK(sess_opts.config_options.AddConfigEntry("session.inter_op.allow_spinning", "0")); + ASSERT_STATUS_OK(sess_opts.config_options.AddConfigEntry("session.intra_op.allow_spinning", "0")); InferenceSessionWrapper session{sess_opts, env, OPTIONAL_INPUT_CUSTOM_OP_MODEL_URI_2}; ASSERT_STATUS_OK(session.AddCustomOpDomains(AsSpan(op_domains))); From e8bcd0dea206863727a2a146469b7958fd589d4d Mon Sep 17 00:00:00 2001 From: quic-tirupath Date: Mon, 24 Nov 2025 13:02:22 -0800 Subject: [PATCH 15/17] [QNN EP] Fuse Gelu pattern into a QNN Gelu Node (#26417) ### Description - ONNX models exported with older Opset version contains Gelu operator decomposed into multiple operators (Div, Erf, Add, Mul). - QNN doesn't support Erf operator but supports Gelu operator - Since QNN doesn't support Erf operator, the graphs contain Gelu pattern partition between QNN and CPU EPs and degrading the inference time. ### Motivation and Context - Identify and fuse the Gelu pattern into a QNN Gelu node improves the inference time. --------- Co-authored-by: Tirupathi Reddy T --- .../qnn/builder/qnn_node_group/gelu_fusion.cc | 480 ++++++++++++++++++ .../qnn/builder/qnn_node_group/gelu_fusion.h | 73 +++ .../builder/qnn_node_group/qnn_node_group.cc | 6 +- .../qnn/builder/qnn_node_group/utils.cc | 86 +++- .../qnn/builder/qnn_node_group/utils.h | 6 + .../qnn/qnn_node_group/gelu_fusion_test.cc | 407 +++++++++++++++ 6 files changed, 1053 insertions(+), 5 deletions(-) create mode 100644 onnxruntime/core/providers/qnn/builder/qnn_node_group/gelu_fusion.cc create mode 100644 onnxruntime/core/providers/qnn/builder/qnn_node_group/gelu_fusion.h create mode 100644 onnxruntime/test/providers/qnn/qnn_node_group/gelu_fusion_test.cc diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/gelu_fusion.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/gelu_fusion.cc new file mode 100644 index 0000000000000..619e3eaf5fad4 --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/gelu_fusion.cc @@ -0,0 +1,480 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/qnn/builder/qnn_node_group/gelu_fusion.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "core/providers/qnn/ort_api.h" +#include "core/providers/qnn/builder/qnn_utils.h" +#include "core/providers/qnn/builder/op_builder_factory.h" +#include "core/providers/qnn/builder/qnn_model_wrapper.h" +#include "core/providers/qnn/builder/qnn_node_group/utils.h" + +namespace onnxruntime { +namespace qnn { + +// Helper function to extract value from raw data based on QNN data type +static Status GetValueOnQnnDataType(const Qnn_DataType_t qnn_data_type, + const uint8_t* raw_ptr, + double& value) { + switch (qnn_data_type) { + case QNN_DATATYPE_INT_8: + case QNN_DATATYPE_SFIXED_POINT_8: { + value = static_cast(*reinterpret_cast(raw_ptr)); + break; + } + case QNN_DATATYPE_INT_16: + case QNN_DATATYPE_SFIXED_POINT_16: { + value = static_cast(*reinterpret_cast(raw_ptr)); + break; + } + case QNN_DATATYPE_INT_32: + case QNN_DATATYPE_SFIXED_POINT_32: { + value = static_cast(*reinterpret_cast(raw_ptr)); + break; + } + case QNN_DATATYPE_UINT_8: + case QNN_DATATYPE_UFIXED_POINT_8: { + value = static_cast(*reinterpret_cast(raw_ptr)); + break; + } + case QNN_DATATYPE_UINT_16: + case QNN_DATATYPE_UFIXED_POINT_16: { + value = static_cast(*reinterpret_cast(raw_ptr)); + break; + } + case QNN_DATATYPE_UINT_32: + case QNN_DATATYPE_UFIXED_POINT_32: { + value = static_cast(*reinterpret_cast(raw_ptr)); + break; + } + case QNN_DATATYPE_FLOAT_32: { + value = static_cast(*reinterpret_cast(raw_ptr)); + break; + } + case QNN_DATATYPE_FLOAT_16: { + value = static_cast(reinterpret_cast(raw_ptr)->ToFloat()); + break; + } + default: + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Qnn Data Type: ", qnn_data_type, " not supported."); + } + return Status::OK(); +} + +// Helper function to extract a scalar float value from a constant initializer +// Handles both float and quantized (INT type) constant inputs +static std::optional GetConstantInitializerFloatScalar(QnnModelWrapper& qnn_model_wrapper, + const NodeUnitIODef& io_def) { + const GraphViewer& graph_viewer = qnn_model_wrapper.GetGraphViewer(); + const auto& name = io_def.node_arg.Name(); + + if (!graph_viewer.IsConstantInitializer(name, true)) { + return std::nullopt; + } + + // Get tensor info to check if it's quantized + TensorInfo tensor_info = {}; + if (!qnn_model_wrapper.GetTensorInfo(io_def, tensor_info).IsOK()) { + return std::nullopt; + } + + // Must be an initializer + if (!tensor_info.is_initializer || !tensor_info.initializer_tensor) { + return std::nullopt; + } + + // Unpack the initializer data + std::vector unpacked_tensor; + if (!qnn_model_wrapper.UnpackInitializerData(*tensor_info.initializer_tensor, unpacked_tensor).IsOK()) { + return std::nullopt; + } + + if (unpacked_tensor.empty()) { + return std::nullopt; + } + + // Extract the value using GetValueOnQnnDataType + double extracted_value = 0.0; + if (!GetValueOnQnnDataType(tensor_info.qnn_data_type, unpacked_tensor.data(), extracted_value).IsOK()) { + return std::nullopt; + } + + // Check if quantized and dequantize if needed + const bool is_quantized = tensor_info.quant_param.IsQuantized(); + if (is_quantized) { + // For quantized tensors, dequantize the value + if (!tensor_info.quant_param.IsPerTensor()) { + return std::nullopt; // Only support per-tensor quantization + } + + const Qnn_QuantizeParams_t& quant_param = tensor_info.quant_param.Get(); + double dequantized_value = utils::Dequantize(quant_param.scaleOffsetEncoding.offset, + quant_param.scaleOffsetEncoding.scale, + extracted_value); + return static_cast(dequantized_value); + } + + // For non-quantized tensors, return the extracted value directly + return static_cast(extracted_value); +} + +// Helper function to check if a constant initializer has the expected float value +static bool IsInitializerWithExpectedValue(QnnModelWrapper& qnn_model_wrapper, + const NodeUnitIODef& io_def, + float expected_value, + float tolerance = 1e-5f) { + std::optional actual_value = GetConstantInitializerFloatScalar(qnn_model_wrapper, io_def); + if (!actual_value.has_value()) { + return false; + } + + // Compare with expected value within tolerance + return std::fabs(actual_value.value() - expected_value) <= tolerance; +} + +// Forward declaration. +static Status CreateOrValidateOnQnn(QnnModelWrapper& qnn_model_wrapper, + gsl::span node_units, + const NodeUnitIODef& root_input, + const NodeUnitIODef& final_output, + bool validate); + +// Helper function to validate on QNN +static Status ValidateOnQnn(QnnModelWrapper& qnn_model_wrapper, + gsl::span node_units, + const NodeUnitIODef& root_input, + const NodeUnitIODef& final_output) { + return CreateOrValidateOnQnn(qnn_model_wrapper, node_units, root_input, final_output, true); +} + +// Helper function to create on QNN +static Status CreateOnQnn(QnnModelWrapper& qnn_model_wrapper, + gsl::span node_units, + const NodeUnitIODef& root_input, + const NodeUnitIODef& final_output) { + return CreateOrValidateOnQnn(qnn_model_wrapper, node_units, root_input, final_output, false); +} + +// Gets the parent and child of the Erf node. Can handle the following sequences +// - Parent -> Erf -> Child. +// - Parent -> DQ -> Erf -> Q -> Child. +// +// Also returns the outputs of the Erf. For the sequence `DQ -> Erf -> Q`, returns the outputs of the Q. +static bool GetErfParentAndChild(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& erf_node_unit, + const std::unordered_map& node_to_node_unit, + const std::unordered_map& node_unit_to_qnn_node_group, + /*out*/ const NodeUnit*& parent_node_unit, + /*out*/ const NodeUnit*& child_node_unit, + /*out*/ const NodeUnit*& dq_node_unit, + /*out*/ const NodeUnit*& q_node_unit, + /*out*/ gsl::span& erf_outputs) { + const GraphViewer& graph_viewer = qnn_model_wrapper.GetGraphViewer(); + + auto get_first_parent = [&](const NodeUnit& node_unit) -> const NodeUnit* { + const auto& inputs = node_unit.Inputs(); + if (inputs.empty()) { + return nullptr; + } + return GetParentOfInput(graph_viewer, node_unit, inputs[0], + node_to_node_unit, node_unit_to_qnn_node_group); + }; + + auto get_first_child = [&](const NodeUnit& node_unit) -> const NodeUnit* { + const auto& outputs = node_unit.Outputs(); + if (outputs.empty()) { + return nullptr; + } + + return GetOnlyChildOfOutput(graph_viewer, node_unit, outputs[0], + node_to_node_unit, node_unit_to_qnn_node_group); + }; + + const NodeUnit* erf_parent_node_unit = get_first_parent(erf_node_unit); + if (erf_parent_node_unit == nullptr) { + return false; + } + + const NodeUnit* erf_child_node_unit = get_first_child(erf_node_unit); + if (erf_child_node_unit == nullptr) { + return false; + } + + if (erf_node_unit.UnitType() == NodeUnit::Type::SingleNode && + erf_parent_node_unit->OpType() == "DequantizeLinear" && + erf_child_node_unit->OpType() == "QuantizeLinear") { + // This is the explicit sequence DQ -> Erf -> Q. + // Look past the DQ and Q nodes to get the actual parent and child. + // We do this because ORT utils do not automatically group DQ -> Erf -> Q into a NodeUnit. + dq_node_unit = erf_parent_node_unit; + q_node_unit = erf_child_node_unit; + erf_parent_node_unit = get_first_parent(*erf_parent_node_unit); + erf_child_node_unit = get_first_child(*erf_child_node_unit); + + erf_outputs = q_node_unit->Outputs(); + } else { + erf_outputs = erf_node_unit.Outputs(); + } + + parent_node_unit = erf_parent_node_unit; + child_node_unit = erf_child_node_unit; + return parent_node_unit != nullptr && child_node_unit != nullptr; +} + +std::unique_ptr GeluFusion::TryFusion( + QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& erf_node_unit, + const std::unordered_map& node_to_node_unit, + const std::unordered_map& node_unit_to_qnn_node_group, + const logging::Logger& /*logger*/) { + if (erf_node_unit.OpType() != "Erf") { + return nullptr; + } + + const GraphViewer& graph_viewer = qnn_model_wrapper.GetGraphViewer(); + const NodeUnit* div_node_unit = nullptr; + const NodeUnit* add_node_unit = nullptr; + const NodeUnit* dq_node_unit = nullptr; + const NodeUnit* q_node_unit = nullptr; + gsl::span erf_outputs; + + if (!GetErfParentAndChild(qnn_model_wrapper, erf_node_unit, node_to_node_unit, node_unit_to_qnn_node_group, + div_node_unit, add_node_unit, dq_node_unit, q_node_unit, erf_outputs)) { + return nullptr; + } + + // Erf must have a Div parent. + if (div_node_unit == nullptr || div_node_unit->OpType() != "Div") { + return nullptr; + } + + // Div must have 2 inputs + const auto& div_inputs = div_node_unit->Inputs(); + if (div_inputs.size() < 2) { + return nullptr; + } + + // Check second input of Div is sqrt(2) ≈ 1.4142 + if (!IsInitializerWithExpectedValue(qnn_model_wrapper, div_inputs[1], static_cast(M_SQRT2))) { + return nullptr; + } + + // Erf must have an Add child consuming its output + if (add_node_unit == nullptr || add_node_unit->OpType() != "Add") { + return nullptr; + } + + // Add must have 2 inputs + const auto& add_inputs = add_node_unit->Inputs(); + if (add_inputs.size() < 2) { + return nullptr; + } + + // Check the other input node (e.g. not the Erf) is 1.0f + bool is_erf_first_input = (add_inputs[0].node_arg.Name() == erf_outputs[0].node_arg.Name()); + const auto& add_const_input = add_inputs[is_erf_first_input ? 1 : 0]; + if (!IsInitializerWithExpectedValue(qnn_model_wrapper, add_const_input, 1.0f)) { + return nullptr; + } + + // Add must have a Mul child consuming its output + const auto& add_outputs = add_node_unit->Outputs(); + if (add_outputs.empty()) { + return nullptr; + } + + const NodeUnit* mul_node_unit = GetOnlyChildOfOutput(graph_viewer, *add_node_unit, add_outputs[0], + node_to_node_unit, node_unit_to_qnn_node_group); + if (mul_node_unit == nullptr || mul_node_unit->OpType() != "Mul") { + return nullptr; + } + + // Now check which pattern we have + const auto& root_input_name = div_inputs[0].node_arg.Name(); + const auto& mul_inputs = mul_node_unit->Inputs(); + + if (mul_inputs.size() < 2) { + return nullptr; + } + + // Try to match Pattern 1: root -> Mul(0.5) -> ... -> Mul + // In this case, one input to the final Mul should be from a Mul node + const NodeUnit* mul2_node_unit = nullptr; + + // Check if either input to mul_node_unit comes from a Mul node + for (size_t i = 0; i < 2; ++i) { + const auto& mul_input = mul_inputs[i]; + + const NodeUnit* producer_unit = GetParentOfInput(graph_viewer, *mul_node_unit, mul_input, + node_to_node_unit, node_unit_to_qnn_node_group); + if (producer_unit && producer_unit->OpType() == "Mul") { + const auto& mul2_inputs = producer_unit->Inputs(); + if (mul2_inputs.size() >= 2) { + bool has_root_input = (mul2_inputs[0].node_arg.Name() == root_input_name || + mul2_inputs[1].node_arg.Name() == root_input_name); + if (has_root_input) { + int root_index = (mul2_inputs[0].node_arg.Name() == root_input_name) ? 0 : 1; + const auto& mul_const_input = mul2_inputs[1 - root_index]; + + if (IsInitializerWithExpectedValue(qnn_model_wrapper, mul_const_input, 0.5f)) { + mul2_node_unit = producer_unit; + break; + } + } + } + } + if (mul2_node_unit != nullptr) break; + } + + std::vector node_units; + const NodeUnit* final_mul_node_unit = nullptr; + + if (mul2_node_unit != nullptr) { + // Pattern 1: root -> Mul(0.5) -> ... -> Mul + if (dq_node_unit != nullptr) { + assert(q_node_unit != nullptr); + node_units = {div_node_unit, dq_node_unit, &erf_node_unit, q_node_unit, add_node_unit, mul2_node_unit, + mul_node_unit}; + } else { + node_units = {div_node_unit, &erf_node_unit, add_node_unit, mul2_node_unit, mul_node_unit}; + } + final_mul_node_unit = mul_node_unit; + } else { + // Try Pattern 2: root -> ... -> Mul -> Mul(0.5) + // Check if one input to mul_node_unit is root + bool has_root_input = (mul_inputs[0].node_arg.Name() == root_input_name || + mul_inputs[1].node_arg.Name() == root_input_name); + + if (!has_root_input) { + return nullptr; + } + + // mul_node_unit must have a Mul child consuming its output + const auto& mul_outputs = mul_node_unit->Outputs(); + if (mul_outputs.empty()) { + return nullptr; + } + + const NodeUnit* mul2_node_unit_pattern2 = GetOnlyChildOfOutput(graph_viewer, *mul_node_unit, mul_outputs[0], + node_to_node_unit, node_unit_to_qnn_node_group); + if (mul2_node_unit_pattern2 == nullptr || mul2_node_unit_pattern2->OpType() != "Mul") { + return nullptr; + } + + // Verify this final Mul has 2 inputs + const auto& mul2_inputs = mul2_node_unit_pattern2->Inputs(); + if (mul2_inputs.size() < 2) { + return nullptr; + } + + // Check the constant input is 0.5f + int mul_const_input_index = 0; + if (mul2_inputs[0].node_arg.Name() == mul_outputs[0].node_arg.Name()) { + mul_const_input_index = 1; + } + const auto& mul_const_input = mul2_inputs[mul_const_input_index]; + if (!IsInitializerWithExpectedValue(qnn_model_wrapper, mul_const_input, 0.5f)) { + return nullptr; + } + + // Pattern 2 + if (dq_node_unit != nullptr) { + assert(q_node_unit != nullptr); + node_units = {div_node_unit, dq_node_unit, &erf_node_unit, q_node_unit, add_node_unit, + mul_node_unit, mul2_node_unit_pattern2}; + } else { + node_units = {div_node_unit, &erf_node_unit, add_node_unit, mul_node_unit, mul2_node_unit_pattern2}; + } + + final_mul_node_unit = mul2_node_unit_pattern2; + } + + // Validate on QNN + const NodeUnitIODef& root_input = div_inputs[0]; + const NodeUnitIODef& final_output = final_mul_node_unit->Outputs()[0]; + + if (Status status = ValidateOnQnn(qnn_model_wrapper, node_units, root_input, final_output); + !status.IsOK()) { + return nullptr; + } + + return std::make_unique(std::move(node_units), &erf_node_unit); +} + +GeluFusion::GeluFusion(std::vector&& node_units, const NodeUnit* target_node_unit) + : node_units_(std::move(node_units)), target_node_unit_(target_node_unit) { +} + +Status GeluFusion::IsSupported(QnnModelWrapper& qmw, const logging::Logger& /*logger*/) const { + ORT_RETURN_IF_NOT(!node_units_.empty(), "GeluFusion node_units_ is empty"); + const NodeUnitIODef& root_input = node_units_[0]->Inputs()[0]; + const NodeUnitIODef& final_output = node_units_.back()->Outputs()[0]; + return ValidateOnQnn(qmw, node_units_, root_input, final_output); +} + +Status GeluFusion::AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logger& /*logger*/) const { + ORT_RETURN_IF_NOT(!node_units_.empty(), "GeluFusion node_units_ is empty"); + const NodeUnitIODef& root_input = node_units_[0]->Inputs()[0]; + const NodeUnitIODef& final_output = node_units_.back()->Outputs()[0]; + return CreateOnQnn(qmw, node_units_, root_input, final_output); +} + +gsl::span GeluFusion::GetNodeUnits() const { + return gsl::span(node_units_.data(), node_units_.size()); +} + +const NodeUnit* GeluFusion::GetTargetNodeUnit() const { + return target_node_unit_; +} + +static Status CreateOrValidateOnQnn(QnnModelWrapper& qnn_model_wrapper, + gsl::span node_units, + const NodeUnitIODef& root_input, + const NodeUnitIODef& final_output, + bool validate) { + assert(node_units.size() >= 4); + const auto& node_name = utils::GetUniqueName(*node_units[0]); + + QnnTensorWrapper input_tensor; + QnnTensorWrapper output_tensor; + + ORT_RETURN_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(root_input, input_tensor)); + ORT_RETURN_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(final_output, output_tensor)); + + if (validate) { + ORT_RETURN_IF_ERROR(qnn_model_wrapper.ValidateQnnNode(node_name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + QNN_OP_GELU, + {input_tensor.GetQnnTensor()}, + {output_tensor.GetQnnTensor()}, + {})); + } else { + // Only add tensor wrappers if they don't already exist + if (!qnn_model_wrapper.IsQnnTensorWrapperExist(root_input.node_arg.Name())) { + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_tensor)), "Failed to add input"); + } + if (!qnn_model_wrapper.IsQnnTensorWrapperExist(final_output.node_arg.Name())) { + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(output_tensor)), "Failed to add output"); + } + ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(node_name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + QNN_OP_GELU, + {root_input.node_arg.Name()}, + {final_output.node_arg.Name()}, + {}, + validate), + "Failed to add fused Gelu node."); + } + + return Status::OK(); +} + +} // namespace qnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/gelu_fusion.h b/onnxruntime/core/providers/qnn/builder/qnn_node_group/gelu_fusion.h new file mode 100644 index 0000000000000..508b1fca48a67 --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/gelu_fusion.h @@ -0,0 +1,73 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include + +#include "core/providers/qnn/builder/qnn_node_group/qnn_node_group.h" +#include "core/providers/qnn/ort_api.h" + +namespace onnxruntime { +namespace qnn { + +class QnnModelWrapper; + +/// +/// Represents a fusion of the Gelu pattern expanded into ONNX operators. +/// This fusion handles two patterns: +/// Pattern 1: +/// +-------Mul(0.5)---------------------+ +/// | | +/// | v +/// [root] --> Div -----> Erf --> Add --> Mul ==> +/// (B=1.4142...) (1) +/// +/// Pattern 2: +/// +------------------------------------+ +/// | | +/// | v +/// [root] --> Div -----> Erf --> Add --> Mul -->Mul ==> +/// (B=1.4142...) (1) (0.5) +/// +/// Both patterns are translated into a QNN Gelu operator. +/// The contained NodeUnits can be of type SingleNode or QDQGroup (with Q-DQ nodes). +/// The second inputs to Div, Add, and Mul Node Units should be constant. +/// +class GeluFusion : public IQnnNodeGroup { + public: + GeluFusion(std::vector&& node_units, const NodeUnit* target_node_unit); + ORT_DISALLOW_COPY_AND_ASSIGNMENT(GeluFusion); + + Status IsSupported(QnnModelWrapper& qmw, const logging::Logger& logger) const override; + Status AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logger& logger) const override; + gsl::span GetNodeUnits() const override; + const NodeUnit* GetTargetNodeUnit() const override; + std::string_view Type() const override { return "GeluFusion"; } + + /// + /// Traverses graph to check if the given starting NodeUnit is part of a valid Gelu pattern. + /// If so, returns a IQnnNodeGroup that contains all the NodeUnits in the pattern. + /// + /// Used for validation and traverse/query the graph + /// Erf node unit that could be part of the sequence + /// Maps a Node to a NodeUnit. + /// Maps a NodeUnit to a IQnnNodeGroup. + /// + /// A valid IQnnNodeGroup on success or an empty std::unique_ptr otherwise + static std::unique_ptr TryFusion( + QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& erf_node_unit, + const std::unordered_map& node_to_node_unit, + const std::unordered_map& node_unit_to_qnn_node_group, + const logging::Logger& logger); + + private: + std::vector node_units_; + const NodeUnit* target_node_unit_; +}; + +} // namespace qnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc index 368caa518b7ba..4297801ce4cdc 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc @@ -22,6 +22,7 @@ #include "core/providers/qnn/builder/qnn_node_group/lpbqgemm_fusion.h" #include "core/providers/qnn/builder/qnn_node_group/lpbqmatmul_fusion.h" #include "core/providers/qnn/builder/qnn_node_group/reshape_transpose_rank5.h" +#include "core/providers/qnn/builder/qnn_node_group/gelu_fusion.h" #include "core/providers/qnn/builder/qnn_utils.h" #include "core/providers/qnn/ort_api.h" @@ -83,6 +84,7 @@ static std::unordered_map> fusions = { {"Gemm", {LowPowerBlockQuantizedGemmFusion::TryFusion, ReshapeGemmFusion::TryFusion}}, {"Mul", {ScaleSoftmaxFusion::TryFusion}}, {"Cast", {CastLoneQFusion::TryFusion}}, + {"Erf", {GeluFusion::TryFusion}}, {"Reshape", {Rank6ToRank5Fusion::TryFusion}}, {"Transpose", {ChannelShuffleFusion::TryFusion}}}; @@ -119,9 +121,11 @@ static std::unique_ptr TryQnnFusions( const std::unordered_map& node_to_node_unit, const std::unordered_map& node_unit_to_qnn_node_group, const logging::Logger& logger) { - // For now, all fusions involve standalone node units (i.e., no wrapping DQ/Q nodes) except MatMul w/ LPBQ encodings and Reshape + // For now, all fusions involve standalone node units (i.e., no wrapping DQ/Q nodes) except + // MatMul w/ LPBQ encodings, Erf and Reshape if (starting_node_unit.UnitType() != NodeUnit::Type::SingleNode && starting_node_unit.OpType() != "MatMul" && + starting_node_unit.OpType() != "Erf" && starting_node_unit.OpType() != "Reshape") { return nullptr; } diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.cc index 10e1633e4b57d..7b77164a38545 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.cc @@ -226,14 +226,92 @@ const NodeUnit* GetParentOfInput(const GraphViewer& graph_viewer, return nullptr; } - // parent must not already be part of a QDQ NodeUnit (i.e., be standalone). - if (p_parent_node_unit->UnitType() != NodeUnit::Type::SingleNode) { + return p_parent_node_unit; + } + return nullptr; +} + +const NodeUnit* GetOnlyChildOfOutput(const GraphViewer& graph_viewer, + const NodeUnit& node_unit, + const NodeUnitIODef& output, + const std::unordered_map& node_unit_map, + const std::unordered_map& qnn_node_group_map) { + const Node* p_parent_node = nullptr; + + for (auto node : node_unit.GetAllNodesInGroup()) { + for (auto node_output : node->OutputDefs()) { + if (node_output->Name() == output.node_arg.Name()) { + p_parent_node = node; + break; + } + } + // break the loop if producer node of output is found + if (p_parent_node != nullptr) { + break; + } + } + + // return if the given output tensor is not produced by any node in the given node_unit + if (p_parent_node == nullptr) { + return nullptr; + } + + const Node& parent_node = *p_parent_node; + + if (graph_viewer.NodeProducesGraphOutput(parent_node)) { + // Node is producing a graph output + return nullptr; + } + + // First pass: count how many children consume this specific output + int child_count = 0; + const NodeUnit* p_child_node_unit = nullptr; + + for (auto edge = parent_node.OutputEdgesBegin(); edge != parent_node.OutputEdgesEnd(); ++edge) { + const Node& child_node = edge->GetNode(); + + // Check if this edge corresponds to the output we're looking for + bool is_matching_output = false; + for (auto child_input : child_node.InputDefs()) { + if (child_input->Name() == output.node_arg.Name()) { + is_matching_output = true; + break; + } + } + + if (!is_matching_output) { + continue; + } + + if (graph_viewer.GetNode(child_node.Index()) == nullptr) { + // Node is not in this GraphViewer return nullptr; } - return p_parent_node_unit; + const auto child_node_unit_it = node_unit_map.find(&child_node); + if (child_node_unit_it == node_unit_map.end()) { + return nullptr; + } + const NodeUnit* current_child_node_unit = child_node_unit_it->second; + + // Check if child node has already been handled. Should not be the case if the calling + // fusion function has been called in topological order, but check to be safe. + if (qnn_node_group_map.count(current_child_node_unit) != 0) { + return nullptr; + } + + // Store the child node unit and increment count + p_child_node_unit = current_child_node_unit; + child_count++; + + // If we found more than one child, return nullptr immediately + if (child_count > 1) { + return nullptr; + } } - return nullptr; + + // Return the child only if there's exactly one child + return (child_count == 1) ? p_child_node_unit : nullptr; } } // namespace qnn diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.h b/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.h index 14e2a3f25e7db..b52cdd5fa3ec6 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.h @@ -51,5 +51,11 @@ const NodeUnit* GetParentOfInput(const GraphViewer& graph_viewer, const std::unordered_map& node_unit_map, const std::unordered_map& qnn_node_group_map); +const NodeUnit* GetOnlyChildOfOutput(const GraphViewer& graph_viewer, + const NodeUnit& node_unit, + const NodeUnitIODef& output, + const std::unordered_map& node_unit_map, + const std::unordered_map& qnn_node_group_map); + } // namespace qnn } // namespace onnxruntime diff --git a/onnxruntime/test/providers/qnn/qnn_node_group/gelu_fusion_test.cc b/onnxruntime/test/providers/qnn/qnn_node_group/gelu_fusion_test.cc new file mode 100644 index 0000000000000..e28cf00aa070b --- /dev/null +++ b/onnxruntime/test/providers/qnn/qnn_node_group/gelu_fusion_test.cc @@ -0,0 +1,407 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if !defined(ORT_MINIMAL_BUILD) + +#include +#include + +#include "core/graph/graph.h" +#include "core/graph/node_attr_utils.h" +#include "test/providers/qnn/qnn_test_utils.h" +#include "test/unittest_util/qdq_test_utils.h" +#include "gtest/gtest.h" + +namespace onnxruntime { +namespace test { + +#if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) + +namespace { + +// Helper function to build GELU Pattern 1: root -> Mul -> Div -> Erf -> Add -> Mul +// Pattern 1: +// +-------Mul(0.5)---------------------+ +// | | +// | v +// [root] --> Div -----> Erf --> Add --> Mul ==> +// (B=1.4142...) (1) +GetTestModelFn BuildGeluPattern1TestCase(const TestInputDef& input_def) { + return [input_def](ModelTestBuilder& builder) -> void { + constexpr float sqrt_2 = 1.4142135381698608f; + constexpr float half = 0.5f; + constexpr float one = 1.0f; + + // Create input + NodeArg* input = MakeTestInput(builder, input_def); + + // Create Mul(0.5) branch: input * 0.5 + NodeArg* half_initializer = builder.MakeScalarInitializer(half); + NodeArg* mul_half_output = builder.MakeIntermediate(); + builder.AddNode("Mul", {input, half_initializer}, {mul_half_output}); + + // Create main branch: input / sqrt(2) + NodeArg* sqrt2_initializer = builder.MakeScalarInitializer(sqrt_2); + NodeArg* div_output = builder.MakeIntermediate(); + builder.AddNode("Div", {input, sqrt2_initializer}, {div_output}); + + // Erf + NodeArg* erf_output = builder.MakeIntermediate(); + builder.AddNode("Erf", {div_output}, {erf_output}); + + // Add 1.0 + NodeArg* one_initializer = builder.MakeScalarInitializer(one); + NodeArg* add_output = builder.MakeIntermediate(); + builder.AddNode("Add", {erf_output, one_initializer}, {add_output}); + + // Final Mul: (add_output) * (mul_half_output) + NodeArg* output = builder.MakeOutput(); + builder.AddNode("Mul", {add_output, mul_half_output}, {output}); + }; +} + +// Helper function to build GELU Pattern 2: Mul(0.5) after the main sequence +// Pattern 2: +// +------------------------------------+ +// | | +// | v +// [root] --> Div -----> Erf --> Add --> Mul -->Mul ==> +// (B=1.4142...) (1) (0.5) +GetTestModelFn BuildGeluPattern2TestCase(const TestInputDef& input_def) { + return [input_def](ModelTestBuilder& builder) -> void { + constexpr float sqrt_2 = 1.4142135381698608f; + constexpr float half = 0.5f; + constexpr float one = 1.0f; + + // Create input + NodeArg* input = MakeTestInput(builder, input_def); + + // Main branch: input / sqrt(2) + NodeArg* sqrt2_initializer = builder.MakeScalarInitializer(sqrt_2); + NodeArg* div_output = builder.MakeIntermediate(); + builder.AddNode("Div", {input, sqrt2_initializer}, {div_output}); + + // Erf + NodeArg* erf_output = builder.MakeIntermediate(); + builder.AddNode("Erf", {div_output}, {erf_output}); + + // Add 1.0 + NodeArg* one_initializer = builder.MakeScalarInitializer(one); + NodeArg* add_output = builder.MakeIntermediate(); + builder.AddNode("Add", {erf_output, one_initializer}, {add_output}); + + // Mul with input: input * add_output + NodeArg* mul_output = builder.MakeIntermediate(); + builder.AddNode("Mul", {input, add_output}, {mul_output}); + + // Final Mul with 0.5: mul_output * 0.5 + NodeArg* half_initializer = builder.MakeScalarInitializer(half); + NodeArg* output = builder.MakeOutput(); + builder.AddNode("Mul", {mul_output, half_initializer}, {output}); + }; +} + +// Helper function to build QDQ GELU Pattern 1 +template +GetTestQDQModelFn BuildQDQGeluPattern1TestCase(const TestInputDef& input_def) { + return [input_def](ModelTestBuilder& builder, std::vector>& output_qparams) -> void { + constexpr float sqrt_2 = 1.4142135381698608f; + constexpr float half = 0.5f; + constexpr float one = 1.0f; + + // Create input + NodeArg* input = MakeTestInput(builder, input_def); + QuantParams input_qparams = GetTestInputQuantParams(input_def); + + // Quantize input once + NodeArg* input_q = builder.MakeIntermediate(); + builder.AddQuantizeLinearNode(input, input_qparams.scale, input_qparams.zero_point, input_q); + + // Create quantized constants with individual quantization parameters + // For scalar constants, use range [0, value] to ensure proper quantization + QuantParams sqrt2_qparams = GetTestInputQuantParams(TestInputDef({}, true, 0.0f, sqrt_2)); + NodeArg* sqrt2_initializer = builder.MakeScalarInitializer(sqrt_2); + NodeArg* sqrt2_q = builder.MakeIntermediate(); + builder.AddQuantizeLinearNode(sqrt2_initializer, sqrt2_qparams.scale, sqrt2_qparams.zero_point, sqrt2_q); + + QuantParams one_qparams = GetTestInputQuantParams(TestInputDef({}, true, 0.0f, one)); + NodeArg* one_initializer = builder.MakeScalarInitializer(one); + NodeArg* one_q = builder.MakeIntermediate(); + builder.AddQuantizeLinearNode(one_initializer, one_qparams.scale, one_qparams.zero_point, one_q); + + QuantParams half_qparams = GetTestInputQuantParams(TestInputDef({}, true, 0.0f, half)); + NodeArg* half_initializer = builder.MakeScalarInitializer(half); + NodeArg* half_q = builder.MakeIntermediate(); + builder.AddQuantizeLinearNode(half_initializer, half_qparams.scale, half_qparams.zero_point, half_q); + + NodeArg* input_dq_1 = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(input_q, input_qparams.scale, input_qparams.zero_point, input_dq_1); + NodeArg* sqrt2_dq = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(sqrt2_q, sqrt2_qparams.scale, sqrt2_qparams.zero_point, sqrt2_dq); + NodeArg* div_output = builder.MakeIntermediate(); + builder.AddNode("Div", {input_dq_1, sqrt2_dq}, {div_output}); + NodeArg* div_q = builder.MakeIntermediate(); + builder.AddQuantizeLinearNode(div_output, input_qparams.scale, input_qparams.zero_point, div_q); + + // DQ -> Erf -> Q + NodeArg* div_dq = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(div_q, input_qparams.scale, input_qparams.zero_point, div_dq); + NodeArg* erf_output = builder.MakeIntermediate(); + builder.AddNode("Erf", {div_dq}, {erf_output}); + NodeArg* erf_q = builder.MakeIntermediate(); + builder.AddQuantizeLinearNode(erf_output, input_qparams.scale, input_qparams.zero_point, erf_q); + + // DQ -> Add -> Q + NodeArg* erf_dq = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(erf_q, input_qparams.scale, input_qparams.zero_point, erf_dq); + NodeArg* one_dq = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(one_q, one_qparams.scale, one_qparams.zero_point, one_dq); + NodeArg* add_output = builder.MakeIntermediate(); + builder.AddNode("Add", {erf_dq, one_dq}, {add_output}); + NodeArg* add_q = builder.MakeIntermediate(); + builder.AddQuantizeLinearNode(add_output, input_qparams.scale, input_qparams.zero_point, add_q); + + // DQ -> Mul (with input) -> Q + NodeArg* input_dq_2 = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(input_q, input_qparams.scale, input_qparams.zero_point, input_dq_2); + NodeArg* add_dq = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(add_q, input_qparams.scale, input_qparams.zero_point, add_dq); + NodeArg* mul_output = builder.MakeIntermediate(); + builder.AddNode("Mul", {input_dq_2, add_dq}, {mul_output}); + NodeArg* mul_q = builder.MakeIntermediate(); + builder.AddQuantizeLinearNode(mul_output, input_qparams.scale, input_qparams.zero_point, mul_q); + + // Final DQ -> Mul (with 0.5) -> Q + NodeArg* mul_dq = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(mul_q, input_qparams.scale, input_qparams.zero_point, mul_dq); + NodeArg* half_dq = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(half_q, half_qparams.scale, half_qparams.zero_point, half_dq); + NodeArg* mul_final_output = builder.MakeIntermediate(); + builder.AddNode("Mul", {mul_dq, half_dq}, {mul_final_output}); + + // Add output QDQ + AddQDQNodePairWithOutputAsGraphOutput(builder, mul_final_output, output_qparams[0].scale, + output_qparams[0].zero_point); + }; +} + +// Helper function to build QDQ GELU Pattern 2 +template +GetTestQDQModelFn BuildQDQGeluPattern2TestCase(const TestInputDef& input_def) { + return [input_def](ModelTestBuilder& builder, std::vector>& output_qparams) -> void { + constexpr float sqrt_2 = 1.4142135381698608f; + constexpr float half = 0.5f; + constexpr float one = 1.0f; + + // Create input + NodeArg* input = MakeTestInput(builder, input_def); + QuantParams input_qparams = GetTestInputQuantParams(input_def); + + // Quantize input once + NodeArg* input_q = builder.MakeIntermediate(); + builder.AddQuantizeLinearNode(input, input_qparams.scale, input_qparams.zero_point, input_q); + + // Create quantized constants with individual quantization parameters + // For scalar constants, use range [0, value] to ensure proper quantization + QuantParams sqrt2_qparams = GetTestInputQuantParams(TestInputDef({}, true, 0.0f, sqrt_2)); + NodeArg* sqrt2_initializer = builder.MakeScalarInitializer(sqrt_2); + NodeArg* sqrt2_q = builder.MakeIntermediate(); + builder.AddQuantizeLinearNode(sqrt2_initializer, sqrt2_qparams.scale, sqrt2_qparams.zero_point, sqrt2_q); + + QuantParams one_qparams = GetTestInputQuantParams(TestInputDef({}, true, 0.0f, one)); + NodeArg* one_initializer = builder.MakeScalarInitializer(one); + NodeArg* one_q = builder.MakeIntermediate(); + builder.AddQuantizeLinearNode(one_initializer, one_qparams.scale, one_qparams.zero_point, one_q); + + QuantParams half_qparams = GetTestInputQuantParams(TestInputDef({}, true, 0.0f, half)); + NodeArg* half_initializer = builder.MakeScalarInitializer(half); + NodeArg* half_q = builder.MakeIntermediate(); + builder.AddQuantizeLinearNode(half_initializer, half_qparams.scale, half_qparams.zero_point, half_q); + + // Main branch: DQ -> Div -> Q + NodeArg* input_dq_1 = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(input_q, input_qparams.scale, input_qparams.zero_point, input_dq_1); + NodeArg* sqrt2_dq = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(sqrt2_q, sqrt2_qparams.scale, sqrt2_qparams.zero_point, sqrt2_dq); + NodeArg* div_output = builder.MakeIntermediate(); + builder.AddNode("Div", {input_dq_1, sqrt2_dq}, {div_output}); + NodeArg* div_q = builder.MakeIntermediate(); + builder.AddQuantizeLinearNode(div_output, input_qparams.scale, input_qparams.zero_point, div_q); + + // DQ -> Erf -> Q + NodeArg* div_dq = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(div_q, input_qparams.scale, input_qparams.zero_point, div_dq); + NodeArg* erf_output = builder.MakeIntermediate(); + builder.AddNode("Erf", {div_dq}, {erf_output}); + NodeArg* erf_q = builder.MakeIntermediate(); + builder.AddQuantizeLinearNode(erf_output, input_qparams.scale, input_qparams.zero_point, erf_q); + + // DQ -> Add -> Q + NodeArg* erf_dq = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(erf_q, input_qparams.scale, input_qparams.zero_point, erf_dq); + NodeArg* one_dq = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(one_q, one_qparams.scale, one_qparams.zero_point, one_dq); + NodeArg* add_output = builder.MakeIntermediate(); + builder.AddNode("Add", {erf_dq, one_dq}, {add_output}); + NodeArg* add_q = builder.MakeIntermediate(); + builder.AddQuantizeLinearNode(add_output, input_qparams.scale, input_qparams.zero_point, add_q); + + // DQ -> Mul (with input) -> Q + NodeArg* input_dq_2 = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(input_q, input_qparams.scale, input_qparams.zero_point, input_dq_2); + NodeArg* add_dq = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(add_q, input_qparams.scale, input_qparams.zero_point, add_dq); + NodeArg* mul_output = builder.MakeIntermediate(); + builder.AddNode("Mul", {input_dq_2, add_dq}, {mul_output}); + NodeArg* mul_q = builder.MakeIntermediate(); + builder.AddQuantizeLinearNode(mul_output, input_qparams.scale, input_qparams.zero_point, mul_q); + + // Final DQ -> Mul (with 0.5) + NodeArg* mul_dq = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(mul_q, input_qparams.scale, input_qparams.zero_point, mul_dq); + NodeArg* half_dq = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(half_q, half_qparams.scale, half_qparams.zero_point, half_dq); + NodeArg* mul_final_output = builder.MakeIntermediate(); + builder.AddNode("Mul", {mul_dq, half_dq}, {mul_final_output}); + + // Add output QDQ + AddQDQNodePairWithOutputAsGraphOutput(builder, mul_final_output, output_qparams[0].scale, + output_qparams[0].zero_point); + }; +} + +ProviderOptions GetProviderOptions() { + ProviderOptions provider_options; + provider_options["backend_type"] = "htp"; + provider_options["offload_graph_io_quantization"] = "0"; + return provider_options; +} + +} // namespace + +// Test GELU Pattern 1 with float32 model (for baseline comparison) +TEST_F(QnnHTPBackendTests, GeluFusionPattern1_Float32) { + ProviderOptions provider_options = GetProviderOptions(); + auto input_def = TestInputDef({1, 2, 3, 4}, false, -1.0f, 1.0f); + + RunQnnModelTest(BuildGeluPattern1TestCase(input_def), + provider_options, + /*opset_version=*/13, + /*expected_ep_assignment=*/ExpectedEPNodeAssignment::All, + /*fp32_abs_err=*/1e-3f); +} + +// Test GELU Pattern 2 with float32 model (for baseline comparison) +TEST_F(QnnHTPBackendTests, GeluFusionPattern2_Float32) { + ProviderOptions provider_options = GetProviderOptions(); + auto input_def = TestInputDef({1, 2, 3, 4}, false, -1.0f, 1.0f); + + RunQnnModelTest(BuildGeluPattern2TestCase(input_def), + provider_options, + /*opset_version=*/13, + /*expected_ep_assignment=*/ExpectedEPNodeAssignment::All, + /*fp32_abs_err=*/1e-3f); +} + +// Test GELU Pattern 1 with larger input shape +TEST_F(QnnHTPBackendTests, GeluFusionPattern1_LargeInput) { + ProviderOptions provider_options = GetProviderOptions(); + auto input_def = TestInputDef({1, 128, 768}, false, -1.5f, 1.5f); + + RunQnnModelTest(BuildGeluPattern1TestCase(input_def), + provider_options, + /*opset_version=*/13, + /*expected_ep_assignment=*/ExpectedEPNodeAssignment::All, + /*fp32_abs_err=*/2e-3f); +} + +// Test GELU Pattern 2 with larger input shape +TEST_F(QnnHTPBackendTests, GeluFusionPattern2_LargeInput) { + ProviderOptions provider_options = GetProviderOptions(); + auto input_def = TestInputDef({1, 128, 768}, false, -1.5f, 1.5f); + + RunQnnModelTest(BuildGeluPattern2TestCase(input_def), + provider_options, + /*opset_version=*/13, + /*expected_ep_assignment=*/ExpectedEPNodeAssignment::All, + /*fp32_abs_err=*/2e-3f); +} + +// Test GELU Pattern 1 with 3D input +TEST_F(QnnHTPBackendTests, GeluFusionPattern1_3D) { + ProviderOptions provider_options = GetProviderOptions(); + auto input_def = TestInputDef({1, 16, 32}, false, -1.0f, 1.0f); + + RunQnnModelTest(BuildGeluPattern1TestCase(input_def), + provider_options, + /*opset_version=*/13, + /*expected_ep_assignment=*/ExpectedEPNodeAssignment::All, + /*fp32_abs_err=*/1e-3f); +} + +// Test GELU Pattern 2 with 3D input +TEST_F(QnnHTPBackendTests, GeluFusionPattern2_3D) { + ProviderOptions provider_options = GetProviderOptions(); + auto input_def = TestInputDef({1, 16, 32}, false, -1.0f, 1.0f); + + RunQnnModelTest(BuildGeluPattern2TestCase(input_def), + provider_options, + /*opset_version=*/13, + /*expected_ep_assignment=*/ExpectedEPNodeAssignment::All, + /*fp32_abs_err=*/1e-3f); +} + +// Test GELU Pattern 1 with 2D input (typical for linear layers) +TEST_F(QnnHTPBackendTests, GeluFusionPattern1_2D) { + ProviderOptions provider_options = GetProviderOptions(); + auto input_def = TestInputDef({32, 512}, false, -1.5f, 1.5f); + + RunQnnModelTest(BuildGeluPattern1TestCase(input_def), + provider_options, + /*opset_version=*/13, + /*expected_ep_assignment=*/ExpectedEPNodeAssignment::All, + /*fp32_abs_err=*/2e-3f); +} + +// Test GELU Pattern 2 with 2D input (typical for linear layers) +TEST_F(QnnHTPBackendTests, GeluFusionPattern2_2D) { + ProviderOptions provider_options = GetProviderOptions(); + auto input_def = TestInputDef({32, 512}, false, -1.5f, 1.5f); + + RunQnnModelTest(BuildGeluPattern2TestCase(input_def), + provider_options, + /*opset_version=*/13, + /*expected_ep_assignment=*/ExpectedEPNodeAssignment::All, + /*fp32_abs_err=*/2e-3f); +} + +// Test GELU Pattern 1 with QDQ +TEST_F(QnnHTPBackendTests, GeluFusionPattern1_QDQ_U8) { + ProviderOptions provider_options = GetProviderOptions(); + auto input_def = TestInputDef({1, 2, 3, 4}, false, -1.0f, 1.0f); + + TestQDQModelAccuracy(BuildGeluPattern1TestCase(input_def), + BuildQDQGeluPattern1TestCase(input_def), + provider_options, + /*opset_version=*/13, + /*expected_ep_assignment=*/ExpectedEPNodeAssignment::All); +} + +// Test GELU Pattern 2 with QDQ +TEST_F(QnnHTPBackendTests, GeluFusionPattern2_QDQ_U8) { + ProviderOptions provider_options = GetProviderOptions(); + auto input_def = TestInputDef({1, 2, 3, 4}, false, -1.0f, 1.0f); + + TestQDQModelAccuracy(BuildGeluPattern2TestCase(input_def), + BuildQDQGeluPattern2TestCase(input_def), + provider_options, + /*opset_version=*/13, + /*expected_ep_assignment=*/ExpectedEPNodeAssignment::All); +} + +#endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) + +} // namespace test +} // namespace onnxruntime + +#endif // !defined(ORT_MINIMAL_BUILD) From 8e951ef6a0e61580943634954dc6f71a74be5356 Mon Sep 17 00:00:00 2001 From: Adrian Lizarraga Date: Mon, 24 Nov 2025 16:29:18 -0800 Subject: [PATCH 16/17] Update weight sharing tool to support plugin EPs (#26614) ### Description - Updates the `ep_weight_sharing_ctx_gen` tool to support specifying a plugin EP configuration (via JSON). - Mark the `ep_weight_sharing_ctx_gen` tool as deprecated and add notification to README that recommends the use the public Python ORT APIs instead. - Note we no longer publish a binary for this tool [as of ORT 1.22.2](https://github.com/microsoft/onnxruntime/pull/24895). - Added an example Python script in the README. - Added a Python unit test that tests compiling models with weight sharing using an example plugin EP. #### Tool usage Create a JSON file that contains information about the plugin EP to load/use (e.g., `example_plugin_ep_config.json`): ```json { "ep_library_registration_name": "example_plugin_ep", "ep_library_path": "example_plugin_ep.dll", "selected_ep_name": "example_plugin_ep", "default_ep_options": { "option_key": "option_value" } } ``` Call the `ep_weight_sharing_ctx_gen` tool with the `-p` command-line option to specify the location of the above configuration file: ```console $ ep_weight_sharing_ctx_gen.exe -p example_plugin_ep_config.json model_1.onnx,model_2.onnx ``` ### Motivation and Context Close the functionality gap between traditional provider-bridge EPs and plugin EPs. This PR allows using plugin EPs with the tool that compiles models with weight sharing. --- .../test/ep_weight_sharing_ctx_gen/README.md | 66 +++++++++++++ .../command_args_parser.cc | 95 ++++++++++++++++++- .../example_plugin_ep_config.json | 6 ++ .../test/ep_weight_sharing_ctx_gen/main.cc | 77 ++++++++++++++- .../test_configuration.h | 18 ++++ onnxruntime/test/python/helper.py | 12 +++ .../onnxruntime_test_python_compile_api.py | 48 +++++++++- 7 files changed, 319 insertions(+), 3 deletions(-) create mode 100644 onnxruntime/test/ep_weight_sharing_ctx_gen/example_plugin_ep_config.json diff --git a/onnxruntime/test/ep_weight_sharing_ctx_gen/README.md b/onnxruntime/test/ep_weight_sharing_ctx_gen/README.md index 66b8467bda335..51a405613bea1 100644 --- a/onnxruntime/test/ep_weight_sharing_ctx_gen/README.md +++ b/onnxruntime/test/ep_weight_sharing_ctx_gen/README.md @@ -1,5 +1,8 @@ # ONNXRuntime EP Context Model Generation with Weight Sharing +> [!NOTE] +> This tool is deprecated. Please use the public ONNX Runtime Python APIs to compile models with resource sharing. Refer to the example Python script at the end of this document. + [EP context with weight sharing design doc](https://onnxruntime.ai/docs/execution-providers/EP-Context-Design.html#epcontext-with-weight-sharing) OnnxRuntime provides the ep_weight_sharing_ctx_gen tool to automate the weight-sharing workflow. This tool handles the entire process. This tool is specifically designed for weight sharing scenarios, streamlining the EPContext model generation process. @@ -13,6 +16,23 @@ Example: ./ep_weight_sharing_ctx_gen -e qnn -i "soc_model|60 htp_graph_finalizat Options: -e [qnn|tensorrt|openvino|vitisai]: Specifies the compile based provider 'qnn', 'tensorrt', 'openvino', 'vitisai'. Default: 'qnn'. + -p [plugin_ep_config_json_file]: Specify JSON configuration file for a plugin EP. Takes precedence over the '-e' and '-i' options. + + Example JSON configuration that selects plugin EP devices via name: + { + "ep_library_registration_name": "example_plugin_ep", + "ep_library_path": "example_plugin_ep.dll", + "selected_ep_name": "example_plugin_ep", + "default_ep_options": { "key": "value" } + } + + Example JSON configuration that selects plugin EP devices via index: + { + "ep_library_registration_name": "example_plugin_ep", + "ep_library_path": "example_plugin_ep.dll", + "selected_ep_device_indices": [ 0 ], + "default_ep_options": { "key": "value" } + } -v: Show verbose information. -C: Specify session configuration entries as key-value pairs: -C "| |" Refer to onnxruntime_session_options_config_keys.h for valid keys and values. @@ -36,3 +56,49 @@ Options: -h: help ``` + +# Example: Use Python APIs to compile models with resource sharing +Use of the public ORT Python APIs is now recommended for compiling models with resource (e.g., "weight") sharing. +The following snippet shows an example that compiles two models using an example plugin EP. + +```Python +import onnxruntime +import os + +def main(): + ep_name = "example_ep" + ep_lib_path = "example_plugin_ep.dll" + + onnxruntime.register_execution_provider_library(ep_name, os.path.realpath(ep_lib_path)) + + # Find one or more EP devices that correspond to the EP of interest. + # In this example, we pick the first one. + ep_device = next((d for d in onnxruntime.get_ep_devices() if d.ep_name == ep_name), None) + + # These are the names/paths to the input and output models. + input_models = ["model_0.onnx", "model_1.onnx"] + output_models = ["model_0_ctx.onnx", "model_1_ctx.onnx"] + + num_models = len(input_models) + session_options = onnxruntime.SessionOptions() + provider_options = {} # Empty for this example + + # Set option that tells EP to share resources (e.g., weights) across sessions. + session_options.add_session_config_entry("ep.share_ep_contexts", "1") + session_options.add_provider_for_devices([ep_device], provider_options) + + # Compile individual models + for i in range(len(input_models)): + if i == num_models - 1: + # Tell EP that this is the last compiling session that will be sharing resources. + session_options.add_session_config_entry("ep.stop_share_ep_contexts", "1") + + model_compiler = onnxruntime.ModelCompiler( + session_options, + input_models[i], + embed_compiled_data_into_model=False, + ) + model_compiler.compile_to_file(output_models[i]) + + onnxruntime.unregister_execution_provider_library(ep_name) +``` diff --git a/onnxruntime/test/ep_weight_sharing_ctx_gen/command_args_parser.cc b/onnxruntime/test/ep_weight_sharing_ctx_gen/command_args_parser.cc index cecf5575d42a5..15bce163ba16a 100644 --- a/onnxruntime/test/ep_weight_sharing_ctx_gen/command_args_parser.cc +++ b/onnxruntime/test/ep_weight_sharing_ctx_gen/command_args_parser.cc @@ -4,6 +4,7 @@ #include "command_args_parser.h" #include +#include #include #include #include @@ -21,6 +22,7 @@ #include #include +#include "nlohmann/json.hpp" #include "test_configuration.h" namespace onnxruntime { @@ -35,6 +37,23 @@ namespace qnnctxgen { "\n" "Options:\n" "\t-e [qnn|tensorrt|openvino|vitisai]: Specifies the compile based provider 'qnn', 'tensorrt', 'openvino', 'vitisai'. Default: 'qnn'.\n" + "\t-p [plugin_ep_config_json_file]: Specify JSON configuration file for a plugin EP. Takes precedence over the '-e' and '-i' options.\n" + "\n" + "\t Example JSON configuration that selects plugin EP devices via EP name:\n" + "\t {\n" + "\t \"ep_library_registration_name\": \"example_plugin_ep\",\n" + "\t \"ep_library_path\": \"example_plugin_ep.dll\",\n" + "\t \"selected_ep_name\": \"example_plugin_ep\",\n" + "\t \"default_ep_options\": { \"key\": \"value\" }\n" + "\t }\n" + "\n" + "\t Example JSON configuration that selects plugin EP devices via index:\n" + "\t {\n" + "\t \"ep_library_registration_name\": \"example_plugin_ep\",\n" + "\t \"ep_library_path\": \"example_plugin_ep.dll\",\n" + "\t \"selected_ep_device_indices\": [ 0 ],\n" + "\t \"default_ep_options\": { \"key\": \"value\" }\n" + "\t }\n" "\t-v: Show verbose information.\n" "\t-C: Specify session configuration entries as key-value pairs: -C \"| |\" \n" "\t Refer to onnxruntime_session_options_config_keys.h for valid keys and values. \n" @@ -58,6 +77,7 @@ namespace qnnctxgen { "\n" "\t-h: help\n"); } + #ifdef _WIN32 static const ORTCHAR_T* delimiter = L","; #else @@ -110,9 +130,63 @@ static bool ParseSessionConfigs(const std::string& configs_string, return true; } +static bool ParsePluginEpConfig(const std::string& json_file_path, PluginEpConfig& config_out) { + using json = nlohmann::json; + bool success = true; + + ORT_TRY { + std::ifstream ifs{json_file_path}; + if (!ifs) { + std::cerr << "ERROR: Failed to open plugin EP configuration file at path: " + << json_file_path.c_str() << std::endl; + return false; + } + + std::string content(std::istreambuf_iterator{ifs}, + std::istreambuf_iterator{}); + PluginEpConfig config{}; + const auto parsed_json = json::parse(content); + + // required keys + parsed_json.at("ep_library_registration_name").get_to(config.ep_library_registration_name); + parsed_json.at("ep_library_path").get_to(config.ep_library_path); + + // optional keys + config.default_ep_options = parsed_json.value("default_ep_options", {}); + config.selected_ep_name = parsed_json.value("selected_ep_name", {}); + config.selected_ep_device_indices = + parsed_json.value("selected_ep_device_indices", {}); + + if (config.selected_ep_name.empty() == config.selected_ep_device_indices.empty()) { + std::cerr << "ERROR: Plugin EP configuration must specify exactly one of 'selected_ep_name' " + << "or 'selected_ep_device_indices'" << std::endl; + return false; + } + + config_out = std::move(config); + return success; + } + ORT_CATCH(const json::exception& e) { + ORT_HANDLE_EXCEPTION([&]() { + std::string kExampleValidJsonStr = + "{\n" + " \"ep_library_registration_name\": \"example_plugin_ep\",\n" + " \"ep_library_path\": \"/path/to/example_plugin_ep.dll\",\n" + " \"selected_ep_name\": \"example_plugin_ep\"\n" + "}"; + + success = false; + std::cerr << "ERROR: JSON parse error: " << e.what() << std::endl; + std::cerr << "This is an example valid JSON configuration:\n" + << kExampleValidJsonStr.c_str() << std::endl; + }); + } + return success; +} + /*static*/ bool CommandLineParser::ParseArguments(TestConfig& test_config, int argc, ORTCHAR_T* argv[]) { int ch; - while ((ch = getopt(argc, argv, ORT_TSTR("e:o:u:i:C:vh"))) != -1) { + while ((ch = getopt(argc, argv, ORT_TSTR("e:p:o:u:i:C:vh"))) != -1) { switch (ch) { case 'e': if (!CompareCString(optarg, ORT_TSTR("qnn"))) { @@ -128,6 +202,20 @@ static bool ParseSessionConfigs(const std::string& configs_string, return false; } break; + case 'p': { +#ifdef _MSC_VER + std::string plugin_ep_config_file_path = ToUTF8String(optarg); +#else + std::string plugin_ep_config_file_path = optarg; +#endif + PluginEpConfig plugin_ep_config{}; + if (!ParsePluginEpConfig(plugin_ep_config_file_path, plugin_ep_config)) { + return false; + } + + test_config.machine_config.plugin_ep_config = std::move(plugin_ep_config); + break; + } case 'v': test_config.run_config.f_verbose = true; break; @@ -202,6 +290,11 @@ static bool ParseSessionConfigs(const std::string& configs_string, argc -= optind; argv += optind; + if (argc == 0) { + std::cerr << "ERROR: Did not specify model paths" << std::endl; + return false; + } + ParsePaths(argv[0], test_config.model_file_paths); return true; diff --git a/onnxruntime/test/ep_weight_sharing_ctx_gen/example_plugin_ep_config.json b/onnxruntime/test/ep_weight_sharing_ctx_gen/example_plugin_ep_config.json new file mode 100644 index 0000000000000..f8967d1831582 --- /dev/null +++ b/onnxruntime/test/ep_weight_sharing_ctx_gen/example_plugin_ep_config.json @@ -0,0 +1,6 @@ +{ + "ep_library_registration_name": "example_plugin_ep", + "ep_library_path": "example_plugin_ep.dll", + "selected_ep_name": "example_plugin_ep", + "default_ep_options": { "option_key": "option_value" } +} diff --git a/onnxruntime/test/ep_weight_sharing_ctx_gen/main.cc b/onnxruntime/test/ep_weight_sharing_ctx_gen/main.cc index 18abe1eb131d8..3f2cda26fe9df 100644 --- a/onnxruntime/test/ep_weight_sharing_ctx_gen/main.cc +++ b/onnxruntime/test/ep_weight_sharing_ctx_gen/main.cc @@ -10,6 +10,7 @@ // onnx dependencies #include "onnx/onnx_pb.h" +#include #include using namespace onnxruntime; @@ -81,6 +82,72 @@ static void UpdateEpContextModel(const std::vector> } } +using PluginEpLibraryRegistrationHandle = std::unique_ptr>; + +static PluginEpLibraryRegistrationHandle RegisterPluginEpLibrary(Ort::Env& env, + const std::string& ep_library_registration_name, + const std::basic_string& ep_library_path) { + env.RegisterExecutionProviderLibrary(ep_library_registration_name.c_str(), ep_library_path); + + auto unregister_ep_library = [&env, registration_name = ep_library_registration_name](void* p) { + if (p == nullptr) { + return; + } + + ORT_TRY { + env.UnregisterExecutionProviderLibrary(registration_name.c_str()); + } + ORT_CATCH(const Ort::Exception& e) { + ORT_HANDLE_EXCEPTION([&]() { + std::cerr << "Failed to unregister EP library with name '" << registration_name << "': " + << e.what() << std::endl; + }); + } + }; + + // Set `handle_value` to something not equal to nullptr. The particular value doesn't really matter. + // We are just using the unique_ptr deleter to unregister the EP library. + void* const handle_value = reinterpret_cast(0x1); + return PluginEpLibraryRegistrationHandle{handle_value, unregister_ep_library}; +} + +static bool SetPluginEpSessionOptions(Ort::Env& env, Ort::SessionOptions& session_options, + const qnnctxgen::PluginEpConfig& config, + PluginEpLibraryRegistrationHandle& plugin_ep_library_registration_handle) { + auto lib_registration_handle = RegisterPluginEpLibrary(env, config.ep_library_registration_name, + ToPathString(config.ep_library_path)); + + std::vector ep_devices = env.GetEpDevices(); + std::vector selected_ep_devices{}; + + if (!config.selected_ep_device_indices.empty()) { + for (const auto idx : config.selected_ep_device_indices) { + if (idx >= ep_devices.size()) { + std::cerr << "ERROR: Selected EP device index is out of range (max is " << ep_devices.size() - 1 << "): " + << idx << std::endl; + return false; + } + + selected_ep_devices.push_back(ep_devices[idx]); + } + } else { + std::copy_if(ep_devices.begin(), ep_devices.end(), std::back_inserter(selected_ep_devices), + [&selected_ep_name = std::as_const(config.selected_ep_name)](Ort::ConstEpDevice ep_device) { + return ep_device.EpName() == selected_ep_name; + }); + } + + if (selected_ep_devices.empty()) { + std::cerr << "ERROR: No EP devices were selected" << std::endl; + return false; + } + + session_options.AppendExecutionProvider_V2(env, selected_ep_devices, config.default_ep_options); + plugin_ep_library_registration_handle = std::move(lib_registration_handle); + + return true; +} + #ifdef _WIN32 int real_main(int argc, wchar_t* argv[]) { #else @@ -98,6 +165,7 @@ int real_main(int argc, char* argv[]) { Ort::Env env(logging_level, "ep_weight_sharing"); ORT_TRY { + PluginEpLibraryRegistrationHandle plugin_ep_library_registration_handle{}; Ort::SessionOptions so; so.SetLogId("ep_weight_sharing_ctx_gen_session_logger"); // Set default session option to dump EPContext model with non-embed mode @@ -136,7 +204,14 @@ int real_main(int argc, char* argv[]) { // The context binary file generated later includes all graphs from previous models { std::string provider_name_ = test_config.machine_config.provider_type_name; - if (provider_name_ == onnxruntime::kQnnExecutionProvider) { + + if (const auto& plugin_ep_config = test_config.machine_config.plugin_ep_config; plugin_ep_config.has_value()) { + if (!SetPluginEpSessionOptions(env, so, *plugin_ep_config, plugin_ep_library_registration_handle)) { + std::cerr << "ERROR: Failed to initialize session for plugin EP " + << test_config.machine_config.plugin_ep_config->ep_library_path << std::endl; + return 1; + } + } else if (provider_name_ == onnxruntime::kQnnExecutionProvider) { #ifdef USE_QNN so.AppendExecutionProvider("QNN", provider_options); #else diff --git a/onnxruntime/test/ep_weight_sharing_ctx_gen/test_configuration.h b/onnxruntime/test/ep_weight_sharing_ctx_gen/test_configuration.h index 198d03211f561..6dfb7b60ddc27 100644 --- a/onnxruntime/test/ep_weight_sharing_ctx_gen/test_configuration.h +++ b/onnxruntime/test/ep_weight_sharing_ctx_gen/test_configuration.h @@ -5,6 +5,7 @@ #include #include +#include #include #include @@ -14,8 +15,25 @@ namespace onnxruntime { namespace qnnctxgen { +// Configuration for initializing the dynamic plugin EP infrastructure. +struct PluginEpConfig { + std::string ep_library_registration_name{}; + std::string ep_library_path{}; + + // Note: Exactly one of `selected_ep_name` or `selected_ep_device_indices` should be set. + // An empty value for either means it is unset. + + // Specifies the EP devices matching this EP name as the selected EP devices. + std::string selected_ep_name{}; + // Specifies the selected EP devices by their indices. + std::vector selected_ep_device_indices{}; + + std::unordered_map default_ep_options{}; +}; + struct MachineConfig { std::string provider_type_name{onnxruntime::kQnnExecutionProvider}; + std::optional plugin_ep_config = std::nullopt; }; struct RunConfig { diff --git a/onnxruntime/test/python/helper.py b/onnxruntime/test/python/helper.py index 2a2c3fc9b4532..99960640fe92e 100644 --- a/onnxruntime/test/python/helper.py +++ b/onnxruntime/test/python/helper.py @@ -1,4 +1,5 @@ import os +import sys def get_name(name): @@ -13,3 +14,14 @@ def get_name(name): if os.path.exists(res): return res raise FileNotFoundError(f"Unable to find '{name}' or '{rel}' or '{res}'") + + +def get_shared_library_filename_for_platform(base_name): + if sys.platform.startswith("win"): + return base_name + ".dll" + + if sys.platform.startswith("darwin"): + return "lib" + base_name + ".dylib" + + # Else, assume linux + return "lib" + base_name + ".so" diff --git a/onnxruntime/test/python/onnxruntime_test_python_compile_api.py b/onnxruntime/test/python/onnxruntime_test_python_compile_api.py index e46cdb4f98850..c60307d3c0116 100644 --- a/onnxruntime/test/python/onnxruntime_test_python_compile_api.py +++ b/onnxruntime/test/python/onnxruntime_test_python_compile_api.py @@ -10,7 +10,7 @@ import onnx from autoep_helper import AutoEpTestCase -from helper import get_name +from helper import get_name, get_shared_library_filename_for_platform import onnxruntime as onnxrt from onnxruntime.capi.onnxruntime_pybind11_state import Fail, ModelRequiresCompilation @@ -53,6 +53,52 @@ def test_compile_with_files_prefer_npu_policy(self): self.assertTrue(os.path.exists(output_model_path)) self.unregister_execution_provider_library(ep_name) + def test_compile_shared_resources_plugin_ep(self): + """ + Test compiling two example models using weight sharing (via example plugin EP) + """ + ep_lib_path = get_shared_library_filename_for_platform("example_plugin_ep") + try: + ep_lib_path = get_name(ep_lib_path) + except FileNotFoundError: + self.skipTest(f"Skipping test because EP library '{ep_lib_path}' cannot be found") + + ep_name = "example_ep" + self.register_execution_provider_library(ep_name, os.path.realpath(ep_lib_path)) + + ep_device = next((d for d in onnxrt.get_ep_devices() if d.ep_name == ep_name), None) + self.assertIsNotNone(ep_device) + + input_models = [get_name("add_mul_add.onnx"), get_name("mul_1.onnx")] + output_models = [ + os.path.join(self._tmp_dir_path, "output_model_0_ctx.onnx"), + os.path.join(self._tmp_dir_path, "output_model_1_ctx.onnx"), + ] + + num_models = len(input_models) + session_options = onnxrt.SessionOptions() + + # Set option that tells EP to share resources (e.g., weights) across sessions. The example plugin EP + # doesn't actually do anything special, but we do this to test the API + session_options.add_session_config_entry("ep.share_ep_contexts", "1") + session_options.add_provider_for_devices([ep_device], {}) + + # Compile individual models + for i in range(num_models): + if i == num_models - 1: + # Tell EP that this is the last session that will be sharing resources. + session_options.add_session_config_entry("ep.stop_share_ep_contexts", "1") + + model_compiler = onnxrt.ModelCompiler( + session_options, + input_models[i], + embed_compiled_data_into_model=False, + ) + model_compiler.compile_to_file(output_models[i]) + self.assertTrue(os.path.exists(output_models[i])) + + self.unregister_execution_provider_library(ep_name) + def test_compile_with_ep_selection_delegate(self): """ Tests compiling a model (to/from files) using an EP selection delegate callback. From e6e048e073fd9e787195d809a66904990d48b082 Mon Sep 17 00:00:00 2001 From: Colm Donelan <52702205+Colm-in-Arm@users.noreply.github.com> Date: Tue, 25 Nov 2025 03:04:42 +0000 Subject: [PATCH 17/17] KFI-203 Improve thread safety of packing in convolve_kleidiai.cpp (#26575) ### Description Making cache objects of packed data thread_local rather than static. ### Motivation and Context Both LHS and RHS packing utilize a cache mechanism based on a static unordered map. There's the potential for interference between parallel inference sessions. Made both structures thread_local. Signed-off-by: Colm Donelan --- onnxruntime/core/mlas/lib/kleidiai/convolve_kleidiai.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/onnxruntime/core/mlas/lib/kleidiai/convolve_kleidiai.cpp b/onnxruntime/core/mlas/lib/kleidiai/convolve_kleidiai.cpp index fb3b1d1d29eec..487e1533f5967 100644 --- a/onnxruntime/core/mlas/lib/kleidiai/convolve_kleidiai.cpp +++ b/onnxruntime/core/mlas/lib/kleidiai/convolve_kleidiai.cpp @@ -332,8 +332,8 @@ static std::shared_ptr RhsPackWeightsBiasSme(const size_t co, const const float* weights, const float* bias, MLAS_THREADPOOL* ThreadPool) { - //cache of prepacked kai rhs weights and biases - static std::unordered_map> rhs_cache; + // Cache of prepacked kai rhs weights and biases. thread_local to prevent interference from parallel sessions. + thread_local std::unordered_map> rhs_cache; RhsCacheKey key = { co, ci, kh, kw, dilationh, dilationw, HashWeights(weights) }; @@ -474,8 +474,8 @@ static std::unique_ptr LhsPackImageDataSme(const size_t ci, const s auto nhwc = NChwToNhwc(1, ci, ih, iw, in, 1, 1, false, ThreadPool); - //cache of computed lhs ptr offsets - static std::unordered_map> lhs_ptrs_cache; + // Cache of computed lhs ptr offsets. thread_local to prevent interference from parallel sessions. + thread_local std::unordered_map> lhs_ptrs_cache; std::shared_ptr lhs_ptrs; if (auto found = lhs_ptrs_cache.find(key); found != lhs_ptrs_cache.end()) {