diff --git a/.github/workflows/android.yml b/.github/workflows/android.yml index f83847c3d6df4..feba60d58b9a2 100644 --- a/.github/workflows/android.yml +++ b/.github/workflows/android.yml @@ -124,7 +124,7 @@ jobs: - uses: actions/checkout@v5 - name: Use jdk 17 - uses: actions/setup-java@v4 + uses: actions/setup-java@v5 with: distribution: 'temurin' java-version: '17' @@ -206,7 +206,7 @@ jobs: - uses: actions/checkout@v5 - name: Use jdk 17 - uses: actions/setup-java@v4 + uses: actions/setup-java@v5 with: distribution: 'temurin' java-version: '17' diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml index efe580c1b3b0c..8630abba416e1 100644 --- a/.github/workflows/codeql.yml +++ b/.github/workflows/codeql.yml @@ -55,7 +55,7 @@ jobs: # Setup Java to use a version that is not too old for the project - if: ${{ matrix.language == 'java' }} name: Setup Java 11 - uses: actions/setup-java@v4 + uses: actions/setup-java@v5 with: java-version: '11' distribution: 'microsoft' diff --git a/.github/workflows/publish-java-apidocs.yml b/.github/workflows/publish-java-apidocs.yml index c107535473786..742069542ecb9 100644 --- a/.github/workflows/publish-java-apidocs.yml +++ b/.github/workflows/publish-java-apidocs.yml @@ -25,7 +25,7 @@ jobs: steps: - uses: actions/checkout@v5 - name: Set up JDK 11 - uses: actions/setup-java@v4 + uses: actions/setup-java@v5 with: java-version: '11' distribution: 'adopt' diff --git a/.github/workflows/windows_cuda.yml b/.github/workflows/windows_cuda.yml index 92429de408cda..fb5ad36d2ab1d 100644 --- a/.github/workflows/windows_cuda.yml +++ b/.github/workflows/windows_cuda.yml @@ -60,7 +60,7 @@ jobs: with: node-version: '20.x' - - uses: actions/setup-java@v4 + - uses: actions/setup-java@v5 with: distribution: 'temurin' java-version: '17' @@ -172,7 +172,7 @@ jobs: with: node-version: '20.x' - - uses: actions/setup-java@v4 + - uses: actions/setup-java@v5 with: distribution: 'temurin' java-version: '17' diff --git a/.github/workflows/windows_dml.yml b/.github/workflows/windows_dml.yml index fe45f4af0cd46..c3db42fdb15a3 100644 --- a/.github/workflows/windows_dml.yml +++ b/.github/workflows/windows_dml.yml @@ -51,7 +51,7 @@ jobs: with: node-version: '20.x' - - uses: actions/setup-java@v4 + - uses: actions/setup-java@v5 with: distribution: 'temurin' java-version: '17' diff --git a/.github/workflows/windows_tensorrt.yml b/.github/workflows/windows_tensorrt.yml index 0aef550576d21..3a176ac41ebcf 100644 --- a/.github/workflows/windows_tensorrt.yml +++ b/.github/workflows/windows_tensorrt.yml @@ -65,7 +65,7 @@ jobs: with: node-version: '20.x' - - uses: actions/setup-java@v4 + - uses: actions/setup-java@v5 with: distribution: 'temurin' java-version: '17' @@ -177,7 +177,7 @@ jobs: with: node-version: '20.x' - - uses: actions/setup-java@v4 + - uses: actions/setup-java@v5 with: distribution: 'temurin' java-version: '17' diff --git a/.github/workflows/windows_webgpu.yml b/.github/workflows/windows_webgpu.yml index 730dafa0ddc43..1924e5d617679 100644 --- a/.github/workflows/windows_webgpu.yml +++ b/.github/workflows/windows_webgpu.yml @@ -61,7 +61,7 @@ jobs: node-version: "20.x" - name: Setup Java - uses: actions/setup-java@v4 + uses: actions/setup-java@v5 with: distribution: "temurin" java-version: "17" @@ -249,7 +249,7 @@ jobs: node-version: "20.x" - name: Setup Java - uses: actions/setup-java@v4 + uses: actions/setup-java@v5 with: distribution: "temurin" java-version: "17" diff --git a/.github/workflows/windows_x64_debug_build_x64_debug.yml b/.github/workflows/windows_x64_debug_build_x64_debug.yml index c31027ce32fcd..1f71d3020bce2 100644 --- a/.github/workflows/windows_x64_debug_build_x64_debug.yml +++ b/.github/workflows/windows_x64_debug_build_x64_debug.yml @@ -43,7 +43,7 @@ jobs: node-version: '20.x' - name: Setup Java - uses: actions/setup-java@v4 + uses: actions/setup-java@v5 with: distribution: 'temurin' java-version: '17' diff --git a/.github/workflows/windows_x64_release_build_x64_release.yml b/.github/workflows/windows_x64_release_build_x64_release.yml index 8a9d79aa4b8d1..dec37b8f5620d 100644 --- a/.github/workflows/windows_x64_release_build_x64_release.yml +++ b/.github/workflows/windows_x64_release_build_x64_release.yml @@ -43,7 +43,7 @@ jobs: node-version: '20.x' - name: Setup Java - uses: actions/setup-java@v4 + uses: actions/setup-java@v5 with: distribution: 'temurin' java-version: '17' diff --git a/.github/workflows/windows_x64_release_ep_generic_interface_build_x64_release_ep_generic_interface.yml b/.github/workflows/windows_x64_release_ep_generic_interface_build_x64_release_ep_generic_interface.yml index a3639a9bcbfac..978f3e345141f 100644 --- a/.github/workflows/windows_x64_release_ep_generic_interface_build_x64_release_ep_generic_interface.yml +++ b/.github/workflows/windows_x64_release_ep_generic_interface_build_x64_release_ep_generic_interface.yml @@ -43,7 +43,7 @@ jobs: node-version: '20.x' - name: Setup Java - uses: actions/setup-java@v4 + uses: actions/setup-java@v5 with: distribution: 'temurin' java-version: '17' diff --git a/.github/workflows/windows_x64_release_vitisai_build_x64_release.yml b/.github/workflows/windows_x64_release_vitisai_build_x64_release.yml index e356d9ad15c99..983bfaf411983 100644 --- a/.github/workflows/windows_x64_release_vitisai_build_x64_release.yml +++ b/.github/workflows/windows_x64_release_vitisai_build_x64_release.yml @@ -43,7 +43,7 @@ jobs: node-version: '20.x' - name: Setup Java - uses: actions/setup-java@v4 + uses: actions/setup-java@v5 with: distribution: 'temurin' java-version: '17' diff --git a/.github/workflows/windows_x64_release_xnnpack.yml b/.github/workflows/windows_x64_release_xnnpack.yml index fa08c7e47cd36..c9d76ee450015 100644 --- a/.github/workflows/windows_x64_release_xnnpack.yml +++ b/.github/workflows/windows_x64_release_xnnpack.yml @@ -43,7 +43,7 @@ jobs: node-version: '20.x' - name: Setup Java - uses: actions/setup-java@v4 + uses: actions/setup-java@v5 with: distribution: 'temurin' java-version: '17' diff --git a/.github/workflows/windows_x86.yml b/.github/workflows/windows_x86.yml index f7774a70dbd43..762762ebcd435 100644 --- a/.github/workflows/windows_x86.yml +++ b/.github/workflows/windows_x86.yml @@ -44,7 +44,7 @@ jobs: architecture: x86 #Add architecture - name: Setup Java - uses: actions/setup-java@v4 + uses: actions/setup-java@v5 with: distribution: 'temurin' java-version: '17' diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 1f37432cc530c..98548957d0b42 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -22,7 +22,9 @@ if("${CMAKE_CXX_COMPILER_ID}" MATCHES "IntelLLVM") endif() # Needed for Java -set(CMAKE_C_STANDARD 99) +if (NOT CMAKE_CXX_STANDARD) + set(CMAKE_C_STANDARD 99) +endif() include(CheckCXXCompilerFlag) include(CheckLanguage) @@ -32,11 +34,13 @@ include(CheckFunctionExists) include(CheckSymbolExists) include(GNUInstallDirs) # onnxruntime_providers_* require CMAKE_INSTALL_* variables -# TODO: update this once all system adapt c++20 -if(CMAKE_SYSTEM_NAME STREQUAL "Darwin") -set(CMAKE_CXX_STANDARD 20) -else() -set(CMAKE_CXX_STANDARD 17) +if (NOT CMAKE_CXX_STANDARD) + # TODO: update this once all system adapt c++20 + if (CMAKE_SYSTEM_NAME STREQUAL "Darwin") + set(CMAKE_CXX_STANDARD 20) + else() + set(CMAKE_CXX_STANDARD 17) + endif() endif() if (MSVC) diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 5b78235988413..6847db64004ca 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -1232,6 +1232,12 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) ${onnxruntime_perf_test_src_patterns} ) onnxruntime_add_executable(onnxruntime_perf_test ${onnxruntime_perf_test_src} ${ONNXRUNTIME_ROOT}/core/platform/path_lib.cc) + + # ABSL_FLAGS_STRIP_NAMES is set to 1 by default to disable flag registration when building for Android, iPhone, and "embedded devices". + # See the issue: https://github.com/abseil/abseil-cpp/issues/1875 + # We set it to 0 for all builds to be able to use ABSL flags for onnxruntime_perf_test. + target_compile_definitions(onnxruntime_perf_test PRIVATE ABSL_FLAGS_STRIP_NAMES=0) + if(MSVC) target_compile_options(onnxruntime_perf_test PRIVATE "$<$:SHELL:--compiler-options /utf-8>" "$<$>:/utf-8>") diff --git a/csharp/ApiDocs/_exported_templates/default/styles/docfx.css b/csharp/ApiDocs/_exported_templates/default/styles/docfx.css index 64dcde3385eff..0d3cf530038d9 100644 --- a/csharp/ApiDocs/_exported_templates/default/styles/docfx.css +++ b/csharp/ApiDocs/_exported_templates/default/styles/docfx.css @@ -323,6 +323,8 @@ article section { } .docs-search > .search-query:focus { outline: 0; + border: 2px solid #0050C5; + background-color: #f8f9fa; } .search-results-frame { clear: both; @@ -597,6 +599,8 @@ body .toc{ } .toc-filter > input:focus { outline: 0; + border: 2px solid #0050C5; + background-color: #f8f9fa; } .toc-filter > .filter-icon { position: absolute; diff --git a/csharp/ApiDocs/_exported_templates/default/styles/docfx.js b/csharp/ApiDocs/_exported_templates/default/styles/docfx.js index b6167dbd6193d..6412c718be37d 100644 --- a/csharp/ApiDocs/_exported_templates/default/styles/docfx.js +++ b/csharp/ApiDocs/_exported_templates/default/styles/docfx.js @@ -802,6 +802,7 @@ $(function () { } } container.addEventListener('click', function (event) { return handleClick(event, state); }); + container.addEventListener('keydown', function (event) { return handleKeyDown(event, state); }); if (state.groups.length === 0) { return state; } @@ -820,6 +821,7 @@ $(function () { while (li) { var a = li.firstElementChild; a.setAttribute(contentAttrs.name, 'tab'); + a.setAttribute('role', 'tab'); var dataTab = a.getAttribute('data-tab').replace(/\+/g, ' '); a.setAttribute('data-tab', dataTab); var section = element.querySelector("[id=\"" + a.getAttribute('aria-controls') + "\"]"); @@ -915,6 +917,91 @@ $(function () { } } + function handleKeyDown(event, state) { + var info = getTabInfoFromEvent(event); + if (info === null) { + return; + } + + var handled = false; + var tabGroup = info.group; + var currentTabIndex = tabGroup.tabs.findIndex(function(tab) { return tab.a === info.anchor; }); + + switch (event.key) { + case 'ArrowLeft': + case 'ArrowUp': + // Move to previous tab + handled = true; + var prevIndex = currentTabIndex - 1; + if (prevIndex < 0) { + prevIndex = tabGroup.tabs.length - 1; + } + while (prevIndex !== currentTabIndex && !tabGroup.tabs[prevIndex].visible) { + prevIndex--; + if (prevIndex < 0) { + prevIndex = tabGroup.tabs.length - 1; + } + } + if (tabGroup.tabs[prevIndex].visible) { + tabGroup.tabs[prevIndex].focus(); + } + break; + + case 'ArrowRight': + case 'ArrowDown': + // Move to next tab + handled = true; + var nextIndex = currentTabIndex + 1; + if (nextIndex >= tabGroup.tabs.length) { + nextIndex = 0; + } + while (nextIndex !== currentTabIndex && !tabGroup.tabs[nextIndex].visible) { + nextIndex++; + if (nextIndex >= tabGroup.tabs.length) { + nextIndex = 0; + } + } + if (tabGroup.tabs[nextIndex].visible) { + tabGroup.tabs[nextIndex].focus(); + } + break; + + case 'Home': + // Move to first visible tab + handled = true; + for (var i = 0; i < tabGroup.tabs.length; i++) { + if (tabGroup.tabs[i].visible) { + tabGroup.tabs[i].focus(); + break; + } + } + break; + + case 'End': + // Move to last visible tab + handled = true; + for (var i = tabGroup.tabs.length - 1; i >= 0; i--) { + if (tabGroup.tabs[i].visible) { + tabGroup.tabs[i].focus(); + break; + } + } + break; + + case 'Enter': + case ' ': // Space key + // Activate the current tab + handled = true; + handleClick(event, state); + break; + } + + if (handled) { + event.preventDefault(); + event.stopPropagation(); + } + } + function selectTabs(tabIds) { for (var _i = 0, tabIds_1 = tabIds; _i < tabIds_1.length; _i++) { var tabId = tabIds_1[_i]; diff --git a/csharp/ApiDocs/_exported_templates/default/styles/docfx.vendor.css b/csharp/ApiDocs/_exported_templates/default/styles/docfx.vendor.css index 609602eb134dd..0c9f7c0c0aec3 100644 --- a/csharp/ApiDocs/_exported_templates/default/styles/docfx.vendor.css +++ b/csharp/ApiDocs/_exported_templates/default/styles/docfx.vendor.css @@ -1220,7 +1220,7 @@ to{background-position:0 0} .list-group-item.active .list-group-item-text,.list-group-item.active:focus .list-group-item-text,.list-group-item.active:hover .list-group-item-text{color:#c7ddef} a.list-group-item,button.list-group-item{color:#555} a.list-group-item .list-group-item-heading,button.list-group-item .list-group-item-heading{color:#333} -a.list-group-item:focus,a.list-group-item:hover,button.list-group-item:focus,button.list-group-item:hover{color:#555;text-decoration:none;background-color:#f5f5f5} +a.list-group-item:focus,a.list-group-item:hover,button.list-group-item:focus,button.list-group-item:hover{color:#333;text-decoration:none;background-color:#f5f5f5} button.list-group-item{width:100%;text-align:left} .list-group-item-success{color:#3c763d;background-color:#dff0d8} a.list-group-item-success,button.list-group-item-success{color:#3c763d} @@ -1230,7 +1230,7 @@ a.list-group-item-success.active,a.list-group-item-success.active:focus,a.list-g .list-group-item-info{color:#31708f;background-color:#d9edf7} a.list-group-item-info,button.list-group-item-info{color:#31708f} a.list-group-item-info .list-group-item-heading,button.list-group-item-info .list-group-item-heading{color:inherit} -a.list-group-item-info:focus,a.list-group-item-info:hover,button.list-group-item-info:focus,button.list-group-item-info:hover{color:#31708f;background-color:#c4e3f3} +a.list-group-item-info:focus,a.list-group-item-info:hover,button.list-group-item-info:focus,button.list-group-item-info:hover{color:#1e4a5f;background-color:#c4e3f3} a.list-group-item-info.active,a.list-group-item-info.active:focus,a.list-group-item-info.active:hover,button.list-group-item-info.active,button.list-group-item-info.active:focus,button.list-group-item-info.active:hover{color:#fff;background-color:#31708f;border-color:#31708f} .list-group-item-warning{color:#8a6d3b;background-color:#fcf8e3} a.list-group-item-warning,button.list-group-item-warning{color:#8a6d3b} diff --git a/include/onnxruntime/core/graph/model_saving_options.h b/include/onnxruntime/core/graph/model_saving_options.h index 6c041ec96a035..06c1b1ac6475f 100644 --- a/include/onnxruntime/core/graph/model_saving_options.h +++ b/include/onnxruntime/core/graph/model_saving_options.h @@ -9,36 +9,30 @@ class PrepackedWeightsForGraph; // These options affect how the model initializers are written to the external file. // This includes options to align external initializer offset. -// For models running on CPU, ORT will try to use mmap to load external -// initializers. To use mmap, external initializer need to be offset aligned. +// ORT will try to use mmap to load external initializers. +// // ORT saves external initializers into single data file, each initializer is // accessed with offset(start position of initializer) and length(byte length of -// initializer) of the data file. To use mmap, each offset need to be aligned -// which means offset need to divisible by allocation granularity(64KB for -// windows and 4K for other OSes). With align_offset to true, ORT will align -// offset for large initializer when save ONNX model with external data file. +// initializer) of the data file. With align_offset to true, ORT will align +// offset for large initializer (larger than align_threshold) +// when save ONNX model with external data file. It will align then to +// on_disk_alignment value. struct ModelSavingOptions { explicit ModelSavingOptions(size_t size_threshold) : initializer_size_threshold(size_threshold) {} // Minimal initializer size in bytes to be externalized on disk size_t initializer_size_threshold; - // Offset will always be page aligned and allocation granularity aligned for - // mmap support. This is done by padding previous tensor data with zeros - // keeping same length. + // Offset will always be aligned for mmap support. + // This is done by padding previous tensor data with zeros keeping same length. bool align_offset = false; // Alignment threshold for size of data. // Having a low threshold will waste file space for small initializers. // Only when tensor's data size is > the page_align_threshold it will be force // aligned. Default to 1MB. int64_t align_threshold = 1048576; - // The allocation Granularity for mmap() support. - // Typically 64KB for Windows & 4KB for other OSes. Default to 64KB. -#ifdef _WIN32 - int64_t allocation_granularity = 65536; -#else - int64_t allocation_granularity = 4096; -#endif + // Alignment factor for big tensors (bigger than align_threshold). Defaults to 4K. + int64_t on_disk_alignment = 4096; // Force embed all external initializer into the Onnx file // Used for EPContext model generation while some nodes fallback on CPU which has external data dependency bool force_embed_external_ini = false; diff --git a/include/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_options.h b/include/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_options.h index dc27204017caa..a32f465e44adf 100644 --- a/include/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_options.h +++ b/include/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_options.h @@ -31,7 +31,7 @@ constexpr const char* kDetailedBuildLog = "nv_detailed_build_log"; constexpr const char* kProfilesMinShapes = "nv_profile_min_shapes"; constexpr const char* kProfilesMaxShapes = "nv_profile_max_shapes"; constexpr const char* kProfilesOptShapes = "nv_profile_opt_shapes"; -constexpr const char* kCudaGraphEnable = "nv_cuda_graph_enable"; +constexpr const char* kCudaGraphEnable = "enable_cuda_graph"; constexpr const char* kMultiProfileEnable = "nv_multi_profile_enable"; constexpr const char* kUseExternalDataInitializer = "nv_use_external_data_initializer"; diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index bedeeb972c3a7..9ae6174817b7c 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -87,7 +87,7 @@ extern "C" { #else #define ORT_EXPORT #endif -#define ORT_API_CALL _stdcall +#define ORT_API_CALL __stdcall #define ORT_MUST_USE_RESULT #define ORTCHAR_T wchar_t #else @@ -902,6 +902,16 @@ typedef void (*RunAsyncCallbackFn)(void* user_data, OrtValue** outputs, size_t n * * \nosubgrouping */ +/* + * Public enum for compiled model compatibility across EPs. + */ +typedef enum OrtCompiledModelCompatibility { + OrtCompiledModelCompatibility_EP_NOT_APPLICABLE = 0, + OrtCompiledModelCompatibility_EP_SUPPORTED_OPTIMAL, + OrtCompiledModelCompatibility_EP_SUPPORTED_PREFER_RECOMPILATION, + OrtCompiledModelCompatibility_EP_UNSUPPORTED, +} OrtCompiledModelCompatibility; + struct OrtApi { /// \name OrtStatus /// @{ @@ -6480,6 +6490,24 @@ struct OrtApi { * \since Version 1.23. */ ORT_API2_STATUS(Graph_GetModelMetadata, _In_ const OrtGraph* graph, _Outptr_ OrtModelMetadata** out); + + /** \brief Validate a compiled model's compatibility information for one or more EP devices. + * + * \param[in] ep_devices The EP devices to validate against (e.g., from GetEpDevices). + * All devices must belong to the same execution provider. + * \param[in] num_ep_devices The number of EP devices provided. + * \param[in] compatibility_info The compatibility info string produced when the model was compiled. + * \param[out] out_status The resulting compatibility status for the EP devices. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(GetModelCompatibilityForEpDevices, + _In_reads_(num_ep_devices) const OrtEpDevice* const* ep_devices, + _In_ size_t num_ep_devices, + _In_ const char* compatibility_info, + _Out_ OrtCompiledModelCompatibility* out_status); }; /* diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index 2f4fd36c8115f..c39e27088e8bc 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -725,9 +725,7 @@ using AllocatedStringPtr = std::unique_ptr; * constructors to construct an instance of a Status object from exceptions. */ struct Status : detail::Base { - using Base = detail::Base; - using Base::Base; - + Status() = default; // Same as with std::nullptr_t. But can be used in re-sizable containers and represent success. explicit Status(std::nullptr_t) noexcept {} ///< Create an empty object, must be assigned a valid one to be used explicit Status(OrtStatus* status) noexcept; ///< Takes ownership of OrtStatus instance returned from the C API. explicit Status(const Exception&); ///< Creates status instance out of exception diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index 73200d8852223..d0089726812a3 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -823,7 +823,7 @@ inline Status Env::CopyTensors(const std::vector& src_tensors, return Status("Source and destination tensor vectors must have the same size", ORT_INVALID_ARGUMENT); } if (src_tensors.empty()) { - return Status(); + return Status(nullptr); } const OrtValue* const* src_tensors_ptr = reinterpret_cast(src_tensors.data()); diff --git a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h index 620cb5fcf13cc..975f6b453a88d 100644 --- a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h @@ -482,18 +482,6 @@ typedef enum OrtEpDataLayout { OrtEpDataLayout_Default = OrtEpDataLayout_NCHW, } OrtEpDataLayout; -/** - * \brief Enumeration describing the compatibility state of a compiled model relative to an execution provider. - * - * \since Version 1.23. - */ -typedef enum OrtCompiledModelCompatibility { - OrtCompiledModelCompatibility_EP_NOT_APPLICABLE = 0, - OrtCompiledModelCompatibility_EP_SUPPORTED_OPTIMAL, - OrtCompiledModelCompatibility_EP_SUPPORTED_PREFER_RECOMPILATION, - OrtCompiledModelCompatibility_EP_UNSUPPORTED, -} OrtCompiledModelCompatibility; - /** * \brief The OrtEp struct provides functions to implement for an execution provider. * \since Version 1.22. @@ -901,20 +889,28 @@ struct OrtEpFactory { */ ORT_API_T(const char*, GetVersion, _In_ const OrtEpFactory* this_ptr); - /** \brief Validate the compatibility of a compiled model with the execution provider. + /** \brief Validate the compatibility of a compiled model with the execution provider factory for one or more devices. + * + * Given a compatibility info string produced during model compilation, the EP factory should determine whether the + * compiled model is compatible with the EP factory when targeting the provided hardware devices. All devices provided + * must belong to the same execution provider instance that this factory creates. * - * This function validates if a model produced with the supplied compatibility info string is supported by the underlying EP. - * The EP should check if a compiled model is compatible with the EP and set the model_compatibility parameter accordingly. + * The EP factory implementation should consider the set of devices (e.g., multi-adapter or multi-GPU scenarios) when + * evaluating compatibility and set `model_compatibility` accordingly. * * \param[in] this_ptr The OrtEpFactory instance. - * \param[in] compatibility_info The compatibility information string that will be used - * \param[out] model_compatibility OrtCompiledModelCompatibility enum value describing the compatibility of the model with the EP. + * \param[in] devices Array of OrtHardwareDevice pointers that the EP would run on. All must map to this EP. + * \param[in] num_devices Number of entries in `devices`. + * \param[in] compatibility_info The compatibility information string produced when the model was compiled. + * \param[out] model_compatibility OrtCompiledModelCompatibility value describing the compatibility of the model with the EP. * * \snippet{doc} snippets.dox OrtStatus Return Value * * \since Version 1.23. */ ORT_API2_STATUS(ValidateCompiledModelCompatibilityInfo, _In_ OrtEpFactory* this_ptr, + _In_reads_(num_devices) const OrtHardwareDevice* const* devices, + _In_ size_t num_devices, _In_ const char* compatibility_info, _Out_ OrtCompiledModelCompatibility* model_compatibility); diff --git a/js/react_native/e2e/package-lock.json b/js/react_native/e2e/package-lock.json index a537c53b5e1ba..9abe23940e0f2 100644 --- a/js/react_native/e2e/package-lock.json +++ b/js/react_native/e2e/package-lock.json @@ -57,14 +57,14 @@ } }, "node_modules/@babel/code-frame": { - "version": "7.26.2", - "resolved": "https://registry.npmjs.org/@babel/code-frame/-/code-frame-7.26.2.tgz", - "integrity": "sha512-RJlIHRueQgwWitWgF8OdFYGZX328Ax5BCemNGlqHfplnRT9ESi8JkFlvaVYbS+UubVY6dpv87Fs2u5M29iNFVQ==", + "version": "7.27.1", + "resolved": "https://registry.npmjs.org/@babel/code-frame/-/code-frame-7.27.1.tgz", + "integrity": "sha512-cjQ7ZlQ0Mv3b47hABuTevyTuYN4i+loJKGeV9flcCgIK37cCXRh+L1bd3iBHlynerhQ7BhCkn2BPbQUL+rGqFg==", "license": "MIT", "dependencies": { - "@babel/helper-validator-identifier": "^7.25.9", + "@babel/helper-validator-identifier": "^7.27.1", "js-tokens": "^4.0.0", - "picocolors": "^1.0.0" + "picocolors": "^1.1.1" }, "engines": { "node": ">=6.9.0" @@ -417,18 +417,18 @@ } }, "node_modules/@babel/helper-string-parser": { - "version": "7.25.9", - "resolved": "https://registry.npmjs.org/@babel/helper-string-parser/-/helper-string-parser-7.25.9.tgz", - "integrity": "sha512-4A/SCr/2KLd5jrtOMFzaKjVtAei3+2r/NChoBNoZ3EyP/+GlhoaEGoWOZUmFmoITP7zOJyHIMm+DYRd8o3PvHA==", + "version": "7.27.1", + "resolved": "https://registry.npmjs.org/@babel/helper-string-parser/-/helper-string-parser-7.27.1.tgz", + "integrity": "sha512-qMlSxKbpRlAridDExk92nSobyDdpPijUq2DW6oDnUqd0iOGxmQjyqhMIihI9+zv4LPyZdRje2cavWPbCbWm3eA==", "license": "MIT", "engines": { "node": ">=6.9.0" } }, "node_modules/@babel/helper-validator-identifier": { - "version": "7.25.9", - "resolved": "https://registry.npmjs.org/@babel/helper-validator-identifier/-/helper-validator-identifier-7.25.9.tgz", - "integrity": "sha512-Ed61U6XJc3CVRfkERJWDz4dJwKe7iLmmJsbOGu9wSloNSFttHV0I8g6UAgb7qnK5ly5bGLPd4oXZlxCdANBOWQ==", + "version": "7.27.1", + "resolved": "https://registry.npmjs.org/@babel/helper-validator-identifier/-/helper-validator-identifier-7.27.1.tgz", + "integrity": "sha512-D2hP9eA+Sqx1kBZgzxZh0y1trbuU+JoDkiEwqhQ36nodYqJwyEIhPSdMNd7lOm/4io72luTPWH20Yda0xOuUow==", "license": "MIT", "engines": { "node": ">=6.9.0" @@ -458,25 +458,25 @@ } }, "node_modules/@babel/helpers": { - "version": "7.25.6", - "resolved": "https://registry.npmjs.org/@babel/helpers/-/helpers-7.25.6.tgz", - "integrity": "sha512-Xg0tn4HcfTijTwfDwYlvVCl43V6h4KyVVX2aEm4qdO/PC6L2YvzLHFdmxhoeSA3eslcE6+ZVXHgWwopXYLNq4Q==", + "version": "7.28.3", + "resolved": "https://registry.npmjs.org/@babel/helpers/-/helpers-7.28.3.tgz", + "integrity": "sha512-PTNtvUQihsAsDHMOP5pfobP8C6CM4JWXmP8DrEIt46c3r2bf87Ua1zoqevsMo9g+tWDwgWrFP5EIxuBx5RudAw==", "license": "MIT", "dependencies": { - "@babel/template": "^7.25.0", - "@babel/types": "^7.25.6" + "@babel/template": "^7.27.2", + "@babel/types": "^7.28.2" }, "engines": { "node": ">=6.9.0" } }, "node_modules/@babel/parser": { - "version": "7.26.9", - "resolved": "https://registry.npmjs.org/@babel/parser/-/parser-7.26.9.tgz", - "integrity": "sha512-81NWa1njQblgZbQHxWHpxxCzNsa3ZwvFqpUg7P+NNUU6f3UU2jBEg4OlF/J6rl8+PQGh1q6/zWScd001YwcA5A==", + "version": "7.28.3", + "resolved": "https://registry.npmjs.org/@babel/parser/-/parser-7.28.3.tgz", + "integrity": "sha512-7+Ey1mAgYqFAx2h0RuoxcQT5+MlG3GTV0TQrgr7/ZliKsm/MNDxVVutlWaziMq7wJNAz8MTqz55XLpWvva6StA==", "license": "MIT", "dependencies": { - "@babel/types": "^7.26.9" + "@babel/types": "^7.28.2" }, "bin": { "parser": "bin/babel-parser.js" @@ -2189,14 +2189,14 @@ } }, "node_modules/@babel/template": { - "version": "7.26.9", - "resolved": "https://registry.npmjs.org/@babel/template/-/template-7.26.9.tgz", - "integrity": "sha512-qyRplbeIpNZhmzOysF/wFMuP9sctmh2cFzRAZOn1YapxBsE1i9bJIY586R/WBLfLcmcBlM8ROBiQURnnNy+zfA==", + "version": "7.27.2", + "resolved": "https://registry.npmjs.org/@babel/template/-/template-7.27.2.tgz", + "integrity": "sha512-LPDZ85aEJyYSd18/DkjNh4/y1ntkE5KwUHWTiqgRxruuZL2F1yuHligVHLvcHY2vMHXttKFpJn6LwfI7cw7ODw==", "license": "MIT", "dependencies": { - "@babel/code-frame": "^7.26.2", - "@babel/parser": "^7.26.9", - "@babel/types": "^7.26.9" + "@babel/code-frame": "^7.27.1", + "@babel/parser": "^7.27.2", + "@babel/types": "^7.27.1" }, "engines": { "node": ">=6.9.0" @@ -2240,13 +2240,13 @@ "license": "MIT" }, "node_modules/@babel/types": { - "version": "7.26.9", - "resolved": "https://registry.npmjs.org/@babel/types/-/types-7.26.9.tgz", - "integrity": "sha512-Y3IR1cRnOxOCDvMmNiym7XpXQ93iGDDPHx+Zj+NM+rg0fBaShfQLkg+hKPaZCEvg5N/LeCo4+Rj/i3FuJsIQaw==", + "version": "7.28.2", + "resolved": "https://registry.npmjs.org/@babel/types/-/types-7.28.2.tgz", + "integrity": "sha512-ruv7Ae4J5dUYULmeXw1gmb7rYRz57OWCPM57pHojnLq/3Z1CK2lNSLTCVjxVk1F/TZHwOZZrOWi0ur95BbLxNQ==", "license": "MIT", "dependencies": { - "@babel/helper-string-parser": "^7.25.9", - "@babel/helper-validator-identifier": "^7.25.9" + "@babel/helper-string-parser": "^7.27.1", + "@babel/helper-validator-identifier": "^7.27.1" }, "engines": { "node": ">=6.9.0" @@ -10443,9 +10443,9 @@ } }, "node_modules/picocolors": { - "version": "1.1.0", - "resolved": "https://registry.npmjs.org/picocolors/-/picocolors-1.1.0.tgz", - "integrity": "sha512-TQ92mBOW0l3LeMeyLV6mzy/kWr8lkd/hp3mTg7wYK7zJhuBStmGMBG0BdeDZS/dZx1IukaX6Bk11zcln25o1Aw==", + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/picocolors/-/picocolors-1.1.1.tgz", + "integrity": "sha512-xceH2snhtb5M9liqDsmEw56le376mTZkEX/jEb/RxNFyegNul7eNslCXP9FDj/Lcu0X8KEyMceP2ntpaHrDEVA==", "license": "ISC" }, "node_modules/picomatch": { diff --git a/js/web/docs/webnn-operators.md b/js/web/docs/webnn-operators.md index d9a030f320c6c..0fffe99ec4f78 100644 --- a/js/web/docs/webnn-operators.md +++ b/js/web/docs/webnn-operators.md @@ -80,6 +80,7 @@ platforms. Check the [WebNN status](https://webmachinelearning.github.io/webnn-s | PRelu | ai.onnx(7-8, 9-15, 16+) | prelu | | | QuantizeLinear | ai.onnx(10-12, 13-18, 19-20, 21-22, 23+) | quantizeLinear | The shape of x_scale should be a subsample of the shape of input | | Reciprocal | ai.onnx(7-12, 13+) | reciprocal | | +| Round | ai.onnx(11-21, 22+) | roundEven | | | ReduceL1 | ai.onnx(7-10, 11-12, 13-17, 18+) | reduceL1 | Input 'axes' if present should be a constant | | ReduceL2 | ai.onnx(7-10, 11-12, 13-17, 18+) | reduceL2 | Input 'axes' if present should be a constant | | ReduceLogSum| ai.onnx(7-10, 11-12, 13-17, 18+) | reduceLogSum | Input 'axes' if present should be a constant | diff --git a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc index 7623e2d88f3cd..34410a5f42630 100644 --- a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc @@ -106,7 +106,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, QEmbedLayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t, QGemm); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, QGemm); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, QMoE); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MLFloat16, QMoE); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, QMoE); // ******** End: Quantization ******************* // #ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED @@ -272,7 +273,8 @@ Status RegisterQuantizationKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, }; for (auto& function_table_entry : function_table) { diff --git a/onnxruntime/contrib_ops/cpu/moe/moe_base_cpu.h b/onnxruntime/contrib_ops/cpu/moe/moe_base_cpu.h index 73e7ee6014b95..eae96c186d471 100644 --- a/onnxruntime/contrib_ops/cpu/moe/moe_base_cpu.h +++ b/onnxruntime/contrib_ops/cpu/moe/moe_base_cpu.h @@ -6,7 +6,8 @@ #include "core/common/common.h" #include "core/framework/tensor_shape.h" #include "core/framework/op_kernel.h" -#include "contrib_ops/cpu/quantization/moe_helper.h" +#include "contrib_ops/cpu/moe/moe_helper.h" +#include namespace onnxruntime { namespace contrib { @@ -46,12 +47,21 @@ class MoEBaseCPU { if (use_sparse_mixer_) { ORT_ENFORCE(k_ == 2, "Sparse mixer only supports k=2"); } + + swiglu_fusion_ = op_kernel_info.GetAttrOrDefault("swiglu_fusion", 0); + swiglu_limit_ = op_kernel_info.GetAttrOrDefault("swiglu_limit", std::numeric_limits::infinity()); + activation_alpha_ = op_kernel_info.GetAttrOrDefault("activation_alpha", 1.0f); + activation_beta_ = op_kernel_info.GetAttrOrDefault("activation_beta", 0.0f); } bool normalize_routing_weights_; bool use_sparse_mixer_; int64_t k_; ActivationType activation_type_; + float activation_alpha_; + float activation_beta_; + float swiglu_limit_; + int64_t swiglu_fusion_; }; } // namespace contrib diff --git a/onnxruntime/contrib_ops/cpu/quantization/moe_helper.h b/onnxruntime/contrib_ops/cpu/moe/moe_helper.h similarity index 100% rename from onnxruntime/contrib_ops/cpu/quantization/moe_helper.h rename to onnxruntime/contrib_ops/cpu/moe/moe_helper.h diff --git a/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc new file mode 100644 index 0000000000000..9b35a40f64f2a --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc @@ -0,0 +1,393 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/cpu/moe/moe_quantization_cpu.h" + +#include "core/framework/allocator.h" +#include "core/framework/float16.h" +#include "core/mlas/inc/mlas.h" +#include "core/platform/threadpool.h" +#include "core/providers/cpu/math/gemm_helper.h" +#include "contrib_ops/cpu/moe/moe_utils.h" +#include "contrib_ops/cpu/moe/moe_helper.h" + +#include +#include +#include +#include +#include + +namespace onnxruntime { +namespace contrib { + +// Helper function to dequantize weights. Supports 4-bit and 8-bit symmetric quantization. +// The source quantized weights are stored as a row-major representation of the transposed +// logical weight matrix (W^T). This function dequantizes it into a float row-major W^T matrix. +template +void DequantizeBlock(const uint8_t* quantized_data, + const TScale* scales, + int64_t /*block_size*/, + int64_t num_bits, + int64_t rows, + int64_t cols, + float* dequantized_data) { + const float zero_point = num_bits == 8 ? 128.0f : 8.0f; + if (num_bits == 8) { + for (int64_t r = 0; r < rows; ++r) { + const float scale = static_cast(scales[r]); + for (int64_t c = 0; c < cols; ++c) { + // Symmetric quantization: dequantized_value = scale * (quantized_value - zero_point) + dequantized_data[r * cols + c] = scale * (static_cast(quantized_data[r * cols + c]) - zero_point); + } + } + } else if (num_bits == 4) { + const int64_t packed_cols = (cols + 1) / 2; + for (int64_t r = 0; r < rows; ++r) { + const float scale = static_cast(scales[r]); + for (int64_t c = 0; c < cols; ++c) { + const uint8_t packed_val = quantized_data[r * packed_cols + c / 2]; + // Unpack the 4-bit value. Low nibble for even columns, high nibble for odd columns. + const uint8_t quantized_val = (c % 2 == 0) ? (packed_val & 0x0F) : (packed_val >> 4); + // Symmetric quantization: dequantized_value = scale * (quantized_value - zero_point) + dequantized_data[r * cols + c] = scale * (static_cast(quantized_val) - zero_point); + } + } + } +} + +template +QMoECPU::QMoECPU(const OpKernelInfo& op_kernel_info) + : OpKernel(op_kernel_info), + MoEBaseCPU(op_kernel_info) { + ORT_ENFORCE(op_kernel_info.GetAttr("expert_weight_bits", &expert_weight_bits_).IsOK()); + ORT_ENFORCE(expert_weight_bits_ == 4 || expert_weight_bits_ == 8, + "Attribute 'expert_weight_bits' must be 4 or 8."); + block_size_ = op_kernel_info.GetAttrOrDefault("block_size", 0); +} + +template +Status QMoECPU::Compute(OpKernelContext* context) const { + // --- 1. Get Inputs and Attributes --- + const auto* input = context->Input(0); + const auto* router_probs = context->Input(1); + const auto* fc1_experts_weights = context->Input(2); + const auto* fc1_scales = context->Input(3); + const auto* fc1_experts_bias = context->Input(4); + const auto* fc2_experts_weights = context->Input(5); + const auto* fc2_scales = context->Input(6); + const auto* fc2_experts_bias = context->Input(7); + const auto* fc3_experts_weights = context->Input(8); + const auto* fc3_scales = context->Input(9); + const auto* fc3_experts_bias = context->Input(10); + + MoEParameters moe_params; + ORT_RETURN_IF_ERROR(moe_helper::CheckInputs( + moe_params, input, router_probs, + fc1_experts_weights, fc1_experts_bias, fc1_scales, + fc2_experts_weights, fc2_experts_bias, fc2_scales, + fc3_experts_weights, fc3_experts_bias, fc3_scales, + expert_weight_bits_ == 4 ? 2 : 1, + true)); + + if (fc3_experts_weights || fc3_experts_bias || fc3_scales) { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "FC3 gating is not yet implemented on CPU for QMoE"); + } + + const auto& input_shape = input->Shape(); + const int64_t num_tokens = moe_params.num_rows; + const int64_t hidden_size = moe_params.hidden_size; + const int64_t inter_size = moe_params.inter_size; + const int64_t num_experts = moe_params.num_experts; + const int64_t fc1_out_features = inter_size * (swiglu_fusion_ > 0 ? 2 : 1); + + auto* output = context->Output(0, input_shape); + auto* tp = context->GetOperatorThreadPool(); + + AllocatorPtr allocator; + ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); + + const size_t output_buffer_size = static_cast(output->Shape().Size()); + + const T* input_data = input->Data(); + const T* router_probs_data = router_probs->Data(); + + // --- 2. Routing Logic: Assign tokens to experts --- + IAllocatorUniquePtr router_logits_float_buffer; + const float* router_logits_float; + if constexpr (std::is_same_v) { + router_logits_float_buffer = IAllocator::MakeUniquePtr(allocator, static_cast(num_tokens * num_experts)); + router_logits_float = router_logits_float_buffer.get(); + MlasConvertHalfToFloatBuffer(reinterpret_cast(router_probs_data), + const_cast(router_logits_float), + static_cast(num_tokens * num_experts)); + } else { + router_logits_float = reinterpret_cast(router_probs_data); + } + + auto route_expert_ptr = IAllocator::MakeUniquePtr(allocator, static_cast(num_tokens * k_)); + int* route_expert = route_expert_ptr.get(); + auto route_scale_ptr = IAllocator::MakeUniquePtr(allocator, static_cast(num_tokens * k_)); + float* route_scale = route_scale_ptr.get(); + + // Parallelize the routing logic to improve performance for large token batches. + // Minor performance regression for single-token decoding is an acceptable trade-off + int num_routing_threads = (tp == nullptr || num_tokens < 4096) ? 1 : std::min(static_cast(num_tokens), concurrency::ThreadPool::DegreeOfParallelism(tp)); + + std::vector>> thread_local_expert_token_maps(num_routing_threads); + for (auto& map : thread_local_expert_token_maps) { + map.resize(static_cast(num_experts)); + } + + concurrency::ThreadPool::TrySimpleParallelFor(tp, num_routing_threads, [&](std::ptrdiff_t thread_id) { + auto work = concurrency::ThreadPool::PartitionWork(static_cast(thread_id), num_routing_threads, static_cast(num_tokens)); + auto& local_expert_token_map = thread_local_expert_token_maps[thread_id]; + + // Pre-allocate buffers for this thread to reuse, avoiding allocations inside the loop. + std::vector> sorted_logits(static_cast(num_experts)); + std::vector top_k_exp(static_cast(k_)); + + for (int64_t i = work.start; i < work.end; ++i) { + const float* logits = router_logits_float + i * num_experts; + for (int64_t j = 0; j < num_experts; ++j) { + sorted_logits[static_cast(j)] = {logits[j], j}; + } + std::partial_sort(sorted_logits.begin(), sorted_logits.begin() + static_cast(k_), sorted_logits.end(), std::greater<>()); + + float max_logit = -std::numeric_limits::infinity(); + for (int64_t j = 0; j < k_; ++j) { + if (sorted_logits[static_cast(j)].first > max_logit) { + max_logit = sorted_logits[static_cast(j)].first; + } + } + + float sum_exp = 0.0f; + for (int64_t j = 0; j < k_; ++j) { + top_k_exp[static_cast(j)] = std::exp(sorted_logits[static_cast(j)].first - max_logit); + sum_exp += top_k_exp[static_cast(j)]; + } + + float scale = (sum_exp == 0.0f) ? 0.0f : (1.0f / sum_exp); + for (int64_t j = 0; j < k_; ++j) { + int64_t expert_idx = sorted_logits[static_cast(j)].second; + int64_t route_idx = i * k_ + j; + route_expert[route_idx] = static_cast(expert_idx); + route_scale[route_idx] = top_k_exp[static_cast(j)] * scale; + if (route_scale[route_idx] > 0.0f) { + local_expert_token_map[static_cast(expert_idx)].push_back(route_idx); + } + } + } + }); + + // Merge the maps from each thread into a single global map. + std::vector> expert_token_map(static_cast(num_experts)); + for (int64_t expert_idx = 0; expert_idx < num_experts; ++expert_idx) { + size_t total_tokens_for_expert = 0; + for (int t = 0; t < num_routing_threads; ++t) { + total_tokens_for_expert += thread_local_expert_token_maps[t][static_cast(expert_idx)].size(); + } + expert_token_map[static_cast(expert_idx)].reserve(total_tokens_for_expert); + } + + for (int t = 0; t < num_routing_threads; ++t) { + for (int64_t expert_idx = 0; expert_idx < num_experts; ++expert_idx) { + auto& local_tokens = thread_local_expert_token_maps[t][static_cast(expert_idx)]; + if (!local_tokens.empty()) { + expert_token_map[static_cast(expert_idx)].insert(expert_token_map[static_cast(expert_idx)].end(), local_tokens.begin(), local_tokens.end()); + } + } + } + + // --- 3. Parallel Expert Computation --- + IAllocatorUniquePtr input_float_buffer; + const float* input_float; + if constexpr (std::is_same_v) { + input_float_buffer = IAllocator::MakeUniquePtr(allocator, static_cast(num_tokens * hidden_size)); + input_float = input_float_buffer.get(); + MlasConvertHalfToFloatBuffer(reinterpret_cast(input_data), + const_cast(input_float), + static_cast(num_tokens * hidden_size)); + } else { + input_float = reinterpret_cast(input_data); + } + + int num_expert_threads = (tp == nullptr) ? 1 : std::min(static_cast(num_experts), concurrency::ThreadPool::DegreeOfParallelism(tp)); + if (num_expert_threads == 0) num_expert_threads = 1; + auto thread_local_outputs_ptr = IAllocator::MakeUniquePtr(allocator, static_cast(num_expert_threads) * output_buffer_size); + float* thread_local_outputs = thread_local_outputs_ptr.get(); + memset(thread_local_outputs, 0, static_cast(num_expert_threads) * output_buffer_size * sizeof(float)); + + // Pre-calculate workspace size per thread to avoid allocations inside the loop + size_t max_tokens_per_expert = 0; + for (const auto& tokens : expert_token_map) { + if (tokens.size() > max_tokens_per_expert) { + max_tokens_per_expert = tokens.size(); + } + } + + const size_t A1_size = static_cast(max_tokens_per_expert * hidden_size); + const size_t C1_size = static_cast(max_tokens_per_expert * fc1_out_features); + const size_t A2_size = static_cast(max_tokens_per_expert * inter_size); + const size_t C2_size = static_cast(max_tokens_per_expert * hidden_size); + const size_t B1_dequant_size = static_cast(fc1_out_features * hidden_size); + const size_t B2_dequant_size = static_cast(hidden_size * inter_size); + const size_t bias1_size = static_cast(fc1_out_features); + const size_t bias2_size = static_cast(hidden_size); + + const size_t workspace_elements_per_thread = A1_size + C1_size + A2_size + C2_size + B1_dequant_size + B2_dequant_size + bias1_size + bias2_size; + auto workspace_ptr = IAllocator::MakeUniquePtr(allocator, static_cast(num_expert_threads) * workspace_elements_per_thread); + float* workspace = workspace_ptr.get(); + + concurrency::ThreadPool::TrySimpleParallelFor(tp, num_expert_threads, [&](std::ptrdiff_t thread_id_pd) { + int thread_id = static_cast(thread_id_pd); + auto work = concurrency::ThreadPool::PartitionWork(thread_id, num_expert_threads, static_cast(num_experts)); + + float* thread_workspace = workspace + static_cast(thread_id) * workspace_elements_per_thread; + + for (int64_t expert_idx = work.start; expert_idx < work.end; ++expert_idx) { + const auto& routes = expert_token_map[static_cast(expert_idx)]; + if (routes.empty()) { + continue; + } + + const int64_t num_expert_tokens = routes.size(); + + // Partition the workspace for the current expert + float* A1 = thread_workspace; + float* C1 = A1 + num_expert_tokens * hidden_size; + float* A2 = C1 + num_expert_tokens * fc1_out_features; + float* C2 = A2 + num_expert_tokens * inter_size; + float* B1_dequant = C2 + num_expert_tokens * hidden_size; + float* B2_dequant = B1_dequant + fc1_out_features * hidden_size; + float* bias1_float = B2_dequant + hidden_size * inter_size; + float* bias2_float = bias1_float + fc1_out_features; + + // --- Gather input tokens for the current expert --- + for (int64_t i = 0; i < num_expert_tokens; ++i) { + const int64_t token_idx = routes[static_cast(i)] / k_; + memcpy(A1 + i * hidden_size, + input_float + token_idx * hidden_size, + static_cast(hidden_size) * sizeof(float)); + } + + // --- FC1 GEMM (X * W1^T) --- + DequantizeBlock(fc1_experts_weights->Data() + expert_idx * fc1_out_features * (hidden_size / (8 / expert_weight_bits_)), + fc1_scales->Data() + expert_idx * fc1_out_features * (block_size_ > 0 ? hidden_size / block_size_ : 1), + block_size_, expert_weight_bits_, + fc1_out_features, hidden_size, B1_dequant); + + MlasGemm(CblasNoTrans, CblasTrans, + static_cast(num_expert_tokens), static_cast(fc1_out_features), static_cast(hidden_size), + 1.0f, A1, static_cast(hidden_size), + B1_dequant, static_cast(hidden_size), + 0.0f, C1, static_cast(fc1_out_features), + nullptr); + + const T* B1_bias = (fc1_experts_bias) ? fc1_experts_bias->Data() + expert_idx * fc1_out_features : nullptr; + if (B1_bias) { + if constexpr (std::is_same_v) { + MlasConvertHalfToFloatBuffer(reinterpret_cast(B1_bias), bias1_float, static_cast(fc1_out_features)); + } else { + memcpy(bias1_float, B1_bias, static_cast(fc1_out_features) * sizeof(float)); + } + for (int64_t i = 0; i < num_expert_tokens; ++i) { + for (int64_t j = 0; j < fc1_out_features; ++j) { + C1[i * fc1_out_features + j] += bias1_float[j]; + } + } + } + + // --- Activation --- + for (int64_t i = 0; i < num_expert_tokens; ++i) { + const float* C1_token = C1 + i * fc1_out_features; + float* A2_token = A2 + i * inter_size; + ApplySwiGLUActivation(C1_token, A2_token, inter_size, true, activation_alpha_, activation_beta_, swiglu_limit_); + } + + // --- FC2 GEMM (A2 * W2^T) --- + DequantizeBlock(fc2_experts_weights->Data() + expert_idx * hidden_size * (inter_size / (8 / expert_weight_bits_)), + fc2_scales->Data() + expert_idx * hidden_size * (block_size_ > 0 ? inter_size / block_size_ : 1), + block_size_, expert_weight_bits_, + hidden_size, inter_size, B2_dequant); + + MlasGemm(CblasNoTrans, CblasTrans, + static_cast(num_expert_tokens), static_cast(hidden_size), static_cast(inter_size), + 1.0f, A2, static_cast(inter_size), + B2_dequant, static_cast(inter_size), + 0.0f, C2, static_cast(hidden_size), + nullptr); + + const T* B2_bias = (fc2_experts_bias) ? fc2_experts_bias->Data() + expert_idx * hidden_size : nullptr; + if (B2_bias) { + if constexpr (std::is_same_v) { + MlasConvertHalfToFloatBuffer(reinterpret_cast(B2_bias), bias2_float, static_cast(hidden_size)); + } else { + memcpy(bias2_float, B2_bias, static_cast(hidden_size) * sizeof(float)); + } + } + + for (int64_t i = 0; i < num_expert_tokens; ++i) { + const int64_t route_idx = routes[static_cast(i)]; + const int64_t token_idx = route_idx / k_; + const float weight = route_scale[route_idx]; + + float* dest = thread_local_outputs + static_cast(thread_id) * output_buffer_size + token_idx * hidden_size; + const float* src = C2 + i * hidden_size; + for (int64_t j = 0; j < hidden_size; ++j) { + dest[j] += weight * (src[j] + (B2_bias ? bias2_float[j] : 0.0f)); + } + } + } + }); + + // --- 4. Final Reduction (accumulate expert outputs to a float buffer) --- + auto accumulate = [&](float* buffer) { + memset(buffer, 0, output_buffer_size * sizeof(float)); + for (int i = 0; i < num_expert_threads; ++i) { + for (size_t j = 0; j < output_buffer_size; ++j) { + buffer[j] += thread_local_outputs[static_cast(i) * output_buffer_size + j]; + } + } + }; + + if constexpr (std::is_same_v) { + auto final_output_float_ptr = IAllocator::MakeUniquePtr(allocator, output_buffer_size); + float* final_output_float = final_output_float_ptr.get(); + accumulate(final_output_float); + + // --- 5. Convert final float buffer to output type T --- + MlasConvertFloatToHalfBuffer(final_output_float, + reinterpret_cast(output->MutableData()), + static_cast(output_buffer_size)); + } else { // T is float + accumulate(output->MutableData()); + } + + return Status::OK(); +} + +// Explicit template instantiation +template QMoECPU::QMoECPU(const OpKernelInfo& op_kernel_info); +template Status QMoECPU::Compute(OpKernelContext* context) const; +template QMoECPU::QMoECPU(const OpKernelInfo& op_kernel_info); +template Status QMoECPU::Compute(OpKernelContext* context) const; + +// Kernel Registration +ONNX_OPERATOR_TYPED_KERNEL_EX( + QMoE, kMSDomain, 1, float, kCpuExecutionProvider, + KernelDefBuilder() + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + QMoECPU); + +ONNX_OPERATOR_TYPED_KERNEL_EX( + QMoE, kMSDomain, 1, MLFloat16, kCpuExecutionProvider, + KernelDefBuilder() + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + QMoECPU); + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.h b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.h new file mode 100644 index 0000000000000..890580e051a8e --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.h @@ -0,0 +1,34 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/common.h" +#include "core/framework/op_kernel.h" +#include "contrib_ops/cpu/moe/moe_base_cpu.h" + +namespace onnxruntime { +namespace contrib { + +/** + * @brief QMoE is the templated CPU implementation of the Quantized Mixture of Experts operator. + * + * This kernel supports both float and MLFloat16 data types for activations, scales, and outputs. + * It parallelizes expert computation using the ONNX Runtime thread pool and minimizes memory + * usage through on-the-fly block dequantization of weights. + * + * @tparam T The data type for the kernel (float or MLFloat16). + */ +template +class QMoECPU final : public OpKernel, public MoEBaseCPU { + public: + explicit QMoECPU(const OpKernelInfo& op_kernel_info); + Status Compute(OpKernelContext* context) const override; + + private: + int64_t expert_weight_bits_; + int64_t block_size_; +}; + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/moe/moe_utils.cc b/onnxruntime/contrib_ops/cpu/moe/moe_utils.cc index 6214b7819b765..2c59210bfabd4 100644 --- a/onnxruntime/contrib_ops/cpu/moe/moe_utils.cc +++ b/onnxruntime/contrib_ops/cpu/moe/moe_utils.cc @@ -4,6 +4,7 @@ #include "contrib_ops/cpu/moe/moe_utils.h" #include #include +#include "core/common/common.h" namespace onnxruntime { namespace contrib { @@ -19,74 +20,31 @@ float ApplyActivation(float x, ActivationType activation_type) { case ActivationType::Identity: return x; case ActivationType::SwiGLU: - // SwiGLU: This is handled specially as it requires gating, not applied here + // SwiGLU is a special case handled by ApplySwiGLUActivation, this is just a placeholder return x; default: - return x; // Default to identity + return x; } } -// Helper method for applying SwiGLU activation with different memory layouts -void ApplySwiGLUActivation(float* data, int64_t inter_size, bool is_interleaved_format) { - constexpr float swiglu_alpha = 1.702f; - constexpr float clamp_limit = 7.0f; // Clamping limit as specified - +void ApplySwiGLUActivation(const float* input_data, float* output_data, int64_t inter_size, bool is_interleaved_format, + float activation_alpha, float activation_beta, float clamp_limit) { if (is_interleaved_format) { - // For interleaved format [linear, gate, linear, gate, ...], process directly - // Make a temporary copy of each pair of values before modifying them for (int64_t i = 0; i < inter_size; ++i) { - const size_t idx = static_cast(i); - const size_t linear_idx = 2 * idx; - const size_t gate_idx = linear_idx + 1; + float gate_val = input_data[2 * i]; + float linear_val = input_data[2 * i + 1]; - // Store original values - float linear_val = data[linear_idx]; // Interleaved: even index - float gate_val = data[gate_idx]; // Interleaved: odd index + gate_val = std::min(gate_val, clamp_limit); + linear_val = std::clamp(linear_val, -clamp_limit, clamp_limit); - // Apply clamping to the values - if (gate_val > clamp_limit) gate_val = clamp_limit; // Clamp gate max only - if (linear_val > clamp_limit) linear_val = clamp_limit; // Clamp linear min/max - if (linear_val < -clamp_limit) linear_val = -clamp_limit; - - // SwiGLU: gate * sigmoid(alpha * gate) * (linear + 1) - float sigmoid_arg = swiglu_alpha * gate_val; + float sigmoid_arg = activation_alpha * gate_val; float sigmoid_out = 1.0f / (1.0f + std::exp(-sigmoid_arg)); float swish_out = gate_val * sigmoid_out; - float result = swish_out * (linear_val + 1.0f); - // Store result in first element (linear position) - data[idx] = result; + output_data[i] = swish_out * (linear_val + activation_beta); } } else { - // For chunked layout [linear..., gate...], handle separately - // Need to work with original data in-place - // First, store all the gate computations since they depend on original gate values - std::vector computed_gates(static_cast(inter_size)); - - for (int64_t i = 0; i < inter_size; ++i) { - const size_t idx = static_cast(i); - float gate_val = data[idx + static_cast(inter_size)]; - - // Apply clamping to the gate value (max only) - if (gate_val > clamp_limit) gate_val = clamp_limit; - - // Compute the gate part of SwiGLU - float sigmoid_arg = swiglu_alpha * gate_val; - float sigmoid_out = 1.0f / (1.0f + std::exp(-sigmoid_arg)); - computed_gates[idx] = gate_val * sigmoid_out; - } - - // Now apply the full activation with the precomputed gate values - for (int64_t i = 0; i < inter_size; ++i) { - const size_t idx = static_cast(i); - float linear_val = data[idx]; - - // Apply clamping to the linear value (min/max) - if (linear_val > clamp_limit) linear_val = clamp_limit; - if (linear_val < -clamp_limit) linear_val = -clamp_limit; - - data[idx] = computed_gates[idx] * (linear_val + 1.0f); - } + ORT_NOT_IMPLEMENTED("Non-interleaved format not supported for SwiGLU activation"); } } diff --git a/onnxruntime/contrib_ops/cpu/moe/moe_utils.h b/onnxruntime/contrib_ops/cpu/moe/moe_utils.h index e20dc101c7412..de238e8d7ae66 100644 --- a/onnxruntime/contrib_ops/cpu/moe/moe_utils.h +++ b/onnxruntime/contrib_ops/cpu/moe/moe_utils.h @@ -9,7 +9,9 @@ namespace onnxruntime { namespace contrib { float ApplyActivation(float x, ActivationType activation_type); -void ApplySwiGLUActivation(float* data, int64_t inter_size, bool is_interleaved_format); + +void ApplySwiGLUActivation(const float* input_data, float* output_data, int64_t inter_size, bool is_interleaved_format, + float activation_alpha, float activation_beta, float clamp_limit); } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/quantization/moe_quantization_cpu.cc b/onnxruntime/contrib_ops/cpu/quantization/moe_quantization_cpu.cc deleted file mode 100644 index 8bd4dcf1afbab..0000000000000 --- a/onnxruntime/contrib_ops/cpu/quantization/moe_quantization_cpu.cc +++ /dev/null @@ -1,596 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "contrib_ops/cpu/quantization/moe_quantization_cpu.h" -#include "core/framework/allocator.h" -#include "core/framework/buffer_deleter.h" -#include "core/mlas/inc/mlas.h" -#include "core/mlas/inc/mlas_q4.h" -#include "core/mlas/inc/mlas_qnbit.h" -#include "core/platform/threadpool.h" -#include "contrib_ops/cpu/moe/moe_utils.h" -#include - -using namespace onnxruntime::common; -using namespace ONNX_NAMESPACE; - -namespace onnxruntime { -namespace contrib { - -#define REGISTER_KERNEL() \ - ONNX_OPERATOR_KERNEL_EX(QMoE, kMSDomain, 1, kCpuExecutionProvider, \ - (*KernelDefBuilder::Create()) \ - .MayInplace(0, 0) \ - .TypeConstraint("T", BuildKernelDefConstraints()) \ - .TypeConstraint("T1", BuildKernelDefConstraints()) \ - .TypeConstraint("T2", BuildKernelDefConstraints()), \ - QMoE); - -REGISTER_KERNEL(); - -// QMoE CPU kernel registration is handled in cpu_contrib_kernels.cc - -QMoE::QMoE(const OpKernelInfo& op_kernel_info) - : OpKernel(op_kernel_info), - MoEBaseCPU(op_kernel_info), - prepacked_fc1_weights_data_(nullptr), - prepacked_fc2_weights_data_(nullptr), - weights_allocator_(nullptr), - is_prepacked_(false) { - ORT_ENFORCE(op_kernel_info.GetAttr("expert_weight_bits", &expert_weight_bits_).IsOK()); - ORT_ENFORCE(expert_weight_bits_ == 8 || expert_weight_bits_ == 4, - "expert_weight_bits must be 4 or 8, but got ", expert_weight_bits_); -} - -Status QMoE::Compute(OpKernelContext* context) const { - const Tensor* input = context->Input(0); - const Tensor* router_probs = context->Input(1); - const Tensor* fc1_experts_weights = context->Input(2); - const Tensor* fc1_scales = context->Input(3); - const Tensor* fc1_experts_bias_optional = context->Input(4); - const Tensor* fc2_experts_weights = context->Input(5); - const Tensor* fc2_scales = context->Input(6); - const Tensor* fc2_experts_bias_optional = context->Input(7); - const Tensor* fc3_experts_weights_optional = context->Input(8); - const Tensor* fc3_scales_optional = context->Input(9); - const Tensor* fc3_experts_bias_optional = context->Input(10); - - MoEParameters moe_params; - ORT_RETURN_IF_ERROR(::onnxruntime::contrib::moe_helper::CheckInputs( - moe_params, input, router_probs, - fc1_experts_weights, fc1_experts_bias_optional, fc1_scales, - fc2_experts_weights, fc2_experts_bias_optional, fc2_scales, - fc3_experts_weights_optional, fc3_experts_bias_optional, fc3_scales_optional, - expert_weight_bits_ == 4 ? 2 : 1, - activation_type_ == ActivationType::SwiGLU)); - - // Dispatch based on input data type - if (input->IsDataType()) { - if (expert_weight_bits_ == 4) { - return QuantizedMoEImpl(context, moe_params, input, router_probs, - fc1_experts_weights, fc1_experts_bias_optional, fc2_experts_weights, - fc2_experts_bias_optional, fc3_experts_weights_optional, - fc3_experts_bias_optional, fc1_scales, fc2_scales, fc3_scales_optional); - } else { - return QuantizedMoEImpl(context, moe_params, input, router_probs, - fc1_experts_weights, fc1_experts_bias_optional, fc2_experts_weights, - fc2_experts_bias_optional, fc3_experts_weights_optional, - fc3_experts_bias_optional, fc1_scales, fc2_scales, fc3_scales_optional); - } - } else if (input->IsDataType()) { - if (expert_weight_bits_ == 4) { - return QuantizedMoEImpl(context, moe_params, input, router_probs, - fc1_experts_weights, fc1_experts_bias_optional, fc2_experts_weights, - fc2_experts_bias_optional, fc3_experts_weights_optional, - fc3_experts_bias_optional, fc1_scales, fc2_scales, fc3_scales_optional); - } else { - return QuantizedMoEImpl(context, moe_params, input, router_probs, - fc1_experts_weights, fc1_experts_bias_optional, fc2_experts_weights, - fc2_experts_bias_optional, fc3_experts_weights_optional, - fc3_experts_bias_optional, fc1_scales, fc2_scales, fc3_scales_optional); - } - } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "QMoE only supports float and MLFloat16 data types, but got ", - DataTypeImpl::ToString(input->DataType())); - } -} - -template -Status QMoE::QuantizedMoEImpl(OpKernelContext* context, - MoEParameters& moe_params, - const Tensor* input, - const Tensor* router_probs, - const Tensor* fc1_experts_weights, - const Tensor* fc1_experts_bias_optional, - const Tensor* fc2_experts_weights, - const Tensor* fc2_experts_bias_optional, - const Tensor* fc3_experts_weights_optional, - const Tensor* fc3_experts_bias_optional, - const Tensor* fc1_scales, - const Tensor* fc2_scales, - const Tensor* fc3_scales_optional) const { - // SwiGLU validation - FC3 not supported - bool is_swiglu = (activation_type_ == ActivationType::SwiGLU); - if (is_swiglu && fc3_experts_weights_optional != nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, - "SwiGLU activation is not supported with fc3."); - } - if (!is_swiglu && fc3_experts_weights_optional != nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, - "FC3 gating is not yet implemented on CPU."); - } - - // Check if we need to repack weights - if (!is_prepacked_ || - cached_num_experts_ != moe_params.num_experts || - cached_hidden_size_ != moe_params.hidden_size || - cached_inter_size_ != moe_params.inter_size || - cached_is_swiglu_ != is_swiglu) { - // Need to prepack weights - Status status = const_cast(this)->PrepackAndDequantizeWeights( - context, moe_params, fc1_experts_weights, fc2_experts_weights, - fc1_scales, fc2_scales, is_swiglu); - ORT_RETURN_IF_ERROR(status); - } - // Get thread pool - auto* thread_pool = context->GetOperatorThreadPool(); - - // Get input data pointers - const T* input_data = input->Data(); - const T* router_probs_data = router_probs->Data(); - const T* fc1_bias_data = fc1_experts_bias_optional ? fc1_experts_bias_optional->Data() : nullptr; - const T* fc2_bias_data = fc2_experts_bias_optional ? fc2_experts_bias_optional->Data() : nullptr; - - Tensor* output = context->Output(0, input->Shape()); - T* output_data = output->MutableData(); - - AllocatorPtr allocator; - ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); - - const int64_t num_threads = std::min( - static_cast(concurrency::ThreadPool::DegreeOfParallelism(thread_pool)), - moe_params.num_rows); - - const int64_t total_output_size = moe_params.num_rows * moe_params.hidden_size; - std::fill_n(output_data, total_output_size, MLFloat16(0.0f)); - - // Using prepacked weights - no need to convert scales - - auto thread_fc1_buffers = IAllocator::MakeUniquePtr(allocator, static_cast(num_threads * moe_params.inter_size * (is_swiglu ? 2 : 1))); - auto thread_fc2_buffers = IAllocator::MakeUniquePtr(allocator, static_cast(num_threads * moe_params.hidden_size)); - - // Set up output buffer - IAllocatorUniquePtr output_float; - float* output_float_ptr = nullptr; - - if constexpr (std::is_same_v) { - // For MLFloat16, we need a separate float buffer - output_float = IAllocator::MakeUniquePtr(allocator, static_cast(total_output_size)); - output_float_ptr = output_float.get(); - } else { - // For float, we can write directly to output_data - output_float = IAllocatorUniquePtr(output_data, [](float*) {}); - output_float_ptr = output_data; - } - - // Initialize output to zeros - std::fill_n(output_float_ptr, total_output_size, 0.0f); - - // Prepare float buffers for input data and biases - IAllocatorUniquePtr input_float; - IAllocatorUniquePtr router_probs_float; - - // Pointers for easier access - float* input_float_ptr = nullptr; - float* router_probs_float_ptr = nullptr; - - // Pre-convert bias tensors to float (if they exist) - const int64_t fc1_bias_size = is_swiglu ? 2 * moe_params.inter_size : moe_params.inter_size; - const int64_t fc2_bias_size = moe_params.hidden_size; - - // Allocate buffers for converted biases using ORT allocator - IAllocatorUniquePtr fc1_bias_float; - IAllocatorUniquePtr fc2_bias_float; - - if (fc1_bias_data) { - fc1_bias_float = IAllocator::MakeUniquePtr(allocator, static_cast(moe_params.num_experts * fc1_bias_size)); - } - - if (fc2_bias_data) { - fc2_bias_float = IAllocator::MakeUniquePtr(allocator, static_cast(moe_params.num_experts * fc2_bias_size)); - } - - // Convert input and router_probs based on type - if constexpr (std::is_same_v) { - // For MLFloat16, convert to float - need to allocate buffers first - input_float = IAllocator::MakeUniquePtr(allocator, static_cast(moe_params.num_rows * moe_params.hidden_size)); - router_probs_float = IAllocator::MakeUniquePtr(allocator, static_cast(moe_params.num_rows * moe_params.num_experts)); - - input_float_ptr = input_float.get(); - router_probs_float_ptr = router_probs_float.get(); - - // Convert MLFloat16 to float - MlasConvertHalfToFloatBufferInParallel(reinterpret_cast(input_data), - input_float_ptr, - static_cast(moe_params.num_rows * moe_params.hidden_size), - thread_pool); - - MlasConvertHalfToFloatBufferInParallel(reinterpret_cast(router_probs_data), - router_probs_float_ptr, - static_cast(moe_params.num_rows * moe_params.num_experts), - thread_pool); - - // Convert biases to float once (if they exist) - if (fc1_bias_data) { - MlasConvertHalfToFloatBufferInParallel(reinterpret_cast(fc1_bias_data), - fc1_bias_float.get(), - static_cast(moe_params.num_experts * fc1_bias_size), - thread_pool); - } - - if (fc2_bias_data) { - MlasConvertHalfToFloatBufferInParallel(reinterpret_cast(fc2_bias_data), - fc2_bias_float.get(), - static_cast(moe_params.num_experts * fc2_bias_size), - thread_pool); - } - } else { - // For float, point to original input and router_probs directly instead of copying - input_float = IAllocatorUniquePtr(const_cast(input_data), [](float*) {}); - router_probs_float = IAllocatorUniquePtr(const_cast(router_probs_data), [](float*) {}); - - // Set pointers to the original data - input_float_ptr = const_cast(input_data); - router_probs_float_ptr = const_cast(router_probs_data); - - // For float, just point to the original bias data directly without copying - // No need to allocate or copy, just reuse the original pointers - if (fc1_bias_data) { - // Release previously allocated memory if any - fc1_bias_float.reset(); - // Direct pointer to original data - fc1_bias_float = IAllocatorUniquePtr(const_cast(fc1_bias_data), [](float*) {}); - } - - if (fc2_bias_data) { - // Release previously allocated memory if any - fc2_bias_float.reset(); - // Direct pointer to original data - fc2_bias_float = IAllocatorUniquePtr(const_cast(fc2_bias_data), [](float*) {}); - } - } - - // No need to initialize thread results - using direct output buffer - - // Determine activation related parameters - const bool is_4bit = UseUInt4x2; - const int64_t act_multiplier = is_swiglu ? 2 : 1; - const int64_t fc1_output_size = is_swiglu ? 2 * moe_params.inter_size : moe_params.inter_size; - - // Use prepacked dequantized weights - no need to dequantize here - const float* dequant_fc1_weights = prepacked_fc1_weights_data_; - const float* dequant_fc2_weights = prepacked_fc2_weights_data_; - - // Process tokens in parallel - concurrency::ThreadPool::TryParallelFor( - thread_pool, static_cast(moe_params.num_rows), - static_cast(std::max(1, moe_params.num_rows / num_threads)), - [&](ptrdiff_t start_token, ptrdiff_t end_token) { - const int64_t thread_id = start_token / ((moe_params.num_rows + num_threads - 1) / num_threads); - const int64_t thread_fc1_size = is_4bit ? (moe_params.inter_size * (is_swiglu ? 2 : 1)) : (moe_params.inter_size * act_multiplier); - float* thread_fc1_output = thread_fc1_buffers.get() + thread_id * thread_fc1_size; - float* thread_fc2_output = thread_fc2_buffers.get() + thread_id * moe_params.hidden_size; - - // Process each token in this thread's range - for (std::ptrdiff_t token_idx = start_token; token_idx < end_token; ++token_idx) { - const float* token_input = input_float_ptr + static_cast(SafeInt(token_idx)) * moe_params.hidden_size; - float* token_result = output_float_ptr + static_cast(SafeInt(token_idx)) * moe_params.hidden_size; - - // Process all experts for this token - for (std::ptrdiff_t expert_idx = 0; expert_idx < moe_params.num_experts; ++expert_idx) { - float routing_weight = router_probs_float_ptr[static_cast(SafeInt(token_idx)) * moe_params.num_experts + static_cast(SafeInt(expert_idx))]; - if (routing_weight <= 1e-6f) continue; // Skip experts with negligible routing weight - - // FC1: input -> intermediate using pre-dequantized weights + MLAS SGEMM - const int64_t fc1_weight_offset = is_4bit ? (static_cast(SafeInt(expert_idx)) * moe_params.hidden_size * fc1_output_size) : (static_cast(SafeInt(expert_idx)) * moe_params.hidden_size * moe_params.inter_size * act_multiplier); - const float* fc1_expert_weights = dequant_fc1_weights + fc1_weight_offset; - - // Bias size is always equal to output size (fc1_output_size), regardless of bit width - const int64_t fc1_bias_size = fc1_output_size; - - // Use MLAS SGEMM for FC1 - MLAS_SGEMM_DATA_PARAMS fc1_params; - fc1_params.A = token_input; - fc1_params.lda = static_cast(moe_params.hidden_size); - fc1_params.B = fc1_expert_weights; - fc1_params.ldb = static_cast(moe_params.hidden_size); - fc1_params.C = thread_fc1_output; - fc1_params.ldc = static_cast(fc1_bias_size); - fc1_params.alpha = 1.0f; - fc1_params.beta = 0.0f; - - MlasGemm(CblasNoTrans, CblasNoTrans, 1, static_cast(fc1_bias_size), static_cast(moe_params.hidden_size), fc1_params, nullptr); - - // Handle different activation types - if (is_swiglu) { - // Add bias if present - if (fc1_bias_data) { - // Use the pre-converted float bias data - const float* fc1_expert_bias_float = fc1_bias_float.get() + static_cast(SafeInt(expert_idx)) * fc1_bias_size; - for (int64_t i = 0; i < fc1_bias_size; ++i) { - thread_fc1_output[i] += fc1_expert_bias_float[i]; - } - } - contrib::ApplySwiGLUActivation(thread_fc1_output, moe_params.inter_size, is_4bit); - } else { - // Standard activation (non-SwiGLU) - if (fc1_bias_data) { - // Use the pre-converted float bias data - const float* fc1_expert_bias_float = fc1_bias_float.get() + static_cast(SafeInt(expert_idx)) * moe_params.inter_size; - for (int64_t i = 0; i < moe_params.inter_size; ++i) { - thread_fc1_output[i] += fc1_expert_bias_float[i]; - thread_fc1_output[i] = ApplyActivation(thread_fc1_output[i], activation_type_); - } - } else { - for (int64_t i = 0; i < moe_params.inter_size; ++i) { - thread_fc1_output[i] = ApplyActivation(thread_fc1_output[i], activation_type_); - } - } - } - - // FC2: intermediate -> output using pre-dequantized weights + MLAS SGEMM - const float* fc2_expert_weights = dequant_fc2_weights + static_cast(SafeInt(expert_idx)) * moe_params.inter_size * moe_params.hidden_size; - - // Use MLAS SGEMM for FC2 - MLAS_SGEMM_DATA_PARAMS fc2_params; - fc2_params.A = thread_fc1_output; - fc2_params.lda = static_cast(moe_params.inter_size); - fc2_params.B = fc2_expert_weights; - fc2_params.ldb = static_cast(moe_params.inter_size); - fc2_params.C = thread_fc2_output; - fc2_params.ldc = static_cast(moe_params.hidden_size); - fc2_params.alpha = 1.0f; - fc2_params.beta = 0.0f; - - MlasGemm(CblasNoTrans, CblasNoTrans, 1, static_cast(moe_params.hidden_size), static_cast(moe_params.inter_size), fc2_params, nullptr); - - // Add bias, apply routing weight, and accumulate to final result - if (fc2_bias_data) { - // Use the pre-converted float bias data - const float* fc2_expert_bias_float = fc2_bias_float.get() + static_cast(SafeInt(expert_idx)) * moe_params.hidden_size; - for (int64_t i = 0; i < moe_params.hidden_size; ++i) { - token_result[i] += routing_weight * (thread_fc2_output[i] + fc2_expert_bias_float[i]); - } - } else { - for (int64_t i = 0; i < moe_params.hidden_size; ++i) { - token_result[i] += routing_weight * thread_fc2_output[i]; - } - } - } - } - }); - - // No need for accumulation since threads write directly to output_float - - // Convert results back to the appropriate output type, if needed - if constexpr (std::is_same_v) { - // For MLFloat16, convert from float to half - MlasConvertFloatToHalfBuffer(output_float_ptr, reinterpret_cast(output_data), static_cast(total_output_size)); - } - // For float, no conversion needed as we directly wrote to output_data - - // Suppress unused parameter warnings for optional parameters that are not used in non-SwiGLU modes - if (!is_swiglu) { - ORT_UNUSED_PARAMETER(fc3_experts_bias_optional); - ORT_UNUSED_PARAMETER(fc3_scales_optional); - } - - return Status::OK(); -} - -template -Status QMoE::PrepackAndDequantizeWeights(OpKernelContext* context, - MoEParameters& moe_params, - const Tensor* fc1_experts_weights, - const Tensor* fc2_experts_weights, - const Tensor* fc1_scales, - const Tensor* fc2_scales, - bool is_swiglu) { - // Get thread pool - auto* thread_pool = context->GetOperatorThreadPool(); - - // Get input data pointers - const uint8_t* fc1_weights_data = fc1_experts_weights->Data(); - const uint8_t* fc2_weights_data = fc2_experts_weights->Data(); - const void* fc1_scales_data_typed = fc1_scales->DataRaw(); - const void* fc2_scales_data_typed = fc2_scales->DataRaw(); - bool is_fp32_scales = fc1_scales->IsDataType(); - - AllocatorPtr allocator; - ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); - - const int64_t num_threads = std::min( - static_cast(concurrency::ThreadPool::DegreeOfParallelism(thread_pool)), - moe_params.num_experts); - - // Prepare scales in float format - const int64_t fc1_scales_size = moe_params.num_experts * (is_swiglu ? 2 * moe_params.inter_size : moe_params.inter_size); - const int64_t fc2_scales_size = moe_params.num_experts * moe_params.hidden_size; - - auto fc1_scales_float = IAllocator::MakeUniquePtr(allocator, static_cast(fc1_scales_size)); - auto fc2_scales_float = IAllocator::MakeUniquePtr(allocator, static_cast(fc2_scales_size)); - - if (is_fp32_scales) { - // For float scales, just copy - std::memcpy(fc1_scales_float.get(), fc1_scales_data_typed, static_cast(fc1_scales_size) * sizeof(float)); - std::memcpy(fc2_scales_float.get(), fc2_scales_data_typed, static_cast(fc2_scales_size) * sizeof(float)); - } else { - // For MLFloat16 scales, convert to float - MlasConvertHalfToFloatBufferInParallel(reinterpret_cast(fc1_scales_data_typed), - fc1_scales_float.get(), - static_cast(fc1_scales_size), - thread_pool); - MlasConvertHalfToFloatBufferInParallel(reinterpret_cast(fc2_scales_data_typed), - fc2_scales_float.get(), - static_cast(fc2_scales_size), - thread_pool); - } - - const float* fc1_scales_data = fc1_scales_float.get(); - const float* fc2_scales_data = fc2_scales_float.get(); - - // Determine quantization parameters based on bit width - using symmetric quantization for TensorRT compatibility - const bool is_4bit = UseUInt4x2; - const int64_t act_multiplier = is_swiglu ? 2 : 1; - const int64_t fc1_output_size = is_swiglu ? 2 * moe_params.inter_size : moe_params.inter_size; - - // Calculate weight sizes and strides based on quantization type - const int64_t fc1_weight_stride = is_4bit ? (moe_params.hidden_size * fc1_output_size / 2) : (moe_params.hidden_size * moe_params.inter_size * act_multiplier); - const int64_t fc2_weight_stride = is_4bit ? (moe_params.inter_size * moe_params.hidden_size / 2) : (moe_params.inter_size * moe_params.hidden_size); - - // Get or create a persistent allocator for weights - if (weights_allocator_ == nullptr) { - ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&weights_allocator_)); - } - - // Allocate prepacked weight buffers using ORT allocator - const size_t fc1_weights_size = static_cast(moe_params.num_experts * moe_params.hidden_size * (is_4bit ? fc1_output_size : moe_params.inter_size * act_multiplier)); - const size_t fc2_weights_size = static_cast(moe_params.num_experts * moe_params.inter_size * moe_params.hidden_size); - - prepacked_fc1_weights_ = IAllocator::MakeUniquePtr(weights_allocator_, fc1_weights_size); - prepacked_fc2_weights_ = IAllocator::MakeUniquePtr(weights_allocator_, fc2_weights_size); - - // Store pointers for easy access - prepacked_fc1_weights_data_ = prepacked_fc1_weights_.get(); - prepacked_fc2_weights_data_ = prepacked_fc2_weights_.get(); - - // Helper lambda for dequantizing a single weight value - updated for symmetric quantization - auto DequantizeWeight = [&](const uint8_t* weights, size_t linear_idx, - const float* scales, int64_t scale_idx) -> float { - if (is_4bit) { - // For Int4, two values are packed in each uint8 - size_t packed_idx = linear_idx / 2; - uint8_t packed_value = weights[packed_idx]; - uint8_t quantized_weight = (linear_idx % 2 == 0) ? (packed_value & 0x0F) : ((packed_value >> 4) & 0x0F); - // Convert uint4 to int4 with proper mapping for symmetric quantization - int8_t signed_weight = static_cast(quantized_weight); - if (signed_weight >= 8) { - signed_weight -= 16; // Map [8, 15] to [-8, -1] for proper signed representation - } - return static_cast(signed_weight) * scales[scale_idx]; - } else { - // For Int8, convert uint8 to int8 for symmetric quantization - int8_t signed_weight = static_cast(weights[linear_idx]); - return static_cast(signed_weight) * scales[scale_idx]; - } - }; - - // Dequantize FC1 weights for all experts - concurrency::ThreadPool::TryParallelFor( - thread_pool, static_cast(moe_params.num_experts), - static_cast(std::max(1, moe_params.num_experts / num_threads)), - [&](ptrdiff_t expert_start, ptrdiff_t expert_end) { - for (std::ptrdiff_t expert_idx = expert_start; expert_idx < expert_end; ++expert_idx) { - const uint8_t* fc1_expert_weights = fc1_weights_data + static_cast(SafeInt(expert_idx)) * fc1_weight_stride; - const float* fc1_expert_scales = fc1_scales_data + static_cast(SafeInt(expert_idx)) * (is_4bit ? fc1_output_size : moe_params.inter_size * act_multiplier); - float* dequant_fc1_expert = prepacked_fc1_weights_data_ + static_cast(SafeInt(expert_idx)) * moe_params.hidden_size * (is_4bit ? fc1_output_size : moe_params.inter_size * act_multiplier); - - const int64_t output_cols = is_4bit ? fc1_output_size : moe_params.inter_size * act_multiplier; - for (int64_t out_col = 0; out_col < output_cols; ++out_col) { - for (int64_t in_col = 0; in_col < moe_params.hidden_size; ++in_col) { - size_t linear_idx = static_cast(out_col * moe_params.hidden_size + in_col); - dequant_fc1_expert[linear_idx] = DequantizeWeight(fc1_expert_weights, linear_idx, fc1_expert_scales, out_col); - } - } - } - }); - - // Dequantize FC2 weights for all experts - concurrency::ThreadPool::TryParallelFor( - thread_pool, static_cast(moe_params.num_experts), - static_cast(std::max(1, moe_params.num_experts / num_threads)), - [&](ptrdiff_t expert_start, ptrdiff_t expert_end) { - for (std::ptrdiff_t expert_idx = expert_start; expert_idx < expert_end; ++expert_idx) { - const uint8_t* fc2_expert_weights = fc2_weights_data + static_cast(SafeInt(expert_idx)) * fc2_weight_stride; - const float* fc2_expert_scales = fc2_scales_data + static_cast(SafeInt(expert_idx)) * moe_params.hidden_size; - float* dequant_fc2_expert = prepacked_fc2_weights_data_ + static_cast(SafeInt(expert_idx)) * moe_params.inter_size * moe_params.hidden_size; - - for (int64_t out_col = 0; out_col < moe_params.hidden_size; ++out_col) { - for (int64_t in_col = 0; in_col < moe_params.inter_size; ++in_col) { - size_t linear_idx = static_cast(out_col * moe_params.inter_size + in_col); - dequant_fc2_expert[linear_idx] = DequantizeWeight(fc2_expert_weights, linear_idx, fc2_expert_scales, out_col); - } - } - } - }); - - // Update cached parameters - cached_num_experts_ = moe_params.num_experts; - cached_hidden_size_ = moe_params.hidden_size; - cached_inter_size_ = moe_params.inter_size; - cached_is_swiglu_ = is_swiglu; - is_prepacked_ = true; - - return Status::OK(); -} - -// Explicit template instantiations -template Status QMoE::QuantizedMoEImpl(OpKernelContext* context, - MoEParameters& moe_params, - const Tensor* input, - const Tensor* router_probs, - const Tensor* fc1_experts_weights, - const Tensor* fc1_experts_bias_optional, - const Tensor* fc2_experts_weights, - const Tensor* fc2_experts_bias_optional, - const Tensor* fc3_experts_weights_optional, - const Tensor* fc3_experts_bias_optional, - const Tensor* fc1_scales, - const Tensor* fc2_scales, - const Tensor* fc3_scales_optional) const; - -template Status QMoE::QuantizedMoEImpl(OpKernelContext* context, - MoEParameters& moe_params, - const Tensor* input, - const Tensor* router_probs, - const Tensor* fc1_experts_weights, - const Tensor* fc1_experts_bias_optional, - const Tensor* fc2_experts_weights, - const Tensor* fc2_experts_bias_optional, - const Tensor* fc3_experts_weights_optional, - const Tensor* fc3_experts_bias_optional, - const Tensor* fc1_scales, - const Tensor* fc2_scales, - const Tensor* fc3_scales_optional) const; - -template Status QMoE::QuantizedMoEImpl(OpKernelContext* context, - MoEParameters& moe_params, - const Tensor* input, - const Tensor* router_probs, - const Tensor* fc1_experts_weights, - const Tensor* fc1_experts_bias_optional, - const Tensor* fc2_experts_weights, - const Tensor* fc2_experts_bias_optional, - const Tensor* fc3_experts_weights_optional, - const Tensor* fc3_experts_bias_optional, - const Tensor* fc1_scales, - const Tensor* fc2_scales, - const Tensor* fc3_scales_optional) const; - -template Status QMoE::QuantizedMoEImpl(OpKernelContext* context, - MoEParameters& moe_params, - const Tensor* input, - const Tensor* router_probs, - const Tensor* fc1_experts_weights, - const Tensor* fc1_experts_bias_optional, - const Tensor* fc2_experts_weights, - const Tensor* fc2_experts_bias_optional, - const Tensor* fc3_experts_weights_optional, - const Tensor* fc3_experts_bias_optional, - const Tensor* fc1_scales, - const Tensor* fc2_scales, - const Tensor* fc3_scales_optional) const; - -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/quantization/moe_quantization_cpu.h b/onnxruntime/contrib_ops/cpu/quantization/moe_quantization_cpu.h deleted file mode 100644 index 19caa86c0fd98..0000000000000 --- a/onnxruntime/contrib_ops/cpu/quantization/moe_quantization_cpu.h +++ /dev/null @@ -1,63 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "contrib_ops/cpu/moe/moe_base_cpu.h" -#include "core/common/common.h" -#include "core/framework/op_kernel.h" - -namespace onnxruntime { -namespace contrib { - -class QMoE final : public OpKernel, public MoEBaseCPU { - public: - explicit QMoE(const OpKernelInfo& op_kernel_info); - Status Compute(OpKernelContext* ctx) const override; - - private: - template - Status PrepackAndDequantizeWeights(OpKernelContext* context, - MoEParameters& moe_params, - const Tensor* fc1_experts_weights, - const Tensor* fc2_experts_weights, - const Tensor* fc1_scales, - const Tensor* fc2_scales, - bool is_swiglu); - - template - Status QuantizedMoEImpl(OpKernelContext* context, - MoEParameters& moe_params, - const Tensor* input, - const Tensor* router_probs, - const Tensor* fc1_experts_weights, - const Tensor* fc1_experts_bias_optional, - const Tensor* fc2_experts_weights, - const Tensor* fc2_experts_bias_optional, - const Tensor* fc3_experts_weights_optional, - const Tensor* fc3_experts_bias_optional, - const Tensor* fc1_scales, - const Tensor* fc2_scales, - const Tensor* fc3_scales_optional) const; - - // Prepacked dequantized weights stored for reuse - IAllocatorUniquePtr prepacked_fc1_weights_; - IAllocatorUniquePtr prepacked_fc2_weights_; - float* prepacked_fc1_weights_data_{nullptr}; - float* prepacked_fc2_weights_data_{nullptr}; - - // Persistent allocator for weights - AllocatorPtr weights_allocator_; - - // Cached parameters to detect changes requiring repack - mutable int64_t cached_num_experts_{0}; - mutable int64_t cached_hidden_size_{0}; - mutable int64_t cached_inter_size_{0}; - mutable bool cached_is_swiglu_{false}; - mutable bool is_prepacked_{false}; - - int64_t expert_weight_bits_; -}; - -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu index 2c6a28f1c55f4..ce8c0270f5c32 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu @@ -540,7 +540,8 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ if (normalize_routing_weights && k_idx == k - 1) { #pragma unroll for (int ki = 0; ki < k; ++ki) { - output[idx - ki] = T(static_cast(output[idx - ki]) / output_row_sum); + float old_val = static_cast(output[idx - ki]); + output[idx - ki] = T(old_val / output_row_sum); } } } diff --git a/onnxruntime/contrib_ops/cuda/moe/moe_base.h b/onnxruntime/contrib_ops/cuda/moe/moe_base.h index 3ca7cee46b22b..5f0c30b16a8f4 100644 --- a/onnxruntime/contrib_ops/cuda/moe/moe_base.h +++ b/onnxruntime/contrib_ops/cuda/moe/moe_base.h @@ -7,7 +7,7 @@ #include "core/framework/tensor_shape.h" #include "core/framework/op_kernel.h" #include "contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels.h" -#include "contrib_ops/cpu/quantization/moe_helper.h" +#include "contrib_ops/cpu/moe/moe_helper.h" namespace onnxruntime { namespace contrib { diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc index ef72eea76b2d3..ab611a8e5a7c0 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc @@ -139,6 +139,7 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { WGSL_TEMPLATE_PARAMETER(is_fp16, is_fp16_), WGSL_TEMPLATE_PARAMETER(is_qualcomm, is_qualcomm_), WGSL_TEMPLATE_PARAMETER(is_unidirectional, is_unidirectional_), + WGSL_TEMPLATE_PARAMETER(prefer_subgroupshuffle, !is_nvidia_), WGSL_TEMPLATE_PARAMETER(qkv_head_size, qkv_head_size_), WGSL_TEMPLATE_PARAMETER(qkv_num_heads, qkv_num_heads_)); } @@ -283,6 +284,7 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co const uint32_t tile_size = 64; bool has_attention_bias = attention_bias != nullptr; bool is_qualcomm = context.AdapterInfo().vendor == std::string_view{"qualcomm"}; + bool is_nvidia = context.AdapterInfo().vendor == std::string_view{"nvidia"}; bool is_fp16 = (Q->GetElementType() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16); FlashAttentionProgram program{"FlashAttention", has_attention_bias, @@ -290,7 +292,8 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co is_fp16, parameters.head_size_, parameters.num_heads_, - parameters.is_unidirectional_}; + parameters.is_unidirectional_, + is_nvidia}; program.AddInputs({{Q, ProgramTensorMetadataDependency::TypeAndRank, 4}, {present_key, ProgramTensorMetadataDependency::TypeAndRank, 4}, {present_value, ProgramTensorMetadataDependency::TypeAndRank, 4}}); @@ -303,7 +306,7 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co const uint32_t num_seq_tile = (parameters.sequence_length_ + tile_size - 1) / tile_size; program.SetDispatchGroupSize(parameters.num_heads_ * num_seq_tile) .SetWorkgroupSize(tile_size) - .CacheHint(has_attention_bias, parameters.head_size_, parameters.num_heads_, parameters.is_unidirectional_, is_qualcomm) + .CacheHint(has_attention_bias, parameters.head_size_, parameters.num_heads_, parameters.is_unidirectional_, is_qualcomm, is_nvidia) .AddUniformVariables({{static_cast(parameters.sequence_length_)}, {static_cast(parameters.total_sequence_length_)}, {static_cast(present_sequence_length)}, diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h index 9839b43ee8a69..c75494df253c1 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h @@ -40,14 +40,16 @@ class FlashAttentionProgram final : public Program { bool is_fp16, int qkv_head_size, int qkv_num_heads, - bool is_unidirectional) + bool is_unidirectional, + bool is_nvidia) : Program{kernel_name}, has_attention_bias_(has_attention_bias), is_qualcomm_(is_qualcomm), is_fp16_(is_fp16), qkv_head_size_(qkv_head_size), qkv_num_heads_(qkv_num_heads), - is_unidirectional_(is_unidirectional) { + is_unidirectional_(is_unidirectional), + is_nvidia_(is_nvidia) { } Status GenerateShaderCode(ShaderHelper& sh) const override; @@ -67,6 +69,7 @@ class FlashAttentionProgram final : public Program { int qkv_head_size_; int qkv_num_heads_; bool is_unidirectional_; + bool is_nvidia_; }; class FlashAttentionDecodeQKTProgram final : public Program { diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.wgsl.template index a8c28388292f9..0674702bd6030 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.wgsl.template @@ -3,6 +3,7 @@ #param is_fp16 #param is_qualcomm #param is_unidirectional +#param prefer_subgroupshuffle #param qkv_head_size #param qkv_num_heads @@ -110,6 +111,22 @@ fn loadAttentionBias(q_idx_global : u32, k_idx_global : u32, head_idx : u32) -> } #endif +fn fetchKTile(k_idx: u32, vec_idx: u32, k_val: q_value_t) -> q_value_t { +#if prefer_subgroupshuffle + return subgroupShuffle(k_val, k_idx); +#else + return k_tile[k_idx][vec_idx]; +#endif +} + +fn fetchVTile(k_idx: u32, vec_idx: u32, v_val: q_value_t) -> q_value_t { +#if prefer_subgroupshuffle + return subgroupShuffle(v_val, k_idx); +#else + return v_tile[k_idx][vec_idx]; +#endif +} + $MAIN { let head_idx = u32(workgroup_idx / uniforms.num_seq_tile); let capped_sg_id = min(sg_id, max_k_step - 1u); @@ -149,47 +166,54 @@ $MAIN { var qk_4 : vec4; if (sg_size > 8) { for (var i : u32 = 0u; i < head_size_vec; i++) { -#if is_qualcomm +#if prefer_subgroupshuffle + #if is_qualcomm var k_local = q_value_t(0); if (sg_id < max_k_step) { k_local = k_tile[sg_id][i]; } -#else + #else var k_local = k_tile[capped_sg_id][i]; + #endif +#else + var k_local = q_value_t(0); #endif var q_own = q_tile[i]; - qk_1[0] += dot(q_own, subgroupShuffle(k_local, 0)); - qk_1[1] += dot(q_own, subgroupShuffle(k_local, 1)); - qk_1[2] += dot(q_own, subgroupShuffle(k_local, 2)); - qk_1[3] += dot(q_own, subgroupShuffle(k_local, 3)); - qk_2[0] += dot(q_own, subgroupShuffle(k_local, 4)); - qk_2[1] += dot(q_own, subgroupShuffle(k_local, 5)); - qk_2[2] += dot(q_own, subgroupShuffle(k_local, 6)); - qk_2[3] += dot(q_own, subgroupShuffle(k_local, 7)); - qk_3[0] += dot(q_own, subgroupShuffle(k_local, 8)); - qk_3[1] += dot(q_own, subgroupShuffle(k_local, 9)); - qk_3[2] += dot(q_own, subgroupShuffle(k_local, 10)); - qk_3[3] += dot(q_own, subgroupShuffle(k_local, 11)); - qk_4[0] += dot(q_own, subgroupShuffle(k_local, 12)); - qk_4[1] += dot(q_own, subgroupShuffle(k_local, 13)); - qk_4[2] += dot(q_own, subgroupShuffle(k_local, 14)); - qk_4[3] += dot(q_own, subgroupShuffle(k_local, 15)); + qk_1[0] += dot(q_own, fetchKTile(0, i, k_local)); + qk_1[1] += dot(q_own, fetchKTile(1, i, k_local)); + qk_1[2] += dot(q_own, fetchKTile(2, i, k_local)); + qk_1[3] += dot(q_own, fetchKTile(3, i, k_local)); + qk_2[0] += dot(q_own, fetchKTile(4, i, k_local)); + qk_2[1] += dot(q_own, fetchKTile(5, i, k_local)); + qk_2[2] += dot(q_own, fetchKTile(6, i, k_local)); + qk_2[3] += dot(q_own, fetchKTile(7, i, k_local)); + qk_3[0] += dot(q_own, fetchKTile(8, i, k_local)); + qk_3[1] += dot(q_own, fetchKTile(9, i, k_local)); + qk_3[2] += dot(q_own, fetchKTile(10, i, k_local)); + qk_3[3] += dot(q_own, fetchKTile(11, i, k_local)); + qk_4[0] += dot(q_own, fetchKTile(12, i, k_local)); + qk_4[1] += dot(q_own, fetchKTile(13, i, k_local)); + qk_4[2] += dot(q_own, fetchKTile(14, i, k_local)); + qk_4[3] += dot(q_own, fetchKTile(15, i, k_local)); } } else { for (var i : u32 = 0u; i < head_size_vec; i++) { +#if prefer_subgroupshuffle var k_local = k_tile[capped_sg_id][i]; +#else + var k_local = q_value_t(0); +#endif var q_own = q_tile[i]; - qk_1[0] += dot(q_own, subgroupShuffle(k_local, 0)); - qk_1[1] += dot(q_own, subgroupShuffle(k_local, 1)); - qk_1[2] += dot(q_own, subgroupShuffle(k_local, 2)); - qk_1[3] += dot(q_own, subgroupShuffle(k_local, 3)); - qk_2[0] += dot(q_own, subgroupShuffle(k_local, 4)); - qk_2[1] += dot(q_own, subgroupShuffle(k_local, 5)); - qk_2[2] += dot(q_own, subgroupShuffle(k_local, 6)); - qk_2[3] += dot(q_own, subgroupShuffle(k_local, 7)); + qk_1[0] += dot(q_own, fetchKTile(0, i, k_local)); + qk_1[1] += dot(q_own, fetchKTile(1, i, k_local)); + qk_1[2] += dot(q_own, fetchKTile(2, i, k_local)); + qk_1[3] += dot(q_own, fetchKTile(3, i, k_local)); + qk_2[0] += dot(q_own, fetchKTile(4, i, k_local)); + qk_2[1] += dot(q_own, fetchKTile(5, i, k_local)); + qk_2[2] += dot(q_own, fetchKTile(6, i, k_local)); + qk_2[3] += dot(q_own, fetchKTile(7, i, k_local)); } } - qk_1 = qk_1 + loadAttentionBias(q_idx_global, k_start, head_idx); qk_2 = qk_2 + loadAttentionBias(q_idx_global, k_start + 4, head_idx); if (sg_size > 8) { @@ -326,36 +350,44 @@ $MAIN { #else if (sg_size > 8) { for (var i : u32 = 0; i < head_size_vec; i++) { + #if prefer_subgroupshuffle var val = v_tile[capped_sg_id][i]; - var sum = subgroupShuffle(val, 0) * qk_1[0]; - sum += subgroupShuffle(val, 1) * qk_1[1]; - sum += subgroupShuffle(val, 2) * qk_1[2]; - sum += subgroupShuffle(val, 3) * qk_1[3]; - sum += subgroupShuffle(val, 4) * qk_2[0]; - sum += subgroupShuffle(val, 5) * qk_2[1]; - sum += subgroupShuffle(val, 6) * qk_2[2]; - sum += subgroupShuffle(val, 7) * qk_2[3]; - sum += subgroupShuffle(val, 8) * qk_3[0]; - sum += subgroupShuffle(val, 9) * qk_3[1]; - sum += subgroupShuffle(val, 10) * qk_3[2]; - sum += subgroupShuffle(val, 11) * qk_3[3]; - sum += subgroupShuffle(val, 12) * qk_4[0]; - sum += subgroupShuffle(val, 13) * qk_4[1]; - sum += subgroupShuffle(val, 14) * qk_4[2]; - sum += subgroupShuffle(val, 15) * qk_4[3]; + #else + var val = q_value_t(0); + #endif + var sum = fetchVTile(0, i, val) * qk_1[0]; + sum += fetchVTile(1, i, val) * qk_1[1]; + sum += fetchVTile(2, i, val) * qk_1[2]; + sum += fetchVTile(3, i, val) * qk_1[3]; + sum += fetchVTile(4, i, val) * qk_2[0]; + sum += fetchVTile(5, i, val) * qk_2[1]; + sum += fetchVTile(6, i, val) * qk_2[2]; + sum += fetchVTile(7, i, val) * qk_2[3]; + sum += fetchVTile(8, i, val) * qk_3[0]; + sum += fetchVTile(9, i, val) * qk_3[1]; + sum += fetchVTile(10, i, val) * qk_3[2]; + sum += fetchVTile(11, i, val) * qk_3[3]; + sum += fetchVTile(12, i, val) * qk_4[0]; + sum += fetchVTile(13, i, val) * qk_4[1]; + sum += fetchVTile(14, i, val) * qk_4[2]; + sum += fetchVTile(15, i, val) * qk_4[3]; o_tile[i] = o_tile[i] * o_ratio + sum; } } else { for (var i : u32 = 0; i < head_size_vec; i++) { + #if prefer_subgroupshuffle var val = v_tile[capped_sg_id][i]; - var sum = subgroupShuffle(val, 0) * qk_1[0]; - sum += subgroupShuffle(val, 1) * qk_1[1]; - sum += subgroupShuffle(val, 2) * qk_1[2]; - sum += subgroupShuffle(val, 3) * qk_1[3]; - sum += subgroupShuffle(val, 4) * qk_2[0]; - sum += subgroupShuffle(val, 5) * qk_2[1]; - sum += subgroupShuffle(val, 6) * qk_2[2]; - sum += subgroupShuffle(val, 7) * qk_2[3]; + #else + var val = q_value_t(0); + #endif + var sum = fetchVTile(0, i, val) * qk_1[0]; + sum += fetchVTile(1, i, val) * qk_1[1]; + sum += fetchVTile(2, i, val) * qk_1[2]; + sum += fetchVTile(3, i, val) * qk_1[3]; + sum += fetchVTile(4, i, val) * qk_2[0]; + sum += fetchVTile(5, i, val) * qk_2[1]; + sum += fetchVTile(6, i, val) * qk_2[2]; + sum += fetchVTile(7, i, val) * qk_2[3]; o_tile[i] = o_tile[i] * o_ratio + sum; } } diff --git a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul.wgsl.template b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul.wgsl.template index e4e3730eba808..ee6dde3788157 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul.wgsl.template @@ -4,6 +4,7 @@ #param block_size #param n_bits #param has_zero_points +#param is_qualcomm #include "quantization/dp4a_matmul_common.wgsl.template" @@ -108,7 +109,26 @@ fn loadSHMA(a_global_base:u32, kidx_v:u32, row: u32, col: u32) } #endif +#if n_bits == 2 + fn loadSHMB(b_global_base:u32, kidx_v:u32, row: u32, col: u32) + { + let b_global = b_global_base + row; + if (b_global >= uniforms.N) + { + return; + } + let b_value = input_b[b_global*uniforms.K16+kidx_v+col]; + tile_B[col][row] = DequantizedFrom2BitsTo8Bits(b_value); + let block_idx = kidx_v/(block_size/16); + scale_B[row] = scales_b[b_global*(uniforms.K/block_size) + block_idx]; + } +#endif + $MAIN { +#if n_bits == 2 + LoadDequantizationTable(local_idx); + workgroupBarrier(); +#endif // During the load phase we use all 256 threads to load 64 rows of A/B. // For each row we load tile_size_k_vec (2) vectorized elements, which are 32 elements of K. let a_global_base = u32(workgroup_idx / uniforms.num_N_tile) * tile_size; @@ -119,19 +139,36 @@ $MAIN { // During the compute phase, we have the 64x64 tile split into // subtiles of 16x16. We have a grid of 4x4 subtiles. - let subtile_id = u32(local_idx / subtile_size); - let subtile_idx = u32(subtile_id / 4); - let subtile_idy = u32(subtile_id % 4); - let base_A = subtile_idx * 16; - let base_B = subtile_idy * 16; + var subtile_id = u32(local_idx / subtile_size); + var subtile_idx = u32(subtile_id / 4); + var subtile_idy = u32(subtile_id % 4); + var base_A = subtile_idx * 16; + var base_B = subtile_idy * 16; // For each subtile we have 16 threads assigned. - let a_idx = u32(local_idx % subtile_size); + var a_idx = u32(local_idx % subtile_size); +#if is_qualcomm + // subtile_idx is always 0 + // subtile_idy is one of {0,1,2,3} + // The subtile is now rectangular 64x16 for qualcomm case and we have 4 subtiles, this way we don't need to + // increase the number of lane_output each thread needs to track. That is if we want to use a subtile that is 64x64 + // we would need var lane_outputs: array; + if (sg_size == 64) { + subtile_id = u32(local_idx / sg_size); + subtile_idx = u32(subtile_id / 4); + subtile_idy = u32(subtile_id % 4); + base_A = subtile_idx * sg_size; + base_B = subtile_idy * 16; + a_idx = sg_id; + } + var lane_outputs: array; +#else var lane_output1: vec4; var lane_output2: vec4; var lane_output3: vec4; var lane_output4: vec4; - // K's vectrorization is 16 items per index. See input_a/input_b. +#endif + // K's vectorization is 16 items per index. See input_a/input_b. // tile_size_k_vec - is the k tile size in vectorized space (1/16). That is // k tile size is 32. In vectorized space that is 32/16 = 2. for (var kidx_v:u32 = 0; kidx_v < uniforms.K16; kidx_v+=tile_size_k_vec) @@ -154,6 +191,34 @@ $MAIN { var own_scale_a: output_element_t = scale_A[base_A + a_idx]; #if has_zero_points && n_bits == 8 + #if is_qualcomm + if (sg_size == 64) + { + var own_b0: vec4; + var own_b1: vec4; + var own_scale_b: output_element_t; + var zero: i32; + if (sg_id < 16) + { + own_b0 = tile_B[0][base_B + sg_id]; + own_b1 = tile_B[1][base_B + sg_id]; + own_scale_b = scale_B[base_B + sg_id]; + zero = zeroes[base_B + sg_id]; + } + // Step 2: Access registers across the subgroup using subgroupShuffle and perform the matmul. + for (var i = 0u; i < 16u; i++) + { + lane_outputs[i] += SDP8AI(own_a0, subgroupShuffle(own_b0, i), own_a1, subgroupShuffle(own_b1, i), subgroupShuffle(own_scale_b, i) * own_scale_a, subgroupShuffle(zero, i)); + } + } + else + { + for (var i = 0u; i < 16u; i++) + { + lane_outputs[i] += SDP8AI(own_a0, tile_B[0][base_B + i], own_a1, tile_B[1][base_B + i], own_scale_a * scale_B[base_B + i], zeroes[base_B + i]); + } + } + #else if (sg_size == 16) { var own_b0: vec4 = tile_B[0][base_B + sg_id]; @@ -206,7 +271,34 @@ $MAIN { lane_output4[2] += SDP8AI(own_a0, tile_B[0][base_B + 14], own_a1, tile_B[1][base_B + 14], own_scale_a * scale_B[base_B + 14], zeroes[base_B + 14]); lane_output4[3] += SDP8AI(own_a0, tile_B[0][base_B + 15], own_a1, tile_B[1][base_B + 15], own_scale_a * scale_B[base_B + 15], zeroes[base_B + 15]); } + #endif #else + #if is_qualcomm + if (sg_size == 64) + { + var own_b0: vec4; + var own_b1: vec4; + var own_scale_b: output_element_t; + if (sg_id < 16) + { + own_b0 = tile_B[0][base_B + sg_id]; + own_b1 = tile_B[1][base_B + sg_id]; + own_scale_b = scale_B[base_B + sg_id]; + } + // Step 2: Access registers across the subgroup using subgroupShuffle and perform the matmul. + for (var i = 0u; i < 16u; i++) + { + lane_outputs[i] += SDP8AI(own_a0, subgroupShuffle(own_b0, i), own_a1, subgroupShuffle(own_b1, i), subgroupShuffle(own_scale_b, i) * own_scale_a); + } + } + else + { + for (var i = 0u; i < 16u; i++) + { + lane_outputs[i] += SDP8AI(own_a0, tile_B[0][base_B + i], own_a1, tile_B[1][base_B + i], own_scale_a * scale_B[base_B + i]); + } + } + #else if (sg_size == 16) { var own_b0: vec4 = tile_B[0][base_B + sg_id]; @@ -258,6 +350,7 @@ $MAIN { lane_output4[2] += SDP8AI(own_a0, tile_B[0][base_B + 14], own_a1, tile_B[1][base_B + 14], own_scale_a * scale_B[base_B + 14]); lane_output4[3] += SDP8AI(own_a0, tile_B[0][base_B + 15], own_a1, tile_B[1][base_B + 15], own_scale_a * scale_B[base_B + 15]); } + #endif #endif workgroupBarrier(); } @@ -268,9 +361,16 @@ $MAIN { // This creates a shader requirement that uniforms.N % 16 == 0 if (a_global < uniforms.M && b_global < uniforms.N) { +#if is_qualcomm + output[output_idx] = vec4(lane_outputs[0], lane_outputs[1], lane_outputs[2], lane_outputs[3]); + output[output_idx+1] = vec4(lane_outputs[4], lane_outputs[5], lane_outputs[6], lane_outputs[7]); + output[output_idx+2] = vec4(lane_outputs[8], lane_outputs[9], lane_outputs[10], lane_outputs[11]); + output[output_idx+3] = vec4(lane_outputs[12], lane_outputs[13], lane_outputs[14], lane_outputs[15]); +#else output[output_idx] = lane_output1; output[output_idx+1] = lane_output2; output[output_idx+2] = lane_output3; output[output_idx+3] = lane_output4; +#endif } } // MAIN diff --git a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_common.wgsl.template b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_common.wgsl.template index 38fe0388c5954..186685b9a8dd4 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_common.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_common.wgsl.template @@ -7,6 +7,7 @@ #include "quantization/matmul_nbits_zero_pt.wgsl.template" #if n_bits == 4 + alias mul_precision = output_element_t; fn DequantizedFrom4BitsTo8Bits(in: vec2, zero: i32) -> vec4 { var out = vec4(0); @@ -23,6 +24,9 @@ #endif #if n_bits == 8 + // For 8bits, in case data overflow when converting from int32 (output of dot4I8Packed) to f16, we force it convert to f32. + // Then do the scale. Finally, convert to output element type. + alias mul_precision = f32; fn AlignWithZeroPoint(in: vec4) -> vec4 { var out = vec4(0); @@ -34,8 +38,282 @@ } #endif -// For 8bits, in case data overflow when converting from int32 (output of dot4I8Packed) to f16, we force it convert to f32. -// Then do the scale. Finally, convert to output element type. +#if n_bits == 2 + alias mul_precision = output_element_t; + const lut_size = 256; + var shm_dequantization_table : array; + const q2_dequantization_table = array( + 0xFEFEFEFE, + 0xFEFEFEFF, + 0xFEFEFE00, + 0xFEFEFE01, + 0xFEFEFFFE, + 0xFEFEFFFF, + 0xFEFEFF00, + 0xFEFEFF01, + 0xFEFE00FE, + 0xFEFE00FF, + 0xFEFE0000, + 0xFEFE0001, + 0xFEFE01FE, + 0xFEFE01FF, + 0xFEFE0100, + 0xFEFE0101, + 0xFEFFFEFE, + 0xFEFFFEFF, + 0xFEFFFE00, + 0xFEFFFE01, + 0xFEFFFFFE, + 0xFEFFFFFF, + 0xFEFFFF00, + 0xFEFFFF01, + 0xFEFF00FE, + 0xFEFF00FF, + 0xFEFF0000, + 0xFEFF0001, + 0xFEFF01FE, + 0xFEFF01FF, + 0xFEFF0100, + 0xFEFF0101, + 0xFE00FEFE, + 0xFE00FEFF, + 0xFE00FE00, + 0xFE00FE01, + 0xFE00FFFE, + 0xFE00FFFF, + 0xFE00FF00, + 0xFE00FF01, + 0xFE0000FE, + 0xFE0000FF, + 0xFE000000, + 0xFE000001, + 0xFE0001FE, + 0xFE0001FF, + 0xFE000100, + 0xFE000101, + 0xFE01FEFE, + 0xFE01FEFF, + 0xFE01FE00, + 0xFE01FE01, + 0xFE01FFFE, + 0xFE01FFFF, + 0xFE01FF00, + 0xFE01FF01, + 0xFE0100FE, + 0xFE0100FF, + 0xFE010000, + 0xFE010001, + 0xFE0101FE, + 0xFE0101FF, + 0xFE010100, + 0xFE010101, + 0xFFFEFEFE, + 0xFFFEFEFF, + 0xFFFEFE00, + 0xFFFEFE01, + 0xFFFEFFFE, + 0xFFFEFFFF, + 0xFFFEFF00, + 0xFFFEFF01, + 0xFFFE00FE, + 0xFFFE00FF, + 0xFFFE0000, + 0xFFFE0001, + 0xFFFE01FE, + 0xFFFE01FF, + 0xFFFE0100, + 0xFFFE0101, + 0xFFFFFEFE, + 0xFFFFFEFF, + 0xFFFFFE00, + 0xFFFFFE01, + 0xFFFFFFFE, + 0xFFFFFFFF, + 0xFFFFFF00, + 0xFFFFFF01, + 0xFFFF00FE, + 0xFFFF00FF, + 0xFFFF0000, + 0xFFFF0001, + 0xFFFF01FE, + 0xFFFF01FF, + 0xFFFF0100, + 0xFFFF0101, + 0xFF00FEFE, + 0xFF00FEFF, + 0xFF00FE00, + 0xFF00FE01, + 0xFF00FFFE, + 0xFF00FFFF, + 0xFF00FF00, + 0xFF00FF01, + 0xFF0000FE, + 0xFF0000FF, + 0xFF000000, + 0xFF000001, + 0xFF0001FE, + 0xFF0001FF, + 0xFF000100, + 0xFF000101, + 0xFF01FEFE, + 0xFF01FEFF, + 0xFF01FE00, + 0xFF01FE01, + 0xFF01FFFE, + 0xFF01FFFF, + 0xFF01FF00, + 0xFF01FF01, + 0xFF0100FE, + 0xFF0100FF, + 0xFF010000, + 0xFF010001, + 0xFF0101FE, + 0xFF0101FF, + 0xFF010100, + 0xFF010101, + 0x00FEFEFE, + 0x00FEFEFF, + 0x00FEFE00, + 0x00FEFE01, + 0x00FEFFFE, + 0x00FEFFFF, + 0x00FEFF00, + 0x00FEFF01, + 0x00FE00FE, + 0x00FE00FF, + 0x00FE0000, + 0x00FE0001, + 0x00FE01FE, + 0x00FE01FF, + 0x00FE0100, + 0x00FE0101, + 0x00FFFEFE, + 0x00FFFEFF, + 0x00FFFE00, + 0x00FFFE01, + 0x00FFFFFE, + 0x00FFFFFF, + 0x00FFFF00, + 0x00FFFF01, + 0x00FF00FE, + 0x00FF00FF, + 0x00FF0000, + 0x00FF0001, + 0x00FF01FE, + 0x00FF01FF, + 0x00FF0100, + 0x00FF0101, + 0x0000FEFE, + 0x0000FEFF, + 0x0000FE00, + 0x0000FE01, + 0x0000FFFE, + 0x0000FFFF, + 0x0000FF00, + 0x0000FF01, + 0x000000FE, + 0x000000FF, + 0x00000000, + 0x00000001, + 0x000001FE, + 0x000001FF, + 0x00000100, + 0x00000101, + 0x0001FEFE, + 0x0001FEFF, + 0x0001FE00, + 0x0001FE01, + 0x0001FFFE, + 0x0001FFFF, + 0x0001FF00, + 0x0001FF01, + 0x000100FE, + 0x000100FF, + 0x00010000, + 0x00010001, + 0x000101FE, + 0x000101FF, + 0x00010100, + 0x00010101, + 0x01FEFEFE, + 0x01FEFEFF, + 0x01FEFE00, + 0x01FEFE01, + 0x01FEFFFE, + 0x01FEFFFF, + 0x01FEFF00, + 0x01FEFF01, + 0x01FE00FE, + 0x01FE00FF, + 0x01FE0000, + 0x01FE0001, + 0x01FE01FE, + 0x01FE01FF, + 0x01FE0100, + 0x01FE0101, + 0x01FFFEFE, + 0x01FFFEFF, + 0x01FFFE00, + 0x01FFFE01, + 0x01FFFFFE, + 0x01FFFFFF, + 0x01FFFF00, + 0x01FFFF01, + 0x01FF00FE, + 0x01FF00FF, + 0x01FF0000, + 0x01FF0001, + 0x01FF01FE, + 0x01FF01FF, + 0x01FF0100, + 0x01FF0101, + 0x0100FEFE, + 0x0100FEFF, + 0x0100FE00, + 0x0100FE01, + 0x0100FFFE, + 0x0100FFFF, + 0x0100FF00, + 0x0100FF01, + 0x010000FE, + 0x010000FF, + 0x01000000, + 0x01000001, + 0x010001FE, + 0x010001FF, + 0x01000100, + 0x01000101, + 0x0101FEFE, + 0x0101FEFF, + 0x0101FE00, + 0x0101FE01, + 0x0101FFFE, + 0x0101FFFF, + 0x0101FF00, + 0x0101FF01, + 0x010100FE, + 0x010100FF, + 0x01010000, + 0x01010001, + 0x010101FE, + 0x010101FF, + 0x01010100, + 0x01010101); + fn LoadDequantizationTable(local_idx:u32) + { + // Move dequantization table into on chip memory. + shm_dequantization_table[local_idx] = q2_dequantization_table[local_idx]; + } + fn DequantizedFrom2BitsTo8Bits(in: u32) -> vec4 + { + let unpacked = unpack4xU8(in); + return vec4(shm_dequantization_table[unpacked[0]], + shm_dequantization_table[unpacked[1]], + shm_dequantization_table[unpacked[2]], + shm_dequantization_table[unpacked[3]]); + } +#endif + #if has_zero_points && n_bits == 8 // If has_zero_points is true, vec4(unpack4xU8(b_data)) - vec4(zero) may be out of the range [-128, 127] since zero can be any value between [0, 255]. // To avoid the data overflow when use pack4xI8, we still use |pack4xI8(vec4(unpack4xU8(xxx)) - vec4(128))| to process the b data. In SDP8AI, we use the @@ -61,7 +339,7 @@ local_sum += dot4I8Packed(a2[3], b2[3]); dequantized_a_sum += vec4(unpack4xI8(a2[3])); local_sum -= dot(dequantized_a_sum, vec4(bias_zero)); - return output_element_t(f32(local_sum) * f32(scale)); + return output_element_t(mul_precision(local_sum) * mul_precision(scale)); } #else // Scaled dot product of 8 packed unsigned integers. @@ -75,6 +353,6 @@ local_sum += dot4I8Packed(a2[1], b2[1]); local_sum += dot4I8Packed(a2[2], b2[2]); local_sum += dot4I8Packed(a2[3], b2[3]); - return output_element_t(f32(local_sum) * f32(scale)); + return output_element_t(mul_precision(local_sum) * mul_precision(scale)); } #endif diff --git a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc index 6d0107a779f33..84954946fa6be 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc @@ -28,6 +28,7 @@ Status DP4AMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const { return WGSL_TEMPLATE_APPLY(shader, "quantization/dp4a_matmul.wgsl.template", WGSL_TEMPLATE_PARAMETER(block_size, block_size_), WGSL_TEMPLATE_PARAMETER(has_zero_points, has_zero_points_), + WGSL_TEMPLATE_PARAMETER(is_qualcomm, is_qualcomm_), WGSL_TEMPLATE_PARAMETER(n_bits, nbits_), WGSL_TEMPLATE_PARAMETER(output_type_i32, true)); } @@ -51,6 +52,7 @@ Status DP4AMatMulNBitsSmallMProgram::GenerateShaderCode(ShaderHelper& shader) co WGSL_TEMPLATE_PARAMETER(has_zero_points, has_zero_points_), WGSL_TEMPLATE_PARAMETER(n_bits, nbits_), WGSL_TEMPLATE_PARAMETER(output_type_i32, true), + WGSL_TEMPLATE_PARAMETER(single_scale_weights, single_scale_weights_), WGSL_TEMPLATE_PARAMETER(sub_tile_count, sub_tile_count), WGSL_TEMPLATE_PARAMETER(tile_size, tile_size_), WGSL_TEMPLATE_PARAMETER(tile_size_k_vec, tile_size_k_vec_)); @@ -86,6 +88,7 @@ Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor .AddUniformVariable({M * K / kU32Components}); ORT_RETURN_IF_ERROR(context.RunProgram(quantize_program)); const bool has_zero_points = zero_points != nullptr; + const bool single_scale_weights = (block_size == K * N); if (M < min_M_for_tile_optimization) { uint32_t tile_size_k_vec = 16; uint32_t tile_size_n = 32; @@ -94,18 +97,18 @@ Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor tile_size_k_vec = 32; tile_size_n = 4; } - - DP4AMatMulNBitsSmallMProgram mul_program{tile_size_k_vec, tile_size_n, nbits, has_zero_points}; + const uint32_t b_components = (nbits == 2 ? kVec2Components : kVec4Components); + DP4AMatMulNBitsSmallMProgram mul_program{tile_size_k_vec, tile_size_n, nbits, has_zero_points, single_scale_weights}; uint32_t num_N_tile = (N + tile_size_n - 1) / tile_size_n; mul_program.SetWorkgroupSize(128); mul_program.SetDispatchGroupSize(M * num_N_tile); mul_program.AddInputs({{&a_quant, ProgramTensorMetadataDependency::TypeAndRank, static_cast(kVec4Components)}, {&a_scale, ProgramTensorMetadataDependency::TypeAndRank, 1}, - {b, ProgramTensorMetadataDependency::TypeAndRank, static_cast(kVec4Components * kU32Components)}, + {b, ProgramTensorMetadataDependency::TypeAndRank, static_cast(b_components * kU32Components)}, {scales, ProgramTensorMetadataDependency::TypeAndRank, 1}}) .AddUniformVariables({M, N, K, K / 16, K / 32, block_size, num_N_tile, zero_blocks_per_col}) .AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, 1}) - .CacheHint(nbits, tile_size_k_vec, tile_size_n, has_zero_points); + .CacheHint(nbits, tile_size_k_vec, tile_size_n, has_zero_points, single_scale_weights); if (has_zero_points) { mul_program.AddInput({zero_points, ProgramTensorMetadataDependency::None, {(zero_points->Shape().Size() + 3) / 4}, 4}); } @@ -116,12 +119,13 @@ Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor TensorShape reshaped_y_shape{1, M, N / kVec4Components}; uint32_t num_M_tile = (M + kTileSize - 1) / kTileSize; uint32_t num_N_tile = (N + kTileSize - 1) / kTileSize; - DP4AMatMulNBitsProgram mul_program{block_size, nbits, has_zero_points}; + bool is_qualcomm = context.AdapterInfo().vendor == std::string_view{"qualcomm"}; + DP4AMatMulNBitsProgram mul_program{block_size, nbits, has_zero_points, is_qualcomm}; mul_program.SetWorkgroupSize(256); mul_program.SetDispatchGroupSize(num_M_tile * num_N_tile); mul_program.AddInputs({{&a_quant, ProgramTensorMetadataDependency::TypeAndRank, static_cast(kVec4Components)}, {&a_scale, ProgramTensorMetadataDependency::TypeAndRank, 1}, - {b, ProgramTensorMetadataDependency::TypeAndRank, static_cast(nbits == 4 ? kVec2Components * kU32Components : kVec4Components * kU32Components)}, + {b, ProgramTensorMetadataDependency::TypeAndRank, static_cast((nbits / 2) * kU32Components)}, {scales, ProgramTensorMetadataDependency::TypeAndRank, 1}}) .AddUniformVariables({{static_cast(M)}, {static_cast(N)}, @@ -131,7 +135,7 @@ Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor {num_N_tile}, {zero_blocks_per_col}}) .AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, reshaped_y_shape, static_cast(kVec4Components)}) - .CacheHint("Block" + std::to_string(block_size), nbits, has_zero_points); + .CacheHint("Block" + std::to_string(block_size), nbits, has_zero_points, is_qualcomm); if (has_zero_points) { mul_program.AddInput({zero_points, ProgramTensorMetadataDependency::None, {(zero_points->Shape().Size() + 3) / 4}, 4}); } diff --git a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h index 5c07885011aac..b00392cbb291e 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h +++ b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h @@ -21,10 +21,11 @@ class DP4AMatMulQuantizeProgram final : public Program { public: - DP4AMatMulNBitsProgram(uint32_t block_size, uint32_t nbits, bool has_zero_points) : Program{"DP4AMatMulNBits"}, - block_size_(block_size), - nbits_(nbits), - has_zero_points_(has_zero_points) {} + DP4AMatMulNBitsProgram(uint32_t block_size, uint32_t nbits, bool has_zero_points, bool is_qualcomm) : Program{"DP4AMatMulNBits"}, + block_size_(block_size), + nbits_(nbits), + has_zero_points_(has_zero_points), + is_qualcomm_(is_qualcomm) {} Status GenerateShaderCode(ShaderHelper& sh) const override; WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( {"M", ProgramUniformVariableDataType::Uint32}, @@ -39,15 +40,17 @@ class DP4AMatMulNBitsProgram final : public Program { uint32_t block_size_; uint32_t nbits_; bool has_zero_points_; + bool is_qualcomm_; }; class DP4AMatMulNBitsSmallMProgram final : public Program { public: - DP4AMatMulNBitsSmallMProgram(uint32_t tile_size_k_vec, uint32_t tile_size, uint32_t nbits, bool has_zero_points) : Program{"DP4AMatMulNBitsSmallMProgram"}, - tile_size_k_vec_(tile_size_k_vec), - tile_size_(tile_size), - nbits_(nbits), - has_zero_points_(has_zero_points) {} + DP4AMatMulNBitsSmallMProgram(uint32_t tile_size_k_vec, uint32_t tile_size, uint32_t nbits, bool has_zero_points, bool single_scale_weights) : Program{"DP4AMatMulNBitsSmallMProgram"}, + tile_size_k_vec_(tile_size_k_vec), + tile_size_(tile_size), + nbits_(nbits), + has_zero_points_(has_zero_points), + single_scale_weights_(single_scale_weights) {} Status GenerateShaderCode(ShaderHelper& sh) const override; WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( {"M", ProgramUniformVariableDataType::Uint32}, @@ -64,6 +67,7 @@ class DP4AMatMulNBitsSmallMProgram final : public Program(bits_); if (has_zero_points) { ORT_ENFORCE(zero_points->DataType() == DataTypeImpl::GetType(), "Currently, only uint8 is supported for zero points, but got ", zero_points->DataType()); + ORT_ENFORCE(nbits != 2, "Currently, zero points are not supported for Q2 quantization."); } MatMulComputeHelper helper; @@ -124,10 +127,12 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context const uint32_t N = onnxruntime::narrow(helper.N()); const uint32_t K = onnxruntime::narrow(helper.K()); const uint32_t block_size = onnxruntime::narrow(block_size_); - const uint32_t nbits = onnxruntime::narrow(bits_); - const uint32_t n_blocks_per_col = (K + block_size - 1) / block_size; - const uint32_t blob_size = (block_size / 8) * nbits; + // Special case matrix used by bitnets where there is a single scale for the entire + const bool single_scale_weights = (block_size == K * N); + const uint32_t block_size_per_col = single_scale_weights ? K : block_size; + const uint32_t n_blocks_per_col = (K + block_size_per_col - 1) / block_size_per_col; + const uint32_t blob_size = (block_size_per_col / 8) * nbits; const uint32_t blob_size_in_words = blob_size / 4; const uint32_t components_a = GetMaxComponents(K); const uint32_t components_b = GetMaxComponents(blob_size_in_words); @@ -157,6 +162,7 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context const bool use_wide_tile_program = block_size == 32 && components_a == 4 && components_b == 4 && + nbits != 2 && M >= kMinMForTileOptimization; if (use_wide_tile_program) { // Enforce output components to 1. @@ -212,7 +218,8 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context constexpr uint32_t kU32Components = 4; uint32_t components_b_with_u32 = components_b * kU32Components; uint32_t num_N_tile = (N + tile_size - 1) / tile_size; - MatMulNBitsProgram program{tile_size, nbits, has_zero_points}; + uint32_t K_of_b = (n_blocks_per_col * blob_size) / components_b_with_u32; + MatMulNBitsProgram program{tile_size, nbits, has_zero_points, single_scale_weights}; program.SetWorkgroupSize(workgroup_size); program.SetDispatchGroupSize((N + tile_size - 1) / tile_size, M, batch_count); program @@ -220,8 +227,8 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context {b, ProgramTensorMetadataDependency::TypeAndRank, static_cast(components_b_with_u32)}, {scales, ProgramTensorMetadataDependency::TypeAndRank}}) .AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank}) - .AddUniformVariables({{M}, {N}, {K}, {K / components_a}, {n_blocks_per_col * blob_size / components_b_with_u32}, {block_size}, {n_blocks_per_col}, {zero_blocks_per_col}, {num_N_tile}, {batch_count}}) - .CacheHint(nbits, has_zero_points); + .AddUniformVariables({{M}, {N}, {K}, {K / components_a}, {K_of_b}, {block_size}, {n_blocks_per_col}, {zero_blocks_per_col}, {num_N_tile}, {batch_count}}) + .CacheHint(nbits, has_zero_points, single_scale_weights); if (has_zero_points) { program.AddInput({zero_points, ProgramTensorMetadataDependency::None, {(zero_points->Shape().Size() + 3) / 4}, 4}); } diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h index aabc73ca05d03..1f7bd16d9cb6f 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h @@ -37,7 +37,7 @@ class MatMulNBitsWideTileProgram final : public Program { public: - MatMulNBitsProgram(uint32_t tile_size, uint32_t nbits, bool has_zero_points) : Program{"MatMulNBits"}, tile_size_(tile_size), nbits_(nbits), has_zero_points_(has_zero_points) {} + MatMulNBitsProgram(uint32_t tile_size, uint32_t nbits, bool has_zero_points, bool single_scale_weights) : Program{"MatMulNBits"}, tile_size_(tile_size), nbits_(nbits), has_zero_points_(has_zero_points), single_scale_weights_(single_scale_weights) {} Status GenerateShaderCode(ShaderHelper& sh) const override; WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( {"M", ProgramUniformVariableDataType::Uint32}, @@ -55,6 +55,7 @@ class MatMulNBitsProgram final : public Program { uint32_t tile_size_; uint32_t nbits_; bool has_zero_points_; + bool single_scale_weights_; }; class MatMulNBits final : public WebGpuKernel { @@ -65,8 +66,8 @@ class MatMulNBits final : public WebGpuKernel { block_size_ = info.GetAttr("block_size"); bits_ = info.GetAttr("bits"); accuracy_level_ = info.GetAttrOrDefault("accuracy_level", 4); - ORT_ENFORCE(bits_ == 4 || bits_ == 8, - "Only 4b/8b quantization is supported for MatMulNBits op, additional bits support is planned."); + ORT_ENFORCE(bits_ == 4 || bits_ == 8 || bits_ == 2, + "Only 4b/8b/2b quantization is supported for MatMulNBits op, additional bits support is planned."); } Status ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const override; diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.wgsl.template b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.wgsl.template index 5fd0a9b3dd788..aba6e3d57c72a 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.wgsl.template @@ -6,6 +6,7 @@ #param component_b #param elements_in_value_b #param n_bits +#param single_scale_weights #param sub_tile_count #param tile_size_k_vec #param tile_size_k @@ -35,6 +36,12 @@ $MAIN { let idx = local_idx % tile_size_k_vec; let idy = local_idx / tile_size_k_vec; +#if single_scale_weights + let block_idx = 0; + let scale_b = scales_b[0]; + let zero = mm_read_zero(0, 0, uniforms.N, uniforms.zero_blocks_per_col); +#endif + for (var kidx = 0u; kidx < uniforms.K; kidx += tile_size_k) { for (var id = local_idx; id < a_length_per_tile; id += workgroup_size_x) @@ -49,9 +56,11 @@ $MAIN { var k_offset = kidx / elements_in_value_b + idx; if (b_global < uniforms.N && k_offset < uniforms.K_of_b) { +#if !single_scale_weights let block_idx = (kidx + idx * elements_in_value_b) / uniforms.block_size; let scale_b = scales_b[b_global * uniforms.blocks_per_col + block_idx]; let zero = mm_read_zero(b_global, block_idx, uniforms.N, uniforms.zero_blocks_per_col); +#endif var b_value = input_b[b_global * uniforms.K_of_b + k_offset]; #if n_bits == 4 @@ -91,6 +100,38 @@ $MAIN { a_offset += 1; #endif } +#elif n_bits == 2 + var sum = output_element_t(0); + var a_offset = idx * (16 / component_a) * component_b; + for (var i = 0u; i < component_b; i++) { + let b_data_0 = vec4(unpack4xU8(b_value[i] & 0x03030303u)) - vec4(zero); + let b_data_1 = vec4(unpack4xU8((b_value[i] >> 2) & 0x03030303u)) - vec4(zero); + let b_data_2 = vec4(unpack4xU8((b_value[i] >> 4) & 0x03030303u)) - vec4(zero); + let b_data_3 = vec4(unpack4xU8((b_value[i] >> 6) & 0x03030303u)) - vec4(zero); + + let b0 = vec4(b_data_0[0], b_data_1[0], b_data_2[0], b_data_3[0]) * scale_b; + let b1 = vec4(b_data_0[1], b_data_1[1], b_data_2[1], b_data_3[1]) * scale_b; + let b2 = vec4(b_data_0[2], b_data_1[2], b_data_2[2], b_data_3[2]) * scale_b; + let b3 = vec4(b_data_0[3], b_data_1[3], b_data_2[3], b_data_3[3]) * scale_b; + +#if component_a == 1 + sum += dot(vec4(tile_A[a_offset], tile_A[a_offset + 1], tile_A[a_offset + 2], tile_A[a_offset + 3]), b0) + + dot(vec4(tile_A[a_offset + 4], tile_A[a_offset + 5], tile_A[a_offset + 6], tile_A[a_offset + 7]), b1) + + dot(vec4(tile_A[a_offset + 8], tile_A[a_offset + 9], tile_A[a_offset + 10], tile_A[a_offset + 11]), b2) + + dot(vec4(tile_A[a_offset + 12], tile_A[a_offset + 13], tile_A[a_offset + 14], tile_A[a_offset + 15]), b3); + a_offset += 16; +#elif component_a == 2 + sum += dot(vec4(tile_A[a_offset], tile_A[a_offset + 1]), b0) + + dot(vec4(tile_A[a_offset + 2], tile_A[a_offset + 3]), b1) + + dot(vec4(tile_A[a_offset + 4], tile_A[a_offset + 5]), b2) + + dot(vec4(tile_A[a_offset + 6], tile_A[a_offset + 7]), b3); + a_offset += 8; +#elif component_a == 4 + sum += dot(tile_A[a_offset], b0) + dot(tile_A[a_offset + 1], b1) + + dot(tile_A[a_offset + 2], b2) + dot(tile_A[a_offset + 3], b3); + a_offset += 4; +#endif + } #endif inter_results[local_row_offset + idy][idx] += sum; diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_zero_pt.wgsl.template b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_zero_pt.wgsl.template index 0da5bd09609af..9135708adf153 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_zero_pt.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_zero_pt.wgsl.template @@ -17,6 +17,8 @@ #elif n_bits == 8 const default_zero_point = 128; const bit_mask = 0xFFu; +#elif n_bits == 2 + const default_zero_point = 2; #endif #if has_zero_points diff --git a/onnxruntime/core/common/safeint.h b/onnxruntime/core/common/safeint.h index 3ee70f369b65d..6aba5871ac62e 100644 --- a/onnxruntime/core/common/safeint.h +++ b/onnxruntime/core/common/safeint.h @@ -13,11 +13,11 @@ class SafeIntExceptionHandler; template <> class SafeIntExceptionHandler { public: - static void SafeIntOnOverflow() { + [[noreturn]] static void SafeIntOnOverflow() { ORT_THROW("Integer overflow"); } - static void SafeIntOnDivZero() { + [[noreturn]] static void SafeIntOnDivZero() { ORT_THROW("Divide by zero"); } }; diff --git a/onnxruntime/core/framework/tensor_external_data_info.cc b/onnxruntime/core/framework/tensor_external_data_info.cc index 971851db62437..d7f5b23d56c70 100644 --- a/onnxruntime/core/framework/tensor_external_data_info.cc +++ b/onnxruntime/core/framework/tensor_external_data_info.cc @@ -107,7 +107,7 @@ void ExternalDataInfo::SetExternalLocationToProto(const std::filesystem::path& e std::ostream& ExternalDataInfo::WritePrepackedToFileAndAddToProto( const PrepackedWeightsForGraph& prepacked_for_graph, const InlinedHashSet& blob_keys, bool align, - int64_t align_threshold, int64_t allocation_granularity, + int64_t align_threshold, int64_t on_disk_alignment, std::ostream& os, int64_t& external_offset, ::ONNX_NAMESPACE::TensorProto& proto) { size_t key_count = 0; for (const auto& key : blob_keys) { @@ -120,7 +120,7 @@ std::ostream& ExternalDataInfo::WritePrepackedToFileAndAddToProto( const auto size_in_bytes = prepacked_weights->buffer_sizes_[i]; if (align && static_cast(size_in_bytes) > align_threshold) { // return early on error - if (!AlignAndPad(os, allocation_granularity, external_offset)) { + if (!AlignAndPad(os, on_disk_alignment, external_offset)) { return os; } } diff --git a/onnxruntime/core/framework/tensor_external_data_info.h b/onnxruntime/core/framework/tensor_external_data_info.h index 2de1e01f381ec..784b3f352a78e 100644 --- a/onnxruntime/core/framework/tensor_external_data_info.h +++ b/onnxruntime/core/framework/tensor_external_data_info.h @@ -41,15 +41,13 @@ class ExternalDataInfo { size_t tensor_bytes_size, ::ONNX_NAMESPACE::TensorProto& proto); - // Pads the output with zeros according to the specified allocation_granularity + // Pads the output with zeros according to the specified alignment_factor // It updates external_offset for alignment. // need to do padding before write actual tensor data as we do offset alignment at the begin of - // large tensors (offset need to be page aligned and allocation granularity aligned) like below: + // large tensors (offset need to be page aligned) like below: // \242\2557\256\023.\031&0000000000000000\332)k+\253\246\342\246(&\006!\347\232\374\236\325\026\032+\36XXXX // |<---smaller tensor---->|<---padding--->|<------------------large tensor----------------------------->| - static std::ostream& AlignAndPad(std::ostream& stream, int64_t allocation_granularity, int64_t& external_offset) { - // Align to the larger of the page size or the allocation granularity - int64_t alignment_factor = std::max(static_cast(4096), allocation_granularity); + static std::ostream& AlignAndPad(std::ostream& stream, int64_t alignment_factor, int64_t& external_offset) { // Align to the next page or alloc granularity boundary SafeInt safe_external_offset = external_offset; int64_t new_external_offset = ((safe_external_offset + alignment_factor - 1) / alignment_factor) * @@ -66,7 +64,7 @@ class ExternalDataInfo { static std::ostream& WritePrepackedToFileAndAddToProto( const PrepackedWeightsForGraph& prepacked_for_graph, const InlinedHashSet& blob_keys, - bool align, int64_t align_threshold, int64_t allocation_granularity, + bool align, int64_t align_threshold, int64_t on_disk_alignment, std::ostream& os, int64_t& external_offset, ::ONNX_NAMESPACE::TensorProto& proto); diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index e4f8cd6df678e..0a228176175eb 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -4536,14 +4536,14 @@ Status Graph::AddExternalInitializersToGraphProtoImpl( continue; } - // update external_offset for alignment + // update external_offset for alignment (if enabled) // need to do padding before write actual tensor data as we do offset alignment at the begin of - // large tensors (offset need to be page aligned and allocation granularity aligned) like below: + // large tensors (offset need to be page aligned) like below: // \242\2557\256\023.\031&0000000000000000\332)k+\253\246\342\246(&\006!\347\232\374\236\325\026\032+\36XXXX // |<---smaller tensor---->|<---padding--->|<------------------large tensor----------------------------->| if (model_saving_options.align_offset && static_cast(tensor_bytes_size) > model_saving_options.align_threshold) { - ORT_RETURN_IF_NOT(ExternalDataInfo::AlignAndPad(external_stream, model_saving_options.allocation_granularity, + ORT_RETURN_IF_NOT(ExternalDataInfo::AlignAndPad(external_stream, model_saving_options.on_disk_alignment, external_offset), "Failed writing external data to: ", model_external_file_path); } @@ -4576,7 +4576,7 @@ Status Graph::AddExternalInitializersToGraphProtoImpl( auto& os = ExternalDataInfo::WritePrepackedToFileAndAddToProto( *prepacked_weights_for_graph_, blob_keys_to_external_data, model_saving_options.align_offset, model_saving_options.align_threshold, - model_saving_options.allocation_granularity, + model_saving_options.on_disk_alignment, external_stream, external_offset, *output_proto); ORT_RETURN_IF_NOT(os.good(), "Failed to write pre-packed blobs to external file"); } diff --git a/onnxruntime/core/platform/windows/env.cc b/onnxruntime/core/platform/windows/env.cc index 36c6b54a1fce0..aa237fc6441b2 100644 --- a/onnxruntime/core/platform/windows/env.cc +++ b/onnxruntime/core/platform/windows/env.cc @@ -29,6 +29,7 @@ limitations under the License. #include #include "core/common/logging/logging.h" #include "core/common/narrow.h" +#include "core/common/safeint.h" #include "core/common/span_utils.h" #include "core/platform/env.h" #include "core/platform/scoped_resource.h" @@ -439,30 +440,28 @@ Status WindowsEnv::MapFileIntoMemory(_In_z_ const ORTCHAR_T* file_path, SYSTEM_INFO sysinfo; GetSystemInfo(&sysinfo); - static const DWORD page_size = sysinfo.dwPageSize; static const DWORD allocation_granularity = sysinfo.dwAllocationGranularity; - const FileOffsetType offset_to_page = offset % static_cast(page_size); - const size_t mapped_length = length + static_cast(offset_to_page); - const FileOffsetType mapped_offset = offset - offset_to_page; - if (mapped_offset % allocation_granularity != 0) { - const auto error_code = GetLastError(); - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, - "mapped offset must be a multiple of the allocation granularity", - " , mapped_offset = ", mapped_offset, - " , allocation_granularity = ", allocation_granularity, - " , errcode = ", error_code, - " - ", std::system_category().message(error_code)); - } + const FileOffsetType offset_to_granularity = offset % static_cast(allocation_granularity); + const SIZE_T mapped_length = SafeInt(offset_to_granularity) + length; + const FileOffsetType mapped_offset = offset - offset_to_granularity; + assert((mapped_offset % allocation_granularity) == 0); void* const mapped_base = MapViewOfFile(file_mapping_handle.get(), FILE_MAP_READ, static_cast((mapped_offset >> 32) & 0xFFFFFFFF), static_cast(mapped_offset & 0xFFFFFFFF), mapped_length); - GSL_SUPPRESS(r.11) + + if (mapped_base == nullptr) { + const auto error_code = GetLastError(); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "MapViewOfFile ", ToUTF8String(Basename(file_path)), + " fail, errcode = ", error_code, + " - ", std::system_category().message(error_code)); + } mapped_memory = - MappedMemoryPtr{reinterpret_cast(mapped_base) + offset_to_page, + MappedMemoryPtr{reinterpret_cast(mapped_base) + offset_to_granularity, [mapped_base](void*) { UnmapFile(mapped_base); }}; diff --git a/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc b/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc index c691be6ffd0e8..dbf2be74f7362 100644 --- a/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc +++ b/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc @@ -520,14 +520,12 @@ Status DequantizeLinear::Compute(OpKernelContext* ctx) const { const T* zero_point = x_zero_point ? x_zero_point->Data() : nullptr; #if !defined(DISABLE_FLOAT8_TYPES) - if constexpr (boost::mp11::mp_contains>, - T>::value) { + if constexpr (boost::mp11::mp_contains::value) { ORT_ENFORCE(zero_point == nullptr || std::all_of(zero_point, zero_point + x_zero_point->Shape().Size(), [](T zp) { return zp == T{0}; }), - "DequantizeLinear with type int32 or float8 should have no zero point or all zero points should be 0"); + "DequantizeLinear with type float8 should have no zero point or all zero points should be 0"); } #endif diff --git a/onnxruntime/core/providers/cuda/cuda_graph.cc b/onnxruntime/core/providers/cuda/cuda_graph.cc index 8353c654681fc..88e58aec70550 100644 --- a/onnxruntime/core/providers/cuda/cuda_graph.cc +++ b/onnxruntime/core/providers/cuda/cuda_graph.cc @@ -72,7 +72,7 @@ void CUDAGraphManager::CaptureEnd(CudaGraphAnnotation_t cuda_graph_annotation_id cuda_graph_set_.Put(cuda_graph_annotation_id, graph_exec); } -Status CUDAGraphManager::Replay(CudaGraphAnnotation_t cuda_graph_annotation_id) { +Status CUDAGraphManager::Replay(CudaGraphAnnotation_t cuda_graph_annotation_id, bool sync_status_flag) { // Although this function is not thread safe, the lock is not needed here because // CUDA EP maintains a separate cuda graph per thread LOGS_DEFAULT(INFO) << "Replaying CUDA graph on stream " << stream_ << " with cuda_graph_annotation_id " @@ -81,7 +81,9 @@ Status CUDAGraphManager::Replay(CudaGraphAnnotation_t cuda_graph_annotation_id) cudaGraphExec_t graph_exec = cuda_graph_set_.Get(cuda_graph_annotation_id); CUDA_RETURN_IF_ERROR(cudaGraphLaunch(graph_exec, stream_)); - CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream_)); + if (sync_status_flag) { + CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream_)); + } return Status::OK(); } diff --git a/onnxruntime/core/providers/cuda/cuda_graph.h b/onnxruntime/core/providers/cuda/cuda_graph.h index 064b526e604bc..6b61a66671de4 100644 --- a/onnxruntime/core/providers/cuda/cuda_graph.h +++ b/onnxruntime/core/providers/cuda/cuda_graph.h @@ -38,7 +38,7 @@ struct CUDAGraphManager { void SetStream(cudaStream_t stream); void CaptureBegin(CudaGraphAnnotation_t cuda_graph_annotation_id); void CaptureEnd(CudaGraphAnnotation_t cuda_graph_annotation_id); - Status Replay(CudaGraphAnnotation_t cuda_graph_annotation_id); + Status Replay(CudaGraphAnnotation_t cuda_graph_annotation_id, bool sync_status_flag = true); void Reset(); diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc index 451be69c81cfb..b7997ce86737a 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc @@ -19,6 +19,7 @@ #include "nv_data_transfer.h" #include "onnx_ctx_model_helper.h" #include "core/providers/cuda/shared_inc/cuda_call.h" +#include "core/providers/cuda/cuda_graph.h" #include "core/providers/cuda/math/unary_elementwise_ops_impl.h" #include "core/session/allocator_adapters.h" #include "cuda_runtime_api.h" @@ -255,7 +256,8 @@ bool ApplyProfileShapesFromProviderOptions(std::vector>>& profile_min_shapes, std::unordered_map>>& profile_max_shapes, std::unordered_map>>& profile_opt_shapes, - ShapeRangesMap& input_explicit_shape_ranges) { + ShapeRangesMap& input_explicit_shape_ranges, + bool& cuda_graph_flag) { if (trt_profiles.size() == 0) { LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] Number of optimization profiles should be greater than 0, but it's 0."; return false; @@ -282,6 +284,10 @@ bool ApplyProfileShapesFromProviderOptions(std::vectorisShapeTensor()) { + if (cuda_graph_flag) { + LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] Shape tensor detected on input '" << input->getName() << "'. Disabling CUDA Graph."; + cuda_graph_flag = false; + } int shape_size = nb_dims == 0 ? 1 : static_cast(profile_min_shapes[input_name][i].size()); std::vector shapes_min(shape_size), shapes_opt(shape_size), shapes_max(shape_size); @@ -715,10 +721,10 @@ Status BindKernelOutput(Ort::KernelContext& ctx, } NvExecutionProvider::PerThreadContext::PerThreadContext(OrtDevice::DeviceId device_id, bool has_user_compute_stream, cudaStream_t stream) { - // TODO: figure out if PerThreadContext is used at all. If not, just clean it up. + // Only set device if user hasn't provided a compute stream if (has_user_compute_stream) { CUDA_CALL_THROW(cudaSetDevice(device_id)); - (void)(stream); + (void)stream; } } @@ -745,6 +751,86 @@ bool NvExecutionProvider::PerThreadContext::UpdateTensorRTContext(std::string fu return false; } +void NvExecutionProvider::PerThreadContext::DeleteCapturedGraph(CudaGraphAnnotation_t cuda_graph_annotation_id) { + graph_id_to_run_count_.erase(cuda_graph_annotation_id); + cuda_graph_.Reset(); +} + +void NvExecutionProvider::PerThreadContext::ResetWarmupRuns(CudaGraphAnnotation_t cuda_graph_annotation_id) { + if (graph_id_to_run_count_.find(cuda_graph_annotation_id) == graph_id_to_run_count_.end()) { + return; + } + graph_id_to_run_count_[cuda_graph_annotation_id] = 0; +} + +bool NvExecutionProvider::PerThreadContext::IsGraphCaptureAllowed(CudaGraphAnnotation_t cuda_graph_annotation_id) const { + if (!IsGraphCaptureAllowedOnRun(cuda_graph_annotation_id)) { + return false; + } + + // Safe access to map - return false if key doesn't exist yet + auto it = graph_id_to_run_count_.find(cuda_graph_annotation_id); + if (it == graph_id_to_run_count_.end()) { + return false; // Entry doesn't exist yet, not ready for capture + } + + bool allowed = it->second >= min_num_runs_before_cuda_graph_capture_; + if (allowed) { + LOGS_DEFAULT(VERBOSE) << "NvTensorRTRTX EP Graph capture allowed for ID: " << cuda_graph_annotation_id + << ", run count: " << it->second; + } + return allowed; +} + +bool NvExecutionProvider::PerThreadContext::IsGraphCaptureAllowedOnRun(CudaGraphAnnotation_t cuda_graph_annotation_id) const { + return cuda_graph_.IsGraphCaptureAllowedOnRun(cuda_graph_annotation_id); +} + +CudaGraphAnnotation_t NvExecutionProvider::PerThreadContext::GetCudaGraphAnnotationId(const onnxruntime::RunOptions& run_options) const { + // Actual implementation + auto graph_annotation_str = run_options.GetConfigOptions().GetConfigEntry(kOrtRunOptionsConfigCudaGraphAnnotation); + CudaGraphAnnotation_t cuda_graph_annotation_id = kCudaGraphAnnotationDefault; + + // Kind of debugging head implementation, can be cleaned and made robust like CUDA EP + if (graph_annotation_str.has_value() && !graph_annotation_str->empty()) { + if (!TryParseStringWithClassicLocale(*graph_annotation_str, cuda_graph_annotation_id)) { + LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] Failed to parse cuda graph annotation id: " + << *graph_annotation_str << ", using default: " << kCudaGraphAnnotationDefault; + cuda_graph_annotation_id = kCudaGraphAnnotationDefault; + } + } + return cuda_graph_annotation_id; +} + +void NvExecutionProvider::PerThreadContext::SetCurrentGraphAnnotationId(CudaGraphAnnotation_t cuda_graph_annotation_id) { + current_graph_annotation_id_ = cuda_graph_annotation_id; +} + +CudaGraphAnnotation_t NvExecutionProvider::PerThreadContext::GetCurrentGraphAnnotationId() const { + return current_graph_annotation_id_; +} + +void NvExecutionProvider::PerThreadContext::CaptureBegin(CudaGraphAnnotation_t cuda_graph_annotation_id) { + cuda_graph_.Reset(); + cuda_graph_.CaptureBegin(cuda_graph_annotation_id); +} + +void NvExecutionProvider::PerThreadContext::CaptureEnd(CudaGraphAnnotation_t cuda_graph_annotation_id) { + cuda_graph_.CaptureEnd(cuda_graph_annotation_id); +} + +bool NvExecutionProvider::PerThreadContext::IsGraphCaptured(CudaGraphAnnotation_t cuda_graph_annotation_id) const { + return cuda_graph_.IsGraphCaptured(cuda_graph_annotation_id); +} + +Status NvExecutionProvider::PerThreadContext::ReplayGraph(CudaGraphAnnotation_t cuda_graph_annotation_id, bool sync_status_flag) { + return cuda_graph_.Replay(cuda_graph_annotation_id, sync_status_flag); +} + +void NvExecutionProvider::PerThreadContext::IncrementRegularRunCountBeforeGraphCapture(CudaGraphAnnotation_t cuda_graph_annotation_id) { + graph_id_to_run_count_[cuda_graph_annotation_id]++; +} + bool NvExecutionProvider::PerThreadContext::IsTensorRTContextInMap(std::string fused_node) { auto it = trt_context_map_.find(fused_node); if (it != trt_context_map_.end()) { @@ -846,6 +932,12 @@ NvExecutionProvider::NvExecutionProvider(const NvExecutionProviderInfo& info) if (info.has_user_compute_stream) { external_stream_ = true; stream_ = static_cast(info.user_compute_stream); + } else if (cuda_graph_enable_) { + external_stream_ = false; + CUDA_CALL_THROW(cudaStreamCreate(&stream_)); + } else { + external_stream_ = false; + stream_ = nullptr; // Will be created in compute function } std::string profile_min_shapes, profile_max_shapes, profile_opt_shapes; @@ -1010,7 +1102,7 @@ NvExecutionProvider::NvExecutionProvider(const NvExecutionProviderInfo& info) // external stream: // If user provides "external" cuda stream, only this cuda stream will be used even if multiple threads are running InferenceSession.Run() concurrently. // So, no need to synchronize different streams after enqueueV3. - if (cuda_graph_enable_ || external_stream_) { + if (external_stream_) { sync_stream_after_enqueue_ = false; } @@ -1038,7 +1130,7 @@ NvExecutionProvider::NvExecutionProvider(const NvExecutionProviderInfo& info) << ", nv_force_sequential_engine_build: " << force_sequential_engine_build_ << ", nv_sparsity_enable: " << sparsity_enable_ << ", nv_auxiliary_streams: " << auxiliary_streams_ - << ", nv_cuda_graph_enable: " << cuda_graph_enable_ + << ", enable_cuda_graph: " << cuda_graph_enable_ << ", nv_dump_ep_context_model: " << dump_ep_context_model_ << ", nv_ep_context_file_path: " << ep_context_file_path_ << ", nv_ep_context_embed_mode: " << ep_context_embed_mode_ @@ -1060,7 +1152,7 @@ NvExecutionProvider::~NvExecutionProvider() { } } - if (!external_stream_ && stream_) { + if (!external_stream_ && stream_ != nullptr) { ORT_IGNORE_RETURN_VALUE(CUDA_CALL(cudaStreamDestroy(stream_))); } ReleaseTensorRTCustomOpDomainList(info_.custom_op_domain_list); @@ -1072,41 +1164,82 @@ NvExecutionProvider::~NvExecutionProvider() { } } +void NvExecutionProvider::HandleCudaGraphStart(cudaStream_t stream, bool require_io_binding, + CudaGraphAnnotation_t cuda_graph_annotation_id, bool& graph_replay_on_this_run, bool& should_start_capture) { + graph_replay_on_this_run = false; + should_start_capture = false; + + // Case 1: CUDA Graph capture is enabled AND IO binding is required. + // In this case, we force graph re-capture by resetting warmup runs. + // If a graph for this annotation ID already exists, delete it before proceeding. + if (require_io_binding && cuda_graph_enable_) { + GetPerThreadContext().ResetWarmupRuns(cuda_graph_annotation_id); + + if (GetPerThreadContext().IsGraphCaptured(cuda_graph_annotation_id)) { + LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] Graph already captured and required_io_binding is true, resetting warmup runs and deleting graph"; + GetPerThreadContext().DeleteCapturedGraph(cuda_graph_annotation_id); + } + // Case 2: CUDA Graph capture is enabled AND IO binding is NOT required + } else if (cuda_graph_enable_ && !require_io_binding) { + // If the graph is not yet captured, increment the regular run counter + if (cuda_graph_annotation_id != kCudaGraphAnnotationSkip && + !GetPerThreadContext().IsGraphCaptured(cuda_graph_annotation_id)) { + GetPerThreadContext().IncrementRegularRunCountBeforeGraphCapture(cuda_graph_annotation_id); + } + + // If capture is allowed and graph not already captured, + // set the stream and begin capture + if (!GetPerThreadContext().IsGraphCaptured(cuda_graph_annotation_id) && + GetPerThreadContext().IsGraphCaptureAllowed(cuda_graph_annotation_id)) { + GetPerThreadContext().SetCudaGraphStream(stream); + GetPerThreadContext().CaptureBegin(cuda_graph_annotation_id); + should_start_capture = true; + } + + // If a graph is already captured for this ID, mark it for replay in this run. + if (GetPerThreadContext().IsGraphCaptured(cuda_graph_annotation_id)) { + graph_replay_on_this_run = true; + } + } +} + bool NvExecutionProvider::IsGraphCaptureEnabled() const { return cuda_graph_enable_; } -bool NvExecutionProvider::IsGraphCaptureAllowed() const { - return regular_run_count_before_graph_capture_ >= min_num_runs_before_cuda_graph_capture_; +bool NvExecutionProvider::IsGraphCaptured(int graph_annotation_id) const { + // This is hardcoded to always return false because we are not allowing the ORT framework to have the CUDA graph control. + (void)graph_annotation_id; + return false; } -void NvExecutionProvider::CaptureBegin(int) { - cuda_graph_.Reset(); - cuda_graph_.CaptureBegin(0); +Status NvExecutionProvider::ReplayGraph(int graph_annotation_id) { + // This is hardcoded to always return OK because we are not allowing the ORT framework to have the CUDA graph control. + (void)graph_annotation_id; + return Status::OK(); } -void NvExecutionProvider::CaptureEnd(int) { - cuda_graph_.CaptureEnd(0); - is_graph_captured_ = true; -} +Status NvExecutionProvider::OnRunStart(const onnxruntime::RunOptions& run_options) { + if (cuda_graph_enable_) { + CudaGraphAnnotation_t cuda_graph_annotation_id = GetPerThreadContext().GetCudaGraphAnnotationId(run_options); + GetPerThreadContext().SetCurrentGraphAnnotationId(cuda_graph_annotation_id); + } -bool NvExecutionProvider::IsGraphCaptured(int) const { - return is_graph_captured_; + if (multi_profile_enable_ == true) { + auto graph_annotation_str = + run_options.GetConfigOptions().GetConfigEntry(nv::run_option_names::kProfileIndex); + TryParseStringWithClassicLocale(*graph_annotation_str, nv_profile_index_); + } + return Status::OK(); } -Status NvExecutionProvider::ReplayGraph(int) { - ORT_ENFORCE(IsGraphCaptured(0)); - // Please note that CUDAGraph::Replay() is not thread safe. - // ORT TRT calls ReplayGraph() in compute_func() where synchronization is enforced due to lock_guard(), - // therefore calling CUDAGraph::Replay() here is guaranteed to be thread safe. - return cuda_graph_.Replay(0); -} +Status NvExecutionProvider::OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& run_options) { + (void)run_options; -void NvExecutionProvider::IncrementRegularRunCountBeforeGraphCapture() { - // Please note that this function is not thread safe. - // ORT TRT calls this function in compute_func() where synchronization is enforced due to lock_guard(), - // therefore following increment is guaranteed to be thread safe. - ++regular_run_count_before_graph_capture_; + if (sync_stream && external_stream_) { + CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream_)); + } + return Status::OK(); } std::vector NvExecutionProvider::CreatePreferredAllocators() { @@ -1133,22 +1266,6 @@ std::unique_ptr NvExecutionProvider::GetDataTransfer() const { return std::make_unique(); } -Status NvExecutionProvider::OnRunStart(const onnxruntime::RunOptions& run_options) { - if (multi_profile_enable_ == true) { - auto graph_annotation_str = - run_options.GetConfigOptions().GetConfigEntry(nv::run_option_names::kProfileIndex); - TryParseStringWithClassicLocale(*graph_annotation_str, nv_profile_index_); - } - return Status::OK(); -} - -Status NvExecutionProvider::OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& /*run_options*/) { - if (sync_stream && external_stream_) { - CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream_)); - } - return Status::OK(); -} - // Get the pointer to the IBuilder instance. // Note: This function is not thread safe. Calls to this function from different threads must be serialized // even though it doesn't make sense to have multiple threads initializing the same inference session. @@ -2519,7 +2636,7 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr profile_opt_shapes_.find(input_name) != profile_opt_shapes_.end() && profile_max_shapes_.find(input_name) != profile_max_shapes_.end(); if (has_explicit_profile && tensor_has_profile) { - apply_profile = ApplyProfileShapesFromProviderOptions(trt_profiles, input, profile_min_shapes_, profile_max_shapes_, profile_opt_shapes_, input_explicit_shape_ranges); + apply_profile = ApplyProfileShapesFromProviderOptions(trt_profiles, input, profile_min_shapes_, profile_max_shapes_, profile_opt_shapes_, input_explicit_shape_ranges, cuda_graph_enable_); } else { LOGS_DEFAULT(INFO) << "[NvTensorRTRTX EP] Creating implicit profile for tensor " << input_name; profile_min_shapes_[input_name] = std::vector>{{}}; @@ -2546,7 +2663,7 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr profile_max_shapes_[input_name][0][idx_dim] = dim_value; } } - apply_profile = ApplyProfileShapesFromProviderOptions(trt_profiles, input, profile_min_shapes_, profile_max_shapes_, profile_opt_shapes_, input_explicit_shape_ranges); + apply_profile = ApplyProfileShapesFromProviderOptions(trt_profiles, input, profile_min_shapes_, profile_max_shapes_, profile_opt_shapes_, input_explicit_shape_ranges, cuda_graph_enable_); } if (!apply_profile) { std::ostringstream msg; @@ -2600,6 +2717,7 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr // Otherwise engine will be handled at inference time. std::unique_ptr trt_engine; std::unique_ptr trt_context; + std::unique_ptr trt_runtime_config; // Generate file name for dumping ep context model if (dump_ep_context_model_ && ctx_model_path_.empty()) { @@ -2622,6 +2740,13 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "NvTensorRTRTX EP failed to deserialize engine for fused node: " + fused_node.Name()); } + + trt_runtime_config = std::unique_ptr(trt_engine->createRuntimeConfig()); + if (trt_runtime_config && cuda_graph_enable_) { + trt_runtime_config->setDynamicShapesKernelSpecializationStrategy(nvinfer1::DynamicShapesKernelSpecializationStrategy::kEAGER); + } + trt_runtime_config->setExecutionContextAllocationStrategy(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED); + if (detailed_build_log_) { auto engine_build_stop = std::chrono::steady_clock::now(); LOGS_DEFAULT(INFO) << "TensorRT engine build for " << fused_node.Name() << " took: " << std::chrono::duration_cast(engine_build_stop - engine_build_start).count() << "ms" << std::endl; @@ -2681,7 +2806,7 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr // Build context // Note: Creating an execution context from an engine is thread safe per TRT doc // https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading - trt_context = std::unique_ptr(trt_engine->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED)); + trt_context = std::unique_ptr(trt_engine->createExecutionContext(trt_runtime_config.get())); if (!trt_context) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "NvTensorRTRTX EP could not build execution context for fused node: " + fused_node.Name()); @@ -2777,9 +2902,17 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr } OrtAllocator* alloc = alloc_; - void* cuda_stream; - Ort::ThrowOnError(api->KernelContext_GetGPUComputeStream(context, &cuda_stream)); - cudaStream_t stream = static_cast(cuda_stream); + cudaStream_t stream; + if (stream_ != nullptr) { + // Use our existing stream (either user's or our early-created) + stream = stream_; + } else { + // Create stream now (lazy creation case) + void* cuda_stream; + Ort::ThrowOnError(api->KernelContext_GetGPUComputeStream(context, &cuda_stream)); + stream = static_cast(cuda_stream); + stream_ = stream; + } if (multi_profile_enable_ == true) { if (!trt_context->setOptimizationProfileAsync(nv_profile_index_, stream)) @@ -2860,7 +2993,6 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr nvinfer1::Dims dims; void* data_ptr = nullptr; - Status status = BindContextOutput(ctx, trt_context, output_name, output_index, output_type, i, output_tensors, output_dim_sizes, dds_output_allocator_map, scratch_buffers, alloc, buffers, dims, data_ptr, skip_output_binding_allowed); if (status != Status::OK()) { @@ -2886,18 +3018,23 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr trt_context->setDeviceMemoryV2(trt_state->context_memory.get(), mem_size); } } - // Start CUDA graph capture. - // Note: The reason we don't put graph capture in OnRunStart() like CUDA EP does is because - // current ORT TRT doesn't get cuda stream until compute time and graph capture requires cuda stream. - if (cuda_graph_enable_ && IsGraphCaptureAllowed() && !IsGraphCaptured(0)) { - LOGS_DEFAULT(INFO) << "Capturing the cuda graph for this model"; - cuda_graph_.SetStream(stream); - CaptureBegin(0); - } - // Run TRT inference - if (!trt_context->enqueueV3(stream)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "NvTensorRTRTX EP execution context enqueue failed."); + // Start CUDA graph capture with the correct stream + // Note: We need to set the stream and start capture here because this is where we have access to the actual compute stream + // Get the graph annotation ID that was stored during OnRunStart + CudaGraphAnnotation_t cuda_graph_annotation_id = GetPerThreadContext().GetCurrentGraphAnnotationId(); + bool graph_replay_on_this_run = false; + bool should_start_capture = false; + + HandleCudaGraphStart(stream, require_io_binding, cuda_graph_annotation_id, + graph_replay_on_this_run, should_start_capture); + + if (!graph_replay_on_this_run) { + if (!trt_context->enqueueV3(stream)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "NvTensorRTRTX EP execution context enqueue failed."); + } + } else { + ORT_RETURN_IF_ERROR(GetPerThreadContext().ReplayGraph(cuda_graph_annotation_id, sync_stream_after_enqueue_)); } /* @@ -2914,10 +3051,15 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr * Therefore, TRT EP needs to call cudaStreamSynchronize() which means to wait until stream has completed all operations to prevent the concurrent issue mentioned above. * However, if cuda graph is enabled, TRT EP won't call cudaStreamSynchronize() since it's not allowed during graph capture. */ + + if (cuda_graph_enable_ && should_start_capture) { + GetPerThreadContext().CaptureEnd(cuda_graph_annotation_id); + ORT_RETURN_IF_ERROR(GetPerThreadContext().ReplayGraph(cuda_graph_annotation_id, sync_stream_after_enqueue_)); + } + if (sync_stream_after_enqueue_) { CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream)); } - // Assign TRT output back to ORT output // (1) Bind TRT DDS output to ORT kernel context output. (It needs to wait until enqueueV3 is finished) // (2) Cast TRT INT32 output to ORT INT64 output or TRT double output to float output @@ -2951,21 +3093,6 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr } } - // End CUDA graph capture. - // Note: One reason we don't put end of graph capture in OnRunEnd() like CUDA EP does is because of cuda stream mentioned in graph capture - // above, another reason is because OnRunEnd() is not synchronized with OnRunStart() and ExecuteGraph() per inference_session.cc. - // It's safe to start/end CUDA graph capture in compute_func() here since cuda graph object is maintained by a per thread basis. - if (cuda_graph_enable_ && !IsGraphCaptured(0)) { - if (IsGraphCaptureAllowed()) { - CaptureEnd(0); - // CUDA work issued to a capturing stream doesn't actually run on the GPU, - // so run the captured graph here to actually execute the work. - ORT_RETURN_IF_ERROR(ReplayGraph(0)); - } else { - IncrementRegularRunCountBeforeGraphCapture(); - } - } - return Status::OK(); }; @@ -3098,9 +3225,16 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(const Gra } OrtAllocator* alloc = alloc_; - void* cuda_stream; - Ort::ThrowOnError(api->KernelContext_GetGPUComputeStream(context, &cuda_stream)); - cudaStream_t stream = static_cast(cuda_stream); + cudaStream_t stream; + if (stream_ != nullptr) { + // Use our existing stream (either user's or our early-created) + stream = stream_; + } else { + // Create stream now (lazy creation case) + void* cuda_stream; + Ort::ThrowOnError(api->KernelContext_GetGPUComputeStream(context, &cuda_stream)); + stream = static_cast(cuda_stream); + } // Check before using trt_engine if (trt_engine == nullptr) { @@ -3203,18 +3337,23 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(const Gra trt_context->setDeviceMemoryV2(trt_state->context_memory.get(), mem_size); } } - // Start CUDA graph capture. - // Note: The reason we don't put graph capture in OnRunStart() like CUDA EP does is because - // current ORT TRT doesn't get cuda stream until compute time and graph capture requires cuda stream. - if (cuda_graph_enable_ && IsGraphCaptureAllowed() && !IsGraphCaptured(0)) { - LOGS_DEFAULT(INFO) << "Capturing the cuda graph for this model"; - cuda_graph_.SetStream(stream); - CaptureBegin(0); - } - // Run TRT inference - if (!trt_context->enqueueV3(stream)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "NvTensorRTRTX EP execution context enqueue failed."); + // Start CUDA graph capture with the correct stream + // Note: We need to set the stream and start capture here because this is where we have access to the actual compute stream + // Get the graph annotation ID that was stored during OnRunStart + CudaGraphAnnotation_t cuda_graph_annotation_id = GetPerThreadContext().GetCurrentGraphAnnotationId(); + bool graph_replay_on_this_run = false; + bool should_start_capture = false; + + HandleCudaGraphStart(stream, require_io_binding, cuda_graph_annotation_id, + graph_replay_on_this_run, should_start_capture); + + if (!graph_replay_on_this_run) { + if (!trt_context->enqueueV3(stream)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "NvTensorRTRTX EP execution context enqueue failed."); + } + } else { + ORT_RETURN_IF_ERROR(GetPerThreadContext().ReplayGraph(cuda_graph_annotation_id, sync_stream_after_enqueue_)); } /* @@ -3231,10 +3370,15 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(const Gra * Therefore, TRT EP needs to call cudaStreamSynchronize() which means to wait until stream has completed all operations to prevent the concurrent issue mentioned above. * However, if cuda graph is enabled, TRT EP won't call cudaStreamSynchronize() since it's not allowed during graph capture. */ + + if (cuda_graph_enable_ && should_start_capture) { + GetPerThreadContext().CaptureEnd(cuda_graph_annotation_id); + ORT_RETURN_IF_ERROR(GetPerThreadContext().ReplayGraph(cuda_graph_annotation_id, sync_stream_after_enqueue_)); + } + if (sync_stream_after_enqueue_) { CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream)); } - // Assign TRT output back to ORT output // (1) Bind TRT DDS output to ORT kernel context output. (It needs to wait until enqueueV3 is finished) // (2) Cast TRT INT32 output to ORT INT64 output or TRT double output to float output @@ -3268,21 +3412,6 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(const Gra } } - // End CUDA graph capture. - // Note: One reason we don't put end of graph capture in OnRunEnd() like CUDA EP does is because of cuda stream mentioned in graph capture - // above, another reason is because OnRunEnd() is not synchronized with OnRunStart() and ExecuteGraph() per inference_session.cc. - // It's safe to start/end CUDA graph capture in compute_func() here since cuda graph object is maintained by a per thread basis. - if (cuda_graph_enable_ && !IsGraphCaptured(0)) { - if (IsGraphCaptureAllowed()) { - CaptureEnd(0); - // CUDA work issued to a capturing stream doesn't actually run on the GPU, - // so run the captured graph here to actually execute the work. - ORT_RETURN_IF_ERROR(ReplayGraph(0)); - } else { - IncrementRegularRunCountBeforeGraphCapture(); - } - } - return Status::OK(); }; diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h index e3dd38eb837ff..22b8314649757 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h @@ -12,7 +12,7 @@ typedef void* cublasHandle_t; typedef void* cudnnStatus_t; #endif #include "core/providers/nv_tensorrt_rtx/nv_includes.h" - +#include "core/session/onnxruntime_run_options_config_keys.h" #include #include "core/providers/cuda/cuda_graph.h" #include "nv_execution_provider_info.h" @@ -305,9 +305,11 @@ class NvExecutionProvider : public IExecutionProvider { std::vector CreatePreferredAllocators() override; + // CUDA Graph support bool IsGraphCaptureEnabled() const override; bool IsGraphCaptured(int graph_annotation_id) const override; Status ReplayGraph(int graph_annotation_id) override; + void HandleCudaGraphStart(cudaStream_t stream, bool require_io_binding, CudaGraphAnnotation_t cuda_graph_annotation_id, bool& graph_replay_on_this_run, bool& should_start_capture); static common::Status RefitEngine(std::string onnx_model_filename, std::string& onnx_model_folder_path, @@ -405,15 +407,6 @@ class NvExecutionProvider : public IExecutionProvider { // Call cudaStreamSynchronize() after TRT enqueueV3() mutable bool sync_stream_after_enqueue_ = true; - CUDAGraph cuda_graph_; - bool is_graph_captured_ = false; - int regular_run_count_before_graph_capture_ = 0; - // There is chance (currently only happens in CUDA EP) that the second regular run allocates GPU memory for causes like: - // (1) memory pattern is enabled. (2) arena allocation for stream. - // Since no GPU memory allocation is allowed during graph capturing, we need at least two regular runs - // to allocate enough memory in Arena before graph capturing. - const int min_num_runs_before_cuda_graph_capture_ = 1; // required min regular runs before graph capture for the necessary memory allocations. - // [Note] We don't use PerThreadContext for now since it has issue with multithreading // // TRT or CUDA objects that must be maintained on a per thread basis will be put under this PerThreadContext data structure. @@ -436,14 +429,20 @@ class NvExecutionProvider : public IExecutionProvider { bool UpdateTensorRTContext(std::string fused_node, std::unique_ptr context); void ResetTensorRTContext(std::string fused_node); - void InitCUDAGraph(); - void SetGraphStream(cudaStream_t stream); - bool IsGraphCaptureAllowed() const; - void CaptureBegin(int graph_annotation_id); - void CaptureEnd(int graph_annotation_id); - bool IsGraphCaptured(int graph_annotation_id) const; - Status ReplayGraph(int graph_annotation_id); - void IncrementRegularRunCountBeforeGraphCapture(); + // CUDA Graph management + void SetCudaGraphStream(cudaStream_t stream) { cuda_graph_.SetStream(stream); } + bool IsGraphCaptureAllowed(CudaGraphAnnotation_t cuda_graph_annotation_id) const; + bool IsGraphCaptureAllowedOnRun(CudaGraphAnnotation_t cuda_graph_annotation_id) const; + CudaGraphAnnotation_t GetCudaGraphAnnotationId(const onnxruntime::RunOptions& run_options) const; + void SetCurrentGraphAnnotationId(CudaGraphAnnotation_t cuda_graph_annotation_id); + CudaGraphAnnotation_t GetCurrentGraphAnnotationId() const; + void CaptureBegin(CudaGraphAnnotation_t cuda_graph_annotation_id); + void CaptureEnd(CudaGraphAnnotation_t cuda_graph_annotation_id); + bool IsGraphCaptured(CudaGraphAnnotation_t cuda_graph_annotation_id) const; + Status ReplayGraph(CudaGraphAnnotation_t cuda_graph_annotation_id, bool sync_status_flag); + void IncrementRegularRunCountBeforeGraphCapture(CudaGraphAnnotation_t cuda_graph_annotation_id); + void ResetWarmupRuns(CudaGraphAnnotation_t cuda_graph_annotation_id); + void DeleteCapturedGraph(CudaGraphAnnotation_t cuda_graph_annotation_id); private: cudnnHandle_t external_cudnn_handle_ = nullptr; @@ -466,13 +465,18 @@ class NvExecutionProvider : public IExecutionProvider { // Cuda graph with multi threads will be supported in the future, so cuda_graph_ is put under PerThreadContext. // ORT TRT only supports CUDA graph when whole model is supported by TRT, so simply maintaining a CUDAGraph instance is enough (no need to maintain one CUDAGraph instance per TRT subgraph) CUDAGraph cuda_graph_; + // Map of graph id to regular_run_count_before_graph_capture + std::unordered_map graph_id_to_run_count_; bool is_graph_captured_ = false; int regular_run_count_before_graph_capture_ = 0; + // Current graph annotation ID for this run + CudaGraphAnnotation_t current_graph_annotation_id_ = kCudaGraphAnnotationDefault; // There is chance (currently only happens in CUDA EP) that the second regular run allocates GPU memory for causes like: // (1) memory pattern is enabled. (2) arena allocation for stream. // Since no GPU memory allocation is allowed during graph capturing, we need at least two regular runs // to allocate enough memory in Arena before graph capturing. - const int min_num_runs_before_cuda_graph_capture_ = 1; // required min regular runs before graph capture for the necessary memory allocations. + const int min_num_runs_before_cuda_graph_capture_ = 2; // required min regular runs before graph capture for the necessary memory allocations. + // https://github.com/NVIDIA/TensorRT/blob/main/samples/common/sampleInference.cpp#L1258-L1291 Based on the trtexec code }; using PerThreadContextMap = std::unordered_map>; @@ -606,11 +610,6 @@ class NvExecutionProvider : public IExecutionProvider { std::unordered_map& output_map, std::vector& node_compute_funcs); - bool IsGraphCaptureAllowed() const; - void CaptureBegin(int graph_annotation_id); - void CaptureEnd(int graph_annotation_id); - void IncrementRegularRunCountBeforeGraphCapture(); - /** * Get the pointer to the IBuilder instance. * This function only creates the instance at the first time it's being called." diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/einsum_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/einsum_op_builder.cc index a22a331b2453f..0925d8e1a6062 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/einsum_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/einsum_op_builder.cc @@ -48,13 +48,13 @@ std::optional ParseEquation(std::string_view equation_string) { if (term_1.size() < 2 || term_2.size() < 2 || result.size() < 2) { return std::nullopt; } - if (!std::all_of(term_1.begin(), term_1.end(), [](unsigned char c) { return std::islower(c); })) { + if (!std::all_of(term_1.begin(), term_1.end(), [](unsigned char c) { return std::isalpha(c); })) { return std::nullopt; } - if (!std::all_of(term_2.begin(), term_2.end(), [](unsigned char c) { return std::islower(c); })) { + if (!std::all_of(term_2.begin(), term_2.end(), [](unsigned char c) { return std::isalpha(c); })) { return std::nullopt; } - if (!std::all_of(result.begin(), result.end(), [](unsigned char c) { return std::islower(c); })) { + if (!std::all_of(result.begin(), result.end(), [](unsigned char c) { return std::isalpha(c); })) { return std::nullopt; } return std::make_tuple(term_1, term_2, result); diff --git a/onnxruntime/core/providers/webgpu/nn/layer_norm.cc b/onnxruntime/core/providers/webgpu/nn/layer_norm.cc index 5ac4541e8323e..7d4ae8c2197ff 100644 --- a/onnxruntime/core/providers/webgpu/nn/layer_norm.cc +++ b/onnxruntime/core/providers/webgpu/nn/layer_norm.cc @@ -26,7 +26,7 @@ static TensorShape GetOverrideShape(const TensorShape& shape, int components) { } Status LayerNormProgram::GenerateShaderCode(ShaderHelper& shader) const { - const auto& x = shader.AddInput("x", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); + const auto& x = shader.AddInput("x", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); shader.AddInput("scale", ShaderUsage::UseUniform); if (has_bias_) { shader.AddInput("bias", ShaderUsage::UseUniform); @@ -39,35 +39,113 @@ Status LayerNormProgram::GenerateShaderCode(ShaderHelper& shader) const { shader.AddOutput("inv_std_dev_output", ShaderUsage::None); } - int components = x.NumComponents(); - std::string bias = (has_bias_) ? " + bias[j]" : ""; - std::string simpl1 = (simplified_) ? "" : " - mean * mean"; - std::string simpl2 = (simplified_) ? "" : " - mean"; - - shader.AdditionalImplementation() << "alias element_t = " << (is_fp16_ ? "f16;\n" : "f32;\n") - << "alias f32_val_t = " << (components == 4 ? "vec4" : (components == 2 ? "vec2" : "f32")) << ";\n"; - - shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.norm_count") - << "let offset = global_idx * uniforms.norm_size_vectorized;\n" - << "var mean_vector = f32_val_t(0);\n" - << "var mean_square_vector = f32_val_t(0);\n" - << "for (var h: u32 = 0u; h < uniforms.norm_size_vectorized; h++) {\n" - << " let value = f32_val_t(x[h + offset]);\n" - << " mean_vector += value;\n" - << " mean_square_vector += value * value;\n" - << "}\n" - << "let mean = " << SumVector("mean_vector", components) << " / f32(uniforms.norm_size);\n" - << "let inv_std_dev = inverseSqrt(" << SumVector("mean_square_vector", components) << " / f32(uniforms.norm_size)" << simpl1 << " + uniforms.epsilon);\n" - << "for (var j: u32 = 0; j < uniforms.norm_size_vectorized; j++) {\n" - << " let f32input = f32_val_t(x[j + offset]);\n" - << " let f32scale = f32_val_t(scale[j]);\n" - << " y[j + offset] = x_value_t((f32input" << simpl2 << ") * inv_std_dev * f32scale)" << bias << ";\n" - << "}\n"; - if (has_mean_output_) { - shader.MainFunctionBody() << "mean_output[global_idx] = mean;\n"; - } - if (has_inv_std_dev_output_) { - shader.MainFunctionBody() << "inv_std_dev_output[global_idx] = inv_std_dev;\n"; + std::string simpl1 = (simplified_) ? "" : "- mean * mean "; + std::string simpl2 = (simplified_) ? "" : "- x_element_t(mean) "; + + if (split_norm_dim_) { + shader.AdditionalImplementation() + << "var sum_shared : array;\n" + << "var sum_squared_shared : array;\n"; + + shader.MainFunctionBody() + << " var sum_vec4 = vec4(0);\n" + << " var sum_squared_vec4 = vec4(0);\n" + << " var cur_input = x_value_t(0);\n" + << " for (var i: u32 = 0; i < uniforms.norm_size / (workgroup_size_x * 4); i++) {\n" + << " let input_offset = i * workgroup_size_x + local_idx;\n" + << " let input_value = x[input_offset];\n" + << " if (i == workgroup_idx) {\n" + << " cur_input = input_value;\n" + << " }\n" + << " let f32_value = vec4(input_value);\n" + << " sum_vec4 += f32_value;\n" + << " sum_squared_vec4 += f32_value * f32_value;\n" + << " }\n" + << " var sum = " << SumVector("sum_vec4", 4) << ";\n" + << " var sum_squared = " << SumVector("sum_squared_vec4", 4) << ";\n" + << " sum_shared[local_idx] = sum;\n" + << " sum_squared_shared[local_idx] = sum_squared;\n" + << " workgroupBarrier();\n" + << " var reduce_size : u32 = workgroup_size_x;\n" + << " for (var curr_size = reduce_size >> 1; curr_size > 0; curr_size = reduce_size >> 1) {\n" + << " reduce_size = curr_size + (reduce_size & 1);\n" + << " if (local_idx < curr_size) {\n" + << " sum_shared[local_idx] += sum_shared[local_idx + reduce_size];\n" + << " sum_squared_shared[local_idx] += sum_squared_shared[local_idx + reduce_size];\n" + << " }\n" + << " workgroupBarrier();\n" + << " }\n" + << " let mean = sum_shared[0] / f32(uniforms.norm_size);\n" + << " let inv_std_dev = inverseSqrt(sum_squared_shared[0] / f32(uniforms.norm_size) " << simpl1 << "+ uniforms.epsilon);\n" + << " let offset = workgroup_idx * workgroup_size_x + local_idx;\n" + << " y[offset] = ((cur_input " << simpl2 << ") * x_element_t(inv_std_dev) * scale[offset]" << (has_bias_ ? " + bias[offset] " : "") << ");\n"; + + if (has_mean_output_) { + shader.MainFunctionBody() << " if (local_idx == 0 && workgroup_idx == 0) {\n" + << " mean_output[global_idx / uniforms.norm_size] = mean;\n" + << " }\n"; + } + if (has_inv_std_dev_output_) { + shader.MainFunctionBody() << " if (local_idx == 0 && workgroup_idx == 0) {\n" + << " inv_std_dev_output[global_idx / uniforms.norm_size] = inv_std_dev;\n" + << " }\n"; + } + } else { + int components = x.NumComponents(); + std::string bias = (has_bias_) ? " + bias[offset1d + i] " : ""; + + shader.AdditionalImplementation() + << "alias f32_val_t = " << (components == 4 ? "vec4" : (components == 2 ? "vec2" : "f32")) << ";\n" + << "var sum_shared : array;\n" + << "var sum_squared_shared : array;\n"; + + shader.MainFunctionBody() + << "let ix = local_idx;\n" + << "let iy = global_idx / workgroup_size_x;\n" + << "let norm_size_vectorized: u32 = uniforms.norm_size / uniforms.components;\n" + << "var stride = norm_size_vectorized / workgroup_size_x;\n" + << "let offset = ix * stride + iy * norm_size_vectorized;\n" + << "let offset1d = stride * ix;\n" + << "sum_shared[ix] = f32_val_t(0);\n" + << "sum_squared_shared[ix] = f32_val_t(0);\n" + << "if (ix == workgroup_size_x - 1) {\n" + << " stride = norm_size_vectorized - stride * ix;\n" + << "}\n" + << "for (var i: u32 = 0; i < stride; i++) {\n" + << " let input_value = x[offset + i];\n" + << " y[offset + i] = input_value;\n" + << " let f32_value = f32_val_t(input_value);\n" + << " sum_shared[ix] += f32_value;\n" + << " sum_squared_shared[ix] += f32_value * f32_value;\n" + << "}\n" + << "workgroupBarrier();\n" + << "var reduce_size : u32 = workgroup_size_x;\n" + << "for (var curr_size = reduce_size >> 1; curr_size > 0; curr_size = reduce_size >> 1) {\n" + << " reduce_size = curr_size + (reduce_size & 1);\n" + << " if (ix < curr_size) {\n" + << " sum_shared[ix] += sum_shared[ix + reduce_size];\n" + << " sum_squared_shared[ix] += sum_squared_shared[ix + reduce_size];\n" + << " }\n" + << " workgroupBarrier();\n" + << "}\n" + << "let sum = sum_shared[0];\n" + << "let square_sum = sum_squared_shared[0];\n" + << "let mean = " << SumVector("sum", components) << " / f32(uniforms.norm_size);\n" + << "let inv_std_dev = inverseSqrt(" << SumVector("square_sum", components) << " / f32(uniforms.norm_size) " << simpl1 << "+ uniforms.epsilon);\n" + << "for (var i: u32 = 0; i < stride; i++) {\n" + << " y[offset + i] = (y[offset + i] " << simpl2 << ") * x_element_t(inv_std_dev) * scale[offset1d + i]" << bias << ";\n" + << "};\n"; + + if (has_mean_output_) { + shader.MainFunctionBody() << "if (ix == 0) {\n" + << " mean_output[iy] = mean;\n" + << "}\n"; + } + if (has_inv_std_dev_output_) { + shader.MainFunctionBody() << "if (ix == 0) {\n" + << " inv_std_dev_output[iy] = inv_std_dev;\n" + << "}\n"; + } } return Status::OK(); @@ -81,8 +159,6 @@ Status LayerNorm::ComputeInternal(onnxruntime::webgpu::ComputeContex const auto x_shape = x->Shape(); - const bool is_fp16 = x->GetElementType() == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; - const size_t axis = NormalizeAxis(axis_, x_shape.NumDimensions()); const uint32_t norm_count = onnxruntime::narrow(x_shape.SizeToDimension(axis)); const int64_t norm_size = x_shape.SizeFromDimension(axis); @@ -116,14 +192,19 @@ Status LayerNorm::ComputeInternal(onnxruntime::webgpu::ComputeContex return Status::OK(); } - LayerNormProgram program{bias != nullptr, is_fp16, simplified, mean != nullptr, inv_std_dev != nullptr}; + // Check if we should use split norm dimension optimization + const bool split_norm_dim = norm_size % 512 == 0 && norm_count == 1; + + LayerNormProgram program{bias != nullptr, simplified, mean != nullptr, inv_std_dev != nullptr, split_norm_dim}; - program.CacheHint(components, simplified) + program.CacheHint(components, simplified, split_norm_dim) .AddInputs({{x, ProgramTensorMetadataDependency::Type, GetOverrideShape(x->Shape(), components), components}}) .AddInputs( {{scale, ProgramTensorMetadataDependency::Type, GetOverrideShape(scale->Shape(), components), components}}) .AddOutputs({{y, ProgramTensorMetadataDependency::None, GetOverrideShape(y->Shape(), components), components}}) - .SetDispatchGroupSize((norm_count + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .AddUniformVariables({ + {static_cast(components)}, + }) .AddUniformVariables({ {static_cast(norm_count)}, }) @@ -137,6 +218,15 @@ Status LayerNorm::ComputeInternal(onnxruntime::webgpu::ComputeContex {static_cast(epsilon_)}, }); + if (split_norm_dim) { + const uint32_t workgroup_size_x = 128; + const uint32_t dispatch_size_x = onnxruntime::narrow(norm_size / (workgroup_size_x * components)); + program.SetDispatchGroupSize(dispatch_size_x, 1, 1) + .SetWorkgroupSize(workgroup_size_x); + } else { + program.SetDispatchGroupSize(norm_count); + } + if (bias != nullptr) { program.AddInput( {bias, ProgramTensorMetadataDependency::Type, GetOverrideShape(bias->Shape(), components), components}); diff --git a/onnxruntime/core/providers/webgpu/nn/layer_norm.h b/onnxruntime/core/providers/webgpu/nn/layer_norm.h index c7cb9280a0b77..112b152d37130 100644 --- a/onnxruntime/core/providers/webgpu/nn/layer_norm.h +++ b/onnxruntime/core/providers/webgpu/nn/layer_norm.h @@ -11,28 +11,29 @@ namespace webgpu { class LayerNormProgram final : public Program { public: - LayerNormProgram(bool has_bias, bool is_fp16, bool simplified, bool has_mean_output, - bool has_inv_std_dev_output) + LayerNormProgram(bool has_bias, bool simplified, bool has_mean_output, + bool has_inv_std_dev_output, bool split_norm_dim = false) : Program{"LayerNorm"}, has_bias_{has_bias}, - is_fp16_{is_fp16}, simplified_{simplified}, has_mean_output_{has_mean_output}, - has_inv_std_dev_output_{has_inv_std_dev_output} {} + has_inv_std_dev_output_{has_inv_std_dev_output}, + split_norm_dim_{split_norm_dim} {} Status GenerateShaderCode(ShaderHelper& sh) const override; - WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"norm_count", ProgramUniformVariableDataType::Uint32}, + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"components", ProgramUniformVariableDataType::Uint32}, + {"norm_count", ProgramUniformVariableDataType::Uint32}, {"norm_size", ProgramUniformVariableDataType::Uint32}, {"norm_size_vectorized", ProgramUniformVariableDataType::Uint32}, {"epsilon", ProgramUniformVariableDataType::Float32}); private: bool has_bias_; - bool is_fp16_; bool simplified_; bool has_mean_output_; bool has_inv_std_dev_output_; + bool split_norm_dim_; }; template diff --git a/onnxruntime/core/providers/webgpu/tensor/unsqueeze.cc b/onnxruntime/core/providers/webgpu/tensor/unsqueeze.cc index 4bcef4fd79296..b89623110217d 100644 --- a/onnxruntime/core/providers/webgpu/tensor/unsqueeze.cc +++ b/onnxruntime/core/providers/webgpu/tensor/unsqueeze.cc @@ -11,7 +11,31 @@ namespace webgpu { ONNX_OPERATOR_KERNEL_EX( Unsqueeze, kOnnxDomain, - 13, + 23, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .TypeConstraint("axes", DataTypeImpl::GetTensorType()) + .Alias(0, 0) + .InputMemoryType(OrtMemTypeCPU, 1), + Unsqueeze); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Unsqueeze, + kOnnxDomain, + 21, 22, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .TypeConstraint("axes", DataTypeImpl::GetTensorType()) + .Alias(0, 0) + .InputMemoryType(OrtMemTypeCPU, 1), + Unsqueeze); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Unsqueeze, + kOnnxDomain, + 13, 20, kWebGpuExecutionProvider, (*KernelDefBuilder::Create()) .TypeConstraint("T", WebGpuSupportedNumberTypes()) diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc index 2afe8a964a003..bbb3fbdd221d3 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc @@ -249,7 +249,9 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, Unsqueeze); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Unsqueeze); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Unsqueeze); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 20, Unsqueeze); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 21, 22, Unsqueeze); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 23, Unsqueeze); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 9, 15, Where); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 16, Where); @@ -548,7 +550,9 @@ std::unique_ptr RegisterKernels() { BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/webnn/builders/impl/unary_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/unary_op_builder.cc index 91af452c64efd..87d91178f07bc 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/unary_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/unary_op_builder.cc @@ -27,42 +27,14 @@ Status UnaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const const auto& op_type(node.OpType()); emscripten::val input = model_builder.GetOperand(node.InputDefs()[0]->Name()); - emscripten::val output = emscripten::val::object(); emscripten::val options = emscripten::val::object(); options.set("label", node.Name()); - if (op_type == "Abs") { - output = model_builder.GetBuilder().call("abs", input, options); - } else if (op_type == "Ceil") { - output = model_builder.GetBuilder().call("ceil", input, options); - } else if (op_type == "Cos") { - output = model_builder.GetBuilder().call("cos", input, options); - } else if (op_type == "Erf") { - output = model_builder.GetBuilder().call("erf", input, options); - } else if (op_type == "Exp") { - output = model_builder.GetBuilder().call("exp", input, options); - } else if (op_type == "Floor") { - output = model_builder.GetBuilder().call("floor", input, options); - } else if (op_type == "Identity") { - output = model_builder.GetBuilder().call("identity", input, options); - } else if (op_type == "Log") { - output = model_builder.GetBuilder().call("log", input, options); - } else if (op_type == "Neg") { - output = model_builder.GetBuilder().call("neg", input, options); - } else if (op_type == "Reciprocal") { - output = model_builder.GetBuilder().call("reciprocal", input, options); - } else if (op_type == "Sign") { - output = model_builder.GetBuilder().call("sign", input, options); - } else if (op_type == "Sin") { - output = model_builder.GetBuilder().call("sin", input, options); - } else if (op_type == "Sqrt") { - output = model_builder.GetBuilder().call("sqrt", input, options); - } else if (op_type == "Tan") { - output = model_builder.GetBuilder().call("tan", input, options); - } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "UnaryOpBuilder::AddToModelBuilderImpl, unknown op: ", op_type); - } + const std::string_view webnn_op_type = GetWebNNOpType(op_type); + ORT_RETURN_IF(webnn_op_type.empty(), "Cannot get WebNN op type"); + + emscripten::val output = model_builder.GetBuilder().call( + std::string(webnn_op_type).c_str(), input, options); model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); return Status::OK(); @@ -84,6 +56,7 @@ void CreateUnaryOpBuilder(const std::string& op_type, OpBuilderRegistrations& op "Log", "Neg", "Reciprocal", + "Round", "Sign", "Sin", "Sqrt", diff --git a/onnxruntime/core/providers/webnn/builders/map_info.h b/onnxruntime/core/providers/webnn/builders/map_info.h index 1c30fed7a7916..590614edf851c 100644 --- a/onnxruntime/core/providers/webnn/builders/map_info.h +++ b/onnxruntime/core/providers/webnn/builders/map_info.h @@ -174,6 +174,7 @@ const std::unordered_map op_inputs_map = { {"Greater", {"greater", {{0, "a"}, {1, "b"}}}}, {"Reciprocal", {"reciprocal", {{0, "input"}}}}, {"ReduceMean", {"reduceMean", {{0, "input"}}}}, + {"Round", {"roundEven", {{0, "input"}}}}, {"GlobalMaxPool", {"maxPool2d", {{0, "input"}}}}, {"HardSigmoid", {"hardSigmoid", {{0, "input"}}}}, {"ReduceProd", {"reduceProduct", {{0, "input"}}}}, @@ -200,7 +201,7 @@ const std::unordered_map op_inputs_map = { {"DequantizeLinear", {"dequantizeLinear", {{0, "input"}, {1, "scale"}, {2, "zeroPoint"}}}}, {"InstanceNormalization", {"instanceNormalization", {{0, "input"}, {1, "scale"}, {2, "bias"}}}}, {"GRU", {"gru", {{0, "input"}, {1, "weight"}, {2, "recurrentWeight"}, {3, "bias"}, {5, "initialHiddenState"}}}}, - {"BatchNormalization", {"batchNormalization", {{0, "input"}, {1, "scale"}, {2, "bias"}, {3, "input_mean"}, {4, "input_var"}}}}, + {"BatchNormalization", {"batchNormalization", {{0, "input"}, {1, "scale"}, {2, "bias"}, {3, "mean"}, {4, "variance"}}}}, {"LSTM", {"lstm", {{0, "input"}, {1, "weight"}, {2, "recurrentWeight"}, {3, "bias"}, {5, "initialHiddenState"}, {6, "initialCellState"}, {7, "peepholeWeight"}}}}, }; } // namespace webnn diff --git a/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc b/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc index 67e3123ee7af6..c761213ef4dc2 100644 --- a/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc +++ b/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc @@ -26,6 +26,7 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() { CreateUnaryOpBuilder("Log", op_registrations); CreateUnaryOpBuilder("Neg", op_registrations); CreateUnaryOpBuilder("Reciprocal", op_registrations); + CreateUnaryOpBuilder("Round", op_registrations); CreateUnaryOpBuilder("Sign", op_registrations); CreateUnaryOpBuilder("Sin", op_registrations); CreateUnaryOpBuilder("Sqrt", op_registrations); diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 1f491bc788870..ad0a1ad137f06 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -3423,25 +3423,86 @@ ORT_API_STATUS_IMPL(OrtApis::CopyTensors, _In_ const OrtEnv* env, API_IMPL_END } +// Validate compiled model compatibility info for specific EP device(s) +ORT_API_STATUS_IMPL(OrtApis::GetModelCompatibilityForEpDevices, + _In_reads_(num_ep_devices) const OrtEpDevice* const* ep_devices, + _In_ size_t num_ep_devices, + _In_ const char* compatibility_info, + _Out_ OrtCompiledModelCompatibility* out_status) { + API_IMPL_BEGIN + if (ep_devices == nullptr || num_ep_devices == 0 || compatibility_info == nullptr || out_status == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Invalid argument provided to GetModelCompatibilityForEpDevices."); + } + + // Validate inputs and ensure all devices belong to the same EP/factory + const OrtEpFactory* first_factory = nullptr; + for (size_t i = 0; i < num_ep_devices; ++i) { + if (ep_devices[i] == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "ep_devices contains a null entry."); + } + const OrtEpFactory* f = ep_devices[i]->GetMutableFactory(); + if (i == 0) { + first_factory = f; + } else if (f != first_factory) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "All ep_devices must be from the same execution provider."); + } + } + + OrtCompiledModelCompatibility status = OrtCompiledModelCompatibility_EP_NOT_APPLICABLE; + OrtStatus* ort_status = nullptr; + OrtEpFactory* factory = ep_devices[0]->GetMutableFactory(); + if (factory && factory->ValidateCompiledModelCompatibilityInfo) { + // collect hardware devices corresponding to the ep_devices + InlinedVector hardware_devices; + hardware_devices.reserve(num_ep_devices); + for (size_t i = 0; i < num_ep_devices; ++i) { + hardware_devices.push_back(ep_devices[i]->device); + } + ort_status = factory->ValidateCompiledModelCompatibilityInfo(factory, + hardware_devices.data(), + hardware_devices.size(), + compatibility_info, + &status); + } + if (ort_status != nullptr) { + return ToOrtStatus(ToStatusAndRelease(ort_status)); + } + + *out_status = status; + return nullptr; + API_IMPL_END +} + #else // defined(ORT_MINIMAL_BUILD) ORT_API_STATUS_IMPL(OrtApis::RegisterExecutionProviderLibrary, _In_ OrtEnv* /*env*/, _In_ const char* /*registration_name*/, const ORTCHAR_T* /*path*/) { API_IMPL_BEGIN - return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "This API in not supported in a minimal build."); + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "RegisterExecutionProviderLibrary is not supported in a minimal build."); API_IMPL_END } ORT_API_STATUS_IMPL(OrtApis::UnregisterExecutionProviderLibrary, _In_ OrtEnv* /*env*/, _In_ const char* /*registration_name*/) { API_IMPL_BEGIN - return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "This API in not supported in a minimal build."); + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "UnregisterExecutionProviderLibrary is not supported in a minimal build."); API_IMPL_END } ORT_API_STATUS_IMPL(OrtApis::GetEpDevices, _In_ const OrtEnv* /*env*/, _Outptr_ const OrtEpDevice* const** /*ep_devices*/, _Out_ size_t* /*num_ep_devices*/) { API_IMPL_BEGIN - return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "This API in not supported in a minimal build."); + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "GetEpDevices is not supported in a minimal build."); + API_IMPL_END +} + +// Minimal build stub for GetModelCompatibilityForEpDevices to satisfy symbol references from the API table +ORT_API_STATUS_IMPL(OrtApis::GetModelCompatibilityForEpDevices, + _In_reads_(num_ep_devices) const OrtEpDevice* const* /*ep_devices*/, + _In_ size_t /*num_ep_devices*/, + _In_ const char* /*compatibility_info*/, + _Out_ OrtCompiledModelCompatibility* /*out_status*/) { + API_IMPL_BEGIN + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "GetModelCompatibilityForEpDevices is not supported in a minimal build."); API_IMPL_END } @@ -3453,7 +3514,7 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_V2, _In_ OrtS _In_reads_(num_op_options) const char* const* /*ep_option_vals*/, size_t /*num_ep_options*/) { API_IMPL_BEGIN - return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "This API in not supported in a minimal build."); + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "SessionOptionsAppendExecutionProvider_V2 is not supported in a minimal build."); API_IMPL_END } @@ -3466,7 +3527,7 @@ ORT_API_STATUS_IMPL(OrtApis::SessionGetEpDeviceForInputs, _In_ const OrtSession* _Out_writes_(num_values) const OrtEpDevice** /*inputs_ep_devices*/, _In_ size_t /*num_values*/) { API_IMPL_BEGIN - return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "This API in not supported in a minimal build."); + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "SessionGetEpDeviceForInputs is not supported in a minimal build."); API_IMPL_END } @@ -3474,7 +3535,7 @@ ORT_API_STATUS_IMPL(OrtApis::CreateSyncStreamForEpDevice, _In_ const OrtEpDevice _In_opt_ const OrtKeyValuePairs* /*stream_options*/, _Outptr_ OrtSyncStream** /*ort_stream*/) { API_IMPL_BEGIN - return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "This API in not supported in a minimal build."); + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "CreateSyncStreamForEpDevice is not supported in a minimal build."); API_IMPL_END } @@ -3493,7 +3554,7 @@ ORT_API_STATUS_IMPL(OrtApis::CopyTensors, _In_ const OrtEnv* /*env*/, _In_opt_ OrtSyncStream* /*stream*/, _In_ size_t /*num_tensors*/) { API_IMPL_BEGIN - return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "This API in not supported in a minimal build."); + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "CopyTensors is not supported in a minimal build."); API_IMPL_END } @@ -4108,6 +4169,7 @@ static constexpr OrtApi ort_api_1_to_23 = { &OrtApis::CopyTensors, &OrtApis::Graph_GetModelMetadata, + &OrtApis::GetModelCompatibilityForEpDevices, }; // OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase. diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index b3b0036c68247..e62149d04a16c 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -636,6 +636,13 @@ ORT_API_STATUS_IMPL(ValueInfo_IsFromOuterScope, _In_ const OrtValueInfo* value_i // OrtGraph ORT_API_STATUS_IMPL(Graph_GetName, _In_ const OrtGraph* graph, _Outptr_ const char** graph_name); ORT_API_STATUS_IMPL(Graph_GetModelMetadata, _In_ const OrtGraph* graph, _Outptr_ OrtModelMetadata** out); + +// EP Compatibility Info APIs +ORT_API_STATUS_IMPL(GetModelCompatibilityForEpDevices, + _In_reads_(num_ep_devices) const OrtEpDevice* const* ep_devices, + _In_ size_t num_ep_devices, + _In_ const char* compatibility_info, + _Out_ OrtCompiledModelCompatibility* out_status); ORT_API_STATUS_IMPL(Graph_GetModelPath, _In_ const OrtGraph* graph, _Outptr_ const ORTCHAR_T** model_path); ORT_API_STATUS_IMPL(Graph_GetOnnxIRVersion, _In_ const OrtGraph* graph, _Out_ int64_t* onnx_ir_version); ORT_API_STATUS_IMPL(Graph_GetNumOperatorSets, _In_ const OrtGraph* graph, _Out_ size_t* num_operator_sets); diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_internal.h b/onnxruntime/core/session/plugin_ep/ep_factory_internal.h index 23e5e95af2903..093bfce462d32 100644 --- a/onnxruntime/core/session/plugin_ep/ep_factory_internal.h +++ b/onnxruntime/core/session/plugin_ep/ep_factory_internal.h @@ -80,9 +80,11 @@ class EpFactoryInternal : public OrtEpFactory { return impl_->CreateSyncStreamForDevice(memory_device, stream_options, stream); } - OrtStatus* ValidateCompiledModelCompatibilityInfo(_In_ const char* compatibility_info, + OrtStatus* ValidateCompiledModelCompatibilityInfo(_In_reads_(num_devices) const OrtHardwareDevice* const* devices, + _In_ size_t num_devices, + _In_ const char* compatibility_info, _Out_ OrtCompiledModelCompatibility* model_compatibility) noexcept { - return impl_->ValidateCompiledModelCompatibilityInfo(compatibility_info, model_compatibility); + return impl_->ValidateCompiledModelCompatibilityInfo(devices, num_devices, compatibility_info, model_compatibility); } // Function ORT calls to release an EP instance. diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h b/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h index 6c55730d83979..f29154d19c53c 100644 --- a/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h +++ b/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h @@ -62,8 +62,13 @@ class EpFactoryInternalImpl { return false; } - virtual OrtStatus* ValidateCompiledModelCompatibilityInfo(_In_ const char* compatibility_info, - _Out_ OrtCompiledModelCompatibility* model_compatibility) noexcept { + virtual OrtStatus* ValidateCompiledModelCompatibilityInfo( + _In_reads_(num_devices) const OrtHardwareDevice* const* devices, + _In_ size_t num_devices, + _In_ const char* compatibility_info, + _Out_ OrtCompiledModelCompatibility* model_compatibility) noexcept { + ORT_UNUSED_PARAMETER(devices); + ORT_UNUSED_PARAMETER(num_devices); ORT_UNUSED_PARAMETER(compatibility_info); // Default implementation: mark as not applicable *model_compatibility = OrtCompiledModelCompatibility_EP_NOT_APPLICABLE; diff --git a/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc index 3bfca62a4d011..c8829423fbe26 100644 --- a/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc +++ b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc @@ -668,8 +668,15 @@ Status PluginExecutionProvider::ValidateCompiledModelCompatibilityInfo(const std // Plugin EP did not provide an implementation of this function, so we call a default implementation. return Base::ValidateCompiledModelCompatibilityInfo(compatibility_info, model_compatibility); } - // Delegate to the EP factory's validation method + // Delegate to the EP factory's validation method, passing hardware devices derived from our ep_devices_ + std::vector hardware_devices; + hardware_devices.reserve(ep_devices_.size()); + for (const auto* ep_device : ep_devices_) { + hardware_devices.push_back(ep_device->device); + } ORT_RETURN_IF_ERROR(ToStatusAndRelease(ep_factory_.ValidateCompiledModelCompatibilityInfo(&ep_factory_, + hardware_devices.data(), + hardware_devices.size(), compatibility_info.c_str(), &model_compatibility))); return Status::OK(); diff --git a/onnxruntime/core/session/plugin_ep/forward_to_factory_impl.h b/onnxruntime/core/session/plugin_ep/forward_to_factory_impl.h index 29793b503c9d1..2cceb1d08d536 100644 --- a/onnxruntime/core/session/plugin_ep/forward_to_factory_impl.h +++ b/onnxruntime/core/session/plugin_ep/forward_to_factory_impl.h @@ -46,9 +46,12 @@ struct ForwardToFactoryImpl { } static OrtStatus* ORT_API_CALL ValidateCompiledModelCompatibilityInfo(OrtEpFactory* this_ptr, + _In_reads_(num_devices) const OrtHardwareDevice* const* devices, + size_t num_devices, const char* compatibility_info, OrtCompiledModelCompatibility* model_compatibility) noexcept { - return static_cast(this_ptr)->ValidateCompiledModelCompatibilityInfo(compatibility_info, model_compatibility); + return static_cast(this_ptr)->ValidateCompiledModelCompatibilityInfo(devices, num_devices, + compatibility_info, model_compatibility); } static OrtStatus* ORT_API_CALL CreateAllocator(_In_ OrtEpFactory* this_ptr, diff --git a/onnxruntime/test/contrib_ops/matmul_2bits_test.cc b/onnxruntime/test/contrib_ops/matmul_2bits_test.cc index 884922ec5c098..57b1edfd0edce 100644 --- a/onnxruntime/test/contrib_ops/matmul_2bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_2bits_test.cc @@ -198,6 +198,9 @@ void RunTest2Bits(const TestOptions2Bits& opts) { std::vector> execution_providers; if constexpr (std::is_same::value) { execution_providers.emplace_back(DefaultCpuExecutionProvider()); +#ifdef USE_WEBGPU + execution_providers.push_back(DefaultWebGpuExecutionProvider()); +#endif test.ConfigEps(std::move(execution_providers)); test.RunWithConfig(); } @@ -244,22 +247,47 @@ void TestMatMul2BitsTyped(float abs_error = 0.1f, float rel_error = 0.02f) { } // namespace -TEST(MatMulNBits, Float32_2Bits_Accuracy0) { - // Currently, only fallback option enabled for 2bit datatypes - // where the 2bits are dequantized to fp32 - TestMatMul2BitsTyped(); - TestMatMul2BitsTyped(); - TestMatMul2BitsTyped(); - TestMatMul2BitsTyped(); - TestMatMul2BitsTyped(); - TestMatMul2BitsTyped(); - TestMatMul2BitsTyped(); - TestMatMul2BitsTyped(); - TestMatMul2BitsTyped(); - TestMatMul2BitsTyped(); - TestMatMul2BitsTyped(); - TestMatMul2BitsTyped(); +template +struct TypedTestParams { + static constexpr int batch_size = BatchSize; + static constexpr int M = MVal; + static constexpr int N = NVal; + static constexpr int K = KVal; +}; + +using TestTypes = ::testing::Types< + TypedTestParams<1, 1, 16, 16>, + TypedTestParams<1, 2, 16, 16>, + TypedTestParams<1, 32, 16, 16>, + TypedTestParams<1, 32, 32, 16>, + TypedTestParams<1, 32, 16, 128>, + TypedTestParams<1, 288, 16, 16>, + TypedTestParams<4, 1, 16, 16>, + TypedTestParams<4, 2, 16, 16>, + TypedTestParams<4, 32, 16, 16>, + TypedTestParams<4, 32, 32, 16>, + TypedTestParams<4, 32, 16, 128>, + TypedTestParams<4, 288, 16, 16>>; + +template +class MatMulNBits : public ::testing::Test { + public: + static constexpr int batch_size = T::batch_size; + static constexpr int M = T::M; + static constexpr int N = T::N; + static constexpr int K = T::K; +}; + +TYPED_TEST_SUITE(MatMulNBits, TestTypes); + +TYPED_TEST(MatMulNBits, Float32_2Bits_Accuracy0) { + TestMatMul2BitsTyped(); +} + +TYPED_TEST(MatMulNBits, Float32_2Bits_Accuracy4) { + TestMatMul2BitsTyped(); } + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/contrib_ops/moe_test.cc b/onnxruntime/test/contrib_ops/moe_test.cc index 05fff886080d7..ed7ca998e0b86 100644 --- a/onnxruntime/test/contrib_ops/moe_test.cc +++ b/onnxruntime/test/contrib_ops/moe_test.cc @@ -9,13 +9,16 @@ namespace onnxruntime { namespace test { +// Note: QMoE CPU implementation now always applies softmax normalization to top-k selected experts +// regardless of the normalize_routing_weights parameter value for mathematical correctness. + #ifndef ENABLE_TRAINING static void RunMoETest(const std::vector& input, const std::vector& router_probs, const std::vector& fc1_experts_weights, const std::vector& fc2_experts_weights, const std::vector& fc3_experts_weights, const std::vector& fc1_experts_bias, const std::vector& fc2_experts_bias, const std::vector& output_data, int num_rows, int num_experts, int hidden_size, int inter_size, std::string activation_type, - int normalize_routing_weights = 0, int top_k = 1, bool use_float16 = false) { + int normalize_routing_weights = 1, int top_k = 1, bool use_float16 = false) { constexpr int min_cuda_arch = 700; bool enable_cuda = HasCudaEnvironment(min_cuda_arch); @@ -27,8 +30,8 @@ static void RunMoETest(const std::vector& input, const std::vector std::vector input_dims = {num_rows, hidden_size}; std::vector router_probs_dims = {num_rows, num_experts}; - std::vector fc1_experts_weights_dims = {num_experts, hidden_size, inter_size}; - std::vector fc2_experts_weights_dims = {num_experts, inter_size, hidden_size}; + std::vector fc1_experts_weights_dims = {num_experts, inter_size, hidden_size}; + std::vector fc2_experts_weights_dims = {num_experts, hidden_size, inter_size}; std::vector fc3_experts_weights_dims = fc1_experts_weights_dims; std::vector fc1_experts_bias_dims = {num_experts, inter_size}; std::vector fc2_experts_bias_dims = {num_experts, hidden_size}; @@ -90,7 +93,7 @@ static void RunQMoETest(const std::vector& input, const std::vector& fc3_experts_weights, const std::vector& fc1_scales, const std::vector& fc2_scales, const std::vector& fc3_scales, const std::vector& output_data, int num_rows, int num_experts, int hidden_size, - int inter_size, std::string activation_type, int normalize_routing_weights = 0, int top_k = 1, int expert_weight_bits = 4) { + int inter_size, std::string activation_type, int normalize_routing_weights = 1, int top_k = 1, int expert_weight_bits = 4) { constexpr int min_cuda_arch = 700; // Test CUDA execution provider @@ -103,7 +106,6 @@ static void RunQMoETest(const std::vector& input, const std::vector input_dims = {num_rows, hidden_size}; std::vector router_probs_dims = {num_rows, num_experts}; - // Adjust weight dimensions based on quantization type for CUDA as well std::vector fc1_experts_weights_dims = {num_experts, hidden_size, expert_weight_bits == 4 ? inter_size / 2 : inter_size}; std::vector fc2_experts_weights_dims = {num_experts, inter_size, expert_weight_bits == 4 ? hidden_size / 2 : hidden_size}; std::vector fc3_experts_weights_dims = fc1_experts_weights_dims; @@ -150,7 +152,6 @@ static void RunQMoETest(const std::vector& input, const std::vector input_dims = {num_rows, hidden_size}; std::vector router_probs_dims = {num_rows, num_experts}; - // Adjust weight dimensions based on quantization type std::vector fc1_experts_weights_dims = {num_experts, hidden_size, expert_weight_bits == 4 ? inter_size / 2 : inter_size}; std::vector fc2_experts_weights_dims = {num_experts, inter_size, expert_weight_bits == 4 ? hidden_size / 2 : hidden_size}; std::vector fc1_scales_dims = {num_experts, inter_size}; @@ -355,7 +356,7 @@ TEST(MoETest, MoETest_Gelu) { 1.3354061f, 0.5049282f, 0.72775036f, 0.90331376f, 1.2945517f, 0.9123066f, 1.1995136f, 0.7708638f}; RunMoETest(input, router_probs, fc1_experts_weights, fc2_experts_weights, {}, fc1_experts_bias, fc2_experts_bias, - output, num_rows, num_experts, hidden_size, inter_size, "gelu"); + output, num_rows, num_experts, hidden_size, inter_size, "gelu", 0); } TEST(MoETest, MoETest_Relu) { @@ -533,7 +534,7 @@ TEST(MoETest, MoETest_Relu) { 4.8571277f, 5.649453f, 5.485141f, 5.306299f, 4.767025f, 6.9010167f, 5.3520975f, 6.711155f}; RunMoETest(input, router_probs, fc1_experts_weights, fc2_experts_weights, {}, fc1_experts_bias, fc2_experts_bias, - output, num_rows, num_experts, hidden_size, inter_size, "relu"); + output, num_rows, num_experts, hidden_size, inter_size, "relu", 0); } TEST(MoETest, MoETest_Mixtral) { @@ -1322,7 +1323,6 @@ TEST(MoETest, QMoETest_Mixtral_Int4) { // CPU-specific QMoE tests TEST(MoETest, QMoETest_CPU_Int4_MLAS) { - // Test CPU implementation with 4-bit quantization (MLAS optimized path) - CPU only int num_rows = 2; int num_experts = 2; int hidden_size = 32; @@ -1336,31 +1336,32 @@ TEST(MoETest, QMoETest_CPU_Int4_MLAS) { const std::vector router_probs = {0.3f, 0.7f, 0.6f, 0.4f}; - // Generate simple test weights for 4-bit symmetric quantization - // Use 0x00 which unpacks to 0,0 (both 0 for 4-bit) - std::vector fc1_experts_weights(num_experts * hidden_size * inter_size / 2, 0x00); - std::vector fc2_experts_weights(num_experts * inter_size * hidden_size / 2, 0x00); // 0,0 values to produce zero output + // Generate simple test weights for 4-bit symmetric quantization with SwiGLU + // Use 0x88 which unpacks to 8,8 -> 0,0 in signed form (8-8=0) for zero weights + // For SwiGLU: FC1 outputs 2*inter_size (gate + linear), FC2 takes inter_size input + std::vector fc1_experts_weights(num_experts * hidden_size * inter_size, 0x88); // 2*inter_size for SwiGLU, packed into /2 + std::vector fc2_experts_weights(num_experts * inter_size * hidden_size / 2, 0x88); // 8,8 values to produce zero output std::vector fc3_experts_weights; // Empty for CPU (FC3 not supported) - std::vector fc1_scales(num_experts * inter_size, 0.01f); // Smaller scale factor - std::vector fc2_scales(num_experts * hidden_size, 0.01f); // Smaller scale factor + std::vector fc1_scales(num_experts * inter_size * 2, 0.01f); // 2x for SwiGLU (gate + linear) + std::vector fc2_scales(num_experts * hidden_size, 0.01f); // Smaller scale factor std::vector fc3_scales; - // With zero weights (0x00), the current implementation will produce all zero outputs + // With zero weights (0x88 -> 8,8 -> 0,0 signed), the implementation will produce all zero outputs std::vector output(num_rows * hidden_size, 0.0f); // Test CPU execution provider ONLY (don't use RunQMoETest which tests both CUDA and CPU) OpTester cpu_tester("QMoE", 1, onnxruntime::kMSDomain); cpu_tester.AddAttribute("k", 2); - cpu_tester.AddAttribute("activation_type", "gelu"); - cpu_tester.AddAttribute("normalize_routing_weights", 1); - cpu_tester.AddAttribute("expert_weight_bits", 4); // Test 4-bit quantization + cpu_tester.AddAttribute("activation_type", "swiglu"); // CPU only supports swiglu + cpu_tester.AddAttribute("normalize_routing_weights", 1); // Always use 1 - softmax normalization always applied + cpu_tester.AddAttribute("expert_weight_bits", 4); // Test 4-bit quantization std::vector input_dims = {num_rows, hidden_size}; std::vector router_probs_dims = {num_rows, num_experts}; - std::vector fc1_experts_weights_dims = {num_experts, hidden_size, inter_size / 2}; // legacy shape + std::vector fc1_experts_weights_dims = {num_experts, 2 * inter_size, hidden_size / 2}; // SwiGLU: 2*inter_size output, 4-bit packed std::vector fc2_experts_weights_dims = {num_experts, inter_size, hidden_size / 2}; - std::vector fc1_scales_dims = {num_experts, inter_size}; + std::vector fc1_scales_dims = {num_experts, inter_size * 2}; // 2x for SwiGLU std::vector fc2_scales_dims = {num_experts, hidden_size}; std::vector output_dims = {num_rows, hidden_size}; @@ -1376,8 +1377,8 @@ TEST(MoETest, QMoETest_CPU_Int4_MLAS) { cpu_tester.AddOptionalInputEdge(); // fc3_scales (use float for CPU) cpu_tester.AddOptionalInputEdge(); // fc3_experts_bias - // When using 0x00 for 4-bit quantized weights with the current implementation, - // all dequantized values should be 0.0f, and thus output should be all zeros + // When using 0x88 for 4-bit quantized weights with the current implementation, + // all dequantized values should be 0.0f (8-8=0), and thus output should be all zeros std::vector expected_output(num_rows * hidden_size, 0.0f); cpu_tester.AddOutput("output", output_dims, ToFloat16(expected_output)); @@ -1400,31 +1401,31 @@ TEST(MoETest, QMoETest_CPU_Int8_MLAS) { const std::vector router_probs = {0.4f, 0.6f}; - // For 8-bit symmetric quantization, dimensions don't need /2 - // Use quantized weights close to zero for reasonable dequantization - std::vector fc1_experts_weights(num_experts * hidden_size * inter_size, 2); // 2 = small positive value - std::vector fc2_experts_weights(num_experts * inter_size * hidden_size, 254); // 254 = -2 in 8-bit signed - std::vector fc3_experts_weights; // Empty for CPU + // For 8-bit symmetric quantization with SwiGLU + // Use quantized weights at zero point for zero outputs (128 = 0 in signed) + std::vector fc1_experts_weights(num_experts * 2 * inter_size * hidden_size, 128); // 2*inter_size for SwiGLU, no packing for 8-bit + std::vector fc2_experts_weights(num_experts * inter_size * hidden_size, 128); // 128 = 0 in signed + std::vector fc3_experts_weights; // Empty for CPU - std::vector fc1_scales(num_experts * inter_size, 0.1f); + std::vector fc1_scales(num_experts * inter_size * 2, 0.1f); // 2x for SwiGLU std::vector fc2_scales(num_experts * hidden_size, 0.1f); std::vector fc3_scales; - // Expected output should be close to zero since we're using small weights around zero point + // Expected output should be zero since we're using zero weights (128-128=0) std::vector output(num_rows * hidden_size, 0.0f); // Test with different attributes for 8-bit OpTester cpu_tester("QMoE", 1, onnxruntime::kMSDomain); cpu_tester.AddAttribute("k", 1); - cpu_tester.AddAttribute("activation_type", "relu"); - cpu_tester.AddAttribute("normalize_routing_weights", 0); - cpu_tester.AddAttribute("expert_weight_bits", 8); // Test 8-bit quantization + cpu_tester.AddAttribute("activation_type", "swiglu"); // CPU only supports swiglu + cpu_tester.AddAttribute("normalize_routing_weights", 1); // Always use 1 - softmax normalization always applied + cpu_tester.AddAttribute("expert_weight_bits", 8); // Test 8-bit quantization std::vector input_dims = {num_rows, hidden_size}; std::vector router_probs_dims = {num_rows, num_experts}; - std::vector fc1_experts_weights_dims = {num_experts, hidden_size, inter_size}; // legacy shape + std::vector fc1_experts_weights_dims = {num_experts, inter_size * 2, hidden_size}; // SwiGLU: 2*inter_size output, 8-bit no packing std::vector fc2_experts_weights_dims = {num_experts, inter_size, hidden_size}; - std::vector fc1_scales_dims = {num_experts, inter_size}; + std::vector fc1_scales_dims = {num_experts, inter_size * 2}; // 2x for SwiGLU std::vector fc2_scales_dims = {num_experts, hidden_size}; std::vector output_dims = {num_rows, hidden_size}; @@ -1457,27 +1458,30 @@ TEST(MoETest, QMoETest_CPU_FC3_Error) { const std::vector input = {0.1f, -0.2f, 0.3f, -0.4f, 0.5f, -0.6f, 0.7f, -0.8f}; const std::vector router_probs = {0.5f, 0.5f}; - std::vector fc1_experts_weights(num_experts * hidden_size * inter_size / 2, 0x01); // 0,1 in symmetric quantization - std::vector fc2_experts_weights(num_experts * inter_size * hidden_size / 2, 0x10); // 1,0 in symmetric quantization - std::vector fc3_experts_weights(num_experts * hidden_size * inter_size / 2, 0x21); // 2,1 in symmetric quantization, FC3 provided + // Using new layout: fc1 has fused swiglu doubling (2*inter_size) and 4-bit pack_size=2 so hidden_size packed dimension is hidden_size/2 + const int pack_size = 2; // for 4-bit + const int fc1_inter_size = 2 * inter_size; // swiglu fused + std::vector fc1_experts_weights(num_experts * fc1_inter_size * (hidden_size / pack_size), 0x01); + std::vector fc2_experts_weights(num_experts * hidden_size * (inter_size / pack_size), 0x10); + std::vector fc3_experts_weights(num_experts * inter_size * (hidden_size / pack_size), 0x21); // FC3 provided - std::vector fc1_scales(num_experts * inter_size, 0.1f); + std::vector fc1_scales(num_experts * fc1_inter_size, 0.1f); std::vector fc2_scales(num_experts * hidden_size, 0.05f); std::vector fc3_scales(num_experts * inter_size, 0.08f); // FC3 scales provided // Test CPU execution provider ONLY (designed to test CPU-specific error handling) OpTester cpu_tester("QMoE", 1, onnxruntime::kMSDomain); cpu_tester.AddAttribute("k", 1); - cpu_tester.AddAttribute("activation_type", "relu"); - cpu_tester.AddAttribute("normalize_routing_weights", 0); + cpu_tester.AddAttribute("activation_type", "swiglu"); // CPU only supports swiglu + cpu_tester.AddAttribute("normalize_routing_weights", 1); // Use 1 for consistency, though this test focuses on FC3 error cpu_tester.AddAttribute("expert_weight_bits", 4); std::vector input_dims = {num_rows, hidden_size}; std::vector router_probs_dims = {num_rows, num_experts}; - std::vector fc1_experts_weights_dims = {num_experts, hidden_size, inter_size / 2}; // legacy shape - std::vector fc2_experts_weights_dims = {num_experts, inter_size, hidden_size / 2}; - std::vector fc3_experts_weights_dims = {num_experts, hidden_size, inter_size / 2}; - std::vector fc1_scales_dims = {num_experts, inter_size}; + std::vector fc1_experts_weights_dims = {num_experts, fc1_inter_size, hidden_size / pack_size}; + std::vector fc2_experts_weights_dims = {num_experts, hidden_size, inter_size / pack_size}; + std::vector fc3_experts_weights_dims = {num_experts, inter_size, hidden_size / pack_size}; + std::vector fc1_scales_dims = {num_experts, fc1_inter_size}; std::vector fc2_scales_dims = {num_experts, hidden_size}; std::vector fc3_scales_dims = {num_experts, inter_size}; std::vector output_dims = {num_rows, hidden_size}; @@ -1522,9 +1526,9 @@ TEST(MoETest, QMoETest_CPU_SwiGLU_Int4) { const int fc1_weight_size_per_expert = hidden_size * inter_size * 2 / 2; // For 4-bit SwiGLU const int fc2_weight_size_per_expert = inter_size * hidden_size / 2; // For 4-bit FC2 - // Generate test weights for symmetric quantization (zero point is 0) - std::vector fc1_experts_weights(num_experts * fc1_weight_size_per_expert, 0x12); // 1,2 -> small positive weights - std::vector fc2_experts_weights(num_experts * fc2_weight_size_per_expert, 0xFF); // -1,0 -> small mixed weights + // Generate test weights for symmetric quantization (zero point is 8 for 4-bit) + std::vector fc1_experts_weights(num_experts * fc1_weight_size_per_expert, 0x88); // 8,8 -> 0,0 signed weights + std::vector fc2_experts_weights(num_experts * fc2_weight_size_per_expert, 0x88); // 8,8 -> 0,0 signed weights std::vector fc3_experts_weights; // Empty for SwiGLU (gate weights concatenated with FC1) // Scales: for SwiGLU, FC1 has 2*inter_size outputs (linear + gate) @@ -1532,7 +1536,10 @@ TEST(MoETest, QMoETest_CPU_SwiGLU_Int4) { std::vector fc2_scales(num_experts * hidden_size, 0.05f); std::vector fc3_scales; - // Expected output should be small but non-zero due to SwiGLU nonlinearity + // For SwiGLU with zero weights (0x88 -> 8,8 -> 0,0 signed): + // Gate output = 0, Linear output = 0 + // SwiGLU = gate * sigmoid(gate) * (linear + 1) = 0 * sigmoid(0) * (0 + 1) = 0 * 0.5 * 1 = 0 + // So output should be zero std::vector output(num_rows * hidden_size, 0.0f); OpTester cpu_tester("QMoE", 1, onnxruntime::kMSDomain); @@ -1582,10 +1589,11 @@ TEST(MoETest, QMoETest_CPU_SwiGLU_Int8) { const int fc1_weight_size_per_expert = hidden_size * inter_size * 2; // For 8-bit SwiGLU const int fc2_weight_size_per_expert = inter_size * hidden_size; // For 8-bit FC2 - // Generate test weights at zero (for symmetric quantization) to produce zero output - std::vector fc1_experts_weights(num_experts * fc1_weight_size_per_expert, 0); // Zero in symmetric quantization - std::vector fc2_experts_weights(num_experts * fc2_weight_size_per_expert, 0); // Zero in symmetric quantization - std::vector fc3_experts_weights; // Empty for SwiGLU + // Generate test weights at zero (for symmetric quantization storage format: uint8 with zero point 128) + // Fill with 128 so dequantized value (val - 128) == 0 => zero output + std::vector fc1_experts_weights(num_experts * fc1_weight_size_per_expert, 128); + std::vector fc2_experts_weights(num_experts * fc2_weight_size_per_expert, 128); + std::vector fc3_experts_weights; // Empty for SwiGLU // Scales: for SwiGLU, FC1 has 2*inter_size outputs std::vector fc1_scales(num_experts * inter_size * 2, 0.1f); @@ -1627,65 +1635,6 @@ TEST(MoETest, QMoETest_CPU_SwiGLU_Int8) { cpu_tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &cpu_execution_providers); } -// Test for Float32 input and output type with QMoE operator -TEST(MoETest, QMoETest_CPU_Float32) { - // Test CPU implementation with float32 input/output - int num_rows = 1; - int num_experts = 2; - int hidden_size = 8; - int inter_size = 8; - - const std::vector input = {0.2f, -0.3f, 0.4f, -0.5f, 0.6f, -0.7f, 0.8f, -0.9f}; - const std::vector router_probs = {0.0f, 0.0f}; - - // For 8-bit quantization weights - const int fc1_weight_size_per_expert = hidden_size * inter_size; - const int fc2_weight_size_per_expert = inter_size * hidden_size; - - // Generate test weights at zero point (128 for 8-bit) to produce zero output - std::vector fc1_experts_weights(num_experts * fc1_weight_size_per_expert, 128); - std::vector fc2_experts_weights(num_experts * fc2_weight_size_per_expert, 128); - - // Scales - std::vector fc1_scales(num_experts * inter_size, 0.1f); - std::vector fc2_scales(num_experts * hidden_size, 0.1f); - - std::vector output(num_rows * hidden_size, 0.0f); - - OpTester cpu_tester("QMoE", 1, onnxruntime::kMSDomain); - cpu_tester.AddAttribute("k", 2); - cpu_tester.AddAttribute("activation_type", "gelu"); - cpu_tester.AddAttribute("normalize_routing_weights", 1); - cpu_tester.AddAttribute("expert_weight_bits", 8); - - std::vector input_dims = {num_rows, hidden_size}; - std::vector router_probs_dims = {num_rows, num_experts}; - std::vector fc1_experts_weights_dims = {num_experts, hidden_size, inter_size}; // legacy shape - std::vector fc2_experts_weights_dims = {num_experts, inter_size, hidden_size}; - std::vector fc1_scales_dims = {num_experts, inter_size}; - std::vector fc2_scales_dims = {num_experts, hidden_size}; - std::vector output_dims = {num_rows, hidden_size}; - - // Use float directly instead of MLFloat16 - cpu_tester.AddInput("input", input_dims, input); - cpu_tester.AddInput("router_probs", router_probs_dims, router_probs); - cpu_tester.AddInput("fc1_experts_weights", fc1_experts_weights_dims, fc1_experts_weights); - cpu_tester.AddInput("fc1_scales", fc1_scales_dims, fc1_scales); - cpu_tester.AddOptionalInputEdge(); // fc1_experts_bias - cpu_tester.AddInput("fc2_experts_weights", fc2_experts_weights_dims, fc2_experts_weights); - cpu_tester.AddInput("fc2_scales", fc2_scales_dims, fc2_scales); - cpu_tester.AddOptionalInputEdge(); // fc2_experts_bias - cpu_tester.AddOptionalInputEdge(); // fc3_experts_weights - cpu_tester.AddOptionalInputEdge(); // fc3_scales - cpu_tester.AddOptionalInputEdge(); // fc3_experts_bias - cpu_tester.AddOutput("output", output_dims, output); - cpu_tester.SetOutputTolerance(0.02f); - - std::vector> cpu_execution_providers; - cpu_execution_providers.push_back(DefaultCpuExecutionProvider()); - cpu_tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &cpu_execution_providers); -} - #endif } // namespace test diff --git a/onnxruntime/test/framework/ep_compatibility_test.cc b/onnxruntime/test/framework/ep_compatibility_test.cc index be97cf2620881..ee82d4683ab73 100644 --- a/onnxruntime/test/framework/ep_compatibility_test.cc +++ b/onnxruntime/test/framework/ep_compatibility_test.cc @@ -408,3 +408,94 @@ TEST_F(EpCompatibilityTest, TestSessionOptionConfiguration) { EXPECT_TRUE(has_config); EXPECT_EQ(config_value, "0"); } + +// ----------------------------- +// C API unit tests +// ----------------------------- + +namespace { + +// Helper to create an OrtEnv and fetch a CPU EP device pointer via the C API. +// Returns a pair of (env, cpu_device). Caller releases env via api->ReleaseEnv. +static std::pair CreateEnvAndGetCpuEpDevice(const OrtApi* api) { + OrtEnv* env = nullptr; + EXPECT_EQ(nullptr, api->CreateEnv(ORT_LOGGING_LEVEL_WARNING, "EpCompatCapiTest", &env)); + EXPECT_NE(env, nullptr); + + const OrtEpDevice* const* devices = nullptr; + size_t num_devices = 0; + EXPECT_EQ(nullptr, api->GetEpDevices(env, &devices, &num_devices)); + EXPECT_GT(num_devices, 0u); + + const OrtEpDevice* cpu_device = nullptr; + for (size_t i = 0; i < num_devices; ++i) { + const char* name = api->EpDevice_EpName(devices[i]); + if (name && std::string(name) == "CPUExecutionProvider") { + cpu_device = devices[i]; + break; + } + } + + // Fallback: just pick the first device if CPU wasn't found (environment-dependent builds). + if (!cpu_device && num_devices > 0) { + cpu_device = devices[0]; + } + + EXPECT_NE(cpu_device, nullptr); + return {env, cpu_device}; +} + +} // namespace + +TEST(EpCompatibilityCapiTest, InvalidArguments) { + const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION); + ASSERT_NE(api, nullptr); + + OrtCompiledModelCompatibility out_status = OrtCompiledModelCompatibility_EP_NOT_APPLICABLE; + + // ep_devices == nullptr + OrtStatus* st = api->GetModelCompatibilityForEpDevices(nullptr, 0, "info", &out_status); + ASSERT_NE(st, nullptr); + EXPECT_EQ(api->GetErrorCode(st), ORT_INVALID_ARGUMENT); + api->ReleaseStatus(st); + + // Prepare a valid device + auto [env, device] = CreateEnvAndGetCpuEpDevice(api); + ASSERT_NE(env, nullptr); + ASSERT_NE(device, nullptr); + + // compatibility_info == nullptr + const OrtEpDevice* devices1[] = {device}; + st = api->GetModelCompatibilityForEpDevices(devices1, 1, nullptr, &out_status); + ASSERT_NE(st, nullptr); + EXPECT_EQ(api->GetErrorCode(st), ORT_INVALID_ARGUMENT); + api->ReleaseStatus(st); + + // out_status == nullptr + st = api->GetModelCompatibilityForEpDevices(devices1, 1, "some-info", nullptr); + ASSERT_NE(st, nullptr); + EXPECT_EQ(api->GetErrorCode(st), ORT_INVALID_ARGUMENT); + api->ReleaseStatus(st); + + api->ReleaseEnv(env); +} + +TEST(EpCompatibilityCapiTest, CpuEpReturnsNotApplicableIfNoValidation) { + const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION); + ASSERT_NE(api, nullptr); + + auto [env, device] = CreateEnvAndGetCpuEpDevice(api); + ASSERT_NE(env, nullptr); + ASSERT_NE(device, nullptr); + + OrtCompiledModelCompatibility out_status = static_cast(-1); + const OrtEpDevice* devices2[] = {device}; + OrtStatus* st = api->GetModelCompatibilityForEpDevices(devices2, 1, "arbitrary-compat-string", &out_status); + ASSERT_EQ(st, nullptr) << (st ? api->GetErrorMessage(st) : ""); + + // For providers that don't implement validation, API should return EP_NOT_APPLICABLE. + EXPECT_EQ(out_status, OrtCompiledModelCompatibility_EP_NOT_APPLICABLE); + api->ReleaseStatus(st); + + api->ReleaseEnv(env); +} diff --git a/onnxruntime/test/framework/save_model_with_external_initializers.cc b/onnxruntime/test/framework/save_model_with_external_initializers.cc index 98874874d50e9..e70d870ef6988 100644 --- a/onnxruntime/test/framework/save_model_with_external_initializers.cc +++ b/onnxruntime/test/framework/save_model_with_external_initializers.cc @@ -84,7 +84,7 @@ Status LoadSaveAndCompareModel(const std::filesystem::path& input_onnx, size_t tensor_offset; std::stringstream stream(entry.value()); stream >> tensor_offset; - ORT_RETURN_IF_NOT(tensor_offset % model_saving_options.allocation_granularity == 0, + ORT_RETURN_IF_NOT(tensor_offset % model_saving_options.on_disk_alignment == 0, "tensor offset not align"); } } diff --git a/onnxruntime/test/platform/file_io_test.cc b/onnxruntime/test/platform/file_io_test.cc index ccc703716844f..a1a863d2442d1 100644 --- a/onnxruntime/test/platform/file_io_test.cc +++ b/onnxruntime/test/platform/file_io_test.cc @@ -157,7 +157,6 @@ TEST(FileIoTest, MapFileIntoMemory) { SYSTEM_INFO sysinfo; GetSystemInfo(&sysinfo); static const auto page_size = sysinfo.dwPageSize; - static const auto allocation_granularity = sysinfo.dwAllocationGranularity; ASSERT_GT(page_size, static_cast(0)); TempFilePath tmp(ORT_TSTR("map_file_test_")); @@ -167,21 +166,10 @@ TEST(FileIoTest, MapFileIntoMemory) { const auto offsets_and_lengths = GenerateValidOffsetLengthPairs( 0, expected_data.size(), page_size / 10); - for (const auto& offset_and_length : offsets_and_lengths) { - const auto offset = offset_and_length.first; - const auto length = offset_and_length.second; - - // The offset must be a multiple of the allocation granularity - if (offset % allocation_granularity != 0) { - continue; - } - + for (const auto& [offset, length] : offsets_and_lengths) { Env::MappedMemoryPtr mapped_memory{}; - auto status = Env::Default().MapFileIntoMemory( - tmp.path.c_str(), offset, length, mapped_memory); - ASSERT_TRUE(status.IsOK()) - << "MapFileIntoMemory failed for offset " << offset << " and length " << length - << " with error: " << status.ErrorMessage(); + ASSERT_STATUS_OK(Env::Default().MapFileIntoMemory( + tmp.path.c_str(), offset, length, mapped_memory)); auto mapped_span = gsl::make_span(mapped_memory.get(), length); @@ -190,20 +178,11 @@ TEST(FileIoTest, MapFileIntoMemory) { ASSERT_TRUE(SpanEq(mapped_span, expected_data_span)); } - { - Env::MappedMemoryPtr mapped_memory{}; - - // invalid - offset is not a multiple of the allocation granularity - ASSERT_FALSE(Env::Default().MapFileIntoMemory( - tmp.path.c_str(), allocation_granularity * 3 / 2, page_size / 10, mapped_memory) - .IsOK()); - } - { Env::MappedMemoryPtr mapped_memory{}; // invalid - negative offset - ASSERT_FALSE(Env::Default().MapFileIntoMemory(tmp.path.c_str(), -1, 0, mapped_memory).IsOK()); + ASSERT_STATUS_NOT_OK(Env::Default().MapFileIntoMemory(tmp.path.c_str(), -1, 0, mapped_memory)); } } #endif diff --git a/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc b/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc index 8fdbf0060eaa0..0e17aa835028e 100644 --- a/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc @@ -137,7 +137,7 @@ TEST(DequantizeLinearOpTest, Uint16) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } -// scalar zero & scale with int8 +// scalar zero & scale with int32 TEST(DequantizeLinearOpTest, Int32) { OpTester test("DequantizeLinear", 10); std::vector dims{4}; @@ -147,6 +147,17 @@ TEST(DequantizeLinearOpTest, Int32) { test.Run(); } +// non-zero zero point with int32 +TEST(DequantizeLinearOpTest, Int32_Non_Zero_Zero_Point) { + OpTester test("DequantizeLinear", 10); + std::vector dims{4}; + test.AddInput("x", dims, {-30, -3, 100, 127}); + test.AddInput("x_scale", {}, {2.0f}, true); + test.AddInput("x_zero_point", {}, {1}, true); + test.AddOutput("y", dims, {-62.f, -8.f, 198.f, 252.f}); + test.Run(); +} + TEST(DequantizeLinearOpTest_BroadcastTensor, Int32) { OpTester test("DequantizeLinear", 13); test.AddInput("x", {4}, {-30, -3, 100, 127}); diff --git a/onnxruntime/test/providers/qnn/einsum_op_test.cc b/onnxruntime/test/providers/qnn/einsum_op_test.cc index 11a3d5a083aab..fefd682de203e 100644 --- a/onnxruntime/test/providers/qnn/einsum_op_test.cc +++ b/onnxruntime/test/providers/qnn/einsum_op_test.cc @@ -150,6 +150,32 @@ TEST_F(QnnCPUBackendTests, EinsumRank2) { /*tolerance=*/1e-4f); } +TEST_F(QnnCPUBackendTests, EinsumRank3MatMul) { + const std::vector shape0{4, 5, 6}; + const std::vector shape1{4, 6, 5}; + const std::vector data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f); + const std::vector data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f); + RunQnnEinsum( + /*backend=*/kQnnBackendTypeCpu, + /*in0=*/TestInputDef(shape0, /*is_initializer=*/false, std::move(data0)), + /*in1=*/TestInputDef(shape1, /*is_initializer=*/false, std::move(data1)), + /*equation=*/"hij,hjk->hik", + /*tolerance=*/1e-4f); +} + +TEST_F(QnnCPUBackendTests, EinsumRank3MatMul_QK) { + const std::vector shape0{4, 5, 6}; + const std::vector shape1{4, 6, 5}; + const std::vector data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f); + const std::vector data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f); + RunQnnEinsum( + /*backend=*/kQnnBackendTypeCpu, + /*in0=*/TestInputDef(shape0, /*is_initializer=*/false, std::move(data0)), + /*in1=*/TestInputDef(shape1, /*is_initializer=*/false, std::move(data1)), + /*equation=*/"hQK,hKd->hQd", + /*tolerance=*/1e-4f); +} + TEST_F(QnnCPUBackendTests, EinsumRank4MatMul) { const std::vector shape0{3, 4, 5, 6}; const std::vector shape1{3, 4, 6, 5}; @@ -189,6 +215,19 @@ TEST_F(QnnCPUBackendTests, EinsumRank4MatMulTransposeAll1) { /*tolerance=*/1e-4f); } +TEST_F(QnnCPUBackendTests, EinsumRank4MatMulTransposeY_QK) { + const std::vector shape0{2, 3, 4, 6}; + const std::vector shape1{2, 3, 5, 6}; + const std::vector data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f); + const std::vector data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f); + RunQnnEinsum( + /*backend=*/kQnnBackendTypeCpu, + /*in0=*/TestInputDef(shape0, /*is_initializer=*/false, std::move(data0)), + /*in1=*/TestInputDef(shape1, /*is_initializer=*/false, std::move(data1)), + /*equation=*/"bnQd,bnKd->bnQK", + /*tolerance=*/1e-4f); +} + TEST_F(QnnCPUBackendTests, EinsumRank4MatMulTransposeAll2) { const std::vector shape0{1, 7, 1, 7}; const std::vector shape1{1, 9, 1, 7}; @@ -273,6 +312,60 @@ TEST_F(QnnHTPBackendTests, EinsumF16Rank4MatMulTransposeY) { /*tolerance=*/1e-2f); } +TEST_F(QnnHTPBackendTests, EinsumF16Rank4MatMulTransposeY_QK) { + const std::vector shape0{2, 3, 4, 2}; + const std::vector shape1{2, 3, 5, 2}; + const std::vector data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f); + const std::vector data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f); + RunQnnEinsum( + /*backend=*/kQnnBackendTypeHtp, + /*in0=*/TestInputDef(shape0, /*is_initializer=*/false, std::move(data0)), + /*in1=*/TestInputDef(shape1, /*is_initializer=*/false, std::move(data1)), + /*equation=*/"bnQd,bnKd->bnQK", + /*tolerance=*/1e-2f); +} + +TEST_F(QnnHTPBackendTests, EinsumRank3MatMulTransposeY) { + const std::vector shape0{2, 4, 2}; + const std::vector shape1{2, 5, 2}; + const std::vector data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f); + const std::vector data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f); + RunQnnEinsum( + /*backend=*/kQnnBackendTypeHtp, + /*in0=*/TestInputDef(shape0, /*is_initializer=*/false, std::move(data0)), + /*in1=*/TestInputDef(shape1, /*is_initializer=*/false, std::move(data1)), + /*equation=*/"bid,bjd->bij", + /*tolerance=*/1e-2f); +} + +TEST_F(QnnHTPBackendTests, EinsumRank3MatMulTransposeY_QK) { + const std::vector shape0{2, 4, 2}; + const std::vector shape1{2, 5, 2}; + const std::vector data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f); + const std::vector data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f); + RunQnnEinsum( + /*backend=*/kQnnBackendTypeHtp, + /*in0=*/TestInputDef(shape0, /*is_initializer=*/false, std::move(data0)), + /*in1=*/TestInputDef(shape1, /*is_initializer=*/false, std::move(data1)), + /*equation=*/"bQd,bKd->bQK", + /*tolerance=*/1e-2f); +} + +// The value pair (65.1049271, 65.0625076) at index #51 don't match, which is -0.0424194 from 65.1049 +// Disable this Rank3 test on HTP since it has accuracy issue. +TEST_F(QnnHTPBackendTests, DISABLED_EinsumRank3MatMul_QK) { + const std::vector shape0{4, 5, 6}; + const std::vector shape1{4, 6, 5}; + const std::vector data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f); + const std::vector data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f); + RunQnnEinsum( + /*backend=*/kQnnBackendTypeHtp, + /*in0=*/TestInputDef(shape0, /*is_initializer=*/false, std::move(data0)), + /*in1=*/TestInputDef(shape1, /*is_initializer=*/false, std::move(data1)), + /*equation=*/"hQK,hKd->hQd", + /*tolerance=*/1e-2f); +} + TEST_F(QnnHTPBackendTests, EinsumF16Rank4MatMulTransposeAll1) { const std::vector shape0{1, 3, 1, 7}; const std::vector shape1{1, 7, 1, 3}; @@ -365,6 +458,66 @@ TEST_F(QnnHTPBackendTests, EinsumQdqRank4MatMulTransposeY) { /*tolerance=*/QDQTolerance()); } +TEST_F(QnnHTPBackendTests, EinsumQdqRank4MatMulTransposeY_QK) { + const std::vector shape0{2, 3, 4, 2}; + const std::vector shape1{2, 3, 5, 2}; + const std::vector data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f); + const std::vector data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f); + RunQnnHtpQdqEinsum( + /*in0=*/TestInputDef(shape0, /*is_initializer=*/false, std::move(data0)), + /*in1=*/TestInputDef(shape1, /*is_initializer=*/false, std::move(data1)), + /*equation=*/"bnQd,bnKd->bnQK", + /*tolerance=*/QDQTolerance()); +} + +TEST_F(QnnHTPBackendTests, EinsumQdqRank3MatMulTransposeY) { + const std::vector shape0{2, 4, 2}; + const std::vector shape1{2, 5, 2}; + const std::vector data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f); + const std::vector data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f); + RunQnnHtpQdqEinsum( + /*in0=*/TestInputDef(shape0, /*is_initializer=*/false, std::move(data0)), + /*in1=*/TestInputDef(shape1, /*is_initializer=*/false, std::move(data1)), + /*equation=*/"bid,bjd->bij", + /*tolerance=*/QDQTolerance()); +} + +TEST_F(QnnHTPBackendTests, EinsumQdqRank3MatMulTransposeY_QK) { + const std::vector shape0{2, 4, 2}; + const std::vector shape1{2, 5, 2}; + const std::vector data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f); + const std::vector data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f); + RunQnnHtpQdqEinsum( + /*in0=*/TestInputDef(shape0, /*is_initializer=*/false, std::move(data0)), + /*in1=*/TestInputDef(shape1, /*is_initializer=*/false, std::move(data1)), + /*equation=*/"bQd,bKd->bQK", + /*tolerance=*/QDQTolerance()); +} + +TEST_F(QnnHTPBackendTests, EinsumQdqRank3MatMul) { + const std::vector shape0{4, 5, 6}; + const std::vector shape1{4, 6, 5}; + const std::vector data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f); + const std::vector data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f); + RunQnnHtpQdqEinsum( + /*in0=*/TestInputDef(shape0, /*is_initializer=*/false, std::move(data0)), + /*in1=*/TestInputDef(shape1, /*is_initializer=*/false, std::move(data1)), + /*equation=*/"hij,hjk->hik", + /*tolerance=*/QDQTolerance()); +} + +TEST_F(QnnHTPBackendTests, EinsumQdqRank3MatMul_QK) { + const std::vector shape0{4, 5, 6}; + const std::vector shape1{4, 6, 5}; + const std::vector data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f); + const std::vector data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f); + RunQnnHtpQdqEinsum( + /*in0=*/TestInputDef(shape0, /*is_initializer=*/false, std::move(data0)), + /*in1=*/TestInputDef(shape1, /*is_initializer=*/false, std::move(data1)), + /*equation=*/"hQK,hKd->hQd", + /*tolerance=*/QDQTolerance()); +} + TEST_F(QnnHTPBackendTests, EinsumQdqRank4MatMulTransposeAll1) { const std::vector shape0{1, 3, 1, 7}; const std::vector shape1{1, 7, 1, 3}; diff --git a/onnxruntime/test/python/transformers/test_qmoe_cpu.py b/onnxruntime/test/python/transformers/test_qmoe_cpu.py index 909795e6639bb..efaaca29a01b6 100644 --- a/onnxruntime/test/python/transformers/test_qmoe_cpu.py +++ b/onnxruntime/test/python/transformers/test_qmoe_cpu.py @@ -10,28 +10,26 @@ # license information. # -------------------------------------------------------------------------- # -# Note on QMoE quantization approaches: +# QMoE quantization implementation notes: # -# Both CPU and CUDA implementations of QMoE use symmetric quantization: +# Both CPU and CUDA implementations use symmetric quantization centered around 0: +# - 4-bit: range [-8, 7] with no zero-point (symmetric around 0) +# - 8-bit: range [-128, 127] with no zero-point (symmetric around 0) # -# 1. CPU (this file): Symmetric quantization -# - 4-bit: range = [-8, 7] -# - 8-bit: range = [-128, 127] +# This follows the _symmetric_quantize_last_axis_of_batched_matrix pattern. +# Tolerance values account for numerical differences between implementations. # -# 2. CUDA: Symmetric quantization -# - 4-bit: range = [-8, 7] -# - 8-bit: range = [-128, 127] -# -# This aligned approach ensures better compatibility with TensorRT. -# The tolerance values used in testing account for minor numerical differences. +# Routing Logic: CPU implementation uses top-k selection first, then softmax +# normalization on the selected experts. This provides proper weight distribution +# while maintaining computational efficiency. # -------------------------------------------------------------------------- -import itertools -import os +import time import unittest from collections import OrderedDict import numpy import torch +import torch.nn.functional as F from onnx import helper from parameterized import parameterized from torch import nn @@ -41,27 +39,41 @@ try: from onnx import TensorProto - HAS_ONNX = True + has_onnx = True except ImportError: - print("ONNX is not installed. Some functionality will not be available.") - HAS_ONNX = False + has_onnx = False + + class TensorProtoPlaceholder: + FLOAT16 = 10 + FLOAT = 1 + + +class ClassInstantier(OrderedDict): + def __getitem__(self, key): + content = super().__getitem__(key) + cls, kwargs = content if isinstance(content, tuple) else (content, {}) + return cls(**kwargs) + + +ACT2CLS = { + "silu": nn.SiLU, + "gelu": nn.GELU, +} +ACT2FN = ClassInstantier(ACT2CLS) + +if not has_onnx: - # Define placeholder constants if onnx is not available class TensorProtoPlaceholder: FLOAT16 = 10 FLOAT = 1 - # BF16 not supported in QMoE CPU UINT8 = 2 TensorProto = TensorProtoPlaceholder -# Reduces number of tests to run for faster pipeline checks -pipeline_mode = os.getenv("PIPELINE_MODE", "1") == "1" - onnxruntime.preload_dlls() -# Force CPU execution provider regardless of CUDA availability device = torch.device("cpu") + ort_provider = ["CPUExecutionProvider"] torch.manual_seed(42) @@ -70,7 +82,6 @@ class TensorProtoPlaceholder: onnx_to_torch_type_map = { TensorProto.FLOAT16: torch.float16, TensorProto.FLOAT: torch.float, - # BF16 not supported in QMoE CPU TensorProto.UINT8: torch.uint8, } @@ -83,24 +94,17 @@ class TensorProtoPlaceholder: ort_dtype_name_map = { TensorProto.FLOAT16: "FP16", TensorProto.FLOAT: "FP32", - # QMoE CPU does not support BF16 } def quant_dequant(weights, is_4_bit_quantization: bool = True): """ Quantize and dequantize weights for testing purposes. - This function exactly matches the C++ implementation in QMoE CPU. - - This uses symmetric quantization to match the C++ implementation and for TensorRT compatibility: - - 4-bit: range = [-8, 7] - - 8-bit: range = [-128, 127] + This function uses symmetric quantization centered around 0 (no zero-point). - This implementation aims to precisely match the C++ implementation by: - 1. Using symmetric quantization (zero point = 0) - 2. Using the same scale calculation methodology - 3. Using consistent rounding behavior - 4. Properly handling edge cases + This uses symmetric quantization similar to _symmetric_quantize_last_axis_of_batched_matrix: + - 4-bit: range = [-8, 7], no zero-point (symmetric around 0) + - 8-bit: range = [-128, 127], no zero-point (symmetric around 0) """ # Handle edge case of all-zero weights tensor if torch.all(weights == 0): @@ -122,119 +126,107 @@ def quant_dequant(weights, is_4_bit_quantization: bool = True): torch.zeros_like(weights), ) - # Get absolute maximum for scale calculation + # Calculate scale like C++ implementation abs_max = weights.abs().max(dim=-1, keepdim=True)[0] + abs_max = torch.clamp(abs_max, min=1e-8) # More conservative clamping for better precision if is_4_bit_quantization: - # For 4-bit symmetric quantization, range is [-8, 7] - scale = abs_max / 7.0 # Scale factor ensures max value maps to 7 + # 4-bit: scale = abs_max / 7.0 (using 7.0 as max positive value for symmetric range) + # Use higher precision computation for better accuracy + scale = (abs_max.double() / 7.0).float() + 1e-12 # Handle potential edge cases for zero or very small weights - if torch.max(abs_max) < 1e-10: - # For extremely small values, avoid division by near-zero + if torch.max(abs_max) < 1e-8: packed_size = (weights.shape[-1] + 1) // 2 - # Just return zeros with appropriate scale to avoid numerical issues return ( - torch.ones_like(weights[..., 0:1]) * 1e-6, # Very small non-zero scale - torch.full( + torch.ones_like(weights[..., 0:1]) * 1e-8, + torch.zeros( (weights.shape[0], weights.shape[1], packed_size), - fill_value=8 | (8 << 4), # 8 = 0 in symmetric quantization dtype=torch.uint8, device=weights.device, ), torch.zeros_like(weights), ) - # Convert to int4 range (-8 to 7) - scaled_weights = torch.round(weights / scale) - clipped_weights = torch.clamp(scaled_weights, -8, 7) + # Quantize: round(weight / scale) then clamp to [-8, 7] + # Use higher precision for the division to reduce accumulated errors + scaled_weights = weights.double() / scale.double() + quantized_weights = torch.round(scaled_weights).clamp(-8, 7).float() - # Convert from int4 signed range [-8,7] to uint4 storage range [0,15] - # by adding 8 to map -8->0, -7->1, ..., 7->15 - quant_weights = (clipped_weights + 8).to(torch.uint8) + # For symmetric quantization, we use signed int4 representation + # Convert to uint8 storage for packing: shift [-8,7] -> [0,15] for storage only + storage_weights = (quantized_weights + 8).to(torch.uint8) # Pack 4-bit values into uint8 (every two elements) even_indices = torch.arange(0, weights.shape[-1], 2) odd_indices = torch.arange(1, weights.shape[-1], 2) - # Handle odd length by padding + # Handle odd length by padding with zero (which is 8 in storage representation) if odd_indices.shape[0] < even_indices.shape[0]: - # Pad with 8 (which represents 0 in symmetric quantization) - # Create a new padding tensor for more predictable behavior padding = torch.full( - (quant_weights.shape[0], quant_weights.shape[1], 1), - fill_value=8, + (storage_weights.shape[0], storage_weights.shape[1], 1), + fill_value=8, # 0 in symmetric quantization, stored as 8 dtype=torch.uint8, - device=quant_weights.device, + device=storage_weights.device, ) - quant_weights = torch.cat([quant_weights, padding], dim=-1) - odd_indices = torch.arange(1, quant_weights.shape[-1], 2) + storage_weights = torch.cat([storage_weights, padding], dim=-1) + odd_indices = torch.arange(1, storage_weights.shape[-1], 2) - even_weights = quant_weights[..., even_indices] - odd_weights = quant_weights[..., odd_indices] + even_weights = storage_weights[..., even_indices] + odd_weights = storage_weights[..., odd_indices] - # Pack two 4-bit values into each byte + # Pack: low nibble = even, high nibble = odd packed_weights = (even_weights & 0xF) | ((odd_weights & 0xF) << 4) - # For dequantization, unpack + # Dequantize: scale * quantized_value (no zero-point subtraction) + # Unpack for dequantization lower = packed_weights & 0xF upper = (packed_weights >> 4) & 0xF - # Restore original shape, taking care to handle dimensions correctly + # Restore original shape and convert back to signed representation unpacked_weights = torch.zeros_like(weights, dtype=torch.uint8) - - # Assign values ensuring we don't go out of bounds unpacked_weights[..., even_indices] = lower - # Calculate valid odd indices that fit within our original tensor dimensions valid_odd_length = min(odd_indices.shape[0], weights.shape[-1] - even_indices.shape[0]) - valid_odd_indices = odd_indices[:valid_odd_length] - - # Only assign upper bits to valid positions if valid_odd_length > 0: + valid_odd_indices = odd_indices[:valid_odd_length] unpacked_weights[..., valid_odd_indices] = upper[..., :valid_odd_length] - # Convert back from uint4 to int4 by subtracting 8 - int4_weights = unpacked_weights.float() - 8 - - # Dequantize with proper broadcasting - # Make sure scale has the right shape for broadcasting - scale_expanded = scale.float() - if scale_expanded.dim() < int4_weights.dim(): - for _ in range(int4_weights.dim() - scale_expanded.dim()): - scale_expanded = scale_expanded.unsqueeze(-1) - result = (int4_weights * scale_expanded).to(dtype=weights.dtype) - return scale.to(torch.float16), packed_weights, result + # Convert back to signed values: [0,15] -> [-8,7] and apply scale + signed_weights = unpacked_weights.float() - 8.0 # Convert storage back to signed + dequant_scale = scale.float() # Ensure FP32 precision for computation + result = dequant_scale * signed_weights # No zero-point in symmetric quantization + + return scale.to(torch.float16), packed_weights, result.to(weights.dtype) else: - # 8-bit symmetric quantization, range is [-128, 127] - scale = abs_max / 127.0 # Scale factor ensures max value maps to 127 + # 8-bit: scale = abs_max / 127.0 (using 127.0 as max positive value for symmetric range) + # Use higher precision computation for better accuracy + scale = (abs_max.double() / 127.0).float() + 1e-12 # Handle potential edge cases for zero or very small weights - if torch.max(abs_max) < 1e-10: - # For extremely small values, avoid division by near-zero - # Just return zeros with appropriate scale to avoid numerical issues + if torch.max(abs_max) < 1e-8: return ( - torch.ones_like(weights[..., 0:1]) * 1e-6, # Very small non-zero scale - torch.full_like(weights, fill_value=128, dtype=torch.uint8), # 128 = 0 in symmetric + torch.ones_like(weights[..., 0:1]) * 1e-8, + torch.zeros_like(weights, dtype=torch.uint8), torch.zeros_like(weights), ) - # Convert to int8 range (-128 to 127) - scaled_weights = torch.round(weights / scale) - clipped_weights = torch.clamp(scaled_weights, -128, 127) + # Quantize: round(weight / scale) then clamp to [-128, 127] + # Use higher precision for the division to reduce accumulated errors + scaled_weights = weights.double() / scale.double() + quantized_weights = torch.round(scaled_weights).clamp(-128, 127).float() - # Convert from int8 signed range [-128,127] to uint8 storage range [0,255] - # by adding 128 to map -128->0, -127->1, ..., 127->255 - quant_weights = (clipped_weights + 128).to(torch.uint8) + # For symmetric quantization, we use signed int8 representation + # Convert to uint8 storage: shift [-128,127] -> [0,255] for storage only + storage_weights = (quantized_weights + 128).to(torch.uint8) - # Dequantize - convert back from uint8 to int8 by subtracting 128, then multiply by scale - # Make sure scale has the right shape for broadcasting - scale_expanded = scale.float() - if scale_expanded.dim() < quant_weights.dim(): - for _ in range(quant_weights.dim() - scale_expanded.dim()): - scale_expanded = scale_expanded.unsqueeze(-1) - result = ((quant_weights.float() - 128) * scale_expanded).to(dtype=weights.dtype) - return scale.to(torch.float16), quant_weights, result + # Dequantize: scale * quantized_value (no zero-point subtraction) + # Convert back to signed values: [0,255] -> [-128,127] and apply scale + signed_weights = storage_weights.float() - 128.0 # Convert storage back to signed + dequant_scale = scale.float() # Ensure FP32 precision for computation + result = dequant_scale * signed_weights # No zero-point in symmetric quantization + + return scale.to(torch.float16), storage_weights, result.to(weights.dtype) def create_cpu_moe_onnx_graph( @@ -254,38 +246,23 @@ def create_cpu_moe_onnx_graph( use_swiglu=False, use_quant=False, quant_bits=4, + swiglu_interleaved=False, ): - # Make sure we have onnx available before proceeding - if not HAS_ONNX: - print("ONNX not found, skipping graph creation") + if not has_onnx: return None - # Define intermediate_size variable consistently inter_size = intermediate_size topk = top_k - # Note: SwiGLU requires 2 components (gate and value) - # Force use_quant to True - we only want to test QMoE use_quant = True - # Note: In QMoE, biases are not used at all, only scales - # The following parameters are only relevant when use_quant=False (which is never the case here) - # fc1_bias and fc2_bias are completely ignored for QMoE - - # Ensure all variables are properly initialized for safety if fc1_scales is None and use_quant: - print("Warning: fc1_scales is None but quantization is enabled") return None if fc2_scales is None and use_quant: - print("Warning: fc2_scales is None but quantization is enabled") return None - if not HAS_ONNX: - print("ONNX not found, skipping graph creation") + if not has_onnx: return None - # Using uint8 storage type with symmetric quantization - # 4-bit: range = [-8, 7] (stored as uint8 values [0, 15]) - # 8-bit: range = [-128, 127] (stored as uint8 values [0, 255]) assert fc1_experts_weights.dtype == torch.uint8, "FC1 weights must be uint8 for QMoE" assert fc2_experts_weights.dtype == torch.uint8, "FC2 weights must be uint8 for QMoE" assert fc1_scales is not None, "FC1 scales must be provided for QMoE" @@ -293,12 +270,9 @@ def create_cpu_moe_onnx_graph( assert fc1_scales.dtype == torch.float16, "FC1 scales must be float16 for QMoE" assert fc2_scales.dtype == torch.float16, "FC2 scales must be float16 for QMoE" - # Make sure we have onnx available before proceeding - if not HAS_ONNX: - print("ONNX not found, skipping graph creation") + if not has_onnx: return None - # Always use QMoE, never MoE op_name = "QMoE" inputs = [ "input", @@ -311,10 +285,6 @@ def create_cpu_moe_onnx_graph( "", ] - # Note: In QMoE mode, biases are not used at all - # This code path is never executed since use_quant is always True - - # Use SwiGLU activation if specified, otherwise use SiLU activation = "swiglu" if use_swiglu else "silu" nodes = [ @@ -324,8 +294,13 @@ def create_cpu_moe_onnx_graph( ["output"], "MoE_0", k=topk, - normalize_routing_weights=0, + normalize_routing_weights=1, # Use proper routing normalization to match PyTorch behavior activation_type=activation, + # Add new attributes with backwards-compatible default values + swiglu_fusion=1 if (use_swiglu and swiglu_interleaved) else 0, # 1 = fused and interleaved + swiglu_limit=7.0, + activation_alpha=1.702, + activation_beta=1.0, domain="com.microsoft", ), ] @@ -333,15 +308,10 @@ def create_cpu_moe_onnx_graph( if use_quant: nodes[0].attribute.extend([helper.make_attribute("expert_weight_bits", quant_bits)]) - # For 4-bit quantization, we need to pack 2 values into each byte - pack_factor = 2 if quant_bits == 4 else 1 - - # For SwiGLU, we need to double the FC1 dimension to accommodate both gate and value paths - act_factor = 2 if use_swiglu else 1 - - # FC1 shape needs to account for both SwiGLU and quantization packing - fc1_shape = [num_experts, hidden_size, (act_factor * inter_size) // pack_factor] - fc2_shape = [num_experts, inter_size, hidden_size // pack_factor] + # Weights are store in column major order. Need pack 2 int4 values into uint8. + # Use the actual tensor shapes instead of calculating them to avoid size mismatches + fc1_shape = list(fc1_experts_weights.shape) + fc2_shape = list(fc2_experts_weights.shape) torch_dtype = onnx_to_torch_type_map[onnx_dtype] @@ -354,39 +324,90 @@ def create_cpu_moe_onnx_graph( weight_onnx_type, fc1_shape, fc1_experts_weights.flatten().detach().cpu().numpy().astype(weight_numpy_type).tolist(), - raw=False, ), helper.make_tensor( "fc2_experts_weights", weight_onnx_type, fc2_shape, fc2_experts_weights.flatten().detach().cpu().numpy().astype(weight_numpy_type).tolist(), - raw=False, ), ] - # QMoE always uses scales, never biases - # For SwiGLU, FC1 scales shape needs to be doubled to account for gate and value components fc1_scale_shape = [num_experts, 2 * inter_size if use_swiglu else inter_size] fc2_scale_shape = [num_experts, hidden_size] + + fc1_scale_size = num_experts * (2 * inter_size if use_swiglu else inter_size) + fc2_scale_size = num_experts * hidden_size + + # Handle scale tensors - fc1_scales and fc2_scales are guaranteed to be not None due to earlier assertions + # Handle different possible scale tensor structures for fc1_scales + if len(fc1_scales.shape) == 4: + # 4D case: [num_experts, inter_size, hidden_size, 1] - extract first scale per expert per output + if use_swiglu: + fc1_scale_tensor = fc1_scales.to(torch_dtype)[:, : 2 * inter_size, 0, 0].flatten().detach().cpu().numpy() + else: + fc1_scale_tensor = fc1_scales.to(torch_dtype)[:, :inter_size, 0, 0].flatten().detach().cpu().numpy() + elif len(fc1_scales.shape) == 2: + # 2D case: already flattened, just ensure correct size + fc1_scale_tensor = fc1_scales.to(torch_dtype).flatten().detach().cpu().numpy() + if use_swiglu and fc1_scale_tensor.size == num_experts * inter_size: + # For SwiGLU, duplicate the scales to cover both gate and value components + fc1_scale_tensor = numpy.tile(fc1_scale_tensor.reshape(num_experts, inter_size), (1, 2)).flatten() + elif fc1_scale_tensor.size > fc1_scale_size: + # Truncate to expected size + fc1_scale_tensor = fc1_scale_tensor[:fc1_scale_size] + else: + # Other cases: flatten and truncate/pad as needed + fc1_scale_tensor = fc1_scales.to(torch_dtype).flatten().detach().cpu().numpy() + if fc1_scale_tensor.size > fc1_scale_size: + fc1_scale_tensor = fc1_scale_tensor[:fc1_scale_size] + elif fc1_scale_tensor.size < fc1_scale_size: + # Pad with ones if too small + pad_size = fc1_scale_size - fc1_scale_tensor.size + fc1_scale_tensor = numpy.concatenate([fc1_scale_tensor, numpy.ones(pad_size, dtype=fc1_scale_tensor.dtype)]) + + # Process scale tensor for proper shape + fc1_scale_data_list = fc1_scale_tensor.tolist() + fc1_scale_data = fc1_scale_data_list + + # Handle different possible scale tensor structures for fc2_scales + if len(fc2_scales.shape) == 4: + # 4D case: [num_experts, hidden_size, inter_size, 1] - extract first scale per expert per output + fc2_scale_tensor = fc2_scales.to(torch_dtype)[:, :hidden_size, 0, 0].flatten().detach().cpu().numpy() + elif len(fc2_scales.shape) == 2: + # 2D case: already flattened, just ensure correct size + fc2_scale_tensor = fc2_scales.to(torch_dtype).flatten().detach().cpu().numpy() + if fc2_scale_tensor.size > fc2_scale_size: + # Truncate to expected size + fc2_scale_tensor = fc2_scale_tensor[:fc2_scale_size] + else: + # Other cases: flatten and truncate/pad as needed + fc2_scale_tensor = fc2_scales.to(torch_dtype).flatten().detach().cpu().numpy() + if fc2_scale_tensor.size > fc2_scale_size: + fc2_scale_tensor = fc2_scale_tensor[:fc2_scale_size] + elif fc2_scale_tensor.size < fc2_scale_size: + # Pad with ones if too small + pad_size = fc2_scale_size - fc2_scale_tensor.size + fc2_scale_tensor = numpy.concatenate([fc2_scale_tensor, numpy.ones(pad_size, dtype=fc2_scale_tensor.dtype)]) + + # Process scale tensor for proper shape + fc2_scale_data_list = fc2_scale_tensor.tolist() + fc2_scale_data = fc2_scale_data_list + initializers.extend( [ helper.make_tensor( "fc1_scales", onnx_dtype, fc1_scale_shape, - fc1_scales.to(torch_dtype).flatten().tolist() - if fc1_scales is not None - else [1.0] * (num_experts * inter_size), + fc1_scale_data, raw=False, ), helper.make_tensor( "fc2_scales", onnx_dtype, fc2_scale_shape, - fc2_scales.to(torch_dtype).flatten().tolist() - if fc2_scales is not None - else [1.0] * (num_experts * hidden_size), + fc2_scale_data, raw=False, ), ] @@ -427,13 +448,6 @@ def __getitem__(self, key): return cls(**kwargs) -ACT2CLS = { - "silu": nn.SiLU, - "gelu": nn.GELU, -} -ACT2FN = ClassInstantier(ACT2CLS) - - class PhiMoEConfig: def __init__( self, @@ -452,7 +466,96 @@ def __init__( self.router_jitter_noise = router_jitter_noise +class SwigluMoeConfig: + def __init__( + self, + hidden_size=4096, + intermediate_size=14336, + num_local_experts=8, + num_experts_per_token=2, + ): + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_local_experts = num_local_experts + self.num_experts_per_token = num_experts_per_token + + +def swiglu(x: torch.Tensor, alpha: float = 1.702, limit: float = 7.0): + dim = x.shape[-1] + x = x.view(-1, dim // 2, 2) + x_glu, x_linear = x[..., 0], x[..., 1] + + if limit is not None: + x_glu = x_glu.clamp(max=limit) + x_linear = x_linear.clamp(min=-limit, max=limit) + + y = x_glu * torch.sigmoid(alpha * x_glu) * (x_linear + 1) + return y + + +class MoEBlockSparseTop2MLP(nn.Module): + def __init__(self, config): + super().__init__() + self.ffn_dim = config.intermediate_size + self.hidden_dim = config.hidden_size + + self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) + self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False) + self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) + + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_states): + current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states) + current_hidden_states = self.w2(current_hidden_states) + return current_hidden_states + + +class PhiMoEBlockSparseTop2MLP(MoEBlockSparseTop2MLP): + def __init__(self, config: PhiMoEConfig): + super().__init__(config) + + +class PhiMoESwiGLUMLP(nn.Module): + """ + Phi3 MoE expert converted to 2-weight SwiGLU structure for CPU compatibility. + This converts the traditional 3-weight Phi3 structure to SwiGLU format. + """ + + def __init__(self, config: PhiMoEConfig): + super().__init__() + self.intermediate_size = config.intermediate_size + self.hidden_dim = config.hidden_size + self.w1 = nn.Linear(self.hidden_dim, 2 * self.intermediate_size, bias=True) + self.w2 = nn.Linear(self.intermediate_size, self.hidden_dim, bias=True) + + def forward(self, x): + x1 = self.w1(x) + y = swiglu(x1) + y = self.w2(y) + return y + + +class SwigluMlp(nn.Module): + def __init__(self, config): + super().__init__() + self.intermediate_size = config.intermediate_size + self.hidden_dim = config.hidden_size + self.w1 = nn.Linear(self.hidden_dim, 2 * self.intermediate_size, bias=True) + self.w2 = nn.Linear(self.intermediate_size, self.hidden_dim, bias=True) + + def forward(self, x): + x1 = self.w1(x) + y = swiglu(x1) + y = self.w2(y) + return y + + def masked_sampling_omp_inference(scores, top_k, jitter_eps, training): + """ + Updated to match the CUDA implementation's routing logic for fair comparison. + This now uses the same complex jitter-based masking approach as the CUDA tests. + """ assert top_k == 2 assert not training @@ -467,8 +570,6 @@ def masked_sampling_omp_inference(scores, top_k, jitter_eps, training): dim=-1, index=selected_experts[:, 0].unsqueeze(-1) ) - ################ second expert gating ################ - mask_logits_threshold_2 = mask_logits_threshold[:, 1].unsqueeze(-1) factor = scores.abs().clamp(min=mask_logits_threshold_2) @@ -489,60 +590,6 @@ def masked_sampling_omp_inference(scores, top_k, jitter_eps, training): ) -class MoEBlockSparseTop2MLP(nn.Module): - def __init__(self, config): - super().__init__() - self.ffn_dim = config.intermediate_size - self.hidden_dim = config.hidden_size - - self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) - self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False) - self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) - - self.act_fn = ACT2FN[config.hidden_act] - - def forward(self, hidden_states): - current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states) - current_hidden_states = self.w2(current_hidden_states) - return current_hidden_states - - -class PhiMoEBlockSparseTop2MLP(MoEBlockSparseTop2MLP): - def __init__(self, config: PhiMoEConfig, use_swiglu=False): - super().__init__(config) - self.use_swiglu = use_swiglu - - def forward(self, hidden_states): - if self.use_swiglu: - # SwiGLU implementation matching C++ implementation exactly - gate_output = self.w1(hidden_states) # Gate - value_output = self.w3(hidden_states) # Value - - # Apply SwiGLU exactly as in the C++ implementation - # C++ uses swiglu_alpha = 1.702f and clamp_limit = 7.0f - swiglu_alpha = 1.702 - clamp_limit = 7.0 - - # Apply clamping to match C++ implementation - gate_output = torch.clamp(gate_output, max=clamp_limit) # Clamp max only for gate - value_output = torch.clamp(value_output, min=-clamp_limit, max=clamp_limit) # Clamp both for value - - # Compute gate activation: gate * sigmoid(alpha * gate) - sigmoid_input = swiglu_alpha * gate_output - sigmoid_output = torch.sigmoid(sigmoid_input) - swish_output = gate_output * sigmoid_output - - # Multiply by (value + 1) as done in C++ - current_hidden_states = swish_output * (value_output + 1.0) - - # Apply FC2 - current_hidden_states = self.w2(current_hidden_states) - return current_hidden_states - else: - # Original implementation with standard activation - return super().forward(hidden_states) - - class SparseMoeBlockORTHelper(nn.Module): def __init__(self, quant_bits=0, onnx_dtype=None): super().__init__() @@ -555,7 +602,6 @@ def __init__(self, quant_bits=0, onnx_dtype=None): def create_ort_session(self, moe_onnx_graph): if moe_onnx_graph is None: - print("No ONNX graph provided, skipping session creation") return None sess_options = onnxruntime.SessionOptions() @@ -563,9 +609,7 @@ def create_ort_session(self, moe_onnx_graph): try: ort_session = onnxruntime.InferenceSession(moe_onnx_graph, sess_options, providers=ort_provider) - except Exception as e: - print(f"Failed to create ONNX Runtime session with provider {ort_provider}: {e}") - print("Skipping ONNX Runtime execution for this test case.") + except Exception: return None return ort_session @@ -574,20 +618,15 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: pass def ort_forward(self, hidden_states: torch.Tensor, enable_performance_test=False) -> torch.Tensor: - # If session creation failed, we can't run inference if self.ort_sess is None: - print("No ORT session available, skipping ONNX Runtime execution") return None batch_size, sequence_length, hidden_dim = hidden_states.shape hidden_states_flat = hidden_states.view(-1, hidden_dim) - # router_logits: (batch * sequence_length, n_experts) router_logits = self.gate(hidden_states_flat) - # Determine the correct torch dtype from the onnx_dtype torch_dtype = onnx_to_torch_type_map[self.onnx_dtype] - # Prepare tensors on the correct device for ORT inference with the CORRECT dtype tensors = { "input": hidden_states_flat.clone().to(device=device, dtype=torch_dtype), "router_probs": router_logits.clone().to(device=device, dtype=torch_dtype), @@ -595,11 +634,9 @@ def ort_forward(self, hidden_states: torch.Tensor, enable_performance_test=False } try: - # Bind inputs and outputs to torch tensors directly. iobinding = self.ort_sess.io_binding() for name, tensor in tensors.items(): - # Ensure tensor is on the globally defined device if name == "output": iobinding.bind_output( name=name, @@ -624,200 +661,63 @@ def ort_forward(self, hidden_states: torch.Tensor, enable_performance_test=False iobinding.synchronize_outputs() if enable_performance_test: - import time # noqa: PLC0415 - - repeat = 100 # Using fewer repeats for CPU tests + repeat = 100 s = time.time() for _ in range(repeat): iobinding.synchronize_inputs() self.ort_sess.run_with_iobinding(iobinding) iobinding.synchronize_outputs() e = time.time() - print(f"QMoE CPU kernel time: {(e - s) / repeat * 1000} ms") + time_ms = (e - s) / repeat * 1000 + is_swiglu = hasattr(self, "use_swiglu") and self.use_swiglu + is_interleaved = hasattr(self, "swiglu_interleaved") and self.swiglu_interleaved + act_type = f"SwiGLU(interleaved={is_interleaved})" if is_swiglu else "SiLU" + print(f"ORT Performance - {act_type} {self.quant_bits}-bit: {time_ms:.3f} ms/inference") - # The output tensor is on `device`. Reshape and return it. return tensors["output"].reshape(batch_size, sequence_length, hidden_dim) except Exception as e: - print(f"Error running ORT session: {e!s}") raise - def parity_check(self): - hidden_state = torch.randn(self.batch_size, self.sequence_length, self.hidden_dim).to(device) - torch_output = self.forward(hidden_state) - ort_output = self.ort_forward(hidden_state) - - # If no ORT output was produced, we can't do a parity check - if ort_output is None: - print("ORT execution failed or is not supported, skipping parity check") - return - - dtype_str = ort_dtype_name_map[self.onnx_dtype] - max_diff = (torch_output.cpu() - ort_output.cpu()).abs().max() - non_finite = torch.isnan(max_diff) or torch.isinf(max_diff) - - print( - f"name: {self.__class__.__name__}, quant_bits: {self.quant_bits}, dtype: {dtype_str}," - f" batch: {self.batch_size}, seq_len: {self.sequence_length}," - f" max_diff: {max_diff}" - ) - - # Report if NaN or Inf values are detected - if non_finite: - print( - "Warning: NaN or Inf values detected in the output difference. Numerical comparisons will be limited." - ) - - # Maps "ort_type:quant_bits" to (atol, rtol) - # Note: Now that both CPU and CUDA use symmetric quantization, - # we can use more consistent tolerances across implementations. - ort_dtype_quant_bits_tolerance_map = { - "FP32:0": (5e-3, 1e-3), - "FP16:0": (5e-2, 1e-3), - "FP16:4": (2.0, 8e-3), # Improved tolerance with symmetric quantization - "FP16:8": (1.5, 8e-3), # Improved tolerance with symmetric quantization - } - - tolerance_key = f"{dtype_str}:{self.quant_bits}" - if tolerance_key not in ort_dtype_quant_bits_tolerance_map: - print(f"Warning: No tolerance defined for {tolerance_key}, using default") - atol, rtol = 10.0, 1e-1 - else: - atol, rtol = ort_dtype_quant_bits_tolerance_map[tolerance_key] - - # Report stats but don't assert (just for information) - # Handle NaN/Inf values more gracefully - try: - diff = (torch_output.cpu() - ort_output.cpu()).abs() - mean_diff = diff.mean().item() if not torch.isnan(diff.mean()) else float("nan") - median_diff = diff.median().item() if not torch.isnan(diff.median()) else float("nan") - p95_diff = ( - torch.quantile(diff, 0.95).item() if not torch.isnan(torch.quantile(diff, 0.95)) else float("nan") - ) - - print(f"Stats - Mean diff: {mean_diff}, Median diff: {median_diff}, 95th percentile: {p95_diff}") - - # Check if results are within tolerance - max_diff_val = max_diff.item() - if not non_finite and max_diff_val > atol: - print(f"Warning: Maximum difference ({max_diff_val:.6f}) exceeds absolute tolerance ({atol:.6f})") - elif not non_finite: - print(f"Success: All values within absolute tolerance ({atol:.6f})") - - # For quantized models, the relative difference can be very large for small values - # This is because quantization has a greater effect on small values than large ones - # Add a larger epsilon to prevent misleading large relative differences for near-zero values - # Safely compute relative differences - if not non_finite: - relative_diff = diff / torch.max(torch_output.cpu().abs(), torch.tensor(1e-3)) - max_rel_diff = relative_diff.max().item() - rel_exceeds = (relative_diff > rtol).float().mean().item() * 100 - - if max_rel_diff > rtol: - print( - f"Warning: Maximum relative difference ({max_rel_diff:.6f}) exceeds relative tolerance ({rtol:.6f})" - ) - print(f"Percentage of values exceeding relative tolerance: {rel_exceeds:.2f}%") - else: - print(f"Success: All relative differences within relative tolerance ({rtol:.6f})") - except Exception as e: - # If any calculation fails, just log it but don't crash the test - print(f"Warning: Error calculating statistics: {e}") - - # Note: Higher relative differences are expected in quantized models - # This is because quantization inherently introduces error, especially for small values - # The key metric is the absolute difference, which we've significantly improved - - def benchmark_ort(self): - hidden_state = torch.randn(self.batch_size, self.sequence_length, self.hidden_dim).to(device) - self.ort_forward(hidden_state, enable_performance_test=True) - - -class PhiMoESparseMoeBlock(SparseMoeBlockORTHelper): - """ - This implementation is - strictly equivalent to standard MoE with full capacity (no - dropped tokens). It's faster since it formulates MoE operations - in terms of block-sparse operations to accommodate imbalanced - assignments of tokens to experts, whereas standard MoE either - (1) drop tokens at the cost of reduced performance or (2) set - capacity factor to number of experts and thus waste computation - and memory on padding. - - CPU version: Modified to use only FC1 and FC2 for CPU compatibility. - - Quantization: Uses symmetric quantization to exactly match the C++ implementation: - - 4-bit: range = [-8, 7] (stored as uint8 values [0, 15]) - - 8-bit: range = [-128, 127] (stored as uint8 values [0, 255]) - This ensures the test exactly simulates the C++ implementation with full - compatibility with the CUDA implementation and TensorRT. - """ - - def __init__(self, config, batch_size, sequence_length, quant_bits=0, onnx_dtype=None, use_swiglu=False): - # Ensure we always have a valid quantization bits value (4 or 8) before passing to parent - if quant_bits <= 0: - print("Warning: quant_bits was set to 0 or negative, forcing to 4-bit") - quant_bits = 4 - - # Now pass the validated quant_bits to parent constructor - super().__init__(quant_bits, onnx_dtype) - self.hidden_dim = config.hidden_size - self.ffn_dim = config.intermediate_size - self.num_experts = config.num_local_experts - self.top_k = config.num_experts_per_tok - self.router_jitter_noise = config.router_jitter_noise - self.use_swiglu = use_swiglu - - # gating - self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False) - - # Use PhiMoEBlockSparseTop2MLP for all experts - self.experts = nn.ModuleList( - [PhiMoEBlockSparseTop2MLP(config, use_swiglu=self.use_swiglu) for _ in range(self.num_experts)] - ) + def recreate_onnx_model(self): + """Recreate the ONNX model with the current weights to reflect any changes to the quantization code.""" w1_list, w2_list = [], [] w1_scale_list, w2_scale_list = [], [] - # Always use quantization for QMoE is_4_bit = self.quant_bits == 4 for i in range(self.num_experts): - # Quantize the weights w1_scale, pre_qweight1, w1_qdq = quant_dequant(self.experts[i].w1.weight, is_4_bit) w2_scale, pre_qweight2, w2_qdq = quant_dequant(self.experts[i].w2.weight, is_4_bit) - # For SwiGLU, we also need to quantize w3 (value) weights - w3_qdq = None # Initialize w3_qdq to avoid unbound variable error if self.use_swiglu: - w3_scale, pre_qweight3, w3_qdq = quant_dequant(self.experts[i].w3.weight, is_4_bit) - # Combine gate (w1) and value (w3) for SwiGLU - if is_4_bit: - # For 4-bit, we need to combine the packed weights in the right format - # Double the intermediate size for SwiGLU (gate + value) - # Each byte contains two 4-bit values + if self.swiglu_interleaved: + pass + else: + w3_scale, pre_qweight3, w3_qdq = quant_dequant(self.experts[i].w3.weight, is_4_bit) + gate_weights = pre_qweight1 value_weights = pre_qweight3 - # Create a new tensor with double the last dimension - combined_shape = list(gate_weights.shape) - combined_shape[-1] *= 2 # Double the last dimension for gate+value - combined_weights = torch.zeros(combined_shape, dtype=torch.uint8, device=gate_weights.device) - combined_weights[..., : gate_weights.shape[-1]] = gate_weights - combined_weights[..., gate_weights.shape[-1] :] = value_weights - pre_qweight1 = combined_weights - else: - # For 8-bit, we can just concatenate along the last dimension - pre_qweight1 = torch.cat([pre_qweight1, pre_qweight3], dim=-1) + gate_scales = w1_scale + value_scales = w3_scale + + pre_qweight1 = torch.cat([gate_weights, value_weights], dim=0) + w1_scale = torch.cat([gate_scales, value_scales], dim=0) + + if self.swiglu_interleaved: + self.experts[i].w1.weight = nn.Parameter(w1_qdq.contiguous().clone()) - # Same for scales - combine gate and value scales - w1_scale = torch.cat([w1_scale, w3_scale], dim=-1) + else: + intermediate_size = self.experts[i].w1.weight.shape[0] + gate_dequant = w1_qdq[:intermediate_size].contiguous().clone() + value_dequant = w1_qdq[intermediate_size:].contiguous().clone() + self.experts[i].w1.weight.data = gate_dequant + self.experts[i].w3.weight.data = value_dequant + else: + self.experts[i].w1.weight.data = w1_qdq.contiguous().clone() - # Update the expert weights with dequantized values for PyTorch execution - self.experts[i].w1.weight.data = w1_qdq - self.experts[i].w2.weight.data = w2_qdq - if self.use_swiglu and w3_qdq is not None: - self.experts[i].w3.weight.data = w3_qdq + self.experts[i].w2.weight.data = w2_qdq.contiguous().clone() - # Store the quantized weights and scales for ONNX model w1_list.append(pre_qweight1) w2_list.append(pre_qweight2) w1_scale_list.append(w1_scale) @@ -826,14 +726,14 @@ def __init__(self, config, batch_size, sequence_length, quant_bits=0, onnx_dtype self.moe_experts_weight1 = torch.stack(w1_list, dim=0) self.moe_experts_weight2 = torch.stack(w2_list, dim=0) - # Always use scales for QMoE moe_experts_weight_scale1 = torch.stack(w1_scale_list, dim=0) moe_experts_weight_scale2 = torch.stack(w2_scale_list, dim=0) - self.batch_size = batch_size - self.sequence_length = sequence_length + if moe_experts_weight_scale1.dim() == 3: + moe_experts_weight_scale1 = moe_experts_weight_scale1.squeeze(-1) + if moe_experts_weight_scale2.dim() == 3: + moe_experts_weight_scale2 = moe_experts_weight_scale2.squeeze(-1) - # Use CPU specific graph creation try: self.moe_onnx_graph = create_cpu_moe_onnx_graph( hidden_size=self.hidden_dim, @@ -841,49 +741,277 @@ def __init__(self, config, batch_size, sequence_length, quant_bits=0, onnx_dtype num_experts=self.num_experts, top_k=self.top_k, intermediate_size=self.ffn_dim, - torch_dtype=torch.float32, # Assuming float32 as default + torch_dtype=torch.float32, onnx_dtype=self.onnx_dtype, fc1_experts_weights=self.moe_experts_weight1, fc2_experts_weights=self.moe_experts_weight2, - # Biases are not used in QMoE, only passed as None for API compatibility + # Biases are not used in QMoE fc1_bias=None, fc2_bias=None, # Scales are used for dequantization fc1_scales=moe_experts_weight_scale1, fc2_scales=moe_experts_weight_scale2, - use_swiglu=self.use_swiglu, # Use SwiGLU if specified + use_swiglu=self.use_swiglu, use_quant=True, # Always use QMoE quant_bits=self.quant_bits, + swiglu_interleaved=self.swiglu_interleaved if hasattr(self, "swiglu_interleaved") else False, ) - except Exception as e: - print(f"Error creating ONNX graph: {e}") + except Exception: self.moe_onnx_graph = None + return False self.ort_sess = self.create_ort_session(self.moe_onnx_graph) if self.moe_onnx_graph else None + return self.ort_sess is not None + + def parity_check(self): + model_updated = self.recreate_onnx_model() + if not model_updated: + return + + hidden_state = torch.randn(self.batch_size, self.sequence_length, self.hidden_dim).to(device) + torch_output = self.forward(hidden_state) + ort_output = self.ort_forward(hidden_state) + + if ort_output is None: + return + + torch_has_nan = torch.isnan(torch_output).any() + ort_has_nan = torch.isnan(ort_output).any() + torch_has_inf = torch.isinf(torch_output).any() + ort_has_inf = torch.isinf(ort_output).any() + + if torch_has_nan or ort_has_nan or torch_has_inf or ort_has_inf: + torch_output_clean = torch.where( + torch.isnan(torch_output) | torch.isinf(torch_output), torch.zeros_like(torch_output), torch_output + ) + ort_output_clean = torch.where( + torch.isnan(ort_output) | torch.isinf(ort_output), torch.zeros_like(ort_output), ort_output + ) + max_diff = (torch_output_clean.cpu() - ort_output_clean.cpu()).abs().max() + + if (torch_has_nan and ort_has_nan) or (torch_has_inf and ort_has_inf): + problematic_torch = torch.isnan(torch_output) | torch.isinf(torch_output) + problematic_ort = torch.isnan(ort_output) | torch.isinf(ort_output) + if torch.equal(problematic_torch, problematic_ort): + max_diff = 0.0 + else: + max_diff = (torch_output.cpu() - ort_output.cpu()).abs().max() + + is_swiglu = hasattr(self, "use_swiglu") and self.use_swiglu + is_interleaved = hasattr(self, "swiglu_interleaved") and self.swiglu_interleaved + act_type = f"SwiGLU(interleaved={is_interleaved})" if is_swiglu else "SiLU" + + print(f"Parity check - {act_type} {self.quant_bits}-bit: max_diff = {max_diff:.6f}") + + ort_dtype_quant_bits_tolerance_map = { + "FP32:0": (5e-3, 1e-3), + "FP16:0": (5e-2, 1e-3), + "FP16:4": (0.05, 0.01), + "FP16:8": (0.02, 0.01), + "FP32:4": (0.11, 0.01), + "FP32:8": (0.11, 0.01), + } + + dtype_str = ort_dtype_name_map[self.onnx_dtype] + tolerance_key = f"{dtype_str}:{self.quant_bits}" + if tolerance_key in ort_dtype_quant_bits_tolerance_map: + base_atol, rtol = ort_dtype_quant_bits_tolerance_map[tolerance_key] + + if max_diff > base_atol: + raise AssertionError( + f"QMoE parity check failed: max difference {max_diff:.6f} exceeds " + f"tolerance {base_atol:.6f} for {tolerance_key}" + ) + else: + fallback_atol = 0.1 + if max_diff > fallback_atol: + raise AssertionError( + f"QMoE parity check failed: max difference {max_diff:.6f} exceeds " + f"fallback tolerance {fallback_atol:.6f} for unknown config {tolerance_key}" + ) + + def benchmark_ort(self): + hidden_state = torch.randn(self.batch_size, self.sequence_length, self.hidden_dim).to(device) + self.ort_forward(hidden_state, enable_performance_test=True) + + +def small_test_cases(): + for batch_size in [1, 4]: + for sequence_length in [32, 128]: + yield batch_size, sequence_length + + +class SwigluMoEBlock(SparseMoeBlockORTHelper): + def __init__( + self, config: SwigluMoeConfig, batch_size: int, sequence_length: int, quant_bits: int = 0, onnx_dtype=None + ): + super().__init__(quant_bits, onnx_dtype=onnx_dtype) + self.hidden_dim = config.hidden_size + self.ffn_dim = config.intermediate_size + self.num_experts = config.num_local_experts + self.top_k = config.num_experts_per_token + self.use_swiglu = True + self.swiglu_interleaved = True + use_quant = self.quant_bits > 0 + + self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=True) + + self.experts = nn.ModuleList([SwigluMlp(config) for _ in range(self.num_experts)]) + + fc1_w_list, fc2_w_list = [], [] + fc1_b_list, fc2_b_list = [], [] + scale_1_list, scale_2_list = [], [] + + for expert in self.experts: + fc1_b_list.append(expert.w1.bias) + fc2_b_list.append(expert.w2.bias) + if not use_quant: + fc1_w_list.append(expert.w1.weight) + fc2_w_list.append(expert.w2.weight) + else: + is_4_bit = self.quant_bits == 4 + + scale1, pre_qweight1, w1_qdq = quant_dequant(expert.w1.weight, is_4_bit) + scale2, pre_qweight2, w2_qdq = quant_dequant(expert.w2.weight, is_4_bit) + + expert.w1.weight.data = w1_qdq + expert.w2.weight.data = w2_qdq + + fc1_w_list.append(pre_qweight1) + fc2_w_list.append(pre_qweight2) + scale_1_list.append(scale1) + scale_2_list.append(scale2) + + self.batch_size = batch_size + self.sequence_length = sequence_length + + self.moe_onnx_graph = None + self.ort_sess = None def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - """ """ batch_size, sequence_length, hidden_dim = hidden_states.shape - hidden_states = hidden_states.view(-1, hidden_dim) router_logits = self.gate(hidden_states) + routing_weights, selected_experts = torch.topk(router_logits, self.top_k, dim=-1) + routing_weights = F.softmax(routing_weights, dim=1, dtype=torch.float) + + routing_weights = routing_weights.to(hidden_states.dtype) + + final_hidden_states = torch.zeros( + (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device + ) + + expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) + + for expert_idx in range(self.num_experts): + expert_layer = self.experts[expert_idx] + idx, top_x = torch.where(expert_mask[expert_idx]) + + if top_x.shape[0] == 0: + continue + + current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) + current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] + + final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) + + final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) + return final_hidden_states + + +class PhiMoESparseMoeBlock(SparseMoeBlockORTHelper): + def __init__( + self, config: PhiMoEConfig, batch_size: int, sequence_length: int, quant_bits: int = 0, onnx_dtype=None + ): + super().__init__(quant_bits, onnx_dtype=onnx_dtype) + self.hidden_dim = config.hidden_size + self.ffn_dim = config.intermediate_size + self.num_experts = config.num_local_experts + self.top_k = config.num_experts_per_tok + self.router_jitter_noise = config.router_jitter_noise + self.use_swiglu = True + self.swiglu_interleaved = True + use_quant = self.quant_bits > 0 + + self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=True) + + self.experts = nn.ModuleList([PhiMoESwiGLUMLP(config) for _ in range(self.num_experts)]) + + fc1_w_list, fc2_w_list = [], [] + fc1_b_list, fc2_b_list = [], [] + scale_1_list, scale_2_list = [], [] + + for expert in self.experts: + fc1_b_list.append(expert.w1.bias) + fc2_b_list.append(expert.w2.bias) + if not use_quant: + fc1_w_list.append(expert.w1.weight) + fc2_w_list.append(expert.w2.weight) + else: + is_4_bit = self.quant_bits == 4 + + scale1, pre_qweight1, w1_qdq = quant_dequant(expert.w1.weight, is_4_bit) + scale2, pre_qweight2, w2_qdq = quant_dequant(expert.w2.weight, is_4_bit) + + expert.w1.weight.data = w1_qdq + expert.w2.weight.data = w2_qdq + + fc1_w_list.append(pre_qweight1) + fc2_w_list.append(pre_qweight2) + scale_1_list.append(scale1) + scale_2_list.append(scale2) - routing_weights, selected_experts = masked_sampling_omp_inference( - router_logits, + fc1_experts_weights = torch.stack(fc1_w_list, dim=0) + fc2_experts_weights = torch.stack(fc2_w_list, dim=0) + fc1_experts_bias = torch.stack(fc1_b_list, dim=0) + fc2_experts_bias = torch.stack(fc2_b_list, dim=0) + + moe_experts_weight_scale1 = torch.stack(scale_1_list, dim=0) if use_quant else None + moe_experts_weight_scale2 = torch.stack(scale_2_list, dim=0) if use_quant else None + + self.batch_size = batch_size + self.sequence_length = sequence_length + + self.moe_onnx_graph = create_cpu_moe_onnx_graph( + hidden_size=self.hidden_dim, + sequence_length=self.batch_size * self.sequence_length, + num_experts=self.num_experts, top_k=self.top_k, - jitter_eps=self.router_jitter_noise, - training=False, + intermediate_size=self.ffn_dim, + torch_dtype=torch.float32, + onnx_dtype=self.onnx_dtype, + fc1_experts_weights=fc1_experts_weights, + fc2_experts_weights=fc2_experts_weights, + fc1_bias=fc1_experts_bias, + fc2_bias=fc2_experts_bias, + fc1_scales=moe_experts_weight_scale1, + fc2_scales=moe_experts_weight_scale2, + use_swiglu=self.use_swiglu, + use_quant=use_quant, + quant_bits=self.quant_bits, + swiglu_interleaved=self.swiglu_interleaved, ) + self.ort_sess = self.create_ort_session(self.moe_onnx_graph) if self.moe_onnx_graph else None + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """PyTorch reference forward pass using SwiGLU-style routing""" + batch_size, sequence_length, hidden_dim = hidden_states.shape + + hidden_states = hidden_states.view(-1, hidden_dim) + router_logits = self.gate(hidden_states) + + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + routing_weights = routing_weights.to(hidden_states.dtype) + final_hidden_states = torch.zeros( (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device ) - # One hot encode the selected experts to create an expert mask - # this will be used to easily index which expert is going to be sollicitated expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) - # Loop over all available experts in the model and perform the computation on each expert for expert_idx in range(self.num_experts): expert_layer = self.experts[expert_idx] idx, top_x = torch.where(expert_mask[expert_idx]) @@ -891,106 +1019,99 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if top_x.shape[0] == 0: continue - # in torch it is faster to index using lists than torch tensors - top_x_list = top_x.tolist() - idx_list = idx.tolist() - - # Index the correct hidden states and compute the expert hidden state for - # the current expert. We need to make sure to multiply the output hidden - # states by `routing_weights` on the corresponding tokens (top-1 and top-2) - current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim) - current_hidden_states = expert_layer(current_state) * routing_weights[top_x_list, idx_list, None] + current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) + current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] - # However `index_add_` only support torch tensors for indexing so we'll use - # the `top_x` tensor here. final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) - final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) + final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) return final_hidden_states -def small_test_cases(): - for batch_size in [1, 4]: - for sequence_length in [32, 128]: - yield batch_size, sequence_length +disable_cpu_qmoe_tests = False +# Define test cases for different MoE types +phi3_test_cases = [ + (1, 32, 4), + (1, 32, 8), + (2, 16, 4), + (2, 16, 8), +] -# Define our test cases for QMoE (4-bit and 8-bit quantization) -# Only test QMoE since standard MoE is not supported on CPU -cpu_phi3_test_cases = list( - itertools.product( - [1, 4], # batch_size - [8, 32], # sequence_length - smaller sequence lengths for CPU - [4, 8], # quant_bits - only test QMoE (4-bit and 8-bit) - [False], # use_swiglu - standard SiLU cases - ) -) - -# Additional test cases for SwiGLU activation -cpu_phi3_swiglu_test_cases = list( - itertools.product( - [1, 4], # batch_size - [8, 32], # sequence_length - smaller sequence lengths for CPU - [4, 8], # quant_bits - only test QMoE (4-bit and 8-bit) - [True], # use_swiglu - SwiGLU activation - ) -) -# Temporarily disable CPU qMoE tests. A fix will come soon. -disable_cpu_qmoe_tests = True +@unittest.skipIf(disable_cpu_qmoe_tests, "Skipping qMoE cpu tests") +class TestPhiQMoECPU(unittest.TestCase): + @parameterized.expand(phi3_test_cases) + def test_phi3_qmoe_parity_cpu(self, batch_size, sequence_length, quant_bits): + torch.manual_seed(42) + numpy.random.seed(42) + + test_config = f"batch_size={batch_size}, sequence_length={sequence_length}, quant_bits={quant_bits}" + print(f"Running Phi3 QMoE test: {test_config}") + + config = PhiMoEConfig(hidden_size=128, intermediate_size=256, num_local_experts=4, num_experts_per_tok=2) + + phi3_moe = PhiMoESparseMoeBlock( + config, + batch_size=batch_size, + sequence_length=sequence_length, + quant_bits=quant_bits, + onnx_dtype=TensorProto.FLOAT, + ) + + hidden_states = torch.randn(batch_size, sequence_length, config.hidden_size).to(torch.float32) + + torch_result = phi3_moe.forward(hidden_states) + + # Verify output shape and basic properties + expected_shape = (batch_size, sequence_length, config.hidden_size) + self.assertEqual(torch_result.shape, expected_shape) + self.assertFalse(torch.isnan(torch_result).any()) + self.assertFalse(torch.isinf(torch_result).any()) + + phi3_moe.parity_check() + + +disable_cpu_qmoe_tests = False + +swiglu_test_cases = [ + (1, 32, 4), + (1, 32, 8), + (2, 16, 4), + (2, 16, 8), +] @unittest.skipIf(disable_cpu_qmoe_tests, "Skipping qMoE cpu tests") -class TestPhiQMoECPU(unittest.TestCase): - @parameterized.expand(cpu_phi3_test_cases + cpu_phi3_swiglu_test_cases) - def test_phi3_qmoe_parity_cpu(self, batch_size, sequence_length, quant_bits, use_swiglu=False): - activation_type = "SwiGLU" if use_swiglu else "SiLU" - print( - f"Running PhiMoE CPU test with batch_size={batch_size}, sequence_length={sequence_length}, " - f"quant_bits={quant_bits}, activation={activation_type}" +class TestSwigluQMoECPU(unittest.TestCase): + @parameterized.expand(swiglu_test_cases) + def test_swiglu_qmoe_parity_cpu(self, batch_size, sequence_length, quant_bits): + torch.manual_seed(42) + numpy.random.seed(42) + + test_config = f"batch_size={batch_size}, sequence_length={sequence_length}, quant_bits={quant_bits}" + print(f"Running SwiGLU test: {test_config}") + + config = SwigluMoeConfig(hidden_size=128, intermediate_size=256, num_local_experts=4, num_experts_per_token=2) + + swiglu_moe = SwigluMoEBlock( + config, + batch_size=batch_size, + sequence_length=sequence_length, + quant_bits=quant_bits, + onnx_dtype=TensorProto.FLOAT, ) - config = PhiMoEConfig(hidden_size=256, intermediate_size=512, hidden_act="silu") # Smaller sizes for CPU tests - phi3_moe = PhiMoESparseMoeBlock(config, batch_size, sequence_length, quant_bits, use_swiglu=use_swiglu) - phi3_moe.to(device) - # Skip tests if ONNX is not available - if not HAS_ONNX: - self.skipTest("ONNX is not installed") + hidden_states = torch.randn(batch_size, sequence_length, config.hidden_size).to(torch.float32) - # Skip if the session creation failed - if phi3_moe.ort_sess is None: - self.skipTest("Failed to create ONNX Runtime session - CPU MoE operator not available") + torch_result = swiglu_moe.forward(hidden_states) - try: - phi3_moe.parity_check() - except RuntimeError as e: - if "FC3 gating is not yet implemented on CPU" in str(e): - self.skipTest("FC3 gating is not yet implemented on CPU") - else: - raise - - @parameterized.expand([(8, False), (4, False), (8, True), (4, True)]) - def test_phi3_qmoe_cpu_benchmark(self, quant_bits, use_swiglu=False): - activation_type = "SwiGLU" if use_swiglu else "SiLU" - print(f"Benchmarking PhiMoE CPU with quant_bits={quant_bits}, activation={activation_type}") - batch_size = 1 - sequence_length = 32 - config = PhiMoEConfig(hidden_size=256, intermediate_size=512) - phi3_moe = PhiMoESparseMoeBlock(config, batch_size, sequence_length, quant_bits, use_swiglu=use_swiglu) - phi3_moe.to(device) - - # Skip tests if ONNX is not available or session creation failed - if not HAS_ONNX or phi3_moe.ort_sess is None: - self.skipTest("ONNX not installed or CPU MoE operator not available") - return + expected_shape = (batch_size, sequence_length, config.hidden_size) + self.assertEqual(torch_result.shape, expected_shape) + self.assertFalse(torch.isnan(torch_result).any()) + self.assertFalse(torch.isinf(torch_result).any()) - try: - phi3_moe.benchmark_ort() - except RuntimeError as e: - if "FC3 gating is not yet implemented on CPU" in str(e): - self.skipTest("FC3 gating is not yet implemented on CPU") - else: - raise + swiglu_moe.parity_check() if __name__ == "__main__":