diff --git a/.github/workflows/linux-wasm-ci-build-and-test-workflow.yml b/.github/workflows/linux-wasm-ci-build-and-test-workflow.yml
index fc9bb53659442..1f57b4c6d2ba2 100644
--- a/.github/workflows/linux-wasm-ci-build-and-test-workflow.yml
+++ b/.github/workflows/linux-wasm-ci-build-and-test-workflow.yml
@@ -15,6 +15,14 @@ on:
required: false
type: boolean
default: false
+ use_vcpkg:
+ required: false
+ type: boolean
+ default: true
+ enable_wasm_threads:
+ required: false
+ type: boolean
+ default: true
build_jsep:
required: false
type: boolean
@@ -29,7 +37,7 @@ jobs:
runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"]
env:
buildArch: x64
- common_build_args: --parallel --use_vcpkg --use_vcpkg_ms_internal_asset_cache --config ${{ inputs.build_config }} --skip_submodule_sync --build_wasm --enable_wasm_simd --enable_wasm_threads ${{ inputs.extra_build_args }}
+ common_build_args: --parallel ${{ inputs.use_vcpkg == true && '--use_vcpkg --use_vcpkg_ms_internal_asset_cache' || '' }} --config ${{ inputs.build_config }} --skip_submodule_sync --build_wasm --enable_wasm_simd ${{ inputs.enable_wasm_threads == true && '--enable_wasm_threads' || '' }} ${{ inputs.extra_build_args }}
steps:
- name: Checkout code
diff --git a/.github/workflows/web.yml b/.github/workflows/web.yml
index 8f922ef26cd7e..0133e4994e5e9 100644
--- a/.github/workflows/web.yml
+++ b/.github/workflows/web.yml
@@ -52,6 +52,16 @@ jobs:
build_jsep: true
build_webgpu: true
+ wasm_Release_static_library:
+ needs: precheck
+ uses: ./.github/workflows/linux-wasm-ci-build-and-test-workflow.yml
+ with:
+ build_config: Release
+ extra_build_args: "--skip_tests --enable_wasm_api_exception_catching --disable_rtti --build_wasm_static_lib"
+ use_vcpkg: false
+ enable_wasm_threads: false
+ skip_publish: true
+
web_Debug:
needs:
- precheck
diff --git a/README.md b/README.md
index f1817282b61a0..019bc8291354e 100644
--- a/README.md
+++ b/README.md
@@ -20,26 +20,6 @@
- ONNX Runtime Inferencing: [microsoft/onnxruntime-inference-examples](https://github.com/microsoft/onnxruntime-inference-examples)
- ONNX Runtime Training: [microsoft/onnxruntime-training-examples](https://github.com/microsoft/onnxruntime-training-examples)
-## Builtin Pipeline Status
-
-|System|Inference|Training|
-|---|---|---|
-|Windows|[](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=9)
[](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=218)
[](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=47)
[](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=228)||
-|Linux|[](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=11)
[](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=64)
[](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=12)
[](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=45)
[](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=55)|[](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=86)
[](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=84)|
-|Mac|[](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=13)||
-|Android|[](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=53)||
-|iOS|[](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=134)||
-|Web|[](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=161)||
-|Other|[](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=187&repoName=microsoft%2Fonnxruntime)||
-
-This project is tested with [BrowserStack](https://www.browserstack.com/home).
-
-## Third-party Pipeline Status
-
-|System|Inference|Training|
-|---|---|---|
-|Linux|[](https://github.com/Ascend/onnxruntime/actions/workflows/build-and-test.yaml)||
-
## Releases
The current release and past releases can be found here: https://github.com/microsoft/onnxruntime/releases.
diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt
index 416ed5e49f25a..47bfa3f312eec 100644
--- a/cmake/CMakeLists.txt
+++ b/cmake/CMakeLists.txt
@@ -372,6 +372,7 @@ if (onnxruntime_USE_ROCM)
if (HIPIFY_PERL_PATH-NOTFOUND)
MESSAGE(FATAL_ERROR "hipify-perl not found")
endif()
+ MESSAGE("HIPIFY PATH:"${HIPIFY_PERL_PATH}/hipify-perl)
set(onnxruntime_HIPIFY_PERL ${HIPIFY_PERL_PATH}/hipify-perl)
endif()
diff --git a/cmake/adjust_global_compile_flags.cmake b/cmake/adjust_global_compile_flags.cmake
index 8f5ef15c53ef2..6647312e99d8f 100644
--- a/cmake/adjust_global_compile_flags.cmake
+++ b/cmake/adjust_global_compile_flags.cmake
@@ -4,6 +4,8 @@ if (ANDROID)
# Build shared libraries with support for 16 KB ELF alignment
# https://source.android.com/docs/core/architecture/16kb-page-size/16kb#build-lib-16kb-alignment
set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -Wl,-z,max-page-size=16384")
+ # Also apply to MODULE libraries (like libonnxruntime4j_jni.so)
+ set(CMAKE_MODULE_LINKER_FLAGS "${CMAKE_MODULE_LINKER_FLAGS} -Wl,-z,max-page-size=16384")
endif()
# Enable space optimization for gcc/clang
diff --git a/cmake/deps.txt b/cmake/deps.txt
index 6e045f6dcdc9d..2df433b0353c6 100644
--- a/cmake/deps.txt
+++ b/cmake/deps.txt
@@ -9,9 +9,6 @@
#since the file contains a version string: "lts_20230802". However, the file is for debugging purposes only and would
#not affect built binaries.
#
-# NOTE: You must run deps_update_and_upload.py and generate_cgmanifest.py when ready to test your changes in a CI.
-# See https://microsoft.sharepoint.com/teams/ONNX2/_layouts/OneNote.aspx?id=%2Fteams%2FONNX2%2FShared%20Documents%2FNotebooks%2FONNX%20Ecosystem%20Team%20Notebook&wd=target%28Development.one%7C63D3AB47-51D1-4A62-9965-66882234BD44%2FAdd%20or%20update%20a%20dependency%20in%20deps.txt%7C0E9ED71D-89D5-40FA-B05F-C0123289C591%2F%29
-#
abseil_cpp;https://github.com/abseil/abseil-cpp/archive/refs/tags/20240722.0.zip;36ee53eb1466fb6e593fc5c286680de31f8a494a
coremltools;https://github.com/apple/coremltools/archive/refs/tags/7.1.zip;f1bab0f30966f2e217d8e01207d518f230a1641a
cxxopts;https://github.com/jarro2783/cxxopts/archive/3c73d91c0b04e2b59462f0a741be8c07024c1bc0.zip;6c6ca7f8480b26c8d00476e0e24b7184717fe4f0
@@ -29,7 +26,7 @@ flatbuffers;https://github.com/google/flatbuffers/archive/refs/tags/v23.5.26.zip
fp16;https://github.com/Maratyszcza/FP16/archive/0a92994d729ff76a58f692d3028ca1b64b145d91.zip;b985f6985a05a1c03ff1bb71190f66d8f98a1494
fxdiv;https://github.com/Maratyszcza/FXdiv/archive/63058eff77e11aa15bf531df5dd34395ec3017c8.zip;a5658f4036402dbca7cebee32be57fb8149811e1
google_benchmark;https://github.com/google/benchmark/archive/refs/tags/v1.8.5.zip;cd47d3d272faf353600c8cc2fdec2b52d6f69177
-googletest;https://github.com/google/googletest/archive/refs/tags/v1.15.0.zip;9d2d0af8d77ac726ea55d44a8fa727ec98311349
+googletest;https://github.com/google/googletest/archive/refs/tags/v1.17.0.zip;f638fa0e724760e2ba07ff8cfba32cd644e1ce28
#xnnpack 2024.09.04
googlexnnpack;https://github.com/google/XNNPACK/archive/fe98e0b93565382648129271381c14d6205255e3.zip;14f61dcf17cec2cde34ba2dcf61d6f24bf6059f3
json;https://github.com/nlohmann/json/archive/refs/tags/v3.11.3.zip;5e88795165cc8590138d1f47ce94ee567b85b4d6
@@ -37,7 +34,7 @@ microsoft_gsl;https://github.com/microsoft/GSL/archive/refs/tags/v4.0.0.zip;cf36
microsoft_wil;https://github.com/microsoft/wil/archive/refs/tags/v1.0.230629.1.zip;e4a542a323c070376f7c2d1973d0f7ddbc1d2fa5
mimalloc;https://github.com/microsoft/mimalloc/archive/refs/tags/v2.1.1.zip;d5ee7d34223d0567892db5179849939c8769dc41
mp11;https://github.com/boostorg/mp11/archive/refs/tags/boost-1.82.0.zip;9bc9e01dffb64d9e0773b2e44d2f22c51aace063
-onnx;https://github.com/onnx/onnx/archive/7fc2b81a275223f5b02a522d9d2649837542a7be.zip;555338a12903941bb45f57540476244f9ffee17b
+onnx;https://github.com/onnx/onnx/archive/refs/tags/v1.18.0.zip;f156d032a3af91b66d554e11158b33ca77bbb1f2
# Use the latest commit of 10.9-GA
onnx_tensorrt;https://github.com/onnx/onnx-tensorrt/archive/d5dce67db7c2e64b07e055571f5ec06f7f254de2.zip;01114d3b67650857281fa50faa2e412130a63b69
protobuf;https://github.com/protocolbuffers/protobuf/archive/refs/tags/v21.12.zip;7cf2733949036c7d52fda017badcab093fe73bfa
@@ -58,6 +55,6 @@ cutlass;https://github.com/NVIDIA/cutlass/archive/refs/tags/v3.9.2.zip;b7f8dc4a8
extensions;https://github.com/microsoft/onnxruntime-extensions/archive/c24b7bab0c12f53da76d0c31b03b9f0f8ec8f3b4.zip;239063aee4946a9af147b473a4c3da78ba7413b4
composable_kernel;https://github.com/ROCmSoftwarePlatform/composable_kernel/archive/204da9c522cebec5220bba52cd3542ebcaf99e7a.zip;1827348efd47831c13074245274d41b7cae8a557
directx_headers;https://github.com/microsoft/DirectX-Headers/archive/refs/tags/v1.613.1.zip;47653509a3371eabb156360f42faf582f314bf2e
-cudnn_frontend;https://github.com/NVIDIA/cudnn-frontend/archive/refs/tags/v1.7.0.zip;d0753d8d5b39947ca0729d7773cb84653a129eb1
+cudnn_frontend;https://github.com/NVIDIA/cudnn-frontend/archive/refs/tags/v1.12.0.zip;7e733cfdc410d777b76122d64232499205589a96
dawn;https://github.com/google/dawn/archive/4cb1f9be152a4fa6bb695c08cd707ab078a1e2fb.zip;de39336b7715f53c14eec61072293b85cc73b691
kleidiai;https://github.com/ARM-software/kleidiai/archive/refs/tags/v1.4.0.tar.gz;22d3b57b54a61c194ab256ff11b0353a3b220244
diff --git a/cmake/external/cudnn_frontend.cmake b/cmake/external/cudnn_frontend.cmake
index 8642607fa0ca0..d89ab0f669f35 100644
--- a/cmake/external/cudnn_frontend.cmake
+++ b/cmake/external/cudnn_frontend.cmake
@@ -6,8 +6,10 @@ onnxruntime_fetchcontent_declare(
EXCLUDE_FROM_ALL
)
+set(CUDNN_FRONTEND_SKIP_JSON_LIB OFF CACHE BOOL "" FORCE)
set(CUDNN_FRONTEND_BUILD_SAMPLES OFF CACHE BOOL "" FORCE)
-set(CUDNN_FRONTEND_BUILD_UNIT_TESTS OFF CACHE BOOL "" FORCE)
+set(CUDNN_FRONTEND_BUILD_TESTS OFF CACHE BOOL "" FORCE)
set(CUDNN_FRONTEND_BUILD_PYTHON_BINDINGS OFF CACHE BOOL "" FORCE)
set(CUDNN_PATH ${onnxruntime_CUDNN_HOME})
+
onnxruntime_fetchcontent_makeavailable(cudnn_frontend)
diff --git a/cmake/external/onnx b/cmake/external/onnx
index 7fc2b81a27522..e709452ef2bbc 160000
--- a/cmake/external/onnx
+++ b/cmake/external/onnx
@@ -1 +1 @@
-Subproject commit 7fc2b81a275223f5b02a522d9d2649837542a7be
+Subproject commit e709452ef2bbc1d113faf678c24e6d3467696e83
diff --git a/cmake/onnxruntime_common.cmake b/cmake/onnxruntime_common.cmake
index e629df4843109..1e26eede8a66f 100644
--- a/cmake/onnxruntime_common.cmake
+++ b/cmake/onnxruntime_common.cmake
@@ -169,10 +169,6 @@ if(APPLE)
target_link_libraries(onnxruntime_common PRIVATE "-framework Foundation")
endif()
-if(MSVC)
- target_link_libraries(onnxruntime_common PRIVATE dxcore.lib)
-endif()
-
if(MSVC)
if(onnxruntime_target_platform STREQUAL "ARM64")
set(ARM64 TRUE)
diff --git a/cmake/onnxruntime_kernel_explorer.cmake b/cmake/onnxruntime_kernel_explorer.cmake
index 62a6d45088052..65a20c4229290 100644
--- a/cmake/onnxruntime_kernel_explorer.cmake
+++ b/cmake/onnxruntime_kernel_explorer.cmake
@@ -64,7 +64,7 @@ elseif (onnxruntime_USE_ROCM)
)
auto_set_source_files_hip_language(${kernel_explorer_kernel_srcs} ${kernel_explorer_rocm_kernel_srcs})
target_sources(kernel_explorer PRIVATE ${kernel_explorer_rocm_kernel_srcs})
- target_compile_definitions(kernel_explorer PRIVATE __HIP_PLATFORM_AMD__=1 __HIP_PLATFORM_HCC__=1 HIPBLAS_V2)
+ target_compile_definitions(kernel_explorer PRIVATE __HIP_PLATFORM_AMD__=1 __HIP_PLATFORM_HCC__=1 HIPBLAS)
if (onnxruntime_USE_COMPOSABLE_KERNEL)
target_compile_definitions(kernel_explorer PRIVATE USE_COMPOSABLE_KERNEL)
if (onnxruntime_USE_COMPOSABLE_KERNEL_CK_TILE)
diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake
index 3279a17f8cd5e..f8f5546ae9465 100644
--- a/cmake/onnxruntime_mlas.cmake
+++ b/cmake/onnxruntime_mlas.cmake
@@ -281,6 +281,9 @@ function(setup_kleidiai)
${MLAS_SRC_DIR}/kai_ukernel_interface.cpp
)
target_link_libraries(onnxruntime_mlas PRIVATE kleidiai)
+
+ list(APPEND onnxruntime_EXTERNAL_LIBRARIES kleidiai)
+ set(onnxruntime_EXTERNAL_LIBRARIES ${onnxruntime_EXTERNAL_LIBRARIES} PARENT_SCOPE)
endfunction()
if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten")
diff --git a/cmake/onnxruntime_providers_dml.cmake b/cmake/onnxruntime_providers_dml.cmake
index 62136c5c568d7..ac4f0103ea323 100644
--- a/cmake/onnxruntime_providers_dml.cmake
+++ b/cmake/onnxruntime_providers_dml.cmake
@@ -59,7 +59,7 @@
if (GDK_PLATFORM STREQUAL Scarlett)
target_link_libraries(onnxruntime_providers_dml PRIVATE ${gdk_dx_libs})
else()
- target_link_libraries(onnxruntime_providers_dml PRIVATE dxguid.lib d3d12.lib dxgi.lib dxcore.lib)
+ target_link_libraries(onnxruntime_providers_dml PRIVATE dxguid.lib d3d12.lib dxgi.lib)
endif()
target_link_libraries(onnxruntime_providers_dml PRIVATE delayimp.lib)
diff --git a/cmake/onnxruntime_providers_rocm.cmake b/cmake/onnxruntime_providers_rocm.cmake
index 108b8b46deb27..03f1e288f4d0d 100644
--- a/cmake/onnxruntime_providers_rocm.cmake
+++ b/cmake/onnxruntime_providers_rocm.cmake
@@ -154,7 +154,7 @@
set_target_properties(onnxruntime_providers_rocm PROPERTIES LINKER_LANGUAGE CXX)
set_target_properties(onnxruntime_providers_rocm PROPERTIES FOLDER "ONNXRuntime")
- target_compile_definitions(onnxruntime_providers_rocm PRIVATE HIPBLAS_V2)
+ target_compile_definitions(onnxruntime_providers_rocm PRIVATE HIPBLAS)
if (onnxruntime_ENABLE_TRAINING)
target_include_directories(onnxruntime_providers_rocm PRIVATE ${ORTTRAINING_ROOT} ${CMAKE_CURRENT_BINARY_DIR}/amdgpu/orttraining ${MPI_CXX_INCLUDE_DIRS})
diff --git a/cmake/onnxruntime_python.cmake b/cmake/onnxruntime_python.cmake
index f6eac2c24eca2..5639b295f0787 100644
--- a/cmake/onnxruntime_python.cmake
+++ b/cmake/onnxruntime_python.cmake
@@ -189,7 +189,10 @@ set(onnxruntime_pybind11_state_static_providers
if(onnxruntime_BUILD_QNN_EP_STATIC_LIB)
list(APPEND onnxruntime_pybind11_state_static_providers PRIVATE onnxruntime_providers_qnn)
endif()
-
+if(WIN32)
+ # onnxruntime_pybind11_state is a DLL
+ target_sources(onnxruntime_pybind11_state PRIVATE "${ONNXRUNTIME_ROOT}/core/dll/dllmain.cc")
+endif()
target_link_libraries(onnxruntime_pybind11_state PRIVATE
onnxruntime_session
${onnxruntime_libs}
@@ -1064,12 +1067,6 @@ if (onnxruntime_USE_QNN)
${QNN_LIB_FILES}
$/onnxruntime/capi/
)
- add_custom_command(
- TARGET onnxruntime_pybind11_state POST_BUILD
- COMMAND ${CMAKE_COMMAND} -E copy
- $
- $/onnxruntime/capi/
- )
if (EXISTS "${onnxruntime_QNN_HOME}/Qualcomm AI Hub Proprietary License.pdf")
add_custom_command(
TARGET onnxruntime_pybind11_state POST_BUILD
diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake
index c8de91d6c6eb6..26ef7970fa2b6 100644
--- a/cmake/onnxruntime_unittests.cmake
+++ b/cmake/onnxruntime_unittests.cmake
@@ -1334,6 +1334,14 @@ endif()
# shared lib
if (onnxruntime_BUILD_SHARED_LIB)
+ if(WIN32)
+ AddTest(DYN
+ TARGET onnxruntime_shared_lib_dlopen_test
+ SOURCES ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/dlopen_main.cc
+ LIBS onnxruntime
+ DEPENDS ${all_dependencies}
+ )
+ endif()
onnxruntime_add_static_library(onnxruntime_mocked_allocator ${TEST_SRC_DIR}/util/test_allocator.cc)
target_include_directories(onnxruntime_mocked_allocator PUBLIC ${TEST_SRC_DIR}/util/include)
target_link_libraries(onnxruntime_mocked_allocator PRIVATE ${GSL_TARGET})
diff --git a/cmake/onnxruntime_webassembly.cmake b/cmake/onnxruntime_webassembly.cmake
index f00292fade52d..c0b6efb0eb75d 100644
--- a/cmake/onnxruntime_webassembly.cmake
+++ b/cmake/onnxruntime_webassembly.cmake
@@ -503,58 +503,60 @@ jsepDownload:_pp_")
set_target_properties(onnxruntime_webassembly PROPERTIES OUTPUT_NAME ${target_name} SUFFIX ".mjs")
- #
- # The following POST_BUILD script is a workaround for enabling:
- # - using onnxruntime-web with Multi-threading enabled when import from CDN
- # - using onnxruntime-web when consumed in some frameworks like Vite
- #
- # In the use case mentioned above, the file name of the script may be changed. So we need to replace the line:
- # `new Worker(new URL("ort-wasm-*.mjs", import.meta.url),`
- # with
- # `new Worker(new URL(import.meta.url),`
- #
- # This behavior is introduced in https://github.com/emscripten-core/emscripten/pull/22165. Since it's unlikely to be
- # reverted, and there is no config to disable this behavior, we have to use a post-build script to workaround it.
- #
-
- # Generate a script to do the post-build work
- file(WRITE ${CMAKE_CURRENT_BINARY_DIR}/wasm_post_build.js "
- const fs = require('fs');
- const path = require('path');
-
- // node wasm_post_build.js
- const mjsFilePath = process.argv[2];
- let contents = fs.readFileSync(mjsFilePath).toString();
-
- const regex = 'new Worker\\\\(new URL\\\\(\".+?\", ?import\\\\.meta\\\\.url\\\\),';
- const matches = [...contents.matchAll(new RegExp(regex, 'g'))];
- if (matches.length !== 1) {
- throw new Error(
- `Unexpected number of matches for \"${regex}\" in \"${filepath}\": ${matches.length}.`,
+ if (onnxruntime_ENABLE_WEBASSEMBLY_THREADS)
+ #
+ # The following POST_BUILD script is a workaround for enabling:
+ # - using onnxruntime-web with Multi-threading enabled when import from CDN
+ # - using onnxruntime-web when consumed in some frameworks like Vite
+ #
+ # In the use case mentioned above, the file name of the script may be changed. So we need to replace the line:
+ # `new Worker(new URL("ort-wasm-*.mjs", import.meta.url),`
+ # with
+ # `new Worker(new URL(import.meta.url),`
+ #
+ # This behavior is introduced in https://github.com/emscripten-core/emscripten/pull/22165. Since it's unlikely to be
+ # reverted, and there is no config to disable this behavior, we have to use a post-build script to workaround it.
+ #
+
+ # Generate a script to do the post-build work
+ file(WRITE ${CMAKE_CURRENT_BINARY_DIR}/wasm_post_build.js "
+ const fs = require('fs');
+ const path = require('path');
+
+ // node wasm_post_build.js
+ const mjsFilePath = process.argv[2];
+ let contents = fs.readFileSync(mjsFilePath).toString();
+
+ const regex = 'new Worker\\\\(new URL\\\\(\".+?\", ?import\\\\.meta\\\\.url\\\\),';
+ const matches = [...contents.matchAll(new RegExp(regex, 'g'))];
+ if (matches.length !== 1) {
+ throw new Error(
+ `Unexpected number of matches for \"\${regex}\" in \"\${mjsFilePath}\": \${matches.length}.`,
+ );
+ }
+
+ // Replace the only occurrence.
+ contents = contents.replace(
+ new RegExp(regex),
+ `new Worker(new URL(import.meta.url),`,
);
- }
- // Replace the only occurrence.
- contents = contents.replace(
- new RegExp(regex),
- `new Worker(new URL(import.meta.url),`,
- );
+ fs.writeFileSync(mjsFilePath, contents);
+ "
+ )
- fs.writeFileSync(mjsFilePath, contents);
- "
- )
+ find_program(NODE_EXECUTABLE node required)
+ if (NOT NODE_EXECUTABLE)
+ message(FATAL_ERROR "Node is required to run the post-build script")
+ endif()
- find_program(NODE_EXECUTABLE node required)
- if (NOT NODE_EXECUTABLE)
- message(FATAL_ERROR "Node is required to run the post-build script")
+ add_custom_command(
+ TARGET onnxruntime_webassembly
+ POST_BUILD
+ # Backup file at $.bak
+ COMMAND ${CMAKE_COMMAND} -E copy_if_different "$" "$.bak"
+ COMMAND ${CMAKE_COMMAND} -E echo "Performing post-process for $"
+ COMMAND ${NODE_EXECUTABLE} "${CMAKE_CURRENT_BINARY_DIR}/wasm_post_build.js" "$"
+ )
endif()
-
- add_custom_command(
- TARGET onnxruntime_webassembly
- POST_BUILD
- # Backup file at $.bak
- COMMAND ${CMAKE_COMMAND} -E copy_if_different "$" "$.bak"
- COMMAND ${CMAKE_COMMAND} -E echo "Performing post-process for $"
- COMMAND ${NODE_EXECUTABLE} "${CMAKE_CURRENT_BINARY_DIR}/wasm_post_build.js" "$"
- )
endif()
diff --git a/cmake/vcpkg-ports/onnx/portfile.cmake b/cmake/vcpkg-ports/onnx/portfile.cmake
index 0cd6bfa305843..7df4fd0898bde 100644
--- a/cmake/vcpkg-ports/onnx/portfile.cmake
+++ b/cmake/vcpkg-ports/onnx/portfile.cmake
@@ -3,8 +3,8 @@ vcpkg_check_linkage(ONLY_STATIC_LIBRARY)
vcpkg_from_github(
OUT_SOURCE_PATH SOURCE_PATH
REPO onnx/onnx
- REF 7fc2b81a275223f5b02a522d9d2649837542a7be
- SHA512 6911b4e532a7735ef40660dee904877850234a600b39d46a8dab91f6506c6547e3bd10af5d5f0f0abc0c6e7e6e1fc04c0ea307eb9f4aef5c614eaaa50403804d
+ REF "v${VERSION}"
+ SHA512 2f38664947c8d1efc40620a7c1b1953d2aa4b0a37b67c4886b86e77c1d697363c26413413ddda8eabc545892fb1bcb43afc7e93e62f0901527524a2727e1ea8d
PATCHES
fix-cmakelists.patch
fix-dependency-protobuf.patch
diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs
index 9794d2c184d5d..6e325f7fe9646 100644
--- a/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs
+++ b/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs
@@ -17,6 +17,7 @@ public enum GraphOptimizationLevel
ORT_DISABLE_ALL = 0,
ORT_ENABLE_BASIC = 1,
ORT_ENABLE_EXTENDED = 2,
+ ORT_ENABLE_LAYOUT = 3,
ORT_ENABLE_ALL = 99
}
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index 8c1ab002bce67..b657c828fbde1 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -298,6 +298,7 @@ Do not modify directly.*
|||[19, 20]|**T1** = tensor(float), tensor(float16)
**T2** = tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int8), tensor(uint8)|
|||[13, 18]|**T1** = tensor(float)
**T2** = tensor(int8), tensor(uint8)|
|||[10, 12]|**T1** = tensor(float)
**T2** = tensor(int8), tensor(uint8)|
+|RMSNormalization|*in* X:**T**
*in* scale:**V**
*out* Y:**V**|23+|**T** = tensor(double), tensor(float), tensor(float16)
**V** = tensor(double), tensor(float), tensor(float16)|
|RNN|*in* X:**T**
*in* W:**T**
*in* R:**T**
*in* B:**T**
*in* sequence_lens:**T1**
*in* initial_h:**T**
*out* Y:**T**
*out* Y_h:**T**|22+|**T** = tensor(float)
**T1** = tensor(int32)|
|||[14, 21]|**T** = tensor(float)
**T1** = tensor(int32)|
|||[7, 13]|**T** = tensor(float)
**T1** = tensor(int32)|
@@ -437,7 +438,7 @@ Do not modify directly.*
|||[13, 17]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||[2, 10]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
-|SplitToSequence|*in* input:**T**
*in* split:**I**
*out* output_sequence:**S**|11+|**I** = tensor(int32), tensor(int64)
**S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))
**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(string)|
+|SplitToSequence|*in* input:**T**
*in* split:**I**
*out* output_sequence:**S**|11+|**I** = tensor(int32), tensor(int64)
**S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))
**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(string)|
|Sqrt|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(double), tensor(float)|
|||[6, 12]|**T** = tensor(double), tensor(float)|
|Squeeze|*in* data:**T**
*in* axes:**tensor(int64)**
*out* squeezed:**T**
or
*in* data:**T**
*out* squeezed:**T**|23+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
@@ -782,6 +783,7 @@ Do not modify directly.*
|||[19, 20]|**T1** = tensor(float), tensor(float16)
**T2** = tensor(float8e4m3fn), tensor(float8e5m2), tensor(int8), tensor(uint8)|
|||[13, 18]|**T1** = tensor(float)
**T2** = tensor(int8), tensor(uint8)|
|||[10, 12]|**T1** = tensor(float)
**T2** = tensor(int8), tensor(uint8)|
+|RMSNormalization|*in* X:**T**
*in* scale:**V**
*out* Y:**V**|23+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)
**V** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
|RNN|*in* X:**T**
*in* W:**T**
*in* R:**T**
*in* B:**T**
*in* sequence_lens:**T1**
*in* initial_h:**T**
*out* Y:**T**
*out* Y_h:**T**|14+|**T** = tensor(double), tensor(float), tensor(float16)
**T1** = tensor(int32)|
|||[7, 13]|**T** = tensor(double), tensor(float), tensor(float16)
**T1** = tensor(int32)|
|RandomNormal|*out* output:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
diff --git a/include/onnxruntime/core/framework/allocator.h b/include/onnxruntime/core/framework/allocator.h
index 15c15c6c143d2..c84d34cfd3cbe 100644
--- a/include/onnxruntime/core/framework/allocator.h
+++ b/include/onnxruntime/core/framework/allocator.h
@@ -101,7 +101,9 @@ class IAllocator {
const OrtMemoryInfo& Info() const { return memory_info_; };
// Each implementation of IAllocator can override and provide their own implementation
- virtual void GetStats(AllocatorStats* /*stats*/) { return; }
+ virtual void GetStats(AllocatorStats* stats) {
+ *stats = {};
+ }
static bool CalcMemSizeForArray(size_t nmemb, size_t size, size_t* out) noexcept {
return CalcMemSizeForArrayWithAlignment(nmemb, size, 0, out);
diff --git a/include/onnxruntime/core/optimizer/graph_transformer_level.h b/include/onnxruntime/core/optimizer/graph_transformer_level.h
index 111f38f9ccb6e..3f2126ce494a6 100644
--- a/include/onnxruntime/core/optimizer/graph_transformer_level.h
+++ b/include/onnxruntime/core/optimizer/graph_transformer_level.h
@@ -12,8 +12,9 @@ enum class TransformerLevel : int {
Level1, // basic optimizations
Level2, // extended optimizations
Level3, // layout optimizations
+ Level4, // unsupported datatypes optimizations
// The max level should always be same as the last level.
- MaxLevel = Level3
+ MaxLevel = Level4
};
} // namespace onnxruntime
diff --git a/include/onnxruntime/core/optimizer/graph_transformer_utils.h b/include/onnxruntime/core/optimizer/graph_transformer_utils.h
index 31b0f22340510..6f07ead935f4a 100644
--- a/include/onnxruntime/core/optimizer/graph_transformer_utils.h
+++ b/include/onnxruntime/core/optimizer/graph_transformer_utils.h
@@ -36,7 +36,8 @@ namespace optimizer_utils {
TODO: This is visible for testing at the moment, but we should rather make it private. */
InlinedVector> GenerateRewriteRules(
TransformerLevel level,
- const InlinedHashSet& rules_to_disable = {});
+ const InlinedHashSet& rules_to_disable = {},
+ const bool enable_cast_chain_elimination = false);
/** Given a TransformerLevel, this method generates a name for the rule-based graph transformer of that level. */
std::string GenerateRuleBasedTransformerName(TransformerLevel level);
@@ -45,7 +46,8 @@ std::string GenerateRuleBasedTransformerName(TransformerLevel level);
std::unique_ptr GenerateRuleBasedGraphTransformer(
TransformerLevel level,
const InlinedHashSet& rules_to_disable,
- const InlinedHashSet& compatible_execution_providers);
+ const InlinedHashSet& compatible_execution_providers,
+ const bool enable_cast_chain_elimination = false);
/** Generates all predefined (both rule-based and non-rule-based) transformers for this level.
Any transformers or rewrite rules named in rules_and_transformers_to_disable will be excluded. */
diff --git a/include/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_options.h b/include/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_options.h
index 0c9095f566fad..11cc6f131dab3 100644
--- a/include/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_options.h
+++ b/include/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_options.h
@@ -8,6 +8,7 @@
* - `kHasUserComputeStream`: Indicates whether a user-provided compute stream is used.
* - `kUserComputeStream`: Specifies the user-provided compute stream.
* - `kMaxWorkspaceSize`: Sets the maximum workspace size for GPU memory allocation.
+ * - 'kMaxSharedMemSize': Sets the maximum amount of shared memory that TensorRT kernels are allowed to use
* - `kDumpSubgraphs`: Enables or disables dumping of subgraphs for debugging.
* - `kDetailedBuildLog`: Enables or disables detailed build logs for debugging.
* - `kProfilesMinShapes`: Specifies the minimum shapes for profiling.
@@ -24,6 +25,7 @@ constexpr const char* kDeviceId = "device_id";
constexpr const char* kHasUserComputeStream = "has_user_compute_stream";
constexpr const char* kUserComputeStream = "user_compute_stream";
constexpr const char* kMaxWorkspaceSize = "nv_max_workspace_size";
+constexpr const char* kMaxSharedMemSize = "nv_max_shared_mem_size";
constexpr const char* kDumpSubgraphs = "nv_dump_subgraphs";
constexpr const char* kDetailedBuildLog = "nv_detailed_build_log";
constexpr const char* kProfilesMinShapes = "nv_profile_min_shapes";
diff --git a/include/onnxruntime/core/session/environment.h b/include/onnxruntime/core/session/environment.h
index 776045a97cae5..a0053ffd3e3e3 100644
--- a/include/onnxruntime/core/session/environment.h
+++ b/include/onnxruntime/core/session/environment.h
@@ -6,6 +6,8 @@
#include
#include
#include
+#include
+#include
#include "core/common/common.h"
#include "core/common/basic_types.h"
diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h
index a2f518ae09a4b..0892accec40b0 100644
--- a/include/onnxruntime/core/session/onnxruntime_c_api.h
+++ b/include/onnxruntime/core/session/onnxruntime_c_api.h
@@ -340,6 +340,26 @@ typedef struct OrtAllocator {
* those made during session initialization. This allows for separate memory management strategies for these allocations.
*/
void*(ORT_API_CALL* Reserve)(struct OrtAllocator* this_, size_t size); ///< Returns a pointer to an allocated block of `size` bytes
+
+ /**
+ * @brief Function used to get the statistics of the allocator.
+ *
+ * Return a pointer to the OrtKeyValuePairs structure that contains the statistics of the allocator
+ * and the user should call OrtApi::ReleaseKeyValuePairs.
+ * Supported keys are:
+ * - Limit: Bytes limit of the allocator. -1 if no limit is set.
+ * - InUse: Number of bytes in use.
+ * - TotalAllocated: The total number of allocated bytes by the allocator.
+ * - MaxInUse: The maximum bytes in use.
+ * - NumAllocs: Number of allocations.
+ * - NumReserves: Number of reserves. (Number of calls to Reserve() in arena-based allocators)
+ * - NumArenaExtensions: Number of arena extensions (Relevant only for arena based allocators)
+ * - NumArenaShrinkages: Number of arena shrinkages (Relevant only for arena based allocators)
+ * - MaxAllocSize: The max single allocation seen.
+ *
+ * NOTE: If the allocator does not implement this function, the OrtKeyValuePairs instance will be empty.
+ */
+ ORT_API2_STATUS(GetStats, _In_ const struct OrtAllocator* this_, _Outptr_ OrtKeyValuePairs** out);
} OrtAllocator;
typedef void(ORT_API_CALL* OrtLoggingFunction)(
@@ -355,6 +375,7 @@ typedef enum GraphOptimizationLevel {
ORT_DISABLE_ALL = 0,
ORT_ENABLE_BASIC = 1,
ORT_ENABLE_EXTENDED = 2,
+ ORT_ENABLE_LAYOUT = 3,
ORT_ENABLE_ALL = 99
} GraphOptimizationLevel;
@@ -672,6 +693,7 @@ typedef struct OrtTensorRTProviderOptions {
typedef struct OrtMIGraphXProviderOptions {
int device_id; // hip device id.
int migraphx_fp16_enable; // MIGraphX FP16 precision. Default 0 = false, nonzero = true
+ int migraphx_fp8_enable; // MIGraphX FP8 precision. Default 0 = false, nonzero = true
int migraphx_int8_enable; // MIGraphX INT8 precision. Default 0 = false, nonzero = true
int migraphx_use_native_calibration_table; // MIGraphx INT8 cal table. Default 0 = false, noznero = true
const char* migraphx_int8_calibration_table_name; // MIGraphx INT8 calibration table name
@@ -680,6 +702,21 @@ typedef struct OrtMIGraphXProviderOptions {
int migraphx_load_compiled_model; // migraphx int8 cal table. Default 0 = false, noznero = true
const char* migraphx_load_model_path; // migraphx model path name
bool migraphx_exhaustive_tune; // migraphx tuned compile Default = false
+
+ /** \brief MIGraphX memory limit (To use all possible memory pass in maximum size_t)
+ * Defaults to SIZE_MAX.
+ * \note If a ::OrtArenaCfg has been applied, it will override this field
+ */
+ size_t migraphx_mem_limit;
+
+ /** \brief Strategy used to grow the memory arena
+ * 0 = kNextPowerOfTwo
+ * 1 = kSameAsRequested
+ * Defaults to 0.
+ * \note If a ::OrtArenaCfg has been applied, it will override this field
+ */
+ int migraphx_arena_extend_strategy;
+
} OrtMIGraphXProviderOptions;
/** \brief OpenVINO Provider Options
@@ -5275,6 +5312,22 @@ struct OrtApi {
* \since Version 1.23
*/
ORT_API2_STATUS(GetTensorSizeInBytes, _In_ const OrtValue* ort_value, _Out_ size_t* size);
+
+ /** \brief Calls OrtAllocator::GetStats function
+ *
+ * Return a pointer to the OrtKeyValuePairs structure that contains the statistics of the allocator
+ * and the user should call OrtApi::ReleaseKeyValuePairs.
+ *
+ * NOTE: If the allocator does not implement this function, the OrtKeyValuePairs instance will be empty.
+ *
+ * \param[in] ort_allocator The allocator to get stats from
+ * \param[out] out A pointer to the OrtKeyValuePairs instance that contains the stats
+ *
+ * \snippet{doc} snippets.dox OrtStatus Return Value
+ *
+ * \since Version 1.23.
+ */
+ ORT_API2_STATUS(AllocatorGetStats, _In_ const OrtAllocator* ort_allocator, _Outptr_ OrtKeyValuePairs** out);
};
/*
diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h
index c7f81264115c6..08e8736e9e591 100644
--- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h
+++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h
@@ -1740,15 +1740,11 @@ struct ConstValueImpl : Base {
size_t GetStringTensorElementLength(size_t element_index) const;
///
- /// Returns the total size of the tensor data in bytes.
+ /// Returns the total size of the tensor data in bytes. Throws an exception if the OrtValue
+ /// does not contain a tensor or if it contains a tensor that contains strings.
+ /// For numeric tensors, this is sizeof(element_type) * total_element_count.
///
/// The total size of the tensor data in bytes
- /// Throws an exception if the OrtValue does not contain a tensor or
- /// if it contains a tensor that contains strings
- ///
- /// For numeric tensors, this is sizeof(element_type) * total_element_count.
- ///
- ///
size_t GetTensorSizeInBytes() const; ///< Wraps OrtApi::GetTensorSizeInBytes
#if !defined(DISABLE_SPARSE_TENSORS)
@@ -2155,6 +2151,12 @@ struct AllocatorImpl : Base {
MemoryAllocation GetAllocation(size_t size);
void Free(void* p);
ConstMemoryInfo GetInfo() const;
+
+ /** \brief Function that returns the statistics of the allocator.
+ *
+ * \return A pointer to a KeyValuePairs object that will be filled with the allocator statistics.
+ */
+ KeyValuePairs GetStats() const;
};
} // namespace detail
diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h
index 6cd52732b923b..25936038ba297 100644
--- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h
+++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h
@@ -243,6 +243,12 @@ inline ConstMemoryInfo AllocatorImpl::GetInfo() const {
return ConstMemoryInfo{out};
}
+template
+inline KeyValuePairs AllocatorImpl::GetStats() const {
+ OrtKeyValuePairs* out;
+ ThrowOnError(GetApi().AllocatorGetStats(this->p_, &out));
+ return KeyValuePairs(out);
+}
} // namespace detail
inline AllocatorWithDefaultOptions::AllocatorWithDefaultOptions() {
diff --git a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h
index 5497d7c71a393..97e53e6acee5a 100644
--- a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h
+++ b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h
@@ -67,6 +67,10 @@ static const char* const kOrtSessionOptionsEnableQuantQDQCleanup = "session.enab
// GeluApproximation has side effects which may change the inference results. It is disabled by default due to this.
static const char* const kOrtSessionOptionsEnableGeluApproximation = "optimization.enable_gelu_approximation";
+// Enable or disable Cast chain elimination in graph optimization. "0": disable; "1": enable. The default is "0".
+// CastElimination with chain elimination has side effects which may change the inference results. It is disabled by default due to this.
+static const char* const kOrtSessionOptionsEnableCastChainElimination = "optimization.enable_cast_chain_elimination";
+
// This setting controls whether to enable AheadOfTime function inlining.
// AOT function inlining examines the graph and attempts to inline as many locally defined functions in the model
// as possible with the help of enabled execution providers.
@@ -107,6 +111,37 @@ static const char* const kOrtSessionOptionsMemoryOptimizerProbeConfig = "optimiz
// Default is an empty string which means no optimizers are disabled.
static const char* const kOrtSessionOptionsDisableSpecifiedOptimizers = "optimization.disable_specified_optimizers";
+// It controls whether to run graph optimizations in loop or not.
+//
+// "0": disable. Graph Optimization Loop is disabled.
+// ```
+// Level 2 --> Level 3 --> InsertCastTransforms --> Level 4
+// ^ |
+// | "No Loop" |
+// | |
+// X xxxxxxxxxxx X
+// ```
+// "1": enable. Graph Optimization Loop is enabled, such that, if optimizations at Level 4 are applied then
+// the loop will check for any other valid optimization that can happen.
+// ```
+// Level 2 --> Level 3 --> InsertCastTransforms --> Level 4
+// ^ |
+// | "Loop only depending on Level 4" |
+// | |
+// ---------------------------------------------------
+// ```
+// "2": enable. Graph Optimization Loop is enabled, such that, if optimizations at Level 2 or above are applied then
+// The loop will check for any other valid optimization that can happen.
+// ```
+// Level 2 --> Level 3 --> InsertCastTransforms --> Level 4
+// ^ |
+// | "Loop" |
+// | |
+// ---------------------------------------------------
+// ```
+// Default value is set to "1".
+static const char* const kOrtSessionOptionsGraphOptimizationsLoopLevel = "session.graph_optimizations_loop_level";
+
// Enable or disable using device allocator for allocating initialized tensor memory. "1": enable; "0": disable. The default is "0".
// Using device allocators means the memory allocation is made using malloc/new.
static const char* const kOrtSessionOptionsUseDeviceAllocatorForInitializers = "session.use_device_allocator_for_initializers";
diff --git a/java/src/main/java/ai/onnxruntime/OrtSession.java b/java/src/main/java/ai/onnxruntime/OrtSession.java
index f15ad938463a7..a399d5080ca16 100644
--- a/java/src/main/java/ai/onnxruntime/OrtSession.java
+++ b/java/src/main/java/ai/onnxruntime/OrtSession.java
@@ -652,6 +652,8 @@ public enum OptLevel {
* graph.
*/
EXTENDED_OPT(2),
+ /** Applies all the layout optimizations like NCHW and NCHWC to the ONNX graph. */
+ LAYOUT_OPT(3),
/** Applies all available optimizations to the ONNX graph. */
ALL_OPT(99);
diff --git a/java/src/main/native/OrtJniUtil.c b/java/src/main/native/OrtJniUtil.c
index 6a3c279073860..fe19015d642f0 100644
--- a/java/src/main/native/OrtJniUtil.c
+++ b/java/src/main/native/OrtJniUtil.c
@@ -47,6 +47,8 @@ GraphOptimizationLevel convertOptimizationLevel(jint level) {
return ORT_ENABLE_BASIC;
case 2:
return ORT_ENABLE_EXTENDED;
+ case 3:
+ return ORT_ENABLE_LAYOUT;
case 99:
return ORT_ENABLE_ALL;
default:
diff --git a/js/common/lib/inference-session.ts b/js/common/lib/inference-session.ts
index 4ef4891b5b46a..4a670e24aa6b7 100644
--- a/js/common/lib/inference-session.ts
+++ b/js/common/lib/inference-session.ts
@@ -81,7 +81,7 @@ export declare namespace InferenceSession {
*
* This setting is available only in ONNXRuntime (Node.js binding and react-native) or WebAssembly backend
*/
- graphOptimizationLevel?: 'disabled' | 'basic' | 'extended' | 'all';
+ graphOptimizationLevel?: 'disabled' | 'basic' | 'extended' | 'layout' | 'all';
/**
* Whether enable CPU memory arena.
diff --git a/js/node/src/session_options_helper.cc b/js/node/src/session_options_helper.cc
index b189b45556306..7fff751a29186 100644
--- a/js/node/src/session_options_helper.cc
+++ b/js/node/src/session_options_helper.cc
@@ -31,6 +31,7 @@ const std::unordered_map GRAPH_OPT_LEVEL_NA
{"disabled", ORT_DISABLE_ALL},
{"basic", ORT_ENABLE_BASIC},
{"extended", ORT_ENABLE_EXTENDED},
+ {"layout", ORT_ENABLE_LAYOUT},
{"all", ORT_ENABLE_ALL}};
const std::unordered_map EXECUTION_MODE_NAME_TO_ID_MAP = {{"sequential", ORT_SEQUENTIAL},
diff --git a/js/react_native/android/src/main/java/ai/onnxruntime/reactnative/OnnxruntimeModule.java b/js/react_native/android/src/main/java/ai/onnxruntime/reactnative/OnnxruntimeModule.java
index 1be8c22b40da8..496db5a6087e6 100644
--- a/js/react_native/android/src/main/java/ai/onnxruntime/reactnative/OnnxruntimeModule.java
+++ b/js/react_native/android/src/main/java/ai/onnxruntime/reactnative/OnnxruntimeModule.java
@@ -326,6 +326,7 @@ public WritableMap run(String key, ReadableMap input, ReadableArray output, Read
{"disabled", SessionOptions.OptLevel.NO_OPT},
{"basic", SessionOptions.OptLevel.BASIC_OPT},
{"extended", SessionOptions.OptLevel.EXTENDED_OPT},
+ // {"layout", SessionOptions.OptLevel.LAYOUT_OPT},
{"all", SessionOptions.OptLevel.ALL_OPT},
})
.collect(Collectors.toMap(p -> (String)p[0], p -> (SessionOptions.OptLevel)p[1]));
diff --git a/js/react_native/ios/OnnxruntimeModule.mm b/js/react_native/ios/OnnxruntimeModule.mm
index d3527aad6ae38..b1b55075d26bc 100644
--- a/js/react_native/ios/OnnxruntimeModule.mm
+++ b/js/react_native/ios/OnnxruntimeModule.mm
@@ -301,6 +301,7 @@ - (NSDictionary*)run:(NSString*)url
@"disabled" : @(ORT_DISABLE_ALL),
@"basic" : @(ORT_ENABLE_BASIC),
@"extended" : @(ORT_ENABLE_EXTENDED),
+ @"layout" : @(ORT_ENABLE_LAYOUT),
@"all" : @(ORT_ENABLE_ALL)
};
diff --git a/js/web/lib/wasm/session-options.ts b/js/web/lib/wasm/session-options.ts
index cd787379220c1..26d07b4347131 100644
--- a/js/web/lib/wasm/session-options.ts
+++ b/js/web/lib/wasm/session-options.ts
@@ -14,6 +14,8 @@ const getGraphOptimzationLevel = (graphOptimizationLevel: string | unknown): num
return 1;
case 'extended':
return 2;
+ case 'layout':
+ return 3;
case 'all':
return 99;
default:
diff --git a/js/web/lib/wasm/wasm-factory.ts b/js/web/lib/wasm/wasm-factory.ts
index 23eb2f0978feb..f2a28396d7486 100644
--- a/js/web/lib/wasm/wasm-factory.ts
+++ b/js/web/lib/wasm/wasm-factory.ts
@@ -151,7 +151,12 @@ export const initializeWebAssembly = async (flags: Env.WebAssemblyFlags): Promis
const wasmPathOverride = (wasmPathOverrideFlag as URL)?.href ?? wasmPathOverrideFlag;
const wasmBinaryOverride = flags.wasmBinary;
- const [objectUrl, ortWasmFactory] = await importWasmModule(mjsPathOverride, wasmPrefixOverride, numThreads > 1);
+ const [objectUrl, ortWasmFactory] = await importWasmModule(
+ mjsPathOverride,
+ wasmPrefixOverride,
+ numThreads > 1,
+ !!wasmBinaryOverride || !!wasmPathOverride,
+ );
let isTimeout = false;
diff --git a/js/web/lib/wasm/wasm-utils-import.ts b/js/web/lib/wasm/wasm-utils-import.ts
index a8e27f6f334bc..d9180e220c80c 100644
--- a/js/web/lib/wasm/wasm-utils-import.ts
+++ b/js/web/lib/wasm/wasm-utils-import.ts
@@ -234,9 +234,45 @@ export const importWasmModule = async (
urlOverride: string | undefined,
prefixOverride: string | undefined,
isMultiThreaded: boolean,
+ isWasmOverridden: boolean,
): Promise<[undefined | string, EmscriptenModuleFactory]> => {
- if (!urlOverride && !prefixOverride && embeddedWasmModule && scriptSrc && isSameOrigin(scriptSrc)) {
- return [undefined, embeddedWasmModule];
+ //
+ // Check if we should use the embedded module.
+ //
+
+ // To use the embedded module, it should be available, and no URL override or prefix override should be specified.
+ let useEmbeddedModule = embeddedWasmModule && !(urlOverride || prefixOverride);
+ if (useEmbeddedModule) {
+ if (!scriptSrc) {
+ // no URL info available.
+ //
+ // Note: when the embedded module is available, it means the current script is ESM. Usually, in ESM, the
+ // `import.meta.url` is available. But in some cases (eg. Cloudflare Workers), the value of `import.meta.url`
+ // can be `null` or `undefined`. In this case, we can only load the embedded module when:
+ //
+ // 1. The WebAssembly module binary is overridden:
+ // ```js
+ // env.wasm.wasmPaths = undefined; // or not specified
+ // env.wasm.wasmBinary = /* a Uint8Array containing the WebAssembly binary */;
+ // ```
+ //
+ // 2. The ".wasm" only is overridden.
+ // ```js
+ // env.wasm.wasmPaths = { wasm: /* URL of the .wasm file */ };
+ // ```
+ //
+ if (isWasmOverridden && !isMultiThreaded) {
+ useEmbeddedModule = true;
+ } else {
+ throw new Error('cannot determine the script source URL.');
+ }
+ } else {
+ // if the script source is available, we can check if it is from the same origin.
+ useEmbeddedModule = isSameOrigin(scriptSrc);
+ }
+ }
+ if (useEmbeddedModule) {
+ return [undefined, embeddedWasmModule!];
} else {
const wasmModuleFilename = !BUILD_DEFS.DISABLE_JSEP
? 'ort-wasm-simd-threaded.jsep.mjs'
diff --git a/js/web/script/test-runner-cli-args.ts b/js/web/script/test-runner-cli-args.ts
index f546f58a28bfa..3bfb89164393e 100644
--- a/js/web/script/test-runner-cli-args.ts
+++ b/js/web/script/test-runner-cli-args.ts
@@ -58,7 +58,7 @@ Options:
*** Session Options ***
-u=<...>, --optimized-model-file-path=<...> Specify whether to dump the optimized model.
-o=<...>, --graph-optimization-level=<...> Specify graph optimization level.
- Default is 'all'. Valid values are 'disabled', 'basic', 'extended', 'all'.
+ Default is 'all'. Valid values are 'disabled', 'basic', 'extended', 'layout', 'all'.
-i=<...>, --io-binding=<...> Specify the IO binding testing type. Should be one of the following:
none (default)
gpu-tensor use pre-allocated GPU tensors for inputs and outputs
@@ -195,7 +195,7 @@ export interface TestRunnerCliArgs {
/**
* Specify graph optimization level
*/
- graphOptimizationLevel: 'disabled' | 'basic' | 'extended' | 'all';
+ graphOptimizationLevel: 'disabled' | 'basic' | 'extended' | 'layout' | 'all';
cpuOptions?: InferenceSession.CpuExecutionProviderOption;
cudaOptions?: InferenceSession.CudaExecutionProviderOption;
@@ -480,7 +480,7 @@ export function parseTestRunnerCliArgs(cmdlineArgs: string[]): TestRunnerCliArgs
const graphOptimizationLevel = args['graph-optimization-level'] || args.o || 'all';
if (
typeof graphOptimizationLevel !== 'string' ||
- ['disabled', 'basic', 'extended', 'all'].indexOf(graphOptimizationLevel) === -1
+ ['disabled', 'basic', 'extended', 'layout', 'all'].indexOf(graphOptimizationLevel) === -1
) {
throw new Error(`graph optimization level is invalid: ${graphOptimizationLevel}`);
}
diff --git a/objectivec/include/ort_enums.h b/objectivec/include/ort_enums.h
index 78de233972ccf..61a127f1a4b55 100644
--- a/objectivec/include/ort_enums.h
+++ b/objectivec/include/ort_enums.h
@@ -50,6 +50,7 @@ typedef NS_ENUM(int32_t, ORTGraphOptimizationLevel) {
ORTGraphOptimizationLevelNone,
ORTGraphOptimizationLevelBasic,
ORTGraphOptimizationLevelExtended,
+ ORTGraphOptimizationLevelLayout,
ORTGraphOptimizationLevelAll,
};
diff --git a/objectivec/ort_enums.mm b/objectivec/ort_enums.mm
index 60939812df531..5fcbe34e5e8a4 100644
--- a/objectivec/ort_enums.mm
+++ b/objectivec/ort_enums.mm
@@ -68,6 +68,7 @@
{ORTGraphOptimizationLevelNone, ORT_DISABLE_ALL},
{ORTGraphOptimizationLevelBasic, ORT_ENABLE_BASIC},
{ORTGraphOptimizationLevelExtended, ORT_ENABLE_EXTENDED},
+ {ORTGraphOptimizationLevelLayout, ORT_ENABLE_LAYOUT},
{ORTGraphOptimizationLevelAll, ORT_ENABLE_ALL},
};
diff --git a/onnxruntime/contrib_ops/cuda/bert/cudnn_fmha/cudnn_flash_attention.cu b/onnxruntime/contrib_ops/cuda/bert/cudnn_fmha/cudnn_flash_attention.cc
similarity index 99%
rename from onnxruntime/contrib_ops/cuda/bert/cudnn_fmha/cudnn_flash_attention.cu
rename to onnxruntime/contrib_ops/cuda/bert/cudnn_fmha/cudnn_flash_attention.cc
index aabbe4cc7582a..ec5deccf655ff 100644
--- a/onnxruntime/contrib_ops/cuda/bert/cudnn_fmha/cudnn_flash_attention.cu
+++ b/onnxruntime/contrib_ops/cuda/bert/cudnn_fmha/cudnn_flash_attention.cc
@@ -313,8 +313,7 @@ struct BytesHash {
// Use thread local caches because cuDNN execution plans are not guaranteed to be thread safe.
// TODO(tianleiwu): since the key includes sequence lengths, we may want to limit the cache size.
-thread_local
-std::unordered_map, BytesHash > mha_graph_cache;
+thread_local std::unordered_map, BytesHash > mha_graph_cache;
void run(
void* output,
@@ -341,7 +340,6 @@ void run(
cudnnHandle_t handle,
Stream* stream,
AllocatorPtr allocator) {
-
GraphParams params;
params.batch_size = batch_size;
params.num_heads_q = num_heads_q;
diff --git a/onnxruntime/contrib_ops/cuda/bert/embed_layer_norm_impl.cu b/onnxruntime/contrib_ops/cuda/bert/embed_layer_norm_impl.cu
index 8a17e945df3f3..e6f1798f6ef72 100644
--- a/onnxruntime/contrib_ops/cuda/bert/embed_layer_norm_impl.cu
+++ b/onnxruntime/contrib_ops/cuda/bert/embed_layer_norm_impl.cu
@@ -39,7 +39,13 @@ __global__ void MaskIndexKernelSmall(int sequence_length, const int* mask, int*
// blockIdx.x is b
const int offset = blockIdx.x * sequence_length; // batch strides of sequence_length
+#if CUDA_VERSION >= 12090
+ ::cuda::minimum min;
+#else
+ // Deprecated on CUDA 12.9
cub::Min min;
+#endif
+
int thread_data(sequence_length);
const int idx = offset + threadIdx.x;
@@ -66,7 +72,13 @@ __global__ void MaskIndexKernel(int sequence_length, const int* mask, int* mask_
// blockIdx.x is b
const int offset = blockIdx.x * sequence_length; // batch strides of sequence_length
+#if CUDA_VERSION >= 12090
+ ::cuda::minimum min;
+#else
+ // Deprecated on CUDA 12.9
cub::Min min;
+#endif
+
int thread_data(sequence_length);
for (int i = threadIdx.x; i < sequence_length; i += TPB) {
diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm_configs.h b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm_configs.h
index e48ef3f154883..e7ba5d4b54f05 100644
--- a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm_configs.h
+++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm_configs.h
@@ -376,24 +376,19 @@ struct CutlassGemmConfig {
};
inline std::ostream& operator<<(std::ostream& out, CutlassGemmConfig const& config) {
- // clang-format off
- if (config.is_tma_warp_specialized)
- {
- out << "tile_config_sm90_enum: " << config.getTileConfigAsInt()
- << ", mainloop_schedule_enum: " << int(config.mainloop_schedule)
- << ", epilogue_schedule_enum: " << int(config.epilogue_schedule)
- << ", cluster_shape_enum: " << int(config.cluster_shape)
- << ", enable_cuda_kernel: " << (config.enableCudaKernel ? "true" : "false");
- }
- else
- {
- out << "tile_config_enum: " << config.getTileConfigAsInt()
- << ", split_k_style_enum: " << int(config.split_k_style)
- << ", split_k_factor: " << config.split_k_factor
- << ", stages: " << config.stages
- << ", enable_cuda_kernel: " << (config.enableCudaKernel ? "true" : "false");
- }
- // clang-format on
+ if (config.is_tma_warp_specialized) {
+ out << "tile_config_sm90_enum: " << config.getTileConfigAsInt()
+ << ", mainloop_schedule_enum: " << int(config.mainloop_schedule)
+ << ", epilogue_schedule_enum: " << int(config.epilogue_schedule)
+ << ", cluster_shape_enum: " << int(config.cluster_shape)
+ << ", enable_cuda_kernel: " << (config.enableCudaKernel ? "true" : "false");
+ } else {
+ out << "tile_config_enum: " << config.getTileConfigAsInt()
+ << ", split_k_style_enum: " << int(config.split_k_style)
+ << ", split_k_factor: " << config.split_k_factor
+ << ", stages: " << config.stages
+ << ", enable_cuda_kernel: " << (config.enableCudaKernel ? "true" : "false");
+ }
return out;
}
diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_preprocessors.cc b/onnxruntime/contrib_ops/cuda/llm/cutlass_preprocessors.cc
deleted file mode 100644
index 50ee944161538..0000000000000
--- a/onnxruntime/contrib_ops/cuda/llm/cutlass_preprocessors.cc
+++ /dev/null
@@ -1,687 +0,0 @@
-/*
- * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "contrib_ops/cuda/llm/cutlass_preprocessors.h"
-
-#include
-
-#include "core/common/common.h"
-#include "contrib_ops/cuda/llm/common/cuda_runtime_utils.h"
-#include "contrib_ops/cuda/llm/common/logger.h"
-
-#if defined(__GNUC__)
-#pragma GCC diagnostic push
-#pragma GCC diagnostic ignored "-Wunused-parameter"
-#endif
-
-#include "cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h"
-
-#if defined(__GNUC__)
-#pragma GCC diagnostic pop
-#endif
-
-using namespace onnxruntime::llm::common;
-
-namespace onnxruntime::llm {
-namespace kernels {
-namespace cutlass_kernels {
-
-struct LayoutDetails {
- enum class Layout {
- UNKNOWN,
- ROW_MAJOR,
- COLUMN_MAJOR
- };
-
- Layout layoutB = Layout::UNKNOWN;
- int rows_per_column_tile = 1;
- int columns_interleaved = 1;
-
- bool uses_imma_ldsm = false;
-};
-
-template
-struct getLayoutDetails {
-};
-
-template <>
-struct getLayoutDetails {
- LayoutDetails operator()() {
- LayoutDetails layout_details;
- layout_details.layoutB = LayoutDetails::Layout::ROW_MAJOR;
- return layout_details;
- }
-};
-
-template <>
-struct getLayoutDetails {
- LayoutDetails operator()() {
- LayoutDetails layout_details;
- layout_details.layoutB = LayoutDetails::Layout::COLUMN_MAJOR;
- return layout_details;
- }
-};
-
-template
-struct getLayoutDetails> {
- LayoutDetails operator()() {
- LayoutDetails layout_details;
- layout_details.layoutB = LayoutDetails::Layout::COLUMN_MAJOR;
- layout_details.rows_per_column_tile = RowsPerTile;
- layout_details.columns_interleaved = ColumnsInterleaved;
- return layout_details;
- }
-};
-
-template
-LayoutDetails getLayoutDetailsForArchAndQuantType() {
- using CompileTraits = cutlass::gemm::kernel::LayoutDetailsB;
- using LayoutB = typename CompileTraits::Layout;
- using MmaOperator = typename CompileTraits::Operator;
- LayoutDetails details = getLayoutDetails()();
- details.uses_imma_ldsm = std::is_same::value;
- return details;
-}
-
-template
-LayoutDetails getLayoutDetailsForArch(QuantType quant_type) {
- LayoutDetails details;
- switch (quant_type) {
- case QuantType::W8_A16:
- details = getLayoutDetailsForArchAndQuantType();
- break;
- case QuantType::W4_A16:
- details = getLayoutDetailsForArchAndQuantType();
- break;
- case QuantType::W4_AFP8:
- details = getLayoutDetailsForArchAndQuantType();
- break;
- default:
- ORT_THROW("Unsupported quantization type");
- }
- return details;
-}
-
-LayoutDetails getLayoutDetailsForTransform(QuantType quant_type, int arch) {
- if (arch >= 75 && arch < 80) {
- return getLayoutDetailsForArch(quant_type);
- } else if (arch >= 80 && arch < 90) {
- return getLayoutDetailsForArch(quant_type);
- } else if (arch >= 90 && arch < 100) {
- return getLayoutDetailsForArch(quant_type);
- } else if (arch >= 100) {
- return getLayoutDetailsForArch(quant_type);
- } else {
- ORT_THROW("Unsupported Arch");
- return LayoutDetails();
- }
-}
-
-// Permutes the rows of B in a way that is compatible with Turing+ architectures.
-//
-// Throws an error for other architectures.
-// The data is permuted such that:
-// For W8_A16, each group of 16 rows is permuted using the map below:
-// 0 1 8 9 2 3 10 11 4 5 12 13 6 7 14 15
-// For W4_A16, each group of 32 rows is permuted using the map below:
-// 0 1 8 9 16 17 24 25 2 3 10 11 18 19 26 27 4 5 12 13 20 21 28 29 6 7 14 15 22 23 30 31
-// For W4_A8, see the map in the code. The idea is similar to above.
-// The goal of this permutation is to ensure data ends up in the correct threads after
-// we execute LDSM. It counteracts the effect of the data being of different widths.
-// For more information about the expected layouts, see the MMA section in the PTX docs.
-std::vector get_permutation_map(QuantType quant_type) {
- if (quant_type == QuantType::W8_A16) {
- return {0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15};
- } else if (quant_type == QuantType::W4_A16) {
- return {0, 1, 8, 9, 16, 17, 24, 25, 2, 3, 10, 11, 18, 19, 26, 27, 4, 5, 12, 13, 20, 21, 28, 29, 6, 7, 14, 15,
- 22, 23, 30, 31};
- } else if (quant_type == QuantType::W4_AFP8) {
- return {0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23, 8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15,
- 28, 29, 30, 31};
- } else {
- ORT_THROW("Invalid quantization type for LDSM permutation");
- }
-}
-
-void permute_B_rows_for_mixed_gemm(int8_t* permuted_quantized_tensor, int8_t const* quantized_tensor,
- std::vector const& shape, QuantType quant_type) {
- ORT_LLM_LOG_TRACE(__PRETTY_FUNCTION__);
- // We only want to run this step for weight only quant.
- std::vector row_permutation = get_permutation_map(quant_type);
-
- ORT_ENFORCE(shape.size() == 2 || shape.size() == 3, "Shape must be 2-D or 3-D");
- const size_t num_experts = shape.size() == 2 ? 1 : shape[0];
- const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1];
- const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2];
-
- int const BITS_PER_ELT = get_weight_quant_bits(quant_type);
- int const K = 16 / BITS_PER_ELT;
-
- uint32_t const* input_byte_ptr = reinterpret_cast(quantized_tensor);
- uint32_t* output_byte_ptr = reinterpret_cast(permuted_quantized_tensor);
-
- int MMA_SHAPE_N = 8;
- int B_ROWS_PER_MMA = 8 * K;
- int const elts_in_int32 = 32 / BITS_PER_ELT;
-
- int const num_vec_cols = num_cols / elts_in_int32;
-
- ORT_ENFORCE(num_rows % B_ROWS_PER_MMA == 0,
- "Invalid shape for quantized tensor. Number of rows of quantized matrix must be a multiple of ",
- B_ROWS_PER_MMA);
- ORT_ENFORCE(num_cols % MMA_SHAPE_N == 0,
- "Invalid shape for quantized tensor. On turing/Ampere, the number of cols must be a multiple of ",
- MMA_SHAPE_N);
-
- ORT_ENFORCE(size_t(B_ROWS_PER_MMA) == row_permutation.size(), "Unexpected number of LDSM rows permuted.");
-
- for (int expert = 0; expert < static_cast(num_experts); ++expert) {
- const int64_t matrix_offset = expert * int64_t(num_rows) * int64_t(num_vec_cols);
- for (int base_row = 0; base_row < static_cast(num_rows); base_row += B_ROWS_PER_MMA) {
- for (int tile_row = 0; tile_row < B_ROWS_PER_MMA; ++tile_row) {
- for (int write_col = 0; write_col < num_vec_cols; ++write_col) {
- int const write_row = base_row + tile_row;
- int const tile_read_row = row_permutation[tile_row];
- int const read_row = base_row + tile_read_row;
- int const read_col = write_col;
-
- const int64_t read_offset = matrix_offset + int64_t(read_row) * num_vec_cols + read_col;
- const int64_t write_offset = matrix_offset + int64_t(write_row) * num_vec_cols + write_col;
-
- output_byte_ptr[write_offset] = input_byte_ptr[read_offset];
- }
- }
- }
- }
-}
-
-// We need to use this transpose to correctly handle packed int4 and int8 data
-// The reason this code is relatively complex is that the "trivial" loops took a substantial
-// amount of time to transpose leading to long preprocessing times. This seemed to be a big
-// issue for relatively large models.
-template
-void subbyte_transpose_impl(
- int8_t* transposed_quantized_tensor, int8_t const* quantized_tensor, std::vector const& shape) {
- ORT_LLM_LOG_TRACE(__PRETTY_FUNCTION__);
- constexpr int bits_per_elt = get_weight_quant_bits(quant_type);
-
- ORT_ENFORCE(shape.size() == 2 || shape.size() == 3, "Shape must be 2-D or 3-D");
- const size_t num_experts = shape.size() == 2 ? 1 : shape[0];
- const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1];
- const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2];
-
- const size_t col_bytes = num_cols * bits_per_elt / 8;
- const size_t col_bytes_trans = num_rows * bits_per_elt / 8;
-
- uint8_t const* input_byte_ptr = reinterpret_cast(quantized_tensor);
- uint8_t* output_byte_ptr = reinterpret_cast(transposed_quantized_tensor);
-
- static constexpr int ELTS_PER_BYTE = 8 / bits_per_elt;
-
- static constexpr int M_TILE_L1 = 64;
- static constexpr int N_TILE_L1 = M_TILE_L1 / ELTS_PER_BYTE;
- uint8_t cache_buf[M_TILE_L1][N_TILE_L1];
-
- static constexpr int VECTOR_WIDTH = std::min(32, N_TILE_L1);
-
- // We assume the dims are a multiple of vector width. Our kernels only handle dims which are multiples
- // of 64 for weight-only quantization. As a result, this seemed like a reasonable tradeoff because it
- // allows GCC to emit vector instructions.
- ORT_ENFORCE(!(col_bytes_trans % VECTOR_WIDTH) && !(col_bytes % VECTOR_WIDTH),
- "Number of bytes for rows and cols must be a multiple of ", VECTOR_WIDTH, ". However, num_rows_bytes = ",
- col_bytes_trans, " and num_col_bytes = ", col_bytes);
-
- for (size_t expert = 0; expert < num_experts; ++expert) {
- const size_t matrix_offset = expert * num_rows * col_bytes;
- for (size_t row_tile_start = 0; row_tile_start < num_rows; row_tile_start += M_TILE_L1) {
- for (size_t col_tile_start_byte = 0; col_tile_start_byte < col_bytes; col_tile_start_byte += N_TILE_L1) {
- int const row_limit = std::min(row_tile_start + M_TILE_L1, num_rows);
- int const col_limit = std::min(col_tile_start_byte + N_TILE_L1, col_bytes);
-
- for (int ii = 0; ii < M_TILE_L1; ++ii) {
- int const row = row_tile_start + ii;
-
- for (int jj = 0; jj < N_TILE_L1; jj += VECTOR_WIDTH) {
- int const col = col_tile_start_byte + jj;
-
- const size_t logical_src_offset = matrix_offset + row * col_bytes + col;
-
- if (row < row_limit && col < col_limit) {
- for (int v = 0; v < VECTOR_WIDTH; ++v) {
- cache_buf[ii][jj + v] = input_byte_ptr[logical_src_offset + v];
- }
- }
- }
- }
-
- if constexpr (bits_per_elt == 8) {
- for (int ii = 0; ii < M_TILE_L1; ++ii) {
- for (int jj = ii + 1; jj < N_TILE_L1; ++jj) {
- std::swap(cache_buf[ii][jj], cache_buf[jj][ii]);
- }
- }
- } else if constexpr (bits_per_elt == 4) {
- for (int ii = 0; ii < M_TILE_L1; ++ii) {
- // Using M_TILE_L1 here is deliberate since we assume that the cache tile
- // is square in the number of elements (not necessarily the number of bytes).
- for (int jj = ii + 1; jj < M_TILE_L1; ++jj) {
- int const ii_byte = ii / ELTS_PER_BYTE;
- int const ii_bit_offset = ii % ELTS_PER_BYTE;
-
- int const jj_byte = jj / ELTS_PER_BYTE;
- int const jj_bit_offset = jj % ELTS_PER_BYTE;
-
- uint8_t src_elt = 0xF & (cache_buf[ii][jj_byte] >> (4 * jj_bit_offset));
- uint8_t tgt_elt = 0xF & (cache_buf[jj][ii_byte] >> (4 * ii_bit_offset));
-
- cache_buf[ii][jj_byte] &= (0xF0 >> (4 * jj_bit_offset));
- cache_buf[jj][ii_byte] &= (0xF0 >> (4 * ii_bit_offset));
-
- cache_buf[ii][jj_byte] |= (tgt_elt << (4 * jj_bit_offset));
- cache_buf[jj][ii_byte] |= (src_elt << (4 * ii_bit_offset));
- }
- }
- } else {
- ORT_THROW("Unsupported quantization type.");
- }
-
- const size_t row_tile_start_trans = col_tile_start_byte * ELTS_PER_BYTE;
- const size_t col_tile_start_byte_trans = row_tile_start / ELTS_PER_BYTE;
-
- int const row_limit_trans = std::min(row_tile_start_trans + M_TILE_L1, num_cols);
- int const col_limit_trans = std::min(col_tile_start_byte_trans + N_TILE_L1, col_bytes_trans);
-
- for (int ii = 0; ii < M_TILE_L1; ++ii) {
- int const row = row_tile_start_trans + ii;
- for (int jj = 0; jj < N_TILE_L1; jj += VECTOR_WIDTH) {
- int const col = col_tile_start_byte_trans + jj;
-
- const size_t logical_tgt_offset = matrix_offset + row * col_bytes_trans + col;
-
- if (row < row_limit_trans && col < col_limit_trans) {
- for (int v = 0; v < VECTOR_WIDTH; ++v) {
- output_byte_ptr[logical_tgt_offset + v] = cache_buf[ii][jj + v];
- }
- }
- }
- }
- }
- }
- }
-}
-
-void subbyte_transpose(int8_t* transposed_quantized_tensor, int8_t const* quantized_tensor,
- std::vector const& shape, QuantType quant_type) {
- ORT_LLM_LOG_TRACE(__PRETTY_FUNCTION__);
-
- if (quant_type == QuantType::W8_A16) {
- subbyte_transpose_impl(transposed_quantized_tensor, quantized_tensor, shape);
- } else if (quant_type == QuantType::W4_A16) {
- subbyte_transpose_impl(transposed_quantized_tensor, quantized_tensor, shape);
- } else if (quant_type == QuantType::W4_AFP8) {
- subbyte_transpose_impl(transposed_quantized_tensor, quantized_tensor, shape);
- } else {
- ORT_THROW("Invalid quant_type");
- }
-}
-
-void add_bias_and_interleave_int8s_inplace(int8_t* int8_tensor, const size_t num_elts) {
- for (size_t ii = 0; ii < num_elts; ++ii) {
- int8_tensor[ii] = int8_t(int(int8_tensor[ii]) + 128);
- }
-
- // Step 2 will transform the layout of a 32-bit register in CUDA in order to match the int4 layout. This has no
- // performance benefit and is purely so that int4 and int8 have the same layout.
- // Pictorially, this does the following:
- // bit 32 0
- // [elt_3 elt_2 elt_1 elt_0] (each elt occupies 8 bits)
- //
- // And it will rearrange the output 32 bit register to be the following:
- // bit 32 0
- // [elt_3 elt_1 elt_2 elt_0] (each elt occupies 8 bits)
-
- ORT_ENFORCE(num_elts % 4 == 0, "Dimensions of int8 tensor must be a multiple of 4 for register relayout");
- for (size_t base = 0; base < num_elts; base += 4) {
- std::swap(int8_tensor[base + 1], int8_tensor[base + 2]);
- }
-}
-
-void add_bias_and_interleave_int4s_inplace(int8_t* packed_int4_tensor, const size_t num_elts) {
- size_t const num_bytes = num_elts / 2;
-
- // Step 1 will be to transform all the int4s to unsigned in order to make the dequantize take as little
- // instructions as possible in the CUDA code.
- for (size_t ii = 0; ii < num_bytes; ++ii) {
- int8_t transformed_packed_int4s = 0;
- int8_t transformed_first_elt = (int8_t(packed_int4_tensor[ii] << 4) >> 4) + 8; // The double shift here is to ensure sign extension
- int8_t transformed_second_elt = (packed_int4_tensor[ii] >> 4) + 8;
-
- ORT_ENFORCE(
- transformed_first_elt >= 0 && transformed_first_elt <= 15, "Illegal result for int4 transform (first elt)");
- ORT_ENFORCE(transformed_second_elt >= 0 && transformed_second_elt <= 15,
- "Illegal result for int4 transform (second elt)");
-
- // We don't need to mask in these ops since everything should be in the range 0-15
- transformed_packed_int4s |= transformed_first_elt;
- transformed_packed_int4s |= (transformed_second_elt << 4);
- packed_int4_tensor[ii] = transformed_packed_int4s;
- }
-
- // Step 2 will transform the layout of a 32-bit register in CUDA in order to minimize the number of shift & logical
- // instructions That are needed to extract the int4s in the GEMM main loop. Pictorially, the loop below will do the
- // following: Take as input a 32 bit register with layout: bit 32 0
- // [elt_7 elt_6 elt_5 elt_4 elt_3 elt_2 elt_1 elt_0] (each elt occupies 4 bits)
- //
- // And it will rearrange the output 32 bit register to be the following:
- // bit 32 0
- // [elt_7 elt_5 elt_3 elt_1 elt_6 elt_4 elt_2 elt_0] (each elt occupies 4 bits)
-
- ORT_ENFORCE(num_bytes % 4 == 0, "Dimensions of int4 tensor must be a multiple of 8 for register relayout");
- const size_t num_registers = num_bytes / 4;
-
- uint32_t* register_ptr = reinterpret_cast(packed_int4_tensor);
- for (size_t ii = 0; ii < num_registers; ++ii) {
- const uint32_t current_register = register_ptr[ii];
- uint32_t transformed_register = 0;
-
- for (int dest_idx = 0; dest_idx < 8; ++dest_idx) {
- int const src_idx = dest_idx < 4 ? 2 * dest_idx : 2 * (dest_idx - 4) + 1;
- int const src_shift = 4 * src_idx;
- int const dest_shift = 4 * dest_idx;
-
- const uint32_t src_bits = (current_register >> src_shift) & 0xF;
- transformed_register |= (src_bits << dest_shift);
- }
- register_ptr[ii] = transformed_register;
- }
-}
-
-void add_bias_and_interleave_quantized_tensor_inplace(int8_t* tensor, const size_t num_elts, QuantType quant_type) {
- ORT_LLM_LOG_TRACE(__PRETTY_FUNCTION__);
- if (quant_type == QuantType::W8_A16) {
- add_bias_and_interleave_int8s_inplace(tensor, num_elts);
- } else if (quant_type == QuantType::W4_A16 || quant_type == QuantType::W4_AFP8) {
- // W4_AFP8 uses the same preprocessor as W4_A16 because the FP8 data must
- // be converted to FP16 before the scales can be applied using CUDA cores.
- // As a result, we still want permute the data so that it is well aligned
- // for conversion to FP16.
- add_bias_and_interleave_int4s_inplace(tensor, num_elts);
- } else {
- ORT_THROW("Invalid quantization type for interleaving.");
- }
-}
-
-void interleave_column_major_tensor(int8_t* interleaved_quantized_tensor, int8_t const* quantized_tensor,
- std::vector const& shape, QuantType quant_type, LayoutDetails details) {
- ORT_LLM_LOG_TRACE(__PRETTY_FUNCTION__);
-
- ORT_ENFORCE(shape.size() == 2 || shape.size() == 3, "Shape must be 2-D or 3-D");
- const size_t num_experts = shape.size() == 2 ? 1 : shape[0];
- const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1];
- const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2];
-
- int const BITS_PER_ELT = get_weight_quant_bits(quant_type);
- int const elts_in_int32 = 32 / BITS_PER_ELT;
-
- int const rows_per_tile = details.rows_per_column_tile;
-
- ORT_ENFORCE(!(num_rows % elts_in_int32),
- "The number of rows must be a multiple of ", elts_in_int32, " but the number of rows is ", num_rows);
-
- uint32_t const* input_byte_ptr = reinterpret_cast(quantized_tensor);
- uint32_t* output_byte_ptr = reinterpret_cast(interleaved_quantized_tensor);
-
- ORT_ENFORCE(!(num_rows % rows_per_tile),
- "The number of rows must be a multiple of ", rows_per_tile, " but the number of rows is ", num_rows);
-
- int const num_vec_rows = num_rows / elts_in_int32;
- int const vec_rows_per_tile = rows_per_tile / elts_in_int32;
- int const interleave = details.columns_interleaved;
-
- for (int expert = 0; expert < static_cast(num_experts); ++expert) {
- const int64_t matrix_offset = expert * int64_t(num_vec_rows) * int64_t(num_cols);
- for (int64_t read_col = 0; read_col < static_cast(num_cols); ++read_col) {
- const int64_t write_col = read_col / interleave;
- for (int base_vec_row = 0; base_vec_row < num_vec_rows; base_vec_row += vec_rows_per_tile) {
- for (int vec_read_row = base_vec_row;
- vec_read_row < std::min(num_vec_rows, base_vec_row + vec_rows_per_tile); ++vec_read_row) {
- const int64_t vec_write_row = interleave * base_vec_row + vec_rows_per_tile * (read_col % interleave) + vec_read_row % vec_rows_per_tile;
-
- const int64_t read_offset = matrix_offset + read_col * num_vec_rows + vec_read_row;
- const int64_t write_offset = matrix_offset + int64_t(write_col) * num_vec_rows * interleave + vec_write_row;
- output_byte_ptr[write_offset] = input_byte_ptr[read_offset];
- }
- }
- }
- }
-}
-
-void preprocess_weights_for_mixed_gemm(int8_t* preprocessed_quantized_weight, int8_t const* row_major_quantized_weight,
- std::vector const& shape, QuantType quant_type, bool force_interleave) {
- int arch = getSMVersion();
- if (force_interleave && arch >= 90) {
- // Workaround for MOE which doesn't have specialized Hopper/Blackwell kernels yet
- arch = 80;
- }
- // Force use sm80 kernel for GB20x.
- if (arch >= 100) {
- arch = 80;
- }
- LayoutDetails details = getLayoutDetailsForTransform(quant_type, arch);
-
- ORT_ENFORCE(shape.size() == 2 || shape.size() == 3, "Shape must be 2-D or 3-D");
-
- size_t num_elts = 1;
- for (auto const& dim : shape) {
- num_elts *= dim;
- }
-
- const size_t num_bytes = num_elts * get_weight_quant_bits(quant_type) / 8;
-
- std::vector src_buf(num_bytes);
- std::vector dst_buf(num_bytes);
- std::copy(row_major_quantized_weight, row_major_quantized_weight + num_bytes, src_buf.begin());
-
- // Works on row major data, so issue this permutation first.
- if (details.uses_imma_ldsm) {
- permute_B_rows_for_mixed_gemm(dst_buf.data(), src_buf.data(), shape, quant_type);
- src_buf.swap(dst_buf);
- }
-
- if (details.layoutB == LayoutDetails::Layout::COLUMN_MAJOR) {
- subbyte_transpose(dst_buf.data(), src_buf.data(), shape, quant_type);
- src_buf.swap(dst_buf);
- }
-
- if (details.columns_interleaved > 1 && arch != 90) {
- interleave_column_major_tensor(dst_buf.data(), src_buf.data(), shape, quant_type, details);
- src_buf.swap(dst_buf);
- }
-
- add_bias_and_interleave_quantized_tensor_inplace(src_buf.data(), num_elts, quant_type);
- std::copy(src_buf.begin(), src_buf.end(), preprocessed_quantized_weight);
-}
-
-/*
- Arguments:
- input_weight_ptr - the weight tensor to be quantized. Must be 2-D or 3-D and of type FP16.
-
- quant_type - the type of the output quantization weight.
-
- This function does symmetric quantization on 2-D or 3-D tensors. It uses the full int range and assumes the
- zero-point is zero and will automatically construct the scales.
-
- It always quantizes the last axis of the tensor. For 3-D tensors, it operates in "batched" mode where the tensor is
- viewed as a stack of matrices and a scale is produced for each column of every matrix.
-
-Outputs
- processed_quantized_weight - quantized AND processed weight for GEMM. This MUST be used with the CUTLASS GEMM
- unprocessed_quantized_weight - quantized but unprocessed weights. Useful for reference checking.
- scale_ptr - scales for the quantized weight.
-
- Note that the returned quantized_weights will be preprocessed in a way to accelerate the mixed type GEMM. The data
- layout may not make sense if printed.
-
- Shapes:
- quant_type == int8:
- If weight is a [m,n] matrix, quantized_weights will have shape [m,n] and scales of shape [n]
- If weight is a [b,m,n] tensor, unprocessed_quantized_weight will have shape [b,m,n] and scales of shape [b,n]
- quant_type == int4:
- If weight is a [m,n] matrix, quantized_weights will have shape [m, ceil(n/2)] and scales of shape [n]
- If weight is a [b,m,n] tensor, unprocessed_quantized_weight will have shape [b,m, ceil(n/2)] and scales of shape
- [b,n]
-
- The quantized_weight will be of type torch.int8 and have two int4 values packed in a single byte. This is the
- reason for halving the shape. At the time of writing this code, there was not an elegant way to handle this kind
- of batched quantization using torch's quantized tensors (to the best of the author's knowledge). Scale tensors
- must have a dimension of 1, which breaks the semantics we need for batched weights.
- */
-
-template
-void symmetric_quantize(int8_t* processed_quantized_weight, int8_t* unprocessed_quantized_weight,
- ComputeType* scale_ptr, WeightType const* input_weight_ptr, std::vector const& shape, QuantType quant_type,
- bool force_interleave) {
- ORT_ENFORCE(processed_quantized_weight, "Processed quantized tensor is NULL");
- ORT_ENFORCE(scale_ptr, "Scale output pointer is NULL");
- ORT_ENFORCE(input_weight_ptr, "Input weight pointer is NULL");
-
- ORT_ENFORCE(shape.size() == 2 || shape.size() == 3, "Shape must be 2-D or 3-D");
- const size_t num_experts = shape.size() == 2 ? 1 : shape[0];
- const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1];
- const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2];
-
- int const bits_in_type = get_weight_quant_bits(quant_type);
- int const bytes_per_out_col = num_cols * bits_in_type / 8;
-
- int const bits_per_weigtht_element = get_weight_quant_bits(quant_type);
-
- std::vector weight_buf;
- if (unprocessed_quantized_weight == nullptr) {
- weight_buf.resize(num_experts * num_rows * num_cols);
- unprocessed_quantized_weight = weight_buf.data();
- }
-
- int const input_mat_size = num_rows * num_cols;
- int const quantized_mat_size = num_rows * bytes_per_out_col;
- float const quant_range_scale = 1.f / float(1 << (bits_in_type - 1));
-
- std::vector per_col_max(num_cols);
-
- for (int expert = 0; expert < static_cast(num_experts); ++expert) {
- WeightType const* current_weight = input_weight_ptr + expert * input_mat_size;
- int8_t* current_quantized_weight = unprocessed_quantized_weight + expert * quantized_mat_size;
-
- // First we find the per column max for this expert weight.
- for (size_t jj = 0; jj < num_cols; ++jj) {
- per_col_max[jj] = 0.f;
- }
-
- for (size_t ii = 0; ii < num_rows; ++ii) {
- WeightType const* current_weight_row = current_weight + ii * num_cols;
- for (size_t jj = 0; jj < num_cols; ++jj) {
- per_col_max[jj] = std::max(per_col_max[jj], std::abs(float(current_weight_row[jj])));
- }
- }
-
- // Then, we construct the scales
- ComputeType* current_scales = scale_ptr + expert * num_cols;
- for (size_t jj = 0; jj < num_cols; ++jj) {
- per_col_max[jj] *= quant_range_scale;
- current_scales[jj] = ComputeType(per_col_max[jj]);
- }
-
- // Finally, construct the weights.
- for (size_t ii = 0; ii < num_rows; ++ii) {
- int8_t* current_quantized_weight_row = current_quantized_weight + ii * bytes_per_out_col;
- WeightType const* current_weight_row = current_weight + ii * num_cols;
- for (int jj = 0; jj < bytes_per_out_col; ++jj) {
- if (bits_per_weigtht_element == 8) {
- float const col_scale = per_col_max[jj];
- float const weight_elt = float(current_weight_row[jj]);
- float const scaled_weight = (col_scale != 0.0f) ? round(weight_elt / col_scale) : 0.0f;
- const int8_t clipped_weight = int8_t(std::max(-128.f, std::min(127.f, scaled_weight)));
- current_quantized_weight_row[jj] = clipped_weight;
- } else if (bits_per_weigtht_element == 4) {
- // We will pack two int4 elements per iteration of the inner loop.
- int8_t packed_int4s = 0;
- for (int packed_idx = 0; packed_idx < 2; ++packed_idx) {
- int const input_idx = 2 * jj + packed_idx;
- if (input_idx < static_cast(num_cols)) {
- float const col_scale = per_col_max[input_idx];
- float const weight_elt = float(current_weight_row[input_idx]);
- float const scaled_weight = (col_scale != 0.0f) ? round(weight_elt / col_scale) : 0.0f;
- int int_weight = int(scaled_weight);
- const int8_t clipped_weight = std::max(-8, std::min(7, int_weight));
-
- // Kill the sign extension bits (hence 0x0F mask) then shift to upper bits
- // if packing the second int4 and or the bits into the final result.
- packed_int4s |= ((clipped_weight & 0x0F) << (4 * packed_idx));
- }
- }
- current_quantized_weight_row[jj] = packed_int4s;
- } else {
- ORT_THROW("Unsupported quantization type");
- }
- }
- }
- }
-
- preprocess_weights_for_mixed_gemm(
- processed_quantized_weight, unprocessed_quantized_weight, shape, quant_type, force_interleave);
-}
-
-template void symmetric_quantize(
- int8_t*, int8_t*, half*, float const*, std::vector const&, QuantType, bool);
-
-template void symmetric_quantize(
- int8_t*, int8_t*, half*, half const*, std::vector const&, QuantType, bool);
-
-template void symmetric_quantize<__nv_bfloat16, __nv_bfloat16>(
- int8_t*, int8_t*, __nv_bfloat16*, __nv_bfloat16 const*, std::vector const&, QuantType, bool);
-
-template void symmetric_quantize<__nv_bfloat16, float>(
- int8_t*, int8_t*, __nv_bfloat16*, float const*, std::vector const&, QuantType, bool);
-
-template
-void symmetric_quantize(int8_t* processed_quantized_weight, ComputeType* scale_ptr, WeightType const* input_weight_ptr,
- std::vector const& shape, QuantType quant_type, bool force_interleave) {
- symmetric_quantize(
- processed_quantized_weight, nullptr, scale_ptr, input_weight_ptr, shape, quant_type, force_interleave);
-}
-
-template void symmetric_quantize(
- int8_t*, float*, float const*, std::vector const&, QuantType, bool);
-
-template void symmetric_quantize(
- int8_t*, half*, float const*, std::vector const&, QuantType, bool);
-
-template void symmetric_quantize(int8_t*, half*, half const*, std::vector const&, QuantType, bool);
-
-template void symmetric_quantize<__nv_bfloat16, __nv_bfloat16>(
- int8_t*, __nv_bfloat16*, __nv_bfloat16 const*, std::vector const&, QuantType, bool);
-
-template void symmetric_quantize<__nv_bfloat16, half>(
- int8_t*, __nv_bfloat16*, half const*, std::vector const&, QuantType, bool);
-
-template void symmetric_quantize(
- int8_t*, half*, __nv_bfloat16 const*, std::vector const&, QuantType, bool);
-
-template void symmetric_quantize<__nv_bfloat16, float>(
- int8_t*, __nv_bfloat16*, float const*, std::vector const&, QuantType, bool);
-
-} // namespace cutlass_kernels
-} // namespace kernels
-} // namespace onnxruntime::llm
diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_preprocessors.h b/onnxruntime/contrib_ops/cuda/llm/cutlass_preprocessors.h
deleted file mode 100644
index 3e83852228e24..0000000000000
--- a/onnxruntime/contrib_ops/cuda/llm/cutlass_preprocessors.h
+++ /dev/null
@@ -1,75 +0,0 @@
-/*
- * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#pragma once
-
-#include
-#include
-#include
-
-#include "core/common/common.h"
-
-namespace onnxruntime::llm {
-namespace kernels {
-namespace cutlass_kernels {
-
-enum class QuantType {
- W8_A16,
- W4_A16,
- W4_AFP8
-};
-
-constexpr int get_weight_quant_bits(QuantType quant_type) {
- switch (quant_type) {
- case QuantType::W8_A16:
- return 8;
- case QuantType::W4_A16:
- return 4;
- case QuantType::W4_AFP8:
- return 4;
- default:
- ORT_THROW("Invalid quant_type");
- return -1;
- }
-}
-
-// Shapes here can be 2 or 3D. 2-D shapes are [num_rows, num_cols]
-// 3-D shapes are [num_experts, num_rows, num_cols]
-void permute_B_rows_for_mixed_gemm(int8_t* permuted_quantized_tensor, int8_t const* quantized_tensor,
- std::vector const& shape, QuantType quant_type);
-
-void subbyte_transpose(int8_t* transposed_quantized_tensor, int8_t const* quantized_tensor,
- std::vector const& shape, QuantType quant_type);
-
-void add_bias_and_interleave_quantized_tensor_inplace(int8_t* tensor, const size_t num_elts, QuantType quant_type);
-
-void preprocess_weights_for_mixed_gemm(int8_t* preprocessed_quantized_weight, int8_t const* row_major_quantized_weight,
- std::vector const& shape, QuantType quant_type, bool force_interleave = false);
-
-template
-void symmetric_quantize(int8_t* processed_quantized_weight, ComputeType* scale_ptr, WeightType const* input_weight_ptr,
- std::vector const& shape, QuantType quant_type, bool force_interleave);
-
-// This is exposed so that we can write tests that use the processed weights for CUTLASS but the unprocessed weight
-// to implement a simple reference implementation.
-template
-void symmetric_quantize(int8_t* processed_quantized_weight, int8_t* unprocessed_quantized_weight,
- ComputeType* scale_ptr, WeightType const* input_weight_ptr, std::vector const& shape, QuantType quant_type,
- bool force_interleave);
-
-} // namespace cutlass_kernels
-} // namespace kernels
-} // namespace onnxruntime::llm
diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/bf16_int4_gemm_scaleonly.cu b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/bf16_int4_gemm_scaleonly.cu
new file mode 100644
index 0000000000000..de834db4b7440
--- /dev/null
+++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/bf16_int4_gemm_scaleonly.cu
@@ -0,0 +1,27 @@
+/*
+ * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "contrib_ops/cuda/llm/fpA_intB_gemm/fpA_intB_gemm_template.h"
+
+namespace onnxruntime::llm {
+namespace kernels {
+namespace cutlass_kernels {
+#ifdef ENABLE_BF16
+template class CutlassFpAIntBGemmRunner<__nv_bfloat16, cutlass::uint4b_t, cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY>;
+#endif
+} // namespace cutlass_kernels
+} // namespace kernels
+} // namespace onnxruntime::llm
diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/bf16_int8_gemm_scaleonly.cu b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/bf16_int8_gemm_scaleonly.cu
new file mode 100644
index 0000000000000..97c71615ce54d
--- /dev/null
+++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/bf16_int8_gemm_scaleonly.cu
@@ -0,0 +1,27 @@
+/*
+ * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "contrib_ops/cuda/llm/fpA_intB_gemm/fpA_intB_gemm_template.h"
+
+namespace onnxruntime::llm {
+namespace kernels {
+namespace cutlass_kernels {
+#ifdef ENABLE_BF16
+template class CutlassFpAIntBGemmRunner<__nv_bfloat16, uint8_t, cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY>;
+#endif
+} // namespace cutlass_kernels
+} // namespace kernels
+} // namespace onnxruntime::llm
diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/fp16_int4_gemm_scaleonly.cu b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/fp16_int4_gemm_scaleonly.cu
new file mode 100644
index 0000000000000..5905f48b9b479
--- /dev/null
+++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/fp16_int4_gemm_scaleonly.cu
@@ -0,0 +1,25 @@
+/*
+ * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "contrib_ops/cuda/llm/fpA_intB_gemm/fpA_intB_gemm_template.h"
+
+namespace onnxruntime::llm {
+namespace kernels {
+namespace cutlass_kernels {
+template class CutlassFpAIntBGemmRunner;
+} // namespace cutlass_kernels
+} // namespace kernels
+} // namespace onnxruntime::llm
diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/fp16_int8_gemm_scaleonly.cu b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/fp16_int8_gemm_scaleonly.cu
new file mode 100644
index 0000000000000..aa3e984ab2945
--- /dev/null
+++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/fp16_int8_gemm_scaleonly.cu
@@ -0,0 +1,25 @@
+/*
+ * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "contrib_ops/cuda/llm/fpA_intB_gemm/fpA_intB_gemm_template.h"
+
+namespace onnxruntime::llm {
+namespace kernels {
+namespace cutlass_kernels {
+template class CutlassFpAIntBGemmRunner;
+} // namespace cutlass_kernels
+} // namespace kernels
+} // namespace onnxruntime::llm
diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/launchers/fpA_intB_gemm_launcher_1.generated.cu b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/launchers/fpA_intB_gemm_launcher_1.generated.cu
index 468d53f336e55..ba513a831b432 100644
--- a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/launchers/fpA_intB_gemm_launcher_1.generated.cu
+++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/launchers/fpA_intB_gemm_launcher_1.generated.cu
@@ -70,6 +70,69 @@ half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGe
);
+template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<16>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>,
+cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> (
+const half*, const cutlass::uint4b_t*, const half*, const half*, const half*, const float,
+half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int*
+);
+
+
+template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<32>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>,
+cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> (
+const half*, const cutlass::uint4b_t*, const half*, const half*, const half*, const float,
+half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int*
+);
+
+
+template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<64>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>,
+cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> (
+const half*, const cutlass::uint4b_t*, const half*, const half*, const half*, const float,
+half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int*
+);
+
+
+template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<128>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>,
+cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> (
+const half*, const cutlass::uint4b_t*, const half*, const half*, const half*, const float,
+half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int*
+);
+
+
+template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<128>, cute::Int<64>>, cute::Shape, cute::Int<2>, cute::Int<1>>,
+cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> (
+const half*, const cutlass::uint4b_t*, const half*, const half*, const half*, const float,
+half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int*
+);
+
+
+template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<256>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>,
+cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> (
+const half*, const cutlass::uint4b_t*, const half*, const half*, const half*, const float,
+half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int*
+);
+
+
+template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<256>, cute::Int<64>>, cute::Shape, cute::Int<2>, cute::Int<1>>,
+cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> (
+const half*, const cutlass::uint4b_t*, const half*, const half*, const half*, const float,
+half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int*
+);
+
+
template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<16>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>,
@@ -133,6 +196,69 @@ half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGe
);
+template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<16>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>,
+cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> (
+const half*, const uint8_t*, const half*, const half*, const half*, const float,
+half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int*
+);
+
+
+template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<32>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>,
+cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> (
+const half*, const uint8_t*, const half*, const half*, const half*, const float,
+half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int*
+);
+
+
+template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<64>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>,
+cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> (
+const half*, const uint8_t*, const half*, const half*, const half*, const float,
+half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int*
+);
+
+
+template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<128>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>,
+cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> (
+const half*, const uint8_t*, const half*, const half*, const half*, const float,
+half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int*
+);
+
+
+template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<128>, cute::Int<64>>, cute::Shape, cute::Int<2>, cute::Int<1>>,
+cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> (
+const half*, const uint8_t*, const half*, const half*, const half*, const float,
+half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int*
+);
+
+
+template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<256>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>,
+cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> (
+const half*, const uint8_t*, const half*, const half*, const half*, const float,
+half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int*
+);
+
+
+template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<256>, cute::Int<64>>, cute::Shape, cute::Int<2>, cute::Int<1>>,
+cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> (
+const half*, const uint8_t*, const half*, const half*, const half*, const float,
+half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int*
+);
+
+
template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, cutlass::uint4b_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16,
cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias,
cute::Shape, cute::Int<16>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>,
@@ -196,6 +322,69 @@ __nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::
);
+template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, cutlass::uint4b_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16,
+cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY, onnxruntime::llm::cutlass_extensions::EpilogueOpBias,
+cute::Shape, cute::Int<16>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>,
+cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> (
+const __nv_bfloat16*, const cutlass::uint4b_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float,
+__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int*
+);
+
+
+template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, cutlass::uint4b_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16,
+cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY, onnxruntime::llm::cutlass_extensions::EpilogueOpBias,
+cute::Shape, cute::Int<32>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>,
+cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> (
+const __nv_bfloat16*, const cutlass::uint4b_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float,
+__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int*
+);
+
+
+template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, cutlass::uint4b_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16,
+cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY, onnxruntime::llm::cutlass_extensions::EpilogueOpBias,
+cute::Shape, cute::Int<64>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>,
+cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> (
+const __nv_bfloat16*, const cutlass::uint4b_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float,
+__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int*
+);
+
+
+template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, cutlass::uint4b_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16,
+cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY, onnxruntime::llm::cutlass_extensions::EpilogueOpBias,
+cute::Shape, cute::Int<128>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>,
+cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> (
+const __nv_bfloat16*, const cutlass::uint4b_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float,
+__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int*
+);
+
+
+template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, cutlass::uint4b_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16,
+cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY, onnxruntime::llm::cutlass_extensions::EpilogueOpBias,
+cute::Shape, cute::Int<128>, cute::Int<64>>, cute::Shape, cute::Int<2>, cute::Int<1>>,
+cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> (
+const __nv_bfloat16*, const cutlass::uint4b_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float,
+__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int*
+);
+
+
+template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, cutlass::uint4b_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16,
+cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY, onnxruntime::llm::cutlass_extensions::EpilogueOpBias,
+cute::Shape, cute::Int<256>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>,
+cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> (
+const __nv_bfloat16*, const cutlass::uint4b_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float,
+__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int*
+);
+
+
+template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, cutlass::uint4b_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16,
+cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY, onnxruntime::llm::cutlass_extensions::EpilogueOpBias,
+cute::Shape, cute::Int<256>, cute::Int<64>>, cute::Shape, cute::Int<2>, cute::Int<1>>,
+cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> (
+const __nv_bfloat16*, const cutlass::uint4b_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float,
+__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int*
+);
+
+
template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, uint8_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16,
cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias,
cute::Shape, cute::Int<16>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>,
@@ -259,6 +448,69 @@ __nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::
);
+template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, uint8_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16,
+cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY, onnxruntime::llm::cutlass_extensions::EpilogueOpBias,
+cute::Shape, cute::Int<16>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>,
+cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> (
+const __nv_bfloat16*, const uint8_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float,
+__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int*
+);
+
+
+template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, uint8_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16,
+cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY, onnxruntime::llm::cutlass_extensions::EpilogueOpBias,
+cute::Shape, cute::Int<32>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>,
+cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> (
+const __nv_bfloat16*, const uint8_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float,
+__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int*
+);
+
+
+template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, uint8_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16,
+cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY, onnxruntime::llm::cutlass_extensions::EpilogueOpBias,
+cute::Shape, cute::Int<64>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>,
+cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> (
+const __nv_bfloat16*, const uint8_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float,
+__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int*
+);
+
+
+template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, uint8_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16,
+cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY, onnxruntime::llm::cutlass_extensions::EpilogueOpBias,
+cute::Shape, cute::Int<128>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>,
+cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> (
+const __nv_bfloat16*, const uint8_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float,
+__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int*
+);
+
+
+template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, uint8_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16,
+cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY, onnxruntime::llm::cutlass_extensions::EpilogueOpBias,
+cute::Shape, cute::Int<128>, cute::Int<64>>, cute::Shape, cute::Int<2>, cute::Int<1>>,
+cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> (
+const __nv_bfloat16*, const uint8_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float,
+__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int*
+);
+
+
+template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, uint8_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16,
+cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY, onnxruntime::llm::cutlass_extensions::EpilogueOpBias,
+cute::Shape, cute::Int<256>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>,
+cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> (
+const __nv_bfloat16*, const uint8_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float,
+__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int*
+);
+
+
+template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, uint8_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16,
+cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY, onnxruntime::llm::cutlass_extensions::EpilogueOpBias,
+cute::Shape, cute::Int<256>, cute::Int<64>>, cute::Shape, cute::Int<2>, cute::Int<1>>,
+cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> (
+const __nv_bfloat16*, const uint8_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float,
+__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int*
+);
+
+
} // namespace cutlass_kernels
} // namespace kernels
} // namespace onnxruntime::llm
diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/launchers/fpA_intB_gemm_launcher_2.generated.cu b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/launchers/fpA_intB_gemm_launcher_2.generated.cu
index 0156c83840b09..6c6318fa6c589 100644
--- a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/launchers/fpA_intB_gemm_launcher_2.generated.cu
+++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/launchers/fpA_intB_gemm_launcher_2.generated.cu
@@ -133,6 +133,132 @@ half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGe
);
+template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<16>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>,
+cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> (
+const half*, const cutlass::uint4b_t*, const half*, const half*, const half*, const float,
+half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int*
+);
+
+
+template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<16>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>,
+cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> (
+const half*, const cutlass::uint4b_t*, const half*, const half*, const half*, const float,
+half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int*
+);
+
+
+template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<32>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>,
+cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> (
+const half*, const cutlass::uint4b_t*, const half*, const half*, const half*, const float,
+half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int*
+);
+
+
+template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<32>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>,
+cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> (
+const half*, const cutlass::uint4b_t*, const half*, const half*, const half*, const float,
+half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int*
+);
+
+
+template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<64>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>,
+cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> (
+const half*, const cutlass::uint4b_t*, const half*, const half*, const half*, const float,
+half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int*
+);
+
+
+template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<64>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>,
+cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> (
+const half*, const cutlass::uint4b_t*, const half*, const half*, const half*, const float,
+half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int*
+);
+
+
+template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<128>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>,
+cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> (
+const half*, const cutlass::uint4b_t*, const half*, const half*, const half*, const float,
+half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int*
+);
+
+
+template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<128>, cute::Int<64>>, cute::Shape, cute::Int<2>, cute::Int<1>>,
+cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> (
+const half*, const cutlass::uint4b_t*, const half*, const half*, const half*, const float,
+half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int*
+);
+
+
+template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<128>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>,
+cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> (
+const half*, const cutlass::uint4b_t*, const half*, const half*, const half*, const float,
+half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int*
+);
+
+
+template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<128>, cute::Int<64>>, cute::Shape, cute::Int<2>, cute::Int<1>>,
+cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> (
+const half*, const cutlass::uint4b_t*, const half*, const half*, const half*, const float,
+half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int*
+);
+
+
+template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<256>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>,
+cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> (
+const half*, const cutlass::uint4b_t*, const half*, const half*, const half*, const float,
+half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int*
+);
+
+
+template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<256>, cute::Int<64>>, cute::Shape, cute::Int<2>, cute::Int<1>>,
+cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> (
+const half*, const cutlass::uint4b_t*, const half*, const half*, const half*, const float,
+half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int*
+);
+
+
+template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<256>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>,
+cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> (
+const half*, const cutlass::uint4b_t*, const half*, const half*, const half*, const float,
+half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int*
+);
+
+
+template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<256>, cute::Int<64>>, cute::Shape, cute::Int<2>, cute::Int<1>>,
+cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> (
+const half*, const cutlass::uint4b_t*, const half*, const half*, const half*, const float,
+half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int*
+);
+
+
template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<16>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>,
@@ -259,6 +385,132 @@ half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGe
);
+template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<16>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>,
+cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> (
+const half*, const uint8_t*, const half*, const half*, const half*, const float,
+half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int*
+);
+
+
+template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<16>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>,
+cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> (
+const half*, const uint8_t*, const half*, const half*, const half*, const float,
+half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int*
+);
+
+
+template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<32>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>,
+cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> (
+const half*, const uint8_t*, const half*, const half*, const half*, const float,
+half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int*
+);
+
+
+template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<32>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>,
+cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> (
+const half*, const uint8_t*, const half*, const half*, const half*, const float,
+half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int*
+);
+
+
+template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<64>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>,
+cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> (
+const half*, const uint8_t*, const half*, const half*, const half*, const float,
+half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int*
+);
+
+
+template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<64>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>,
+cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> (
+const half*, const uint8_t*, const half*, const half*, const half*, const float,
+half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int*
+);
+
+
+template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<128>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>,
+cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> (
+const half*, const uint8_t*, const half*, const half*, const half*, const float,
+half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int*
+);
+
+
+template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<128>, cute::Int<64>>, cute::Shape, cute::Int<2>, cute::Int<1>>,
+cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> (
+const half*, const uint8_t*, const half*, const half*, const half*, const float,
+half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int*
+);
+
+
+template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<128>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>,
+cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> (
+const half*, const uint8_t*, const half*, const half*, const half*, const float,
+half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int*
+);
+
+
+template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<128>, cute::Int<64>>, cute::Shape, cute::Int<2>, cute::Int<1>>,
+cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> (
+const half*, const uint8_t*, const half*, const half*, const half*, const float,
+half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int*
+);
+
+
+template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<256>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>,
+cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> (
+const half*, const uint8_t*, const half*, const half*, const half*, const float,
+half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int*
+);
+
+
+template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<256>, cute::Int<64>>, cute::Shape, cute::Int<2>, cute::Int<1>>,
+cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> (
+const half*, const uint8_t*, const half*, const half*, const half*, const float,
+half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int*
+);
+
+
+template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<256>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>,
+cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> (
+const half*, const uint8_t*, const half*, const half*, const half*, const float,
+half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int*
+);
+
+
+template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<256>, cute::Int<64>>, cute::Shape, cute::Int<2>, cute::Int<1>>,
+cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> (
+const half*, const uint8_t*, const half*, const half*, const half*, const float,
+half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int*
+);
+
+
template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, cutlass::uint4b_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16,
cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias,
cute::Shape, cute::Int<16>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>,
@@ -385,6 +637,132 @@ __nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::
);
+template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, cutlass::uint4b_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16,
+cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY, onnxruntime::llm::cutlass_extensions::EpilogueOpBias,
+cute::Shape, cute::Int<16>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>,
+cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> (
+const __nv_bfloat16*, const cutlass::uint4b_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float,
+__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int*
+);
+
+
+template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, cutlass::uint4b_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16,
+cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY, onnxruntime::llm::cutlass_extensions::EpilogueOpBias,
+cute::Shape, cute::Int<16>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>,
+cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> (
+const __nv_bfloat16*, const cutlass::uint4b_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float,
+__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int*
+);
+
+
+template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, cutlass::uint4b_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16,
+cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY, onnxruntime::llm::cutlass_extensions::EpilogueOpBias,
+cute::Shape, cute::Int<32>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>,
+cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> (
+const __nv_bfloat16*, const cutlass::uint4b_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float,
+__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int*
+);
+
+
+template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, cutlass::uint4b_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16,
+cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY, onnxruntime::llm::cutlass_extensions::EpilogueOpBias,
+cute::Shape, cute::Int<32>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>,
+cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> (
+const __nv_bfloat16*, const cutlass::uint4b_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float,
+__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int*
+);
+
+
+template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, cutlass::uint4b_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16,
+cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY, onnxruntime::llm::cutlass_extensions::EpilogueOpBias,
+cute::Shape, cute::Int<64>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>,
+cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> (
+const __nv_bfloat16*, const cutlass::uint4b_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float,
+__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int*
+);
+
+
+template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, cutlass::uint4b_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16,
+cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY, onnxruntime::llm::cutlass_extensions::EpilogueOpBias,
+cute::Shape, cute::Int<64>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>,
+cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> (
+const __nv_bfloat16*, const cutlass::uint4b_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float,
+__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int*
+);
+
+
+template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, cutlass::uint4b_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16,
+cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY, onnxruntime::llm::cutlass_extensions::EpilogueOpBias,
+cute::Shape