diff --git a/cmake/onnxruntime.cmake b/cmake/onnxruntime.cmake index ae6684b061883..010696a61022c 100644 --- a/cmake/onnxruntime.cmake +++ b/cmake/onnxruntime.cmake @@ -22,13 +22,14 @@ endif() function(get_c_cxx_api_headers HEADERS_VAR) set(_headers "${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_c_api.h" - "${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_ep_c_api.h" "${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_cxx_api.h" "${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_cxx_inline.h" + "${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_ep_c_api.h" + "${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_ep_device_ep_metadata_keys.h" "${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_float16.h" + "${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_lite_custom_op.h" "${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_run_options_config_keys.h" "${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h" - "${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_lite_custom_op.h" ) if (onnxruntime_ENABLE_TRAINING_APIS) diff --git a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h index c53a2f42247d9..44c7bb6ee424a 100644 --- a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h @@ -350,12 +350,12 @@ struct OrtEp { uint32_t ort_version_supported; /** \brief Get the execution provider name. + * + * The returned string should be a null-terminated, UTF-8 encoded string. ORT will copy it. * * \param[in] this_ptr The OrtEp instance. * \return The execution provider name. * - * \note Returned string is owned by ORT and valid until UnregisterExecutionProviderLibrary is called. - * * \since Version 1.22. */ const char*(ORT_API_CALL* GetName)(_In_ const OrtEp* this_ptr); @@ -578,6 +578,8 @@ struct OrtEpFactory { uint32_t ort_version_supported; /** \brief Get the name of the execution provider that the factory creates. + * + * The returned string should be a null-terminated, UTF-8 encoded string. ORT will copy it. * * \param[in] this_ptr The OrtEpFactory instance. * \return The name of the execution provider the factory creates. @@ -587,6 +589,8 @@ struct OrtEpFactory { const char*(ORT_API_CALL* GetName)(const OrtEpFactory* this_ptr); /** \brief Get the name of vendor who owns the execution provider that the factory creates. + * + * The returned string should be a null-terminated, UTF-8 encoded string. ORT will copy it. * * \param[in] this_ptr The OrtEpFactory instance. * \return vendor The vendor name of the execution provider the factory creates. @@ -659,6 +663,20 @@ struct OrtEpFactory { */ void(ORT_API_CALL* ReleaseEp)(OrtEpFactory* this_ptr, struct OrtEp* ep); + /** \brief Get the version of the execution provider that the factory creates. + * + * The version string should adhere to the Semantic Versioning 2.0 specification + * (https://github.com/semver/semver/blob/v2.0.0/semver.md). + * + * The returned string should be a null-terminated, UTF-8 encoded string. ORT will copy it. + * + * \param[in] this_ptr The OrtEpFactory instance. + * \return The execution provider version string. + * + * \since Version 1.23. + */ + const char*(ORT_API_CALL* GetVersion)(_In_ const OrtEpFactory* this_ptr); + /** \brief Create an OrtAllocator for the given OrtMemoryInfo. * * This is used to create an allocator that an execution provider requires. The factory that creates the EP is diff --git a/include/onnxruntime/core/session/onnxruntime_ep_device_ep_metadata_keys.h b/include/onnxruntime/core/session/onnxruntime_ep_device_ep_metadata_keys.h new file mode 100644 index 0000000000000..f0992f05f31e5 --- /dev/null +++ b/include/onnxruntime/core/session/onnxruntime_ep_device_ep_metadata_keys.h @@ -0,0 +1,10 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +// This file contains well-known keys for OrtEpDevice EP metadata entries. +// It does NOT specify all available metadata keys. + +// Key for the execution provider version string. This should be available for all plugin EPs. +static const char* const kOrtEpDevice_EpMetadataKey_Version = "version"; diff --git a/onnxruntime/core/common/semver.cc b/onnxruntime/core/common/semver.cc new file mode 100644 index 0000000000000..618d9dc29ea74 --- /dev/null +++ b/onnxruntime/core/common/semver.cc @@ -0,0 +1,60 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/common/semver.h" + +#include + +#include "core/common/common.h" +#include "core/common/narrow.h" +#include "core/common/parse_string.h" + +namespace onnxruntime { + +Status ParseSemVerVersion(std::string_view version_string, SemVerVersion* semver_version_out) { + // Semantic Versioning version regex was copied from here: + // https://github.com/semver/semver/blob/d58db1686379c8c6d52e32d42d3a530a964264e5/semver.md?plain=1#L357 + static const std::regex semver_pattern{ + R"(^(0|[1-9]\d*)\.(0|[1-9]\d*)\.(0|[1-9]\d*)(?:-((?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*)(?:\.(?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*))*))?(?:\+([0-9a-zA-Z-]+(?:\.[0-9a-zA-Z-]+)*))?$)"}; + + std::cmatch match_result{}; + ORT_RETURN_IF_NOT(std::regex_match(version_string.data(), version_string.data() + version_string.size(), + match_result, semver_pattern), + "Version string is not in semantic versioning format: '", version_string, "'"); + + auto sub_match_to_string_view = [](const std::csub_match& sub_match) -> std::optional { + if (!sub_match.matched) { + return std::nullopt; + } + return std::string_view{sub_match.first, narrow(sub_match.length())}; + }; + + auto parse_version_component = + [&sub_match_to_string_view](const std::csub_match& sub_match, uint32_t& component) -> Status { + const auto component_str = sub_match_to_string_view(sub_match); + ORT_RETURN_IF_NOT(component_str.has_value(), "sub_match does not match anything."); + return ParseStringWithClassicLocale(*component_str, component); + }; + + SemVerVersion semver_version{}; + + ORT_RETURN_IF_ERROR(parse_version_component(match_result[1], semver_version.major)); + ORT_RETURN_IF_ERROR(parse_version_component(match_result[2], semver_version.minor)); + ORT_RETURN_IF_ERROR(parse_version_component(match_result[3], semver_version.patch)); + + semver_version.prerelease = sub_match_to_string_view(match_result[4]); + semver_version.build_metadata = sub_match_to_string_view(match_result[5]); + + if (semver_version_out) { + *semver_version_out = std::move(semver_version); + } + return Status::OK(); +} + +SemVerVersion ParseSemVerVersion(std::string_view version_string) { + SemVerVersion result{}; + ORT_THROW_IF_ERROR(ParseSemVerVersion(version_string, &result)); + return result; +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/common/semver.h b/onnxruntime/core/common/semver.h new file mode 100644 index 0000000000000..a07c24f016886 --- /dev/null +++ b/onnxruntime/core/common/semver.h @@ -0,0 +1,32 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include + +#include "core/common/status.h" + +namespace onnxruntime { + +// Semantic Versioning version utilities. +// See https://github.com/semver/semver/blob/v2.0.0/semver.md. + +// Semantic Versioning version components. +struct SemVerVersion { + uint32_t major{}; + uint32_t minor{}; + uint32_t patch{}; + std::optional prerelease{}; + std::optional build_metadata{}; +}; + +// Parse a Semantic Versioning version from `version_string`. +// If provided, the parsed version components will be written to `semver_version`. +Status ParseSemVerVersion(std::string_view version_string, SemVerVersion* semver_version); + +// Parse a Semantic Versioning version from `version_string`. +SemVerVersion ParseSemVerVersion(std::string_view version_string); + +} // namespace onnxruntime diff --git a/onnxruntime/core/graph/graph_proto_serializer.cc b/onnxruntime/core/graph/graph_proto_serializer.cc index 80bb3f13814d1..993020278eb03 100644 --- a/onnxruntime/core/graph/graph_proto_serializer.cc +++ b/onnxruntime/core/graph/graph_proto_serializer.cc @@ -11,7 +11,8 @@ void GraphViewerToProto(const GraphViewer& graph_view, ONNX_NAMESPACE::GraphProto& graph_proto, bool include_initializer, bool include_outer_scope_args, - ExecutionOrder order) { + ExecutionOrder order, + bool include_initializer_data) { graph_proto.set_name(graph_view.Name()); graph_proto.set_doc_string(graph_view.Description()); @@ -92,7 +93,25 @@ void GraphViewerToProto(const GraphViewer& graph_view, const auto& [name, init] = *it; current_scope_initializer_set.insert(name); auto* p_initializer = graph_proto.add_initializer(); - ORT_THROW_IF_ERROR(get_initializer_with_data(*init, *p_initializer)); + + // Do not save raw or external data into the graph, only the metadata + if (!include_initializer_data && (init->has_raw_data() || init->has_data_location())) { + // Set datatype + if (init->has_data_type()) { + p_initializer->set_data_type(init->data_type()); + } + // Set name + if (init->has_name()) { + p_initializer->set_name(init->name()); + } + + // Set dims + for (int i = 0; i < init->dims_size(); ++i) { + p_initializer->add_dims(init->dims()[i]); + } + } else { + ORT_THROW_IF_ERROR(get_initializer_with_data(*init, *p_initializer)); + } } // handle outer scope value which is a constant initializer diff --git a/onnxruntime/core/graph/graph_proto_serializer.h b/onnxruntime/core/graph/graph_proto_serializer.h index ce21e1b609b26..2a8180477c476 100644 --- a/onnxruntime/core/graph/graph_proto_serializer.h +++ b/onnxruntime/core/graph/graph_proto_serializer.h @@ -11,5 +11,6 @@ void GraphViewerToProto(const GraphViewer& graph_view, ONNX_NAMESPACE::GraphProto& graph_proto, bool include_initializer, bool include_outer_scope_args, - ExecutionOrder order = ExecutionOrder::DEFAULT); + ExecutionOrder order = ExecutionOrder::DEFAULT, + bool include_initializer_data = true); } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/cuda_provider_factory.cc b/onnxruntime/core/providers/cuda/cuda_provider_factory.cc index 6ba2dd8176590..2de496a9168a0 100644 --- a/onnxruntime/core/providers/cuda/cuda_provider_factory.cc +++ b/onnxruntime/core/providers/cuda/cuda_provider_factory.cc @@ -308,12 +308,14 @@ CUDA_Provider* GetProvider() { } // namespace onnxruntime #include "core/framework/error_code_helper.h" +#include "onnxruntime_config.h" // for ORT_VERSION // OrtEpApi infrastructure to be able to use the CUDA EP as an OrtEpFactory for auto EP selection. struct CudaEpFactory : OrtEpFactory { CudaEpFactory(const OrtApi& ort_api_in) : ort_api{ort_api_in} { GetName = GetNameImpl; GetVendor = GetVendorImpl; + GetVersion = GetVersionImpl; GetSupportedDevices = GetSupportedDevicesImpl; CreateEp = CreateEpImpl; ReleaseEp = ReleaseEpImpl; @@ -329,6 +331,10 @@ struct CudaEpFactory : OrtEpFactory { return factory->vendor.c_str(); } + static const char* ORT_API_CALL GetVersionImpl(const OrtEpFactory* /*this_ptr*/) noexcept { + return ORT_VERSION; + } + static OrtStatus* GetSupportedDevicesImpl(OrtEpFactory* this_ptr, const OrtHardwareDevice* const* devices, size_t num_devices, diff --git a/onnxruntime/core/providers/qnn/qnn_provider_factory.cc b/onnxruntime/core/providers/qnn/qnn_provider_factory.cc index 8a5f83f636824..c679ea1adb286 100644 --- a/onnxruntime/core/providers/qnn/qnn_provider_factory.cc +++ b/onnxruntime/core/providers/qnn/qnn_provider_factory.cc @@ -116,6 +116,7 @@ ORT_API(onnxruntime::Provider*, GetProvider) { } #include "core/framework/error_code_helper.h" +#include "onnxruntime_config.h" // for ORT_VERSION // OrtEpApi infrastructure to be able to use the QNN EP as an OrtEpFactory for auto EP selection. struct QnnEpFactory : OrtEpFactory { @@ -126,6 +127,7 @@ struct QnnEpFactory : OrtEpFactory { : ort_api{ort_api_in}, ep_name{ep_name}, ort_hw_device_type{hw_type}, qnn_backend_type{qnn_backend_type} { GetName = GetNameImpl; GetVendor = GetVendorImpl; + GetVersion = GetVersionImpl; GetSupportedDevices = GetSupportedDevicesImpl; CreateEp = CreateEpImpl; ReleaseEp = ReleaseEpImpl; @@ -143,6 +145,10 @@ struct QnnEpFactory : OrtEpFactory { return factory->vendor.c_str(); } + static const char* ORT_API_CALL GetVersionImpl(const OrtEpFactory* /*this_ptr*/) noexcept { + return ORT_VERSION; + } + // Creates and returns OrtEpDevice instances for all OrtHardwareDevices that this factory supports. // An EP created with this factory is expected to be able to execute a model with *all* supported // hardware devices at once. A single instance of QNN EP is not currently setup to partition a model among diff --git a/onnxruntime/core/providers/shared_library/provider_interfaces.h b/onnxruntime/core/providers/shared_library/provider_interfaces.h index dba26b3982d86..44dd70211327e 100644 --- a/onnxruntime/core/providers/shared_library/provider_interfaces.h +++ b/onnxruntime/core/providers/shared_library/provider_interfaces.h @@ -1097,7 +1097,8 @@ struct ProviderHost { ONNX_NAMESPACE::GraphProto& graph_proto, bool include_initializers, bool include_outer_scope_args, - int execution_order) noexcept = 0; + int execution_order, + bool include_initializer_data) noexcept = 0; virtual const Node* GraphViewer__GetProducerNode(const GraphViewer* p, const std::string& node_arg_name) const = 0; virtual IOnnxRuntimeOpSchemaCollectionPtr GraphViewer__GetSchemaRegistry(const GraphViewer* p) const = 0; diff --git a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h index 80b5e26db8680..23fbead1e9707 100644 --- a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h +++ b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h @@ -1150,8 +1150,9 @@ class GraphViewer final { void ToProto(ONNX_NAMESPACE::GraphProto& graph_proto, bool include_initializers, bool include_outer_scope_args, - int execution_order = 0) const { - g_host->GraphViewer__ToProto(this, graph_proto, include_initializers, include_outer_scope_args, execution_order); + int execution_order = 0, + bool include_initializer_data = true) const { + g_host->GraphViewer__ToProto(this, graph_proto, include_initializers, include_outer_scope_args, execution_order, include_initializer_data); } const Node* GetProducerNode(const std::string& node_arg_name) const { return g_host->GraphViewer__GetProducerNode(this, node_arg_name); } IOnnxRuntimeOpSchemaCollectionPtr GetSchemaRegistry() const { return g_host->GraphViewer__GetSchemaRegistry(this); } diff --git a/onnxruntime/core/providers/webgpu/allocator.cc b/onnxruntime/core/providers/webgpu/allocator.cc index 315d0cd75e946..48d884858f493 100644 --- a/onnxruntime/core/providers/webgpu/allocator.cc +++ b/onnxruntime/core/providers/webgpu/allocator.cc @@ -3,7 +3,7 @@ #include "core/framework/session_state.h" #include "core/providers/webgpu/allocator.h" -#include "core/providers/webgpu/webgpu_context.h" +#include "core/providers/webgpu/buffer_manager.h" namespace onnxruntime { namespace webgpu { @@ -15,18 +15,17 @@ void* GpuBufferAllocator::Alloc(size_t size) { stats_.num_allocs++; -#if !defined(__wasm__) - if (!session_initialized_ && context_.DeviceHasFeature(wgpu::FeatureName::BufferMapExtendedUsages)) { - return context_.BufferManager().CreateUMA(size); + // Check if the buffer manager supports UMA and we're not yet in an initialized session + if (!session_initialized_ && buffer_manager_.SupportsUMA()) { + return buffer_manager_.CreateUMA(size); } -#endif // !defined(__wasm__) - return context_.BufferManager().Create(size); + return buffer_manager_.Create(size); } void GpuBufferAllocator::Free(void* p) { if (p != nullptr) { - context_.BufferManager().Release(static_cast(p)); + buffer_manager_.Release(static_cast(p)); stats_.num_allocs--; } } diff --git a/onnxruntime/core/providers/webgpu/allocator.h b/onnxruntime/core/providers/webgpu/allocator.h index 0b27f713777bc..de9b0a800ef64 100644 --- a/onnxruntime/core/providers/webgpu/allocator.h +++ b/onnxruntime/core/providers/webgpu/allocator.h @@ -9,27 +9,26 @@ namespace onnxruntime { namespace webgpu { -class WebGpuContext; +class BufferManager; class GpuBufferAllocator : public IAllocator { public: - GpuBufferAllocator(const WebGpuContext& context) + GpuBufferAllocator(const BufferManager& buffer_manager) : IAllocator( OrtMemoryInfo(WEBGPU_BUFFER, OrtAllocatorType::OrtDeviceAllocator, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NONE, 0), OrtMemTypeDefault)), - context_{context} { + buffer_manager_{buffer_manager} { } virtual void* Alloc(size_t size) override; virtual void Free(void* p) override; void GetStats(AllocatorStats* stats) override; - void OnSessionInitializationEnd(); private: AllocatorStats stats_; - const WebGpuContext& context_; + const BufferManager& buffer_manager_; bool session_initialized_ = false; }; diff --git a/onnxruntime/core/providers/webgpu/buffer_manager.cc b/onnxruntime/core/providers/webgpu/buffer_manager.cc index 1d8c689cbd909..e8140a4d59eab 100644 --- a/onnxruntime/core/providers/webgpu/buffer_manager.cc +++ b/onnxruntime/core/providers/webgpu/buffer_manager.cc @@ -37,7 +37,7 @@ class DisabledCacheManager : public IBufferCacheManager { wgpuBufferRelease(buffer); } - void OnRefresh() override { + void OnRefresh(GraphCaptureState /*graph_capture_state*/) override { // no-op } }; @@ -59,7 +59,7 @@ class LazyReleaseCacheManager : public IBufferCacheManager { pending_buffers_.emplace_back(buffer); } - void OnRefresh() override { + void OnRefresh(GraphCaptureState /*graph_capture_state*/) override { Release(); pending_buffers_.clear(); } @@ -103,7 +103,7 @@ class SimpleCacheManager : public IBufferCacheManager { pending_buffers_.emplace_back(buffer); } - void OnRefresh() override { + void OnRefresh(GraphCaptureState /*graph_capture_state*/) override { for (auto& buffer : pending_buffers_) { buffers_[static_cast(wgpuBufferGetSize(buffer))].emplace_back(buffer); } @@ -196,12 +196,9 @@ class BucketCacheManager : public IBufferCacheManager { pending_buffers_.emplace_back(buffer); } - void OnRefresh() override { - // TODO: consider graph capture. currently not supported - + void OnRefresh(GraphCaptureState /*graph_capture_state*/) override { for (auto& buffer : pending_buffers_) { auto buffer_size = static_cast(wgpuBufferGetSize(buffer)); - auto it = buckets_.find(buffer_size); if (it != buckets_.end() && it->second.size() < buckets_limit_[buffer_size]) { it->second.emplace_back(buffer); @@ -249,6 +246,155 @@ class BucketCacheManager : public IBufferCacheManager { std::vector buckets_keys_; }; +class GraphCacheManager : public IBufferCacheManager { + public: + GraphCacheManager() : buckets_limit_{BUCKET_DEFAULT_LIMIT_TABLE} { + Initialize(); + } + GraphCacheManager(std::unordered_map&& buckets_limit) : buckets_limit_{buckets_limit} { + Initialize(); + } + + size_t CalculateBufferSize(size_t request_size) override { + // binary serch size + auto it = std::lower_bound(buckets_keys_.begin(), buckets_keys_.end(), request_size); + if (it == buckets_keys_.end()) { + return NormalizeBufferSize(request_size); + } else { + return *it; + } + } + + WGPUBuffer TryAcquireCachedBuffer(size_t buffer_size) override { + auto it = buckets_.find(buffer_size); + if (it != buckets_.end() && !it->second.empty()) { + auto buffer = it->second.back(); + it->second.pop_back(); + return buffer; + } + return nullptr; + } + + void RegisterBuffer(WGPUBuffer /*buffer*/, size_t /*request_size*/) override { + // no-op + } + + void ReleaseBuffer(WGPUBuffer buffer) override { + pending_buffers_.emplace_back(buffer); + } + + void OnRefresh(GraphCaptureState /*graph_capture_state*/) override { + // Initialize buckets if they don't exist yet + if (buckets_.empty()) { + for (const auto& pair : buckets_limit_) { + buckets_.emplace(pair.first, std::vector()); + } + } + + for (auto& buffer : pending_buffers_) { + auto buffer_size = static_cast(wgpuBufferGetSize(buffer)); + auto it = buckets_.find(buffer_size); + if (it != buckets_.end()) { + it->second.emplace_back(buffer); + } else { + // insert a new bucket if it doesn't exist + buckets_[buffer_size] = std::vector{buffer}; + } + } + + pending_buffers_.clear(); + } + + ~GraphCacheManager() { + for (auto& buffer : pending_buffers_) { + wgpuBufferRelease(buffer); + } + for (auto& pair : buckets_) { + for (auto& buffer : pair.second) { + wgpuBufferRelease(buffer); + } + } + } + + protected: + void Initialize() { + buckets_keys_.reserve(buckets_limit_.size()); + for (const auto& pair : buckets_limit_) { + buckets_keys_.push_back(pair.first); + } + std::sort(buckets_keys_.begin(), buckets_keys_.end()); + +#ifndef NDEBUG // if debug build + ORT_ENFORCE(std::all_of(buckets_keys_.begin(), buckets_keys_.end(), [](size_t size) { return size % 16 == 0; }), + "Bucket sizes must be multiples of 16."); + + for (size_t i = 1; i < buckets_keys_.size(); ++i) { + ORT_ENFORCE(buckets_keys_[i] > buckets_keys_[i - 1], "Bucket sizes must be in increasing order."); + } +#endif + } + std::unordered_map buckets_limit_; + std::unordered_map> buckets_; + std::vector pending_buffers_; + std::vector buckets_keys_; +}; + +class GraphSimpleCacheManager : public IBufferCacheManager { + size_t CalculateBufferSize(size_t request_size) override { + return NormalizeBufferSize(request_size); + } + + WGPUBuffer TryAcquireCachedBuffer(size_t buffer_size) override { + auto it = buffers_.find(buffer_size); + if (it != buffers_.end() && !it->second.empty()) { + auto buffer = it->second.back(); + it->second.pop_back(); + return buffer; + } + + return nullptr; + } + + void RegisterBuffer(WGPUBuffer /*buffer*/, size_t /*request_size*/) override { + // no-op + } + + void ReleaseBuffer(WGPUBuffer buffer) override { + pending_buffers_.emplace_back(buffer); + } + + void OnRefresh(GraphCaptureState graph_capture_state) override { + for (auto& buffer : pending_buffers_) { + if (graph_capture_state == GraphCaptureState::Default) { + buffers_[static_cast(wgpuBufferGetSize(buffer))].emplace_back(buffer); + } else { + captured_buffers_.emplace_back(buffer); + } + } + pending_buffers_.clear(); + } + + public: + ~GraphSimpleCacheManager() { + for (auto& buffer : pending_buffers_) { + wgpuBufferRelease(buffer); + } + for (auto& pair : buffers_) { + for (auto& buffer : pair.second) { + wgpuBufferRelease(buffer); + } + } + for (auto& buffer : captured_buffers_) { + wgpuBufferRelease(buffer); + } + } + + protected: + std::map> buffers_; + std::vector pending_buffers_; + std::vector captured_buffers_; +}; + std::unique_ptr CreateBufferCacheManager(BufferCacheMode cache_mode) { switch (cache_mode) { case BufferCacheMode::Disabled: @@ -259,6 +405,10 @@ std::unique_ptr CreateBufferCacheManager(BufferCacheMode ca return std::make_unique(); case BufferCacheMode::Bucket: return std::make_unique(); + case BufferCacheMode::Graph: + return std::make_unique(); + case BufferCacheMode::GraphSimple: + return std::make_unique(); default: ORT_NOT_IMPLEMENTED("Unsupported buffer cache mode"); } @@ -278,6 +428,12 @@ std::ostream& operator<<(std::ostream& os, BufferCacheMode mode) { case BufferCacheMode::Bucket: os << "Bucket"; break; + case BufferCacheMode::Graph: + os << "Graph"; + break; + case BufferCacheMode::GraphSimple: + os << "GraphSimple"; + break; default: os << "Unknown(" << static_cast(mode) << ")"; } @@ -292,7 +448,7 @@ BufferManager::BufferManager(WebGpuContext& context, BufferCacheMode storage_buf default_cache_{CreateBufferCacheManager(BufferCacheMode::Disabled)} { } -void BufferManager::Upload(void* src, WGPUBuffer dst, size_t size) { +void BufferManager::Upload(void* src, WGPUBuffer dst, size_t size) const { // If the buffer is mapped, we can directly write to it. void* mapped_data = wgpuBufferGetMappedRange(dst, 0, WGPU_WHOLE_MAP_SIZE); // ensure the buffer is mapped if (mapped_data) { @@ -317,10 +473,10 @@ void BufferManager::Upload(void* src, WGPUBuffer dst, size_t size) { auto& command_encoder = context_.GetCommandEncoder(); context_.EndComputePass(); command_encoder.CopyBufferToBuffer(staging_buffer, 0, dst, 0, buffer_size); - context_.Flush(); + context_.Flush(*this); } -void BufferManager::MemCpy(WGPUBuffer src, WGPUBuffer dst, size_t size) { +void BufferManager::MemCpy(WGPUBuffer src, WGPUBuffer dst, size_t size) const { ORT_ENFORCE(src != dst, "Source and destination buffers must be different."); EnforceBufferUnmapped(context_, src); EnforceBufferUnmapped(context_, dst); @@ -337,7 +493,7 @@ void BufferManager::MemCpy(WGPUBuffer src, WGPUBuffer dst, size_t size) { command_encoder.CopyBufferToBuffer(src, 0, dst, 0, buffer_size); } -WGPUBuffer BufferManager::Create(size_t size, wgpu::BufferUsage usage) { +WGPUBuffer BufferManager::Create(size_t size, wgpu::BufferUsage usage) const { auto& cache = GetCacheManager(usage); auto buffer_size = cache.CalculateBufferSize(size); @@ -358,7 +514,7 @@ WGPUBuffer BufferManager::Create(size_t size, wgpu::BufferUsage usage) { return buffer; } -WGPUBuffer BufferManager::CreateUMA(size_t size, wgpu::BufferUsage usage) { +WGPUBuffer BufferManager::CreateUMA(size_t size, wgpu::BufferUsage usage) const { ORT_ENFORCE(usage & wgpu::BufferUsage::Storage, "UMA buffer must be a storage buffer."); auto& cache = GetCacheManager(usage); auto buffer_size = cache.CalculateBufferSize(size); @@ -378,12 +534,21 @@ WGPUBuffer BufferManager::CreateUMA(size_t size, wgpu::BufferUsage usage) { return buffer; } -void BufferManager::Release(WGPUBuffer buffer) { +bool BufferManager::SupportsUMA() const { +#if !defined(__wasm__) + // Check if the device supports the BufferMapExtendedUsages feature + return context_.DeviceHasFeature(wgpu::FeatureName::BufferMapExtendedUsages); +#else + return false; +#endif // !defined(__wasm__) +} + +void BufferManager::Release(WGPUBuffer buffer) const { EnforceBufferUnmapped(context_, buffer); GetCacheManager(buffer).ReleaseBuffer(buffer); } -void BufferManager::Download(WGPUBuffer src, void* dst, size_t size) { +void BufferManager::Download(WGPUBuffer src, void* dst, size_t size) const { EnforceBufferUnmapped(context_, src); auto buffer_size = NormalizeBufferSize(size); @@ -395,7 +560,7 @@ void BufferManager::Download(WGPUBuffer src, void* dst, size_t size) { auto& command_encoder = context_.GetCommandEncoder(); context_.EndComputePass(); command_encoder.CopyBufferToBuffer(src, 0, staging_buffer, 0, buffer_size); - context_.Flush(); + context_.Flush(*this); // TODO: revise wait in whole project @@ -405,13 +570,14 @@ void BufferManager::Download(WGPUBuffer src, void* dst, size_t size) { auto mapped_data = staging_buffer.GetConstMappedRange(); memcpy(dst, mapped_data, size); + staging_buffer.Unmap(); } -void BufferManager::RefreshPendingBuffers() { - storage_cache_->OnRefresh(); - uniform_cache_->OnRefresh(); - query_resolve_cache_->OnRefresh(); - default_cache_->OnRefresh(); +void BufferManager::RefreshPendingBuffers(GraphCaptureState graph_capture_state) const { + storage_cache_->OnRefresh(graph_capture_state); + uniform_cache_->OnRefresh(graph_capture_state); + query_resolve_cache_->OnRefresh(graph_capture_state); + default_cache_->OnRefresh(graph_capture_state); } IBufferCacheManager& BufferManager::GetCacheManager(wgpu::BufferUsage usage) const { diff --git a/onnxruntime/core/providers/webgpu/buffer_manager.h b/onnxruntime/core/providers/webgpu/buffer_manager.h index b9028ad5de858..e854139496726 100644 --- a/onnxruntime/core/providers/webgpu/buffer_manager.h +++ b/onnxruntime/core/providers/webgpu/buffer_manager.h @@ -14,11 +14,20 @@ namespace webgpu { class WebGpuContext; +// For command capture and replay +enum class GraphCaptureState { + Default, + Capturing, + Replaying +}; + enum class BufferCacheMode { Disabled, LazyRelease, Simple, - Bucket + Bucket, + Graph, + GraphSimple, }; std::ostream& operator<<(std::ostream& os, BufferCacheMode mode); @@ -26,12 +35,13 @@ std::ostream& operator<<(std::ostream& os, BufferCacheMode mode); // IBufferCacheManager is an interface for buffer cache management. // // By implementing this interface, we can have different buffer cache management strategies. -// Currently, we have 3 strategies: +// Currently, we have 5 strategies: // - Disabled: no cache. always allocate a new buffer and release it immediately after use. // - LazyRelease: no cache. the difference from Disabled is that it delays the release of buffers until the next refresh. // - Simple: a simple cache that always keeps buffers. when a buffer is requested, it tries to find a buffer in the cache. // - Bucket: a cache that keeps buffers in different buckets based on the buffer size, with a maximum number of buffers in each bucket. -// +// - Graph: used for graph capturing storage buffer cache mode. All buffers will be cached. Buffers can be reused across runs and in one run. +// - GraphSimple: used for graph capturing uniform buffer cache mode. All buffers will be cached. Buffers can be reused across runs but can't be reused in one run. class IBufferCacheManager { public: virtual ~IBufferCacheManager() = default; @@ -49,7 +59,7 @@ class IBufferCacheManager { virtual void ReleaseBuffer(WGPUBuffer buffer) = 0; // when a stream refresh is requested - virtual void OnRefresh() = 0; + virtual void OnRefresh(GraphCaptureState graph_capture_state) = 0; }; // @@ -58,16 +68,16 @@ class IBufferCacheManager { class BufferManager { public: BufferManager(WebGpuContext& context, BufferCacheMode storage_buffer_cache_mode, BufferCacheMode uniform_buffer_cache_mode, BufferCacheMode query_resolve_buffer_cache_mode); - - void Upload(void* src, WGPUBuffer dst, size_t size); - void MemCpy(WGPUBuffer src, WGPUBuffer dst, size_t size); - WGPUBuffer Create(size_t size, wgpu::BufferUsage usage = wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst); + void Upload(void* src, WGPUBuffer dst, size_t size) const; + void MemCpy(WGPUBuffer src, WGPUBuffer dst, size_t size) const; + WGPUBuffer Create(size_t size, wgpu::BufferUsage usage = wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst) const; // Create a buffer mapped for writing. - WGPUBuffer CreateUMA(size_t size, wgpu::BufferUsage usage = wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | - wgpu::BufferUsage::CopyDst); - void Release(WGPUBuffer buffer); - void Download(WGPUBuffer src, void* dst, size_t size); - void RefreshPendingBuffers(); + WGPUBuffer CreateUMA(size_t size, wgpu::BufferUsage usage = wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst) const; + // Check if CreateUMA is supported (i.e., the device has BufferMapExtendedUsages feature) + bool SupportsUMA() const; + void Release(WGPUBuffer buffer) const; + void Download(WGPUBuffer src, void* dst, size_t size) const; + void RefreshPendingBuffers(GraphCaptureState graph_capture_state) const; private: IBufferCacheManager& GetCacheManager(wgpu::BufferUsage usage) const; diff --git a/onnxruntime/core/providers/webgpu/compute_context.cc b/onnxruntime/core/providers/webgpu/compute_context.cc index 1713a9a1ad050..25caa9b954fc0 100644 --- a/onnxruntime/core/providers/webgpu/compute_context.cc +++ b/onnxruntime/core/providers/webgpu/compute_context.cc @@ -5,12 +5,16 @@ #include "core/providers/webgpu/compute_context.h" #include "core/providers/webgpu/webgpu_context.h" +#include "core/providers/webgpu/allocator.h" +#include "core/providers/webgpu/buffer_manager.h" +#include "core/providers/webgpu/webgpu_execution_provider.h" namespace onnxruntime { namespace webgpu { -ComputeContext::ComputeContext(OpKernelContext& kernel_context) +ComputeContext::ComputeContext(OpKernelContext& kernel_context, const WebGpuExecutionProvider& ep) : webgpu_context_{WebGpuContextFactory::GetContext(kernel_context.GetDeviceId())}, - kernel_context_{kernel_context} { + kernel_context_{kernel_context}, + ep_{ep} { } void ComputeContext::PushErrorScope() { @@ -26,5 +30,9 @@ Status ComputeContext::PopErrorScope() { return Status::OK(); } +const webgpu::BufferManager& ComputeContext::BufferManager() const { + return ep_.BufferManager(); +} + } // namespace webgpu } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/compute_context.h b/onnxruntime/core/providers/webgpu/compute_context.h index 7a9cf1ecf85ba..fe95917e4e906 100644 --- a/onnxruntime/core/providers/webgpu/compute_context.h +++ b/onnxruntime/core/providers/webgpu/compute_context.h @@ -16,14 +16,16 @@ namespace onnxruntime { class Tensor; +class WebGpuExecutionProvider; namespace webgpu { class WebGpuContext; +class BufferManager; class ComputeContext { public: - ComputeContext(OpKernelContext& kernel_context); + ComputeContext(OpKernelContext& kernel_context, const WebGpuExecutionProvider& ep); virtual ~ComputeContext() = default; @@ -115,7 +117,6 @@ class ComputeContext { ORT_THROW_IF_ERROR(kernel_context_.GetTempSpaceAllocator(&allocator)); return {data_type, std::forward(shape), allocator}; } - // // Run a compute shader program. // @@ -123,6 +124,11 @@ class ComputeContext { return webgpu_context_.Run(*this, program); } + // + // Get the buffer manager from the GPU allocator. + // + const webgpu::BufferManager& BufferManager() const; + // // Push error scope. // @@ -140,6 +146,7 @@ class ComputeContext { protected: WebGpuContext& webgpu_context_; OpKernelContext& kernel_context_; + const WebGpuExecutionProvider& ep_; }; } // namespace webgpu diff --git a/onnxruntime/core/providers/webgpu/data_transfer.cc b/onnxruntime/core/providers/webgpu/data_transfer.cc index ac376b4fce069..6d66a7308f1de 100644 --- a/onnxruntime/core/providers/webgpu/data_transfer.cc +++ b/onnxruntime/core/providers/webgpu/data_transfer.cc @@ -2,7 +2,7 @@ // Licensed under the MIT License. #include "core/providers/webgpu/data_transfer.h" -#include "core/providers/webgpu/webgpu_context.h" +#include "core/providers/webgpu/buffer_manager.h" namespace onnxruntime { namespace webgpu { @@ -25,15 +25,15 @@ common::Status DataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const { if (dst_device.Type() == OrtDevice::GPU) { if (src_device.Type() == OrtDevice::GPU) { // copy from GPU to GPU - context_.BufferManager().MemCpy(static_cast(const_cast(src_data)), - static_cast(dst_data), bytes); + buffer_manager_.MemCpy(static_cast(const_cast(src_data)), + static_cast(dst_data), bytes); } else { // copy from CPU to GPU - context_.BufferManager().Upload(const_cast(src_data), static_cast(dst_data), bytes); + buffer_manager_.Upload(const_cast(src_data), static_cast(dst_data), bytes); } } else /* if (src_device.Type() == OrtDevice::GPU) */ { // copy from GPU to CPU - context_.BufferManager().Download(static_cast(const_cast(src_data)), dst_data, bytes); + buffer_manager_.Download(static_cast(const_cast(src_data)), dst_data, bytes); } } diff --git a/onnxruntime/core/providers/webgpu/data_transfer.h b/onnxruntime/core/providers/webgpu/data_transfer.h index f9949576aa60b..0adf380149acf 100644 --- a/onnxruntime/core/providers/webgpu/data_transfer.h +++ b/onnxruntime/core/providers/webgpu/data_transfer.h @@ -9,11 +9,11 @@ namespace onnxruntime { namespace webgpu { -class WebGpuContext; +class BufferManager; class DataTransfer : public IDataTransfer { public: - DataTransfer(const WebGpuContext& context) : context_{context} {}; + DataTransfer(const BufferManager& buffer_manager) : buffer_manager_{buffer_manager} {}; ~DataTransfer() {}; bool CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const override; @@ -21,7 +21,7 @@ class DataTransfer : public IDataTransfer { common::Status CopyTensor(const Tensor& src, Tensor& dst) const override; private: - const WebGpuContext& context_; + const BufferManager& buffer_manager_; }; } // namespace webgpu diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.cc b/onnxruntime/core/providers/webgpu/webgpu_context.cc index 4bb41c2eb0ba6..4bd79a627df22 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_context.cc @@ -401,6 +401,7 @@ Status WebGpuContext::Run(ComputeContext& context, const ProgramBase& program) { const size_t uniform_buffer_total_size = (current_offset + max_alignment_of_field - 1) / max_alignment_of_field * max_alignment_of_field; WGPUBuffer uniform_buffer = nullptr; + const webgpu::BufferManager& buffer_mgr = context.BufferManager(); if (uniform_buffer_total_size > 0) { std::vector uniform_data_buffer(uniform_buffer_total_size); @@ -408,7 +409,7 @@ Status WebGpuContext::Run(ComputeContext& context, const ProgramBase& program) { memcpy(uniform_data_buffer.data() + offset, uniform.data.data(), uniform.data.size()); } - uniform_buffer = buffer_mgr_->Create(uniform_buffer_total_size, wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform); + uniform_buffer = buffer_mgr.Create(uniform_buffer_total_size, wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform); device_queue_.WriteBuffer(uniform_buffer, 0, uniform_data_buffer.data(), uniform_buffer_total_size); } @@ -429,13 +430,11 @@ Status WebGpuContext::Run(ComputeContext& context, const ProgramBase& program) { } LaunchComputePipeline(compute_pass_encoder, bind_buffers, *program_artifact, x, y, z); - if (uniform_buffer) { - buffer_mgr_->Release(uniform_buffer); + buffer_mgr.Release(uniform_buffer); } WriteTimestamp(num_pending_dispatches_ * 2 + 1); - ++num_pending_dispatches_; if (num_pending_dispatches_ >= max_num_pending_dispatches_ || @@ -443,7 +442,7 @@ Status WebGpuContext::Run(ComputeContext& context, const ProgramBase& program) { EndComputePass(); } if (num_pending_dispatches_ >= max_num_pending_dispatches_) { - Flush(); + Flush(buffer_mgr); num_pending_dispatches_ = 0; } @@ -659,7 +658,7 @@ Status WebGpuContext::PopErrorScope() { return status; } -void WebGpuContext::Flush() { +void WebGpuContext::Flush(const webgpu::BufferManager& buffer_mgr) { if (!current_command_encoder_) { return; } @@ -690,10 +689,11 @@ void WebGpuContext::Flush() { pending_queries_.emplace_back(std::move(pending_kernels_), query_read_buffer); pending_kernels_.clear(); } - auto command_buffer = current_command_encoder_.Finish(); device_queue_.Submit(1, &command_buffer); - BufferManager().RefreshPendingBuffers(); + if (graph_capture_state_ != GraphCaptureState::Replaying) { + buffer_mgr.RefreshPendingBuffers(graph_capture_state_); + } current_command_encoder_ = nullptr; num_pending_dispatches_ = 0; } @@ -724,15 +724,90 @@ void WebGpuContext::LaunchComputePipeline(const wgpu::ComputePassEncoder& comput bind_group_desc.label = {program_artifact.name.data(), program_artifact.name.length()}; auto bind_group = wgpuDeviceCreateBindGroup(Device().Get(), &bind_group_desc); + if (graph_capture_state_ == GraphCaptureState::Capturing) { + external_captured_commands_->push_back({program_artifact.compute_pipeline, + bind_group, + bind_group_layout, + {x, y, z}}); + } else { + compute_pass_encoder.SetPipeline(program_artifact.compute_pipeline); + wgpuComputePassEncoderSetBindGroup(compute_pass_encoder.Get(), 0, bind_group, 0, nullptr); + compute_pass_encoder.DispatchWorkgroups(x, y, z); + + wgpuBindGroupRelease(bind_group); + wgpuBindGroupLayoutRelease(bind_group_layout); + } +} + +void WebGpuContext::CaptureBegin(std::vector* captured_commands, const webgpu::BufferManager& buffer_manager) { + LOGS_DEFAULT(VERBOSE) << "CaptureBegin with external storage"; + // Flush any pending commands before we change the status + Flush(buffer_manager); + + external_captured_commands_ = captured_commands; + + // Make sure the external vector is empty before we start capturing + if (external_captured_commands_) { + external_captured_commands_->clear(); + } + + // TODO: support profiling with graph capture. + ORT_ENFORCE(!is_profiling_, "profiling is not supported yet under graph capture mode"); + + graph_capture_state_ = GraphCaptureState::Capturing; +} + +void WebGpuContext::Replay(const std::vector& captured_commands, const webgpu::BufferManager& buffer_manager) { + LOGS_DEFAULT(VERBOSE) << "Replay with external storage"; + graph_capture_state_ = GraphCaptureState::Replaying; + // Replay all captured commands from the provided vector + const size_t command_count = captured_commands.size(); + for (size_t i = 0; i < command_count; ++i) { + auto& command = captured_commands[i]; + const auto& compute_pass_encoder = GetComputePassEncoder(); + WriteTimestamp(num_pending_dispatches_ * 2); + compute_pass_encoder.SetPipeline(command.compute_pipeline); + wgpuComputePassEncoderSetBindGroup(compute_pass_encoder.Get(), 0, command.bind_group, 0, nullptr); + compute_pass_encoder.DispatchWorkgroups(command.dispatch_group[0], command.dispatch_group[1], command.dispatch_group[2]); + WriteTimestamp(num_pending_dispatches_ * 2 + 1); + ++num_pending_dispatches_; + if (num_pending_dispatches_ >= max_num_pending_dispatches_ || + (is_profiling_ && query_type_ == TimestampQueryType::AtPasses)) { + EndComputePass(); + } + if (num_pending_dispatches_ >= max_num_pending_dispatches_) { + Flush(buffer_manager); + num_pending_dispatches_ = 0; + } + } + + // Flush any remaining commands + Flush(buffer_manager); + + graph_capture_state_ = GraphCaptureState::Default; +} + +void WebGpuContext::CaptureEnd() { + LOGS_DEFAULT(VERBOSE) << "CaptureEnd"; - // TODO support graph capture + graph_capture_state_ = GraphCaptureState::Default; + external_captured_commands_ = nullptr; +} - compute_pass_encoder.SetPipeline(program_artifact.compute_pipeline); - wgpuComputePassEncoderSetBindGroup(compute_pass_encoder.Get(), 0, bind_group, 0, nullptr); - compute_pass_encoder.DispatchWorkgroups(x, y, z); +void WebGpuContext::ReleaseGraphResources(std::vector& captured_commands) { + LOGS_DEFAULT(VERBOSE) << "ReleaseGraphResources: Releasing " << captured_commands.size() << " captured command resources"; - wgpuBindGroupRelease(bind_group); - wgpuBindGroupLayoutRelease(bind_group_layout); + for (auto& command : captured_commands) { + if (command.bind_group != nullptr) { + wgpuBindGroupRelease(command.bind_group); + command.bind_group = nullptr; + } + + if (command.bind_group_layout != nullptr) { + wgpuBindGroupLayoutRelease(command.bind_group_layout); + command.bind_group_layout = nullptr; + } + } } std::unordered_map WebGpuContextFactory::contexts_; diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.h b/onnxruntime/core/providers/webgpu/webgpu_context.h index 4111f809b1627..3084483db522d 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.h +++ b/onnxruntime/core/providers/webgpu/webgpu_context.h @@ -10,7 +10,6 @@ #include "core/common/common.h" #include "core/framework/library_handles.h" -#include "core/providers/webgpu/webgpu_execution_provider.h" #include "core/providers/webgpu/buffer_manager.h" #include "core/providers/webgpu/program_manager.h" @@ -26,6 +25,14 @@ class WebGpuContext; class ComputeContext; class ProgramBase; +// Definition for CapturedCommandInfo in the webgpu namespace +struct CapturedCommandInfo { + wgpu::ComputePipeline compute_pipeline; + WGPUBindGroup bind_group; + WGPUBindGroupLayout bind_group_layout; + std::array dispatch_group; +}; + struct WebGpuContextConfig { int context_id; WGPUInstance instance; @@ -118,8 +125,12 @@ class WebGpuContext final { current_compute_pass_encoder_ = nullptr; } } + void CaptureBegin(std::vector* captured_commands, const webgpu::BufferManager& buffer_manager); + void CaptureEnd(); + void Replay(const std::vector& captured_commands, const webgpu::BufferManager& buffer_manager); + void ReleaseGraphResources(std::vector& captured_commands); - void Flush(); + void Flush(const webgpu::BufferManager& buffer_mgr); webgpu::BufferManager& BufferManager() const { return *buffer_mgr_; } @@ -243,6 +254,10 @@ class WebGpuContext final { uint64_t gpu_timestamp_offset_ = 0; bool is_profiling_ = false; bool preserve_device_; + GraphCaptureState graph_capture_state_{GraphCaptureState::Default}; + + // External vector to store captured commands, owned by EP + std::vector* external_captured_commands_ = nullptr; #if defined(ENABLE_PIX_FOR_WEBGPU_EP) std::unique_ptr pix_frame_generator_ = nullptr; diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc index 13c746a6b1d31..460d220ecf1b9 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc @@ -772,11 +772,21 @@ WebGpuExecutionProvider::WebGpuExecutionProvider(int context_id, context_{context}, preferred_data_layout_{config.data_layout}, force_cpu_node_names_{std::move(config.force_cpu_node_names)}, - enable_graph_capture_{config.enable_graph_capture} {} + enable_graph_capture_{config.enable_graph_capture} { + // If graph capture is enabled, create a dedicated buffer manager for graph mode + if (enable_graph_capture_) { + // Create buffer manager for graph capture mode with appropriate cache modes + graph_buffer_mgr_ = webgpu::BufferManagerFactory::Create( + context_, + webgpu::BufferCacheMode::Graph, + webgpu::BufferCacheMode::GraphSimple, + webgpu::BufferCacheMode::Disabled); + } +} std::vector WebGpuExecutionProvider::CreatePreferredAllocators() { AllocatorCreationInfo gpuBufferAllocatorCreationInfo([&](int) { - return std::make_unique(context_); + return std::make_unique(BufferManager()); }, 0, false); auto preferred_allocators = std::vector{CreateAllocator(gpuBufferAllocatorCreationInfo)}; @@ -846,7 +856,7 @@ std::shared_ptr WebGpuExecutionProvider::GetKernelRegistry() con } std::unique_ptr WebGpuExecutionProvider::GetDataTransfer() const { - return std::make_unique(context_); + return std::make_unique(BufferManager()); } #if defined(__wasm__) @@ -871,6 +881,12 @@ std::optional WebGpuExecutionProvider::ShouldConvertDataLayoutForOp(std::s } WebGpuExecutionProvider::~WebGpuExecutionProvider() { + // Release all resources associated with the captured graph + if (!captured_commands_.empty()) { + context_.ReleaseGraphResources(captured_commands_); + } + // The graph_buffer_mgr_ will be automatically cleaned up by unique_ptr + WebGpuContextFactory::ReleaseContext(context_id_); } @@ -897,23 +913,24 @@ Status WebGpuExecutionProvider::OnRunStart(const onnxruntime::RunOptions& /*run_ } if (IsGraphCaptureEnabled() && IsGraphCaptureAllowed() && !IsGraphCaptured(0)) { - ORT_NOT_IMPLEMENTED("graph capture not implemented"); + context_.CaptureBegin(&captured_commands_, *graph_buffer_mgr_); } + return Status::OK(); } Status WebGpuExecutionProvider::OnRunEnd(bool /* sync_stream */, const onnxruntime::RunOptions& /*run_options*/) { + context_.Flush(BufferManager()); + if (IsGraphCaptureEnabled() && !IsGraphCaptured(0)) { if (IsGraphCaptureAllowed()) { - ORT_NOT_IMPLEMENTED("graph capture not implemented"); - // is_graph_captured_ = true; + context_.CaptureEnd(); + is_graph_captured_ = true; } else { IncrementRegularRunCountBeforeGraphCapture(); } } - context_.Flush(); - if (profiler_->Enabled()) { context_.CollectProfilingData(profiler_->Events()); } @@ -937,10 +954,18 @@ bool WebGpuExecutionProvider::IsGraphCaptured(int) const { Status WebGpuExecutionProvider::ReplayGraph(int) { ORT_ENFORCE(IsGraphCaptured(0)); - ORT_ENFORCE(false); + context_.Replay(captured_commands_, *graph_buffer_mgr_); return Status::OK(); } +webgpu::BufferManager& WebGpuExecutionProvider::BufferManager() const { + if (graph_buffer_mgr_) { + return *graph_buffer_mgr_; + } else { + return context_.BufferManager(); + } +} + bool WebGpuExecutionProvider::IsGraphCaptureAllowed() const { return regular_run_count_before_graph_capture_ >= min_num_runs_before_cuda_graph_capture_; } diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h index 2003f9b2ebcc6..2567be2a1eb18 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h @@ -8,6 +8,7 @@ #include "core/framework/session_options.h" #include "core/graph/constants.h" #include "core/providers/providers.h" +#include "core/providers/webgpu/buffer_manager.h" struct pthreadpool; namespace onnxruntime { @@ -18,9 +19,11 @@ template KernelCreateInfo BuildKernelCreateInfo(); class WebGpuContext; -enum class BufferCacheMode; class WebGpuProfiler; class GpuBufferAllocator; + +// Forward declare CapturedCommandInfo which is now defined in webgpu_context.h +struct CapturedCommandInfo; } // namespace webgpu struct WebGpuExecutionProviderConfig { @@ -81,10 +84,12 @@ class WebGpuExecutionProvider : public IExecutionProvider { bool IsGraphCaptureEnabled() const override; bool IsGraphCaptured(int graph_annotation_id) const override; Status ReplayGraph(int graph_annotation_id) override; + webgpu::BufferManager& BufferManager() const; private: bool IsGraphCaptureAllowed() const; void IncrementRegularRunCountBeforeGraphCapture(); + int context_id_; webgpu::WebGpuContext& context_; webgpu::WebGpuProfiler* profiler_ = nullptr; @@ -95,6 +100,12 @@ class WebGpuExecutionProvider : public IExecutionProvider { int regular_run_count_before_graph_capture_ = 0; const int min_num_runs_before_cuda_graph_capture_ = 1; // required min regular runs before graph capture for the necessary memory allocations. webgpu::GpuBufferAllocator* allocator_ = nullptr; + + // Buffer manager specifically for graph capture mode + std::unique_ptr graph_buffer_mgr_ = nullptr; + + // Store captured commands directly in the EP instead of in WebGpuContext + std::vector captured_commands_; }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_kernel.h b/onnxruntime/core/providers/webgpu/webgpu_kernel.h index d7682e751d9e4..e37be2944a22b 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_kernel.h +++ b/onnxruntime/core/providers/webgpu/webgpu_kernel.h @@ -9,6 +9,8 @@ #include "core/framework/op_kernel.h" namespace onnxruntime { + +class WebGpuExecutionProvider; namespace webgpu { // ----------------------------------------------------------------------- @@ -17,11 +19,12 @@ namespace webgpu { class WebGpuKernel : public OpKernel { public: explicit WebGpuKernel(const OpKernelInfo& info) - : OpKernel(info) { + : OpKernel(info), + ep_(*static_cast(info.GetExecutionProvider())) { } Status Compute(OpKernelContext* p_op_kernel_context) const override { - ComputeContext context{*p_op_kernel_context}; + ComputeContext context{*p_op_kernel_context, ep_}; context.PushErrorScope(); Status s = ComputeInternal(context); @@ -31,6 +34,9 @@ class WebGpuKernel : public OpKernel { } virtual Status ComputeInternal(ComputeContext& context) const = 0; + + private: + const WebGpuExecutionProvider& ep_; }; } // namespace webgpu diff --git a/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc b/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc index d6812b2d0704d..80b3988215c6b 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc @@ -220,10 +220,12 @@ std::shared_ptr WebGpuProviderFactoryCreator::Create( webgpu::WebGpuBufferCacheConfig buffer_cache_config; - buffer_cache_config.storage.mode = parse_buffer_cache_mode(kStorageBufferCacheMode, webgpu::BufferCacheMode::Bucket); + buffer_cache_config.storage.mode = parse_buffer_cache_mode(kStorageBufferCacheMode, + webgpu::BufferCacheMode::Bucket); LOGS_DEFAULT(VERBOSE) << "WebGPU EP storage buffer cache mode: " << buffer_cache_config.storage.mode; - buffer_cache_config.uniform.mode = parse_buffer_cache_mode(kUniformBufferCacheMode, webgpu::BufferCacheMode::Simple); + buffer_cache_config.uniform.mode = parse_buffer_cache_mode(kUniformBufferCacheMode, + webgpu::BufferCacheMode::Simple); LOGS_DEFAULT(VERBOSE) << "WebGPU EP uniform buffer cache mode: " << buffer_cache_config.uniform.mode; buffer_cache_config.query_resolve.mode = parse_buffer_cache_mode(kQueryResolveBufferCacheMode, webgpu::BufferCacheMode::Disabled); diff --git a/onnxruntime/core/providers/webnn/builders/impl/matMulNBits_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/matMulNBits_op_builder.cc index 521aa4a4bfc5a..111d03571e974 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/matMulNBits_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/matMulNBits_op_builder.cc @@ -100,20 +100,25 @@ Status MatMulNBitsBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, // x_zero_point has the same shape as x_scale const bool has_zero_points = TensorExists(input_defs, 3); emscripten::val x_zero_point = emscripten::val::undefined(); + emscripten::val zero_points_desc = emscripten::val::object(); + zero_points_desc.set("dataType", emscripten::val("uint4")); + zero_points_desc.set("shape", x_scale_shape_array); + zero_points_desc.set("dimensions", x_scale_shape_array); if (has_zero_points) { // zero_points is an initializer with data type 'uint8', we need to register it as 'uint4' WebNN constant const auto zero_points_tensor = *initializers.at(input_defs[3]->Name()); - emscripten::val zero_points_desc = emscripten::val::object(); - zero_points_desc.set("dataType", emscripten::val("uint4")); - zero_points_desc.set("shape", x_scale_shape_array); - zero_points_desc.set("dimensions", x_scale_shape_array); ORT_RETURN_IF_ERROR(model_builder.RegisterConstant(zero_points_tensor, x_zero_point, zero_points_desc, logger)); } else { // zero_points' default value is 8, referred from CPU EP const int8_t default_zero_point = 8; - x_zero_point = model_builder.CreateOrGetConstant(ONNX_NAMESPACE::TensorProto_DataType_UINT4, - default_zero_point, - x_scale_shape); + // Always create a new WebNN constant for zero_points to facilitate MatMulNBits fusion in Chromium + auto num_elements = (Product(x_scale_shape) + 1) / 2; + emscripten::val default_zero_point_buffer = emscripten::val::global("Uint8Array").new_(num_elements); + default_zero_point_buffer.call("fill", + emscripten::val(PackInt8ToUint8DoubledNibbles( + default_zero_point, ONNX_NAMESPACE::TensorProto_DataType_UINT4))); + x_zero_point = + model_builder.GetBuilder().call("constant", zero_points_desc, default_zero_point_buffer); } // DequantizeLinear diff --git a/onnxruntime/core/session/ep_api.cc b/onnxruntime/core/session/ep_api.cc index bbadfbee70656..ad965845041f7 100644 --- a/onnxruntime/core/session/ep_api.cc +++ b/onnxruntime/core/session/ep_api.cc @@ -5,6 +5,8 @@ #include #include + +#include "core/common/semver.h" #include "core/framework/error_code_helper.h" #include "core/framework/func_api.h" #include "core/framework/ort_value.h" @@ -14,6 +16,7 @@ #include "core/graph/ep_api_types.h" #include "core/session/abi_devices.h" #include "core/session/abi_ep_types.h" +#include "core/session/onnxruntime_ep_device_ep_metadata_keys.h" #include "core/session/ort_apis.h" using namespace onnxruntime; @@ -34,6 +37,21 @@ ORT_API_STATUS_IMPL(CreateEpDevice, _In_ OrtEpFactory* ep_factory, ep_device->ep_metadata = *ep_metadata; } + // Add EP version from OrtEpFactory to metadata. OrtEpFactory::GetVersion is supported since 1.23. + if (ep_factory->ort_version_supported >= uint32_t{23}) { + if (ep_device->ep_metadata.Entries().find(kOrtEpDevice_EpMetadataKey_Version) != + ep_device->ep_metadata.Entries().end()) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, + "The provided EP metadata should not explicitly specify the EP version."); + } + + { + std::string ep_version = ep_factory->GetVersion(ep_factory); + ORT_API_RETURN_IF_STATUS_NOT_OK(ParseSemVerVersion(ep_version, nullptr)); + ep_device->ep_metadata.Add(kOrtEpDevice_EpMetadataKey_Version, std::move(ep_version)); + } + } + if (ep_options) { ep_device->ep_options = *ep_options; } diff --git a/onnxruntime/core/session/ep_api_utils.h b/onnxruntime/core/session/ep_api_utils.h index 366f934fc610e..daccd24453371 100644 --- a/onnxruntime/core/session/ep_api_utils.h +++ b/onnxruntime/core/session/ep_api_utils.h @@ -16,6 +16,10 @@ struct ForwardToFactory { return static_cast(this_ptr)->GetVendor(); } + static const char* ORT_API_CALL GetVersion(const OrtEpFactory* this_ptr) noexcept { + return static_cast(this_ptr)->GetVersion(); + } + static OrtStatus* ORT_API_CALL GetSupportedDevices(OrtEpFactory* this_ptr, const OrtHardwareDevice* const* devices, size_t num_devices, diff --git a/onnxruntime/core/session/ep_factory_internal.cc b/onnxruntime/core/session/ep_factory_internal.cc index b906f25935983..b289010cc6c5b 100644 --- a/onnxruntime/core/session/ep_factory_internal.cc +++ b/onnxruntime/core/session/ep_factory_internal.cc @@ -8,6 +8,7 @@ #include "core/session/abi_session_options_impl.h" #include "core/session/ep_api_utils.h" #include "core/session/ort_apis.h" +#include "onnxruntime_config.h" // for ORT_VERSION namespace onnxruntime { @@ -24,11 +25,16 @@ EpFactoryInternal::EpFactoryInternal(const std::string& ep_name, const std::stri OrtEpFactory::GetName = Forward::GetFactoryName; OrtEpFactory::GetVendor = Forward::GetVendor; + OrtEpFactory::GetVersion = Forward::GetVersion; OrtEpFactory::GetSupportedDevices = Forward::GetSupportedDevices; OrtEpFactory::CreateEp = Forward::CreateEp; OrtEpFactory::ReleaseEp = Forward::ReleaseEp; } +const char* EpFactoryInternal::GetVersion() const noexcept { + return ORT_VERSION; +} + OrtStatus* EpFactoryInternal::GetSupportedDevices(const OrtHardwareDevice* const* devices, size_t num_devices, OrtEpDevice** ep_devices, diff --git a/onnxruntime/core/session/ep_factory_internal.h b/onnxruntime/core/session/ep_factory_internal.h index 1951b51a38bee..087c0c60f8f4e 100644 --- a/onnxruntime/core/session/ep_factory_internal.h +++ b/onnxruntime/core/session/ep_factory_internal.h @@ -39,6 +39,7 @@ class EpFactoryInternal : public OrtEpFactory { const char* GetName() const noexcept { return ep_name_.c_str(); } const char* GetVendor() const noexcept { return vendor_.c_str(); } + const char* GetVersion() const noexcept; OrtStatus* GetSupportedDevices(_In_reads_(num_devices) const OrtHardwareDevice* const* devices, _In_ size_t num_devices, diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index 8cd16fb4e7347..3db35ae8769e0 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -1389,8 +1389,9 @@ struct ProviderHostImpl : ProviderHost { ONNX_NAMESPACE::GraphProto& graph_proto, bool include_initializers, bool include_outer_scope_args, - int execution_order) noexcept override { - GraphViewerToProto(*p, graph_proto, include_initializers, include_outer_scope_args, static_cast(execution_order)); + int execution_order, + bool include_initializer_data) noexcept override { + GraphViewerToProto(*p, graph_proto, include_initializers, include_outer_scope_args, static_cast(execution_order), include_initializer_data); } const Node* GraphViewer__GetProducerNode(const GraphViewer* p, const std::string& node_arg_name) const override { return p->GetProducerNode(node_arg_name); } IOnnxRuntimeOpSchemaCollectionPtr GraphViewer__GetSchemaRegistry(const GraphViewer* p) const override { return p->GetSchemaRegistry(); } diff --git a/onnxruntime/python/tools/quantization/onnx_quantizer.py b/onnxruntime/python/tools/quantization/onnx_quantizer.py index b0a78281041d0..148e4c06a8051 100644 --- a/onnxruntime/python/tools/quantization/onnx_quantizer.py +++ b/onnxruntime/python/tools/quantization/onnx_quantizer.py @@ -28,6 +28,7 @@ get_qmin_qmax_for_qType, get_qrange_for_qType, ms_domain, + quantize_onnx_initializer, save_and_reload_model_with_shape_infer, tensor_proto_to_array, ) @@ -635,6 +636,137 @@ def find_quantized_value(self, input_name): return self.parent.find_quantized_value(input_name) return None + def adjust_single_weight_scale_if_needed( + self, + bias_val, + input_scale, + weight_scale, + weight_scale_dtype, + weight_name, + bias_name, + qrange, + multiplicative_epsilon, + idx=None, + ): + """Adjust a single weight scale to ensure the int32 bias does not overflow.""" + absmax = np.abs(bias_val) + bias_smallest_valid_scale = multiplicative_epsilon * (2.0 * absmax) / qrange + + input_scale_fp64 = np.array(input_scale.item(), dtype=np.float64) + weight_scale_fp64 = np.array(weight_scale.item(), dtype=np.float64) + bias_candidate_scale = input_scale_fp64 * weight_scale_fp64 + + if (bias_candidate_scale < bias_smallest_valid_scale) and (bias_candidate_scale > 0.0): + ratio = bias_smallest_valid_scale / bias_candidate_scale + new_scale = weight_scale_fp64 * ratio + if idx is None: + logging.info( + f"Increasing scale for weight `{weight_name}` by the ratio {ratio} to " + f"ensure bias `{bias_name}` has a valid scale." + ) + return True, np.array(new_scale, dtype=weight_scale_dtype) + else: + logging.info( + f"Increased scale[{idx}] for weight `{weight_name}` by ratio {ratio} " + f"to ensure bias `{bias_name}` has a valid scale." + ) + return True, new_scale.astype(weight_scale_dtype) + return False, weight_scale + + def _adjust_weight_scale_for_int32_bias( + self, + input_scale: np.ndarray, + weight_scale: np.ndarray, + weight_name: str, + bias_tp: onnx.TensorProto, + is_per_channel: bool, + ) -> tuple[bool, np.ndarray | None]: + """Checks if the bias scale is too small and increases the weight scale if needed.""" + + if not weight_scale.size: + return False, None + + bias_float_data = tensor_proto_to_array(bias_tp) + int32_info = np.iinfo(np.int32) + multiplicative_epsilon = 1.0001 + qrange = np.array(int32_info.max, dtype=np.float64) - np.array(int32_info.min + 1, dtype=np.float64) + weight_scale_dtype = weight_scale.dtype + updated = False + + if not is_per_channel: + rmin = np.minimum(bias_float_data.min(), np.array(0, dtype=np.float64)) + rmax = np.maximum(bias_float_data.max(), np.array(0, dtype=np.float64)) + absmax = np.maximum(np.abs(rmin), np.abs(rmax)) + changed, new_scale = self.adjust_single_weight_scale_if_needed( + absmax, + input_scale, + weight_scale, + weight_scale_dtype, + weight_name, + bias_tp.name, + qrange, + multiplicative_epsilon, + ) + if changed: + weight_scale = new_scale + updated = True + elif weight_scale.shape and len(weight_scale.shape) == 1: + for i in range(weight_scale.shape[0]): + changed, new_scale = self.adjust_single_weight_scale_if_needed( + bias_float_data[i], + input_scale, + weight_scale[i], + weight_scale_dtype, + weight_name, + bias_tp.name, + qrange, + multiplicative_epsilon, + idx=i, + ) + if changed: + weight_scale[i] = new_scale + updated = True + + return updated, weight_scale + + def _requantize_weight(self, weight_name: str, new_scale: np.ndarray) -> None: + """Re-quantizes the given weight initializer using the provided scale.""" + + if weight_name not in self.quantized_value_map: + return + + qv = self.quantized_value_map[weight_name] + + weight_tp = find_by_name(weight_name, self.model.initializer()) + scale_init = find_by_name(qv.scale_name, self.model.initializer()) + zp_init = find_by_name(qv.zp_name, self.model.initializer()) + q_weight_init = find_by_name(qv.q_name, self.model.initializer()) + + if weight_tp is None or scale_init is None or zp_init is None or q_weight_init is None: + return + + self.model.remove_initializer(scale_init) + self.model.remove_initializer(q_weight_init) + + weight_zero_point = onnx.numpy_helper.to_array(zp_init) + axis = qv.axis + + # Add new scale initializer + scale_np = np.asarray(new_scale, dtype=onnx.helper.tensor_dtype_to_np_dtype(weight_tp.data_type)) + new_scale_init = onnx.numpy_helper.from_array(scale_np.reshape(scale_init.dims), qv.scale_name) + self.model.add_initializer(new_scale_init) + + # Add new quantized weight initializer + new_q_weight = quantize_onnx_initializer( + weight_tp, + self.weight_qType, + weight_zero_point, + scale_np, + axis, + quant_weight_name=qv.q_name, + ) + self.model.add_initializer(new_q_weight) + def quantize_bias_static(self, bias_name, input_name, weight_name, beta=1.0): """ Quantized the bias. Zero Point == 0 and Scale == Input_Scale * Weight_Scale @@ -660,6 +792,29 @@ def quantize_bias_static(self, bias_name, input_name, weight_name, beta=1.0): inputscale_initializer = find_by_name(input_scale_name, self.model.initializer()) input_scale = tensor_proto_to_array(inputscale_initializer) + # Adjust weight scale if quantizing to int32 may overflow due to a small scale + weight_zp_name = self.quantized_value_map[weight_name].zp_name + weight_zp_init = find_by_name(weight_zp_name, self.model.initializer()) + weight_zero_point = onnx.numpy_helper.to_array(weight_zp_init) if weight_zp_init is not None else None + is_per_channel = self.per_channel + if ( + weight_zero_point is not None + and weight_zero_point.size + and not weight_zero_point.any() + and self.weight_qType in (onnx_proto.TensorProto.INT8,) + ): + bias_initializer = find_by_name(bias_name, self.model.initializer()) + did_update, new_weight_scale = self._adjust_weight_scale_for_int32_bias( + input_scale, + weight_scale, + weight_name, + bias_initializer, + is_per_channel, + ) + if did_update: + self._requantize_weight(weight_name, new_weight_scale) + weight_scale = new_weight_scale + ( quantized_bias_name, quantized_bias_scale_name, diff --git a/onnxruntime/test/autoep/library/ep_factory.cc b/onnxruntime/test/autoep/library/ep_factory.cc index c2fa5ec88a0d8..d4895102b0bf1 100644 --- a/onnxruntime/test/autoep/library/ep_factory.cc +++ b/onnxruntime/test/autoep/library/ep_factory.cc @@ -14,6 +14,7 @@ ExampleEpFactory::ExampleEpFactory(const char* ep_name, ApiPtrs apis) ort_version_supported = ORT_API_VERSION; // set to the ORT version we were compiled with. GetName = GetNameImpl; GetVendor = GetVendorImpl; + GetVersion = GetVersionImpl; GetSupportedDevices = GetSupportedDevicesImpl; @@ -86,6 +87,12 @@ const char* ORT_API_CALL ExampleEpFactory::GetVendorImpl(const OrtEpFactory* thi return factory->vendor_.c_str(); } +/*static*/ +const char* ORT_API_CALL ExampleEpFactory::GetVersionImpl(const OrtEpFactory* this_ptr) noexcept { + const auto* factory = static_cast(this_ptr); + return factory->ep_version_.c_str(); +} + /*static*/ OrtStatus* ORT_API_CALL ExampleEpFactory::GetSupportedDevicesImpl(OrtEpFactory* this_ptr, const OrtHardwareDevice* const* devices, @@ -107,7 +114,7 @@ OrtStatus* ORT_API_CALL ExampleEpFactory::GetSupportedDevicesImpl(OrtEpFactory* factory->ort_api.CreateKeyValuePairs(&ep_options); // random example using made up values - factory->ort_api.AddKeyValuePair(ep_metadata, "version", "0.1"); + factory->ort_api.AddKeyValuePair(ep_metadata, "supported_devices", "CrackGriffin 7+"); factory->ort_api.AddKeyValuePair(ep_options, "run_really_fast", "true"); // OrtEpDevice copies ep_metadata and ep_options. @@ -136,7 +143,7 @@ OrtStatus* ORT_API_CALL ExampleEpFactory::GetSupportedDevicesImpl(OrtEpFactory* // if (device.Type() == OrtHardwareDeviceType::OrtHardwareDeviceType_CPU) { // Ort::KeyValuePairs ep_metadata; // Ort::KeyValuePairs ep_options; - // ep_metadata.Add("version", "0.1"); + // ep_metadata.Add("supported_devices", "CrackGriffin 7+"); // ep_options.Add("run_really_fast", "true"); // Ort::EpDevice ep_device{*this_ptr, device, ep_metadata.GetConst(), ep_options.GetConst()}; // ep_devices[num_ep_devices++] = ep_device.release(); diff --git a/onnxruntime/test/autoep/library/ep_factory.h b/onnxruntime/test/autoep/library/ep_factory.h index 8ab67fc9d8ce6..fda77f12c4814 100644 --- a/onnxruntime/test/autoep/library/ep_factory.h +++ b/onnxruntime/test/autoep/library/ep_factory.h @@ -22,6 +22,8 @@ class ExampleEpFactory : public OrtEpFactory, public ApiPtrs { static const char* ORT_API_CALL GetVendorImpl(const OrtEpFactory* this_ptr) noexcept; + static const char* ORT_API_CALL GetVersionImpl(const OrtEpFactory* this_ptr) noexcept; + static OrtStatus* ORT_API_CALL GetSupportedDevicesImpl(OrtEpFactory* this_ptr, const OrtHardwareDevice* const* devices, size_t num_devices, @@ -49,8 +51,9 @@ class ExampleEpFactory : public OrtEpFactory, public ApiPtrs { static OrtStatus* ORT_API_CALL CreateDataTransferImpl(OrtEpFactory* this_ptr, OrtDataTransferImpl** data_transfer) noexcept; - const std::string ep_name_; // EP name - const std::string vendor_{"Contoso"}; // EP vendor name + const std::string ep_name_; // EP name + const std::string vendor_{"Contoso"}; // EP vendor name + const std::string ep_version_{"0.1.0"}; // EP version // CPU allocator so we can control the arena behavior. optional as ORT always provides a CPU allocator if needed. using MemoryInfoUniquePtr = std::unique_ptr>; diff --git a/onnxruntime/test/autoep/test_autoep_selection.cc b/onnxruntime/test/autoep/test_autoep_selection.cc index be20d2c7c5a60..01dece34e50b0 100644 --- a/onnxruntime/test/autoep/test_autoep_selection.cc +++ b/onnxruntime/test/autoep/test_autoep_selection.cc @@ -14,6 +14,7 @@ #include "core/session/abi_key_value_pairs.h" #include "core/session/abi_session_options_impl.h" #include "core/session/onnxruntime_cxx_api.h" +#include "core/session/onnxruntime_ep_device_ep_metadata_keys.h" #include "test_allocator.h" #include "test/shared_lib/utils.h" @@ -564,7 +565,8 @@ TEST(OrtEpLibrary, LoadUnloadPluginLibraryCxxApi) { ASSERT_STREQ(test_ep_device->EpVendor(), "Contoso"); auto metadata = test_ep_device->EpMetadata(); - ASSERT_STREQ(metadata.GetValue("version"), "0.1"); + ASSERT_STREQ(metadata.GetValue(kOrtEpDevice_EpMetadataKey_Version), "0.1.0"); + ASSERT_STREQ(metadata.GetValue("supported_devices"), "CrackGriffin 7+"); auto options = test_ep_device->EpOptions(); ASSERT_STREQ(options.GetValue("run_really_fast"), "true"); diff --git a/onnxruntime/test/common/semver_test.cc b/onnxruntime/test/common/semver_test.cc new file mode 100644 index 0000000000000..5ec066e59b838 --- /dev/null +++ b/onnxruntime/test/common/semver_test.cc @@ -0,0 +1,37 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/common/semver.h" + +#include "gtest/gtest.h" + +#include "test/util/include/asserts.h" + +namespace onnxruntime::test { + +TEST(SemVerParsingTest, Basic) { + { + auto semver = ParseSemVerVersion("1.2.3-abcde+fghij"); + EXPECT_EQ(semver.major, 1); + EXPECT_EQ(semver.minor, 2); + EXPECT_EQ(semver.patch, 3); + EXPECT_EQ(semver.prerelease, "abcde"); + EXPECT_EQ(semver.build_metadata, "fghij"); + } + + { + auto semver = ParseSemVerVersion("1.2.3"); + EXPECT_EQ(semver.major, 1); + EXPECT_EQ(semver.minor, 2); + EXPECT_EQ(semver.patch, 3); + EXPECT_EQ(semver.prerelease, std::nullopt); + EXPECT_EQ(semver.build_metadata, std::nullopt); + } +} + +TEST(SemVerParsingTest, Invalid) { + SemVerVersion semver{}; + ASSERT_STATUS_NOT_OK(ParseSemVerVersion("version one point zero", &semver)); +} + +} // namespace onnxruntime::test diff --git a/onnxruntime/test/framework/inference_session_test.cc b/onnxruntime/test/framework/inference_session_test.cc index 2ce3c4859394d..add9fa6a504c9 100644 --- a/onnxruntime/test/framework/inference_session_test.cc +++ b/onnxruntime/test/framework/inference_session_test.cc @@ -217,7 +217,7 @@ static void CreateMatMulModel(std::unique_ptr& p_model, Prov if (provider_type == kCpuExecutionProvider) { node.SetExecutionProviderType(provider_type); } else { -#if defined(USE_CUDA) || defined(USE_ROCM) +#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_WEBGPU) node.SetExecutionProviderType(provider_type); #endif } @@ -286,55 +286,89 @@ void RunModelWithBindingMatMul(InferenceSession& session_object, bool is_preallocate_output_vec, ProviderType allocation_provider, IExecutionProvider* gpu_provider, - OrtDevice* output_device) { + OrtDevice* output_device, + bool enable_graph_capture) { std::unique_ptr io_binding; Status st = session_object.NewIOBinding(&io_binding); ASSERT_TRUE(st.IsOK()); - auto input_allocator = io_binding->GetCPUAllocator(bind_provider_type); // bind a value to A with input that will produce invalid output in order to test replacement of a feed std::vector values_mul_x_tmp = {12.f, 11.f, 10.f, 9.f, 8.f, 7.f, 6.f, 5.f, 4.f, 3.f, 2.f, 1.f}; std::vector dims_mul_x_A_tmp = {3, 4}; - OrtValue input_tmp; - CreateMLValue(input_allocator, dims_mul_x_A_tmp, values_mul_x_tmp, &input_tmp); - ASSERT_STATUS_OK(io_binding->BindInput("A", input_tmp)); - const void* tmp_A = io_binding->GetInputs()[0].Get().DataRaw(); // location of data post binding - - // prepare inputs std::vector values_mul_x = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f}; - - /* - 0 1 2 3 0 1 2 - 4 5 6 7 3 4 5 - 8 9 10 11 6 7 8 - 9 10 11 - */ - // bind one input to cpu allocator from bind_provider_type, and another on user provided CPU memory - // so both code pathes are covered - OrtValue input_ml_value_A; std::vector dims_mul_x_A = {3, 4}; - CreateMLValue(input_allocator, dims_mul_x_A, values_mul_x, &input_ml_value_A); - - OrtValue input_ml_value_B; std::vector dims_mul_x_B = {4, 3}; - CreateMLValue(TestCPUExecutionProvider()->CreatePreferredAllocators()[0], dims_mul_x_B, values_mul_x, - &input_ml_value_B); - - ASSERT_STATUS_OK(io_binding->BindInput("A", input_ml_value_A)); - ASSERT_STATUS_OK(io_binding->BindInput("B", input_ml_value_B)); - // check location of 'A' post-binding has changed to validate that the previous value was replaced - ASSERT_TRUE(io_binding->GetInputs()[0].Get().DataRaw() != tmp_A); + auto cpu_alloc = TestCPUExecutionProvider()->CreatePreferredAllocators()[0]; + onnxruntime::AllocatorPtr gpu_alloc = nullptr; + if (allocation_provider == kWebGpuExecutionProvider) { + // Use session_object.GetAllocator to get the OrtAllocator for WebGPU. + // Otherwise, gpu_provider->CreatePreferredAllocators() will create a new OrtAllocator which will go to the create UMA path. + // And it can't be used for copying buffer to buffer since the target buffer is still in mapped state. + OrtMemoryInfo mem_info(WEBGPU_BUFFER, OrtAllocatorType::OrtDeviceAllocator, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NONE, 0)); + gpu_alloc = session_object.GetAllocator(mem_info); + } else if (allocation_provider == kCudaExecutionProvider || allocation_provider == kRocmExecutionProvider) { + gpu_alloc = gpu_provider->CreatePreferredAllocators()[0]; + } + if (enable_graph_capture) { + // For graph capture, all inputs/outputs should be in preallocated gpu memory. + ASSERT_TRUE(is_preallocate_output_vec); + OrtValue input_ml_value_A_cpu; + CreateMLValue(cpu_alloc, dims_mul_x_A, values_mul_x, &input_ml_value_A_cpu); + auto& cpu_tensor_a = input_ml_value_A_cpu.Get(); + Tensor gpu_tensor_a(cpu_tensor_a.DataType(), cpu_tensor_a.Shape(), gpu_alloc); + st = gpu_provider->GetDataTransfer()->CopyTensor(cpu_tensor_a, gpu_tensor_a); + ASSERT_TRUE(st.IsOK()); + OrtValue input_ml_value_A; + Tensor::InitOrtValue(std::move(gpu_tensor_a), input_ml_value_A); + + OrtValue input_ml_value_B_cpu; + CreateMLValue(cpu_alloc, dims_mul_x_B, values_mul_x, &input_ml_value_B_cpu); + auto& cpu_tensor_b = input_ml_value_B_cpu.Get(); + Tensor gpu_tensor_b(cpu_tensor_b.DataType(), cpu_tensor_b.Shape(), gpu_alloc); + st = gpu_provider->GetDataTransfer()->CopyTensor(cpu_tensor_b, gpu_tensor_b); + ASSERT_TRUE(st.IsOK()); + OrtValue input_ml_value_B; + Tensor::InitOrtValue(std::move(gpu_tensor_b), input_ml_value_B); + ASSERT_STATUS_OK(io_binding->BindInput("A", input_ml_value_A)); + ASSERT_STATUS_OK(io_binding->BindInput("B", input_ml_value_B)); + } else { + auto input_allocator = io_binding->GetCPUAllocator(bind_provider_type); + OrtValue input_tmp; + CreateMLValue(input_allocator, dims_mul_x_A_tmp, values_mul_x_tmp, &input_tmp); + ASSERT_STATUS_OK(io_binding->BindInput("A", input_tmp)); + const void* tmp_A = io_binding->GetInputs()[0].Get().DataRaw(); // location of data post binding + + // prepare inputs + /* + 0 1 2 3 0 1 2 + 4 5 6 7 3 4 5 + 8 9 10 11 6 7 8 + 9 10 11 + */ + // bind one input to cpu allocator from bind_provider_type, and another on user provided CPU memory + // so both code pathes are covered + OrtValue input_ml_value_A; + CreateMLValue(input_allocator, dims_mul_x_A, values_mul_x, &input_ml_value_A); + + OrtValue input_ml_value_B; + CreateMLValue(cpu_alloc, dims_mul_x_B, values_mul_x, &input_ml_value_B); + + ASSERT_STATUS_OK(io_binding->BindInput("A", input_ml_value_A)); + ASSERT_STATUS_OK(io_binding->BindInput("B", input_ml_value_B)); + + // check location of 'A' post-binding has changed to validate that the previous value was replaced + ASSERT_TRUE(io_binding->GetInputs()[0].Get().DataRaw() != tmp_A); + } // prepare outputs std::vector expected_output_dims = {3, 3}; OrtValue output_ml_value; if (is_preallocate_output_vec) { if (allocation_provider == kCpuExecutionProvider) { - AllocateMLValue(TestCPUExecutionProvider()->CreatePreferredAllocators()[0], expected_output_dims, - &output_ml_value); - } else if (allocation_provider == kCudaExecutionProvider || allocation_provider == kRocmExecutionProvider) { - AllocateMLValue(gpu_provider->CreatePreferredAllocators()[0], expected_output_dims, &output_ml_value); + AllocateMLValue(cpu_alloc, expected_output_dims, &output_ml_value); + } else if (allocation_provider == kCudaExecutionProvider || allocation_provider == kRocmExecutionProvider || allocation_provider == kWebGpuExecutionProvider) { + AllocateMLValue(gpu_alloc, expected_output_dims, &output_ml_value); } else { ORT_THROW("Unsupported provider"); } @@ -351,6 +385,7 @@ void RunModelWithBindingMatMul(InferenceSession& session_object, // prepare expected inputs and outputs std::vector expected_values_mul_y = {42, 48, 54, 114, 136, 158, 186, 224, 262}; + std::vector expected_values_mul_y_2 = {174, 216, 258, 102, 128, 154, 30, 40, 50}; // Now run st = session_object.Run(run_options, *io_binding.get()); @@ -358,24 +393,24 @@ void RunModelWithBindingMatMul(InferenceSession& session_object, std::cout << "Run returned status: " << st.ErrorMessage() << std::endl; ASSERT_TRUE(st.IsOK()); - if ((is_preallocate_output_vec && (allocation_provider == kCudaExecutionProvider || allocation_provider == kRocmExecutionProvider)) || + if ((is_preallocate_output_vec && (allocation_provider == kCudaExecutionProvider || allocation_provider == kRocmExecutionProvider || allocation_provider == kWebGpuExecutionProvider)) || (output_device && output_device->Type() == OrtDevice::GPU)) { -#if defined(USE_CUDA) || defined(USE_ROCM) +#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_WEBGPU) // in this case we need to copy the tensor from cuda to cpu std::vector& outputs = io_binding->GetOutputs(); ASSERT_EQ(1u, outputs.size()); auto& rtensor = outputs.front().Get(); auto element_type = rtensor.DataType(); auto& shape = rtensor.Shape(); - auto cpu_allocator = TestCPUExecutionProvider()->CreatePreferredAllocators()[0]; - std::unique_ptr cpu_tensor = std::make_unique(element_type, - shape, - cpu_allocator); + std::unique_ptr cpu_tensor = std::make_unique(element_type, shape, cpu_alloc); #ifdef USE_CUDA st = GetProviderInfo_CUDA().CreateGPUDataTransfer()->CopyTensor(rtensor, *cpu_tensor.get()); #endif #ifdef USE_ROCM st = GetProviderInfo_ROCM().CreateGPUDataTransfer()->CopyTensor(rtensor, *cpu_tensor.get()); +#endif +#ifdef USE_WEBGPU + st = gpu_provider->GetDataTransfer()->CopyTensor(rtensor, *cpu_tensor.get()); #endif ASSERT_TRUE(st.IsOK()); OrtValue ml_value; @@ -385,11 +420,40 @@ void RunModelWithBindingMatMul(InferenceSession& session_object, VerifyOutputs({ml_value}, expected_output_dims, expected_values_mul_y); #endif } else { - if (allocation_provider == kCudaExecutionProvider || allocation_provider == kRocmExecutionProvider) { + if (allocation_provider == kCudaExecutionProvider || allocation_provider == kRocmExecutionProvider || allocation_provider == kWebGpuExecutionProvider) { ASSERT_STATUS_OK(gpu_provider->Sync()); } VerifyOutputs(io_binding->GetOutputs(), expected_output_dims, expected_values_mul_y); } + + if (enable_graph_capture) { + // Update input_a's value. Run again. Replay the captured graph + OrtValue input_a2; + CreateMLValue(cpu_alloc, dims_mul_x_A_tmp, values_mul_x_tmp, &input_a2); + auto& cpu_tensor_a2 = input_a2.Get(); + st = gpu_provider->GetDataTransfer()->CopyTensor(cpu_tensor_a2, const_cast(io_binding->GetInputs()[0].Get())); + ASSERT_TRUE(st.IsOK()); + + st = session_object.Run(run_options, *io_binding.get()); + + std::cout << "Run returned status: " << st.ErrorMessage() << std::endl; + ASSERT_TRUE(st.IsOK()); + + // Copy the tensor from gpu to cpu + std::vector& outputs = io_binding->GetOutputs(); + ASSERT_EQ(1u, outputs.size()); + auto& rtensor = outputs.front().Get(); + auto element_type = rtensor.DataType(); + auto& shape = rtensor.Shape(); + std::unique_ptr cpu_tensor = std::make_unique(element_type, shape, cpu_alloc); + st = gpu_provider->GetDataTransfer()->CopyTensor(rtensor, *cpu_tensor.get()); + ASSERT_TRUE(st.IsOK()); + OrtValue ml_value; + ml_value.Init(cpu_tensor.release(), + DataTypeImpl::GetType(), + DataTypeImpl::GetType()->GetDeleteFunc()); + VerifyOutputs({ml_value}, expected_output_dims, expected_values_mul_y_2); + } } TEST(InferenceSessionTests, NoTimeout) { @@ -1059,16 +1123,16 @@ static void TestBindHelper(const std::string& log_str, ProviderType run_provider_type, bool preallocate_output, ProviderType allocation_provider = kCpuExecutionProvider, - OrtDevice* output_device = nullptr) { + OrtDevice* output_device = nullptr, + bool enable_graph_capture = false) { SessionOptions so; so.session_logid = "InferenceSessionTests." + log_str; so.session_log_verbosity_level = 1; // change to 1 for detailed logging - InferenceSession session_object{so, GetEnvironment()}; IExecutionProvider* gpu_provider{}; - if (bind_provider_type == kCudaExecutionProvider || bind_provider_type == kRocmExecutionProvider) { + if (bind_provider_type == kCudaExecutionProvider || bind_provider_type == kRocmExecutionProvider || bind_provider_type == kWebGpuExecutionProvider) { #ifdef USE_CUDA auto provider = DefaultCudaExecutionProvider(); gpu_provider = provider.get(); @@ -1078,6 +1142,15 @@ static void TestBindHelper(const std::string& log_str, auto provider = DefaultRocmExecutionProvider(); gpu_provider = provider.get(); ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(std::move(provider))); +#endif +#ifdef USE_WEBGPU + ConfigOptions config_options{}; + ORT_ENFORCE(config_options.AddConfigEntry(webgpu::options::kEnableGraphCapture, + enable_graph_capture ? webgpu::options::kEnableGraphCapture_ON : webgpu::options::kEnableGraphCapture_OFF) + .IsOK()); + auto provider = WebGpuExecutionProviderWithOptions(config_options); + gpu_provider = provider.get(); + ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(std::move(provider))); #endif } @@ -1100,7 +1173,8 @@ static void TestBindHelper(const std::string& log_str, preallocate_output, allocation_provider, gpu_provider, - output_device); + output_device, + enable_graph_capture); } TEST(InferenceSessionTests, TestBindCpu) { @@ -1187,12 +1261,15 @@ TEST(InferenceSessionTests, InvalidInputTypeOfTensorElement) { ASSERT_TRUE(!st.IsOK()); } -#if defined(USE_CUDA) || defined(USE_ROCM) +#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_WEBGPU) #if USE_CUDA constexpr const char* kGpuExecutionProvider = kCudaExecutionProvider; #elif USE_ROCM constexpr const char* kGpuExecutionProvider = kRocmExecutionProvider; +#elif USE_WEBGPU +constexpr const char* kGpuExecutionProvider = kWebGpuExecutionProvider; #endif + TEST(InferenceSessionTests, TestBindCuda) { TestBindHelper("TestBindCuda", kGpuExecutionProvider, @@ -1223,7 +1300,7 @@ TEST(InferenceSessionTests, TestBindCudaPreallocateOutputOnCpu2) { true /* preallocate output on CPU */, kCpuExecutionProvider); } - +#ifndef USE_WEBGPU TEST(InferenceSessionTests, TestBindCudaSpecifyOutputDeviceOnCuda) { OrtDevice device(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NVIDIA, 0); @@ -1234,7 +1311,17 @@ TEST(InferenceSessionTests, TestBindCudaSpecifyOutputDeviceOnCuda) { kGpuExecutionProvider, &device /* specify output device */); } - +#else +TEST(InferenceSessionTests, TestGraphCapture) { + TestBindHelper("TestGraphCapture", + kGpuExecutionProvider, + kGpuExecutionProvider, + true /* preallocate output on GPU */, + kGpuExecutionProvider, + nullptr, + true /* enable graph capture*/); +} +#endif // !USE_WEBGPU #endif TEST(InferenceSessionTests, ModelWithoutOpset) { diff --git a/onnxruntime/test/python/onnxruntime_test_python_autoep.py b/onnxruntime/test/python/onnxruntime_test_python_autoep.py index f1c924a1ade94..0c52740398b7a 100644 --- a/onnxruntime/test/python/onnxruntime_test_python_autoep.py +++ b/onnxruntime/test/python/onnxruntime_test_python_autoep.py @@ -210,7 +210,8 @@ def test_example_plugin_ep_devices(self): self.assertEqual(test_ep_device.ep_vendor, "Contoso") ep_metadata = test_ep_device.ep_metadata - self.assertEqual(ep_metadata["version"], "0.1") + self.assertEqual(ep_metadata["version"], "0.1.0") + self.assertEqual(ep_metadata["supported_devices"], "CrackGriffin 7+") ep_options = test_ep_device.ep_options self.assertEqual(ep_options["run_really_fast"], "true") diff --git a/onnxruntime/test/python/quantization/test_qoperator_adjust_int32_bias.py b/onnxruntime/test/python/quantization/test_qoperator_adjust_int32_bias.py new file mode 100644 index 0000000000000..e4c958996f773 --- /dev/null +++ b/onnxruntime/test/python/quantization/test_qoperator_adjust_int32_bias.py @@ -0,0 +1,105 @@ +import os +import tempfile +import unittest + +import numpy as np +import onnx +from op_test_utils import TestDataFeeds, check_model_correctness + +from onnxruntime.quantization import QuantFormat, QuantType, quantize_static + + +class TestAdjustWeightScaleForInt32BiasQOperator(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls._tmp_model_dir = tempfile.TemporaryDirectory(prefix="ort.qop.adj_int32_bias_") + cls._tmp_dir_path = cls._tmp_model_dir.name + + @classmethod + def tearDownClass(cls): + cls._tmp_model_dir.cleanup() + + def build_conv_test_model(self, input_shape, weight_shape, onnx_float_type): + np_float_type = onnx.helper.tensor_dtype_to_np_dtype(onnx_float_type) + input_0 = onnx.helper.make_tensor_value_info("input_0", onnx_float_type, input_shape) + output_0 = onnx.helper.make_tensor_value_info("output_0", onnx_float_type, None) + + tiny_value = 1e-7 if np_float_type == np.float32 else 0.007782 + + # Step 1: reshape to (C_out, -1) to ensure per-channel broadcasting + weight_data = np.full(weight_shape, tiny_value, dtype=np_float_type) + weight_data = weight_data.reshape(weight_shape[0], -1) + for i in range(weight_data.shape[0]): + for j in range(weight_data.shape[1]): + if j % 2 == 0: + weight_data[i, j] = -weight_data[i, j] + # Step 2: reshape back to original shape + weight_data = weight_data.reshape(weight_shape) + weight = onnx.numpy_helper.from_array(weight_data, "weight") + + bias_shape = [weight_shape[0]] + bias_data = np.ones(bias_shape, dtype=np_float_type) + for i in range(len(bias_data)): + bias_data[i] = 5.0 if (i % 2 == 0) else -4.5 + if np_float_type == np.float16: + bias_data[i] = 1400 if (i % 2 == 0) else -1200 + bias = onnx.numpy_helper.from_array(bias_data, "bias") + + conv_node = onnx.helper.make_node("Conv", ["input_0", "weight", "bias"], ["output_0"], name="Conv0") + graph = onnx.helper.make_graph([conv_node], "Convfloat", [input_0], [output_0], initializer=[weight, bias]) + opset_imports = [onnx.helper.make_opsetid("", 21)] + model = onnx.helper.make_model(graph, opset_imports=opset_imports) + model = onnx.shape_inference.infer_shapes(model) + onnx.checker.check_model(model, True) + return model + + def test_adjust_weight_scale_for_int32_bias_qop(self): + test_configs = [ + (onnx.TensorProto.FLOAT, True), + (onnx.TensorProto.FLOAT, False), + (onnx.TensorProto.FLOAT, True), + (onnx.TensorProto.FLOAT, False), + ] + + for float_type, per_channel in test_configs: + with self.subTest(float_type=float_type, per_channel=per_channel): + label = f"_f{float_type}_perchannel{per_channel}" + float_model_path = os.path.join(self._tmp_dir_path, f"conv{label}.float.onnx") + qop_model_path = os.path.join(self._tmp_dir_path, f"conv{label}.qop.onnx") + + input_shape = [1, 1, 128, 128] + weight_shape = [8, 1, 1, 1] + float_model = self.build_conv_test_model(input_shape, weight_shape, float_type) + onnx.save_model(float_model, float_model_path) + + np_float_type = onnx.helper.tensor_dtype_to_np_dtype(float_type) + input_rmin = 0.0 + input_scale = 0.05 if float_type == onnx.TensorProto.FLOAT else 0.01 + input_rmax = (input_scale * 255.0) + input_rmin + input_data_list = [ + {"input_0": np.full(input_shape, input_rmin, dtype=np_float_type)}, + {"input_0": np.full(input_shape, (input_rmax - input_rmin) / 2.0, dtype=np_float_type)}, + {"input_0": np.full(input_shape, input_rmax, dtype=np_float_type)}, + ] + data_reader = TestDataFeeds(input_data_list) + + quantize_static( + float_model_path, + qop_model_path, + data_reader, + activation_type=QuantType.QInt8, + weight_type=QuantType.QInt8, + per_channel=per_channel, + quant_format=QuantFormat.QOperator, + extra_options={ + "ActivationSymmetric": True, + "WeightSymmetric": True, + }, + ) + + data_reader.rewind() + check_model_correctness(self, float_model_path, qop_model_path, data_reader.get_next()) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxruntime/test/util/default_providers.cc b/onnxruntime/test/util/default_providers.cc index 81cb56d34c925..2e4aa3923b649 100644 --- a/onnxruntime/test/util/default_providers.cc +++ b/onnxruntime/test/util/default_providers.cc @@ -313,6 +313,15 @@ std::unique_ptr DefaultWebGpuExecutionProvider(bool is_nhwc) #endif } +std::unique_ptr WebGpuExecutionProviderWithOptions(const ConfigOptions& config_options) { +#ifdef USE_WEBGPU + return WebGpuProviderFactoryCreator::Create(config_options)->CreateProvider(); +#else + ORT_UNUSED_PARAMETER(config_options); + return nullptr; +#endif +} + std::unique_ptr DefaultCannExecutionProvider() { #ifdef USE_CANN OrtCANNProviderOptions provider_options{}; diff --git a/onnxruntime/test/util/include/default_providers.h b/onnxruntime/test/util/include/default_providers.h index ce6434991051c..67d85edb4b8ef 100644 --- a/onnxruntime/test/util/include/default_providers.h +++ b/onnxruntime/test/util/include/default_providers.h @@ -5,6 +5,7 @@ #include "core/common/optional.h" #include "core/providers/providers.h" #include "core/providers/provider_factory_creators.h" +#include "core/framework/config_options.h" #include "core/framework/execution_provider.h" namespace onnxruntime { @@ -64,6 +65,7 @@ std::unique_ptr QnnExecutionProviderWithOptions(const Provid const SessionOptions* session_options = nullptr); std::unique_ptr DefaultXnnpackExecutionProvider(); std::unique_ptr DefaultWebGpuExecutionProvider(bool is_nhwc = true); +std::unique_ptr WebGpuExecutionProviderWithOptions(const ConfigOptions& config_options); std::unique_ptr DefaultCannExecutionProvider(); std::unique_ptr DefaultDmlExecutionProvider();