diff --git a/.github/actions/locate-vcvarsall-and-setup-env/action.yml b/.github/actions/locate-vcvarsall-and-setup-env/action.yml index 3066721e797ea..bf1016bf2265b 100644 --- a/.github/actions/locate-vcvarsall-and-setup-env/action.yml +++ b/.github/actions/locate-vcvarsall-and-setup-env/action.yml @@ -14,7 +14,7 @@ runs: steps: - name: Setup VCPKG - uses: microsoft/onnxruntime-github-actions/setup-build-tools@v0.0.6 + uses: microsoft/onnxruntime-github-actions/setup-build-tools@v0.0.7 with: vcpkg-version: '2025.04.09' vcpkg-hash: '31a28b58854b7c7b503db99bb2eb41582d9f835b401adf3bd0f680ef329faa4ab4278b987b586a2a6141e2c98f007833266a4e3b60c3164226a3905466a082ce' diff --git a/.github/workflows/android.yml b/.github/workflows/android.yml index 69ff9a1cec976..092b6fc8f5ce5 100644 --- a/.github/workflows/android.yml +++ b/.github/workflows/android.yml @@ -37,7 +37,7 @@ jobs: ndk-version: 28.0.13004108 - name: Get Docker Image using Action - uses: microsoft/onnxruntime-github-actions/build-docker-image@v0.0.6 + uses: microsoft/onnxruntime-github-actions/build-docker-image@v0.0.7 id: build_docker_image_step with: dockerfile: ${{ github.workspace }}/tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/Dockerfile @@ -131,7 +131,7 @@ jobs: architecture: x64 - - uses: microsoft/onnxruntime-github-actions/setup-build-tools@v0.0.6 + - uses: microsoft/onnxruntime-github-actions/setup-build-tools@v0.0.7 with: vcpkg-version: '2025.04.09' vcpkg-hash: '31a28b58854b7c7b503db99bb2eb41582d9f835b401adf3bd0f680ef329faa4ab4278b987b586a2a6141e2c98f007833266a4e3b60c3164226a3905466a082ce' diff --git a/.github/workflows/linux-wasm-ci-build-and-test-workflow.yml b/.github/workflows/linux-wasm-ci-build-and-test-workflow.yml index e53626d879dd1..fc9bb53659442 100644 --- a/.github/workflows/linux-wasm-ci-build-and-test-workflow.yml +++ b/.github/workflows/linux-wasm-ci-build-and-test-workflow.yml @@ -42,7 +42,7 @@ jobs: with: python-version: "3.12" architecture: ${{ env.buildArch }} - - uses: microsoft/onnxruntime-github-actions/setup-build-tools@v0.0.6 + - uses: microsoft/onnxruntime-github-actions/setup-build-tools@v0.0.7 with: vcpkg-version: '2025.03.19' vcpkg-hash: '17e96169cd3f266c4716fcdc1bb728e6a64f103941ece463a2834d50694eba4fb48f30135503fd466402afa139abc847ef630733c442595d1c34979f261b0114' @@ -57,29 +57,17 @@ jobs: core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || ''); core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || ''); - - name: Install EMSDK + - name: Build (simd + threads) run: | - set -ex - cd ${{ github.workspace }}/cmake/external/emsdk - ./emsdk install 4.0.4 - ./emsdk activate 4.0.4 - - - name: Build and test (browser) (simd + threads) - run: | - set -e -x - source ${{ github.workspace }}/cmake/external/emsdk/emsdk_env.sh - cd '${{ github.workspace }}' python ./tools/ci_build/build.py \ ${{ env.common_build_args }} \ --build_dir ${{ github.workspace }}/build/wasm_inferencing \ - --wasm_run_tests_in_browser + --skip_tests + working-directory: ${{ github.workspace }} - name: Build (simd + threads + JSEP) if: ${{ inputs.build_jsep == true }} run: | - set -e -x - source ${{ github.workspace }}/cmake/external/emsdk/emsdk_env.sh - cd '${{ github.workspace }}' python ./tools/ci_build/build.py \ ${{ env.common_build_args }} \ --build_dir ${{ github.workspace }}/build/wasm_inferencing_jsep \ @@ -87,13 +75,11 @@ jobs: --use_webnn \ --target onnxruntime_webassembly \ --skip_tests + working-directory: ${{ github.workspace }} - name: Build (simd + threads + WebGPU experimental) if: ${{ inputs.build_webgpu == true }} run: | - set -e -x - source ${{ github.workspace }}/cmake/external/emsdk/emsdk_env.sh - cd '${{ github.workspace }}' python ./tools/ci_build/build.py \ ${{ env.common_build_args }} \ --build_dir ${{ github.workspace }}/build/wasm_inferencing_webgpu \ @@ -102,6 +88,7 @@ jobs: --use_webnn \ --target onnxruntime_webassembly \ --skip_tests + working-directory: ${{ github.workspace }} - name: Create Artifacts if: ${{ inputs.skip_publish != true }} @@ -135,6 +122,28 @@ jobs: name: ${{ inputs.build_config }}_wasm_webgpu path: ${{ github.workspace }}/artifacts/wasm_webgpu + - name: Test (Node.js) (simd + threads) + # onnxruntime_test_all is currently only supported in Debug build because it requires exception, which is disabled in Release build. + if: ${{ inputs.build_config == 'Debug' }} + run: | + python ./tools/ci_build/build.py \ + ${{ env.common_build_args }} \ + --build_dir ${{ github.workspace }}/build/wasm_inferencing \ + --test + working-directory: ${{ github.workspace }} + + - name: Test (browser) (simd + threads) + # onnxruntime_test_all is currently only supported in Debug build because it requires exception, which is disabled in Release build. + if: ${{ inputs.build_config == 'Debug' }} + run: | + python ./tools/ci_build/build.py \ + ${{ env.common_build_args }} \ + --build_dir ${{ github.workspace }}/build/wasm_inferencing \ + --wasm_run_tests_in_browser \ + --target onnxruntime_test_all \ + --update --build --test + working-directory: ${{ github.workspace }} + - name: Publish test results if: ${{ always() && inputs.build_config == 'Debug' }} uses: actions/upload-artifact@v4 diff --git a/.github/workflows/linux_cuda_ci.yml b/.github/workflows/linux_cuda_ci.yml index 0dbe63371c7b8..38526e7a5c00f 100644 --- a/.github/workflows/linux_cuda_ci.yml +++ b/.github/workflows/linux_cuda_ci.yml @@ -50,7 +50,7 @@ jobs: - name: Checkout code uses: actions/checkout@v4 - - uses: microsoft/onnxruntime-github-actions/build-docker-image@v0.0.6 + - uses: microsoft/onnxruntime-github-actions/build-docker-image@v0.0.7 id: build_docker_image_step with: dockerfile: ${{ github.workspace }}/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda @@ -93,7 +93,7 @@ jobs: # So build.py --build_dir build/Release inside the container correctly finds the artifacts. - name: Test ONNX Runtime id: test_step - uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.6 + uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.7 with: docker_image: ${{ steps.build_docker_image_step.outputs.full-image-name }} build_config: Release diff --git a/.github/workflows/linux_minimal_build.yml b/.github/workflows/linux_minimal_build.yml index e68ef56cdb1ce..5f90d9430342e 100644 --- a/.github/workflows/linux_minimal_build.yml +++ b/.github/workflows/linux_minimal_build.yml @@ -43,7 +43,7 @@ jobs: core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || ''); core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || ''); - - uses: microsoft/onnxruntime-github-actions/setup-build-tools@v0.0.6 + - uses: microsoft/onnxruntime-github-actions/setup-build-tools@v0.0.7 with: vcpkg-version: '2025.04.09' vcpkg-hash: '31a28b58854b7c7b503db99bb2eb41582d9f835b401adf3bd0f680ef329faa4ab4278b987b586a2a6141e2c98f007833266a4e3b60c3164226a3905466a082ce' @@ -53,7 +53,7 @@ jobs: disable-terrapin: 'true' - name: Build Full ORT and Prepare Test Files - uses: microsoft/onnxruntime-github-actions/build-and-prep-ort-files@v0.0.6 + uses: microsoft/onnxruntime-github-actions/build-and-prep-ort-files@v0.0.7 - name: Upload Test Data Artifact uses: actions/upload-artifact@v4 @@ -80,7 +80,7 @@ jobs: node-version: 20 - name: Get Docker Image using Action - uses: microsoft/onnxruntime-github-actions/build-docker-image@v0.0.6 + uses: microsoft/onnxruntime-github-actions/build-docker-image@v0.0.7 id: build_docker_image_step with: dockerfile: ${{ github.workspace }}/tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/Dockerfile @@ -98,7 +98,7 @@ jobs: core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || ''); - name: Run Build 2 (Update) - uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.6 + uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.7 with: docker_image: ${{ steps.build_docker_image_step.outputs.full-image-name }} @@ -113,7 +113,7 @@ jobs: --enable_training_ops - name: Run Build 2 (Build) - uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.6 + uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.7 with: docker_image: ${{ steps.build_docker_image_step.outputs.full-image-name }} @@ -151,7 +151,7 @@ jobs: core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || ''); core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || ''); - - uses: microsoft/onnxruntime-github-actions/setup-build-tools@v0.0.6 + - uses: microsoft/onnxruntime-github-actions/setup-build-tools@v0.0.7 with: vcpkg-version: '2025.04.09' vcpkg-hash: '31a28b58854b7c7b503db99bb2eb41582d9f835b401adf3bd0f680ef329faa4ab4278b987b586a2a6141e2c98f007833266a4e3b60c3164226a3905466a082ce' @@ -161,7 +161,7 @@ jobs: disable-terrapin: 'true' - name: Build Full ORT and Prepare Test Files - uses: microsoft/onnxruntime-github-actions/build-minimal-ort-and-run-tests@v0.0.6 + uses: microsoft/onnxruntime-github-actions/build-minimal-ort-and-run-tests@v0.0.7 with: reduced-ops-config-file: required_ops.ort_models.config enable-custom-ops: 'true' @@ -191,7 +191,7 @@ jobs: core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || ''); core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || ''); - - uses: microsoft/onnxruntime-github-actions/setup-build-tools@v0.0.6 + - uses: microsoft/onnxruntime-github-actions/setup-build-tools@v0.0.7 with: vcpkg-version: '2025.04.09' vcpkg-hash: '31a28b58854b7c7b503db99bb2eb41582d9f835b401adf3bd0f680ef329faa4ab4278b987b586a2a6141e2c98f007833266a4e3b60c3164226a3905466a082ce' @@ -200,7 +200,7 @@ jobs: add-cmake-to-path: 'true' disable-terrapin: 'true' - name: Build Full ORT and Prepare Test Files - uses: microsoft/onnxruntime-github-actions/build-minimal-ort-and-run-tests@v0.0.6 + uses: microsoft/onnxruntime-github-actions/build-minimal-ort-and-run-tests@v0.0.7 with: reduced-ops-config-file: required_ops_and_types.ort_models.config enable-type-reduction: 'true' @@ -229,7 +229,7 @@ jobs: core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || ''); core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || ''); - - uses: microsoft/onnxruntime-github-actions/setup-build-tools@v0.0.6 + - uses: microsoft/onnxruntime-github-actions/setup-build-tools@v0.0.7 with: vcpkg-version: '2025.04.09' vcpkg-hash: '31a28b58854b7c7b503db99bb2eb41582d9f835b401adf3bd0f680ef329faa4ab4278b987b586a2a6141e2c98f007833266a4e3b60c3164226a3905466a082ce' @@ -239,7 +239,7 @@ jobs: disable-terrapin: 'true' - name: Build Full ORT and Prepare Test Files - uses: microsoft/onnxruntime-github-actions/build-minimal-ort-and-run-tests@v0.0.6 + uses: microsoft/onnxruntime-github-actions/build-minimal-ort-and-run-tests@v0.0.7 with: globally_allowed_types: 'bool,float,int8_t,uint8_t' enable-type-reduction: 'true' @@ -264,7 +264,7 @@ jobs: node-version: 20 - name: Get Docker Image using Action - uses: microsoft/onnxruntime-github-actions/build-docker-image@v0.0.6 + uses: microsoft/onnxruntime-github-actions/build-docker-image@v0.0.7 id: build_docker_image_step with: dockerfile: ${{ github.workspace }}/tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/Dockerfile @@ -282,7 +282,7 @@ jobs: core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || ''); - name: Run Build 5 (Update) - uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.6 + uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.7 with: docker_image: ${{ steps.build_docker_image_step.outputs.full-image-name }} @@ -295,7 +295,7 @@ jobs: --minimal_build extended - name: Run Build 5 (Build) - uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.6 + uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.7 with: docker_image: ${{ steps.build_docker_image_step.outputs.full-image-name }} @@ -307,7 +307,7 @@ jobs: --use_binskim_compliant_compile_flags --minimal_build extended - name: Run Build 5 (Test) - uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.6 + uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.7 with: docker_image: ${{ steps.build_docker_image_step.outputs.full-image-name }} @@ -334,7 +334,7 @@ jobs: submodules: false - name: Get Docker Image using Action - uses: microsoft/onnxruntime-github-actions/build-docker-image@v0.0.6 + uses: microsoft/onnxruntime-github-actions/build-docker-image@v0.0.7 id: build_docker_image_step with: dockerfile: ${{ github.workspace }}/tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/Dockerfile @@ -358,7 +358,7 @@ jobs: touch ${{ runner.temp }}/.test_data/include_no_operators.config - name: Run Build 6a (Update) - uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.6 + uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.7 with: docker_image: ${{ steps.build_docker_image_step.outputs.full-image-name }} @@ -374,7 +374,7 @@ jobs: --cmake_extra_defines onnxruntime_BUILD_UNIT_TESTS=OFF - name: Run Build 6a (Build) - uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.6 + uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.7 with: docker_image: ${{ steps.build_docker_image_step.outputs.full-image-name }} @@ -391,7 +391,7 @@ jobs: - name: Run Build 6a (Test) - uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.6 + uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.7 with: docker_image: ${{ steps.build_docker_image_step.outputs.full-image-name }} @@ -427,7 +427,7 @@ jobs: touch ${{ runner.temp }}/.test_data/include_no_operators.config - name: Get Docker Image using Action - uses: microsoft/onnxruntime-github-actions/build-docker-image@v0.0.6 + uses: microsoft/onnxruntime-github-actions/build-docker-image@v0.0.7 id: build_docker_image_step with: dockerfile: ${{ github.workspace }}/tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/Dockerfile @@ -445,7 +445,7 @@ jobs: core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || ''); - name: Run Build 6b (Update) - uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.6 + uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.7 with: docker_image: ${{ steps.build_docker_image_step.outputs.full-image-name }} @@ -464,7 +464,7 @@ jobs: --cmake_extra_defines onnxruntime_BUILD_UNIT_TESTS=OFF - name: Run Build 6b (Build) - uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.6 + uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.7 with: docker_image: ${{ steps.build_docker_image_step.outputs.full-image-name }} @@ -503,7 +503,7 @@ jobs: touch ${{ runner.temp }}/.test_data/include_no_operators.config - name: Get Docker Image using Action - uses: microsoft/onnxruntime-github-actions/build-docker-image@v0.0.6 + uses: microsoft/onnxruntime-github-actions/build-docker-image@v0.0.7 id: build_docker_image_step with: dockerfile: ${{ github.workspace }}/tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/Dockerfile @@ -526,7 +526,7 @@ jobs: touch ${{ runner.temp }}/.test_data/include_no_operators.config - name: Run Build 6c (Update) - uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.6 + uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.7 with: docker_image: ${{ steps.build_docker_image_step.outputs.full-image-name }} @@ -545,7 +545,7 @@ jobs: --cmake_extra_defines onnxruntime_BUILD_UNIT_TESTS=OFF - name: Run Build 6c (Build) - uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.6 + uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.7 with: docker_image: ${{ steps.build_docker_image_step.outputs.full-image-name }} @@ -588,7 +588,7 @@ jobs: path: ${{ runner.temp }}/.test_data/ - name: Get Docker Image using Action - uses: microsoft/onnxruntime-github-actions/build-docker-image@v0.0.6 + uses: microsoft/onnxruntime-github-actions/build-docker-image@v0.0.7 id: build_docker_image_step with: dockerfile: ${{ github.workspace }}/tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/Dockerfile diff --git a/.github/workflows/linux_tensorrt_ci.yml b/.github/workflows/linux_tensorrt_ci.yml index 405de75e95454..1df467043329a 100644 --- a/.github/workflows/linux_tensorrt_ci.yml +++ b/.github/workflows/linux_tensorrt_ci.yml @@ -52,7 +52,7 @@ jobs: # --- Build the Docker image needed for testing --- - name: Build Docker Image for Testing - uses: microsoft/onnxruntime-github-actions/build-docker-image@v0.0.6 + uses: microsoft/onnxruntime-github-actions/build-docker-image@v0.0.7 id: build_docker_image_step with: dockerfile: ${{ github.workspace }}/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda @@ -95,7 +95,7 @@ jobs: # So build.py --build_dir build/Release inside the container correctly finds the artifacts. - name: Test ONNX Runtime id: test_step - uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.6 + uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.7 with: docker_image: ${{ steps.build_docker_image_step.outputs.full-image-name }} build_config: Release diff --git a/.github/workflows/reusable_linux_build.yml b/.github/workflows/reusable_linux_build.yml index 27595254800f9..af24e3a3d901a 100644 --- a/.github/workflows/reusable_linux_build.yml +++ b/.github/workflows/reusable_linux_build.yml @@ -83,7 +83,7 @@ jobs: python-version: ${{ inputs.python_version }} - name: Build Docker Image (${{ inputs.architecture }} / ${{ inputs.build_config }}) - uses: microsoft/onnxruntime-github-actions/build-docker-image@v0.0.6 + uses: microsoft/onnxruntime-github-actions/build-docker-image@v0.0.7 id: build_docker_image_step with: dockerfile: ${{ github.workspace }}/${{ inputs.dockerfile_path }} @@ -103,7 +103,7 @@ jobs: # ------------- Update Step (CMake Generation) ------------- - name: Generate Build Files (CMake) (${{ inputs.architecture }} / ${{ inputs.build_config }}) id: update_step - uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.6 + uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.7 with: docker_image: ${{ steps.build_docker_image_step.outputs.full-image-name }} build_config: ${{ inputs.build_config }} @@ -115,7 +115,7 @@ jobs: # ------------- Build Step (Compilation) ------------- - name: Build ONNX Runtime (${{ inputs.architecture }} / ${{ inputs.build_config }}) id: build_step - uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.6 + uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.7 with: docker_image: ${{ steps.build_docker_image_step.outputs.full-image-name }} build_config: ${{ inputs.build_config }} @@ -128,7 +128,7 @@ jobs: - name: Test ONNX Runtime (${{ inputs.architecture }} / ${{ inputs.build_config }}) id: test_step if: inputs.run_tests == true - uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.6 + uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.7 with: docker_image: ${{ steps.build_docker_image_step.outputs.full-image-name }} build_config: ${{ inputs.build_config }} diff --git a/.github/workflows/windows_webgpu.yml b/.github/workflows/windows_webgpu.yml index 999025f560674..70e8ea7e2792f 100644 --- a/.github/workflows/windows_webgpu.yml +++ b/.github/workflows/windows_webgpu.yml @@ -19,6 +19,9 @@ jobs: webgpu_build_x64_RelWithDebInfo: runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Win2022-GPU-A10"] timeout-minutes: 300 + strategy: + matrix: + vcpkg_option: [novcpkg, vcpkg] env: OrtPackageId: Microsoft.ML.OnnxRuntime OnnxRuntimeBuildDirectory: ${{ github.workspace }} @@ -107,7 +110,23 @@ jobs: - name: Build and Test shell: pwsh run: | - python.exe ${{ github.workspace }}\tools\ci_build\build.py --config RelWithDebInfo --build_dir ${{ github.workspace }} --skip_submodule_sync --build_csharp --parallel --use_binskim_compliant_compile_flags --cmake_generator "Visual Studio 17 2022" --build_shared_lib --enable_onnx_tests --build_nodejs --use_webgpu --cmake_extra_defines onnxruntime_BUILD_UNIT_TESTS=ON onnxruntime_BUILD_DAWN_MONOLITHIC_LIBRARY=ON + python.exe ${{ github.workspace }}\tools\ci_build\build.py ` + --config RelWithDebInfo ` + --build_dir ${{ github.workspace }} ` + --skip_submodule_sync ` + --build_csharp ` + --parallel ` + --use_binskim_compliant_compile_flags ` + --cmake_generator "Visual Studio 17 2022" ` + --build_shared_lib ` + --enable_onnx_tests ` + --build_nodejs ` + --build_java ` + --use_webgpu ` + ${{ matrix.vcpkg_option == 'vcpkg' && '--use_vcpkg' || '' }} ` + --cmake_extra_defines ` + onnxruntime_BUILD_UNIT_TESTS=ON ` + onnxruntime_BUILD_DAWN_MONOLITHIC_LIBRARY=ON if ($lastExitCode -ne 0) { exit $lastExitCode } diff --git a/.gitmodules b/.gitmodules index 7656fc429d005..b5bff01d89850 100644 --- a/.gitmodules +++ b/.gitmodules @@ -7,4 +7,4 @@ [submodule "cmake/external/emsdk"] path = cmake/external/emsdk url = https://github.com/emscripten-core/emsdk.git - branch = 4.0.4 + branch = 4.0.8 diff --git a/cgmanifests/cgmanifest.json b/cgmanifests/cgmanifest.json index f29857a231eb9..bf889e9fb61a8 100644 --- a/cgmanifests/cgmanifest.json +++ b/cgmanifests/cgmanifest.json @@ -36,7 +36,7 @@ "component": { "type": "git", "git": { - "commitHash": "bee4d1dd8dc1ee4a1fd8fa6a96476c2f8b7492a3", + "commitHash": "5c210da409e7f1e51ddf445134a4376fdbd70d7d", "repositoryUrl": "https://github.com/dmlc/dlpack.git" } } @@ -316,16 +316,6 @@ "comments": "gtest-ios-framework" } }, - { - "component": { - "type": "git", - "git": { - "commitHash": "277508879878e0a5b5b43599b1bea11f66eb3c6c", - "repositoryUrl": "https://github.com/dmlc/dlpack.git" - }, - "comments": "dlpack" - } - }, { "component": { "Type": "other", diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 121799e16ee97..416ed5e49f25a 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -258,7 +258,6 @@ option(onnxruntime_USE_OPENVINO_INTERFACE "Build ONNXRuntime shared lib which is option(onnxruntime_USE_VITISAI_INTERFACE "Build ONNXRuntime shared lib which is compatible with Vitis-AI EP interface" OFF) option(onnxruntime_USE_QNN_INTERFACE "Build ONNXRuntime shared lib which is compatible with QNN EP interface" OFF) - if("${CMAKE_C_COMPILER_ID}" STREQUAL "GNU" AND CMAKE_C_COMPILER_VERSION VERSION_LESS 11.1) message(FATAL_ERROR "GCC version must be greater than or equal to 11.1") endif() @@ -859,6 +858,10 @@ set(ONNXRUNTIME_PROVIDER_NAMES cpu) set(ORT_PROVIDER_FLAGS) if (onnxruntime_USE_CUDA) + include(cuda_configuration) + setup_cuda_compiler() + setup_cuda_architectures() + enable_language(CUDA) message( STATUS "CMAKE_CUDA_COMPILER_VERSION: ${CMAKE_CUDA_COMPILER_VERSION}") @@ -878,9 +881,6 @@ if (onnxruntime_USE_CUDA) set(onnxruntime_USE_FLASH_ATTENTION OFF) endif() - if (CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.4) - message( FATAL_ERROR "Failed build due to CUDA compiler version < 11.4") - endif() if (WIN32) message( STATUS "Lean Attention unsupported in Windows") set(onnxruntime_USE_LEAN_ATTENTION OFF) @@ -1066,6 +1066,34 @@ endif() if (onnxruntime_USE_WEBGPU) list(APPEND ORT_PROVIDER_FLAGS -DUSE_WEBGPU=1) list(APPEND ONNXRUNTIME_PROVIDER_NAMES webgpu) + + # TODO: the following code is used to disable building Dawn using vcpkg temporarily + # until we figure out how to resolve the packaging pipeline failures + # + # if (onnxruntime_USE_VCPKG AND NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten") + if (FALSE) + if (NOT onnxruntime_BUILD_DAWN_MONOLITHIC_LIBRARY) + message(FATAL_ERROR "onnxruntime_USE_VCPKG is not supported with onnxruntime_BUILD_DAWN_MONOLITHIC_LIBRARY=OFF") + endif() + if (onnxruntime_USE_EXTERNAL_DAWN) + message(FATAL_ERROR "onnxruntime_USE_VCPKG is not supported with onnxruntime_USE_EXTERNAL_DAWN=ON") + endif() + if (onnxruntime_CUSTOM_DAWN_SRC_PATH) + message(FATAL_ERROR "onnxruntime_USE_VCPKG is not supported with a custom dawn source path") + endif() + if (WIN32) + if (onnxruntime_ENABLE_DAWN_BACKEND_VULKAN) + message(FATAL_ERROR "onnxruntime_USE_VCPKG is not supported with onnxruntime_ENABLE_DAWN_BACKEND_VULKAN=ON") + endif() + if (NOT onnxruntime_ENABLE_DAWN_BACKEND_D3D12) + message(FATAL_ERROR "onnxruntime_USE_VCPKG is not supported with onnxruntime_ENABLE_DAWN_BACKEND_D3D12=OFF") + endif() + if (onnxruntime_ENABLE_PIX_FOR_WEBGPU_EP) + message(FATAL_ERROR "onnxruntime_USE_VCPKG is not supported with onnxruntime_ENABLE_PIX_FOR_WEBGPU_EP=ON") + endif() + endif() + endif() + if (onnxruntime_BUILD_DAWN_MONOLITHIC_LIBRARY) list(APPEND ORT_PROVIDER_FLAGS -DBUILD_DAWN_MONOLITHIC_LIBRARY=1) endif() @@ -1562,25 +1590,17 @@ if (onnxruntime_USE_CUDA) file(TO_CMAKE_PATH CUDAToolkit_ROOT ${onnxruntime_CUDA_HOME}) endif() find_package(CUDAToolkit REQUIRED) - if (NOT CMAKE_CUDA_ARCHITECTURES) - # Note that we generate SASS+PTX code for specified cuda architectures by assigning "xy" - # To add SASS only, assign "xy-real" - # To add PTX only, assign "xy-virtual" - if (CMAKE_LIBRARY_ARCHITECTURE STREQUAL "aarch64-linux-gnu") - # Support for Jetson/Tegra ARM devices - set(CMAKE_CUDA_ARCHITECTURES "53-real;62-real;72-real;87") # TX1/Nano, TX2, Xavier, Orin - else() - if (CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 12) - # 37, 50 still work in CUDA 11 but are marked deprecated and will be removed in future CUDA version. - set(CMAKE_CUDA_ARCHITECTURES "37-real;50-real;52-real;60-real;70-real;75-real;80-real;86-real;89") - elseif (CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 12.8) - set(CMAKE_CUDA_ARCHITECTURES "52-real;60-real;70-real;75-real;80-real;86-real;89-real;90") - else() - # https://cmake.org/cmake/help/latest/prop_tgt/CUDA_ARCHITECTURES.html - set(CMAKE_CUDA_ARCHITECTURES "all") # Supporting all, including latest Blackwell B series & RTX 50 series - endif() - endif() + + if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 11.8) + add_definitions("-DENABLE_FP8") + message(STATUS "CUDA Toolkit version is greater or equal than 11.8, enable -DENABLE_FP8 flag") endif() + + if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.8) + add_definitions("-DENABLE_FP4") + message(STATUS "CUDA Toolkit version is greater or equal than 12.8, enable -DENABLE_FP4 flag") + endif() + set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xfatbin=-compress-all") set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr") set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --Werror default-stream-launch") diff --git a/cmake/deps.txt b/cmake/deps.txt index a10bede254007..6e045f6dcdc9d 100644 --- a/cmake/deps.txt +++ b/cmake/deps.txt @@ -16,13 +16,15 @@ abseil_cpp;https://github.com/abseil/abseil-cpp/archive/refs/tags/20240722.0.zip coremltools;https://github.com/apple/coremltools/archive/refs/tags/7.1.zip;f1bab0f30966f2e217d8e01207d518f230a1641a cxxopts;https://github.com/jarro2783/cxxopts/archive/3c73d91c0b04e2b59462f0a741be8c07024c1bc0.zip;6c6ca7f8480b26c8d00476e0e24b7184717fe4f0 date;https://github.com/HowardHinnant/date/archive/refs/tags/v3.0.1.zip;2dac0c81dc54ebdd8f8d073a75c053b04b56e159 -dlpack;https://github.com/dmlc/dlpack/archive/refs/tags/v0.6.zip;4d565dd2e5b31321e5549591d78aa7f377173445 +dlpack;https://github.com/dmlc/dlpack/archive/5c210da409e7f1e51ddf445134a4376fdbd70d7d.zip;e499c86e4e5c5268a87661d7ea39c27fae10907c # This Eigen commit id matches the eigen archive being consumed from https://gitlab.com/libeigen/eigen/-/archive/3.4/eigen-3.4.zip # prior to the 3.4.1 RC changing the bits and invalidating the hash. # it contains changes on top of 3.4.0 which are required to fix build issues. # Until the 3.4.1 release this is the best option we have. # Issue link: https://gitlab.com/libeigen/eigen/-/issues/2744 -eigen;https://gitlab.com/libeigen/eigen/-/archive/1d8b82b0740839c0de7f1242a3585e3390ff5f33/eigen-1d8b82b0740839c0de7f1242a3585e3390ff5f33.zip;5ea4d05e62d7f954a46b3213f9b2535bdd866803 +# Moved to github mirror to avoid gitlab issues. +# Issue link: https://github.com/bazelbuild/bazel-central-registry/issues/4355 +eigen;https://github.com/eigen-mirror/eigen/archive/1d8b82b0740839c0de7f1242a3585e3390ff5f33/eigen-1d8b82b0740839c0de7f1242a3585e3390ff5f33.zip;05b19b49e6fbb91246be711d801160528c135e34 flatbuffers;https://github.com/google/flatbuffers/archive/refs/tags/v23.5.26.zip;59422c3b5e573dd192fead2834d25951f1c1670c fp16;https://github.com/Maratyszcza/FP16/archive/0a92994d729ff76a58f692d3028ca1b64b145d91.zip;b985f6985a05a1c03ff1bb71190f66d8f98a1494 fxdiv;https://github.com/Maratyszcza/FXdiv/archive/63058eff77e11aa15bf531df5dd34395ec3017c8.zip;a5658f4036402dbca7cebee32be57fb8149811e1 @@ -52,7 +54,7 @@ pytorch_cpuinfo;https://github.com/pytorch/cpuinfo/archive/8a1772a0c5c447df2d18e re2;https://github.com/google/re2/archive/refs/tags/2024-07-02.zip;646e1728269cde7fcef990bf4a8e87b047882e88 safeint;https://github.com/dcleblanc/SafeInt/archive/refs/tags/3.0.28.zip;23f252040ff6cb9f1fd18575b32fa8fb5928daac tensorboard;https://github.com/tensorflow/tensorboard/archive/373eb09e4c5d2b3cc2493f0949dc4be6b6a45e81.zip;67b833913605a4f3f499894ab11528a702c2b381 -cutlass;https://github.com/NVIDIA/cutlass/archive/refs/tags/v3.5.1.zip;e49b2b964163d27765a5002d210a2f3c73771835 +cutlass;https://github.com/NVIDIA/cutlass/archive/refs/tags/v3.9.2.zip;b7f8dc4a879765127ce31dfeabd31c556c80ec79 extensions;https://github.com/microsoft/onnxruntime-extensions/archive/c24b7bab0c12f53da76d0c31b03b9f0f8ec8f3b4.zip;239063aee4946a9af147b473a4c3da78ba7413b4 composable_kernel;https://github.com/ROCmSoftwarePlatform/composable_kernel/archive/204da9c522cebec5220bba52cd3542ebcaf99e7a.zip;1827348efd47831c13074245274d41b7cae8a557 directx_headers;https://github.com/microsoft/DirectX-Headers/archive/refs/tags/v1.613.1.zip;47653509a3371eabb156360f42faf582f314bf2e diff --git a/cmake/external/abseil-cpp.cmake b/cmake/external/abseil-cpp.cmake index 5cfb9e78b4720..488df5a4e0de8 100644 --- a/cmake/external/abseil-cpp.cmake +++ b/cmake/external/abseil-cpp.cmake @@ -15,7 +15,9 @@ set(ABSL_USE_EXTERNAL_GOOGLETEST ON) if (onnxruntime_USE_XNNPACK) set(ABSL_ENABLE_INSTALL OFF) else() - set(ABSL_ENABLE_INSTALL ON) + if (NOT CMAKE_SYSTEM_NAME MATCHES "AIX") + set(ABSL_ENABLE_INSTALL ON) + endif() endif() if(Patch_FOUND AND WIN32) diff --git a/cmake/external/cuda_configuration.cmake b/cmake/external/cuda_configuration.cmake new file mode 100644 index 0000000000000..ef94ec25132e3 --- /dev/null +++ b/cmake/external/cuda_configuration.cmake @@ -0,0 +1,172 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. +# + +macro(setup_cuda_compiler) + # Determine CUDA version before enabling the language extension check_language(CUDA) clears CMAKE_CUDA_HOST_COMPILER + # if CMAKE_CUDA_COMPILER is not set + include(CheckLanguage) + if(NOT CMAKE_CUDA_COMPILER AND CMAKE_CUDA_HOST_COMPILER) + set(CMAKE_CUDA_HOST_COMPILER_BACKUP ${CMAKE_CUDA_HOST_COMPILER}) + endif() + check_language(CUDA) + if(CMAKE_CUDA_HOST_COMPILER_BACKUP) + set(CMAKE_CUDA_HOST_COMPILER ${CMAKE_CUDA_HOST_COMPILER_BACKUP}) + check_language(CUDA) + endif() + if(CMAKE_CUDA_COMPILER) + message(STATUS "CUDA compiler: ${CMAKE_CUDA_COMPILER}") + if(NOT WIN32) # Linux + execute_process( + COMMAND "bash" "-c" "${CMAKE_CUDA_COMPILER} --version | grep -E -o 'V[0-9]+.[0-9]+.[0-9]+' | cut -c2-" + RESULT_VARIABLE _BASH_SUCCESS + OUTPUT_VARIABLE CMAKE_CUDA_COMPILER_VERSION + OUTPUT_STRIP_TRAILING_WHITESPACE) + + if(NOT _BASH_SUCCESS EQUAL 0) + message(FATAL_ERROR "Failed to determine CUDA version") + endif() + + else() # Windows + execute_process( + COMMAND ${CMAKE_CUDA_COMPILER} --version + OUTPUT_VARIABLE versionString + RESULT_VARIABLE versionResult) + + if(versionResult EQUAL 0 AND versionString MATCHES "V[0-9]+\\.[0-9]+\\.[0-9]+") + string(REGEX REPLACE "V" "" version ${CMAKE_MATCH_0}) + set(CMAKE_CUDA_COMPILER_VERSION "${version}") + else() + message(FATAL_ERROR "Failed to determine CUDA version") + endif() + endif() + else() + message(FATAL_ERROR "No CUDA compiler found") + endif() + + set(CUDA_REQUIRED_VERSION "11.4") + if(CMAKE_CUDA_COMPILER_VERSION VERSION_LESS CUDA_REQUIRED_VERSION) + message(FATAL_ERROR "CUDA version ${CMAKE_CUDA_COMPILER_VERSION} must be at least ${CUDA_REQUIRED_VERSION}") + endif() +endmacro() + +macro(setup_cuda_architectures) + # cmake-format: off + # Initialize and normalize CMAKE_CUDA_ARCHITECTURES before enabling CUDA. + # Special values: + # (1) `native` is resolved to HIGHEST available architecture. Fallback to `all` if detection failed. + # (2) `all` / `all-major` / unset is resolved to a default set of architectures we optimized and compiler supports. + # Numerical architectures: + # * For `-virtual` architectures, the last one is kept as it is, and the others are ignored. + # * `-real` suffix is automatically added for other cases. + # * Always use accelerated (`-a` suffix) target for supported real architectures. + # cmake-format: on + + if(CMAKE_CUDA_ARCHITECTURES STREQUAL "native") + # Detect highest available compute capability + set(OUTPUTFILE ${PROJECT_BINARY_DIR}/detect_cuda_arch) + set(CUDAFILE ${CMAKE_SOURCE_DIR}/utils/detect_cuda_arch.cu) + execute_process(COMMAND ${CMAKE_CUDA_COMPILER} -lcuda ${CUDAFILE} -o ${OUTPUTFILE}) + message(VERBOSE "Detecting native CUDA compute capability") + execute_process( + COMMAND ${OUTPUTFILE} + RESULT_VARIABLE CUDA_RETURN_CODE + OUTPUT_VARIABLE CUDA_ARCH_OUTPUT) + if(NOT ${CUDA_RETURN_CODE} EQUAL 0) + message(WARNING "Detecting native CUDA compute capability - fail") + message(WARNING "CUDA compute capability detection failed, compiling for all optimized architectures") + unset(CMAKE_CUDA_ARCHITECTURES) + else() + message(STATUS "Detecting native CUDA compute capability - done") + set(CMAKE_CUDA_ARCHITECTURES "${CUDA_ARCH_OUTPUT}") + endif() + elseif(CMAKE_CUDA_ARCHITECTURES STREQUAL "all") + unset(CMAKE_CUDA_ARCHITECTURES) + message(STATUS "Setting CMAKE_CUDA_ARCHITECTURES to all enables a list of architectures OnnxRuntime optimized for, " + "not all architectures CUDA compiler supports.") + elseif(CMAKE_CUDA_ARCHITECTURES STREQUAL "all-major") + unset(CMAKE_CUDA_ARCHITECTURES) + message( + STATUS "Setting CMAKE_CUDA_ARCHITECTURES to all-major enables a list of architectures OnnxRuntime optimized for, " + "not all major architectures CUDA compiler supports.") + else() + message(STATUS "Original CMAKE_CUDA_ARCHITECTURES : ${CMAKE_CUDA_ARCHITECTURES}") + endif() + + if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES) + if(CMAKE_LIBRARY_ARCHITECTURE STREQUAL "aarch64-linux-gnu") + # Support for Jetson/Tegra ARM devices + set(CMAKE_CUDA_ARCHITECTURES "53;62;72;87") # TX1/Nano, TX2, Xavier, Orin + else() + if(CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 12) + # 37, 50 still work in CUDA 11 but are marked deprecated and will be removed in future CUDA version. + set(CMAKE_CUDA_ARCHITECTURES "37;50;52;60;70;75;80;86;89") + elseif(CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 12.8) + set(CMAKE_CUDA_ARCHITECTURES "52;60;70;75;80;86;89;90") + else() + set(CMAKE_CUDA_ARCHITECTURES "60;70;75;80;86;89;90;100;120") + endif() + endif() + endif() + + unset(CMAKE_CUDA_ARCHITECTURES_CLEAN) + unset(CMAKE_CUDA_ARCHITECTURES_LAST_VIRTUAL) + foreach(CUDA_ARCH IN LISTS CMAKE_CUDA_ARCHITECTURES) + if(CUDA_ARCH STREQUAL "") + continue() + endif() + + if(CUDA_ARCH MATCHES "^([1-9])([0-9])+a?-virtual$") + set(CMAKE_CUDA_ARCHITECTURES_LAST_VIRTUAL ${CUDA_ARCH}) + elseif(CUDA_ARCH MATCHES "^(([1-9])([0-9])+)a?-real$") + list(APPEND CMAKE_CUDA_ARCHITECTURES_CLEAN ${CMAKE_MATCH_1}) + elseif(CUDA_ARCH MATCHES "^(([1-9])([0-9])+)a?$") + list(APPEND CMAKE_CUDA_ARCHITECTURES_CLEAN ${CMAKE_MATCH_1}) + else() + message(FATAL_ERROR "Unrecognized CUDA architecture: ${CUDA_ARCH}") + endif() + endforeach() + list(REMOVE_DUPLICATES CMAKE_CUDA_ARCHITECTURES_CLEAN) + set(CMAKE_CUDA_ARCHITECTURES ${CMAKE_CUDA_ARCHITECTURES_CLEAN}) + + # CMAKE_CUDA_ARCHITECTURES_ORIG contains all architectures enabled, without automatically added -real or -a suffix. + set(CMAKE_CUDA_ARCHITECTURES_ORIG "${CMAKE_CUDA_ARCHITECTURES}") + message(STATUS "GPU architectures: ${CMAKE_CUDA_ARCHITECTURES_ORIG}") + + set(ARCHITECTURES_WITH_KERNELS "80" "86" "89" "90" "100" "120") + foreach(CUDA_ARCH IN LISTS ARCHITECTURES_WITH_KERNELS) + if(NOT "${CUDA_ARCH}" IN_LIST CMAKE_CUDA_ARCHITECTURES_ORIG) + add_definitions("-DEXCLUDE_SM_${CUDA_ARCH}") + message(STATUS "Excluding SM ${CUDA_ARCH}") + endif() + endforeach() + + # Enable accelerated features (like WGMMA, TMA and setmaxnreg) for SM >= 90. + set(ARCHITECTURES_WITH_ACCEL "90" "100" "101" "120") + unset(CMAKE_CUDA_ARCHITECTURES_NORMALIZED) + foreach(CUDA_ARCH IN LISTS CMAKE_CUDA_ARCHITECTURES) + if("${CUDA_ARCH}" IN_LIST ARCHITECTURES_WITH_ACCEL) + list(APPEND CMAKE_CUDA_ARCHITECTURES_NORMALIZED "${CUDA_ARCH}a-real") + else() + list(APPEND CMAKE_CUDA_ARCHITECTURES_NORMALIZED "${CUDA_ARCH}-real") + endif() + endforeach() + + if(DEFINED CMAKE_CUDA_ARCHITECTURES_LAST_VIRTUAL) + list(APPEND CMAKE_CUDA_ARCHITECTURES_NORMALIZED "${CMAKE_CUDA_ARCHITECTURES_LAST_VIRTUAL}") + endif() + + set(CMAKE_CUDA_ARCHITECTURES ${CMAKE_CUDA_ARCHITECTURES_NORMALIZED}) + + message(STATUS "CMAKE_CUDA_ARCHITECTURES: ${CMAKE_CUDA_ARCHITECTURES}") +endmacro() diff --git a/cmake/external/emsdk b/cmake/external/emsdk index 074211759c17c..419021fa04042 160000 --- a/cmake/external/emsdk +++ b/cmake/external/emsdk @@ -1 +1 @@ -Subproject commit 074211759c17c646164d3271ca1d155cc174f78e +Subproject commit 419021fa040428bc69ef1559b325addb8e10211f diff --git a/cmake/external/extensions.cmake b/cmake/external/extensions.cmake index 8c00c1c8a530b..bd3c47d53f53d 100644 --- a/cmake/external/extensions.cmake +++ b/cmake/external/extensions.cmake @@ -69,7 +69,7 @@ set_target_properties(ortcustomops PROPERTIES INTERFACE_INCLUDE_DIRECTORIES "") # target library or executable are defined in CMakeLists.txt of onnxruntime-extensions target_include_directories(ocos_operators PRIVATE ${RE2_INCLUDE_DIR} ${json_SOURCE_DIR}/include) -target_include_directories(ortcustomops PUBLIC $) +target_include_directories(ortcustomops PUBLIC $) if(OCOS_ENABLE_SPM_TOKENIZER) onnxruntime_add_include_to_target(sentencepiece-static ${PROTOBUF_LIB} ${ABSEIL_LIBS}) endif() diff --git a/cmake/external/onnxruntime_external_deps.cmake b/cmake/external/onnxruntime_external_deps.cmake index d967e806eb5a3..4f6bcc8c90419 100644 --- a/cmake/external/onnxruntime_external_deps.cmake +++ b/cmake/external/onnxruntime_external_deps.cmake @@ -132,6 +132,9 @@ if(NOT ONNX_CUSTOM_PROTOC_EXECUTABLE AND NOT onnxruntime_USE_VCPKG) elseif(CMAKE_HOST_SYSTEM_PROCESSOR STREQUAL "x86") onnxruntime_fetchcontent_declare(protoc_binary URL ${DEP_URL_protoc_win32} URL_HASH SHA1=${DEP_SHA1_protoc_win32} EXCLUDE_FROM_ALL) FetchContent_Populate(protoc_binary) + elseif(CMAKE_HOST_SYSTEM_PROCESSOR STREQUAL "ARM64") + onnxruntime_fetchcontent_declare(protoc_binary URL ${DEP_URL_protoc_win64} URL_HASH SHA1=${DEP_SHA1_protoc_win64} EXCLUDE_FROM_ALL) + FetchContent_Populate(protoc_binary) endif() if(protoc_binary_SOURCE_DIR) @@ -625,118 +628,143 @@ endif() if (onnxruntime_USE_WEBGPU) - set(DAWN_BUILD_SAMPLES OFF CACHE BOOL "" FORCE) - set(DAWN_ENABLE_NULL OFF CACHE BOOL "" FORCE) - set(DAWN_FETCH_DEPENDENCIES ON CACHE BOOL "" FORCE) - set(DAWN_BUILD_TESTS OFF CACHE BOOL "" FORCE) - if (NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten") - if (onnxruntime_BUILD_DAWN_MONOLITHIC_LIBRARY) - set(DAWN_BUILD_MONOLITHIC_LIBRARY ON CACHE BOOL "" FORCE) - set(DAWN_ENABLE_INSTALL ON CACHE BOOL "" FORCE) - - if (onnxruntime_USE_EXTERNAL_DAWN) - message(FATAL_ERROR "onnxruntime_USE_EXTERNAL_DAWN and onnxruntime_BUILD_DAWN_MONOLITHIC_LIBRARY cannot be enabled at the same time.") + # TODO: the following code is used to disable building Dawn using vcpkg temporarily + # until we figure out how to resolve the packaging pipeline failures + # + # if (onnxruntime_USE_VCPKG AND NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten") + if (FALSE) + # vcpkg does not support Emscripten yet + find_package(dawn REQUIRED) + else() + # + # Please keep the following in sync with cmake/vcpkg-ports/dawn/portfile.cmake + # + set(DAWN_BUILD_SAMPLES OFF CACHE BOOL "" FORCE) + set(DAWN_ENABLE_NULL OFF CACHE BOOL "" FORCE) + set(DAWN_FETCH_DEPENDENCIES ON CACHE BOOL "" FORCE) + set(DAWN_BUILD_TESTS OFF CACHE BOOL "" FORCE) + if (NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten") + if (onnxruntime_BUILD_DAWN_MONOLITHIC_LIBRARY) + set(DAWN_BUILD_MONOLITHIC_LIBRARY ON CACHE BOOL "" FORCE) + set(DAWN_ENABLE_INSTALL ON CACHE BOOL "" FORCE) + + if (onnxruntime_USE_EXTERNAL_DAWN) + message(FATAL_ERROR "onnxruntime_USE_EXTERNAL_DAWN and onnxruntime_BUILD_DAWN_MONOLITHIC_LIBRARY cannot be enabled at the same time.") + endif() + else() + # use dawn::dawn_native and dawn::dawn_proc instead of the monolithic dawn::webgpu_dawn to minimize binary size + set(DAWN_BUILD_MONOLITHIC_LIBRARY OFF CACHE BOOL "" FORCE) + set(DAWN_ENABLE_INSTALL OFF CACHE BOOL "" FORCE) endif() - else() - # use dawn::dawn_native and dawn::dawn_proc instead of the monolithic dawn::webgpu_dawn to minimize binary size - set(DAWN_BUILD_MONOLITHIC_LIBRARY OFF CACHE BOOL "" FORCE) - set(DAWN_ENABLE_INSTALL OFF CACHE BOOL "" FORCE) - endif() - if (onnxruntime_ENABLE_PIX_FOR_WEBGPU_EP) - set(DAWN_ENABLE_DESKTOP_GL ON CACHE BOOL "" FORCE) - set(DAWN_ENABLE_OPENGLES ON CACHE BOOL "" FORCE) - set(DAWN_SUPPORTS_GLFW_FOR_WINDOWING ON CACHE BOOL "" FORCE) - set(DAWN_USE_GLFW ON CACHE BOOL "" FORCE) - set(DAWN_USE_WINDOWS_UI ON CACHE BOOL "" FORCE) - set(TINT_BUILD_GLSL_WRITER ON CACHE BOOL "" FORCE) - set(TINT_BUILD_GLSL_VALIDATOR ON CACHE BOOL "" FORCE) - else() - set(DAWN_ENABLE_DESKTOP_GL OFF CACHE BOOL "" FORCE) - set(DAWN_ENABLE_OPENGLES OFF CACHE BOOL "" FORCE) - set(DAWN_SUPPORTS_GLFW_FOR_WINDOWING OFF CACHE BOOL "" FORCE) - set(DAWN_USE_GLFW OFF CACHE BOOL "" FORCE) - set(DAWN_USE_WINDOWS_UI OFF CACHE BOOL "" FORCE) - set(TINT_BUILD_GLSL_WRITER OFF CACHE BOOL "" FORCE) - set(TINT_BUILD_GLSL_VALIDATOR OFF CACHE BOOL "" FORCE) - endif() - - # disable things we don't use - set(DAWN_DXC_ENABLE_ASSERTS_IN_NDEBUG OFF) - set(DAWN_USE_X11 OFF CACHE BOOL "" FORCE) - - set(TINT_BUILD_TESTS OFF CACHE BOOL "" FORCE) - set(TINT_BUILD_CMD_TOOLS OFF CACHE BOOL "" FORCE) - set(TINT_BUILD_IR_BINARY OFF CACHE BOOL "" FORCE) - set(TINT_BUILD_SPV_READER OFF CACHE BOOL "" FORCE) # don't need. disabling is a large binary size saving - set(TINT_BUILD_WGSL_WRITER ON CACHE BOOL "" FORCE) # needed to create cache key. runtime error if not enabled. - - # SPIR-V validation shouldn't be required given we're using Tint to create the SPIR-V. - set(DAWN_ENABLE_SPIRV_VALIDATION OFF CACHE BOOL "" FORCE) - - if (WIN32) - # building this requires the HLSL writer to be enabled in Tint. TBD if that we need either of these to be ON. - set(DAWN_USE_BUILT_DXC ON CACHE BOOL "" FORCE) - set(TINT_BUILD_HLSL_WRITER ON CACHE BOOL "" FORCE) - - if ((NOT onnxruntime_ENABLE_DAWN_BACKEND_VULKAN) AND (NOT onnxruntime_ENABLE_DAWN_BACKEND_D3D12)) - message(FATAL_ERROR "At least one of onnxruntime_ENABLE_DAWN_BACKEND_VULKAN or onnxruntime_ENABLE_DAWN_BACKEND_D3D12 must be enabled when using Dawn on Windows.") - endif() - if (onnxruntime_ENABLE_DAWN_BACKEND_VULKAN) - set(DAWN_ENABLE_VULKAN ON CACHE BOOL "" FORCE) - set(TINT_BUILD_SPV_WRITER ON CACHE BOOL "" FORCE) + if (onnxruntime_ENABLE_PIX_FOR_WEBGPU_EP) + set(DAWN_ENABLE_DESKTOP_GL ON CACHE BOOL "" FORCE) + set(DAWN_ENABLE_OPENGLES ON CACHE BOOL "" FORCE) + set(DAWN_SUPPORTS_GLFW_FOR_WINDOWING ON CACHE BOOL "" FORCE) + set(DAWN_USE_GLFW ON CACHE BOOL "" FORCE) + set(DAWN_USE_WINDOWS_UI ON CACHE BOOL "" FORCE) + set(TINT_BUILD_GLSL_WRITER ON CACHE BOOL "" FORCE) + set(TINT_BUILD_GLSL_VALIDATOR ON CACHE BOOL "" FORCE) else() - set(DAWN_ENABLE_VULKAN OFF CACHE BOOL "" FORCE) + set(DAWN_ENABLE_DESKTOP_GL OFF CACHE BOOL "" FORCE) + set(DAWN_ENABLE_OPENGLES OFF CACHE BOOL "" FORCE) + set(DAWN_SUPPORTS_GLFW_FOR_WINDOWING OFF CACHE BOOL "" FORCE) + set(DAWN_USE_GLFW OFF CACHE BOOL "" FORCE) + set(DAWN_USE_WINDOWS_UI OFF CACHE BOOL "" FORCE) + set(TINT_BUILD_GLSL_WRITER OFF CACHE BOOL "" FORCE) + set(TINT_BUILD_GLSL_VALIDATOR OFF CACHE BOOL "" FORCE) endif() - if (onnxruntime_ENABLE_DAWN_BACKEND_D3D12) - set(DAWN_ENABLE_D3D12 ON CACHE BOOL "" FORCE) - else() - set(DAWN_ENABLE_D3D12 OFF CACHE BOOL "" FORCE) + + # disable things we don't use + set(DAWN_DXC_ENABLE_ASSERTS_IN_NDEBUG OFF) + set(DAWN_USE_X11 OFF CACHE BOOL "" FORCE) + + set(TINT_BUILD_TESTS OFF CACHE BOOL "" FORCE) + set(TINT_BUILD_CMD_TOOLS OFF CACHE BOOL "" FORCE) + set(TINT_BUILD_IR_BINARY OFF CACHE BOOL "" FORCE) + set(TINT_BUILD_SPV_READER OFF CACHE BOOL "" FORCE) # don't need. disabling is a large binary size saving + set(TINT_BUILD_WGSL_WRITER ON CACHE BOOL "" FORCE) # needed to create cache key. runtime error if not enabled. + + # SPIR-V validation shouldn't be required given we're using Tint to create the SPIR-V. + set(DAWN_ENABLE_SPIRV_VALIDATION OFF CACHE BOOL "" FORCE) + + if (WIN32) + # building this requires the HLSL writer to be enabled in Tint. TBD if that we need either of these to be ON. + set(DAWN_USE_BUILT_DXC ON CACHE BOOL "" FORCE) + set(TINT_BUILD_HLSL_WRITER ON CACHE BOOL "" FORCE) + + if ((NOT onnxruntime_ENABLE_DAWN_BACKEND_VULKAN) AND (NOT onnxruntime_ENABLE_DAWN_BACKEND_D3D12)) + message(FATAL_ERROR "At least one of onnxruntime_ENABLE_DAWN_BACKEND_VULKAN or onnxruntime_ENABLE_DAWN_BACKEND_D3D12 must be enabled when using Dawn on Windows.") + endif() + if (onnxruntime_ENABLE_DAWN_BACKEND_VULKAN) + set(DAWN_ENABLE_VULKAN ON CACHE BOOL "" FORCE) + set(TINT_BUILD_SPV_WRITER ON CACHE BOOL "" FORCE) + else() + set(DAWN_ENABLE_VULKAN OFF CACHE BOOL "" FORCE) + endif() + if (onnxruntime_ENABLE_DAWN_BACKEND_D3D12) + set(DAWN_ENABLE_D3D12 ON CACHE BOOL "" FORCE) + else() + set(DAWN_ENABLE_D3D12 OFF CACHE BOOL "" FORCE) + endif() + # We are currently always using the D3D12 backend. + set(DAWN_ENABLE_D3D11 OFF CACHE BOOL "" FORCE) endif() - # We are currently always using the D3D12 backend. - set(DAWN_ENABLE_D3D11 OFF CACHE BOOL "" FORCE) endif() - endif() - if (onnxruntime_CUSTOM_DAWN_SRC_PATH) - # use the custom dawn source path if provided - # - # specified as: - # build.py --use_webgpu --cmake_extra_defines "onnxruntime_CUSTOM_DAWN_SRC_PATH=" - onnxruntime_fetchcontent_declare( - dawn - SOURCE_DIR ${onnxruntime_CUSTOM_DAWN_SRC_PATH} - EXCLUDE_FROM_ALL - ) - else() - onnxruntime_fetchcontent_declare( - dawn - URL ${DEP_URL_dawn} - URL_HASH SHA1=${DEP_SHA1_dawn} - # # All previous patches are merged into the upstream dawn project. We don't need to apply any patches right now. - # # if we need to apply patches in the future, we can uncomment the following line. + if (onnxruntime_CUSTOM_DAWN_SRC_PATH) + # use the custom dawn source path if provided # - # The dawn.patch contains the following changes: - # - # - (private) Allow WGPUBufferImpl class to destroy the buffer in the destructor - # In native implementation, wgpuBufferRelease will trigger the buffer destroy (if refcount decreased to 0). But - # in emwgpu implementation, the buffer destroy won't happen. This change adds a destructor to the buffer class - # to destroy the buffer when the refcount is 0 for non-external buffers. - # - # - (private) Remove hard-coded CMAKE_OSX_DEPLOYMENT_TARGET in Dawn's CMake files - # https://github.com/microsoft/onnxruntime/pull/23729 - # - # - (private) Reduce unsafe buffer usage warning in aligned_storage.h - # https://github.com/microsoft/onnxruntime/pull/24308 - # The patch disables the UNSAFE_BUFFER_USAGE warning around the AlignedStorage struct in aligned_storage.h. This is done - # by using TINT_BEGIN_DISABLE_WARNING and TINT_END_DISABLE_WARNING macros, which helps in warnings related to unsafe buffer usage - # usage when compiling the code, making the build process cleaner and faster. - # - PATCH_COMMAND ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/dawn/dawn.patch - EXCLUDE_FROM_ALL - ) - endif() + # specified as: + # build.py --use_webgpu --cmake_extra_defines "onnxruntime_CUSTOM_DAWN_SRC_PATH=" + onnxruntime_fetchcontent_declare( + dawn + SOURCE_DIR ${onnxruntime_CUSTOM_DAWN_SRC_PATH} + EXCLUDE_FROM_ALL + ) + else() + set(ONNXRUNTIME_Dawn_PATCH_COMMAND + # The dawn.patch contains the following changes: + # + # - (private) Allow WGPUBufferImpl class to destroy the buffer in the destructor + # In native implementation, wgpuBufferRelease will trigger the buffer destroy (if refcount decreased to 0). But + # in emwgpu implementation, the buffer destroy won't happen. This change adds a destructor to the buffer class + # to destroy the buffer when the refcount is 0 for non-external buffers. + # + # - (private) Remove hard-coded CMAKE_OSX_DEPLOYMENT_TARGET in Dawn's CMake files + # https://github.com/microsoft/onnxruntime/pull/23729 + # + # - (private) Reduce unsafe buffer usage warning in aligned_storage.h + # https://github.com/microsoft/onnxruntime/pull/24308 + # The patch disables the UNSAFE_BUFFER_USAGE warning around the AlignedStorage struct in aligned_storage.h. This is done + # by using TINT_BEGIN_DISABLE_WARNING and TINT_END_DISABLE_WARNING macros, which helps in warnings related to unsafe buffer usage + # usage when compiling the code, making the build process cleaner and faster. + # + ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/dawn/dawn.patch && + + # The dawn_force_enable_f16_nvidia_vulkan.patch contains the following changes: + # + # - (private) Force enable f16 support for NVIDIA Vulkan + # Dawn disabled f16 support for NVIDIA Vulkan by default because of crashes in f16 CTS tests (crbug.com/tint/2164). + # Since the crashes are limited to specific GPU models, we patched Dawn to remove the restriction. + ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/dawn/dawn_force_enable_f16_nvidia_vulkan.patch && + + # The dawn_fix_copy_dxil_dll.patch contains the following changes: + # + # - (private) Fix copy of dxil.dll in Dawn + # The patch ensures the copy of dxil.dll to be done after the build step of `dxcompiler` target. + ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/dawn/dawn_fix_copy_dxil_dll.patch) - onnxruntime_fetchcontent_makeavailable(dawn) + onnxruntime_fetchcontent_declare( + dawn + URL ${DEP_URL_dawn} + URL_HASH SHA1=${DEP_SHA1_dawn} + PATCH_COMMAND ${ONNXRUNTIME_Dawn_PATCH_COMMAND} + EXCLUDE_FROM_ALL + ) + endif() + + onnxruntime_fetchcontent_makeavailable(dawn) + endif() if (NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten") if (onnxruntime_BUILD_DAWN_MONOLITHIC_LIBRARY) diff --git a/cmake/onnxruntime.cmake b/cmake/onnxruntime.cmake index 1b124e3bb3f74..f6130f8c518a6 100644 --- a/cmake/onnxruntime.cmake +++ b/cmake/onnxruntime.cmake @@ -158,8 +158,8 @@ if(onnxruntime_BUILD_SHARED_LIB) target_link_options(onnxruntime PRIVATE "LINKER:-exported_symbols_list,${SYMBOL_FILE}") set_target_properties(onnxruntime PROPERTIES MACOSX_RPATH TRUE - SKIP_BUILD_RPATH TRUE - INSTALL_RPATH_USE_LINK_PATH FALSE + BUILD_WITH_INSTALL_RPATH TRUE + INSTALL_RPATH "@loader_path" BUILD_WITH_INSTALL_NAME_DIR TRUE INSTALL_NAME_DIR @rpath) endif() diff --git a/cmake/onnxruntime_common.cmake b/cmake/onnxruntime_common.cmake index f9cd35fa71aa8..e629df4843109 100644 --- a/cmake/onnxruntime_common.cmake +++ b/cmake/onnxruntime_common.cmake @@ -11,6 +11,8 @@ set(onnxruntime_common_src_patterns "${ONNXRUNTIME_ROOT}/core/common/logging/*.cc" "${ONNXRUNTIME_ROOT}/core/common/logging/sinks/*.h" "${ONNXRUNTIME_ROOT}/core/common/logging/sinks/*.cc" + "${ONNXRUNTIME_ROOT}/core/platform/check_intel.h" + "${ONNXRUNTIME_ROOT}/core/platform/check_intel.cc" "${ONNXRUNTIME_ROOT}/core/platform/device_discovery.h" "${ONNXRUNTIME_ROOT}/core/platform/device_discovery.cc" "${ONNXRUNTIME_ROOT}/core/platform/env.h" @@ -100,6 +102,14 @@ if(WIN32) target_compile_options(onnxruntime_common PRIVATE "/Zc:char8_t-") endif() endif() + +if(NOT WIN32 AND NOT APPLE AND NOT ANDROID AND CMAKE_SYSTEM_PROCESSOR MATCHES "x86_64") + set_source_files_properties( + ${ONNXRUNTIME_ROOT}/core/common/spin_pause.cc + PROPERTIES COMPILE_FLAGS "-mwaitpkg" + ) +endif() + if (onnxruntime_USE_TELEMETRY) set_target_properties(onnxruntime_common PROPERTIES COMPILE_FLAGS "/FI${ONNXRUNTIME_INCLUDE_DIR}/core/platform/windows/TraceLoggingConfigPrivate.h") endif() diff --git a/cmake/onnxruntime_java.cmake b/cmake/onnxruntime_java.cmake index 1227264e595ed..a65bd9373d1b7 100644 --- a/cmake/onnxruntime_java.cmake +++ b/cmake/onnxruntime_java.cmake @@ -58,6 +58,15 @@ file(GLOB onnxruntime4j_native_src onnxruntime_add_shared_library_module(onnxruntime4j_jni ${onnxruntime4j_native_src}) set_property(TARGET onnxruntime4j_jni PROPERTY C_STANDARD 11) +if (APPLE) + set_target_properties(onnxruntime4j_jni PROPERTIES + MACOSX_RPATH TRUE + SKIP_BUILD_RPATH TRUE + INSTALL_RPATH_USE_LINK_PATH FALSE + BUILD_WITH_INSTALL_NAME_DIR TRUE + INSTALL_NAME_DIR @rpath) +endif() + # depend on java sources. if they change, the JNI should recompile add_dependencies(onnxruntime4j_jni onnxruntime4j) onnxruntime_add_include_to_target(onnxruntime4j_jni onnxruntime_session) @@ -166,6 +175,32 @@ if (WIN32) if (onnxruntime_USE_QNN AND NOT onnxruntime_BUILD_QNN_EP_STATIC_LIB) add_custom_command(TARGET onnxruntime4j_jni POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different $ ${JAVA_PACKAGE_LIB_DIR}/$) endif() + if (onnxruntime_USE_WEBGPU) + if (onnxruntime_ENABLE_DAWN_BACKEND_D3D12) + # TODO: the following code is used to disable building Dawn using vcpkg temporarily + # until we figure out how to resolve the packaging pipeline failures + # + # if (onnxruntime_USE_VCPKG) + if (FALSE) + add_custom_command( + TARGET onnxruntime4j_jni POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different + $ + $ + ${JAVA_PACKAGE_LIB_DIR}/ + ) + else() + add_custom_command( + TARGET onnxruntime4j_jni POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different + $/dxil.dll + $/dxcompiler.dll + ${JAVA_PACKAGE_LIB_DIR}/ + ) + endif() + endif() + if (onnxruntime_BUILD_DAWN_MONOLITHIC_LIBRARY) + add_custom_command(TARGET onnxruntime4j_jni POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different $ ${JAVA_PACKAGE_LIB_DIR}/$) + endif() + endif() endif() else() add_custom_command(TARGET onnxruntime4j_jni POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different $ ${JAVA_PACKAGE_LIB_DIR}/$) @@ -188,6 +223,9 @@ else() if (onnxruntime_USE_QNN AND NOT onnxruntime_BUILD_QNN_EP_STATIC_LIB) add_custom_command(TARGET onnxruntime4j_jni POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different $ ${JAVA_PACKAGE_LIB_DIR}/$) endif() + if (onnxruntime_USE_WEBGPU AND onnxruntime_BUILD_DAWN_MONOLITHIC_LIBRARY) + add_custom_command(TARGET onnxruntime4j_jni POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different $ ${JAVA_PACKAGE_LIB_DIR}/$) + endif() endif() # run the build process (this copies the results back into CMAKE_CURRENT_BINARY_DIR) diff --git a/cmake/onnxruntime_nodejs.cmake b/cmake/onnxruntime_nodejs.cmake index 355575be3bcf7..54ac045ce135f 100644 --- a/cmake/onnxruntime_nodejs.cmake +++ b/cmake/onnxruntime_nodejs.cmake @@ -74,8 +74,17 @@ endif() if (onnxruntime_USE_WEBGPU) set(NODEJS_BINDING_USE_WEBGPU "--use_webgpu") if (WIN32 AND onnxruntime_ENABLE_DAWN_BACKEND_D3D12) - list(APPEND NODEJS_DLL_DEPS "$/dxil.dll") - list(APPEND NODEJS_DLL_DEPS "$/dxcompiler.dll") + # TODO: the following code is used to disable building Dawn using vcpkg temporarily + # until we figure out how to resolve the packaging pipeline failures + # + # if (onnxruntime_USE_VCPKG) + if (FALSE) + list(APPEND NODEJS_DLL_DEPS "$") + list(APPEND NODEJS_DLL_DEPS "$") + else() + list(APPEND NODEJS_DLL_DEPS "$/dxil.dll") + list(APPEND NODEJS_DLL_DEPS "$/dxcompiler.dll") + endif() endif() if (onnxruntime_BUILD_DAWN_MONOLITHIC_LIBRARY) list(APPEND NODEJS_DLL_DEPS "$") diff --git a/cmake/onnxruntime_providers_cuda.cmake b/cmake/onnxruntime_providers_cuda.cmake index 6a7510a5d83bc..da46f29dacf5f 100644 --- a/cmake/onnxruntime_providers_cuda.cmake +++ b/cmake/onnxruntime_providers_cuda.cmake @@ -179,7 +179,7 @@ set(onnxruntime_NVCC_THREADS "1" CACHE STRING "Number of threads that NVCC can use for compilation.") target_compile_options(${target} PRIVATE "$<$:SHELL:--threads \"${onnxruntime_NVCC_THREADS}\">") endif() - + # Since CUDA 12.8, compiling diagnostics become stricter if (CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.8) target_compile_options(${target} PRIVATE "$<$:--relocatable-device-code=true>") @@ -261,6 +261,11 @@ set_target_properties(${target} PROPERTIES LINKER_LANGUAGE CUDA) set_target_properties(${target} PROPERTIES FOLDER "ONNXRuntime") + if("90" IN_LIST CMAKE_CUDA_ARCHITECTURES_ORIG) + target_compile_options(${target} PRIVATE $<$:-Xptxas=-w>) + target_compile_definitions(${target} PRIVATE COMPILE_HOPPER_TMA_GEMMS) + endif() + if (onnxruntime_ENABLE_CUDA_PROFILING) # configure cupti for cuda profiling target_link_libraries(${target} PRIVATE CUDA::cupti) endif() diff --git a/cmake/onnxruntime_providers_nv.cmake b/cmake/onnxruntime_providers_nv.cmake index 12d824fc3360e..a804f2d7ae55c 100644 --- a/cmake/onnxruntime_providers_nv.cmake +++ b/cmake/onnxruntime_providers_nv.cmake @@ -1,4 +1,5 @@ # Copyright (c) Microsoft Corporation. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Licensed under the MIT License. find_package(CUDAToolkit REQUIRED 12.8) enable_language(CUDA) @@ -9,6 +10,9 @@ if (onnxruntime_NV_PLACEHOLDER_BUILDER) add_definitions(-DORT_NV_PLACEHOLDER_BUILDER) endif() +if (NOT onnxruntime_USE_TENSORRT_BUILTIN_PARSER) + message(FATAL_ERROR "TensorRT RTX can not be used with the open source parser.") +endif () set(BUILD_LIBRARY_ONLY 1) add_definitions("-DONNX_ML=1") add_definitions("-DONNX_NAMESPACE=onnx") diff --git a/cmake/onnxruntime_providers_tensorrt.cmake b/cmake/onnxruntime_providers_tensorrt.cmake index 59c7db9999b43..3698aaa902922 100644 --- a/cmake/onnxruntime_providers_tensorrt.cmake +++ b/cmake/onnxruntime_providers_tensorrt.cmake @@ -33,12 +33,27 @@ PATH_SUFFIXES include) file(READ ${TENSORRT_INCLUDE_DIR}/NvInferVersion.h NVINFER_VER_CONTENT) - string(REGEX MATCH "define NV_TENSORRT_MAJOR * +([0-9]+)" NV_TENSORRT_MAJOR "${NVINFER_VER_CONTENT}") - string(REGEX REPLACE "define NV_TENSORRT_MAJOR * +([0-9]+)" "\\1" NV_TENSORRT_MAJOR "${NV_TENSORRT_MAJOR}") - string(REGEX MATCH "define NV_TENSORRT_MINOR * +([0-9]+)" NV_TENSORRT_MINOR "${NVINFER_VER_CONTENT}") - string(REGEX REPLACE "define NV_TENSORRT_MINOR * +([0-9]+)" "\\1" NV_TENSORRT_MINOR "${NV_TENSORRT_MINOR}") - string(REGEX MATCH "define NV_TENSORRT_PATCH * +([0-9]+)" NV_TENSORRT_PATCH "${NVINFER_VER_CONTENT}") - string(REGEX REPLACE "define NV_TENSORRT_PATCH * +([0-9]+)" "\\1" NV_TENSORRT_PATCH "${NV_TENSORRT_PATCH}") + + # Starting TRT 10.11, TRT version macros have changed + string(REGEX MATCH "TRT_MAJOR_ENTERPRISE" TRT_VER_CHECK "${NVINFER_VER_CONTENT}") + # Pre TRT 10.11 + if("${TRT_VER_CHECK}" STREQUAL "") + string(REGEX MATCH "define NV_TENSORRT_MAJOR * +([0-9]+)" NV_TENSORRT_MAJOR "${NVINFER_VER_CONTENT}") + string(REGEX REPLACE "define NV_TENSORRT_MAJOR * +([0-9]+)" "\\1" NV_TENSORRT_MAJOR "${NV_TENSORRT_MAJOR}") + string(REGEX MATCH "define NV_TENSORRT_MINOR * +([0-9]+)" NV_TENSORRT_MINOR "${NVINFER_VER_CONTENT}") + string(REGEX REPLACE "define NV_TENSORRT_MINOR * +([0-9]+)" "\\1" NV_TENSORRT_MINOR "${NV_TENSORRT_MINOR}") + string(REGEX MATCH "define NV_TENSORRT_PATCH * +([0-9]+)" NV_TENSORRT_PATCH "${NVINFER_VER_CONTENT}") + string(REGEX REPLACE "define NV_TENSORRT_PATCH * +([0-9]+)" "\\1" NV_TENSORRT_PATCH "${NV_TENSORRT_PATCH}") + # TRT 10.11+ + else() + string(REGEX MATCH "define TRT_MAJOR_ENTERPRISE * +([0-9]+)" NV_TENSORRT_MAJOR "${NVINFER_VER_CONTENT}") + string(REGEX REPLACE "define TRT_MAJOR_ENTERPRISE * +([0-9]+)" "\\1" NV_TENSORRT_MAJOR "${NV_TENSORRT_MAJOR}") + string(REGEX MATCH "define TRT_MINOR_ENTERPRISE * +([0-9]+)" NV_TENSORRT_MINOR "${NVINFER_VER_CONTENT}") + string(REGEX REPLACE "define TRT_MINOR_ENTERPRISE * +([0-9]+)" "\\1" NV_TENSORRT_MINOR "${NV_TENSORRT_MINOR}") + string(REGEX MATCH "define TRT_PATCH_ENTERPRISE * +([0-9]+)" NV_TENSORRT_PATCH "${NVINFER_VER_CONTENT}") + string(REGEX REPLACE "define TRT_PATCH_ENTERPRISE * +([0-9]+)" "\\1" NV_TENSORRT_PATCH "${NV_TENSORRT_PATCH}") + endif() + math(EXPR NV_TENSORRT_MAJOR_INT "${NV_TENSORRT_MAJOR}") math(EXPR NV_TENSORRT_MINOR_INT "${NV_TENSORRT_MINOR}") math(EXPR NV_TENSORRT_PATCH_INT "${NV_TENSORRT_PATCH}") diff --git a/cmake/onnxruntime_providers_webgpu.cmake b/cmake/onnxruntime_providers_webgpu.cmake index 4bbca7b1b811a..a8a79bc928dd1 100644 --- a/cmake/onnxruntime_providers_webgpu.cmake +++ b/cmake/onnxruntime_providers_webgpu.cmake @@ -59,8 +59,24 @@ list(APPEND onnxruntime_DELAYLOAD_FLAGS "/DELAYLOAD:webgpu_dawn.dll") endif() - list(APPEND onnxruntime_providers_webgpu_dll_deps "$") + # TODO: the following code is used to disable building Dawn using vcpkg temporarily + # until we figure out how to resolve the packaging pipeline failures + # + # if (onnxruntime_USE_VCPKG) + if (FALSE) + # Fix Dawn vcpkg build issue (missing IMPORTED_IMPLIB and IMPORTED_LOCATION for target dawn::webgpu_dawn) + get_target_property(webgpu_dawn_target_IMPORTED_IMPLIB dawn::webgpu_dawn IMPORTED_IMPLIB) + if (NOT webgpu_dawn_target_IMPORTED_IMPLIB) + set_target_properties(dawn::webgpu_dawn PROPERTIES IMPORTED_IMPLIB "webgpu_dawn.lib") + endif() + get_target_property(webgpu_dawn_target_IMPORTED_LOCATION dawn::webgpu_dawn IMPORTED_LOCATION) + if (NOT webgpu_dawn_target_IMPORTED_LOCATION) + set_target_properties(dawn::webgpu_dawn PROPERTIES IMPORTED_LOCATION "webgpu_dawn.dll") + endif() + endif() endif() + + list(APPEND onnxruntime_providers_webgpu_dll_deps "$") else() if (NOT onnxruntime_USE_EXTERNAL_DAWN) target_link_libraries(onnxruntime_providers_webgpu dawn::dawn_native) @@ -70,11 +86,23 @@ if (WIN32 AND onnxruntime_ENABLE_DAWN_BACKEND_D3D12) # Ensure dxil.dll and dxcompiler.dll exist in the output directory $ - add_dependencies(onnxruntime_providers_webgpu copy_dxil_dll) - add_dependencies(onnxruntime_providers_webgpu dxcompiler) + # TODO: the following code is used to disable building Dawn using vcpkg temporarily + # until we figure out how to resolve the packaging pipeline failures + # + # if (onnxruntime_USE_VCPKG) + if (FALSE) + find_package(directx-dxc CONFIG REQUIRED) + target_link_libraries(onnxruntime_providers_webgpu Microsoft::DirectXShaderCompiler) + target_link_libraries(onnxruntime_providers_webgpu Microsoft::DXIL) + list(APPEND onnxruntime_providers_webgpu_dll_deps "$") + list(APPEND onnxruntime_providers_webgpu_dll_deps "$") + else() + add_dependencies(onnxruntime_providers_webgpu copy_dxil_dll) + add_dependencies(onnxruntime_providers_webgpu dxcompiler) - list(APPEND onnxruntime_providers_webgpu_dll_deps "$/dxil.dll") - list(APPEND onnxruntime_providers_webgpu_dll_deps "$/dxcompiler.dll") + list(APPEND onnxruntime_providers_webgpu_dll_deps "$/dxil.dll") + list(APPEND onnxruntime_providers_webgpu_dll_deps "$/dxcompiler.dll") + endif() endif() if (onnxruntime_providers_webgpu_dll_deps) diff --git a/cmake/onnxruntime_python.cmake b/cmake/onnxruntime_python.cmake index 8f7a96e052fa1..f6eac2c24eca2 100644 --- a/cmake/onnxruntime_python.cmake +++ b/cmake/onnxruntime_python.cmake @@ -453,6 +453,9 @@ endif() file(GLOB onnxruntime_python_tools_srcs CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/python/tools/*.py" ) +file(GLOB onnxruntime_python_tools_qnn_src CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/python/tools/qnn/*.py" +) file(GLOB onnxruntime_python_quantization_src CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/python/tools/quantization/*.py" ) @@ -564,6 +567,7 @@ add_custom_command( COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/tools/qdq_helpers COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/tools/ort_format_model COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/tools/ort_format_model/ort_flatbuffers_py + COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/tools/qnn COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/transformers COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/transformers/models COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/transformers/models/bart @@ -649,6 +653,9 @@ add_custom_command( COMMAND ${CMAKE_COMMAND} -E copy_directory ${ONNXRUNTIME_ROOT}/core/flatbuffers/ort_flatbuffers_py $/onnxruntime/tools/ort_format_model/ort_flatbuffers_py + COMMAND ${CMAKE_COMMAND} -E copy + ${onnxruntime_python_tools_qnn_src} + $/onnxruntime/tools/qnn/ COMMAND ${CMAKE_COMMAND} -E copy ${onnxruntime_python_quantization_src} $/onnxruntime/quantization/ @@ -1073,6 +1080,40 @@ if (onnxruntime_USE_QNN) endif() endif() +if (onnxruntime_USE_WEBGPU) + if (WIN32 AND onnxruntime_ENABLE_DAWN_BACKEND_D3D12) + # TODO: the following code is used to disable building Dawn using vcpkg temporarily + # until we figure out how to resolve the packaging pipeline failures + # + # if (onnxruntime_USE_VCPKG) + if (FALSE) + add_custom_command( + TARGET onnxruntime_pybind11_state POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy + $ + $ + $/onnxruntime/capi/ + ) + else() + add_custom_command( + TARGET onnxruntime_pybind11_state POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy + $/dxil.dll + $/dxcompiler.dll + $/onnxruntime/capi/ + ) + endif() + endif() + if (onnxruntime_BUILD_DAWN_MONOLITHIC_LIBRARY) + add_custom_command( + TARGET onnxruntime_pybind11_state POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy + $ + $/onnxruntime/capi/ + ) + endif() +endif() + if (onnxruntime_USE_VSINPU) add_custom_command( TARGET onnxruntime_pybind11_state POST_BUILD diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index b31fdd4ea1ee1..15cc238173f29 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -724,6 +724,7 @@ endif() # or reduced op builds. if(onnxruntime_USE_QNN AND NOT onnxruntime_MINIMAL_BUILD AND NOT onnxruntime_REDUCED_OPS_BUILD) list(APPEND onnxruntime_test_framework_src_patterns ${TEST_SRC_DIR}/providers/qnn/*) + list(APPEND onnxruntime_test_framework_src_patterns ${TEST_SRC_DIR}/providers/qnn/qnn_node_group/*) list(APPEND onnxruntime_test_framework_libs onnxruntime_providers_qnn) list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_qnn) if(NOT onnxruntime_BUILD_QNN_EP_STATIC_LIB) @@ -1278,6 +1279,9 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) list(APPEND onnxruntime_perf_test_libs onnxruntime_graph onnxruntime_session onnxruntime_providers onnxruntime_framework onnxruntime_util onnxruntime_mlas onnxruntime_optimizer onnxruntime_flatbuffers iconv re2 gtest absl_failure_signal_handler absl_examine_stack absl_flags_parse absl_flags_usage absl_flags_usage_internal) endif() target_link_libraries(onnxruntime_perf_test PRIVATE ${onnxruntime_perf_test_libs} Threads::Threads) + if (onnxruntime_USE_CUDA OR onnxruntime_USE_NV OR onnxruntime_USE_TENSORRT) + target_link_libraries(onnxruntime_perf_test PRIVATE CUDA::cudart) + endif() if(WIN32) target_link_libraries(onnxruntime_perf_test PRIVATE debug dbghelp advapi32) endif() diff --git a/cmake/onnxruntime_webassembly.cmake b/cmake/onnxruntime_webassembly.cmake index bfb73e14ce7a4..f00292fade52d 100644 --- a/cmake/onnxruntime_webassembly.cmake +++ b/cmake/onnxruntime_webassembly.cmake @@ -196,7 +196,7 @@ else() onnxruntime_util re2::re2 ) - set(EXPORTED_RUNTIME_METHODS "'stackAlloc','stackRestore','stackSave','UTF8ToString','stringToUTF8','lengthBytesUTF8','getValue','setValue'") + set(EXPORTED_RUNTIME_METHODS "'stackAlloc','stackRestore','stackSave','UTF8ToString','stringToUTF8','lengthBytesUTF8','getValue','setValue','HEAP8','HEAPU8','HEAP32','HEAPU32'") if (onnxruntime_USE_XNNPACK) target_link_libraries(onnxruntime_webassembly PRIVATE XNNPACK) string(APPEND EXPORTED_RUNTIME_METHODS ",'addFunction'") diff --git a/cmake/patches/dawn/dawn_fix_copy_dxil_dll.patch b/cmake/patches/dawn/dawn_fix_copy_dxil_dll.patch new file mode 100644 index 0000000000000..cd4d53b4cbdb7 --- /dev/null +++ b/cmake/patches/dawn/dawn_fix_copy_dxil_dll.patch @@ -0,0 +1,13 @@ +diff --git a/third_party/CMakeLists.txt b/third_party/CMakeLists.txt +index cdfde38819..fc5ff76421 100644 +--- a/third_party/CMakeLists.txt ++++ b/third_party/CMakeLists.txt +@@ -352,6 +352,8 @@ function(AddSubdirectoryDXC) + TARGET copy_dxil_dll + COMMAND ${CMAKE_COMMAND} -E copy_if_different ${DXIL_DLL_PATH} $ + COMMENT "Copying ${DXIL_DLL_PATH} to $") ++ # Ensure folder "$" exists when copying the dll ++ add_dependencies(copy_dxil_dll dxcompiler) + # Make dxc target depend on copy_dxil_dll + add_dependencies(dxc copy_dxil_dll) + endif() diff --git a/cmake/patches/dawn/dawn_force_enable_f16_nvidia_vulkan.patch b/cmake/patches/dawn/dawn_force_enable_f16_nvidia_vulkan.patch new file mode 100644 index 0000000000000..2d999a456fdec --- /dev/null +++ b/cmake/patches/dawn/dawn_force_enable_f16_nvidia_vulkan.patch @@ -0,0 +1,19 @@ +diff --git a/src/dawn/native/vulkan/PhysicalDeviceVk.cpp b/src/dawn/native/vulkan/PhysicalDeviceVk.cpp +index 158f10764c..a324c101ed 100644 +--- a/src/dawn/native/vulkan/PhysicalDeviceVk.cpp ++++ b/src/dawn/native/vulkan/PhysicalDeviceVk.cpp +@@ -269,11 +269,9 @@ void PhysicalDevice::InitializeSupportedFeaturesImpl() { + mDeviceInfo.shaderFloat16Int8Features.shaderFloat16 == VK_TRUE && + mDeviceInfo._16BitStorageFeatures.storageBuffer16BitAccess == VK_TRUE && + mDeviceInfo._16BitStorageFeatures.uniformAndStorageBuffer16BitAccess == VK_TRUE) { +- // TODO(crbug.com/tint/2164): Investigate crashes in f16 CTS tests to enable on NVIDIA. +- if (!gpu_info::IsNvidia(GetVendorId())) { +- EnableFeature(Feature::ShaderF16); +- shaderF16Enabled = true; +- } ++ // ONNX Runtime Patch: enable shaderF16 on all devices. ++ EnableFeature(Feature::ShaderF16); ++ shaderF16Enabled = true; + } + + if (mDeviceInfo.HasExt(DeviceExt::DrawIndirectCount) && diff --git a/cmake/utils/detect_cuda_arch.cu b/cmake/utils/detect_cuda_arch.cu new file mode 100644 index 0000000000000..83fbc13dbff7f --- /dev/null +++ b/cmake/utils/detect_cuda_arch.cu @@ -0,0 +1,39 @@ +#include +#include +#include +#include +#include + +int main(int argc, char* argv[]) +{ + int n_devices = 0; + int rc = cudaGetDeviceCount(&n_devices); + if (rc != cudaSuccess) + { + cudaError_t error = cudaGetLastError(); + std::cout << "CUDA error: " << cudaGetErrorString(error) << std::endl; + return rc; + } + + std::vector> arch(n_devices); + for (int cd = 0; cd < n_devices; ++cd) + { + cudaDeviceProp dev; + int rc = cudaGetDeviceProperties(&dev, cd); + if (rc != cudaSuccess) + { + cudaError_t error = cudaGetLastError(); + std::cout << "CUDA error: " << cudaGetErrorString(error) << std::endl; + return rc; + } + else + { + arch[cd] = {dev.major, dev.minor}; + } + } + + std::pair best_cc = *std::max_element(begin(arch), end(arch)); + std::cout << best_cc.first << best_cc.second; + + return 0; +} diff --git a/cmake/vcpkg-ports/.gitattributes b/cmake/vcpkg-ports/.gitattributes new file mode 100644 index 0000000000000..9812ceb1ffd9b --- /dev/null +++ b/cmake/vcpkg-ports/.gitattributes @@ -0,0 +1 @@ +*.patch text eol=lf diff --git a/cmake/vcpkg-ports/dawn/dawn.patch b/cmake/vcpkg-ports/dawn/dawn.patch new file mode 100644 index 0000000000000..1fe66d2cf917d --- /dev/null +++ b/cmake/vcpkg-ports/dawn/dawn.patch @@ -0,0 +1,59 @@ +diff --git a/src/cmake/DawnCompilerPlatformFlags.cmake b/src/cmake/DawnCompilerPlatformFlags.cmake +index 50638e2456..efa42711e6 100644 +--- a/src/cmake/DawnCompilerPlatformFlags.cmake ++++ b/src/cmake/DawnCompilerPlatformFlags.cmake +@@ -63,7 +63,3 @@ endif () + if (MSVC AND NOT COMPILER_IS_CLANG_CL) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /MP") + endif () +- +-if (TARGET_MACOS) +- set(CMAKE_OSX_DEPLOYMENT_TARGET "11.0" CACHE STRING "Minimum macOS version" FORCE) +-endif () +\ No newline at end of file +diff --git a/third_party/emdawnwebgpu/webgpu.cpp b/third_party/emdawnwebgpu/webgpu.cpp +index 5bfac41dcc..71a153daaa 100644 +--- a/third_party/emdawnwebgpu/webgpu.cpp ++++ b/third_party/emdawnwebgpu/webgpu.cpp +@@ -692,6 +692,7 @@ struct WGPUBufferImpl final : public EventSource, + WGPUBufferImpl(const EventSource* source, bool mappedAtCreation); + // Injection constructor used when we already have a backing Buffer. + WGPUBufferImpl(const EventSource* source, WGPUBufferMapState mapState); ++ ~WGPUBufferImpl(); + + void Destroy(); + const void* GetConstMappedRange(size_t offset, size_t size); +@@ -1361,6 +1362,12 @@ WGPUBufferImpl::WGPUBufferImpl(const EventSource* source, + RefCountedWithExternalCount(kImportedFromJS), + mMapState(mapState) {} + ++WGPUBufferImpl::~WGPUBufferImpl() { ++ if (!IsImported()) { ++ Destroy(); ++ } ++} ++ + void WGPUBufferImpl::Destroy() { + emwgpuBufferDestroy(this); + AbortPendingMap("Buffer was destroyed before mapping was resolved."); +diff --git a/src/tint/utils/memory/aligned_storage.h b/src/tint/utils/memory/aligned_storage.h +index c532c4fc38..19c950af4c 100644 +--- a/src/tint/utils/memory/aligned_storage.h ++++ b/src/tint/utils/memory/aligned_storage.h +@@ -31,6 +31,9 @@ + #include + + #include "src/tint/utils/memory/bitcast.h" ++#include "src/tint/utils/macros/compiler.h" ++ ++TINT_BEGIN_DISABLE_WARNING(UNSAFE_BUFFER_USAGE); + + namespace tint { + +@@ -50,4 +53,6 @@ struct alignas(alignof(T)) AlignedStorage { + + } // namespace tint + ++TINT_END_DISABLE_WARNING(UNSAFE_BUFFER_USAGE); ++ + #endif // SRC_TINT_UTILS_MEMORY_ALIGNED_STORAGE_H_ diff --git a/cmake/vcpkg-ports/dawn/dawn_force_enable_f16_nvidia_vulkan.patch b/cmake/vcpkg-ports/dawn/dawn_force_enable_f16_nvidia_vulkan.patch new file mode 100644 index 0000000000000..2d999a456fdec --- /dev/null +++ b/cmake/vcpkg-ports/dawn/dawn_force_enable_f16_nvidia_vulkan.patch @@ -0,0 +1,19 @@ +diff --git a/src/dawn/native/vulkan/PhysicalDeviceVk.cpp b/src/dawn/native/vulkan/PhysicalDeviceVk.cpp +index 158f10764c..a324c101ed 100644 +--- a/src/dawn/native/vulkan/PhysicalDeviceVk.cpp ++++ b/src/dawn/native/vulkan/PhysicalDeviceVk.cpp +@@ -269,11 +269,9 @@ void PhysicalDevice::InitializeSupportedFeaturesImpl() { + mDeviceInfo.shaderFloat16Int8Features.shaderFloat16 == VK_TRUE && + mDeviceInfo._16BitStorageFeatures.storageBuffer16BitAccess == VK_TRUE && + mDeviceInfo._16BitStorageFeatures.uniformAndStorageBuffer16BitAccess == VK_TRUE) { +- // TODO(crbug.com/tint/2164): Investigate crashes in f16 CTS tests to enable on NVIDIA. +- if (!gpu_info::IsNvidia(GetVendorId())) { +- EnableFeature(Feature::ShaderF16); +- shaderF16Enabled = true; +- } ++ // ONNX Runtime Patch: enable shaderF16 on all devices. ++ EnableFeature(Feature::ShaderF16); ++ shaderF16Enabled = true; + } + + if (mDeviceInfo.HasExt(DeviceExt::DrawIndirectCount) && diff --git a/cmake/vcpkg-ports/dawn/dawn_vcpkg_integration.patch b/cmake/vcpkg-ports/dawn/dawn_vcpkg_integration.patch new file mode 100644 index 0000000000000..6e97475c8ad53 --- /dev/null +++ b/cmake/vcpkg-ports/dawn/dawn_vcpkg_integration.patch @@ -0,0 +1,125 @@ +diff --git a/CMakeLists.txt b/CMakeLists.txt +index b46b68204b..3e985ae3cd 100644 +--- a/CMakeLists.txt ++++ b/CMakeLists.txt +@@ -127,6 +127,8 @@ if (DAWN_SUPPORTS_GLFW_FOR_WINDOWING) + set(BUILD_SAMPLES ON) + endif() + ++option(DAWN_ENABLE_VCPKG "Enable vcpkg integration" OFF) ++ + option(DAWN_ENABLE_ASAN "Enable address sanitizer" OFF) + option(DAWN_ENABLE_INSTALL "Enable install step for Dawn libraries" OFF) + option(DAWN_ENABLE_TSAN "Enable thread sanitizer" OFF) +@@ -439,16 +441,25 @@ set(TINT_SPIRV_TOOLS_DIR ${DAWN_SPIRV_TOOLS_DIR}) + ################################################################################ + # Run on all subdirectories + ################################################################################ +-if (DAWN_BUILD_PROTOBUF AND EXISTS "${DAWN_PROTOBUF_DIR}/cmake") +- if (("${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang") AND WIN32) +- set(protobuf_HAVE_BUILTIN_ATOMICS 1) ++if (DAWN_ENABLE_VCPKG) ++ find_package(absl REQUIRED) ++ find_package(SPIRV-Headers REQUIRED) ++ find_package(SPIRV-Tools REQUIRED) ++ if (DAWN_USE_BUILT_DXC) ++ find_package(directx-dxc CONFIG REQUIRED) + endif() ++else() ++ if (DAWN_BUILD_PROTOBUF AND EXISTS "${DAWN_PROTOBUF_DIR}/cmake") ++ if (("${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang") AND WIN32) ++ set(protobuf_HAVE_BUILTIN_ATOMICS 1) ++ endif() + +- # Needs to come before SPIR-V Tools +- include("third_party/protobuf.cmake") +-endif() ++ # Needs to come before SPIR-V Tools ++ include("third_party/protobuf.cmake") ++ endif() + +-add_subdirectory(third_party) ++ add_subdirectory(third_party) ++endif() + + # TODO(crbug.com/tint/455): Tint does not currently build with CMake when + # BUILD_SHARED_LIBS=1, so always build it as static for now. +diff --git a/src/dawn/native/CMakeLists.txt b/src/dawn/native/CMakeLists.txt +index d3128bf764..319a847311 100644 +--- a/src/dawn/native/CMakeLists.txt ++++ b/src/dawn/native/CMakeLists.txt +@@ -865,7 +865,9 @@ if (DAWN_ENABLE_D3D12) + if (DAWN_USE_BUILT_DXC) + target_compile_definitions(dawn_native PRIVATE "DAWN_USE_BUILT_DXC") + target_compile_definitions(dawn_native_objects PRIVATE "DAWN_USE_BUILT_DXC") +- add_dependencies(dawn_native copy_dxil_dll) ++ if (NOT DAWN_ENABLE_VCPKG) ++ add_dependencies(dawn_native copy_dxil_dll) ++ endif() + endif() + endif() + +@@ -942,5 +944,9 @@ endif () + # They happen because dxcompiler is declared a shared library and bundle_libraries + # doesn't work well with shared libs + if (DAWN_USE_BUILT_DXC) +- target_link_libraries(dawn_native PRIVATE dxcompiler) ++ if (DAWN_ENABLE_VCPKG) ++ target_link_libraries(dawn_native PRIVATE Microsoft::DirectXShaderCompiler) ++ else() ++ target_link_libraries(dawn_native PRIVATE dxcompiler) ++ endif() + endif() +diff --git a/src/tint/CMakeLists.txt b/src/tint/CMakeLists.txt +index 8692171222..b3da2fbbbf 100644 +--- a/src/tint/CMakeLists.txt ++++ b/src/tint/CMakeLists.txt +@@ -214,13 +214,21 @@ function(tint_default_compile_options TARGET) + endfunction() + + function(tint_spvheaders_compile_options TARGET) +- target_link_libraries(${TARGET} PRIVATE SPIRV-Headers) +- target_include_directories(${TARGET} PRIVATE "${TINT_SPIRV_HEADERS_DIR}/include") ++ if (DAWN_ENABLE_VCPKG) ++ target_link_libraries(${TARGET} PRIVATE SPIRV-Headers::SPIRV-Headers) ++ else () ++ target_link_libraries(${TARGET} PRIVATE SPIRV-Headers) ++ target_include_directories(${TARGET} PRIVATE "${TINT_SPIRV_HEADERS_DIR}/include") ++ endif() + endfunction() + + function(tint_spvtools_compile_options TARGET) +- target_link_libraries(${TARGET} PRIVATE SPIRV-Tools) +- target_include_directories(${TARGET} PRIVATE "${TINT_SPIRV_TOOLS_DIR}/include") ++ if (DAWN_ENABLE_VCPKG) ++ target_link_libraries(${TARGET} PRIVATE SPIRV-Tools-static) ++ else () ++ target_link_libraries(${TARGET} PRIVATE SPIRV-Tools) ++ target_include_directories(${TARGET} PRIVATE "${TINT_SPIRV_TOOLS_DIR}/include") ++ endif() + endfunction() + + function(tint_lib_compile_options TARGET) +@@ -562,12 +570,16 @@ function(tint_target_add_external_dependencies TARGET KIND) + target_link_libraries(${TARGET} PRIVATE + SPIRV-Tools-opt + ) +- target_include_directories(${TARGET} PRIVATE +- "${TINT_SPIRV_TOOLS_DIR}" +- "${TINT_SPIRV_TOOLS_DIR}/include" +- "${TINT_SPIRV_TOOLS_DIR}/source" +- "${spirv-tools_BINARY_DIR}" +- ) ++ if (DAWN_ENABLE_VCPKG) ++ target_link_libraries(${TARGET} PRIVATE SPIRV-Tools-static) ++ else () ++ target_include_directories(${TARGET} PRIVATE ++ "${TINT_SPIRV_TOOLS_DIR}" ++ "${TINT_SPIRV_TOOLS_DIR}/include" ++ "${TINT_SPIRV_TOOLS_DIR}/source" ++ "${spirv-tools_BINARY_DIR}" ++ ) ++ endif() + elseif(${DEPENDENCY} STREQUAL "thread") + find_package(Threads REQUIRED) + target_link_libraries(${TARGET} PRIVATE Threads::Threads) diff --git a/cmake/vcpkg-ports/dawn/portfile.cmake b/cmake/vcpkg-ports/dawn/portfile.cmake new file mode 100644 index 0000000000000..1c53f8316c372 --- /dev/null +++ b/cmake/vcpkg-ports/dawn/portfile.cmake @@ -0,0 +1,138 @@ +# NOTE: dynamic library vs. static library +# +# We are building Dawn as a shared library `webgpu_dawn`. However, we need to set the `BUILD_SHARED_LIBS` option to +# `OFF` in this portfile. See the explanation below. +# +# In CMake convention, the `BUILD_SHARED_LIBS` option is used to control whether a library is built as a shared library or a static library. +# However, in the Dawn repository, there are multiple targets. Instead of building each target as a shared library, Dawn +# uses a CMake option `DAWN_BUILD_MONOLITHIC_LIBRARY` to control whether to build a monolithic dynamic library. +# +# When `DAWN_BUILD_MONOLITHIC_LIBRARY` is set to `ON`, a single library is built that contains all the targets. The +# library is always built as a shared library, regardless of the value of `BUILD_SHARED_LIBS`. +# +# In the vcpkg migration, we found that when both `DAWN_BUILD_MONOLITHIC_LIBRARY` and `BUILD_SHARED_LIBS` are set to `ON`, the build process will fail with some unexpected errors. +# So we need to set `BUILD_SHARED_LIBS` to `OFF` in this mode. +# +# The following function call ensures BUILD_SHARED_LIBS is set to OFF. +vcpkg_check_linkage(ONLY_STATIC_LIBRARY) + +if(VCPKG_TARGET_IS_EMSCRIPTEN) + message(FATAL_ERROR "This port is currently not supported on Emscripten.") +endif() + +set(onnxruntime_vcpkg_DAWN_OPTIONS) + +list(APPEND onnxruntime_vcpkg_DAWN_OPTIONS + + # enable the vcpkg flag + -DDAWN_ENABLE_VCPKG=ON + + # fetch dependencies is disabled when using vcpkg + -DDAWN_FETCH_DEPENDENCIES=OFF + + -DDAWN_BUILD_SAMPLES=OFF + -DDAWN_ENABLE_NULL=OFF + -DDAWN_BUILD_TESTS=OFF +) + +if (NOT VCPKG_TARGET_IS_EMSCRIPTEN) + list(APPEND onnxruntime_vcpkg_DAWN_OPTIONS + + -DDAWN_BUILD_MONOLITHIC_LIBRARY=ON + -DDAWN_ENABLE_INSTALL=ON + + -DDAWN_ENABLE_DESKTOP_GL=OFF + -DDAWN_ENABLE_OPENGLES=OFF + -DDAWN_SUPPORTS_GLFW_FOR_WINDOWING=OFF + -DDAWN_USE_GLFW=OFF + -DDAWN_USE_WINDOWS_UI=OFF + -DTINT_BUILD_GLSL_WRITER=OFF + -DTINT_BUILD_GLSL_VALIDATOR=OFF + + -DDAWN_DXC_ENABLE_ASSERTS_IN_NDEBUG=OFF + -DDAWN_USE_X11=OFF + + -DTINT_BUILD_TESTS=OFF + -DTINT_BUILD_CMD_TOOLS=OFF + -DTINT_BUILD_IR_BINARY=OFF + -DTINT_BUILD_SPV_READER=OFF + -DTINT_BUILD_WGSL_WRITER=ON + + -DDAWN_ENABLE_SPIRV_VALIDATION=OFF + + # explicitly set the jinja2 and markupsafe directories to empty strings + # when they are empty, the python script will import them from the system + # + # pip install jinja2 markupsafe + # + -DDAWN_JINJA2_DIR= + -DDAWN_MARKUPSAFE_DIR= + ) +endif() + +if(VCPKG_TARGET_IS_WINDOWS) + # feature detection on Windows + vcpkg_check_features(OUT_FEATURE_OPTIONS FEATURE_OPTIONS + FEATURES + windows-use-d3d12 onnxruntime_vcpkg_ENABLE_DAWN_BACKEND_D3D12 + windows-use-vulkan onnxruntime_vcpkg_ENABLE_DAWN_BACKEND_VULKAN + ) + + list(APPEND onnxruntime_vcpkg_DAWN_OPTIONS + -DDAWN_USE_BUILT_DXC=ON + -DTINT_BUILD_HLSL_WRITER=ON + ) + + if((NOT onnxruntime_vcpkg_ENABLE_DAWN_BACKEND_VULKAN) AND(NOT onnxruntime_vcpkg_ENABLE_DAWN_BACKEND_D3D12)) + message(FATAL_ERROR "At least one of \"windows-use-d3d12\" or \"windows-use-vulkan\" must be enabled when using Dawn on Windows.") + endif() + + if(onnxruntime_vcpkg_ENABLE_DAWN_BACKEND_VULKAN) + list(APPEND onnxruntime_vcpkg_DAWN_OPTIONS + -DDAWN_ENABLE_VULKAN=ON + -DTINT_BUILD_SPV_WRITER=ON + ) + else() + list(APPEND onnxruntime_vcpkg_DAWN_OPTIONS + -DDAWN_ENABLE_VULKAN=OFF + ) + endif() + + if(onnxruntime_vcpkg_ENABLE_DAWN_BACKEND_D3D12) + list(APPEND onnxruntime_vcpkg_DAWN_OPTIONS + -DDAWN_ENABLE_D3D12=ON + ) + else() + list(APPEND onnxruntime_vcpkg_DAWN_OPTIONS + -DDAWN_ENABLE_D3D12=OFF + ) + endif() + + # We are currently always using the D3D12 backend. + list(APPEND onnxruntime_vcpkg_DAWN_OPTIONS + -DDAWN_ENABLE_D3D11=OFF + ) +endif() + +vcpkg_from_github( + OUT_SOURCE_PATH SOURCE_PATH + REPO google/dawn + REF "${VERSION}" + SHA512 9771e0be45ad2b85e4d85e12cbf03b9c9b4cc297e8f819e6277d8f02821adb671bf420fd13e241be4f6d7795a3acf0d0a38649c6e0e38a523a6ec0f042591efe + + PATCHES + dawn.patch + dawn_force_enable_f16_nvidia_vulkan.patch + dawn_vcpkg_integration.patch +) + +vcpkg_cmake_configure( + SOURCE_PATH "${SOURCE_PATH}" + WINDOWS_USE_MSBUILD + OPTIONS + ${onnxruntime_vcpkg_DAWN_OPTIONS} + + # MAYBE_UNUSED_VARIABLES +) + +vcpkg_cmake_install() diff --git a/cmake/vcpkg-ports/dawn/vcpkg.json b/cmake/vcpkg-ports/dawn/vcpkg.json new file mode 100644 index 0000000000000..0ea8627f7e17c --- /dev/null +++ b/cmake/vcpkg-ports/dawn/vcpkg.json @@ -0,0 +1,62 @@ +{ + "name": "dawn", + "version-string": "4cb1f9be152a4fa6bb695c08cd707ab078a1e2fb", + "port-version": 1, + "description": "Dawn, a native WebGPU implementation.", + "homepage": "https://dawn.googlesource.com/dawn", + "license": "BSD-3-Clause", + "dependencies": [ + { "name": "vcpkg-cmake", "host": true }, + { "name": "vcpkg-cmake-config", "host": true }, + { "name": "abseil", "version>=": "20250127.1" }, + { "name": "protobuf", "version>=": "3.21.12" }, + { + "name": "spirv-headers", + "version>=": "1.4.304.1", + "platform": "!emscripten" + }, + { + "name": "spirv-tools", + "version>=": "1.4.304.1", + "platform": "!emscripten" + }, + { + "name": "vulkan-headers", + "version>=": "1.4.304.1#1", + "platform": "(windows | linux) & (arm64 | x64)" + }, + { + "name": "vulkan-loader", + "version>=": "1.4.304.1", + "platform": "(windows | linux) & (arm64 | x64)" + }, + { + "name": "vulkan-utility-libraries", + "version>=": "1.4.304.1", + "platform": "(windows | linux) & (arm64 | x64)" + } + ], + "features": { + "windows-use-d3d12": { + "description": "Enable D3D12 backend on Windows.", + "dependencies": [ + { + "name": "directx-dxc", + "version>=": "2025-02-20#1", + "platform": "windows & !arm32" + }, + { + "name": "directx-headers", + "version>=": "1.615.0", + "platform": "windows & !arm32" + } + ] + }, + "windows-use-vulkan": { + "description": "Enable Vulkan backend on Windows." + } + }, + "default-features": [ + { "name": "windows-use-d3d12", "platform": "windows & !arm32" } + ] +} diff --git a/cmake/vcpkg.json b/cmake/vcpkg.json index f46abddfa028f..7c6b2fed36d1b 100644 --- a/cmake/vcpkg.json +++ b/cmake/vcpkg.json @@ -77,27 +77,23 @@ "features": { "tests": { "description": "Build ONNXRuntime unit tests", - "dependencies": [ - "gtest" - ] + "dependencies": ["gtest"] }, "xnnpack-ep": { "description": "Build with XNNPack EP", - "dependencies": [ - "xnnpack" - ] + "dependencies": ["xnnpack"] }, "coreml-ep": { "description": "Build with CoreML EP", - "dependencies": [ - "fp16" - ] + "dependencies": ["fp16"] }, "dml-ep": { - "description": "Build with CoreML EP", - "dependencies": [ - "directx-headers" - ] + "description": "Build with DirectML EP", + "dependencies": ["directx-headers"] + }, + "webgpu-ep": { + "description": "Build with WebGPU EP", + "dependencies": [] } }, "overrides": [ diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/CompileModel.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/CompileModel.shared.cs index 9f42bf2247529..c348184658e7e 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/CompileModel.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/CompileModel.shared.cs @@ -6,6 +6,18 @@ namespace Microsoft.ML.OnnxRuntime using System; using System.Runtime.InteropServices; + /// + /// Flags representing options to enable when compiling a model. + /// Matches OrtCompileApiFlags in the ORT C API. + /// + [Flags] + public enum OrtCompileApiFlags : uint + { + NONE = 0, + ERROR_IF_NO_NODES_COMPILED = 1 << 0, + ERROR_IF_OUTPUT_FILE_EXISTS = 1 << 1, + } + /// /// This class is used to set options for model compilation, and to produce a compiled model using those options. /// See https://onnxruntime.ai/docs/api/c/ for further details of various options. @@ -108,6 +120,16 @@ public void SetEpContextEmbedMode(bool embed) NativeMethods.CompileApi.OrtModelCompilationOptions_SetEpContextEmbedMode(handle, embed)); } + /// + /// Sets flags from OrtCompileApiFlags that represent one or more boolean options to enable. + /// + /// bitwise OR of flags in OrtCompileApiFlags to enable. + public void SetFlags(OrtCompileApiFlags flags) + { + NativeApiStatus.VerifySuccess( + NativeMethods.CompileApi.OrtModelCompilationOptions_SetFlags(handle, (uint)flags)); + } + internal IntPtr Handle => handle; diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeCompileApiMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeCompileApiMethods.shared.cs index 3a87f87d124e9..3edc25b307a21 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeCompileApiMethods.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeCompileApiMethods.shared.cs @@ -1,152 +1,166 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -namespace Microsoft.ML.OnnxRuntime.CompileApi; - using System; using System.Runtime.InteropServices; -// NOTE: The order of the APIs in this struct should match exactly that in OrtCompileApi -// See onnxruntime/core/session/compile_api.cc. -[StructLayout(LayoutKind.Sequential)] -public struct OrtCompileApi +namespace Microsoft.ML.OnnxRuntime.CompileApi { - public IntPtr ReleaseModelCompilationOptions; - public IntPtr CreateModelCompilationOptionsFromSessionOptions; - public IntPtr ModelCompilationOptions_SetInputModelPath; - public IntPtr ModelCompilationOptions_SetInputModelFromBuffer; - public IntPtr ModelCompilationOptions_SetOutputModelPath; - public IntPtr ModelCompilationOptions_SetOutputModelExternalInitializersFile; - public IntPtr ModelCompilationOptions_SetOutputModelBuffer; - public IntPtr ModelCompilationOptions_SetEpContextEmbedMode; - public IntPtr CompileModel; -} + // NOTE: The order of the APIs in this struct should match exactly that in OrtCompileApi + // See onnxruntime/core/session/compile_api.cc. + [StructLayout(LayoutKind.Sequential)] + public struct OrtCompileApi + { + public IntPtr ReleaseModelCompilationOptions; + public IntPtr CreateModelCompilationOptionsFromSessionOptions; + public IntPtr ModelCompilationOptions_SetInputModelPath; + public IntPtr ModelCompilationOptions_SetInputModelFromBuffer; + public IntPtr ModelCompilationOptions_SetOutputModelPath; + public IntPtr ModelCompilationOptions_SetOutputModelExternalInitializersFile; + public IntPtr ModelCompilationOptions_SetOutputModelBuffer; + public IntPtr ModelCompilationOptions_SetEpContextEmbedMode; + public IntPtr CompileModel; + public IntPtr ModelCompilationOptions_SetFlags; + } -internal class NativeMethods -{ - private static OrtCompileApi _compileApi; - - // - // Define the delegate signatures, and a static member for each to hold the marshaled function pointer. - // - // We populate the static members in the constructor of this class. - // - // The C# code will call the C++ API through the delegate instances in the static members. - // - - [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate void DOrtReleaseModelCompilationOptions(IntPtr /* OrtModelCompilationOptions* */ options); - public DOrtReleaseModelCompilationOptions OrtReleaseModelCompilationOptions; - - [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /* OrtStatus* */ DOrtCreateModelCompilationOptionsFromSessionOptions( - IntPtr /* const OrtEnv* */ env, - IntPtr /* const OrtSessionOptions* */ sessionOptions, - out IntPtr /* OrtModelCompilationOptions** */ outOptions); - public DOrtCreateModelCompilationOptionsFromSessionOptions - OrtCreateModelCompilationOptionsFromSessionOptions; - - [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /* OrtStatus* */ DOrtModelCompilationOptions_SetInputModelPath( - IntPtr /* OrtModelCompilationOptions* */ options, - byte[] /* const ORTCHAR_T* */ inputModelPath); - public DOrtModelCompilationOptions_SetInputModelPath OrtModelCompilationOptions_SetInputModelPath; - - [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /* OrtStatus* */ DOrtModelCompilationOptions_SetInputModelFromBuffer( - IntPtr /* OrtModelCompilationOptions* */ options, - byte[] /* const void* */ inputModelData, - UIntPtr /* size_t */ inputModelDataSize); - public DOrtModelCompilationOptions_SetInputModelFromBuffer - OrtModelCompilationOptions_SetInputModelFromBuffer; - - [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /* OrtStatus* */ DOrtModelCompilationOptions_SetOutputModelPath( - IntPtr /* OrtModelCompilationOptions* */ options, - byte[] /* const ORTCHAR_T* */ outputModelPath); - public DOrtModelCompilationOptions_SetOutputModelPath OrtModelCompilationOptions_SetOutputModelPath; - - [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /* OrtStatus* */ DOrtModelCompilationOptions_SetOutputModelExternalInitializersFile( - IntPtr /* OrtModelCompilationOptions* */ options, - byte[] /* const ORTCHAR_T* */ externalInitializersFilePath, - UIntPtr /* size_t */ externalInitializerSizeThreshold); - public DOrtModelCompilationOptions_SetOutputModelExternalInitializersFile - OrtModelCompilationOptions_SetOutputModelExternalInitializersFile; - - [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /* OrtStatus* */ DOrtModelCompilationOptions_SetOutputModelBuffer( - IntPtr /* OrtModelCompilationOptions* */ options, - IntPtr /* OrtAllocator* */ allocator, - ref IntPtr /* void** */ outputModelBufferPtr, - ref UIntPtr /* size_t* */ outputModelBufferSizePtr); - public DOrtModelCompilationOptions_SetOutputModelBuffer OrtModelCompilationOptions_SetOutputModelBuffer; - - [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /* OrtStatus* */ DOrtModelCompilationOptions_SetEpContextEmbedMode( - IntPtr /* OrtModelCompilationOptions* */ options, - bool embedEpContextInModel); - public DOrtModelCompilationOptions_SetEpContextEmbedMode OrtModelCompilationOptions_SetEpContextEmbedMode; - - [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /* OrtStatus* */ DOrtCompileModel( - IntPtr /* const OrtEnv* */ env, - IntPtr /* const OrtModelCompilationOptions* */ modelOptions); - public DOrtCompileModel OrtCompileModel; - - internal NativeMethods(OnnxRuntime.NativeMethods.DOrtGetCompileApi getCompileApi) + internal class NativeMethods { + private static OrtCompileApi _compileApi; + + // + // Define the delegate signatures, and a static member for each to hold the marshaled function pointer. + // + // We populate the static members in the constructor of this class. + // + // The C# code will call the C++ API through the delegate instances in the static members. + // + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate void DOrtReleaseModelCompilationOptions(IntPtr /* OrtModelCompilationOptions* */ options); + public DOrtReleaseModelCompilationOptions OrtReleaseModelCompilationOptions; + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DOrtCreateModelCompilationOptionsFromSessionOptions( + IntPtr /* const OrtEnv* */ env, + IntPtr /* const OrtSessionOptions* */ sessionOptions, + out IntPtr /* OrtModelCompilationOptions** */ outOptions); + public DOrtCreateModelCompilationOptionsFromSessionOptions + OrtCreateModelCompilationOptionsFromSessionOptions; + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DOrtModelCompilationOptions_SetInputModelPath( + IntPtr /* OrtModelCompilationOptions* */ options, + byte[] /* const ORTCHAR_T* */ inputModelPath); + public DOrtModelCompilationOptions_SetInputModelPath OrtModelCompilationOptions_SetInputModelPath; + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DOrtModelCompilationOptions_SetInputModelFromBuffer( + IntPtr /* OrtModelCompilationOptions* */ options, + byte[] /* const void* */ inputModelData, + UIntPtr /* size_t */ inputModelDataSize); + public DOrtModelCompilationOptions_SetInputModelFromBuffer + OrtModelCompilationOptions_SetInputModelFromBuffer; + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DOrtModelCompilationOptions_SetOutputModelPath( + IntPtr /* OrtModelCompilationOptions* */ options, + byte[] /* const ORTCHAR_T* */ outputModelPath); + public DOrtModelCompilationOptions_SetOutputModelPath OrtModelCompilationOptions_SetOutputModelPath; + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DOrtModelCompilationOptions_SetOutputModelExternalInitializersFile( + IntPtr /* OrtModelCompilationOptions* */ options, + byte[] /* const ORTCHAR_T* */ externalInitializersFilePath, + UIntPtr /* size_t */ externalInitializerSizeThreshold); + public DOrtModelCompilationOptions_SetOutputModelExternalInitializersFile + OrtModelCompilationOptions_SetOutputModelExternalInitializersFile; + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DOrtModelCompilationOptions_SetOutputModelBuffer( + IntPtr /* OrtModelCompilationOptions* */ options, + IntPtr /* OrtAllocator* */ allocator, + ref IntPtr /* void** */ outputModelBufferPtr, + ref UIntPtr /* size_t* */ outputModelBufferSizePtr); + public DOrtModelCompilationOptions_SetOutputModelBuffer OrtModelCompilationOptions_SetOutputModelBuffer; + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DOrtModelCompilationOptions_SetEpContextEmbedMode( + IntPtr /* OrtModelCompilationOptions* */ options, + bool embedEpContextInModel); + public DOrtModelCompilationOptions_SetEpContextEmbedMode OrtModelCompilationOptions_SetEpContextEmbedMode; + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DOrtCompileModel( + IntPtr /* const OrtEnv* */ env, + IntPtr /* const OrtModelCompilationOptions* */ modelOptions); + public DOrtCompileModel OrtCompileModel; + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DOrtModelCompilationOptions_SetFlags( + IntPtr /* OrtModelCompilationOptions* */ options, + uint flags); + public DOrtModelCompilationOptions_SetFlags OrtModelCompilationOptions_SetFlags; + + internal NativeMethods(OnnxRuntime.NativeMethods.DOrtGetCompileApi getCompileApi) + { #if NETSTANDARD2_0 - IntPtr compileApiPtr = getCompileApi(); - _compileApi = (OrtCompileApi)Marshal.PtrToStructure(compileApiPtr, typeof(OrtCompileApi)); + IntPtr compileApiPtr = getCompileApi(); + _compileApi = (OrtCompileApi)Marshal.PtrToStructure(compileApiPtr, typeof(OrtCompileApi)); #else - _compileApi = (OrtCompileApi)getCompileApi(); + _compileApi = (OrtCompileApi)getCompileApi(); #endif - OrtReleaseModelCompilationOptions = - (DOrtReleaseModelCompilationOptions)Marshal.GetDelegateForFunctionPointer( - _compileApi.ReleaseModelCompilationOptions, - typeof(DOrtReleaseModelCompilationOptions)); - - OrtCreateModelCompilationOptionsFromSessionOptions = - (DOrtCreateModelCompilationOptionsFromSessionOptions)Marshal.GetDelegateForFunctionPointer( - _compileApi.CreateModelCompilationOptionsFromSessionOptions, - typeof(DOrtCreateModelCompilationOptionsFromSessionOptions)); - - OrtModelCompilationOptions_SetInputModelPath = - (DOrtModelCompilationOptions_SetInputModelPath)Marshal.GetDelegateForFunctionPointer( - _compileApi.ModelCompilationOptions_SetInputModelPath, - typeof(DOrtModelCompilationOptions_SetInputModelPath)); - - OrtModelCompilationOptions_SetInputModelFromBuffer = - (DOrtModelCompilationOptions_SetInputModelFromBuffer)Marshal.GetDelegateForFunctionPointer( - _compileApi.ModelCompilationOptions_SetInputModelFromBuffer, - typeof(DOrtModelCompilationOptions_SetInputModelFromBuffer)); - - OrtModelCompilationOptions_SetOutputModelPath = - (DOrtModelCompilationOptions_SetOutputModelPath)Marshal.GetDelegateForFunctionPointer( - _compileApi.ModelCompilationOptions_SetOutputModelPath, - typeof(DOrtModelCompilationOptions_SetOutputModelPath)); - - OrtModelCompilationOptions_SetOutputModelExternalInitializersFile = - (DOrtModelCompilationOptions_SetOutputModelExternalInitializersFile)Marshal.GetDelegateForFunctionPointer( - _compileApi.ModelCompilationOptions_SetOutputModelExternalInitializersFile, - typeof(DOrtModelCompilationOptions_SetOutputModelExternalInitializersFile)); - - OrtModelCompilationOptions_SetOutputModelBuffer = - (DOrtModelCompilationOptions_SetOutputModelBuffer)Marshal.GetDelegateForFunctionPointer( - _compileApi.ModelCompilationOptions_SetOutputModelBuffer, - typeof(DOrtModelCompilationOptions_SetOutputModelBuffer)); - - OrtModelCompilationOptions_SetEpContextEmbedMode = - (DOrtModelCompilationOptions_SetEpContextEmbedMode)Marshal.GetDelegateForFunctionPointer( - _compileApi.ModelCompilationOptions_SetEpContextEmbedMode, - typeof(DOrtModelCompilationOptions_SetEpContextEmbedMode)); - - OrtCompileModel = - (DOrtCompileModel)Marshal.GetDelegateForFunctionPointer( - _compileApi.CompileModel, - typeof(DOrtCompileModel)); + OrtReleaseModelCompilationOptions = + (DOrtReleaseModelCompilationOptions)Marshal.GetDelegateForFunctionPointer( + _compileApi.ReleaseModelCompilationOptions, + typeof(DOrtReleaseModelCompilationOptions)); + + OrtCreateModelCompilationOptionsFromSessionOptions = + (DOrtCreateModelCompilationOptionsFromSessionOptions)Marshal.GetDelegateForFunctionPointer( + _compileApi.CreateModelCompilationOptionsFromSessionOptions, + typeof(DOrtCreateModelCompilationOptionsFromSessionOptions)); + + OrtModelCompilationOptions_SetInputModelPath = + (DOrtModelCompilationOptions_SetInputModelPath)Marshal.GetDelegateForFunctionPointer( + _compileApi.ModelCompilationOptions_SetInputModelPath, + typeof(DOrtModelCompilationOptions_SetInputModelPath)); + + OrtModelCompilationOptions_SetInputModelFromBuffer = + (DOrtModelCompilationOptions_SetInputModelFromBuffer)Marshal.GetDelegateForFunctionPointer( + _compileApi.ModelCompilationOptions_SetInputModelFromBuffer, + typeof(DOrtModelCompilationOptions_SetInputModelFromBuffer)); + + OrtModelCompilationOptions_SetOutputModelPath = + (DOrtModelCompilationOptions_SetOutputModelPath)Marshal.GetDelegateForFunctionPointer( + _compileApi.ModelCompilationOptions_SetOutputModelPath, + typeof(DOrtModelCompilationOptions_SetOutputModelPath)); + + OrtModelCompilationOptions_SetOutputModelExternalInitializersFile = + (DOrtModelCompilationOptions_SetOutputModelExternalInitializersFile)Marshal.GetDelegateForFunctionPointer( + _compileApi.ModelCompilationOptions_SetOutputModelExternalInitializersFile, + typeof(DOrtModelCompilationOptions_SetOutputModelExternalInitializersFile)); + + OrtModelCompilationOptions_SetOutputModelBuffer = + (DOrtModelCompilationOptions_SetOutputModelBuffer)Marshal.GetDelegateForFunctionPointer( + _compileApi.ModelCompilationOptions_SetOutputModelBuffer, + typeof(DOrtModelCompilationOptions_SetOutputModelBuffer)); + + OrtModelCompilationOptions_SetEpContextEmbedMode = + (DOrtModelCompilationOptions_SetEpContextEmbedMode)Marshal.GetDelegateForFunctionPointer( + _compileApi.ModelCompilationOptions_SetEpContextEmbedMode, + typeof(DOrtModelCompilationOptions_SetEpContextEmbedMode)); + + OrtCompileModel = + (DOrtCompileModel)Marshal.GetDelegateForFunctionPointer( + _compileApi.CompileModel, + typeof(DOrtCompileModel)); + + OrtModelCompilationOptions_SetFlags = + (DOrtModelCompilationOptions_SetFlags)Marshal.GetDelegateForFunctionPointer( + _compileApi.ModelCompilationOptions_SetFlags, + typeof(DOrtModelCompilationOptions_SetFlags)); + + } } } diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/CompileApiTests.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/CompileApiTests.cs index 72c165df56418..bf576b54d8b45 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/CompileApiTests.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/CompileApiTests.cs @@ -8,6 +8,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests; using System; using System.Globalization; +using System.IO; using System.Runtime.InteropServices; using Xunit; @@ -61,6 +62,63 @@ public void BasicUsage() allocator.FreeMemory(bytePtr); } + + // Test using OrtCompileApiFlags.ERROR_NO_NODES_COMPILED. A model compiled with CPU EP will not generate + // any compiled EPContext nodes, so expect an ORT_FAIL error. + using (var compileOptions = new OrtModelCompilationOptions(so)) + { + var model = TestDataLoader.LoadModelFromEmbeddedResource("squeezenet.onnx"); + var output_model_file = "should_not_generate.onnx"; + compileOptions.SetInputModelFromBuffer(model); + compileOptions.SetOutputModelPath(output_model_file); + compileOptions.SetFlags(OrtCompileApiFlags.ERROR_IF_NO_NODES_COMPILED); + + // compile should fail + try + { + compileOptions.CompileModel(); + Assert.Fail("CompileModel() should have thrown an exception"); + } + catch (OnnxRuntimeException ex) + { + Assert.Contains("Unable to compile any nodes", ex.Message); + } + + Assert.False(File.Exists(output_model_file)); // Output file should not be generated. + } + + // Test using OrtCompileApiFlags.ERROR_IF_OUTPUT_FILE_EXISTS. + using (var compileOptions = new OrtModelCompilationOptions(so)) + { + var model = TestDataLoader.LoadModelFromEmbeddedResource("squeezenet.onnx"); + var output_model_file = "squeezenet_ctx.onnx"; + + // Compile and generate an output model. + compileOptions.SetInputModelFromBuffer(model); + compileOptions.SetOutputModelPath(output_model_file); + compileOptions.CompileModel(); + Assert.True(File.Exists(output_model_file)); + + // Try to compile again with flag that prevents replacing an existing file. + // Expect failure. + compileOptions.SetFlags(OrtCompileApiFlags.ERROR_IF_OUTPUT_FILE_EXISTS); + + // compile should fail + try + { + compileOptions.CompileModel(); + Assert.Fail("CompileModel() should have thrown an exception"); + } + catch (OnnxRuntimeException ex) + { + Assert.Contains("exists already", ex.Message); + } + + if (File.Exists(output_model_file)) + { + File.Delete(output_model_file); + } + } } } diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp.csproj b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp.csproj index ee3c8c69aa2ae..54b9925710296 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp.csproj +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp.csproj @@ -71,7 +71,7 @@ Include="$(NativeBuildOutputDir)\onnxruntime_providers_*.dll; $(NativeBuildOutputDir)\onnxruntime_providers_*.pdb; $(NativeBuildOutputDir)\custom_op_library*.dll; - $(NativeBuildOutputDir)\example_plugin_ep.dll"> + $(NativeBuildOutputDir)\example_plugin_ep*.dll"> PreserveNewest false diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index dbe7d9b85092a..7ba2f820e9bdb 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -2037,9 +2037,9 @@ This version of the operator has been available since version 1 of the 'com.micr GatherBlockQuantized is a Gather with data quantized. It is similar to Gather (https://github.com/onnx/onnx/blob/main/docs/Operators.md#gather) with differences: 1. Input `data` is a constant. It is quantized block-wise along attribute `quantize_axis` with block size specified by attribute `block_size`. - `block_size must` be a power of 2 and not smaller than 16, like 16, 32, 64, 128, .. + `block_size` must be a power of 2 and not smaller than 16, like 16, 32, 64, 128, ... 2. Input `data`'s scale and zero point are specified by input `scales` and `zero_points`. `scales` and `zero_points` are also constants. - If `zero_points` is not provided, 0 is the zero point except when data is uint8 type then the default zero point is 8. + If `zero_points` is not provided, the default value is 0 for int4/uint4, or 2^(bits-1) for uint8. 3. During the op execution, `data` and `indices` are first used to generate the quantized output. Then, `scales` and `zero_points` are used to dequantize the output. 4. The `output` and `scales` have the same type. The `data` and `zero_points` have the same type. @@ -2946,29 +2946,20 @@ This version of the operator has been available since version 1 of the 'com.micr ### **com.microsoft.MatMulNBits** - MatMulNBits is a MatMul with weight quantized with N bits(e.g., 2, 3, 4, 5, 6, 7).It does Matrix Multiplication like MatMul (https://github.com/onnx/onnx/blob/main/docs/Operators.md#matmul) with differences: - 1. Input B is a 2D constant Matrix. Its input feature count and output feature count are specified by attribute 'K' and 'N'. - 2. Input B is quantized with x bits which is specified by attribute 'bits'. It is quantized blockwisely along dimension 0 (e.g. column) with block size specified by attribute block_size. - And block_size is not an arbitrary number and must be a power of 2 and not smaller than 16, like 16, 32, 64, 128,.. - 3. Input B's scale and zero point are specified by input scales and zero_points. - - Input B is stored as uint8_t with shape: [N][n_blocks_per_col][blob_size] in which: - - n_blocks_per_col = (K + block_size - 1) / block_size - - blob_size = CeilDiv(block_size * bits, bitsof(uint8_t)<8>) - For all bits from 2-8, a row of data is stored squeezely and represented by uint8_t. - - for 2,4,8 bits, 4x2bit,2x4bit,1x8bit are stored in one uint8_t. - 4bit example: - |.|.|.|.| .|.|.|.| =uint8_t (2x4bit) - - for 3,5,6,7 bits, 32x3bit,32x5bit,16x6bit,32x7bit are stored in 12xuint8_t,20xuint8_t,12xuint8_t,28xuint8_t separately. no bits are wasted. - 3bit example: - |.|.|. |.|.|. |.|.|. = 9bit, which across 2 uint8_t, the highest bit for the second uint8_t is used. - The last uint_8 may have some bits unused. - - - Input scales is stored in same type as original type of B(float32, float16) with shape like: [N * n_blocks_per_col] - Input zero_points is stored as uint8_t or same as type(A). It has the same packing method as input B. - - [N * CeilDiv(n_blocks_per_col * bits, 8)] - If zero_points has same type as A, it's not packed and has the same shape as Scales. + MatMulNBits performs a matrix multiplication where the right-hand-side matrix (weights) is quantized to N bits. + + It is a fusion of two operations: + 1. Linear dequantization of the quantized weights using scale and (optionally) zero-point with formula: + dequantized_weight = (quantized_weight - zero_point) * scale + 2. Matrix multiplication between the input matrix A and the dequantized weight matrix. + + The weight matrix is a 2D constant matrix with the input feature count and output feature count specified by attributes 'K' and 'N'. + It is quantized block-wise along the K dimension with a block size specified by the 'block_size' attribute. + The block size must be a power of 2 and not smaller than 16 (e.g., 16, 32, 64, 128). Each block has its own scale and zero-point. + The quantization is performed using a bit-width specified by the 'bits' attribute, which can take values from 2 to 8. + + The quantized weights are stored in a bit-packed format along the K dimension, with each block being represented by a blob of uint8. + For example, for 4 bits, the first 4 bits are stored in the lower 4 bits of a byte, and the second 4 bits are stored in the higher 4 bits of a byte. #### Version @@ -2978,30 +2969,30 @@ This version of the operator has been available since version 1 of the 'com.micr
K : int (required)
-
size of each input feature
+
Input feature dimension of the weight matrix.
N : int (required)
-
size of each output feature
+
Output feature dimension of the weight matrix.
accuracy_level : int
The minimum accuracy level of input A, can be: 0(unset), 1(fp32), 2(fp16), 3(bf16), or 4(int8) (default unset). It is used to control how input A is quantized or downcast internally while doing computation, for example: 0 means input A will not be quantized or downcast while doing computation. 4 means input A can be quantized with the same block_size to int8 internally from type T1.
-
bits : int (required)
-
number of bits used for weight quantization (default 4)
+
bits : int
+
Bit-width used to quantize the weights (valid range: 2~8)
block_size : int (required)
-
number of groupsize used for weight quantization,(default 128). It needs to be a power of 2 and not smaller than 16.
+
Size of each quantization block along the K (input feature) dimension. Must be a power of two and ≥ 16 (e.g., 16, 32, 64, 128).
#### Inputs (3 - 6)
A : T1
-
The input tensor, not quantized
+
The input tensor, not quantized.
B : T2
-
1 or 2 dimensional data blob
+
Packed uint8 tensor of shape (N, k_blocks, blob_size), where k_blocks = ceil(K / block_size) and blob_size = (block_size * bits / 8). The quantized weights are stored in a bit-packed format along the K dimension, packed within each block_size.
scales : T1
-
quantization scale
+
Per-block scaling factors for dequantization with shape (N, k_blocks) and same data type as input A.
zero_points (optional) : T3
-
quantization zero points
+
Per-block zero point for dequantization. It can be either packed or unpacked: Packed (uint8) format has shape (N, ceil(k_blocks * bits / 8)), and it uses same bit-packing method as Input B. Unpacked (same type as A) format has shape (N, k_blocks). If not provided, a default zero point is used: 2^(bits - 1) (e.g., 8 for 4-bit quantization, 128 for 8-bit).
g_idx (optional) : T4
-
group_idx
+
group_idx. This input is deprecated
bias (optional) : T1
Bias to add to result. It should have shape [N].
@@ -3016,12 +3007,12 @@ This version of the operator has been available since version 1 of the 'com.micr #### Type Constraints
-
T1 : tensor(float), tensor(float16)
-
Constrain input and output types to float/half_float tensors.
-
T2 : tensor(uint8), tensor(int32)
-
Constrain quantized weight types to uint8/int32.
-
T3 : tensor(uint8), tensor(int32), tensor(float16), tensor(float)
-
Constrain quantized zero point types to uint8/int32/float16/float.
+
T1 : tensor(float), tensor(float16), tensor(bfloat16)
+
Constrain input and output types to float tensors.
+
T2 : tensor(uint8)
+
Constrain quantized weight types to uint8.
+
T3 : tensor(uint8), tensor(float16), tensor(float), tensor(bfloat16)
+
Constrain quantized zero point types to uint8 or float tensors.
T4 : tensor(int32)
the index tensor.
@@ -5585,8 +5576,8 @@ This version of the operator has been available since version 1 of the 'com.micr #### Type Constraints
-
T : tensor(float), tensor(float16)
-
Constrain input and output types to float or half tensors.
+
T : tensor(float), tensor(float16), tensor(bfloat16)
+
Constrain input and output to float tensors.
U : tensor(float)
Constrain mean and inv_std_var to float tensors.
@@ -6354,5 +6345,3 @@ No versioning maintained for experimental ops.
T : tensor(float)
Constrain input and output types to float32 tensors.
- - diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 1c30e67534a0c..8c1ab002bce67 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -943,7 +943,7 @@ Do not modify directly.* |Irfft|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |LongformerAttention|*in* input:**T**
*in* weight:**T**
*in* bias:**T**
*in* mask:**T**
*in* global_weight:**T**
*in* global_bias:**T**
*in* global:**G**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| |MatMulBnb4|*in* A:**T1**
*in* B:**T2**
*in* absmax:**T1**
*out* Y:**T1**|1+|**T1** = tensor(bfloat16), tensor(float), tensor(float16)
**T2** = tensor(uint8)| -|MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T3**
*in* g_idx:**T4**
*in* bias:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(uint8)| +|MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T3**
*in* g_idx:**T4**
*in* bias:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(uint8)
**T3** = tensor(float), tensor(float16), tensor(uint8)| |MoE|*in* input:**T**
*in* router_probs:**T**
*in* fc1_experts_weights:**T**
*in* fc1_experts_bias:**T**
*in* fc2_experts_weights:**T**
*in* fc2_experts_bias:**T**
*in* fc3_experts_weights:**T**
*in* fc3_experts_bias:**T**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| |MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* attention_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* past_sequence_length:**M**
*in* cache_indirection:**M**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**
*out* qk:**QK**|1+|**QK** = tensor(float), tensor(float16)
**T** = tensor(float), tensor(float16)| |NGramRepeatBlock|*in* input_ids:**Tid**
*in* scores:**T**
*out* scores_out:**T**|1+|**T** = tensor(float)
**Tid** = tensor(int64)| @@ -959,7 +959,7 @@ Do not modify directly.* |QOrderedMatMul|*in* A:**Q**
*in* scale_A:**S**
*in* B:**Q**
*in* scale_B:**S**
*in* scale_Y:**S**
*in* bias:**S**
*in* C:**Q**
*in* scale_C:**S**
*out* Y:**Q**|1+|**Q** = tensor(int8)
**S** = tensor(float)| |QuantizeLinear|*in* x:**T1**
*in* y_scale:**T1**
*in* y_zero_point:**T2**
*out* y:**T2**|1+|**T1** = tensor(float16)
**T2** = tensor(int8), tensor(uint8)| |QuantizeWithOrder|*in* input:**F**
*in* scale_input:**S**
*out* output:**Q**|1+|**F** = tensor(float), tensor(float16)
**Q** = tensor(int8)
**S** = tensor(float)| -|QuickGelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| +|QuickGelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| |RelativePositionBias|*in* bias_table:**T**
*in* query_length:**U**
*in* key_length:**U**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| |RemovePadding|*in* input:**T**
*in* sequence_token_count:**M**
*out* output:**T**
*out* token_offset:**M**
*out* cumulated_seq_len:**M**
*out* max_seq_len:**M**|1+|**T** = tensor(float), tensor(float16)| |RestorePadding|*in* input:**T**
*in* token_offset:**M**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| @@ -968,7 +968,7 @@ Do not modify directly.* |Sampling|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*in* presence_mask:**I**
*in* seed:**I**
*out* sequences:**I**
*out* filtered_logits:**T**|1+|**T** = tensor(float), tensor(float16)| |SkipGroupNorm|*in* X:**T**
*in* gamma:**M**
*in* beta:**M**
*in* skip:**T**
*in* bias:**T**
*out* Y:**T**
*out* S:**T**|1+|**T** = tensor(float), tensor(float16)| |SkipLayerNormalization|*in* input:**T**
*in* skip:**T**
*in* gamma:**T**
*in* beta:**T**
*in* bias:**T**
*out* output:**T**
*out* mean:**U**
*out* inv_std_var:**U**
*out* input_skip_bias_sum:**T**|1+|**T** = tensor(float), tensor(float16)| -|SkipSimplifiedLayerNormalization|*in* input:**T**
*in* skip:**T**
*in* gamma:**T**
*in* bias:**T**
*out* output:**T**
*out* mean:**U**
*out* inv_std_var:**U**
*out* input_skip_bias_sum:**T**|1+|**T** = tensor(float), tensor(float16)| +|SkipSimplifiedLayerNormalization|*in* input:**T**
*in* skip:**T**
*in* gamma:**T**
*in* bias:**T**
*out* output:**T**
*out* mean:**U**
*out* inv_std_var:**U**
*out* input_skip_bias_sum:**T**|1+|**T** = tensor(bfloat16), tensor(float), tensor(float16)| |SparseAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* block_row_indices:**M**
*in* block_col_indices:**M**
*in* total_sequence_length:**M**
*in* key_total_sequence_lengths:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(bfloat16), tensor(float16)| |TransposeMatMul|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| |Trilu|*in* X:**T**
*in* k:**tensor(int64)**
*out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| diff --git a/include/onnxruntime/core/common/common.h b/include/onnxruntime/core/common/common.h index 10f658f52e0d9..adfd341451aed 100644 --- a/include/onnxruntime/core/common/common.h +++ b/include/onnxruntime/core/common/common.h @@ -302,7 +302,7 @@ inline std::wstring ToWideString(const std::wstring& s) { return s; } inline std::string ToWideString(const std::string& s) { return s; } #endif -constexpr size_t kMaxStrLen = 2048; +constexpr size_t kMaxStrLen = 4096; // Returns whether `key` is in `container`. // Like C++20's map/set contains() member function. diff --git a/include/onnxruntime/core/common/spin_pause.h b/include/onnxruntime/core/common/spin_pause.h index 49b71e5567d3e..4d987f1d12977 100644 --- a/include/onnxruntime/core/common/spin_pause.h +++ b/include/onnxruntime/core/common/spin_pause.h @@ -3,26 +3,11 @@ #pragma once -#if defined(_M_AMD64) -#include -#endif - -#if defined(__x86_64__) -#include -#endif - namespace onnxruntime { - namespace concurrency { // Intrinsic to use in spin-loops - -inline void SpinPause() { -#if defined(_M_AMD64) || defined(__x86_64__) - _mm_pause(); -#endif -} +void SpinPause(); } // namespace concurrency - } // namespace onnxruntime diff --git a/include/onnxruntime/core/framework/allocator.h b/include/onnxruntime/core/framework/allocator.h index ce7d4aaf652d0..15c15c6c143d2 100644 --- a/include/onnxruntime/core/framework/allocator.h +++ b/include/onnxruntime/core/framework/allocator.h @@ -203,7 +203,8 @@ class IAllocator { @returns std::unique_ptr with allocated memory and deleter. Throws if it cannot allocate memory. */ template - static IAllocatorUniquePtr MakeUniquePtrFromOrtAllocator(OrtAllocator* ort_allocator, size_t count_or_bytes) { + static IAllocatorUniquePtr MakeUniquePtrFromOrtAllocator(OrtAllocator* ort_allocator, size_t count_or_bytes, + bool use_reserve = false) { ValidateAllocator(ort_allocator); size_t alloc_size = count_or_bytes; @@ -215,7 +216,12 @@ class IAllocator { alloc_size = ValidatedCalcMemSizeForArray(count_or_bytes, size); } - T* p = static_cast(ort_allocator->Alloc(ort_allocator, alloc_size)); + T* p = nullptr; + if (use_reserve) { + p = static_cast(ort_allocator->Reserve(ort_allocator, alloc_size)); + } else { + p = static_cast(ort_allocator->Alloc(ort_allocator, alloc_size)); + } ValidateAllocation(p, alloc_size); return IAllocatorUniquePtr{p, diff --git a/include/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_options.h b/include/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_options.h index 0b1cbe6afac79..0c9095f566fad 100644 --- a/include/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_options.h +++ b/include/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_options.h @@ -32,7 +32,11 @@ constexpr const char* kProfilesOptShapes = "nv_profile_opt_shapes"; constexpr const char* kCudaGraphEnable = "nv_cuda_graph_enable"; constexpr const char* kONNXBytestream = "nv_onnx_bytestream"; constexpr const char* kONNXBytestreamSize = "nv_onnx_bytestream_size"; +constexpr const char* kMultiProfileEnable = "nv_multi_profile_enable"; } // namespace provider_option_names +namespace run_option_names { +constexpr const char* kProfileIndex = "nv_profile_index"; +} } // namespace nv } // namespace onnxruntime diff --git a/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h b/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h index 687f74c94f154..9fb1eb9107774 100644 --- a/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h +++ b/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h @@ -21,6 +21,7 @@ struct OrtTensorRTProviderOptionsV2 { int trt_min_subgraph_size{1}; // minimum size of TensorRT subgraphs size_t trt_max_workspace_size{0}; // maximum workspace size for TensorRT. Default is 0 means max device memory size int trt_fp16_enable{0}; // enable TensorRT FP16 precision. Default 0 = false, nonzero = true + int trt_bf16_enable{0}; // enable TensorRT BF16 precision. Default 0 = false, nonzero = true int trt_int8_enable{0}; // enable TensorRT INT8 precision. Default 0 = false, nonzero = true const char* trt_int8_calibration_table_name{nullptr}; // TensorRT INT8 calibration table name. int trt_int8_use_native_calibration_table{0}; // use native TensorRT generated calibration table. Default 0 = false, nonzero = true diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 0d2da44971b3a..a2f518ae09a4b 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -3719,6 +3719,7 @@ struct OrtApi { * -# "gpu" * -# "htp": Default. * -# "saver" + * -# "ir" * "backend_path": File path to QNN backend library. Mutually exclusive with "backend_type". * "profiling_level": QNN profiling level. * Available options: @@ -3740,6 +3741,14 @@ struct OrtApi { * -# "low_power_saver" * -# "power_saver" * -# "sustained_high_performance" + * "dump_qnn_ir_dlc": Use the QnnIr backend library to write .dlc files for each subgraph dispatched to QNN. When + * enabled, inference results will be incorrect. Use only for debugging. + * -# "0": Default: disabled + * -# "1": enabled + * "dump_qnn_ir_dlc_dir": Set the directory into which QnnIr will be configured to write QNN graphs as .dlc files. + * Default is current working directory. + * "qnn_ir_backend_path": File path to the QnnIr backend library. If "dump_qnn_ir_dlc" is enabled, use this path + * instead of looking for the Ir backend in the standard location. * "qnn_saver_path": File path to the QNN Saver backend library. If specified, QNN Saver will be enabled and will * dump QNN API calls to disk for replay/debugging. QNN Saver produces incorrect model inference results and * may alter model/EP partitioning. Use only for debugging. @@ -5782,6 +5791,21 @@ struct OrtModelEditorApi { * ORT Compile API */ +/** \brief Flags representing options to enable when compiling a model. + */ +typedef enum OrtCompileApiFlags { + // Default. Do not enable any additional compilation options. + OrtCompileApiFlags_NONE = 0, + + // Force compilation to return an error (ORT_FAIL) if no nodes were compiled. + // Otherwise, a model with basic optimizations (ORT_ENABLE_BASIC) is still generated by default. + OrtCompileApiFlags_ERROR_IF_NO_NODES_COMPILED = 1 << 0, + + // Force compilation to return an error (ORT_FAIL) if a file with the same filename as the output model exists. + // Otherwise, compilation will automatically overwrite the output file if it exists. + OrtCompileApiFlags_ERROR_IF_OUTPUT_FILE_EXISTS = 1 << 1, +} OrtCompileApiFlags; + /** * \brief The OrtCompileApi struct provides functions to compile ONNX models. * @@ -5964,6 +5988,18 @@ struct OrtCompileApi { * \since Version 1.22. */ ORT_API2_STATUS(CompileModel, _In_ const OrtEnv* env, _In_ const OrtModelCompilationOptions* model_options); + + /** \brief Sets flags from OrtCompileApiFlags that represent one or more boolean options to enable. + * + * \param[in] model_compile_options The OrtModelCompilationOptions instance. + * \param[in] flags bitwise OR of flags in OrtCompileApiFlags to enable. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(ModelCompilationOptions_SetFlags, _In_ OrtModelCompilationOptions* model_compile_options, + size_t flags); }; ORT_RUNTIME_CLASS(Ep); diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index bf8e57894d384..c7f81264115c6 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -1160,6 +1160,7 @@ struct ModelCompilationOptions : detail::Base { size_t initializer_size_threshold); ///< Wraps OrtApi::ModelCompilationOptions_SetOutputModelExternalInitializersFile ModelCompilationOptions& SetOutputModelBuffer(OrtAllocator* allocator, void** output_model_buffer_ptr, size_t* output_model_buffer_size_ptr); ///< Wraps OrtApi::ModelCompilationOptions_SetOutputModelBuffer + ModelCompilationOptions& SetFlags(size_t flags); ///< Wraps OrtApi::ModelCompilationOptions_SetFlags }; /** \brief Compiles an input model to generate a model with EPContext nodes that execute EP-specific kernels. Wraps OrtApi::CompileModels. diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index 0d0b3198a8736..6cd52732b923b 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -832,6 +832,11 @@ inline ModelCompilationOptions& ModelCompilationOptions::SetEpContextEmbedMode( return *this; } +inline ModelCompilationOptions& ModelCompilationOptions::SetFlags(size_t flags) { + Ort::ThrowOnError(GetCompileApi().ModelCompilationOptions_SetFlags(this->p_, flags)); + return *this; +} + namespace detail { template diff --git a/java/src/main/java/ai/onnxruntime/OnnxRuntime.java b/java/src/main/java/ai/onnxruntime/OnnxRuntime.java index c28c79f1e723e..fd813eff2f575 100644 --- a/java/src/main/java/ai/onnxruntime/OnnxRuntime.java +++ b/java/src/main/java/ai/onnxruntime/OnnxRuntime.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019, 2023, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2019, 2025, Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ package ai.onnxruntime; @@ -79,6 +79,9 @@ final class OnnxRuntime { /** The short name of the ONNX runtime QNN provider library */ static final String ONNXRUNTIME_LIBRARY_QNN_NAME = "onnxruntime_providers_qnn"; + /** The short name of the WebGPU DAWN library */ + static final String ONNXRUNTIME_LIBRARY_WEBGPU_DAWN_NAME = "webgpu_dawn"; + /** The OS & CPU architecture string */ private static final String OS_ARCH_STR = initOsArch(); @@ -162,6 +165,10 @@ static synchronized void init() throws IOException { // the ONNX Runtime native library will load it extractProviderLibrary(ONNXRUNTIME_LIBRARY_SHARED_NAME); + // Extract and prepare the Dawn shared library (if present) but don't try to load it, + // the ONNX Runtime native library will load it + extractProviderLibrary(ONNXRUNTIME_LIBRARY_WEBGPU_DAWN_NAME); + if (!isAndroid()) { load(ONNXRUNTIME_LIBRARY_NAME); } diff --git a/java/src/test/java/ai/onnxruntime/InferenceTest.java b/java/src/test/java/ai/onnxruntime/InferenceTest.java index 7246738fd4406..c3f9d345078fe 100644 --- a/java/src/test/java/ai/onnxruntime/InferenceTest.java +++ b/java/src/test/java/ai/onnxruntime/InferenceTest.java @@ -2133,6 +2133,9 @@ private static SqueezeNetTuple openSessionSqueezeNet(EnumSet provid case QNN: options.addQnn(Collections.singletonMap("backend_type", "cpu")); break; + case WEBGPU: + options.addWebGPU(Collections.emptyMap()); + break; case VITIS_AI: case RK_NPU: case MI_GRAPH_X: diff --git a/js/node/CMakeLists.txt b/js/node/CMakeLists.txt index 52af5dc48a21a..aedb1e35158ef 100644 --- a/js/node/CMakeLists.txt +++ b/js/node/CMakeLists.txt @@ -92,9 +92,21 @@ if (WIN32) endif() message(STATUS "onnxruntime dist dir: ${ONNXRUNTIME_WIN_BIN_DIR}") endif() + +if (APPLE) + if (${ONNXRUNTIME_GENERATOR} MATCHES "Xcode") + set(ONNXRUNTIME_MAC_BIN_DIR ${ONNXRUNTIME_BUILD_DIR}/${CMAKE_BUILD_TYPE}) + else() + set(ONNXRUNTIME_MAC_BIN_DIR ${ONNXRUNTIME_BUILD_DIR}) + endif() + message(STATUS "onnxruntime dist dir: ${ONNXRUNTIME_MAC_BIN_DIR}") +endif() + # add libraries if (WIN32) target_link_directories(onnxruntime_binding PRIVATE ${ONNXRUNTIME_WIN_BIN_DIR}) +elseif (APPLE) + target_link_directories(onnxruntime_binding PRIVATE ${ONNXRUNTIME_MAC_BIN_DIR}) else() target_link_directories(onnxruntime_binding PRIVATE ${ONNXRUNTIME_BUILD_DIR}) endif() @@ -114,7 +126,7 @@ if (WIN32) file(COPY ${ONNXRUNTIME_WIN_BIN_DIR}/onnxruntime.dll DESTINATION ${dist_folder}) elseif (APPLE) - file(COPY ${ONNXRUNTIME_BUILD_DIR}/libonnxruntime.dylib + file(COPY ${ONNXRUNTIME_MAC_BIN_DIR}/libonnxruntime.dylib DESTINATION ${dist_folder} FOLLOW_SYMLINK_CHAIN) elseif (UNIX) file(COPY ${ONNXRUNTIME_BUILD_DIR}/libonnxruntime.so diff --git a/js/node/README.md b/js/node/README.md index b8414546c4729..272ec6ef561c2 100644 --- a/js/node/README.md +++ b/js/node/README.md @@ -27,13 +27,13 @@ The following table lists the supported versions of ONNX Runtime Node.js binding | EPs/Platforms | Windows x64 | Windows arm64 | Linux x64 | Linux arm64 | MacOS x64 | MacOS arm64 | | ------------- | ------------------ | ------------------ | ------------------ | ------------------ | ------------------ | ------------------ | | CPU | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ | -| WebGPU | ✔️ \[1] | ✔️ \[1] | ❌ \[2] | ❌ \[2] | ✔️ \[1] | ✔️ \[1] | +| WebGPU | ✔️ \[1] | ✔️ \[1] | ✔️ \[1] | ❌ \[2] | ✔️ \[1] | ✔️ \[1] | | DirectML | ✔️ | ✔️ | ❌ | ❌ | ❌ | ❌ | | CUDA | ❌ | ❌ | ✔️\[3] | ❌ | ❌ | ❌ | | CoreML | ❌ | ❌ | ❌ | ❌ | ✔️ | ✔️ | - \[1]: WebGPU support is currently experimental. -- \[2]: WebGPU support is not available on Linux x64 and arm64 yet in the pre-built binaries. +- \[2]: WebGPU support is not available on Linux arm64 yet in the pre-built binaries. - \[3]: CUDA v12. See [CUDA EP Installation](#cuda-ep-installation) for details. To use on platforms without pre-built binaries, you can build Node.js binding from source and consume it by `npm install /js/node/`. See also [instructions](https://onnxruntime.ai/docs/build/inferencing.html#apis-and-language-bindings) for building ONNX Runtime Node.js binding locally. diff --git a/js/node/script/install-metadata.js b/js/node/script/install-metadata.js index e0186ec45d1b4..41b905ba88eaf 100644 --- a/js/node/script/install-metadata.js +++ b/js/node/script/install-metadata.js @@ -20,15 +20,15 @@ const metadata = { 'linux/x64:cuda12': { './libonnxruntime_providers_cuda.so': { package: 'nuget:linux/x64:cuda12', - path: 'runtimes/win-x64/native/libonnxruntime_providers_cuda.so', + path: 'runtimes/linux-x64/native/libonnxruntime_providers_cuda.so', }, './libonnxruntime_providers_shared.so': { package: 'nuget:linux/x64:cuda12', - path: 'runtimes/win-x64/native/libonnxruntime_providers_shared.so', + path: 'runtimes/linux-x64/native/libonnxruntime_providers_shared.so', }, './libonnxruntime_providers_tensorrt.so': { package: 'nuget:linux/x64:cuda12', - path: 'runtimes/win-x64/native/libonnxruntime_providers_tensorrt.so', + path: 'runtimes/linux-x64/native/libonnxruntime_providers_tensorrt.so', }, }, }, diff --git a/js/web/docs/webnn-operators.md b/js/web/docs/webnn-operators.md index 981a684154df1..d9a030f320c6c 100644 --- a/js/web/docs/webnn-operators.md +++ b/js/web/docs/webnn-operators.md @@ -25,6 +25,7 @@ platforms. Check the [WebNN status](https://webmachinelearning.github.io/webnn-s | Clip | ai.onnx(7-10, 11, 12, 13+) | clamp | | | Concat | ai.onnx(7-10, 11-12, 13+) | concat | | | Conv | ai.onnx(7-10, 11+) | conv2d | Only supports 3-D or 4-D input and 'W' (weight) | +| ConvInteger | ai.onnx(10+) | cast, conv2d, dequantizeLinear | Only supports 3-D or 4-D input and 'W' (weight) | | ConvTranspose | ai.onnx(7-10, 11+) | convTranspose2d | Only supports 3-D or 4-D input and 'W' (weight) | | Cos | ai.onnx(7+) | cos | | | CumSum | ai.onnx(11-13, 14+) | cumulativeSum | 'axis' input should be a constant | diff --git a/js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts b/js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts index 3e1f1be22efa2..c8e77d14117bf 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts @@ -48,8 +48,11 @@ const validateInputs = (inputs: readonly TensorView[], attributes: MatMulNBitsAt if (inputs.length === 4) { const zeroPoints = inputs[3]; const zeroPointsShape = zeroPoints.dims; + + // This assumes zero points are packed. + // Unpack format (zero point has same data type and shape as scale) is not supported by webgpu. const expectedZeroPointsSize = - attributes.bits > 4 ? attributes.n * nBlocksPerCol : attributes.n * Math.floor((nBlocksPerCol + 1) / 2); + attributes.n * (attributes.bits === 8 ? nBlocksPerCol : Math.floor((nBlocksPerCol * attributes.bits + 7) / 8)); if (ShapeUtil.size(zeroPointsShape) !== expectedZeroPointsSize) { throw new Error('zeroPoints input size error.'); } diff --git a/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts b/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts index 7c62d1f7182a7..2056416873df5 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts @@ -152,7 +152,9 @@ const createSoftmaxProgramInfo = (context: ComputeContext, attributes: SoftmaxAt // calculate final value for each element in the row for (var col = lindex; col < cols; col += wg) { - let value = exp(getValue(row, col, row_stride) - rowMaxShared) / rowSumShared; + var value = exp(getValue(row, col, row_stride) - rowMaxShared) / rowSumShared; + // max operation protects against NaN since all values should be >=0 + value = max(value, ${valueType}(0.0)); setValue(row, col, row_stride, value); } }`; diff --git a/js/web/script/build.ts b/js/web/script/build.ts index 6a9432c2b5acd..2ea883f739c52 100644 --- a/js/web/script/build.ts +++ b/js/web/script/build.ts @@ -644,7 +644,12 @@ async function main() { isProduction: true, outputName: 'ort.wasm.bundle', format: 'esm', - define: { ...DEFAULT_DEFINE, 'BUILD_DEFS.DISABLE_JSEP': 'true', 'BUILD_DEFS.DISABLE_WEBGL': 'true' }, + define: { + ...DEFAULT_DEFINE, + 'BUILD_DEFS.DISABLE_JSEP': 'true', + 'BUILD_DEFS.DISABLE_WEBGL': 'true', + 'BUILD_DEFS.ENABLE_BUNDLE_WASM_JS': 'true', + }, }); // ort.webgl[.min].[m]js await addAllWebBuildTasks({ diff --git a/js/web/test/data/ops/matmulnbits.jsonc b/js/web/test/data/ops/matmulnbits.jsonc index 63e0a0ed52879..f6671fdab7089 100644 --- a/js/web/test/data/ops/matmulnbits.jsonc +++ b/js/web/test/data/ops/matmulnbits.jsonc @@ -35,7 +35,7 @@ ] }, { - "dims": [8], + "dims": [8, 1], "type": "float32", "data": [0, 1, 2, 3, 4, 5, 6, 7] } @@ -92,12 +92,12 @@ ] }, { - "dims": [8], + "dims": [8, 1], "type": "float32", "data": [0, 1, 2, 3, 4, 5, 6, 7] }, { - "dims": [8], + "dims": [8, 1], "type": "uint8", "data": [248, 249, 250, 251, 252, 253, 254, 255] } @@ -163,7 +163,7 @@ ] }, { - "dims": [16], + "dims": [8, 2], "type": "float32", "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] } @@ -229,12 +229,12 @@ ] }, { - "dims": [16], + "dims": [8, 2], "type": "float32", "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] }, { - "dims": [8], + "dims": [8, 1], "type": "uint8", "data": [0, 1, 2, 3, 4, 5, 6, 7] } @@ -309,7 +309,7 @@ ] }, { - "dims": [24], + "dims": [8, 3], "type": "float32", "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23] } @@ -384,12 +384,12 @@ ] }, { - "dims": [24], + "dims": [8, 3], "type": "float32", "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23] }, { - "dims": [16], + "dims": [8, 2], "type": "uint8", "data": [240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255] } @@ -474,7 +474,7 @@ ] }, { - "dims": [32], + "dims": [8, 4], "type": "float32", "data": [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, @@ -562,7 +562,7 @@ ] }, { - "dims": [32], + "dims": [8, 4], "type": "float32", "data": [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, @@ -570,7 +570,7 @@ ] }, { - "dims": [16], + "dims": [8, 2], "type": "uint8", "data": [240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255] } @@ -604,7 +604,7 @@ ], "cases": [ { - "name": "MatMulNBits; K=80, N=8, block_size=16, bits=4; asymmetric", + "name": "MatMulNBits; K=80, N=8, block_size=16, bits=4; asymmetric; 1D scale and zero point", "inputs": [ { "data": [ @@ -742,7 +742,7 @@ ] }, { - "dims": [16], + "dims": [16, 1], "type": "float32", "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] } @@ -822,12 +822,12 @@ ] }, { - "dims": [16], + "dims": [16, 1], "type": "float32", "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] }, { - "dims": [16], + "dims": [16, 1], "type": "uint8", "data": [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] } @@ -925,7 +925,7 @@ ] }, { - "dims": [32], + "dims": [32, 1], "type": "float32", "data": [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, @@ -1084,7 +1084,7 @@ ] }, { - "dims": [32], + "dims": [32, 1], "type": "float32", "data": [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, @@ -1092,7 +1092,7 @@ ] }, { - "dims": [32], + "dims": [32, 1], "type": "uint8", "data": [ 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, @@ -1251,7 +1251,7 @@ ] }, { - "dims": [32], + "dims": [16, 2], "type": "float32", "data": [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, @@ -1353,7 +1353,7 @@ ] }, { - "dims": [32], + "dims": [16, 2], "type": "float32", "data": [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, @@ -1361,7 +1361,7 @@ ] }, { - "dims": [16], + "dims": [16, 1], "type": "uint8", "data": [128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128] } @@ -1496,7 +1496,7 @@ ] }, { - "dims": [64], + "dims": [32, 2], "type": "float32", "data": [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, @@ -1701,7 +1701,7 @@ ] }, { - "dims": [64], + "dims": [32, 2], "type": "float32", "data": [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, @@ -1710,7 +1710,7 @@ ] }, { - "dims": [32], + "dims": [32, 1], "type": "uint8", "data": [ 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, @@ -1912,7 +1912,7 @@ ] }, { - "dims": [32], + "dims": [32, 1], "type": "float32", "data": [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, @@ -2112,7 +2112,7 @@ ] }, { - "dims": [32], + "dims": [32, 1], "type": "float32", "data": [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, @@ -2120,7 +2120,7 @@ ] }, { - "dims": [32], + "dims": [32, 1], "type": "uint8", "data": [ 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, @@ -2259,7 +2259,7 @@ ] }, { - "dims": [1, 8], + "dims": [8, 1], "type": "float32", "data": [0, 1, 2, 3, 4, 5, 6, 7] } @@ -2322,7 +2322,7 @@ ] }, { - "dims": [1, 8], + "dims": [8, 1], "type": "float32", "data": [0, 1, 2, 3, 4, 5, 6, 7] } @@ -2386,12 +2386,12 @@ ] }, { - "dims": [16], + "dims": [16, 1], "type": "float32", "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] }, { - "dims": [16], + "dims": [16, 1], "type": "uint8", "data": [128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128] } @@ -2458,7 +2458,7 @@ ] }, { - "dims": [8], + "dims": [8, 1], "type": "float32", "data": [0, 1, 2, 3, 4, 5, 6, 7] } diff --git a/js/web/test/e2e/src/cjs-js/main.js b/js/web/test/e2e/src/cjs-js/main.js index c9b8d3e85455d..5eea342fdcae7 100644 --- a/js/web/test/e2e/src/cjs-js/main.js +++ b/js/web/test/e2e/src/cjs-js/main.js @@ -6,13 +6,15 @@ const ort = require('onnxruntime-web/wasm'); const { setupMultipleThreads, testInferenceAndValidate } = require('./shared'); -if (typeof SharedArrayBuffer === 'undefined') { - it('Browser package consuming test - single-thread - [js][commonjs]', async function () { - await testInferenceAndValidate(ort, { executionProviders: ['wasm'] }); - }); -} else { - it('Browser package consuming test - multi-thread - [js][commonjs]', async function () { - setupMultipleThreads(ort); - await testInferenceAndValidate(ort, { executionProviders: ['wasm'] }); - }); +if (typeof it !== 'undefined') { + if (typeof SharedArrayBuffer === 'undefined') { + it('Browser package consuming test - single-thread - [js][commonjs]', async function () { + await testInferenceAndValidate(ort, { executionProviders: ['wasm'] }); + }); + } else { + it('Browser package consuming test - multi-thread - [js][commonjs]', async function () { + setupMultipleThreads(ort); + await testInferenceAndValidate(ort, { executionProviders: ['wasm'] }); + }); + } } diff --git a/js/web/test/e2e/src/esm-js/main.js b/js/web/test/e2e/src/esm-js/main.js index 7687b8b731878..54744a2a4b16f 100644 --- a/js/web/test/e2e/src/esm-js/main.js +++ b/js/web/test/e2e/src/esm-js/main.js @@ -6,13 +6,15 @@ import * as ort from 'onnxruntime-web/wasm'; import { setupMultipleThreads, default as testInferenceAndValidate } from './shared.js'; -if (typeof SharedArrayBuffer === 'undefined') { - it('Browser package consuming test - single-thread - [js][esm]', async function () { - await testInferenceAndValidate(ort, { executionProviders: ['wasm'] }); - }); -} else { - it('Browser package consuming test - multi-thread - [js][esm]', async function () { - setupMultipleThreads(ort); - await testInferenceAndValidate(ort, { executionProviders: ['wasm'] }); - }); +if (typeof it !== 'undefined') { + if (typeof SharedArrayBuffer === 'undefined') { + it('Browser package consuming test - single-thread - [js][esm]', async function () { + await testInferenceAndValidate(ort, { executionProviders: ['wasm'] }); + }); + } else { + it('Browser package consuming test - multi-thread - [js][esm]', async function () { + setupMultipleThreads(ort); + await testInferenceAndValidate(ort, { executionProviders: ['wasm'] }); + }); + } } diff --git a/onnxruntime/__init__.py b/onnxruntime/__init__.py index 6ef0707f4b7c6..56268369bf98a 100644 --- a/onnxruntime/__init__.py +++ b/onnxruntime/__init__.py @@ -30,6 +30,7 @@ NodeArg, # noqa: F401 OrtAllocatorType, # noqa: F401 OrtArenaCfg, # noqa: F401 + OrtCompileApiFlags, # noqa: F401 OrtEpDevice, # noqa: F401 OrtExecutionProviderDevicePolicy, # noqa: F401 OrtHardwareDevice, # noqa: F401 diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index 8f013a1426ef8..65e8808190da3 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -15,6 +15,7 @@ #include "core/mlas/inc/mlas_q4.h" #include "core/providers/cpu/math/matmul_helper.h" #include "core/providers/common.h" +#include "contrib_ops/cpu/quantization/matmul_nbits_helper.h" namespace onnxruntime { namespace contrib { @@ -677,11 +678,17 @@ template Status MatMulNBits::Compute(OpKernelContext* ctx) const { concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool(); const Tensor* a = ctx->Input(InputIndex::A); + // If B is prepacked, B would have been removed from the context + const bool is_b_prepacked = packed_b_size_ > 0; + const Tensor* b = is_b_prepacked ? nullptr : ctx->Input(InputIndex::B); const Tensor* scales = scales_are_packed_ ? nullptr : ctx->Input(InputIndex::scales); const Tensor* zero_points = ctx->Input(InputIndex::zero_points); const Tensor* reorder_idx = ctx->Input(InputIndex::g_idx); const Tensor* bias = ctx->Input(InputIndex::bias); + ORT_RETURN_IF_ERROR(matmul_nbits_helper::CheckInputs( + a, b, scales, zero_points, reorder_idx, bias, N_, K_, block_size_, nbits_)); + TensorShape b_shape({static_cast(N_), static_cast(K_)}); MatMulComputeHelper helper; ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b_shape, false, true)); @@ -713,25 +720,22 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { } } - // If B is prepacked, B would have been removed from the context - const Tensor* b = ctx->Input(InputIndex::B); return ComputeBUnpacked(a, b, scales, zero_points, reorder_idx, bias, y, allocator, thread_pool, helper); } -#define REGISTER_MatMulNBits(T1) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - MatMulNBits, \ - kMSDomain, \ - 1, \ - T1, \ - kCpuExecutionProvider, \ - KernelDefBuilder() \ - .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \ - .TypeConstraint("T2", DataTypeImpl::GetTensorType()) \ - .TypeConstraint("T3", {DataTypeImpl::GetTensorType(), \ - DataTypeImpl::GetTensorType(), \ - DataTypeImpl::GetTensorType()}) \ - .TypeConstraint("T4", DataTypeImpl::GetTensorType()), \ +#define REGISTER_MatMulNBits(T1) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + MatMulNBits, \ + kMSDomain, \ + 1, \ + T1, \ + kCpuExecutionProvider, \ + KernelDefBuilder() \ + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("T2", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("T3", {DataTypeImpl::GetTensorType(), \ + DataTypeImpl::GetTensorType()}) \ + .TypeConstraint("T4", DataTypeImpl::GetTensorType()), \ MatMulNBits); REGISTER_MatMulNBits(float); diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_helper.h b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_helper.h new file mode 100644 index 0000000000000..80a360ebb1b29 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_helper.h @@ -0,0 +1,77 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/common.h" +#include "core/providers/common.h" +#include "core/util/shape_checker.h" + +namespace onnxruntime { +namespace contrib { +namespace matmul_nbits_helper { + +template +Status CheckInputs(const T* /*activation*/, + const T* quantized_weight, + const T* scales, + const T* zero_points, + const T* group_index, + const T* bias, + int64_t n, + int64_t k, + int64_t block_size, + int64_t bits) { + // activation (A) + // quantized_weight (B) : (N, k_blocks, blob_size), or null after prepacking. + // k_blocks = (K + block_size - 1) / block_size + // blob_size = block_size * bits / 8 + // scales : (N, k_blocks) + // zero_points : (N, (k_blocks * bits + 7) / 8) for uint8, (N, k_blocks) for other types, or null + // group_index : (K) or (k_blocks * block_size), or null + // bias : (N), or null + // Note that scales and zero_points can be 1D for backward compatibility. + if (bits != 4 && bits != 8) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "bits should be 4 or 8, got ", bits); + } + + if (block_size < 16 || (block_size & (block_size - 1)) != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "block_size must be a power of 2, and >= 16. Got ", block_size); + } + + int64_t k_blocks = (k + block_size - 1) / block_size; + int64_t blob_size = block_size * bits / 8; + + ASSERT_TENSOR_SHAPE(quantized_weight, make_shape(n, k_blocks, blob_size)); + + // 1D shape is for backward compatibility for existing models. + ASSERT_TENSOR_SHAPE_2(scales, make_shape(n * k_blocks), make_shape(n, k_blocks)); + + if (zero_points != nullptr) { + if (zero_points->GetElementType() == ONNX_NAMESPACE::TensorProto_DataType_UINT8) { + const int64_t zero_point_blob_size = (k_blocks * bits + 7) / 8; + + ASSERT_TENSOR_SHAPE_2(zero_points, make_shape(n * zero_point_blob_size), make_shape(n, zero_point_blob_size)); + } else { + if (zero_points->GetElementType() != scales->GetElementType()) { + return ORT_MAKE_STATUS( + ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'zero_points' and 'scales' should have the same data type when zero_points is not uint8"); + } + + ASSERT_TENSOR_SHAPE_2(zero_points, make_shape(n * k_blocks), make_shape(n, k_blocks)); + } + } + + // Group_index shall be 1D of K, or K padded to multiple of block_size + ASSERT_TENSOR_SHAPE_2(group_index, make_shape(k), make_shape(k_blocks * block_size)); + + ASSERT_TENSOR_SHAPE(bias, make_shape(n)); + + return Status::OK(); +} + +} // namespace matmul_nbits_helper +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/utils/console_dumper.h b/onnxruntime/contrib_ops/cpu/utils/console_dumper.h index 64bd2b7b1855e..0c1d6a95dff20 100644 --- a/onnxruntime/contrib_ops/cpu/utils/console_dumper.h +++ b/onnxruntime/contrib_ops/cpu/utils/console_dumper.h @@ -17,34 +17,28 @@ class IConsoleDumper { virtual ~IConsoleDumper() {} void Disable() { is_enabled_ = false; } bool IsEnabled() const { return is_enabled_; } - virtual void Print(const char* name, const float* tensor, int dim0, int dim1) const = 0; - virtual void Print(const char* name, const MLFloat16* tensor, int dim0, int dim1) const = 0; - virtual void Print(const char* name, const size_t* tensor, int dim0, int dim1) const = 0; - virtual void Print(const char* name, const int64_t* tensor, int dim0, int dim1) const = 0; - virtual void Print(const char* name, const int32_t* tensor, int dim0, int dim1) const = 0; - - virtual void Print(const char* name, const float* tensor, int dim0, int dim1, int dim2) const = 0; - virtual void Print(const char* name, const MLFloat16* tensor, int dim0, int dim1, int dim2) const = 0; - virtual void Print(const char* name, const int64_t* tensor, int dim0, int dim1, int dim2) const = 0; - virtual void Print(const char* name, const int32_t* tensor, int dim0, int dim1, int dim2) const = 0; - - virtual void Print(const char* name, const float* tensor, int dim0, int dim1, int dim2, int dim3) const = 0; - virtual void Print(const char* name, const MLFloat16* tensor, int dim0, int dim1, int dim2, int dim3) const = 0; - virtual void Print(const char* name, const int64_t* tensor, int dim0, int dim1, int dim2, int dim3) const = 0; - virtual void Print(const char* name, const int32_t* tensor, int dim0, int dim1, int dim2, int dim3) const = 0; - - virtual void Print(const char* name, const int32_t* tensor, gsl::span& dims) const = 0; - virtual void Print(const char* name, const int64_t* tensor, gsl::span& dims) const = 0; - virtual void Print(const char* name, const float* tensor, gsl::span& dims) const = 0; - virtual void Print(const char* name, const MLFloat16* tensor, gsl::span& dims) const = 0; virtual void Print(const char* name, const Tensor& value) const = 0; virtual void Print(const char* name, const OrtValue& value) const = 0; - virtual void Print(const char* name, int index, bool end_line) const = 0; - virtual void Print(const char* name, const std::string& value, bool end_line) const = 0; - virtual void Print(const std::string& value) const = 0; +#define TENSOR_DUMPER_PRINT_TYPE(dtype) \ + virtual void Print(const char* name, const dtype* tensor, int dim0, int dim1) const = 0; \ + virtual void Print(const char* name, const dtype* tensor, int dim0, int dim1, int dim2) const = 0; \ + virtual void Print(const char* name, const dtype* tensor, int dim0, int dim1, int dim2, int dim3) const = 0; \ + virtual void Print(const char* name, const dtype* tensor, gsl::span& dims) const = 0; + + TENSOR_DUMPER_PRINT_TYPE(int8_t) + TENSOR_DUMPER_PRINT_TYPE(uint8_t) + TENSOR_DUMPER_PRINT_TYPE(int32_t) + TENSOR_DUMPER_PRINT_TYPE(int64_t) + TENSOR_DUMPER_PRINT_TYPE(float) + TENSOR_DUMPER_PRINT_TYPE(MLFloat16) + TENSOR_DUMPER_PRINT_TYPE(BFloat16) + TENSOR_DUMPER_PRINT_TYPE(UInt4x2) + TENSOR_DUMPER_PRINT_TYPE(Int4x2) +#undef TENSOR_DUMPER_PRINT_TYPE + protected: bool is_enabled_; }; diff --git a/onnxruntime/contrib_ops/cpu/utils/dump_tensor.cc b/onnxruntime/contrib_ops/cpu/utils/dump_tensor.cc index 7755f9505d99d..947311b89fbfd 100644 --- a/onnxruntime/contrib_ops/cpu/utils/dump_tensor.cc +++ b/onnxruntime/contrib_ops/cpu/utils/dump_tensor.cc @@ -62,6 +62,54 @@ void DumpCpuTensor(const char* name, const T* tensor, int dim0, int dim1, int di } } +template +void DumpCpuTensor(const char* name, const T* tensor, int dim0, int dim1, int dim2, int dim3) { + std::unique_lock lock(s_mutex); + + if (s_output_thread_id) + std::cout << "Thread ID:" << std::this_thread::get_id() << std::endl; + + if (nullptr != name) { + std::cout << std::string(name) << std::endl; + } + + if (onnxruntime::utils::kDefaultSnippetThreshold < static_cast(dim0 * dim1 * dim2 * dim3)) { + for (int i = 0; i < dim0; i++) { + std::cout << "[" << i << "]:" << std::endl; + onnxruntime::utils::PrintCpuTensorSnippet(tensor + i * dim1 * dim2 * dim3, dim1, dim2, dim3, + onnxruntime::utils::kDefaultSnippetEdgeItems); + } + } else { + for (int i = 0; i < dim0; i++) { + std::cout << "[" << i << "]:" << std::endl; + onnxruntime::utils::PrintCpuTensorFull(tensor + i * dim1 * dim2 * dim3, dim1, dim2, dim3); + } + } +} + +void DumpCpuTensor(const char* name, const Tensor& tensor, int dim0, int dim1, int dim2, int dim3) { + MLDataType dataType = tensor.DataType(); + if (dataType == DataTypeImpl::GetType()) { + DumpCpuTensor(name, tensor.Data(), dim0, dim1, dim2, dim3); + } else if (dataType == DataTypeImpl::GetType()) { + DumpCpuTensor(name, tensor.Data(), dim0, dim1, dim2, dim3); + } else if (dataType == DataTypeImpl::GetType()) { + DumpCpuTensor(name, tensor.Data(), dim0, dim1, dim2, dim3); + } else if (dataType == DataTypeImpl::GetType()) { + DumpCpuTensor(name, tensor.Data(), dim0, dim1, dim2, dim3); + } else if (dataType == DataTypeImpl::GetType()) { + DumpCpuTensor(name, tensor.Data(), dim0, dim1, dim2, dim3); + } else if (dataType == DataTypeImpl::GetType()) { + DumpCpuTensor(name, tensor.Data(), dim0, dim1, dim2, dim3); + } else if (dataType == DataTypeImpl::GetType()) { + DumpCpuTensor(name, tensor.Data(), dim0, dim1, dim2, dim3); + } else if (dataType == DataTypeImpl::GetType()) { + DumpCpuTensor(name, tensor.Data(), dim0, dim1, dim2, dim3); + } else { + assert(0); + } +} + void DumpCpuTensor(const char* name, const Tensor& tensor, int dim0, int dim1, int dim2) { MLDataType dataType = tensor.DataType(); if (dataType == DataTypeImpl::GetType()) { @@ -72,6 +120,14 @@ void DumpCpuTensor(const char* name, const Tensor& tensor, int dim0, int dim1, i DumpCpuTensor(name, tensor.Data(), dim0, dim1, dim2); } else if (dataType == DataTypeImpl::GetType()) { DumpCpuTensor(name, tensor.Data(), dim0, dim1, dim2); + } else if (dataType == DataTypeImpl::GetType()) { + DumpCpuTensor(name, tensor.Data(), dim0, dim1, dim2); + } else if (dataType == DataTypeImpl::GetType()) { + DumpCpuTensor(name, tensor.Data(), dim0, dim1, dim2); + } else if (dataType == DataTypeImpl::GetType()) { + DumpCpuTensor(name, tensor.Data(), dim0, dim1, dim2); + } else if (dataType == DataTypeImpl::GetType()) { + DumpCpuTensor(name, tensor.Data(), dim0, dim1, dim2); } else { assert(0); } @@ -87,11 +143,23 @@ void DumpCpuTensor(const char* name, const Tensor& tensor, int dim0, int dim1) { DumpCpuTensor(name, tensor.Data(), dim0, dim1); } else if (dataType == DataTypeImpl::GetType()) { DumpCpuTensor(name, tensor.Data(), dim0, dim1); + } else if (dataType == DataTypeImpl::GetType()) { + DumpCpuTensor(name, tensor.Data(), dim0, dim1); + } else if (dataType == DataTypeImpl::GetType()) { + DumpCpuTensor(name, tensor.Data(), dim0, dim1); + } else if (dataType == DataTypeImpl::GetType()) { + DumpCpuTensor(name, tensor.Data(), dim0, dim1); + } else if (dataType == DataTypeImpl::GetType()) { + DumpCpuTensor(name, tensor.Data(), dim0, dim1); } else { assert(0); } } +void DumpCpuTensor(const char* name, const Tensor& tensor, int dim0) { + DumpCpuTensor(name, tensor, 1, dim0); +} + void DumpCpuTensor(const char* name, const Tensor& tensor) { const auto& shape = tensor.Shape(); @@ -101,21 +169,33 @@ void DumpCpuTensor(const char* name, const Tensor& tensor) { std::cout << "Shape:" << shape << std::endl; size_t num_dims = shape.NumDimensions(); - if (num_dims >= 3) { - int dim0 = static_cast(shape.SizeToDimension(num_dims - 2)); - int dim1 = static_cast(shape[num_dims - 2]); - int dim2 = static_cast(shape[num_dims - 1]); + if (num_dims >= 4) { + int dim0 = static_cast(shape.SizeToDimension(num_dims - 4)); + int dim1 = static_cast(shape[num_dims - 3]); + int dim2 = static_cast(shape[num_dims - 2]); + int dim3 = static_cast(shape[num_dims - 1]); + DumpCpuTensor(nullptr, tensor, dim0, dim1, dim2, dim3); + return; + } + + if (num_dims == 3) { + int dim0 = static_cast(shape[0]); + int dim1 = static_cast(shape[1]); + int dim2 = static_cast(shape[2]); DumpCpuTensor(nullptr, tensor, dim0, dim1, dim2); return; } - auto num_items = shape.Size(); - size_t num_rows = 1; - if (num_dims > 1) { - num_rows = static_cast(shape[0]); + if (num_dims == 2) { + int dim0 = static_cast(shape[0]); + int dim1 = static_cast(shape[1]); + DumpCpuTensor(nullptr, tensor, dim0, dim1); + return; + } + + if (num_dims == 1) { + DumpCpuTensor(nullptr, tensor, static_cast(shape[0])); } - size_t row_size = num_items / num_rows; - DumpCpuTensor(nullptr, tensor, static_cast(num_rows), static_cast(row_size)); } CpuTensorConsoleDumper::CpuTensorConsoleDumper() { @@ -133,84 +213,6 @@ void CpuTensorConsoleDumper::Print(const std::string& value) const { std::cout << value << std::endl; } -void CpuTensorConsoleDumper::Print(const char* name, const float* tensor, int dim0, int dim1) const { - if (!is_enabled_) - return; - DumpCpuTensor(name, tensor, dim0, dim1); -} - -void CpuTensorConsoleDumper::Print(const char* name, const MLFloat16* tensor, int dim0, int dim1) const { - if (!is_enabled_) - return; - DumpCpuTensor(name, tensor, dim0, dim1); -} - -void CpuTensorConsoleDumper::Print(const char* name, const size_t* tensor, int dim0, int dim1) const { - if (!is_enabled_) - return; - DumpCpuTensor(name, tensor, dim0, dim1); -} - -void CpuTensorConsoleDumper::Print(const char* name, const int64_t* tensor, int dim0, int dim1) const { - if (!is_enabled_) - return; - DumpCpuTensor(name, tensor, dim0, dim1); -} - -void CpuTensorConsoleDumper::Print(const char* name, const int32_t* tensor, int dim0, int dim1) const { - if (!is_enabled_) - return; - DumpCpuTensor(name, tensor, dim0, dim1); -} - -void CpuTensorConsoleDumper::Print(const char* name, const float* tensor, int dim0, int dim1, int dim2) const { - if (!is_enabled_) - return; - DumpCpuTensor(name, tensor, dim0, dim1, dim2); -} - -void CpuTensorConsoleDumper::Print(const char* name, const MLFloat16* tensor, int dim0, int dim1, int dim2) const { - if (!is_enabled_) - return; - DumpCpuTensor(name, tensor, dim0, dim1, dim2); -} - -void CpuTensorConsoleDumper::Print(const char* name, const int64_t* tensor, int dim0, int dim1, int dim2) const { - if (!is_enabled_) - return; - DumpCpuTensor(name, tensor, dim0, dim1, dim2); -} - -void CpuTensorConsoleDumper::Print(const char* name, const int32_t* tensor, int dim0, int dim1, int dim2) const { - if (!is_enabled_) - return; - DumpCpuTensor(name, tensor, dim0, dim1, dim2); -} - -void CpuTensorConsoleDumper::Print(const char* name, const float* tensor, int dim0, int dim1, int dim2, int dim3) const { - if (!is_enabled_) - return; - DumpCpuTensor(name, tensor, dim0 * dim1, dim2, dim3); -} - -void CpuTensorConsoleDumper::Print(const char* name, const MLFloat16* tensor, int dim0, int dim1, int dim2, int dim3) const { - if (!is_enabled_) - return; - DumpCpuTensor(name, tensor, dim0 * dim1, dim2, dim3); -} - -void CpuTensorConsoleDumper::Print(const char* name, const int64_t* tensor, int dim0, int dim1, int dim2, int dim3) const { - if (!is_enabled_) - return; - DumpCpuTensor(name, tensor, dim0 * dim1, dim2, dim3); -} - -void CpuTensorConsoleDumper::Print(const char* name, const int32_t* tensor, int dim0, int dim1, int dim2, int dim3) const { - if (!is_enabled_) - return; - DumpCpuTensor(name, tensor, dim0 * dim1, dim2, dim3); -} - void CpuTensorConsoleDumper::Print(const char* name, const Tensor& tensor) const { if (!is_enabled_) return; @@ -222,45 +224,33 @@ void CpuTensorConsoleDumper::Print(const char* name, const OrtValue& value) cons Print(name, tensor); } -void CpuTensorConsoleDumper::Print(const char* name, int index, bool end_line) const { - if (!is_enabled_) - return; - - std::unique_lock lock(s_mutex); - std::cout << std::string(name) << "[" << index << "]"; - - if (end_line) { - std::cout << std::endl; +#define TENSOR_DUMPER_PRINT_TYPE(dtype) \ + void CpuTensorConsoleDumper::Print(const char* name, const dtype* tensor, int dim0, int dim1) const { \ + if (is_enabled_) \ + DumpCpuTensor(name, tensor, dim0, dim1); \ + } \ + void CpuTensorConsoleDumper::Print(const char* name, const dtype* tensor, int dim0, int dim1, int dim2) const { \ + if (is_enabled_) \ + DumpCpuTensor(name, tensor, dim0, dim1, dim2); \ + } \ + void CpuTensorConsoleDumper::Print(const char* name, const dtype* tensor, int dim0, int dim1, int dim2, int dim3) const { \ + if (is_enabled_) \ + DumpCpuTensor(name, tensor, dim0, dim1, dim2, dim3); \ + } \ + void CpuTensorConsoleDumper::Print(const char* name, const dtype* tensor, gsl::span& dims) const { \ + PrintTensorByDims(this, name, tensor, dims); \ } -} -void CpuTensorConsoleDumper::Print(const char* name, const std::string& value, bool end_line) const { - if (!is_enabled_) - return; - - std::unique_lock lock(s_mutex); - std::cout << std::string(name) << "=" << value; - - if (end_line) { - std::cout << std::endl; - } -} - -void CpuTensorConsoleDumper::Print(const char* name, const int32_t* tensor, gsl::span& dims) const { - PrintTensorByDims(this, name, tensor, dims); -} - -void CpuTensorConsoleDumper::Print(const char* name, const int64_t* tensor, gsl::span& dims) const { - PrintTensorByDims(this, name, tensor, dims); -} - -void CpuTensorConsoleDumper::Print(const char* name, const float* tensor, gsl::span& dims) const { - PrintTensorByDims(this, name, tensor, dims); -} - -void CpuTensorConsoleDumper::Print(const char* name, const MLFloat16* tensor, gsl::span& dims) const { - PrintTensorByDims(this, name, tensor, dims); -} +TENSOR_DUMPER_PRINT_TYPE(int8_t) +TENSOR_DUMPER_PRINT_TYPE(uint8_t) +TENSOR_DUMPER_PRINT_TYPE(int32_t) +TENSOR_DUMPER_PRINT_TYPE(int64_t) +TENSOR_DUMPER_PRINT_TYPE(float) +TENSOR_DUMPER_PRINT_TYPE(MLFloat16) +TENSOR_DUMPER_PRINT_TYPE(BFloat16) +TENSOR_DUMPER_PRINT_TYPE(UInt4x2) +TENSOR_DUMPER_PRINT_TYPE(Int4x2) +#undef TENSOR_DUMPER_PRINT_TYPE #else @@ -270,68 +260,33 @@ CpuTensorConsoleDumper::CpuTensorConsoleDumper() { void CpuTensorConsoleDumper::Print(const std::string&) const { } -void CpuTensorConsoleDumper::Print(const char*, const float*, int, int) const { -} - -void CpuTensorConsoleDumper::Print(const char*, const MLFloat16*, int, int) const { -} - -void CpuTensorConsoleDumper::Print(const char*, const size_t*, int, int) const { -} - -void CpuTensorConsoleDumper::Print(const char*, const int64_t*, int, int) const { -} - -void CpuTensorConsoleDumper::Print(const char*, const int32_t*, int, int) const { -} - -void CpuTensorConsoleDumper::Print(const char*, const float*, int, int, int) const { -} - -void CpuTensorConsoleDumper::Print(const char*, const MLFloat16*, int, int, int) const { -} - -void CpuTensorConsoleDumper::Print(const char*, const int64_t*, int, int, int) const { -} - -void CpuTensorConsoleDumper::Print(const char*, const int32_t*, int, int, int) const { -} - -void CpuTensorConsoleDumper::Print(const char*, const float*, int, int, int, int) const { -} - -void CpuTensorConsoleDumper::Print(const char*, const MLFloat16*, int, int, int, int) const { -} - -void CpuTensorConsoleDumper::Print(const char*, const int64_t*, int, int, int, int) const { -} - -void CpuTensorConsoleDumper::Print(const char*, const int32_t*, int, int, int, int) const { -} - void CpuTensorConsoleDumper::Print(const char*, const Tensor&) const { } void CpuTensorConsoleDumper::Print(const char*, const OrtValue&) const { } -void CpuTensorConsoleDumper::Print(const char*, int, bool) const { -} - -void CpuTensorConsoleDumper::Print(const char*, const std::string&, bool) const { -} - -void CpuTensorConsoleDumper::Print(const char*, const int32_t*, gsl::span&) const { -} - -void CpuTensorConsoleDumper::Print(const char*, const int64_t*, gsl::span&) const { -} +#define TENSOR_DUMPER_PRINT_TYPE(dtype) \ + void CpuTensorConsoleDumper::Print(const char*, const dtype*, int, int) const { \ + } \ + void CpuTensorConsoleDumper::Print(const char*, const dtype*, int, int, int) const { \ + } \ + void CpuTensorConsoleDumper::Print(const char*, const dtype*, int, int, int, int) const { \ + } \ + void CpuTensorConsoleDumper::Print(const char*, const dtype*, gsl::span&) const { \ + } -void CpuTensorConsoleDumper::Print(const char*, const float*, gsl::span&) const { -} +TENSOR_DUMPER_PRINT_TYPE(int8_t) +TENSOR_DUMPER_PRINT_TYPE(uint8_t) +TENSOR_DUMPER_PRINT_TYPE(int32_t) +TENSOR_DUMPER_PRINT_TYPE(int64_t) +TENSOR_DUMPER_PRINT_TYPE(float) +TENSOR_DUMPER_PRINT_TYPE(MLFloat16) +TENSOR_DUMPER_PRINT_TYPE(BFloat16) +TENSOR_DUMPER_PRINT_TYPE(UInt4x2) +TENSOR_DUMPER_PRINT_TYPE(Int4x2) +#undef TENSOR_DUMPER_PRINT_TYPE -void CpuTensorConsoleDumper::Print(const char*, const MLFloat16*, gsl::span&) const { -} #endif } // namespace contrib diff --git a/onnxruntime/contrib_ops/cpu/utils/dump_tensor.h b/onnxruntime/contrib_ops/cpu/utils/dump_tensor.h index 6fc4dfd4a0671..6de0439d7f8ba 100644 --- a/onnxruntime/contrib_ops/cpu/utils/dump_tensor.h +++ b/onnxruntime/contrib_ops/cpu/utils/dump_tensor.h @@ -14,31 +14,9 @@ class CpuTensorConsoleDumper : public IConsoleDumper { public: CpuTensorConsoleDumper(); virtual ~CpuTensorConsoleDumper() {} - void Print(const char* name, const float* tensor, int dim0, int dim1) const override; - void Print(const char* name, const MLFloat16* tensor, int dim0, int dim1) const override; - void Print(const char* name, const size_t* tensor, int dim0, int dim1) const override; - void Print(const char* name, const int64_t* tensor, int dim0, int dim1) const override; - void Print(const char* name, const int32_t* tensor, int dim0, int dim1) const override; - - void Print(const char* name, const float* tensor, int dim0, int dim1, int dim2) const override; - void Print(const char* name, const MLFloat16* tensor, int dim0, int dim1, int dim2) const override; - void Print(const char* name, const int64_t* tensor, int dim0, int dim1, int dim2) const override; - void Print(const char* name, const int32_t* tensor, int dim0, int dim1, int dim2) const override; - - void Print(const char* name, const float* tensor, int dim0, int dim1, int dim2, int dim3) const override; - void Print(const char* name, const MLFloat16* tensor, int dim0, int dim1, int dim2, int dim3) const override; - void Print(const char* name, const int64_t* tensor, int dim0, int dim1, int dim2, int dim3) const override; - void Print(const char* name, const int32_t* tensor, int dim0, int dim1, int dim2, int dim3) const override; - - void Print(const char* name, const int32_t* tensor, gsl::span& dims) const override; - void Print(const char* name, const int64_t* tensor, gsl::span& dims) const override; - void Print(const char* name, const float* tensor, gsl::span& dims) const override; - void Print(const char* name, const MLFloat16* tensor, gsl::span& dims) const override; void Print(const char* name, const Tensor& value) const override; void Print(const char* name, const OrtValue& value) const override; - void Print(const char* name, int index, bool end_line) const override; - void Print(const char* name, const std::string& value, bool end_line) const override; void Print(const std::string& value) const override; @@ -47,6 +25,23 @@ class CpuTensorConsoleDumper : public IConsoleDumper { void Print(const char* name, const std::vector& vec, size_t max_count = 0) const { this->Print(name, vec.data(), 1, static_cast(std::min(max_count, vec.size()))); } + +#define TENSOR_DUMPER_PRINT_TYPE(dtype) \ + void Print(const char* name, const dtype* tensor, int dim0, int dim1) const override; \ + void Print(const char* name, const dtype* tensor, int dim0, int dim1, int dim2) const override; \ + void Print(const char* name, const dtype* tensor, int dim0, int dim1, int dim2, int dim3) const override; \ + void Print(const char* name, const dtype* tensor, gsl::span& dims) const override; + + TENSOR_DUMPER_PRINT_TYPE(int8_t) + TENSOR_DUMPER_PRINT_TYPE(uint8_t) + TENSOR_DUMPER_PRINT_TYPE(int32_t) + TENSOR_DUMPER_PRINT_TYPE(int64_t) + TENSOR_DUMPER_PRINT_TYPE(float) + TENSOR_DUMPER_PRINT_TYPE(MLFloat16) + TENSOR_DUMPER_PRINT_TYPE(BFloat16) + TENSOR_DUMPER_PRINT_TYPE(UInt4x2) + TENSOR_DUMPER_PRINT_TYPE(Int4x2) +#undef TENSOR_DUMPER_PRINT_TYPE }; } // namespace contrib diff --git a/onnxruntime/contrib_ops/cuda/activation/activations.cc b/onnxruntime/contrib_ops/cuda/activation/activations.cc index 6303858b9bd48..0c4d42b328510 100644 --- a/onnxruntime/contrib_ops/cuda/activation/activations.cc +++ b/onnxruntime/contrib_ops/cuda/activation/activations.cc @@ -44,7 +44,8 @@ namespace cuda { #define UNARY_ACTIVATION_OP_HFD(name, ver, domain) \ UNARY_ACTIVATION_OP_TYPED(name, ver, domain, MLFloat16) \ UNARY_ACTIVATION_OP_TYPED(name, ver, domain, float) \ - UNARY_ACTIVATION_OP_TYPED(name, ver, domain, double) + UNARY_ACTIVATION_OP_TYPED(name, ver, domain, double) \ + UNARY_ACTIVATION_OP_TYPED(name, ver, domain, BFloat16) UNARY_ACTIVATION_OP_HFD(Affine, 1, kOnnxDomain); UNARY_ACTIVATION_OP_HFD(ParametricSoftplus, 1, kOnnxDomain); diff --git a/onnxruntime/contrib_ops/cuda/activation/activations_impl.cu b/onnxruntime/contrib_ops/cuda/activation/activations_impl.cu index 36f33fbb24c18..a11691d22d8be 100644 --- a/onnxruntime/contrib_ops/cuda/activation/activations_impl.cu +++ b/onnxruntime/contrib_ops/cuda/activation/activations_impl.cu @@ -62,7 +62,8 @@ struct OP_QuickGelu : public CtxQuickGelu { #define SPECIALIZED_UNARY_ACTIVATIONL_HFD(name) \ SPECIALIZED_UNARY_ACTIVATION_IMPL(name, half) \ SPECIALIZED_UNARY_ACTIVATION_IMPL(name, float) \ - SPECIALIZED_UNARY_ACTIVATION_IMPL(name, double) + SPECIALIZED_UNARY_ACTIVATION_IMPL(name, double) \ + SPECIALIZED_UNARY_ACTIVATION_IMPL(name, BFloat16) #define UNARY_ACTIVATION_OP_NAME(name) \ UNARY_ACTIVATION_IMPL(name); \ diff --git a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h index 8d8f735e3ed34..5aeda0f74e92b 100644 --- a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h +++ b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h @@ -142,15 +142,16 @@ template ; + typename Attention::Params p; { // set parameters p.query_ptr = const_cast(reinterpret_cast(params.query)); p.key_ptr = const_cast(reinterpret_cast(params.key)); p.value_ptr = const_cast(reinterpret_cast(params.value)); p.attn_bias_ptr = const_cast(reinterpret_cast(params.attn_bias)); - p.seqstart_q_ptr = params.seqstart_q_ptr; - p.seqstart_k_ptr = params.seqstart_k_ptr; - p.seqlen_k_ptr = params.seqlen_k_ptr; + p.seqstart_q_ptr = const_cast(params.seqstart_q_ptr); + p.seqstart_k_ptr = const_cast(params.seqstart_k_ptr); + p.seqlen_k_ptr = const_cast(params.seqlen_k_ptr); p.logsumexp_ptr = nullptr; // [num_heads, num_queries] for backward or nullptr for forward p.output_ptr = reinterpret_cast(params.output); @@ -260,7 +261,7 @@ void DispatchIsAligned(const MemoryEfficientAttentionParams& params) { params.v_head_size % AlignedAK::kAlignmentV == 0; DISPATCH_BOOL(is_aligned, kIsAligned, ([&]() { - LaunchCutlassFmha(params); + LaunchCutlassFmha(params); })); #if defined(_MSC_VER) && !defined(__clang__) diff --git a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/kernel_forward.h b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/kernel_forward.h index f35d6c2e6c8dc..41691d823f528 100644 --- a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/kernel_forward.h +++ b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/kernel_forward.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -38,6 +38,7 @@ #include #include +#include #include #include "cutlass/fast_math.h" @@ -71,8 +72,6 @@ #include "41_fused_multi_head_attention/gemm_kernel_utils.h" #include "41_fused_multi_head_attention/transform/tile_smem_loader.h" -#include - using namespace gemm_kernel_utils; namespace { @@ -174,9 +173,10 @@ struct AttentionKernel { scalar_t* key_ptr = nullptr; // [num_keys, num_heads, head_dim] scalar_t* value_ptr = nullptr; // [num_keys, num_heads, head_dim_value] scalar_t* attn_bias_ptr = nullptr; // [num_heads, num_queries, num_keys] - const int32_t* seqstart_q_ptr = nullptr; - const int32_t* seqstart_k_ptr = nullptr; - const int32_t* seqlen_k_ptr = nullptr; + int32_t* seqstart_q_ptr = nullptr; + int32_t* seqstart_k_ptr = nullptr; + + int32_t* seqlen_k_ptr = nullptr; uint32_t causal_diagonal_offset = 0; // Output tensors @@ -1105,15 +1105,15 @@ struct AttentionKernel { using EpilogueOutputOp = typename cutlass::epilogue:: thread::MemoryEfficientAttentionNormalize< typename cutlass::platform::conditional< - kIsLast, + kIsLast::value, output_t, output_accum_t>::type, output_accum_t, DefaultOp::kCount, typename DefaultOp::ElementAccumulator, ElementCompute, - kIsFirst, - kIsLast, + kIsFirst::value, + kIsLast::value, cutlass::Array>; using Epilogue = typename cutlass::epilogue::threadblock:: EpiloguePipelined< @@ -1121,7 +1121,7 @@ struct AttentionKernel { typename MM1::Mma::Operator, DefaultEpilogue::kPartitionsK, typename cutlass::platform::conditional< - kIsLast, + kIsLast::value, typename MM1::OutputTileIterator, typename MM1::OutputTileIteratorAccum>::type, typename DefaultEpilogue:: @@ -1139,7 +1139,7 @@ struct AttentionKernel { int col = blockN * MM1::Mma::Shape::kN; auto source_iter = createOutputAccumIter(col); auto dest_iter = call_conditional< - kIsLast, + kIsLast::value, decltype(createOutputIter), decltype(createOutputAccumIter)>:: apply(createOutputIter, createOutputAccumIter, col); diff --git a/onnxruntime/contrib_ops/cuda/bert/layer_norm.cuh b/onnxruntime/contrib_ops/cuda/bert/layer_norm.cuh index ff3178b56c2a6..0953161dc0d44 100644 --- a/onnxruntime/contrib_ops/cuda/bert/layer_norm.cuh +++ b/onnxruntime/contrib_ops/cuda/bert/layer_norm.cuh @@ -25,6 +25,7 @@ limitations under the License. #include "core/providers/cuda/cu_inc/common.cuh" #include "core/providers/cuda/shared_inc/cuda_call.h" #include +#include #include #include @@ -60,6 +61,15 @@ __device__ inline half2 AddHalf2(const half2 a, const half2 b) { #endif } +template <> +__device__ inline nv_bfloat16 Rsqrt(const nv_bfloat16& x) { + return hrsqrt(x); +} + +__device__ inline nv_bfloat162 AddHalf2(const nv_bfloat162 a, const nv_bfloat162 b) { + return __hadd2(a, b); +} + struct KeyValuePairSum { __device__ inline cub::KeyValuePair operator()(const cub::KeyValuePair& a, const cub::KeyValuePair& b) { @@ -78,6 +88,14 @@ struct KeyValuePairSum { const cub::KeyValuePair& b) { return cub::KeyValuePair(AddHalf2(a.key, b.key), AddHalf2(a.value, b.value)); } + + __device__ inline cub::KeyValuePair operator()(const cub::KeyValuePair& a, + const cub::KeyValuePair& b) { + const nv_bfloat162 a2 = __halves2bfloat162(a.key, a.value); + const nv_bfloat162 b2 = __halves2bfloat162(b.key, b.value); + const nv_bfloat162 res = AddHalf2(a2, b2); + return cub::KeyValuePair(__low2bfloat16(res), __high2bfloat16(res)); + } }; template diff --git a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc index 428b903c03682..92ae7e81fb5bd 100644 --- a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc +++ b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc @@ -34,6 +34,7 @@ namespace cuda { REGISTER_KERNEL_TYPED(float) REGISTER_KERNEL_TYPED(MLFloat16) +REGISTER_KERNEL_TYPED(BFloat16) using namespace ONNX_NAMESPACE; @@ -106,19 +107,35 @@ Status SkipLayerNorm::ComputeInternal(OpKernelContext* ctx) const (bias != nullptr) ? reinterpret_cast(bias->Data()) : nullptr, // bias to add sum_output != nullptr ? reinterpret_cast(sum_output->MutableData()) : nullptr); } else { - LaunchSkipLayerNormKernel( - Stream(ctx), - reinterpret_cast(output->MutableData()), - sum_output != nullptr ? reinterpret_cast(sum_output->MutableData()) : nullptr, - reinterpret_cast(input->Data()), - reinterpret_cast(skip->Data()), - (bias != nullptr) ? reinterpret_cast(bias->Data()) : nullptr, - reinterpret_cast(gamma->Data()), - (beta != nullptr) ? reinterpret_cast(beta->Data()) : nullptr, - epsilon_, - hidden_size, - row_count, - skip_size); + if constexpr (std::is_same_v) { + LaunchSkipLayerNormKernel( + Stream(ctx), + reinterpret_cast(output->MutableData()), + sum_output != nullptr ? reinterpret_cast(sum_output->MutableData()) : nullptr, + reinterpret_cast(input->Data()), + reinterpret_cast(skip->Data()), + (bias != nullptr) ? reinterpret_cast(bias->Data()) : nullptr, + reinterpret_cast(gamma->Data()), + (beta != nullptr) ? reinterpret_cast(beta->Data()) : nullptr, + epsilon_, + hidden_size, + row_count, + skip_size); + } else { + LaunchSkipLayerNormKernel( + Stream(ctx), + reinterpret_cast(output->MutableData()), + sum_output != nullptr ? reinterpret_cast(sum_output->MutableData()) : nullptr, + reinterpret_cast(input->Data()), + reinterpret_cast(skip->Data()), + (bias != nullptr) ? reinterpret_cast(bias->Data()) : nullptr, + reinterpret_cast(gamma->Data()), + (beta != nullptr) ? reinterpret_cast(beta->Data()) : nullptr, + epsilon_, + hidden_size, + row_count, + skip_size); + } } CUDA_RETURN_IF_ERROR(cudaGetLastError()); diff --git a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu index 50c8e4b5e0398..a1dcab0a6bf89 100644 --- a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu @@ -30,6 +30,7 @@ limitations under the License. #include "contrib_ops/cuda/bert/layer_norm.cuh" #include "contrib_ops/cuda/bert/skip_layer_norm_impl.h" #include +#include namespace onnxruntime { namespace contrib { @@ -49,6 +50,11 @@ half maybe2half(float x) { return __float2half_rn(x); } +template <> +nv_bfloat16 maybe2half(float x) { + return __float2bfloat16_rn(x); +} + // Using only power of 2 numbers will lead to waste of compute for same size such as 768, which is a very common case // in BERT. Ideally we can step by wrap_size * num_unroll, but listing too many steps will cause long compile time. constexpr int kSizes[] = {128, 320, 384, 640, 768, 1024, 1280, 2048, 4096, 5120, 8192}; @@ -263,7 +269,8 @@ SKIPLAYERNORM_IMPL(float, true); SKIPLAYERNORM_IMPL(float, false); SKIPLAYERNORM_IMPL(half, true); SKIPLAYERNORM_IMPL(half, false); - +SKIPLAYERNORM_IMPL(nv_bfloat16, true); +SKIPLAYERNORM_IMPL(nv_bfloat16, false); } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc index b8931bf1ea0f8..17f3433aed38a 100644 --- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc @@ -42,6 +42,7 @@ class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, BiasAdd); class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, QuickGelu); class CUDA_MS_OP_TYPED_CLASS_NAME(1, double, QuickGelu); class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, QuickGelu); +class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, QuickGelu); class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, TransposeMatMul); // backward compatibility class CUDA_MS_OP_TYPED_CLASS_NAME(1, double, TransposeMatMul); // backward compatibility class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, TransposeMatMul); // backward compatibility @@ -129,6 +130,7 @@ class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, SkipLayerNormalization); class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, SkipLayerNormalization); class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, SkipSimplifiedLayerNormalization); class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, SkipSimplifiedLayerNormalization); +class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, SkipSimplifiedLayerNormalization); class CUDA_ONNX_OP_TYPED_CLASS_NAME(1, float, ThresholdedRelu); class CUDA_ONNX_OP_TYPED_CLASS_NAME(1, double, ThresholdedRelu); class CUDA_ONNX_OP_TYPED_CLASS_NAME(1, MLFloat16, ThresholdedRelu); @@ -256,6 +258,7 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, // backward compatibility BuildKernelCreateInfo, // backward compatibility BuildKernelCreateInfo, // backward compatibility @@ -339,6 +342,7 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/contrib_ops/cuda/llm/common/cuda_runtime_utils.h b/onnxruntime/contrib_ops/cuda/llm/common/cuda_runtime_utils.h new file mode 100644 index 0000000000000..06442c6e02ae0 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/common/cuda_runtime_utils.h @@ -0,0 +1,46 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include "core/providers/cuda/shared_inc/cuda_call.h" + +namespace onnxruntime::llm::common { +inline int getDevice() { + int deviceID{0}; + CUDA_CALL_THROW(cudaGetDevice(&deviceID)); + return deviceID; +} + +inline int getSMVersion() { + int device{-1}; + CUDA_CALL_THROW(cudaGetDevice(&device)); + int sm_major = 0; + int sm_minor = 0; + CUDA_CALL_THROW(cudaDeviceGetAttribute(&sm_major, cudaDevAttrComputeCapabilityMajor, device)); + CUDA_CALL_THROW(cudaDeviceGetAttribute(&sm_minor, cudaDevAttrComputeCapabilityMinor, device)); + return sm_major * 10 + sm_minor; +} + +inline int getMultiProcessorCount() { + int nSM{0}; + int deviceID{0}; + CUDA_CALL_THROW(cudaGetDevice(&deviceID)); + CUDA_CALL_THROW(cudaDeviceGetAttribute(&nSM, cudaDevAttrMultiProcessorCount, deviceID)); + return nSM; +} +} // namespace onnxruntime::llm::common diff --git a/onnxruntime/contrib_ops/cuda/llm/common/logger.h b/onnxruntime/contrib_ops/cuda/llm/common/logger.h new file mode 100644 index 0000000000000..a3992e751926d --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/common/logger.h @@ -0,0 +1,18 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/shared_library/provider_api.h" + +#ifndef NDEBUG +#define ORT_LLM_LOG_TRACE(msg) LOGS_DEFAULT(VERBOSE) << msg +#define ORT_LLM_LOG_DEBUG(msg) LOGS_DEFAULT(VERBOSE) << msg +#else +#define ORT_LLM_LOG_TRACE(msg) +#define ORT_LLM_LOG_DEBUG(msg) +#endif + +#define ORT_LLM_LOG_INFO(msg) LOGS_DEFAULT(INFO) << msg +#define ORT_LLM_LOG_WARNING(msg) LOGS_DEFAULT(WARNING) << msg +#define ORT_LLM_LOG_ERROR(msg) LOGS_DEFAULT(ERROR) << msg diff --git a/onnxruntime/contrib_ops/cuda/llm/common/workspace.h b/onnxruntime/contrib_ops/cuda/llm/common/workspace.h new file mode 100644 index 0000000000000..126884a941336 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/common/workspace.h @@ -0,0 +1,75 @@ +/* + * Copyright (c) 1993-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include +#include + +namespace onnxruntime::llm::common { + +std::uintptr_t constexpr kCudaMemAlign = 128; + +inline int8_t* alignPtr(int8_t* ptr, uintptr_t to) { + uintptr_t addr = (uintptr_t)ptr; + if (addr % to) { + addr += to - addr % to; + } + return reinterpret_cast(addr); +} + +constexpr size_t alignSize(size_t size, size_t to) { + if ((size % to) != 0U) { + size += to - size % to; + } + return size; +} + +inline int8_t* nextWorkspacePtrCommon(int8_t* ptr, uintptr_t previousWorkspaceSize, uintptr_t const alignment) { + uintptr_t addr = (uintptr_t)ptr; + addr += previousWorkspaceSize; + return alignPtr(reinterpret_cast(addr), alignment); +} + +inline int8_t* nextWorkspacePtr(int8_t* ptr, uintptr_t previousWorkspaceSize) { + return nextWorkspacePtrCommon(ptr, previousWorkspaceSize, kCudaMemAlign); +} + +inline int8_t* nextWorkspacePtr( + int8_t* const base, uintptr_t& offset, uintptr_t const size, uintptr_t const alignment = kCudaMemAlign) { + uintptr_t curr_offset = offset; + uintptr_t next_offset = curr_offset + ((size + alignment - 1) / alignment) * alignment; + int8_t* newptr = size == 0 ? nullptr : base + curr_offset; + offset = next_offset; + return newptr; +} + +inline int8_t* nextWorkspacePtrWithAlignment( + int8_t* ptr, uintptr_t previousWorkspaceSize, uintptr_t const alignment = kCudaMemAlign) { + return nextWorkspacePtrCommon(ptr, previousWorkspaceSize, alignment); +} + +inline size_t calculateTotalWorkspaceSize( + size_t const* workspaces, int count, uintptr_t const alignment = kCudaMemAlign) { + size_t total = 0; + for (int i = 0; i < count; i++) { + total += workspaces[i]; + if (workspaces[i] % alignment) { + total += alignment - (workspaces[i] % alignment); + } + } + return total; +} + +}; // namespace onnxruntime::llm::common diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/arch/mma.h b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/arch/mma.h new file mode 100644 index 0000000000000..6de056b44339d --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/arch/mma.h @@ -0,0 +1,97 @@ +/* + * Copyright (c) 2017-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! \file + \brief Templates exposing architecture support for multiply-add operations +*/ + +#pragma once +#include "contrib_ops/cuda/llm/cutlass_extensions/weight_only_quant_op.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace arch { + +// Tag which triggers MMA which will trigger +struct OpMultiplyAddDequantizeInterleavedBToA; + +/* + Below we have extra tags to signal what kind of dequantization we want to do + (per col, scale only fine grained, finegrained with zero). This still lets us + the existing template infrastructure (incl. that in CUTLASS). However, we + split out the template below into OpMultiplyAddDequantizeInterleavedBToA along + with the quantization op before instantiating the GEMM pieces. + + Note that this is somewhat of a hack, but it SIGNIFICANTLY reduces the amount of + code we need to duplicate. + */ +struct OpMultiplyAddDequantizeInterleavedBToA_percol_scale; +struct OpMultiplyAddDequantizeInterleavedBToA_fine_scale; +struct OpMultiplyAddDequantizeInterleavedBToA_fine_scalebias; + +// The default just forwards the original operator +template +struct TagOperator { + using TaggedOperator = MmaOp; +}; + +// Specializations below attach more information to the operator +template <> +struct TagOperator { + using TaggedOperator = OpMultiplyAddDequantizeInterleavedBToA_percol_scale; +}; + +template <> +struct TagOperator { + using TaggedOperator = OpMultiplyAddDequantizeInterleavedBToA_fine_scale; +}; + +template <> +struct TagOperator { + using TaggedOperator = OpMultiplyAddDequantizeInterleavedBToA_fine_scalebias; +}; + +// Here we instantiate some structs to "detag" the tagged operator. It splits it back to the original +// operator + the extra information. If no extra info was tagged, the dequant op per column scaling +// as a default. +template +struct DetagOperator { + using Operator = TaggedMmaOp; + static constexpr WeightOnlyQuantOp QuantOp = WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY; +}; + +template <> +struct DetagOperator { + using Operator = OpMultiplyAddDequantizeInterleavedBToA; + static constexpr WeightOnlyQuantOp QuantOp = WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY; +}; + +template <> +struct DetagOperator { + using Operator = OpMultiplyAddDequantizeInterleavedBToA; + static constexpr WeightOnlyQuantOp QuantOp = WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY; +}; + +template <> +struct DetagOperator { + using Operator = OpMultiplyAddDequantizeInterleavedBToA; + static constexpr WeightOnlyQuantOp QuantOp = WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS; +}; + +} // namespace arch +} // namespace cutlass diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/compute_occupancy.h b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/compute_occupancy.h new file mode 100644 index 0000000000000..63dca2f458e1a --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/compute_occupancy.h @@ -0,0 +1,74 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include + +#include "cutlass/device_kernel.h" +#include "contrib_ops/cuda/llm/common/cuda_runtime_utils.h" +#include "core/providers/cuda/cuda_common.h" + +namespace onnxruntime::llm { +namespace cutlass_extensions { + +template +inline int compute_occupancy_for_kernel() { + int smem_size = static_cast(sizeof(typename GemmKernel::SharedStorage)); + + if (smem_size > (48 << 10)) { + cudaFuncAttributes attr; + int device = 0; + int max_smem_per_block = 0; + CUDA_CALL_THROW(cudaGetDevice(&device)); + CUDA_CALL_THROW( + cudaDeviceGetAttribute(&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device)); + if constexpr (enable_cutlass_3x) { + CUDA_CALL_THROW(cudaFuncGetAttributes(&attr, cutlass::device_kernel)); + } else { + CUDA_CALL_THROW(cudaFuncGetAttributes(&attr, cutlass::Kernel)); + } + if (smem_size + attr.sharedSizeBytes >= static_cast(max_smem_per_block)) { + // This should mean that + // cudaFuncSetAttribute(cutlass::Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) + // wouldn't work. In that case, we return an occupancy of 0. This will cause the heuristic to ignore this + // configuration. + return 0; + } + + if constexpr (enable_cutlass_3x) { + CUDA_CALL_THROW(cudaFuncSetAttribute( + cutlass::device_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } else { + CUDA_CALL_THROW(cudaFuncSetAttribute( + cutlass::Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + } + + int max_active_blocks = -1; + if constexpr (enable_cutlass_3x) { + CUDA_CALL_THROW( + cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_active_blocks, cutlass::device_kernel, + 128 * (GemmKernel::NumLoadWarpGroups + GemmKernel::NumMmaWarpGroups), smem_size)); + } else { + CUDA_CALL_THROW(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks, cutlass::Kernel, GemmKernel::kThreadCount, smem_size)); + } + + return max_active_blocks; +} + +} // namespace cutlass_extensions +} // namespace onnxruntime::llm diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/epilogue/thread/fused_activations.h b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/epilogue/thread/fused_activations.h new file mode 100644 index 0000000000000..e0911460ef8a3 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/epilogue/thread/fused_activations.h @@ -0,0 +1,61 @@ +/* + * Copyright (c) 2017-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! \file + \brief Functor performing linear combination with a maximum operation used by epilogues. +*/ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/thread/activation.h" +#include "cutlass/epilogue/thread/linear_combination_generic.h" +#include "cutlass/epilogue/thread/scale_type.h" +#include "cutlass/functional.h" +#include "cutlass/half.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/numeric_types.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace thread { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +__forceinline__ __device__ float copysignf_pos(float a, float b) { + float r; + r = __int_as_float(__float_as_int(a) | (__float_as_int(b) & 0x80000000)); + return r; +} + +__forceinline__ __device__ float tanh_opt(float x) { +#if (__CUDACC_VER_MAJOR__ < 11) || (__CUDA_ARCH__ < 750) + float const exp_val = -1.f * fabs(2 * x); + return copysignf_pos((1.0f - __expf(exp_val)) / (__expf(exp_val) + 1.0f), x); +#else + return fast_tanh(x); +#endif +} + +} // namespace thread +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/epilogue_helpers.h b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/epilogue_helpers.h new file mode 100644 index 0000000000000..1d7ff42d591e2 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/epilogue_helpers.h @@ -0,0 +1,122 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/** + * @file epilogue_helpers.h + * + * This file includes types for the epilogues. The empty structs exist so we can signal to template + * code the type of epilogue we want to run, and let the underlying code specify the details such as + * element types, accumulator type and elements per vector access. + * + */ + +#pragma once + +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/epilogue/thread/linear_combination_generic.h" +#include "cutlass/epilogue/thread/linear_combination_relu.h" +#include "cutlass/epilogue/thread/linear_combination_silu.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/epilogue/thread/fused_activations.h" +#include + +namespace onnxruntime::llm { +namespace cutlass_extensions { + +struct EpilogueOpBiasSilu { +}; + +struct EpilogueOpBiasReLU { +}; + +struct EpilogueOpBiasFtGelu { +}; + +struct EpilogueOpBias { +}; + +struct EpilogueOpDefaultSilu { +}; + +struct EpilogueOpDefaultReLU { +}; + +struct EpilogueOpDefaultFtGelu { +}; + +struct EpilogueOpDefault { +}; + +template +struct Epilogue { + static_assert(sizeof(ElementType) == 0, "Unrecognized Epilogue Tag"); +}; + +constexpr auto BiasScaleMode = cutlass::epilogue::thread::ScaleType::NoBetaScaling; + +template +struct Epilogue { + using Op = cutlass::epilogue::thread::LinearCombinationSilu; +}; + +template +struct Epilogue { + using Op = cutlass::epilogue::thread::LinearCombinationRelu; +}; + +template +struct Epilogue { + using Op = cutlass::epilogue::thread::LinearCombinationGeneric; +}; + +template +struct Epilogue { + using Op = cutlass::epilogue::thread::LinearCombination; +}; + +constexpr auto DefaultScaleMode = cutlass::epilogue::thread::ScaleType::Default; + +template +struct Epilogue { + using Op = cutlass::epilogue::thread::LinearCombinationSilu; +}; + +template +struct Epilogue { + using Op = cutlass::epilogue::thread::LinearCombinationRelu; +}; + +template +struct Epilogue { + using Op = cutlass::epilogue::thread::LinearCombinationGeneric; +}; + +template +struct Epilogue { + using Op = cutlass::epilogue::thread::LinearCombination; +}; + +} // namespace cutlass_extensions +} // namespace onnxruntime::llm diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/collective/builders/sm90_gmma_builder_interleaved.inl b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/collective/builders/sm90_gmma_builder_interleaved.inl new file mode 100644 index 0000000000000..a7146d99224eb --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/collective/builders/sm90_gmma_builder_interleaved.inl @@ -0,0 +1,140 @@ +/* + * Copyright (c) 2023-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "cutlass/arch/mma.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/gemm.h" + +#include "cutlass/gemm/collective/builders/sm90_common.inl" + +// SM90 Collective Builders should be used only starting CUDA 12.0 +#if (__CUDACC_VER_MAJOR__ >= 12) +#define CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED +#endif + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective +{ + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA_TMA_WS_RS Mixed Scaled GEMM +template +struct CollectiveBuilderInterleaved + || cute::is_same_v + || cute::is_same_v)>> +{ + +private: + using ScaleA = detail::deduce_mixed_width_dtype_t<1, ElementPairA_>; + using ScaleB = detail::deduce_mixed_width_dtype_t<1, ElementPairB_>; + using ZeroA = detail::deduce_mixed_width_dtype_t<2, ElementPairA_>; + using ZeroB = detail::deduce_mixed_width_dtype_t<2, ElementPairB_>; + static constexpr bool NeitherIsTuple + = !cute::is_tuple::value && !cute::is_tuple::value; + +public: + using ElementA = detail::deduce_mixed_width_dtype_t<0, ElementPairA_>; + using ElementB = detail::deduce_mixed_width_dtype_t<0, ElementPairB_>; + static_assert(cute::is_tuple::value ^ cute::is_tuple::value + || (NeitherIsTuple && (sizeof_bits::value != sizeof_bits::value)), + "Either A OR B must be a tuple or the widths of A and B must be different."); + + static constexpr bool IsANarrow = sizeof_bits::value < sizeof_bits::value; + + using GmemLayoutATag = GmemLayoutATag_; + using GmemLayoutBTag = GmemLayoutBTag_; + + using ElementPairA = cute::conditional_t, ElementPairA_>; + using ElementPairB = cute::conditional_t, ElementPairB_>; + + static constexpr bool IsATransformed = cute::is_tuple::value; + using ElementScale = cute::conditional_t; + using ElementZero = cute::conditional_t; + + static_assert(is_static::value); + static_assert(is_static::value); + static_assert(detail::is_aligned(), + "Should meet TMA alignment requirement\n"); +#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED + static_assert(cutlass::detail::dependent_false, "Unsupported Toolkit for SM90 Collective Builder\n"); +#endif + static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_rs_tag_to_major_A(); + static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_rs_tag_to_major_B(); + static constexpr bool IsWarpSpecializedTransposeB = detail::is_warpspecialized_transpose_B(); + static_assert(!IsWarpSpecializedTransposeB, "Mixed input GEMM does not support WS transpose B."); + + // If A is scaled, then we don't need to swap. Otherwise, we must ensure B goes to RF and we must swap the operands. + static constexpr bool SwapAB = !IsATransformed; + + // When we relax the above assertion, we must handle setting the tile mma GmmaMajorB correctly. + static constexpr cute::GMMA::Major TiledMmaGmmaMajorB = SwapAB ? GmmaMajorA : GmmaMajorB; + + using ElementMma = cute::conditional_t; + using AtomLayoutMNK = cute::conditional_t, + Layout>, Layout>>; + + using TiledMma + = decltype(cute::make_tiled_mma(cute::GMMA::rs_op_selector(), + AtomLayoutMNK{})); + + using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{}))); + using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{}))); + + using SmemLayoutAtomA + = decltype(detail::rs_smem_selector(TileShape_MNK{})), + decltype(cute::get<2>(TileShape_MNK{})), IsWarpSpecializedTransposeB>()); + using SmemLayoutAtomB + = decltype(detail::rs_smem_selector(TileShape_MNK{})), + decltype(cute::get<2>(TileShape_MNK{})), IsWarpSpecializedTransposeB>()); + + using RealElementA = cute::conditional_t; + using RealElementB = cute::conditional_t; + static constexpr int PipelineStages + = detail::compute_stage_count_or_override_single_affine_transformed_input(StageCountType{}); + + using SmemCopyAtomA = cute::conditional_t>; + using SmemCopyAtomB = cute::conditional_t, void>; + + using DispatchPolicy + = MainloopSm90TmaGmmaRmemAWarpSpecializedMixedInput; + + // We pack the scale data with the operand that will be optionally scaled and converted before MMA. + using StrideA = TagToStrideA_t; + using StrideB = TagToStrideB_t; + + using CollectiveOp = CollectiveMmaInterleaved; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/collective/collective_builder_interleaved.hpp b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/collective/collective_builder_interleaved.hpp new file mode 100644 index 0000000000000..97feaa2498bba --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/collective/collective_builder_interleaved.hpp @@ -0,0 +1,55 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +///////////////////////////////////////////////////////////////////////////////////////////////// +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/collective/collective_mma_interleaved.hpp" + +namespace cutlass::gemm::collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct CollectiveBuilderInterleaved { + static_assert(sizeof(ElementA) == 0, "Could not build a collective for given parameters."); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/collective/builders/sm90_gmma_builder_interleaved.inl" +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/collective/collective_mma_interleaved.hpp b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/collective/collective_mma_interleaved.hpp new file mode 100644 index 0000000000000..ce56a9d717ceb --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/collective/collective_mma_interleaved.hpp @@ -0,0 +1,55 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/detail/dependent_false.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct CollectiveMmaInterleaved { + static_assert(cutlass::detail::dependent_false, "Could not find a mainloop specialization."); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/collective/sm90_mma_interleaved_tma_gmma_rs_warpspecialized_mixed_input.hpp" +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/collective/sm90_mma_interleaved_tma_gmma_rs_warpspecialized_mixed_input.hpp b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/collective/sm90_mma_interleaved_tma_gmma_rs_warpspecialized_mixed_input.hpp new file mode 100644 index 0000000000000..499504439aa46 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/collective/sm90_mma_interleaved_tma_gmma_rs_warpspecialized_mixed_input.hpp @@ -0,0 +1,1372 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/detail/collective.hpp" +#include "cutlass/detail/dependent_false.hpp" +#include "cutlass/detail/layout.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/numeric_types.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/trace.h" +#include "cutlass/transform/collective/sm90_wgmma_transpose.hpp" + +#include "cute/algorithm/functional.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/arch/cluster_sm90.hpp" +#include "cute/arch/copy_sm90.hpp" +#include "cute/atom/copy_traits_sm90_tma.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" +#include "cute/tensor_predicate.hpp" +#include "contrib_ops/cuda/llm/cutlass_extensions/interleaved_numeric_conversion.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// WarpSpecialized Mainloop that source A operand from registers +template +struct CollectiveMmaInterleaved, + TileShape_, ElementAOptionalTuple, StrideA_, ElementBOptionalTuple, StrideB_, TiledMma_, GmemTiledCopyA_, + SmemLayoutAtomA_, SmemCopyAtomA_, TransformA_, GmemTiledCopyB_, SmemLayoutAtomB_, SmemCopyAtomB_, TransformB_> { + private: + template + static constexpr auto get_logical_ptr(PointerType const* ptr) { + if constexpr (cute::sizeof_bits_v < 8) { + return subbyte_iterator(ptr); + } else { + return ptr; + } + } + + template + static constexpr auto get_smem_interleave_layout() { + if constexpr (cute::sizeof_bits_v == 4 && cute::sizeof_bits_v == 8) { + return Layout(TileShape{})), Shape<_4, _4, _2, _4>>, + Stride<_128, Stride<_1, _8, _4, _32>>>{}; + } else if constexpr (cute::sizeof_bits_v == 4 && cute::sizeof_bits_v == 16) { + return Layout(TileShape{})), Shape<_2, _4, _4, _2>>, + Stride<_64, Stride<_1, _8, _2, _32>>>{}; + } else if constexpr (cute::sizeof_bits_v == 8 && cute::sizeof_bits_v == 16) { + return Layout(TileShape{})), Shape<_2, _4, _2, _4>>, + Stride<_64, Stride<_1, _4, _2, _16>>>{}; + } else { + static_assert(dependent_false, + "unsupported weight and activation, must be one of w4a8,w4a16,w8a16"); + } + } + + enum class ConversionMode { + DirectConvert, + ConvertAndScale, + ConvertAndScaleWithZero + }; + + using ScaleA = detail::deduce_mixed_width_dtype_t<1, ElementAOptionalTuple>; + using ScaleB = detail::deduce_mixed_width_dtype_t<1, ElementBOptionalTuple>; + using ZeroA = detail::deduce_mixed_width_dtype_t<2, ElementAOptionalTuple>; + using ZeroB = detail::deduce_mixed_width_dtype_t<2, ElementBOptionalTuple>; + + public: + // + // Type Aliases + // + using DispatchPolicy = MainloopSm90TmaGmmaRmemAWarpSpecializedMixedInput; + using TileShape = TileShape_; + + static_assert(cute::is_tuple::value ^ cute::is_tuple::value, + "Either A OR B must be a tuple. It must take the from {ElementOperand, [ElementScale]," + "[ElementZero]}. Inputs in [] are optional."); + + using ElementA = detail::deduce_mixed_width_dtype_t<0, ElementAOptionalTuple>; + using ElementB = detail::deduce_mixed_width_dtype_t<0, ElementBOptionalTuple>; + static constexpr bool IsATransformed = cute::is_tuple::value; + using ElementScale = cute::conditional_t; + using ElementZero = cute::conditional_t; + // For cases where we can't have a void type, we can use this to allow the code to compile when the scale / zero is + // void. + using NonVoidElementScale = cute::conditional_t, float, ElementScale>; + using NonVoidElementZero = cute::conditional_t, float, ElementZero>; + + using StrideA = StrideA_; + using StrideB = StrideB_; + // These are always MN major + using StrideScale = cute::Stride, int64_t, int64_t>; + // For cases where we can't have a void scale, we can use this to allow the code to compile when the scale is void. + using NonVoidStrideScale = cute::conditional_t, cute::Stride<_1, int64_t, int64_t>, StrideScale>; + + static_assert((IsATransformed && cutlass::gemm::detail::is_k_major()) || (!IsATransformed && cutlass::gemm::detail::is_k_major()), + "The transformed type must be K-major."); + + static_assert((IsATransformed && (sizeof(ElementB) == 2)) || (!IsATransformed && (sizeof(ElementA) == 2)) || (cutlass::gemm::detail::is_k_major() && cutlass::gemm::detail::is_k_major()), + "The unscaled element must be 2 bytes OR both inputs must be K-major"); + + static_assert(cutlass::gemm::detail::is_mn_major(), + "Scale must be MN major [Col Major if A is scaled, Row Major if B is scaled]."); + + using CtaShape_MNK = decltype(shape_div(TileShape{}, ClusterShape{})); + + using TiledMma = TiledMma_; + using ElementAccumulator = typename TiledMma::ValTypeC; + + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using GmemTiledCopyScale = cute::SM90_TMA_LOAD; + + using SmemLayoutAtomA = SmemLayoutAtomA_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + // Scale layout atom set after swapping. + + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using SmemCopyAtomScale = Copy_Atom; + + // We must ensure the type to be scaled goes to RF + static constexpr bool SwapAB = !IsATransformed; + using InternalSmemLayoutAtomA = cute::conditional_t; + using InternalSmemLayoutAtomB = cute::conditional_t; + using InternalSmemCopyAtomA = cute::conditional_t; + using InternalSmemCopyAtomB = cute::conditional_t; + // TMA converts f32 input to tf32 when copying from GMEM to SMEM + // For all other types, cast to size equivalent uint type to avoid any rounding by TMA. + static constexpr bool ConvertF32toTF32A = cute::is_same_v; + static constexpr bool ConvertF32toTF32B = cute::is_same_v; + using ConvertedElementA = cute::conditional_t>>; + using ConvertedElementB = cute::conditional_t>>; + using RealInternalElementA = cute::conditional_t; + using RealInternalElementB = cute::conditional_t; + using InternalElementA = cute::conditional_t; + using InternalElementB = cute::conditional_t; + using InternalStrideA = cute::conditional_t; + using InternalStrideB = cute::conditional_t; + + using TransformA = TransformA_; + using TransformB = TransformB_; + using InternalTransformA = cute::conditional_t; + using InternalTransformB = cute::conditional_t; + + static constexpr int IsSubbyteA = cute::sizeof_bits_v < 8; + using TmaElementA = cute::conditional_t; + + using ArchTag = typename DispatchPolicy::ArchTag; + + using MainloopPipeline = cutlass::PipelineTmaAsync; + using PipelineState = cutlass::PipelineState; + + using PipelineParams = typename MainloopPipeline::Params; + + // One threads per CTA are producers (1 for operand tile) + static constexpr int NumProducerThreadEvents = 1; + + using SmemLayoutAtomScale = Layout(InternalSmemLayoutAtomA{})), cute::Int<1>>>; + using ScaleTileShape = decltype(make_shape(shape<0>(TileShape{}), shape<1>(SmemLayoutAtomScale{}))); + static constexpr int type_factor = sizeof_bits::value / sizeof_bits::value; + + static_assert(cute::rank(InternalSmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<0>(TileShape{}) % size<0>(InternalSmemLayoutAtomA{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(InternalSmemLayoutAtomA{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(cute::rank(InternalSmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<1>(TileShape{}) % size<0>(InternalSmemLayoutAtomB{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(InternalSmemLayoutAtomB{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(rank(SmemLayoutAtomScale{}) == 2, "SmemLayoutAtomScale must be rank 2"); + static_assert( + (size<0>(TileShape{}) % size<0>(SmemLayoutAtomScale{})) == 0, "SmemLayoutAtomScale must equal the tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomScale{})) == 0, + "SmemLayoutAtomScale must evenly divide tile k shape."); + + // Tile along modes in a way that maximizes the TMA box size. + using SmemLayoutA = decltype(tile_to_shape(InternalSmemLayoutAtomA{}, + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int{}), + cute::conditional_t<::cutlass::gemm::detail::is_major<0, InternalStrideA>(), Step<_2, _1, _3>, + Step<_1, _2, _3>>{})); + + using Layout_Interleave = decltype(cute::composition(SmemLayoutA{}.layout_a(), SmemLayoutA{}.offset(), + get_smem_interleave_layout())); + using SmemLayoutA_mma_interleave = decltype(tile_to_shape(Layout_Interleave{}, + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int{}), + cute::conditional_t<::cutlass::gemm::detail::is_major<0, InternalStrideA>(), Step<_2, _1, _3>, + Step<_1, _2, _3>>{})); + using SmemLayoutA_mma = decltype(cute::composition(SmemLayoutA{}.layout_a(), SmemLayoutA{}.offset(), + make_layout(make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int{}), + make_stride(get<2>(TileShape{}), _1{}, get<0>(TileShape{}) * get<2>(TileShape{}))))); + // cute::conditional_t< ::cutlass::gemm::detail::is_major<0,InternalStrideA>(), + // Stride<_1, cute::Int(TileShape{})>, cute::Int(TileShape{}) * + // get<2>(TileShape{})>>, Stride(TileShape{})>, _1, + // cute::Int(TileShape{}) * get<2>(TileShape{})>>>{}))); + + using SmemLayoutB = decltype(tile_to_shape(InternalSmemLayoutAtomB{}, + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}), + cute::conditional_t<::cutlass::gemm::detail::is_major<0, InternalStrideB>(), Step<_2, _1, _3>, + Step<_1, _2, _3>>{})); + + // It is assumed that the scales and zero-points share the same smem layout + using SmemLayoutScale = decltype(tile_to_shape(SmemLayoutAtomScale{}, + make_shape(shape<0>(ScaleTileShape{}), shape<1>(ScaleTileShape{}), Int{}), + cute::conditional_t<::cutlass::gemm::detail::is_major<0, NonVoidStrideScale>(), Step<_2, _1, _3>, + Step<_1, _2, _3>>{})); + + static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 2 or more."); + static_assert(!cute::is_base_of::value && cute::is_base_of::value, + "MMA atom must source A from rmem and B operand from smem_desc for this mainloop."); + static_assert( + cute::is_same_v || cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + static_assert( + cute::is_same_v || cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + + // To relax them, we need to handle loading more than 1 row of scales for every main loop iteration. + // We must also handle updating the pipeline transaction bytes on the fly. + // NOTE: Deleting this assertion without required changes will cause the code to hang. + static_assert(size<1>(SmemLayoutAtomScale{}) == 1, "size<1>(SmemLayoutAtomScale) must be 1."); + + private: + static constexpr ConversionMode get_conversion_mode() { + if constexpr (cute::is_void_v) { + return ConversionMode::DirectConvert; + } else if constexpr (cute::is_void_v) { + return ConversionMode::ConvertAndScale; + } else { + return ConversionMode::ConvertAndScaleWithZero; + } + } + + static constexpr ConversionMode KernelConversionMode = get_conversion_mode(); + static constexpr bool ModeHasScales = KernelConversionMode == ConversionMode::ConvertAndScale || KernelConversionMode == ConversionMode::ConvertAndScaleWithZero; + + static constexpr auto elements_per_smem_scale() { + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + return 0; + } else if constexpr (ModeHasScales) { + return cute::cosize_v; + } else { + static_assert( + cutlass::detail::dependent_false, "Type not handled in scale smem allocation."); + } + } + + static constexpr auto elements_per_smem_zero() { + if constexpr (KernelConversionMode == ConversionMode::DirectConvert || KernelConversionMode == ConversionMode::ConvertAndScale) { + return 0; + } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + return cute::cosize_v; + } else { + static_assert( + cutlass::detail::dependent_false, "Type not handled in scale smem allocation."); + } + } + + // These methods use some the public members of the class. For that reason, we define them after the public section. + static constexpr uint32_t compute_tma_transaction_bytes_mk() { + constexpr uint32_t baseline_bytes = cutlass::bits_to_bytes(size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast(cute::sizeof_bits_v)); + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + return baseline_bytes; + } else if constexpr (ModeHasScales) { + constexpr uint32_t scale_tx_bytes = cutlass::bits_to_bytes(size<0>(SmemLayoutScale{}) * size<1>(SmemLayoutScale{}) * static_cast(cute::sizeof_bits_v)); + static_assert(scale_tx_bytes % 128 == 0, "Each scale stage must be 128B aligned."); // required by TMA + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + return baseline_bytes + scale_tx_bytes; + } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + // Scale and zero share smem layout + constexpr uint32_t zero_tx_bytes = cutlass::bits_to_bytes(size<0>(SmemLayoutScale{}) * size<1>(SmemLayoutScale{}) * static_cast(cute::sizeof_bits_v)); + static_assert(zero_tx_bytes % 128 == 0, "Each zero stage must be 128B aligned."); // required by TMA + return baseline_bytes + scale_tx_bytes + zero_tx_bytes; + } else { + static_assert(cutlass::detail::dependent_false, + "Type not handled in tma transaction bytes computation."); + } + } else { + static_assert(cutlass::detail::dependent_false, + "Type not handled in tma transaction bytes computation."); + } + } + + static constexpr uint32_t compute_tma_transaction_bytes_nk() { + return cutlass::bits_to_bytes(size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast(cute::sizeof_bits_v)); + } + + public: + static constexpr size_t SmemAlignmentA = cutlass::detail::alignment_for_swizzle(SmemLayoutA{}); + + static constexpr size_t SmemAlignmentB = cutlass::detail::alignment_for_swizzle(SmemLayoutB{}); + + // Just pick the max alignment of A and B since it is required to be at least 128B + static constexpr size_t SmemAlignmentScale = cute::max(SmemAlignmentA, SmemAlignmentB); + + static_assert(SmemAlignmentA >= 128 && SmemAlignmentB >= 128, "Require at least 128B alignment"); + + struct SharedStorage { + static constexpr int scale_elements = elements_per_smem_scale(); + static constexpr int zero_elements = elements_per_smem_zero(); + + struct TensorStorage : cute::aligned_struct { + cute::ArrayEngine> smem_A; + cute::ArrayEngine> smem_B; + cute::ArrayEngine smem_scale; + cute::ArrayEngine smem_zero; + } tensors; + + using PipelineStorage = typename MainloopPipeline::SharedStorage; + PipelineStorage pipeline; + }; + + using TensorStorage = typename SharedStorage::TensorStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + // Host side kernel arguments + struct Arguments { + ElementA const* ptr_A = nullptr; + StrideA dA{}; + ElementB const* ptr_B = nullptr; + StrideB dB{}; + ElementScale const* ptr_S = nullptr; + NonVoidStrideScale dS{}; + int group_size = 0; + ElementZero const* ptr_Z = nullptr; + uint32_t mma_promotion_interval = 4; + }; + + // Device side kernel params + struct Params { + private: + using Outer = CollectiveMmaInterleaved; + + public: + // Assumption: StrideA is congruent with Problem_MK + using TMA_A = decltype(make_tma_copy(GmemTiledCopyA{}, + make_tensor(Outer::get_logical_ptr(static_cast(nullptr)), + repeat_like(InternalStrideA{}, static_cast(0)), InternalStrideA{}), + SmemLayoutA{}(_, _, cute::Int<0>{}), make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + size<1>(ClusterShape{}))); // mcast along N mode for this M load, if any + + using TMA_Scale = decltype(make_tma_copy(GmemTiledCopyScale{}, + make_tensor(Outer::get_logical_ptr(static_cast(nullptr)), + repeat_like(NonVoidStrideScale{}, static_cast(0)), NonVoidStrideScale{}), + SmemLayoutScale{}(_, _, cute::Int<0>{}), ScaleTileShape{}, + _1{})); // mcast along N mode for this M load, if any. Scale is ALWAYS loaded with A for RF kernel + + using TMA_Zero = decltype(make_tma_copy(GmemTiledCopyScale{}, + make_tensor(Outer::get_logical_ptr(static_cast(nullptr)), + repeat_like(NonVoidStrideScale{}, static_cast(0)), NonVoidStrideScale{}), + SmemLayoutScale{}(_, _, cute::Int<0>{}), ScaleTileShape{}, + _1{})); // mcast along N mode for this M load, if any. Scale is ALWAYS loaded with A for RF kernel + + // Assumption: StrideB is congruent with Problem_NK + using TMA_B = decltype(make_tma_copy(GmemTiledCopyB{}, + make_tensor(Outer::get_logical_ptr(static_cast(nullptr)), + repeat_like(InternalStrideB{}, static_cast(0)), InternalStrideB{}), + SmemLayoutB{}(_, _, cute::Int<0>{}), make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any + TMA_A tma_load_a; + TMA_B tma_load_b; + TMA_Scale tma_load_scale; + TMA_Zero tma_load_zero; + int64_t scale_k; + int group_size; + uint32_t tma_transaction_bytes = TmaTransactionBytes; + uint32_t tma_transaction_bytes_mk = TmaTransactionBytesMK; + uint32_t tma_transaction_bytes_nk = TmaTransactionBytesNK; + }; + + // + // Methods + // + + template + static constexpr Params to_underlying_arguments( + ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + (void)workspace; + + // Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + if constexpr (SwapAB) { + M = get<1>(problem_shape_MNKL); + N = get<0>(problem_shape_MNKL); + } + + InternalElementA const* ptr_A; + InternalStrideA dA; + InternalElementB const* ptr_B; + InternalStrideB dB; + + if constexpr (not SwapAB) { + ptr_A = reinterpret_cast(args.ptr_A); + ptr_B = reinterpret_cast(args.ptr_B); + dA = args.dA; + dB = args.dB; + } else { + ptr_A = reinterpret_cast(args.ptr_B); + ptr_B = reinterpret_cast(args.ptr_A); + dA = args.dB; + dB = args.dA; + } + + Tensor tensor_a = make_tensor(get_logical_ptr(ptr_A), make_layout(make_shape(M, K, L), dA)); + Tensor tensor_b = make_tensor(get_logical_ptr(ptr_B), make_layout(make_shape(N, K, L), dB)); + typename Params::TMA_A tma_load_a = make_tma_copy(GmemTiledCopyA{}, tensor_a, + SmemLayoutA{}(_, _, cute::Int<0>{}), make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + size<1>(ClusterShape{})); // mcast along N mode for this M load, if any + + typename Params::TMA_B tma_load_b = make_tma_copy(GmemTiledCopyB{}, tensor_b, + SmemLayoutB{}(_, _, cute::Int<0>{}), make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + size<0>(ClusterShape{})); // mcast along M mode for this N load, if any + + typename Params::TMA_Scale tma_load_scale; + typename Params::TMA_Zero tma_load_zero; + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + return {tma_load_a, tma_load_b, tma_load_scale, tma_load_zero, 0, 0, TmaTransactionBytes, + TmaTransactionBytesMK, TmaTransactionBytesNK}; + } else if constexpr (ModeHasScales) { + auto scale_k = (K + args.group_size - 1) / args.group_size; + ElementScale const* ptr_S = args.ptr_S; + StrideScale dS = args.dS; + Tensor tensor_scale = make_tensor(get_logical_ptr(ptr_S), make_layout(make_shape(M, scale_k, L), dS)); + tma_load_scale = make_tma_copy(GmemTiledCopyScale{}, tensor_scale, SmemLayoutScale{}(_, _, cute::Int<0>{}), + ScaleTileShape{}, _1{}); // mcast along N mode for this M load, if any + + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + return {tma_load_a, tma_load_b, tma_load_scale, tma_load_zero, scale_k, args.group_size, + TmaTransactionBytes, TmaTransactionBytesMK, TmaTransactionBytesNK}; + } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + Tensor tensor_zero = make_tensor(get_logical_ptr(args.ptr_Z), make_layout(make_shape(M, scale_k, L), dS)); + tma_load_zero = make_tma_copy(GmemTiledCopyScale{}, tensor_zero, SmemLayoutScale{}(_, _, cute::Int<0>{}), + ScaleTileShape{}, _1{}); // mcast along N mode for this M load, if any + return {tma_load_a, tma_load_b, tma_load_scale, tma_load_zero, scale_k, args.group_size, + TmaTransactionBytes, TmaTransactionBytesMK, TmaTransactionBytesNK}; + } else { + static_assert(cutlass::detail::dependent_false, + "Conversion mode not handled in to_underlying_arguments."); + } + } else { + static_assert(cutlass::detail::dependent_false, + "Conversion mode not handled in to_underlying_arguments."); + } + } + + template + static bool can_implement(ProblemShape const& problem_shape, [[maybe_unused]] Arguments const& args) { + constexpr int tma_alignment_bits = 128; + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + bool implementable = true; + constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M, K, L), StrideA{}); + constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N, K, L), StrideB{}); + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + implementable = implementable && (args.ptr_S == nullptr); + implementable = implementable && (args.ptr_Z == nullptr); + } else if constexpr (ModeHasScales) { + int const scale_mn = SwapAB ? N : M; + int const scale_k = (K + args.group_size - 1) / args.group_size; + constexpr int min_tma_aligned_elements_scale = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment( + cute::make_shape(scale_mn, scale_k, L), StrideScale{}); + implementable = implementable && (args.group_size == K || ((args.group_size % size<2>(TileShape{})) == 0)); + implementable = implementable && args.group_size != 0; + implementable = implementable && (args.ptr_S != nullptr); + + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + implementable = implementable && (args.ptr_Z == nullptr); + } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + constexpr int min_tma_aligned_elements_zero = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment( + cute::make_shape(scale_mn, scale_k, L), StrideScale{}); + implementable = implementable && (args.ptr_Z != nullptr); + } else { + static_assert( + cutlass::detail::dependent_false, "Conversion mode not handled in can_implement."); + } + } else { + static_assert( + cutlass::detail::dependent_false, "Conversion mode not handled in can_implement."); + } + + if (!implementable) { + CUTLASS_TRACE_HOST( + " CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + return implementable; + } + + static constexpr int K_PIPE_MAX = DispatchPolicy::Stages; + static constexpr uint32_t TmaTransactionBytesMK = compute_tma_transaction_bytes_mk(); + static constexpr uint32_t TmaTransactionBytesNK = compute_tma_transaction_bytes_nk(); + static constexpr uint32_t TmaTransactionBytes = TmaTransactionBytesMK + TmaTransactionBytesNK; + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& mainloop_params) { + cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor()); + cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor()); + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + // Nothing extra to do + } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + cute::prefetch_tma_descriptor(mainloop_params.tma_load_scale.get_tma_descriptor()); + } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + cute::prefetch_tma_descriptor(mainloop_params.tma_load_scale.get_tma_descriptor()); + cute::prefetch_tma_descriptor(mainloop_params.tma_load_zero.get_tma_descriptor()); + } else { + static_assert( + cutlass::detail::dependent_false, "Conversion mode not handled in TMA prefetch."); + } + } + + /// Set up the data needed by this collective for load and mma. + /// Returns a tuple of tensors. The collective and the kernel layer have the contract + /// Returned tuple must contain at least two elements, with the first two elements being: + /// gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l) + /// gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l) + /// The rest of the tensors can be specified as needed by this collective. + template + CUTLASS_DEVICE auto load_init(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params) const { + using X = Underscore; + // Separate out problem shape for convenience + auto [M, N, K, L] = problem_shape_MNKL; + + // TMA requires special handling of strides to deal with coord codomain mapping + // Represent the full tensors -- get these from TMA + Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(make_shape(M, K, L)); // (m,k,l) + Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N, K, L)); // (n,k,l) + + // Make tiled views, defer the slice + Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_, _, _), Step<_1, X, _1>{}); // (BLK_M,BLK_K,m,k,l) + Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_, _, _), Step{}); // (BLK_N,BLK_K,n,k,l) + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + return cute::make_tuple(gA_mkl, gB_nkl); + } else if constexpr (ModeHasScales) { + auto scale_k = mainloop_params.scale_k; + Tensor mS_mkl = mainloop_params.tma_load_scale.get_tma_tensor(make_shape(M, scale_k, L)); // (m,scale_k,l) + Tensor gS_mkl = local_tile(mS_mkl, ScaleTileShape{}, make_coord(_, _)); // (BLK_M,BLK_Scale_K,m,scale_k,l) + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + return cute::make_tuple(gA_mkl, gB_nkl, gS_mkl); + } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + Tensor mZ_mkl = mainloop_params.tma_load_zero.get_tma_tensor(make_shape(M, scale_k, L)); // (m,scale_k,l) + Tensor gZ_mkl = local_tile(mZ_mkl, ScaleTileShape{}, make_coord(_, _)); // (BLK_M,BLK_Scale_K,m,scale_k,l) + return cute::make_tuple(gA_mkl, gB_nkl, gS_mkl, gZ_mkl); + } else { + static_assert( + cutlass::detail::dependent_false, "Conversion mode not handled in load_init."); + } + } else { + static_assert( + cutlass::detail::dependent_false, "Conversion mode not handled in load_init."); + } + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Producer Perspective + /// This overload gets triggered when we have scales. + template + CUTLASS_DEVICE void load(Params const& mainloop_params, MainloopPipeline pipeline, PipelineState smem_pipe_write, + cute::tuple const& load_inputs, BlockCoord const& blk_coord, KTileIterator k_tile_iter, int k_tile_count, + int thread_idx, uint32_t block_rank_in_cluster, TensorStorage& shared_tensors) { + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + static_assert(sizeof...(Ts) == 2, "Direct convert needs two inputs"); + } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + static_assert(sizeof...(Ts) == 3, "Scaled convert needs three inputs"); + } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + static_assert(sizeof...(Ts) == 4, "Scaled and zero convert needs four inputs"); + } else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in TMA load."); + } + + int lane_predicate = cute::elect_one_sync(); + + if (lane_predicate) { + Tensor sA_ = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB_ = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + Tensor sA = as_position_independent_swizzle_tensor(sA_); // (BLK_M,BLK_K,PIPE) + Tensor sB = as_position_independent_swizzle_tensor(sB_); // (BLK_N,BLK_K,PIPE) + + // + // Prepare the TMA loads for A, B and Scales + // + + constexpr uint32_t cluster_shape_x = get<0>(ClusterShape()); + uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; + + Tensor gA_mkl = get<0>(load_inputs); + Tensor gB_nkl = get<1>(load_inputs); + + auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y); + auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x); + + // Partition the inputs based on the current block coordinates. + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; + Tensor gA = gA_mkl(_, _, m_coord, _, l_coord); // (BLK_M,BLK_K,k) + Tensor gB = gB_nkl(_, _, n_coord, _, l_coord); // (BLK_N,BLK_K,k) + + // Applies the mapping from block_tma_a + Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) + Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) + + Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) + Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) + + uint16_t mcast_mask_a = 0; + uint16_t mcast_mask_b = 0; + uint16_t mcast_mask_s = 0; + + // Issue TmaLoads + // Maps the tile -> block, value + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int n = 0; n < size<1>(block_layout); ++n) { + mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x, n, Int<0>{})); + } + } + + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int m = 0; m < size<0>(block_layout); ++m) { + mcast_mask_b |= (uint16_t(1) << block_layout(m, cluster_local_block_id.y, Int<0>{})); + } + } + + auto extra_input_partitions = partition_extra_tma_inputs( + mainloop_params, load_inputs, shared_tensors, cluster_local_block_id, m_coord, l_coord); + + // Mainloop + CUTLASS_PRAGMA_NO_UNROLL + for (; k_tile_count > 0; --k_tile_count) { + // LOCK smem_pipe_write for _writing_ + pipeline.producer_acquire(smem_pipe_write); + + // + // Copy gmem to smem for *k_tile_iter + // + + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write); + + int write_stage = smem_pipe_write.index(); + copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_, _, _, *k_tile_iter), + tAsA(_, _, _, write_stage)); + copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_, _, _, *k_tile_iter), + tBsB(_, _, _, write_stage)); + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + // Nothing extra to do. + } else if constexpr (ModeHasScales) { + auto tSgS = get<0>(extra_input_partitions); + auto tSsS = get<1>(extra_input_partitions); + + // Temporary factor which will determine which k tile to reload from gmem. Needed so we don't modify + // tma transaction bytes on the fly. We must do a ceiling divide here to correctly handle with + // group_size == K. In that case, we don't require that K is a multiple of the threadblock tile K + int const ReloadFactor = (mainloop_params.group_size + size<2>(TileShape{}) - 1) / size<2>(TileShape{}); + int const scale_load_k = *k_tile_iter / ReloadFactor; // This will always be 0 when group_size == K. + copy(mainloop_params.tma_load_scale.with(*tma_barrier, mcast_mask_s), tSgS(_, _, _, scale_load_k), + tSsS(_, _, _, write_stage)); + + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + // Nothing extra to do + } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + auto tZgZ = get<2>(extra_input_partitions); + auto tZsZ = get<3>(extra_input_partitions); + copy(mainloop_params.tma_load_zero.with(*tma_barrier, mcast_mask_s), + tZgZ(_, _, _, scale_load_k), tZsZ(_, _, _, write_stage)); + } else { + static_assert(cutlass::detail::dependent_false, + "Conversion mode not handled for TMA copy op."); + } + } else { + static_assert(cutlass::detail::dependent_false, + "Conversion mode not handled for TMA copy op."); + } + + ++k_tile_iter; + + // Advance smem_pipe_write + ++smem_pipe_write; + } + } + } + + /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster + CUTLASS_DEVICE void load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_write) { + int lane_predicate = cute::elect_one_sync(); + + // Issue the epilogue waits + if (lane_predicate) { + /* This helps avoid early exit of blocks in Cluster + * Waits for all stages to either be released (all + * Consumer UNLOCKs), or if the stage was never used + * then would just be acquired since the phase was + * still inverted from make_producer_start_state + */ + pipeline.producer_tail(smem_pipe_write); + } + } + + template + constexpr auto interleave_for_mixed_input() { + if constexpr (cute::sizeof_bits_v == 4 && cute::sizeof_bits_v == 8) { + return Layout, _1, Shape<_2, _2>>, + Stride, _0, Stride<_16, _32>>>{}; + } else if constexpr (cute::sizeof_bits_v == 4 && cute::sizeof_bits_v == 16) { + return Layout, _1, Shape<_2>>, + Stride, _0, Stride<_16>>>{}; + } else if constexpr (cute::sizeof_bits_v == 8 && cute::sizeof_bits_v == 16) { + return Layout, _1, Shape<_2, _2>>, + Stride, _0, Stride<_8, _16>>>{}; + } else { + static_assert(dependent_false, + "unsupported weight and activation, must be one of w4a8,w4a16,w8a16"); + } + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template + CUTLASS_DEVICE void mma(MainloopPipeline pipeline, PipelineState smem_pipe_read, FrgTensorC& accum, + int k_tile_count, int thread_idx, TensorStorage& shared_tensors, Params const& mainloop_params) { + static_assert(is_rmem::value, "C tensor must be rmem resident."); + static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::rank(InternalSmemLayoutAtomA{}) == 2, "InternalSmemLayoutAtomA must be rank 2."); + static_assert(cute::rank(InternalSmemLayoutAtomB{}) == 2, "InternalSmemLayoutAtomB must be rank 2."); + static_assert(!cute::is_void_v, + "SM90 GMMA mainloops must specify a non-void copy atom for RF sourced instructions."); + static_assert(cute::is_void_v, + "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); + + // Obtain warp index + int warp_idx = canonical_warp_idx_sync(); + [[maybe_unused]] int warp_group_thread_idx = thread_idx % 128; + + Tensor sA_ = make_tensor( + make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA_mma_interleave{}); // (BLK_M,BLK_K,PIPE) + Tensor sA = as_position_independent_swizzle_tensor(sA_); // (BLK_M,BLK_K,PIPE) + + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // + // Define C accumulators and A/B partitioning + // + + // Layout of warp group to thread mapping + + static_assert(stride<0>(typename TiledMma::BLayout{}) == 0 and size<0>(typename TiledMma::BLayout{}) == NumThreadsPerWarpGroup, + "Stride of the first mode must be 0 and the size of the mode must be NumThreadsPerWarpGroup"); + + constexpr int MmaWarpGroups = size(TiledMma{}) / NumThreadsPerWarpGroup; + Layout warp_group_thread_layout = make_layout(Int{}, Int{}); + + int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / NumThreadsPerWarpGroup, 0); + + TiledMma tiled_mma; + auto mma_thread_slice = tiled_mma.get_thread_slice(thread_idx); + Tensor tCsA = mma_thread_slice.partition_A(sA); + auto mma_warpgroup_slice = tiled_mma.get_slice(warp_group_thread_layout(warp_group_idx)); + + auto interleave_layout = interleave_for_mixed_input(); + + auto interleave_remapping = cute::flat_product(interleave_layout, Layout>>{}); + + Tensor tCsA_remapped = tCsA.compose(interleave_remapping); + + auto interleave_remapping_thread = right_inverse(interleave_layout); + + // Allocate fragments and descriptors + Tensor tCrA_mma = mma_thread_slice.partition_fragment_A(sA(_, _, Int<0>{})); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrA_load = make_fragment_like(tCrA_mma); + + Tensor tCsB = mma_warpgroup_slice.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + Tensor tCrB = mma_warpgroup_slice.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE) + + // + // Copy Atom A retiling + // + auto smem_tiled_copy_A = make_tiled_copy_A(InternalSmemCopyAtomA{}, tiled_mma); + auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(warp_group_thread_idx); + + Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA_load); // (CPY,CPY_M,CPY_K) + + // Compute the max vector length that can be used to copy A. This will match the vector width of the + // conversions used. It helps by allowing the compiler to convert using the same register that was used + // to load the data from smem. This significantly reduces the need to move data among registers. + // Note that this is correct even if copy fails to vectorize, since the granularity at which we perform + // the conversion does not impact correctness. + using A_CPY_VEC = decltype(max_common_vector(tCsA, tCrA_copy_view)); + using A_CPY_VEC_remapped = decltype(max_common_vector(tCsA_remapped, tCrA_copy_view)); + static_assert(A_CPY_VEC_remapped{} == 32 / cutlass::sizeof_bits::value, + "max_common_vector(tCsA_remapped, tCrA_copy_view) is 32 / cutlass::sizeof_bits::value"); + auto tCrA_mma_tmp = tCrA_mma.compose(interleave_remapping_thread); + auto tCrA_mma_inverse_mapping = tCrA_mma_tmp.compose(tCrA_mma.layout()); + + auto tCrA_load_tmp = tCrA_load.compose(interleave_remapping_thread); + auto tCrA_load_inverse_mapping = tCrA_load_tmp.compose(tCrA_load.layout()); + + // Partition of thread -> shared and thread -> RF + auto partitioned_extra_info = partition_extra_mma_info(mma_thread_slice, shared_tensors); + auto copy_partitions_extra_info = retile_extra_mma_info(tiled_mma, partitioned_extra_info, warp_group_thread_idx); + + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // CPY_M + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCrA_copy_view)); // CPY_K + CUTE_STATIC_ASSERT_V(size<1>(tCrA_mma) == size<1>(accum)); // MMA_M + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum)); // N + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K + CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE + + // + // PIPELINED MAIN LOOP + // + + // We release buffers to producer warps(dma load) with some mmas in flight + PipelineState smem_pipe_release = smem_pipe_read; + + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + + warpgroup_fence_operand(accum); + + constexpr int K_BLOCK_MAX = size<2>(tCrA_load); + + constexpr int kNumKIterationsPerWarpBLoad = type_factor / 2; + + ConsumerToken barrier_token = {BarrierStatus::WaitAgain}; + // first k tile + { + barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + + int read_stage = smem_pipe_read.index(); + + ++smem_pipe_read; + barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + + // copy smem->rmem for A operand + copy_A_and_extra_info(smem_tiled_copy_A, tCsA_remapped, tCrA_copy_view, partitioned_extra_info, + copy_partitions_extra_info, 0, read_stage, kNumKIterationsPerWarpBLoad); + if (K_BLOCK_MAX > 1) { // prefetch next block + copy_A_and_extra_info(smem_tiled_copy_A, tCsA_remapped, tCrA_copy_view, partitioned_extra_info, + copy_partitions_extra_info, 1, read_stage, kNumKIterationsPerWarpBLoad); + } + + transform_A_kblock( + tCrA_load, A_CPY_VEC_remapped{}, tCrA_mma, partitioned_extra_info, 0, kNumKIterationsPerWarpBLoad); + + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block) { + warpgroup_arrive(); + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma, tCrA_mma_inverse_mapping(_, _, k_block), tCrB(_, _, k_block, read_stage), accum); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + warpgroup_commit_batch(); + + if (k_block < K_BLOCK_MAX - 2) // prefetch next block + { + copy_A_and_extra_info(smem_tiled_copy_A, tCsA_remapped, tCrA_copy_view, partitioned_extra_info, + copy_partitions_extra_info, k_block + 2, read_stage, kNumKIterationsPerWarpBLoad); + } + if (k_block < K_BLOCK_MAX - 1) { + transform_A_kblock(tCrA_load, A_CPY_VEC_remapped{}, tCrA_mma, partitioned_extra_info, k_block + 1, + kNumKIterationsPerWarpBLoad); + } + } + + --k_tile_count; + if (k_tile_count > 0) { + // Wait for K_BLOCK_MAX - 1 to be in flight to ensure that it is safe to overwrite the A registers for + // the first mma. + pipeline.consumer_wait(smem_pipe_read, barrier_token); + copy_A_and_extra_info(smem_tiled_copy_A, tCsA_remapped, tCrA_copy_view, partitioned_extra_info, + copy_partitions_extra_info, 0, smem_pipe_read.index(), kNumKIterationsPerWarpBLoad); + if (K_BLOCK_MAX > 1) { // prefetch next block + copy_A_and_extra_info(smem_tiled_copy_A, tCsA_remapped, tCrA_copy_view, partitioned_extra_info, + copy_partitions_extra_info, 1, smem_pipe_read.index(), kNumKIterationsPerWarpBLoad); + } + warpgroup_wait(); + transform_A_kblock( + tCrA_load, A_CPY_VEC_remapped{}, tCrA_mma, partitioned_extra_info, 0, kNumKIterationsPerWarpBLoad); + } + } + + if (k_tile_count == 0) { + return; + } + + warpgroup_fence_operand(accum); + // Mainloop GMMAs + CUTLASS_PRAGMA_NO_UNROLL + for (; k_tile_count > 1; --k_tile_count) { + // + // Compute on k_tile + // + + int read_stage = smem_pipe_read.index(); + ++smem_pipe_read; + + warpgroup_fence_operand(accum); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block) { + warpgroup_arrive(); + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma, tCrA_mma_inverse_mapping(_, _, k_block), tCrB(_, _, k_block, read_stage), accum); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + warpgroup_commit_batch(); + + warpgroup_wait(); // We have K_BLOCK_MAX - 1 GMMA instructions pending for this + // stage, so we can release prior barrier + if (k_block == K_BLOCK_MAX - 1) { + pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it + ++smem_pipe_release; + } + + if (k_block == 0) { + barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + } + + if (k_block == K_BLOCK_MAX - 1) { + pipeline.consumer_wait(smem_pipe_read, barrier_token); + copy_A_and_extra_info(smem_tiled_copy_A, tCsA_remapped, tCrA_copy_view, partitioned_extra_info, + copy_partitions_extra_info, 0, smem_pipe_read.index(), kNumKIterationsPerWarpBLoad); + if (K_BLOCK_MAX > 1) { // prefetch next block + copy_A_and_extra_info(smem_tiled_copy_A, tCsA_remapped, tCrA_copy_view, partitioned_extra_info, + copy_partitions_extra_info, 1, smem_pipe_read.index(), kNumKIterationsPerWarpBLoad); + } + transform_A_kblock(tCrA_load, A_CPY_VEC_remapped{}, tCrA_mma, partitioned_extra_info, 0, + kNumKIterationsPerWarpBLoad); + } else { + if (k_block < K_BLOCK_MAX - 2) { // prefetch next block + copy_A_and_extra_info(smem_tiled_copy_A, tCsA_remapped, tCrA_copy_view, partitioned_extra_info, + copy_partitions_extra_info, k_block + 2, read_stage, kNumKIterationsPerWarpBLoad); + } + transform_A_kblock(tCrA_load, A_CPY_VEC_remapped{}, tCrA_mma, partitioned_extra_info, k_block + 1, + kNumKIterationsPerWarpBLoad); + } + } + warpgroup_fence_operand(accum); + } + + warpgroup_fence_operand(accum); + + { + // + // Compute on k_tile + // + + int read_stage = smem_pipe_read.index(); + + warpgroup_fence_operand(accum); + + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block) { + warpgroup_arrive(); + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma, tCrA_mma_inverse_mapping(_, _, k_block), tCrB(_, _, k_block, read_stage), accum); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + warpgroup_commit_batch(); + warpgroup_wait(); + if (k_block == K_BLOCK_MAX - 1) // release prior barrier + { + pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it + ++smem_pipe_release; + } + + if (k_block < K_BLOCK_MAX - 2) // prefetch next block + { + copy_A_and_extra_info(smem_tiled_copy_A, tCsA_remapped, tCrA_copy_view, partitioned_extra_info, + copy_partitions_extra_info, k_block + 2, read_stage, kNumKIterationsPerWarpBLoad); + } + + if (k_block < K_BLOCK_MAX - 1) { + copy_A_and_extra_info(smem_tiled_copy_A, tCsA_remapped, tCrA_copy_view, partitioned_extra_info, + copy_partitions_extra_info, k_block + 1, read_stage, kNumKIterationsPerWarpBLoad); + transform_A_kblock(tCrA_load, A_CPY_VEC_remapped{}, tCrA_mma, partitioned_extra_info, k_block + 1, + kNumKIterationsPerWarpBLoad); + } + } + } + + warpgroup_fence_operand(accum); + } + + /// Perform a Consumer Epilogue to release all buffers + CUTLASS_DEVICE void mma_tail(MainloopPipeline pipeline, PipelineState smem_pipe_release, int k_tile_count) { + // Prologue GMMAs + int prologue_mma_count = 1; + k_tile_count -= prologue_mma_count; + + smem_pipe_release.advance(k_tile_count); + + // Wait on all GMMAs to complete + warpgroup_wait<0>(); + + for (int count = 0; count < prologue_mma_count; ++count) { + pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it + ++smem_pipe_release; + } + } + + private: + /// Utilities for any additional inputs inside of the TMA load + template + CUTLASS_DEVICE auto partition_extra_tma_inputs(Params const& mainloop_params, cute::tuple const& load_inputs, + TensorStorage& shared_tensors, uint2 const& cluster_local_block_id, int const m_coord, int const l_coord) { + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + return cute::make_tuple(); + } else if constexpr (ModeHasScales) { + Tensor sS = make_tensor( + make_smem_ptr(shared_tensors.smem_scale.begin()), SmemLayoutScale{}); // (BLK_M,BLK_K,PIPE) + Tensor gS_mkl = get<2>(load_inputs); + auto block_tma_s = mainloop_params.tma_load_scale.get_slice(cluster_local_block_id.y); + Tensor gS = gS_mkl(_, _, m_coord, _, l_coord); // (BLK_M,BLK_K,k) + + Tensor tSgS = block_tma_s.partition_S(gS); // (TMA,TMA_M,TMA_K,k) + Tensor tSsS = block_tma_s.partition_D(sS); // (TMA,TMA_M,TMA_K,PIPE) + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + return cute::make_tuple(tSgS, tSsS); + } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + Tensor sZ = make_tensor( + make_smem_ptr(shared_tensors.smem_zero.begin()), SmemLayoutScale{}); // (BLK_M,BLK_K,PIPE) + Tensor gZ_mkl = get<3>(load_inputs); + auto block_tma_z = mainloop_params.tma_load_zero.get_slice(cluster_local_block_id.y); + Tensor gZ = gZ_mkl(_, _, m_coord, _, l_coord); // (BLK_M,BLK_K,k) + + Tensor tZgZ = block_tma_z.partition_S(gZ); // (TMA,TMA_M,TMA_K,k) + Tensor tZsZ = block_tma_z.partition_D(sZ); // (TMA,TMA_M,TMA_K,PIPE) + return cute::make_tuple(tSgS, tSsS, tZgZ, tZsZ); + } else { + static_assert(cutlass::detail::dependent_false, + "Conversion mode not handled for input partitioning."); + } + } else { + static_assert(cutlass::detail::dependent_false, + "Conversion mode not handled for input partitioning."); + } + } + + template + constexpr auto scale_remapping() { + if constexpr (cute::sizeof_bits_v == 8) { + return Layout, Stride<_1, _8, _4>>{}; + } else if constexpr (cute::sizeof_bits_v == 16) { + return Layout, Stride<_1, _4, _2>>{}; + } else { + static_assert(dependent_false, "cute::sizeof_bits_v must be 8 or 16"); + } + } + + /// Utilities for partitioning extra inputs for loading from smem in the mainloop. + template + CUTLASS_DEVICE auto partition_extra_mma_info(ThreadMma const& mma_thread_slice, TensorStorage& shared_tensors) { + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + // nothing to do + return cute::make_tuple(); + } else if constexpr (ModeHasScales) { + Tensor sS = make_tensor( + make_smem_ptr(shared_tensors.smem_scale.begin()), SmemLayoutScale{}); // (BLK_M,BLK_SCALE_K,PIPE) + Tensor tCsS = mma_thread_slice.partition_A(sS); + auto remappingScale = scale_remapping(); + Tensor tCsS_remapped = tCsS.compose(remappingScale, _, _, _); + Tensor tCrS = make_tensor(mma_thread_slice.partition_fragment_A(sS(_, _, Int<0>{})).shape()); + + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + return cute::make_tuple(tCsS_remapped, tCrS); + } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + Tensor sZ = make_tensor( + make_smem_ptr(shared_tensors.smem_zero.begin()), SmemLayoutScale{}); // (BLK_M,BLK_SCALE_K,PIPE) + Tensor tCsZ = mma_thread_slice.partition_A(sZ); + Tensor tCsZ_remapped = tCsZ.compose(remappingScale, _, _, _); + Tensor tCrZ = make_tensor(mma_thread_slice.partition_fragment_A(sZ(_, _, Int<0>{})).shape()); + return cute::make_tuple(tCsS_remapped, tCrS, tCsZ_remapped, tCrZ); + } else { + static_assert( + cutlass::detail::dependent_false, "Conversion mode not handled in A -> RF path."); + } + } else { + static_assert( + cutlass::detail::dependent_false, "Conversion mode not handled in A -> RF path."); + } + } + + /// Returns the tiled copy and copy views for the extra inputs. + template + CUTLASS_DEVICE auto retile_extra_mma_info( + TiledMma const& tiled_mma, cute::tuple& partitioned_extra_info, int const warp_group_thread_idx) { + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + // nothing to do + return cute::make_tuple(); + } else if constexpr (ModeHasScales) { + auto smem_tiled_copy_S = make_tiled_copy_A(SmemCopyAtomScale{}, tiled_mma); + auto smem_thr_copy_S = smem_tiled_copy_S.get_thread_slice(warp_group_thread_idx); + Tensor tCrS_copy_view = smem_thr_copy_S.retile_D(cute::get<1>(partitioned_extra_info)); // (CPY,CPY_M,CPY_K) + + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + return cute::make_tuple(smem_tiled_copy_S, tCrS_copy_view); + } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + Tensor tCrZ_copy_view = smem_thr_copy_S.retile_D(cute::get<3>(partitioned_extra_info)); // (CPY,CPY_M,CPY_K) + return cute::make_tuple(smem_tiled_copy_S, tCrS_copy_view, tCrZ_copy_view); + } else { + static_assert( + cutlass::detail::dependent_false, "Conversion mode not handled in A -> RF path."); + } + } else { + static_assert( + cutlass::detail::dependent_false, "Conversion mode not handled in A -> RF path."); + } + } + + /// Utilities to copy A and extra inputs from smem to RF + template + CUTLASS_DEVICE void copy_A_and_extra_info(SmemTiledCopyA const& smem_tiled_copy_A, TensorASmemView const& tCsA, + TensorACopyView& tCrA_copy_view, cute::tuple const& partitioned_mma_extra_info, + cute::tuple const& tiled_copy_and_views, int k_block, int read_stage, int kNumKIterationsPerWarpBLoad) { + if (kNumKIterationsPerWarpBLoad == 1) { + copy(smem_tiled_copy_A, tCsA(_, _, k_block, read_stage), tCrA_copy_view(_, _, k_block)); + } else { + using reshape_layout = Layout, Int<1>, Int<2>>>; + auto tCrA_copy_view_reshaped = tCrA_copy_view.compose(reshape_layout{}); + if (k_block % kNumKIterationsPerWarpBLoad == 0) + copy(smem_tiled_copy_A, tCsA(_, _, k_block / kNumKIterationsPerWarpBLoad, read_stage), + tCrA_copy_view_reshaped(_, _, k_block / kNumKIterationsPerWarpBLoad)); + } + if (k_block == 0) { + // We are starting a new k-tile so copy the scale + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + // nothing to do + } else if constexpr (ModeHasScales) { + auto smem_tiled_copy_S = cute::get<0>(tiled_copy_and_views); + auto tCrS_copy_view = cute::get<1>(tiled_copy_and_views); + auto tCsS = cute::get<0>(partitioned_mma_extra_info); + copy(smem_tiled_copy_S, tCsS(_, _, k_block, read_stage), tCrS_copy_view(_, _, k_block)); + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + // Nothing extra to do + } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + auto tCsZ = cute::get<2>(partitioned_mma_extra_info); + auto tCrZ_copy_view = cute::get<2>(tiled_copy_and_views); + copy(smem_tiled_copy_S, tCsZ(_, _, k_block, read_stage), tCrZ_copy_view(_, _, k_block)); + } else { + static_assert(cutlass::detail::dependent_false, + "Conversion mode not handled in A -> RF path."); + } + } else { + static_assert( + cutlass::detail::dependent_false, "Conversion mode not handled in A -> RF path."); + } + } + } + + /// Utilities to transform A. + template + CUTLASS_DEVICE void transform_A_kblock(TCrA_load const& tCrA_load, cute::Int vec_A, + TCrA_mma& tCrA_mma, cute::tuple const& partitioned_extra_info, int const k_block, + int kNumKIterationsPerWarpBLoad) { + if (kNumKIterationsPerWarpBLoad != 1) { + if (k_block % kNumKIterationsPerWarpBLoad == 0) { + int k_block_load = k_block / kNumKIterationsPerWarpBLoad; + using reshape_layout = Layout, _1, _2>>; + auto tCrA_load_reshaped = tCrA_load.compose(reshape_layout{}); + auto tCra_mma_reshaped = tCrA_mma.compose(reshape_layout{}); + + using scale_reshape = Layout, _1, _1>, Stride, _0, _0>>; + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + transform_internal_A( + tCrA_load_reshaped(_, _, k_block_load), vec_A, tCra_mma_reshaped(_, _, k_block_load)); + } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + auto tCrS = cute::get<1>(partitioned_extra_info); + auto tCrS_reshaped = tCrS.compose(scale_reshape{}); + transform_internal_A(tCrA_load_reshaped(_, _, k_block_load), vec_A, + make_fragment_like(tCra_mma_reshaped)(_, _, k_block_load), tCrS_reshaped(_, _, 0), + tCra_mma_reshaped(_, _, k_block_load)); + } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + auto tCrS = cute::get<1>(partitioned_extra_info); + auto tCrS_reshaped = tCrS.compose(scale_reshape{}); + auto tCrZ = cute::get<3>(partitioned_extra_info); + auto tCrZ_reshaped = tCrZ.compose(scale_reshape{}); + transform_internal_A(tCrA_load_reshaped(_, _, k_block_load), vec_A, + make_fragment_like(tCra_mma_reshaped)(_, _, k_block_load), tCrS_reshaped(_, _, 0), + tCrZ_reshaped(_, _, 0), tCra_mma_reshaped(_, _, k_block_load)); + } else { + static_assert(cutlass::detail::dependent_false, "No A data is loaded."); + } + } + } else { + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + transform_internal_A(tCrA_load(_, _, k_block), vec_A, tCrA_mma(_, _, k_block)); + } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + auto tCrS = cute::get<1>(partitioned_extra_info); + transform_internal_A(tCrA_load(_, _, k_block), vec_A, + make_fragment_like(tCrA_mma)(_, _, k_block), tCrS(_, _, 0), tCrA_mma(_, _, k_block)); + } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + auto tCrS = cute::get<1>(partitioned_extra_info); + auto tCrZ = cute::get<3>(partitioned_extra_info); + transform_internal_A(tCrA_load(_, _, k_block), vec_A, + make_fragment_like(tCrA_mma)(_, _, k_block), tCrS(_, _, 0), tCrZ(_, _, 0), + tCrA_mma(_, _, k_block)); + } else { + static_assert(cutlass::detail::dependent_false, "No A data is loaded."); + } + } + } + + /// Utilities for transforming the A operand prior to issuing tensorcore math. + template > + CUTLASS_DEVICE void convert_tensor(Tensor const& in, Tensor& out, + cute::Int width = {}) { + /// This is an element-wise conversion where we expect both tensors to have the same layout. + /// As a result, we can cast as a cutlass array to use the fast numeric converters without + /// worrying about indexing into the layout. + constexpr int N = cosize_v; + + /// The inputs must be backed by registers & be statically sized. + static_assert(is_rmem::value, "Input tensor for A conversion must come from registers"); + static_assert(is_rmem::value, "Output tensor for A conversion must come from registers"); + static_assert(is_static_v, "Tensor layout for the conversion must be static"); + static_assert(cosize_v == size(TensorLayout{}), "Cosize and size of the layout must be equal."); + static_assert( + N % ConversionVectorWidth == 0, "Conversion vector width must divide cosize of the tensor layout."); + + using SrcType = typename EngineIn::value_type; + using DstType = typename EngineOut::value_type; + + using SrcArray = cutlass::Array; + using DstArray = cutlass::Array; + + using Converter = std::conditional_t < cutlass::sizeof_bits_v, + cutlass::FastInterleavedAndBiasedNumericArrayConverter, + cutlass::NumericArrayConverter>; + + constexpr int NumIterations = N / ConversionVectorWidth; + + for (int ii = 0; ii < NumIterations; ++ii) { + SrcArray const* src_array_ptr = reinterpret_cast(raw_pointer_cast(in.data())) + ii; + DstArray* dst_array_ptr = reinterpret_cast(raw_pointer_cast(out.data())) + ii; + *dst_array_ptr = Converter::convert(*src_array_ptr); + } + } + + template + CUTLASS_DEVICE void transform_internal_A(Tensor&& in, + cute::Int a_vec_width, Tensor&& out) { + convert_tensor(in, out, a_vec_width); + } + + template + CUTLASS_DEVICE void transform_internal_A(Tensor&& in, + cute::Int a_vec_width, Tensor&& converted_inputs, + Tensor&& scales, Tensor&& out) { + static_assert(cute::is_same_v, + "Type of the engine input buffer must equal the scale buffer"); + + // First, we upcast the inputs to the scale type + convert_tensor(in, converted_inputs, a_vec_width); + + // Apply scales and broadcast across inputs, store in converted_inputs + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<1>(converted_inputs); ++i) { + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size<0>(converted_inputs); ++j) { + if constexpr (cute::is_same_v) { + converted_inputs(j, i) = bfloat16_t(__hmul(reinterpret_cast<__nv_bfloat16 const&>(converted_inputs(j, i)), + reinterpret_cast<__nv_bfloat16 const&>(scales(j, i)))); + } else { + converted_inputs(j, i) *= scales(j, i); + } + } + } + + // Finally, we convert the scaled inputs to the mma type. + convert_tensor(converted_inputs, out); + } + + template + CUTLASS_DEVICE void transform_internal_A(Tensor&& in, + cute::Int a_vec_width, Tensor&& converted_inputs, + Tensor&& scales, Tensor&& zeros, + Tensor&& out) { + static_assert(cute::is_same_v, + "Type of the engine input buffer must equal the scale buffer"); + + static_assert(cute::is_same_v, + "Type of the engine zero buffer must equal the scale buffer"); + + // First, we upcast the inputs to the scale type + convert_tensor(in, converted_inputs, a_vec_width); + + // Apply scales and broadcast across inputs, store in converted_inputs + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<1>(converted_inputs); ++i) { + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size<0>(converted_inputs); ++j) { + if constexpr (cute::is_same_v) { + converted_inputs(j, i) = bfloat16_t(__hfma(reinterpret_cast<__nv_bfloat16 const&>(converted_inputs(j, i)), + reinterpret_cast<__nv_bfloat16 const&>(scales(j, i)), + reinterpret_cast<__nv_bfloat16 const&>(zeros(j, i)))); + } else { + converted_inputs(j, i) = converted_inputs(j, i) * scales(j, i) + zeros(j, i); + } + } + } + + // Finally, we convert the scaled inputs to the mma type. + convert_tensor(converted_inputs, out); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/device/gemm_universal_base_compat.h b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/device/gemm_universal_base_compat.h new file mode 100644 index 0000000000000..c7f2a682323a0 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/device/gemm_universal_base_compat.h @@ -0,0 +1,370 @@ +/* + * Copyright (c) 2017-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! + \file + \brief The universal GEMM accommodates serial reductions, parallel reductions, batched strided, and + batched array variants. +*/ + +#pragma once + +// #include + +#include "cutlass/arch/arch.h" +#include "cutlass/cutlass.h" +#include "cutlass/device_kernel.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/kernel/gemm_universal.h" +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" + +#include "cutlass/gemm/device/default_gemm_configuration.h" +#include "cutlass/gemm/kernel/default_gemm_universal.h" + +#include "cutlass/trace.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/* + This is the device layer from CUTLASS 2.10 (SHA - cc85b64cf676c45f98a17e3a47c0aafcf817f088) + It is replicated here since we needed to duplicate kernel level APIs for mixed dtype GEMMs + and SmoothQuant. The newer device layer is not compatible with these older kernel level APIs. + + Note: While CUTLASS 3.x supports stream-k, none of the kernels in the extensions folder support + that feature at the moment. + */ + +template +class GemmUniversalBaseCompat { + public: + using GemmKernel = GemmKernel_; + using ThreadblockShape = typename GemmKernel::Mma::Shape; + + using ElementA = typename GemmKernel::ElementA; + using LayoutA = typename GemmKernel::LayoutA; + using TensorRefA = TensorRef; + static ComplexTransform const kTransformA = GemmKernel::kTransformA; + + using ElementB = typename GemmKernel::ElementB; + using LayoutB = typename GemmKernel::LayoutB; + using TensorRefB = TensorRef; + static ComplexTransform const kTransformB = GemmKernel::kTransformB; + + using ElementC = typename GemmKernel::ElementC; + using LayoutC = typename GemmKernel::LayoutC; + using TensorRefC = TensorRef; + using TensorRefD = TensorRef; + + using ElementAccumulator = typename GemmKernel::Mma::Policy::Operator::ElementC; + + using EpilogueOutputOp = typename GemmKernel::EpilogueOutputOp; + using ThreadblockSwizzle = typename GemmKernel::ThreadblockSwizzle; + using Operator = typename GemmKernel::Operator; + + /// Argument structure + using Arguments = typename GemmKernel::Arguments; + + protected: + /// Kernel parameters object + typename GemmKernel::Params params_; + + protected: + /// Private helper to obtain the grid dimensions with fix-up for split-K + static void get_grid_shape_(gemm::GemmCoord& grid_tiled_shape, int& gemm_k_size, Arguments const& args) { + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + + grid_tiled_shape = threadblock_swizzle.get_tiled_shape( + args.problem_size, {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, args.batch_count); + + gemm_k_size = args.problem_size.k(); + + if (args.mode == GemmUniversalMode::kGemm || args.mode == GemmUniversalMode::kGemmSplitKParallel) { + int const kAlignK = const_max(const_max(128 / sizeof_bits::value, 128 / sizeof_bits::value), 1); + + gemm_k_size = round_up(ceil_div(args.problem_size.k(), args.batch_count), kAlignK); + + if (gemm_k_size) { + grid_tiled_shape.k() = ceil_div(args.problem_size.k(), gemm_k_size); + } + } + } + + public: + /// Constructs the GEMM. + GemmUniversalBaseCompat() {} + + /// Determines whether the GEMM can execute the given problem. + static Status can_implement(Arguments const& args) { + // Determine grid shape + cutlass::gemm::GemmCoord grid_tiled_shape; + int gemm_k_size = 0; + + get_grid_shape_(grid_tiled_shape, gemm_k_size, args); + + ThreadblockSwizzle threadblock_swizzle; + dim3 grid = threadblock_swizzle.get_grid_shape(grid_tiled_shape); + + uint32_t const kGridYZMax = ((1 << (sizeof(uint16_t) * 8)) - 1); + + if (!(grid.y <= kGridYZMax && grid.z <= kGridYZMax)) { + return Status::kErrorInvalidProblem; + } + + return GemmKernel::can_implement(args); + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const& args) { + CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::get_workspace_size()"); + + size_t workspace_bytes = 0; + + // Determine grid shape + cutlass::gemm::GemmCoord grid_tiled_shape; + int gemm_k_size = 0; + + get_grid_shape_(grid_tiled_shape, gemm_k_size, args); + + if (args.mode == GemmUniversalMode::kGemmSplitKParallel) { + // Split-K parallel always requires a temporary workspace + workspace_bytes = sizeof(ElementC) * size_t(args.batch_stride_D) * size_t(grid_tiled_shape.k()); + } else if (args.mode == GemmUniversalMode::kGemm && grid_tiled_shape.k() > 1) { + // Serial split-K only requires a temporary workspace if the number of partitions along the + // GEMM K dimension is greater than one. + workspace_bytes = sizeof(int) * size_t(grid_tiled_shape.m()) * size_t(grid_tiled_shape.n()); + } + + CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes); + + workspace_bytes += GemmKernel::get_extra_workspace_size(args, grid_tiled_shape); + + return workspace_bytes; + } + + /// Computes the grid shape + static dim3 get_grid_shape(Arguments const& args) { + CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::get_grid_shape()"); + + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord grid_tiled_shape; + int gemm_k_size = 0; + + get_grid_shape_(grid_tiled_shape, gemm_k_size, args); + dim3 result = threadblock_swizzle.get_grid_shape(grid_tiled_shape); + + CUTLASS_TRACE_HOST(" grid_tiled_shape: " << grid_tiled_shape << "\n" + << " result = {" << result << "}"); + + return result; + } + + /// Computes the maximum number of active blocks per multiprocessor + static int maximum_active_blocks(int smem_capacity = -1) { + CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::maximum_active_blocks()"); + + int max_active_blocks = -1; + int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); + + CUTLASS_TRACE_HOST(" smem_size: " << smem_size << " bytes"); + + if (smem_size <= (48 << 10)) { + cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks, Kernel, GemmKernel::kThreadCount, smem_size); + + if (result == cudaSuccess) { + CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks); + return max_active_blocks; + } + } else { + // Query assuming zero shared memory then compute occupancy limit based on SMEM + cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks, Kernel, GemmKernel::kThreadCount, 0); + + if (result != cudaSuccess) { + CUTLASS_TRACE_HOST( + " cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error " << cudaGetErrorString(result)); + + return -1; + } + + if (smem_capacity < 0) { + int device_idx = 0; + result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + return -1; + } + + cudaDeviceProp properties; + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + return -1; + } + + smem_capacity = static_cast(properties.sharedMemPerMultiprocessor); + } + + int occupancy = std::min(max_active_blocks, smem_capacity / smem_size); + + CUTLASS_TRACE_HOST(" occupancy: " << occupancy); + + return occupancy; + } + + CUTLASS_TRACE_HOST(" returning internal error"); + + return -1; + } + + /// Initializes GEMM state from arguments. + Status initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::initialize() - workspace " + << workspace << ", stream: " << (stream ? "non-null" : "null")); + + size_t workspace_bytes = get_workspace_size(args); + + CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes); + + if (workspace_bytes) { + if (!workspace) { + CUTLASS_TRACE_HOST(" error: device workspace must not be null"); + + return Status::kErrorWorkspaceNull; + } + + if (args.mode == GemmUniversalMode::kGemm) { + CUTLASS_TRACE_HOST(" clearing device workspace"); + cudaError_t result = cudaMemsetAsync(workspace, 0, workspace_bytes, stream); + + if (result != cudaSuccess) { + CUTLASS_TRACE_HOST(" cudaMemsetAsync() returned error " << cudaGetErrorString(result)); + + return Status::kErrorInternal; + } + } + } + + // Get CUDA grid shape + cutlass::gemm::GemmCoord grid_tiled_shape; + int gemm_k_size = 0; + + get_grid_shape_(grid_tiled_shape, gemm_k_size, args); + + // Initialize the Params structure + params_ = typename GemmKernel::Params(args, grid_tiled_shape, gemm_k_size, static_cast(workspace)); + + // Specify shared memory capacity for kernel. + int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); + + if (smem_size >= (48 << 10)) { + cudaError_t result = cudaFuncSetAttribute(Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + } + + return Status::kSuccess; + } + + /// Lightweight update given a subset of arguments + Status update(Arguments const& args, void* workspace = nullptr) { + CUTLASS_TRACE_HOST("GemmUniversalBaseCompat()::update() - workspace: " << workspace); + + size_t workspace_bytes = get_workspace_size(args); + + if (workspace_bytes && !workspace) { + return Status::kErrorWorkspaceNull; + } + + params_.update(args, workspace); + + return Status::kSuccess; + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr) { + CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::run()"); + + // + // Configure grid and block dimensions + // + + ThreadblockSwizzle threadblock_swizzle; + + dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); + dim3 block(GemmKernel::kThreadCount, 1, 1); + + int smem_size = static_cast(sizeof(typename GemmKernel::SharedStorage)); + + // + // Launch kernel + // + + CUTLASS_TRACE_HOST(" grid: (" << grid << "), block: (" << block << "), SMEM: " << smem_size << " bytes"); + + // Launch + cutlass::Kernel<<>>(params_); + + // + // Query for errors + // + cudaError_t result = cudaGetLastError(); + + if (result != cudaSuccess) { + CUTLASS_TRACE_HOST(" grid launch failed with error " << cudaGetErrorString(result)); + return Status::kErrorInternal; + } + + return Status::kSuccess; + } + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) { + return run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + Status status = initialize(args, workspace, stream); + + if (status == Status::kSuccess) { + status = run(stream); + } + + return status; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h new file mode 100644 index 0000000000000..83ebe2191717b --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h @@ -0,0 +1,149 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "cutlass/arch/arch.h" +#include "cutlass/arch/mma.h" +#include "cutlass/bfloat16.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/half.h" +#include "cutlass/layout/matrix.h" + +#include "contrib_ops/cuda/llm/cutlass_extensions/arch/mma.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h" + +namespace cutlass { +namespace gemm { +namespace kernel { + +template +struct MixedGemmArchTraits { + static_assert(dependent_false, "Unrecognized parameterization"); +}; + +template +struct MixedGemmArchTraits { + static constexpr int Stages = 2; + using OperatorClass = cutlass::arch::OpClassSimt; + using AccType = float; + using LayoutB = cutlass::layout::ColumnMajor; + + static constexpr int ElementsPerAccessA = 1; + static constexpr int ElementsPerAccessB = 1; + static constexpr int ElementsPerAccessC = 1; + static constexpr int ThreadblockK = 8; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + + using Operator = cutlass::arch::OpMultiplyAdd; +}; + +// ======================= Turing Traits ============================== +// Note that turing does not have native bfloat support so weights and activations will be casted to fp16 +// and compute will happen in fp16 then will be converted for bf16 output. +template +struct MixedGemmArchTraits::value || cutlass::platform::is_same::value>::type> { + private: + using LayoutDetails = LayoutDetailsB; + + public: + static constexpr int ThreadblockK = LayoutDetails::ThreadblockK; + + using OperatorClass = cutlass::arch::OpClassTensorOp; + using AccType = float; + using LayoutB = typename LayoutDetails::Layout; + + static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits::value; + static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess; + static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits::value; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + + using Operator = typename LayoutDetails::Operator; +}; + +// ======================= Ampere Traits ============================== +template +struct MixedGemmArchTraits::value || cutlass::platform::is_same::value>::type> { + private: + using LayoutDetails = LayoutDetailsB; + + public: + static constexpr int ThreadblockK = LayoutDetails::ThreadblockK; + + using OperatorClass = cutlass::arch::OpClassTensorOp; + using AccType = float; + using LayoutB = typename LayoutDetails::Layout; + + static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits::value; + static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess; + static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits::value; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + + using Operator = typename LayoutDetails::Operator; +}; + +// ======================= Ada Traits ============================== +template +struct MixedGemmArchTraits::value || cutlass::platform::is_same::value>::type> { + private: + using LayoutDetails = LayoutDetailsB; + + public: + static constexpr int ThreadblockK = LayoutDetails::ThreadblockK; + + using OperatorClass = cutlass::arch::OpClassTensorOp; + using AccType = float; + using LayoutB = typename LayoutDetails::Layout; + + static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits::value; + static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess; + static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits::value; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256 / cutlass::sizeof_bits::value>; + + using Operator = typename LayoutDetails::Operator; +}; + +// FP8 A/B = fp8, C/D = fp32 +template +struct MixedGemmArchTraits::value || cutlass::platform::is_same::value>::type> { + private: + using LayoutDetails = LayoutDetailsB; + + public: + static constexpr int ThreadblockK = LayoutDetails::ThreadblockK; + + using OperatorClass = cutlass::arch::OpClassTensorOp; + using AccType = float; + // be careful, TypeC should align with TmaWarpSpecializedGroupedGemmInput::OutputTypeAdaptor_t + using TypeC = __nv_bfloat16; + using LayoutB = typename LayoutDetails::Layout; + + static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits::value; + static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess; + static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits::value; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256 / cutlass::sizeof_bits::value>; + + using Operator = typename LayoutDetails::Operator; +}; + +} // namespace kernel +} // namespace gemm +} // namespace cutlass diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/kernel/default_int8_traits.h b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/kernel/default_int8_traits.h new file mode 100644 index 0000000000000..fe4bc0940d9e8 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/kernel/default_int8_traits.h @@ -0,0 +1,51 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "cutlass/arch/arch.h" +#include "cutlass/arch/mma.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/layout/matrix.h" + +namespace cutlass { +namespace gemm { +namespace kernel { + +template +struct Int8GemmArchTraits { + using OperatorClass = cutlass::arch::OpClassSimt; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; +}; + +// ======================= Turing Traits ============================== +template <> +struct Int8GemmArchTraits { + using OperatorClass = cutlass::arch::OpClassTensorOp; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; +}; + +// ======================= Ampere Traits ============================== +template <> +struct Int8GemmArchTraits { + using OperatorClass = cutlass::arch::OpClassTensorOp; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; +}; + +} // namespace kernel +} // namespace gemm +} // namespace cutlass diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h new file mode 100644 index 0000000000000..a888ea3e71487 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h @@ -0,0 +1,461 @@ +/* + * Copyright (c) 2017-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! \file + \brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K. +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/arch/arch.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/semaphore.h" + +#include + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { +template +inline constexpr bool dependent_false_v = false; +} + +template +struct GemmFpAIntB { + using Mma = Mma_; + using Epilogue = Epilogue_; + using EpilogueOutputOp = typename Epilogue::OutputOp; + using ThreadblockSwizzle = ThreadblockSwizzle_; + static bool const kSplitKSerial = SplitKSerial; + + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Element; + using ElementC = typename Epilogue::OutputTileIterator::Element; + using LayoutC = typename Mma::LayoutC; + using ElementScale = ElementC; + + static ComplexTransform const kTransformA = Mma::kTransformA; + static ComplexTransform const kTransformB = Mma::kTransformA; + + // Type definitions about the mainloop. + using Operator = typename Mma::Operator; + using OperatorClass = typename Mma::Operator::OperatorClass; + using ThreadblockShape = typename Mma::Shape; + using WarpShape = typename Mma::Operator::Shape; + using InstructionShape = typename Mma::Policy::Operator::InstructionShape; + using ArchTag = typename Mma::ArchTag; + + static int const kStages = Mma::kStages; + static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; + static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + static constexpr int kInterleave = Mma::IteratorB::Shape::kRow / Mma::Shape::kK; + + /// Parameters structure + struct Arguments { + GemmUniversalMode mode = GemmUniversalMode::kGemm; + + cutlass::gemm::GemmCoord problem_size; + int group_size; + typename Mma::IteratorA::TensorRef ref_A; + typename Mma::IteratorB::TensorRef ref_B; + typename Mma::IteratorScale::TensorRef ref_scale; + typename Mma::IteratorScale::TensorRef ref_zero; + typename Epilogue::OutputTileIterator::TensorRef ref_C; + typename Epilogue::OutputTileIterator::TensorRef ref_D; + + // Control serial split-k + int batch_count; + + typename EpilogueOutputOp::Params output_op; + + // For gather+scatter operations + int const* gather_A_indices; + int const* gather_B_indices; + int const* scatter_D_indices; + + // Included so we can use Gemm Universal + int batch_stride_D = 0; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Arguments() {} + + CUTLASS_HOST_DEVICE + Arguments(cutlass::gemm::GemmCoord const& problem_size, int const group_size, + typename Mma::IteratorA::TensorRef ref_A, typename Mma::IteratorB::TensorRef ref_B, + typename Mma::IteratorScale::TensorRef ref_scale, typename Mma::IteratorScale::TensorRef ref_zero, + typename Epilogue::OutputTileIterator::TensorRef ref_C, + typename Epilogue::OutputTileIterator::TensorRef ref_D, int serial_split_k_factor, + typename EpilogueOutputOp::Params output_op = typename EpilogueOutputOp::Params(), + int const* gather_A_indices = nullptr, int const* gather_B_indices = nullptr, + int const* scatter_D_indices = nullptr) + : problem_size(problem_size), group_size(group_size), ref_A(ref_A), ref_B(ref_B), ref_scale(ref_scale), ref_zero(ref_zero), ref_C(ref_C), ref_D(ref_D), batch_count(serial_split_k_factor), output_op(output_op), gather_A_indices(gather_A_indices), gather_B_indices(gather_B_indices), scatter_D_indices(scatter_D_indices) { + } + }; + + /// Parameters structure + struct Params { + cutlass::gemm::GemmCoord problem_size; + int group_size; + cutlass::gemm::GemmCoord grid_tiled_shape; + int swizzle_log_tile; + typename Mma::IteratorA::Params params_A; + typename Mma::IteratorA::TensorRef ref_A; + typename Mma::IteratorB::Params params_B; + typename Mma::IteratorB::TensorRef ref_B; + typename Mma::IteratorScale::Params params_scale; + typename Mma::IteratorScale::TensorRef ref_scale; + typename Mma::IteratorScale::TensorRef ref_zero; + typename Epilogue::OutputTileIterator::Params params_C; + typename Epilogue::OutputTileIterator::TensorRef ref_C; + typename Epilogue::OutputTileIterator::Params params_D; + typename Epilogue::OutputTileIterator::TensorRef ref_D; + typename EpilogueOutputOp::Params output_op; + int* semaphore; + int gemm_k_size; + // For gather+scatter operations + int const* gather_A_indices; + int const* gather_B_indices; + int const* scatter_D_indices; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params() + : swizzle_log_tile(0), semaphore(0), gemm_k_size(0) { + } + + CUTLASS_HOST_DEVICE + Params(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape, int const gemm_k_size, + void* workspace = nullptr) + : problem_size(args.problem_size), group_size(args.group_size), grid_tiled_shape(grid_tiled_shape), swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)), params_A(args.ref_A.layout()), ref_A(args.ref_A), params_B(args.ref_B.layout()), ref_B(args.ref_B), params_scale(args.ref_scale.layout()), ref_scale(args.ref_scale), ref_zero(args.ref_zero), params_C(args.ref_C.layout()), ref_C(args.ref_C), params_D(args.ref_D.layout()), ref_D(args.ref_D), output_op(args.output_op), semaphore(static_cast(workspace)), gemm_k_size(gemm_k_size), gather_A_indices(args.gather_A_indices), gather_B_indices(args.gather_B_indices), scatter_D_indices(args.scatter_D_indices) { + } + }; + + /// Shared memory storage structure + union SharedStorage { + typename Mma::SharedStorage main_loop; + typename Epilogue::SharedStorage epilogue; + }; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + GemmFpAIntB() {} + + /// Determines whether kernel satisfies alignment + static Status can_implement(Arguments const& args) { + static int const alignmentA = (platform::is_same>::value) ? 32 + : (platform::is_same>::value) + ? 64 + : Mma::IteratorA::AccessType::kElements; + static int const alignmentB = (platform::is_same>::value) ? 32 + : (platform::is_same>::value) + ? 64 + : Mma::IteratorB::AccessType::kElements; + + static int const alignmentScale = Mma::IteratorScale::AccessType::kElements; + + static int const alignmentC = (platform::is_same>::value) + ? 32 + : (platform::is_same>::value) + ? 64 + : Epilogue::OutputTileIterator::kElementsPerAccess; + + if (!TensorRef_aligned(args.ref_A, alignmentA)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(args.ref_B, alignmentB)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(args.ref_scale, alignmentScale)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(args.ref_zero, alignmentScale)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(args.ref_C, alignmentC)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(args.ref_D, alignmentC)) { + return Status::kErrorMisalignedOperand; + } + + if (!args.ref_scale.good()) { + return Status::kErrorNotSupported; + } + + if constexpr (hasZero(Mma::QuantOp)) { + if (!args.ref_zero.good()) { + return Status::kErrorNotSupported; + } + } else { + if (args.ref_zero.good()) { + return Status::kErrorNotSupported; + } + } + + if constexpr (isFinegrained(Mma::QuantOp)) { + if (args.group_size != 64 && args.group_size != 128) { + return Status::kErrorNotSupported; + } + } + + return Status::kSuccess; + } + + static size_t get_extra_workspace_size(Arguments const& /*args*/, cutlass::gemm::GemmCoord const& /*grid_tiled_shape*/) { + return 0; + } + + // Initializes the fine grained scale+bias iterator. Needed since the fine grained iterator + // has a different constructor signature than a regular cutlass iterator + template = true> + CUTLASS_DEVICE static IteratorScale initialize_scale(typename IteratorScale::Params const& params, + typename IteratorScale::Pointer pointer_scale, typename IteratorScale::Pointer pointer_zero, + typename IteratorScale::TensorCoord extent, int thread_id, + typename IteratorScale::TensorCoord const& threadblock_offset, int group_size) { + return IteratorScale(params, pointer_scale, pointer_zero, extent, thread_id, threadblock_offset, group_size); + } + + template = true> + CUTLASS_DEVICE static IteratorScale initialize_scale(typename IteratorScale::Params const& params, + typename IteratorScale::Pointer pointer_scale, typename IteratorScale::Pointer pointer_zero, + typename IteratorScale::TensorCoord extent, int thread_id, + typename IteratorScale::TensorCoord const& threadblock_offset, int group_size) { + return IteratorScale(params, pointer_scale, extent, thread_id, threadblock_offset); + } + + CUTLASS_DEVICE + void run_kernel_(Params const& params, SharedStorage& shared_storage) { + using LayoutB = typename Mma::IteratorB::Layout; + static_assert(platform::is_same::value && kInterleave == 1 || platform::is_same::value && kInterleave >= 1, + "B must be row major/col major OR col major interleaved."); + + // Compute threadblock location + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // Early exit if CTA is out of range + if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { + return; + } + + // Compute initial location in logical coordinates + cutlass::MatrixCoord tb_offset_A{ + threadblock_tile_offset.m() * Mma::Shape::kM, + threadblock_tile_offset.k() * params.gemm_k_size, + }; + + cutlass::MatrixCoord tb_offset_B{threadblock_tile_offset.k() * params.gemm_k_size * kInterleave, + threadblock_tile_offset.n() * Mma::Shape::kN / kInterleave}; + + typename MatrixCoord::Index fg_row_offset = threadblock_tile_offset.k() * params.gemm_k_size / 64; + typename MatrixCoord::Index scale_row_offset = isFinegrained(Mma::QuantOp) ? fg_row_offset : 0; + cutlass::MatrixCoord tb_offset_scale{scale_row_offset, threadblock_tile_offset.n() * Mma::Shape::kN}; + + // Problem size is a function of threadblock index in the K dimension + int problem_size_k = min(params.problem_size.k(), (threadblock_tile_offset.k() + 1) * params.gemm_k_size); + + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations = (problem_size_k - tb_offset_A.column() + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A(params.params_A, params.ref_A.data(), + {params.problem_size.m(), problem_size_k}, thread_idx, tb_offset_A, params.gather_A_indices); + + typename Mma::IteratorB iterator_B(params.params_B, params.ref_B.data(), + {problem_size_k * kInterleave, params.problem_size.n() / kInterleave}, thread_idx, tb_offset_B, + params.gather_B_indices); + + typename MatrixCoord::Index scale_row_extent = isFinegrained(Mma::QuantOp) ? problem_size_k / 64 : 1; + typename Mma::IteratorScale iterator_scale = initialize_scale( + params.params_scale, params.ref_scale.data(), params.ref_zero.data(), + {scale_row_extent, params.problem_size.n()}, thread_idx, tb_offset_scale, params.group_size); + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + int lane_idx = threadIdx.x % 32; + + // + // Main loop + // + // Construct thread-scoped matrix multiply + Mma mma(shared_storage.main_loop, params.group_size, thread_idx, warp_idx, lane_idx); + + typename Mma::FragmentC accumulators; + + accumulators.clear(); + + if (!kSplitKSerial || gemm_k_iterations > 0) { + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, iterator_scale, accumulators); + } + + // + // Epilogue + // + + EpilogueOutputOp output_op(params.output_op); + + // + // Masked tile iterators constructed from members + // + + threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // assume identity swizzle + MatrixCoord threadblock_offset( + threadblock_tile_offset.m() * Mma::Shape::kM, threadblock_tile_offset.n() * Mma::Shape::kN); + + int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); + + // Construct the semaphore. + Semaphore semaphore(params.semaphore + block_idx, thread_idx); + + // If performing a reduction via split-K, fetch the initial synchronization + if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { + // Fetch the synchronization lock initially but do not block. + semaphore.fetch(); + + // Indicate which position in a serial reduction the output operator is currently updating + output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); + } + + // Tile iterator loading from source tensor. + typename Epilogue::OutputTileIterator iterator_C(params.params_C, params.ref_C.data(), params.problem_size.mn(), + thread_idx, threadblock_offset, params.scatter_D_indices); + + // Tile iterator writing to destination tensor. + typename Epilogue::OutputTileIterator iterator_D(params.params_D, params.ref_D.data(), params.problem_size.mn(), + thread_idx, threadblock_offset, params.scatter_D_indices); + + Epilogue epilogue(shared_storage.epilogue, thread_idx, warp_idx, lane_idx); + + // Wait on the semaphore - this latency may have been covered by iterator construction + if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { + // For subsequent threadblocks, the source matrix is held in the 'D' tensor. + if (threadblock_tile_offset.k()) { + iterator_C = iterator_D; + } + + semaphore.wait(threadblock_tile_offset.k()); + } + + // Execute the epilogue operator to update the destination tensor. + epilogue(output_op, iterator_D, accumulators, iterator_C); + + // + // Release the semaphore + // + + if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { + int lock = 0; + if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) { + // The final threadblock resets the semaphore for subsequent grids. + lock = 0; + } else { + // Otherwise, the semaphore is incremented + lock = threadblock_tile_offset.k() + 1; + } + + semaphore.release(lock); + } + } + + template + CUTLASS_DEVICE void run_kernel(Params const& params, SharedStorage& shared_storage) { + if constexpr (platform::is_same::value) { + run_kernel_(params, shared_storage); + } else { + CUTLASS_NOT_IMPLEMENTED(); + } + } + + /* + To improve compilation speed, we do not compile the device operator if the CUDA_ARCH does not correspond + to the ArchTag of the cutlass kernel operator. + */ + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const& params, SharedStorage& shared_storage) { +#if defined(__CUDA_ARCH__) +#if (__CUDA_ARCH__ >= 750) && (__CUDA_ARCH__ < 800) + run_kernel(params, shared_storage); +#elif (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 890) + run_kernel(params, shared_storage); +#elif (__CUDA_ARCH__ == 890) + run_kernel(params, shared_storage); +#elif (__CUDA_ARCH__ >= 1000) + // Use SM80 implementation for GB10x, GB20x. + run_kernel(params, shared_storage); +#else + CUTLASS_NOT_IMPLEMENTED(); // Don't compile these for Hopper or later. Use CUTLASS 3.x kernels. +#endif +#else + CUTLASS_NOT_IMPLEMENTED(); +#endif + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/kernel/gemm_with_epilogue_visitor.h b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/kernel/gemm_with_epilogue_visitor.h new file mode 100644 index 0000000000000..163a43238a425 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/kernel/gemm_with_epilogue_visitor.h @@ -0,0 +1,451 @@ +/* + * Copyright (c) 2017-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! \file + \brief GEMM kernel to support the epilogue visitor model + for customized softmax partial reduction epilogue fusion. + + This source file will likely be moved to `include/cutlass/gemm/kernel/` in the future once + its usage has been stabilized. For now, it is included in this example to demonstrate + some basic output fusion options. + + original file: 3rdparty/cutlass/examples/35_gemm_softmax/gemm_with_epilogue_visitor.h +*/ + +#pragma once + +#include "cutlass/complex.h" +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/semaphore.h" +#include "cutlass/trace.h" + +#include "contrib_ops/cuda/llm/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h" + +namespace tk = onnxruntime::llm::common; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct GemmWithEpilogueVisitor { + public: + using Mma = Mma_; + using Epilogue = Epilogue_; + using EpilogueVisitor = typename Epilogue::Visitor; + using ThreadblockSwizzle = ThreadblockSwizzle_; + + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using TensorRefA = TensorRef; + + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Layout; + using TensorRefB = TensorRef; + + using ElementCompute = typename EpilogueVisitor::ElementCompute; + using LayoutAlphaCol = cutlass::layout::RowMajor; + using LayoutAlphaRow = cutlass::layout::ColumnMajor; + using TensorRefAlphaCol = TensorRef; + using TensorRefAlphaRow = TensorRef; + + using ElementC = typename EpilogueVisitor::ElementOutput; + using LayoutC = typename Epilogue::Layout; + using TensorRefC = TensorRef; + + static ComplexTransform const kTransformA = Mma::kTransformA; + static ComplexTransform const kTransformB = Mma::kTransformB; + using Operator = typename Mma::Operator; + + using OperatorClass = typename Mma::Operator::OperatorClass; + using ThreadblockShape = typename Mma::Shape; + using WarpShape = typename Mma::Operator::Shape; + using InstructionShape = typename Mma::Policy::Operator::InstructionShape; + using ArchTag = typename Mma::ArchTag; + using EpilogueOutputOp = + typename Epilogue::Visitor::ElementwiseFunctor; // Define type so GemmUniversalBase doesn't complain + + static int const kStages = Mma::kStages; + static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; + static int const kAlignmentC = EpilogueVisitor::kElementsPerAccess; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + /// Split-K preserves splits that are 128b aligned + static int const kSplitKAlignment = const_max(128 / sizeof_bits::value, 128 / sizeof_bits::value); + + // + // Structures + // + + /// Argument structure + struct Arguments { + // + // Data members + // + + GemmUniversalMode mode; + GemmCoord problem_size; + int batch_count; + + TensorRefA ref_A; + TensorRefB ref_B; + tk::QuantMode quant_option; + TensorRefAlphaCol ref_alpha_col; + TensorRefAlphaRow ref_alpha_row; + TensorRefC ref_C; + TensorRefC ref_D; + + int64_t batch_stride_A; + int64_t batch_stride_B; + int64_t batch_stride_D; + + typename EpilogueVisitor::Arguments epilogue_visitor; + + // + // Methods + // + + Arguments() + : mode(GemmUniversalMode::kGemm), batch_count(1) { + } + + /// constructs an arguments structure + Arguments(GemmUniversalMode mode_, GemmCoord problem_size_, int batch_count_, TensorRefA ref_A_, + TensorRefB ref_B_, tk::QuantMode quant_option_, TensorRefAlphaCol ref_alpha_col_, + TensorRefAlphaRow ref_alpha_row_, TensorRefC ref_C_, TensorRefC ref_D_, int64_t batch_stride_A_, + int64_t batch_stride_B_, typename EpilogueVisitor::Arguments epilogue_visitor_) + : mode(mode_), problem_size(problem_size_), batch_count(batch_count_), ref_A(ref_A_), ref_B(ref_B_), quant_option(quant_option_), ref_alpha_col(ref_alpha_col_), ref_alpha_row(ref_alpha_row_), ref_C(ref_C_), ref_D(ref_D_), batch_stride_A(batch_stride_A_), batch_stride_B(batch_stride_B_), batch_stride_D(0), epilogue_visitor(epilogue_visitor_) { + } + }; + + // + // Structure for precomputing values in host memory and passing to kernels + // + + /// Parameters structure + struct Params { + cutlass::gemm::GemmCoord problem_size; + cutlass::gemm::GemmCoord grid_tiled_shape; + int swizzle_log_tile; + + typename Mma::IteratorA::Params params_A; + typename Mma::IteratorB::Params params_B; + typename EpilogueVisitor::ScaleTileIterator::Params params_alpha_col; + typename EpilogueVisitor::ScaleTileIterator::Params params_alpha_row; + typename EpilogueVisitor::OutputTileIterator::Params params_C; + typename EpilogueVisitor::OutputTileIterator::Params params_D; + + GemmUniversalMode mode; + int batch_count; + int gemm_k_size; + + void* ptr_A; + void* ptr_B; + tk::QuantMode quant_option; + typename EpilogueVisitor::ScaleTileIterator::Element* ptr_alpha_col; + typename EpilogueVisitor::ScaleTileIterator::Element* ptr_alpha_row; + ElementC* ptr_C; + ElementC* ptr_D; + + int64_t batch_stride_A; + int64_t batch_stride_B; + + typename EpilogueVisitor::Params epilogue_visitor; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params() + : swizzle_log_tile(0), params_A(0), params_B(0), params_alpha_col(0), params_C(0), params_D(0), batch_count(0), gemm_k_size(0), mode(cutlass::gemm::GemmUniversalMode::kGemm), ptr_A(nullptr), ptr_B(nullptr), ptr_alpha_col(nullptr), ptr_alpha_row(nullptr), ptr_C(nullptr), ptr_D(nullptr), batch_stride_A(0), batch_stride_B(0) { + } + + Params( + Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape_, int gemm_k_size_, int* workspace_) + : problem_size(args.problem_size), swizzle_log_tile(0), params_A(args.ref_A.layout()), params_B(args.ref_B.layout()), params_alpha_col(args.ref_alpha_col.layout()), params_alpha_row(args.ref_alpha_col.layout()), params_C(args.ref_C.layout()), params_D(args.ref_D.layout()), mode(args.mode), batch_count(args.batch_count), gemm_k_size(args.problem_size.k()), ptr_A(args.ref_A.data()), ptr_B(args.ref_B.data()), quant_option(args.quant_option), ptr_alpha_col(args.ref_alpha_col.data()), ptr_alpha_row(args.ref_alpha_row.data()), ptr_C(args.ref_C.data()), ptr_D(args.ref_D.data()), batch_stride_A(args.batch_stride_A), batch_stride_B(args.batch_stride_B), epilogue_visitor(args.epilogue_visitor) { + ThreadblockSwizzle threadblock_swizzle; + + grid_tiled_shape = threadblock_swizzle.get_tiled_shape(args.problem_size, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, args.batch_count); + + if (args.mode == GemmUniversalMode::kGemm || args.mode == GemmUniversalMode::kGemmSplitKParallel) { + int const kAlignK = const_max(const_max(128 / sizeof_bits::value, 128 / sizeof_bits::value), 1); + + gemm_k_size = round_up(ceil_div(args.problem_size.k(), args.batch_count), kAlignK); + + if (gemm_k_size) { + grid_tiled_shape.k() = ceil_div(args.problem_size.k(), gemm_k_size); + } + } + + swizzle_log_tile = threadblock_swizzle.get_log_tile(grid_tiled_shape); + } + }; + + /// Shared memory storage structure + union SharedStorage { + typename Mma::SharedStorage main_loop; + + struct + { + typename Epilogue::SharedStorage epilogue; + typename EpilogueVisitor::SharedStorage visitor; + } epilogue; + }; + + public: + // + // Methods + // + + CUTLASS_DEVICE + GemmWithEpilogueVisitor() {} + + /// Determines whether kernel satisfies alignment + static Status can_implement(cutlass::gemm::GemmCoord const& problem_size) { + CUTLASS_TRACE_HOST("GemmWithEpilogueVisitor::can_implement()"); + + static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; + static int const kAlignmentC = EpilogueVisitor::OutputTileIterator::kElementsPerAccess; + + bool isAMisaligned = false; + bool isBMisaligned = false; + bool isCMisaligned = false; + + if (platform::is_same::value) { + isAMisaligned = problem_size.k() % kAlignmentA; + } else if (platform::is_same::value) { + isAMisaligned = problem_size.m() % kAlignmentA; + } else if (platform::is_same>::value || platform::is_same>::value) { + isAMisaligned = problem_size.k() % kAlignmentA; + } + + if (platform::is_same::value) { + isBMisaligned = problem_size.n() % kAlignmentB; + } else if (platform::is_same::value) { + isBMisaligned = problem_size.k() % kAlignmentB; + } else if (platform::is_same>::value || platform::is_same>::value) { + isBMisaligned = problem_size.k() % kAlignmentB; + } + + if (platform::is_same::value) { + isCMisaligned = problem_size.n() % kAlignmentC; + } else if (platform::is_same::value) { + isCMisaligned = problem_size.m() % kAlignmentC; + } else if (platform::is_same>::value || platform::is_same>::value) { + isCMisaligned = problem_size.n() % kAlignmentC; + } + + if (isAMisaligned) { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for A operand"); + return Status::kErrorMisalignedOperand; + } + + if (isBMisaligned) { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for B operand"); + return Status::kErrorMisalignedOperand; + } + + if (isCMisaligned) { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for C operand"); + return Status::kErrorMisalignedOperand; + } + + CUTLASS_TRACE_HOST(" returning kSuccess"); + + return Status::kSuccess; + } + + static Status can_implement(Arguments const& args) { + return can_implement(args.problem_size); + } + + static size_t get_extra_workspace_size(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape) { + return 0; + } + +#define SPLIT_K_ENABLED 1 + + /// Executes one GEMM + CUTLASS_DEVICE + void run_kernel_(Params const& params, SharedStorage& shared_storage) { + // Compute threadblock location + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // Early exit if CTA is out of range + if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { + return; + } + + int offset_k = 0; + int problem_size_k = params.problem_size.k(); + + ElementA* ptr_A = static_cast(params.ptr_A); + ElementB* ptr_B = static_cast(params.ptr_B); + +#if SPLIT_K_ENABLED + // + // Fetch pointers based on mode. + // + if (params.mode == GemmUniversalMode::kGemm || params.mode == GemmUniversalMode::kGemmSplitKParallel) { + if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) { + problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size; + } + + offset_k = threadblock_tile_offset.k() * params.gemm_k_size; + } else if (params.mode == GemmUniversalMode::kBatched) { + ptr_A += threadblock_tile_offset.k() * params.batch_stride_A; + ptr_B += threadblock_tile_offset.k() * params.batch_stride_B; + } else if (params.mode == GemmUniversalMode::kArray) { + ptr_A = static_cast(params.ptr_A)[threadblock_tile_offset.k()]; + ptr_B = static_cast(params.ptr_B)[threadblock_tile_offset.k()]; + } +#endif + + // Compute initial location in logical coordinates + cutlass::MatrixCoord tb_offset_A{ + threadblock_tile_offset.m() * Mma::Shape::kM, + offset_k, + }; + + cutlass::MatrixCoord tb_offset_B{offset_k, threadblock_tile_offset.n() * Mma::Shape::kN}; + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A( + params.params_A, ptr_A, {params.problem_size.m(), problem_size_k}, thread_idx, tb_offset_A); + + typename Mma::IteratorB iterator_B( + params.params_B, ptr_B, {problem_size_k, params.problem_size.n()}, thread_idx, tb_offset_B); + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + + int lane_idx = threadIdx.x % 32; + + // + // Main loop + // + + // Construct thread-scoped matrix multiply + Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); + + typename Mma::FragmentC accumulators; + + accumulators.clear(); + + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations = (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators); + + // + // Masked tile iterators constructed from members + // + + threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // assume identity swizzle + MatrixCoord threadblock_offset( + threadblock_tile_offset.m() * Mma::Shape::kM, threadblock_tile_offset.n() * Mma::Shape::kN); + + int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); + + // + // Construct the epilogue visitor + // + + EpilogueVisitor epilogue_visitor(params.epilogue_visitor, shared_storage.epilogue.visitor, + params.problem_size.mn(), thread_idx, warp_idx, lane_idx, params.params_alpha_col, params.params_C, + params.params_D, params.quant_option, params.ptr_alpha_row, params.ptr_alpha_col, params.ptr_C, + params.ptr_D, threadblock_offset, blockIdx.y * params.problem_size.m()); + + if (params.mode == GemmUniversalMode::kGemm) { + // Indicate which position in a serial reduction the output operator is currently updating + epilogue_visitor.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); + } else if (params.mode == GemmUniversalMode::kBatched || params.mode == GemmUniversalMode::kArray) { + epilogue_visitor.set_batch_index(threadblock_tile_offset.k()); + } + + // Construct the epilogue + Epilogue epilogue(shared_storage.epilogue.epilogue, thread_idx, warp_idx, lane_idx); + + // Execute the epilogue operator to update the destination tensor. + epilogue(epilogue_visitor, accumulators); + } + + template + CUTLASS_DEVICE void run_kernel(Params const& params, SharedStorage& shared_storage) { + if constexpr (platform::is_same::value) { + run_kernel_(params, shared_storage); + } else { + CUTLASS_NOT_IMPLEMENTED(); + } + } + + /* + To improve compilation speed, we do not compile the device operator if the CUDA_ARCH does not correspond + to the ArchTag of the cutlass kernel operator. + */ + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const& params, SharedStorage& shared_storage) { +#if defined(__CUDA_ARCH__) +#if (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 900) + run_kernel(params, shared_storage); +#elif (__CUDA_ARCH__ >= 900) + // TODO - replace with CUTLASS_NOT_IMPLEMENTED() and upgrade to 3.x kernels. + run_kernel(params, shared_storage); +#else + static_assert( + false, "Invalid architecture being compiled. Only Ampere+ supported in weight-only quantization kernels."); +#endif +#else + CUTLASS_NOT_IMPLEMENTED(); +#endif + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h new file mode 100644 index 0000000000000..c0656ac784830 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h @@ -0,0 +1,112 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/* + This file exists so that we use the same weight layout for MoE grouped gemm and regular gemm when the weight is + quantized. The preprocessing code reads this template to know how to organize the quantized weight matrices + to be consumed by CUTLASS. + + Note that for int4, ThreadBlockK MUST be 64. + + */ + +#pragma once + +#include "cutlass/layout/matrix.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/arch/arch.h" +#include "cutlass/arch/mma.h" +#include "cutlass/platform/platform.h" + +#include "contrib_ops/cuda/llm/cutlass_extensions/arch/mma.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/tile_interleaved_layout.h" + +namespace cutlass { +namespace gemm { +namespace kernel { + +template +struct LayoutDetailsB { +}; + +// Specializations for Turing+ when B is FP16. These are currently only used for MoE networks. +// TODO - Switch this to column major for weights since gemms should be more performant. +template +struct LayoutDetailsB= 75>::type> { + static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits::value; + using Layout = layout::ColumnMajor; + static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; + using Operator = cutlass::arch::OpMultiplyAdd; +}; + +template +struct LayoutDetailsB= 75>::type> { + static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits::value; + using Layout = layout::ColumnMajor; + static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; + using Operator = cutlass::arch::OpMultiplyAdd; +}; + +template +struct LayoutDetailsB { + static constexpr int ThreadblockK = 64; + + private: + static constexpr int ElementsPerCacheLine = 128 * 8 / sizeof_bits::value; + static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK; + + public: + using Layout = layout::ColumnMajor; + static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; + using Operator = cutlass::arch::OpMultiplyAdd; + // for fast accumulation + // using Operator = cutlass::arch::OpMultiplyAddFastAccum; +}; + +// Specializations for Turing+ when B is quantized. These can use the operator OpMultiplyAddDequantizeInterleavedBToA, +// which signals that we want to dequantize after loading from smem. +template +struct LayoutDetailsB= 75>::type> { + static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits::value; + + private: + static constexpr int ElementsPerCacheLine = 128 * 8 / sizeof_bits::value; + static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK; + + public: + using Layout = layout::ColumnMajorTileInterleave; + static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; + using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA; +}; + +template +struct LayoutDetailsB= 75>::type> { + static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits::value; + + private: + static constexpr int ElementsPerCacheLine = 128 * 8 / sizeof_bits::value; + static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK; + + public: + using Layout = layout::ColumnMajorTileInterleave; + static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; + using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA; +}; + +} // namespace kernel +} // namespace gemm +} // namespace cutlass diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/default_dq_mma.h b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/default_dq_mma.h new file mode 100644 index 0000000000000..ef28dcc46cd21 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/default_dq_mma.h @@ -0,0 +1,117 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "contrib_ops/cuda/llm/cutlass_extensions/arch/mma.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/interleaved_numeric_conversion.h" + +namespace cutlass { +namespace gemm { +namespace threadblock { +//////////////////////////////////////////////////////////////////////////////// + +// We need to distinguish here, since we want volta support. It is too much effort +// to write shared memory iterators that are probably needed for volta to function +// properly. As a result, we allow converters both after the LDG (for volta) and after +// the LDS for Turing+. +template < + /// Iterator for B matrix in global memory + typename IteratorB, + /// Warp level Mma + typename MmaOperator, + /// Math operation perform by warp level operator + typename MathOperator> +struct SetConverters { +}; + +// Dequantize after LDG, so set transforms accordingly +template < + /// Iterator for B matrix in global memory + typename IteratorB, + /// Mma Policy + typename MmaOperator> +struct SetConverters { + using TransformAfterLDG = FastInterleavedAndBiasedNumericArrayConverter; + + using TransformAfterLDS = NumericArrayConverter; +}; + +// Dequantize after LDS, so set transforms accordingly + +template < + /// Iterator for B matrix in global memory + typename IteratorB, + /// Mma Policy + typename MmaOperator> +struct SetConverters { + using TransformAfterLDG = NumericArrayConverter; + + using TransformAfterLDS = FastInterleavedAndBiasedNumericArrayConverter; +}; + +//////////////////////////////////////////////////////////////////////////////// + +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for the input scale + typename ElementScale_, + /// Layout for the scale operand + typename LayoutScale_, + /// Access granularity of Scales in unit of elements + int kAlignmentScale, + /// Element type for internal accumulation + typename ElementAccumulator_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Operator class tag + typename OperatorClass_, + /// Tag indicating architecture to tune for + typename ArchTag_, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape_, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape_, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape_, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Operation performed by GEMM + typename Operator_, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, + /// + typename Enable = void> +struct DqMma; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h new file mode 100644 index 0000000000000..8d73329ed7713 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h @@ -0,0 +1,289 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "cutlass/gemm/threadblock/default_mma.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/arch/mma.h" + +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/dq_mma_multistage.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/warp/default_mma_tensor_op.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/tile_interleaved_layout.h" + +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/default_dq_mma.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h" + +namespace cutlass { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +template +struct DefaultScaleIteratorsMultistage; + +// Fine grained iterators +template +struct DefaultScaleIteratorsMultistage> { + using IteratorScale = cutlass::transform::threadblock::FineGrainedScaleZeroIterator, Element, + Layout, 0, Alignment>; + + using SmemIteratorScale = IteratorScale; +}; + +// Per column iterators +template +struct DefaultScaleIteratorsMultistage> { + // ThreadMap for scale iterator + static_assert((MmaShape::kN % Alignment) == 0, ""); + + private: + using IteratorScaleThreadMap = transform::PitchLinearStripminedThreadMap, + MmaShape::kN / Alignment, Alignment>; + + public: + // Define iterators over tiles from the scale operand + using IteratorScale = cutlass::transform::threadblock::PredicatedTileIterator, + Element, Layout, 0, IteratorScaleThreadMap, Alignment>; + + using SmemIteratorScale = IteratorScale; +}; + +//////////////////////////////////////////////////////////////////////////////// + +template < + /// Type for element A + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Type for element B + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for the input scale + typename ElementScale, + /// Layout for the scale operand + typename LayoutScale, + /// Access granularity of Scales in unit of elements + int kAlignmentScale, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Stages in GEMM + int kStages, + /// Operator performed by GEMM + typename Operator_, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear> +struct DqMma= 80 && !layout::IsColumnMajorTileInterleave::value)>::type> { + static_assert(platform::is_same::value || platform::is_same::value || platform::is_same::value, + "Element A must be fp16, fp8 or bf16"); + + using OperatorInfo = arch::DetagOperator; + using Operator = typename OperatorInfo::Operator; + static_assert(platform::is_same::value, + "Mma multistage must dequantize after ldsm"); + + static_assert(platform::is_same::value || platform::is_same::value, + "Element B must be uint8 or uint4"); + + static cutlass::arch::CacheOperation::Kind const CacheOpA = ((sizeof_bits::value * kAlignmentA) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + static cutlass::arch::CacheOperation::Kind const CacheOpB = ((sizeof_bits::value * kAlignmentB) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + // Define the MmaCore components + // Mma core does not depend on stages, so pass in at least 3 here to mma multistage pieces are created + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::Array; + using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, ElementA, LayoutA, 1, ThreadMapA, + AccessTypeA>; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::Array; + using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, ElementB, LayoutB, 0, ThreadMapB, + AccessTypeB>; + + using ScaleIterators = DefaultScaleIteratorsMultistage; + + // Define iterators over tiles from the scale operand + using IteratorScale = typename ScaleIterators::IteratorScale; + + using SmemIteratorScale = typename ScaleIterators::SmemIteratorScale; + + using Converter = FastInterleavedAndBiasedNumericArrayConverter; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::DqMmaMultistage; +}; + +// Specialization to handle column major interleave B +template < + /// Type for element A + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Type for element B + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for the input scale + typename ElementScale, + /// Layout for the scale operand + typename LayoutScale, + /// Access granularity of Scales in unit of elements + int kAlignmentScale, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Stages in GEMM + int kStages, + /// Operator performed by GEMM + typename Operator_, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear> +struct DqMma= 80 && layout::IsColumnMajorTileInterleave::value)>::type> { + static_assert(platform::is_same::value || platform::is_same::value || platform::is_same::value, + "Element A must be fp16, fp8 or bf16"); + + using OperatorInfo = arch::DetagOperator; + using Operator = typename OperatorInfo::Operator; + static_assert(platform::is_same::value, + "Mma multistage must dequantize after ldsm"); + + static_assert(platform::is_same::value || platform::is_same::value, + "Element B must be uint8 or uint4"); + + static cutlass::arch::CacheOperation::Kind const CacheOpA = ((sizeof_bits::value * kAlignmentA) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + static cutlass::arch::CacheOperation::Kind const CacheOpB = ((sizeof_bits::value * kAlignmentB) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + // Define the MmaCore components + // Mma core does not depend on stages, so pass in at least 3 here to mma multistage pieces are created + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::Array; + using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, ElementA, LayoutA, 1, ThreadMapA, + AccessTypeA>; + + private: + static constexpr int ColumnsInterleaved = LayoutB::kColumnsInterleaved; + static constexpr int RowsPerTile = LayoutB::kRowsPerTile; + static_assert(!(MmaCore::Shape::kN % ColumnsInterleaved), ""); + static_assert(RowsPerTile == MmaCore::Shape::kK, ""); + + using OriginalThreadMap = typename MmaCore::IteratorThreadMapB; + using OriginalWarpArrangement = typename OriginalThreadMap::Detail::WarpThreadArrangement; + static_assert(!(OriginalWarpArrangement::kStrided % ColumnsInterleaved), ""); + + using GmemIteratorShape = MatrixShape; + using GmemThreadMapB = transform::PitchLinearWarpRakedThreadMap< + layout::PitchLinearShape, OriginalThreadMap::kThreads, + layout::PitchLinearShape, + MmaCore::kAccessSizeInBits / sizeof_bits::value>; + + public: + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::Array; + using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator; + + using ScaleIterators = DefaultScaleIteratorsMultistage; + + // Define iterators over tiles from the scale operand + using IteratorScale = typename ScaleIterators::IteratorScale; + + using SmemIteratorScale = typename ScaleIterators::SmemIteratorScale; + + using Converter = FastInterleavedAndBiasedNumericArrayConverter; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::DqMmaMultistage; +}; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h new file mode 100644 index 0000000000000..ae0cee20d3575 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h @@ -0,0 +1,270 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "cutlass/gemm/threadblock/default_mma.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/arch/mma.h" + +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/warp/default_mma_tensor_op.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/tile_interleaved_layout.h" + +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/default_dq_mma.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h" + +namespace cutlass { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +template +struct DefaultScaleIteratorsPipelined; + +// Fine grained iterators +template +struct DefaultScaleIteratorsPipelined> { + private: + using SmemScaleType = half_t; + + public: + using IteratorScale = cutlass::transform::threadblock::FineGrainedScaleZeroIterator, Element, + Layout, 0, Alignment>; + + using SmemIteratorScale = cutlass::transform::threadblock::FineGrainedScaleZeroIterator, + SmemScaleType, Layout, 0, Alignment>; +}; + +// Per column iterators +template +struct DefaultScaleIteratorsPipelined> { + static_assert((MmaShape::kN % Alignment) == 0, ""); + + private: + // ThreadMap for scale iterator + using IteratorScaleThreadMap = transform::PitchLinearStripminedThreadMap, + MmaShape::kN / Alignment, Alignment>; + using SmemScaleType = half_t; + + public: + // Define iterators over tiles from the scale operand + using IteratorScale = cutlass::transform::threadblock::PredicatedTileIterator, + Element, Layout, 0, IteratorScaleThreadMap, Alignment>; + + using SmemIteratorScale = cutlass::transform::threadblock::PredicatedTileIterator, SmemScaleType, + Layout, 0, IteratorScaleThreadMap, Alignment>; +}; + +//////////////////////////////////////////////////////////////////////////////// + +template < + /// Type for element A + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Type for element B + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for the input scale + typename ElementScale, + /// Layout for the scale operand + typename LayoutScale, + /// Access granularity of Scales in unit of elements + int kAlignmentScale, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator_> +struct DqMma::value)>::type> { + static_assert(platform::is_same::value || platform::is_same::value, + "Element A must be fp16 or bf16"); + + static_assert(platform::is_same::value || platform::is_same::value, + "Element B must be uint8 or uint4"); + + using OperatorInfo = arch::DetagOperator; + using Operator = typename OperatorInfo::Operator; + static_assert(OperatorInfo::QuantOp == WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY, ""); + + static constexpr bool DqAfterLDG = platform::is_same::value; + using MmaCoreElementA = half_t; + using MmaCoreElementB = typename platform::conditional::type; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, ElementA, LayoutA, 1, + typename MmaCore::IteratorThreadMapA, kAlignmentA>; + + // Define iterators over tiles from the B operand + using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, ElementB, LayoutB, 0, + typename MmaCore::IteratorThreadMapB, kAlignmentB>; + + using ScaleIterators = DefaultScaleIteratorsPipelined; + + // Define iterators over tiles from the scale operand + using IteratorScale = typename ScaleIterators::IteratorScale; + + using SmemIteratorScale = typename ScaleIterators::SmemIteratorScale; + + using Converters = SetConverters; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::DqMmaPipelined; +}; + +// Specialization to handle column major interleave B +template < + /// Type for element A + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Type for element B + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for the input scale + typename ElementScale, + /// Layout for the scale operand + typename LayoutScale, + /// Access granularity of Scales in unit of elements + int kAlignmentScale, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator_> +struct DqMma::value)>::type> { + static_assert(platform::is_same::value || platform::is_same::value, + "Element A must be fp16 or bf16"); + + static_assert(platform::is_same::value || platform::is_same::value, + "Element B must be uint8 or uint4"); + + using OperatorInfo = arch::DetagOperator; + using Operator = typename OperatorInfo::Operator; + + static constexpr bool DqAfterLDG = platform::is_same::value; + using MmaCoreElementA = half_t; + using MmaCoreElementB = typename platform::conditional::type; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, ElementA, LayoutA, 1, + typename MmaCore::IteratorThreadMapA, kAlignmentA>; + + private: + static constexpr int ColumnsInterleaved = LayoutB::kColumnsInterleaved; + static constexpr int RowsPerTile = LayoutB::kRowsPerTile; + static_assert(!(MmaCore::Shape::kN % ColumnsInterleaved), ""); + static_assert(RowsPerTile == MmaCore::Shape::kK, ""); + + using OriginalThreadMap = typename MmaCore::IteratorThreadMapB; + using OriginalWarpArrangement = typename OriginalThreadMap::Detail::WarpThreadArrangement; + static_assert(!(OriginalWarpArrangement::kStrided % ColumnsInterleaved), ""); + + using GmemIteratorShape = MatrixShape; + using GmemThreadMapB = transform::PitchLinearWarpRakedThreadMap< + layout::PitchLinearShape, OriginalThreadMap::kThreads, + layout::PitchLinearShape, + MmaCore::kAccessSizeInBits / sizeof_bits::value>; + + public: + // Define iterators over tiles from the B operand + using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator; + + // ThreadMap for scale iterator + static_assert((MmaCore::Shape::kN % kAlignmentScale) == 0, ""); + using IteratorScaleThreadMap = transform::PitchLinearStripminedThreadMap, + MmaCore::Shape::kN / kAlignmentScale, kAlignmentScale>; + + using ScaleIterators = DefaultScaleIteratorsPipelined; + + // Define iterators over tiles from the scale operand + using IteratorScale = typename ScaleIterators::IteratorScale; + + using SmemIteratorScale = typename ScaleIterators::SmemIteratorScale; + + using Converters = SetConverters; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::DqMmaPipelined; +}; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/default_mma.h b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/default_mma.h new file mode 100644 index 0000000000000..dfe99c271f547 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/default_mma.h @@ -0,0 +1,336 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/default_mma_bf16.h" + +namespace cutlass { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int8 weight, mma pipelined (stage=2) +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator> +struct DefaultMma { + private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + + public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; + +//////////////////////////////////////////////////////////////////////////////// +/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight, mma pipelined (stage=2) +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator> +struct DefaultMma { + private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + + public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; + +//////////////////////////////////////////////////////////////////////////////// +/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int8 weight, mma multistage +/// (stage>=3) +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator, + /// + int kStages, + /// Shared memory clear option + SharedMemoryClearOption SharedMemoryClear> +struct DefaultMma { + private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + + public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; + +//////////////////////////////////////////////////////////////////////////////// +/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight, mma multistage +/// (stage>=3) +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator, + /// + int kStages, + /// Shared memory clear option + SharedMemoryClearOption SharedMemoryClear> +struct DefaultMma { + private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + + public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; + +#ifdef ENABLE_FP8 +//////////////////////////////////////////////////////////////////////////////// +/// Specialization for row-major output (OperatorClass TensorOp), fp8 activation & int4 weight, mma multistage +/// (stage>=3) +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator, + /// + int kStages, + /// Shared memory clear option + SharedMemoryClearOption SharedMemoryClear> +struct DefaultMma { + private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + + public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; + +#endif + +// fp16 x fp16 specialization on Ampere to use mma multistage for 2 stage. Helps avoid reg spills on +// large tile when not enough shared mem is present to do 3+ stage +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear, + /// Gather operand A by using an index array + bool GatherA, + /// Gather operand B by using an index array + bool GatherB> +struct DefaultMma { + // Define the MmaCore components + // 3 is used on purpose here to trigger components for mma multistage + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::Array; + using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, half_t, LayoutA, 1, ThreadMapA, AccessTypeA, + GatherA>; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::Array; + using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, half_t, LayoutB, 0, ThreadMapB, AccessTypeB, + GatherB>; + + // Define the threadblock-scoped multistage matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::MmaMultistage; +}; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/default_mma_bf16.h b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/default_mma_bf16.h new file mode 100644 index 0000000000000..cb5ce0f72b362 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/default_mma_bf16.h @@ -0,0 +1,336 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "cutlass/gemm/threadblock/default_mma.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h" + +namespace cutlass { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & bf16 weight +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear, + /// Gather operand A by using an index array + bool GatherA, + /// Gather operand B by using an index array + bool GatherB> +struct DefaultMma { + private: + using MmaElementA = bfloat16_t; + using MmaElementB = bfloat16_t; + + public: + // Define the MmaCore components + using MmaCore = + typename cutlass::gemm::threadblock::DefaultMmaCore; + + using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, bfloat16_t, LayoutA, 1, + typename MmaCore::IteratorThreadMapA, kAlignmentA, GatherA>; + + // Define iterators over tiles from the B operand + using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, bfloat16_t, LayoutB, 0, + typename MmaCore::IteratorThreadMapB, kAlignmentB, GatherB>; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::MmaPipelined; +}; + +// bf16 x bf16 specialization on Ampere to use mma multistage for 2 stage. Helps avoid reg spills on +// large tile when not enough shared mem is present to do 3+ stage +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear, + /// Gather operand A by using an index array + bool GatherA, + /// Gather operand B by using an index array + bool GatherB> +struct DefaultMma { + // Define the MmaCore components + // 3 is used on purpose here to trigger components for mma multistage + using MmaCore = + typename cutlass::gemm::threadblock::DefaultMmaCore; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::Array; + using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, bfloat16_t, LayoutA, 1, ThreadMapA, + AccessTypeA, GatherA>; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::Array; + using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, bfloat16_t, LayoutB, 0, ThreadMapB, + AccessTypeB, GatherB>; + + // Define the threadblock-scoped multistage matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::MmaMultistage; +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & int8 weight +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator> +struct DefaultMma { + private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + + public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; + +//////////////////////////////////////////////////////////////////////////////// +/// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & int4 weight +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator> +struct DefaultMma { + private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + + public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; + +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator, + /// + int kStages, + /// Shared memory clear option + SharedMemoryClearOption SharedMemoryClear> +struct DefaultMma { + private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + + public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; + +//////////////////////////////////////////////////////////////////////////////// +/// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & int4 weight +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator, + /// + int kStages, + /// Shared memory clear option + SharedMemoryClearOption SharedMemoryClear> +struct DefaultMma { + private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + + public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/dq_mma_base.h b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/dq_mma_base.h new file mode 100644 index 0000000000000..cad280febbe76 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/dq_mma_base.h @@ -0,0 +1,211 @@ +/* + * Copyright (c) 2017-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/threadblock/mma_base.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +#include "contrib_ops/cuda/llm/cutlass_extensions/weight_only_quant_op.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// The type of the scales + typename ElementScale_, + /// Number of stages, + int Stages, + /// The dequantizing op to be performed. + WeightOnlyQuantOp DequantOp, + /// Used for partial specialization, + typename Enable = bool> +class DqMmaBase { + public: + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + + ///< Policy describing tuning details + using Policy = Policy_; + + ///< Type of the scale to be loaded + using ElementScale = ElementScale_; + + static_assert(DequantOp != WeightOnlyQuantOp::UNDEFINED, ""); + + // Finegrained scales get streamed in via cp.async + static constexpr int ScalebiasStages = isFinegrained(DequantOp) ? Stages : 1; + // We always have scales. + static constexpr int ScaleElementsPerStage = Shape::kN; + // We sometimes have a bias + static constexpr int BiasElementsPerStage = hasZero(DequantOp) ? Shape::kN : 0; + + // + // Dependent types + // + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Shape describing the overall GEMM computed from shared memory + /// by each warp. + using WarpGemm = typename Policy::Operator::Shape; + + /// Shape describing the number of warps filling the CTA + using WarpCount = GemmShape; + + /// Number of warp-level GEMM operations + static int const kWarpGemmIterations = (WarpGemm::kK / Operator::Policy::MmaShape::kK); + + static constexpr int kNumKIterationsPerWarpBLoad = Operator::IteratorB::InstructionShape::kRow / Operator::InstructionShape::kK; + + static_assert(!(kWarpGemmIterations % kNumKIterationsPerWarpBLoad), ""); + static constexpr int kWarpGemmIterationsForB = kWarpGemmIterations / kNumKIterationsPerWarpBLoad; + + /// Number of stages + static int const kStages = Stages; + + /// Tensor reference to the A operand + using TensorRefA = TensorRef; + + /// Tensor reference to the B operand + using TensorRefB = TensorRef; + + // + // Nested structs + // + + /// Shared storage object needed by threadblock-scoped GEMM + class SharedStorage { + public: + // + // Type definitions + // + + /// Shape of the A matrix operand in shared memory + using ShapeA = MatrixShape; + + /// Shape of the B matrix operand in shared memory + using ShapeB = MatrixShape; + + /// Shape of the shared memory buffer for the scales for the B matrix. + using ShapeScale = MatrixShape; + /// Shape of the shared memory buffer for the biases of the B matrix. + using ShapeZero = MatrixShape; + + public: + // + // Data members + // + + /// Buffer for A operand + AlignedBuffer operand_A; + + /// Buffer for B operand + AlignedBuffer operand_B; + + /// Buffer to hold scales for threadblock + AlignedBuffer operand_scale; + + /// Buffer to hold scales for threadblock + AlignedBuffer operand_zero; + + public: + // + // Methods + // + + /// Returns a layout object for the A matrix + CUTLASS_DEVICE + static typename Operator::LayoutA LayoutA() { + return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn}); + } + + /// Returns a layout object for the B matrix + CUTLASS_HOST_DEVICE + static typename Operator::LayoutB LayoutB() { + return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn}); + } + + /// Returns a TensorRef to the A operand + CUTLASS_HOST_DEVICE + TensorRefA operand_A_ref() { + return TensorRefA{operand_A.data(), LayoutA()}; + } + + /// Returns a TensorRef to the B operand + CUTLASS_HOST_DEVICE + TensorRefB operand_B_ref() { + return TensorRefB{operand_B.data(), LayoutB()}; + } + }; + + protected: + // + // Data members + // + + /// Iterator to load a warp-scoped tile of A operand from shared memory + typename Operator::IteratorA warp_tile_iterator_A_; + + /// Iterator to load a warp-scoped tile of B operand from shared memory + typename Operator::IteratorB warp_tile_iterator_B_; + + public: + /// Construct from tensor references + CUTLASS_DEVICE + DqMmaBase( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + SharedStorage& shared_storage, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx), warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx) { + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/dq_mma_multistage.h b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/dq_mma_multistage.h new file mode 100644 index 0000000000000..78b6abb50513f --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/dq_mma_multistage.h @@ -0,0 +1,93 @@ +/* + * Copyright (c) 2017-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/dq_mma_base.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/interleaved_numeric_conversion.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Cache operation for operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB, + /// Data type for the scales + typename IteratorScale_, + /// Iterators over scales in shared memory + typename SmemIteratorScale_, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Converter for B matrix applited immediately after the LDS + typename TransformBAfterLDS_, + /// The quantization operator being used + WeightOnlyQuantOp QuantOp_, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, + /// Used for partial specialization + typename Enable = void> +class DqMmaMultistage; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/dq_mma_multistage_finegrained.h" diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/dq_mma_multistage_finegrained.h b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/dq_mma_multistage_finegrained.h new file mode 100644 index 0000000000000..5db74039469c4 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/dq_mma_multistage_finegrained.h @@ -0,0 +1,612 @@ +/* + * Copyright (c) 2017-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/dq_mma_base.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/interleaved_numeric_conversion.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Cache operation for operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB, + /// Iterators over scales in global memory + typename IteratorScale_, + /// Iterators over scales in shared memory + typename SmemIteratorScale_, + /// Data type of accumulator matrix + typename ElementC_, + /// Layout of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Converter for B matrix applied immediately after the LDS + typename TransformBAfterLDS_, + /// The quantization operator being used + WeightOnlyQuantOp QuantOp_, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear> +class DqMmaMultistage> + : public DqMmaBase { + public: + ///< Base class + using Base = DqMmaBase; + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + ///< Iterates over tiles of A operand in global memory + using IteratorA = IteratorA_; + ///< Iterates over tiles of B operand in global memory + using IteratorB = IteratorB_; + ///< Data type of accumulator matrix + using ElementC = ElementC_; + ///< Layout of accumulator matrix + using LayoutC = LayoutC_; + ///< Policy describing tuning details + using Policy = Policy_; + + using IteratorScale = IteratorScale_; + using ElementScale = typename IteratorScale::Element; + using LayoutScale = typename IteratorScale::Layout; + + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB = SmemIteratorB_; + using SmemIteratorScale = SmemIteratorScale_; + + static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; + + using TransformBAfterLDS = TransformBAfterLDS_; + + static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; + // + // Dependent types + // + + /// Fragment of accumulator tile + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Minimum architecture is Sm80 to support cp.async + using ArchTag = arch::Sm80; + + using Dequantizer = warp::MmaTensorOpDequantizer; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = Operator::kTransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = Operator::kTransformB; + + static_assert(Base::SharedStorage::ShapeScale::kRow == Stages, ""); + static_assert(Base::SharedStorage::ShapeScale::kColumn == Shape::kN, ""); + + /// Internal structure exposed for introspection. + struct Detail { + static_assert(Base::kWarpGemmIterations > 1, + "The pipelined structure requires at least two warp-level " + "GEMM operations."); + + /// Number of cp.async instructions to load one stage of operand A + static int const AsyncCopyIterationsPerStageA = IteratorA::ThreadMap::Iterations::kCount; + + /// Number of cp.async instructions to load one stage of operand B + static int const AsyncCopyIterationsPerStageB = IteratorB::ThreadMap::Iterations::kCount; + + /// Number of stages + static int const kStages = Stages; + + /// Number of cp.async instructions to load on group of operand A + static int const kAccessesPerGroupA = (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + + /// Number of cp.async instructions to load on group of operand B + static int const kAccessesPerGroupB = (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + }; + + private: + using WarpFragmentA = typename Operator::FragmentA; + using WarpFragmentB = typename Operator::FragmentB; + Dequantizer warp_dequantizer_; + + using ElementA = typename IteratorA::Element; + using ElementB = typename IteratorB::Element; + using LayoutDetailsForB = kernel::LayoutDetailsB; + + static constexpr bool RequiresTileInterleave = layout::IsColumnMajorTileInterleave::value; + static_assert(!RequiresTileInterleave || (RequiresTileInterleave && (Shape::kK == LayoutDetailsForB::ThreadblockK)), + "Layout K must match threadblockK"); + + private: + // + // Data members + // + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + + /// Iterator to write threadblock-scoped tile of scale and zero operand to shared memory + SmemIteratorScale smem_iterator_scale_; + + public: + /// Construct from tensor references + CUTLASS_DEVICE + DqMmaMultistage( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + typename Base::SharedStorage& shared_storage, + /// The group size for quantization + int const group_size, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : Base(shared_storage, thread_idx, warp_idx, lane_idx), warp_dequantizer_({shared_storage.operand_scale.data(), LayoutScale(Shape::kN)}, {shared_storage.operand_zero.data(), LayoutScale(Shape::kN)}, (warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / Base::WarpCount::kM, lane_idx), smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx), smem_iterator_scale_(LayoutScale(Shape::kN), shared_storage.operand_scale.data(), shared_storage.operand_zero.data(), {Base::kStages, Shape::kN}, thread_idx, group_size) { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n}); + } + + CUTLASS_DEVICE + void copy_scales_and_advance(IteratorScale& iterator_scale, int stage = -1, int k_iter = -1) { + static_assert(IteratorScale::Shape::kRow == 1, "Scale stride must be 1."); + + typename IteratorScale::AccessType* gmem_scale_ptr = iterator_scale.get_scale(); + typename IteratorScale::AccessType* gmem_zero_ptr = iterator_scale.get_zero(); + + typename IteratorScale::AccessType* smem_scale_ptr = reinterpret_cast(this->smem_iterator_scale_.get_scale()); + typename IteratorScale::AccessType* smem_zero_ptr = reinterpret_cast(this->smem_iterator_scale_.get_zero()); + + int const kSrcBytes = sizeof_bits::value * IteratorScale::kAlignment / 8; + + cutlass::arch::cp_async(smem_scale_ptr, gmem_scale_ptr, iterator_scale.valid()); + + if (gmem_zero_ptr != nullptr) { + cutlass::arch::cp_async(smem_zero_ptr, gmem_zero_ptr, iterator_scale.valid()); + } + + if (iterator_scale.group_size_ == 64) { + iterator_scale.add_tile_offset({1, 0}); + } else if (iterator_scale.group_size_ == 128) { + if constexpr (Shape::kK == 128) { + iterator_scale.add_tile_offset({1, 0}); + } else if constexpr (Shape::kK == 64) { + if (iterator_scale.row_groupsize64_ & 0x1) { + iterator_scale.add_tile_offset({1, 0}); + } + } else { + static_assert(Shape::kK == 0, "Unsupported k tile shape, can only be 64 or 128"); + } + } + + iterator_scale.row_groupsize64_++; + + this->smem_iterator_scale_.add_tile_offset({1, 0}); + } + + CUTLASS_DEVICE + void copy_tiles_and_advance( + IteratorA& iterator_A, IteratorB& iterator_B, int group_start_A = 0, int group_start_B = 0) { + iterator_A.set_iteration_index(group_start_A * IteratorA::kAccessesPerVector); + this->smem_iterator_A_.set_iteration_index(group_start_A); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) { + if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) { + typename IteratorA::AccessType* dst_ptr = reinterpret_cast(this->smem_iterator_A_.get()); + + int const kSrcBytes = sizeof_bits::value * IteratorA::ThreadMap::kElementsPerAccess / IteratorA::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_A.get(); + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { + cutlass::arch::cp_async_zfill(dst_ptr + v, gmem_ptr, iterator_A.valid()); + } else { + cutlass::arch::cp_async(dst_ptr + v, gmem_ptr, iterator_A.valid()); + } + + ++iterator_A; + } + + ++this->smem_iterator_A_; + } + } + + iterator_B.set_iteration_index(group_start_B * IteratorB::kAccessesPerVector); + this->smem_iterator_B_.set_iteration_index(group_start_B); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { + if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { + typename IteratorB::AccessType* dst_ptr = reinterpret_cast(this->smem_iterator_B_.get()); + + int const kSrcBytes = sizeof_bits::value * IteratorB::ThreadMap::kElementsPerAccess / IteratorB::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_B.get(); + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { + cutlass::arch::cp_async_zfill(dst_ptr + v, gmem_ptr, iterator_B.valid()); + } else { + cutlass::arch::cp_async(dst_ptr + v, gmem_ptr, iterator_B.valid()); + } + + ++iterator_B; + } + ++this->smem_iterator_B_; + } + } + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + ///< problem size of GEMM + int gemm_k_iterations, + ///< destination accumulator tile + FragmentC& accum, + ///< iterator over A operand in global memory + IteratorA iterator_A, + ///< iterator over B operand in global memory + IteratorB iterator_B, + ///< iterator over scale operand in global memory + IteratorScale iterator_scale, + ///< initial value of accumulator + FragmentC const& src_accum) { + // + // Prologue + // + + TransformBAfterLDS lds_converter; + + // Issue several complete stages + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < Base::kStages - 1; ++stage, --gemm_k_iterations) { + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + iterator_scale.clear_mask(gemm_k_iterations == 0); + + iterator_A.set_iteration_index(0); + this->smem_iterator_A_.set_iteration_index(0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { + typename IteratorA::AccessType* dst_ptr = reinterpret_cast(this->smem_iterator_A_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + int const kSrcBytes = sizeof_bits::value * IteratorA::ThreadMap::kElementsPerAccess / IteratorA::kAccessesPerVector / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_A.get(), iterator_A.valid()); + + ++iterator_A; + } + + ++this->smem_iterator_A_; + } + + iterator_B.set_iteration_index(0); + this->smem_iterator_B_.set_iteration_index(0); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + typename IteratorB::AccessType* dst_ptr = reinterpret_cast(this->smem_iterator_B_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + int const kSrcBytes = sizeof_bits::value * IteratorB::ThreadMap::kElementsPerAccess / IteratorB::kAccessesPerVector / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_B.get(), iterator_B.valid()); + + ++iterator_B; + } + + ++this->smem_iterator_B_; + } + + copy_scales_and_advance(iterator_scale, stage, gemm_k_iterations); + + // Move to the next stage + iterator_A.add_tile_offset({0, 1}); + iterator_B.add_tile_offset({1, 0}); + + this->smem_iterator_A_.add_tile_offset({0, 1}); + this->smem_iterator_B_.add_tile_offset({1, 0}); + + // Defines the boundary of a stage of cp.async. + cutlass::arch::cp_async_fence(); + } + + // Perform accumulation in the 'd' output operand + accum = src_accum; + + // + // Clear the remaining tiles of SMEM. This is a functional requirement for some kernels + // so that all accumulator elements outside the GEMM footprint are zero. + // + + if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage) { + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_); + + typename IteratorA::AccessType zero_A; + zero_A.clear(); + + last_smem_iterator_A.set_iteration_index(0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { + typename IteratorA::AccessType* dst_ptr = reinterpret_cast(last_smem_iterator_A.get()); + + *dst_ptr = zero_A; + + ++last_smem_iterator_A; + } + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB last_smem_iterator_B(this->smem_iterator_B_); + typename IteratorB::AccessType zero_B; + + zero_B.clear(); + last_smem_iterator_B.set_iteration_index(0); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + typename IteratorB::AccessType* dst_ptr = reinterpret_cast(last_smem_iterator_B.get()); + + *dst_ptr = zero_B; + + ++last_smem_iterator_B; + } + } + + // Wait until we have at least one committed global fetch stage. (#uncommitted = Base::kStages - 1 - #committed) + cutlass::arch::cp_async_wait(); + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpFragmentA warp_frag_A[2]; + WarpFragmentB warp_frag_B[2]; + typename Dequantizer::FragmentScale warp_frag_scales; + typename Dequantizer::FragmentZero warp_frag_zeros; + + Operator warp_mma; + + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_B_.set_kgroup_index(0); + + this->warp_tile_iterator_A_.load(warp_frag_A[0]); + this->warp_tile_iterator_B_.load(warp_frag_B[0]); + + warp_dequantizer_.load(warp_frag_scales, warp_frag_zeros); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + warp_dequantizer_.add_pointer_offset(Shape::kN); + + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + iterator_scale.clear_mask(gemm_k_iterations == 0); + + int smem_write_stage_idx = Base::kStages - 1; + int smem_read_stage_idx = 0; + + // + // Mainloop + // + + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > (-Base::kStages + 1);) { + // + // Loop over GEMM K dimension + // + + // Computes a warp-level GEMM on data held in shared memory + // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) { + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + + this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); + ++this->warp_tile_iterator_A_; + + int const warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad; + int const warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad; + if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1) { + this->warp_tile_iterator_B_.set_kgroup_index( + (warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB); + this->warp_tile_iterator_B_.load(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]); + ++this->warp_tile_iterator_B_; + } + + typename TransformBAfterLDS::result_type converted_frag_B = lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]); + warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales, warp_frag_zeros); + + using FragmentOperandB = cutlass::Array; + constexpr cutlass::FloatRoundStyle RoundStyle = cutlass::FloatRoundStyle::round_to_nearest; + constexpr int ConversionVectorWidth = TransformBAfterLDS::result_type::kElements; + static_assert(ConversionVectorWidth == FragmentOperandB::kElements); + + using Converter = cutlass::NumericArrayConverter; + + FragmentOperandB converted_frag_B_operand = Converter::convert(converted_frag_B); + warp_mma( + accum, warp_frag_A[warp_mma_k % 2], converted_frag_B_operand, accum, warp_tileB_k_compute_offset); + + // Issue global->shared copies for the this stage + if (warp_mma_k < Base::kWarpGemmIterations - 1) { + int group_start_iteration_A, group_start_iteration_B; + + group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA; + group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB; + + copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, group_start_iteration_B); + + // This is the first group of a given stage, so we issue the loads for the B scales immediately. + if (group_start_iteration_B == 0) { + copy_scales_and_advance(iterator_scale); + } + } + + if (warp_mma_k + 2 == Base::kWarpGemmIterations) { + int group_start_iteration_A, group_start_iteration_B; + group_start_iteration_A = (warp_mma_k + 1) * Detail::kAccessesPerGroupA; + group_start_iteration_B = (warp_mma_k + 1) * Detail::kAccessesPerGroupB; + + copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, group_start_iteration_B); + + // Inserts a memory fence between stages of cp.async instructions. + cutlass::arch::cp_async_fence(); + + // Wait until we have at least one committed global fetch stage. (#uncommitted = Base::kStages - 1 - + // #committed) + arch::cp_async_wait(); + __syncthreads(); + + // Move to the next stage + iterator_A.add_tile_offset({0, 1}); + iterator_B.add_tile_offset({1, 0}); + + this->smem_iterator_A_.add_tile_offset({0, 1}); + this->smem_iterator_B_.add_tile_offset({1, 0}); + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (smem_write_stage_idx == (Base::kStages - 1)) { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + this->smem_iterator_scale_.add_tile_offset({-Base::kStages, 0}); + smem_write_stage_idx = 0; + } else { + ++smem_write_stage_idx; + } + + if (smem_read_stage_idx == (Base::kStages - 1)) { + this->warp_tile_iterator_A_.add_tile_offset( + {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterationsForB, 0}); + warp_dequantizer_.add_pointer_offset(-Base::kStages * Shape::kN); + smem_read_stage_idx = 0; + } else { + ++smem_read_stage_idx; + } + + --gemm_k_iterations; + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + iterator_scale.clear_mask(gemm_k_iterations == 0); + } + } + + // Load the scale needed for the next tile iteration. + warp_dequantizer_.load(warp_frag_scales, warp_frag_zeros); + // Update internal pointer to set of scales in shared memory. + warp_dequantizer_.add_pointer_offset(Shape::kN); + } + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { + // commit and drain all pending and predicated LDGSTS pnz from the GEMM mainloop + cutlass::arch::cp_async_fence(); + cutlass::arch::cp_async_wait<0>(); + __syncthreads(); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h new file mode 100644 index 0000000000000..e992915cafeea --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h @@ -0,0 +1,89 @@ +/* + * Copyright (c) 2017-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/numeric_conversion.h" + +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/gemm.h" + +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/dq_mma_base.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/interleaved_numeric_conversion.h" + +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm_configs.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Data type for the scales + typename IteratorScale_, + /// Iterators over scales in shared memory + typename SmemIteratorScale_, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Converter for B matrix applied immediately after the LDG (before STS) + typename TransformBAfterLDG_, + /// Converter for B matrix applited immediately after the LDS + typename TransformBAfterLDS_, + /// The quantization operator being used + WeightOnlyQuantOp QuantOp_, + /// Used for partial specialization + typename Enable = void> +class DqMmaPipelined; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/dq_mma_pipelined_finegrained.h" diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/dq_mma_pipelined_finegrained.h b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/dq_mma_pipelined_finegrained.h new file mode 100644 index 0000000000000..b362195834c87 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/dq_mma_pipelined_finegrained.h @@ -0,0 +1,431 @@ +/* + * Copyright (c) 2017-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/numeric_conversion.h" + +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/gemm.h" + +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/dq_mma_base.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/interleaved_numeric_conversion.h" + +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm_configs.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Iterators over scales in global memory + typename IteratorScale_, + /// Iterators over scales in shared memory + typename SmemIteratorScale_, + /// Data type of accumulator matrix + typename ElementC_, + /// Layout of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Converter for B matrix applied immediately after the LDG (before STS) + typename TransformBAfterLDG_, + /// Converter for B matrix applited immediately after the LDS + typename TransformBAfterLDS_, + /// The quantization operator being used + WeightOnlyQuantOp QuantOp_> +class DqMmaPipelined> + : public DqMmaBase { + public: + ///< Base class + using Base = DqMmaBase; + + using Shape = Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using IteratorA = IteratorA_; ///< Iterates over tiles of A operand in global memory + using IteratorB = IteratorB_; ///< Iterates over tiles of B operand in global memory + using ElementC = ElementC_; ///< Data type of accumulator matrix + using LayoutC = LayoutC_; ///< Layout of accumulator matrix + using Policy = Policy_; ///< Policy describing tuning details + + using IteratorScale = IteratorScale_; + using ElementScale = typename IteratorScale::Element; + using LayoutScale = typename IteratorScale::Layout; + + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB = SmemIteratorB_; + using SmemIteratorScale = SmemIteratorScale_; + + using TransformBAfterLDG = TransformBAfterLDG_; + using TransformBAfterLDS = TransformBAfterLDS_; + + static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; + + // + // Dependent types + // + + /// Fragment of operand A loaded from global memory + using FragmentA = typename IteratorA::Fragment; + + /// Fragment of operand B loaded from global memory + using FragmentB = typename IteratorB::Fragment; + + /// Fragment of operand Scale loaded from global memory; + using FragmentScale = typename IteratorScale::Fragment; + + /// Fragment of accumulator tile + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Obtain the arch tag from the warp-level operator + using ArchTag = typename Policy::Operator::ArchTag; + + using Dequantizer = warp::MmaTensorOpDequantizer; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = Operator::kTransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = Operator::kTransformB; + + // staticaly assert kStages for DqMmaPipelined is two (Double-buffered pipeline) + static_assert((Base::kStages == 2), "DqMmaPipelined requires kStages set to value 2"); + + static_assert(Base::SharedStorage::ShapeScale::kRow == Base::kStages, ""); + static_assert(Base::SharedStorage::ShapeScale::kColumn == Shape::kN, ""); + + private: + using WarpFragmentA = typename Operator::FragmentA; + using WarpFragmentB = typename Operator::FragmentB; + Dequantizer warp_dequantizer_; + + using WarpFragmentScale = typename Dequantizer::FragmentScale; + using WarpFragmentZero = typename Dequantizer::FragmentZero; + + using ElementA = typename IteratorA::Element; + using ElementB = typename IteratorB::Element; + using LayoutDetailsForB = kernel::LayoutDetailsB; + + static constexpr bool RequiresTileInterleave = layout::IsColumnMajorTileInterleave::value; + static_assert(!RequiresTileInterleave || (RequiresTileInterleave && (Shape::kK == LayoutDetailsForB::ThreadblockK)), + "Layout K must match threadblockK"); + + protected: + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + + /// Iterator to write threadblock-scoped tile of scale and zero operand to shared memory + SmemIteratorScale smem_iterator_scale_; + + public: + /// Construct from tensor references + CUTLASS_DEVICE + DqMmaPipelined(typename Base::SharedStorage& + shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM + int const group_size, ///< The group size for quantization + int thread_idx, ///< ID within the threadblock + int warp_idx, ///< ID of warp + int lane_idx ///< ID of each thread within a warp + ) + : Base(shared_storage, thread_idx, warp_idx, lane_idx), warp_dequantizer_({shared_storage.operand_scale.data(), LayoutScale(Shape::kN)}, {shared_storage.operand_zero.data(), LayoutScale(Shape::kN)}, (warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / Base::WarpCount::kM, lane_idx), smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx), smem_iterator_scale_(LayoutScale(Shape::kN), shared_storage.operand_scale.data(), shared_storage.operand_zero.data(), {Base::kStages, Shape::kN}, thread_idx, group_size) { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n}); + } + + CUTLASS_DEVICE + void copy_scales_and_advance(IteratorScale& iterator_scale) { + using TransformScale = NumericArrayConverter; + + FragmentScale tb_frag_scales; + FragmentScale tb_frag_zeros; + tb_frag_scales.clear(); + tb_frag_zeros.clear(); + + TransformScale transformScale; + + using FragmentElement = typename FragmentScale::Element; + + auto gmem_scale_ptr = iterator_scale.get_scale(); + auto gmem_zero_ptr = iterator_scale.get_zero(); + + arch::global_load(tb_frag_scales, gmem_scale_ptr, iterator_scale.valid()); + + if (gmem_zero_ptr != nullptr) { + arch::global_load( + tb_frag_zeros, gmem_zero_ptr, iterator_scale.valid()); + } + + typename TransformScale::result_type tb_frag_scales_fp16 = transformScale(tb_frag_scales); + typename TransformScale::result_type tb_frag_zeros_fp16; + if (gmem_zero_ptr != nullptr) + tb_frag_zeros_fp16 = transformScale(tb_frag_zeros); + + auto frag_scale_ptr_fp16 = reinterpret_cast(&tb_frag_scales_fp16); + auto frag_zero_ptr_fp16 = reinterpret_cast(&tb_frag_zeros_fp16); + auto smem_scale_ptr = this->smem_iterator_scale_.get_scale(); + auto smem_zero_ptr = this->smem_iterator_scale_.get_zero(); + + if (iterator_scale.valid()) { + auto smem_offset = cast_smem_ptr_to_uint(smem_scale_ptr); + arch::shared_store(smem_offset, frag_scale_ptr_fp16); + + if (gmem_zero_ptr != nullptr) { + smem_offset = cast_smem_ptr_to_uint(smem_zero_ptr); + arch::shared_store(smem_offset, frag_zero_ptr_fp16); + } + } + + if (iterator_scale.group_size_ == 64) { + iterator_scale.add_tile_offset({1, 0}); + } else if (iterator_scale.group_size_ == 128) { + if constexpr (Shape::kK == 128) { + iterator_scale.add_tile_offset({1, 0}); + } else if constexpr (Shape::kK == 64) { + if (iterator_scale.row_groupsize64_ & 0x1) { + iterator_scale.add_tile_offset({1, 0}); + } + } else { + static_assert(Shape::kK == 0, "Unsupported k tile shape, can only be 64 or 128"); + } + } + + iterator_scale.row_groupsize64_++; + + this->smem_iterator_scale_.add_tile_offset({1, 0}); + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()(int gemm_k_iterations, ///< number of iterations of the mainloop + FragmentC& accum, ///< destination accumulator tile + IteratorA iterator_A, ///< iterator over A operand in global memory + IteratorB iterator_B, ///< iterator over B operand in global memory + IteratorScale iterator_scale, ///< iterator over scale operand in global memory + FragmentC const& src_accum) { ///< source accumulator tile + + // + // Prologue + // + TransformBAfterLDG ldg_converter; + TransformBAfterLDS lds_converter; + + using TransformA = NumericArrayConverter; + + // These transforms are mainly to handle when we have bfloat activations and weights in GMEM and want + // to issue HMMA on architectures older than Ampere. We will convert to FP16 before STS. + TransformA transformA; + + // Perform accumulation in the 'd' output operand + accum = src_accum; + + FragmentA tb_frag_A; + FragmentB tb_frag_B; + + tb_frag_A.clear(); + tb_frag_B.clear(); + + // The last kblock is loaded in the prolog + iterator_A.load(tb_frag_A); + iterator_B.load(tb_frag_B); + + ++iterator_A; + ++iterator_B; + + this->smem_iterator_A_.store(transformA(tb_frag_A)); + this->smem_iterator_B_.store(ldg_converter(tb_frag_B)); + + ++this->smem_iterator_A_; + ++this->smem_iterator_B_; + + copy_scales_and_advance(iterator_scale); + + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math instructions + WarpFragmentA warp_frag_A[2]; + WarpFragmentB warp_frag_B[2]; + WarpFragmentScale warp_frag_scales; + WarpFragmentZero warp_frag_zero; + + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_B_.set_kgroup_index(0); + + this->warp_tile_iterator_A_.load(warp_frag_A[0]); + this->warp_tile_iterator_B_.load(warp_frag_B[0]); + + warp_dequantizer_.load(warp_frag_scales, warp_frag_zero); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + warp_dequantizer_.add_pointer_offset(Shape::kN); + + Operator warp_mma; + + int smem_write_stage_idx = 1; + + // Avoid reading out of bounds + iterator_A.clear_mask(gemm_k_iterations <= 1); + iterator_B.clear_mask(gemm_k_iterations <= 1); + iterator_scale.clear_mask(gemm_k_iterations <= 1); + + // Issue loads during the first warp-level matrix multiply-add *AFTER* issuing + // shared memory loads (which have the tighest latency requirement). + + // + // Mainloop + // + + // Note: The main loop does not support Base::kWarpGemmIterations == 2. + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > 0; --gemm_k_iterations) { + // + // Loop over GEMM K dimension + // + + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) { + // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group + // as the case may be. + + if (warp_mma_k == Base::kWarpGemmIterations - 1) { + // Write fragments to shared memory + this->smem_iterator_A_.store(transformA(tb_frag_A)); + + this->smem_iterator_B_.store(ldg_converter(tb_frag_B)); + + __syncthreads(); + + ++this->smem_iterator_A_; + ++this->smem_iterator_B_; + + // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory + if (smem_write_stage_idx == 1) { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + this->smem_iterator_scale_.add_tile_offset({-Base::kStages, 0}); + } else { + this->warp_tile_iterator_A_.add_tile_offset( + {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterationsForB, 0}); + warp_dequantizer_.add_pointer_offset(-Base::kStages * Shape::kN); + } + + smem_write_stage_idx ^= 1; + } + + this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); + ++this->warp_tile_iterator_A_; + + int const warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad; + int const warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad; + // We are just about to finish computing on a fragment of B, so initiate the load for the next fragment. + if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1) { + this->warp_tile_iterator_B_.set_kgroup_index( + (warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB); + this->warp_tile_iterator_B_.load(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]); + ++this->warp_tile_iterator_B_; + } + + if (warp_mma_k == 0) { + iterator_A.load(tb_frag_A); + iterator_B.load(tb_frag_B); + + ++iterator_A; + ++iterator_B; + + copy_scales_and_advance(iterator_scale); + + // Avoid reading out of bounds if this was the last loop iteration + iterator_A.clear_mask(gemm_k_iterations <= 2); + iterator_B.clear_mask(gemm_k_iterations <= 2); + iterator_scale.clear_mask(gemm_k_iterations <= 2); + } + + typename TransformBAfterLDS::result_type converted_frag_B = lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]); + warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales, warp_frag_zero); + warp_mma(accum, warp_frag_A[warp_mma_k % 2], converted_frag_B, accum, warp_tileB_k_compute_offset); + } + + // Load the scales needed for the next tile iteration + warp_dequantizer_.load(warp_frag_scales, warp_frag_zero); + // Update internal pointer to the set of scales in shared memory + warp_dequantizer_.add_pointer_offset(Shape::kN); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/warp/default_mma_tensor_op.h b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/warp/default_mma_tensor_op.h new file mode 100644 index 0000000000000..e680493cf060a --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/warp/default_mma_tensor_op.h @@ -0,0 +1,89 @@ +/* + * Copyright (c) 2017-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! \file + \brief Default warp-level GEMM operators selected by data type, size, and layouts of operands. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/warp/default_mma_tensor_op.h" +#include "cutlass/gemm/warp/mma_tensor_op.h" + +#include "contrib_ops/cuda/llm/cutlass_extensions/arch/mma.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h" + +namespace cutlass { +namespace gemm { +namespace warp { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for m-by-n-by-kgroup +template < + /// Shape of one matrix production operation (concept: GemmShape) + typename WarpShape_, + /// Shape of one matrix production operation (concept: GemmShape) + typename InstructionShape_, + /// Data type of A elements, + typename ElementA, + /// Layout of A matrix (concept: MatrixLayout) + typename LayoutA, + /// Data type of B elements + typename ElementB, + /// Layout of B matrix (concept: MatrixLayout) + typename LayoutB, + /// Element type of C matrix + typename ElementC, + /// Layout of C matrix (concept: MatrixLayout) + typename LayoutC, + /// Number of partitions along K dimension + int PartitionsK, + /// Store the accumulators in row major or column major. Row major is used + /// when output layout is interleaved. + bool AccumulatorsInRowMajor> +struct DefaultMmaTensorOp { + private: + // Shape for computing the FP16s + using ComputeInstructionShape = InstructionShape_; + + // Chosen so we get K=16 for int8 and K=32 for int4. + static constexpr int LoadInstructionK = 128 / sizeof_bits::value; + + // Shape for loading the narrow data type from shared memory + using LoadInstructionShape = GemmShape; + + public: + using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< + cutlass::arch::Mma, + cutlass::MatrixShape<1, 1>>; + + // Define the warp-level tensor op + using Type = cutlass::gemm::warp::MmaTensorOpComputeBWithF16; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h new file mode 100644 index 0000000000000..21c787e91be50 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h @@ -0,0 +1,263 @@ +/* + * Copyright (c) 2017-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! \file + \brief Templates implementing warp-level matrix multiply-accumulate operations targeting + Tensor Cores. +*/ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/platform/platform.h" + +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/arch/memory_sm75.h" +#include "cutlass/arch/mma_sm75.h" +#include "cutlass/arch/mma_sm80.h" +#include "cutlass/arch/mma_sm89.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/warp/mma.h" + +#include "cutlass/gemm/warp/mma_tensor_op_policy.h" + +#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator.h" +#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace warp { + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Data type of A elements + typename ElementA_, + /// Layout of A matrix (concept: MatrixLayout) + typename LayoutA_, + /// Data type of B elements + typename ElementB_, + /// Layout of B matrix (concept: MatrixLayout) + typename LayoutB_, + /// Element type of C matrix + typename ElementC_, + /// Layout of C matrix (concept: MatrixLayout) + typename LayoutC_, + /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) + typename Policy_, + /// Instruction shape to override shared memory iterators with + typename SharedMemoryInstructionShape_, + /// Number of partitions along K dimension + int PartitionsK_ = 1, + /// Store the accumulators in row major or column major. Row major is used + /// when output layout is interleaved. + bool AccumulatorsInRowMajor = false, + /// Used for partial specialization + typename Enable = bool> +class MmaTensorOpComputeBWithF16 { + public: + /// Shape of warp-level matrix operation (concept: GemmShape) + using Shape = Shape_; + + /// Data type of multiplicand A + using ElementA = ElementA_; + + /// Layout of multiplicand A + using LayoutA = LayoutA_; + + /// Data type of multiplicand B + using ElementB = ElementB_; + + /// Layout of multiplicand B + using LayoutB = LayoutB_; + + /// Data type of accumulator matrix C + using ElementC = ElementC_; + + /// Layout of accumulator matrix C + using LayoutC = LayoutC_; + + /// Shape of the warp in units of thread (concept: MmaLanePolicySimt) + using Policy = Policy_; + + /// Underlying matrix multiply operator (concept: arch::Mma) + using ArchMmaOperator = typename Policy::Operator; + + /// Indicates math operator + using MathOperator = typename ArchMmaOperator::Operator; + + /// Architecture tag from underlying instruction + using ArchTag = typename ArchMmaOperator::ArchTag; + static_assert((platform::is_same::value && platform::is_same::value) || (platform::is_same::value && platform::is_same::value && ArchTag::kMinComputeCapability >= 80) || (platform::is_same::value && platform::is_same::value && ArchTag::kMinComputeCapability >= 89), + "MmaTensorOpCvtBToA only supports underlying HMMA/QMMA"); + + static_assert(platform::is_same::value || (platform::is_same::value && ArchTag::kMinComputeCapability >= 80) || (platform::is_same::value && ArchTag::kMinComputeCapability >= 89), + "MmaTensorOpCvtBToA only supports Fp16 A or Bf16 A on Ampere+, or FP8 on Ada"); + + /// Indicates class of matrix operator + using OperatorClass = arch::OpClassTensorOp; + + /// Shape of underlying instruction + using InstructionShape = typename ArchMmaOperator::Shape; + + /// Instruction shape to override shared memory iterators with + using SharedMemoryInstructionShape = SharedMemoryInstructionShape_; + + static_assert( + SharedMemoryInstructionShape::kM == InstructionShape::kM, "M dimension of compute instruction must match load"); + static_assert( + SharedMemoryInstructionShape::kN == InstructionShape::kN, "N dimension of compute instruction must match load"); + + static constexpr int kExpansionFactor = SharedMemoryInstructionShape::kK / InstructionShape::kK; + + static_assert(!(Shape::kK % SharedMemoryInstructionShape::kK), ""); + + /// Complex transform on A operand + static ComplexTransform const kTransformA = ComplexTransform::kNone; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = ComplexTransform::kNone; + + /// Number of threads participating in warp-level matrix product + static int const kThreadCount = 32; + + /// Number of partitions along K dimension + static int const kPartitionsK = PartitionsK_; + + public: + /// Iterates over the A operand in memory + using IteratorA = MmaTensorOpMultiplicandTileIterator, Operand::kA, ElementA, LayoutA, + MatrixShape, Policy::OpDelta::kRow, kThreadCount, kPartitionsK>; + + /// Storage for A tile + using FragmentA = typename IteratorA::Fragment; + + /// Storage for transformed A tile + using TransformedFragmentA = Array; + + /// Iterates over the B operand in memory + using IteratorB = MmaTensorOpMultiplicandTileIterator, Operand::kB, ElementB, + LayoutB, MatrixShape, Policy::OpDelta::kRow, + kThreadCount, kPartitionsK>; + + /// Storage for B tile + using FragmentB = typename IteratorB::Fragment; + + /// Storage for transformed B tile + using TransformedFragmentB = Array; + + /// Iterates over the C operand in memory + using IteratorC = MmaTensorOpAccumulatorTileIterator, ElementC, LayoutC, + typename ArchMmaOperator::Shape, typename Policy::OpDelta>; + + /// Storage for C tile + using FragmentC = typename IteratorC::Fragment; + + /// Number of mma operations performed + using MmaIterations = MatrixShape<(Shape::kM + ArchMmaOperator::Shape::kM - 1) / ArchMmaOperator::Shape::kM, + (Shape::kN + ArchMmaOperator::Shape::kN - 1) / ArchMmaOperator::Shape::kN>; + + public: + /// Underlying matrix multiply operator (concept: arch::Mma) + ArchMmaOperator mma; + + public: + // + // Methods + // + + /// Ctor + CUTLASS_DEVICE + MmaTensorOpComputeBWithF16() {} + + /// Performs a warp-level matrix multiply-accumulate operation + CUTLASS_DEVICE + void operator()(FragmentC& D, TransformedFragmentA const& A, TransformedFragmentB const& B, FragmentC const& C, + int const warp_tileB_k_offset) const { + using MmaOperandA = typename ArchMmaOperator::FragmentA; + using MmaOperandB = typename ArchMmaOperator::FragmentB; + using MmaOperandC = typename ArchMmaOperator::FragmentC; + + static_assert( + TransformedFragmentB::kElements == MmaOperandB::kElements * kExpansionFactor * MmaIterations::kColumn, + "Each thread should have a pack of mma registers for each column iteration AND for the expanded K dim of " + "B"); + + D = C; + + MmaOperandA const* ptr_A = reinterpret_cast(&A); + MmaOperandB const* ptr_B = reinterpret_cast(&B); + MmaOperandC* ptr_D = reinterpret_cast(&D); + +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) + // Serpentine visitation order maximizing reuse of Rb + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < MmaIterations::kColumn; ++n) { + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < MmaIterations::kRow; ++m) { + int m_serpentine = ((n % 2) ? (MmaIterations::kRow - 1 - m) : m); + + int n_offsetB = warp_tileB_k_offset + kExpansionFactor * n; + if (AccumulatorsInRowMajor) { // matrix B is reordered + mma(ptr_D[n + m_serpentine * MmaIterations::kColumn], ptr_A[m_serpentine], ptr_B[n_offsetB], + ptr_D[n + m_serpentine * MmaIterations::kColumn]); + } else { + mma(ptr_D[m_serpentine + n * MmaIterations::kRow], ptr_A[m_serpentine], ptr_B[n_offsetB], + ptr_D[m_serpentine + n * MmaIterations::kRow]); + } + } + } +#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + // Serpentine visitation order maximizing reuse of Ra + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < MmaIterations::kRow; ++m) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < MmaIterations::kColumn; ++n) { + int n_serpentine = ((m % 2) ? (MmaIterations::kColumn - 1 - n) : n); + + int n_serpentine_offsetB = warp_tileB_k_offset + kExpansionFactor * n_serpentine; + if (AccumulatorsInRowMajor) { // matrix B is reordered + mma(ptr_D[n_serpentine + m * MmaIterations::kColumn], ptr_A[m], ptr_B[n_serpentine_offsetB], + ptr_D[n_serpentine + m * MmaIterations::kColumn]); + } else { + mma(ptr_D[m + n_serpentine * MmaIterations::kRow], ptr_A[m], ptr_B[n_serpentine_offsetB], + ptr_D[m + n_serpentine * MmaIterations::kRow]); + } + } + } +#else + assert(0); +#endif + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h new file mode 100644 index 0000000000000..47f1bb240e8b3 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h @@ -0,0 +1,393 @@ +/* + * Copyright (c) 2017-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! \file + \brief Defines iterators used by warp-level matrix multiply operations targeting Tensor Cores. +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/array.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_ref.h" + +#include "cutlass/arch/arch.h" +#include "cutlass/arch/memory_sm75.h" +#include "cutlass/gemm/gemm.h" + +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor.h" + +#include "cutlass/functional.h" +#include "cutlass/platform/platform.h" + +#include "contrib_ops/cuda/llm/cutlass_extensions/weight_only_quant_op.h" + +#include + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace warp { + +//////////////////////////////////////////////////////////////////////////////// + +template < + /// Matrix multiply operator + typename MmaOperator_, + /// Size of the matrix to load (concept: MatrixShape) + typename Shape_, + /// Operand identity + Operand Operand, + /// Data type of Scale elements + typename Element_, + /// Layout of operand + typename Layout_, + /// Number of threads participating in one matrix operation + int Threads, + /// + WeightOnlyQuantOp QuantOp_, + /// + typename Enable = void> +class MmaTensorOpDequantizer; + +//////////////////////////////////////////////////////////////////////////////// +// Bfloat specialization for Ampere +template < + /// Underlying matrix multiply operator (concept: MmaTensorOp) + typename MmaOperator_, + /// Shape of the warp level matrix multiply (concept: GemmShape) + typename Shape_, + /// + WeightOnlyQuantOp QuantOp_> +class MmaTensorOpDequantizer= 80 && platform::is_same::value>::type> { + public: + /// Mma Operator + using MmaOperator = MmaOperator_; + + // The architecture specific mma ooperator being used + using ArchMmaOperator = typename MmaOperator::ArchMmaOperator; + + // Mma Instruction Shape + using InstructionShape = typename ArchMmaOperator::Shape; + + // This is the ratio of the load instruction vs the compute instruction. + static constexpr int kExpansionFactor = MmaOperator::IteratorB::InstructionShape::kRow / InstructionShape::kK; + + /// Type of the scales + using ElementScale = bfloat16_t; + + /// Fragment to hold B data before Mma + using FragmentDequantizedOperand = Array; + + // Fragment to hold scale data to apply to B before mma + // We need 1 fp16 per matrix iteration in the N dimension + static constexpr int kColsPerMmaPerThread = 1; + using FragmentScale = Array; + using FragmentZero = Array; + + /// Warp mma shape + using Shape = Shape_; + + /// Layout of the scales in shared memory + using Layout = layout::RowMajor; + + /// TensorRef type for loading element from a tensor + using TensorRef = TensorRef; + + static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; + + CUTLASS_DEVICE + MmaTensorOpDequantizer(TensorRef smem_scales, TensorRef smem_zeros, int const warp_idx_n, int const lane_idx) { + int const warp_offset = warp_idx_n * Shape::kN; + int const quad = lane_idx / 4; + int const thread_offset = warp_offset + quad; + pointer_scale_ = smem_scales.data() + thread_offset; + if constexpr (hasZero(QuantOp)) { + pointer_zero_ = smem_zeros.data() + thread_offset; + } + } + + CUTLASS_DEVICE + MmaTensorOpDequantizer(TensorRef smem_scales, int const warp_idx_n, int const lane_idx) + : MmaTensorOpDequantizer(smem_scales, TensorRef(), warp_idx_n, lane_idx) { + } + + CUTLASS_DEVICE + void load(FragmentScale& scale_frag) { + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) { + scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN]; + } + } + + CUTLASS_DEVICE + void dequantize(FragmentDequantizedOperand& operand_frag, FragmentScale const& scale_frag) { +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + using _MmaOperandB = typename ArchMmaOperator::FragmentB; + using ExpandedMmaOperandB = Array; + static_assert(ExpandedMmaOperandB::kElements * MmaOperator::MmaIterations::kColumn == FragmentDequantizedOperand::kElements, + ""); + + __nv_bfloat16 const* scale_ptr = reinterpret_cast<__nv_bfloat16 const*>(&scale_frag); + ExpandedMmaOperandB* operand_frag_ptr = reinterpret_cast(&operand_frag); + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) { + static_assert(ExpandedMmaOperandB::kElements % 2 == 0, ""); + + __nv_bfloat162 scalex2 = __bfloat162bfloat162(scale_ptr[mma_n_iter]); + __nv_bfloat162* operand_bf16x2_ptr = reinterpret_cast<__nv_bfloat162*>(&operand_frag_ptr[mma_n_iter]); + + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < ExpandedMmaOperandB::kElements / 2; ++ii) { + operand_bf16x2_ptr[ii] = __hmul2(operand_bf16x2_ptr[ii], scalex2); + } + } +#else + // Slow path not implemented here on purpose. If we need to do HMMA on older arch, scale conversion should + // happen before scales are stored to shared memory and we should use the fp16 dequantizer. This will avoid + // numerous conversion instructions in GEMM main loop. + arch::device_breakpoint(); +#endif + } + + CUTLASS_DEVICE + void load(FragmentScale& scale_frag, FragmentScale& zero_frag) { + if constexpr (hasZero(QuantOp)) { + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) { + scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN]; + zero_frag[mma_n_iter] = pointer_zero_[mma_n_iter * InstructionShape::kN]; + } + } else { + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) { + scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN]; + } + } + } + + CUTLASS_DEVICE + void dequantize( + FragmentDequantizedOperand& operand_frag, FragmentScale const& scale_frag, FragmentScale const& zero_frag) { +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + using _MmaOperandB = typename ArchMmaOperator::FragmentB; + using ExpandedMmaOperandB = Array; + static_assert(ExpandedMmaOperandB::kElements * MmaOperator::MmaIterations::kColumn == FragmentDequantizedOperand::kElements, + ""); + + __nv_bfloat16 const* scale_ptr = reinterpret_cast<__nv_bfloat16 const*>(&scale_frag); + __nv_bfloat16 const* zero_ptr = reinterpret_cast<__nv_bfloat16 const*>(&zero_frag); + + ExpandedMmaOperandB* operand_frag_ptr = reinterpret_cast(&operand_frag); + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) { + static_assert(ExpandedMmaOperandB::kElements % 2 == 0, ""); + + __nv_bfloat162 scalex2 = __bfloat162bfloat162(scale_ptr[mma_n_iter]); + __nv_bfloat162 zerox2 = __bfloat162bfloat162(zero_ptr[mma_n_iter]); + __nv_bfloat162* operand_bf16x2_ptr = reinterpret_cast<__nv_bfloat162*>(&operand_frag_ptr[mma_n_iter]); + + if constexpr (hasZero(QuantOp)) { + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < ExpandedMmaOperandB::kElements / 2; ++ii) { + operand_bf16x2_ptr[ii] = __hfma2(operand_bf16x2_ptr[ii], scalex2, zerox2); + } + } else { + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < ExpandedMmaOperandB::kElements / 2; ++ii) { + operand_bf16x2_ptr[ii] = __hmul2(operand_bf16x2_ptr[ii], scalex2); + } + } + } +#else + // Slow path not implemented here on purpose. If we need to do HMMA on older arch, scale conversion should + // happen before scales are stored to shared memory and we should use the fp16 dequantizer. This will avoid + // numerous conversion instructions in GEMM main loop. + arch::device_breakpoint(); +#endif + } + + // Adds a pointer offset in units of elements. + CUTLASS_DEVICE + void add_pointer_offset(int64_t const& offset) { + static_assert(sizeof(ElementScale) > 1, ""); + pointer_scale_ += offset; + pointer_zero_ += offset; + } + + private: + ElementScale const* pointer_scale_; + ElementScale const* pointer_zero_; +}; + +//////////////////////////////////////////////////////////////////////////////// + +// Specialization for Turing & Ampere +template < + /// Underlying matrix multiply operator (concept: MmaTensorOp) + typename MmaOperator_, + /// Shape of the warp level matrix multiply (concept: GemmShape) + typename Shape_, + /// + WeightOnlyQuantOp QuantOp_> +class MmaTensorOpDequantizer= 75 && platform::is_same::value>::type> { + public: + /// Mma Operator + using MmaOperator = MmaOperator_; + + // The architecture specific mma ooperator being used + using ArchMmaOperator = typename MmaOperator::ArchMmaOperator; + + // Mma Instruction Shape + using InstructionShape = typename ArchMmaOperator::Shape; + + // This is the ratio of the load instruction vs the compute instruction. + static constexpr int kExpansionFactor = MmaOperator::IteratorB::InstructionShape::kRow / InstructionShape::kK; + + /// Type of the scales + using ElementScale = half_t; + + /// Fragment to hold B data before Mma + using FragmentDequantizedOperand = Array; + + // Fragment to hold scale data to apply to B before mma + // We need 1 fp16 per matrix iteration in the N dimension + static constexpr int kColsPerMmaPerThread = 1; + using FragmentScale = Array; + using FragmentZero = Array; + + /// Warp mma shape + using Shape = Shape_; + + /// Layout of the scales in shared memory + using Layout = layout::RowMajor; + + /// TensorRef type for loading element from a tensor + using TensorRef = TensorRef; + + static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; + + CUTLASS_DEVICE + MmaTensorOpDequantizer(TensorRef smem_scales, TensorRef smem_zeros, int const warp_idx_n, int const lane_idx) { + int const warp_offset = warp_idx_n * Shape::kN; + int const quad = lane_idx / 4; + int const thread_offset = warp_offset + quad; + pointer_scale_ = smem_scales.data() + thread_offset; + if constexpr (hasZero(QuantOp)) { + pointer_zero_ = smem_zeros.data() + thread_offset; + } + } + + CUTLASS_DEVICE + MmaTensorOpDequantizer(TensorRef smem_scales, int const warp_idx_n, int const lane_idx) + : MmaTensorOpDequantizer(smem_scales, TensorRef(), warp_idx_n, lane_idx) { + } + + CUTLASS_DEVICE + void load(FragmentScale& scale_frag) { + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) { + scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN]; + } + } + + CUTLASS_DEVICE + void dequantize(FragmentDequantizedOperand& operand_frag, FragmentScale const& scale_frag) { + using _MmaOperandB = typename ArchMmaOperator::FragmentB; + using ExpandedMmaOperandB = Array; + static_assert(ExpandedMmaOperandB::kElements * MmaOperator::MmaIterations::kColumn == FragmentDequantizedOperand::kElements, + ""); + + multiplies mul_op; + + ExpandedMmaOperandB* operand_frag_ptr = reinterpret_cast(&operand_frag); + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) { + operand_frag_ptr[mma_n_iter] = mul_op(operand_frag_ptr[mma_n_iter], scale_frag[mma_n_iter]); + } + } + + CUTLASS_DEVICE + void load(FragmentScale& scale_frag, FragmentScale& zero_frag) { + if constexpr (hasZero(QuantOp)) { + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) { + scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN]; + zero_frag[mma_n_iter] = pointer_zero_[mma_n_iter * InstructionShape::kN]; + } + } else { + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) { + scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN]; + } + } + } + + CUTLASS_DEVICE + void dequantize( + FragmentDequantizedOperand& operand_frag, FragmentScale const& scale_frag, FragmentScale const& zero_frag) { + using _MmaOperandB = typename ArchMmaOperator::FragmentB; + using ExpandedMmaOperandB = Array; + static_assert(ExpandedMmaOperandB::kElements * MmaOperator::MmaIterations::kColumn == FragmentDequantizedOperand::kElements, + ""); + + multiplies mul_op; + ExpandedMmaOperandB* operand_frag_ptr = reinterpret_cast(&operand_frag); + + if constexpr (hasZero(QuantOp)) { + plus plus_op; + + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) { + operand_frag_ptr[mma_n_iter] = plus_op(mul_op(operand_frag_ptr[mma_n_iter], scale_frag[mma_n_iter]), zero_frag[mma_n_iter]); + } + } else { + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) { + operand_frag_ptr[mma_n_iter] = mul_op(operand_frag_ptr[mma_n_iter], scale_frag[mma_n_iter]); + } + } + } + + // Adds a pointer offset in units of elements. + CUTLASS_DEVICE + void add_pointer_offset(int64_t const& offset) { + static_assert(sizeof(ElementScale) > 1, ""); + pointer_scale_ += offset; + pointer_zero_ += offset; + } + + private: + ElementScale const* pointer_scale_; + ElementScale const* pointer_zero_; +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm_configs.h b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm_configs.h new file mode 100644 index 0000000000000..e48ef3f154883 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm_configs.h @@ -0,0 +1,405 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include + +#if defined(__GNUC__) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-parameter" +#pragma GCC diagnostic ignored "-Wunused-function" +#pragma GCC diagnostic ignored "-Wunused-local-typedefs" +#endif + +#include "cute/tensor.hpp" + +namespace onnxruntime::llm { +namespace cutlass_extensions { + +// Note: The shapes are in the format MxNxK. The K shape of the runtime config MUST match the K shape +// in the kernel layout details when doing weight only quantization. +enum class CutlassTileConfig { + // Signals that we should run heuristics do choose a config + Undefined, + + // Signals that we should run heuristics do choose a config + ChooseWithHeuristic, + + // SiMT config + CtaShape128x128x8_WarpShape64x64x8, + + // TensorCore configs CTA_N = 128, CTA_K = 64 + // Warp configs for M=16 + CtaShape16x128x64_WarpShape16x32x64, + + // Warp configs for M=32 + CtaShape32x128x64_WarpShape32x32x64, + + // Warp configs for M=64 + CtaShape64x128x64_WarpShape32x64x64, + CtaShape64x64x128_WarpShape32x64x64, + CtaShape64x128x64_WarpShape64x32x64, + + // Warp configs for M=128 + CtaShape128x64x64_WarpShape64x32x64, + CtaShape128x128x64_WarpShape64x32x64, + CtaShape128x128x64_WarpShape64x64x64, + CtaShape128x128x64_WarpShape128x32x64, + CtaShape128x256x64_WarpShape64x64x64, + + // Warp configs for M=256 + CtaShape256x128x64_WarpShape64x64x64, + + // TensorCore config CTA_N = 64, CTA_K = 128 + CtaShape128x64x128_WarpShape64x32x128, + + // TensorCore config CTA_N = 256, CTA_K = 64 + CtaShape16x256x64_WarpShape16x64x64, + + // TensorCore config CTA_N = 256, CTA_K = 128 + CtaShape16x256x128_WarpShape16x64x128 + +}; + +enum class SplitKStyle { + NO_SPLIT_K, + SPLIT_K_SERIAL, + STREAM_K, // Sm80+ + // SPLIT_K_PARALLEL // Not supported yet +}; + +enum class CutlassTileConfigSM90 { + // Signals that we should run heuristics do choose a config + Undefined, + + // Signals that we should run heuristics do choose a config + ChooseWithHeuristic, + + // CTA configs for M=64 + CtaShape64x16x128B, + CtaShape64x32x128B, + CtaShape64x64x128B, + CtaShape64x128x128B, + CtaShape64x256x128B, + + // CTA configs for M=128 + CtaShape128x16x128B, + CtaShape128x32x128B, + CtaShape128x64x128B, + CtaShape128x128x128B, + CtaShape128x256x128B, + + // CTA configs for M=128 + CtaShape256x128x128B, +}; + +enum class CutlassTileConfigSM100 { + // Signals that we should run heuristics do choose a config + Undefined, + + // Signals that we should run heuristics do choose a config + ChooseWithHeuristic, + + /* + * Grouped GEMM + */ + // M=64 + CtaShape64x32x128B, + CtaShape64x64x128B, + CtaShape64x128x128B, + CtaShape64x256x128B, + + // M=128 + CtaShape128x8x256B, + CtaShape128x16x128B, + CtaShape128x32x128B, + CtaShape128x64x128B, + CtaShape128x128x128B, + CtaShape128x256x128B, + CtaShape128x128x256B, + CtaShape128x256x256B, + + // M=256 + CtaShape256x64x128B, + CtaShape256x128x128B, + CtaShape256x256x128B, +}; + +enum class MainloopScheduleType { + AUTO, // Automatically selects between pingpong and cooperative schedules on Hopper. On older architectures, this + // defaults to the "legacy" main loop schedule. + PINGPONG, + COOPERATIVE, + WARPSPECIALIZED +}; + +#if 0 +static auto get_mainloop_schedule_name(MainloopScheduleType schedule) { + if (schedule == MainloopScheduleType::AUTO) { + return "auto"; + } else if (schedule == MainloopScheduleType::PINGPONG) { + return "pingpong"; + } else if (schedule == MainloopScheduleType::COOPERATIVE) { + return "cooperative"; + } else if (schedule == MainloopScheduleType::WARPSPECIALIZED) { + return "warpspecialized"; + } + return "unknown schedule"; +} +#endif + +enum class EpilogueScheduleType { + AUTO, // Automatically chooses an epilogue schedule compatible with the selected main loop schedule for Hopper. For + // architectures older than hopper, the epilogue is always performed by the same thread block as the main + // loop. +}; + +enum class TileShape { + TileShape_64x16x128, + TileShape_64x32x128, + TileShape_64x64x128, + TileShape_64x128x128, + TileShape_64x256x128, + TileShape_64x512x128, + TileShape_128x16x128, + TileShape_128x32x128, + TileShape_128x64x128, + TileShape_128x128x128, + TileShape_128x256x128 +}; + +template +constexpr auto get_tile_shape() { + using namespace cute; + if constexpr (Shape_MNK == TileShape::TileShape_64x16x128) { + return cute::Shape<_64, _16, _128>{}; + } else if constexpr (Shape_MNK == TileShape::TileShape_64x32x128) { + return cute::Shape<_64, _32, _128>{}; + } else if constexpr (Shape_MNK == TileShape::TileShape_64x64x128) { + return cute::Shape<_64, _64, _128>{}; + } else if constexpr (Shape_MNK == TileShape::TileShape_64x128x128) { + return cute::Shape<_64, _128, _128>{}; + } else if constexpr (Shape_MNK == TileShape::TileShape_64x256x128) { + return cute::Shape<_64, _256, _128>{}; + } else if constexpr (Shape_MNK == TileShape::TileShape_64x512x128) { + return cute::Shape<_64, _512, _128>{}; + } else if constexpr (Shape_MNK == TileShape::TileShape_128x16x128) { + return cute::Shape<_128, _16, _128>{}; + } else if constexpr (Shape_MNK == TileShape::TileShape_128x32x128) { + return cute::Shape<_128, _32, _128>{}; + } else if constexpr (Shape_MNK == TileShape::TileShape_128x64x128) { + return cute::Shape<_128, _64, _128>{}; + } else if constexpr (Shape_MNK == TileShape::TileShape_128x128x128) { + return cute::Shape<_128, _128, _128>{}; + } else if constexpr (Shape_MNK == TileShape::TileShape_128x256x128) { + return cute::Shape<_128, _256, _128>{}; + } +} + +#if 0 +static auto get_tile_shape_name(TileShape Shape_MNK) { + if (Shape_MNK == TileShape::TileShape_64x16x128) { + return "64x16x128"; + } else if (Shape_MNK == TileShape::TileShape_64x32x128) { + return "64x32x128"; + } else if (Shape_MNK == TileShape::TileShape_64x64x128) { + return "64x64x128"; + } else if (Shape_MNK == TileShape::TileShape_64x128x128) { + return "64x128x128"; + } else if (Shape_MNK == TileShape::TileShape_64x256x128) { + return "64x256x128"; + } else if (Shape_MNK == TileShape::TileShape_64x512x128) { + return "64x512x128"; + } else if (Shape_MNK == TileShape::TileShape_128x16x128) { + return "128x16x128"; + } else if (Shape_MNK == TileShape::TileShape_128x32x128) { + return "128x32x128"; + } else if (Shape_MNK == TileShape::TileShape_128x64x128) { + return "128x64x128"; + } else if (Shape_MNK == TileShape::TileShape_128x128x128) { + return "128x128x128"; + } else if (Shape_MNK == TileShape::TileShape_128x256x128) { + return "128x256x128"; + } + return "Unknown shape"; +} +#endif + +enum class ClusterShape { + ClusterShape_1x1x1, + ClusterShape_2x1x1, + ClusterShape_1x2x1, + ClusterShape_2x2x1, + ClusterShape_1x4x1, + ClusterShape_4x2x1, + ClusterShape_2x4x1, + ClusterShape_4x4x1, + ClusterShape_1x8x1, + ClusterShape_8x1x1 +}; + +#if 0 +static auto get_cluster_shape_name(ClusterShape Shape_MNK) { + if (Shape_MNK == ClusterShape::ClusterShape_1x1x1) { + return "1x1x1"; + } else if (Shape_MNK == ClusterShape::ClusterShape_2x1x1) { + return "2x1x1"; + } else if (Shape_MNK == ClusterShape::ClusterShape_1x2x1) { + return "1x2x1"; + } else if (Shape_MNK == ClusterShape::ClusterShape_2x2x1) { + return "2x2x1"; + } else if (Shape_MNK == ClusterShape::ClusterShape_1x8x1) { + return "1x8x1"; + } else if (Shape_MNK == ClusterShape::ClusterShape_8x1x1) { + return "8x1x1"; + } + return "Unknown shape"; +} + +template +constexpr auto get_cluster_shape() { + using namespace cute; + if constexpr (Shape_MNK == ClusterShape::ClusterShape_1x1x1) { + return cute::Shape<_1, _1, _1>{}; + } else if constexpr (Shape_MNK == ClusterShape::ClusterShape_2x1x1) { + return cute::Shape<_2, _1, _1>{}; + } else if constexpr (Shape_MNK == ClusterShape::ClusterShape_1x2x1) { + return cute::Shape<_1, _2, _1>{}; + } else if constexpr (Shape_MNK == ClusterShape::ClusterShape_2x2x1) { + return cute::Shape<_2, _2, _1>{}; + } else if constexpr (Shape_MNK == ClusterShape::ClusterShape_1x8x1) { + return cute::Shape<_1, _8, _1>{}; + } else if constexpr (Shape_MNK == ClusterShape::ClusterShape_8x1x1) { + return cute::Shape<_8, _1, _1>{}; + } +} +#endif + +struct CutlassGemmConfig { + enum CandidateConfigTypeParam : int { + NONE = 0, + WEIGHT_ONLY = 1u << 0, + SIMT_ONLY = 1u << 1, + INT8_ONLY = 1u << 2, + HOPPER = 1u << 3, + BLACKWELL = 1u << 4, + GROUPED_GEMM = 1u << 5, + FP8_ONLY = 1u << 6, + FP4_ONLY = 1u << 7 + }; + + CutlassTileConfig tile_config_sm80 = CutlassTileConfig::ChooseWithHeuristic; + SplitKStyle split_k_style = SplitKStyle::NO_SPLIT_K; + int split_k_factor = -1; + int stages = -1; + + // config options for sm90 + CutlassTileConfigSM90 tile_config_sm90 = CutlassTileConfigSM90::ChooseWithHeuristic; + CutlassTileConfigSM100 tile_config_sm100 = CutlassTileConfigSM100::ChooseWithHeuristic; + MainloopScheduleType mainloop_schedule = MainloopScheduleType::AUTO; + EpilogueScheduleType epilogue_schedule = EpilogueScheduleType::AUTO; + ClusterShape cluster_shape = ClusterShape::ClusterShape_1x1x1; + bool enableCudaKernel = false; + int sm_version = 80; // Use 80 as a catch all for <90 + bool is_tma_warp_specialized = false; + + CutlassGemmConfig() = default; + + CutlassGemmConfig(CutlassTileConfig tile_config, SplitKStyle split_k_style, int split_k_factor, int stages) + : tile_config_sm80(tile_config), split_k_style(split_k_style), split_k_factor(split_k_factor), stages(stages), sm_version(80) { + } + + CutlassGemmConfig(CutlassTileConfigSM90 tile_config_sm90, MainloopScheduleType mainloop_schedule, + EpilogueScheduleType epilogue_schedule, ClusterShape cluster_shape) + : tile_config_sm90(tile_config_sm90), mainloop_schedule(mainloop_schedule), epilogue_schedule(epilogue_schedule), cluster_shape(cluster_shape), sm_version(90), is_tma_warp_specialized(true) { + } + + CutlassGemmConfig(CutlassTileConfigSM100 tile_config_sm100, MainloopScheduleType mainloop_schedule, + EpilogueScheduleType epilogue_schedule, ClusterShape cluster_shape) + : tile_config_sm100(tile_config_sm100), mainloop_schedule(mainloop_schedule), epilogue_schedule(epilogue_schedule), cluster_shape(cluster_shape), sm_version(100), is_tma_warp_specialized(true) { + } + + int getTileConfigAsInt() const { + if (sm_version == 120) + return (int)tile_config_sm80; + if (sm_version >= 100) + return (int)tile_config_sm100; + if (sm_version == 90) + return (int)tile_config_sm90; + if (sm_version < 90) + return (int)tile_config_sm80; + assert(false && "Invalid SM version"); + return -1; + } + + std::string toString() const { + std::stringstream tactic; + tactic << "Cutlass GEMM Tactic"; + if (is_tma_warp_specialized && getTileConfigAsInt() != (int)CutlassTileConfigSM90::ChooseWithHeuristic) { + assert(sm_version >= 90 && "Invalid cutlass GEMM config"); + tactic << "\n\tstyle=TMA Warp Specialized" + << "\n\tsm: " << sm_version << "\n\ttile shape ID: " << getTileConfigAsInt() + << "\n\tcluster shape ID: " << (int)cluster_shape + << "\n\tmainloop sched: " << (int)mainloop_schedule << "\n\tepi sched: " << (int)epilogue_schedule + << "\n\tenable cuda kernel: " << (enableCudaKernel ? "true" : "false"); + } else if (tile_config_sm80 != onnxruntime::llm::cutlass_extensions::CutlassTileConfig::ChooseWithHeuristic) { + assert(sm_version < 90 && "Invalid cutlass GEMM config"); + tactic << "\n\tstyle=compatible" + << "\n\ttile shape ID: " << (int)tile_config_sm80 << "\n\tstages: " << (int)stages + << "\n\tsplit k: " << (int)split_k_factor + << "\n\tenable cuda kernel: " << (enableCudaKernel ? "true" : "false"); + } else if (enableCudaKernel) { + tactic << "\n\tenable cuda kernel: " << (enableCudaKernel ? "true" : "false"); + } else { + tactic << "\n\tundefined"; + } + tactic << "\n"; + return tactic.str(); + } +}; + +inline std::ostream& operator<<(std::ostream& out, CutlassGemmConfig const& config) { + // clang-format off + if (config.is_tma_warp_specialized) + { + out << "tile_config_sm90_enum: " << config.getTileConfigAsInt() + << ", mainloop_schedule_enum: " << int(config.mainloop_schedule) + << ", epilogue_schedule_enum: " << int(config.epilogue_schedule) + << ", cluster_shape_enum: " << int(config.cluster_shape) + << ", enable_cuda_kernel: " << (config.enableCudaKernel ? "true" : "false"); + } + else + { + out << "tile_config_enum: " << config.getTileConfigAsInt() + << ", split_k_style_enum: " << int(config.split_k_style) + << ", split_k_factor: " << config.split_k_factor + << ", stages: " << config.stages + << ", enable_cuda_kernel: " << (config.enableCudaKernel ? "true" : "false"); + } + // clang-format on + return out; +} + +} // namespace cutlass_extensions +} // namespace onnxruntime::llm + +#if defined(__GNUC__) +#pragma GCC diagnostic pop +#endif diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/interleaved_numeric_conversion.h b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/interleaved_numeric_conversion.h new file mode 100644 index 0000000000000..86c45a865954e --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/interleaved_numeric_conversion.h @@ -0,0 +1,399 @@ +/* + * Copyright (c) 2017-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! + \file + \brief Boost-like numeric conversion operator for int8 and CUTLASS int4b_t interleaved in a register +*/ + +#pragma once + +#include "cutlass/arch/arch.h" +#include "cutlass/array.h" +#include "cutlass/half.h" +#include "cutlass/numeric_types.h" + +namespace cutlass { + +// This converter is meant to be used with data interleaved in a 32-bit register where the even elements are in the low +// bits and the odd elemeents are in the high bits of the register. In addition, it assumes elements were originally +// signed and had a bias of 2**(b-1) added (where b is the number of bits in the type) to make all numbers unsigned. +// This converter will uninterleave the data and subtract the bias while converting to the result type. +template +struct FastInterleavedAndBiasedNumericArrayConverter { +}; + +template <> +struct FastInterleavedAndBiasedNumericArrayConverter { + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + result_type result; + + uint32_t* h = reinterpret_cast(&result); + uint32_t const i8s = reinterpret_cast(source); + + static constexpr uint32_t mask_for_elt_01 = 0x5250; + static constexpr uint32_t mask_for_elt_23 = 0x5351; + static constexpr uint32_t start_byte_for_fp16 = 0x64646464; + asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(h[0]) : "r"(i8s), "n"(start_byte_for_fp16), "n"(mask_for_elt_01)); + asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(h[1]) : "r"(i8s), "n"(start_byte_for_fp16), "n"(mask_for_elt_23)); + + // Lastly, we subtract 1152 from our constructed number using fp16 math to get our signed integer as fp16. + static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(I8s_TO_F16s_MAGIC_NUM)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[1]) : "r"(h[1]), "r"(I8s_TO_F16s_MAGIC_NUM)); + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) { + return convert(s); + } +}; + +template +struct FastInterleavedAndBiasedNumericArrayConverter { + static constexpr int VEC_WIDTH = 4; + static_assert(!(N % VEC_WIDTH), "N must be multiple of 4."); + + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + using scalar_result_type = typename result_type::Element; + using scalar_source_type = typename source_type::Element; + FastInterleavedAndBiasedNumericArrayConverter + convert_vector_; + + result_type result; + using vec_result = Array; + using vec_source = Array; + + vec_result* result_ptr = reinterpret_cast(&result); + vec_source const* source_ptr = reinterpret_cast(&source); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / VEC_WIDTH; ++i) { + result_ptr[i] = convert_vector_(source_ptr[i]); + } + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) { + return convert(s); + } +}; + +template <> +struct FastInterleavedAndBiasedNumericArrayConverter { + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + result_type result; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + + uint32_t* bf16_result_ptr = reinterpret_cast(&result); + uint32_t const i8s = reinterpret_cast(source); + + static constexpr uint32_t fp32_base = 0x4B000000; + float fp32_intermediates[4]; + + // Construct FP32s, bfloat does not have enough mantissa for IADD trick + uint32_t* fp32_intermediates_casted = reinterpret_cast(fp32_intermediates); + fp32_intermediates_casted[0] = __byte_perm(i8s, fp32_base, 0x7650); + fp32_intermediates_casted[1] = __byte_perm(i8s, fp32_base, 0x7652); + fp32_intermediates_casted[2] = __byte_perm(i8s, fp32_base, 0x7651); + fp32_intermediates_casted[3] = __byte_perm(i8s, fp32_base, 0x7653); + + // Subtract out fp32_base + 128 to make the unsigned integer signed. + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < 4; ++ii) { + fp32_intermediates[ii] -= 8388736.f; + } + + // Truncate the fp32 representation and pack up as bfloat16s. + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < 2; ++ii) { + bf16_result_ptr[ii] = __byte_perm(fp32_intermediates_casted[2 * ii + 0], fp32_intermediates_casted[2 * ii + 1], 0x7632); + } +#else + // Disable this on architectures older than Ampere since they lack hardware for bf16 mma. If one wishes to use + // HMMA on older hardware, they should Convert directly to FP16 using FP16 converters. + result.clear(); // Suppress compiler warning + arch::device_breakpoint(); +#endif + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) { + return convert(s); + } +}; + +template +struct FastInterleavedAndBiasedNumericArrayConverter { + static constexpr int VEC_WIDTH = 4; + static_assert(!(N % VEC_WIDTH), "N must be multiple of 4."); + + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + using scalar_result_type = typename result_type::Element; + using scalar_source_type = typename source_type::Element; + FastInterleavedAndBiasedNumericArrayConverter + convert_vector_; + + result_type result; + using vec_result = Array; + using vec_source = Array; + + vec_result* result_ptr = reinterpret_cast(&result); + vec_source const* source_ptr = reinterpret_cast(&source); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / VEC_WIDTH; ++i) { + result_ptr[i] = convert_vector_(source_ptr[i]); + } + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) { + return convert(s); + } +}; + +template <> +struct FastInterleavedAndBiasedNumericArrayConverter { + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + result_type result; + + uint32_t* h = reinterpret_cast(&result); + uint32_t const i4s = reinterpret_cast(source); + + // First, we extract the i4s and construct an intermediate fp16 number. + static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint32_t BOTTOM_MASK = 0x000f000f; + static constexpr uint32_t TOP_MASK = 0x00f000f0; + static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400; + + // Note that the entire sequence only requires 1 shift instruction. This is thanks to the register packing + // format and the fact that we force our integers to be unsigned, and account for this in the fp16 subtractions. + // In addition, I exploit the fact that sub and fma have the same throughput in order to convert elt_23 and + // elt_67 to fp16 without having to shift them to the bottom bits before hand. + + // Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW dependency if we issue + // immediately before required. + const uint32_t top_i4s = i4s >> 8; + // Extract elt_01 - (i4s & 0x000f000f) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[0]) + : "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); + // Extract elt_23 (i4s & 0x00f000f0) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[1]) + : "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); + // Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[2]) + : "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); + // Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[3]) + : "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); + + // I use inline PTX below because I am not sure if the compiler will emit float2half instructions if I use the + // half2 ctor. In this case, I chose performance reliability over code readability. + + // This is the half2 {1032, 1032} represented as an integer. + static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408; + // This is the half2 {1 / 16, 1 / 16} represented as an integer. + static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00; + // This is the half2 {-72, -72} represented as an integer. + static constexpr uint32_t NEG_72 = 0xd480d480; + + // Finally, we construct the output numbers. + // Convert elt_01 + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM)); + // Convert elt_23 + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_72)); + // Convert elt_45 + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM)); + // Convert elt_67 + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_72)); + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) { + return convert(s); + } +}; + +template +struct FastInterleavedAndBiasedNumericArrayConverter { + static constexpr int VEC_WIDTH = 8; + static_assert(!(N % VEC_WIDTH), "N must be multiple of 8."); + + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + using scalar_result_type = typename result_type::Element; + using scalar_source_type = typename source_type::Element; + FastInterleavedAndBiasedNumericArrayConverter + convert_vector_; + + result_type result; + using vec_result = Array; + using vec_source = Array; + + vec_result* result_ptr = reinterpret_cast(&result); + vec_source const* source_ptr = reinterpret_cast(&source); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / VEC_WIDTH; ++i) { + result_ptr[i] = convert_vector_(source_ptr[i]); + } + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) { + return convert(s); + } +}; + +template <> +struct FastInterleavedAndBiasedNumericArrayConverter { + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + result_type result; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + + uint32_t* h = reinterpret_cast(&result); + uint32_t const source_i4s = reinterpret_cast(source); + + // First, we extract the i4s and construct an intermediate fp16 number. + static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint32_t MASK = 0x000f000f; + static constexpr uint32_t I4s_TO_BF16s_MAGIC_NUM = 0x43004300; + + // We don't have enough mantissa to remove as much shift overhead as FP16, so we must loop. + // No shift needed for first item. + uint32_t i4s = source_i4s; + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[0]) + : "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut)); + CUTLASS_PRAGMA_UNROLL + for (int ii = 1; ii < result_type::kElements / 2; ++ii) { + i4s >>= sizeof_bits::value; + // (i4s & 0x000f000f) | 0x43004300 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[ii]) + : "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut)); + } + + // This is the BF16 {-136, -136} represented as an integer. + static constexpr uint32_t BF16_BIAS = 0xC308C308; + static constexpr uint32_t BF16_ONE = 0x3F803F80; + + // Finally, we construct the output numbers. + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < result_type::kElements / 2; ++ii) { + // Since this section is for Ampere+, we use bf16 fma to do the bias subtraction + asm("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[ii]) : "r"(h[ii]), "r"(BF16_ONE), "r"(BF16_BIAS)); + } +#else + // Disable this on architectures older than Ampere since they lack hardware for bf16 mma. If one wishes to use + // HMMA on older hardware, they should Convert directly to FP16 using FP16 converters. + arch::device_breakpoint(); + result.clear(); // Suppress compiler warning. +#endif + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) { + return convert(s); + } +}; + +template +struct FastInterleavedAndBiasedNumericArrayConverter { + static constexpr int VEC_WIDTH = 8; + static_assert(!(N % VEC_WIDTH), "N must be multiple of 8."); + + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + using scalar_result_type = typename result_type::Element; + using scalar_source_type = typename source_type::Element; + FastInterleavedAndBiasedNumericArrayConverter + convert_vector_; + + result_type result; + using vec_result = Array; + using vec_source = Array; + + vec_result* result_ptr = reinterpret_cast(&result); + vec_source const* source_ptr = reinterpret_cast(&source); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / VEC_WIDTH; ++i) { + result_ptr[i] = convert_vector_(source_ptr[i]); + } + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) { + return convert(s); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/tile_interleaved_layout.h b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/tile_interleaved_layout.h new file mode 100644 index 0000000000000..30df05f24257e --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/tile_interleaved_layout.h @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2017-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! \file + \brief Defines new layouts needed for MoE +*/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/pitch_linear_coord.h" + +namespace cutlass { +namespace layout { + +template +struct ColumnMajorTileInterleave { + static constexpr int kRowsPerTile = RowsPerTile; + static constexpr int kColumnsInterleaved = ColumnsInterleaved; +}; + +template +struct IsColumnMajorTileInterleave { + static constexpr bool value = false; +}; + +template +struct IsColumnMajorTileInterleave> { + static constexpr bool value = true; +}; + +} // namespace layout +} // namespace cutlass diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h new file mode 100644 index 0000000000000..cf5ebdaeec261 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h @@ -0,0 +1,218 @@ +/* + * Copyright (c) 2017-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! \file + \brief Templates for visiting scales to be used when dequantizing the weights for weight-only GEMM + quantization. +*/ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/cutlass.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" +#include "cutlass/transform/threadblock/predicated_tile_access_iterator_params.h" + +//////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace transform { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +template +class FineGrainedScaleZeroIterator; + +template +class FineGrainedScaleZeroIterator { + public: + using Shape = Shape_; + using Element = Element_; + using Layout = layout::RowMajor; + static int const kAdvanceRank = 0; + static int const kAlignment = Alignment_; + + static int const kAccessesPerVector = 1; + + /// Row index of scales corresponding to the groupsize of 64 + int row_groupsize64_; + int group_size_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using AccessType = AlignedArray; + + using Fragment = cutlass::Array; + + // For compatibility with existing iterator interface + struct Params { + LongIndex stride_ = 0; + + /// amount (in byte) to increment pointer from first access of current tile + /// to first access of next tile + LongIndex inc_advance_ = 0; + + // Default ctor + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : stride_(layout.stride(0)) { + inc_advance_ = Shape::kRow * stride_ * sizeof_bits::value / 8; + } + }; + + private: + /// Internal pointer type permits fast address arithmetic + using BytePointer = char*; + + private: + // + // Data members + // + + /// Parameters object with precomputed internal state + Params const params_; + + /// Internal pointer to first access of tile + BytePointer pointer_scale_; + BytePointer pointer_zero_; + + bool is_valid_ = false; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_DEVICE + FineGrainedScaleZeroIterator( + ///< Precomputed parameters object + Params const& params, + ///< Pointer to start of scale tensor + Pointer pointer_scale, + ///< Pointer to start of zero tensor + Pointer pointer_zero, + ///< Extent of the scale and bias + TensorCoord extent, + ///< ID of each participating thread + int thread_id, + ///< Initial offset of threadblock + TensorCoord const& threadblock_offset, + ///< Group size + int group_size) + : params_(params), pointer_scale_(reinterpret_cast(const_cast(pointer_scale))), pointer_zero_(reinterpret_cast(const_cast(pointer_zero))) { + row_groupsize64_ = threadblock_offset.row(); + group_size_ = group_size; + + const LongIndex tb_row_byte_offset = threadblock_offset.row() / (group_size / 64) * params_.stride_ * sizeof_bits::value / 8; + const LongIndex tb_col_byte_offset = threadblock_offset.column() * sizeof_bits::value / 8; + pointer_scale_ += (tb_row_byte_offset + tb_col_byte_offset); + + if (pointer_zero_ != nullptr) { + pointer_zero_ += (tb_row_byte_offset + tb_col_byte_offset); + } + + static constexpr int THREADS_PER_ROW = Shape::kColumn / kAlignment; + + int const thread_row = thread_id / THREADS_PER_ROW; + int const thread_col = thread_id % THREADS_PER_ROW; + + const LongIndex thread_row_byte_offset = thread_row * params_.stride_ * sizeof_bits::value / 8; + const LongIndex thread_col_byte_offset = thread_col * kAlignment * sizeof_bits::value / 8; + pointer_scale_ += (thread_row_byte_offset + thread_col_byte_offset); + if (pointer_zero_ != nullptr) { + pointer_zero_ += (thread_row_byte_offset + thread_col_byte_offset); + } + + // For the rows, we must check that we are within the extent AND the tile to avoid extra reads on + // a given iteration. The same threads will be responsible for issues reads since the number of scales + // read in a given iteration is a constant. Therefore, we should never have to update is_valid_ + // outside of the constructor. + int const global_row = threadblock_offset.row() + thread_row; + int const global_col = threadblock_offset.column() + thread_col * kAlignment; + + bool const row_in_bounds = global_row < extent.row() && thread_row < Shape::kRow; + bool const col_in_bounds = global_col < extent.column(); + + is_valid_ = row_in_bounds && col_in_bounds; + } + + /// Construct a PredicatedTileAccessIterator with zero threadblock offset + CUTLASS_HOST_DEVICE FineGrainedScaleZeroIterator(Params const& params, ///< Precomputed parameters object + Pointer pointer_scale, ///< Pointer to start of scale tensor + Pointer pointer_zero, ///< Pointer to start of zero tensor + TensorCoord extent, ///< Extent of tensor + int thread_id, ///< ID of each participating thread + int group_size) + : FineGrainedScaleZeroIterator( + params, pointer_scale, pointer_zero, extent, thread_id, make_Coord(0, 0), group_size) { + } + + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const& tile_offset) { + const LongIndex row_byte_offset = tile_offset.row() * params_.inc_advance_; + const LongIndex col_byte_offset = tile_offset.column() * Shape::kColumn * sizeof_bits::value / 8; + pointer_scale_ += row_byte_offset + col_byte_offset; + if (pointer_zero_ != nullptr) { + pointer_zero_ += row_byte_offset + col_byte_offset; + } + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE void clear_mask(bool enable = true) { + is_valid_ &= (!enable); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() const { + return is_valid_; + } + + /// Returns a scale pointer + CUTLASS_HOST_DEVICE + AccessType* get_scale() const { + return reinterpret_cast(pointer_scale_); + } + + /// Returns a zero pointer + CUTLASS_HOST_DEVICE + AccessType* get_zero() const { + return reinterpret_cast(pointer_zero_); + } +}; + +} // namespace threadblock +} // namespace transform +} // namespace cutlass diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/weight_only_quant_op.h b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/weight_only_quant_op.h new file mode 100644 index 0000000000000..cc54764c2be50 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/weight_only_quant_op.h @@ -0,0 +1,41 @@ +/* + * Copyright (c) 2017-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! \file + \brief Defines iterators used by warp-level matrix multiply operations targeting Tensor Cores. +*/ + +#pragma once + +namespace cutlass { + +enum class WeightOnlyQuantOp { + UNDEFINED, + PER_COLUMN_SCALE_ONLY, + FINEGRAINED_SCALE_ONLY, + FINEGRAINED_SCALE_AND_ZEROS +}; + +constexpr bool isFinegrained(WeightOnlyQuantOp op) { + return op == WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS || op == WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY; +} + +constexpr bool hasZero(WeightOnlyQuantOp op) { + return op == WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS; +} + +} // namespace cutlass diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_heuristic.cc b/onnxruntime/contrib_ops/cuda/llm/cutlass_heuristic.cc new file mode 100644 index 0000000000000..d53fb558ba1a1 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_heuristic.cc @@ -0,0 +1,479 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef __GNUC__ // Check if the compiler is GCC or Clang +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#pragma GCC diagnostic ignored "-Wunused-parameter" +#pragma GCC diagnostic ignored "-Wunused-local-typedefs" +#pragma GCC diagnostic ignored "-Wsign-compare" +#endif // __GNUC__ + +#include "contrib_ops/cuda/llm/cutlass_heuristic.h" + +#include + +#include "cutlass/gemm/gemm.h" +#include "cutlass/numeric_types.h" +#include "core/common/common.h" + +#include +#include +#include +#include + +using namespace onnxruntime::llm::cutlass_extensions; + +namespace onnxruntime::llm { +namespace kernels { +namespace cutlass_kernels { + +struct TileShape { + int m; + int n; +}; + +TileShape get_cta_shape_for_config(CutlassTileConfig tile_config) { + switch (tile_config) { + case CutlassTileConfig::CtaShape16x128x64_WarpShape16x32x64: + return TileShape{16, 128}; + case CutlassTileConfig::CtaShape16x256x64_WarpShape16x64x64: + return TileShape{16, 256}; + case CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64: + return TileShape{32, 128}; + case CutlassTileConfig::CtaShape64x64x128_WarpShape32x64x64: + return TileShape{64, 64}; + case CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64: + case CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64: + return TileShape{64, 128}; + case CutlassTileConfig::CtaShape128x64x64_WarpShape64x32x64: + return TileShape{128, 64}; + case CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8: + case CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64: + case CutlassTileConfig::CtaShape128x128x64_WarpShape64x64x64: + case CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64: + return TileShape{128, 128}; + case CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64: + return TileShape{128, 256}; + case CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64: + return TileShape{256, 128}; + case CutlassTileConfig::CtaShape16x256x128_WarpShape16x64x128: + return TileShape{16, 256}; + default: + ORT_THROW("[get_grid_shape_for_config] Invalid config"); + } +} + +bool is_valid_split_k_factor(int64_t const m, int64_t const n, int64_t const k, TileShape const tile_shape, + int const split_k_factor, size_t const workspace_bytes, bool const is_weight_only) { + // All tile sizes have a k_tile of 64. + static constexpr int k_tile = 64; + + // For weight-only quant, we need k and k_elements_per_split to be a multiple of cta_k + if (is_weight_only) { + if ((k % k_tile) != 0) { + return false; + } + + if ((k % split_k_factor) != 0) { + return false; + } + + int const k_elements_per_split = k / split_k_factor; + if ((k_elements_per_split % k_tile) != 0) { + return false; + } + } + + // Check that the workspace has sufficient space for this split-k factor + int const ctas_in_m_dim = (m + tile_shape.m - 1) / tile_shape.m; + int const ctas_in_n_dim = (n + tile_shape.n - 1) / tile_shape.n; + int const required_ws_bytes = split_k_factor == 1 ? 0 : sizeof(int) * ctas_in_m_dim * ctas_in_n_dim; + + if (required_ws_bytes > workspace_bytes) { + return false; + } + + return true; +} + +std::vector get_candidate_tiles( + int const sm, CutlassGemmConfig::CandidateConfigTypeParam const config_type_param) { + enum class CutlassGemmType : char { + Default, + WeightOnly, + Simt, + Int8, + Fp8 + }; + + CutlassGemmType gemm_type = CutlassGemmType::Default; + if (config_type_param & CutlassGemmConfig::SIMT_ONLY) { + gemm_type = CutlassGemmType::Simt; + } else if (config_type_param & CutlassGemmConfig::WEIGHT_ONLY) { + gemm_type = CutlassGemmType::WeightOnly; + } else if (config_type_param & CutlassGemmConfig::INT8_ONLY) { + gemm_type = CutlassGemmType::Int8; + } else if (config_type_param & CutlassGemmConfig::FP8_ONLY) { + gemm_type = CutlassGemmType::Fp8; + } + + std::vector base_configs{ + CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64}; + if (sm >= 75) { + base_configs.push_back(CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64); + } + + switch (gemm_type) { + case CutlassGemmType::Simt: + return {CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8}; + case CutlassGemmType::WeightOnly: + if (sm >= 75) { + return {CutlassTileConfig::CtaShape16x128x64_WarpShape16x32x64, + CutlassTileConfig::CtaShape16x256x64_WarpShape16x64x64, + CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, + CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64, + CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64}; + } else { + return {CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, + CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64}; + } + case CutlassGemmType::Int8: + return {CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, + CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64, + CutlassTileConfig::CtaShape128x64x64_WarpShape64x32x64, + CutlassTileConfig::CtaShape64x64x128_WarpShape32x64x64, + CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64, + CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64}; + case CutlassGemmType::Fp8: + if (config_type_param & CutlassGemmConfig::GROUPED_GEMM) { + if (sm == 89) { + return {CutlassTileConfig::CtaShape16x256x128_WarpShape16x64x128, + CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, + CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64, + CutlassTileConfig::CtaShape64x64x128_WarpShape32x64x64, + CutlassTileConfig::CtaShape128x64x64_WarpShape64x32x64, + CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64, + CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64}; + } else { + // no valid ampere style fp8 configs for sm90 + return {}; + } + } else { + if (sm == 89) { + return {CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, + CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64, + CutlassTileConfig::CtaShape64x64x128_WarpShape32x64x64, + CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64, + CutlassTileConfig::CtaShape128x64x64_WarpShape64x32x64, + CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64, + CutlassTileConfig::CtaShape128x128x64_WarpShape64x64x64, + CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64, + CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64, + CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64, + CutlassTileConfig::CtaShape128x64x128_WarpShape64x32x128, + CutlassTileConfig::CtaShape16x256x128_WarpShape16x64x128}; + } else { + return {}; + } + } + default: + return base_configs; + } +} + +std::vector get_candidate_tiles_sm90(CutlassGemmConfig::CandidateConfigTypeParam const config) { +#ifdef FAST_BUILD + // Fast build disables all configs except this one for SM90 + return {CutlassTileConfigSM90::CtaShape128x128x128B}; +#else + if (config & CutlassGemmConfig::GROUPED_GEMM) { + return {CutlassTileConfigSM90::CtaShape128x16x128B, CutlassTileConfigSM90::CtaShape128x32x128B, + CutlassTileConfigSM90::CtaShape128x64x128B, CutlassTileConfigSM90::CtaShape128x128x128B, + CutlassTileConfigSM90::CtaShape128x256x128B, CutlassTileConfigSM90::CtaShape256x128x128B}; + } else { + return {CutlassTileConfigSM90::CtaShape64x16x128B, CutlassTileConfigSM90::CtaShape64x32x128B, + CutlassTileConfigSM90::CtaShape64x64x128B, CutlassTileConfigSM90::CtaShape64x128x128B, + CutlassTileConfigSM90::CtaShape64x256x128B, CutlassTileConfigSM90::CtaShape128x16x128B, + CutlassTileConfigSM90::CtaShape128x32x128B, CutlassTileConfigSM90::CtaShape128x64x128B, + CutlassTileConfigSM90::CtaShape128x128x128B, CutlassTileConfigSM90::CtaShape128x256x128B}; + } +#endif +} + +// We only compile CUTLASS kernels with multi-cast along M if the M tile is >= 128. This is purely to improve +// compilation speed. +bool sm90_supports_mcast_along_m(CutlassTileConfigSM90 const tile) { +#ifdef FAST_BUILD + return false; +#else + std::set valid_tiles{CutlassTileConfigSM90::CtaShape128x16x128B, + CutlassTileConfigSM90::CtaShape128x32x128B, CutlassTileConfigSM90::CtaShape128x64x128B, + CutlassTileConfigSM90::CtaShape128x128x128B, CutlassTileConfigSM90::CtaShape128x256x128B, + CutlassTileConfigSM90::CtaShape256x128x128B}; + return valid_tiles.count(tile) == 1; +#endif +} + +// We only compile CUTLASS kernels with multi-cast along N if the N tile is >= 128. This is purely to improve +// compilation speed. +bool sm90_supports_mcast_along_n(CutlassTileConfigSM90 const tile) { +#ifdef FAST_BUILD + return false; +#else + std::set valid_tiles{CutlassTileConfigSM90::CtaShape64x128x128B, + CutlassTileConfigSM90::CtaShape64x256x128B, CutlassTileConfigSM90::CtaShape128x128x128B, + CutlassTileConfigSM90::CtaShape128x256x128B, CutlassTileConfigSM90::CtaShape256x128x128B}; + return valid_tiles.count(tile) == 1; +#endif +} + +std::vector get_candidate_configs_sm90(CutlassGemmConfig::CandidateConfigTypeParam const config) { + auto tiles = get_candidate_tiles_sm90(config); + std::vector candidate_configs; + for (auto const& tile_config : tiles) { + CutlassGemmConfig config( + tile_config, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1); + candidate_configs.push_back(config); + + bool const has_m_mcast = sm90_supports_mcast_along_m(tile_config); + bool const has_n_mcast = sm90_supports_mcast_along_n(tile_config); + if (has_m_mcast) { + CutlassGemmConfig config( + tile_config, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_2x1x1); + candidate_configs.push_back(config); + } + + if (has_n_mcast) { + CutlassGemmConfig config( + tile_config, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x2x1); + candidate_configs.push_back(config); + } + + if (has_m_mcast && has_n_mcast) { + CutlassGemmConfig config( + tile_config, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_2x2x1); + candidate_configs.push_back(config); + } + } + // add cuda kernel profiler to tactics for weight-only plugins + if (config & CutlassGemmConfig::WEIGHT_ONLY) { + if (tiles.size() > 0) { + CutlassGemmConfig CudaKernelConfig( + tiles[0], MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1); + CudaKernelConfig.enableCudaKernel = true; + candidate_configs.push_back(CudaKernelConfig); + } + } + return candidate_configs; +} + +std::vector get_candidate_configs_sm100(CutlassGemmConfig::CandidateConfigTypeParam const config) { +#ifdef FAST_BUILD + // Fast build disables all configs except this one for SM100 + return {CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x128x128B, MainloopScheduleType::AUTO, + EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1}}; +#else + if (config & CutlassGemmConfig::GROUPED_GEMM) { + std::vector candidate_configs; + if ((config & CutlassGemmConfig::FP4_ONLY) != 0) { + candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x128x128B, + MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1}); + candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape256x128x128B, + MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_2x1x1}); + // candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x256x128B, + // MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1}); + candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x256x128B, + MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x2x1}); + candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape256x64x128B, + MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_2x1x1}); + candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x64x128B, + MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1}); + return candidate_configs; + } + + for (int cluster_m = 1; cluster_m <= 2; cluster_m++) { + bool Is2SM = cluster_m == 2; + for (int cluster_n = 1; cluster_n <= 2; cluster_n++) { + std::vector base = {// M=128 + CutlassTileConfigSM100::CtaShape128x128x128B, CutlassTileConfigSM100::CtaShape128x256x128B}; + + if (Is2SM) { + if (cluster_n == 1) { + base.push_back(CutlassTileConfigSM100::CtaShape128x64x128B); + base.push_back(CutlassTileConfigSM100::CtaShape256x64x128B); + } + + std::vector twosm = {// M=256 + CutlassTileConfigSM100::CtaShape256x128x128B, CutlassTileConfigSM100::CtaShape256x256x128B}; + std::copy(twosm.begin(), twosm.end(), std::back_inserter(base)); + } else { + if (cluster_n == 1) { + base.push_back(CutlassTileConfigSM100::CtaShape128x32x128B); + if ((config & CutlassGemmConfig::FP8_ONLY) != 0) { + base.push_back(CutlassTileConfigSM100::CtaShape128x16x128B); + } + } + + if (cluster_n == 1 && cluster_m == 1 && ((config & CutlassGemmConfig::FP8_ONLY) != 0)) { + base.push_back(CutlassTileConfigSM100::CtaShape128x8x256B); + } + + std::vector onesm{CutlassTileConfigSM100::CtaShape64x64x128B, + CutlassTileConfigSM100::CtaShape64x128x128B, CutlassTileConfigSM100::CtaShape64x256x128B, + CutlassTileConfigSM100::CtaShape128x64x128B}; + std::copy(onesm.begin(), onesm.end(), std::back_inserter(base)); + } + + constexpr std::array, 2> cluster_shapes = + {{std::array{ClusterShape::ClusterShape_1x1x1, ClusterShape::ClusterShape_1x2x1}, + std::array{ClusterShape::ClusterShape_2x1x1, ClusterShape::ClusterShape_2x2x1}}}; + + auto cluster = cluster_shapes[cluster_m - 1][cluster_n - 1]; + for (auto tile : base) { + CutlassGemmConfig config{tile, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, cluster}; + candidate_configs.push_back(config); + } + } + } + return candidate_configs; + } else { + ORT_THROW("Not Implemented: SM100 GEMM candidates have not been defined."); + } +#endif + +} // namespace kernels + +std::vector get_candidate_configs( + int sm, int const max_split_k, CutlassGemmConfig::CandidateConfigTypeParam const config_type_param) { + if ((config_type_param & CutlassGemmConfig::FP4_ONLY) && !(config_type_param & CutlassGemmConfig::BLACKWELL)) { + // FP4 is only supported on blackwell + return {}; + } + + if (sm == 90 && (config_type_param & CutlassGemmConfig::HOPPER)) { + return get_candidate_configs_sm90(config_type_param); + } + if (sm >= 100 && sm != 120 && (config_type_param & CutlassGemmConfig::BLACKWELL)) { + return get_candidate_configs_sm100(config_type_param); + } + + std::vector tiles = get_candidate_tiles(sm, config_type_param); + + std::vector candidate_configs; + bool const int8_configs_only = config_type_param & CutlassGemmConfig::INT8_ONLY; + int const min_stages = int8_configs_only ? 3 : 2; + int const max_stages = int8_configs_only ? 6 : (sm >= 80 ? 4 : 2); + for (auto const& tile_config : tiles) { + for (int stages = min_stages; stages <= max_stages; ++stages) { + CutlassGemmConfig config(tile_config, SplitKStyle::NO_SPLIT_K, 1, stages); + candidate_configs.push_back(config); + if (sm >= 75) { + for (int split_k_factor = 2; split_k_factor <= max_split_k; ++split_k_factor) { + auto config = CutlassGemmConfig{tile_config, SplitKStyle::SPLIT_K_SERIAL, split_k_factor, stages}; + candidate_configs.push_back(config); + } + } + } + } + // add cuda kernel profiler to tactics for weight-only plugins + if (config_type_param & CutlassGemmConfig::WEIGHT_ONLY) { + if (tiles.size() > 0) { + CutlassGemmConfig CudaKernelConfig(tiles[0], SplitKStyle::NO_SPLIT_K, 1, min_stages); + CudaKernelConfig.enableCudaKernel = true; + candidate_configs.push_back(CudaKernelConfig); + } + } + return candidate_configs; +} + +CutlassGemmConfig estimate_best_config_from_occupancies( + std::vector const& candidate_configs, + std::vector const& occupancies, int64_t const m, int64_t const n, int64_t const k, int64_t const /*num_experts*/, + int const split_k_limit, size_t const workspace_bytes, int const multi_processor_count, int const is_weight_only) { + if (occupancies.size() != candidate_configs.size()) { + ORT_THROW( + "[estimate_best_config_from_occupancies] occpancies and " + "candidate configs vectors must have equal length."); + } + + CutlassGemmConfig best_config; + // Score will be [0, 1]. The objective is to minimize this score. + // It represents the fraction of SM resources unused in the last wave. + float config_score = 1.0f; + int config_waves = INT_MAX; + int current_m_tile = 0; + + int const max_split_k = n >= multi_processor_count * 256 ? 1 : split_k_limit; + for (int ii = 0; ii < candidate_configs.size(); ++ii) { + CutlassGemmConfig candidate_config = candidate_configs[ii]; + TileShape tile_shape = get_cta_shape_for_config(candidate_config.tile_config_sm80); + int occupancy = occupancies[ii]; + + if (occupancy == 0) { + continue; + } + + // Keep small tile sizes when possible. + if (best_config.tile_config_sm80 != CutlassTileConfig::ChooseWithHeuristic && m < current_m_tile && current_m_tile < tile_shape.m) { + continue; + } + + int const ctas_in_m_dim = (m + tile_shape.m - 1) / tile_shape.m; + int const ctas_in_n_dim = (n + tile_shape.n - 1) / tile_shape.n; + + for (int split_k_factor = 1; split_k_factor <= max_split_k; ++split_k_factor) { + if (is_valid_split_k_factor(m, n, k, tile_shape, split_k_factor, workspace_bytes, is_weight_only)) { + int const ctas_per_wave = occupancy * multi_processor_count; + int const ctas_for_problem = ctas_in_m_dim * ctas_in_n_dim * split_k_factor; + + int const num_waves_total = (ctas_for_problem + ctas_per_wave - 1) / ctas_per_wave; + float const num_waves_fractional = ctas_for_problem / float(ctas_per_wave); + float const current_score = float(num_waves_total) - num_waves_fractional; + + float const score_slack = 0.1f; + if (current_score < config_score || ((config_waves > num_waves_total) && (current_score < config_score + score_slack))) { + config_score = current_score; + config_waves = num_waves_total; + SplitKStyle split_style = split_k_factor > 1 ? SplitKStyle::SPLIT_K_SERIAL : SplitKStyle::NO_SPLIT_K; + best_config = CutlassGemmConfig( + candidate_config.tile_config_sm80, split_style, split_k_factor, candidate_config.stages); + current_m_tile = tile_shape.m; + } else if (current_score == config_score && (best_config.stages < candidate_config.stages || split_k_factor < best_config.split_k_factor || current_m_tile < tile_shape.m)) { + // Prefer deeper pipeline or smaller split-k + SplitKStyle split_style = split_k_factor > 1 ? SplitKStyle::SPLIT_K_SERIAL : SplitKStyle::NO_SPLIT_K; + best_config = CutlassGemmConfig( + candidate_config.tile_config_sm80, split_style, split_k_factor, candidate_config.stages); + current_m_tile = tile_shape.m; + config_waves = num_waves_total; + } + } + } + } + + if (best_config.tile_config_sm80 == CutlassTileConfig::ChooseWithHeuristic) { + ORT_THROW("Heuristic failed to find a valid config."); + } + + return best_config; +} + +} // namespace cutlass_kernels +} // namespace kernels +} // namespace onnxruntime::llm + +#ifdef __GNUC__ // Check if the compiler is GCC or Clang +#pragma GCC diagnostic pop +#endif // __GNUC diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_heuristic.h b/onnxruntime/contrib_ops/cuda/llm/cutlass_heuristic.h new file mode 100644 index 0000000000000..b9b0301d78fc7 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_heuristic.h @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "cute/tensor.hpp" +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm_configs.h" + +namespace onnxruntime::llm { +namespace kernels { +namespace cutlass_kernels { + +template +struct should_filter_tma_warp_specialized_gemm_problem_shape { +#ifdef FAST_BUILD + using SupportedCtaShape = cute::Shape(TileShape{}))>; + using SupportedCgaShape = cute::Shape; + + constexpr static bool value = !cute::is_same_v || !cute::is_same_v; +#else + constexpr static bool value = false; +#endif +}; +template +constexpr static bool should_filter_tma_warp_specialized_gemm_problem_shape_v = should_filter_tma_warp_specialized_gemm_problem_shape::value; + +std::vector get_candidate_configs( + int sm, int const max_split_k, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig::CandidateConfigTypeParam const); + +onnxruntime::llm::cutlass_extensions::CutlassGemmConfig estimate_best_config_from_occupancies( + std::vector const& candidate_configs, + std::vector const& occupancies, int64_t const m, int64_t const n, int64_t const k, int64_t const /*num_experts*/, + int const split_k_limit, size_t const workspace_bytes, int const multi_processor_count, int const is_weight_only); + +} // namespace cutlass_kernels +} // namespace kernels +} // namespace onnxruntime::llm diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_preprocessors.cc b/onnxruntime/contrib_ops/cuda/llm/cutlass_preprocessors.cc new file mode 100644 index 0000000000000..50ee944161538 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_preprocessors.cc @@ -0,0 +1,687 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "contrib_ops/cuda/llm/cutlass_preprocessors.h" + +#include + +#include "core/common/common.h" +#include "contrib_ops/cuda/llm/common/cuda_runtime_utils.h" +#include "contrib_ops/cuda/llm/common/logger.h" + +#if defined(__GNUC__) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-parameter" +#endif + +#include "cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h" + +#if defined(__GNUC__) +#pragma GCC diagnostic pop +#endif + +using namespace onnxruntime::llm::common; + +namespace onnxruntime::llm { +namespace kernels { +namespace cutlass_kernels { + +struct LayoutDetails { + enum class Layout { + UNKNOWN, + ROW_MAJOR, + COLUMN_MAJOR + }; + + Layout layoutB = Layout::UNKNOWN; + int rows_per_column_tile = 1; + int columns_interleaved = 1; + + bool uses_imma_ldsm = false; +}; + +template +struct getLayoutDetails { +}; + +template <> +struct getLayoutDetails { + LayoutDetails operator()() { + LayoutDetails layout_details; + layout_details.layoutB = LayoutDetails::Layout::ROW_MAJOR; + return layout_details; + } +}; + +template <> +struct getLayoutDetails { + LayoutDetails operator()() { + LayoutDetails layout_details; + layout_details.layoutB = LayoutDetails::Layout::COLUMN_MAJOR; + return layout_details; + } +}; + +template +struct getLayoutDetails> { + LayoutDetails operator()() { + LayoutDetails layout_details; + layout_details.layoutB = LayoutDetails::Layout::COLUMN_MAJOR; + layout_details.rows_per_column_tile = RowsPerTile; + layout_details.columns_interleaved = ColumnsInterleaved; + return layout_details; + } +}; + +template +LayoutDetails getLayoutDetailsForArchAndQuantType() { + using CompileTraits = cutlass::gemm::kernel::LayoutDetailsB; + using LayoutB = typename CompileTraits::Layout; + using MmaOperator = typename CompileTraits::Operator; + LayoutDetails details = getLayoutDetails()(); + details.uses_imma_ldsm = std::is_same::value; + return details; +} + +template +LayoutDetails getLayoutDetailsForArch(QuantType quant_type) { + LayoutDetails details; + switch (quant_type) { + case QuantType::W8_A16: + details = getLayoutDetailsForArchAndQuantType(); + break; + case QuantType::W4_A16: + details = getLayoutDetailsForArchAndQuantType(); + break; + case QuantType::W4_AFP8: + details = getLayoutDetailsForArchAndQuantType(); + break; + default: + ORT_THROW("Unsupported quantization type"); + } + return details; +} + +LayoutDetails getLayoutDetailsForTransform(QuantType quant_type, int arch) { + if (arch >= 75 && arch < 80) { + return getLayoutDetailsForArch(quant_type); + } else if (arch >= 80 && arch < 90) { + return getLayoutDetailsForArch(quant_type); + } else if (arch >= 90 && arch < 100) { + return getLayoutDetailsForArch(quant_type); + } else if (arch >= 100) { + return getLayoutDetailsForArch(quant_type); + } else { + ORT_THROW("Unsupported Arch"); + return LayoutDetails(); + } +} + +// Permutes the rows of B in a way that is compatible with Turing+ architectures. +// +// Throws an error for other architectures. +// The data is permuted such that: +// For W8_A16, each group of 16 rows is permuted using the map below: +// 0 1 8 9 2 3 10 11 4 5 12 13 6 7 14 15 +// For W4_A16, each group of 32 rows is permuted using the map below: +// 0 1 8 9 16 17 24 25 2 3 10 11 18 19 26 27 4 5 12 13 20 21 28 29 6 7 14 15 22 23 30 31 +// For W4_A8, see the map in the code. The idea is similar to above. +// The goal of this permutation is to ensure data ends up in the correct threads after +// we execute LDSM. It counteracts the effect of the data being of different widths. +// For more information about the expected layouts, see the MMA section in the PTX docs. +std::vector get_permutation_map(QuantType quant_type) { + if (quant_type == QuantType::W8_A16) { + return {0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15}; + } else if (quant_type == QuantType::W4_A16) { + return {0, 1, 8, 9, 16, 17, 24, 25, 2, 3, 10, 11, 18, 19, 26, 27, 4, 5, 12, 13, 20, 21, 28, 29, 6, 7, 14, 15, + 22, 23, 30, 31}; + } else if (quant_type == QuantType::W4_AFP8) { + return {0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23, 8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, + 28, 29, 30, 31}; + } else { + ORT_THROW("Invalid quantization type for LDSM permutation"); + } +} + +void permute_B_rows_for_mixed_gemm(int8_t* permuted_quantized_tensor, int8_t const* quantized_tensor, + std::vector const& shape, QuantType quant_type) { + ORT_LLM_LOG_TRACE(__PRETTY_FUNCTION__); + // We only want to run this step for weight only quant. + std::vector row_permutation = get_permutation_map(quant_type); + + ORT_ENFORCE(shape.size() == 2 || shape.size() == 3, "Shape must be 2-D or 3-D"); + const size_t num_experts = shape.size() == 2 ? 1 : shape[0]; + const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1]; + const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2]; + + int const BITS_PER_ELT = get_weight_quant_bits(quant_type); + int const K = 16 / BITS_PER_ELT; + + uint32_t const* input_byte_ptr = reinterpret_cast(quantized_tensor); + uint32_t* output_byte_ptr = reinterpret_cast(permuted_quantized_tensor); + + int MMA_SHAPE_N = 8; + int B_ROWS_PER_MMA = 8 * K; + int const elts_in_int32 = 32 / BITS_PER_ELT; + + int const num_vec_cols = num_cols / elts_in_int32; + + ORT_ENFORCE(num_rows % B_ROWS_PER_MMA == 0, + "Invalid shape for quantized tensor. Number of rows of quantized matrix must be a multiple of ", + B_ROWS_PER_MMA); + ORT_ENFORCE(num_cols % MMA_SHAPE_N == 0, + "Invalid shape for quantized tensor. On turing/Ampere, the number of cols must be a multiple of ", + MMA_SHAPE_N); + + ORT_ENFORCE(size_t(B_ROWS_PER_MMA) == row_permutation.size(), "Unexpected number of LDSM rows permuted."); + + for (int expert = 0; expert < static_cast(num_experts); ++expert) { + const int64_t matrix_offset = expert * int64_t(num_rows) * int64_t(num_vec_cols); + for (int base_row = 0; base_row < static_cast(num_rows); base_row += B_ROWS_PER_MMA) { + for (int tile_row = 0; tile_row < B_ROWS_PER_MMA; ++tile_row) { + for (int write_col = 0; write_col < num_vec_cols; ++write_col) { + int const write_row = base_row + tile_row; + int const tile_read_row = row_permutation[tile_row]; + int const read_row = base_row + tile_read_row; + int const read_col = write_col; + + const int64_t read_offset = matrix_offset + int64_t(read_row) * num_vec_cols + read_col; + const int64_t write_offset = matrix_offset + int64_t(write_row) * num_vec_cols + write_col; + + output_byte_ptr[write_offset] = input_byte_ptr[read_offset]; + } + } + } + } +} + +// We need to use this transpose to correctly handle packed int4 and int8 data +// The reason this code is relatively complex is that the "trivial" loops took a substantial +// amount of time to transpose leading to long preprocessing times. This seemed to be a big +// issue for relatively large models. +template +void subbyte_transpose_impl( + int8_t* transposed_quantized_tensor, int8_t const* quantized_tensor, std::vector const& shape) { + ORT_LLM_LOG_TRACE(__PRETTY_FUNCTION__); + constexpr int bits_per_elt = get_weight_quant_bits(quant_type); + + ORT_ENFORCE(shape.size() == 2 || shape.size() == 3, "Shape must be 2-D or 3-D"); + const size_t num_experts = shape.size() == 2 ? 1 : shape[0]; + const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1]; + const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2]; + + const size_t col_bytes = num_cols * bits_per_elt / 8; + const size_t col_bytes_trans = num_rows * bits_per_elt / 8; + + uint8_t const* input_byte_ptr = reinterpret_cast(quantized_tensor); + uint8_t* output_byte_ptr = reinterpret_cast(transposed_quantized_tensor); + + static constexpr int ELTS_PER_BYTE = 8 / bits_per_elt; + + static constexpr int M_TILE_L1 = 64; + static constexpr int N_TILE_L1 = M_TILE_L1 / ELTS_PER_BYTE; + uint8_t cache_buf[M_TILE_L1][N_TILE_L1]; + + static constexpr int VECTOR_WIDTH = std::min(32, N_TILE_L1); + + // We assume the dims are a multiple of vector width. Our kernels only handle dims which are multiples + // of 64 for weight-only quantization. As a result, this seemed like a reasonable tradeoff because it + // allows GCC to emit vector instructions. + ORT_ENFORCE(!(col_bytes_trans % VECTOR_WIDTH) && !(col_bytes % VECTOR_WIDTH), + "Number of bytes for rows and cols must be a multiple of ", VECTOR_WIDTH, ". However, num_rows_bytes = ", + col_bytes_trans, " and num_col_bytes = ", col_bytes); + + for (size_t expert = 0; expert < num_experts; ++expert) { + const size_t matrix_offset = expert * num_rows * col_bytes; + for (size_t row_tile_start = 0; row_tile_start < num_rows; row_tile_start += M_TILE_L1) { + for (size_t col_tile_start_byte = 0; col_tile_start_byte < col_bytes; col_tile_start_byte += N_TILE_L1) { + int const row_limit = std::min(row_tile_start + M_TILE_L1, num_rows); + int const col_limit = std::min(col_tile_start_byte + N_TILE_L1, col_bytes); + + for (int ii = 0; ii < M_TILE_L1; ++ii) { + int const row = row_tile_start + ii; + + for (int jj = 0; jj < N_TILE_L1; jj += VECTOR_WIDTH) { + int const col = col_tile_start_byte + jj; + + const size_t logical_src_offset = matrix_offset + row * col_bytes + col; + + if (row < row_limit && col < col_limit) { + for (int v = 0; v < VECTOR_WIDTH; ++v) { + cache_buf[ii][jj + v] = input_byte_ptr[logical_src_offset + v]; + } + } + } + } + + if constexpr (bits_per_elt == 8) { + for (int ii = 0; ii < M_TILE_L1; ++ii) { + for (int jj = ii + 1; jj < N_TILE_L1; ++jj) { + std::swap(cache_buf[ii][jj], cache_buf[jj][ii]); + } + } + } else if constexpr (bits_per_elt == 4) { + for (int ii = 0; ii < M_TILE_L1; ++ii) { + // Using M_TILE_L1 here is deliberate since we assume that the cache tile + // is square in the number of elements (not necessarily the number of bytes). + for (int jj = ii + 1; jj < M_TILE_L1; ++jj) { + int const ii_byte = ii / ELTS_PER_BYTE; + int const ii_bit_offset = ii % ELTS_PER_BYTE; + + int const jj_byte = jj / ELTS_PER_BYTE; + int const jj_bit_offset = jj % ELTS_PER_BYTE; + + uint8_t src_elt = 0xF & (cache_buf[ii][jj_byte] >> (4 * jj_bit_offset)); + uint8_t tgt_elt = 0xF & (cache_buf[jj][ii_byte] >> (4 * ii_bit_offset)); + + cache_buf[ii][jj_byte] &= (0xF0 >> (4 * jj_bit_offset)); + cache_buf[jj][ii_byte] &= (0xF0 >> (4 * ii_bit_offset)); + + cache_buf[ii][jj_byte] |= (tgt_elt << (4 * jj_bit_offset)); + cache_buf[jj][ii_byte] |= (src_elt << (4 * ii_bit_offset)); + } + } + } else { + ORT_THROW("Unsupported quantization type."); + } + + const size_t row_tile_start_trans = col_tile_start_byte * ELTS_PER_BYTE; + const size_t col_tile_start_byte_trans = row_tile_start / ELTS_PER_BYTE; + + int const row_limit_trans = std::min(row_tile_start_trans + M_TILE_L1, num_cols); + int const col_limit_trans = std::min(col_tile_start_byte_trans + N_TILE_L1, col_bytes_trans); + + for (int ii = 0; ii < M_TILE_L1; ++ii) { + int const row = row_tile_start_trans + ii; + for (int jj = 0; jj < N_TILE_L1; jj += VECTOR_WIDTH) { + int const col = col_tile_start_byte_trans + jj; + + const size_t logical_tgt_offset = matrix_offset + row * col_bytes_trans + col; + + if (row < row_limit_trans && col < col_limit_trans) { + for (int v = 0; v < VECTOR_WIDTH; ++v) { + output_byte_ptr[logical_tgt_offset + v] = cache_buf[ii][jj + v]; + } + } + } + } + } + } + } +} + +void subbyte_transpose(int8_t* transposed_quantized_tensor, int8_t const* quantized_tensor, + std::vector const& shape, QuantType quant_type) { + ORT_LLM_LOG_TRACE(__PRETTY_FUNCTION__); + + if (quant_type == QuantType::W8_A16) { + subbyte_transpose_impl(transposed_quantized_tensor, quantized_tensor, shape); + } else if (quant_type == QuantType::W4_A16) { + subbyte_transpose_impl(transposed_quantized_tensor, quantized_tensor, shape); + } else if (quant_type == QuantType::W4_AFP8) { + subbyte_transpose_impl(transposed_quantized_tensor, quantized_tensor, shape); + } else { + ORT_THROW("Invalid quant_type"); + } +} + +void add_bias_and_interleave_int8s_inplace(int8_t* int8_tensor, const size_t num_elts) { + for (size_t ii = 0; ii < num_elts; ++ii) { + int8_tensor[ii] = int8_t(int(int8_tensor[ii]) + 128); + } + + // Step 2 will transform the layout of a 32-bit register in CUDA in order to match the int4 layout. This has no + // performance benefit and is purely so that int4 and int8 have the same layout. + // Pictorially, this does the following: + // bit 32 0 + // [elt_3 elt_2 elt_1 elt_0] (each elt occupies 8 bits) + // + // And it will rearrange the output 32 bit register to be the following: + // bit 32 0 + // [elt_3 elt_1 elt_2 elt_0] (each elt occupies 8 bits) + + ORT_ENFORCE(num_elts % 4 == 0, "Dimensions of int8 tensor must be a multiple of 4 for register relayout"); + for (size_t base = 0; base < num_elts; base += 4) { + std::swap(int8_tensor[base + 1], int8_tensor[base + 2]); + } +} + +void add_bias_and_interleave_int4s_inplace(int8_t* packed_int4_tensor, const size_t num_elts) { + size_t const num_bytes = num_elts / 2; + + // Step 1 will be to transform all the int4s to unsigned in order to make the dequantize take as little + // instructions as possible in the CUDA code. + for (size_t ii = 0; ii < num_bytes; ++ii) { + int8_t transformed_packed_int4s = 0; + int8_t transformed_first_elt = (int8_t(packed_int4_tensor[ii] << 4) >> 4) + 8; // The double shift here is to ensure sign extension + int8_t transformed_second_elt = (packed_int4_tensor[ii] >> 4) + 8; + + ORT_ENFORCE( + transformed_first_elt >= 0 && transformed_first_elt <= 15, "Illegal result for int4 transform (first elt)"); + ORT_ENFORCE(transformed_second_elt >= 0 && transformed_second_elt <= 15, + "Illegal result for int4 transform (second elt)"); + + // We don't need to mask in these ops since everything should be in the range 0-15 + transformed_packed_int4s |= transformed_first_elt; + transformed_packed_int4s |= (transformed_second_elt << 4); + packed_int4_tensor[ii] = transformed_packed_int4s; + } + + // Step 2 will transform the layout of a 32-bit register in CUDA in order to minimize the number of shift & logical + // instructions That are needed to extract the int4s in the GEMM main loop. Pictorially, the loop below will do the + // following: Take as input a 32 bit register with layout: bit 32 0 + // [elt_7 elt_6 elt_5 elt_4 elt_3 elt_2 elt_1 elt_0] (each elt occupies 4 bits) + // + // And it will rearrange the output 32 bit register to be the following: + // bit 32 0 + // [elt_7 elt_5 elt_3 elt_1 elt_6 elt_4 elt_2 elt_0] (each elt occupies 4 bits) + + ORT_ENFORCE(num_bytes % 4 == 0, "Dimensions of int4 tensor must be a multiple of 8 for register relayout"); + const size_t num_registers = num_bytes / 4; + + uint32_t* register_ptr = reinterpret_cast(packed_int4_tensor); + for (size_t ii = 0; ii < num_registers; ++ii) { + const uint32_t current_register = register_ptr[ii]; + uint32_t transformed_register = 0; + + for (int dest_idx = 0; dest_idx < 8; ++dest_idx) { + int const src_idx = dest_idx < 4 ? 2 * dest_idx : 2 * (dest_idx - 4) + 1; + int const src_shift = 4 * src_idx; + int const dest_shift = 4 * dest_idx; + + const uint32_t src_bits = (current_register >> src_shift) & 0xF; + transformed_register |= (src_bits << dest_shift); + } + register_ptr[ii] = transformed_register; + } +} + +void add_bias_and_interleave_quantized_tensor_inplace(int8_t* tensor, const size_t num_elts, QuantType quant_type) { + ORT_LLM_LOG_TRACE(__PRETTY_FUNCTION__); + if (quant_type == QuantType::W8_A16) { + add_bias_and_interleave_int8s_inplace(tensor, num_elts); + } else if (quant_type == QuantType::W4_A16 || quant_type == QuantType::W4_AFP8) { + // W4_AFP8 uses the same preprocessor as W4_A16 because the FP8 data must + // be converted to FP16 before the scales can be applied using CUDA cores. + // As a result, we still want permute the data so that it is well aligned + // for conversion to FP16. + add_bias_and_interleave_int4s_inplace(tensor, num_elts); + } else { + ORT_THROW("Invalid quantization type for interleaving."); + } +} + +void interleave_column_major_tensor(int8_t* interleaved_quantized_tensor, int8_t const* quantized_tensor, + std::vector const& shape, QuantType quant_type, LayoutDetails details) { + ORT_LLM_LOG_TRACE(__PRETTY_FUNCTION__); + + ORT_ENFORCE(shape.size() == 2 || shape.size() == 3, "Shape must be 2-D or 3-D"); + const size_t num_experts = shape.size() == 2 ? 1 : shape[0]; + const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1]; + const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2]; + + int const BITS_PER_ELT = get_weight_quant_bits(quant_type); + int const elts_in_int32 = 32 / BITS_PER_ELT; + + int const rows_per_tile = details.rows_per_column_tile; + + ORT_ENFORCE(!(num_rows % elts_in_int32), + "The number of rows must be a multiple of ", elts_in_int32, " but the number of rows is ", num_rows); + + uint32_t const* input_byte_ptr = reinterpret_cast(quantized_tensor); + uint32_t* output_byte_ptr = reinterpret_cast(interleaved_quantized_tensor); + + ORT_ENFORCE(!(num_rows % rows_per_tile), + "The number of rows must be a multiple of ", rows_per_tile, " but the number of rows is ", num_rows); + + int const num_vec_rows = num_rows / elts_in_int32; + int const vec_rows_per_tile = rows_per_tile / elts_in_int32; + int const interleave = details.columns_interleaved; + + for (int expert = 0; expert < static_cast(num_experts); ++expert) { + const int64_t matrix_offset = expert * int64_t(num_vec_rows) * int64_t(num_cols); + for (int64_t read_col = 0; read_col < static_cast(num_cols); ++read_col) { + const int64_t write_col = read_col / interleave; + for (int base_vec_row = 0; base_vec_row < num_vec_rows; base_vec_row += vec_rows_per_tile) { + for (int vec_read_row = base_vec_row; + vec_read_row < std::min(num_vec_rows, base_vec_row + vec_rows_per_tile); ++vec_read_row) { + const int64_t vec_write_row = interleave * base_vec_row + vec_rows_per_tile * (read_col % interleave) + vec_read_row % vec_rows_per_tile; + + const int64_t read_offset = matrix_offset + read_col * num_vec_rows + vec_read_row; + const int64_t write_offset = matrix_offset + int64_t(write_col) * num_vec_rows * interleave + vec_write_row; + output_byte_ptr[write_offset] = input_byte_ptr[read_offset]; + } + } + } + } +} + +void preprocess_weights_for_mixed_gemm(int8_t* preprocessed_quantized_weight, int8_t const* row_major_quantized_weight, + std::vector const& shape, QuantType quant_type, bool force_interleave) { + int arch = getSMVersion(); + if (force_interleave && arch >= 90) { + // Workaround for MOE which doesn't have specialized Hopper/Blackwell kernels yet + arch = 80; + } + // Force use sm80 kernel for GB20x. + if (arch >= 100) { + arch = 80; + } + LayoutDetails details = getLayoutDetailsForTransform(quant_type, arch); + + ORT_ENFORCE(shape.size() == 2 || shape.size() == 3, "Shape must be 2-D or 3-D"); + + size_t num_elts = 1; + for (auto const& dim : shape) { + num_elts *= dim; + } + + const size_t num_bytes = num_elts * get_weight_quant_bits(quant_type) / 8; + + std::vector src_buf(num_bytes); + std::vector dst_buf(num_bytes); + std::copy(row_major_quantized_weight, row_major_quantized_weight + num_bytes, src_buf.begin()); + + // Works on row major data, so issue this permutation first. + if (details.uses_imma_ldsm) { + permute_B_rows_for_mixed_gemm(dst_buf.data(), src_buf.data(), shape, quant_type); + src_buf.swap(dst_buf); + } + + if (details.layoutB == LayoutDetails::Layout::COLUMN_MAJOR) { + subbyte_transpose(dst_buf.data(), src_buf.data(), shape, quant_type); + src_buf.swap(dst_buf); + } + + if (details.columns_interleaved > 1 && arch != 90) { + interleave_column_major_tensor(dst_buf.data(), src_buf.data(), shape, quant_type, details); + src_buf.swap(dst_buf); + } + + add_bias_and_interleave_quantized_tensor_inplace(src_buf.data(), num_elts, quant_type); + std::copy(src_buf.begin(), src_buf.end(), preprocessed_quantized_weight); +} + +/* + Arguments: + input_weight_ptr - the weight tensor to be quantized. Must be 2-D or 3-D and of type FP16. + + quant_type - the type of the output quantization weight. + + This function does symmetric quantization on 2-D or 3-D tensors. It uses the full int range and assumes the + zero-point is zero and will automatically construct the scales. + + It always quantizes the last axis of the tensor. For 3-D tensors, it operates in "batched" mode where the tensor is + viewed as a stack of matrices and a scale is produced for each column of every matrix. + +Outputs + processed_quantized_weight - quantized AND processed weight for GEMM. This MUST be used with the CUTLASS GEMM + unprocessed_quantized_weight - quantized but unprocessed weights. Useful for reference checking. + scale_ptr - scales for the quantized weight. + + Note that the returned quantized_weights will be preprocessed in a way to accelerate the mixed type GEMM. The data + layout may not make sense if printed. + + Shapes: + quant_type == int8: + If weight is a [m,n] matrix, quantized_weights will have shape [m,n] and scales of shape [n] + If weight is a [b,m,n] tensor, unprocessed_quantized_weight will have shape [b,m,n] and scales of shape [b,n] + quant_type == int4: + If weight is a [m,n] matrix, quantized_weights will have shape [m, ceil(n/2)] and scales of shape [n] + If weight is a [b,m,n] tensor, unprocessed_quantized_weight will have shape [b,m, ceil(n/2)] and scales of shape + [b,n] + + The quantized_weight will be of type torch.int8 and have two int4 values packed in a single byte. This is the + reason for halving the shape. At the time of writing this code, there was not an elegant way to handle this kind + of batched quantization using torch's quantized tensors (to the best of the author's knowledge). Scale tensors + must have a dimension of 1, which breaks the semantics we need for batched weights. + */ + +template +void symmetric_quantize(int8_t* processed_quantized_weight, int8_t* unprocessed_quantized_weight, + ComputeType* scale_ptr, WeightType const* input_weight_ptr, std::vector const& shape, QuantType quant_type, + bool force_interleave) { + ORT_ENFORCE(processed_quantized_weight, "Processed quantized tensor is NULL"); + ORT_ENFORCE(scale_ptr, "Scale output pointer is NULL"); + ORT_ENFORCE(input_weight_ptr, "Input weight pointer is NULL"); + + ORT_ENFORCE(shape.size() == 2 || shape.size() == 3, "Shape must be 2-D or 3-D"); + const size_t num_experts = shape.size() == 2 ? 1 : shape[0]; + const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1]; + const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2]; + + int const bits_in_type = get_weight_quant_bits(quant_type); + int const bytes_per_out_col = num_cols * bits_in_type / 8; + + int const bits_per_weigtht_element = get_weight_quant_bits(quant_type); + + std::vector weight_buf; + if (unprocessed_quantized_weight == nullptr) { + weight_buf.resize(num_experts * num_rows * num_cols); + unprocessed_quantized_weight = weight_buf.data(); + } + + int const input_mat_size = num_rows * num_cols; + int const quantized_mat_size = num_rows * bytes_per_out_col; + float const quant_range_scale = 1.f / float(1 << (bits_in_type - 1)); + + std::vector per_col_max(num_cols); + + for (int expert = 0; expert < static_cast(num_experts); ++expert) { + WeightType const* current_weight = input_weight_ptr + expert * input_mat_size; + int8_t* current_quantized_weight = unprocessed_quantized_weight + expert * quantized_mat_size; + + // First we find the per column max for this expert weight. + for (size_t jj = 0; jj < num_cols; ++jj) { + per_col_max[jj] = 0.f; + } + + for (size_t ii = 0; ii < num_rows; ++ii) { + WeightType const* current_weight_row = current_weight + ii * num_cols; + for (size_t jj = 0; jj < num_cols; ++jj) { + per_col_max[jj] = std::max(per_col_max[jj], std::abs(float(current_weight_row[jj]))); + } + } + + // Then, we construct the scales + ComputeType* current_scales = scale_ptr + expert * num_cols; + for (size_t jj = 0; jj < num_cols; ++jj) { + per_col_max[jj] *= quant_range_scale; + current_scales[jj] = ComputeType(per_col_max[jj]); + } + + // Finally, construct the weights. + for (size_t ii = 0; ii < num_rows; ++ii) { + int8_t* current_quantized_weight_row = current_quantized_weight + ii * bytes_per_out_col; + WeightType const* current_weight_row = current_weight + ii * num_cols; + for (int jj = 0; jj < bytes_per_out_col; ++jj) { + if (bits_per_weigtht_element == 8) { + float const col_scale = per_col_max[jj]; + float const weight_elt = float(current_weight_row[jj]); + float const scaled_weight = (col_scale != 0.0f) ? round(weight_elt / col_scale) : 0.0f; + const int8_t clipped_weight = int8_t(std::max(-128.f, std::min(127.f, scaled_weight))); + current_quantized_weight_row[jj] = clipped_weight; + } else if (bits_per_weigtht_element == 4) { + // We will pack two int4 elements per iteration of the inner loop. + int8_t packed_int4s = 0; + for (int packed_idx = 0; packed_idx < 2; ++packed_idx) { + int const input_idx = 2 * jj + packed_idx; + if (input_idx < static_cast(num_cols)) { + float const col_scale = per_col_max[input_idx]; + float const weight_elt = float(current_weight_row[input_idx]); + float const scaled_weight = (col_scale != 0.0f) ? round(weight_elt / col_scale) : 0.0f; + int int_weight = int(scaled_weight); + const int8_t clipped_weight = std::max(-8, std::min(7, int_weight)); + + // Kill the sign extension bits (hence 0x0F mask) then shift to upper bits + // if packing the second int4 and or the bits into the final result. + packed_int4s |= ((clipped_weight & 0x0F) << (4 * packed_idx)); + } + } + current_quantized_weight_row[jj] = packed_int4s; + } else { + ORT_THROW("Unsupported quantization type"); + } + } + } + } + + preprocess_weights_for_mixed_gemm( + processed_quantized_weight, unprocessed_quantized_weight, shape, quant_type, force_interleave); +} + +template void symmetric_quantize( + int8_t*, int8_t*, half*, float const*, std::vector const&, QuantType, bool); + +template void symmetric_quantize( + int8_t*, int8_t*, half*, half const*, std::vector const&, QuantType, bool); + +template void symmetric_quantize<__nv_bfloat16, __nv_bfloat16>( + int8_t*, int8_t*, __nv_bfloat16*, __nv_bfloat16 const*, std::vector const&, QuantType, bool); + +template void symmetric_quantize<__nv_bfloat16, float>( + int8_t*, int8_t*, __nv_bfloat16*, float const*, std::vector const&, QuantType, bool); + +template +void symmetric_quantize(int8_t* processed_quantized_weight, ComputeType* scale_ptr, WeightType const* input_weight_ptr, + std::vector const& shape, QuantType quant_type, bool force_interleave) { + symmetric_quantize( + processed_quantized_weight, nullptr, scale_ptr, input_weight_ptr, shape, quant_type, force_interleave); +} + +template void symmetric_quantize( + int8_t*, float*, float const*, std::vector const&, QuantType, bool); + +template void symmetric_quantize( + int8_t*, half*, float const*, std::vector const&, QuantType, bool); + +template void symmetric_quantize(int8_t*, half*, half const*, std::vector const&, QuantType, bool); + +template void symmetric_quantize<__nv_bfloat16, __nv_bfloat16>( + int8_t*, __nv_bfloat16*, __nv_bfloat16 const*, std::vector const&, QuantType, bool); + +template void symmetric_quantize<__nv_bfloat16, half>( + int8_t*, __nv_bfloat16*, half const*, std::vector const&, QuantType, bool); + +template void symmetric_quantize( + int8_t*, half*, __nv_bfloat16 const*, std::vector const&, QuantType, bool); + +template void symmetric_quantize<__nv_bfloat16, float>( + int8_t*, __nv_bfloat16*, float const*, std::vector const&, QuantType, bool); + +} // namespace cutlass_kernels +} // namespace kernels +} // namespace onnxruntime::llm diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_preprocessors.h b/onnxruntime/contrib_ops/cuda/llm/cutlass_preprocessors.h new file mode 100644 index 0000000000000..3e83852228e24 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_preprocessors.h @@ -0,0 +1,75 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +#include "core/common/common.h" + +namespace onnxruntime::llm { +namespace kernels { +namespace cutlass_kernels { + +enum class QuantType { + W8_A16, + W4_A16, + W4_AFP8 +}; + +constexpr int get_weight_quant_bits(QuantType quant_type) { + switch (quant_type) { + case QuantType::W8_A16: + return 8; + case QuantType::W4_A16: + return 4; + case QuantType::W4_AFP8: + return 4; + default: + ORT_THROW("Invalid quant_type"); + return -1; + } +} + +// Shapes here can be 2 or 3D. 2-D shapes are [num_rows, num_cols] +// 3-D shapes are [num_experts, num_rows, num_cols] +void permute_B_rows_for_mixed_gemm(int8_t* permuted_quantized_tensor, int8_t const* quantized_tensor, + std::vector const& shape, QuantType quant_type); + +void subbyte_transpose(int8_t* transposed_quantized_tensor, int8_t const* quantized_tensor, + std::vector const& shape, QuantType quant_type); + +void add_bias_and_interleave_quantized_tensor_inplace(int8_t* tensor, const size_t num_elts, QuantType quant_type); + +void preprocess_weights_for_mixed_gemm(int8_t* preprocessed_quantized_weight, int8_t const* row_major_quantized_weight, + std::vector const& shape, QuantType quant_type, bool force_interleave = false); + +template +void symmetric_quantize(int8_t* processed_quantized_weight, ComputeType* scale_ptr, WeightType const* input_weight_ptr, + std::vector const& shape, QuantType quant_type, bool force_interleave); + +// This is exposed so that we can write tests that use the processed weights for CUTLASS but the unprocessed weight +// to implement a simple reference implementation. +template +void symmetric_quantize(int8_t* processed_quantized_weight, int8_t* unprocessed_quantized_weight, + ComputeType* scale_ptr, WeightType const* input_weight_ptr, std::vector const& shape, QuantType quant_type, + bool force_interleave); + +} // namespace cutlass_kernels +} // namespace kernels +} // namespace onnxruntime::llm diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_type_conversion.h b/onnxruntime/contrib_ops/cuda/llm/cutlass_type_conversion.h new file mode 100644 index 0000000000000..1fe8035cbcdae --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_type_conversion.h @@ -0,0 +1,146 @@ +/* + * Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "contrib_ops/cuda/llm/nv_infer_datatype.h" + +#include "cutlass/half.h" +#include + +#include "cutlass/bfloat16.h" +#include + +#include "cutlass/float8.h" +#include + +#if defined(ENABLE_FP4) +#include "cutlass/float_subbyte.h" +#include +#endif + +namespace onnxruntime::llm { +namespace kernels { +namespace cutlass_kernels { +/////////////////////////////////////////////////////////////////////////////////////////////////// +// nvinfer::DataType to Cutlass +/////////////////////////////////////////////////////////////////////////////////////////////////// +template +struct CutlassType { + using type = void; +}; + +template <> +struct CutlassType { + using type = cutlass::half_t; +}; + +template <> +struct CutlassType { + using type = cutlass::bfloat16_t; +}; + +template <> +struct CutlassType { + using type = cutlass::float_e4m3_t; +}; + +#if defined(ENABLE_FP4) +template <> +struct CutlassType { + using type = cutlass::float_e2m1_t; +}; +#endif + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// CUDA to Cutlass + +template +struct CudaToCutlassTypeAdapter { + using type = T; +}; + +template <> +struct CudaToCutlassTypeAdapter { + using type = cutlass::half_t; +}; + +template <> +struct CudaToCutlassTypeAdapter<__nv_bfloat16> { + using type = cutlass::bfloat16_t; +}; + +#if defined(ENABLE_FP8) +template <> +struct CudaToCutlassTypeAdapter<__nv_fp8_e4m3> { + using type = cutlass::float_e4m3_t; +}; + +template <> +struct CudaToCutlassTypeAdapter<__nv_fp8_e5m2> { + using type = cutlass::float_e5m2_t; +}; +#endif + +#if defined(ENABLE_FP4) +template <> +struct CudaToCutlassTypeAdapter<__nv_fp4_e2m1> { + using type = cutlass::float_e2m1_t; +}; +#endif + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// Cutlass to CUDA + +template +struct CudaToCudaTypeAdapter { + using type = T; +}; + +template <> +struct CudaToCudaTypeAdapter { + using type = half; +}; + +template <> +struct CudaToCudaTypeAdapter { + using type = __nv_bfloat16; +}; + +#if defined(ENABLE_FP8) +template <> +struct CudaToCudaTypeAdapter { + using type = __nv_fp8_e4m3; +}; + +template <> +struct CudaToCudaTypeAdapter { + using type = __nv_fp8_e5m2; +}; +#endif + +#if defined(ENABLE_FP4) +template <> +struct CudaToCudaTypeAdapter { + using type = __nv_fp4_e2m1; +}; +#endif + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass_kernels +} // namespace kernels +} // namespace onnxruntime::llm diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/bf16_int4_gemm_scale_zeros.cu b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/bf16_int4_gemm_scale_zeros.cu new file mode 100644 index 0000000000000..47e662b9a88ba --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/bf16_int4_gemm_scale_zeros.cu @@ -0,0 +1,26 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "contrib_ops/cuda/llm/fpA_intB_gemm/fpA_intB_gemm_template.h" + +namespace onnxruntime::llm { +namespace kernels { +namespace cutlass_kernels { +template class CutlassFpAIntBGemmRunner<__nv_bfloat16, cutlass::uint4b_t, + cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS>; +} // namespace cutlass_kernels +} // namespace kernels +} // namespace onnxruntime::llm diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/bf16_int8_gemm_scale_zeros.cu b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/bf16_int8_gemm_scale_zeros.cu new file mode 100644 index 0000000000000..9452aa0e1fbe6 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/bf16_int8_gemm_scale_zeros.cu @@ -0,0 +1,26 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "contrib_ops/cuda/llm/fpA_intB_gemm/fpA_intB_gemm_template.h" + +namespace onnxruntime::llm { +namespace kernels { +namespace cutlass_kernels { +template class CutlassFpAIntBGemmRunner<__nv_bfloat16, uint8_t, + cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS>; +} // namespace cutlass_kernels +} // namespace kernels +} // namespace onnxruntime::llm diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/fp16_int4_gemm_scale_zeros.cu b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/fp16_int4_gemm_scale_zeros.cu new file mode 100644 index 0000000000000..4a22e0f1b2aac --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/fp16_int4_gemm_scale_zeros.cu @@ -0,0 +1,26 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "contrib_ops/cuda/llm/fpA_intB_gemm/fpA_intB_gemm_template.h" + +namespace onnxruntime::llm { +namespace kernels { +namespace cutlass_kernels { +template class CutlassFpAIntBGemmRunner; +} // namespace cutlass_kernels +} // namespace kernels +} // namespace onnxruntime::llm diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/fp16_int8_gemm_scale_zeros.cu b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/fp16_int8_gemm_scale_zeros.cu new file mode 100644 index 0000000000000..9f4091be4cd07 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/fp16_int8_gemm_scale_zeros.cu @@ -0,0 +1,25 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "contrib_ops/cuda/llm/fpA_intB_gemm/fpA_intB_gemm_template.h" + +namespace onnxruntime::llm { +namespace kernels { +namespace cutlass_kernels { +template class CutlassFpAIntBGemmRunner; +} // namespace cutlass_kernels +} // namespace kernels +} // namespace onnxruntime::llm diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/fpA_intB_gemm.h b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/fpA_intB_gemm.h new file mode 100644 index 0000000000000..0141c76bbc031 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/fpA_intB_gemm.h @@ -0,0 +1,135 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm_configs.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/weight_only_quant_op.h" +#include +#include + +namespace tkc = onnxruntime::llm::cutlass_extensions; + +namespace onnxruntime::llm { +namespace kernels { +namespace cutlass_kernels { + +// TRT Activation Type does not have Gelu or Silu +enum class ActivationType { + Gelu, + Relu, + Silu, + Identity, + InvalidType +}; + +/* + This runner only supports: + T in {half, __nv_bfloat} WeightType in {int8_t, cutlass::uint4b_t} + + Activations, biases, scales and outputs are all assumed to be row-major. + + However, it is assumed that B is in a special format governed by cutlass_extensions/gemm/kernel/mixed_gemm_B_layout. + In this case, B must be preprocessed using the cutlass weight only quant preprocessors. The weight preprocessor + will instantiate the layout and preprocess based on the instantiation, so layout changes should only require + modifications to mix_gemm_B_layout.h. +*/ + +class CutlassFpAIntBGemmRunnerInterface { + public: + CutlassFpAIntBGemmRunnerInterface() {} + + virtual ~CutlassFpAIntBGemmRunnerInterface() {} + + virtual void gemm(void const* A, void const* B, void const* weight_scales, void* C, int m, int n, int k, + tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream) = 0; + + virtual void gemm(void const* A, void const* B, void const* weight_scales, float const alpha, void* C, int m, int n, + int k, tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes, + cudaStream_t stream) = 0; + + virtual void gemm(void const* A, void const* B, void const* weight_scales, void const* weight_zero_points, + void const* biases, void* C, int m, int n, int k, int const group_size, tkc::CutlassGemmConfig gemmConfig, + char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream) = 0; + + virtual void gemm(void const* A, void const* B, void const* weight_scales, void const* weight_zero_points, + void const* biases, float const alpha, void* C, int m, int n, int k, int const group_size, + tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream) = 0; + + // Returns desired workspace size in bytes. + virtual size_t getWorkspaceSize(int const m, int const n, int const k) = 0; + + virtual std::vector getConfigs() const = 0; + + protected: + static constexpr int SPLIT_K_LIMIT = 7; + static constexpr int MIN_M_TILE = 16; + static constexpr int MIN_N_TILE = 64; + + static constexpr int MAX_M_TILE_SM90 = 128; + static constexpr int MAX_N_TILE_SM90 = 256; +}; + +template +class CutlassFpAIntBGemmRunner : public virtual CutlassFpAIntBGemmRunnerInterface { + public: + CutlassFpAIntBGemmRunner(); + ~CutlassFpAIntBGemmRunner(); + + void gemm(void const* A, void const* B, void const* weight_scales, void* C, int m, int n, int k, + tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes, + cudaStream_t stream) override; + + void gemm(void const* A, void const* B, void const* weight_scales, float const alpha, void* C, int m, int n, int k, + tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes, + cudaStream_t stream) override; + + void gemm(void const* A, void const* B, void const* weight_scales, void const* weight_zero_points, + void const* biases, void* C, int m, int n, int k, int const group_size, tkc::CutlassGemmConfig gemmConfig, + char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream) override; + + void gemm(void const* A, void const* B, void const* weight_scales, void const* weight_zero_points, + void const* biases, float const alpha, void* C, int m, int n, int k, int const group_size, + tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes, + cudaStream_t stream) override; + + // Disabled since the fused GEMM, activation kernels will not be used in v1. + + // void gemm_bias_act(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, int m, int n, + // int k, ActivationType activation_type, char* workspace_ptr, const size_t workspace_bytes, cudaStream_t + // stream); + + // Returns desired workspace size in bytes. + size_t getWorkspaceSize(int const m, int const n, int const k) override; + + std::vector getConfigs() const override; + + private: + template + void dispatch_to_arch(ActivationType const* A, WeightType const* B, ScaleZeroType const* weight_scales, + ScaleZeroType const* weight_zero_points, BiasType const* biases, float const alpha, OutputType* C, int m, int n, + int k, int const group_size, tkc::CutlassGemmConfig gemm_config, char* workspace_ptr, + const size_t workspace_bytes, cudaStream_t stream, int* occupancy = nullptr); + + private: + int sm_; + int multi_processor_count_; +}; + +} // namespace cutlass_kernels +} // namespace kernels +} // namespace onnxruntime::llm diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/fpA_intB_gemm_template.h b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/fpA_intB_gemm_template.h new file mode 100644 index 0000000000000..715397270331b --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/fpA_intB_gemm_template.h @@ -0,0 +1,489 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifdef __GNUC__ // Check if the compiler is GCC or Clang +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#endif // __GNUC__ + +#include "cutlass/gemm/kernel/default_gemm.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/compute_occupancy.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/device/gemm_universal_base_compat.h" + +#include "contrib_ops/cuda/llm/cutlass_extensions/epilogue_helpers.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/default_mma.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm_configs.h" + +#ifdef __GNUC__ // Check if the compiler is GCC or Clang +#pragma GCC diagnostic pop +#endif // __GNUC__ + +#include "core/common/common.h" +#include "contrib_ops/cuda/llm/common/cuda_runtime_utils.h" +#include "contrib_ops/cuda/llm/common/logger.h" +#include "contrib_ops/cuda/llm/cutlass_heuristic.h" +#include "contrib_ops/cuda/llm/cutlass_type_conversion.h" +#include "contrib_ops/cuda/llm/fpA_intB_gemm/fpA_intB_gemm.h" +#include "contrib_ops/cuda/llm/fpA_intB_gemm/fpA_intB_gemm_template_sm90.h" +#include "core/providers/cuda/shared_inc/cuda_call.h" + +namespace tk = onnxruntime::llm::common; +namespace tkc = onnxruntime::llm::cutlass_extensions; + +namespace onnxruntime::llm { +namespace kernels { +namespace cutlass_kernels { + +template +void generic_mixed_gemm_kernelLauncher(ActivationType const* A, WeightType const* B, ScaleZeroType const* weight_scales, + ScaleZeroType const* weight_zero_points, BiasType const* biases, float const alpha, OutputType* C, int m, int n, + int k, int const group_size, tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes, + cudaStream_t stream, int* occupancy = nullptr) { + ORT_LLM_LOG_DEBUG(__PRETTY_FUNCTION__); + + static_assert( +#ifdef ENABLE_FP8 + cutlass::platform::is_same::value || +#endif + cutlass::platform::is_same::value || cutlass::platform::is_same::value || cutlass::platform::is_same::value, + "Specialized for bfloat16, half, float"); + + static_assert(cutlass::platform::is_same::value || cutlass::platform::is_same::value || cutlass::platform::is_same::value, + ""); + + // The cutlass type for the input elements. This is needed to convert to cutlass::half_t if necessary. + using CutlassActivationType = typename CudaToCutlassTypeAdapter::type; + using CutlassWeightType = typename CudaToCutlassTypeAdapter::type; + using CutlassScaleZeroType = typename CudaToCutlassTypeAdapter::type; + using CutlassBiasType = typename CudaToCutlassTypeAdapter::type; + using CutlassOutputType = typename CudaToCutlassTypeAdapter::type; + + // We need separate config for each architecture since we will target different tensorcore instructions. For float, + // we do not target TCs. + using MixedGemmArchTraits = cutlass::gemm::kernel::MixedGemmArchTraits; + using ElementAccumulator = typename MixedGemmArchTraits::AccType; + + constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits::value; + using EpilogueOp = + typename tkc::Epilogue::Op; + + using Operator = typename MixedGemmArchTraits::Operator; + using TaggedOperator = typename cutlass::arch::TagOperator::TaggedOperator; + + using GemmKernel_ = typename cutlass::gemm::kernel::DefaultGemm, Stages, true, + TaggedOperator>::GemmKernel; + + using GemmKernel = cutlass::gemm::kernel::GemmFpAIntB; + + if (occupancy != nullptr) { + *occupancy = onnxruntime::llm::cutlass_extensions::compute_occupancy_for_kernel(); + return; + } + + using Gemm = cutlass::gemm::device::GemmUniversalBaseCompat; + + int const ldb = cutlass::platform::is_same::value + ? n + : k * GemmKernel::kInterleave; + + if (weight_scales == nullptr) { + ORT_THROW("Weight scales must always be set to a non-null value."); + } + + if constexpr (cutlass::isFinegrained(QuantOp)) { + if constexpr (cutlass::platform::is_same::value) { + if (group_size != 128) { + ORT_THROW("Only group size 128 supported for fine grained W4A(fp)8 kernels."); + } + } + if (group_size != 64 && group_size != 128) { + ORT_THROW("Only group size 64 and 128 supported for fine grained kernels."); + } + + if constexpr (QuantOp == cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY) { + if (weight_zero_points != nullptr) { + ORT_THROW("Weight zero pointer must be a nullptr for scale only fine grained"); + } + } else if constexpr (QuantOp == cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS) { + if (weight_zero_points == nullptr) { + ORT_THROW("Weight zero pointer must be valid for scale and bias fine grained"); + } + } + } else { + if (group_size != k) { + ORT_THROW("Invalid group size for per column scaling kernels."); + } + + if (weight_zero_points != nullptr) { + ORT_THROW("Weight zero-points must be null when running per column scaling"); + } + } + + int const ld_scale_zero = cutlass::isFinegrained(QuantOp) ? n : 0; + ElementAccumulator output_op_beta = (biases == nullptr) ? ElementAccumulator(0.f) : ElementAccumulator(1.f); + typename Gemm::Arguments args({m, n, k}, group_size, + {reinterpret_cast(const_cast(A)), k}, + {reinterpret_cast(const_cast(B)), ldb}, + {reinterpret_cast(const_cast(weight_scales)), ld_scale_zero}, + {reinterpret_cast(const_cast(weight_zero_points)), ld_scale_zero}, + {reinterpret_cast(const_cast(biases)), 0}, + {reinterpret_cast(C), n}, gemm_config.split_k_factor, + {ElementAccumulator(alpha), output_op_beta}); + + // This assertion is enabled because because for the column interleaved layout, K MUST be a multiple of + // threadblockK. The reason for this is that the default pitchlinear iterators are used to handle walking over the + // interleaved matrix. The way masking in handled in these do not map to the interleaved layout. We need to write + // our own predicated iterator in order to relax this limitation. + if (GemmKernel::kInterleave > 1 && ((k % MixedGemmArchTraits::ThreadblockK) || ((k / gemm_config.split_k_factor) % MixedGemmArchTraits::ThreadblockK))) { + ORT_THROW("Temp assertion: k must be multiple of threadblockK"); + } + + Gemm gemm; + if (gemm.get_workspace_size(args) > workspace_bytes) { + ORT_LLM_LOG_WARNING( + "Requested split-k but workspace size insufficient. Falling back to non-split-k implementation."); + // If requested split-k factor will require more workspace bytes, revert to standard gemm. + args.batch_count = 1; + } + + auto can_implement = gemm.can_implement(args); + if (can_implement != cutlass::Status::kSuccess) { + std::string err_msg = "fpA_intB cutlass kernel will fail for params. Error: " + std::string(cutlassGetStatusString(can_implement)); + ORT_THROW("[fpA_intB_gemm] Error:", err_msg); + } + + auto init_status = gemm.initialize(args, workspace, stream); + if (init_status != cutlass::Status::kSuccess) { + std::string err_msg = "Failed to initialize cutlass fpA_intB gemm. Error: " + std::string(cutlassGetStatusString(init_status)); + ORT_THROW("[fpA_intB_gemm] Error:", err_msg); + } + + auto run_status = gemm.run(stream); + if (run_status != cutlass::Status::kSuccess) { + std::string err_msg = "Failed to run cutlass fpA_intB gemm. Error: " + std::string(cutlassGetStatusString(run_status)); + ORT_THROW("[fpA_intB_gemm] Error:", err_msg); + } +} + +// This filters out invalid template combinations that we DON'T want instantiated in CUTLASS. For example, +// instantiating SM=75, Stages=3 is invalid so we would need to filter that out. Fine grained +// quanitzation is only supported on Ampere+ GPUs. FP8 GEMM is only supported on Ada+ GPUs. +template +void filter_and_run_mixed_gemm(ActivationType const* A, WeightType const* B, ScaleZeroType const* weight_scales, + ScaleZeroType const* weight_zero_points, BiasType const* biases, float const alpha, OutputType* C, int m, int n, + int k, int const group_size, tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes, + cudaStream_t stream, int* occupancy = nullptr) { + ORT_LLM_LOG_DEBUG(__PRETTY_FUNCTION__); + if constexpr (Stages > 2 && arch::kMinComputeCapability < 80) { + // Multistage only supported on Ampere + std::string err_msg = "Cutlass fpA_intB gemm not supported for arch " + std::to_string(arch::kMinComputeCapability) + " with stages set to " + std::to_string(Stages); + ORT_THROW("[fpA_intB_gemm] Error:", err_msg); + } else if constexpr (Stages == 2 && arch::kMinComputeCapability >= 89) { + // Multistage only supported on Ampere + std::string err_msg = "Cutlass fpA_intB gemm not supported for arch " + std::to_string(arch::kMinComputeCapability) + " with stages set to " + std::to_string(Stages); + ORT_THROW("[fpA_intB_gemm] Error:", err_msg); + } else if constexpr (cutlass::platform::is_same::value && arch::kMinComputeCapability < 89) { + // FP8 activation type only supported on Ada+ GPUs + std::string err_msg = "Cutlass fpA_intB gemm not supported for arch " + std::to_string(arch::kMinComputeCapability) + " with activation type set to FP8"; + ORT_THROW("[fpA_intB_gemm] Error:", err_msg); + } else { + generic_mixed_gemm_kernelLauncher(A, B, weight_scales, weight_zero_points, biases, + alpha, C, m, n, k, group_size, gemm_config, workspace, workspace_bytes, stream, occupancy); + } +} + +template +void dispatch_gemm_config(ActivationType const* A, WeightType const* B, ScaleZeroType const* weight_scales, + ScaleZeroType const* weight_zero_points, BiasType const* biases, float const alpha, OutputType* C, int m, int n, + int k, int const group_size, tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes, + cudaStream_t stream, int* occupancy = nullptr) { + ORT_LLM_LOG_DEBUG(__PRETTY_FUNCTION__); + switch (gemm_config.stages) { + case 2: + filter_and_run_mixed_gemm(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, + n, k, group_size, gemm_config, workspace, workspace_bytes, stream, occupancy); + break; + case 3: + filter_and_run_mixed_gemm(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, + n, k, group_size, gemm_config, workspace, workspace_bytes, stream, occupancy); + break; + case 4: + filter_and_run_mixed_gemm(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, + n, k, group_size, gemm_config, workspace, workspace_bytes, stream, occupancy); + break; + default: + std::string err_msg = "dispatch_gemm_config does not support stages " + std::to_string(gemm_config.stages); + ORT_THROW("[fpA_intB_gemm] Error:", err_msg); + break; + } +} + +template +constexpr bool is_fp8() { + return std::is_same_v || std::is_same_v; +} + +template +void dispatch_gemm_to_cutlass(ActivationType const* A, WeightType const* B, ScaleZeroType const* weight_scales, + ScaleZeroType const* weight_zero_points, BiasType const* biases, float const alpha, OutputType* C, int m, int n, + int k, int const group_size, char* workspace, size_t workspace_bytes, tkc::CutlassGemmConfig gemm_config, + cudaStream_t stream, int* occupancy = nullptr) { + ORT_LLM_LOG_DEBUG(__PRETTY_FUNCTION__); + + // Don't instantiate configs that are not supported pre-hopper. Produce a sensible error instead. + constexpr bool any_is_fp8 = is_fp8() || is_fp8() || is_fp8() || is_fp8() || is_fp8(); + + constexpr bool all_types_are_the_same = std::is_same_v && std::is_same_v && std::is_same_v; + + constexpr bool is_valid_pre_hopper = (all_types_are_the_same && !any_is_fp8) || (arch::kMinComputeCapability == 89); + + if constexpr (is_valid_pre_hopper) { + // Note that SIMT configs are omitted here since they are not supported for fpA_intB. + // We also only instantiate configs here where threadblockShapeM == warpShapeM since those usually perform the + // best for mixed type gemms. + constexpr int tile_shape_k = 128 * 8 / cutlass::sizeof_bits::value; + switch (gemm_config.tile_config_sm80) { + case tkc::CutlassTileConfig::CtaShape16x128x64_WarpShape16x32x64: + dispatch_gemm_config, + cutlass::gemm::GemmShape<16, 32, tile_shape_k>>(A, B, weight_scales, weight_zero_points, biases, alpha, + C, m, n, k, group_size, gemm_config, workspace, workspace_bytes, stream, occupancy); + break; + case tkc::CutlassTileConfig::CtaShape16x256x64_WarpShape16x64x64: + dispatch_gemm_config, + cutlass::gemm::GemmShape<16, 64, tile_shape_k>>(A, B, weight_scales, weight_zero_points, biases, alpha, + C, m, n, k, group_size, gemm_config, workspace, workspace_bytes, stream, occupancy); + break; + case tkc::CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64: + dispatch_gemm_config, + cutlass::gemm::GemmShape<32, 32, tile_shape_k>>(A, B, weight_scales, weight_zero_points, biases, alpha, + C, m, n, k, group_size, gemm_config, workspace, workspace_bytes, stream, occupancy); + break; + case tkc::CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64: + dispatch_gemm_config, + cutlass::gemm::GemmShape<64, 32, tile_shape_k>>(A, B, weight_scales, weight_zero_points, biases, alpha, + C, m, n, k, group_size, gemm_config, workspace, workspace_bytes, stream, occupancy); + break; + case tkc::CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64: + dispatch_gemm_config, + cutlass::gemm::GemmShape<128, 32, tile_shape_k>>(A, B, weight_scales, weight_zero_points, biases, alpha, + C, m, n, k, group_size, gemm_config, workspace, workspace_bytes, stream, occupancy); + break; + case tkc::CutlassTileConfig::Undefined: + ORT_THROW("[fpA_intB_gemm] Error:[dispatch_gemm_to_cutlass] gemm config undefined."); + break; + case tkc::CutlassTileConfig::ChooseWithHeuristic: + ORT_THROW( + "[fpA_intB_gemm] Error:[dispatch_gemm_to_cutlass] gemm config should have already been set by " + "heuristic."); + break; + default: + ORT_THROW( + "[fpA_intB_gemm] Error:[dispatch_gemm_to_cutlass] Config is invalid for mixed type GEMM."); + break; + } + } else { + // This is not a limitation in CUTLASS. We just do not need to support this case. + std::string err_msg = "The activation type must equal the scale, bias and output types on Ampere and earlier."; + ORT_THROW("[fpA_intB_gemm] Error: [dispatch_gemm_to_cutlass] ", err_msg); + } +} + +template +CutlassFpAIntBGemmRunner::CutlassFpAIntBGemmRunner() { + ORT_LLM_LOG_DEBUG(__PRETTY_FUNCTION__); + sm_ = ::onnxruntime::llm::common::getSMVersion(); + multi_processor_count_ = ::onnxruntime::llm::common::getMultiProcessorCount(); +} + +template +CutlassFpAIntBGemmRunner::~CutlassFpAIntBGemmRunner() { + ORT_LLM_LOG_DEBUG(__PRETTY_FUNCTION__); +} + +template +template +void CutlassFpAIntBGemmRunner::dispatch_to_arch(ActivationType const* A, WeightType const* B, + ScaleZeroType const* weight_scales, ScaleZeroType const* weight_zero_points, BiasType const* biases, + float const alpha, OutputType* C, int m, int n, int k, int const group_size, tkc::CutlassGemmConfig gemm_config, + char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream, int* occupancy) { + ORT_LLM_LOG_DEBUG(__PRETTY_FUNCTION__); + + // std::string config_str = gemm_config.toString(); + // printf("######## sm=%d, alpha: %f m:%d n:%d, k:%d, group_size:%d, workspace_bytes:%zu config:%s\n", sm_, alpha, m, n, k, group_size, workspace_bytes, config_str.c_str()); + + if (sm_ >= 75 && sm_ < 80) { + dispatch_gemm_to_cutlass(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size, + workspace_ptr, workspace_bytes, gemm_config, stream, occupancy); + } else if ((sm_ >= 80 && sm_ < 89) || sm_ >= 100) { + dispatch_gemm_to_cutlass(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size, + workspace_ptr, workspace_bytes, gemm_config, stream, occupancy); + } else if (sm_ == 89) { +#if ENABLE_FP8 && ((__CUDACC_VER_MAJOR__ < 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 4)) + if constexpr (cutlass::platform::is_same::value) { + ORT_THROW( + "[fpA_intB_gemm] Error: INT4xFP8 GEMM for Ada needs CUDA>=12.4"); + } +#endif + dispatch_gemm_to_cutlass(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size, + workspace_ptr, workspace_bytes, gemm_config, stream, occupancy); + } else if (sm_ == 90) { + static_assert(!cutlass::platform::is_same::value || cutlass::platform::is_same::value, + "ScaleZeroType must be half for activation=fp8"); + sm90_dispatch_gemm_to_cutlass(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size, workspace_ptr, + workspace_bytes, gemm_config, stream, occupancy); + } else { + ORT_THROW("[fpA_intB_gemm] Error:Arch unsupported for CUTLASS mixed type GEMM"); + } +} + +template +void CutlassFpAIntBGemmRunner::gemm( + void const* A, void const* B, void const* weight_scales, void const* weight_zero_points, void const* biases, + float const alpha, void* C, int m, int n, int k, int const group_size, tkc::CutlassGemmConfig gemmConfig, + char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream) { + ORT_LLM_LOG_DEBUG(__PRETTY_FUNCTION__); + if constexpr ((QuantOp == cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS) || (QuantOp == cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY)) { + dispatch_to_arch((ActivationType const*)A, (WeightType const*)B, + (ScaleZeroType const*)weight_scales, (ScaleZeroType const*)weight_zero_points, (BiasType const*)biases, + alpha, (OutputType*)C, m, n, k, group_size, gemmConfig, workspace_ptr, workspace_bytes, stream, nullptr); + } else { + ORT_THROW("Overload with scale, zero and group size only supported for fine grained bias template."); + } +} + +template +void CutlassFpAIntBGemmRunner::gemm( + void const* A, void const* B, void const* weight_scales, void const* weight_zero_points, void const* biases, + void* C, int m, int n, int k, int const group_size, tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, + const size_t workspace_bytes, cudaStream_t stream) { + ORT_LLM_LOG_DEBUG(__PRETTY_FUNCTION__); + gemm(A, B, weight_scales, weight_zero_points, biases, 1.f, C, m, n, k, group_size, gemmConfig, workspace_ptr, + workspace_bytes, stream); +} + +template +void CutlassFpAIntBGemmRunner::gemm( + void const* A, void const* B, void const* weight_scales, float const alpha, void* C, int m, int n, int k, + tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream) { + ORT_LLM_LOG_DEBUG(__PRETTY_FUNCTION__); + + if constexpr (QuantOp == cutlass::WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY) { + dispatch_to_arch((ActivationType const*)A, (WeightType const*)B, + (ScaleZeroType const*)weight_scales, nullptr, nullptr, alpha, (OutputType*)C, m, n, k, k, gemmConfig, + workspace_ptr, workspace_bytes, stream, nullptr); + } else { + ORT_THROW("Overload with scale only (and no group size) only supported for per column scaling."); + } +} + +template +void CutlassFpAIntBGemmRunner::gemm( + void const* A, void const* B, void const* weight_scales, void* C, int m, int n, int k, + tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream) { + ORT_LLM_LOG_DEBUG(__PRETTY_FUNCTION__); + gemm(A, B, weight_scales, 1.f, C, m, n, k, gemmConfig, workspace_ptr, workspace_bytes, stream); +} + +template +std::vector +CutlassFpAIntBGemmRunner::getConfigs() const { + static constexpr bool is_weight_only = !std::is_same::value; + tkc::CutlassGemmConfig::CandidateConfigTypeParam config_type_param = tkc::CutlassGemmConfig::CandidateConfigTypeParam::HOPPER; + if (is_weight_only) { + config_type_param = static_cast( + config_type_param | tkc::CutlassGemmConfig::CandidateConfigTypeParam::WEIGHT_ONLY); + } + std::vector candidateConfigs = get_candidate_configs(sm_, SPLIT_K_LIMIT, config_type_param); + return candidateConfigs; +} + +template +size_t +CutlassFpAIntBGemmRunner::getWorkspaceSize( + int const m, int const n, int const /*k*/) { + ORT_LLM_LOG_DEBUG(__PRETTY_FUNCTION__); + // For Hopper, we have to allocate large memory size in case for stream-K + if (sm_ == 90) { + // https://github.com/NVIDIA/cutlass/blob/19b4c5e065e7e5bbc8082dfc7dbd792bdac850fc/include/cutlass/gemm/kernel/tile_scheduler_params.h#L878-L892 + // The above lines says sk_tiles = output_tiles - (static_cast(output_tiles / ctas_per_wave) - 1) * + // ctas_per_wave This means sk_tiles is at most 2 * ctas_per_wave, which is 2 * multi_processor_count_ + int const max_sk_tiles = 2 * multi_processor_count_; + + // https://github.com/NVIDIA/cutlass/blob/19b4c5e065e7e5bbc8082dfc7dbd792bdac850fc/include/cutlass/gemm/kernel/tile_scheduler_params.h#L939 + // The above line says uint64_t sk_units = platform::min(ctas_per_sk_wave, min_sized_sk_units); + // That means sk_units is at most ctas_per_sk_wave, which is multi_processor_count_ + int const max_sk_units = multi_processor_count_; + + // https://github.com/NVIDIA/cutlass/blob/19b4c5e065e7e5bbc8082dfc7dbd792bdac850fc/include/cutlass/gemm/kernel/tile_scheduler_params.h#L505 + // The above lines scales sk_tiles by the factor of static_cast(sk_units / sk_tiles + 2) + // That means the final sk_tiles is at most 2 * max_sk_tiles + max_sk_units; + int const max_sk_tiles_with_separate_reduction = 2 * max_sk_tiles + max_sk_units; + + return static_cast( + max_sk_tiles_with_separate_reduction * MAX_M_TILE_SM90 * MAX_N_TILE_SM90 * sizeof(float)); + } + // These are the min tile sizes for each config, which would launch the maximum number of blocks + int const max_grid_m = cutlass::ceil_div(m, MIN_M_TILE); + int const max_grid_n = cutlass::ceil_div(n, MIN_N_TILE); + // We need 4 bytes per block in the worst case. We launch split_k_limit in z dim. + return static_cast(max_grid_m * max_grid_n * SPLIT_K_LIMIT * 4); +} + +} // namespace cutlass_kernels +} // namespace kernels +} // namespace onnxruntime::llm diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/fpA_intB_gemm_template_sm90.h b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/fpA_intB_gemm_template_sm90.h new file mode 100644 index 0000000000000..432adb20079b6 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/fpA_intB_gemm_template_sm90.h @@ -0,0 +1,244 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "cute/numeric/integral_constant.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" + +#include "core/common/common.h" +#include "contrib_ops/cuda/llm/common/cuda_runtime_utils.h" +#include "contrib_ops/cuda/llm/common/logger.h" +#include "contrib_ops/cuda/llm/cutlass_heuristic.h" + +#include "contrib_ops/cuda/llm/fpA_intB_gemm/launchers/fpA_intB_launcher_sm90.h" + +namespace tkc = onnxruntime::llm::cutlass_extensions; + +using namespace cute; + +namespace onnxruntime::llm { +namespace kernels { +namespace cutlass_kernels { + +// This filters out invalid template combinations that we DON'T want instantiated in CUTLASS. For example, +// instantiating SM=75, Stages=3 is invalid so we would need to filter that out. Fine grained +// quanitzation is only supported on Ampere+ GPUs. +template +void sm90_dispatch_epilogue_schedules(ActivationType const* A, WeightType const* B, ScaleZeroType const* weight_scales, + ScaleZeroType const* weight_zero_points, BiasType const* biases, float const alpha, OutputType* C, int m, int n, + int k, int const group_size, tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes, + cudaStream_t stream, int* occupancy = nullptr) { + ORT_LLM_LOG_DEBUG(__PRETTY_FUNCTION__); + switch (gemm_config.epilogue_schedule) { + case tkc::EpilogueScheduleType::AUTO: + using EpilogueScheduleType = cute::conditional_t(CTAShape{}) == Int<64>{}, + cutlass::epilogue::TmaWarpSpecialized, cutlass::epilogue::TmaWarpSpecializedCooperative>; + sm90_generic_mixed_gemm_kernelLauncher(A, B, weight_scales, + weight_zero_points, biases, alpha, C, m, n, k, group_size, gemm_config, workspace, workspace_bytes, stream, + occupancy); + break; + default: + ORT_THROW( + "[fpA_intB_gemm][sm90_dispatch_epilogue_schedules] epilogue schedule config is invalid for " + "mixed type GEMM."); + break; + } +} + +/* + 1x1x1 cluster shape is are supported for any tile shape. + + 2x1x1 cluster shape is only supported for when the M tile is at least 128. + + 1x2x1 cluster shape is only supported when the N tile is at least 128. + + 2x2x1 cluster shape is only supported when both the M and N tiles are at least 128. + + We make the above restrictions to improve compilation speed in TRT-LLM, by pruning kernels + that may not be very useful in practice. + */ +template +constexpr bool are_tile_shapes_supported() { + [[maybe_unused]] constexpr int cta_m = get<0>(CTAShape{}); + [[maybe_unused]] constexpr int cta_n = get<1>(CTAShape{}); + constexpr int cga_m = get<0>(ClusterShape{}); + constexpr int cga_n = get<1>(ClusterShape{}); + + if constexpr (cga_m == _1{} && cga_n == _1{}) { + return true; + } else if constexpr (cga_m == _2{} && cga_n == _1{} && cta_m >= _128{}) { + return true; + } else if constexpr (cga_m == _1{} && cga_n == _2{} && cta_n >= _128{}) { + return true; + } else if constexpr (cga_m == _2{} && cga_n == _2{} && cta_m >= _128{} && cta_n >= _128{}) { + return true; + } else { + return false; + } +} + +template +void sm90_dispatch_mainloop_schedules(ActivationType const* A, WeightType const* B, ScaleZeroType const* weight_scales, + ScaleZeroType const* weight_zero_points, BiasType const* biases, float const alpha, OutputType* C, int m, int n, + int k, int const group_size, tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes, + cudaStream_t stream, int* occupancy = nullptr) { + ORT_LLM_LOG_DEBUG(__PRETTY_FUNCTION__); + + constexpr bool tile_shapes_supported = are_tile_shapes_supported(); + + if constexpr (tile_shapes_supported) { + switch (gemm_config.mainloop_schedule) { + case tkc::MainloopScheduleType::AUTO: + using KernelScheduleType = cute::conditional_t(CTAShape{}) == Int<64>{}, + cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::gemm::KernelTmaWarpSpecializedCooperative>; + sm90_dispatch_epilogue_schedules(A, B, weight_scales, weight_zero_points, + biases, alpha, C, m, n, k, group_size, gemm_config, workspace, workspace_bytes, stream, occupancy); + break; + default: + ORT_THROW( + "[fpA_intB_gemm][sm90_dispatch_mainloop_schedules] mainloop schedule config is invalid " + "for " + "mixed type GEMM."); + break; + } + } else { + ORT_THROW( + "[fpA_intB_gemm][sm90_dispatch_mainloop_schedules] Unsupported CTA and Cluster shapes for " + "mixed type GEMM."); + } +} + +template +void sm90_dispatch_gemm_config(ActivationType const* A, WeightType const* B, ScaleZeroType const* weight_scales, + ScaleZeroType const* weight_zero_points, BiasType const* biases, float const alpha, OutputType* C, int m, int n, + int k, int const group_size, tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes, + cudaStream_t stream, int* occupancy = nullptr) { + ORT_LLM_LOG_DEBUG(__PRETTY_FUNCTION__); + switch (gemm_config.cluster_shape) { + case tkc::ClusterShape::ClusterShape_1x1x1: + sm90_dispatch_mainloop_schedules>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, + k, group_size, gemm_config, workspace, workspace_bytes, stream, occupancy); + break; + case tkc::ClusterShape::ClusterShape_2x1x1: + sm90_dispatch_mainloop_schedules>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, + k, group_size, gemm_config, workspace, workspace_bytes, stream, occupancy); + break; + case tkc::ClusterShape::ClusterShape_1x2x1: + sm90_dispatch_mainloop_schedules>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, + k, group_size, gemm_config, workspace, workspace_bytes, stream, occupancy); + break; + case tkc::ClusterShape::ClusterShape_2x2x1: + sm90_dispatch_mainloop_schedules>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, + k, group_size, gemm_config, workspace, workspace_bytes, stream, occupancy); + break; + default: + ORT_THROW("[fpA_intB_gemm][dispatch_CGA_config] Config is invalid for mixed type GEMM."); + break; + } +} + +template +void sm90_dispatch_gemm_to_cutlass(ActivationType const* A, WeightType const* B, ScaleZeroType const* weight_scales, + ScaleZeroType const* weight_zero_points, BiasType const* biases, float const alpha, OutputType* C, int m, int n, + int k, int const group_size, char* workspace, size_t workspace_bytes, tkc::CutlassGemmConfig gemm_config, + cudaStream_t stream, int* occupancy = nullptr) { + ORT_LLM_LOG_DEBUG(__PRETTY_FUNCTION__); + // Note that SIMT configs are omitted here since they are not supported for fpA_intB. + // We also only instantiate configs here where threadblockShapeM == warpShapeM since those usually perform the best + // for mixed type gemms. + + constexpr int Ktile = 128 / sizeof(ActivationType); + using _Ktile = Int; + switch (gemm_config.tile_config_sm90) { + case tkc::CutlassTileConfigSM90::CtaShape64x16x128B: + sm90_dispatch_gemm_config>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size, + gemm_config, workspace, workspace_bytes, stream, occupancy); + break; + case tkc::CutlassTileConfigSM90::CtaShape64x32x128B: + sm90_dispatch_gemm_config>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size, + gemm_config, workspace, workspace_bytes, stream, occupancy); + break; + case tkc::CutlassTileConfigSM90::CtaShape64x64x128B: + sm90_dispatch_gemm_config>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size, + gemm_config, workspace, workspace_bytes, stream, occupancy); + break; + case tkc::CutlassTileConfigSM90::CtaShape64x128x128B: + sm90_dispatch_gemm_config>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size, + gemm_config, workspace, workspace_bytes, stream, occupancy); + break; + case tkc::CutlassTileConfigSM90::CtaShape64x256x128B: + sm90_dispatch_gemm_config>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size, + gemm_config, workspace, workspace_bytes, stream, occupancy); + break; + case tkc::CutlassTileConfigSM90::CtaShape128x16x128B: + sm90_dispatch_gemm_config>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size, + gemm_config, workspace, workspace_bytes, stream, occupancy); + break; + case tkc::CutlassTileConfigSM90::CtaShape128x32x128B: + sm90_dispatch_gemm_config>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size, + gemm_config, workspace, workspace_bytes, stream, occupancy); + break; + case tkc::CutlassTileConfigSM90::CtaShape128x64x128B: + sm90_dispatch_gemm_config>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size, + gemm_config, workspace, workspace_bytes, stream, occupancy); + break; + case tkc::CutlassTileConfigSM90::CtaShape128x128x128B: + sm90_dispatch_gemm_config>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size, + gemm_config, workspace, workspace_bytes, stream, occupancy); + break; + case tkc::CutlassTileConfigSM90::CtaShape128x256x128B: + sm90_dispatch_gemm_config>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size, + gemm_config, workspace, workspace_bytes, stream, occupancy); + break; + case tkc::CutlassTileConfigSM90::Undefined: + ORT_THROW("[fpA_intB_gemm][sm90_dispatch_gemm_to_cutlass] gemm config undefined."); + break; + case tkc::CutlassTileConfigSM90::ChooseWithHeuristic: + ORT_THROW( + "[fpA_intB_gemm][sm90_dispatch_gemm_to_cutlass] gemm config should have already been set by " + "heuristic."); + break; + default: + ORT_THROW("[fpA_intB_gemm][sm90_dispatch_gemm_to_cutlass] Config is invalid for mixed type GEMM."); + break; + } +} + +} // namespace cutlass_kernels +} // namespace kernels +} // namespace onnxruntime::llm diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/launchers/fpA_intB_gemm_launcher_1.generated.cu b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/launchers/fpA_intB_gemm_launcher_1.generated.cu new file mode 100644 index 0000000000000..468d53f336e55 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/launchers/fpA_intB_gemm_launcher_1.generated.cu @@ -0,0 +1,264 @@ +#include "contrib_ops/cuda/llm/fpA_intB_gemm/launchers/fpA_intB_launcher_sm90.inl" +namespace onnxruntime::llm +{ +namespace kernels +{ +namespace cutlass_kernels +{ + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<16>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> ( +const half*, const cutlass::uint4b_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<32>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> ( +const half*, const cutlass::uint4b_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<64>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> ( +const half*, const cutlass::uint4b_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<128>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> ( +const half*, const cutlass::uint4b_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<128>, cute::Int<64>>, cute::Shape, cute::Int<2>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> ( +const half*, const cutlass::uint4b_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<256>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> ( +const half*, const cutlass::uint4b_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<256>, cute::Int<64>>, cute::Shape, cute::Int<2>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> ( +const half*, const cutlass::uint4b_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<16>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> ( +const half*, const uint8_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<32>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> ( +const half*, const uint8_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<64>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> ( +const half*, const uint8_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<128>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> ( +const half*, const uint8_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<128>, cute::Int<64>>, cute::Shape, cute::Int<2>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> ( +const half*, const uint8_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<256>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> ( +const half*, const uint8_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<256>, cute::Int<64>>, cute::Shape, cute::Int<2>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> ( +const half*, const uint8_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, cutlass::uint4b_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<16>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> ( +const __nv_bfloat16*, const cutlass::uint4b_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, cutlass::uint4b_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<32>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> ( +const __nv_bfloat16*, const cutlass::uint4b_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, cutlass::uint4b_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<64>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> ( +const __nv_bfloat16*, const cutlass::uint4b_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, cutlass::uint4b_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<128>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> ( +const __nv_bfloat16*, const cutlass::uint4b_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, cutlass::uint4b_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<128>, cute::Int<64>>, cute::Shape, cute::Int<2>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> ( +const __nv_bfloat16*, const cutlass::uint4b_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, cutlass::uint4b_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<256>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> ( +const __nv_bfloat16*, const cutlass::uint4b_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, cutlass::uint4b_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<256>, cute::Int<64>>, cute::Shape, cute::Int<2>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> ( +const __nv_bfloat16*, const cutlass::uint4b_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, uint8_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<16>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> ( +const __nv_bfloat16*, const uint8_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, uint8_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<32>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> ( +const __nv_bfloat16*, const uint8_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, uint8_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<64>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> ( +const __nv_bfloat16*, const uint8_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, uint8_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<128>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> ( +const __nv_bfloat16*, const uint8_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, uint8_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<128>, cute::Int<64>>, cute::Shape, cute::Int<2>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> ( +const __nv_bfloat16*, const uint8_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, uint8_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<256>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> ( +const __nv_bfloat16*, const uint8_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, uint8_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<256>, cute::Int<64>>, cute::Shape, cute::Int<2>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> ( +const __nv_bfloat16*, const uint8_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +} // namespace cutlass_kernels +} // namespace kernels +} // namespace onnxruntime::llm diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/launchers/fpA_intB_gemm_launcher_2.generated.cu b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/launchers/fpA_intB_gemm_launcher_2.generated.cu new file mode 100644 index 0000000000000..0156c83840b09 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/launchers/fpA_intB_gemm_launcher_2.generated.cu @@ -0,0 +1,516 @@ +#include "contrib_ops/cuda/llm/fpA_intB_gemm/launchers/fpA_intB_launcher_sm90.inl" +namespace onnxruntime::llm +{ +namespace kernels +{ +namespace cutlass_kernels +{ + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<16>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const half*, const cutlass::uint4b_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<16>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const half*, const cutlass::uint4b_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<32>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const half*, const cutlass::uint4b_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<32>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const half*, const cutlass::uint4b_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<64>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const half*, const cutlass::uint4b_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<64>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const half*, const cutlass::uint4b_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<128>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const half*, const cutlass::uint4b_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<128>, cute::Int<64>>, cute::Shape, cute::Int<2>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const half*, const cutlass::uint4b_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<128>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const half*, const cutlass::uint4b_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<128>, cute::Int<64>>, cute::Shape, cute::Int<2>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const half*, const cutlass::uint4b_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<256>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const half*, const cutlass::uint4b_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<256>, cute::Int<64>>, cute::Shape, cute::Int<2>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const half*, const cutlass::uint4b_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<256>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const half*, const cutlass::uint4b_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<256>, cute::Int<64>>, cute::Shape, cute::Int<2>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const half*, const cutlass::uint4b_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<16>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const half*, const uint8_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<16>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const half*, const uint8_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<32>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const half*, const uint8_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<32>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const half*, const uint8_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<64>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const half*, const uint8_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<64>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const half*, const uint8_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<128>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const half*, const uint8_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<128>, cute::Int<64>>, cute::Shape, cute::Int<2>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const half*, const uint8_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<128>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const half*, const uint8_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<128>, cute::Int<64>>, cute::Shape, cute::Int<2>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const half*, const uint8_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<256>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const half*, const uint8_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<256>, cute::Int<64>>, cute::Shape, cute::Int<2>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const half*, const uint8_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<256>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const half*, const uint8_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<256>, cute::Int<64>>, cute::Shape, cute::Int<2>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const half*, const uint8_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, cutlass::uint4b_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<16>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const __nv_bfloat16*, const cutlass::uint4b_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, cutlass::uint4b_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<16>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const __nv_bfloat16*, const cutlass::uint4b_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, cutlass::uint4b_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<32>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const __nv_bfloat16*, const cutlass::uint4b_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, cutlass::uint4b_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<32>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const __nv_bfloat16*, const cutlass::uint4b_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, cutlass::uint4b_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<64>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const __nv_bfloat16*, const cutlass::uint4b_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, cutlass::uint4b_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<64>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const __nv_bfloat16*, const cutlass::uint4b_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, cutlass::uint4b_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<128>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const __nv_bfloat16*, const cutlass::uint4b_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, cutlass::uint4b_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<128>, cute::Int<64>>, cute::Shape, cute::Int<2>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const __nv_bfloat16*, const cutlass::uint4b_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, cutlass::uint4b_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<128>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const __nv_bfloat16*, const cutlass::uint4b_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, cutlass::uint4b_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<128>, cute::Int<64>>, cute::Shape, cute::Int<2>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const __nv_bfloat16*, const cutlass::uint4b_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, cutlass::uint4b_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<256>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const __nv_bfloat16*, const cutlass::uint4b_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, cutlass::uint4b_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<256>, cute::Int<64>>, cute::Shape, cute::Int<2>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const __nv_bfloat16*, const cutlass::uint4b_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, cutlass::uint4b_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<256>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const __nv_bfloat16*, const cutlass::uint4b_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, cutlass::uint4b_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<256>, cute::Int<64>>, cute::Shape, cute::Int<2>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const __nv_bfloat16*, const cutlass::uint4b_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, uint8_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<16>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const __nv_bfloat16*, const uint8_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, uint8_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<16>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const __nv_bfloat16*, const uint8_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, uint8_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<32>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const __nv_bfloat16*, const uint8_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, uint8_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<32>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const __nv_bfloat16*, const uint8_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, uint8_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<64>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const __nv_bfloat16*, const uint8_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, uint8_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<64>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const __nv_bfloat16*, const uint8_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, uint8_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<128>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const __nv_bfloat16*, const uint8_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, uint8_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<128>, cute::Int<64>>, cute::Shape, cute::Int<2>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const __nv_bfloat16*, const uint8_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, uint8_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<128>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const __nv_bfloat16*, const uint8_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, uint8_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<128>, cute::Int<64>>, cute::Shape, cute::Int<2>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const __nv_bfloat16*, const uint8_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, uint8_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<256>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const __nv_bfloat16*, const uint8_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, uint8_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<256>, cute::Int<64>>, cute::Shape, cute::Int<2>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const __nv_bfloat16*, const uint8_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, uint8_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<256>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const __nv_bfloat16*, const uint8_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, uint8_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<256>, cute::Int<64>>, cute::Shape, cute::Int<2>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const __nv_bfloat16*, const uint8_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +} // namespace cutlass_kernels +} // namespace kernels +} // namespace onnxruntime::llm diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/launchers/fpA_intB_launcher_sm90.h b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/launchers/fpA_intB_launcher_sm90.h new file mode 100644 index 0000000000000..594ae1079c06e --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/launchers/fpA_intB_launcher_sm90.h @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm_configs.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/weight_only_quant_op.h" +#include + +namespace onnxruntime::llm { +namespace kernels { +namespace cutlass_kernels { + +template +void sm90_generic_mixed_gemm_kernelLauncher(ActivationType const* A, WeightType const* B, + ScaleZeroType const* weight_scales, ScaleZeroType const* weight_zero_points, BiasType const* biases, + float const alpha, OutputType* C, int m, int n, int k, int const group_size, + onnxruntime::llm::cutlass_extensions::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes, + cudaStream_t stream, int* occupancy = nullptr); + +} // namespace cutlass_kernels +} // namespace kernels +} // namespace onnxruntime::llm diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/launchers/fpA_intB_launcher_sm90.inl b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/launchers/fpA_intB_launcher_sm90.inl new file mode 100644 index 0000000000000..779ff88455703 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/launchers/fpA_intB_launcher_sm90.inl @@ -0,0 +1,282 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifdef __GNUC__ // Check if the compiler is GCC or Clang +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#endif // __GNUC__ + +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" + +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" + +#include "cutlass/util/packed_stride.hpp" + +#include "contrib_ops/cuda/llm/cutlass_extensions/compute_occupancy.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/epilogue_helpers.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm_configs.h" + +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/collective/collective_builder_interleaved.hpp" + +#ifdef __GNUC__ // Check if the compiler is GCC or Clang +#pragma GCC diagnostic pop +#endif // __GNUC__ + +#include "core/common/common.h" +#include "contrib_ops/cuda/llm/common/cuda_runtime_utils.h" +#include "contrib_ops/cuda/llm/common/logger.h" +#include "contrib_ops/cuda/llm/cutlass_heuristic.h" +#include "contrib_ops/cuda/llm/cutlass_type_conversion.h" +#include "contrib_ops/cuda/llm/fpA_intB_gemm/launchers/fpA_intB_launcher_sm90.h" + +namespace tk = onnxruntime::llm::common; +namespace tkc = onnxruntime::llm::cutlass_extensions; + +using namespace cute; + +namespace onnxruntime::llm { +namespace kernels { +namespace cutlass_kernels { + +template +#ifdef COMPILE_HOPPER_TMA_GEMMS +void sm90_generic_mixed_gemm_kernelLauncher( + ActivationType const* A, WeightType const* B, + ScaleZeroType const* weight_scales, ScaleZeroType const* weight_zero_points, BiasType const* biases, + float const alpha, OutputType* C, int m, int n, int k, int const group_size, tkc::CutlassGemmConfig /*gemm_config*/, + char* workspace, size_t workspace_bytes, cudaStream_t stream, int* occupancy) { + ORT_LLM_LOG_DEBUG(__PRETTY_FUNCTION__); + + using CutlassActivationType = typename CudaToCutlassTypeAdapter::type; + + if constexpr (!should_filter_tma_warp_specialized_gemm_problem_shape_v) { + using CutlassWeightType = typename CudaToCutlassTypeAdapter::type; + + using CutlassScaleZeroType = typename CudaToCutlassTypeAdapter::type; + using CutlassBiasType = typename CudaToCutlassTypeAdapter::type; + using CutlassOutputType = typename CudaToCutlassTypeAdapter::type; + + static_assert(std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v, + "Activation type must be bfloat16, half, FP8"); + + static_assert(std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v, + "Weight type must be fp8, uint8_t or uint4_t"); + + static_assert(!std::is_same_v || + std::is_same_v, + "Scale/Zero type must be half for fp8 activation"); + + using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand + constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; + + using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand + constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; + + // This example manually swaps and transposes, so keep transpose of input layouts + using LayoutA_Transpose = typename cutlass::layout::LayoutTranspose::type; + using LayoutB_Transpose = typename cutlass::layout::LayoutTranspose::type; + + using ElementZero = CutlassScaleZeroType; + using ElementScale = CutlassScaleZeroType; + + // C/D matrix configuration. We reuse the C operand for the bias and set the stride for broadcast. + using LayoutBias = cutlass::layout::RowMajor; + constexpr int AlignmentBias = 128 / cutlass::sizeof_bits::value; + + // D matrix configuration + using LayoutOutput = cutlass::layout::RowMajor; + constexpr int AlignmentOutput = 128 / cutlass::sizeof_bits::value; + + // Core kernel configurations + using ElementAccumulator = float; // Element type for internal accumulation + using ElementCompute = float; // Element type for epilogue computation + using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature + using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag + using TileShape = CTAShape; // Threadblock-level tile size + using KernelSchedule = MainloopScheduleType; + using EpilogueSchedule = EpilogueScheduleType; + + // Shrink the N dimension to match CTA_N if needed + constexpr int epi_tile_M = cute::min(shape<0>(TileShape{}), 128); // 64 or 128 + constexpr int epi_tile_N = cute::min(shape<1>(TileShape{}), 32); // Allow this to be 16 for some small N tiles. + using EpilogueTileType = cute::Shape, cute::Int>; + + static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest; + static_assert(std::is_same_v, ""); + using EVT_bias_addition = cutlass::epilogue::fusion::Sm90EVT< + cutlass::epilogue::fusion::Sm90Compute, // alpha * acc + bias + cutlass::epilogue::fusion::Sm90ScalarBroadcast, // alpha + cutlass::epilogue::fusion::Sm90AccFetch, // acc + cutlass::epilogue::fusion::Sm90ColBroadcast<0, TileShape, CutlassBiasType, CutlassBiasType, + Stride<_1, _0, _0>, + AlignmentBias> // bias + >; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, + TileShape, ClusterShape, EpilogueTileType, ElementAccumulator, ElementAccumulator, + // Transpose layout of D here since we use the explicit swap + transpose trick + // Void C since we don't use it. Prevents smem allocation. + void, typename cutlass::layout::LayoutTranspose::type, AlignmentBias, CutlassOutputType, + typename cutlass::layout::LayoutTranspose::type, AlignmentOutput, EpilogueSchedule, + EVT_bias_addition>::CollectiveOp; + + using PackedScaleZero = cute::tuple; + using PackedScale = cute::tuple; + using ElementBCollectiveInfo = std::conditional_t; + + // We swap A and B operands to the builder here + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilderInterleaved< + ArchTag, + OperatorClass, ElementBCollectiveInfo, LayoutB_Transpose, AlignmentB, CutlassActivationType, + LayoutA_Transpose, AlignmentA, ElementAccumulator, TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule>::CollectiveOp; + + using TileScheduler = cute::conditional_t(CTAShape{}) == Int<64>{}, cutlass::gemm::PersistentScheduler, + cutlass::gemm::StreamKScheduler>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal, // Indicates ProblemShape + CollectiveMainloop, CollectiveEpilogue, TileScheduler>; + + if (occupancy != nullptr) { + *occupancy = onnxruntime::llm::cutlass_extensions::compute_occupancy_for_kernel(); + return; + } + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + using StrideA = typename GemmKernel::StrideA; + using StrideB = typename GemmKernel::StrideB; + using StrideC = typename GemmKernel::StrideC; + using StrideD = typename GemmKernel::StrideD; + using StrideS = typename CollectiveMainloop::StrideScale; + + if (weight_scales == nullptr) { + ORT_THROW("Weight scales must always be set to a non-null value."); + } + + if constexpr (cutlass::isFinegrained(QuantOp)) { + int cta_shape_k = cute::size<2>(TileShape{}); + if (group_size % cta_shape_k != 0) { + std::string err_msg = "The group size must a multiple of " + std::to_string(cta_shape_k); + ORT_THROW("[fpA_intB_gemm] ", err_msg); + } + + if constexpr (QuantOp == cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY) { + if (weight_zero_points != nullptr) { + ORT_THROW("Weight zero pointer must be a nullptr for scale only fine grained"); + } + } else if constexpr (QuantOp == cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS) { + if (weight_zero_points == nullptr) { + ORT_THROW("Weight zero pointer must be valid for scale and bias fine grained"); + } + } + } else { + if (group_size != k) { + ORT_THROW("Invalid group size for per column scaling kernels."); + } + + if (weight_zero_points != nullptr) { + ORT_THROW("Weight zero-points must be null when running per column scaling"); + } + } + + auto cutlass_scale_k = (k + group_size - 1) / group_size; + StrideA stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, 1)); + StrideB stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, 1)); + StrideD stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(n, m, 1)); + StrideS stride_S = cutlass::make_cute_packed_stride(StrideS{}, cute::make_shape(n, cutlass_scale_k, 1)); + + // Use the output as the bias to avoid making a tma descriptor with a nullptr. + auto output_as_bias_type = reinterpret_cast(C); + + typename Gemm::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm, + {n, m, k, 1}, + {reinterpret_cast(B), stride_B, + reinterpret_cast(A), stride_A, + reinterpret_cast(weight_scales), stride_S, + group_size, reinterpret_cast(weight_zero_points)}, + {{}, output_as_bias_type, stride_D, reinterpret_cast(C), stride_D}}; + + args.epilogue.thread = { + {alpha}, // alpha args + {}, // accumulator + {reinterpret_cast(biases), CutlassBiasType(0.f)}, // bias args + {} // end multiply_add + }; + + Gemm gemm; + if (gemm.get_workspace_size(args) > workspace_bytes) { + ORT_LLM_LOG_ERROR("[fpA_intB_gemm] given workspace size insufficient."); + } + + auto can_implement = gemm.can_implement(args); + if (can_implement != cutlass::Status::kSuccess) { + std::string err_msg = "fpA_intB cutlass kernel will fail for params. Error: " + std::string(cutlassGetStatusString(can_implement)); + ORT_THROW("[fpA_intB_gemm] ", err_msg); + } + + auto init_status = gemm.initialize(args, workspace, stream); + if (init_status != cutlass::Status::kSuccess) { + std::string err_msg = "Failed to initialize cutlass fpA_intB gemm. Error: " + std::string(cutlassGetStatusString(init_status)); + ORT_THROW("[fpA_intB_gemm] " + err_msg); + } + + auto run_status = gemm.run(stream); + if (run_status != cutlass::Status::kSuccess) { + std::string err_msg = "Failed to run cutlass fpA_intB gemm. Error: " + std::string(cutlassGetStatusString(run_status)); + ORT_THROW("[fpA_intB_gemm] " + err_msg); + } + } else { + std::stringstream ss; + ss << "[fpA_intB_gemm] Config (" << (int64_t)cute::size<0>(CTAShape{}) << "," + << (int64_t)cute::size<1>(CTAShape{}) << "," << (int64_t)cute::size<2>(CTAShape{}) << ") (" + << (int64_t)cute::size<0>(ClusterShape{}) << "," << (int64_t)cute::size<1>(ClusterShape{}) << "," + << (int64_t)cute::size<2>(ClusterShape{}) << ") not compiled with FAST_BUILD."; + + ORT_THROW(ss.str()); + } +} +#else // COMPILE_HOPPER_TMA_GEMMS +void sm90_generic_mixed_gemm_kernelLauncher(ActivationType const*, WeightType const*, + ScaleZeroType const*, ScaleZeroType const*, BiasType const*, + float const, OutputType*, int, int, int, int const, tkc::CutlassGemmConfig, + char*, size_t, cudaStream_t, int*) { + ORT_LLM_LOG_DEBUG(__PRETTY_FUNCTION__); + ORT_THROW("[fpA_intB_gemm] Please recompile with support for hopper by passing 90a-real as an arch."); +} +#endif // COMPILE_HOPPER_TMA_GEMMS + +} // namespace cutlass_kernels +} // namespace kernels +} // namespace onnxruntime::llm diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_adaptor.cu b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_adaptor.cu new file mode 100644 index 0000000000000..55beb8b9ca029 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_adaptor.cu @@ -0,0 +1,260 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/cuda/llm/fpA_intB_gemm_adaptor.h" +#include +#include "core/providers/cuda/cuda_common.h" + +namespace onnxruntime::llm { +namespace kernels { +namespace fpA_intB_gemv { + +template +__global__ void transposeScaleKernel( + const T* scale, + T* transposed_scale, + int n, int k_blocks) { + // Calculate the output matrix coordinates [row, col] for this thread + // The output matrix has dimensions [k_blocks, n] + int out_row = blockIdx.y * blockDim.y + threadIdx.y; + int out_col = blockIdx.x * blockDim.x + threadIdx.x; + + // Check bounds to ensure we are within the output matrix dimensions [k_blocks, n] + if (out_row < k_blocks && out_col < n) { + int in_row = out_col; + int in_col = out_row; + int64_t input_offset = static_cast(in_row) * k_blocks + in_col; + int64_t output_offset = static_cast(out_row) * n + out_col; + T scale_val = scale[input_offset]; + transposed_scale[output_offset] = scale_val; + } +} + +template +void launch_transpose_scale_kernel( + cudaStream_t stream, + const T* scale, + T* transposed_scale, + int n, int k_blocks) { + constexpr int BLOCK_SIZE = 16; + dim3 blockDim(BLOCK_SIZE, BLOCK_SIZE); + dim3 gridDim( + (n + blockDim.x - 1) / blockDim.x, // Grid size in x covers output columns (n) + (k_blocks + blockDim.y - 1) / blockDim.y // Grid size in y covers output rows (k_blocks) + ); + + transposeScaleKernel<<>>( + scale, + transposed_scale, + n, + k_blocks); +} + +// CUDA kernel to compute -scale * zero_point and transpose +// Each thread computes one element of the OUTPUT matrix (shape [k_blocks, n]) +template +__global__ void computeScaledZeroPointAndTransposeKernel( + const Z* zero_point, // Input zero_point matrix [n, k_blocks] or [n, (k_blocks + 1) / 2] if packed int4 + const T* transposed_scale, // transposed scale [k_blocks, n] + T* scaled_zero_point, // Output matrix [k_blocks, n] + int n, // Rows of input matrices + int k_blocks, // Columns of input matrices + float default_zero_point) { + // Calculate the output matrix coordinates [row, col] for this thread + // The output matrix has dimensions [k_blocks, n] + int out_row = blockIdx.y * blockDim.y + threadIdx.y; + int out_col = blockIdx.x * blockDim.x + threadIdx.x; + + // Check bounds to ensure we are within the output matrix dimensions [k_blocks, n] + if (out_row < k_blocks && out_col < n) { + int in_row = out_col; + int in_col = out_row; + int64_t output_offset = static_cast(out_row) * n + out_col; + + // Perform the computation: scaled_zero_point[out_row, out_col] = -scale[in_row, in_col] * zero_point[in_row, in_col] + T scale_val = transposed_scale[output_offset]; + float zero_point_val; + if (zero_point != nullptr) { + if constexpr (is_zero_point_int4_packed) { // zero point is 4 bit, and two elements are packed into one byte. + int64_t packed_row_size = (k_blocks + 1) / 2; + int64_t packed_zp_offset = static_cast(in_row) * packed_row_size + in_col / 2; + uint8_t packed_zp = zero_point[packed_zp_offset]; + zero_point_val = static_cast((in_col & 0x01) ? (packed_zp >> 4) : (packed_zp & 0x0f)); + } else { + int64_t input_offset = static_cast(in_row) * k_blocks + in_col; + zero_point_val = static_cast(zero_point[input_offset]); + } + } else { + zero_point_val = default_zero_point; + } + + float result = static_cast(scale_val) * (-zero_point_val + default_zero_point); + scaled_zero_point[output_offset] = static_cast(result); + } +} + +template +void launch_scaled_zero_point_kernel( + cudaStream_t stream, + const Z* zero_point, + const T* transposed_scale, + T* scaled_zero_point, + int n, int k_blocks, float default_zero_point) { + assert(zero_point != nullptr); + constexpr int BLOCK_SIZE = 16; + dim3 blockDim(BLOCK_SIZE, BLOCK_SIZE); + dim3 gridDim( + (n + blockDim.x - 1) / blockDim.x, // Grid size in x covers output columns (n) + (k_blocks + blockDim.y - 1) / blockDim.y // Grid size in y covers output rows (k_blocks) + ); + + computeScaledZeroPointAndTransposeKernel<<>>( + zero_point, + transposed_scale, + scaled_zero_point, + n, + k_blocks, + default_zero_point); +} + +// Explicit instantiations: +template void launch_transpose_scale_kernel( + cudaStream_t stream, + const half* scale, + half* transposed_scale, + int n, int k_blocks); + +template void launch_scaled_zero_point_kernel( + cudaStream_t stream, + const half* zero_point, + const half* transposed_scale, + half* scaled_zero_point, + int n, int k_blocks, float default_zero_point); + +template void launch_scaled_zero_point_kernel( + cudaStream_t stream, + const uint8_t* zero_point, + const half* transposed_scale, + half* scaled_zero_point, + int n, int k_blocks, float default_zero_point); + +// zero point is 4 bits packed. +template void launch_scaled_zero_point_kernel( + cudaStream_t stream, + const uint8_t* zero_point, + const half* transposed_scale, + half* scaled_zero_point, + int n, int k_blocks, float default_zero_point); + +// CUDA kernel to unpack uint4, transpose, and pack into int8 directly +__global__ void unpack_transpose_pack_uint4_to_int8_kernel_v2( + const unsigned char* __restrict__ packed_weight, + signed char* __restrict__ packed_transposed_weight, + int n, // original matrix rows + int k) // original matrix columns +{ + // The output 'packed_transposed_weight' has dimensions k x (n/2) bytes. + // Each thread processes one byte in the output. + int out_flat_idx = blockIdx.x * blockDim.x + threadIdx.x; + + // Total number of bytes in the output packed_transposed_weight matrix + int total_output_bytes = k * (n / 2); + + if (out_flat_idx < total_output_bytes) { + constexpr signed char default_zero_point = 8; + + // Calculate row and column in the output packed_transposed_weight matrix (k x n/2) + // out_row_packed: row in the k dimension of the output (0 to k-1) + // out_col_packed: column in the n/2 dimension of the output (0 to n/2 - 1) + const int out_row_packed = out_flat_idx / (n / 2); + const int out_col_packed = out_flat_idx % (n / 2); + + // These two int8 values will form the current output packed byte: + // val_0: corresponds to original_unpacked[2 * out_col_packed][out_row_packed] + // val_1: corresponds to original_unpacked[2 * out_col_packed + 1][out_row_packed] + + // --- Retrieve val_0 --- + // Its original (unpacked) row index was '2 * out_col_packed' + const int r_orig_0 = 2 * out_col_packed; + // Its original (unpacked) column index was 'out_row_packed' + const int c_orig_0 = out_row_packed; + + // Determine the flat index in the input 'packed_weight' (n x k/2) where val_0 resides + const int packed_weight_idx_0 = r_orig_0 * (k / 2) + c_orig_0 / 2; + + unsigned char packed_data_0 = packed_weight[packed_weight_idx_0]; + signed char val_0; + if ((c_orig_0 % 2) == 0) { // If original column is even, it's the lower 4 bits + val_0 = (signed char)(packed_data_0 & 0x0f) - default_zero_point; + } else { // If original column is odd, it's the upper 4 bits + val_0 = (signed char)(packed_data_0 >> 4) - default_zero_point; + } + + // --- Retrieve val_1 --- + // Its original (unpacked) row index was '2 * out_col_packed + 1' + const int r_orig_1 = 2 * out_col_packed + 1; + // Its original (unpacked) column index was 'out_row_packed' + const int c_orig_1 = out_row_packed; + + // Determine the flat index in the input 'packed_weight' (n x k/2) where val_1 resides + const int packed_weight_idx_1 = r_orig_1 * (k / 2) + c_orig_1 / 2; + + unsigned char packed_data_1 = packed_weight[packed_weight_idx_1]; + signed char val_1; + if ((c_orig_1 % 2) == 0) { // If original column is even, it's the lower 4 bits + val_1 = (signed char)(packed_data_1 & 0x0f) - default_zero_point; + } else { // If original column is odd, it's the upper 4 bits + val_1 = (signed char)(packed_data_1 >> 4) - default_zero_point; + } + + // Pack the two signed char values (now 8-bit, but we only care about their 4 LSBs) + // back into a single byte for the output. + packed_transposed_weight[out_flat_idx] = (unsigned char)((val_0 & 0x0f) | ((val_1 & 0x0f) << 4)); + } +} + +void unpack_uint4_transposed_to_int8_direct_cuda( + cudaStream_t stream, void* packed_transposed_weight, const void* packed_weight, int n, int k) { + int total_output_bytes = k * (n / 2); + int threads_per_block = 256; + int num_blocks = (total_output_bytes + threads_per_block - 1) / threads_per_block; + + unpack_transpose_pack_uint4_to_int8_kernel_v2<<>>( + (const unsigned char*)packed_weight, + (signed char*)packed_transposed_weight, + n, + k); +} + +__global__ void transpose_uint8_matrix_and_convert_to_int8_kernel( + const uint8_t* __restrict__ input, // shape: (n, k) + int8_t* __restrict__ output, // shape: (k, n) + int n, int k) { + + int row = blockIdx.y * blockDim.y + threadIdx.y; // index in n + int col = blockIdx.x * blockDim.x + threadIdx.x; // index in k + + if (row < n && col < k) { + int input_idx = row * k + col; + int output_idx = col * n + row; + output[output_idx] = static_cast(static_cast(input[input_idx]) - 128); + } +} + +void transpose_uint8_matrix_and_convert_to_int8( + cudaStream_t stream, + int8_t* output, // shape: (k, n) + const uint8_t* input, // shape: (n, k) + int n, int k) { + + dim3 blockDim(16, 16); + dim3 gridDim((k + blockDim.x - 1) / blockDim.x, + (n + blockDim.y - 1) / blockDim.y); + + transpose_uint8_matrix_and_convert_to_int8_kernel<<>>(input, output, n, k); +} + + +} // namespace fpA_intB_gemv +} // namespace kernels +} // namespace onnxruntime::llm diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_adaptor.h b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_adaptor.h new file mode 100644 index 0000000000000..61023b62d8a49 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_adaptor.h @@ -0,0 +1,43 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include +#include + +// Convert scale and zero_point from MatMulNBits to the format required by fpA_intB_gemm or fpA_intB_gemv kernels. +namespace onnxruntime::llm { +namespace kernels { +namespace fpA_intB_gemv { + +template +void launch_scaled_zero_point_kernel( + cudaStream_t stream, + const Z* zero_point, + const T* transposed_scale, + T* scaled_zero_point, + int n, int k_blocks, float default_zero_point); + +template +void launch_transpose_scale_kernel( + cudaStream_t stream, + const T* scale, + T* transposed_scale, + int n, int k_blocks); + +// Transpose uint4 weight matrix and add default zero points then pack as int8. +void unpack_uint4_transposed_to_int8_direct_cuda(cudaStream_t stream, + void* packed_transposed_weight, + const void* packed_weight, + int n, + int k); + +// Transpose uint8 weight matrix and add default zero points as int8. +void transpose_uint8_matrix_and_convert_to_int8(cudaStream_t stream, + int8_t* output, // shape: (k, n) + const uint8_t* input, // shape: (n, k) + int n, int k); + +} // namespace fpA_intB_gemv +} // namespace kernels +} // namespace onnxruntime::llm diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_profiler.cc b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_profiler.cc new file mode 100644 index 0000000000000..8112562623791 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_profiler.cc @@ -0,0 +1,100 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "contrib_ops/cuda/llm/fpA_intB_gemm_profiler.h" +#include "contrib_ops/cuda/llm/common/workspace.h" + +using namespace onnxruntime::llm::common; +using namespace onnxruntime::llm::kernels::cutlass_kernels; + +namespace onnxruntime::llm::kernels::weight_only { + +void WeightOnlyGroupwiseQuantGemmPluginProfiler::runTactic( + int m, int n, int k, + WeightOnlyGroupwiseQuantGemmPluginProfiler::Config const& tactic, char* workspace, cudaStream_t const& stream) { + int const originalN = mQuantBits == 8 ? n * FP16_INT8_RATIO : n * FP16_INT4_RATIO; + half* actPtr = reinterpret_cast(workspace); + void* weightPtr = nextWorkspacePtr(reinterpret_cast(actPtr), m * k * sizeof(half)); + half* inputScalesPtr = reinterpret_cast(nextWorkspacePtr(reinterpret_cast(weightPtr), n * k * sizeof(float))); + half* zerosPtr = reinterpret_cast( + nextWorkspacePtr(reinterpret_cast(inputScalesPtr), k * originalN * sizeof(half) / mGroupSize)); + half* biasesPtr = reinterpret_cast( + nextWorkspacePtr(reinterpret_cast(zerosPtr), k * originalN * sizeof(half) / mGroupSize)); + half* outputPtr = reinterpret_cast(nextWorkspacePtr(reinterpret_cast(biasesPtr), n * sizeof(half))); + char* workspacePtr = reinterpret_cast(nextWorkspacePtr(reinterpret_cast(outputPtr), m * originalN * sizeof(half))); + + if (!mHasZeros) { + zerosPtr = nullptr; + } + + if (!mHasBiases) { + biasesPtr = nullptr; + } + + if (tactic.enableCudaKernel) { + // run CUDA kernel + void const* pre_quant_scale_ptr = nullptr; + bool apply_alpha_in_advance = false; + float alpha = 1.0f; + onnxruntime::llm::kernels::fpA_intB_gemv::Params params( + actPtr, pre_quant_scale_ptr, weightPtr, + inputScalesPtr, zerosPtr, + biasesPtr, outputPtr, + alpha, m, originalN, k, mGroupSize, mCudaKernelType, apply_alpha_in_advance); + onnxruntime::llm::kernels::fpA_intB_gemv::kernel_launcher(mArch, params, stream); + } else { + // run CUTLASS kernel + int const wsSize = mRunner->getWorkspaceSize(m, originalN, k); + if (mQuantBits == 8) { + mRunner->gemm(actPtr, reinterpret_cast(weightPtr), inputScalesPtr, zerosPtr, biasesPtr, outputPtr, + m, originalN, k, mGroupSize, tactic, workspacePtr, wsSize, stream); + } else { + mRunner->gemm(actPtr, reinterpret_cast(weightPtr), inputScalesPtr, zerosPtr, biasesPtr, + outputPtr, m, originalN, k, mGroupSize, tactic, workspacePtr, wsSize, stream); + } + } +} + +void WeightOnlyGroupwiseQuantGemmPluginProfiler::computeTmpSize(size_t maxM, size_t n, size_t k) { + // Quantized weights are packed in FP16 format (INT4*4 -> FP16, INT8*2 -> FP16) + int const originalN = mQuantBits == 8 ? n * FP16_INT8_RATIO : n * FP16_INT4_RATIO; + std::vector workspaces = { + maxM * k * sizeof(half), // A + k * n * sizeof(float), // B + k * originalN * sizeof(half) / mGroupSize, // scales + k * originalN * sizeof(half) / mGroupSize, // zeros + originalN * sizeof(half), // biases + maxM * originalN * sizeof(half), // C + mRunner->getWorkspaceSize(maxM, originalN, k) // workspace + }; + size_t bytes = calculateTotalWorkspaceSize(workspaces.data(), workspaces.size()); + setTmpWorkspaceSizeInBytes(bytes); +} + +std::vector WeightOnlyGroupwiseQuantGemmPluginProfiler::getTactics( + int /*m*/, int /*n*/, int /*k*/) const { + return mRunner->getConfigs(); +} + +bool WeightOnlyGroupwiseQuantGemmPluginProfiler::checkTactic(int m, int /*n*/, int /*k*/, Config const& tactic) const { + // stop to profile Cuda kernel for m >= 16 + if (tactic.enableCudaKernel) { + return m < 16; + } + return true; +} + +} // namespace onnxruntime::llm::kernels::weight_only diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_profiler.h b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_profiler.h new file mode 100644 index 0000000000000..7be77fa43d85d --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_profiler.h @@ -0,0 +1,86 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "contrib_ops/cuda/llm/gemm_profiler.h" +#include "contrib_ops/cuda/llm/fpA_intB_gemm/fpA_intB_gemm.h" +#include "contrib_ops/cuda/llm/fpA_intB_gemv/fpA_intB_gemv.h" + +#include +#include +#include +#include +#include +#include + +using WeightOnlyGemmRunner = onnxruntime::llm::kernels::cutlass_kernels::CutlassFpAIntBGemmRunnerInterface; +using WeightOnlyGemmRunnerPtr = std::shared_ptr; +using KernelType = onnxruntime::llm::kernels::fpA_intB_gemv::KernelType; + +namespace onnxruntime::llm::kernels::weight_only { +enum class WeightTypeId { + INT8 = 1, + INT4 = 2, +}; + +constexpr int32_t FP16_BITS = 16; +constexpr int32_t INT8_BITS = 8; +constexpr int32_t INT4_BITS = 4; +constexpr int32_t FP16_INT4_RATIO = FP16_BITS / INT4_BITS; +constexpr int32_t FP16_INT8_RATIO = FP16_BITS / INT8_BITS; + +class WeightOnlyGroupwiseQuantGemmPluginProfiler + : public GemmPluginProfiler { + public: + using Config = onnxruntime::llm::cutlass_extensions::CutlassGemmConfig; + + void setQuant(int bits, bool has_bias, bool has_zeros) { + mQuantBits = bits; + mHasBiases = has_bias; + mHasZeros = has_zeros; + } + + void setGroupSize(int groupSize) { + mGroupSize = groupSize; + } + + void setCudaKernelType(KernelType cudaKernelType, int arch) { + mCudaKernelType = cudaKernelType; + mArch = arch; + } + + protected: + void runTactic(int m, int n, int k, Config const& tactic, + char* workspace, cudaStream_t const& stream) override; + + void computeTmpSize(size_t maxM, size_t n, size_t k) override; + + std::vector getTactics(int m, int n, int k) const override; + + bool checkTactic(int m, int n, int k, Config const& tactic) const override; + + private: + bool mHasBiases; + bool mHasZeros; + int mQuantBits; + int mGroupSize; + KernelType mCudaKernelType; + int mArch; +}; + +} // namespace onnxruntime::llm::kernels::weight_only diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/details.h b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/details.h new file mode 100644 index 0000000000000..4fa64ef329c57 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/details.h @@ -0,0 +1,239 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include "core/providers/cuda/cuda_common.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/interleaved_numeric_conversion.h" +#include "contrib_ops/cuda/llm/fpA_intB_gemv/fpA_intB_gemv.h" + +namespace onnxruntime::llm { +namespace kernels { +namespace fpA_intB_gemv { + +template +struct kernel_type_traits; +#define KERNEL_TYPE_TRAITS_REGISTRY(KT, _isGroupwise, _isInt4) \ + template <> \ + struct kernel_type_traits { \ + static constexpr bool isGroupwise = _isGroupwise; \ + static constexpr bool isInt4 = _isInt4; \ + }; + +KERNEL_TYPE_TRAITS_REGISTRY(KernelType::FP16Int8Groupwise, true, false); +KERNEL_TYPE_TRAITS_REGISTRY(KernelType::FP16Int4Groupwise, true, true); +KERNEL_TYPE_TRAITS_REGISTRY(KernelType::FP16Int8PerChannel, false, false); +KERNEL_TYPE_TRAITS_REGISTRY(KernelType::FP16Int4PerChannel, false, true); +KERNEL_TYPE_TRAITS_REGISTRY(KernelType::BF16Int8Groupwise, true, false); +KERNEL_TYPE_TRAITS_REGISTRY(KernelType::BF16Int4Groupwise, true, true); +KERNEL_TYPE_TRAITS_REGISTRY(KernelType::BF16Int8PerChannel, false, false); +KERNEL_TYPE_TRAITS_REGISTRY(KernelType::BF16Int4PerChannel, false, true); +#undef KERNEL_TYPE_TRAITS_REGISTRY + +// A generic memory iterator used for coalesced global memory access with optional enablement. +// Template parameters: +// Enable: If false, disables loading/storing. +// TVec: Vectorized type (e.g., float4, half2). +// Strided: Number of rows in a tile. +// Continuous: Number of contiguous vector elements to load/store at once. +// Scalar type (e.g., half). +template +class GMemIterator { + public: + __device__ __forceinline__ GMemIterator(T* addr, int offset, int step, int stride) + : addr_(Enable ? (addr + offset) : nullptr), step_(step), stride_(stride) { + } + + __device__ __forceinline__ void load(void* dst, int iter, int ii = 0) { + if constexpr (Enable) { +#pragma unroll + for (int jj = 0; jj < Continuous; ++jj) { + reinterpret_cast(dst)[jj] = reinterpret_cast(addr_ + iter * step_ + ii * stride_)[jj]; + } + } + } + + private: + T* addr_; + int step_; + int stride_; +}; + +struct FP16DetailsA { + using Type = half; + using Type2 = half2; + static constexpr int kElemBits = 16; +}; + +struct BF16DetailsA { + using Type = __nv_bfloat16; + using Type2 = __nv_bfloat162; + static constexpr int kElemBits = 16; +}; + +struct Int8DetailsW { + static constexpr int kElemBits = 8; +}; + +struct Int4DetailsW { + static constexpr int kElemBits = 4; +}; + +template +struct ColumnMajor { + using DetailsA = TypeDetailsA; + using DetailsW = TypeDetailsW; + using AccessTypeA = float4; + using AccessTypeW = int; + static constexpr int kAccessSize = 128; + static constexpr int kStepK = kAccessSize / TypeDetailsA::kElemBits; + static constexpr int kTileSize = TileSizeK; + static constexpr int kInterleave = 1; + + struct Mapper { + __device__ __forceinline__ int operator()(int i) { + return i; + } + }; +}; + +template +struct ColumnMajorInterleavedForHopper { + using DetailsA = TypeDetailsA; + using DetailsW = TypeDetailsW; + using AccessTypeA = float4; + using AccessTypeW = int4; + static constexpr int kAccessSize = 128; + static constexpr int kStepK = kAccessSize / TypeDetailsW::kElemBits; + static constexpr int kTileSize = TileSizeK; + static constexpr int kInterleave = 1; + + static constexpr int kTypeFactor = 128 * 8 / (TileSizeK * TypeDetailsW::kElemBits); + + // constants for mapper + static constexpr int kElementGroupSizeA = TileSizeK / 32; + static constexpr int kElementGroupSizeW = kTypeFactor * kElementGroupSizeA; + static constexpr int kGroupOffsetA = 4 * kElementGroupSizeA; + + struct Mapper { + __device__ __forceinline__ int operator()(int i) { + return i % kElementGroupSizeA + (i % kGroupOffsetA) / kElementGroupSizeA * kElementGroupSizeW + i / kGroupOffsetA * kElementGroupSizeA; + } + }; +}; + +template +struct ColumnMajorInterleaved { + using DetailsA = TypeDetailsA; + using DetailsW = TypeDetailsW; + using AccessTypeA = float4; + using AccessTypeW = int4; + static constexpr int kAccessSize = 128; + static constexpr int kStepK = kAccessSize / TypeDetailsW::kElemBits; + static constexpr int kTileSize = TileSizeK; + static constexpr int kInterleave = 128 * 8 / (TileSizeK * TypeDetailsW::kElemBits); + + // constants for mapper + static constexpr int kElementGroupSizeA = TileSizeK / 32; + static constexpr int kElementGroupSizeW = kInterleave * kElementGroupSizeA; + static constexpr int kGroupOffsetA = 4 * kElementGroupSizeA; + + struct Mapper { + __device__ __forceinline__ int operator()(int i) { + return i % kElementGroupSizeA + (i % kGroupOffsetA) / kElementGroupSizeA * kElementGroupSizeW + i / kGroupOffsetA * kElementGroupSizeA; + } + }; +}; + +template class LayoutDetails_, + bool UseInterleavedConverter, int TileSizeK> +struct KernelDetails { + using TypeDetailsA = TypeDetailsA_; + using TypeDetailsW = TypeDetailsW_; + using LayoutDetails = LayoutDetails_; + using AccessTypeA = typename LayoutDetails::AccessTypeA; + using AccessTypeW = typename LayoutDetails::AccessTypeW; + static constexpr int kWarpSize = 32; + static constexpr int kStepK = LayoutDetails::kStepK; + static constexpr int kAccessNumA = kStepK * TypeDetailsA::kElemBits / (sizeof(AccessTypeA) * 8); + static constexpr int kAccessNumW = kStepK * TypeDetailsW::kElemBits / (sizeof(AccessTypeW) * 8); + static constexpr int kInterleave = LayoutDetails::kInterleave; + static constexpr int kThreadsPerInterleavedTile = LayoutDetails::kTileSize / kStepK; + static constexpr int kElemsPerByteW = 8 / TypeDetailsW::kElemBits; + static constexpr bool kUseInterleavedConverter = UseInterleavedConverter; +}; + +template +struct I2FConverter; + +template +struct I2FConverter { + static_assert(std::is_same_v || std::is_same_v); + static_assert(WElemBits == 4 || WElemBits == 8); + using CutlassAType = std::conditional_t, cutlass::half_t, cutlass::bfloat16_t>; + using CutlassWType = std::conditional_t; + static constexpr int kConvertCount = 32 / WElemBits; + using Converter = cutlass::FastInterleavedAndBiasedNumericArrayConverter; + using CvtSrcType = typename Converter::source_type; + using CvtResType = typename Converter::result_type; + + template + __device__ __forceinline__ static void convert(void* src, void* dst) { + static_assert(N % kConvertCount == 0); +#pragma unroll + for (int ii = 0; ii < N / kConvertCount; ++ii) { + reinterpret_cast(dst)[ii] = Converter::convert(reinterpret_cast(src)[ii]); + } + } +}; + +template +struct I2FConverter { + static_assert(std::is_same_v || std::is_same_v); + static_assert(WElemBits == 4 || WElemBits == 8); + using CutlassAType = std::conditional_t, cutlass::half_t, cutlass::bfloat16_t>; + using CutlassWType = std::conditional_t; + static constexpr int kConvertCount = 32 / WElemBits; + using Converter = cutlass::NumericArrayConverter; + using CvtSrcType = typename Converter::source_type; + using CvtResType = typename Converter::result_type; + + template + __device__ __forceinline__ static void convert(void* src, void* dst) { + static_assert(N % kConvertCount == 0); +#pragma unroll + for (int ii = 0; ii < N / kConvertCount; ++ii) { + reinterpret_cast(dst)[ii] = Converter::convert(reinterpret_cast(src)[ii]); + } + } +}; + +template +struct ConverterWrapper { + using TypeDetailsA = typename Details::TypeDetailsA; + using TypeDetailsW = typename Details::TypeDetailsW; + static constexpr bool kUseInterleavedConverter = Details::kUseInterleavedConverter; + using Converter = I2FConverter; +}; + +template +void select_gs(Params& params, cudaStream_t s); + +} // namespace fpA_intB_gemv +} // namespace kernels +} // namespace onnxruntime::llm diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher.h b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher.h new file mode 100644 index 0000000000000..ff1a28661184f --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher.h @@ -0,0 +1,423 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "contrib_ops/cuda/llm/fpA_intB_gemv/fpA_intB_gemv.h" +#include "contrib_ops/cuda/llm/fpA_intB_gemv/details.h" +#include "core/common/common.h" + +namespace onnxruntime::llm { +namespace kernels { +namespace fpA_intB_gemv { + +template +struct MathWrapper { +}; + +template <> +struct MathWrapper { + using Type = typename FP16DetailsA::Type; + using Type2 = typename FP16DetailsA::Type2; + + __device__ __forceinline__ static Type2 to_vec2(Type const& v) { + return __half2half2(v); + } + + __device__ __forceinline__ static Type2 fma2(Type2 const& a, Type2 const& b, Type2 const& c) { + return __hfma2(a, b, c); + } + + __device__ __forceinline__ static Type2 mul2(Type2 const& a, Type2 const& b) { + return __hmul2(a, b); + } + + // __device__ __forceinline__ static Type2 deq2(Type2 const& weight, Type2 const& scale, Type2 const& zero_point) { + // return __hmul2(__hsub2(weight, zero_point), scale); + // } +}; + +template <> +struct MathWrapper { + using Type = typename BF16DetailsA::Type; + using Type2 = typename BF16DetailsA::Type2; + + __device__ __forceinline__ static Type2 to_vec2(Type const& v) { +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + return __bfloat162bfloat162(v); +#else + uint32_t val = 0; + Type2 ret = reinterpret_cast(val); + return ret; +#endif + } + + __device__ __forceinline__ static Type2 fma2(Type2 const& a, Type2 const& b, Type2 const& c) { +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + return __hfma2(a, b, c); +#else + return to_vec2(static_cast(0.f)); +#endif + } + + __device__ __forceinline__ static Type2 mul2(Type2 const& a, Type2 const& b) { +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + return __hmul2(a, b); +#else + return to_vec2(static_cast(0.f)); +#endif + } +}; + +template +__device__ __forceinline__ void apply_scale(void* act, void* act_scale) { + using Type2 = typename MathWrapper::Type2; + static_assert(K % 2 == 0); + [[maybe_unused]] static constexpr int VecK = K / 2; + if constexpr (Enable) { + Type2* pa = reinterpret_cast(act); + Type2* pb = reinterpret_cast(act_scale); +#pragma unroll + for (int m = 0; m < M; ++m) { +#pragma unroll + for (int k = 0; k < VecK; ++k) { + pa[m * VecK + k] = MathWrapper::mul2(pa[m * VecK + k], pb[k]); + } + } + } +} + +template +__device__ __forceinline__ void dequantize(void* w, void* quantized_w, void* scales, void* zeros, float alpha) { + using Type = typename MathWrapper::Type; + using Type2 = typename MathWrapper::Type2; + using Converter = typename ConverterWrapper
::Converter; + static_assert(K % 2 == 0); + static constexpr int VecK = K / 2; +#pragma unroll + for (int n = 0; n < N; ++n) { + Converter::convert(reinterpret_cast(quantized_w) + n * K / Details::kElemsPerByteW, + reinterpret_cast(w) + n * K); + Type2 vec_scale, vec_zero; + if constexpr (ApplyAlphaInAdvance) { + // For W4A8, we assume scales/zero is always half data type, no matter activation dtype is bf16 or fp16 + Type scales_ = static_cast(reinterpret_cast(scales)[n]) * alpha; + vec_scale = MathWrapper::to_vec2(scales_); + vec_zero = MathWrapper::to_vec2(static_cast(0.f)); + if constexpr (EnableZero) { + vec_zero = MathWrapper::to_vec2( + static_cast(reinterpret_cast(zeros)[n]) * alpha); + } + } else { + vec_scale = MathWrapper::to_vec2(reinterpret_cast(scales)[n]); + vec_zero = MathWrapper::to_vec2(static_cast(0.f)); + if constexpr (EnableZero) { + vec_zero = MathWrapper::to_vec2(reinterpret_cast(zeros)[n]); + } + } +#pragma unroll + for (int k = 0; k < VecK; ++k) { + reinterpret_cast(w)[n * VecK + k] = MathWrapper::fma2( + reinterpret_cast(w)[n * VecK + k], vec_scale, vec_zero); + } + } +} + +template +__device__ __forceinline__ void pack_to_vec2(void* dst, void* src, int n) { + using Type = typename MathWrapper::Type; + typename Details::LayoutDetails::Mapper mapper; + int n0 = n & ~0x1, n1 = n & 0x1; + for (int k = 0; k < K; ++k) { + int physical_idx = mapper(k); + reinterpret_cast(dst)[n0 * K + k * 2 + n1] = reinterpret_cast(src)[physical_idx]; + } +} + +template +__device__ __forceinline__ void mma(void* acc, void* w_pack2, void* act) { + using Type = typename MathWrapper::Type; + using Type2 = typename MathWrapper::Type2; + static_assert(N % 2 == 0); + static constexpr int VecN = N / 2; +#pragma unroll + for (int m = 0; m < M; ++m) { +#pragma unroll + for (int n = 0; n < VecN; ++n) { +#pragma unroll + for (int k = 0; k < K; ++k) { + reinterpret_cast(acc)[m * VecN + n] = MathWrapper::fma2( + reinterpret_cast(w_pack2)[n * K + k], + MathWrapper::to_vec2(reinterpret_cast(act)[m * K + k]), + reinterpret_cast(acc)[m * VecN + n]); + } + } + } +} + +template +__device__ __forceinline__ T warp_reduce_sum(T& val) { + val += __shfl_xor_sync(~0, val, 16); + val += __shfl_xor_sync(~0, val, 8); + if (Interleave != 2 && Interleave != 4) + val += __shfl_xor_sync(~0, val, 4); + if (Interleave != 4) + val += __shfl_xor_sync(~0, val, 2); + val += __shfl_xor_sync(~0, val, 1); + return val; +} + +template +__device__ __forceinline__ void epilogue(void* out, int stride, void* tile_acc, void* bias, float alpha) { + using Type = typename MathWrapper::Type; + static constexpr int Interleave = Details::kInterleave; + static constexpr int ThreadsPerInterleavedTile = Details::kThreadsPerInterleavedTile; + static constexpr int WarpSize = Details::kWarpSize; + static constexpr int WarpNum = Threads / WarpSize; + static_assert(Threads % WarpSize == 0); + __shared__ float shmem[CtaM * CtaN * Interleave * WarpNum]; + int tid = threadIdx.x; + int warp_id = tid / WarpSize, lane_id = tid % WarpSize; +#pragma unroll + for (int m = 0; m < CtaM; ++m) { +#pragma unroll + for (int n = 0; n < CtaN; ++n) { + float v = static_cast(reinterpret_cast(tile_acc)[m * CtaN + n]); + v = warp_reduce_sum(v); + if (lane_id < Interleave * ThreadsPerInterleavedTile && lane_id % ThreadsPerInterleavedTile == 0) { + shmem[warp_id * CtaM * CtaN * Interleave + m * CtaN * Interleave + n * Interleave + lane_id / ThreadsPerInterleavedTile] = v; + } + } + } + __syncthreads(); +#pragma unroll + for (int ii = tid; ii < CtaM * CtaN * Interleave; ii += Threads) { + int m = ii / (CtaN * Interleave), n = ii % (CtaN * Interleave); + float val = 0.f, v_bias = 0.f; + if constexpr (EnableBias) { + v_bias = static_cast(reinterpret_cast(bias)[n]); + } +#pragma unroll + for (int jj = 0; jj < WarpNum; ++jj) { + val += shmem[jj * CtaM * CtaN * Interleave + ii]; + } + if constexpr (ApplyAlphaInAdvance) { + reinterpret_cast(out)[m * stride + n] = static_cast(val + v_bias); + } else { + reinterpret_cast(out)[m * stride + n] = static_cast(alpha * val + v_bias); + } + } +} + +template +__device__ __forceinline__ void fill(void* tile, T v) { +#pragma unroll + for (int ii = 0; ii < N; ++ii) { + reinterpret_cast(tile)[ii] = v; + } +} + +template +__global__ void kernel(TypeA* act, TypeA* act_scale, uint8_t* weight, TypeA* scales, TypeA* zeros, TypeA* bias, + TypeA* out, float alpha, int m, int n, int k) { + // ArgType ArgName DataType Shape Layout + // input act fp16/bf16 [m, k] RowMajor + // input act_scale fp16/bf16 [1, k] RowMajor + // input weight int4b/int8b [k, n] ColumnMajor or ColumnMajorInterleaved + // input scales fp16/bf16 [k / GroupSize, n] RowMajor + // input zeros fp16/bf16 [k / GroupSize, n] RowMajor + // input bias fp16/bf16 [1, n] RowMajor + // output out fp16/bf16 [m, n] RowMajor + + using AccessTypeA = typename Details::AccessTypeA; + using AccessTypeW = typename Details::AccessTypeW; + + static constexpr bool Mandatory = true; + static constexpr int StepK = Details::kStepK; + static constexpr int CtaK = StepK * Threads; + static_assert(CtaN % 2 == 0); + if constexpr (GroupSize != 0) { + static_assert((CtaK / Details::kInterleave) % GroupSize == 0); + } + + int const origin_k = k, interleaved_k = k * Details::kInterleave; + + int const tile_id_m = blockIdx.x, tile_id_n = blockIdx.y, tid = threadIdx.x; + int const offset_m = tile_id_m * CtaM, interleaved_offset_n = tile_id_n * CtaN; + int const real_offset_n = interleaved_offset_n * Details::kInterleave + ((tid * StepK / Details::LayoutDetails::kTileSize) % Details::kInterleave); + int const real_offset_k = (tid * StepK / (Details::kInterleave * Details::LayoutDetails::kTileSize)) * Details::LayoutDetails::kTileSize + ((tid * StepK) % Details::LayoutDetails::kTileSize); + + GMemIterator act_iterator( + act, offset_m * origin_k + real_offset_k, CtaK / Details::kInterleave, origin_k); + GMemIterator act_scale_iterator( + act_scale, real_offset_k, CtaK / Details::kInterleave, 0); + GMemIterator weight_iterator( + weight, + (interleaved_offset_n * interleaved_k + tid * StepK) / Details::kElemsPerByteW, CtaK / Details::kElemsPerByteW, + interleaved_k / Details::kElemsPerByteW); + + GMemIterator scales_iterator( + scales, + (GroupSize != 0 ? real_offset_k / GroupSize * n : 0) + real_offset_n, + (GroupSize != 0 ? CtaK / Details::kInterleave / GroupSize * n : 0), Details::kInterleave); + + GMemIterator zeros_iterator( + zeros, + (GroupSize != 0 ? real_offset_k / GroupSize * n : 0) + real_offset_n, + (GroupSize != 0 ? CtaK / Details::kInterleave / GroupSize * n : 0), Details::kInterleave); + + out += offset_m * n + tile_id_n * CtaN * Details::kInterleave; + if constexpr (EnableBias) { + bias += tile_id_n * CtaN * Details::kInterleave; + } + + TypeA tile_acc[CtaM * CtaN]; + fill(tile_acc, static_cast(0.f)); + + for (int idx_k = tid * StepK, iter = 0; idx_k < interleaved_k; idx_k += CtaK, ++iter) { + TypeA vec_act_scale[StepK]; + TypeA vec_scale[CtaN], vec_zero[CtaN]; + TypeA tile_a[StepK], tile_w[StepK], tile_w_pack2[CtaN * StepK]; + uint8_t tile_w_quantized[StepK / Details::kElemsPerByteW]; +#pragma unroll + for (int i = 0; i < CtaN; ++i) { + scales_iterator.load(vec_scale + i, iter, i); + zeros_iterator.load(vec_zero + i, iter, i); + } + act_scale_iterator.load(vec_act_scale, iter); +#pragma unroll + for (int i = 0; i < CtaN; ++i) { + weight_iterator.load(tile_w_quantized, iter, i); + dequantize( + tile_w, tile_w_quantized, vec_scale + i, vec_zero + i, alpha); + pack_to_vec2(tile_w_pack2, tile_w, i); + } +#pragma unroll + for (int i = 0; i < CtaM; ++i) { + act_iterator.load(tile_a, iter, i); + apply_scale(tile_a, vec_act_scale); + mma(tile_acc + i * CtaN, tile_w_pack2, tile_a); + } + } + epilogue(out, n, tile_acc, bias, alpha); +} + +template +void exec_kernel(Params& params, cudaStream_t s) { + using T = typename Details::TypeDetailsA::Type; + if (params.m % CtaM || params.n % (CtaN * Details::kInterleave)) { + throw std::runtime_error("launch failed"); + } + dim3 grid(params.m / CtaM, params.n / (CtaN * Details::kInterleave)); + dim3 block(Threads); + kernel<<>>( + reinterpret_cast(params.act), + reinterpret_cast(params.act_scale), + reinterpret_cast(params.weight), + reinterpret_cast(params.scales), + reinterpret_cast(params.zeros), + reinterpret_cast(params.bias), + reinterpret_cast(params.out), + params.alpha, + params.m, params.n, params.k); +} + +template +void dispatcher(Params& params, cudaStream_t s) { +#define DISPATCHER_FOR_M(target_m, CtaM, CtaN, Threads) \ + do { \ + if (params.m == target_m) { \ + exec_kernel(params, s); \ + return; \ + } \ + } while (0); + + if constexpr (EnableZero) { + DISPATCHER_FOR_M(1, 1, 4, 128); + DISPATCHER_FOR_M(2, 2, 4, 128); + DISPATCHER_FOR_M(3, 3, 4, 128); + DISPATCHER_FOR_M(4, 4, 4, 128); + DISPATCHER_FOR_M(5, 5, 4, 128); + DISPATCHER_FOR_M(6, 6, 4, 128); + DISPATCHER_FOR_M(7, 7, 4, 128); + DISPATCHER_FOR_M(8, 8, 4, 128); + DISPATCHER_FOR_M(9, 9, 4, 128); + DISPATCHER_FOR_M(10, 10, 4, 128); + DISPATCHER_FOR_M(11, 11, 4, 128); + DISPATCHER_FOR_M(12, 12, 4, 128); + DISPATCHER_FOR_M(13, 13, 4, 128); + DISPATCHER_FOR_M(14, 14, 4, 128); + DISPATCHER_FOR_M(15, 15, 4, 128); + } else { + DISPATCHER_FOR_M(1, 1, 8, 128); + DISPATCHER_FOR_M(2, 2, 8, 128); + DISPATCHER_FOR_M(3, 3, 8, 128); + DISPATCHER_FOR_M(4, 4, 8, 128); + DISPATCHER_FOR_M(5, 5, 8, 128); + DISPATCHER_FOR_M(6, 6, 8, 128); + DISPATCHER_FOR_M(7, 7, 8, 128); + DISPATCHER_FOR_M(8, 8, 8, 128); + DISPATCHER_FOR_M(9, 9, 8, 128); + DISPATCHER_FOR_M(10, 10, 8, 128); + DISPATCHER_FOR_M(11, 11, 8, 128); + DISPATCHER_FOR_M(12, 12, 8, 128); + DISPATCHER_FOR_M(13, 13, 8, 128); + DISPATCHER_FOR_M(14, 14, 8, 128); + DISPATCHER_FOR_M(15, 15, 8, 128); + } + throw std::runtime_error("unsupported m"); +#undef DISPATCHER_FOR_M +} + +template +void check_pointer(Params& params, cudaStream_t s) { + assert(!params.act_scale); // act_scale is not supported for now. + assert(!params.apply_alpha_in_advance); // apply_alpha_in_advance is not supported for now. + + if (params.zeros && params.bias) { + dispatcher(params, s); + } else if (!params.zeros && params.bias) { + dispatcher(params, s); + } else if (params.zeros && !params.bias) { + dispatcher(params, s); + } else { + dispatcher(params, s); + } +} + +template +void select_gs(Params& params, cudaStream_t s) { + if constexpr (isGroupwise) { + if (params.groupsize == 64) { + check_pointer(params, s); + return; + } else if (params.groupsize == 128) { + check_pointer(params, s); + return; + } + } + + ORT_THROW("unsupported block_size: ", params.groupsize); +} + +#define INSTANTIATE_WEIGHT_ONLY_CUDA_DISPATCHERS(KType, A, B, Layout, ConverterInterleave, KTile) \ + template void select_gs::isGroupwise, \ + KernelDetails>(Params & params, cudaStream_t s); + +} // namespace fpA_intB_gemv +} // namespace kernels +} // namespace onnxruntime::llm diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_bf16_int4.cu b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_bf16_int4.cu new file mode 100644 index 0000000000000..e2c008884c998 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_bf16_int4.cu @@ -0,0 +1,32 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher.h" + +namespace onnxruntime::llm { +namespace kernels { +namespace fpA_intB_gemv { + +INSTANTIATE_WEIGHT_ONLY_CUDA_DISPATCHERS( + KernelType::BF16Int4Groupwise, BF16DetailsA, Int4DetailsW, ColumnMajorInterleaved, true, 64); + +// KTile=128 for Ada w4a8 +// INSTANTIATE_WEIGHT_ONLY_CUDA_DISPATCHERS( +// KernelType::BF16Int4Groupwise, BF16DetailsA, Int4DetailsW, ColumnMajorInterleaved, true, 128); + +} // namespace fpA_intB_gemv +} // namespace kernels +} // namespace onnxruntime::llm diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_bf16_int4_hopper.cu b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_bf16_int4_hopper.cu new file mode 100644 index 0000000000000..8cd96c44421e5 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_bf16_int4_hopper.cu @@ -0,0 +1,31 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher.h" + +namespace onnxruntime::llm { +namespace kernels { +namespace fpA_intB_gemv { + +INSTANTIATE_WEIGHT_ONLY_CUDA_DISPATCHERS( + KernelType::BF16Int4Groupwise, BF16DetailsA, Int4DetailsW, ColumnMajorInterleavedForHopper, true, 64); + +// INSTANTIATE_WEIGHT_ONLY_CUDA_DISPATCHERS( +// KernelType::BF16Int4Groupwise, BF16DetailsA, Int4DetailsW, ColumnMajorInterleavedForHopper, true, 128); + +} // namespace fpA_intB_gemv +} // namespace kernels +} // namespace onnxruntime::llm diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_bf16_int8.cu b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_bf16_int8.cu new file mode 100644 index 0000000000000..1eb5f51bdffdc --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_bf16_int8.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher.h" + +namespace onnxruntime::llm { +namespace kernels { +namespace fpA_intB_gemv { + +INSTANTIATE_WEIGHT_ONLY_CUDA_DISPATCHERS( + KernelType::BF16Int8Groupwise, BF16DetailsA, Int8DetailsW, ColumnMajorInterleaved, true, 64); + +} // namespace fpA_intB_gemv +} // namespace kernels +} // namespace onnxruntime::llm diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_bf16_int8_hopper.cu b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_bf16_int8_hopper.cu new file mode 100644 index 0000000000000..f5872841e1acb --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_bf16_int8_hopper.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher.h" + +namespace onnxruntime::llm { +namespace kernels { +namespace fpA_intB_gemv { + +INSTANTIATE_WEIGHT_ONLY_CUDA_DISPATCHERS( + KernelType::BF16Int8Groupwise, BF16DetailsA, Int8DetailsW, ColumnMajorInterleavedForHopper, true, 64); + +} // namespace fpA_intB_gemv +} // namespace kernels +} // namespace onnxruntime::llm diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_fp16_int4.cu b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_fp16_int4.cu new file mode 100644 index 0000000000000..f6b76e67b20ba --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_fp16_int4.cu @@ -0,0 +1,32 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher.h" + +namespace onnxruntime::llm { +namespace kernels { +namespace fpA_intB_gemv { + +INSTANTIATE_WEIGHT_ONLY_CUDA_DISPATCHERS( + KernelType::FP16Int4Groupwise, FP16DetailsA, Int4DetailsW, ColumnMajorInterleaved, true, 64); + +// KTile=128 for Ada w4a8 +// INSTANTIATE_WEIGHT_ONLY_CUDA_DISPATCHERS( +// KernelType::FP16Int4Groupwise, FP16DetailsA, Int4DetailsW, ColumnMajorInterleaved, true, 128); + +} // namespace fpA_intB_gemv +} // namespace kernels +} // namespace onnxruntime::llm diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_fp16_int4_hopper.cu b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_fp16_int4_hopper.cu new file mode 100644 index 0000000000000..2ca88285d4cfe --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_fp16_int4_hopper.cu @@ -0,0 +1,31 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher.h" + +namespace onnxruntime::llm { +namespace kernels { +namespace fpA_intB_gemv { + +INSTANTIATE_WEIGHT_ONLY_CUDA_DISPATCHERS( + KernelType::FP16Int4Groupwise, FP16DetailsA, Int4DetailsW, ColumnMajorInterleavedForHopper, true, 64); + +// INSTANTIATE_WEIGHT_ONLY_CUDA_DISPATCHERS( +// KernelType::FP16Int4Groupwise, FP16DetailsA, Int4DetailsW, ColumnMajorInterleavedForHopper, true, 128); + +} // namespace fpA_intB_gemv +} // namespace kernels +} // namespace onnxruntime::llm diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_fp16_int8.cu b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_fp16_int8.cu new file mode 100644 index 0000000000000..7a00e1ba35f80 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_fp16_int8.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher.h" + +namespace onnxruntime::llm { +namespace kernels { +namespace fpA_intB_gemv { + +INSTANTIATE_WEIGHT_ONLY_CUDA_DISPATCHERS( + KernelType::FP16Int8Groupwise, FP16DetailsA, Int8DetailsW, ColumnMajorInterleaved, true, 64); + +} // namespace fpA_intB_gemv +} // namespace kernels +} // namespace onnxruntime::llm diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_fp16_int8_hopper.cu b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_fp16_int8_hopper.cu new file mode 100644 index 0000000000000..4a8506ca6bbde --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_fp16_int8_hopper.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher.h" + +namespace onnxruntime::llm { +namespace kernels { +namespace fpA_intB_gemv { + +INSTANTIATE_WEIGHT_ONLY_CUDA_DISPATCHERS( + KernelType::FP16Int8Groupwise, FP16DetailsA, Int8DetailsW, ColumnMajorInterleavedForHopper, true, 64); + +} // namespace fpA_intB_gemv +} // namespace kernels +} // namespace onnxruntime::llm diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/fpA_intB_gemv.cu b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/fpA_intB_gemv.cu new file mode 100644 index 0000000000000..32cd607d36480 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/fpA_intB_gemv.cu @@ -0,0 +1,96 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include "core/providers/cuda/cuda_common.h" +#include "contrib_ops/cuda/llm/fpA_intB_gemv/fpA_intB_gemv.h" +#include "contrib_ops/cuda/llm/fpA_intB_gemv/details.h" + +namespace onnxruntime::llm { +namespace kernels { +namespace fpA_intB_gemv { + +void kernel_launcher(int arch, Params& params, cudaStream_t s) { +#define EXEC(KType, A, B, Layout, ConverterInterleave) \ + if (params.type == KType) { \ + select_gs::isGroupwise, KernelDetails>( \ + params, s); \ + return; \ + } + +// This is not used since there is no alpha for MatMulNBits currently. +#define EXEC_W4A8(KType, A, B, Layout, ConverterInterleave) \ + if (params.type == KType && params.apply_alpha_in_advance) { \ + select_gs::isGroupwise, KernelDetails>( \ + params, s); \ + return; \ + } + + if (arch >= 75 && arch < 80) { + EXEC(KernelType::FP16Int8Groupwise, FP16DetailsA, Int8DetailsW, ColumnMajorInterleaved, true); + EXEC(KernelType::FP16Int4Groupwise, FP16DetailsA, Int4DetailsW, ColumnMajorInterleaved, true); + } else if (arch >= 80 && arch < 90 || arch >= 100) { + // if (arch == 89 || arch >= 120) + // { + // EXEC_W4A8(KernelType::FP16Int4Groupwise, FP16DetailsA, Int4DetailsW, ColumnMajorInterleaved, true); + // EXEC_W4A8(KernelType::BF16Int4Groupwise, BF16DetailsA, Int4DetailsW, ColumnMajorInterleaved, true); + // } + EXEC(KernelType::FP16Int8Groupwise, FP16DetailsA, Int8DetailsW, ColumnMajorInterleaved, true); + EXEC(KernelType::FP16Int4Groupwise, FP16DetailsA, Int4DetailsW, ColumnMajorInterleaved, true); + + EXEC(KernelType::BF16Int8Groupwise, BF16DetailsA, Int8DetailsW, ColumnMajorInterleaved, true); + EXEC(KernelType::BF16Int4Groupwise, BF16DetailsA, Int4DetailsW, ColumnMajorInterleaved, true); + } else if (arch >= 90) { + // Dispatchers for W4A8 groupwise + // EXEC_W4A8(KernelType::FP16Int4Groupwise, FP16DetailsA, Int4DetailsW, ColumnMajorInterleavedForHopper, true); + // EXEC_W4A8(KernelType::BF16Int4Groupwise, BF16DetailsA, Int4DetailsW, ColumnMajorInterleavedForHopper, true); + + EXEC(KernelType::FP16Int8Groupwise, FP16DetailsA, Int8DetailsW, ColumnMajorInterleavedForHopper, true); + EXEC(KernelType::FP16Int4Groupwise, FP16DetailsA, Int4DetailsW, ColumnMajorInterleavedForHopper, true); + + EXEC(KernelType::BF16Int8Groupwise, BF16DetailsA, Int8DetailsW, ColumnMajorInterleavedForHopper, true); + EXEC(KernelType::BF16Int4Groupwise, BF16DetailsA, Int4DetailsW, ColumnMajorInterleavedForHopper, true); + } +#undef EXEC_W4A8 +#undef EXEC +} + +bool is_supported(int arch, KernelType kernel_type) { +#define SUPPORT(Type) \ + if (kernel_type == Type) \ + return true; + + if (arch >= 75 && arch < 80) { + SUPPORT(KernelType::FP16Int8Groupwise); + SUPPORT(KernelType::FP16Int4Groupwise); + } else if (arch >= 80) { + SUPPORT(KernelType::FP16Int8Groupwise); + SUPPORT(KernelType::FP16Int4Groupwise); + + SUPPORT(KernelType::BF16Int8Groupwise); + SUPPORT(KernelType::BF16Int4Groupwise); + } + return false; +#undef SUPPORT +} + +} // namespace fpA_intB_gemv +} // namespace kernels +} // namespace onnxruntime::llm diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/fpA_intB_gemv.h b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/fpA_intB_gemv.h new file mode 100644 index 0000000000000..db2860c6b265c --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/fpA_intB_gemv.h @@ -0,0 +1,79 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include + +namespace onnxruntime::llm { +namespace kernels { +namespace fpA_intB_gemv { + +enum class KernelType { + FP16Int8Groupwise, + FP16Int4Groupwise, + FP16Int8PerChannel, + FP16Int4PerChannel, + BF16Int8Groupwise, + BF16Int4Groupwise, + BF16Int8PerChannel, + BF16Int4PerChannel +}; + +struct Params { + using Pointer = void*; + using ConstPointer = void const*; + Pointer act; + Pointer act_scale; + Pointer weight; + Pointer scales; + Pointer zeros; + Pointer bias; + Pointer out; + float alpha; + int m; + int n; + int k; + int groupsize; + KernelType type; + bool apply_alpha_in_advance; + + Params(ConstPointer _act, ConstPointer _act_scale, ConstPointer _weight, ConstPointer _scales, ConstPointer _zeros, + ConstPointer _bias, Pointer _out, float _alpha, int _m, int _n, int _k, int _groupsize, KernelType _type, + bool _apply_alpha_in_advance = false) + : act(const_cast(_act)), + act_scale(const_cast(_act_scale)), + weight(const_cast(_weight)), + scales(const_cast(_scales)), + zeros(const_cast(_zeros)), + bias(const_cast(_bias)), + out(_out), + alpha(_alpha), + m(_m), + n(_n), + k(_k), + groupsize(_groupsize), + type(_type), + apply_alpha_in_advance(_apply_alpha_in_advance) { + } +}; + +void kernel_launcher(int arch, Params& params, cudaStream_t s); + +bool is_supported(int arch, KernelType kernel_type); + +} // namespace fpA_intB_gemv +} // namespace kernels +} // namespace onnxruntime::llm diff --git a/onnxruntime/contrib_ops/cuda/llm/gemm_profiler.cc b/onnxruntime/contrib_ops/cuda/llm/gemm_profiler.cc new file mode 100644 index 0000000000000..893ff27c068f8 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/gemm_profiler.cc @@ -0,0 +1,311 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "contrib_ops/cuda/llm/gemm_profiler.h" +#include "contrib_ops/cuda/llm/common/logger.h" +#include "contrib_ops/cuda/llm/fpA_intB_gemm/fpA_intB_gemm.h" +#include "core/providers/cuda/shared_inc/cuda_call.h" + +#include + +namespace onnxruntime::llm::kernels::weight_only { + +template +GemmPluginProfiler::GemmPluginProfiler() { + mMNKProfileMap = std::make_shared(); + + // set SKIP_GEMM_PLUGIN_PROFILINGS=1 to avoid tactics profilings + auto const skipEnv = std::getenv("SKIP_GEMM_PLUGIN_PROFILINGS"); + mSkip = (skipEnv != NULL && std::stoi(skipEnv)); + if (mSkip) { + ORT_LLM_LOG_DEBUG( + "SKIP_GEMM_PLUGIN_PROFILINGS is set. Skipping GEMM plugin profilings. It could result in runtime error " + "if default tactic is not defined."); + } +} + +// template +// void GemmPluginProfiler::serialize( +// char*& buffer, GemmIdType const& gemmId) const +// { +// auto mProfileMap = mMNKProfileMap->getMProfileMap(gemmId); + +// // Save number of profiles for given GEMM ID +// write(buffer, static_cast(mProfileMap->size())); +// for (auto const& pair : *mProfileMap) +// { +// // Save pair of M to the best GEMM config +// write(buffer, pair); +// } +// } + +// template +// void GemmPluginProfiler::deserialize( +// char const*& data, GemmDims& dims, GemmIdType const& gemmId) +// { +// // NOTE: this mutex is not needed since each thread owns its private map, but will put here for +// // consistency +// writer_lock lock(mMNKProfileMap->mutex); + +// mDims = dims; + +// // GemmId gemmId(dims.n, dims.k); +// if (!mMNKProfileMap->existsMProfileMap(gemmId)) +// { +// // Create GEMM with GEMM ID if it does not exist +// mMNKProfileMap->createMProfileMap(gemmId); +// } +// // Populate map with profiles of GEMM ID +// auto profileMap = mMNKProfileMap->getMProfileMap(gemmId); +// int selectedMapSize; +// read(data, selectedMapSize); +// for (int ii = 0; ii < selectedMapSize; ++ii) +// { +// std::pair> config; +// read(data, config); +// profileMap->insert(config); +// } +// } + +// template +// size_t GemmPluginProfiler::getSerializationSize( +// GemmIdType const& gemmId) const +// { +// reader_lock lock(mMNKProfileMap->mutex); +// return sizeof(int) + // size of the tactics map +// mMNKProfileMap->getMProfileMap(gemmId)->size() +// * sizeof(std::pair>); // size of the tactics map +// } + +template +int GemmPluginProfiler::getMaxProfileM() const { + return 8192; +} + +template +void GemmPluginProfiler::initTmpData( + int /*m*/, int /*n*/, int /*k*/, char* /*workspace*/, size_t /*size*/, cudaStream_t /*stream*/) { + /* Do nothing */ +} + +template +void GemmPluginProfiler::profileTactics( + RunnerPtr const& runner, nvinfer::DataType const& type, GemmDims const& dims, GemmIdType const& gemmId, + bool hasWeightOnlyCudaKernel) { + writer_lock lock(mMNKProfileMap->mutex); + + if (!dims.isInitialized()) { + return; + } + + mRunner = runner; + mType = type; + + int const maxM = std::min(nextPowerOfTwo(dims.maxM), getMaxProfileM()); + computeTmpSize(maxM, dims.n, dims.k); + + if (!mMNKProfileMap->existsMProfileMap(gemmId)) { + // Create map for GEMM ID + mMNKProfileMap->createMProfileMap(gemmId); + } + + if (mSkip) { + return; + } + + auto mProfileMap = mMNKProfileMap->getMProfileMap(gemmId); + bool isAllocated{false}; + + auto profileTactics = [&mProfileMap, &isAllocated, this](int m, int n, int k) { + if (mProfileMap->count(m) == 0) { + if (!isAllocated) { + // Allocate tmp data to run GEMMs + allocateTmpData(); + isAllocated = true; + } + initTmpData(m, n, k, mWorkspaceTmp, mTmpWorkspaceSizeInBytes, mStream); + auto tactics = this->getTactics(m, n, k); + + // Profile different tactics for particular m and insert best config to the map + mProfileMap->insert({m, this->profileTacticsForProblem(m, n, k, tactics)}); + } + }; + + CUDA_CALL_THROW(cudaStreamCreate(&mStream)); + + int const startMinMRounded = nextPowerOfTwo(dims.minM); + + if (hasWeightOnlyCudaKernel) { + // Profile tactics for finer granularity of M, + // if CUDA kernel is enabled for weight-only plugins + int minM = dims.minM; + for (int m = std::max(1, minM); m < std::min(16, maxM); m += 1) { + profileTactics(m, dims.n, dims.k); + } + + for (int m = 16; m < maxM; m *= 2) { + profileTactics(m, dims.n, dims.k); + } + } else { + // Profile tactics for CUTLASS kernel only + for (int m = std::max(1, startMinMRounded); m < maxM; m *= 2) { + profileTactics(m, dims.n, dims.k); + } + } + + profileTactics(maxM, dims.n, dims.k); + + if (isAllocated) { + // Free tmp data + freeTmpData(); + } + CUDA_CALL_THROW(cudaStreamDestroy(mStream)); +} + +template +std::optional GemmPluginProfiler::getBestConfig( + int m, GemmIdType const& gemmId) const { + reader_lock lock(mMNKProfileMap->mutex); + + if (mSkip) { + ORT_LLM_LOG_TRACE("Skip is set, no best config is set for this instance"); + return std::nullopt; + } + + int const mRounded = std::min(std::max(1, nextPowerOfTwo(m)), getMaxProfileM()); + fflush(stdout); + + if (mMNKProfileMap->getMProfileMap(gemmId)->count(m) > 0) { + return mMNKProfileMap->getMProfileMap(gemmId)->at(m); + } else if (mMNKProfileMap->getMProfileMap(gemmId)->count(mRounded) > 0) { + return mMNKProfileMap->getMProfileMap(gemmId)->at(mRounded); + } else { + std::ostringstream msg; + msg << "Cannot find best tactic for m=" << m << " and GEMM ID " << gemmId; + ORT_LLM_LOG_WARNING(msg.str()); + return std::nullopt; + } +} + +template +void GemmPluginProfiler::allocateTmpData() { + ORT_ENFORCE(mTmpWorkspaceSizeInBytes > 0, "tmpWorkspaceSizeInBytes must be larger than 0"); + auto const status = cudaMalloc(&mWorkspaceTmp, mTmpWorkspaceSizeInBytes); + ORT_ENFORCE(status == cudaSuccess, "Can't allocate tmp workspace for GEMM tactics profiling."); +} + +template +void GemmPluginProfiler::freeTmpData() { + auto const status = cudaFree(mWorkspaceTmp); + ORT_ENFORCE(status == cudaSuccess, "Can't free tmp workspace for GEMM tactics profiling."); +} + +template +std::optional GemmPluginProfiler::profileTacticsForProblem( + int m, int n, int k, std::vector const& tactics) { + ORT_LLM_LOG_DEBUG(__PRETTY_FUNCTION__); + + float bestTime = std::numeric_limits::max(); + Config bestConfig; + bool foundOne = false; + + // Iterate over all tactics for given M, N and K + for (size_t ii = 0; ii < tactics.size(); ++ii) { + Config const& candidateConfig = tactics[ii]; + float time = std::numeric_limits::max(); + try { + if (!checkTactic(m, n, k, candidateConfig)) { + continue; + } + // Profile particular tactic for given M, N and K + time = profileTacticForProblem(m, n, k, candidateConfig); + foundOne = true; + } catch (std::exception const& e) { + std::ostringstream msg; + msg << "Cannot profile configuration " << ii; + if constexpr (std::is_same_v) { + msg << ": " << candidateConfig.toString(); + } + msg << "\n (for" + << " m=" << m << ", n=" << n << ", k=" << k << ")" + << ", reason: \"" << e.what() << "\". Skipped"; + ORT_LLM_LOG_TRACE(msg.str()); + cudaGetLastError(); // Reset the last cudaError to cudaSuccess. + continue; + } + + // Choose the fastest tactic + if (time < bestTime) { + bestConfig = candidateConfig; + bestTime = time; + } + } + + if (!foundOne) { + std::ostringstream msg; + msg << "Have not found any valid GEMM config for shape (" + << "m=" << m << ", n=" << n << ", k=" << k << "). Will try to use default or fail at runtime"; + ORT_LLM_LOG_WARNING(msg.str()); + return std::nullopt; + } + + return {bestConfig}; +} + +template +float GemmPluginProfiler::profileTacticForProblem( + int m, int n, int k, Config const& tactic) { + constexpr int warmup = 5; + constexpr int runs = 10; + + cudaStream_t stream = mStream; + + // Warmup the execution + for (int i = 0; i < warmup; ++i) { + runTactic(m, n, k, tactic, mWorkspaceTmp, stream); + } + + cudaEvent_t start; + cudaEvent_t stop; + CUDA_CALL_THROW(cudaEventCreate(&start)); + CUDA_CALL_THROW(cudaEventCreate(&stop)); + CUDA_CALL_THROW(cudaStreamSynchronize(stream)); + CUDA_CALL_THROW(cudaEventRecord(start, stream)); + + // Profile GEMM + for (int i = 0; i < runs; ++i) { + runTactic(m, n, k, tactic, mWorkspaceTmp, stream); + } + + CUDA_CALL_THROW(cudaEventRecord(stop, stream)); + + CUDA_CALL_THROW(cudaEventSynchronize(stop)); + + float elapsed; + CUDA_CALL_THROW(cudaEventElapsedTime(&elapsed, start, stop)); + + CUDA_CALL_THROW(cudaEventDestroy(start)); + CUDA_CALL_THROW(cudaEventDestroy(stop)); + + return elapsed / runs; +} + +template class GemmPluginProfiler, GemmIdCore, + GemmIdCoreHash>; + +} // namespace onnxruntime::llm::kernels::weight_only diff --git a/onnxruntime/contrib_ops/cuda/llm/gemm_profiler.h b/onnxruntime/contrib_ops/cuda/llm/gemm_profiler.h new file mode 100644 index 0000000000000..0ab9b91e7f43c --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/gemm_profiler.h @@ -0,0 +1,283 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "contrib_ops/cuda/llm/nv_infer_datatype.h" +#include "core/common/common.h" + +namespace onnxruntime::llm::kernels::weight_only { + +struct GemmDims { + int64_t minM; + int64_t maxM; + int64_t n; + int64_t k; + + GemmDims() + : minM(-1), maxM(-1), n(-1), k(-1) { + } + + GemmDims(int64_t minM_, int64_t maxM_, int64_t n_, int64_t k_) + : minM(minM_), maxM(maxM_), n(n_), k(k_) { + } + + [[nodiscard]] bool isInitialized() const { + return minM >= 0 && maxM >= 0 && n >= 0 && k >= 0; + } +}; + +// Unique ID of GEMM +// In our case GEMM is uniqly identified by N and K +class GemmIdCore { + public: + int n; + int k; + nvinfer::DataType dtype; + + GemmIdCore(int n_, int k_, nvinfer::DataType const& dtype_) + : n(n_), k(k_), dtype(dtype_) { + } + + GemmIdCore() + : n(-1), k(-1), dtype(nvinfer::DataType::kFLOAT) // dtype does not matter here + { + } + + bool operator==(GemmIdCore const& id) const { + return isEqual(id); + } + + friend std::ostream& operator<<(std::ostream& out, GemmIdCore const& id) { + out << "(N;K)=(" << id.n << ";" << id.k << "),"; + out << " type=" << static_cast(id.dtype); + return out; + } + + protected: + bool isEqual(GemmIdCore const& id) const { + return n == id.n && k == id.k && dtype == id.dtype; + } +}; + +// Hash of GemmId +struct GemmIdCoreHash { + std::size_t operator()(GemmIdCore const& id) const { + auto h1 = std::hash{}(id.n); + auto h2 = std::hash{}(id.k); + auto h3 = std::hash{}(static_cast(id.dtype)); + return h1 ^ h2 ^ h3; + } +}; + +// class GemmIdCublas : public GemmIdCore { +// public: +// bool transA{}; +// bool transB{}; +// nvinfer::DataType outputDtype; + +// GemmIdCublas(int n_, int k_, nvinfer::DataType const& dtype_, bool transA_, bool transB_, +// nvinfer::DataType const& output_dtype_) +// : GemmIdCore(n_, k_, dtype_), transA(transA_), transB(transB_), outputDtype(output_dtype_) { +// } + +// GemmIdCublas() {} + +// bool operator==(GemmIdCublas const& id) const { +// return isEqual(id) && transA == id.transA && transB == id.transB && outputDtype == id.outputDtype; +// } + +// friend std::ostream& operator<<(std::ostream& out, GemmIdCublas const& id) { +// out << "(N;K)=(" << id.n << ";" << id.k << "),"; +// out << " type=" << static_cast(id.dtype); +// out << " transA=" << id.transA; +// out << " transB=" << id.transB; +// out << " outputDtype=" << static_cast(id.outputDtype); +// return out; +// } +// }; + +// // Hash of GemmIdCublas +// struct GemmIdCublasHash { +// std::size_t operator()(GemmIdCublas const& id) const { +// auto h1 = std::hash{}(id.n); +// auto h2 = std::hash{}(id.k); +// auto h3 = std::hash{}(static_cast(id.dtype)); +// auto h4 = std::hash{}(id.transA); +// auto h5 = std::hash{}(id.transB); +// auto h6 = std::hash{}(static_cast(id.outputDtype)); +// return h1 ^ h2 ^ h3 ^ h4 ^ h5 ^ h6; +// } +// }; + +template +class GemmPluginProfiler { + public: + // Map for single GEMM for different Ms (GEMM dimension) to the best config for particular M + using MProfileMap = std::unordered_map>; + using MProfileMapPtr = std::shared_ptr; + + // requires exclusive ownership to write to *this + using reader_lock = std::unique_lock; + // requires shared ownership to read from other + using writer_lock = std::shared_lock; + + // Struct of continuing map if GEMMs to the best profiles for different Ms + struct MNKProfileMap { + // Mutex guarding map + std::shared_timed_mutex mutex; + // Map from GEMM Id to profile for particular GEMM + std::unordered_map profileMap; + + bool existsMProfileMap(GemmIdType const& id) { + auto const iter = profileMap.find(id); + return iter != profileMap.end(); + } + + void createMProfileMap(GemmIdType const& id) { + profileMap[id] = std::make_shared(); + } + + MProfileMapPtr getMProfileMap(GemmIdType const& id) { + auto const iter = profileMap.find(id); + if (iter == profileMap.end()) { + ORT_THROW("Cannot find ID (", id, ") in the profile map. Abort."); + } + return iter->second; + } + }; + + using MNKProfileMapPtr = std::shared_ptr; + + GemmPluginProfiler(); + + virtual ~GemmPluginProfiler() = default; + + // void serialize(char*& buffer, GemmIdType const& gemmId) const; + + // void deserialize(char const*& data, GemmDims& dims, GemmIdType const& gemmId); + // size_t getSerializationSize(GemmIdType const& gemmId) const; + + void profileTactics(RunnerPtr const& runner, nvinfer::DataType const& type, GemmDims const& dims, + GemmIdType const& gemmId, bool hasWeightOnlyCudaKernel = false); + + void setSelectionTactics(MNKProfileMapPtr const& map) { + mMNKProfileMap = map; + } + + void setTmpWorkspaceSizeInBytes(size_t bytes) { + mTmpWorkspaceSizeInBytes = bytes; + } + + void setSkip(bool skip) { + mSkip = mSkip || skip; + } + + std::optional getBestConfig(int m, GemmIdType const& gemmId) const; + + virtual int getMaxProfileM() const; + + protected: + virtual void runTactic(int m, int n, int k, Config const& tactic, char* workspace, cudaStream_t const& stream) = 0; + + virtual void computeTmpSize(size_t maxM, size_t n, size_t k) = 0; + + virtual bool checkTactic(int /*m*/, int /*n*/, int /*k*/, Config const& /*tactic*/) const { + return true; + } + + virtual std::vector getTactics(int m, int n, int k) const = 0; + + virtual void initTmpData(int m, int n, int k, char* workspace, size_t size, cudaStream_t stream); + + private: + void allocateTmpData(); + + void freeTmpData(); + + std::optional profileTacticsForProblem(int m, int n, int k, std::vector const& tactics); + + float profileTacticForProblem(int m, int n, int k, Config const& tactic); + + int nextPowerOfTwo(int v) const { + --v; + v |= v >> 1; + v |= v >> 2; + v |= v >> 4; + v |= v >> 8; + v |= v >> 16; + return ++v; + } + + protected: + RunnerPtr mRunner{nullptr}; + + nvinfer::DataType mType{}; + + private: + MNKProfileMapPtr mMNKProfileMap{}; + + size_t mTmpWorkspaceSizeInBytes{0}; + + char* mWorkspaceTmp{nullptr}; + + cudaStream_t mStream; + + GemmDims mDims{}; + + bool mSkip{false}; +}; + +template +class GemmPluginProfilerManager { + public: + using MNKProfileMap = typename GemmPluginProfilerType::MNKProfileMap; + using MNKProfileMapPtr = typename GemmPluginProfilerType::MNKProfileMapPtr; + using GemmPluginProfilerPtr = std::shared_ptr; + + GemmPluginProfilerManager() { + mMNKProfileMap = std::make_shared(); + } + + GemmPluginProfilerPtr createGemmPluginProfiler(bool inference, bool skip = false) { + auto profiler = std::make_shared(); + profiler->setSkip(skip); + // If the profiler is created during the engine build, + // mMNKProfileMap is shared between different profilers to minimize the time spent on the profiling + // and do not repeat profiling for the GEMMs of the same shape. + if (!inference) { + profiler->setSelectionTactics(mMNKProfileMap); + } + return profiler; + } + + private: + MNKProfileMapPtr mMNKProfileMap{}; +}; + +} // namespace onnxruntime::llm::kernels::weight_only diff --git a/onnxruntime/contrib_ops/cuda/llm/generate_kernels.py b/onnxruntime/contrib_ops/cuda/llm/generate_kernels.py new file mode 100644 index 0000000000000..678102c809b63 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/generate_kernels.py @@ -0,0 +1,397 @@ +# Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Generate fpA intB GEMM kernels: +# pip install nvidia-cutlass +# python generate_kernels.py -a "90" -o ./fpA_intB_gemm/launchers + +import argparse +import enum +import os +from itertools import product + +from cutlass_library import ( + DataType, + DataTypeNames, + DataTypeSize, + DataTypeTag, + EpilogueScheduleSuffixes, + EpilogueScheduleTag, + EpilogueScheduleType, + GemmKind, + GemmKindNames, + KernelScheduleSuffixes, + KernelScheduleTag, + KernelScheduleType, +) + + +################################################################################ +# Epilogue Tag enum and string utils +class LlmEpilogueTag(enum.Enum): + epilogue_op_default = enum.auto() + epilogue_op_bias = enum.auto() + epilogue_op_silu = enum.auto() + epilogue_op_gelu = enum.auto() + + +class LlmEpilogueFusion(enum.Enum): + epilogue_fusion_none = enum.auto() + epilogue_fusion_finalize = enum.auto() + + +EpiTagNames = { + LlmEpilogueTag.epilogue_op_default: "lc", # linear combination + LlmEpilogueTag.epilogue_op_bias: "lc_bias", # linear combination with bias addition + LlmEpilogueTag.epilogue_op_silu: "silu", # silu or swiglu + LlmEpilogueTag.epilogue_op_gelu: "gelu", # gelu or geglu +} + +EpiTag = { + LlmEpilogueTag.epilogue_op_default: "onnxruntime::llm::cutlass_extensions::EpilogueOpDefault", + LlmEpilogueTag.epilogue_op_bias: "onnxruntime::llm::cutlass_extensions::EpilogueOpBias", + LlmEpilogueTag.epilogue_op_silu: "onnxruntime::llm::cutlass_extensions::EpilogueOpDefaultSilu", + LlmEpilogueTag.epilogue_op_gelu: "onnxruntime::llm::cutlass_extensions::EpilogueOpDefaultFtGelu", +} + +EpiFusion = { + LlmEpilogueFusion.epilogue_fusion_none: "onnxruntime::llm::TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE", + LlmEpilogueFusion.epilogue_fusion_finalize: "onnxruntime::llm::TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE", +} + +EpiFusionSuffixes = { + None: "", + LlmEpilogueFusion.epilogue_fusion_none: "EpilogueFusion_NONE", + LlmEpilogueFusion.epilogue_fusion_finalize: "EpilogueFusion_FINALIZE", +} + + +################################################################################ +# Quantization Operation and string utils +class LlmQuantOp(enum.Enum): + per_column_scale_only = enum.auto() + finegrained_scale_only = enum.auto() + finegrained_scale_and_zeros = enum.auto() + none = enum.auto() + + +QuantOpNames = { + LlmQuantOp.per_column_scale_only: "cs", + LlmQuantOp.finegrained_scale_only: "fgs", + LlmQuantOp.finegrained_scale_and_zeros: "fgsz", + LlmQuantOp.none: "noquant", +} + +QuantOpTag = { + LlmQuantOp.per_column_scale_only: "cutlass::WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY", + LlmQuantOp.finegrained_scale_only: "cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY", + LlmQuantOp.finegrained_scale_and_zeros: "cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS", + LlmQuantOp.none: "void", +} + +################################################################################ +# The activations, biases, scales and zeros are instantiated using CUDA types, +# not CUTLASS types. This map materializes the name of the CUDA type. + + +def get_data_type_bits(type): + return DataTypeSize[type] + + +def get_data_type_names(type): + return DataTypeNames[type] + + +CudaTypeName = { + DataType.e4m3: "__nv_fp8_e4m3", + DataType.bf16: "__nv_bfloat16", + DataType.f16: "half", + DataType.f32: "float", +} + + +################################################################################ +# A data structure holding all info to instantiate gemm launchers in TRT LLM. +class LlmGemmLauncher: + def __init__( + self, + gemm_kind, + arch, + act_type, + weight_type, + scalezero_type, + bias_type, + output_type, + quant_op, + epi_tag, + cta_shape, + warp_shape, + stages, + cga_shape, + mainloop_schedule, + epi_schedule, + epi_fusion=None, + ): + self.gemm_kind = gemm_kind + self.arch = arch + self.act_type = act_type + self.weight_type = weight_type + self.scalezero_type = scalezero_type + self.bias_type = bias_type + self.output_type = output_type + self.quant_op = quant_op + self.epi_tag = epi_tag + self.cta_shape = cta_shape + self.warp_shape = warp_shape + self.stages = stages + self.cga_shape = cga_shape + self.mainloop_schedule = mainloop_schedule + self.epi_schedule = epi_schedule + self.epi_fusion = epi_fusion + + def __repr__(self): + kernel_prefix = f"{GemmKindNames[self.gemm_kind]}_sm{self.arch}_{get_data_type_names(self.act_type)}_{get_data_type_names(self.weight_type)}_{get_data_type_names(self.scalezero_type)}_{get_data_type_names(self.bias_type)}_{get_data_type_names(self.output_type)}_{QuantOpNames[self.quant_op]}_{EpiTagNames[self.epi_tag]}_{self.cta_shape[0]}x{self.cta_shape[1]}x{self.cta_shape[2]}_{self.warp_shape[0]}x{self.warp_shape[1]}x{self.warp_shape[2]}_{self.stages}" + + hopper_suffix = f"_{self.cga_shape[0]}x{self.cga_shape[1]}x{self.cga_shape[2]}{KernelScheduleSuffixes[self.mainloop_schedule]}{EpilogueScheduleSuffixes[self.epi_schedule]}{EpiFusionSuffixes[self.epi_fusion]}" + + if self.arch >= 90: + return kernel_prefix + hopper_suffix + elif self.arch > 100: + raise ValueError(f"SM{self.arch} not supported yet.") + return kernel_prefix + + +################################################################################ +def tuple_to_cute_shape(shape): + return f"cute::Shape, cute::Int<{shape[1]}>, cute::Int<{shape[2]}>>" + + +def instantiate_operation_tma_warp_specialized(operation): + act_tag = CudaTypeName[operation.act_type] + scale_zero_tag = CudaTypeName[operation.scalezero_type] + bias_tag = CudaTypeName[operation.bias_type] + out_tag = CudaTypeName[operation.output_type] + + quant_op = QuantOpTag[operation.quant_op] + epi_tag = EpiTag[operation.epi_tag] + + cute_cta_shape = tuple_to_cute_shape(operation.cta_shape) + cute_cga_shape = tuple_to_cute_shape(operation.cga_shape) + + kernel_sched = KernelScheduleTag[operation.mainloop_schedule] + epi_sched = EpilogueScheduleTag[operation.epi_schedule] + + assert operation.gemm_kind == GemmKind.Gemm + weight_tag = DataTypeTag[operation.weight_type] + + return f""" +template void sm90_generic_mixed_gemm_kernelLauncher<{act_tag}, {weight_tag}, {scale_zero_tag}, {bias_tag}, {out_tag}, +{quant_op}, {epi_tag}, +{cute_cta_shape}, {cute_cga_shape}, +{kernel_sched}, {epi_sched}> ( +const {act_tag}*, const {weight_tag}*, const {scale_zero_tag}*, const {scale_zero_tag}*, const {bias_tag}*, const float, +{out_tag}*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); +""" + + +def instantiate_operation(insts_list, operation): + if operation.arch >= 90: + insts_list.append(instantiate_operation_tma_warp_specialized(operation)) + + +def get_file_content(launcher_inl_files, operations): + assert operations + include_list = list() + for file in launcher_inl_files: + include_list.append(f'#include "{file}"') + includes = "\n".join(include_list) + + insts_list = list() + for op in operations: + instantiate_operation(insts_list, op) + instantiations = "\n".join(insts_list) + + file_content = f"""{includes} +namespace onnxruntime::llm +{{ +namespace kernels +{{ +namespace cutlass_kernels +{{ + +{instantiations} + +}} // namespace cutlass_kernels +}} // namespace kernels +}} // namespace onnxruntime::llm +""" + return file_content + + +def write_file(launcher_inl_files, operations, output_file): + os.makedirs(os.path.dirname(output_file), exist_ok=True) + # Avoid changing modified time if file content is up to date + content = get_file_content(launcher_inl_files, operations) + if os.path.exists(output_file): + with open(output_file) as f: + if f.read() == content: + return + with open(output_file, mode="w") as f: + f.write(content) + + +def elementwise(x, y, f): + return tuple(f(a, b) for (a, b) in zip(x, y, strict=False)) + + +def is_gemm_op_valid(op): + tile_m, tile_n, _ = op.cta_shape + cga_m, cga_n, _ = op.cga_shape + + if cga_m == 1 and cga_n == 1: + return True + + if cga_m == 2 and cga_n == 1 and tile_m >= 128: + return True + + if cga_m == 1 and cga_n == 2 and tile_n >= 128: + return True + + if cga_m == 2 and cga_n == 2 and tile_m >= 128 and tile_n >= 128: + return True + + return False + + +################################################################################ +def generate_sm90_mixed_gemm_operations(enable_fp8=False, enable_scale_only=False): + arch = 90 + + # For legacy reasons, we use unsigned types for the weights. The instanitated template + # will remap those back to the signed type. + # Takes the form (activation_type, weight_type, scalezero_type, bias_type, output_type) + supported_dtypes = [ + (DataType.f16, DataType.u4, DataType.f16, DataType.f16, DataType.f16), + (DataType.f16, DataType.u8, DataType.f16, DataType.f16, DataType.f16), + (DataType.bf16, DataType.u4, DataType.bf16, DataType.bf16, DataType.bf16), + (DataType.bf16, DataType.u8, DataType.bf16, DataType.bf16, DataType.bf16), + ] + + if enable_fp8: + supported_dtypes = [ + *supported_dtypes, + (DataType.e4m3, DataType.u4, DataType.f16, DataType.f16, DataType.f16), + (DataType.e4m3, DataType.u4, DataType.f16, DataType.bf16, DataType.bf16), + ] + + quant_ops = [LlmQuantOp.finegrained_scale_and_zeros] + + if enable_scale_only: + quant_ops = [ + *quant_ops, + LlmQuantOp.finegrained_scale_only, + ] + + epi_tags = [LlmEpilogueTag.epilogue_op_bias] + + m_tiles = [64, 128] + n_tiles = [16, 32, 64, 128, 256] + cta_shapes_mn = product(m_tiles, n_tiles) + + warp_shape = [4, 1, 1] + stages = 0 # auto + + cga_shapes = product([1, 2], [1, 2], [1]) + + partial_args = product(supported_dtypes, quant_ops, epi_tags, cta_shapes_mn, cga_shapes) + + operations = list() + for dtype_combo, quant_op, epi_tag, cta_shape_mn, cga_shape in partial_args: + max_k_bits = 128 * 8 + cta_shape_k = max_k_bits // get_data_type_bits(dtype_combo[0]) + cta_shape_mnk = (*cta_shape_mn, cta_shape_k) + + use_coop = cta_shape_mn[0] == 128 + mainloop_schedule = ( + KernelScheduleType.TmaWarpSpecializedCooperative + if use_coop + else KernelScheduleType.TmaWarpSpecializedPingpong + ) + epi_schedule = ( + EpilogueScheduleType.TmaWarpSpecializedCooperative if use_coop else EpilogueScheduleType.TmaWarpSpecialized + ) + + mixed_gemm_operation = LlmGemmLauncher( + GemmKind.Gemm, + arch, + *dtype_combo, + quant_op, + epi_tag, + cta_shape_mnk, + warp_shape, + stages, + cga_shape, + mainloop_schedule, + epi_schedule, + ) + + if is_gemm_op_valid(mixed_gemm_operation): + operations.append(mixed_gemm_operation) + + return operations + + +def generate_sm90_operations(is_arch_enabled): + operations = generate_sm90_mixed_gemm_operations() + return operations + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Print the output directory") + + parser.add_argument("-o", "--output_dir", type=str, required=True, help="Path to the output directory") + parser.add_argument("-a", "--architectures", type=str, required=True, help="Architectures to generate kernels for") + + args = parser.parse_args() + + arches = args.architectures.split(";") + + output_dir = os.path.abspath(args.output_dir) + + include_map = { + (GemmKind.Gemm, 90): ["contrib_ops/cuda/llm/fpA_intB_gemm/launchers/fpA_intB_launcher_sm90.inl"], + } + + def has_arch(sm): + return f"{sm}" in arches or f"{sm}-real" in arches + + # The goal here is to group kernels with common instantiations together in order to reduce template instantiation overheads. + # Template instantiation dominates the time in a compilation unit, so it is the most important factor to improve. + operations = [] + operations += generate_sm90_operations(has_arch(90)) + + op_groups = dict() + for op in operations: + dict_key = (op.gemm_kind, op.arch, op.cta_shape[0]) + op_group = op_groups.get(dict_key, list()) + op_group.append(op) + op_groups[dict_key] = op_group + + file_counter = 1 + for key, value in op_groups.items(): + gemm_kind, _, _ = key + out_file = os.path.join(output_dir, f"fpA_intB_gemm_launcher_{file_counter}.generated.cu") + write_file(include_map[key[:2]], value, out_file) + file_counter += 1 diff --git a/onnxruntime/contrib_ops/cuda/llm/nv_infer_datatype.h b/onnxruntime/contrib_ops/cuda/llm/nv_infer_datatype.h new file mode 100644 index 0000000000000..52e8eb225c79c --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/nv_infer_datatype.h @@ -0,0 +1,59 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +// This is corresponding to nvinfer1 namespace used by TensorRT. Add it to avoid dependency on TensorRT. +namespace onnxruntime::llm::nvinfer { + +enum class DataType : int32_t { + //! 32-bit floating point format. + kFLOAT = 0, + + //! IEEE 16-bit floating-point format -- has a 5 bit exponent and 11 bit significand. + kHALF = 1, + + //! Signed 8-bit integer representing a quantized floating-point value. + kINT8 = 2, + + //! Signed 32-bit integer format. + kINT32 = 3, + + //! 8-bit boolean. 0 = false, 1 = true, other values undefined. + kBOOL = 4, + + //! Unsigned 8-bit integer format. + //! Cannot be used to represent quantized floating-point values. + kUINT8 = 5, + + //! Signed 8-bit floating point with + //! 1 sign bit, 4 exponent bits, 3 mantissa bits, and exponent-bias 7. + kFP8 = 6, + + //! Brain float -- has an 8 bit exponent and 8 bit significand. + kBF16 = 7, + + //! Signed 64-bit integer type. + kINT64 = 8, + + //! Signed 4-bit integer type. + kINT4 = 9, + + kFP4 = 10, +}; +} // namespace onnxruntime::llm::nvinfer diff --git a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/epilogue/thread/fused_activations.h b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/epilogue/thread/fused_activations.h index da8cb6d294efd..644caa950e5a4 100644 --- a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/epilogue/thread/fused_activations.h +++ b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/epilogue/thread/fused_activations.h @@ -67,27 +67,6 @@ __forceinline__ __device__ float tanh_opt(float x) { #endif } -///////////////////////////////////////////////////////////////////////////////////////////////// -template <> -struct GELU_taylor { - static bool const kIsHeavy = true; - - CUTLASS_DEVICE - float operator()(float const& z) const { - float k0 = static_cast(0.7978845608028654); - float k1 = static_cast(0.044715); - - return static_cast( - cutlass::constants::half() * z * - (cutlass::constants::one() + tanh_opt(k0 * z * (cutlass::constants::one() + k1 * z * z)))); - } - - using Params = LinearCombinationGenericParams; - - CUTLASS_DEVICE - float operator()(float const& scalar, Params const& params_) const { return this->operator()(scalar); } -}; - } // namespace thread } // namespace epilogue } // namespace cutlass diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc index 33265744f3a7d..3f485f0abdcb1 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc @@ -9,40 +9,253 @@ #include "core/framework/float16.h" #include "core/providers/cpu/math/matmul_helper.h" #include "contrib_ops/cuda/utils/dump_cuda_tensor.h" -#include "matmul_nbits.cuh" -#include "dequantize_blockwise.cuh" +#include "contrib_ops/cpu/utils/dump_tensor.h" +#include "contrib_ops/cuda/quantization/matmul_nbits.cuh" +#include "contrib_ops/cuda/quantization/dequantize_blockwise.cuh" +#include "contrib_ops/cuda/llm/fpA_intB_gemm/fpA_intB_gemm.h" +#include "contrib_ops/cuda/llm/fpA_intB_gemm_adaptor.h" +#include "contrib_ops/cuda/llm/cutlass_preprocessors.h" +#include "contrib_ops/cpu/quantization/matmul_nbits_helper.h" + +constexpr int MatMulNBits_Input_B = 1; +constexpr int MatMulNBits_Input_Scale = 2; +constexpr int MatMulNBits_Input_ZeroPoint = 3; namespace onnxruntime { namespace contrib { namespace cuda { using namespace onnxruntime::cuda; +using onnxruntime::llm::kernels::weight_only::GemmPluginProfilerManager; +using onnxruntime::llm::kernels::weight_only::WeightOnlyGroupwiseQuantGemmPluginProfiler; +using onnxruntime::llm::kernels::weight_only::WeightTypeId; +static GemmPluginProfilerManager s_profilerManager; + +template +void MatMulNBits::InitGemmProfiler(int sm) { + gemmProfiler_ = s_profilerManager.createGemmPluginProfiler(/*inference*/ false); + + if constexpr (std::is_same_v) { + if (nbits_ == 8) { + weightOnlyGemmRunner_ = std::make_shared>(); + } else if (nbits_ == 4) { + weightOnlyGemmRunner_ = std::make_shared>(); + } + } else if constexpr (std::is_same_v) { + if (nbits_ == 8) { + weightOnlyGemmRunner_ = std::make_shared>(); + } else if (nbits_ == 4) { + weightOnlyGemmRunner_ = std::make_shared>(); + } + } + + using onnxruntime::llm::kernels::fpA_intB_gemv::KernelType; + KernelType cuda_kernel_type = nbits_ == 8 ? KernelType::FP16Int8Groupwise : KernelType::FP16Int4Groupwise; + gemmProfiler_->setCudaKernelType(cuda_kernel_type, sm); + gemmProfiler_->setQuant(nbits_, has_bias_, has_zero_points_); + gemmProfiler_->setGroupSize(block_size_); +} + +template +void MatMulNBits::RunGemmProfile(bool hasWeightOnlyCudaKernel, int min_m, int max_m) { + // Number of 16-bit elements after casting int8/int4 to fp16. + int n_16b = N_ / (nbits_ == 8 ? 2 : 4); + + gemmId_ = GemmIdCore(n_16b, K_, onnxruntime::llm::nvinfer::DataType::kHALF); + + GemmDims dims = {min_m, max_m, n_16b, K_}; + gemmProfiler_->profileTactics(weightOnlyGemmRunner_, gemmId_.dtype, dims, gemmId_, hasWeightOnlyCudaKernel); +} + +template +Status MatMulNBits::PrePack(const Tensor& /* tensor */, int /* input_idx */, AllocatorPtr /*alloc*/, + /*out*/ bool& is_packed, + /*out*/ PrePackedWeights* /*prepacked_weights*/) { + is_packed = false; + return Status::OK(); +} + +template <> +Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + bool& is_packed, + PrePackedWeights* /*prepacked_weights*/) { + is_packed = false; + if (has_fpA_intB_gemm_) { + cudaStream_t stream = cudaStreamLegacy; // Use default stream for prepacking. + if (input_idx == MatMulNBits_Input_B) { + ORT_RETURN_IF_ERROR(PrePack_B(tensor, alloc, stream)); + is_packed = true; + } else if (input_idx == MatMulNBits_Input_Scale) { + ORT_RETURN_IF_ERROR(PrePack_Scale(tensor, alloc, stream)); + is_packed = true; + } else if (input_idx == MatMulNBits_Input_ZeroPoint) { + if (has_zero_points_) { + ORT_RETURN_IF_ERROR(PrePack_ZeroPoint(tensor, alloc, stream)); + is_packed = true; + } + } + } + + return Status::OK(); +} + +template +Status MatMulNBits::PrePack_B([[maybe_unused]] const Tensor& tensor, + [[maybe_unused]] AllocatorPtr alloc, + [[maybe_unused]] cudaStream_t stream) { + if constexpr (std::is_same_v) { + size_t n = static_cast(N_); + size_t k = static_cast(K_); + + size_t packed_weight_bytes = n * k / (8 / nbits_); + + // uint8 does not need to be packed so we do not need to allocate extra space. + IAllocatorUniquePtr packed_transposed_weight_space = this->GetTransientScratchBuffer(packed_weight_bytes); + int8_t* packed_transposed_weight = reinterpret_cast(packed_transposed_weight_space.get()); + + fpA_intB_weight_buffer_ = IAllocator::MakeUniquePtr(alloc, packed_weight_bytes, true); // Transient buffer. + + int8_t* preprocessed_weight = reinterpret_cast(fpA_intB_weight_buffer_.get()); + + const uint8_t* blob_data = tensor.Data(); + if (nbits_ == 4) { + // Transpose the weight and add default zero point. + onnxruntime::llm::kernels::fpA_intB_gemv::unpack_uint4_transposed_to_int8_direct_cuda( + stream, packed_transposed_weight, blob_data, n, k); + } else { + onnxruntime::llm::kernels::fpA_intB_gemv::transpose_uint8_matrix_and_convert_to_int8( + stream, packed_transposed_weight, blob_data, n, k); + } + + CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream)); + + auto tranpose_weight_buffer = this->AllocateBufferOnCPUPinned(packed_weight_bytes); + CUDA_RETURN_IF_ERROR(cudaMemcpy(tranpose_weight_buffer.get(), packed_transposed_weight, packed_weight_bytes, cudaMemcpyDeviceToHost)); + + auto processed_weight_buffer = this->AllocateBufferOnCPUPinned(n * k / (8 / nbits_)); + bool force_interleave = false; + + using onnxruntime::llm::kernels::cutlass_kernels::QuantType; + QuantType quant_type = nbits_ == 4 ? QuantType::W4_A16 : QuantType::W8_A16; + + // TODO: Add a cuda kernle for preprocessing so that we can avoid copying the data back to CPU. + onnxruntime::llm::kernels::cutlass_kernels::preprocess_weights_for_mixed_gemm( + reinterpret_cast(processed_weight_buffer.get()), + reinterpret_cast(tranpose_weight_buffer.get()), + {static_cast(k), static_cast(n)}, + quant_type, + force_interleave); + + CUDA_RETURN_IF_ERROR(cudaMemcpy(preprocessed_weight, processed_weight_buffer.get(), n * k / (8 / nbits_), cudaMemcpyHostToDevice)); + CUDA_RETURN_IF_ERROR(cudaDeviceSynchronize()); + + DUMP_TENSOR_INIT(); + DUMP_TENSOR_D("packed transposed_weight in GPU", packed_transposed_weight, k, n * nbits_ / 8); + DUMP_TENSOR_D("preprocessed_weight", reinterpret_cast(preprocessed_weight), k, n * nbits_ / 8); + } + + return Status::OK(); +} + +template +Status MatMulNBits::PrePack_Scale([[maybe_unused]] const Tensor& tensor, + [[maybe_unused]] AllocatorPtr alloc, + [[maybe_unused]] cudaStream_t stream) { + if constexpr (std::is_same_v) { + size_t n = static_cast(N_); + size_t k = static_cast(K_); + + size_t k_blocks = (k + block_size_ - 1) / block_size_; + size_t scale_bytes = n * k_blocks * sizeof(T); + + fpA_intB_scale_buffer_ = IAllocator::MakeUniquePtr(alloc, scale_bytes, true); // Transient buffer. + + typedef typename ToCudaType::MappedType CudaT; + CudaT* transposed_scales = reinterpret_cast(fpA_intB_scale_buffer_.get()); + + onnxruntime::llm::kernels::fpA_intB_gemv::launch_transpose_scale_kernel(stream, reinterpret_cast(tensor.Data()), transposed_scales, n, k_blocks); + CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream)); + + DUMP_TENSOR_INIT(); + DUMP_TENSOR_D("transposed_scales", transposed_scales, k_blocks, n); + } + return Status::OK(); +} + +template +Status MatMulNBits::PrePack_ZeroPoint([[maybe_unused]] const Tensor& tensor, + [[maybe_unused]] AllocatorPtr alloc, + [[maybe_unused]] cudaStream_t stream) { + if constexpr (std::is_same_v) { + size_t n = static_cast(N_); + size_t k = static_cast(K_); + + size_t k_blocks = (k + block_size_ - 1) / block_size_; + size_t scale_bytes = n * k_blocks * sizeof(T); + + typedef typename ToCudaType::MappedType CudaT; + const CudaT* transposed_scales = reinterpret_cast(fpA_intB_scale_buffer_.get()); + + fpA_intB_zero_buffer_ = IAllocator::MakeUniquePtr(alloc, scale_bytes, true); // Transient buffer. + CudaT* scaled_zero_points = reinterpret_cast(fpA_intB_zero_buffer_.get()); + + constexpr float kDefaultZeroPoint4Bit = 8.0f; + constexpr float kDefaultZeroPoint8Bit = 128.0f; + const float default_zero_point = nbits_ == 4 ? kDefaultZeroPoint4Bit : kDefaultZeroPoint8Bit; + const auto* zero_points_data = tensor.DataRaw(); + + // The scaled zero point will be zero for the default zero point, so there is no need to scale when it is nullptr. + if (!tensor.IsDataType()) { // zero point is uint8_t type + if (nbits_ == 4) { + onnxruntime::llm::kernels::fpA_intB_gemv::launch_scaled_zero_point_kernel( + stream, reinterpret_cast(zero_points_data), + transposed_scales, scaled_zero_points, n, k_blocks, default_zero_point); + } else { + onnxruntime::llm::kernels::fpA_intB_gemv::launch_scaled_zero_point_kernel( + stream, reinterpret_cast(zero_points_data), + transposed_scales, scaled_zero_points, n, k_blocks, default_zero_point); + } + } else { // zero point is not uint8_t type + onnxruntime::llm::kernels::fpA_intB_gemv::launch_scaled_zero_point_kernel( + stream, reinterpret_cast(zero_points_data), + transposed_scales, scaled_zero_points, n, k_blocks, default_zero_point); + } + CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream)); + + DUMP_TENSOR_INIT(); + DUMP_TENSOR_D("scaled_zero_points", scaled_zero_points, k_blocks, n); + } + return Status::OK(); +} template Status MatMulNBits::ComputeInternal(OpKernelContext* ctx) const { + const bool is_prepacked = has_fpA_intB_gemm_; const Tensor* a = ctx->Input(0); - const Tensor* b = ctx->Input(1); - const Tensor* scales = ctx->Input(2); - const Tensor* zero_points = ctx->Input(3); + const Tensor* b = is_prepacked ? nullptr : ctx->Input(1); + const Tensor* scales = is_prepacked ? nullptr : ctx->Input(2); + const Tensor* zero_points = is_prepacked ? nullptr : ctx->Input(3); const Tensor* reorder_idx = ctx->Input(4); const Tensor* bias = ctx->Input(5); + if (bias != nullptr) { ORT_THROW("MatMulNBits does not support bias in CUDA kernel"); } + ORT_RETURN_IF_ERROR(matmul_nbits_helper::CheckInputs( + a, b, scales, zero_points, reorder_idx, bias, N_, K_, block_size_, nbits_)); + const auto* a_data = a->Data(); - const uint8_t* blob_data = b->Data(); - const auto* scales_data = scales->Data(); - const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->DataRaw(); + const uint8_t* blob_data = is_prepacked ? nullptr : b->Data(); + const auto* scales_data = is_prepacked ? nullptr : scales->Data(); + const auto* zero_points_data = (is_prepacked || zero_points == nullptr) ? nullptr : zero_points->DataRaw(); const auto* reorder_idx_data = reorder_idx == nullptr ? nullptr : reorder_idx->Data(); - - typedef typename ToCudaType::MappedType CudaT; + const auto* bias_data = bias == nullptr ? nullptr : bias->Data(); constexpr bool transa = false; constexpr bool transb = true; MatMulComputeHelper helper; TensorShape b_shape({N_, K_}); - ORT_RETURN_IF_ERROR( - helper.Compute(a->Shape(), b_shape, transa, transb)); + ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b_shape, transa, transb)); Tensor* Y = ctx->Output(0, helper.OutputShape()); @@ -50,6 +263,61 @@ Status MatMulNBits::ComputeInternal(OpKernelContext* ctx) const { if (Y->Shape().Size() == 0) return Status::OK(); + cudaStream_t stream = static_cast(ctx->GetComputeStream()->GetHandle()); + + typedef typename ToCudaType::MappedType CudaT; + CudaT* out_data = reinterpret_cast(Y->MutableData()); + + int m = SafeInt(helper.M()); + int n = SafeInt(helper.N()); + int k = SafeInt(helper.K()); + + DUMP_TENSOR_INIT(); + + if constexpr (std::is_same::value) { + if (has_fpA_intB_gemm_) { + auto const& bestTactic = gemmProfiler_->getBestConfig(m, gemmId_); + + DUMP_STRING("Best tactic: m=", m, " n=", n, " k=", k, " group_size=", block_size_, bestTactic->toString()); + + if (bestTactic->enableCudaKernel) { + using onnxruntime::llm::kernels::fpA_intB_gemv::KernelType; + KernelType cuda_kernel_type = (nbits_ == 8) ? KernelType::FP16Int8Groupwise : KernelType::FP16Int4Groupwise; + + void const* pre_quant_scale_ptr = nullptr; + bool apply_alpha_in_advance = false; + float alpha = 1.0f; + onnxruntime::llm::kernels::fpA_intB_gemv::Params params( + a_data, pre_quant_scale_ptr, fpA_intB_weight_buffer_.get(), + fpA_intB_scale_buffer_.get(), has_zero_points_ ? fpA_intB_zero_buffer_.get() : nullptr, + bias_data, out_data, + alpha, m, n, k, block_size_, cuda_kernel_type, apply_alpha_in_advance); + + onnxruntime::llm::kernels::fpA_intB_gemv::kernel_launcher(sm_, params, stream); + } else { + const size_t workspace_size = weightOnlyGemmRunner_->getWorkspaceSize(m, n, k); + auto workspace_buffer = GetScratchBuffer(workspace_size, ctx->GetComputeStream()); + + weightOnlyGemmRunner_->gemm( + a_data, + fpA_intB_weight_buffer_.get(), + fpA_intB_scale_buffer_.get(), + has_zero_points_ ? fpA_intB_zero_buffer_.get() : nullptr, + bias_data, + 1.f, + out_data, + m, n, k, + block_size_, + *bestTactic, + reinterpret_cast(workspace_buffer.get()), + workspace_size, + stream); + } + + return Status::OK(); + } + } + if ((reorder_idx_data == nullptr) && (!zero_points || !zero_points->IsDataType())) { bool done = (nbits_ == 8) ? TryMatMul8Bits( reinterpret_cast(Y->MutableData()), @@ -57,24 +325,24 @@ Status MatMulNBits::ComputeInternal(OpKernelContext* ctx) const { blob_data, reinterpret_cast(scales_data), static_cast(zero_points_data), - SafeInt(helper.M()), - SafeInt(helper.N()), - SafeInt(helper.K()), + m, + n, + k, SafeInt(block_size_), GetDeviceProp().sharedMemPerBlock, - static_cast(ctx->GetComputeStream()->GetHandle())) + stream) : TryMatMul4Bits( reinterpret_cast(Y->MutableData()), reinterpret_cast(a_data), blob_data, reinterpret_cast(scales_data), static_cast(zero_points_data), - SafeInt(helper.M()), - SafeInt(helper.N()), - SafeInt(helper.K()), + m, + n, + k, SafeInt(block_size_), GetDeviceProp().sharedMemPerBlock, - static_cast(ctx->GetComputeStream()->GetHandle())); + stream); if (done) { return Status::OK(); } @@ -99,7 +367,7 @@ Status MatMulNBits::ComputeInternal(OpKernelContext* ctx) const { SafeInt(K_padded), SafeInt(N_), SafeInt(block_size_), - static_cast(ctx->GetComputeStream()->GetHandle()))); + stream)); } else { ORT_RETURN_IF_ERROR(Dequantize8Bits( reinterpret_cast(b_data), @@ -110,7 +378,7 @@ Status MatMulNBits::ComputeInternal(OpKernelContext* ctx) const { SafeInt(K_padded), SafeInt(N_), SafeInt(block_size_), - static_cast(ctx->GetComputeStream()->GetHandle()))); + stream)); } } else { // row-wise block ORT_RETURN_IF_ERROR(DequantizeBlockwise8b( @@ -122,7 +390,7 @@ Status MatMulNBits::ComputeInternal(OpKernelContext* ctx) const { column_wise_quant_blk_, SafeInt(K_), SafeInt(N_), - static_cast(ctx->GetComputeStream()->GetHandle()))); + stream)); } } else { // 4 bits if (column_wise_quant_blk_) { @@ -140,7 +408,7 @@ Status MatMulNBits::ComputeInternal(OpKernelContext* ctx) const { SafeInt(K_padded), SafeInt(N_), SafeInt(block_size_), - static_cast(ctx->GetComputeStream()->GetHandle()))); + stream)); } else { ORT_RETURN_IF_ERROR(Dequantize4Bits( reinterpret_cast(b_data), @@ -151,7 +419,7 @@ Status MatMulNBits::ComputeInternal(OpKernelContext* ctx) const { SafeInt(K_padded), SafeInt(N_), SafeInt(block_size_), - static_cast(ctx->GetComputeStream()->GetHandle()))); + stream)); } } else { // row-wise block @@ -166,11 +434,10 @@ Status MatMulNBits::ComputeInternal(OpKernelContext* ctx) const { column_wise_quant_blk_, SafeInt(K_), SafeInt(N_), - static_cast(ctx->GetComputeStream()->GetHandle()))); + stream)); } } - DUMP_TENSOR_INIT(); DUMP_TENSOR_D("DeQuantized", b_data, N_, K_padded); const CudaT alpha = ToCudaType::FromFloat(1.f); @@ -207,7 +474,8 @@ ONNX_OPERATOR_TYPED_KERNEL_EX( kCudaExecutionProvider, (*KernelDefBuilder::Create()) .TypeConstraint("T1", DataTypeImpl::GetTensorType()) - .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + .TypeConstraint("T2", DataTypeImpl::GetTensorType()) + .TypeConstraint("T3", {DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}), MatMulNBits); ONNX_OPERATOR_TYPED_KERNEL_EX( @@ -218,7 +486,8 @@ ONNX_OPERATOR_TYPED_KERNEL_EX( kCudaExecutionProvider, (*KernelDefBuilder::Create()) .TypeConstraint("T1", DataTypeImpl::GetTensorType()) - .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + .TypeConstraint("T2", DataTypeImpl::GetTensorType()) + .TypeConstraint("T3", {DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}), MatMulNBits); } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.h b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.h index f5c2c6c4e4fdf..02740d905c7c7 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.h +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.h @@ -10,11 +10,27 @@ #include "core/common/safeint.h" #include "core/providers/cuda/cuda_kernel.h" #include "core/providers/cuda/shared_inc/fpgeneric.h" +#include "contrib_ops/cuda/llm/fpA_intB_gemm_profiler.h" +#include "core/platform/env_var_utils.h" namespace onnxruntime { namespace contrib { namespace cuda { using namespace onnxruntime::cuda; +using onnxruntime::llm::kernels::cutlass_kernels::CutlassFpAIntBGemmRunner; +using onnxruntime::llm::kernels::weight_only::GemmDims; +using onnxruntime::llm::kernels::weight_only::GemmIdCore; +using onnxruntime::llm::kernels::weight_only::GemmPluginProfilerManager; +using onnxruntime::llm::kernels::weight_only::WeightOnlyGroupwiseQuantGemmPluginProfiler; +using GemmProfilerPtr = std::shared_ptr; +using WeightOnlyGemmRunnerPtr = std::shared_ptr; + +// Environment variable to configure fpA_intB_gemm for experiments. Set it to 0 to disable, 1 to eanble all. +constexpr const char* kFpAIntBGemmOption = "ORT_FPA_INTB_GEMM"; +constexpr int kFpAIntBGemmOption_All = 0x01; +constexpr int kFpAIntBGemmOption_Gemv = 0x02; +constexpr int kFpAIntBGemmOption_Int4 = 0x04; +constexpr int kFpAIntBGemmOption_Int8 = 0x08; template class MatMulNBits final : public CudaKernel { @@ -24,16 +40,91 @@ class MatMulNBits final : public CudaKernel { ORT_ENFORCE(Status::OK() == info.GetAttr("N", &N_)); ORT_ENFORCE(Status::OK() == info.GetAttr("block_size", &block_size_)); ORT_ENFORCE(Status::OK() == info.GetAttr("bits", &nbits_)); + + constexpr size_t kInputIndexScale = 2; + constexpr size_t kInputIndexZeroPoints = 3; + constexpr size_t kInputIndexGroupIndex = 4; + constexpr size_t kInputIndexBias = 5; + + has_zero_points_ = info.GetInputCount() > kInputIndexZeroPoints && info.node().InputDefs()[kInputIndexZeroPoints]->Exists(); + has_g_idx_ = info.GetInputCount() > kInputIndexGroupIndex && info.node().InputDefs()[kInputIndexGroupIndex]->Exists(); + has_bias_ = info.GetInputCount() > kInputIndexBias && info.node().InputDefs()[kInputIndexBias]->Exists(); + sm_ = this->GetDeviceProp().major * 10 + this->GetDeviceProp().minor; + + if (has_zero_points_) { + int32_t zero_point_type = info.node().InputDefs()[kInputIndexZeroPoints]->TypeAsProto()->tensor_type().elem_type(); + int32_t scale_type = info.node().InputDefs()[kInputIndexScale]->TypeAsProto()->tensor_type().elem_type(); + is_zero_points_scale_same_type_ = (zero_point_type == scale_type); + } + + if constexpr (std::is_same::value) { + int option = ParseEnvironmentVariableWithDefault(kFpAIntBGemmOption, 0); + if ((option & (static_cast(nbits_) | kFpAIntBGemmOption_All)) != 0 && + (block_size_ == 64 || block_size_ == 128) && + (nbits_ == 4 || nbits_ == 8) && + !has_g_idx_ && has_zero_points_ && !has_bias_ && + N_ % (nbits_ == 8 ? 32 : 64) == 0 && + K_ % block_size_ == 0 && + sm_ >= 75) { + if ((option & (kFpAIntBGemmOption_Gemv | kFpAIntBGemmOption_All)) != 0) { + using onnxruntime::llm::kernels::fpA_intB_gemv::KernelType; + KernelType cuda_kernel_type = (nbits_ == 8) ? KernelType::FP16Int8Groupwise : KernelType::FP16Int4Groupwise; + if (onnxruntime::llm::kernels::fpA_intB_gemv::is_supported(sm_, cuda_kernel_type)) { + has_fpA_intB_gemv_ = true; + } + } + + InitGemmProfiler(sm_); + + constexpr int max_m = 8291; + RunGemmProfile(has_fpA_intB_gemv_, 1, max_m); + has_fpA_intB_gemm_ = true; + } + } + +#ifndef NDEBUG + printf("n=%d, k=%d, block_size=%d, bits=%d, zp_bits=%d, g_idx=%d, bias=%d, gemv=%d, gemm=%d\n", + int(N_), int(K_), int(block_size_), int(nbits_), + has_zero_points_ ? (is_zero_points_scale_same_type_ ? int(sizeof(T)) * 8 : int(nbits_)) : int(0), + int(has_g_idx_ ? 1 : 0), int(has_bias_ ? 1 : 0), + int(has_fpA_intB_gemv_), int(has_fpA_intB_gemm_)); +#endif } Status ComputeInternal(OpKernelContext* context) const override; + Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + bool& is_packed, PrePackedWeights* prepacked_weights) override; + private: + void InitGemmProfiler(int sm); + void RunGemmProfile(bool hasWeightOnlyCudaKernel, int min_m, int max_m); + + Status PrePack_B(const Tensor& tensor, AllocatorPtr alloc, cudaStream_t stream); + Status PrePack_Scale(const Tensor& tensor, AllocatorPtr alloc, cudaStream_t stream); + Status PrePack_ZeroPoint(const Tensor& tensor, AllocatorPtr alloc, cudaStream_t stream); + int64_t K_; int64_t N_; int64_t block_size_; int64_t nbits_; + int sm_{0}; bool column_wise_quant_blk_{true}; + + bool has_g_idx_{false}; + bool has_bias_{false}; + bool has_zero_points_{false}; + bool is_zero_points_scale_same_type_{false}; + bool has_fpA_intB_gemv_{false}; + bool has_fpA_intB_gemm_{false}; + + WeightOnlyGemmRunnerPtr weightOnlyGemmRunner_{nullptr}; + mutable GemmProfilerPtr gemmProfiler_{nullptr}; + GemmIdCore gemmId_{}; + + IAllocatorUniquePtr fpA_intB_weight_buffer_; + IAllocatorUniquePtr fpA_intB_scale_buffer_; + IAllocatorUniquePtr fpA_intB_zero_buffer_; }; } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.cc b/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.cc index 5c39cf56dfd92..b986f0ae3edad 100644 --- a/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.cc +++ b/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.cc @@ -146,6 +146,31 @@ void DumpGpuTensor(const char* name, const T* tensor, int dim0, int dim1, int di } } +void DumpGpuTensor(const char* name, const Tensor& tensor, int dim0, int dim1, int dim2, int dim3) { + MLDataType dataType = tensor.DataType(); + bool is_gpu_tensor = (tensor.Location().device.Type() == OrtDevice::GPU); + if (dataType == DataTypeImpl::GetType()) { + DumpGpuTensor(name, tensor.Data(), dim0, dim1, dim2, dim3, is_gpu_tensor); + } else if (dataType == DataTypeImpl::GetType()) { + DumpGpuTensor(name, tensor.Data(), dim0, dim1, dim2, dim3, is_gpu_tensor); + } else if (dataType == DataTypeImpl::GetType()) { + DumpGpuTensor(name, tensor.Data(), dim0, dim1, dim2, dim3, is_gpu_tensor); + } else if (dataType == DataTypeImpl::GetType()) { + DumpGpuTensor(name, tensor.Data(), dim0, dim1, dim2, dim3, is_gpu_tensor); + } else if (dataType == DataTypeImpl::GetType()) { + DumpGpuTensor(name, tensor.Data(), dim0, dim1, dim2, dim3, is_gpu_tensor); + } else if (dataType == DataTypeImpl::GetType()) { + DumpGpuTensor(name, tensor.Data(), dim0, dim1, dim2, dim3, is_gpu_tensor); + } else if (dataType == DataTypeImpl::GetType()) { + DumpGpuTensor(name, tensor.Data(), dim0, dim1, dim2, dim3, is_gpu_tensor); + } else if (dataType == DataTypeImpl::GetType()) { + DumpGpuTensor(name, tensor.Data(), dim0, dim1, dim2, dim3, is_gpu_tensor); + } else { + std::cout << std::string(name) << std::endl; + std::cout << "The data type is not supported in DumpGpuTensor" << std::endl; + } +} + void DumpGpuTensor(const char* name, const Tensor& tensor, int dim0, int dim1, int dim2) { MLDataType dataType = tensor.DataType(); bool is_gpu_tensor = (tensor.Location().device.Type() == OrtDevice::GPU); @@ -157,8 +182,17 @@ void DumpGpuTensor(const char* name, const Tensor& tensor, int dim0, int dim1, i DumpGpuTensor(name, tensor.Data(), dim0, dim1, dim2, is_gpu_tensor); } else if (dataType == DataTypeImpl::GetType()) { DumpGpuTensor(name, tensor.Data(), dim0, dim1, dim2, is_gpu_tensor); + } else if (dataType == DataTypeImpl::GetType()) { + DumpGpuTensor(name, tensor.Data(), dim0, dim1, dim2, is_gpu_tensor); + } else if (dataType == DataTypeImpl::GetType()) { + DumpGpuTensor(name, tensor.Data(), dim0, dim1, dim2, is_gpu_tensor); + } else if (dataType == DataTypeImpl::GetType()) { + DumpGpuTensor(name, tensor.Data(), dim0, dim1, dim2, is_gpu_tensor); + } else if (dataType == DataTypeImpl::GetType()) { + DumpGpuTensor(name, tensor.Data(), dim0, dim1, dim2, is_gpu_tensor); } else { - assert(0); + std::cout << std::string(name) << std::endl; + std::cout << "The data type is not supported in DumpGpuTensor" << std::endl; } } @@ -173,11 +207,24 @@ void DumpGpuTensor(const char* name, const Tensor& tensor, int dim0, int dim1) { DumpGpuTensor(name, tensor.Data(), dim0, dim1, is_gpu_tensor); } else if (dataType == DataTypeImpl::GetType()) { DumpGpuTensor(name, tensor.Data(), dim0, dim1, is_gpu_tensor); + } else if (dataType == DataTypeImpl::GetType()) { + DumpGpuTensor(name, tensor.Data(), dim0, dim1, is_gpu_tensor); + } else if (dataType == DataTypeImpl::GetType()) { + DumpGpuTensor(name, tensor.Data(), dim0, dim1, is_gpu_tensor); + } else if (dataType == DataTypeImpl::GetType()) { + DumpGpuTensor(name, tensor.Data(), dim0, dim1, is_gpu_tensor); + } else if (dataType == DataTypeImpl::GetType()) { + DumpGpuTensor(name, tensor.Data(), dim0, dim1, is_gpu_tensor); } else { - assert(0); + std::cout << std::string(name) << std::endl; + std::cout << "The data type is not supported in DumpGpuTensor" << std::endl; } } +void DumpGpuTensor(const char* name, const Tensor& tensor, int dim0) { + DumpGpuTensor(name, tensor, 1, dim0); +} + void DumpGpuTensor(const char* name, const Tensor& tensor) { const auto& shape = tensor.Shape(); @@ -188,21 +235,33 @@ void DumpGpuTensor(const char* name, const Tensor& tensor) { std::cout << tensor.Location().ToString() << std::endl; size_t num_dims = shape.NumDimensions(); - if (num_dims >= 3) { - int dim0 = static_cast(shape.SizeToDimension(num_dims - 2)); - int dim1 = static_cast(shape[num_dims - 2]); - int dim2 = static_cast(shape[num_dims - 1]); + if (num_dims >= 4) { + int dim0 = static_cast(shape.SizeToDimension(num_dims - 4)); + int dim1 = static_cast(shape[num_dims - 3]); + int dim2 = static_cast(shape[num_dims - 2]); + int dim3 = static_cast(shape[num_dims - 1]); + DumpGpuTensor(nullptr, tensor, dim0, dim1, dim2, dim3); + return; + } + + if (num_dims == 3) { + int dim0 = static_cast(shape[0]); + int dim1 = static_cast(shape[1]); + int dim2 = static_cast(shape[2]); DumpGpuTensor(nullptr, tensor, dim0, dim1, dim2); return; } - auto num_items = shape.Size(); - size_t num_rows = 1; - if (num_dims > 1) { - num_rows = static_cast(shape[0]); + if (num_dims == 2) { + int dim0 = static_cast(shape[0]); + int dim1 = static_cast(shape[1]); + DumpGpuTensor(nullptr, tensor, dim0, dim1); + return; + } + + if (num_dims == 1) { + DumpGpuTensor(nullptr, tensor, static_cast(shape[0])); } - size_t row_size = num_items / num_rows; - DumpGpuTensor(nullptr, tensor, static_cast(num_rows), static_cast(row_size)); } CudaTensorConsoleDumper::CudaTensorConsoleDumper() { @@ -213,98 +272,6 @@ void CudaTensorConsoleDumper::Print(const std::string& value) const { std::cout << value << std::endl; } -void CudaTensorConsoleDumper::Print(const char* name, const size_t* tensor, int dim0, int dim1) const { - if (is_enabled_) - DumpGpuTensor(name, tensor, dim0, dim1, true); -} - -void CudaTensorConsoleDumper::Print(const char* name, const int32_t* tensor, int dim0, int dim1) const { - if (is_enabled_) - DumpGpuTensor(name, tensor, dim0, dim1, true); -} - -void CudaTensorConsoleDumper::Print(const char* name, const int32_t* tensor, int dim0, int dim1, int dim2) const { - if (is_enabled_) - DumpGpuTensor(name, tensor, dim0, dim1, dim2, true); -} - -void CudaTensorConsoleDumper::Print(const char* name, const int32_t* tensor, int dim0, int dim1, int dim2, int dim3) const { - if (is_enabled_) - DumpGpuTensor(name, tensor, dim0, dim1, dim2, dim3, true); -} - -void CudaTensorConsoleDumper::Print(const char* name, const int64_t* tensor, int dim0, int dim1) const { - if (is_enabled_) - DumpGpuTensor(name, tensor, dim0, dim1, true); -} - -void CudaTensorConsoleDumper::Print(const char* name, const int64_t* tensor, int dim0, int dim1, int dim2) const { - if (is_enabled_) - DumpGpuTensor(name, tensor, dim0, dim1, dim2, true); -} - -void CudaTensorConsoleDumper::Print(const char* name, const int64_t* tensor, int dim0, int dim1, int dim2, int dim3) const { - if (is_enabled_) - DumpGpuTensor(name, tensor, dim0, dim1, dim2, dim3, true); -} - -void CudaTensorConsoleDumper::Print(const char* name, const float* tensor, int dim0, int dim1) const { - if (is_enabled_) - DumpGpuTensor(name, tensor, dim0, dim1, true); -} - -void CudaTensorConsoleDumper::Print(const char* name, const float* tensor, int dim0, int dim1, int dim2) const { - if (is_enabled_) - DumpGpuTensor(name, tensor, dim0, dim1, dim2, true); -} - -void CudaTensorConsoleDumper::Print(const char* name, const float* tensor, int dim0, int dim1, int dim2, int dim3) const { - if (is_enabled_) - DumpGpuTensor(name, tensor, dim0, dim1, dim2, dim3, true); -} - -void CudaTensorConsoleDumper::Print(const char* name, const MLFloat16* tensor, int dim0, int dim1) const { - if (is_enabled_) - DumpGpuTensor(name, tensor, dim0, dim1, true); -} - -void CudaTensorConsoleDumper::Print(const char* name, const MLFloat16* tensor, int dim0, int dim1, int dim2) const { - if (is_enabled_) - DumpGpuTensor(name, tensor, dim0, dim1, dim2, true); -} - -void CudaTensorConsoleDumper::Print(const char* name, const MLFloat16* tensor, int dim0, int dim1, int dim2, int dim3) const { - if (is_enabled_) - DumpGpuTensor(name, tensor, dim0, dim1, dim2, dim3, true); -} - -void CudaTensorConsoleDumper::Print(const char* name, const BFloat16* tensor, int dim0, int dim1) const { - if (is_enabled_) - DumpGpuTensor(name, tensor, dim0, dim1, true); -} - -void CudaTensorConsoleDumper::Print(const char* name, const BFloat16* tensor, int dim0, int dim1, int dim2) const { - if (is_enabled_) - DumpGpuTensor(name, tensor, dim0, dim1, dim2, true); -} - -void CudaTensorConsoleDumper::Print(const char* name, const BFloat16* tensor, int dim0, int dim1, int dim2, int dim3) const { - if (is_enabled_) - DumpGpuTensor(name, tensor, dim0, dim1, dim2, dim3, true); -} - -void CudaTensorConsoleDumper::Print(const char* name, const half* tensor, int dim0, int dim1) const { - Print(name, reinterpret_cast(tensor), dim0, dim1); -} - -void CudaTensorConsoleDumper::Print(const char* name, const half* tensor, int dim0, int dim1, int dim2) const { - Print(name, reinterpret_cast(tensor), dim0, dim1, dim2); -} - -void CudaTensorConsoleDumper::Print(const char* name, const half* tensor, int dim0, int dim1, int dim2, int dim3) const { - Print(name, reinterpret_cast(tensor), dim0, dim1, dim2, dim3); -} - void CudaTensorConsoleDumper::Print(const char* name, const Tensor& tensor) const { if (is_enabled_) DumpGpuTensor(name, tensor); @@ -315,48 +282,35 @@ void CudaTensorConsoleDumper::Print(const char* name, const OrtValue& value) con Print(name, tensor); } -void CudaTensorConsoleDumper::Print(const char* name, int index, bool end_line) const { - if (!is_enabled_) - return; - - std::cout << std::string(name) << "[" << index << "]"; - if (end_line) { - std::cout << std::endl; - } -} - -void CudaTensorConsoleDumper::Print(const char* name, const std::string& value, bool end_line) const { - if (!is_enabled_) - return; - - std::cout << std::string(name) << "=" << value; - if (end_line) { - std::cout << std::endl; +#define CUDA_DUMPER_PRINT_TYPE(dtype, dtype2) \ + void CudaTensorConsoleDumper::Print(const char* name, const dtype* tensor, int dim0, int dim1) const { \ + if (is_enabled_) \ + DumpGpuTensor(name, reinterpret_cast(tensor), dim0, dim1, true); \ + } \ + void CudaTensorConsoleDumper::Print(const char* name, const dtype* tensor, int dim0, int dim1, int dim2) const { \ + if (is_enabled_) \ + DumpGpuTensor(name, reinterpret_cast(tensor), dim0, dim1, dim2, true); \ + } \ + void CudaTensorConsoleDumper::Print(const char* name, const dtype* tensor, int dim0, int dim1, int dim2, int dim3) const { \ + if (is_enabled_) \ + DumpGpuTensor(name, reinterpret_cast(tensor), dim0, dim1, dim2, dim3, true); \ + } \ + void CudaTensorConsoleDumper::Print(const char* name, const dtype* tensor, gsl::span& dims) const { \ + PrintTensorByDims(this, name, reinterpret_cast(tensor), dims); \ } -} -void CudaTensorConsoleDumper::Print(const char* name, const int32_t* tensor, gsl::span& dims) const { - PrintTensorByDims(this, name, tensor, dims); -} -void CudaTensorConsoleDumper::Print(const char* name, const int64_t* tensor, gsl::span& dims) const { - PrintTensorByDims(this, name, tensor, dims); -} - -void CudaTensorConsoleDumper::Print(const char* name, const float* tensor, gsl::span& dims) const { - PrintTensorByDims(this, name, tensor, dims); -} - -void CudaTensorConsoleDumper::Print(const char* name, const half* tensor, gsl::span& dims) const { - PrintTensorByDims(this, name, tensor, dims); -} - -void CudaTensorConsoleDumper::Print(const char* name, const MLFloat16* tensor, gsl::span& dims) const { - PrintTensorByDims(this, name, tensor, dims); -} +CUDA_DUMPER_PRINT_TYPE(int8_t, int8_t) +CUDA_DUMPER_PRINT_TYPE(uint8_t, uint8_t) +CUDA_DUMPER_PRINT_TYPE(int32_t, int32_t) +CUDA_DUMPER_PRINT_TYPE(int64_t, int64_t) +CUDA_DUMPER_PRINT_TYPE(float, float) +CUDA_DUMPER_PRINT_TYPE(MLFloat16, MLFloat16) +CUDA_DUMPER_PRINT_TYPE(BFloat16, BFloat16) +CUDA_DUMPER_PRINT_TYPE(UInt4x2, UInt4x2) +CUDA_DUMPER_PRINT_TYPE(Int4x2, Int4x2) -void CudaTensorConsoleDumper::Print(const char* name, const BFloat16* tensor, gsl::span& dims) const { - PrintTensorByDims(this, name, tensor, dims); -} +CUDA_DUMPER_PRINT_TYPE(half, MLFloat16) +#undef DUMPER_PRINT_TYPE #else CudaTensorConsoleDumper::CudaTensorConsoleDumper() { @@ -365,92 +319,33 @@ CudaTensorConsoleDumper::CudaTensorConsoleDumper() { void CudaTensorConsoleDumper::Print(const std::string&) const { } -void CudaTensorConsoleDumper::Print(const char*, const size_t*, int, int) const { -} - -void CudaTensorConsoleDumper::Print(const char*, const int32_t*, int, int) const { -} - -void CudaTensorConsoleDumper::Print(const char*, const int32_t*, int, int, int) const { -} - -void CudaTensorConsoleDumper::Print(const char*, const int32_t*, int, int, int, int) const { -} - -void CudaTensorConsoleDumper::Print(const char*, const int64_t*, int, int) const { -} - -void CudaTensorConsoleDumper::Print(const char*, const int64_t*, int, int, int) const { -} - -void CudaTensorConsoleDumper::Print(const char*, const int64_t*, int, int, int, int) const { -} - -void CudaTensorConsoleDumper::Print(const char*, const float*, int, int) const { -} - -void CudaTensorConsoleDumper::Print(const char*, const float*, int, int, int) const { -} - -void CudaTensorConsoleDumper::Print(const char*, const float*, int, int, int, int) const { -} - -void CudaTensorConsoleDumper::Print(const char*, const MLFloat16*, int, int) const { -} - -void CudaTensorConsoleDumper::Print(const char*, const MLFloat16*, int, int, int) const { -} - -void CudaTensorConsoleDumper::Print(const char*, const MLFloat16*, int, int, int, int) const { -} - -void CudaTensorConsoleDumper::Print(const char*, const BFloat16*, int, int) const { -} - -void CudaTensorConsoleDumper::Print(const char*, const BFloat16*, int, int, int) const { -} - -void CudaTensorConsoleDumper::Print(const char*, const BFloat16*, int, int, int, int) const { -} - -void CudaTensorConsoleDumper::Print(const char*, const half*, int, int) const { -} - -void CudaTensorConsoleDumper::Print(const char*, const half*, int, int, int) const { -} - -void CudaTensorConsoleDumper::Print(const char*, const half*, int, int, int, int) const { -} - void CudaTensorConsoleDumper::Print(const char*, const Tensor&) const { } void CudaTensorConsoleDumper::Print(const char*, const OrtValue&) const { } -void CudaTensorConsoleDumper::Print(const char*, int, bool) const { -} - -void CudaTensorConsoleDumper::Print(const char*, const std::string&, bool) const { -} - -void CudaTensorConsoleDumper::Print(const char*, const int32_t*, gsl::span&) const { -} - -void CudaTensorConsoleDumper::Print(const char*, const int64_t*, gsl::span&) const { -} - -void CudaTensorConsoleDumper::Print(const char*, const float*, gsl::span&) const { -} - -void CudaTensorConsoleDumper::Print(const char*, const half*, gsl::span&) const { -} - -void CudaTensorConsoleDumper::Print(const char*, const MLFloat16*, gsl::span&) const { -} +#define CUDA_DUMPER_PRINT_TYPE(dtype) \ + void CudaTensorConsoleDumper::Print(const char*, const dtype*, int, int) const { \ + } \ + void CudaTensorConsoleDumper::Print(const char*, const dtype*, int, int, int) const { \ + } \ + void CudaTensorConsoleDumper::Print(const char*, const dtype*, int, int, int, int) const { \ + } \ + void CudaTensorConsoleDumper::Print(const char*, const dtype*, gsl::span&) const { \ + } -void CudaTensorConsoleDumper::Print(const char*, const BFloat16*, gsl::span&) const { -} +CUDA_DUMPER_PRINT_TYPE(int8_t) +CUDA_DUMPER_PRINT_TYPE(uint8_t) +CUDA_DUMPER_PRINT_TYPE(int32_t) +CUDA_DUMPER_PRINT_TYPE(int64_t) +CUDA_DUMPER_PRINT_TYPE(float) +CUDA_DUMPER_PRINT_TYPE(MLFloat16) +CUDA_DUMPER_PRINT_TYPE(BFloat16) +CUDA_DUMPER_PRINT_TYPE(UInt4x2) +CUDA_DUMPER_PRINT_TYPE(Int4x2) +CUDA_DUMPER_PRINT_TYPE(half) +#undef DUMPER_PRINT_TYPE #endif diff --git a/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.h b/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.h index 631421b1623be..ec034bc15341e 100644 --- a/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.h +++ b/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.h @@ -16,44 +16,28 @@ class CudaTensorConsoleDumper : public onnxruntime::contrib::IConsoleDumper { CudaTensorConsoleDumper(); virtual ~CudaTensorConsoleDumper() {} - void Print(const char* name, const size_t* tensor, int dim0, int dim1) const override; - - void Print(const char* name, const int32_t* tensor, int dim0, int dim1) const override; - void Print(const char* name, const int32_t* tensor, int dim0, int dim1, int dim2) const override; - void Print(const char* name, const int32_t* tensor, int dim0, int dim1, int dim2, int dim3) const override; - void Print(const char* name, const int32_t* tensor, gsl::span& dims) const override; - - void Print(const char* name, const int64_t* tensor, int dim0, int dim1) const override; - void Print(const char* name, const int64_t* tensor, int dim0, int dim1, int dim2) const override; - void Print(const char* name, const int64_t* tensor, int dim0, int dim1, int dim2, int dim3) const override; - void Print(const char* name, const int64_t* tensor, gsl::span& dims) const override; - - void Print(const char* name, const float* tensor, int dim0, int dim1) const override; - void Print(const char* name, const float* tensor, int dim0, int dim1, int dim2) const override; - void Print(const char* name, const float* tensor, int dim0, int dim1, int dim2, int dim3) const override; - void Print(const char* name, const float* tensor, gsl::span& dims) const override; - - void Print(const char* name, const MLFloat16* tensor, int dim0, int dim1) const override; - void Print(const char* name, const MLFloat16* tensor, int dim0, int dim1, int dim2) const override; - void Print(const char* name, const MLFloat16* tensor, int dim0, int dim1, int dim2, int dim3) const override; - void Print(const char* name, const MLFloat16* tensor, gsl::span& dims) const override; - - void Print(const char* name, const half* tensor, int dim0, int dim1) const; - void Print(const char* name, const half* tensor, int dim0, int dim1, int dim2) const; - void Print(const char* name, const half* tensor, int dim0, int dim1, int dim2, int dim3) const; - void Print(const char* name, const half* tensor, gsl::span& dims) const; - - void Print(const char* name, const BFloat16* tensor, int dim0, int dim1) const; - void Print(const char* name, const BFloat16* tensor, int dim0, int dim1, int dim2) const; - void Print(const char* name, const BFloat16* tensor, int dim0, int dim1, int dim2, int dim3) const; - void Print(const char* name, const BFloat16* tensor, gsl::span& dims) const; - void Print(const char* name, const Tensor& value) const override; void Print(const char* name, const OrtValue& value) const override; - void Print(const char* name, int index, bool end_line) const override; - void Print(const char* name, const std::string& value, bool end_line) const override; - void Print(const std::string& value) const override; + +#define CUDA_DUMPER_PRINT_TYPE(dtype) \ + void Print(const char* name, const dtype* tensor, int dim0, int dim1) const; \ + void Print(const char* name, const dtype* tensor, int dim0, int dim1, int dim2) const; \ + void Print(const char* name, const dtype* tensor, int dim0, int dim1, int dim2, int dim3) const; \ + void Print(const char* name, const dtype* tensor, gsl::span& dims) const; + + CUDA_DUMPER_PRINT_TYPE(int8_t) + CUDA_DUMPER_PRINT_TYPE(uint8_t) + CUDA_DUMPER_PRINT_TYPE(int32_t) + CUDA_DUMPER_PRINT_TYPE(int64_t) + CUDA_DUMPER_PRINT_TYPE(float) + CUDA_DUMPER_PRINT_TYPE(MLFloat16) + CUDA_DUMPER_PRINT_TYPE(BFloat16) + CUDA_DUMPER_PRINT_TYPE(UInt4x2) + CUDA_DUMPER_PRINT_TYPE(Int4x2) + CUDA_DUMPER_PRINT_TYPE(half) + +#undef CUDA_DUMPER_PRINT_TYPE }; } // namespace cuda diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc index 65ecdff44acd6..c384b216f049a 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc @@ -49,7 +49,6 @@ fn mm_read_zero(row : u32, col : u32, r_dim: u32, c_dim: u32) -> output_element_ ss << "const default_zero_point = " << (nbits == 4 ? 8 : 128) << ";\n"; ss << R"( fn mm_read_zero(row : u32, col : u32, r_dim: u32, c_dim: u32) -> output_element_t { - // The default zero point is 8. return output_element_t(default_zero_point); } )"; @@ -433,6 +432,7 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context } // zero_points has shape[N * CeilDiv(n_blocks_per_col * bits, 8)]. So here we need to check whether n_blocks_per_col is divisible by 8/nbits. + // For bits==4, this is counted by elements of uint4. Need add 1 if not divisible by 2. uint32_t zero_blocks_per_col = n_blocks_per_col % (8 / nbits) == 0 ? n_blocks_per_col : n_blocks_per_col + 1; // WideTileProgram diff --git a/onnxruntime/core/common/cpuid_info.cc b/onnxruntime/core/common/cpuid_info.cc index 91961bf22ce1e..00ff896bf6749 100644 --- a/onnxruntime/core/common/cpuid_info.cc +++ b/onnxruntime/core/common/cpuid_info.cc @@ -3,9 +3,12 @@ #include "core/common/cpuid_info.h" #include "core/common/logging/logging.h" #include "core/common/logging/severity.h" +#include "core/platform/check_intel.h" #ifdef __linux__ - +#if (defined(_M_AMD64) || defined(__x86_64__)) && !defined(__ANDROID__) +#include +#endif #include #include #if !defined(__NR_getcpu) @@ -133,6 +136,17 @@ void CPUIDInfo::X86Init() { // avx512_skylake = avx512f | avx512vl | avx512cd | avx512bw | avx512dq has_avx512_skylake_ = has_avx512 && (data[1] & ((1 << 16) | (1 << 17) | (1 << 28) | (1 << 30) | (1 << 31))); is_hybrid_ = (data[3] & (1 << 15)); + // Check for TPAUSE + CheckIntelResult check_intel = CheckIntel(); + if (check_intel.is_intel) { +#ifdef __linux__ +#if !defined(__ANDROID__) + has_tpause_ = __builtin_cpu_supports("waitpkg") != 0; +#endif +#else + has_tpause_ = (data[2] & (1 << 5)) != 0; +#endif + } if (max_SubLeaves >= 1) { GetCPUID(7, 1, data); has_avx512_bf16_ = has_avx512 && (data[0] & (1 << 5)); diff --git a/onnxruntime/core/common/cpuid_info.h b/onnxruntime/core/common/cpuid_info.h index b820fa2ab1af7..9c67ebbffa260 100644 --- a/onnxruntime/core/common/cpuid_info.h +++ b/onnxruntime/core/common/cpuid_info.h @@ -33,6 +33,7 @@ class CPUIDInfo { bool HasSSE3() const { return has_sse3_; } bool HasSSE4_1() const { return has_sse4_1_; } bool IsHybrid() const { return is_hybrid_; } + bool HasTPAUSE() const { return has_tpause_; } // ARM bool HasArmNeonDot() const { return has_arm_neon_dot_; } @@ -112,6 +113,7 @@ class CPUIDInfo { bool has_sse3_{false}; bool has_sse4_1_{false}; bool is_hybrid_{false}; + bool has_tpause_{false}; std::vector core_uarchs_; // micro-arch of each core diff --git a/onnxruntime/core/common/spin_pause.cc b/onnxruntime/core/common/spin_pause.cc new file mode 100644 index 0000000000000..9bada0841c162 --- /dev/null +++ b/onnxruntime/core/common/spin_pause.cc @@ -0,0 +1,48 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/common/spin_pause.h" + +#if defined(_M_AMD64) +#include +#endif + +#if defined(__x86_64__) +#include +#endif + +#if defined(_M_AMD64) || defined(__x86_64__) +#include "core/common/cpuid_info.h" +#if defined(__linux__) +#include +#include +#endif +#endif + +namespace onnxruntime { +namespace concurrency { + +// Intrinsic to use in spin-loops +void SpinPause() { +#if (defined(_M_AMD64) || defined(__x86_64__)) && \ + !defined(__ANDROID__) && \ + !defined(__APPLE__) + + static const bool has_tpause = CPUIDInfo::GetCPUIDInfo().HasTPAUSE(); + static constexpr uint64_t tpause_spin_delay_cycles = 1000; + if (has_tpause) { +#if defined(_WIN32) + _tpause(0x0, __rdtsc() + tpause_spin_delay_cycles); +#elif defined(__linux__) + __builtin_ia32_tpause(0x0, __rdtsc() + tpause_spin_delay_cycles); +#else + _mm_pause(); +#endif + } else { + _mm_pause(); + } +#endif +} + +} // namespace concurrency +} // namespace onnxruntime diff --git a/onnxruntime/core/framework/config_options.h b/onnxruntime/core/framework/config_options.h index 028220d15fc8a..1c356d8cfca56 100644 --- a/onnxruntime/core/framework/config_options.h +++ b/onnxruntime/core/framework/config_options.h @@ -18,7 +18,7 @@ struct ConfigOptions { // Maximum key/value string lengths specified in // core/session/onnxruntime_session_options_config_keys.h static constexpr size_t kMaxKeyLength = 1024; - static constexpr size_t kMaxValueLength = 2048; + static constexpr size_t kMaxValueLength = 4096; std::unordered_map configurations; diff --git a/onnxruntime/core/framework/graph_partitioner.cc b/onnxruntime/core/framework/graph_partitioner.cc index 8ed5eeaa8d44f..9a2991ab02730 100644 --- a/onnxruntime/core/framework/graph_partitioner.cc +++ b/onnxruntime/core/framework/graph_partitioner.cc @@ -767,10 +767,10 @@ static Status InlineFunctionsAOTImpl(const ExecutionProviders& execution_provide static Status GetValidatedEpContextPath(const std::filesystem::path& ep_context_path, const std::filesystem::path& model_path, std::filesystem::path& context_cache_path, - bool allow_overwrite_output_model = false) { + bool error_if_output_file_exists = true) { if (!ep_context_path.empty()) { context_cache_path = ep_context_path; - if (!context_cache_path.has_filename()) { + if (!(context_cache_path.has_filename() && context_cache_path.extension() != "")) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "context_file_path should not point to a folder."); } } else if (!model_path.empty()) { @@ -784,9 +784,9 @@ static Status GetValidatedEpContextPath(const std::filesystem::path& ep_context_ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Both ep_context_path and model_path are empty."); } - if (std::filesystem::exists(context_cache_path) && !allow_overwrite_output_model) { + if (std::filesystem::exists(context_cache_path) && error_if_output_file_exists) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to generate EP context model since the file '", - context_cache_path, "' exist already. Please remove the EP context model if you want to re-generate it."); + context_cache_path, "' exists already. Please remove the EP context model if you want to re-generate it."); } return Status::OK(); @@ -803,16 +803,25 @@ static Status CreateEpContextModel(const ExecutionProviders& execution_providers } if (all_ep_context_nodes.size() < 1) { - ORT_RETURN_IF(ep_context_gen_options.error_if_no_compiled_nodes, - "Compiled model does not contain any EPContext nodes. " - "Check that the session EPs support compilation and can execute at least one model subgraph."); - - LOGS(logger, WARNING) << "Compiled model does not contain any EPContext nodes. " - "Either the session EPs do not support compilation or " - "no subgraphs were able to be compiled."; + auto action_if_no_compiled_nodes = ep_context_gen_options.action_if_no_compiled_nodes; + + ORT_RETURN_IF(action_if_no_compiled_nodes == EpContextModelGenerationOptions::ActionIfNoCompiledNodes::kReturnError, + "Unable to compile any nodes. Check that the session EPs support compilation and can execute " + "at least one subgraph in the model."); + + if (action_if_no_compiled_nodes == EpContextModelGenerationOptions::ActionIfNoCompiledNodes::kDontGenerateModel) { + LOGS(logger, WARNING) << "Unable to compile any nodes. ONNX Runtime will not generate a compiled model. " + "Either the session EPs do not support compilation or the model is already compiled."; + // Note: this path is only taken if a model is compiled with the original compilation approach that uses + // session options configs only. The explicit compile API instead only chooses between + // kReturnError and kGenerateModel. + return Status::OK(); + } - // we continue on to generate the compiled model which may benefit from L1 optimizations even if there are not - // EPContext nodes. + // Assert so that this is caught in a test in DEBUG builds (in case a new enum value is added) + assert(action_if_no_compiled_nodes == EpContextModelGenerationOptions::ActionIfNoCompiledNodes::kGenerateModel); + LOGS(logger, INFO) << "Unable to compile any nodes but will still generate an output model. " + "Either the session EPs do not support compilation or the model is already compiled."; } auto get_ep_context_node = [&all_ep_context_nodes](const std::string& node_name) -> std::pair { @@ -833,9 +842,21 @@ static Status CreateEpContextModel(const ExecutionProviders& execution_providers ORT_RETURN_IF_ERROR(GetValidatedEpContextPath(ep_context_gen_options.output_model_file_path, graph.ModelPath(), context_cache_path, - ep_context_gen_options.overwrite_existing_output_file)); + ep_context_gen_options.error_if_output_file_exists)); } + // Utility function to detect a fused node with an unsupported domain. + // Ex: when compiling an already compiled model, an EPContext node in the input model would be wrapped + // into a fused node with a domain like "QNN". Such fused nodes do not pass ONNX correctness checks, so + // we should detect them here and return a better error message. Otherwise, an ORT_INVALID_GRAPH error is raised + // with a confusing error message *after* the invalid model has been saved/generated. + // Note: This only applies to the explicit compile API. The original compilation approach (via session options), + // early exits above (without error) if the model is already compiled. + auto is_invalid_fused_node = [&graph](const Node& node) { + const std::unordered_map& supported_domains = graph.DomainToVersionMap(); + return (node.NodeType() == Node::Type::Fused) && (supported_domains.find(node.Domain()) == supported_domains.end()); + }; + Model ep_context_model(graph.Name(), false, graph.GetModel().MetaData(), graph.GetModel().ModelPath(), // use source model path so that external initializers can find the data file path IOnnxRuntimeOpSchemaRegistryList{graph.GetSchemaRegistry()}, @@ -872,6 +893,9 @@ static Status CreateEpContextModel(const ExecutionProviders& execution_providers // Use EpContext node created by the EPs if name matched, otherwise use node from original model if (ep_context_node.first) { ep_graph.AddNode(*ep_context_node.second); + } else if (is_invalid_fused_node(node)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, "Encountered an invalid node while compiling a model. ", + "Please ensure the input model is not already compiled."); } else { ep_graph.AddNode(node); } @@ -1216,7 +1240,7 @@ Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr, std::filesystem::path context_cache_path; ORT_RETURN_IF_ERROR(GetValidatedEpContextPath(ep_context_gen_options.output_model_file_path, graph.ModelPath(), context_cache_path, - ep_context_gen_options.overwrite_existing_output_file)); + ep_context_gen_options.error_if_output_file_exists)); } // We use this only if Resource Aware Partitioning is enabled for any of the EPs diff --git a/onnxruntime/core/framework/session_options.h b/onnxruntime/core/framework/session_options.h index 89a43c4f71ee6..b95b38d007fbb 100644 --- a/onnxruntime/core/framework/session_options.h +++ b/onnxruntime/core/framework/session_options.h @@ -70,15 +70,42 @@ struct FreeDimensionOverride { using CheckLoadCancellationFn = std::function; +/// +/// Options that configure the generation of a compiled model (i.e., a model with EPContext nodes). +/// There are two ways to compile a model: +/// 1. By specifying the correct session option configurations and creating an inference session. +/// The compiled model is generated as a side-effect of session creation. +/// 2. Using an explicit compile API (see OrtCompileApi struct in onnxruntime_c_api.h). +/// +/// The default values in this struct are set to match the current/default behavior of approach 1 to maintain +/// compatibility with the older way of compiling. The explicit compile API overrides some of these values to +/// provide its own defaults (see core/session/model_compilation_options.h/cc). +/// struct EpContextModelGenerationOptions { + // Action to take if the output model does not have compiled (EPContext) nodes. + enum class ActionIfNoCompiledNodes { + // Return OK() but don't generate an output model. Compiling via SessionOptions defaults to this behavior + // to maintain compatibility. The explicit compile API does *not* use this action. + kDontGenerateModel = 0, + + // Generate an output model even if it doesn't have compiled nodes. + // The explicit Compile API defaults to this value. + kGenerateModel, + + // Return an error if the model does not have compiled nodes. + // The explicit Compile API can be configured to this value. + kReturnError, + }; + EpContextModelGenerationOptions() = default; // Initializes from string key/value pairs in session config options. + // This initializes this struct from options set via the older compiling approach #1 above. explicit EpContextModelGenerationOptions(const ConfigOptions& config_options); bool enable = false; - bool overwrite_existing_output_file = false; - bool error_if_no_compiled_nodes = false; + bool error_if_output_file_exists = true; + ActionIfNoCompiledNodes action_if_no_compiled_nodes = ActionIfNoCompiledNodes::kDontGenerateModel; bool embed_ep_context_in_model = false; std::string output_model_file_path; diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index 7eea7d218e278..238dd8d4573de 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -1669,7 +1669,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "with shape (batch_size, sequence_length, hidden_size) or (token_count, hidden_size).", "T", OpSchema::Optional) - .TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float or half tensors.") + .TypeConstraint("T", {"tensor(float)", "tensor(float16)", "tensor(bfloat16)"}, "Constrain input and output to float tensors.") .TypeConstraint("U", {"tensor(float)"}, "Constrain mean and inv_std_var to float tensors.") .TypeAndShapeInferenceFunction(SkipLayerNormalizationShapeInference)); diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index f9f7be60a9bd6..96a1ad91a7f17 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -3428,39 +3428,33 @@ This op functions in much the same was as Dropout-11 and Dropout-13 do, except t }); static const char* MatMulNBits_ver1_doc = R"DOC( -MatMulNBits is a MatMul with weight quantized with N bits(e.g., 2, 3, 4, 5, 6, 7).It does Matrix Multiplication like MatMul (https://github.com/onnx/onnx/blob/main/docs/Operators.md#matmul) with differences: - 1. Input B is a 2D constant Matrix. Its input feature count and output feature count are specified by attribute 'K' and 'N'. - 2. Input B is quantized with x bits which is specified by attribute 'bits'. It is quantized blockwisely along dimension 0 (e.g. column) with block size specified by attribute block_size. - And block_size is not an arbitrary number and must be a power of 2 and not smaller than 16, like 16, 32, 64, 128,.. - 3. Input B's scale and zero point are specified by input scales and zero_points. - - Input B is stored as uint8_t with shape: [N][n_blocks_per_col][blob_size] in which: - - n_blocks_per_col = (K + block_size - 1) / block_size - - blob_size = CeilDiv(block_size * bits, bitsof(uint8_t)<8>) - For all bits from 2-8, a row of data is stored squeezely and represented by uint8_t. - - for 2,4,8 bits, 4x2bit,2x4bit,1x8bit are stored in one uint8_t. - 4bit example: - |.|.|.|.| .|.|.|.| =uint8_t (2x4bit) - - for 3,5,6,7 bits, 32x3bit,32x5bit,16x6bit,32x7bit are stored in 12xuint8_t,20xuint8_t,12xuint8_t,28xuint8_t separately. no bits are wasted. - 3bit example: - |.|.|. |.|.|. |.|.|. = 9bit, which across 2 uint8_t, the highest bit for the second uint8_t is used. - The last uint_8 may have some bits unused. - - -Input scales is stored in same type as original type of B(float32, float16) with shape like: [N * n_blocks_per_col] -Input zero_points is stored as uint8_t or same as type(A). It has the same packing method as input B. - - [N * CeilDiv(n_blocks_per_col * bits, 8)] - If zero_points has same type as A, it's not packed and has the same shape as Scales. +MatMulNBits performs a matrix multiplication where the right-hand-side matrix (weights) is quantized to N bits. + +It is a fusion of two operations: +1. Linear dequantization of the quantized weights using scale and (optionally) zero-point with formula: + dequantized_weight = (quantized_weight - zero_point) * scale +2. Matrix multiplication between the input matrix A and the dequantized weight matrix. + +The weight matrix is a 2D constant matrix with the input feature count and output feature count specified by attributes 'K' and 'N'. +It is quantized block-wise along the K dimension with a block size specified by the 'block_size' attribute. +The block size must be a power of 2 and not smaller than 16 (e.g., 16, 32, 64, 128). Each block has its own scale and zero-point. +The quantization is performed using a bit-width specified by the 'bits' attribute, which can take values from 2 to 8. + +The quantized weights are stored in a bit-packed format along the K dimension, with each block being represented by a blob of uint8. +For example, for 4 bits, the first 4 bits are stored in the lower 4 bits of a byte, and the second 4 bits are stored in the higher 4 bits of a byte. )DOC"; ONNX_CONTRIB_OPERATOR_SCHEMA(MatMulNBits) .SetDomain(kMSDomain) .SinceVersion(1) .SetDoc(MatMulNBits_ver1_doc) - .Attr("K", "size of each input feature", AttributeProto::INT) - .Attr("N", "size of each output feature", AttributeProto::INT) - .Attr("bits", "number of bits used for weight quantization (default 4)", AttributeProto::INT) - .Attr("block_size", "number of groupsize used for weight quantization,(default 128). It needs to be a power of 2 and not smaller than 16.", AttributeProto::INT) + .Attr("K", "Input feature dimension of the weight matrix.", AttributeProto::INT) + .Attr("N", "Output feature dimension of the weight matrix.", AttributeProto::INT) + .Attr("bits", "Bit-width used to quantize the weights (valid range: 2~8)", AttributeProto::INT, static_cast(4)) + .Attr("block_size", + "Size of each quantization block along the K (input feature) dimension. " + "Must be a power of two and ≥ 16 (e.g., 16, 32, 64, 128).", + AttributeProto::INT) .Attr("accuracy_level", "The minimum accuracy level of input A, can be: 0(unset), 1(fp32), 2(fp16), 3(bf16), or 4(int8) " "(default unset). It is used to control how input A is quantized or downcast internally while " @@ -3468,16 +3462,27 @@ Input zero_points is stored as uint8_t or same as type(A). It has the same packi "computation. 4 means input A can be quantized with the same block_size to int8 internally from " "type T1.", AttributeProto::INT, static_cast(0)) - .Input(0, "A", "The input tensor, not quantized", "T1") - .Input(1, "B", "1 or 2 dimensional data blob", "T2") - .Input(2, "scales", "quantization scale", "T1") - .Input(3, "zero_points", "quantization zero points", "T3", OpSchema::Optional) - .Input(4, "g_idx", "group_idx", "T4", OpSchema::Optional) + .Input(0, "A", "The input tensor, not quantized.", "T1") + .Input(1, "B", + "Packed uint8 tensor of shape (N, k_blocks, blob_size), " + "where k_blocks = ceil(K / block_size) and blob_size = (block_size * bits / 8). " + "The quantized weights are stored in a bit-packed format along the K dimension, packed within each block_size.", + "T2") + .Input(2, "scales", "Per-block scaling factors for dequantization with shape (N, k_blocks) and same data type as input A.", "T1") + .Input(3, "zero_points", + "Per-block zero point for dequantization. It can be either packed or unpacked: " + "Packed (uint8) format has shape (N, ceil(k_blocks * bits / 8)), and it uses same bit-packing method as Input B. " + "Unpacked (same type as A) format has shape (N, k_blocks). " + "If not provided, a default zero point is used: 2^(bits - 1) (e.g., 8 for 4-bit quantization, 128 for 8-bit). ", + "T3", OpSchema::Optional) + .Input(4, "g_idx", "group_idx. This input is deprecated", "T4", OpSchema::Optional) .Input(5, "bias", "Bias to add to result. It should have shape [N].", "T1", OpSchema::Optional) .Output(0, "Y", "tensor. The output tensor has the same rank as the input. ", "T1") - .TypeConstraint("T1", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float/half_float tensors.") - .TypeConstraint("T2", {"tensor(uint8)", "tensor(int32)"}, "Constrain quantized weight types to uint8/int32.") - .TypeConstraint("T3", {"tensor(uint8)", "tensor(int32)", "tensor(float16)", "tensor(float)"}, "Constrain quantized zero point types to uint8/int32/float16/float.") + .TypeConstraint("T1", {"tensor(float)", "tensor(float16)", "tensor(bfloat16)"}, + "Constrain input and output types to float tensors.") + .TypeConstraint("T2", {"tensor(uint8)"}, "Constrain quantized weight types to uint8.") + .TypeConstraint("T3", {"tensor(uint8)", "tensor(float16)", "tensor(float)", "tensor(bfloat16)"}, + "Constrain quantized zero point types to uint8 or float tensors.") .TypeConstraint("T4", {"tensor(int32)"}, "the index tensor.") .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { // Type inference @@ -3569,9 +3574,9 @@ MatMulBnb4 is a MatMul with weight quantized with 4 bits using either FP4 or NF4 static const char* GatherBlockQuantized_ver1_doc = R"DOC( GatherBlockQuantized is a Gather with data quantized. It is similar to Gather (https://github.com/onnx/onnx/blob/main/docs/Operators.md#gather) with differences: 1. Input `data` is a constant. It is quantized block-wise along attribute `quantize_axis` with block size specified by attribute `block_size`. - `block_size must` be a power of 2 and not smaller than 16, like 16, 32, 64, 128, .. + `block_size` must be a power of 2 and not smaller than 16, like 16, 32, 64, 128, ... 2. Input `data`'s scale and zero point are specified by input `scales` and `zero_points`. `scales` and `zero_points` are also constants. - If `zero_points` is not provided, 0 is the zero point except when data is uint8 type then the default zero point is 8. + If `zero_points` is not provided, the default value is 0 for int4/uint4, or 2^(bits-1) for uint8. 3. During the op execution, `data` and `indices` are first used to generate the quantized output. Then, `scales` and `zero_points` are used to dequantize the output. 4. The `output` and `scales` have the same type. The `data` and `zero_points` have the same type. diff --git a/onnxruntime/core/optimizer/bias_softmax_fusion.cc b/onnxruntime/core/optimizer/bias_softmax_fusion.cc index bcbb70ba8fac5..2bbc70db16cde 100644 --- a/onnxruntime/core/optimizer/bias_softmax_fusion.cc +++ b/onnxruntime/core/optimizer/bias_softmax_fusion.cc @@ -135,7 +135,7 @@ bool TrySelectInputAndBiasWithAlignment(Node& add_node, Node& softmax_node, Node new_axis = (int)HandleNegativeAxis(axis, rank); // The axis attribute for Softmax in OpSet-11 and OpSet-13 are different. - // Details in function documentatin. + // Details in function documentation. if (is_since_opset_13 && new_axis != rank - 1) return false; int singlebatch_rank = rank - new_axis; diff --git a/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc b/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc index 93c7efc9ca167..ac128011c0b9f 100644 --- a/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc +++ b/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc @@ -3308,7 +3308,7 @@ const std::unordered_set& GetLayoutSensitiveOps() { "BatchNormalization", "InstanceNormalization", // convolutions - "Conv", "QLinearConv", "ConvTranspose", + "Conv", "ConvInteger", "QLinearConv", "ConvTranspose", // pooling "AveragePool", "LpPool", "MaxPool", "MaxUnpool", diff --git a/onnxruntime/core/platform/check_intel.cc b/onnxruntime/core/platform/check_intel.cc new file mode 100644 index 0000000000000..d773ae2d2be2f --- /dev/null +++ b/onnxruntime/core/platform/check_intel.cc @@ -0,0 +1,97 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/platform/check_intel.h" + +#if (defined(_M_AMD64) || defined(__x86_64__)) +#if defined(__linux__) +#include +#elif defined(_WIN32) +#include +#endif +#endif + +namespace onnxruntime { + +CheckIntelResult CheckIntel() { + CheckIntelResult intel_check = {false, false}; + bool is_intel = false; + bool is_intel_specified_platform = false; + +#if (defined(_M_AMD64) || defined(__x86_64__)) +#if defined(_WIN32) + constexpr unsigned int kVendorID_Intel[] = {0x756e6547, 0x6c65746e, 0x49656e69}; // "GenuntelineI" + constexpr unsigned int kVendorID_IntelSpecifiedPlatformIDs[] = { + // ExtendedModel, ExtendedFamily, Family Code, and Model Number + 0xa06a, // MTL + 0xc065, // ARL-H + 0xb065 // ARL-U + }; + + int regs_leaf0[4]; + int regs_leaf1[4]; + __cpuid(regs_leaf0, 0); + __cpuid(regs_leaf1, 0x1); + + is_intel = + (kVendorID_Intel[0] == static_cast(regs_leaf0[1])) && + (kVendorID_Intel[1] == static_cast(regs_leaf0[2])) && + (kVendorID_Intel[2] == static_cast(regs_leaf0[3])); + + if (!is_intel) { + return intel_check; // if not an Intel CPU, return early + } + + for (auto intel_specified_platform : kVendorID_IntelSpecifiedPlatformIDs) { + if ((static_cast(regs_leaf1[0]) >> 4) == intel_specified_platform) { + is_intel_specified_platform = true; + break; + } + } + +#elif defined(__linux__) + constexpr unsigned int kVendorID_Intel[] = {0x756e6547, 0x6c65746e, 0x49656e69}; // "GenuntelineI" + unsigned int regs[4] = {0}; + __get_cpuid(0, ®s[0], ®s[1], ®s[2], ®s[3]); + + is_intel = (regs[1] == kVendorID_Intel[0] && + regs[2] == kVendorID_Intel[1] && + regs[3] == kVendorID_Intel[2]); + if (!is_intel) { + return intel_check; // if not an Intel CPU, return early + } + + __get_cpuid(1, ®s[0], ®s[1], ®s[2], ®s[3]); + + unsigned int base_family = (regs[0] >> 8) & 0xF; + unsigned int base_model = (regs[0] >> 4) & 0xF; + unsigned int extended_model = (regs[0] >> 16) & 0xF; + + unsigned int model = + (base_family == 0x6 || base_family == 0xF) + ? (base_model + (extended_model << 4)) + : base_model; + + constexpr unsigned int kVendorID_IntelSpecifiedPlatformIDs[] = { + // ExtendedModel, ExtendedFamily, Family Code, and Model Number + 170, // MTL (0xAA) + 197, // ARL-H (0xC5) + 198 // ARL-U (0xC6) + }; + + for (auto id : kVendorID_IntelSpecifiedPlatformIDs) { + if (model == id) { + is_intel_specified_platform = true; + break; + } + } +#endif //__linux__ +#endif // (_M_AMD64) || (__x86_64__) + + intel_check.is_intel = is_intel; + intel_check.is_intel_specified_platform = is_intel_specified_platform; + + return intel_check; +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/platform/check_intel.h b/onnxruntime/core/platform/check_intel.h new file mode 100644 index 0000000000000..1b82940489171 --- /dev/null +++ b/onnxruntime/core/platform/check_intel.h @@ -0,0 +1,13 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +namespace onnxruntime { +typedef struct { + bool is_intel; + bool is_intel_specified_platform; +} CheckIntelResult; + +CheckIntelResult CheckIntel(); +} // namespace onnxruntime diff --git a/onnxruntime/core/platform/windows/hardware_core_enumerator.cc b/onnxruntime/core/platform/windows/hardware_core_enumerator.cc index 7464ab4c57d01..40a2fb780878c 100644 --- a/onnxruntime/core/platform/windows/hardware_core_enumerator.cc +++ b/onnxruntime/core/platform/windows/hardware_core_enumerator.cc @@ -3,6 +3,7 @@ #include "hardware_core_enumerator.h" #include "core/platform/windows/env.h" +#include "core/platform/check_intel.h" #include #include #include @@ -85,30 +86,11 @@ uint32_t HardwareCoreEnumerator::DefaultIntraOpNumThreads() { // # of logical cores = # of P cores x 2 (if hyper threading is enabled) + # of E cores + # of Soc Cores. auto cores = GetCoreInfo(); #if !defined(_M_ARM64EC) && !defined(_M_ARM64) && !defined(__aarch64__) - const int kVendorID_Intel[3] = {0x756e6547, 0x6c65746e, 0x49656e69}; // "GenuntelineI" - bool isIntelSpecifiedPlatform = false; - const int kVendorID_IntelSpecifiedPlatformIDs[3] = { - // ExtendedModel, ExtendedFamily, Family Code, and Model Number - 0xa06a, // MTL - 0xc065, // ARL-H - 0xb065 // ARL-U - }; - - int regs_leaf0[4]; - int regs_leaf1[4]; - __cpuid(regs_leaf0, 0); - __cpuid(regs_leaf1, 0x1); - - auto isIntel = (kVendorID_Intel[0] == regs_leaf0[1]) && (kVendorID_Intel[1] == regs_leaf0[2]) && (kVendorID_Intel[2] == regs_leaf0[3]); - - for (int intelSpecifiedPlatform : kVendorID_IntelSpecifiedPlatformIDs) { - if ((regs_leaf1[0] >> 4) == intelSpecifiedPlatform) { - isIntelSpecifiedPlatform = true; - } - } - if (isIntel) { - if (isIntelSpecifiedPlatform) { + CheckIntelResult check_intel = CheckIntel(); + + if (check_intel.is_intel) { + if (check_intel.is_intel_specified_platform) { // We want to exclude cores without an LLC return cores.LLCCores; } else { diff --git a/onnxruntime/core/platform/windows/telemetry.cc b/onnxruntime/core/platform/windows/telemetry.cc index 0775e19c5654b..2385bae65d491 100644 --- a/onnxruntime/core/platform/windows/telemetry.cc +++ b/onnxruntime/core/platform/windows/telemetry.cc @@ -183,6 +183,7 @@ void WindowsTelemetry::LogProcessInfo() const { // Telemetry info TraceLoggingUInt8(0, "schemaVersion"), TraceLoggingString(ORT_VERSION, "runtimeVersion"), + TraceLoggingBool(IsDebuggerPresent(), "isDebuggerAttached"), TraceLoggingBool(isRedist, "isRedist")); process_info_logged = true; diff --git a/onnxruntime/core/providers/cuda/shared_inc/cuda_call.h b/onnxruntime/core/providers/cuda/shared_inc/cuda_call.h index 2b2b726e62c79..63e2ab8e9cb9b 100644 --- a/onnxruntime/core/providers/cuda/shared_inc/cuda_call.h +++ b/onnxruntime/core/providers/cuda/shared_inc/cuda_call.h @@ -15,30 +15,30 @@ std::conditional_t CudaCall( ERRTYPE retCode, const char* exprString, const char* libName, SUCCTYPE successCode, const char* msg, const char* file, const int line); -#define CUDA_CALL(expr) (CudaCall((expr), #expr, "CUDA", cudaSuccess, "", __FILE__, __LINE__)) -#define CUBLAS_CALL(expr) (CudaCall((expr), #expr, "CUBLAS", CUBLAS_STATUS_SUCCESS, "", __FILE__, __LINE__)) +#define CUDA_CALL(expr) (::onnxruntime::CudaCall((expr), #expr, "CUDA", cudaSuccess, "", __FILE__, __LINE__)) +#define CUBLAS_CALL(expr) (::onnxruntime::CudaCall((expr), #expr, "CUBLAS", CUBLAS_STATUS_SUCCESS, "", __FILE__, __LINE__)) -#define CUSPARSE_CALL(expr) (CudaCall((expr), #expr, "CUSPARSE", CUSPARSE_STATUS_SUCCESS, "", __FILE__, __LINE__)) -#define CURAND_CALL(expr) (CudaCall((expr), #expr, "CURAND", CURAND_STATUS_SUCCESS, "", __FILE__, __LINE__)) -#define CUDNN_CALL(expr) (CudaCall((expr), #expr, "CUDNN", CUDNN_STATUS_SUCCESS, "", __FILE__, __LINE__)) -#define CUDNN_CALL2(expr, m) (CudaCall((expr), #expr, "CUDNN", CUDNN_STATUS_SUCCESS, m, __FILE__, __LINE__)) +#define CUSPARSE_CALL(expr) (::onnxruntime::CudaCall((expr), #expr, "CUSPARSE", CUSPARSE_STATUS_SUCCESS, "", __FILE__, __LINE__)) +#define CURAND_CALL(expr) (::onnxruntime::CudaCall((expr), #expr, "CURAND", CURAND_STATUS_SUCCESS, "", __FILE__, __LINE__)) +#define CUDNN_CALL(expr) (::onnxruntime::CudaCall((expr), #expr, "CUDNN", CUDNN_STATUS_SUCCESS, "", __FILE__, __LINE__)) +#define CUDNN_CALL2(expr, m) (::onnxruntime::CudaCall((expr), #expr, "CUDNN", CUDNN_STATUS_SUCCESS, m, __FILE__, __LINE__)) -#define CUFFT_CALL(expr) (CudaCall((expr), #expr, "CUFFT", CUFFT_SUCCESS, "", __FILE__, __LINE__)) +#define CUFFT_CALL(expr) (::onnxruntime::CudaCall((expr), #expr, "CUFFT", CUFFT_SUCCESS, "", __FILE__, __LINE__)) -#define CUDA_CALL_THROW(expr) (CudaCall((expr), #expr, "CUDA", cudaSuccess, "", __FILE__, __LINE__)) -#define CUBLAS_CALL_THROW(expr) (CudaCall((expr), #expr, "CUBLAS", CUBLAS_STATUS_SUCCESS, "", __FILE__, __LINE__)) +#define CUDA_CALL_THROW(expr) (::onnxruntime::CudaCall((expr), #expr, "CUDA", cudaSuccess, "", __FILE__, __LINE__)) +#define CUBLAS_CALL_THROW(expr) (::onnxruntime::CudaCall((expr), #expr, "CUBLAS", CUBLAS_STATUS_SUCCESS, "", __FILE__, __LINE__)) -#define CUSPARSE_CALL_THROW(expr) (CudaCall((expr), #expr, "CUSPARSE", CUSPARSE_STATUS_SUCCESS, "", __FILE__, __LINE__)) -#define CURAND_CALL_THROW(expr) (CudaCall((expr), #expr, "CURAND", CURAND_STATUS_SUCCESS, "", __FILE__, __LINE__)) +#define CUSPARSE_CALL_THROW(expr) (::onnxruntime::CudaCall((expr), #expr, "CUSPARSE", CUSPARSE_STATUS_SUCCESS, "", __FILE__, __LINE__)) +#define CURAND_CALL_THROW(expr) (::onnxruntime::CudaCall((expr), #expr, "CURAND", CURAND_STATUS_SUCCESS, "", __FILE__, __LINE__)) // the cudnn configuration call that doesn't need set stream -#define CUDNN_CALL_THROW(expr) (CudaCall((expr), #expr, "CUDNN", CUDNN_STATUS_SUCCESS, "", __FILE__, __LINE__)) +#define CUDNN_CALL_THROW(expr) (::onnxruntime::CudaCall((expr), #expr, "CUDNN", CUDNN_STATUS_SUCCESS, "", __FILE__, __LINE__)) -#define CUFFT_CALL_THROW(expr) (CudaCall((expr), #expr, "CUFFT", CUFFT_SUCCESS, "", __FILE__, __LINE__)) +#define CUFFT_CALL_THROW(expr) (::onnxruntime::CudaCall((expr), #expr, "CUFFT", CUFFT_SUCCESS, "", __FILE__, __LINE__)) #ifdef ORT_USE_NCCL -#define NCCL_CALL(expr) (CudaCall((expr), #expr, "NCCL", ncclSuccess, "", __FILE__, __LINE__)) -#define NCCL_CALL_THROW(expr) (CudaCall((expr), #expr, "NCCL", ncclSuccess, "", __FILE__, __LINE__)) +#define NCCL_CALL(expr) (::onnxruntime::CudaCall((expr), #expr, "NCCL", ncclSuccess, "", __FILE__, __LINE__)) +#define NCCL_CALL_THROW(expr) (::onnxruntime::CudaCall((expr), #expr, "NCCL", ncclSuccess, "", __FILE__, __LINE__)) #endif } // namespace onnxruntime diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_allocator.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_allocator.cc index 4e8179d86fd73..a44ab93ccca8b 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_allocator.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_allocator.cc @@ -1,4 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // Licensed under the MIT License. #include "nv_allocator.h" diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc index 6a7ff63dbc0ed..0fb44fe4eda85 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc @@ -5,6 +5,7 @@ #include #include #include "core/providers/shared_library/provider_api.h" +#include "core/providers/nv_tensorrt_rtx/nv_provider_options.h" #define ORT_API_MANUAL_INIT #include "core/session/onnxruntime_cxx_api.h" #include "core/common/common.h" @@ -20,6 +21,7 @@ #include "core/providers/cuda/math/unary_elementwise_ops_impl.h" #include "core/session/allocator_adapters.h" #include "cuda_runtime_api.h" +#include "core/common/parse_string.h" #include #include #include @@ -743,6 +745,7 @@ Status BindContextInput(Ort::KernelContext& ctx, switch (tensor_type) { CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, float) CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16, uint16_t) + CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16, uint16_t) CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL, bool) CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, int8_t) CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, uint8_t) @@ -829,6 +832,7 @@ Status BindContextOutput(Ort::KernelContext& ctx, switch (output_type) { CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, float) CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16, uint16_t) + CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16, uint16_t) CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL, bool) CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, int8_t) CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, uint8_t) @@ -892,6 +896,7 @@ Status BindKernelOutput(Ort::KernelContext& ctx, switch (output_type) { CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, float) CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16, uint16_t) + CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16, uint16_t) CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL, bool) CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, int8_t) CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, uint8_t) @@ -1140,6 +1145,7 @@ NvExecutionProvider::NvExecutionProvider(const NvExecutionProviderInfo& info) } cuda_graph_enable_ = info.cuda_graph_enable; + multi_profile_enable_ = info.multi_profile_enable; op_types_to_exclude_ = info.op_types_to_exclude; // Validate setting @@ -1321,7 +1327,12 @@ std::unique_ptr NvExecutionProvider::GetDataTransfer() const { return std::make_unique(); } -Status NvExecutionProvider::OnRunStart(const onnxruntime::RunOptions& /*run_options*/) { +Status NvExecutionProvider::OnRunStart(const onnxruntime::RunOptions& run_options) { + if (multi_profile_enable_ == true) { + auto graph_annotation_str = + run_options.GetConfigOptions().GetConfigEntry(nv::run_option_names::kProfileIndex); + TryParseStringWithClassicLocale(*graph_annotation_str, nv_profile_index_); + } return Status::OK(); } @@ -1989,10 +2000,6 @@ NvExecutionProvider::GetCapability(const GraphViewer& graph, if (exclude_ops_set.find(node->OpType()) != exclude_ops_set.end()) { supported_node = false; } - // Exclude contrib ops - if (node->Domain() == kMSDomain) { - supported_node = false; - } if (supported_node) { if (new_subgraph) { @@ -2687,6 +2694,11 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr Ort::ThrowOnError(api->KernelContext_GetGPUComputeStream(context, &cuda_stream)); cudaStream_t stream = static_cast(cuda_stream); + if (multi_profile_enable_ == true) { + if (!trt_context->setOptimizationProfileAsync(nv_profile_index_, stream)) + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Nv EP select an optimization profile for the current context failed"); + } + // Name the engine cache based on GPU compute capacity and reduce the chance of loading an incompatible cache // Note: Engine cache generated on a GPU with large memory might not be loadable on a GPU with smaller memory, even if they share the same compute capacity // Prepare cache name diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h index 76044b4fc2017..6c5e1a1f0a8d3 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h @@ -1,4 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // Licensed under the MIT License. #pragma once @@ -261,8 +262,10 @@ class NvExecutionProvider : public IExecutionProvider { int (*engine_encryption_)(const char*, char*, size_t) = nullptr; bool detailed_build_log_ = false; bool cuda_graph_enable_ = false; + bool multi_profile_enable_ = false; std::string cache_prefix_; std::string op_types_to_exclude_; + int nv_profile_index_ = 0; // The format is as for TENSORRT_VERSION: (MAJOR * 100 + MINOR) * 100 + PATCH int32_t trt_version_; diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_custom_ops.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_custom_ops.cc index 0806ae3638036..c8df7c9437adf 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_custom_ops.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_custom_ops.cc @@ -1,4 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // Licensed under the MIT License. #include diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_custom_ops.h b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_custom_ops.h index 897c2ce0e0b98..81c0d49239ec8 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_custom_ops.h +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_custom_ops.h @@ -1,4 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // Licensed under the MIT License. #pragma once diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_helper.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_helper.cc index cd50f1e6b2d48..8728558006fc5 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_helper.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_helper.cc @@ -1,4 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // Licensed under the MIT License. #include "core/providers/shared_library/provider_api.h" diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_info.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_info.cc index f5ba66746c3c4..78f2723a20118 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_info.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_info.cc @@ -1,4 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // Licensed under the MIT License. #include "core/providers/nv_tensorrt_rtx/nv_execution_provider_info.h" @@ -46,6 +47,7 @@ NvExecutionProviderInfo NvExecutionProviderInfo::FromProviderOptions(const Provi .AddAssignmentToReference(nv::provider_option_names::kProfilesMaxShapes, info.profile_max_shapes) .AddAssignmentToReference(nv::provider_option_names::kProfilesOptShapes, info.profile_opt_shapes) .AddAssignmentToReference(nv::provider_option_names::kCudaGraphEnable, info.cuda_graph_enable) + .AddAssignmentToReference(nv::provider_option_names::kMultiProfileEnable, info.multi_profile_enable) .AddValueParser( nv::provider_option_names::kONNXBytestream, [&onnx_bytestream](const std::string& value_str) -> Status { diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_info.h b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_info.h index 626039e5ef7c8..e70e70bf05eb9 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_info.h +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_info.h @@ -1,4 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // Licensed under the MIT License. #pragma once @@ -42,6 +43,7 @@ struct NvExecutionProviderInfo { std::string profile_max_shapes{""}; std::string profile_opt_shapes{""}; bool cuda_graph_enable{false}; + bool multi_profile_enable{false}; bool dump_ep_context_model{false}; std::string ep_context_file_path{""}; int ep_context_embed_mode{0}; diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_utils.h b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_utils.h index 046010deedf62..22e5eea6924de 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_utils.h +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_utils.h @@ -1,4 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // Licensed under the MIT License. #include diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_includes.h b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_includes.h index 047f325f49b70..a4e3777008560 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_includes.h +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_includes.h @@ -1,4 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // Licensed under the MIT License. #pragma once diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc index 1f4eed7db7203..0fc3e5443bc28 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc @@ -1,4 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // Licensed under the MIT License. #include "core/providers/shared_library/provider_api.h" diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.h b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.h index 928874475735f..5672c5dda632e 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.h +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.h @@ -1,4 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // Licensed under the MIT License. #include "onnxruntime_c_api.h" diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory_creator.h b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory_creator.h index 616f5f1fbe754..6b2e516211257 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory_creator.h +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory_creator.h @@ -1,4 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // Licensed under the MIT License. #pragma once diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/onnx_ctx_model_helper.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/onnx_ctx_model_helper.cc index 25decd8f2ce8f..21d964b0c341f 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/onnx_ctx_model_helper.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/onnx_ctx_model_helper.cc @@ -1,4 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // Licensed under the MIT License. #include diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/onnx_ctx_model_helper.h b/onnxruntime/core/providers/nv_tensorrt_rtx/onnx_ctx_model_helper.h index ccd06750692fc..f0a05c42414e5 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/onnx_ctx_model_helper.h +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/onnx_ctx_model_helper.h @@ -1,4 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // Licensed under the MIT License. #pragma once diff --git a/onnxruntime/core/providers/qnn/builder/op_builder_factory.cc b/onnxruntime/core/providers/qnn/builder/op_builder_factory.cc index efb4afcb88c85..e4d768093aa37 100644 --- a/onnxruntime/core/providers/qnn/builder/op_builder_factory.cc +++ b/onnxruntime/core/providers/qnn/builder/op_builder_factory.cc @@ -181,6 +181,10 @@ OpBuilderRegistrations::OpBuilderRegistrations() { { CreateMatMulOpBuilder("MatMul", *this); } + + { + CreateLSTMOpBuilder("LSTM", *this); + } } const IOpBuilder* GetOpBuilder(const std::string& onnx_op_type) { diff --git a/onnxruntime/core/providers/qnn/builder/op_builder_factory.h b/onnxruntime/core/providers/qnn/builder/op_builder_factory.h index aa1039f857f8e..c1cc61ad19341 100644 --- a/onnxruntime/core/providers/qnn/builder/op_builder_factory.h +++ b/onnxruntime/core/providers/qnn/builder/op_builder_factory.h @@ -102,5 +102,7 @@ void CreateHardSigmoidOpBuilder(const std::string& op_type, OpBuilderRegistratio void CreateMatMulOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateEinsumOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); + +void CreateLSTMOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); } // namespace qnn } // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc index d7432f35e61cf..74518e2fcb7a2 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc @@ -138,6 +138,10 @@ Status BaseOpBuilder::ProcessInt64Tensors(QnnModelWrapper& qnn_model_wrapper, return Status::OK(); } for (size_t i = 0; i < input_names.size(); i++) { + if (input_names[i].size() == 0) { + // For optional inputs, the input_name is empty + continue; + } auto& input_tensorwrapper = qnn_model_wrapper.GetQnnTensorWrapper(input_names[i]); // Insert cast to int32 if input dtype is int64 if (input_tensorwrapper.GetTensorDataType() == QNN_DATATYPE_INT_64) { diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.h b/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.h index 5474db0590f92..5b3fa6ed3b950 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.h +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.h @@ -107,6 +107,35 @@ class BaseOpBuilder : public IOpBuilder { const logging::Logger& logger, std::vector& input_names) const ORT_MUST_USE_RESULT; + template + Status AddQnnScalar(QnnModelWrapper& qnn_model_wrapper, + const NodeIndex& node_index, + const std::string& node_name, + const T& scalar, + const std::string& qnn_scalar_param_name, + std::vector& param_names) const { + Qnn_Scalar_t qnn_scalar = QNN_SCALAR_INIT; + if (std::is_same::value) { + qnn_scalar.dataType = QNN_DATATYPE_FLOAT_32; + qnn_scalar.floatValue = static_cast(scalar); + } else if (std::is_same::value) { + qnn_scalar.dataType = QNN_DATATYPE_UINT_32; + qnn_scalar.uint32Value = static_cast(scalar); + } else if (std::is_same::value) { + qnn_scalar.dataType = QNN_DATATYPE_INT_32; + qnn_scalar.int32Value = static_cast(scalar); + } else if (std::is_same::value) { + qnn_scalar.dataType = QNN_DATATYPE_BOOL_8; + qnn_scalar.bool8Value = static_cast(scalar); + } else { + ORT_RETURN_IF(true, "QNN EP: Unsupported scalar dtype"); + } + QnnParamWrapper qnn_param_wrapper(node_index, node_name, qnn_scalar_param_name, qnn_scalar); + param_names.push_back(qnn_param_wrapper.GetParamTensorName()); + qnn_model_wrapper.AddParamWrapper(std::move(qnn_param_wrapper)); + return Status::OK(); + } + Status SetOutputQParamEqualToInputIfNearlyEqual(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, const logging::Logger& logger, @@ -140,6 +169,7 @@ class BaseOpBuilder : public IOpBuilder { {"Less", QNN_OP_ELEMENT_WISE_LESS}, {"LessOrEqual", QNN_OP_ELEMENT_WISE_LESS_EQUAL}, {"Log", QNN_OP_ELEMENT_WISE_LOG}, + {"LSTM", QNN_OP_LSTM}, {"Max", QNN_OP_ELEMENT_WISE_MAXIMUM}, {"Min", QNN_OP_ELEMENT_WISE_MINIMUM}, {"Neg", QNN_OP_ELEMENT_WISE_NEG}, diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/conv_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/conv_op_builder.cc index 23811c200213a..fbf4cbe53a812 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/conv_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/conv_op_builder.cc @@ -280,6 +280,50 @@ Status ConvOpBuilder::ProcessConv2D3DInputs(QnnModelWrapper& qnn_model_wrapper, std::move(input_info.quant_param), std::move(actual_shape), std::move(unpacked_tensor)); ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_tensorwrapper)), "Failed to add tensor."); + + // Workaround that inserts a QNN Convert op before input[1] (converts from quantized uint16 to signed symmetric int16) + // to avoid a QNN validation failure. + // + // QNN graph WITHOUT workaround (fails validation): + // input_0_uint16 ---> Conv ---> output_uint16 + // ^ + // | + // input_1_uint16 -----+ + // + // QNN graph WITH workaround (passes validation): + // input_0_uint16 ----------------------> Conv ---> output_uint16 + // ^ + // | + // input_1_uint16 --> Convert(to int16) --+ + + std::string weight_input_name = input_names.back(); + const auto& weight_tensor_wrapper = qnn_model_wrapper.GetQnnTensorWrapper(weight_input_name); + + if (weight_tensor_wrapper.GetTensorDataType() == QNN_DATATYPE_UFIXED_POINT_16) { + const auto& quant_param_wrapper = weight_tensor_wrapper.GetQnnQuantParams(); + const Qnn_QuantizeParams_t& quant_param = quant_param_wrapper.Get(); + const auto& transformed_input1_shape = weight_tensor_wrapper.GetTensorDims(); + + ORT_RETURN_IF_NOT(quant_param_wrapper.IsPerTensor(), + "Conv's INT16 weight inputs only support INT16 per-tensor quantization"); + + // Pop Conv weight. Insert Convert op after Weight + input_names.pop_back(); + const std::string& conv_output_name = node_unit.Outputs()[0].node_arg.Name(); + std::string convert_output_name = weight_input_name + "_convert_" + conv_output_name; + + ORT_RETURN_IF_ERROR(utils::InsertConvertOp(qnn_model_wrapper, + weight_input_name, + convert_output_name, + QNN_DATATYPE_UFIXED_POINT_16, + QNN_DATATYPE_SFIXED_POINT_16, + quant_param.scaleOffsetEncoding.offset, + quant_param.scaleOffsetEncoding.scale, + transformed_input1_shape, + true, // Symmetric + do_op_validation)); + input_names.push_back(convert_output_name); + } } // diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/lstm_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/lstm_op_builder.cc new file mode 100644 index 0000000000000..f131d58277038 --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/lstm_op_builder.cc @@ -0,0 +1,807 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/qnn/builder/opbuilder/base_op_builder.h" +#include "core/providers/qnn/builder/qnn_model_wrapper.h" +#include "core/providers/qnn/builder/op_builder_factory.h" +#include "core/providers/qnn/builder/qnn_utils.h" + +namespace onnxruntime { +namespace qnn { + +class LSTMOpBuilder : public BaseOpBuilder { + public: + LSTMOpBuilder() : BaseOpBuilder("LSTMOpBuilder") {} + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(LSTMOpBuilder); + + protected: + /* + ONNX LSTM inputs: + in[0]: X [seq_length, batch_size, input_size], the input sequences packed + in[1]: W [num_directions, 4*hidden_size, input_size], the weight tensor for the gates. Concatenation of W[iofc] and WB[iofc] + in[2]: R [num_directions, 4*hidden_size, hidden_size], the recurrence weight tensor. Concatenation of R[iofc] and RB[iofc] + + ONNX LSTM optional inputs: + in[3]: B [num_directions, 8*hidden_size], the bias tensor for input gate. Concatenation of [Wb[iofc], Rb[iofc]], and [WBb[iofc], RBb[iofc]] (if bidirectional) + in[4]: sequence_lens + in[5]: initial_h [num_directions, batch_size, hidden_size]. + in[6]: initial_c [num_directions, batch_size, hidden_size]. + in[7]: P [num_directions, 3*hidde_size], the weight tensor for peepholes. Concatenation of P[iof] and PB[iof] + + ONNX LSTM Parameters: + - activation_alpha ---> Not supported by QNN. + - activation_beta ---> Not supported by QNN. + - activations ---> Not supported by QNN. + - clip ---> Not supported by QNN since the clip in ONNX applied to iofc while QNN only apply to c. Refer + https://github.com/microsoft/onnxruntime/blob/v1.21.0/onnxruntime/core/providers/cpu/rnn/uni_directional_lstm.cc + - direction + - hidden_size + - input_forget ---> Not supported by QNN + - layout: The shape format of inputs X, initial_h, initial_c and outputs Y, Y_h, Y_c. + If 0, the following shapes are expected: + X.shape = [seq_length, batch_size, input_size], + Y.shape = [seq_length, num_directions, batch_size, hidden_size], + initial_h.shape = Y_h.shape = initial_c.shape = Y_c.shape = [num_directions, batch_size, hidden_size]. + If 1, the following shapes are expected: + X.shape = [batch_size, seq_length, input_size], + Y.shape = [batch_size, seq_length, num_directions, hidden_size], + initial_h.shape = Y_h.shape = initial_c.shape = Y_c.shape = [batch_size, num_directions, hidden_size]. + + ONNX LSTM optional outputs: + out[0]: Y [seq_length, num_directions, batch_size, hidden_size] = stack of out[0] from QNN_LSTM with varient directions + out[1]: Y_h [num_directions, batch_size, hidden_size] = stack of out[2] from QNN_LSTM with varient directions + out[2]: Y_c [num_directions, batch_size, hidden_size] = stack of out[1] from QNN_LSTM with varient directions + + QNN LSTM inputs: + in[0]: x_t: 2D of shape [batch_size, input_size] or + 3D of shape [time_steps, batch_size, input_size] if time_major + [batch_size, time_steps, input_size] else + in[1]: W_xf: input-to-forget weights [num_units, input_size] = ONNX in[1][direction, 2*hidden_size:3*hidden_size, :] + in[2]: W_xc: input-to-cell weights [num_units, input_size] = ONNX in[1][direction, 3*hidden_size:4*hidden_size, :] + in[3]: W_xo: input-to-output weights [num_units, input_size] = ONNX in[1][direction, 1*hidden_size:2*hidden_size, :] + in[4]: W_hf: recurrent-to-forget weights [num_units, output_size] = ONNX in[2][direction, 2*hidden_size:3*hidden_size, :] + in[5]: W_hc: recurrent-to-cell weights [num_units, output_size] = ONNX in[2][direction, 3*hidden_size:4*hidden_size, :] + in[6]: W_ho: recurrent-to-output weights [num_units, output_size] = ONNX in[2][direction, 1*hidden_size:2*hidden_size, :] + in[7]: b_f: forget gate bias [num_units] = ONNX in[3][direction, 2*hidden_size:3*hidden_size] + in[3][direction, 6*hidden_size:7*hidden_size] + in[8]: b_c: cell bias [num_units] = ONNX in[3][direction, 3*hidden_size:4*hidden_size] + in[3][direction, 7*hidden_size:8*hidden_size] + in[9]: b_o: output gate bias [num_units] = ONNX in[3][direction, 1*hidden_size:4*hidden_size] + in[3][direction, 5*hidden_size:6*hidden_size] + + # optional inputs + in[10]: h_t_init: hidden state init [batch_size, output_size] = ONNX in[5][direction] + in[11]: c_t_init: cell state init [batch_size, num_units] = ONNX in[6][direction] + in[12]: The input layer normalization weights ---> not supported on fp16 yet. + in[13]: The forget layer normalization weights ---> not supported on fp16 yet. + in[14]: The cell layer normalization weights ---> not supported on fp16 yet. + in[15]: The output layer normalization weights ---> not supported on fp16 yet. + in[16]: W_xi: input-to-input weights [num_units, input_size] = ONNX in[1][direction, 0*hidden_size:1*hidden_size, :] + in[17]: W_hi: recurrent-to-input weights [num_units, output_size] = ONNX in[2][direction, 0*hidden_size:1*hidden_size, :] + in[18]: W_ci: cell-to-input weights [num_units] = ONNX in[7][direction, 0*hidden_size:1*hidden_size] + in[19]: W_cf: cell-to-forget weights [num_units] = ONNX in[7][direction, 2*hidden_size:3*hidden_size] + in[20]: W_co: cell-to-output weights [num_units] = ONNX in[7][direction, 1*hidden_size:2*hidden_size] + in[21]: b_i: input gate bias [num_units] = ONNX in[3][direction, 0*hidden_size:1*hidden_size] + in[3][direction, 4*hidden_size:5*hidden_size] + in[22]: W_proj: projection weights [output_size, num_units] ---> not used + in[23]: b_proj: projection bias [output_size] ---> not used + in[24]: reset: Determines if the internal state should be reset ---> not used + + QNN LSTM Parameters: + - direction + - cell_clip_threshold ---> not used + - output_clip_threshold ---> not used + - time_major + - input_gate_qscale ---> not used since we fallback to fp16. + - forget_gate_qscale ---> not used since we fallback to fp16. + - cell_gate_qscale ---> not used since we fallback to fp16. + - output_gate_qscale ---> not used since we fallback to fp16. + - hidden_state_offset ---> not used since we fallback to fp16. + - hidden_state_qscale ---> not used since we fallback to fp16. + + QNN LSTM outputs: + out[0]: h_t 2D of shape [batch_size, output_size] or + 3D of shape [time_steps, batch_size, output_size] if time_major + [batch_size, time_steps, output_size] else + out[1]: c_t [batch_size, num_unit] + out[2]: o_t [batch_size, output_size] + + QNN LSTM optional outputs: + out[3]: input_gate [batch_size, num_unit] ---> not used + out[4]: forget_gate [batch_size, num_unit] ---> not used + out[5]: cell_gate [batch_size, num_unit] ---> not used + out[6]: output_gate [batch_size, num_unit] ---> not used + out[7]: hidden_state [batch_size, output_size] ---> not used + */ + + Status IsOpSupported(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + const logging::Logger& logger) const override ORT_MUST_USE_RESULT; + + Status ProcessInputs(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + const logging::Logger& logger, + std::vector& input_names, + bool do_op_validation) const override ORT_MUST_USE_RESULT; + + Status ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + std::vector&& input_names, + const logging::Logger& logger, + bool do_op_validation) const override ORT_MUST_USE_RESULT; + + private: + Status AddUnidirectionLSTM(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + const std::string& direction, + const std::vector& input_names, + const logging::Logger& logger, + const bool& do_op_validation, + const bool& is_bidirection, + std::vector& uni_lstm_output_names) const; + Status AddStridedSliceOrReshape(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + const std::string& input_name, + const std::string& output_name, + const std::vector& input_shape, + const std::vector& output_shape, + const std::vector>& ranges, + const uint32_t& begin_mask, + const uint32_t& end_mask, + const uint32_t& shrink_axes, + const uint32_t& new_axes_mask, + const Qnn_DataType_t& tensor_data_type, + const QnnQuantParamsWrapper& quantize_param, + bool do_op_validation, + bool is_for_input, + bool is_for_output) const; +}; + +Status LSTMOpBuilder::AddStridedSliceOrReshape(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + const std::string& input_name, + const std::string& output_name, + const std::vector& input_shape, + const std::vector& output_shape, + const std::vector>& ranges, + const uint32_t& begin_mask, + const uint32_t& end_mask, + const uint32_t& shrink_axes, + const uint32_t& new_axes_mask, + const Qnn_DataType_t& tensor_data_type, + const QnnQuantParamsWrapper& quantize_param, + bool do_op_validation, + bool is_for_input, + bool is_for_output) const { + if (qnn_model_wrapper.IsQnnTensorWrapperExist(output_name)) { + return Status::OK(); + } + // add strided_slice or reshape + // this is not general condition, only limited to caller in this builder + size_t minSize = std::min(input_shape.size(), output_shape.size()); + if (input_shape[0] == 1 && std::equal(output_shape.rbegin(), output_shape.rbegin() + minSize, input_shape.rbegin())) { + // add Reshape + ORT_RETURN_IF_ERROR(qnn_model_wrapper.AddReshapeNode(input_name, + output_name, + input_shape, + output_shape, + tensor_data_type, + quantize_param.Copy(), + quantize_param.Copy(), + do_op_validation, + is_for_input, + is_for_output)); + } else { + // add StridedSlice + // inputs + QnnTensorWrapper input_tensorwrapper(input_name, is_for_input ? QNN_TENSOR_TYPE_APP_WRITE : QNN_TENSOR_TYPE_NATIVE, + tensor_data_type, quantize_param.Copy(), + std::vector(input_shape)); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_tensorwrapper)), + "Failed to add input tensor for inserted StridedSlice or Reshape."); + + // params + const std::string& node_name = output_name; + + // ranges + std::vector ranges_data; + for (size_t i = 0; i < ranges.size(); i++) { + for (size_t j = 0; j < 3; j++) { + ranges_data.emplace_back(SafeInt(ranges[i][j])); + } + } + QnnParamWrapper ranges_param_wrapper(node_unit.Index(), node_name, QNN_OP_STRIDED_SLICE_PARAM_RANGES, {static_cast(ranges.size()), 3}, std::move(ranges_data), true); + std::vector param_names = { + ranges_param_wrapper.GetParamTensorName(), + }; + qnn_model_wrapper.AddParamWrapper(std::move(ranges_param_wrapper)); + + // begin_mask + ORT_RETURN_IF_ERROR(AddQnnScalar(qnn_model_wrapper, node_unit.Index(), node_name, begin_mask, QNN_OP_STRIDED_SLICE_PARAM_BEGIN_MASK, param_names)); + + // end_mask + ORT_RETURN_IF_ERROR(AddQnnScalar(qnn_model_wrapper, node_unit.Index(), node_name, end_mask, QNN_OP_STRIDED_SLICE_PARAM_END_MASK, param_names)); + + // shrink_axes + ORT_RETURN_IF_ERROR(AddQnnScalar(qnn_model_wrapper, node_unit.Index(), node_name, shrink_axes, QNN_OP_STRIDED_SLICE_PARAM_SHRINK_AXES, param_names)); + + // new_axes_mask + ORT_RETURN_IF_ERROR(AddQnnScalar(qnn_model_wrapper, node_unit.Index(), node_name, new_axes_mask, QNN_OP_STRIDED_SLICE_PARAM_NEW_AXES_MASK, param_names)); + + // outputs + QnnTensorWrapper output_tensorwrapper(output_name, + is_for_output ? QNN_TENSOR_TYPE_APP_READ : QNN_TENSOR_TYPE_NATIVE, + tensor_data_type, + quantize_param.Copy(), + std::vector(output_shape)); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(output_tensorwrapper)), + "Failed to add output tensor for inserted StridedSlice."); + // addNode + ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(node_name, QNN_OP_PACKAGE_NAME_QTI_AISW, QNN_OP_STRIDED_SLICE, {input_name}, + {output_name}, std::move(param_names), do_op_validation), + "Failed to create manually inserted Qnn StridedSlice node."); + } + + return Status::OK(); +} + +Status LSTMOpBuilder::IsOpSupported(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + const logging::Logger& logger) const { + ORT_UNUSED_PARAMETER(qnn_model_wrapper); + ORT_UNUSED_PARAMETER(node_unit); + ORT_UNUSED_PARAMETER(logger); + if (node_unit.Inputs().size() > 4 && node_unit.Inputs()[4].node_arg.Exists()) { + TensorInfo tensor_info = {}; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(node_unit.Inputs()[4], tensor_info)); + + ORT_RETURN_IF_NOT(tensor_info.is_initializer, "QNN EP: dynamic sequence_length is not supported."); + + std::vector sequence_lens_bytes; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(*tensor_info.initializer_tensor, sequence_lens_bytes)); + const size_t num_elems = sequence_lens_bytes.size() / sizeof(int32_t); + gsl::span sequence_lens{reinterpret_cast(sequence_lens_bytes.data()), num_elems}; + ORT_RETURN_IF(std::any_of(sequence_lens.begin(), + sequence_lens.end(), + [sequence_lens](int i) { return i != sequence_lens[0]; }), + "QNN EP: Only support LSTM with same sequence length."); + } + + NodeAttrHelper node_helper(node_unit); + const float clip = node_helper.Get("clip", (float)0.0); + ORT_RETURN_IF(clip != 0, + "QNN EP doesn't support non-default clip for LSTM."); + const std::vector activations = node_helper.Get("activations", std::vector{}); + ORT_RETURN_IF((activations.size() >= 3 && (activations[0] != "sigmoid" || activations[1] != "tanh" || activations[2] != "tanh")) || + (activations.size() == 6 && (activations[3] != "sigmoid" || activations[5] != "tanh" || activations[5] != "tanh")), + "QNN EP doesn't support non-default activations for LSTM."); + // TODO: Add support for layout==1 + const int64_t layout = node_helper.Get("layout", static_cast(0)); + ORT_RETURN_IF_NOT(layout == 0, + "QNN EP: Unsupport layout mode %ld for %s.", layout, node_unit.Name().c_str()); + return Status::OK(); +} + +Status LSTMOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + const logging::Logger& logger, + std::vector& input_names, + bool do_op_validation) const { + ORT_UNUSED_PARAMETER(do_op_validation); + const auto& onnx_inputs = node_unit.Inputs(); + for (size_t i = 0; i < onnx_inputs.size(); i++) { + if (onnx_inputs[i].node_arg.Exists()) { + ORT_RETURN_IF_ERROR(ProcessInput(qnn_model_wrapper, onnx_inputs[i], logger, input_names)); + } else { + input_names.emplace_back(""); + } + } + return Status::OK(); +} + +Status LSTMOpBuilder::AddUnidirectionLSTM(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + const std::string& direction, + const std::vector& input_names, + const logging::Logger& logger, + const bool& do_op_validation, + const bool& is_bidirection, + std::vector& uni_lstm_output_names) const { + ORT_UNUSED_PARAMETER(logger); + + const auto& onnx_inputs = node_unit.Inputs(); + const auto& onnx_outputs = node_unit.Outputs(); + const std::string& node_name = node_unit.Name(); + std::vector input_tensor_infos(onnx_inputs.size()); + for (size_t i = 0; i < onnx_inputs.size(); i++) { + if (onnx_inputs[i].node_arg.Exists()) { + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(onnx_inputs[i], input_tensor_infos[i])); + } + } + // becuase QNN LSTM three outputs are mandatory, we should provide them tensor info + std::vector output_tensor_infos(3); + for (size_t i = 0; i < 3; i++) { + if (onnx_outputs.size() > i && onnx_outputs[i].node_arg.Exists()) { + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(onnx_outputs[i], output_tensor_infos[i])); + } else { + output_tensor_infos[i].qnn_data_type = input_tensor_infos[0].qnn_data_type; + } + } + + NodeAttrHelper node_helper(node_unit); + const uint32_t hidden_size = node_helper.Get("hidden_size", 0); + const int32_t hidden_size_sign = SafeInt(hidden_size); + ORT_RETURN_IF_NOT(hidden_size > 0, "hidden size is not set for LSTM"); + const int64_t layout = node_helper.Get("layout", static_cast(0)); + + const uint32_t input_size = input_tensor_infos[0].shape[2]; + const uint32_t batch_size = layout == 0 ? input_tensor_infos[0].shape[1] : input_tensor_infos[0].shape[0]; + const uint32_t seq_length = layout == 0 ? input_tensor_infos[0].shape[0] : input_tensor_infos[0].shape[1]; + const int32_t direction_idx = input_tensor_infos[1].shape[0] < 2 || direction == "forward" ? 0 : 1; + + // params + std::vector param_names; + + // direction + ORT_RETURN_IF_ERROR(AddQnnScalar(qnn_model_wrapper, node_unit.Index(), node_unit.Name(), direction == "forward" ? QNN_OP_LSTM_DIRECTION_FORWARD : QNN_OP_LSTM_DIRECTION_REVERSE, QNN_OP_LSTM_PARAM_DIRECTION, param_names)); + + // cell_clip_threshold + ORT_RETURN_IF_ERROR(AddQnnScalar(qnn_model_wrapper, node_unit.Index(), node_unit.Name(), 0.0, QNN_OP_LSTM_PARAM_CELL_CLIP_THRESHOLD, param_names)); + + // output_clip_threshold + ORT_RETURN_IF_ERROR(AddQnnScalar(qnn_model_wrapper, node_unit.Index(), node_unit.Name(), 0.0, QNN_OP_LSTM_PARAM_OUTPUT_CLIP_THRESHOLD, param_names)); + + // time_major + ORT_RETURN_IF_ERROR(AddQnnScalar(qnn_model_wrapper, node_unit.Index(), node_unit.Name(), false, QNN_OP_LSTM_PARAM_TIME_MAJOR, param_names)); + + // // input_gate_qscale + ORT_RETURN_IF_ERROR(AddQnnScalar(qnn_model_wrapper, node_unit.Index(), node_unit.Name(), 0.0, QNN_OP_LSTM_PARAM_INPUT_GATE_QSCALE, param_names)); + + // // forget_gate_qscale + ORT_RETURN_IF_ERROR(AddQnnScalar(qnn_model_wrapper, node_unit.Index(), node_unit.Name(), 0.0, QNN_OP_LSTM_PARAM_FORGET_GATE_QSCALE, param_names)); + + // // cell_gate_qscale + ORT_RETURN_IF_ERROR(AddQnnScalar(qnn_model_wrapper, node_unit.Index(), node_unit.Name(), 0.0, QNN_OP_LSTM_PARAM_CELL_GATE_QSCALE, param_names)); + + // // output_gate_qscale + ORT_RETURN_IF_ERROR(AddQnnScalar(qnn_model_wrapper, node_unit.Index(), node_unit.Name(), 0.0, QNN_OP_LSTM_PARAM_OUTPUT_GATE_QSCALE, param_names)); + + // // hidden_state_offset + ORT_RETURN_IF_ERROR(AddQnnScalar(qnn_model_wrapper, node_unit.Index(), node_unit.Name(), 0.0, QNN_OP_LSTM_PARAM_HIDDEN_STATE_OFFSET, param_names)); + + // // hidden_state_qscale + ORT_RETURN_IF_ERROR(AddQnnScalar(qnn_model_wrapper, node_unit.Index(), node_unit.Name(), 0.0, QNN_OP_LSTM_PARAM_HIDDEN_STATE_QSCALE, param_names)); + + // Common LSTM cell inputs + const std::string null_tensor_name = "null_tensor"; + QnnTensorWrapper null_tensor_wrapper(null_tensor_name, QNN_TENSOR_TYPE_NULL, QNN_DATATYPE_UNDEFINED, + QnnQuantParamsWrapper(), std::vector{0}); + + qnn_model_wrapper.AddTensorWrapper(std::move(null_tensor_wrapper)); + std::vector qnn_lstm_input_names(24, null_tensor_name); + + // input W + { + // QNN in[1] = ONNX in[1][direction, 2*hidden_size:3*hidden_size, :] + // QNN in[2] = ONNX in[1][direction, 3*hidden_size:4*hidden_size, :] + // QNN in[3] = ONNX in[1][direction, 1*hidden_size:2*hidden_size, :] + // QNN in[16] = ONNX in[1][direction, 0*hidden_size:1*hidden_size, :] + uint32_t begin_mask = 0b000U; + uint32_t end_mask = 0b000U; + uint32_t shrink_axes = 0b001U; + uint32_t new_axes_mask = 0b000U; + std::vector qnn_input_indices = {1, 2, 3, 16}; + std::vector begins = {2, 3, 1, 0}; + std::vector qnn_lstm_weight_name = { + input_names[1] + "_input_to_forget_gate_weight_" + direction, + input_names[1] + "_input_to_cell_gate_weight_" + direction, + input_names[1] + "_input_to_output_gate_weight_" + direction, + input_names[1] + "_input_to_input_gate_weight_" + direction, + }; + for (size_t i = 0; i < 4; i++) { + std::vector> ranges = {{direction_idx, direction_idx + 1, 1}, + {begins[i] * hidden_size_sign, (begins[i] + 1) * hidden_size_sign, 1}, + {0, SafeInt(input_size), 1}}; + std::vector output_shape = {hidden_size, input_size}; + ORT_RETURN_IF_ERROR(AddStridedSliceOrReshape(/*qnn_model_wrapper=*/qnn_model_wrapper, + /*node_unit=*/node_unit, + /*input_name=*/input_names[1], + /*output_name=*/qnn_lstm_weight_name[i], + /*input_shape=*/input_tensor_infos[1].shape, + /*output_shape=*/output_shape, + /*ranges=*/ranges, + /*begin_mask=*/begin_mask, + /*end_mask=*/end_mask, + /*shrink_axes=*/shrink_axes, + /*new_axes_mask=*/new_axes_mask, + /*tensor_data_type=*/input_tensor_infos[1].qnn_data_type, + /*QnnQuantParamsWrapper=*/input_tensor_infos[1].quant_param, + /*do_op_validation=*/do_op_validation, + /*is_for_input=*/false, + /*is_for_output=*/false)); + qnn_lstm_input_names[qnn_input_indices[i]] = qnn_lstm_weight_name[i]; + } + } + + // input R + { + // QNN in[4] = ONNX in[2][direction, 2*hidden_size:3*hidden_size, :] + // QNN in[5] = ONNX in[2][direction, 3*hidden_size:4*hidden_size, :] + // QNN in[6] = ONNX in[2][direction, 1*hidden_size:2*hidden_size, :] + // QNN in[17] = ONNX in[2][direction, 0*hidden_size:1*hidden_size, :] + uint32_t begin_mask = 0b000U; + uint32_t end_mask = 0b000U; + uint32_t shrink_axes = 0b001U; + uint32_t new_axes_mask = 0b000U; + std::vector qnn_input_indices = {4, 5, 6, 17}; + std::vector begins = {2, 3, 1, 0}; + std::vector qnn_lstm_weight_name = { + input_names[2] + "_recurrent_to_forget_gate_weight_" + direction, + input_names[2] + "_recurrent_to_cell_gate_weight_" + direction, + input_names[2] + "_recurrent_to_output_gate_weight_" + direction, + input_names[2] + "_recurrent_to_input_gate_weight_" + direction}; + for (size_t i = 0; i < 4; i++) { + std::vector> ranges = {{direction_idx, direction_idx + 1, 1}, + {begins[i] * hidden_size_sign, (begins[i] + 1) * hidden_size_sign, 1}, + {0, hidden_size_sign, 1}}; + std::vector output_shape = {hidden_size, hidden_size}; + ORT_RETURN_IF_ERROR(AddStridedSliceOrReshape(/*qnn_model_wrapper=*/qnn_model_wrapper, + /*node_unit=*/node_unit, + /*input_name=*/input_names[2], + /*output_name=*/qnn_lstm_weight_name[i], + /*input_shape=*/input_tensor_infos[2].shape, + /*output_shape=*/output_shape, + /*ranges=*/ranges, + /*begin_mask=*/begin_mask, + /*end_mask=*/end_mask, + /*shrink_axes=*/shrink_axes, + /*new_axes_mask=*/new_axes_mask, + /*tensor_data_type=*/input_tensor_infos[2].qnn_data_type, + /*QnnQuantParamsWrapper=*/input_tensor_infos[2].quant_param, + /*do_op_validation=*/do_op_validation, + /*is_for_input=*/false, + /*is_for_output=*/false)); + qnn_lstm_input_names[qnn_input_indices[i]] = qnn_lstm_weight_name[i]; + } + } + + // input B + { + // QNN in[7] = ONNX in[3][direction, 2*hidden_size:3*hidden_size] + ONNX in[3][direction, 6*hidden_size:7*hidden_size] + // QNN in[8] = ONNX in[3][direction, 3*hidden_size:4*hidden_size] + ONNX in[3][direction, 7*hidden_size:8*hidden_size] + // QNN in[9] = ONNX in[3][direction, 1*hidden_size:2*hidden_size] + ONNX in[3][direction, 5*hidden_size:6*hidden_size] + // QNN in[21] = ONNX in[3][direction, 0*hidden_size:1*hidden_size] + ONNX in[3][direction, 4*hidden_size:5*hidden_size] + uint32_t begin_mask = 0b00U; + uint32_t end_mask = 0b00U; + uint32_t shrink_axes = 0b01U; + uint32_t new_axes_mask = 0b00U; + std::vector output_shape = {hidden_size}; + std::vector qnn_lstm_bias_name = { + node_name + "_forget_gate_bias_" + direction, + node_name + "_cell_gate_bias_" + direction, + node_name + "_output_gate_bias_" + direction, + node_name + "_input_gate_bias_" + direction}; + std::vector qnn_input_indices = {7, 8, 9, 21}; + if (onnx_inputs.size() > 3 && onnx_inputs[3].node_arg.Exists()) { + std::vector begins = {2, 3, 1, 0, 6, 7, 5, 4}; + std::vector onnx_lstm_bias_name = { + input_names[3] + "_input_to_forget_gate_bias_" + direction, + input_names[3] + "_input_to_cell_gate_bias_" + direction, + input_names[3] + "_input_to_output_gate_bias_" + direction, + input_names[3] + "_input_to_input_gate_bias_" + direction, + input_names[3] + "_recurrent_to_forget_gate_bias_" + direction, + input_names[3] + "_recurrent_to_cell_gate_bias_" + direction, + input_names[3] + "_recurrent_to_output_gate_bias_" + direction, + input_names[3] + "_recurrent_to_input_gate_bias_" + direction}; + for (size_t i = 0; i < 8; i++) { + std::vector> ranges = {{direction_idx, direction_idx + 1, 1}, + {begins[i] * hidden_size_sign, (begins[i] + 1) * hidden_size_sign, 1}}; + ORT_RETURN_IF_ERROR(AddStridedSliceOrReshape(/*qnn_model_wrapper=*/qnn_model_wrapper, + /*node_unit=*/node_unit, + /*input_name=*/input_names[3], + /*output_name=*/onnx_lstm_bias_name[i], + /*input_shape=*/input_tensor_infos[3].shape, + /*output_shape=*/output_shape, + /*ranges=*/ranges, + /*begin_mask=*/begin_mask, + /*end_mask=*/end_mask, + /*shrink_axes=*/shrink_axes, + /*new_axes_mask=*/new_axes_mask, + /*tensor_data_type=*/input_tensor_infos[3].qnn_data_type, + /*QnnQuantParamsWrapper=*/input_tensor_infos[3].quant_param, + /*do_op_validation=*/do_op_validation, + /*is_for_input=*/false, + /*is_for_output=*/false)); + } + for (size_t i = 0; i < 4; i++) { + std::vector add_input_names = {onnx_lstm_bias_name[i], onnx_lstm_bias_name[i + 4]}; + // TODO: The quantize_param should not be used directly, we should calculate an approximate quant_param here. + QnnTensorWrapper add_output_tensorwrapper(qnn_lstm_bias_name[i], QNN_TENSOR_TYPE_NATIVE, input_tensor_infos[3].qnn_data_type, + input_tensor_infos[3].quant_param.Copy(), std::vector(output_shape)); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(add_output_tensorwrapper)), + "QNN EP: Failed to add output tensor for inserted ElementWiseAdd node."); + ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(node_name, QNN_OP_PACKAGE_NAME_QTI_AISW, QNN_OP_ELEMENT_WISE_ADD, + std::move(add_input_names), {qnn_lstm_bias_name[i]}, {}, do_op_validation), + "Failed to create manually inserted ElementWiseAdd node."); + qnn_lstm_input_names[qnn_input_indices[i]] = qnn_lstm_bias_name[i]; + } + } else { + // prepare zero bias + std::string zero_bias_name = node_name + "_zero_bias"; + QnnTensorWrapper zero_bias_tensor_wrapper(zero_bias_name, + QNN_TENSOR_TYPE_STATIC, + input_tensor_infos[0].qnn_data_type, + QnnQuantParamsWrapper(), + std::vector(output_shape), + std::vector(utils::GetElementSizeByType(input_tensor_infos[0].qnn_data_type) * hidden_size, 0)); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(zero_bias_tensor_wrapper)), + "Failed to add additional zero bias for QNN LSTM node."); + for (size_t i = 0; i < 4; i++) { + qnn_lstm_input_names[qnn_input_indices[i]] = zero_bias_name; + } + } + } + + // input P + if (onnx_inputs.size() > 7 && onnx_inputs[7].node_arg.Exists()) { + // QNN in[18] = ONNX in[7][direction, 0*hidden_size:1*hidden_size] + // QNN in[19] = ONNX in[7][direction, 2*hidden_size:1*hidden_size] + // QNN in[20] = ONNX in[7][direction, 1*hidden_size:1*hidden_size] + uint32_t begin_mask = 0b00U; + uint32_t end_mask = 0b00U; + uint32_t shrink_axes = 0b01U; + uint32_t new_axes_mask = 0b00U; + std::vector output_shape = {hidden_size}; + std::vector qnn_input_indices = {18, 19, 20}; + std::vector begins = {0, 2, 1}; + std::vector qnn_lstm_weight_name = { + input_names[7] + "_cell_to_input_gate_weight_" + direction, + input_names[7] + "_cell_to_forget_gate_weight_" + direction, + input_names[7] + "_cell_to_output_gate_weight_" + direction}; + for (size_t i = 0; i < 3; i++) { + std::vector> ranges = { + {direction_idx, direction_idx + 1, 1}, + {begins[i] * hidden_size_sign, (begins[i] + 1) * hidden_size_sign, 1}, + }; + ORT_RETURN_IF_ERROR(AddStridedSliceOrReshape(/*qnn_model_wrapper=*/qnn_model_wrapper, + /*node_unit=*/node_unit, + /*input_name=*/input_names[7], + /*output_name=*/qnn_lstm_weight_name[i], + /*input_shape=*/input_tensor_infos[7].shape, + /*output_shape=*/output_shape, + /*ranges=*/ranges, + /*begin_mask=*/begin_mask, + /*end_mask=*/end_mask, + /*shrink_axes=*/shrink_axes, + /*new_axes_mask=*/new_axes_mask, + /*tensor_data_type=*/input_tensor_infos[7].qnn_data_type, + /*QnnQuantParamsWrapper=*/input_tensor_infos[7].quant_param, + /*do_op_validation=*/do_op_validation, + /*is_for_input=*/false, + /*is_for_output=*/false)); + qnn_lstm_input_names[qnn_input_indices[i]] = qnn_lstm_weight_name[i]; + } + } + + // input initial h, c + { + // QNN in[10] = ONNX in[5][direction_idx, :, :] + // QNN in[11] = ONNX in[6][direction_idx, :, :] + uint32_t begin_mask = 0b000U; + uint32_t end_mask = 0b000U; + uint32_t shrink_axes = 0b001U; + uint32_t new_axes_mask = 0b000U; + std::vector> ranges = {{direction_idx, direction_idx + 1, 1}, + {0, SafeInt(batch_size), 1}, + {0, hidden_size_sign, 1}}; + std::vector src_indices = {5, 6}; + std::vector qnn_input_indices = {10, 11}; + std::vector output_shape = {batch_size, hidden_size}; + for (size_t i = 0; i < 2; i++) { + if (onnx_inputs.size() > src_indices[i] && onnx_inputs[src_indices[i]].node_arg.Exists()) { + std::string qnn_lstm_input_name = input_names[src_indices[i]] + "_" + direction; + ORT_RETURN_IF_ERROR(AddStridedSliceOrReshape(/*qnn_model_wrapper=*/qnn_model_wrapper, + /*node_unit=*/node_unit, + /*input_name=*/input_names[src_indices[i]], + /*output_name=*/qnn_lstm_input_name, + /*input_shape=*/input_tensor_infos[src_indices[i]].shape, + /*output_shape=*/output_shape, + /*ranges=*/ranges, + /*begin_mask=*/begin_mask, + /*end_mask=*/end_mask, + /*shrink_axes=*/shrink_axes, + /*new_axes_mask=*/new_axes_mask, + /*tensor_data_type=*/input_tensor_infos[src_indices[i]].qnn_data_type, + /*QnnQuantParamsWrapper=*/input_tensor_infos[src_indices[i]].quant_param, + /*do_op_validation=*/do_op_validation, + /*is_for_input=*/false, + /*is_for_output=*/false)); + qnn_lstm_input_names[qnn_input_indices[i]] = qnn_lstm_input_name; + } else { + // prepare zero initial values + std::string zero_initial_values_name = node_name + "_LSTM_initial_values_" + (i == 0 ? "h" : "c"); + QnnTensorWrapper zero_bias_tensor_wrapper(zero_initial_values_name, + QNN_TENSOR_TYPE_STATIC, + input_tensor_infos[0].qnn_data_type, + QnnQuantParamsWrapper(), + std::vector(output_shape), + std::vector(utils::GetElementSizeByType(input_tensor_infos[0].qnn_data_type) * batch_size * hidden_size, 0)); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(zero_bias_tensor_wrapper)), + "Failed to add additional initial values for QNN LSTM node."); + qnn_lstm_input_names[qnn_input_indices[i]] = zero_initial_values_name; + } + } + } + + // add QNN LSTM + // since HTP doesn't not support 3d yet, add #sequence_length LSTM node + std::vector qnn_all_hidden_state_names; + qnn_all_hidden_state_names.resize(seq_length); + for (uint32_t i = 0; i < seq_length; i++) { + uint32_t sequence_idx = direction == "forward" ? i : seq_length - i - 1; + // Add LSTM inputs + std::vector qnn_lstm_input_names_i = qnn_lstm_input_names; + + // input X + { + // QNN in[0] = ONNX in[0][sequence_idx, :, :] + uint32_t begin_mask = 0b000U; + uint32_t end_mask = 0b000U; + uint32_t shrink_axes = 0b001U; + uint32_t new_axes_mask = 0b000U; + std::vector> ranges = {{SafeInt(sequence_idx), SafeInt(sequence_idx + 1), 1}, + {0, SafeInt(batch_size), 1}, + {0, SafeInt(input_size), 1}}; + std::string qnn_lstm_input_name = input_names[0] + "_cell_" + std::to_string(sequence_idx) + "_input"; + std::vector output_shape = {batch_size, input_size}; + ORT_RETURN_IF_ERROR(AddStridedSliceOrReshape(/*qnn_model_wrapper=*/qnn_model_wrapper, + /*node_unit=*/node_unit, + /*input_name=*/input_names[0], + /*output_name=*/qnn_lstm_input_name, + /*input_shape=*/input_tensor_infos[0].shape, + /*output_shape=*/output_shape, + /*ranges=*/ranges, + /*begin_mask=*/begin_mask, + /*end_mask=*/end_mask, + /*shrink_axes=*/shrink_axes, + /*new_axes_mask=*/new_axes_mask, + /*tensor_data_type=*/input_tensor_infos[0].qnn_data_type, + /*QnnQuantParamsWrapper=*/input_tensor_infos[0].quant_param, + /*do_op_validation=*/do_op_validation, + /*is_for_input=*/false, + /*is_for_output=*/false)); + qnn_lstm_input_names_i[0] = qnn_lstm_input_name; + } + + // outputs + std::vector qnn_lstm_output_shape = {batch_size, hidden_size}; + + std::vector qnn_lstm_output_names = { + node_name + "_QNN_LSTM_output_all_hidden_state_" + std::to_string(sequence_idx) + "_" + direction, + node_name + "_QNN_LSTM_output_cell_state_" + std::to_string(sequence_idx) + "_" + direction, + node_name + "_QNN_LSTM_output_hidden_state_" + std::to_string(sequence_idx) + "_" + direction}; + qnn_lstm_input_names[10] = qnn_lstm_output_names[2]; // update initial_h + qnn_lstm_input_names[11] = qnn_lstm_output_names[1]; // update initial_c + qnn_all_hidden_state_names[sequence_idx] = qnn_lstm_output_names[2]; + + for (size_t j = 0; j < 3; j++) { + QnnTensorWrapper output_tensorwrapper(qnn_lstm_output_names[j], + QNN_TENSOR_TYPE_NATIVE, + output_tensor_infos[j].qnn_data_type, + output_tensor_infos[j].quant_param.Copy(), + std::vector(qnn_lstm_output_shape)); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(output_tensorwrapper)), + "QNN EP: Failed to add %ldth output tensor for QNN LSTM.", j); + } + std::string lstm_node_name = node_name + "_cell_" + std::to_string(sequence_idx) + "_" + direction; + ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(lstm_node_name, QNN_OP_PACKAGE_NAME_QTI_AISW, QNN_OP_LSTM, + std::move(qnn_lstm_input_names_i), std::move(qnn_lstm_output_names), + std::vector(param_names), do_op_validation), + "QNN EP: Failed to create Qnn LSTM node."); + } + + // pack all timestamp outputs together for onnx output[0] + std::string qnn_pack_output_name = node_name + "_QNN_LSTM_output_hidden_state_all_" + direction; + + // add pack for output[0] + std::vector pack_param_names; + ORT_RETURN_IF_ERROR(AddQnnScalar(qnn_model_wrapper, node_unit.Index(), qnn_pack_output_name, 0, QNN_OP_PACK_PARAM_AXIS, pack_param_names)); + + QnnTensorWrapper pack_output_tensorwrapper(qnn_pack_output_name, + QNN_TENSOR_TYPE_NATIVE, + output_tensor_infos[0].qnn_data_type, + output_tensor_infos[0].quant_param.Copy(), + {seq_length, batch_size, hidden_size}); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(pack_output_tensorwrapper)), + "QNN EP: Failed to add output tensor for QNN Pack."); + ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(qnn_pack_output_name, QNN_OP_PACKAGE_NAME_QTI_AISW, QNN_OP_PACK, + std::move(qnn_all_hidden_state_names), {qnn_pack_output_name}, + std::move(pack_param_names), do_op_validation), + "QNN EP: Failed to create Qnn Pack node."); + + // add reshape for all outputs to align onnx output shape for unidirection + std::vector qnn_reshape_input_names = { + qnn_pack_output_name, + qnn_lstm_input_names[10], + qnn_lstm_input_names[11]}; + std::vector> qnn_lstm_output_shapes = { + {seq_length, batch_size, hidden_size}, + {batch_size, hidden_size}, + {batch_size, hidden_size}}; + // in the output shapes below, the value of 1 indicates unidirectional + std::vector> onnx_lstm_output_shapes = { + {seq_length, 1, batch_size, hidden_size}, + {1, batch_size, hidden_size}, + {1, batch_size, hidden_size}}; + for (size_t i = 0; i < 3; i++) { + if (onnx_outputs.size() > i && onnx_outputs[i].node_arg.Exists()) { + const std::string reshape_output_name = is_bidirection ? qnn_reshape_input_names[i] + "_unsqueeze_" + direction : onnx_outputs[i].node_arg.Name(); + ORT_RETURN_IF_ERROR(qnn_model_wrapper.AddReshapeNode(/*input_name=*/qnn_reshape_input_names[i], + /*output_name=*/reshape_output_name, + /*input_shape=*/qnn_lstm_output_shapes[i], + /*output_shape=*/onnx_lstm_output_shapes[i], + /*tensor_data_type=*/output_tensor_infos[i].qnn_data_type, + /*quantize_param=*/output_tensor_infos[i].quant_param, + /*do_op_validation=*/do_op_validation, + /*is_for_input=*/false, + /*is_for_output=*/qnn_model_wrapper.IsGraphOutput(reshape_output_name))); + uni_lstm_output_names.emplace_back(reshape_output_name); + } else { + uni_lstm_output_names.emplace_back(""); + } + } + return Status::OK(); +} + +Status LSTMOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + std::vector&& input_names, + const logging::Logger& logger, + bool do_op_validation) const { + ORT_UNUSED_PARAMETER(do_op_validation); + const auto& inputs = node_unit.Inputs(); + + NodeAttrHelper node_helper(node_unit); + std::string direction = node_helper.Get("direction", "forward"); + ORT_RETURN_IF_NOT(inputs.size() >= 3 && inputs.size() <= 8, "LSTM should receive inputs ranging from 3 to 8!"); + + if (direction == "bidirectional") { + std::vector uni_lstm_output_names_forward, uni_lstm_output_names_reverse; + ORT_RETURN_IF_ERROR(AddUnidirectionLSTM(qnn_model_wrapper, node_unit, "forward", input_names, logger, do_op_validation, true, uni_lstm_output_names_forward)); + ORT_RETURN_IF_ERROR(AddUnidirectionLSTM(qnn_model_wrapper, node_unit, "reverse", input_names, logger, do_op_validation, true, uni_lstm_output_names_reverse)); + + // Concat forward and reverse output + for (size_t i = 0; i < 3; i++) { + TensorInfo output_info = {}; + if (node_unit.Outputs().size() > i && node_unit.Outputs()[i].node_arg.Exists()) { + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(node_unit.Outputs()[i], output_info)); + std::string onnx_output_name = node_unit.Outputs()[i].node_arg.Name(); + + // param + std::vector concat_param_names; + ORT_RETURN_IF_ERROR(AddQnnScalar(qnn_model_wrapper, node_unit.Index(), onnx_output_name, static_cast(output_info.shape.size() - 3), QNN_OP_CONCAT_PARAM_AXIS, concat_param_names)); + + // create tensor and add op + Qnn_TensorType_t output_tensor_type = qnn_model_wrapper.IsGraphOutput(onnx_output_name) ? QNN_TENSOR_TYPE_APP_READ : QNN_TENSOR_TYPE_NATIVE; + QnnTensorWrapper concat_output_tensorwrapper(onnx_output_name, + output_tensor_type, + output_info.qnn_data_type, + output_info.quant_param.Copy(), + std::vector(output_info.shape)); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(concat_output_tensorwrapper)), + "QNN EP: Failed to add output tensor for QNN Concat."); + ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(node_unit.Name(), QNN_OP_PACKAGE_NAME_QTI_AISW, QNN_OP_CONCAT, + {uni_lstm_output_names_forward[i], uni_lstm_output_names_reverse[i]}, {onnx_output_name}, + std::move(concat_param_names), do_op_validation), + "QNN EP: Failed to create Qnn Concat node."); + } + } + } else { + std::vector uni_lstm_output_names; + ORT_RETURN_IF_ERROR(AddUnidirectionLSTM(qnn_model_wrapper, node_unit, direction, input_names, logger, do_op_validation, false, uni_lstm_output_names)); + } + return Status::OK(); +} + +void CreateLSTMOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { + op_registrations.AddOpBuilder(op_type, std::make_unique()); +} + +} // namespace qnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/pool_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/pool_op_builder.cc index ba9f7baa4c1ee..f932858eb2fd9 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/pool_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/pool_op_builder.cc @@ -103,6 +103,36 @@ Status PoolOpBuilder::IsOpSupported(QnnModelWrapper& qnn_model_wrapper, return Status::OK(); } +static std::vector AmendOutputShapeForRank3Pool( + gsl::span input_shape, // {N, H, W, C} + gsl::span kernel_shape, // {k_h, k_w} + gsl::span strides, // {s_h, s_w} + gsl::span pads) { + assert(input_shape.size() == 4 && + kernel_shape.size() == 2 && + strides.size() == 2 && + pads.size() == 4); + + const uint32_t N = input_shape[0]; + const uint32_t H = input_shape[1]; + const uint32_t W = input_shape[2]; + const uint32_t C = input_shape[3]; + + // pad the spatial dims + uint32_t padded_H = H + pads[0] + pads[2]; + uint32_t padded_W = W + pads[1] + pads[3]; + + // floor-mode on NHWC + uint32_t out_H = (padded_H < kernel_shape[0]) + ? 0 + : (padded_H - kernel_shape[0]) / strides[0] + 1; + uint32_t out_W = (padded_W < kernel_shape[1]) + ? 0 + : (padded_W - kernel_shape[1]) / strides[1] + 1; + + return {N, out_H, out_W, C}; +} + Status PoolOpBuilder::SetCommonPoolParams(const NodeAttrHelper& node_helper, std::vector& filter_size, std::vector& pad_amount, std::vector& strides, @@ -153,6 +183,14 @@ Status PoolOpBuilder::SetCommonPoolParams(const NodeAttrHelper& node_helper, dilations = raw_dilations; } + // Max Pool rank 3 input + if (output_shape.size() == 3) { + // Calculate MaxPool output for rank-4 when input is rank 3 + output_shape = AmendOutputShapeForRank3Pool(input_shape, + filter_size, + strides, + pad_amount); + } auto total_pads_0 = (output_shape[1] - 1) * strides[0] + (filter_size[0] - 1) * dilations[0] + 1 - input_shape[1]; auto total_pads_1 = (output_shape[2] - 1) * strides[1] + (filter_size[1] - 1) * dilations[1] + 1 - input_shape[2]; if (auto_pad.compare("SAME_LOWER") != 0) { @@ -189,36 +227,6 @@ void SetPoolParam(const NodeUnit& node_unit, qnn_model_wrapper.AddParamWrapper(std::move(qnn_param)); } -std::vector ComputePoolOutputShape( - const std::vector& input_shape, // {N, H, W, C} - const std::vector& kernel_shape, // {k_h, k_w} - const std::vector& strides, // {s_h, s_w} - const std::vector& pads) { - assert(input_shape.size() == 4 && - kernel_shape.size() == 2 && - strides.size() == 2 && - pads.size() == 4); - - const uint32_t N = input_shape[0]; - const uint32_t H = input_shape[1]; - const uint32_t W = input_shape[2]; - const uint32_t C = input_shape[3]; - - // pad the spatial dims - uint32_t padded_H = H + pads[0] + pads[2]; - uint32_t padded_W = W + pads[1] + pads[3]; - - // floor-mode on NHWC - uint32_t out_H = (padded_H < kernel_shape[0]) - ? 0 - : (padded_H - kernel_shape[0]) / strides[0] + 1; - uint32_t out_W = (padded_W < kernel_shape[1]) - ? 0 - : (padded_W - kernel_shape[1]) / strides[1] + 1; - - return {N, out_H, out_W, C}; -} - Status PoolOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, std::vector&& input_names, @@ -316,10 +324,10 @@ Status PoolOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra } // Calculate MaxPool output for rank-4 when input is rank 3 - auto pooled_shape = ComputePoolOutputShape(onnx_in_shape, - filter_size, - stride, - pad_amount); + auto pooled_shape = AmendOutputShapeForRank3Pool(onnx_in_shape, + filter_size, + stride, + pad_amount); SetPoolParam(node_unit, QNN_OP_POOL_MAX_2D_PARAM_FILTER_SIZE, std::move(filter_size_dim), std::move(filter_size), param_tensor_names, qnn_model_wrapper); SetPoolParam(node_unit, QNN_OP_POOL_MAX_2D_PARAM_PAD_AMOUNT, std::move(pad_amount_dim), std::move(pad_amount), param_tensor_names, qnn_model_wrapper); diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc index ab022df063c96..2650316dd07ac 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc @@ -40,6 +40,7 @@ class SimpleOpBuilder : public BaseOpBuilder { static constexpr std::array gridsample_supported_modes = {"bilinear", "nearest"}; static constexpr std::array gridsample_supported_padding_modes = {"zeros", "border", "reflection"}; + static constexpr std::array scatternd_supported_reduction = {"none", "add", "mul"}; }; Status SimpleOpBuilder::ExplicitOpCheck(QnnModelWrapper& qnn_model_wrapper, @@ -101,6 +102,14 @@ Status SimpleOpBuilder::ExplicitOpCheck(QnnModelWrapper& qnn_model_wrapper, } } + // QNN ScatterND doesn't support MAX, MIN reduction + if (op_type == "ScatterND") { + NodeAttrHelper node_helper(node_unit); + std::string reduction = node_helper.Get("reduction", "none"); + ORT_RETURN_IF_NOT(utils::ArrayHasString(scatternd_supported_reduction, reduction), "ScatterND does not support reduction ", + reduction.c_str()); + } + return Status::OK(); } @@ -254,6 +263,31 @@ Status ProcessGridSampleAttributes(QnnModelWrapper& qnn_model_wrapper, return Status::OK(); } +// Process Reduction attribute of ScatterND op +Status ProcessScatterNDReductionAttribute(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + std::vector& param_tensor_names) { + NodeAttrHelper node_helper(node_unit); + std::string reduction = node_helper.Get("reduction", "none"); + Qnn_Scalar_t reduction_qnn_scalar = QNN_SCALAR_INIT; + reduction_qnn_scalar.dataType = QNN_DATATYPE_UINT_32; + if ("none" == reduction) { + reduction_qnn_scalar.uint32Value = QNN_OP_SCATTER_ND_REDUCTION_NONE; + } else if ("add" == reduction) { + reduction_qnn_scalar.uint32Value = QNN_OP_SCATTER_ND_REDUCTION_ADD; + } else if ("mul" == reduction) { + reduction_qnn_scalar.uint32Value = QNN_OP_SCATTER_ND_REDUCTION_MUL; + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "ScatterND support only reduction:{none, add, mul}."); + } + QnnParamWrapper reduction_param(node_unit.Index(), node_unit.Name(), QNN_OP_SCATTER_ND_PARAM_REDUCTION, + reduction_qnn_scalar); + param_tensor_names.push_back(reduction_param.GetParamTensorName()); + qnn_model_wrapper.AddParamWrapper(std::move(reduction_param)); + + return Status::OK(); +} + Status SimpleOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, std::vector&& input_names, @@ -358,6 +392,11 @@ Status SimpleOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_w ORT_RETURN_IF_ERROR(ProcessGridSampleAttributes(qnn_model_wrapper, node_unit, param_tensor_names)); } + if (op_type == "ScatterND") { + // Process reduction attribute + ORT_RETURN_IF_ERROR(ProcessScatterNDReductionAttribute(qnn_model_wrapper, node_unit, param_tensor_names)); + } + return ProcessOutputs(qnn_model_wrapper, node_unit, std::move(input_names), std::move(param_tensor_names), diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/upsample_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/upsample_op_builder.cc index 48214f92b1a61..cba0eb350992f 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/upsample_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/upsample_op_builder.cc @@ -55,18 +55,6 @@ class UpsampleOpBuilder : public BaseOpBuilder { const OnnxAttrInfo onnx_mode_attr = {"mode", "nearest"}; }; -static Status AddQnnScalar(QnnModelWrapper& qnn_model_wrapper, - const NodeUnit& node_unit, - std::vector& param_tensor_names, - const Qnn_Scalar_t& qnn_scalar, - const std::string& qnn_scalar_param_name) { - QnnParamWrapper qnn_param_wrapper(node_unit.Index(), node_unit.Name(), qnn_scalar_param_name, qnn_scalar); - param_tensor_names.push_back(qnn_param_wrapper.GetParamTensorName()); - qnn_model_wrapper.AddParamWrapper(std::move(qnn_param_wrapper)); - - return Status::OK(); -} - Status UpsampleOpBuilder::IsOpSupported(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, const logging::Logger& logger) const { @@ -161,72 +149,40 @@ Status UpsampleOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model qnn_op_type = (interp_mode == "nearest") ? QNN_OP_RESIZE_NEAREST_NEIGHBOR : QNN_OP_RESIZE_BILINEAR; // Parameter 'align_corners' - Qnn_Scalar_t qnn_align_corners = QNN_SCALAR_INIT; - qnn_align_corners.dataType = QNN_DATATYPE_BOOL_8; - qnn_align_corners.bool8Value = false; const std::string align_corners_param_name = (qnn_op_type == QNN_OP_RESIZE_BILINEAR) ? QNN_OP_RESIZE_BILINEAR_PARAM_ALIGN_CORNERS : QNN_OP_RESIZE_NEAREST_NEIGHBOR_PARAM_ALIGN_CORNERS; - - ORT_RETURN_IF_ERROR(AddQnnScalar(qnn_model_wrapper, node_unit, param_tensor_names, - qnn_align_corners, align_corners_param_name)); + ORT_RETURN_IF_ERROR(AddQnnScalar(qnn_model_wrapper, node_unit.Index(), node_unit.Name(), false, align_corners_param_name, param_tensor_names)); // Parameter 'half_pixel_centers' - Qnn_Scalar_t qnn_half_pixel_centers = QNN_SCALAR_INIT; - qnn_half_pixel_centers.dataType = QNN_DATATYPE_BOOL_8; - qnn_half_pixel_centers.bool8Value = false; const std::string half_pixel_centers_param_name = (qnn_op_type == QNN_OP_RESIZE_BILINEAR) ? QNN_OP_RESIZE_BILINEAR_PARAM_HALF_PIXEL_CENTERS : QNN_OP_RESIZE_NEAREST_NEIGHBOR_PARAM_HALF_PIXEL_CENTERS; - - ORT_RETURN_IF_ERROR(AddQnnScalar(qnn_model_wrapper, node_unit, param_tensor_names, - qnn_half_pixel_centers, half_pixel_centers_param_name)); + ORT_RETURN_IF_ERROR(AddQnnScalar(qnn_model_wrapper, node_unit.Index(), node_unit.Name(), false, half_pixel_centers_param_name, param_tensor_names)); if (qnn_op_type == QNN_OP_RESIZE_BILINEAR) { // Parameter 'antialias' - Qnn_Scalar_t qnn_antialias = QNN_SCALAR_INIT; - qnn_antialias.dataType = QNN_DATATYPE_BOOL_8; - qnn_antialias.bool8Value = false; - - ORT_RETURN_IF_ERROR(AddQnnScalar(qnn_model_wrapper, node_unit, param_tensor_names, - qnn_antialias, QNN_OP_RESIZE_BILINEAR_PARAM_ANTIALIAS)); + ORT_RETURN_IF_ERROR(AddQnnScalar(qnn_model_wrapper, node_unit.Index(), node_unit.Name(), false, QNN_OP_RESIZE_BILINEAR_PARAM_ANTIALIAS, param_tensor_names)); } } else { // Remain as QNN's Resize. // Parameter 'exclude_outside' - Qnn_Scalar_t qnn_exclude_outside = QNN_SCALAR_INIT; - qnn_exclude_outside.dataType = QNN_DATATYPE_BOOL_8; - qnn_exclude_outside.bool8Value = false; - - ORT_RETURN_IF_ERROR(AddQnnScalar(qnn_model_wrapper, node_unit, param_tensor_names, - qnn_exclude_outside, QNN_OP_RESIZE_PARAM_EXCLUDE_OUTSIDE)); + ORT_RETURN_IF_ERROR(AddQnnScalar(qnn_model_wrapper, node_unit.Index(), node_unit.Name(), false, QNN_OP_RESIZE_PARAM_EXCLUDE_OUTSIDE, param_tensor_names)); // Parameter 'transformation_mode' - Qnn_Scalar_t qnn_transformation_mode = QNN_SCALAR_INIT; - qnn_transformation_mode.dataType = QNN_DATATYPE_UINT_32; - qnn_transformation_mode.uint32Value = (supported_modes.at(interp_mode) == QNN_OP_RESIZE_INTERPOLATION_MODE_NEAREST) - ? static_cast(QNN_OP_RESIZE_TRANSFORMATION_MODE_HALF_PIXEL) - : static_cast(QNN_OP_RESIZE_TRANSFORMATION_MODE_ASYMMETRIC); - - ORT_RETURN_IF_ERROR(AddQnnScalar(qnn_model_wrapper, node_unit, param_tensor_names, - qnn_transformation_mode, QNN_OP_RESIZE_PARAM_TRANSFORMATION_MODE)); + uint32_t transformation_mode = (supported_modes.at(interp_mode) == QNN_OP_RESIZE_INTERPOLATION_MODE_NEAREST) + ? static_cast(QNN_OP_RESIZE_TRANSFORMATION_MODE_HALF_PIXEL) + : static_cast(QNN_OP_RESIZE_TRANSFORMATION_MODE_ASYMMETRIC); + ORT_RETURN_IF_ERROR(AddQnnScalar(qnn_model_wrapper, node_unit.Index(), node_unit.Name(), transformation_mode, QNN_OP_RESIZE_PARAM_TRANSFORMATION_MODE, param_tensor_names)); // Parameter 'interpolation_mode' - Qnn_Scalar_t qnn_interp_mode = QNN_SCALAR_INIT; - qnn_interp_mode.dataType = QNN_DATATYPE_UINT_32; - qnn_interp_mode.uint32Value = static_cast(supported_modes.at(interp_mode)); - - ORT_RETURN_IF_ERROR(AddQnnScalar(qnn_model_wrapper, node_unit, param_tensor_names, - qnn_interp_mode, QNN_OP_RESIZE_PARAM_INTERPOLATION_MODE)); + uint32_t qnn_interp_mode = static_cast(supported_modes.at(interp_mode)); + ORT_RETURN_IF_ERROR(AddQnnScalar(qnn_model_wrapper, node_unit.Index(), node_unit.Name(), qnn_interp_mode, QNN_OP_RESIZE_PARAM_INTERPOLATION_MODE, param_tensor_names)); // Parameter 'nearest_mode'. Process only when 'interpolation_mode' is NEAREST. - if (qnn_interp_mode.uint32Value == QNN_OP_RESIZE_INTERPOLATION_MODE_NEAREST) { - Qnn_Scalar_t qnn_nearest_mode = QNN_SCALAR_INIT; - qnn_nearest_mode.dataType = QNN_DATATYPE_UINT_32; - qnn_nearest_mode.uint32Value = static_cast(QNN_OP_RESIZE_NEAREST_MODE_ROUND_PREFER_FLOOR); - - ORT_RETURN_IF_ERROR(AddQnnScalar(qnn_model_wrapper, node_unit, param_tensor_names, - qnn_nearest_mode, QNN_OP_RESIZE_PARAM_NEAREST_MODE)); + if (qnn_interp_mode == QNN_OP_RESIZE_INTERPOLATION_MODE_NEAREST) { + uint32_t qnn_nearest_mode = static_cast(QNN_OP_RESIZE_NEAREST_MODE_ROUND_PREFER_FLOOR); + ORT_RETURN_IF_ERROR(AddQnnScalar(qnn_model_wrapper, node_unit.Index(), node_unit.Name(), qnn_nearest_mode, QNN_OP_RESIZE_PARAM_NEAREST_MODE, param_tensor_names)); } } diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc index 0009dab837525..901569b54e049 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc @@ -1,3 +1,4 @@ +// // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. @@ -14,7 +15,10 @@ #include "HTP/QnnHtpContext.h" #include "HTP/QnnHtpPerfInfrastructure.h" #include "HTP/QnnHtpSystemContext.h" +#include "IR/QnnIrCommon.h" +#include "IR/QnnIrGraph.h" #include "Saver/QnnSaver.h" +#include "Saver/QnnSaverCommon.h" #include #include "core/providers/qnn/ort_api.h" @@ -51,6 +55,93 @@ static const char* DlError() { #endif } +// Workaround for a missing comma in QNN_IR_GRAPH_CUSTOM_CONFIG_INIT. +static QnnIrGraph_CustomConfig_t EmptyIrGraphConfig() { + return { + QNN_IR_GRAPH_CONFIG_OPTION_SERIALIZATION, {QNN_IR_GRAPH_SERIALIZATION_TYPE_FLAT_BUFFER, ""}}; +} + +class QnnIrConfig : public QnnSerializerConfig { + public: + QnnIrConfig(std::string backend_path, std::string dlc_dir) + : QnnSerializerConfig(std::move(backend_path)), dlc_dir_(std::move(dlc_dir)), configs_builder_(MakeConfigsBuilder()) { + } + + const QnnGraph_Config_t** Configure() override { + auto configs_builder = MakeConfigsBuilder(); + + std::filesystem::path dlc_path = (dlc_dir_ / (GetGraphName() + ".dlc")); + std::string dlc_path_str = dlc_path.string(); + gsl::not_null dlc_path_config = configs_builder.PushCustomConfig(); + dlc_path_config->option = QNN_IR_GRAPH_CONFIG_OPTION_SERIALIZATION; + dlc_path_config->serializationOption.serializationType = QNN_IR_GRAPH_SERIALIZATION_TYPE_FLAT_BUFFER; + dlc_path_config->serializationOption.outputPath = dlc_path_str.c_str(); + + gsl::not_null dlc_path_custom_config = configs_builder.PushConfig(); + dlc_path_custom_config->option = QNN_GRAPH_CONFIG_OPTION_CUSTOM; + dlc_path_custom_config->customConfig = dlc_path_config; + + std::filesystem::create_directories(dlc_path); + + // Keep the pointer to dlc_path_str's null-terminated string alive. + std::swap(dlc_path_str, dlc_path_str_); + + std::swap(configs_builder, configs_builder_); + return configs_builder_.GetQnnConfigs(); + } + + bool SupportsArbitraryGraphConfigs() const override { + return false; + } + + private: + static QnnConfigsBuilder MakeConfigsBuilder() { + return QnnConfigsBuilder(QNN_GRAPH_CONFIG_INIT, EmptyIrGraphConfig()); + } + + std::filesystem::path dlc_dir_; + std::string dlc_path_str_; + QnnConfigsBuilder configs_builder_; +}; + +class QnnSaverConfig : public QnnSerializerConfig { + public: + QnnSaverConfig(std::string backend_path) : QnnSerializerConfig(std::move(backend_path)) {} + + const QnnGraph_Config_t** Configure() override { + return nullptr; + } + + bool SupportsArbitraryGraphConfigs() const override { + return true; + } +}; + +QnnSerializerConfig::~QnnSerializerConfig() = default; + +QnnSerializerConfig::QnnSerializerConfig(std::string backend_path) + : backend_path_(std::move(backend_path)) {} + +std::unique_ptr QnnSerializerConfig::CreateIr(std::string backend_path, std::string dlc_dir) { + return std::make_unique(std::move(backend_path), std::move(dlc_dir)); +} + +std::unique_ptr QnnSerializerConfig::CreateSaver(std::string backend_path) { + return std::make_unique(std::move(backend_path)); +} + +const std::string& QnnSerializerConfig::GetBackendPath() const { + return backend_path_; +} + +const std::string& QnnSerializerConfig::GetGraphName() const { + return graph_name_; +} + +void QnnSerializerConfig::SetGraphName(std::string graph_name) { + graph_name_ = std::move(graph_name); +} + Status ReadBinaryFromFile(const std::string& file_path, uint8_t* buffer, size_t buffer_size) { ORT_RETURN_IF(nullptr == buffer, "Binary buffer is nullptr"); std::ifstream in(file_path, std::ifstream::binary); @@ -179,6 +270,10 @@ void QnnBackendManager::SetQnnBackendType(uint32_t backend_id) { case QNN_BACKEND_ID_HTP: qnn_backend_type_ = QnnBackendType::HTP; break; + case QNN_BACKEND_ID_IR: + case QNN_BACKEND_ID_SAVER: + qnn_backend_type_ = QnnBackendType::SERIALIZER; + break; default: qnn_backend_type_ = QnnBackendType::CPU; break; @@ -209,13 +304,19 @@ Status QnnBackendManager::LoadBackend() { return Status::OK(); } +QnnSerializerConfig* QnnBackendManager::GetQnnSerializerConfig() { + return qnn_serializer_config_.get(); +} + // Loads the intended backend (e.g., HTP, CPU, etc) to get its type, and then -// sets QNN Saver as the active backend. QNN op builders will still see the intended backend (e.g., HTP) -// as the backend type to ensure they emit the expected QNN API calls. +// sets QnnSaver or QnnIr as the active backend. QNN op builders will still see the intended backend +// (e.g., HTP) as the backend type to ensure they emit the expected QNN API calls. Note, however, that +// calls to QnnBackend_validateOpConfig will be to the saver backend, not the "intended" one. // -// QNN Saver is a "debugging" backend that serializes all QNN API calls (and weights) into local files. +// QnnSaver and QnnIr are "debugging" backends that serializes all QNN API calls (and weights) into +// local files: Saver dumps to C++ sources and Ir to .dlc archives. // This information can be used to debug issues by replaying QNN API calls with another backend. -Status QnnBackendManager::LoadQnnSaverBackend() { +Status QnnBackendManager::LoadQnnSerializerBackend() { void* backend_lib_handle = nullptr; // Helper that unloads the intended backend library handle when the `unload_backend_lib` variable @@ -245,25 +346,25 @@ Status QnnBackendManager::LoadQnnSaverBackend() { auto backend_id = backend_interface_provider->backendId; SetQnnBackendType(backend_id); - // Load the QNN Saver backend and set it as the activate backend. - QnnInterface_t* saver_interface_provider{nullptr}; + // Load the serializer backend and set it as the activate backend. + QnnInterface_t* serializer_interface_provider{nullptr}; auto saver_rt = GetQnnInterfaceProvider(qnn_saver_path_.c_str(), + QnnInterface_t>(qnn_serializer_config_->GetBackendPath().c_str(), "QnnInterface_getProviders", - &backend_lib_handle_, // NOTE: QNN Saver library handle is set + &backend_lib_handle_, // NOTE: QnnSaver/Ir library handle is set {QNN_API_VERSION_MAJOR, QNN_API_VERSION_MINOR, QNN_API_VERSION_PATCH}, - &saver_interface_provider); + &serializer_interface_provider); ORT_RETURN_IF_ERROR(saver_rt); - qnn_interface_ = saver_interface_provider->QNN_INTERFACE_VER_NAME; // NOTE: QNN Saver will provide the interfaces + qnn_interface_ = serializer_interface_provider->QNN_INTERFACE_VER_NAME; // NOTE: QnnSaver/Ir will provide the interfaces Qnn_Version_t backend_interface_version = GetQnnInterfaceApiVersion(backend_interface_provider); - Qnn_Version_t saver_interface_version = GetQnnInterfaceApiVersion(saver_interface_provider); + Qnn_Version_t serializer_interface_version = GetQnnInterfaceApiVersion(serializer_interface_provider); - LOGS_DEFAULT(INFO) << "Using QNN Saver version: " << saver_interface_version.major << "." - << saver_interface_version.minor << "." << saver_interface_version.patch - << " provider name : " << saver_interface_provider->providerName; + LOGS_DEFAULT(INFO) << "Using QnnSaver/Ir version: " << serializer_interface_version.major << "." + << serializer_interface_version.minor << "." << serializer_interface_version.patch + << " provider name : " << serializer_interface_provider->providerName; LOGS_DEFAULT(INFO) << "Intended backend provider name: " << backend_interface_provider->providerName << " backend id: " << backend_id @@ -636,7 +737,7 @@ Status QnnBackendManager::CreateContext(bool enable_htp_weight_sharing) { configs = npu_context_configs; break; case QnnBackendType::GPU: - // Currently only this works with QnnGpu. + case QnnBackendType::SERIALIZER: configs = nullptr; break; default: @@ -644,6 +745,11 @@ Status QnnBackendManager::CreateContext(bool enable_htp_weight_sharing) { break; } + // Not all serialization backends allow for hardware configs to be applied. + if (qnn_serializer_config_ && !qnn_serializer_config_->SupportsArbitraryGraphConfigs()) { + configs = nullptr; + } + Qnn_ContextHandle_t context = nullptr; Qnn_ErrorHandle_t result = qnn_interface_.contextCreate(backend_handle_, device_handle_, @@ -904,10 +1010,10 @@ Status QnnBackendManager::SetupBackend(const logging::Logger& logger, } Status status = Status::OK(); - if (qnn_saver_path_.empty()) { + if (!qnn_serializer_config_) { status = LoadBackend(); } else { - status = LoadQnnSaverBackend(); + status = LoadQnnSerializerBackend(); } if (status.IsOK()) { LOGS(logger, VERBOSE) << "LoadBackend succeed."; @@ -1287,7 +1393,7 @@ Status QnnBackendManager::ExtractBackendProfilingInfo() { const QnnProfile_EventId_t* profile_events{nullptr}; uint32_t num_events{0}; Qnn_ErrorHandle_t result = qnn_interface_.profileGetEvents(profile_backend_handle_, &profile_events, &num_events); - if (!qnn_saver_path_.empty()) { // Using QNN Saver backend + if (qnn_serializer_config_) { // Using QNN Saver or IR backend // QNN SDK 2.28.2 returns QNN_SAVER_ERROR_DUMMY_RETVALUE, but previous QNN versions return QNN_PROFILE_NO_ERROR. // We accept both values. ORT_RETURN_IF(QNN_PROFILE_NO_ERROR != result && QNN_SAVER_ERROR_DUMMY_RETVALUE != result, @@ -1709,13 +1815,9 @@ Status QnnBackendManager::UnloadLib(void* handle) { #ifdef _WIN32 HMODULE mod = static_cast(handle); -// TODO: QNN SDK 2.17 crashes for some models/tests on Windows x64 when unloading library. -// Example: ReductionOpTest.ArgMax -#if !defined(_M_AMD64) if (FreeLibrary(mod) == 0) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to free library."); } -#endif // !defined(_M_AMD64) mod_handles_.erase(mod); #else auto rt = ::dlclose(handle); diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h index 137b3856d431d..b8e8081f77f27 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h @@ -32,6 +32,62 @@ namespace qnn { class QnnModel; +class QnnSerializerConfig { + public: + virtual ~QnnSerializerConfig(); + + /** + * Create a config to write a DLC file for each graph using the Ir backend. + */ + static std::unique_ptr CreateIr(std::string backend_path, std::string dlc_dir); + + /** + * Create a config to write C++ source files using the Saver backend. + */ + static std::unique_ptr CreateSaver(std::string backend_path); + + /** + * Get the path to the serializer backend. + */ + const std::string& GetBackendPath() const; + + /** + * Set the name of the graph being serialized. This value may be used to determine the name + * of the output files. + * + * \param graph_name The name of the graph being serialized. + */ + void SetGraphName(std::string graph_name); + + /** + * Get any QNN Graph configs required to configure this serializer and perform any + * preparation, such as creating output directories. + * + * \return nullptr or a null-terminated list of QnnGraph_Config_t*. + */ + virtual const QnnGraph_Config_t** Configure() = 0; + + /** + * Some serializers allow for GraphConfigs that are unrelated to serialization to be + * specified at context creation time, while others raise an error. If true, this + * serializer should be configured with graph configs for any applicable real (e.g., HTP) + * backend. + * + * \return true if the backend can be configured with non-serialization graph configs. + */ + virtual bool SupportsArbitraryGraphConfigs() const = 0; + + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(QnnSerializerConfig); + + protected: + QnnSerializerConfig(std::string backend_path); + const std::string& GetGraphName() const; + + private: + std::string backend_path_; + std::string graph_name_{"graph"}; +}; + // configuration values for QnnBackendManager creation struct QnnBackendManagerConfig { std::string backend_path; @@ -39,7 +95,7 @@ struct QnnBackendManagerConfig { ProfilingLevel profiling_level; std::string profiling_file_path; ContextPriority context_priority; - std::string qnn_saver_path; + std::shared_ptr qnn_serializer_config; uint32_t device_id; QnnHtpDevice_Arch_t htp_arch; uint32_t soc_model; @@ -63,7 +119,7 @@ class QnnBackendManager : public std::enable_shared_from_this profiling_level_(config.profiling_level), profiling_file_path_(config.profiling_file_path), context_priority_(config.context_priority), - qnn_saver_path_(config.qnn_saver_path), + qnn_serializer_config_(config.qnn_serializer_config), device_id_(config.device_id), htp_arch_(config.htp_arch), soc_model_(config.soc_model) { @@ -141,6 +197,8 @@ class QnnBackendManager : public std::enable_shared_from_this Status ParseLoraConfig(std::string lora_config); + QnnSerializerConfig* GetQnnSerializerConfig(); + private: Status LoadBackend(); @@ -176,7 +234,7 @@ class QnnBackendManager : public std::enable_shared_from_this Status LoadQnnSystemLib(); - Status LoadQnnSaverBackend(); + Status LoadQnnSerializerBackend(); Status UnloadLib(void* handle); @@ -295,7 +353,7 @@ class QnnBackendManager : public std::enable_shared_from_this #ifdef _WIN32 std::set mod_handles_; #endif - const std::string qnn_saver_path_; + const std::shared_ptr qnn_serializer_config_; uint32_t device_id_ = 0; QnnHtpDevice_Arch_t htp_arch_ = QNN_HTP_DEVICE_ARCH_NONE; uint32_t soc_model_ = QNN_SOC_MODEL_UNKNOWN; diff --git a/onnxruntime/core/providers/qnn/builder/qnn_configs_helper.h b/onnxruntime/core/providers/qnn/builder/qnn_configs_helper.h index b581cd90537d9..74919b2bcd259 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_configs_helper.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_configs_helper.h @@ -44,6 +44,15 @@ class QnnConfigsBuilder { return config_ptrs_.data(); } + /** + * Returns the number of configs that have been added to this builder, excluding any null terminator. + * + * \return The number of configs in this builder. + */ + size_t GetSize() const { + return IsNullTerminated() ? config_ptrs_.size() - 1 : config_ptrs_.size(); + } + /** * Creates and returns a reference to a new custom QNN configuration object. The object is initialized to * the QNN recommended default value. The caller is meant to override fields in this object. diff --git a/onnxruntime/core/providers/qnn/builder/qnn_def.h b/onnxruntime/core/providers/qnn/builder/qnn_def.h index 0d7bc0ba9f4c7..a95628ae9cc7f 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_def.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_def.h @@ -70,7 +70,8 @@ enum class QnnBackendType : uint8_t { GPU, DSP, HTP, - HTP_FP16 + HTP_FP16, + SERIALIZER, }; bool IsCpuBackend(QnnBackendType backend_type); diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model.cc b/onnxruntime/core/providers/qnn/builder/qnn_model.cc index ec84820bb7896..175a76b590895 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_model.cc @@ -331,7 +331,15 @@ Status QnnModel::SetupTensors(std::vector& qnn_tensor_infos, bool is_input) { size_t tensor_count = tensor_wrappers.size(); ORT_RETURN_IF(0 == tensor_count, "Zero tensor size!"); - qnn_tensor_infos.resize(tensor_count); + if (is_input) { + // Resize qnn_tensor_infos according to the number of graph inputs. + auto input_count = GetGraphInputCount(); + ORT_RETURN_IF(input_count < tensor_count, + "The count of graph inputs should be at least the count of tensor_wrapper!"); + qnn_tensor_infos.resize(input_count); + } else { + qnn_tensor_infos.resize(tensor_count); + } for (auto& tensor_wrapper : tensor_wrappers) { ORT_RETURN_IF(utils::QnnTensorHasDynamicShape(tensor_wrapper.GetQnnTensor()), @@ -348,6 +356,18 @@ Status QnnModel::SetupTensors(std::vector& qnn_tensor_infos, qnn_tensor_info.tensor_byte_size = static_cast(length); qnn_tensor_info.ort_index = ort_index; } + // The number of graph inputs and the number of tensor wrappers may not match. + // - For example, for ResizeNearestNeighbor op, Qnn only cares about the 1st input, + // so the rest of the inputs are not converted to tensor wrappers. + // - However, these remaining inputs still appear in the graph inputs, resulting in + // a discrepancy in the input quantities. + // If not all inputs are used, erase the empty allocations in qnn_tensor_infos. + if (is_input) { + qnn_tensor_infos.erase(std::remove_if(qnn_tensor_infos.begin(), + qnn_tensor_infos.end(), + [](QnnTensorInfo qnn_tensor_info) { return qnn_tensor_info.tensor_wrapper == nullptr; }), + qnn_tensor_infos.end()); + } return Status::OK(); } diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model.h b/onnxruntime/core/providers/qnn/builder/qnn_model.h index 6f7738f554ef0..9f10b319f1a57 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_model.h @@ -77,6 +77,11 @@ class QnnModel { return it->second; } + // Return the number of graph inputs + size_t GetGraphInputCount() const { + return model_input_index_map_.size(); + } + size_t GetOutputIndex(const std::string& name) const { return GetInputOutputIndex(name, outputs_info_); } diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc index 0f0b42bf754d7..bd22aec89102c 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc @@ -53,7 +53,7 @@ bool QnnModelWrapper::IsQnnTensorWrapperExist(const std::string& name) const { return model_tensors_map_.find(name) != model_tensors_map_.end(); } -bool QnnModelWrapper::IsQnnParamExit(const std::string& param_tensor_name) const { +bool QnnModelWrapper::QnnParamExists(const std::string& param_tensor_name) const { return model_params_map_.find(param_tensor_name) != model_params_map_.end(); } @@ -121,14 +121,14 @@ bool QnnModelWrapper::AddTensorWrapper(QnnTensorWrapper&& tensor_wrapper) { } bool QnnModelWrapper::AddParamWrapper(QnnParamWrapper&& param_wrapper) { - // Keep a copy of tensor name sine it will be moved with the wrapper into model_params_map_ + // Keep a copy of tensor name since it will be moved with the wrapper into model_params_map_ std::string param_tensor_name = param_wrapper.GetParamTensorName(); if (param_tensor_name.length() == 0) { LOGS(logger_, ERROR) << "Invalid parameter encountered empty name."; return false; } - if (IsQnnParamExit(param_tensor_name) == true) { + if (QnnParamExists(param_tensor_name) == true) { return true; } @@ -159,7 +159,7 @@ bool QnnModelWrapper::CreateQnnInputOutputTensors(const std::string& qnn_node_na } // During graph patitioning, we only need to do op validation, it's not required to create Qnn graph tensor - // We only need to creat the Qnn graph tensor during Compile to create Qnn graph + // We only need to create the Qnn graph tensor during Compile to create Qnn graph if (!do_op_validation) { std::string error_string; auto rt = it->second.CreateQnnGraphTensor(qnn_interface_, graph_, qnn_node_name, tensor_created_map_, error_string); diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h index 9ec6f470af9fd..745dfde7bfac8 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h @@ -279,7 +279,7 @@ class QnnModelWrapper { std::vector& tensor_wrappers, bool do_op_validation = false); - bool IsQnnParamExit(const std::string& param_tensor_name) const; + bool QnnParamExists(const std::string& param_tensor_name) const; bool CreateQnnParamTensors(const std::string& qnn_node_name, const std::vector& param_tensor_names, diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc index 0390a305b2df9..839079e6c1a8e 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc @@ -15,6 +15,7 @@ #include "core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.h" #include "core/providers/qnn/builder/qnn_node_group/qnn_node_group.h" #include "core/providers/qnn/builder/qnn_node_group/reshape_gemm_fusion.h" +#include "core/providers/qnn/builder/qnn_node_group/scale_softmax_fusion.h" #include "core/providers/qnn/builder/qnn_utils.h" #include "core/providers/qnn/ort_api.h" @@ -90,6 +91,7 @@ static std::unique_ptr TryQnnFusions( {"DequantizeLinear", DQQFusion::TryFusion}, {"HardSigmoid", HardSigmoidMulFusion::TryFusion}, {"Gemm", ReshapeGemmFusion::TryFusion}, + {"Mul", ScaleSoftmaxFusion::TryFusion}, }; // For now, all fusions involve standalone node units (i.e., no wrapping DQ/Q nodes). @@ -116,7 +118,7 @@ static Status GetQnnNodeGroupsImpl(/*out*/ std::vector sorted_node_indices = graph_viewer.GetNodesInTopologicalOrder(); + const std::vector& sorted_node_indices = graph_viewer.GetNodesInTopologicalOrder(); sorted_qnn_node_group_indices.reserve(num_node_units); qnn_node_groups.reserve(num_node_units); diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/scale_softmax_fusion.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/scale_softmax_fusion.cc new file mode 100644 index 0000000000000..5c7091b3be3cc --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/scale_softmax_fusion.cc @@ -0,0 +1,226 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/qnn/builder/qnn_node_group/scale_softmax_fusion.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "core/providers/qnn/builder/qnn_utils.h" +#include "core/providers/qnn/builder/op_builder_factory.h" +#include "core/providers/qnn/builder/qnn_node_group/utils.h" +#include "core/providers/qnn/builder/qnn_model_wrapper.h" +#include "core/providers/qnn/builder/opbuilder/base_op_builder.h" + +namespace onnxruntime { +namespace qnn { +namespace { + +constexpr char kOpMul[] = "Mul"; +constexpr char kOpSoftmax[] = "Softmax"; + +/// @brief Get the index of the scalar input in the mul node +/// @param mul Multiply node unit +/// @return The index of the scalar input (0 or 1) if found, otherwise std::nullopt +std::optional GetMulScalarInputIndex(const NodeUnit* mul) { + const NodeArg* mul_y = mul->GetNode().InputDefs()[1]; + const NodeArg* mul_x = mul->GetNode().InputDefs()[0]; + auto y_shape_proto = mul_y->Shape(); + auto x_shape_proto = mul_x->Shape(); + bool is_y_scalar = false; + if (y_shape_proto != nullptr) { + auto y_shape = utils::GetTensorProtoShape(*y_shape_proto); + is_y_scalar = y_shape.NumDimensions() == 0; + } + bool is_x_scalar = false; + if (x_shape_proto != nullptr) { + auto x_shape = utils::GetTensorProtoShape(*x_shape_proto); + is_x_scalar = x_shape.NumDimensions() == 0; + } + if (is_y_scalar) { + return 1U; + } else if (is_x_scalar) { + return 0U; + } + return std::nullopt; +} + +/// @brief Get the axis for softmax +/// @param mul Multiply node unit +/// @param softmax Softmax node unit +/// @return The axis for softmax +std::optional GetPositiveSoftmaxAxis(const NodeUnit* mul, const NodeUnit* softmax) { + NodeAttrHelper softmax_attr_helper(softmax->GetNode()); + std::optional param_axis = softmax_attr_helper.GetInt64(QNN_OP_SOFTMAX_PARAM_AXIS); + if (!param_axis.has_value()) { + return std::nullopt; + } + int64_t axis_value = param_axis.value(); + if (axis_value < 0) { + size_t input_scale_index = GetMulScalarInputIndex(mul).value(); + size_t input_other_index = 1U - input_scale_index; + int rank = mul->GetNode().InputDefs()[input_other_index]->Shape()->dim_size(); + axis_value += static_cast(rank); + } + return static_cast(axis_value); +} + +/// @brief Identify scalar input from mul node if present +/// @param mul Multiply node unit +/// @return The scalar input float value if found, otherwise std::nullopt +std::optional ExtractScalarValueFromMul(const GraphViewer& graph_viewer, const NodeUnit* mul) { + std::optional input_scale_index = GetMulScalarInputIndex(mul); + if (!input_scale_index.has_value()) { + return std::nullopt; + } + const NodeArg* scalar_arg = mul->GetNode().InputDefs()[input_scale_index.value()]; + if (!graph_viewer.IsConstantInitializer(scalar_arg->Name(), true)) { + return std::nullopt; + } + const auto* scalar_tensor = graph_viewer.GetConstantInitializer(scalar_arg->Name()); + if (!scalar_tensor) { + return std::nullopt; + } + if (scalar_tensor->data_type() != ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + return std::nullopt; + } + const auto& raw_data = scalar_tensor->raw_data(); + if (raw_data.size() != sizeof(float) || reinterpret_cast(raw_data.data()) % alignof(float) != 0) { + return std::nullopt; + } + return *reinterpret_cast(raw_data.data()); +} + +/// @brief Create or validate the QNN node +/// @param qnn_model_wrapper QNN model wrapper +/// @param node_units The node units containing the softmax and mul nodes +/// @param validate Whether to validate the QNN node +/// @return Status +Status CreateOrValidateOnQnn( + QnnModelWrapper* qnn_model_wrapper, + gsl::span node_units, + bool validate) { + const NodeUnit* mul = node_units[0]; + const NodeUnit* softmax = node_units[1]; + ORT_RETURN_IF_NOT(mul->OpType() == kOpMul, + "Expected scale node to be of type Mul, got ", mul->OpType()); + ORT_RETURN_IF_NOT(softmax->OpType() == kOpSoftmax, + "Expected softmax node to be of type Softmax, got ", softmax->OpType()); + size_t input_scale_index = GetMulScalarInputIndex(mul).value(); + size_t input_other_index = 1U - input_scale_index; + const NodeUnitIODef& mul_input_other = mul->Inputs()[input_other_index]; + const NodeUnitIODef& softmax_output = softmax->Outputs()[0]; + + std::vector param_tensor_names; + { // axis + std::optional axis = GetPositiveSoftmaxAxis(mul, softmax); + if (axis.has_value()) { + Qnn_Scalar_t axis_scalar = QNN_SCALAR_INIT; + axis_scalar.dataType = QNN_DATATYPE_UINT_32; + axis_scalar.uint32Value = axis.value(); + QnnParamWrapper param_wrapper(softmax->Index(), + softmax->Name(), + QNN_OP_SOFTMAX_PARAM_AXIS, + axis_scalar); + ORT_RETURN_IF_NOT(qnn_model_wrapper->AddParamWrapper(std::move(param_wrapper)), "Failed to add param"); + param_tensor_names.push_back(param_wrapper.GetParamTensorName()); + } + } + { // beta + NodeAttrHelper softmax_attr_helper(softmax->GetNode()); + std::optional beta = softmax_attr_helper.GetFloat(QNN_OP_SOFTMAX_PARAM_BETA); + float scale = ExtractScalarValueFromMul(qnn_model_wrapper->GetGraphViewer(), mul).value_or(1.0f); + Qnn_Scalar_t beta_scalar = QNN_SCALAR_INIT; + beta_scalar.dataType = QNN_DATATYPE_FLOAT_32; + beta_scalar.floatValue = scale * beta.value_or(1.0f); + QnnParamWrapper param_wrapper(softmax->Index(), + softmax->Name(), + QNN_OP_SOFTMAX_PARAM_BETA, + beta_scalar); + ORT_RETURN_IF_NOT(qnn_model_wrapper->AddParamWrapper(std::move(param_wrapper)), "Failed to add param"); + param_tensor_names.push_back(param_wrapper.GetParamTensorName()); + } + + QnnTensorWrapper fused_softmax_input; + QnnTensorWrapper fused_softmax_output; + ORT_RETURN_IF_ERROR(qnn_model_wrapper->MakeTensorWrapper(mul_input_other, fused_softmax_input)); + ORT_RETURN_IF_ERROR(qnn_model_wrapper->MakeTensorWrapper(softmax_output, fused_softmax_output)); + + if (validate) { + ORT_RETURN_IF_ERROR(qnn_model_wrapper->ValidateQnnNode(softmax->Name(), + QNN_OP_PACKAGE_NAME_QTI_AISW, + QNN_OP_SOFTMAX, + {fused_softmax_input.GetQnnTensor()}, + {fused_softmax_output.GetQnnTensor()}, + {})); + } else { + ORT_RETURN_IF_NOT(qnn_model_wrapper->AddTensorWrapper(std::move(fused_softmax_input)), "Failed to add input"); + ORT_RETURN_IF_NOT(qnn_model_wrapper->AddTensorWrapper(std::move(fused_softmax_output)), "Failed to add output"); + ORT_RETURN_IF_NOT(qnn_model_wrapper->CreateQnnNode(softmax->Name(), + QNN_OP_PACKAGE_NAME_QTI_AISW, + QNN_OP_SOFTMAX, + {mul_input_other.node_arg.Name()}, + {softmax_output.node_arg.Name()}, + std::move(param_tensor_names), + validate), + "Failed to add fused " + std::string(kOpSoftmax) + " node."); + } + return Status::OK(); +} + +} // namespace + +std::unique_ptr ScaleSoftmaxFusion::TryFusion( + QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& mul_node_unit, + const std::unordered_map& node_to_node_unit, + const std::unordered_map& node_unit_to_qnn_node_group, + [[maybe_unused]] const logging::Logger& logger) { + if (mul_node_unit.OpType() != kOpMul || mul_node_unit.UnitType() != NodeUnit::Type::SingleNode) { + return nullptr; + } + // Check if the mul node has a scalar input that can fold into the softmax's beta + const GraphViewer& graph_viewer = qnn_model_wrapper.GetGraphViewer(); + std::optional scalar = ExtractScalarValueFromMul(graph_viewer, &mul_node_unit); + if (!scalar.has_value()) { + return nullptr; + } + + // Mul node must have a single Softmax node as child + const std::array child_op_types{kOpSoftmax}; + const NodeUnit* softmax = GetOnlyChildOfType(graph_viewer, mul_node_unit, child_op_types, + node_to_node_unit, node_unit_to_qnn_node_group); + if (softmax == nullptr) { + return nullptr; + } + + std::array node_unit_array{&mul_node_unit, softmax}; + auto node_units = gsl::make_span(node_unit_array.data(), 2); + if (CreateOrValidateOnQnn(&qnn_model_wrapper, node_units, /*validate=*/true) != Status::OK()) { + return nullptr; + } + return std::make_unique(node_units); +} + +gsl::span ScaleSoftmaxFusion::GetNodeUnits() const { + return gsl::span{node_units_.data(), node_units_.size()}; +} + +Status ScaleSoftmaxFusion::IsSupported( + QnnModelWrapper& qnn_model_wrapper, [[maybe_unused]] const logging::Logger& logger) const { + return CreateOrValidateOnQnn(&qnn_model_wrapper, GetNodeUnits(), /*validate=*/true); +} + +Status ScaleSoftmaxFusion::AddToModelBuilder( + QnnModelWrapper& qnn_model_wrapper, [[maybe_unused]] const logging::Logger& logger) const { + return CreateOrValidateOnQnn(&qnn_model_wrapper, GetNodeUnits(), /*validate=*/false); +} + +} // namespace qnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/scale_softmax_fusion.h b/onnxruntime/core/providers/qnn/builder/qnn_node_group/scale_softmax_fusion.h new file mode 100644 index 0000000000000..66eb892e7a884 --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/scale_softmax_fusion.h @@ -0,0 +1,54 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include +#include +#include + +#include "core/providers/qnn/builder/qnn_node_group/qnn_node_group.h" +#include "core/providers/qnn/ort_api.h" + +namespace onnxruntime { +namespace qnn { + +class QnnModelWrapper; + +/// +/// Represents a fusion of pattern: Softmax(Mul(x, scalar_scale)) => QnnSoftmax(x, beta=scalar_scale) +/// +class ScaleSoftmaxFusion : public IQnnNodeGroup { + public: + explicit ScaleSoftmaxFusion(gsl::span node_units) { + ORT_ENFORCE(node_units.size() == 2, "Pattern expect exactly 2 NodeUnits."); + node_units_[0] = node_units[0]; + node_units_[1] = node_units[1]; + } + ORT_DISALLOW_COPY_AND_ASSIGNMENT(ScaleSoftmaxFusion); + + Status IsSupported(QnnModelWrapper& qnn_model_wrapper, const logging::Logger& logger) const override; + Status AddToModelBuilder(QnnModelWrapper& qnn_model_wrapper, const logging::Logger& logger) const override; + gsl::span GetNodeUnits() const override; + const NodeUnit* GetTargetNodeUnit() const override { return node_units_[1]; } + std::string_view Type() const override { return "ScaleSoftmaxFusion"; } + + /// + /// Traverses graph to check if the given starting NodeUnit is part of a valid Softmax -> Mul sequence. + /// If so, returns a IQnnNodeGroup that contains the Softmax and Mul NodeUnits. + /// + static std::unique_ptr TryFusion( + QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& mul_node_unit, + const std::unordered_map& node_to_node_unit, + const std::unordered_map& node_unit_to_qnn_node_group, + const logging::Logger& logger); + + private: + std::array node_units_; +}; + +} // namespace qnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_utils.cc b/onnxruntime/core/providers/qnn/builder/qnn_utils.cc index 4fe223d821f1c..f869f33847bbf 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_utils.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_utils.cc @@ -40,7 +40,7 @@ size_t GetElementSizeByType(const Qnn_DataType_t& data_type) { {QNN_DATATYPE_UFIXED_POINT_8, 1}, {QNN_DATATYPE_UFIXED_POINT_16, 2}, {QNN_DATATYPE_UFIXED_POINT_32, 4}, - }; + {QNN_DATATYPE_UNDEFINED, 1}}; auto pos = data_type_to_size.find(data_type); ORT_ENFORCE(pos != data_type_to_size.end(), "Unknown QNN data type", data_type); @@ -228,6 +228,9 @@ std::ostream& operator<<(std::ostream& out, const Qnn_DataType_t& data_type) { case QNN_DATATYPE_UFIXED_POINT_4: out << "QNN_DATATYPE_UFIXED_POINT_4"; break; + case QNN_DATATYPE_UNDEFINED: + out << "QNN_DATATYPE_UNDEFINED"; + break; default: ORT_THROW("Unknown Qnn Data type"); } @@ -1250,6 +1253,52 @@ Status TransposeFromCnhwToHwcn(std::vector&& original_input_shape_dims, output_buffer); } +// Inserts a QNN Convert operator to convert from one quantization type (e.g., uint16) to another (e.g., uint8). +// (OR) Convert from Asymmetric (e.g., UINT16) to Symmetric (e.g., INT16) quantization type +Status InsertConvertOp(QnnModelWrapper& qnn_model_wrapper, + const std::string& convert_input_name, + const std::string& convert_output_name, + Qnn_DataType_t input_qnn_data_type, + Qnn_DataType_t output_qnn_data_type, + int32_t input_offset, + float input_scale, + const std::vector& output_shape, + bool output_symmetric, + bool do_op_validation) { + // Assume input is already handled. + float qmin = 0.0f; + float qmax = 255.0f; + ORT_RETURN_IF_ERROR(qnn::utils::GetQminQmax(input_qnn_data_type, qmin, qmax)); + double value_min = qnn::utils::Dequantize(input_offset, input_scale, qmin); + double value_max = qnn::utils::Dequantize(input_offset, input_scale, qmax); + float scale = 0.0f; + int32_t offset = 0; + ORT_RETURN_IF_ERROR(qnn::utils::GetQuantParams(static_cast(value_min), + static_cast(value_max), + output_qnn_data_type, + scale, + offset, + output_symmetric)); + + std::vector output_shape_copy = output_shape; + QnnTensorWrapper convert_output_tensorwrapper(convert_output_name, + QNN_TENSOR_TYPE_NATIVE, + output_qnn_data_type, + QnnQuantParamsWrapper(scale, offset), + std::move(output_shape_copy)); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(convert_output_tensorwrapper)), "Failed to add tensor."); + + ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(convert_output_name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + "Convert", + {convert_input_name}, + {convert_output_name}, + {}, + do_op_validation), + "Failed to add node."); + return Status::OK(); +} + } // namespace utils } // namespace qnn } // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_utils.h b/onnxruntime/core/providers/qnn/builder/qnn_utils.h index 7065a4b31f77e..eefde87630077 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_utils.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_utils.h @@ -374,6 +374,17 @@ Status TwoDimensionTranspose(const QnnModelWrapper& qnn_model_wrapper, const onnx::TensorProto& initializer, std::vector& transposed_data); +Status InsertConvertOp(QnnModelWrapper& qnn_model_wrapper, + const std::string& convert_input_name, + const std::string& convert_output_name, + Qnn_DataType_t input_qnn_data_type, + Qnn_DataType_t output_qnn_data_type, + int32_t input_offset, + float input_scale, + const std::vector& output_shape, + bool output_symmetric, + bool do_op_validation); + } // namespace utils } // namespace qnn } // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/ort_api.cc b/onnxruntime/core/providers/qnn/ort_api.cc index 809593b409dad..aec09d043d2bc 100644 --- a/onnxruntime/core/providers/qnn/ort_api.cc +++ b/onnxruntime/core/providers/qnn/ort_api.cc @@ -102,6 +102,18 @@ const std::string& NodeAttrHelper::Get(const std::string& key, const std::string return def_val; } +std::vector NodeAttrHelper::Get(const std::string& key, const std::vector& def_val) const { + if (auto entry = node_attributes_.find(key); entry != node_attributes_.end()) { + std::vector res; + for (int i = 0; i < NODE_ATTR_ITER_VAL(entry).strings_size(); i++) { + res.emplace_back(NODE_ATTR_ITER_VAL(entry).strings(i)); + } + return res; + } + + return def_val; +} + std::vector NodeAttrHelper::Get(const std::string& key, const std::vector& def_val) const { if (auto entry = node_attributes_.find(key); entry != node_attributes_.end()) { const auto& values = NODE_ATTR_ITER_VAL(entry).ints(); diff --git a/onnxruntime/core/providers/qnn/ort_api.h b/onnxruntime/core/providers/qnn/ort_api.h index d25269be075de..2cb4d5c2003bc 100644 --- a/onnxruntime/core/providers/qnn/ort_api.h +++ b/onnxruntime/core/providers/qnn/ort_api.h @@ -151,6 +151,7 @@ class NodeAttrHelper { std::vector Get(const std::string& key, const std::vector& def_val) const; const std::string& Get(const std::string& key, const std::string& def_val) const; + std::vector Get(const std::string& key, const std::vector& def_val) const; // Convert the i() or ints() of the attribute from int64_t to int32_t int32_t Get(const std::string& key, int32_t def_val) const; diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index 65ef19f0b6c0e..c085ef7c31f0e 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -36,18 +36,21 @@ const std::string kDefaultCpuBackendPath = MakeSharedLibraryPath("QnnCpu"); const std::string kDefaultGpuBackendPath = MakeSharedLibraryPath("QnnGpu"); const std::string kDefaultHtpBackendPath = MakeSharedLibraryPath("QnnHtp"); const std::string kDefaultSaverBackendPath = MakeSharedLibraryPath("QnnSaver"); +const std::string kDefaultIrBackendPath = MakeSharedLibraryPath("QnnIr"); static bool ParseBackendTypeName(std::string_view backend_type_name, std::string& backend_path) { constexpr std::string_view kCpuBackendTypeName{"cpu"}; constexpr std::string_view kGpuBackendTypeName{"gpu"}; constexpr std::string_view kHtpBackendTypeName{"htp"}; constexpr std::string_view kSaverBackendTypeName{"saver"}; + constexpr std::string_view kIrBackendTypeName{"ir"}; constexpr std::array kAllowedBackendTypeNames{ kCpuBackendTypeName, kGpuBackendTypeName, kHtpBackendTypeName, kSaverBackendTypeName, + kIrBackendTypeName, }; std::optional associated_backend_path{}; @@ -59,6 +62,8 @@ static bool ParseBackendTypeName(std::string_view backend_type_name, std::string associated_backend_path = kDefaultHtpBackendPath; } else if (backend_type_name == kSaverBackendTypeName) { associated_backend_path = kDefaultSaverBackendPath; + } else if (backend_type_name == kIrBackendTypeName) { + associated_backend_path = kDefaultIrBackendPath; } if (associated_backend_path.has_value()) { @@ -204,6 +209,51 @@ qnn::ProfilingLevel QNNExecutionProvider::GetProfilingLevelFromETWLevel(unsigned } } +static std::unique_ptr ParseSerializerBackendOptions(const ProviderOptions& provider_options_map) { + // Enable use of QNN Saver if the user provides a path the QNN Saver backend library. + static const std::string QNN_SAVER_PATH_KEY = "qnn_saver_path"; + auto qnn_saver_path_pos = provider_options_map.find(QNN_SAVER_PATH_KEY); + if (qnn_saver_path_pos != provider_options_map.end()) { + LOGS_DEFAULT(VERBOSE) << "User specified QNN Saver path: " << qnn_saver_path_pos->second; + return qnn::QnnSerializerConfig::CreateSaver(qnn_saver_path_pos->second); + } + + static const std::string DUMP_QNN_IR_DLC = "dump_qnn_ir_dlc"; + auto dump_qnn_ir_dlc = ParseBoolOption(DUMP_QNN_IR_DLC, false, provider_options_map); + + static const std::string DUMP_QNN_IR_DLC_DIR = "dump_qnn_ir_dlc_dir"; + std::string qnn_ir_dlc_dir; + auto qnn_ir_dlc_dir_pos = provider_options_map.find(DUMP_QNN_IR_DLC_DIR); + if (qnn_ir_dlc_dir_pos != provider_options_map.end()) { + qnn_ir_dlc_dir = qnn_ir_dlc_dir_pos->second; + if (dump_qnn_ir_dlc) { + LOGS_DEFAULT(INFO) << "IR DLC directory: " << qnn_ir_dlc_dir; + } else { + LOGS_DEFAULT(WARNING) << "Provided a directory for dumping QNN graphs to DLC, " + << "but did not set dump_qnn_ir_dlc to 1."; + } + } + + static const std::string QNN_IR_BACKEND_PATH = "qnn_ir_backend_path"; + std::string qnn_ir_backend_path = kDefaultIrBackendPath; + auto qnn_ir_backend_path_pos = provider_options_map.find(QNN_IR_BACKEND_PATH); + if (qnn_ir_backend_path_pos != provider_options_map.end()) { + qnn_ir_backend_path = qnn_ir_backend_path_pos->second; + if (dump_qnn_ir_dlc) { + LOGS_DEFAULT(INFO) << "IR backend path: " << qnn_ir_backend_path; + } else { + LOGS_DEFAULT(WARNING) << "Provided a path to the Ir backend for dumping QNN graphs to DLC, " + << "but did not set dump_qnn_ir_dlc to 1."; + } + } + + if (dump_qnn_ir_dlc) { + return qnn::QnnSerializerConfig::CreateIr(std::move(qnn_ir_backend_path), std::move(qnn_ir_dlc_dir)); + } + + return nullptr; +} + QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_options_map, const ConfigOptions* config_options) : IExecutionProvider{onnxruntime::kQnnExecutionProvider} { @@ -283,6 +333,8 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio LOGS_DEFAULT(VERBOSE) << "Using backend path: " << backend_path; } + std::unique_ptr qnn_serializer_config = ParseSerializerBackendOptions(provider_options_map); + std::string profiling_file_path; static const std::string PROFILING_LEVEL = "profiling_level"; qnn::ProfilingLevel profiling_level = qnn::ProfilingLevel::OFF; @@ -337,15 +389,6 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio ParseHtpGraphFinalizationOptimizationMode(htp_graph_finalization_opt_mode_pos->second); } - // Enable use of QNN Saver if the user provides a path the QNN Saver backend library. - static const std::string QNN_SAVER_PATH_KEY = "qnn_saver_path"; - std::string qnn_saver_path; - auto qnn_saver_path_pos = provider_options_map.find(QNN_SAVER_PATH_KEY); - if (qnn_saver_path_pos != provider_options_map.end()) { - qnn_saver_path = qnn_saver_path_pos->second; - LOGS_DEFAULT(VERBOSE) << "User specified QNN Saver path: " << qnn_saver_path; - } - static const std::string QNN_CONTEXT_PRIORITY = "qnn_context_priority"; qnn::ContextPriority context_priority = qnn::ContextPriority::NORMAL; auto qnn_context_priority_pos = provider_options_map.find(QNN_CONTEXT_PRIORITY); @@ -464,7 +507,7 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio profiling_level, profiling_file_path, context_priority, - qnn_saver_path, + std::move(qnn_serializer_config), device_id_, htp_arch, soc_model}); @@ -912,7 +955,7 @@ Status QNNExecutionProvider::CreateComputeFunc(std::vector& nod return Status::OK(); } -void QNNExecutionProvider::InitQnnGraphConfigs(qnn::QnnConfigsBuilder& configs_builder) const { +void QNNExecutionProvider::InitQnnHtpGraphConfigs(qnn::QnnConfigsBuilder& configs_builder) const { if (qnn_backend_manager_->GetQnnBackendType() == qnn::QnnBackendType::HTP) { if (htp_graph_finalization_opt_mode_ != qnn::HtpGraphFinalizationOptimizationMode::kDefault) { gsl::not_null htp_graph_opt_config = configs_builder.PushCustomConfig(); @@ -956,9 +999,39 @@ Status QNNExecutionProvider::CompileFromOrtGraph(const std::vector qnn_model = std::make_unique(qnn_backend_manager_.get()); - qnn::QnnConfigsBuilder graph_configs_builder(QNN_GRAPH_CONFIG_INIT, - QNN_HTP_GRAPH_CUSTOM_CONFIG_INIT); - InitQnnGraphConfigs(graph_configs_builder); + std::vector all_graph_configs; + + qnn::QnnConfigsBuilder htp_graph_configs_builder(QNN_GRAPH_CONFIG_INIT, + QNN_HTP_GRAPH_CUSTOM_CONFIG_INIT); + InitQnnHtpGraphConfigs(htp_graph_configs_builder); + + const QnnGraph_Config_t** htp_configs = htp_graph_configs_builder.GetQnnConfigs(); + if (htp_configs) { + // Reserve enough for configs + nullptr + all_graph_configs.reserve(htp_graph_configs_builder.GetSize() + 1); + for (const QnnGraph_Config_t** config = htp_configs; *config; ++config) { + all_graph_configs.push_back(*config); + } + } + + qnn::QnnSerializerConfig* qnn_serializer_config = qnn_backend_manager_->GetQnnSerializerConfig(); + if (qnn_serializer_config) { + // We don't bother reserving here to keep the API simpler. Also note that if we're here, + // we're likely debugging and not waiting for inference. + qnn_serializer_config->SetGraphName(fused_node.Name()); + const QnnGraph_Config_t** serializer_configs = qnn_serializer_config->Configure(); + if (serializer_configs) { + for (const QnnGraph_Config_t** config = serializer_configs; *config; ++config) { + all_graph_configs.push_back(*config); + } + } + } + + const QnnGraph_Config_t** all_graph_configs_ptr = nullptr; + if (!all_graph_configs.empty()) { + all_graph_configs.push_back(nullptr); + all_graph_configs_ptr = all_graph_configs.data(); + } std::string json_graph_filepath; @@ -969,7 +1042,7 @@ Status QNNExecutionProvider::CompileFromOrtGraph(const std::vectorComposeGraph(graph_viewer, fused_node, model_settings_, logger, - graph_configs_builder.GetQnnConfigs(), json_graph_filepath)); + all_graph_configs_ptr, json_graph_filepath)); ORT_RETURN_IF_ERROR(qnn_model->FinalizeGraphs(logger)); ORT_RETURN_IF_ERROR(qnn_model->SetupQnnInputOutput(logger)); diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.h b/onnxruntime/core/providers/qnn/qnn_execution_provider.h index 7769a4a453c1b..4ccb1554f8b15 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.h +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.h @@ -68,7 +68,7 @@ class QNNExecutionProvider : public IExecutionProvider { void ParseHtpGraphFinalizationOptimizationMode(const std::string& htp_graph_finalization_opt_mode_string); - void InitQnnGraphConfigs(qnn::QnnConfigsBuilder& configs_builder) const; + void InitQnnHtpGraphConfigs(qnn::QnnConfigsBuilder& configs_builder) const; qnn::ProfilingLevel GetProfilingLevelFromETWLevel(unsigned char level); diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index 72eb2579e9d42..fc8281ce51a1b 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -538,9 +538,18 @@ bool ApplyProfileShapesFromProviderOptions(std::vector(j)][i].push_back(opt_value); } +#if (NV_TENSORRT_MAJOR == 10 && NV_TENSORRT_MINOR > 10) || NV_TENSORRT_MAJOR > 10 + std::vector shapes_min_64(shapes_min.begin(), shapes_min.end()); + std::vector shapes_opt_64(shapes_opt.begin(), shapes_opt.end()); + std::vector shapes_max_64(shapes_max.begin(), shapes_max.end()); + trt_profile->setShapeValuesV2(input_name.c_str(), nvinfer1::OptProfileSelector::kMIN, &shapes_min_64[0], shape_size); + trt_profile->setShapeValuesV2(input_name.c_str(), nvinfer1::OptProfileSelector::kMAX, &shapes_opt_64[0], shape_size); + trt_profile->setShapeValuesV2(input_name.c_str(), nvinfer1::OptProfileSelector::kOPT, &shapes_max_64[0], shape_size); +#else trt_profile->setShapeValues(input_name.c_str(), nvinfer1::OptProfileSelector::kMIN, &shapes_min[0], shape_size); trt_profile->setShapeValues(input_name.c_str(), nvinfer1::OptProfileSelector::kMAX, &shapes_max[0], shape_size); trt_profile->setShapeValues(input_name.c_str(), nvinfer1::OptProfileSelector::kOPT, &shapes_opt[0], shape_size); +#endif } // Execution tensor else { @@ -627,6 +636,17 @@ Status ApplyProfileShapesFromInputTensorValue(std::vectorisShapeTensor()) { // shape tensor int shape_size = nb_dims == 0 ? 1 : static_cast(tensor_shapes[0]); +#if (NV_TENSORRT_MAJOR == 10 && NV_TENSORRT_MINOR > 10) || NV_TENSORRT_MAJOR > 10 + std::vector shapes_min(shape_size), shapes_opt(shape_size), shapes_max(shape_size); + for (int j = 0; j < shape_size; j++) { + shapes_min[j] = *(trt_profiles[0]->getShapeValuesV2(input_name.c_str(), nvinfer1::OptProfileSelector::kMIN)); + shapes_max[j] = *(trt_profiles[0]->getShapeValuesV2(input_name.c_str(), nvinfer1::OptProfileSelector::kMAX)); + shapes_opt[j] = *(trt_profiles[0]->getShapeValuesV2(input_name.c_str(), nvinfer1::OptProfileSelector::kOPT)); + } + trt_profile->setShapeValuesV2(input_name.c_str(), nvinfer1::OptProfileSelector::kMIN, &shapes_min[0], shape_size); + trt_profile->setShapeValuesV2(input_name.c_str(), nvinfer1::OptProfileSelector::kMAX, &shapes_max[0], shape_size); + trt_profile->setShapeValuesV2(input_name.c_str(), nvinfer1::OptProfileSelector::kOPT, &shapes_opt[0], shape_size); +#else std::vector shapes_min(shape_size), shapes_opt(shape_size), shapes_max(shape_size); for (int j = 0; j < shape_size; j++) { shapes_min[j] = *(trt_profiles[0]->getShapeValues(input_name.c_str(), nvinfer1::OptProfileSelector::kMIN)); @@ -636,6 +656,7 @@ Status ApplyProfileShapesFromInputTensorValue(std::vectorsetShapeValues(input_name.c_str(), nvinfer1::OptProfileSelector::kMIN, &shapes_min[0], shape_size); trt_profile->setShapeValues(input_name.c_str(), nvinfer1::OptProfileSelector::kMAX, &shapes_max[0], shape_size); trt_profile->setShapeValues(input_name.c_str(), nvinfer1::OptProfileSelector::kOPT, &shapes_opt[0], shape_size); +#endif } else { // execution tensor nvinfer1::Dims dims_min, dims_opt, dims_max; @@ -733,10 +754,18 @@ Status ApplyProfileShapesFromInputTensorValue(std::vector 10) || NV_TENSORRT_MAJOR > 10 + std::vector shapes_min_64(shapes_min.begin(), shapes_min.end()); + std::vector shapes_opt_64(shapes_opt.begin(), shapes_opt.end()); + std::vector shapes_max_64(shapes_max.begin(), shapes_max.end()); + trt_profile->setShapeValuesV2(input_name.c_str(), nvinfer1::OptProfileSelector::kMIN, &shapes_min_64[0], shape_size); + trt_profile->setShapeValuesV2(input_name.c_str(), nvinfer1::OptProfileSelector::kMAX, &shapes_opt_64[0], shape_size); + trt_profile->setShapeValuesV2(input_name.c_str(), nvinfer1::OptProfileSelector::kOPT, &shapes_max_64[0], shape_size); +#else trt_profile->setShapeValues(input_name.c_str(), nvinfer1::OptProfileSelector::kMIN, &shapes_min[0], shape_size); trt_profile->setShapeValues(input_name.c_str(), nvinfer1::OptProfileSelector::kMAX, &shapes_max[0], shape_size); trt_profile->setShapeValues(input_name.c_str(), nvinfer1::OptProfileSelector::kOPT, &shapes_opt[0], shape_size); +#endif } else { // Execution tensor nvinfer1::Dims dims_min(dims), dims_opt(dims), dims_max(dims); for (int j = 0, end = nb_dims; j < end; ++j) { @@ -958,6 +987,7 @@ Status BindContextInput(Ort::KernelContext& ctx, switch (tensor_type) { CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, float) CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16, uint16_t) + CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16, uint16_t) CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL, bool) CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, int8_t) CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, uint8_t) @@ -1050,6 +1080,7 @@ Status BindContextOutput(Ort::KernelContext& ctx, switch (output_type) { CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, float) CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16, uint16_t) + CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16, uint16_t) CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL, bool) CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, int8_t) CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, uint8_t) @@ -1119,6 +1150,7 @@ Status BindKernelOutput(Ort::KernelContext& ctx, switch (output_type) { CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, float) CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16, uint16_t) + CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16, uint16_t) CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL, bool) CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, int8_t) CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, uint8_t) @@ -1336,6 +1368,12 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv min_subgraph_size_ = info.min_subgraph_size; max_workspace_size_ = info.max_workspace_size; fp16_enable_ = info.fp16_enable; + bf16_enable_ = info.bf16_enable; + // BF16 support is primarily available on NVIDIA GPUs with the Ampere and later architectures with compute capability of 8.0 or higher. + if (bf16_enable_ && prop.major < 8) { + bf16_enable_ = false; + LOGS_DEFAULT(WARNING) << "[TensorRT EP] trt_bf16_enable is set, but platform doesn't support bf16."; + } int8_enable_ = info.int8_enable; if (int8_enable_) { int8_calibration_cache_name_ = info.int8_calibration_table_name; @@ -1382,7 +1420,7 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv } force_sequential_engine_build_ = info.force_sequential_engine_build; context_memory_sharing_enable_ = info.context_memory_sharing_enable; - if (fp16_enable_) { + if (fp16_enable_ || bf16_enable_) { layer_norm_fp32_fallback_ = info.layer_norm_fp32_fallback; } build_heuristics_enable_ = info.build_heuristics_enable; @@ -1419,6 +1457,11 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv fp16_enable_ = (std::stoi(fp16_enable_env) == 0 ? false : true); } + const std::string bf16_enable_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kBF16Enable); + if (!bf16_enable_env.empty()) { + bf16_enable_ = (std::stoi(bf16_enable_env) == 0 ? false : true); + } + const std::string int8_enable_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kINT8Enable); if (!int8_enable_env.empty()) { int8_enable_ = (std::stoi(int8_enable_env) == 0 ? false : true); @@ -1760,6 +1803,7 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv << ", trt_min_subgraph_size: " << min_subgraph_size_ << ", trt_max_workspace_size: " << max_workspace_size_ << ", trt_fp16_enable: " << fp16_enable_ + << ", trt_bf16_enable: " << bf16_enable_ << ", trt_int8_enable: " << int8_enable_ << ", trt_int8_calibration_cache_name: " << int8_calibration_cache_name_ << ", int8_calibration_cache_available: " << int8_calibration_cache_available_ @@ -1818,6 +1862,10 @@ TensorrtExecutionProvider::~TensorrtExecutionProvider() { } ReleaseTensorRTCustomOpDomainList(info_.custom_op_domain_list); + if (context_memory_) { + context_memory_.reset(); + } + if (alloc_ != nullptr) { // This code is same as OrtApis::ReleaseAllocator defined in allocator_adapters.cc. // We can't get api inside destructor so that's why we duplicate the code here. @@ -2295,7 +2343,7 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect auto trt_builder = GetBuilder(trt_logger); auto network_flags = 0; #if NV_TENSORRT_MAJOR > 8 - network_flags |= fp16_enable_ || int8_enable_ ? 0 : 1U << static_cast(nvinfer1::NetworkDefinitionCreationFlag::kSTRONGLY_TYPED); + network_flags |= (fp16_enable_ || int8_enable_ || bf16_enable_) ? 0 : 1U << static_cast(nvinfer1::NetworkDefinitionCreationFlag::kSTRONGLY_TYPED); #else network_flags |= 1U << static_cast(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH); #endif @@ -2908,7 +2956,7 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView auto trt_builder = GetBuilder(trt_logger); auto network_flags = 0; #if NV_TENSORRT_MAJOR > 8 - network_flags |= fp16_enable_ || int8_enable_ ? 0 : 1U << static_cast(nvinfer1::NetworkDefinitionCreationFlag::kSTRONGLY_TYPED); + network_flags |= (fp16_enable_ || int8_enable_ || bf16_enable_) ? 0 : 1U << static_cast(nvinfer1::NetworkDefinitionCreationFlag::kSTRONGLY_TYPED); #else network_flags |= 1U << static_cast(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH); #endif @@ -2921,7 +2969,7 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView } // Force Pow + Reduce ops in layer norm to run in FP32 to avoid overflow - if (fp16_enable_ && layer_norm_fp32_fallback_) { + if ((fp16_enable_ || bf16_enable_) && layer_norm_fp32_fallback_) { for (auto idx = 1; idx < trt_network->getNbLayers() - 1; ++idx) { auto layer = trt_network->getLayer(idx); auto next_layer = trt_network->getLayer(idx + 1); @@ -3070,7 +3118,7 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView } // Check platform availability for low precision - if (fp16_enable_) { + if (fp16_enable_ || bf16_enable_) { #if defined(_MSC_VER) #pragma warning(push) #pragma warning(disable : 4996) @@ -3080,7 +3128,8 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView #pragma warning(pop) #endif fp16_enable_ = false; - LOGS_DEFAULT(WARNING) << "[TensorRT EP] ORT_TENSORRT_FP16_ENABLE is set, but platform doesn't support fast native fp16"; + bf16_enable_ = false; + LOGS_DEFAULT(WARNING) << "[TensorRT EP] ORT_TENSORRT_FP16_ENABLE or ORT_TENSORRT_BF16_ENABLE is set, but platform doesn't support fast native fp16/bf16"; } } @@ -3109,15 +3158,17 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView // Set precision flags std::string trt_node_name_with_precision = fused_node.Name(); - if (fp16_enable_ && int8_enable_) { - trt_config->setFlags(1U << static_cast(nvinfer1::BuilderFlag::kFP16) | 1U << static_cast(nvinfer1::BuilderFlag::kINT8)); - trt_node_name_with_precision += "_fp16_int8"; - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] FP16 and INT8 mode is enabled"; - } else if (fp16_enable_) { + if (fp16_enable_) { trt_config->setFlag(nvinfer1::BuilderFlag::kFP16); trt_node_name_with_precision += "_fp16"; LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] FP16 mode is enabled"; - } else if (int8_enable_) { + } + if (bf16_enable_) { + trt_config->setFlag(nvinfer1::BuilderFlag::kBF16); + trt_node_name_with_precision += "_bf16"; + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] BF16 mode is enabled"; + } + if (int8_enable_) { trt_config->setFlag(nvinfer1::BuilderFlag::kINT8); trt_node_name_with_precision += "_int8"; LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] INT8 mode is enabled"; @@ -3448,17 +3499,9 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView // Note: Creating an execution context from an engine is thread safe per TRT doc // https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading if (context_memory_sharing_enable_) { -#if defined(_MSC_VER) -#pragma warning(push) -#pragma warning(disable : 4996) -#endif - size_t mem_size = trt_engine->getDeviceMemorySize(); -#if defined(_MSC_VER) -#pragma warning(pop) -#endif - if (mem_size > max_ctx_mem_size_) { - max_ctx_mem_size_ = mem_size; - } + // Reset the max_ctx_mem_size_ and context_memory_ since we don't have access to the allocator here. + max_ctx_mem_size_ = 0; + context_memory_ = nullptr; #if NV_TENSORRT_MAJOR < 10 trt_context = std::unique_ptr(trt_engine->createExecutionContextWithoutDeviceMemory()); #else @@ -3545,10 +3588,10 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView *p = {context->allocate_func, context->release_func, context->allocator_handle, context->node_name, builder_.get(), &parsers_[context->node_name], &engines_[context->node_name], &contexts_[context->node_name], &networks_[context->node_name], input_info_[context->node_name], output_info_[context->node_name], - input_shape_ranges_[context->node_name], &tensorrt_mu_, fp16_enable_, int8_enable_, int8_calibration_cache_available_, + input_shape_ranges_[context->node_name], &tensorrt_mu_, fp16_enable_, bf16_enable_, int8_enable_, int8_calibration_cache_available_, dla_enable_, dla_core_, trt_node_name_with_precision, engine_cache_enable_, cache_path_, runtime_.get(), profiles_[context->node_name], - context_memory_sharing_enable_, &max_ctx_mem_size_, dynamic_range_map, engine_decryption_enable_, + context_memory_sharing_enable_, &max_ctx_mem_size_, &context_memory_, dynamic_range_map, engine_decryption_enable_, engine_decryption_, engine_encryption_, timing_cache_enable_, global_cache_path_, force_timing_cache_match_, detailed_build_log_, build_heuristics_enable_, sparsity_enable_, builder_optimization_level_, auxiliary_streams_, !tactic_sources_.empty(), tactics, cuda_graph_enable_, cache_prefix_, cache_suffix, engine_hw_compatible_, @@ -3587,6 +3630,7 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView auto trt_engine = trt_state->engine->get(); auto trt_context = trt_state->context->get(); auto trt_profiles = trt_state->profiles; + auto context_memory = trt_state->context_memory; auto max_context_mem_size_ptr = trt_state->max_context_mem_size_ptr; int num_inputs = static_cast(input_indexes.size()); int num_outputs = static_cast(output_indexes.size()); @@ -3746,12 +3790,17 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView } // Set precision - if (trt_state->fp16_enable && trt_state->int8_enable) { - trt_config->setFlags(1U << static_cast(nvinfer1::BuilderFlag::kFP16) | 1U << static_cast(nvinfer1::BuilderFlag::kINT8)); - } else if (trt_state->fp16_enable) { - trt_config->setFlag(nvinfer1::BuilderFlag::kFP16); - } else if (trt_state->int8_enable) { + if (trt_state->int8_enable) { trt_config->setFlag(nvinfer1::BuilderFlag::kINT8); + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] INT8 mode is enabled"; + } + if (trt_state->fp16_enable) { + trt_config->setFlag(nvinfer1::BuilderFlag::kFP16); + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] FP16 mode is enabled"; + } + if (trt_state->bf16_enable) { + trt_config->setFlag(nvinfer1::BuilderFlag::kBF16); + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] BF16 mode is enabled"; } // Set DLA (DLA can only run with FP16 or INT8) @@ -4031,8 +4080,9 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView #endif if (mem_size > *max_context_mem_size_ptr) { *max_context_mem_size_ptr = mem_size; + *context_memory = IAllocator::MakeUniquePtrFromOrtAllocator(alloc, *max_context_mem_size_ptr, true /*use_reserve*/); } - trt_context->setDeviceMemory(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, *max_context_mem_size_ptr).get()); + trt_context->setDeviceMemory((*context_memory).get()); } // Start CUDA graph capture. @@ -4231,6 +4281,7 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(con output_info_[context->node_name], context_memory_sharing_enable_, &max_ctx_mem_size_, + &context_memory_, &tensorrt_mu_}; *state = p.release(); return 0; @@ -4259,6 +4310,7 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(con auto trt_engine = trt_state->engine->get(); auto trt_context = trt_state->context->get(); auto max_context_mem_size_ptr = trt_state->max_context_mem_size_ptr; + auto context_memory = trt_state->context_memory; int num_outputs = static_cast(output_indexes.size()); std::unordered_map> shape_tensor_values; // This map holds "shape tensor -> shape values" for the shape tensor input across this inference run std::unordered_map> shape_tensor_values_int64; // same as above but for int64 shape tensor input @@ -4356,8 +4408,9 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(con #endif if (mem_size > *max_context_mem_size_ptr) { *max_context_mem_size_ptr = mem_size; + *context_memory = IAllocator::MakeUniquePtrFromOrtAllocator(alloc, *max_context_mem_size_ptr, true /*use_reserve*/); } - trt_context->setDeviceMemory(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, *max_context_mem_size_ptr).get()); + trt_context->setDeviceMemory((*context_memory).get()); } // Start CUDA graph capture. diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h index f6c8f7d7dd46b..b00c800999f3b 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h @@ -23,6 +23,7 @@ static const std::string kMaxPartitionIterations = "ORT_TENSORRT_MAX_PARTITION_I static const std::string kMinSubgraphSize = "ORT_TENSORRT_MIN_SUBGRAPH_SIZE"; static const std::string kMaxWorkspaceSize = "ORT_TENSORRT_MAX_WORKSPACE_SIZE"; static const std::string kFP16Enable = "ORT_TENSORRT_FP16_ENABLE"; +static const std::string kBF16Enable = "ORT_TENSORRT_BF16_ENABLE"; static const std::string kINT8Enable = "ORT_TENSORRT_INT8_ENABLE"; static const std::string kINT8CalibrationTableName = "ORT_TENSORRT_INT8_CALIBRATION_TABLE_NAME"; static const std::string kINT8UseNativeTensorrtCalibrationTable = "ORT_TENSORRT_INT8_USE_NATIVE_CALIBRATION_TABLE"; @@ -172,6 +173,7 @@ struct TensorrtFuncState { std::unordered_map>>> input_shape_ranges; std::mutex* tensorrt_mu_ptr = nullptr; bool fp16_enable = false; + bool bf16_enable = false; bool int8_enable = false; bool int8_calibration_cache_available = false; bool dla_enable = false; @@ -183,6 +185,7 @@ struct TensorrtFuncState { std::vector profiles; bool context_memory_sharing_enable = false; size_t* max_context_mem_size_ptr = nullptr; + IAllocatorUniquePtr* context_memory = nullptr; std::unordered_map dynamic_range_map; bool engine_decryption_enable = false; int (*engine_decryption)(const char*, char*, size_t*) = nullptr; @@ -216,6 +219,7 @@ struct TensorrtShortFuncState { std::vector> output_info; bool context_memory_sharing_enable = false; size_t* max_context_mem_size_ptr = nullptr; + IAllocatorUniquePtr* context_memory = nullptr; std::mutex* tensorrt_mu_ptr = nullptr; }; @@ -295,6 +299,7 @@ class TensorrtExecutionProvider : public IExecutionProvider { size_t min_subgraph_size_ = 1; size_t max_workspace_size_ = 0; bool fp16_enable_ = false; + bool bf16_enable_ = false; bool int8_enable_ = false; bool dla_enable_ = false; int dla_core_ = 0; diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc index ace5bbe65fc24..1a515c37f7ecb 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc @@ -19,6 +19,7 @@ constexpr const char* kMaxPartitionIterations = "trt_max_partition_iterations"; constexpr const char* kMinSubgraphSize = "trt_min_subgraph_size"; constexpr const char* kMaxWorkspaceSize = "trt_max_workspace_size"; constexpr const char* kFp16Enable = "trt_fp16_enable"; +constexpr const char* kBf16Enable = "trt_bf16_enable"; constexpr const char* kInt8Enable = "trt_int8_enable"; constexpr const char* kInt8CalibTable = "trt_int8_calibration_table_name"; constexpr const char* kInt8UseNativeCalibTable = "trt_int8_use_native_calibration_table"; @@ -93,6 +94,7 @@ TensorrtExecutionProviderInfo TensorrtExecutionProviderInfo::FromProviderOptions .AddAssignmentToReference(tensorrt::provider_option_names::kMinSubgraphSize, info.min_subgraph_size) .AddAssignmentToReference(tensorrt::provider_option_names::kMaxWorkspaceSize, info.max_workspace_size) .AddAssignmentToReference(tensorrt::provider_option_names::kFp16Enable, info.fp16_enable) + .AddAssignmentToReference(tensorrt::provider_option_names::kBf16Enable, info.bf16_enable) .AddAssignmentToReference(tensorrt::provider_option_names::kInt8Enable, info.int8_enable) .AddAssignmentToReference(tensorrt::provider_option_names::kInt8CalibTable, info.int8_calibration_table_name) .AddAssignmentToReference(tensorrt::provider_option_names::kInt8UseNativeCalibTable, info.int8_use_native_calibration_table) @@ -155,6 +157,7 @@ ProviderOptions TensorrtExecutionProviderInfo::ToProviderOptions(const TensorrtE {tensorrt::provider_option_names::kMinSubgraphSize, MakeStringWithClassicLocale(info.min_subgraph_size)}, {tensorrt::provider_option_names::kMaxWorkspaceSize, MakeStringWithClassicLocale(info.max_workspace_size)}, {tensorrt::provider_option_names::kFp16Enable, MakeStringWithClassicLocale(info.fp16_enable)}, + {tensorrt::provider_option_names::kBf16Enable, MakeStringWithClassicLocale(info.bf16_enable)}, {tensorrt::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.int8_enable)}, {tensorrt::provider_option_names::kInt8CalibTable, MakeStringWithClassicLocale(info.int8_calibration_table_name)}, {tensorrt::provider_option_names::kInt8UseNativeCalibTable, MakeStringWithClassicLocale(info.int8_use_native_calibration_table)}, @@ -222,6 +225,7 @@ ProviderOptions TensorrtExecutionProviderInfo::ToProviderOptions(const OrtTensor {tensorrt::provider_option_names::kMinSubgraphSize, MakeStringWithClassicLocale(info.trt_min_subgraph_size)}, {tensorrt::provider_option_names::kMaxWorkspaceSize, MakeStringWithClassicLocale(info.trt_max_workspace_size)}, {tensorrt::provider_option_names::kFp16Enable, MakeStringWithClassicLocale(info.trt_fp16_enable)}, + {tensorrt::provider_option_names::kBf16Enable, MakeStringWithClassicLocale(info.trt_bf16_enable)}, {tensorrt::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.trt_int8_enable)}, {tensorrt::provider_option_names::kInt8CalibTable, kInt8CalibTable_}, {tensorrt::provider_option_names::kInt8UseNativeCalibTable, MakeStringWithClassicLocale(info.trt_int8_use_native_calibration_table)}, @@ -319,6 +323,7 @@ void TensorrtExecutionProviderInfo::UpdateProviderOptions(void* provider_options trt_provider_options_v2.trt_min_subgraph_size = internal_options.min_subgraph_size; trt_provider_options_v2.trt_max_workspace_size = internal_options.max_workspace_size; trt_provider_options_v2.trt_fp16_enable = internal_options.fp16_enable; + trt_provider_options_v2.trt_bf16_enable = internal_options.bf16_enable; trt_provider_options_v2.trt_int8_enable = internal_options.int8_enable; trt_provider_options_v2.trt_int8_calibration_table_name = copy_string_if_needed(internal_options.int8_calibration_table_name); diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.h index 139319829c210..a7c3624674dc6 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.h @@ -24,6 +24,7 @@ struct TensorrtExecutionProviderInfo { int min_subgraph_size{1}; size_t max_workspace_size{0}; bool fp16_enable{false}; + bool bf16_enable{false}; bool int8_enable{false}; std::string int8_calibration_table_name{""}; bool int8_use_native_calibration_table{false}; diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc b/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc index 0d2e88d17519c..da1c2514bf6a2 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc @@ -82,6 +82,7 @@ struct Tensorrt_Provider : Provider { info.min_subgraph_size = options.trt_min_subgraph_size; info.max_workspace_size = options.trt_max_workspace_size; info.fp16_enable = options.trt_fp16_enable != 0; + info.bf16_enable = options.trt_bf16_enable != 0; info.int8_enable = options.trt_int8_enable != 0; info.int8_calibration_table_name = options.trt_int8_calibration_table_name == nullptr ? "" : options.trt_int8_calibration_table_name; info.int8_use_native_calibration_table = options.trt_int8_use_native_calibration_table != 0; diff --git a/onnxruntime/core/providers/vitisai/vitisai_provider_factory.cc b/onnxruntime/core/providers/vitisai/vitisai_provider_factory.cc index fc5cce4257ebe..6849bcfc21f88 100644 --- a/onnxruntime/core/providers/vitisai/vitisai_provider_factory.cc +++ b/onnxruntime/core/providers/vitisai/vitisai_provider_factory.cc @@ -52,6 +52,8 @@ std::unique_ptr VitisAIProviderFactory::CreateProvider(const for (const auto& [key, value] : config_options_map) { if (key.rfind(key_prefix, 0) == 0) { provider_options[key.substr(key_prefix.size())] = value; + } else { + provider_options["ort_session_config." + key] = value; } } diff --git a/onnxruntime/core/providers/webgpu/math/softmax.cc b/onnxruntime/core/providers/webgpu/math/softmax.cc index 178ca0b9e0515..2f34aa21c8309 100644 --- a/onnxruntime/core/providers/webgpu/math/softmax.cc +++ b/onnxruntime/core/providers/webgpu/math/softmax.cc @@ -141,7 +141,9 @@ Status SoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const { // Calculate the final value for each element in the row << " for (var col = lindex; col < cols; col += wg) {\n" - << " let value = exp(getValue(row, col, row_stride) - row_max_shared) / row_sum_shared;\n" + << " var value = exp(getValue(row, col, row_stride) - row_max_shared) / row_sum_shared;\n" + << " // max operation protects against NaN since all values should be >=0\n" + << " value = max(value, x_value_t(0.0));\n" << " setValue(row, col, row_stride, value);\n" << " }\n"; diff --git a/onnxruntime/core/providers/webnn/builders/helper.h b/onnxruntime/core/providers/webnn/builders/helper.h index 6556c293f81bf..f124e90580353 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.h +++ b/onnxruntime/core/providers/webnn/builders/helper.h @@ -9,8 +9,8 @@ #include "core/common/inlined_containers.h" #include #include "core/optimizer/initializer.h" -#include "core/providers/common.h" #include "core/providers/shared/utils/utils.h" +#include "map_info.h" #include #include @@ -201,183 +201,27 @@ std::unordered_set GetSupportedNodes(const GraphViewer& graph_viewe const emscripten::val& wnn_limits, const logging::Logger& logger); -// Some ONNX ops are supported by decomposed WebNN ops. -const std::map> decomposed_op_map = { - {"GroupQueryAttention", - {"add", "cast", "concat", "constant", "cumulativeSum", "div", "expand", "lesser", "matmul", "reshape", "scatterND", - "softmax", "transpose", "where"}}, - {"LRN", {"add", "averagePool2d", "div", "mul", "pad", "pow", "transpose"}}, - {"MatMulInteger", {"cast", "dequantizeLinear", "matmul"}}, - {"MatMulNBits", {"add", "dequantizeLinear", "matmul", "reshape", "transpose"}}, - {"MultiHeadAttention", {"add", "cast", "concat", "constant", "div", "matmul", "reshape", "softmax", "transpose"}}, - {"RotaryEmbedding", {"add", "concat", "gather", "mul", "reshape", "slice", "split"}}, - {"SimplifiedLayerNormalization", {"add", "div", "mul", "pow", "reduceMean", "sqrt"}}, - {"SkipSimplifiedLayerNormalization", {"add", "div", "mul", "pow", "reduceMean", "sqrt"}}, -}; -// ONNX op type to WebNN op type mapping. -const std::map op_map = { - {"Abs", "abs"}, - {"Add", "add"}, - {"And", "logicalAnd"}, - {"ArgMax", "argMax"}, - {"ArgMin", "argMin"}, - {"AveragePool", "averagePool2d"}, - {"BatchNormalization", "batchNormalization"}, - {"Cast", "cast"}, - {"Ceil", "ceil"}, - {"Clip", "clamp"}, - {"Concat", "concat"}, - {"Conv", "conv2d"}, - {"ConvInteger", "conv2dInteger"}, - {"ConvTranspose", "convTranspose2d"}, - {"Cos", "cos"}, - {"CumSum", "cumulativeSum"}, - {"Div", "div"}, - {"DequantizeLinear", "dequantizeLinear"}, - {"Dropout", "identity"}, - {"DynamicQuantizeLinear", "dynamicQuantizeLinear"}, - {"Einsum", "matmul"}, - {"Elu", "elu"}, - {"Equal", "equal"}, - {"Erf", "erf"}, - {"Exp", "exp"}, - {"Expand", "expand"}, - {"Flatten", "reshape"}, - {"Floor", "floor"}, - {"Gather", "gather"}, - {"GatherElements", "gatherElements"}, - {"GatherND", "gatherND"}, - {"Gelu", "gelu"}, - {"Gemm", "gemm"}, - {"GlobalAveragePool", "averagePool2d"}, - {"GlobalMaxPool", "maxPool2d"}, - {"GlobalLpPool", "l2Pool2d"}, - {"Greater", "greater"}, - {"GreaterOrEqual", "greaterOrEqual"}, - {"GRU", "gru"}, - {"HardSigmoid", "hardSigmoid"}, - {"HardSwish", "hardSwish"}, - {"Identity", "identity"}, - {"InstanceNormalization", "instanceNormalization"}, - {"LayerNormalization", "layerNormalization"}, - {"LeakyRelu", "leakyRelu"}, - {"Less", "lesser"}, - {"LessOrEqual", "lesserOrEqual"}, - {"Log", "log"}, - {"LpPool", "l2Pool2d"}, - {"LSTM", "lstm"}, - {"MatMul", "matmul"}, - {"Max", "max"}, - {"MaxPool", "maxPool2d"}, - {"Min", "min"}, - {"Mul", "mul"}, - {"Neg", "neg"}, - {"Not", "logicalNot"}, - {"Or", "logicalOr"}, - {"Pad", "pad"}, - {"Pow", "pow"}, - {"PRelu", "prelu"}, - {"QuantizeLinear", "quantizeLinear"}, - {"Reciprocal", "reciprocal"}, - {"ReduceL1", "reduceL1"}, - {"ReduceL2", "reduceL2"}, - {"ReduceLogSum", "reduceLogSum"}, - {"ReduceLogSumExp", "reduceLogSumExp"}, - {"ReduceMax", "reduceMax"}, - {"ReduceMean", "reduceMean"}, - {"ReduceMin", "reduceMin"}, - {"ReduceProd", "reduceProduct"}, - {"ReduceSum", "reduceSum"}, - {"ReduceSumSquare", "reduceSumSquare"}, - {"Relu", "relu"}, - {"Reshape", "reshape"}, - {"Resize", "resample2d"}, - {"ScatterElements", "scatterElements"}, - {"ScatterND", "scatterND"}, - {"Shape", "slice"}, - {"Sigmoid", "sigmoid"}, - {"Sign", "sign"}, - {"Softplus", "softplus"}, - {"Softsign", "softsign"}, - {"Sin", "sin"}, - {"Slice", "slice"}, - {"Softmax", "softmax"}, - {"Split", "split"}, - {"Sqrt", "sqrt"}, - {"Squeeze", "reshape"}, - {"Sub", "sub"}, - {"Tan", "tan"}, - {"Tanh", "tanh"}, - {"Tile", "tile"}, - {"Transpose", "transpose"}, - {"Trilu", "triangular"}, - {"Unsqueeze", "reshape"}, - {"Where", "where"}, - {"Xor", "logicalXor"}, -}; - -// WebNN op name to its first input name mapping, only record the name that is different from "input". -// This map is used to determine the first input name of a WebNN op and is utilized by OpSupportLimits. -const std::map webnn_op_first_input_name_map = { - {"add", "a"}, - {"concat", "inputs"}, - {"div", "a"}, - {"equal", "a"}, - {"gemm", "a"}, - {"greater", "a"}, - {"greaterOrEqual", "a"}, - {"lesser", "a"}, - {"lesserOrEqual", "a"}, - {"logicalAnd", "a"}, - {"logicalNot", "a"}, - {"logicalOr", "a"}, - {"logicalXor", "a"}, - {"matmul", "a"}, - {"max", "a"}, - {"min", "a"}, - {"mul", "a"}, - {"pow", "a"}, - {"sub", "a"}, - {"where", "condition"}, -}; - // Retrieve the first input name of a WebNN op used for validating supported input data types. // WebNN ops have various first input names such as 'a', 'input', 'inputs', etc. -// Special names other than 'input' are recorded in the webnn_op_first_input_name_map. +// All WebNN op inputs are recorded in op_inputs_map. inline std::string_view GetWebNNOpFirstInputName(const std::string_view webnn_op_type) { - auto it = webnn_op_first_input_name_map.find(webnn_op_type); - return (it != webnn_op_first_input_name_map.end()) ? it->second : "input"; + auto it = op_inputs_map.find(webnn_op_type); + if (it != op_inputs_map.end()) { + for (const auto& input : it->second.inputs) { + if (input.index == 0) { + return input.name; + } + } + } + return "input"; } inline std::string_view GetWebNNOpType(const std::string_view op_type) { - auto it = op_map.find(op_type); - // Return an empty string if the op_type is not listed in the op_map. - return (it != op_map.end()) ? it->second : ""; + auto it = op_inputs_map.find(op_type); + // Return an empty string if the op_type is not listed in the op_inputs_map. + return (it != op_inputs_map.end()) ? it->second.opType : ""; } -const std::map onnx_to_webnn_data_type_map = { - {ONNX_NAMESPACE::TensorProto_DataType_INT4, "int4"}, - {ONNX_NAMESPACE::TensorProto_DataType_UINT4, "uint4"}, - {ONNX_NAMESPACE::TensorProto_DataType_BOOL, "uint8"}, - {ONNX_NAMESPACE::TensorProto_DataType_INT8, "int8"}, - {ONNX_NAMESPACE::TensorProto_DataType_UINT8, "uint8"}, - {ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, "float16"}, - {ONNX_NAMESPACE::TensorProto_DataType_FLOAT, "float32"}, - {ONNX_NAMESPACE::TensorProto_DataType_INT32, "int32"}, - {ONNX_NAMESPACE::TensorProto_DataType_INT64, "int64"}, - {ONNX_NAMESPACE::TensorProto_DataType_UINT32, "uint32"}, - {ONNX_NAMESPACE::TensorProto_DataType_UINT64, "uint64"}, -}; - -// This array contains the input/output data types of a WebNN graph that are allowed to be fallback to int32. -constexpr std::array supported_fallback_integer_data_types = { - ONNX_NAMESPACE::TensorProto_DataType_BOOL, - ONNX_NAMESPACE::TensorProto_DataType_INT8, - ONNX_NAMESPACE::TensorProto_DataType_UINT8, - ONNX_NAMESPACE::TensorProto_DataType_UINT32, - ONNX_NAMESPACE::TensorProto_DataType_INT64, -}; - bool AreDataTypesSame(const std::string_view op_type, gsl::span input_types, const logging::Logger& logger); diff --git a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc index 1924c3cb5e698..b9383a63fe307 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc @@ -30,6 +30,8 @@ class ConvOpBuilder : public BaseOpBuilder { const WebnnDeviceType device_type, const logging::Logger& logger) const override; bool HasSupportedInputsImpl(const GraphViewer&, const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const override; + bool HasSupportedOutputsImpl(const Node& node, const emscripten::val& wnn_limits, + const logging::Logger& logger) const override; }; void ConvOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const { @@ -52,18 +54,19 @@ common::Status SetConvBaseOptions(ModelBuilder& model_builder, const logging::Logger& logger) { NodeAttrHelper helper(node); const auto& input_defs = node.InputDefs(); + const auto& op_type = node.OpType(); // Add Padding. AutoPadType auto_pad_type = StringToAutoPadType(helper.Get("auto_pad", "NOTSET")); std::vector pads_out; - if (node.OpType() == "Conv" || node.OpType() == "ConvInteger") { + if (op_type == "Conv" || op_type == "ConvInteger") { // Calculate explicit padding for autoPad. if (AutoPadType::SAME_UPPER == auto_pad_type || AutoPadType::SAME_LOWER == auto_pad_type) { ORT_RETURN_IF_ERROR(HandleAutoPad(input_shape, weight_shape[2], weight_shape[3], pads, strides, dilations, auto_pad_type, pads_out, !is_nhwc)); pads = pads_out; } - } else if (node.OpType() == "ConvTranspose") { + } else if (op_type == "ConvTranspose") { std::vector output_shape = helper.Get("output_shape", std::vector{-1, -1}); // Appending 1's if it is ConvTranspose 1d and output shape is provided. if (output_shape.size() == 1 && is_conv1d && output_shape[0] != -1) { @@ -103,7 +106,7 @@ common::Status SetConvBaseOptions(ModelBuilder& model_builder, options.set("padding", emscripten::val::array(GetNarrowedIntfromInt64(padding))); // Add bias if present. - if (input_defs.size() > 2) { + if (input_defs.size() > 2 && op_type != "ConvInteger") { options.set("bias", model_builder.GetOperand(input_defs[2]->Name())); } @@ -219,6 +222,8 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N const bool is_nhwc = model_builder.GetPreferredLayout() == DataLayout::NHWC; const bool is_conv1d = input_shape.size() == 3 && weight_shape.size() == 3; const bool is_constant_weight = Contains(initializers, weight_name); + + emscripten::val common_options = emscripten::val::object(); // Support conv1d by prepending a 1 or 2 size dimensions. if (is_conv1d) { // Reshape input. @@ -230,7 +235,9 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N input_shape.push_back(1); } std::vector new_shape = GetNarrowedIntfromInt64(input_shape); - input = model_builder.GetBuilder().call("reshape", input, emscripten::val::array(new_shape)); + common_options.set("label", node.Name() + "_reshape_input"); + input = model_builder.GetBuilder().call("reshape", input, + emscripten::val::array(new_shape), common_options); weight_shape.resize(4, 1); // Ensure 4D by appending 1's if needed. strides.resize(2, 1); // Ensure 2D by appending 1's if needed. @@ -277,16 +284,14 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N if (!is_nhwc || !is_constant_weight) { // The weight_shape has been appended 1's, reshape weight operand. std::vector new_shape = GetNarrowedIntfromInt64(weight_shape); - emscripten::val reshape_options = emscripten::val::object(); - reshape_options.set("label", node.Name() + "_reshape_filter"); + common_options.set("label", node.Name() + "_reshape_filter"); filter = model_builder.GetBuilder().call("reshape", filter, emscripten::val::array(new_shape), - reshape_options); + common_options); } } - emscripten::val transpose_options = emscripten::val::object(); if (is_nhwc && !is_constant_weight) { // For NHWC preferred layout, if the weight is input: // - Transpose it from iohw -> ohwi for convTranspose. @@ -298,6 +303,7 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N } else { perm = {0, 2, 3, 1}; // L_0231 } + emscripten::val transpose_options = emscripten::val::object(); transpose_options.set("permutation", emscripten::val::array(perm)); transpose_options.set("label", node.Name() + "_transpose_filter"); filter = model_builder.GetBuilder().call("transpose", filter, transpose_options); @@ -306,20 +312,48 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N if (op_type == "Conv") { output = model_builder.GetBuilder().call("conv2d", input, filter, options); } else if (op_type == "ConvInteger") { - emscripten::val x_zero_point = emscripten::val::null(); - emscripten::val w_zero_point = emscripten::val::null(); - if (input_defs.size() >= 3) { + // WebNN doesn't provide a dedicated op for ConvInteger, it can be simply decomposed by + // DequantizeLinear x, w -> Conv -> Cast (to int32) + int32_t x_type; + ORT_RETURN_IF_NOT(GetType(*input_defs[0], x_type, logger), "Cannot get data type of input x"); + + emscripten::val x_zero_point, w_zero_point, x_scale, w_scale; + if (TensorExists(input_defs, 2)) { x_zero_point = model_builder.GetOperand(node.InputDefs()[2]->Name()); } else { - x_zero_point = model_builder.CreateOrGetConstant(ONNX_NAMESPACE::TensorProto_DataType_UINT8, 0); + x_zero_point = model_builder.CreateOrGetConstant(x_type, 0); } - if (input_defs.size() >= 4) { + + // Scale is not used by ConvInteger but required by DequantizeLinear. So set it to deafult value 1.0f. + // The x_zero_point must be a scalar and the scale input should have the same shape as the zero point input. + // So the x_scale must be a scalar too. + x_scale = model_builder.CreateOrGetConstant(ONNX_NAMESPACE::TensorProto_DataType_FLOAT, 1.0f); + // Dequantize x to Float32 + common_options.set("label", node.Name() + "_dequantized_x"); + input = model_builder.GetBuilder().call("dequantizeLinear", input, x_scale, x_zero_point, + common_options); + + if (TensorExists(input_defs, 3)) { w_zero_point = model_builder.GetOperand(node.InputDefs()[3]->Name()); + std::vector w_zero_point_shape; + ORT_RETURN_IF_NOT(GetShape(*input_defs[3], w_zero_point_shape, logger), "Cannot get shape of w_zero_point"); + w_scale = model_builder.CreateOrGetConstant(ONNX_NAMESPACE::TensorProto_DataType_FLOAT, 1.0f, + GetNarrowedIntfromInt64(w_zero_point_shape)); } else { - w_zero_point = model_builder.CreateOrGetConstant(ONNX_NAMESPACE::TensorProto_DataType_UINT8, 0); + w_zero_point = model_builder.CreateOrGetConstant(x_type, 0); + w_scale = x_scale; } - output = model_builder.GetBuilder().call("conv2dInteger", - input, x_zero_point, filter, w_zero_point, options); + // Dequantize w to Float32 + common_options.set("label", node.Name() + "_dequantized_w"); + filter = model_builder.GetBuilder().call("dequantizeLinear", filter, w_scale, w_zero_point, + common_options); + // Conv with dequantized x and w + options.set("label", node.Name() + "_conv_dequantized_inputs"); + output = model_builder.GetBuilder().call("conv2d", input, filter, options); + + // Cast the result to int32 + common_options.set("label", node.Name() + "_cast_output"); + output = model_builder.GetBuilder().call("cast", output, emscripten::val("int32"), common_options); } else { output = model_builder.GetBuilder().call("convTranspose2d", input, filter, options); } @@ -330,12 +364,11 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N std::vector output_shape; ORT_RETURN_IF_NOT(GetShape(*output_defs[0], output_shape, logger), "Cannot get output shape"); std::vector new_shape = GetNarrowedIntfromInt64(output_shape); - emscripten::val reshape_options = emscripten::val::object(); - reshape_options.set("label", node.Name() + "_reshape_output"); + common_options.set("label", node.Name() + "_reshape_output"); output = model_builder.GetBuilder().call("reshape", output, emscripten::val::array(new_shape), - reshape_options); + common_options); } model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); @@ -410,7 +443,31 @@ bool ConvOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& node, return false; } - return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "input", "X", logger); + if (op_type == "ConvInteger") { + // The first decomposed op of ConvInteger is DequantizeLinear, and so + // we only need to ensure it supports the input0_type. + return IsDataTypeSupportedByOp("DequantizeLinear", input0_type, wnn_limits, "input", "x", logger); + } else { + return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "input", "X", logger); + } +} + +bool ConvOpBuilder::HasSupportedOutputsImpl(const Node& node, const emscripten::val& wnn_limits, + const logging::Logger& logger) const { + const auto& output = *node.OutputDefs()[0]; + const std::string_view op_type = node.OpType(); + int32_t output_type; + if (!GetType(output, output_type, logger)) { + return false; + } + + if (op_type == "ConvInteger") { + // The last decomposed op of ConvInteger is Cast, and so + // we only need to ensure it supports the output_type. + return IsDataTypeSupportedByOp("Cast", output_type, wnn_limits, "output", "Output", logger); + } else { + return IsDataTypeSupportedByOp(op_type, output_type, wnn_limits, "output", "Output", logger); + } } void CreateConvOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/map_info.h b/onnxruntime/core/providers/webnn/builders/map_info.h new file mode 100644 index 0000000000000..59408ba244842 --- /dev/null +++ b/onnxruntime/core/providers/webnn/builders/map_info.h @@ -0,0 +1,205 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Intel Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include "core/providers/common.h" + +/** + * This file defines mappings and structures to facilitate the translation of ONNX operations + * and data types to their corresponding WebNN representations. + * + * It includes: + * - Data type mappings between ONNX and WebNN. + * - Lists of supported fallback integer types for WebNN. + * - Decomposition of certain ONNX operations into sequences of WebNN operations. + * - Structures and maps for input index-to-name translation for ONNX to WebNN ops. + */ +namespace onnxruntime { +namespace webnn { +const std::map onnx_to_webnn_data_type_map = { + {ONNX_NAMESPACE::TensorProto_DataType_INT4, "int4"}, + {ONNX_NAMESPACE::TensorProto_DataType_UINT4, "uint4"}, + {ONNX_NAMESPACE::TensorProto_DataType_BOOL, "uint8"}, + {ONNX_NAMESPACE::TensorProto_DataType_INT8, "int8"}, + {ONNX_NAMESPACE::TensorProto_DataType_UINT8, "uint8"}, + {ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, "float16"}, + {ONNX_NAMESPACE::TensorProto_DataType_FLOAT, "float32"}, + {ONNX_NAMESPACE::TensorProto_DataType_INT32, "int32"}, + {ONNX_NAMESPACE::TensorProto_DataType_INT64, "int64"}, + {ONNX_NAMESPACE::TensorProto_DataType_UINT32, "uint32"}, + {ONNX_NAMESPACE::TensorProto_DataType_UINT64, "uint64"}, +}; + +// This array contains the input/output data types of a WebNN graph that are allowed to be fallback to int32. +constexpr std::array supported_fallback_integer_data_types = { + ONNX_NAMESPACE::TensorProto_DataType_BOOL, + ONNX_NAMESPACE::TensorProto_DataType_INT8, + ONNX_NAMESPACE::TensorProto_DataType_UINT8, + ONNX_NAMESPACE::TensorProto_DataType_UINT32, + ONNX_NAMESPACE::TensorProto_DataType_INT64, +}; + +// Some ONNX ops are supported by decomposed WebNN ops. +const std::map> decomposed_op_map = { + {"ConvInteger", {"cast", "conv2d", "dequantizeLinear"}}, + {"GroupQueryAttention", + {"add", "cast", "concat", "constant", "cumulativeSum", "div", "expand", "lesser", "matmul", "reshape", "scatterND", + "softmax", "transpose", "where"}}, + {"LRN", {"add", "averagePool2d", "div", "mul", "pad", "pow", "transpose"}}, + {"MatMulInteger", {"cast", "dequantizeLinear", "matmul"}}, + {"MatMulNBits", {"add", "dequantizeLinear", "matmul", "reshape", "transpose"}}, + {"MultiHeadAttention", {"add", "cast", "concat", "constant", "div", "matmul", "reshape", "softmax", "transpose"}}, + {"RotaryEmbedding", {"add", "concat", "gather", "mul", "reshape", "slice", "split"}}, + {"SimplifiedLayerNormalization", {"add", "div", "mul", "pow", "reduceMean", "sqrt"}}, + {"SkipSimplifiedLayerNormalization", {"add", "div", "mul", "pow", "reduceMean", "sqrt"}}, +}; + +/** + * Represents information about an input to a WebNN operation. + * + * This structure is used to map ONNX operation inputs to their corresponding + * WebNN operation inputs. It contains the index of the input as specified + * in the ONNX operation and the name of the input in the WebNN operation. + * + * InputInfo::index + * The index of this input as defined in the ONNX operation specification. + * + * InputInfo::name + * The name of this input in the WebNN operation. + */ +struct InputInfo { + int index; + std::string_view name; +}; + +struct WebnnOpInfo { + std::string_view opType; + std::vector inputs; + WebnnOpInfo(std::string_view op, std::initializer_list in) + : opType(op), inputs(in) {} +}; + +/** + * Maps ONNX operation type to their corresponding WebNN operation type and input mappings. + * + * This unordered map provides a mapping between ONNX operation names (keys) and their corresponding + * WebNN operation information (values). Each value is a `WebnnOpInfo` structure that contains: + * - The WebNN operation name (`opType`). + * - A vector of `InputInfo` structures, where each `InputInfo` specifies: + * - The index of the input in the ONNX operation (`index`). + * - The corresponding input name in the WebNN operation (`name`). + * + * For the ONNX operation "Abs", it has only one "input", which is at index 0 in the "Node.InputDefs" array. + * The corresponding WebNN operation is "abs", and the input name is "input". + * + * This mapping is used to translate ONNX operations and their inputs into WebNN operations + * and their respective input names. + * + * Order: + * The sorting rule is based on character length in ascending order (for better formatting), + * and for items with the same length, they are sorted alphabetically. + */ +const std::unordered_map op_inputs_map = { + {"Cos", {"cos", {{0, "input"}}}}, + {"Abs", {"abs", {{0, "input"}}}}, + {"Elu", {"elu", {{0, "input"}}}}, + {"Erf", {"erf", {{0, "input"}}}}, + {"Exp", {"exp", {{0, "input"}}}}, + {"Log", {"log", {{0, "input"}}}}, + {"Neg", {"neg", {{0, "input"}}}}, + {"Pad", {"pad", {{0, "input"}}}}, + {"Sin", {"sin", {{0, "input"}}}}, + {"Tan", {"tan", {{0, "input"}}}}, + {"Cast", {"cast", {{0, "input"}}}}, + {"Ceil", {"ceil", {{0, "input"}}}}, + {"Gelu", {"gelu", {{0, "input"}}}}, + {"Relu", {"relu", {{0, "input"}}}}, + {"Sign", {"sign", {{0, "input"}}}}, + {"Sqrt", {"sqrt", {{0, "input"}}}}, + {"Tanh", {"tanh", {{0, "input"}}}}, + {"Tile", {"tile", {{0, "input"}}}}, + {"Clip", {"clamp", {{0, "input"}}}}, + {"Floor", {"floor", {{0, "input"}}}}, + {"Shape", {"slice", {{0, "input"}}}}, + {"Slice", {"slice", {{0, "input"}}}}, + {"Split", {"split", {{0, "input"}}}}, + {"Sub", {"sub", {{0, "a"}, {1, "b"}}}}, + {"Add", {"add", {{0, "a"}, {1, "b"}}}}, + {"ArgMax", {"argMax", {{0, "input"}}}}, + {"ArgMin", {"argMin", {{0, "input"}}}}, + {"Div", {"div", {{0, "a"}, {1, "b"}}}}, + {"Expand", {"expand", {{0, "input"}}}}, + {"Max", {"max", {{0, "a"}, {1, "b"}}}}, + {"Min", {"min", {{0, "a"}, {1, "b"}}}}, + {"Mul", {"mul", {{0, "a"}, {1, "b"}}}}, + {"Pow", {"pow", {{0, "a"}, {1, "b"}}}}, + {"Concat", {"concat", {{0, "inputs"}}}}, + {"Not", {"logicalNot", {{0, "input"}}}}, + {"Flatten", {"reshape", {{0, "input"}}}}, + {"LpPool", {"l2Pool2d", {{0, "input"}}}}, + {"Reshape", {"reshape", {{0, "input"}}}}, + {"Sigmoid", {"sigmoid", {{0, "input"}}}}, + {"Softmax", {"softmax", {{0, "input"}}}}, + {"Squeeze", {"reshape", {{0, "input"}}}}, + {"Dropout", {"identity", {{0, "input"}}}}, + {"Trilu", {"triangular", {{0, "input"}}}}, + {"Equal", {"equal", {{0, "a"}, {1, "b"}}}}, + {"Identity", {"identity", {{0, "input"}}}}, + {"Less", {"lesser", {{0, "a"}, {1, "b"}}}}, + {"MaxPool", {"maxPool2d", {{0, "input"}}}}, + {"ReduceL1", {"reduceL1", {{0, "input"}}}}, + {"ReduceL2", {"reduceL2", {{0, "input"}}}}, + {"Resize", {"resample2d", {{0, "input"}}}}, + {"Softplus", {"softplus", {{0, "input"}}}}, + {"Softsign", {"softsign", {{0, "input"}}}}, + {"Unsqueeze", {"reshape", {{0, "input"}}}}, + {"Or", {"logicalOr", {{0, "a"}, {1, "b"}}}}, + {"Einsum", {"matmul", {{0, "a"}, {1, "b"}}}}, + {"HardSwish", {"hardSwish", {{0, "input"}}}}, + {"LeakyRelu", {"leakyRelu", {{0, "input"}}}}, + {"MatMul", {"matmul", {{0, "a"}, {1, "b"}}}}, + {"ReduceMax", {"reduceMax", {{0, "input"}}}}, + {"ReduceMin", {"reduceMin", {{0, "input"}}}}, + {"ReduceSum", {"reduceSum", {{0, "input"}}}}, + {"Transpose", {"transpose", {{0, "input"}}}}, + {"And", {"logicalAnd", {{0, "a"}, {1, "b"}}}}, + {"CumSum", {"cumulativeSum", {{0, "input"}}}}, + {"Xor", {"logicalXor", {{0, "a"}, {1, "b"}}}}, + {"GlobalLpPool", {"l2Pool2d", {{0, "input"}}}}, + {"Greater", {"greater", {{0, "a"}, {1, "b"}}}}, + {"Reciprocal", {"reciprocal", {{0, "input"}}}}, + {"ReduceMean", {"reduceMean", {{0, "input"}}}}, + {"GlobalMaxPool", {"maxPool2d", {{0, "input"}}}}, + {"HardSigmoid", {"hardSigmoid", {{0, "input"}}}}, + {"ReduceProd", {"reduceProduct", {{0, "input"}}}}, + {"AveragePool", {"averagePool2d", {{0, "input"}}}}, + {"Gemm", {"gemm", {{0, "a"}, {1, "b"}, {2, "c"}}}}, + {"PRelu", {"prelu", {{0, "input"}, {1, "slope"}}}}, + {"ReduceLogSum", {"reduceLogSum", {{0, "input"}}}}, + {"Gather", {"gather", {{0, "input"}, {1, "indices"}}}}, + {"LessOrEqual", {"lesserOrEqual", {{0, "a"}, {1, "b"}}}}, + {"GlobalAveragePool", {"averagePool2d", {{0, "input"}}}}, + {"ReduceLogSumExp", {"reduceLogSumExp", {{0, "input"}}}}, + {"ReduceSumSquare", {"reduceSumSquare", {{0, "input"}}}}, + {"GatherND", {"gatherND", {{0, "input"}, {1, "indices"}}}}, + {"GreaterOrEqual", {"greaterOrEqual", {{0, "a"}, {1, "b"}}}}, + {"Conv", {"conv2d", {{0, "input"}, {1, "filter"}, {2, "bias"}}}}, + {"DynamicQuantizeLinear", {"dynamicQuantizeLinear", {{0, "input"}}}}, + {"GatherElements", {"gatherElements", {{0, "input"}, {1, "indices"}}}}, + {"ScatterND", {"scatterND", {{0, "input"}, {1, "indices"}, {2, "updates"}}}}, + {"Where", {"where", {{0, "condition"}, {1, "trueValue"}, {2, "falseValue"}}}}, + {"ConvTranspose", {"convTranspose2d", {{0, "input"}, {1, "filter"}, {2, "bias"}}}}, + {"QuantizeLinear", {"quantizeLinear", {{0, "input"}, {1, "scale"}, {2, "zeroPoint"}}}}, + {"ScatterElements", {"scatterElements", {{0, "input"}, {1, "indices"}, {2, "updates"}}}}, + {"LayerNormalization", {"layerNormalization", {{0, "input"}, {1, "scale"}, {2, "bias"}}}}, + {"DequantizeLinear", {"dequantizeLinear", {{0, "input"}, {1, "scale"}, {2, "zeroPoint"}}}}, + {"InstanceNormalization", {"instanceNormalization", {{0, "input"}, {1, "scale"}, {2, "bias"}}}}, + {"GRU", {"gru", {{0, "input"}, {1, "weight"}, {2, "recurrentWeight"}, {3, "bias"}, {5, "initialHiddenState"}}}}, + {"BatchNormalization", {"batchNormalization", {{0, "input"}, {1, "scale"}, {2, "bias"}, {3, "input_mean"}, {4, "input_var"}}}}, + {"LSTM", {"lstm", {{0, "input"}, {1, "weight"}, {2, "recurrentWeight"}, {3, "bias"}, {5, "initialHiddenState"}, {6, "initialCellState"}, {7, "peepholeWeight"}}}}, +}; +} // namespace webnn +} // namespace onnxruntime diff --git a/onnxruntime/core/session/compile_api.cc b/onnxruntime/core/session/compile_api.cc index ad128fee6cc3d..d910e3ea74b57 100644 --- a/onnxruntime/core/session/compile_api.cc +++ b/onnxruntime/core/session/compile_api.cc @@ -201,6 +201,21 @@ ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetEpContextEmbedMode API_IMPL_END } +ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetFlags, + _In_ OrtModelCompilationOptions* ort_model_compile_options, size_t flags) { + API_IMPL_BEGIN +#if !defined(ORT_MINIMAL_BUILD) + auto model_compile_options = reinterpret_cast(ort_model_compile_options); + ORT_API_RETURN_IF_STATUS_NOT_OK(model_compile_options->SetFlags(flags)); + return nullptr; +#else + ORT_UNUSED_PARAMETER(ort_model_compile_options); + ORT_UNUSED_PARAMETER(flags); + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "Compile API is not supported in this build"); +#endif // !defined(ORT_MINIMAL_BUILD) + API_IMPL_END +} + ORT_API_STATUS_IMPL(OrtCompileAPI::CompileModel, _In_ const OrtEnv* env, _In_ const OrtModelCompilationOptions* ort_model_compile_options) { API_IMPL_BEGIN @@ -217,8 +232,9 @@ ORT_API_STATUS_IMPL(OrtCompileAPI::CompileModel, _In_ const OrtEnv* env, } static constexpr OrtCompileApi ort_compile_api = { - // NOTE: The C# bindings depend on the Api order within this struct so all additions must be at the end, - // and no functions can be removed (the implementation needs to change to return an error). + // NOTE: Application compatibility with newer versions of ORT depends on the Api order within this struct so + // all new functions must be added at the end, and no functions that already exist in an officially released version + // of ORT can be reordered or removed. &OrtCompileAPI::ReleaseModelCompilationOptions, &OrtCompileAPI::CreateModelCompilationOptionsFromSessionOptions, @@ -229,6 +245,9 @@ static constexpr OrtCompileApi ort_compile_api = { &OrtCompileAPI::ModelCompilationOptions_SetOutputModelBuffer, &OrtCompileAPI::ModelCompilationOptions_SetEpContextEmbedMode, &OrtCompileAPI::CompileModel, + // End of Version 22 - DO NOT MODIFY ABOVE + + &OrtCompileAPI::ModelCompilationOptions_SetFlags, }; // checks that we don't violate the rule that the functions must remain in the slots they were originally assigned diff --git a/onnxruntime/core/session/compile_api.h b/onnxruntime/core/session/compile_api.h index b8c5211526b9d..5f11b894f2004 100644 --- a/onnxruntime/core/session/compile_api.h +++ b/onnxruntime/core/session/compile_api.h @@ -28,5 +28,7 @@ ORT_API_STATUS_IMPL(ModelCompilationOptions_SetOutputModelBuffer, _In_ OrtModelC ORT_API_STATUS_IMPL(ModelCompilationOptions_SetEpContextEmbedMode, _In_ OrtModelCompilationOptions* model_compile_options, bool embed_ep_context_in_model); ORT_API_STATUS_IMPL(CompileModel, _In_ const OrtEnv* env, _In_ const OrtModelCompilationOptions* model_options); +ORT_API_STATUS_IMPL(ModelCompilationOptions_SetFlags, _In_ OrtModelCompilationOptions* model_options, + size_t flags); } // namespace OrtCompileAPI diff --git a/onnxruntime/core/session/model_compilation_options.cc b/onnxruntime/core/session/model_compilation_options.cc index d0cb092f78843..5de0f03fafc08 100644 --- a/onnxruntime/core/session/model_compilation_options.cc +++ b/onnxruntime/core/session/model_compilation_options.cc @@ -18,10 +18,11 @@ ModelCompilationOptions::ModelCompilationOptions(const onnxruntime::Environment& session_options_.value.has_explicit_ep_context_gen_options = true; session_options_.value.ep_context_gen_options = session_options.value.GetEpContextGenerationOptions(); session_options_.value.ep_context_gen_options.enable = true; - session_options_.value.ep_context_gen_options.overwrite_existing_output_file = true; - // defaulting to false to support wider usage. will log WARNING if compiling model with no context nodes. - // TODO: Add ability for user to explicitly set this. - session_options_.value.ep_context_gen_options.error_if_no_compiled_nodes = false; + session_options_.value.ep_context_gen_options.error_if_output_file_exists = false; + + // defaulting to kGenerateModel to support wider usage. + session_options_.value.ep_context_gen_options.action_if_no_compiled_nodes = + EpContextModelGenerationOptions::ActionIfNoCompiledNodes::kGenerateModel; // Shouldn't fail because the key/value strings are below the maximum string length limits in ConfigOptions. ORT_ENFORCE(session_options_.value.config_options.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1").IsOK()); @@ -104,6 +105,15 @@ Status ModelCompilationOptions::SetEpContextEmbedMode(bool embed_ep_context_in_m return Status::OK(); } +Status ModelCompilationOptions::SetFlags(size_t flags) { + EpContextModelGenerationOptions& options = session_options_.value.ep_context_gen_options; + options.error_if_output_file_exists = flags & OrtCompileApiFlags_ERROR_IF_OUTPUT_FILE_EXISTS; + options.action_if_no_compiled_nodes = + (flags & OrtCompileApiFlags_ERROR_IF_NO_NODES_COMPILED) ? EpContextModelGenerationOptions::ActionIfNoCompiledNodes::kReturnError + : EpContextModelGenerationOptions::ActionIfNoCompiledNodes::kGenerateModel; + return Status::OK(); +} + const OrtSessionOptions& ModelCompilationOptions::GetSessionOptions() const { return session_options_; } diff --git a/onnxruntime/core/session/model_compilation_options.h b/onnxruntime/core/session/model_compilation_options.h index 9238264003645..f96f0317cdaca 100644 --- a/onnxruntime/core/session/model_compilation_options.h +++ b/onnxruntime/core/session/model_compilation_options.h @@ -80,6 +80,13 @@ class ModelCompilationOptions { /// Status indicating potential error Status SetEpContextEmbedMode(bool embed_ep_context_in_model); + /// + /// Sets flags representing enabled boolean options defined in OrtCompileApiFlags. + /// + /// unsigned integer set to the bitwise OR of enabled flags. + /// Status indicating success or an error + Status SetFlags(size_t flags); + /// /// Returns a reference to the session options object. /// diff --git a/onnxruntime/core/util/shape_checker.h b/onnxruntime/core/util/shape_checker.h new file mode 100644 index 0000000000000..9c975275c45b9 --- /dev/null +++ b/onnxruntime/core/util/shape_checker.h @@ -0,0 +1,63 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include "core/framework/tensor_shape.h" + +namespace onnxruntime { + +template +TensorShape make_shape(Args... args) { + std::initializer_list dims = {args...}; + return TensorShape(dims); +} + +// This assumes the tensor is optional, and check wether its shape is expected. +#define ASSERT_TENSOR_DIMS(tensor, ...) \ + if (tensor != nullptr) { \ + static_assert(std::is_same::value, "tensor must be a pointer to a Tensor"); \ + const TensorShape& tensor_shape = tensor->Shape(); \ + const TensorShape& expected_shape = make_shape(__VA_ARGS__); \ + if (tensor_shape != expected_shape) { \ + return ORT_MAKE_STATUS( \ + ONNXRUNTIME, INVALID_ARGUMENT, "Input '" #tensor "' is expected to have shape ", expected_shape, \ + ", got ", tensor_shape); \ + } \ + } + +// This assumes the tensor is optional, and check wether its shape is expected. +#define ASSERT_TENSOR_SHAPE(tensor, shape) \ + if (tensor != nullptr) { \ + static_assert(std::is_same::value, "tensor must be a pointer to a Tensor"); \ + static_assert(std::is_same>, \ + TensorShape>::value, \ + "shape must be or refer to a TensorShape"); \ + const TensorShape& tensor_shape = tensor->Shape(); \ + if (tensor_shape != shape) { \ + return ORT_MAKE_STATUS( \ + ONNXRUNTIME, INVALID_ARGUMENT, "Input '" #tensor "' is expected to have shape ", shape, \ + ", got ", tensor_shape); \ + } \ + } + +// This assumes the tensor is optional, and check wether its shape is shape_1 or shape_2 when it is not null. +#define ASSERT_TENSOR_SHAPE_2(tensor, shape_1, shape_2) \ + if (tensor != nullptr) { \ + static_assert(std::is_same::value, "tensor must be a pointer to a Tensor"); \ + static_assert(std::is_same>, \ + TensorShape>::value, \ + "shape_1 must be or refer to a TensorShape"); \ + static_assert(std::is_same>, \ + TensorShape>::value, \ + "shape_2 must be or refer to a TensorShape"); \ + const TensorShape& tensor_shape = tensor->Shape(); \ + if (tensor_shape != shape_1 && tensor_shape != shape_2) { \ + return ORT_MAKE_STATUS( \ + ONNXRUNTIME, INVALID_ARGUMENT, "Input '" #tensor "' is expected to have shape ", shape_1, \ + " or ", shape_2, ", got ", tensor_shape); \ + } \ + } + +} // namespace onnxruntime diff --git a/onnxruntime/python/onnxruntime_inference_collection.py b/onnxruntime/python/onnxruntime_inference_collection.py index 15c423d7285bc..e8e51db13bcd3 100644 --- a/onnxruntime/python/onnxruntime_inference_collection.py +++ b/onnxruntime/python/onnxruntime_inference_collection.py @@ -644,6 +644,7 @@ def __init__( embed_compiled_data_into_model: bool = False, external_initializers_file_path: str | os.PathLike | None = None, external_initializers_size_threshold: int = 1024, + flags: int = C.OrtCompileApiFlags.NONE, ): """ Creates a ModelCompiler instance. @@ -658,6 +659,8 @@ def __init__( initializers for non-compiled nodes. :param external_initializers_size_threshold: Defaults to 1024. Ignored if `external_initializers_file_path` is None or empty. Initializers larger than this threshold are stored in the external initializers file. + :param flags: Additional boolean options to enable. Set this parameter to a bitwise OR of + flags in onnxruntime.OrtCompileApiFlags. """ input_model_path: str | os.PathLike | None = None input_model_bytes: bytes | None = None @@ -688,6 +691,7 @@ def __init__( embed_compiled_data_into_model, external_initializers_file_path, external_initializers_size_threshold, + flags, ) else: self._model_compiler = C.ModelCompiler( @@ -697,6 +701,7 @@ def __init__( embed_compiled_data_into_model, external_initializers_file_path, external_initializers_size_threshold, + flags, ) def compile_to_file(self, output_model_path: str | None = None): diff --git a/onnxruntime/python/onnxruntime_pybind_model_compiler.cc b/onnxruntime/python/onnxruntime_pybind_model_compiler.cc index 8bb7ee2098caf..4676efa13440b 100644 --- a/onnxruntime/python/onnxruntime_pybind_model_compiler.cc +++ b/onnxruntime/python/onnxruntime_pybind_model_compiler.cc @@ -19,7 +19,8 @@ onnxruntime::Status PyModelCompiler::Create(/*out*/ std::unique_ptr(env, sess_options, PrivateConstructorTag{}); ModelCompilationOptions& compile_options = model_compiler->model_compile_options_; @@ -38,6 +39,10 @@ onnxruntime::Status PyModelCompiler::Create(/*out*/ std::unique_ptrTrue to embed compiled binary data into EPContext nodes. /// The file into which to store initializers for non-compiled /// nodes. + /// Flags from OrtCompileApiFlags /// Ignored if 'external_initializers_file_path' is empty. /// Initializers with a size greater than this threshold are dumped into the external file. /// A Status indicating error or success. @@ -44,7 +45,8 @@ class PyModelCompiler { std::string&& input_model_path_or_bytes, bool input_model_is_path, bool embed_compiled_data_into_model = false, const std::string& external_initializers_file_path = {}, - size_t external_initializers_size_threshold = 1024); + size_t external_initializers_size_threshold = 1024, + size_t flags = 0); // Note: Creation should be done via Create(). This constructor is public so that it can be called from // std::make_shared(). diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index aa2c0cc6a0f86..5c389a85e5316 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -624,6 +624,14 @@ static std::shared_ptr CreateExecutionProviderFactory } else { ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_fp16_enable' should be 'True' or 'False'. Default value is 'False'.\n"); } + } else if (option.first == "trt_bf16_enable") { + if (option.second == "True" || option.second == "true") { + params.trt_bf16_enable = true; + } else if (option.second == "False" || option.second == "false") { + params.trt_bf16_enable = false; + } else { + ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_bf16_enable' should be 'True' or 'False'. Default value is 'False'.\n"); + } } else if (option.first == "trt_int8_enable") { if (option.second == "True" || option.second == "true") { params.trt_int8_enable = true; @@ -2782,6 +2790,11 @@ including arg name, arg type (contains both type and shape).)pbdoc") .value("kSameAsRequested", onnxruntime::ArenaExtendStrategy::kSameAsRequested) .export_values(); + py::enum_(m, "OrtCompileApiFlags", py::arithmetic()) + .value("NONE", OrtCompileApiFlags_NONE) + .value("ERROR_IF_NO_NODES_COMPILED", OrtCompileApiFlags_ERROR_IF_NO_NODES_COMPILED) + .value("ERROR_IF_OUTPUT_FILE_EXISTS", OrtCompileApiFlags_ERROR_IF_OUTPUT_FILE_EXISTS); + py::class_(m, "ModelCompiler", R"pbdoc(This is the class used to compile an ONNX model.)pbdoc") .def(py::init([](const PySessionOptions& sess_options, @@ -2789,14 +2802,16 @@ including arg name, arg type (contains both type and shape).)pbdoc") bool is_path, bool embed_compiled_data_into_model = false, std::string external_initializers_file_path = {}, - size_t external_initializers_size_threshold = 1024) { + size_t external_initializers_size_threshold = 1024, + size_t flags = OrtCompileApiFlags_NONE) { #if !defined(ORT_MINIMAL_BUILD) std::unique_ptr result; OrtPybindThrowIfError(PyModelCompiler::Create(result, GetEnv(), sess_options, std::move(path_or_bytes), is_path, embed_compiled_data_into_model, external_initializers_file_path, - external_initializers_size_threshold)); + external_initializers_size_threshold, + flags)); return result; #else ORT_UNUSED_PARAMETER(sess_options); @@ -2805,6 +2820,7 @@ including arg name, arg type (contains both type and shape).)pbdoc") ORT_UNUSED_PARAMETER(embed_compiled_data_into_model); ORT_UNUSED_PARAMETER(external_initializers_file_path); ORT_UNUSED_PARAMETER(external_initializers_size_threshold); + ORT_UNUSED_PARAMETER(flags); ORT_THROW("Compile API is not supported in this build."); #endif })) diff --git a/onnxruntime/python/tools/qnn/preprocess.py b/onnxruntime/python/tools/qnn/preprocess.py new file mode 100644 index 0000000000000..b7ddf1de9dc34 --- /dev/null +++ b/onnxruntime/python/tools/qnn/preprocess.py @@ -0,0 +1,139 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +"""Provide entry point to preprocess ONNX model especially for QNN.""" + +import argparse +import pathlib + +import onnx + +from onnxruntime.quantization.execution_providers import qnn + + +def _parse_arguments(): + """Parse cmdline arguments.""" + parser = argparse.ArgumentParser(description="Arguments for QNN model preprocess.") + + parser.add_argument("--input_model_path", "-i", required=True, help="Path to the input ONNX model.") + parser.add_argument("--output_model_path", "-o", required=True, help="Path to the output ONNX model.") + + # Save preprocessed model with external data. + parser.add_argument( + "--save_as_external_data", + action="store_true", + help="Whether the output model would be saved with external data.", + ) + parser.add_argument( + "--all_tensors_to_one_file", + action="store_true", + help="Whether to save all external data in one file or save each tensor to a file named with the tensor name.", + ) + parser.add_argument( + "--external_data_location", + help="Filename of the external file where all tensors are saved. The path is relative to the model path.", + ) + parser.add_argument( + "--external_data_size_threshold", + default=1024, + type=int, + help="Tensors with data size larger than this threshold are converted to external data.", + ) + parser.add_argument( + "--external_data_convert_attribute", + action="store_true", + help="Whether to save all tensors, including attribute tensors, to external data.", + ) + + # Preprocess options. + parser.add_argument( + "--fuse_layernorm", + action="store_true", + help="Whether to fuse matched sequences into LayerNormalization nodes if possible.", + ) + + # I/O layouts. + parser.add_argument( + "--inputs_to_make_channel_last", + nargs="+", + default=None, + help="List of graph input names to be transposed into channel-last.", + ) + + parser.add_argument( + "--outputs_to_make_channel_last", + nargs="+", + default=None, + help="List of graph output names to be transposed into channel-last.", + ) + + return parser.parse_args() + + +def qnn_preprocess_model( + model_input: str | pathlib.Path | onnx.ModelProto, + model_output: str | pathlib.Path, + fuse_layernorm: bool = False, + save_as_external_data: bool = False, + all_tensors_to_one_file: bool = False, + external_data_location: str | None = None, + external_data_size_threshold: int = 1024, + external_data_convert_attribute: bool = False, + inputs_to_make_channel_last: list[str] | None = None, + outputs_to_make_channel_last: list[str] | None = None, +) -> bool: + """Preprocess ONNX model for QNN. + + Args: + model_input: A path or ONNX ModelProto specifiying the model to be preprocessed. + model_output: A path specifying where the preprocessed model to be saved. + fuse_layernorm: A bool specifying whether to fuse the matched sequence into a single LayerNormalization node. + Defaults to False. + save_as_external_data: A bool specifying whether to save model with external data. Defaults to False. + all_tensors_to_one_file: A bool specifying whether to save all external data in one file or save each tensor to + a file named with the tensor name. This argument is effective only when `save_as_external_data` is True. + Defaults to False. + external_data_location: A str specifying where to save the external data. The path is relative to the model + path. This argument is effective only when `save_as_external_data` is True. Defaults to the model name. + external_data_size_threshold: An int specifying the threshold of data size for tensors be saved as external + data. This argument is effective only when `save_as_external_data` is True. Defaults to 1024. + external_data_convert_attribute: A bool specifying whether to save all tensors including attributes as external + data. This argument is effective only when `save_as_external_data` is True. Defaults to False. + inputs_to_make_channel_last: A list of strs specifying graph input names to be transposed into channel-last. + Defaults to None. + outputs_to_make_channel_last: A list of strs specifying graph output names to be transposed into channel-last. + Defaults to None. + + Returns: + A bool indicating whether the model is modified. + """ + return qnn.qnn_preprocess_model( + model_input, + model_output, + fuse_layernorm=fuse_layernorm, + save_as_external_data=save_as_external_data, + all_tensors_to_one_file=all_tensors_to_one_file, + external_data_location=external_data_location, + external_data_size_threshold=external_data_size_threshold, + external_data_convert_attribute=external_data_convert_attribute, + inputs_to_make_channel_last=inputs_to_make_channel_last, + outputs_to_make_channel_last=outputs_to_make_channel_last, + ) + + +if __name__ == "__main__": + args = _parse_arguments() + qnn_preprocess_model( + args.input_model_path, + args.output_model_path, + fuse_layernorm=args.fuse_layernorm, + save_as_external_data=args.save_as_external_data, + all_tensors_to_one_file=args.all_tensors_to_one_file, + external_data_location=args.external_data_location, + external_data_size_threshold=args.external_data_size_threshold, + external_data_convert_attribute=args.external_data_convert_attribute, + inputs_to_make_channel_last=args.inputs_to_make_channel_last, + outputs_to_make_channel_last=args.outputs_to_make_channel_last, + ) diff --git a/onnxruntime/python/tools/quantization/execution_providers/qnn/fusion_spacetodepth.py b/onnxruntime/python/tools/quantization/execution_providers/qnn/fusion_spacetodepth.py new file mode 100644 index 0000000000000..ce92b3e2a1d76 --- /dev/null +++ b/onnxruntime/python/tools/quantization/execution_providers/qnn/fusion_spacetodepth.py @@ -0,0 +1,162 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +"""Define SpaceToDepth fusion.""" + +import onnx + +from ... import fusions, onnx_model + + +class FusionSpaceToDepth(fusions.Fusion): + """Fusion for SpaceToDepth.""" + + def __init__(self, model: onnx_model.ONNXModel): + """Initialize. + + Args: + model: An onnx_model.ONNXModel instance. + """ + super().__init__(model, "SpaceToDepth", "Reshape") + + def _fuse_yolo( + self, + node: onnx.NodeProto, + input_name_to_nodes: dict[str, list[onnx.NodeProto]], + output_name_to_node: dict[str, onnx.NodeProto], + ): + """Fuse for early version of YOLO. + + Pattern: + + | [N, C, H, W] + Reshape + | [N, C, H/blk, blk, W/blk, blk] + Transpose + | [N, C, H/blk, W/blk, blk, blk] + Reshape + | [N, C, H/blk * W/blk, blk * blk] + Transpose + | [N, C, blk * blk, H/blk * W/blk] + Reshape + | [N, C, blk * blk, H/blk, W/blk] + Transpose + | [N, blk * blk, C, H/blk, W/blk] + Reshape + | [N, blk * blk * C, H/blk, W/blk] + + This sequence can be fused into a single SpaceToDepth with blocksize `blk`. Note that unlike DepthToSpace + supporting DCR or CRD mode, SpaceToDepth only supports DCR mode in its latest opset version (13), which matches + the pattern here. + """ + reshape_node1 = node + + def get_target_child(parent_node, target_op_type): + """Get target child of given node.""" + if parent_node.output[0] not in input_name_to_nodes: + return None + + children = input_name_to_nodes[parent_node.output[0]] + if len(children) > 1 or children[0].op_type != target_op_type: + return None + + return children[0] + + if ( + (transpose_node1 := get_target_child(reshape_node1, "Transpose")) is None + or (reshape_node2 := get_target_child(transpose_node1, "Reshape")) is None + or (transpose_node2 := get_target_child(reshape_node2, "Transpose")) is None + or (reshape_node3 := get_target_child(transpose_node2, "Reshape")) is None + or (transpose_node3 := get_target_child(reshape_node3, "Transpose")) is None + or (reshape_node4 := get_target_child(transpose_node3, "Reshape")) is None + ): + return False + + def get_tensor_shape(tensor_name): + """Get shape for given tensor name.""" + tensor_type = self.model.get_tensor_type(tensor_name) + if not tensor_type: + return None + + tensor_shape = self.tensor_shape_to_list(tensor_type) + if not tensor_shape: + return None + + return tensor_shape + + if ( + (input_shape := get_tensor_shape(reshape_node1.input[0])) is None + or (reshape_shape1 := get_tensor_shape(reshape_node1.output[0])) is None + or (reshape_shape2 := get_tensor_shape(reshape_node2.output[0])) is None + or (reshape_shape3 := get_tensor_shape(reshape_node3.output[0])) is None + or (reshape_shape4 := get_tensor_shape(reshape_node4.output[0])) is None + ): + return False + + transpose_perm1 = self.get_node_attribute(transpose_node1, "perm") + transpose_perm2 = self.get_node_attribute(transpose_node2, "perm") + transpose_perm3 = self.get_node_attribute(transpose_node3, "perm") + + # Check rank. + if ( + len(input_shape) != 4 + or len(reshape_shape1) != 6 + or len(reshape_shape2) != 4 + or len(reshape_shape3) != 5 + or len(reshape_shape4) != 4 + ): + return False + + # Check shape and perm. + batch, channel, height, width = input_shape + blocksize = reshape_shape1[3] + if ( + reshape_shape1 != [batch, channel, height // blocksize, blocksize, width // blocksize, blocksize] + or transpose_perm1 != [0, 1, 2, 4, 3, 5] + or reshape_shape2 != [batch, channel, (height // blocksize) * (width // blocksize), blocksize**2] + or transpose_perm2 != [0, 1, 3, 2] + or reshape_shape3 != [batch, channel, blocksize**2, height // blocksize, width // blocksize] + or transpose_perm3 != [0, 2, 1, 3, 4] + or reshape_shape4 != [batch, blocksize**2 * channel, height // blocksize, width // blocksize] + ): + return False + + self.nodes_to_remove.extend( + [ + reshape_node1, + transpose_node1, + reshape_node2, + transpose_node2, + reshape_node3, + transpose_node3, + reshape_node4, + ] + ) + + s2d_node = onnx.helper.make_node( + self.fused_op_type, + name=self.create_unique_node_name(), + inputs=[reshape_node1.input[0]], + outputs=[reshape_node4.output[0]], + blocksize=blocksize, + ) + self.nodes_to_add.append(s2d_node) + + return True + + def fuse( + self, + node: onnx.NodeProto, + input_name_to_nodes: dict[str, list[onnx.NodeProto]], + output_name_to_node: dict[str, onnx.NodeProto], + ): + """Fuse a sequence of Reshape and Transpose nodes into a single SpaceToDepth node. + + Args: + node: An onnx.NodeProto matching the specified search type (i.e., Reshape). + input_name_to_nodes: A dict mapping tensor name to consumed nodes. + output_name_to_node: A dict mapping tensor name to produced node. + """ + self._fuse_yolo(node, input_name_to_nodes, output_name_to_node) diff --git a/onnxruntime/python/tools/quantization/execution_providers/qnn/preprocess.py b/onnxruntime/python/tools/quantization/execution_providers/qnn/preprocess.py index 85f5d967f9ee3..44ff7e4aba10b 100644 --- a/onnxruntime/python/tools/quantization/execution_providers/qnn/preprocess.py +++ b/onnxruntime/python/tools/quantization/execution_providers/qnn/preprocess.py @@ -12,7 +12,9 @@ from ...fusions import FusionGelu, FusionLayerNormalization from ...onnx_model import ONNXModel +from ...quant_utils import save_and_reload_model_with_shape_infer from .fusion_lpnorm import FusionLpNormalization +from .fusion_spacetodepth import FusionSpaceToDepth def qnn_preprocess_model( @@ -83,6 +85,7 @@ def qnn_preprocess_model( """ modified = False model = model_input if isinstance(model_input, onnx.ModelProto) else onnx.load_model(model_input) + model = save_and_reload_model_with_shape_infer(model) onnx_model = ONNXModel(model) # Fuse Erf sequence into a single Gelu @@ -95,6 +98,11 @@ def qnn_preprocess_model( if fusion_lpnorm.apply(): modified = True + # Fuse Reshape/Transpose sequence into a single SpaceToDepth. + fusion_s2d = FusionSpaceToDepth(onnx_model) + if fusion_s2d.apply(): + modified = True + # Optionally, fuse ReduceMean sequence into a single LayerNormalization node. if fuse_layernorm: onnx_opset = next(x for x in model.opset_import if x.domain == "" or x.domain == "ai.onnx") diff --git a/onnxruntime/python/tools/quantization/matmul_nbits_quantizer.py b/onnxruntime/python/tools/quantization/matmul_nbits_quantizer.py index 5428898b1c642..c7a832420203d 100644 --- a/onnxruntime/python/tools/quantization/matmul_nbits_quantizer.py +++ b/onnxruntime/python/tools/quantization/matmul_nbits_quantizer.py @@ -904,7 +904,9 @@ def quantize_matmul(self, node: NodeProto, graph_stack: list[GraphProto]) -> lis kwargs["N"] = cols kwargs["bits"] = bits kwargs["block_size"] = self.config.block_size - if self.config.accuracy_level is not None: + + # Do not output accuracy_level if it is 0 since the attribute is optional and is not supported by most EPs. + if self.config.accuracy_level: kwargs["accuracy_level"] = self.config.accuracy_level matmul_qbit_node = onnx.helper.make_node( diff --git a/onnxruntime/python/tools/tensorrt/gen_trt_engine_wrapper_onnx_model.py b/onnxruntime/python/tools/tensorrt/gen_trt_engine_wrapper_onnx_model.py index 1180945d5b5dc..5183ae9a72246 100644 --- a/onnxruntime/python/tools/tensorrt/gen_trt_engine_wrapper_onnx_model.py +++ b/onnxruntime/python/tools/tensorrt/gen_trt_engine_wrapper_onnx_model.py @@ -114,6 +114,8 @@ def trt_data_type_to_onnx_data_type(self, trt_data_type): return TensorProto.FLOAT elif trt_data_type == trt.DataType.HALF: return TensorProto.FLOAT16 + elif trt_data_type == trt.DataType.BF16: + return TensorProto.BFLOAT16 elif trt_data_type == trt.DataType.INT8: return TensorProto.INT8 elif trt_data_type == trt.DataType.INT32: @@ -122,6 +124,8 @@ def trt_data_type_to_onnx_data_type(self, trt_data_type): return TensorProto.BOOL elif trt_data_type == trt.DataType.UINT8: return TensorProto.UINT8 + elif trt_data_type == trt.DataType.INT64: + return TensorProto.INT64 else: return TensorProto.UNDEFINED diff --git a/onnxruntime/python/tools/transformers/convert_generation.py b/onnxruntime/python/tools/transformers/convert_generation.py index 8eb2afb3db896..ed89d00bdc069 100644 --- a/onnxruntime/python/tools/transformers/convert_generation.py +++ b/onnxruntime/python/tools/transformers/convert_generation.py @@ -1447,7 +1447,7 @@ def add_output_qk_to_mha(model: OnnxModel, dtype: int = 0, skip_node_idxs: list[ return model -def fix_past_sequence_length(model: ModelProto): +def fix_past_sequence_length(model: OnnxModel): # Modify total_sequence_length = past_sequence_length + curr_sequence_length subgraph to calculate # past_sequence_length from the new `past_sequence_length` input of size 1D and type int32 instead of # from `past_key_self_0` since DecoderMaskedMultiHeadAttention (DMMHA) uses buffer sharing and @@ -1480,56 +1480,119 @@ def fix_past_sequence_length(model: ModelProto): # | # Add + # Constant names to be used + past_seq_len_name = "past_sequence_length" + past_seq_len_int32 = "past_seq_len_int32" + past_seq_len_int64 = "past_seq_len_int64" + node = list(filter(lambda n: n.op_type == "LayerNormalization", model.model.graph.node))[0] # noqa: RUF015 - base_path = model.match_parent_path( + base_path_hf = model.match_parent_path( + node, + ["Add", "Gather", "Tile", "Expand", "Unsqueeze", "Range"], + [0, 1, 1, 0, 0, 0], + ) + base_path_oai = model.match_parent_path( node, ["Add", "Slice"], [0, 1], ) - if base_path is None: + if base_path_hf is not None: + base_path = base_path_hf + elif base_path_oai is not None: + base_path = base_path_oai + else: + logger.info("Cannot identify base path for fixing past_sequence_length subgraph") return + base_node = base_path[-1] - left_path = model.match_parent_path( - base_path[-1], - ["Unsqueeze", "Add", "Gather", "Shape"], - [2, 0, 0, 0], - ) - right_path = model.match_parent_path( - base_path[-1], - ["Unsqueeze", "Gather", "Shape"], - [1, 0, 0], - ) - long_right_path = model.match_parent_path( - base_path[-1], - ["Unsqueeze", "Gather", "Shape", "Reshape", "Transpose"], - [1, 0, 0, 0, 0], - ) - if left_path is None or right_path is None or left_path[-2:] != right_path[-2:]: - return + if base_node.op_type == "Range": + # Hugging Face implementation + range_node = base_path[-1] + + gather_path = model.match_parent_path( + range_node, + ["Gather", "Shape"], + [0, 0], + ) + if gather_path is None: + logger.info("Cannot identify gather path for fixing past_sequence_length subgraph") + return + + add_path = model.match_parent_path( + range_node, + ["Add", "Gather", "Shape"], + [1, 0, 0], + ) + if add_path is None: + logger.info("Cannot identify add path for fixing past_sequence_length subgraph") + return + add_node = add_path[0] + + if gather_path != add_path[1:]: + logger.info("Gather path and add path do not share the same nodes for calculating the past_sequence_length") + return + + # Remove `past_key_self_0 --> Shape --> Gather` connection + constant_in_gather = list(filter(lambda n: n.output[0] == gather_path[0].input[1], model.model.graph.node))[0] # noqa: RUF015 + model.model.graph.node.remove(constant_in_gather) + model.model.graph.node.remove(gather_path[0]) + model.model.graph.node.remove(gather_path[1]) + + # Add `past_seq_len_int64` as an input name to existing nodes + range_node.input[0] = past_seq_len_int64 + add_node.input[0] = past_seq_len_int64 - # Remove `past_key_self_0 --> [Transpose --> Reshape] --> Shape --> Gather` connection - # where `Transpose --> Reshape` part may or may not exist. The OpenAI implementation of - # Whisper has an extra `Transpose --> Reshape` connection to remove. - constant_node = list(filter(lambda n: n.output[0] == left_path[-2].input[1], model.model.graph.node))[0] # noqa: RUF015 - model.model.graph.node.remove(left_path[-2]) - model.model.graph.node.remove(left_path[-1]) - model.model.graph.node.remove(constant_node) - if long_right_path is not None: - # Remove `Transpose --> Reshape` part - model.model.graph.node.remove(long_right_path[-2]) - model.model.graph.node.remove(long_right_path[-1]) + else: + # OpenAI implementation + input_ids_path = model.match_parent_path( + base_node, + ["Unsqueeze", "Add", "Gather", "Shape", "Reshape", "Transpose"], + [2, 0, 0, 0, 0, 0], + ) + if input_ids_path is None: + logger.info("Cannot identify input_ids path for fixing past_sequence_length subgraph") + return + add_node = input_ids_path[1] + + past_key_path = model.match_parent_path( + base_node, + ["Unsqueeze", "Gather", "Shape", "Reshape", "Transpose"], + [1, 0, 0, 0, 0], + ) + if past_key_path is None: + logger.info("Cannot identify past_key path for fixing past_sequence_length subgraph") + return + unsqueeze_node = past_key_path[0] + + if input_ids_path[2:] != past_key_path[1:]: + logger.info( + "The input_ids path and past_key path do not share the same nodes for calculating the past_sequence_length" + ) + return + + # Remove `past_key_self_0 --> Transpose --> Reshape --> Shape --> Gather` connection + constant_in_gather = list(filter(lambda n: n.output[0] == past_key_path[1].input[1], model.model.graph.node))[0] # noqa: RUF015 + model.model.graph.node.remove(constant_in_gather) + constant_in_reshape = list(filter(lambda n: n.output[0] == past_key_path[-2].input[1], model.model.graph.node))[ # noqa: RUF015 + 0 + ] + model.model.graph.node.remove(constant_in_reshape) + model.model.graph.node.remove(past_key_path[1]) + model.model.graph.node.remove(past_key_path[2]) + model.model.graph.node.remove(past_key_path[3]) + model.model.graph.node.remove(past_key_path[4]) + + # Add `past_seq_len_int64` as an input name to existing nodes + unsqueeze_node.input[0] = past_seq_len_int64 + add_node.input[0] = past_seq_len_int64 # Add `past_sequence_length` as model input - past_seq_len_name = "past_sequence_length" model.model.graph.input.append( onnx.helper.make_tensor_value_info(past_seq_len_name, TensorProto.INT32, shape=[1]), ) # Add `past_sequence_length --> Squeeze --> Cast` connection - past_seq_len_int32 = "past_seq_len_int32" - past_seq_len_int64 = "past_seq_len_int64" - squeeze_node = onnx.helper.make_node( "Squeeze", inputs=[past_seq_len_name], @@ -1546,14 +1609,9 @@ def fix_past_sequence_length(model: ModelProto): ) cast_output = onnx.helper.make_tensor_value_info(past_seq_len_int64, TensorProto.INT64, shape=[]) - model.model.graph.value_info.extend([squeeze_output, cast_output]) - - # Add `past_seq_len_int64` as an input name to existing nodes - left_path[1].input[0] = past_seq_len_int64 - right_path[0].input[0] = past_seq_len_int64 - # Add new nodes to graph model.model.graph.node.extend([squeeze_node, cast_node]) + model.model.graph.value_info.extend([squeeze_output, cast_output]) model.topological_sort() return model, past_seq_len_name diff --git a/onnxruntime/python/tools/transformers/fusion_attention.py b/onnxruntime/python/tools/transformers/fusion_attention.py index 5e1d491daae23..08f8691d8b2b5 100644 --- a/onnxruntime/python/tools/transformers/fusion_attention.py +++ b/onnxruntime/python/tools/transformers/fusion_attention.py @@ -663,12 +663,12 @@ def create_attention_node( first_input: str, output: str, add_qk_str: str = "", + causal: bool = False, past_k: str = "", past_v: str = "", present_k: str = "", present_v: str = "", scale: float | None = None, - causal: bool = False, ) -> NodeProto | None: """Create an Attention node. @@ -685,12 +685,12 @@ def create_attention_node( first_input (str): first input name output (str): output name add_qk_str (str): name of Add node after Q x K' + causal: whether it is uni-directional mask. past_k (str): name of input for past K value past_v (str): name of input for past V value present_k (str): name of output to store present K value present_v (str): name of output to store present V value scale: scale before softmax - causal: whether it is uni-directional mask. Returns: Union[NodeProto, None]: the node created or None if failed. diff --git a/onnxruntime/python/tools/transformers/fusion_bart_attention.py b/onnxruntime/python/tools/transformers/fusion_bart_attention.py index 45bbfa94f6aa2..76dfeb76e4e8d 100644 --- a/onnxruntime/python/tools/transformers/fusion_bart_attention.py +++ b/onnxruntime/python/tools/transformers/fusion_bart_attention.py @@ -6,7 +6,7 @@ import numpy as np from fusion_attention import AttentionMask, FusionAttention -from onnx import TensorProto, helper +from onnx import helper from onnx_model import OnnxModel logger = logging.getLogger(__name__) @@ -26,115 +26,9 @@ def __init__( ): super().__init__(model, hidden_size, num_heads, attention_mask) - def check_runtime_shape_path( - self, - reshape_qkv_2, - reshape_qkv_1, - reshape_q_2, - reshape_k_2, - reshape_v_2, - root_input, - ): - concat_qkv_2_path = self.model.match_parent_path(reshape_qkv_2, ["Concat"], [1]) - if concat_qkv_2_path is None: - return False - concat_qkv_2 = concat_qkv_2_path[0] - - reshape_qkv_2_path_1 = self.model.match_parent_path(concat_qkv_2, ["Unsqueeze", "Gather", "Shape"], [0, 0, 0]) - reshape_qkv_2_path_2 = self.model.match_parent_path(concat_qkv_2, ["Unsqueeze", "Gather", "Shape"], [1, 0, 0]) - if reshape_qkv_2_path_1 is None or reshape_qkv_2_path_2 is None: - return False - - _, gather_1, shape_1 = reshape_qkv_2_path_1 - _, gather_2, shape_2 = reshape_qkv_2_path_2 - - if shape_1.input[0] != root_input or shape_2.input[0] != root_input: - return False - - reshape_qkv_1_path_1 = self.model.match_parent_path(reshape_qkv_1, ["Concat", "Unsqueeze", "Gather"], [1, 0, 0]) - reshape_qkv_1_path_2 = self.model.match_parent_path(reshape_qkv_1, ["Concat", "Unsqueeze", "Gather"], [1, 2, 0]) - if reshape_qkv_1_path_1 is None or reshape_qkv_1_path_2 is None: - return False - if reshape_qkv_1_path_1[-1].name != gather_1.name or reshape_qkv_1_path_2[-1].name != gather_2.name: - return False - - reshape_q_2_path = self.model.match_parent_path(reshape_q_2, ["Concat", "Unsqueeze", "Mul"], [1, 0, 0]) - reshape_k_2_path = self.model.match_parent_path(reshape_k_2, ["Concat", "Unsqueeze", "Mul"], [1, 0, 0]) - reshape_v_2_path = self.model.match_parent_path(reshape_v_2, ["Concat", "Unsqueeze", "Mul"], [1, 0, 0]) - if reshape_q_2_path is None or reshape_k_2_path is None or reshape_v_2_path is None: - return False - - mul_q = reshape_q_2_path[-1] - mul_k = reshape_k_2_path[-1] - mul_v = reshape_v_2_path[-1] - - gather_1_out = gather_1.output[0] - if mul_q.input[0] != gather_1_out or mul_k.input[0] != gather_1_out or mul_v.input[0] != gather_1_out: - return False - - return True - - def check_runtime_shape_path_openai( - self, - reshape_qkv_2, - matmul_qkv, - add_qk, - matmul_qk, - add_q, - ): - reshape_qkv_path = self.model.match_parent_path( - reshape_qkv_2, ["Concat", "Slice", "Shape", "Transpose"], [1, 0, 0, 0] - ) - if reshape_qkv_path is None or reshape_qkv_path[-1].input[0] != matmul_qkv.output[0]: - return False - - matmul_qk_path_1 = self.model.match_parent_path( - matmul_qk, ["Mul", "Pow", "Cast", "Div", "Gather", "Shape"], [0, 1, 0, 0, 0, 0] - ) - matmul_qk_path_2 = self.model.match_parent_path( - matmul_qk, ["Mul", "Pow", "Cast", "Div", "Gather", "Shape"], [1, 1, 0, 0, 0, 0] - ) - if matmul_qk_path_1 is None or matmul_qk_path_2 is None: - return False - - mul_1 = matmul_qk_path_1[0] - mul_2 = matmul_qk_path_2[0] - if mul_1.input[1] != mul_2.input[1]: - return False - if matmul_qk_path_1[-1].input[0] != add_q.output[0] and matmul_qk_path_2[-1].input[0] != add_q.output[0]: - return False - - # For decoder attentions only - if add_qk is not None: - add_qk_path = self.model.match_parent_path(add_qk, ["Slice"], [1]) - if add_qk_path is None: - return False - slice_q_path_1 = self.model.match_parent_path( - add_qk_path[0], ["Slice", "Unsqueeze", "Gather", "Shape"], [0, 2, 0, 0] - ) - slice_q_path_2 = self.model.match_parent_path(add_qk_path[0], ["Unsqueeze", "Gather", "Shape"], [2, 0, 0]) - if slice_q_path_1 is None and slice_q_path_2 is None: - return False - _, unsqueeze_1, _, _ = slice_q_path_1 - unsqueeze_2, _, _ = slice_q_path_2 - if unsqueeze_1.input[0] != unsqueeze_2.input[0]: - return False - if slice_q_path_1[-1].input[0] != add_q.output[0] and slice_q_path_2[-1].input[0] != add_q.output[0]: - return False - - return True - def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): - # Track if fusion is occurring for OpenAI implementation of Whisper - model_impl_openai = False - # SkipLayerNormalization has two inputs, and one of them is the root input for attention. qkv_nodes = self.model.match_parent_path( - normalize_node, - ["Add", "MatMul", "Reshape", "Transpose", "Reshape", "MatMul"], - [1, 1, 0, 0, 0, 0], - ) - qkv_nodes_openai = self.model.match_parent_path( normalize_node, ["Add", "MatMul", "Reshape", "Transpose", "MatMul"], [1, 1, 0, 0, 0], @@ -143,32 +37,21 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): ( add_out, matmul_out, - reshape_qkv_2, - transpose_qkv, - reshape_qkv_1, - matmul_qkv, - ) = qkv_nodes - elif qkv_nodes_openai is not None: - qkv_nodes = qkv_nodes_openai - ( - add_out, - matmul_out, - reshape_qkv_2, + reshape_qkv, transpose_qkv, matmul_qkv, ) = qkv_nodes - # Set model implementation to openai - model_impl_openai = True else: + logger.debug("fuse_attention: failed to match qkv path") return other_inputs = [] - for input in normalize_node.input: - if input not in output_name_to_node: + for input_ in normalize_node.input: + if input_ not in output_name_to_node: continue - if input == qkv_nodes[0].output[0]: + if input_ == qkv_nodes[0].output[0]: continue - other_inputs.append(input) + other_inputs.append(input_) if len(other_inputs) != 1: return root_input = other_inputs[0] @@ -185,9 +68,9 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): SkipLayerNormalization --> Attention --> MatMul --> SkipLayerNormalization """ skip_layernorm = output_name_to_node[root_input] - # For some attention blocks, the end SkipLayerNormalization node may point to an Add node whose + # For some attention blocks, the end SkipLayerNormalization node may point to another node whose # child is the LayerNormalization node. - if skip_layernorm.op_type == "Add": + if skip_layernorm.op_type in {"Add", "Clip"}: skip_layernorm = self.model.get_children(skip_layernorm)[0] for output in skip_layernorm.output: if not output: @@ -201,304 +84,203 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): graph_input_names = {node.name for node in self.model.graph().input} graph_output_names = {node.name for node in self.model.graph().output} - v_nodes = self.model.match_parent_path( - matmul_qkv, - ["Reshape", "Transpose", "Reshape", "Add", "MatMul"], - [1, 0, 0, 0, None], - ) - v_nodes_openai = self.model.match_parent_path( + v_nodes_past_or_present = self.model.match_parent_path( matmul_qkv, ["Transpose", "Reshape", "Add", "MatMul"], [1, 0, 0, None], ) - v_nodes_with_past_self_attn = self.model.match_parent_path( - # Decoder attention with past value concatenated before MatMul + v_nodes_with_past = self.model.match_parent_path( matmul_qkv, - ["Reshape", "Concat", "Transpose", "Reshape", "Add", "MatMul"], - [1, 0, 1, 0, 0, None], + ["Concat", "Transpose", "Reshape", "Add", "MatMul"], + [1, 1, 0, 0, None], ) - v_nodes_with_past_cross_attn = self.model.match_parent_path( - # Decoder attention with past value directly used in MatMul - matmul_qkv, - ["Reshape"], - [1], - ) - v_nodes_with_past_cross_attn_openai = self.model.match_parent_path( + v_nodes_past_only_oai = self.model.match_parent_path( matmul_qkv, ["Transpose", "Reshape", "Reshape", "Transpose"], [1, 0, 0, 0], ) past_v, present_v = "", "" - reshape_v_2, add_v = None, None - if v_nodes is not None: - (reshape_v_2, transpose_v, reshape_v_1, add_v, matmul_v) = v_nodes - # For initial pass through encoder-decoder_with_past to get starting past values (beam search) - present_v = transpose_v.output[0] - elif v_nodes_openai is not None: - v_nodes = v_nodes_openai - (transpose_v, reshape_v_1, add_v, matmul_v) = v_nodes - # For initial pass through encoder-decoder_with_past to get starting past values (beam search) - - # Find the child path to access the correct present_v values - # Openai impl provides present/past v values in 3D format - # whereas ort MultiHeadAttention expects v values in 4D, hence the - # additional Reshape and Transpose nodes are added - # For encoder attention types - # Add -> Reshape -> Transpose -> Present_V - reshape_path = self.model.match_child_path( - add_v, - ["Reshape", "Transpose"], - exclude=[reshape_v_1], - ) - # For decoder attention types - # add_v_node Reshape <- Transpose <-Past_V - # \ / - # \ / - # -> Concat <- - # | - # |--> Reshape -> Transpose -> Present_V - concat_path = self.model.match_child_path(add_v, ["Concat", "Reshape", "Transpose"]) - if reshape_path is not None: - (_, transpose_add_v) = reshape_path - if transpose_add_v.output[0] in graph_output_names: - present_v = transpose_add_v.output[0] - if concat_path is not None: - (concat_v, _, transpose_concat_v) = concat_path - if transpose_concat_v.output[0] in graph_output_names: - present_v = transpose_concat_v.output[0] - concat_nodes = self.model.match_parent_path(concat_v, ["Reshape", "Transpose"], [0, 0]) - _, transpose_concat_v_in = concat_nodes - past_v = transpose_concat_v_in.input[0] - elif v_nodes_with_past_self_attn is not None: - (reshape_v_2, concat_v, transpose_v, reshape_v_1, add_v, matmul_v) = v_nodes_with_past_self_attn - v_nodes = v_nodes_with_past_self_attn + v_nodes, add_v, matmul_v = [], None, None + if v_nodes_past_or_present is not None: + v_nodes = v_nodes_past_or_present + (transpose_v, reshape_v, add_v, matmul_v) = v_nodes + + # Find past_v input name + start_child_nodes = input_name_to_nodes[add_v.output[0]] + for start_child_node in start_child_nodes: + if start_child_node.op_type == "Concat": + concat_v_nodes = self.model.match_parent_path( + start_child_node, + ["Reshape", "Transpose"], + [0, 0], + ) + if concat_v_nodes is not None: + past_v = concat_v_nodes[-1].input[0] + start_child_nodes = input_name_to_nodes[start_child_node.output[0]] + break + + # Find present_v output name + for start_child_node in start_child_nodes: + start_grandchild_nodes = input_name_to_nodes[start_child_node.output[0]] + for start_grandchild_node in start_grandchild_nodes: + if start_grandchild_node.output[0] in graph_output_names: + present_v = start_grandchild_node.output[0] + break + if present_v != "": + break + elif v_nodes_with_past is not None: + v_nodes = v_nodes_with_past + (concat_v, transpose_v, reshape_v, add_v, matmul_v) = v_nodes past_v = concat_v.input[0] present_v = concat_v.output[0] - elif ( - v_nodes_with_past_cross_attn is not None and v_nodes_with_past_cross_attn[-1].input[0] in graph_input_names - ): - v_nodes = v_nodes_with_past_cross_attn - past_v = v_nodes[-1].input[0] - present_v = v_nodes[-1].output[0] - if present_v not in graph_output_names: - identity_node_v = list( - filter(lambda node: node.op_type == "Identity", self.model.input_name_to_nodes()[past_v]) - ) - present_v = identity_node_v[0].output[0] if len(identity_node_v) == 1 else "" - elif ( - v_nodes_with_past_cross_attn_openai is not None - and v_nodes_with_past_cross_attn_openai[-1].input[0] in graph_input_names - ): - v_nodes = v_nodes_with_past_cross_attn_openai + elif matmul_qkv.input[1] in graph_input_names: + # Hugging Face's cross-attention where past_v is used directly as value + past_v = matmul_qkv.input[1] + elif v_nodes_past_only_oai is not None: + # OpenAI's cross-attention where past_v is used directly as value + v_nodes = v_nodes_past_only_oai past_v = v_nodes[-1].input[0] - present_v = v_nodes[-1].output[0] - if present_v not in graph_output_names: - identity_node_v = list( - filter(lambda node: node.op_type == "Identity", self.model.input_name_to_nodes()[past_v]) - ) - present_v = identity_node_v[0].output[0] if len(identity_node_v) == 1 else "" else: logger.debug("fuse_attention: failed to match v path") return past_v = past_v if past_v in graph_input_names else "" present_v = present_v if present_v in graph_output_names else "" - qk_nodes_1 = self.model.match_parent_path(matmul_qkv, ["Softmax", "MatMul"], [0, 0]) - qk_nodes_2 = self.model.match_parent_path( - matmul_qkv, ["Softmax", "Reshape", "Add", "Reshape", "MatMul"], [0, 0, 0, 0, 0] - ) - qk_nodes_2_openai = self.model.match_parent_path(matmul_qkv, ["Softmax", "Add", "MatMul"], [0, 0, 0]) - add_qk = None - if qk_nodes_1 is not None: - _, matmul_qk = qk_nodes_1 - qk_nodes = qk_nodes_1 - elif qk_nodes_2 is not None: - _, _, add_qk, _, matmul_qk = qk_nodes_2 - qk_nodes = qk_nodes_2 - elif qk_nodes_2_openai is not None: - _, add_qk, matmul_qk = qk_nodes_2_openai - qk_nodes = qk_nodes_2_openai + qk_nodes_no_mask = self.model.match_parent_path(matmul_qkv, ["Softmax", "MatMul"], [0, 0]) + qk_nodes_with_mask = self.model.match_parent_path(matmul_qkv, ["Softmax", "Add", "MatMul"], [0, 0, 0]) + qk_nodes, add_qk = [], None + if qk_nodes_no_mask is not None: + _, matmul_qk = qk_nodes_no_mask + qk_nodes = qk_nodes_no_mask + elif qk_nodes_with_mask is not None: + _, add_qk, matmul_qk = qk_nodes_with_mask + qk_nodes = qk_nodes_with_mask else: + logger.debug("fuse_attention: failed to match qk path") return - q_nodes = self.model.match_parent_path( + q_nodes_hf = self.model.match_parent_path( matmul_qk, - ["Reshape", "Transpose", "Reshape", "Mul", "Add", "MatMul"], - [0, 0, 0, 0, 0, 1], + ["Transpose", "Reshape", "Mul", "Add", "MatMul"], + [0, 0, 0, 0, 1], ) - q_nodes_openai = self.model.match_parent_path( + q_nodes_oai = self.model.match_parent_path( matmul_qk, ["Mul", "Transpose", "Reshape", "Add", "MatMul"], [0, 0, 0, 0, 1], ) - reshape_q_2 = None - if q_nodes is not None: - reshape_q_2, transpose_q, reshape_q_1, mul_q, add_q, matmul_q = q_nodes - elif q_nodes_openai is not None: - q_nodes = q_nodes_openai - mul_q, transpose_q, reshape_q_1, add_q, matmul_q = q_nodes + q_nodes = [] + if q_nodes_hf is not None: + q_nodes = q_nodes_hf + (transpose_q, reshape_q, mul_q, add_q, matmul_q) = q_nodes + elif q_nodes_oai is not None: + q_nodes = q_nodes_oai + (mul_q, transpose_q, reshape_q, add_q, matmul_q) = q_nodes else: + logger.debug("fuse_attention: failed to match q path") return - k_nodes_with_bias = self.model.match_parent_path( - matmul_qk, - ["Transpose", "Reshape", "Transpose", "Reshape", "Add", "MatMul"], - [1, 0, 0, 0, 0, 1], - ) - k_nodes_no_bias_openai = self.model.match_parent_path( + k_nodes_no_past_hf = self.model.match_parent_path( matmul_qk, - ["Mul", "Transpose", "Reshape", "MatMul"], - [1, 0, 0, 0], - ) - k_nodes_no_bias = self.model.match_parent_path( - matmul_qk, - ["Transpose", "Reshape", "Transpose", "Reshape", "MatMul"], - [1, 0, 0, 0, 0], + ["Transpose", "Reshape", "MatMul"], + [1, 0, 0], ) - k_nodes_no_bias_with_past_self_attn = self.model.match_parent_path( - # Decoder attention with past key concatenated before MatMul + k_nodes_with_past_hf = self.model.match_parent_path( matmul_qk, - ["Transpose", "Reshape", "Concat", "Transpose", "Reshape", "MatMul"], - [1, 0, 0, 1, 0, 0], + ["Transpose", "Concat", "Transpose", "Reshape", "MatMul"], + [1, 0, 1, 0, 0], ) - k_nodes_no_bias_with_past_cross_attn = self.model.match_parent_path( - # Decoder attention with past key directly used in MatMul + k_nodes_past_or_present_oai = self.model.match_parent_path( matmul_qk, - ["Transpose", "Reshape"], - [1, 0], + ["Mul", "Transpose", "Reshape", "MatMul"], + [1, 0, 0, 0], ) - k_nodes_no_bias_with_past_cross_attn_openai = self.model.match_parent_path( - # Decoder attention with past key directly used in MatMul + k_nodes_past_only_oai = self.model.match_parent_path( matmul_qk, ["Mul", "Transpose", "Reshape", "Reshape", "Transpose"], [1, 0, 0, 0, 0], ) past_k, present_k = "", "" - reshape_k_2, reshape_k_1, matmul_k = None, None, None - if k_nodes_with_bias is not None: - _, reshape_k_2, transpose_k_1, reshape_k_1, add_k, matmul_k = k_nodes_with_bias - k_nodes = k_nodes_with_bias - elif k_nodes_no_bias_openai is not None: - mul_k, transpose_k_1, reshape_k_1, matmul_k = k_nodes_no_bias_openai - k_nodes = k_nodes_no_bias_openai - present_k = matmul_k.output[0] - - # Find the child path to access the correct present_k values - # Openai impl provides present/past k values in 3D format - # whereas ort MultiHeadAttention expects k values in 4D, hence the - # additional Reshape and Transpose nodes are added - # For encoder attention types - # Matmul -> Reshape -> Transpose -> Present_K - reshape_path = self.model.match_child_path( - matmul_k, - ["Reshape", "Transpose"], - exclude=[reshape_k_1], - ) - # For decoder attention types - # matmul_k_node Reshape <- Transpose <- Past_K - # \ / - # \ / - # -> Concat <- - # | - # +--> Reshape -> Transpose -> Present_K - concat_path = self.model.match_child_path(matmul_k, ["Concat", "Reshape", "Transpose"]) - if reshape_path is not None: - (_, transpose_matmul_k) = reshape_path - if transpose_matmul_k.output[0] in graph_output_names: - present_k = transpose_matmul_k.output[0] - if concat_path is not None: - (concat_k, _, transpose_concat_k) = concat_path - if transpose_concat_k.output[0] in graph_output_names: - present_k = transpose_concat_k.output[0] - concat_nodes = self.model.match_parent_path(concat_k, ["Reshape", "Transpose"], [0, 0]) - _, transpose_concat_k_in = concat_nodes - past_k = transpose_concat_k_in.input[0] - elif k_nodes_no_bias is not None: - _, reshape_k_2, transpose_k_1, reshape_k_1, matmul_k = k_nodes_no_bias - k_nodes = k_nodes_no_bias - # For initial pass through encoder-decoder_with_past to get starting past values (beam search) - present_k = transpose_k_1.output[0] - elif k_nodes_no_bias_with_past_self_attn is not None: - _, reshape_k_2, concat_k, _, reshape_k_1, matmul_k = k_nodes_no_bias_with_past_self_attn - k_nodes = k_nodes_no_bias_with_past_self_attn + k_nodes, add_k, matmul_k = [], None, None + if k_nodes_no_past_hf is not None: + k_nodes = k_nodes_no_past_hf + (transpose_k, reshape_k, matmul_k) = k_nodes + + # Find present_k output name + transpose_k_nodes = input_name_to_nodes[reshape_k.output[0]] + for transpose_k_node in transpose_k_nodes: + if transpose_k_node.output[0] in graph_output_names: + present_k = transpose_k_node.output[0] + break + elif k_nodes_with_past_hf is not None: + k_nodes = k_nodes_with_past_hf + (_, concat_k, transpose_k, reshape_k, matmul_k) = k_nodes past_k = concat_k.input[0] present_k = concat_k.output[0] - elif ( - k_nodes_no_bias_with_past_cross_attn is not None - and k_nodes_no_bias_with_past_cross_attn[-1].input[0] in graph_input_names - ): - k_nodes = k_nodes_no_bias_with_past_cross_attn - past_k = k_nodes[-1].input[0] - present_k = k_nodes[-1].output[0] - if present_k not in graph_output_names: - identity_node_k = list( - filter(lambda node: node.op_type == "Identity", self.model.input_name_to_nodes()[past_k]) - ) - present_k = identity_node_k[0].output[0] if len(identity_node_k) == 1 else "" - elif ( - k_nodes_no_bias_with_past_cross_attn_openai is not None - and k_nodes_no_bias_with_past_cross_attn_openai[-1].input[0] in graph_input_names - ): - k_nodes = k_nodes_no_bias_with_past_cross_attn_openai + elif output_name_to_node[matmul_qk.input[1]].input[0] in graph_input_names: + # Hugging Face's cross-attention where past_k is used directly as key + k_nodes = [output_name_to_node[matmul_qk.input[1]]] + past_k = k_nodes[0].input[0] + elif k_nodes_past_or_present_oai is not None: + k_nodes = k_nodes_past_or_present_oai + (_, transpose_k, reshape_k, matmul_k) = k_nodes + + # Find past_k input name + start_child_nodes = input_name_to_nodes[matmul_k.output[0]] + for start_child_node in start_child_nodes: + if start_child_node.op_type == "Concat": + concat_k_nodes = self.model.match_parent_path( + start_child_node, + ["Reshape", "Transpose"], + [0, 0], + ) + if concat_k_nodes is not None: + past_k = concat_k_nodes[-1].input[0] + start_child_nodes = input_name_to_nodes[start_child_node.output[0]] + break + + # Find present_k output name + for start_child_node in start_child_nodes: + start_grandchild_nodes = input_name_to_nodes[start_child_node.output[0]] + for start_grandchild_node in start_grandchild_nodes: + if start_grandchild_node.output[0] in graph_output_names: + present_k = start_grandchild_node.output[0] + break + if present_k != "": + break + elif k_nodes_past_only_oai is not None: + # OpenAI's cross-attention where past_k is used directly as key + k_nodes = k_nodes_past_only_oai past_k = k_nodes[-1].input[0] - present_k = k_nodes[-1].output[0] - if present_k not in graph_output_names: - identity_node_k = list( - filter(lambda node: node.op_type == "Identity", self.model.input_name_to_nodes()[past_k]) - ) - present_k = identity_node_k[0].output[0] if len(identity_node_k) == 1 else "" else: + logger.debug("fuse_attention: failed to match k path") return past_k = past_k if past_k in graph_input_names else "" present_k = present_k if present_k in graph_output_names else "" - if k_nodes in (k_nodes_no_bias_openai, k_nodes_no_bias, k_nodes_no_bias_with_past_self_attn): + if matmul_k is not None and add_k is None: # Create empty Add node for attention graph - bias_dim = self.model.get_initializer(add_v.input[0]).dims[0] + add_v_tensor = self.model.get_initializer(add_v.input[0]) + bias_dim = add_v_tensor.dims[0] + dtype = add_v_tensor.data_type empty_bias_name = "empty_bias" empty_tensor = self.model.get_initializer(empty_bias_name) if empty_tensor is None: self.add_initializer( empty_bias_name, - TensorProto.FLOAT, + dtype, dims=[bias_dim], - vals=np.array([0.0] * bias_dim, dtype=np.float32), + vals=np.array([0.0] * bias_dim, dtype=helper.tensor_dtype_to_np_dtype(dtype)), ) add_name = self.model.create_node_name("Add") - add_k = helper.make_node("Add", [empty_bias_name, matmul_k.output[0]], [reshape_k_1.name], add_name) + add_k = helper.make_node("Add", [empty_bias_name, matmul_k.output[0]], [reshape_k.name], add_name) - if ( - model_impl_openai - and not bool(past_k) - and not self.check_runtime_shape_path_openai( - reshape_qkv_2, - matmul_qkv, - add_qk, - matmul_qk, - add_q, - ) - ): - return - elif ( - not model_impl_openai - and not bool(past_k) - and not self.check_runtime_shape_path( - reshape_qkv_2, - reshape_qkv_1, - reshape_q_2, - reshape_k_2, - reshape_v_2, - root_input, - ) - ): - return - - three_root_inputs = bool(past_k) and bool(past_v) and matmul_k is None and "matmul_v" not in locals() + three_root_inputs = bool(past_k) and bool(past_v) and matmul_k is None and matmul_v is None one_root_input = ( not three_root_inputs - and matmul_k.input[0] == root_input and matmul_q.input[0] == root_input + and matmul_k.input[0] == root_input and matmul_v.input[0] == root_input ) two_root_inputs = ( @@ -509,84 +291,97 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): ) # There are 5 types of attention: - # 1) Encoder attention with one_root_input=True and qk_nodes=qk_nodes_1 - # 2) Decoder attention with one_root_input=True and qk_nodes=qk_nodes_2 - # 3) Decoder attention with past with one_root_input=True and qk_nodes=qk_nodes_1 and past_k=past_decoder_key and past_v=past_decoder_value - # 4) Decoder cross attention with two_root_inputs=True and qk_nodes=qk_nodes_1 - # 5) Decoder cross attention with past with three_root_inputs=True and qk_nodes=qk_nodes_1 - encoder_attention = one_root_input and qk_nodes == qk_nodes_1 - decoder_attention = one_root_input and qk_nodes in (qk_nodes_2, qk_nodes_2_openai) - decoder_attention_with_past = ( - (encoder_attention if not model_impl_openai else decoder_attention) and bool(past_k) and bool(past_v) - ) - decoder_cross_attention = two_root_inputs and qk_nodes == qk_nodes_1 - decoder_cross_attention_with_past = three_root_inputs and qk_nodes == qk_nodes_1 - - # For decoder_attention, the attention mask needs to be included in the attention node - mask_index, mask_nodes = None, [] - if decoder_attention: + # 1) Encoder attention with one_root_input=True and qk_nodes=qk_nodes_no_mask + # 2) Decoder self attention with one_root_input=True and qk_nodes=qk_nodes_with_mask + # 3) Decoder cross attention with two_root_inputs=True and qk_nodes=qk_nodes_no_mask + # 4) Decoder self attention with past with one_root_input=True and qk_nodes=qk_nodes_with_mask and past_k=past_decoder_key and past_v=past_decoder_value + # 5) Decoder cross attention with past with three_root_inputs=True and qk_nodes=qk_nodes_no_mask + encoder_attention = one_root_input and qk_nodes == qk_nodes_no_mask + decoder_self_attention = one_root_input and qk_nodes == qk_nodes_with_mask + decoder_cross_attention = two_root_inputs and qk_nodes == qk_nodes_no_mask + decoder_self_attention_with_past = decoder_self_attention and bool(past_k) and bool(past_v) + decoder_cross_attention_with_past = three_root_inputs and qk_nodes == qk_nodes_no_mask + + # For decoder self-attentions, the attention mask needs to be included in the attention node + causal_mask = qk_nodes == qk_nodes_with_mask + mask_nodes = [] + if causal_mask: mask_nodes_bart = self.model.match_parent_path( add_qk, ["Where"], [1], ) - mask_nodes_whisper = self.model.match_parent_path( + mask_nodes_whisper_hf = self.model.match_parent_path( + add_qk, + ["Slice", "Expand", "Where"], + [1, 0, 1], + ) + mask_nodes_whisper_oai = self.model.match_parent_path( + add_qk, + ["Slice", "Unsqueeze", "Gather", "Shape", "Add"], + [1, 2, 0, 0, 0], + ) + mask_nodes_whisper_oai_unit_test = self.model.match_parent_path( add_qk, - ["Expand", "Unsqueeze", "Unsqueeze", "Where"], - [1, 0, 0, 0], + ["Slice", "Slice"], + [1, 0], ) - if mask_nodes_whisper is not None: - mask_index = mask_nodes_whisper[0].output[-1] - mask_nodes = mask_nodes_whisper + if mask_nodes_whisper_hf is not None: + mask_nodes = mask_nodes_whisper_hf + elif mask_nodes_whisper_oai is not None: + mask_nodes = mask_nodes_whisper_oai + elif mask_nodes_whisper_oai_unit_test is not None: + mask_nodes = mask_nodes_whisper_oai_unit_test elif mask_nodes_bart is not None: - mask_index = mask_nodes_bart[0].output[-1] mask_nodes = mask_nodes_bart + else: + logger.debug("fuse_attention: failed to match mask nodes") + return + assert len(mask_nodes) > 0 if ( encoder_attention - or decoder_attention - or decoder_attention_with_past + or decoder_self_attention or decoder_cross_attention + or decoder_self_attention_with_past or decoder_cross_attention_with_past ): - attention_last_node = reshape_qkv_2 - num_heads, hidden_size = self.get_num_heads_and_hidden_size(reshape_q_1) + attention_last_node = reshape_qkv + num_heads, hidden_size = self.get_num_heads_and_hidden_size(reshape_q) if num_heads <= 0 or hidden_size <= 0 or (hidden_size % num_heads) != 0: logger.debug("fuse_attention: failed to detect num_heads or hidden_size") return new_node = None - if decoder_attention_with_past or decoder_cross_attention or decoder_cross_attention_with_past: - # Note: Decoder attention with past key and past value is fused as multihead attention - # rather than attention because multihead attention supports separate past key and past + if decoder_self_attention_with_past or decoder_cross_attention or decoder_cross_attention_with_past: + # Note: Decoder attention with past key and past value is fused as multi-head attention + # rather than attention because multi-head attention supports separate past key and past # value whereas attention supports concatenated past key and past value. new_node = ( self.create_multihead_attention_node( q_matmul=matmul_q, - k_matmul=matmul_k if decoder_cross_attention or decoder_attention_with_past else past_k, - v_matmul=matmul_v if decoder_cross_attention or decoder_attention_with_past else past_v, + k_matmul=matmul_k if decoder_cross_attention or decoder_self_attention_with_past else past_k, + v_matmul=matmul_v if decoder_cross_attention or decoder_self_attention_with_past else past_v, q_add=add_q, - k_add=add_k if decoder_cross_attention or decoder_attention_with_past else None, - v_add=add_v if decoder_cross_attention or decoder_attention_with_past else None, + k_add=add_k if decoder_cross_attention or decoder_self_attention_with_past else None, + v_add=add_v if decoder_cross_attention or decoder_self_attention_with_past else None, num_heads=num_heads, hidden_size=hidden_size, output=attention_last_node.output[0], - unidirectional=decoder_attention_with_past, - past_k=past_k if decoder_attention_with_past else "", - past_v=past_v if decoder_attention_with_past else "", + unidirectional=causal_mask, + past_k=past_k if decoder_self_attention_with_past else "", + past_v=past_v if decoder_self_attention_with_past else "", present_k=present_k, present_v=present_v, - packed_qkv=decoder_attention_with_past, ) if self.use_multi_head_attention else None ) else: - # Temporarily set multihead attention flag to false + # Temporarily set multi-head attention flag to false use_multi_head_attention_ground_truth = self.use_multi_head_attention self.use_multi_head_attention = False - add_qk_str = mask_index if decoder_attention and mask_index else "" new_node = self.create_attention_node( mask_index=None, q_matmul=matmul_q, @@ -599,17 +394,15 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): hidden_size=hidden_size, first_input=root_input, output=attention_last_node.output[0], - add_qk_str=( - None if len(mask_nodes) > 1 else add_qk_str - ), # deprecate and use is_unidirectional attr instead for Whisper + causal=causal_mask, past_k=past_k, past_v=past_v, present_k=present_k, present_v=present_v, - causal=decoder_attention, ) self.use_multi_head_attention = use_multi_head_attention_ground_truth if new_node is None: + logger.debug("fuse_attention: failed to create fused node") return self.nodes_to_add.append(new_node) @@ -618,22 +411,20 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): self.nodes_to_remove.extend([attention_last_node, transpose_qkv, matmul_qkv]) self.nodes_to_remove.extend(qk_nodes) - # When using multihead attention, keep MatMul nodes in original graph - if decoder_attention_with_past or decoder_cross_attention or decoder_cross_attention_with_past: - if q_nodes[-1].op_type == "MatMul": + # When using multi-head attention, keep MatMul nodes in original graph + if decoder_self_attention_with_past or decoder_cross_attention or decoder_cross_attention_with_past: + if len(q_nodes) > 0 and q_nodes[-1].op_type == "MatMul": q_nodes.pop() - if k_nodes[-1].op_type == "MatMul": + if len(k_nodes) > 0 and k_nodes[-1].op_type == "MatMul": k_nodes.pop() - if v_nodes[-1].op_type == "MatMul": + if len(v_nodes) > 0 and v_nodes[-1].op_type == "MatMul": v_nodes.pop() - if self.disable_multi_head_attention_bias and ( - decoder_cross_attention or decoder_cross_attention_with_past - ): - if q_nodes[-1].op_type == "Add": + if self.disable_multi_head_attention_bias: + if len(q_nodes) > 0 and q_nodes[-1].op_type == "Add": q_nodes.pop() - if k_nodes[-1].op_type == "Add": + if len(k_nodes) > 0 and k_nodes[-1].op_type == "Add": k_nodes.pop() - if v_nodes[-1].op_type == "Add": + if len(v_nodes) > 0 and v_nodes[-1].op_type == "Add": v_nodes.pop() self.nodes_to_remove.extend(q_nodes) diff --git a/onnxruntime/python/tools/transformers/io_binding_helper.py b/onnxruntime/python/tools/transformers/io_binding_helper.py index 2b19ae5029ecc..072bb9bb39a79 100644 --- a/onnxruntime/python/tools/transformers/io_binding_helper.py +++ b/onnxruntime/python/tools/transformers/io_binding_helper.py @@ -53,6 +53,7 @@ def ort_type_to_torch_type(ort_type: str): "tensor(int32)": torch.int32, "tensor(float)": torch.float32, "tensor(float16)": torch.float16, + "tensor(bfloat16)": torch.bfloat16, "tensor(bool)": torch.bool, "tensor(uint8)": torch.uint8, } diff --git a/onnxruntime/python/tools/transformers/models/whisper/requirements.txt b/onnxruntime/python/tools/transformers/models/whisper/requirements.txt index 29a08b5ccd220..f1758cc52280f 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/requirements.txt +++ b/onnxruntime/python/tools/transformers/models/whisper/requirements.txt @@ -1,13 +1,13 @@ -torch>=1.13.0 -transformers>=4.36.0,<= 4.42.4 -openai-whisper>=20231117,<=20240927 +torch>=2.7.0 +transformers>=4.52.3 +openai-whisper==20240927 ffmpeg-python datasets soundfile librosa -optimum<=1.21.2 +optimum onnxruntime-extensions>=0.9.0 -onnx==1.17.0 +onnx protobuf==3.20.2 numpy==1.23.3 psutil diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_jump_times.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_jump_times.py index 4765616ec2b6f..a7c0d3538b8da 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_jump_times.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_jump_times.py @@ -196,7 +196,7 @@ def create_torch_ops(self): # Set torch extensions directory to cache directory os.environ["TORCH_EXTENSIONS_DIR"] = self.cache_dir - # Try to import `jinja` pip package + # Try to import `ninja` pip package try: assert torch.utils.cpp_extension.verify_ninja_availability() except Exception as e: diff --git a/onnxruntime/python/tools/transformers/onnx_model.py b/onnxruntime/python/tools/transformers/onnx_model.py index 8add38b5a7d07..89c2d5e7cc259 100644 --- a/onnxruntime/python/tools/transformers/onnx_model.py +++ b/onnxruntime/python/tools/transformers/onnx_model.py @@ -1349,6 +1349,8 @@ def has_same_value( tensor2: TensorProto, signature_cache1: dict | None = None, signature_cache2: dict | None = None, + rtol: float = 1e-05, + atol: float = 1e-08, ) -> bool: """Returns True when two tensors have same value. Note that name can be different. @@ -1358,6 +1360,8 @@ def has_same_value( tensor2 (TensorProto): initializer 2 signature_cache1 (dict): Optional dictionary to store data signatures of tensor1 in order to speed up comparison. signature_cache2 (dict): Optional dictionary to store data signatures of tensor2 in order to speed up comparison. + rtol (float): Optional relative difference threshold for minor precision differences + atol (float): Optional absolute difference threshold for minor precision differences Returns: bool: True when two initializers has same value. """ @@ -1375,9 +1379,17 @@ def has_same_value( signature_cache1[tensor1.name] = sig1 if signature_cache2 is not None: signature_cache2[tensor2.name] = sig2 - if sig1 == sig2 and tensor1.data_type == tensor2.data_type and tensor1.dims == tensor2.dims: - # Same signature, now do the expensive check to confirm the data is the same - return (numpy_helper.to_array(tensor1) == numpy_helper.to_array(tensor2)).all() + if tensor1.data_type == tensor2.data_type and tensor1.dims == tensor2.dims: + n1 = numpy_helper.to_array(tensor1) + n2 = numpy_helper.to_array(tensor2) + if sig1 == sig2: + # Same signature, now do the expensive check to confirm the data is the same + return (n1 == n2).all() + else: + # Check if tensors are allclose + from numpy import allclose + + return allclose(n1, n2, rtol=rtol, atol=atol) return False diff --git a/onnxruntime/test/common/tensor_op_test_utils.h b/onnxruntime/test/common/tensor_op_test_utils.h index acb520f894569..0ab3b693d59d9 100644 --- a/onnxruntime/test/common/tensor_op_test_utils.h +++ b/onnxruntime/test/common/tensor_op_test_utils.h @@ -133,7 +133,8 @@ inline std::vector ValueRange(size_t count, BFloat16 start, return result; } -inline std::pair MeanStdev(gsl::span v) { +template +inline std::pair MeanStdev(const T& v) { float sum = std::accumulate(v.begin(), v.end(), 0.0f); float mean = sum / v.size(); diff --git a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc index bb2bfab585da8..f8739b859bef5 100644 --- a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc @@ -20,6 +20,7 @@ #include "test/optimizer/graph_transform_test_builder.h" #include "test/providers/provider_test_utils.h" #include "test/util/include/default_providers.h" +#include "test/util/include/scoped_env_vars.h" #include "core/session/onnxruntime_cxx_api.h" #include "core/session/ort_env.h" #include "core/util/qmath.h" @@ -81,6 +82,8 @@ struct TestOptions { bool has_g_idx{false}; bool has_bias{false}; + bool legacy_shape{false}; // for backward compatibility + std::optional output_abs_error{}; std::optional output_rel_error{}; }; @@ -107,28 +110,20 @@ void RunTest(const TestOptions& opts, const bool zp_is_4bit = opts.zp_is_4bit || opts.has_g_idx; - const int64_t M = opts.M, - K = opts.K, - N = opts.N; + const int64_t M = opts.M; + const int64_t K = opts.K; + const int64_t N = opts.N; RandomValueGenerator random{1234}; std::vector input0_vals(random.Gaussian(AsSpan({M, K}), 0.0f, 0.25f)); std::vector input1_f_vals(random.Gaussian(AsSpan({K, N}), 0.0f, 0.25f)); -#if 0 // for Debugging - std::vector input1_f_vals_trans(N * K); - MlasTranspose(input1_f_vals.data(), input1_f_vals_trans.data(), K, N); -#endif - - int q_rows, q_cols; - MlasBlockwiseQuantizedShape(static_cast(opts.block_size), /* columnwise */ true, - static_cast(K), static_cast(N), - q_rows, q_cols); - - size_t q_data_size_in_bytes, q_scale_size, q_zp_size_in_bytes; - MlasBlockwiseQuantizedBufferSizes(static_cast(opts.block_size), /* columnwise */ true, - static_cast(K), static_cast(N), - q_data_size_in_bytes, q_scale_size, &q_zp_size_in_bytes); + int64_t k_blocks = (K + opts.block_size - 1) / opts.block_size; + int64_t blob_size = (opts.block_size * QBits + 7) / 8; + size_t q_scale_size = static_cast(N * k_blocks); + size_t q_data_size_in_bytes = static_cast(N * k_blocks * blob_size); // packed as UInt4x2 + const int64_t zero_point_blob_size = (k_blocks * QBits + 7) / 8; + size_t q_zp_size_in_bytes = static_cast(N * zero_point_blob_size); // packed as UInt4x2 std::vector input1_vals(q_data_size_in_bytes); std::vector scales(q_scale_size); @@ -142,16 +137,6 @@ void RunTest(const TestOptions& opts, static_cast(K), static_cast(opts.block_size)); -#if 0 - for (int i = 0; i < input1_vals.size(); i++) - { - uint8_t byte = input1_vals[i]; - uint8_t val_lo = byte & 0x0f; - uint8_t val_hi = byte >> 4; - std::cout << (int)val_lo << ", " << (int)val_hi << ", "; - } -#endif - const std::vector bias_shape = {N}; const auto bias = [&]() -> std::optional> { if (opts.has_bias) { @@ -184,17 +169,22 @@ void RunTest(const TestOptions& opts, test.AddInput("A", {M, K}, input0_vals, false); } - test.AddInput("B", {q_cols, q_rows}, input1_vals, true); + test.AddInput("B", {N, k_blocks, blob_size}, input1_vals, true); + + auto scales_shape = opts.legacy_shape ? std::vector{N * k_blocks} + : std::vector{N, k_blocks}; if constexpr (use_float16) { - test.AddInput("scales", {static_cast(q_scale_size)}, ToFloat16(scales), true); + test.AddInput("scales", scales_shape, ToFloat16(scales), true); } else { - test.AddInput("scales", {static_cast(q_scale_size)}, scales, true); + test.AddInput("scales", scales_shape, scales, true); } if (opts.has_zero_point) { if (zp_is_4bit) { - test.AddInput("zero_points", {static_cast(q_zp_size_in_bytes)}, zp, true); + auto zp_shape = opts.legacy_shape ? std::vector{N * zero_point_blob_size} + : std::vector{N, zero_point_blob_size}; + test.AddInput("zero_points", zp_shape, zp, true); } else { std::vector zp_f; zp_f.reserve(q_zp_size_in_bytes * 2); @@ -209,9 +199,9 @@ void RunTest(const TestOptions& opts, } if constexpr (use_float16) { - test.AddInput("zero_points", {static_cast(q_scale_size)}, ToFloat16(zp_f), true); + test.AddInput("zero_points", scales_shape, ToFloat16(zp_f), true); } else { - test.AddInput("zero_points", {static_cast(q_scale_size)}, zp_f, true); + test.AddInput("zero_points", scales_shape, zp_f, true); } } } else { @@ -267,7 +257,7 @@ void RunTest(const TestOptions& opts, } // namespace -template +template void TestMatMulNBitsTyped() { TestOptions base_opts{}; base_opts.M = M, base_opts.N = N, base_opts.K = K; @@ -464,6 +454,13 @@ TEST(MatMulNBits, Float16_Accuracy4) { TestMatMulNBitsTyped(); TestMatMulNBitsTyped(); } + +TEST(MatMulNBits, LegacyShape) { + constexpr bool legacy_shape = true; + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); +} + #endif #endif #endif @@ -490,6 +487,7 @@ void RunTest(int64_t M, int64_t N, int64_t K, int64_t block_size, int64_t accura if (use_float16) { opts.output_abs_error = fp16_abs_error; + opts.output_rel_error = use_float16 ? 0.001f : 0.0005f; } std::vector> execution_providers; @@ -552,11 +550,8 @@ TEST(MatMulNBits, Float16Large) { // absolute error of 0.08, but the A10 has errors going as high as 0.22. Ultimately, given the large number // of elements in this test, ULPs should probably be used instead of absolute/relative tolerances. float abs_error = 0.3f; -#elif USE_WEBGPU - // Use absolute error of 0.1 for WebGPU with subgroup implementation - float abs_error = 0.1f; #else - float abs_error = 0.05f; + float abs_error = 0.1f; #endif for (auto block_size : {16, 32, 64, 128}) { @@ -568,6 +563,53 @@ TEST(MatMulNBits, Float16Large) { } } +#ifdef USE_CUDA +TEST(MatMulNBits, Fp16_Int4_Int4ZeroPoint) { + float abs_error = 0.1f; + constexpr bool use_float16 = true; + constexpr bool has_g_idx = false; + constexpr bool zp_is_4bit = true; + constexpr bool has_zeropoint = true; + + ScopedEnvironmentVariables scoped_env_vars{EnvVarMap{{"ORT_FPA_INTB_GEMM", "1"}}}; + + for (auto block_size : {64, 128}) { + RunTest(1, 256, 1024, block_size, 0, has_zeropoint, use_float16, has_g_idx, zp_is_4bit, abs_error); + RunTest(32, 1024, 2048, block_size, 0, has_zeropoint, use_float16, has_g_idx, zp_is_4bit, abs_error); + } +} + +TEST(MatMulNBits, Fp16_Int4_Fp16ZeroPoint) { + float abs_error = 0.1f; + constexpr bool use_float16 = true; + constexpr bool has_g_idx = false; + constexpr bool zp_is_4bit = false; + constexpr bool has_zeropoint = true; + + ScopedEnvironmentVariables scoped_env_vars{EnvVarMap{{"ORT_FPA_INTB_GEMM", "1"}}}; + + for (auto block_size : {64, 128}) { + RunTest(1, 256, 1024, block_size, 0, has_zeropoint, use_float16, has_g_idx, zp_is_4bit, abs_error); + RunTest(32, 1024, 2048, block_size, 0, has_zeropoint, use_float16, has_g_idx, zp_is_4bit, abs_error); + } +} + +TEST(MatMulNBits, Fp16_Int4_NoZeroPoint) { + float abs_error = 0.1f; + constexpr bool use_float16 = true; + constexpr bool has_g_idx = false; + constexpr bool zp_is_4bit = true; + constexpr bool has_zeropoint = false; + + ScopedEnvironmentVariables scoped_env_vars{EnvVarMap{{"ORT_FPA_INTB_GEMM", "1"}}}; + + for (auto block_size : {64, 128}) { + RunTest(1, 256, 1024, block_size, 0, has_zeropoint, use_float16, has_g_idx, zp_is_4bit, abs_error); + RunTest(32, 1024, 2048, block_size, 0, has_zeropoint, use_float16, has_g_idx, zp_is_4bit, abs_error); + } +} +#endif + #endif // defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_DML) } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/contrib_ops/matmul_8bits_test.cc b/onnxruntime/test/contrib_ops/matmul_8bits_test.cc index 257d3b3efdf9c..39f6958d47a12 100644 --- a/onnxruntime/test/contrib_ops/matmul_8bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_8bits_test.cc @@ -21,6 +21,7 @@ #include "test/optimizer/graph_transform_test_builder.h" #include "test/providers/provider_test_utils.h" #include "test/util/include/default_providers.h" +#include "test/util/include/scoped_env_vars.h" #include "core/session/onnxruntime_cxx_api.h" #include "core/session/ort_env.h" #include "core/util/qmath.h" @@ -143,16 +144,17 @@ void RunTest8Bits(const TestOptions8Bits& opts) { test.AddInput("A", {M, K}, FloatsToMLFloat16s(input0_fp32_vals), false); } - test.AddInput("B", {q_cols, q_rows}, input1_vals, true); + int64_t k_blocks = (K + opts.block_size - 1) / opts.block_size; + test.AddInput("B", {q_cols, k_blocks, q_rows / k_blocks}, input1_vals, true); if constexpr (std::is_same::value) { - test.AddInput("scales", {static_cast(q_scale_size)}, scales, true); + test.AddInput("scales", {N, static_cast(q_scale_size) / N}, scales, true); } else { - test.AddInput("scales", {static_cast(q_scale_size)}, FloatsToMLFloat16s(scales), true); + test.AddInput("scales", {N, static_cast(q_scale_size) / N}, FloatsToMLFloat16s(scales), true); } if (opts.has_zero_point) { - test.AddInput("zero_points", {static_cast(q_zp_size_in_bytes)}, zp, true); + test.AddInput("zero_points", {N, static_cast(q_zp_size_in_bytes) / N}, zp, true); } else { test.AddOptionalInputEdge(); } @@ -205,28 +207,26 @@ void RunTest8Bits(const TestOptions8Bits& opts) { } template -void TestMatMul8BitsTyped() { +void TestMatMul8BitsTyped(float abs_error = 0.1f, float rel_error = 0.02f) { TestOptions8Bits base_opts{}; base_opts.M = M, base_opts.N = N, base_opts.K = K; base_opts.block_size = block_size; base_opts.accuracy_level = accuracy_level; - if (base_opts.accuracy_level == 4) { - base_opts.output_abs_error = 0.1f; - base_opts.output_rel_error = 0.02f; - } else if constexpr (std::is_same::value) { - base_opts.output_abs_error = 0.055f; - base_opts.output_rel_error = 0.02f; - } + base_opts.output_abs_error = abs_error; + base_opts.output_rel_error = rel_error; { TestOptions8Bits opts = base_opts; + opts.has_zero_point = false; + opts.has_bias = false; RunTest8Bits(opts); } { TestOptions8Bits opts = base_opts; opts.has_zero_point = true; + opts.has_bias = false; RunTest8Bits(opts); } @@ -234,6 +234,7 @@ void TestMatMul8BitsTyped() { #if !defined(USE_CUDA) && !defined(USE_WEBGPU) { TestOptions8Bits opts = base_opts; + opts.has_zero_point = false; opts.has_bias = true; RunTest8Bits(opts); } @@ -248,7 +249,7 @@ void TestMatMul8BitsTyped() { } } // namespace -TEST(MatMulNBits, Float32_8b_AccuracyLevel4_Float) { +TEST(MatMulNBits, Float32_8b_AccuracyLevel4) { TestMatMul8BitsTyped(); TestMatMul8BitsTyped(); TestMatMul8BitsTyped(); @@ -284,9 +285,25 @@ TEST(MatMulNBits, Float32_8b_AccuracyLevel4_Float) { } #if defined(USE_CUDA) || defined(USE_WEBGPU) -TEST(MatMulNBits, Float32_8b_AccuracyLevel4_Float16) { - TestMatMul8BitsTyped(); - TestMatMul8BitsTyped(); +TEST(MatMulNBits, Float16_8b_AccuracyLevel4) { + constexpr float abs_error = 0.055f; + constexpr float rel_error = 0.02f; + TestMatMul8BitsTyped(abs_error, rel_error); + TestMatMul8BitsTyped(abs_error, rel_error); +} +#endif + +#if defined(USE_CUDA) +TEST(MatMulNBits, Fp16_Int8_Cuda) { + constexpr float abs_error = 0.5f; + constexpr float rel_error = 0.05f; + + ScopedEnvironmentVariables scoped_env_vars{EnvVarMap{{"ORT_FPA_INTB_GEMM", "1"}}}; + + TestMatMul8BitsTyped(abs_error, rel_error); + TestMatMul8BitsTyped(abs_error, rel_error); + TestMatMul8BitsTyped(abs_error, rel_error); + TestMatMul8BitsTyped(abs_error, rel_error); } #endif diff --git a/onnxruntime/test/onnx/TestCase.cc b/onnxruntime/test/onnx/TestCase.cc index f56f9ffcc7858..6dca258601339 100644 --- a/onnxruntime/test/onnx/TestCase.cc +++ b/onnxruntime/test/onnx/TestCase.cc @@ -1455,6 +1455,12 @@ std::unique_ptr> GetBrokenTests(const std::string& provider // Fails with QNN SDK 2.17.0: // expected 7.70947 (40f6b3f3), got 7.84096 (40fae920), diff: 0.131491, tol=0.00870947 idx=419. 100 of 1715 differ broken_tests->insert({"facedetection_op8_qdq", "result differs"}); + // Fails with QNN SDK 2.34.0: + // expected 2.18661 (400bf164), got 1.48898 (3fbe96ce), diff: 0.697631, tol=0.00318661 idx=0. 8 of 8 differ + broken_tests->insert({"gemm_default_vector_bias", "result differs with 2.34"}); + // expected 0.0505495 (3d4f0d00), got 0.0506369 (3d4f68ae), diff: 8.74326e-05, tol=6.05495e-05 idx=448 + broken_tests->insert({"mobilenetv2-1.0", "result differs with 2.34"}); + broken_tests->insert({"facedetection_op8", "segfault with CPU backend, will be fixed by QNN 2.36"}); #if defined(_WIN32) && defined(_M_AMD64) // Fails with QNN SDK 2.17.0 on Windows x64: diff --git a/onnxruntime/test/optimizer/graph_transform_test_builder.h b/onnxruntime/test/optimizer/graph_transform_test_builder.h index 4e50881ad4f90..26df588eab73f 100644 --- a/onnxruntime/test/optimizer/graph_transform_test_builder.h +++ b/onnxruntime/test/optimizer/graph_transform_test_builder.h @@ -147,6 +147,14 @@ class ModelTestBuilder { } } + // Make optional tensor + NodeArg* MakeOptionalTensor() { + ONNX_NAMESPACE::TypeProto type_proto; + type_proto.mutable_tensor_type()->set_elem_type(utils::ToTensorProtoElementType()); + std::string name; + return &graph_.GetOrCreateNodeArg(name, &type_proto); + } + template NodeArg* MakeSymbolicInput(const std::vector>& shape) { ONNX_NAMESPACE::TypeProto type_proto; diff --git a/onnxruntime/test/perftest/command_args_parser.cc b/onnxruntime/test/perftest/command_args_parser.cc index 103da5f534ea7..d409032b4ebb3 100644 --- a/onnxruntime/test/perftest/command_args_parser.cc +++ b/onnxruntime/test/perftest/command_args_parser.cc @@ -61,6 +61,7 @@ namespace perftest { "\t-u [optimized_model_path]: Specify the optimized model path for saving.\n" "\t-d [CUDA only][cudnn_conv_algorithm]: Specify CUDNN convolution algorithms: 0(benchmark), 1(heuristic), 2(default). \n" "\t-q [CUDA only] use separate stream for copy. \n" + "\t-g [TensorRT RTX | TensorRT | CUDA] Enable tensor input and output bindings on CUDA before session run \n" "\t-z: Set denormal as zero. When turning on this option reduces latency dramatically, a model may have denormals.\n" "\t-C: Specify session configuration entries as key-value pairs: -C \"| |\" \n" "\t Refer to onnxruntime_session_options_config_keys.h for valid keys and values. \n" @@ -160,6 +161,9 @@ namespace perftest { "\t-n [Exit after session creation]: allow user to measure session creation time to measure impact of enabling any initialization optimizations.\n" "\t-l Provide file as binary in memory by using fopen before session creation.\n" "\t-R [Register custom op]: allow user to register custom op by .so or .dll file.\n" + "\t-X [Enable onnxruntime-extensions custom ops]: Registers custom ops from onnxruntime-extensions. " + "onnxruntime-extensions must have been built in to onnxruntime. This can be done with the build.py " + "'--use_extensions' option.\n" "\t-h: help\n"); } #ifdef _WIN32 @@ -189,7 +193,7 @@ static bool ParseDimensionOverride(std::basic_string& dim_identifier, /*static*/ bool CommandLineParser::ParseArguments(PerformanceTestConfig& test_config, int argc, ORTCHAR_T* argv[]) { int ch; - while ((ch = getopt(argc, argv, ORT_TSTR("m:e:r:t:p:x:y:c:d:o:u:i:f:F:S:T:C:AMPIDZvhsqznlR:"))) != -1) { + while ((ch = getopt(argc, argv, ORT_TSTR("m:e:r:t:p:x:y:c:d:o:u:i:f:F:S:T:C:AMPIDZvhsqznlgR:X"))) != -1) { switch (ch) { case 'f': { std::basic_string dim_name; @@ -389,6 +393,12 @@ static bool ParseDimensionOverride(std::basic_string& dim_identifier, case 'R': test_config.run_config.register_custom_op_path = optarg; break; + case 'g': + test_config.run_config.enable_cuda_io_binding = true; + break; + case 'X': + test_config.run_config.use_extensions = true; + break; case '?': case 'h': default: diff --git a/onnxruntime/test/perftest/ort_test_session.cc b/onnxruntime/test/perftest/ort_test_session.cc index 8257cbfaa7f95..05136ec0750a1 100644 --- a/onnxruntime/test/perftest/ort_test_session.cc +++ b/onnxruntime/test/perftest/ort_test_session.cc @@ -19,6 +19,10 @@ #include "TestCase.h" #include "strings_helper.h" +#if defined(USE_CUDA) || defined(USE_TENSORRT) || defined(USE_NV) +#include +#endif + #ifdef USE_OPENVINO #include "nlohmann/json.hpp" #endif @@ -145,6 +149,9 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device "\nSupported options are:\n", options); } session_options.AppendExecutionProvider_CUDA_V2(*cuda_options); + if (performance_test_config.run_config.enable_cuda_io_binding) { + device_memory_name_ = CUDA; + } #else ORT_THROW("CUDA is not supported in this build\n"); #endif @@ -188,12 +195,18 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device cuda_options.do_copy_in_default_stream = !performance_test_config.run_config.do_cuda_copy_in_separate_stream; // TODO: Support arena configuration for users of perf test session_options.AppendExecutionProvider_CUDA(cuda_options); + if (performance_test_config.run_config.enable_cuda_io_binding) { + device_memory_name_ = CUDA; + } #else ORT_THROW("TensorRT is not supported in this build\n"); #endif } else if (provider_name_ == onnxruntime::kNvTensorRTRTXExecutionProvider) { #ifdef USE_NV session_options.AppendExecutionProvider("NvTensorRtRtx", provider_options); + if (performance_test_config.run_config.enable_cuda_io_binding) { + device_memory_name_ = CUDA; + } #else ORT_THROW("NV TensorRT RTX is not supported in this build\n"); #endif @@ -813,6 +826,10 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); #endif } + if (performance_test_config.run_config.use_extensions) { + session_options.EnableOrtCustomOps(); + } + if (!performance_test_config.model_info.load_via_path) { session_ = Ort::Session(env, performance_test_config.model_info.model_file_path.c_str(), session_options); } else { @@ -855,7 +872,12 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); return Ort::Value(nullptr); }; } else { - Ort::MemoryInfo memory_info = Ort::MemoryInfo(device_memory_name_.data(), OrtArenaAllocator, 0, OrtMemTypeCPUOutput); + Ort::MemoryInfo memory_info(nullptr); // Default initialize, will be overwritten + if (device_memory_name_ == CUDA) { + memory_info = Ort::MemoryInfo(device_memory_name_.data(), OrtArenaAllocator, 0, OrtMemTypeDefault); + } else { + memory_info = Ort::MemoryInfo(device_memory_name_.data(), OrtArenaAllocator, 0, OrtMemTypeCPUOutput); + } custom_allocator_ = Ort::Allocator(session_, memory_info); allocator_ = custom_allocator_; @@ -956,6 +978,7 @@ static void InitializeTensorWithSeed(int32_t seed, Ort::Value& tensor) { } bool OnnxRuntimeTestSession::PopulateGeneratedInputTestData(int32_t seed) { + Ort::AllocatorWithDefaultOptions default_allocator; // iterate over all input nodes for (size_t i = 0; i < static_cast(input_length_); i++) { Ort::TypeInfo type_info = session_.GetInputTypeInfo(i); @@ -967,10 +990,37 @@ bool OnnxRuntimeTestSession::PopulateGeneratedInputTestData(int32_t seed) { auto transform_fcn = [](int64_t input) { return (input == -1) ? -input : input; }; std::transform(input_node_dim.begin(), input_node_dim.end(), input_node_dim.begin(), transform_fcn); - Ort::Value input_tensor = Ort::Value::CreateTensor(allocator_, (const int64_t*)input_node_dim.data(), - input_node_dim.size(), tensor_info.GetElementType()); - InitializeTensorWithSeed(seed, input_tensor); - PreLoadTestData(0, i, std::move(input_tensor)); + if (device_memory_name_ != CUDA) { + Ort::Value input_tensor = Ort::Value::CreateTensor(allocator_, (const int64_t*)input_node_dim.data(), + input_node_dim.size(), tensor_info.GetElementType()); + InitializeTensorWithSeed(seed, input_tensor); + PreLoadTestData(0, i, std::move(input_tensor)); + } +// Create tensor on CPU, initialize and copy to CUDA tensor +#if defined(USE_CUDA) || defined(USE_TENSORRT) || defined(USE_NV) + else { + Ort::Value default_tensor = Ort::Value::CreateTensor(default_allocator, (const int64_t*)input_node_dim.data(), + input_node_dim.size(), tensor_info.GetElementType()); + InitializeTensorWithSeed(seed, default_tensor); + + // Get pointer to CPU tensor data + const void* default_ptr = default_tensor.GetTensorRawData(); + + size_t total_bytes = default_tensor.GetTensorSizeInBytes(); + + Ort::Value cuda_tensor = Ort::Value::CreateTensor(allocator_, input_node_dim.data(), + input_node_dim.size(), tensor_info.GetElementType()); + + void* cuda_ptr = cuda_tensor.GetTensorMutableData(); + + // Copy the initialized data from CPU to GPU + cudaError_t cuda_err = cudaMemcpy(cuda_ptr, default_ptr, total_bytes, cudaMemcpyHostToDevice); + if (cuda_err != cudaSuccess) { + ORT_THROW("Failed to copy tensor data from CPU to CUDA device. CUDA Error: ", cudaGetErrorString(cuda_err)); + } + PreLoadTestData(0, i, std::move(cuda_tensor)); + } +#endif } } return true; diff --git a/onnxruntime/test/perftest/test_configuration.h b/onnxruntime/test/perftest/test_configuration.h index 90759a4d2f65a..8145f5f35c3b3 100644 --- a/onnxruntime/test/perftest/test_configuration.h +++ b/onnxruntime/test/perftest/test_configuration.h @@ -66,6 +66,8 @@ struct RunConfig { bool disable_spinning_between_run = false; bool exit_after_session_creation = false; std::basic_string register_custom_op_path; + bool enable_cuda_io_binding{false}; + bool use_extensions = false; }; struct PerformanceTestConfig { diff --git a/onnxruntime/test/providers/cpu/math/gemm_test.cc b/onnxruntime/test/providers/cpu/math/gemm_test.cc index 400b5ab20930c..6abb3d62848f2 100644 --- a/onnxruntime/test/providers/cpu/math/gemm_test.cc +++ b/onnxruntime/test/providers/cpu/math/gemm_test.cc @@ -430,6 +430,7 @@ TYPED_TEST(GemmOpTypedTests, TestGemm2DBroadcast_2) { {static_cast(11.0f), static_cast(12.0f), static_cast(13.0f), static_cast(-9.0f), static_cast(-8.0f), static_cast(-7.0f)}); test.Config(run_with_tunable_op) + .ConfigExcludeEps({kQnnExecutionProvider}) // Accuracy issues with QNN CPU backend since QNN 2.34 .RunWithConfig(); } @@ -518,10 +519,8 @@ TYPED_TEST(GemmOpTypedTests, TestGemmBroadcast) { excluded_providers.insert(kOpenVINOExecutionProvider); // OpenVINO: Temporarily disabled due to accuracy issues #endif - if (b_is_initializer && !c_is_initializer) { - // Accuracy issues on QNN's CPU backend with QNN SDK version 2.17 - excluded_providers.insert(kQnnExecutionProvider); - } + // Accuracy issues with QNN CPU backend since QNN 2.34 + excluded_providers.insert(kQnnExecutionProvider); test.ConfigExcludeEps(excluded_providers) .Config(run_with_tunable_op) @@ -553,10 +552,16 @@ TYPED_TEST(GemmOpTypedTests, TestGemmTrans) { test.AddOutput("Y", {2, 3}, {static_cast(11.0f), static_cast(11.0f), static_cast(11.0f), static_cast(-9.0f), static_cast(-9.0f), static_cast(-9.0f)}); + + std::unordered_set excluded_providers; #if defined(OPENVINO_CONFIG_GPU) - test.ConfigExcludeEps({kOpenVINOExecutionProvider}); // OpenVINO: Temporarily disabled due to accuracy issues + excluded_providers.insert(kOpenVINOExecutionProvider); // OpenVINO: Temporarily disabled due to accuracy issues #endif - test.Config(run_with_tunable_op) + // Accuracy issues with QNN CPU backend since QNN 2.34 + excluded_providers.insert(kQnnExecutionProvider); + + test.ConfigExcludeEps(excluded_providers) + .Config(run_with_tunable_op) .RunWithConfig(); } @@ -579,10 +584,15 @@ TYPED_TEST(GemmOpTypedTests, TestGemmTransB) { test.AddOutput("Y", {2, 3}, {static_cast(11.0f), static_cast(11.0f), static_cast(11.0f), static_cast(-9.0f), static_cast(-9.0f), static_cast(-9.0f)}); + + std::unordered_set excluded_providers; #if defined(OPENVINO_CONFIG_GPU) - test.ConfigExcludeEps({kOpenVINOExecutionProvider}); // OpenVINO: Temporarily disabled due to accuracy issues + excluded_providers.insert(kOpenVINOExecutionProvider); // OpenVINO: Temporarily disabled due to accuracy issues #endif - test.Config(run_with_tunable_op) + excluded_providers.insert(kQnnExecutionProvider); // Accuracy issues with QNN CPU backend since QNN 2.34 + + test.ConfigExcludeEps(excluded_providers) + .Config(run_with_tunable_op) .RunWithConfig(); }; run_test(false, false); diff --git a/onnxruntime/test/providers/cpu/math/softmax_test.cc b/onnxruntime/test/providers/cpu/math/softmax_test.cc index 1c6375ebdb0b1..d97873c21983f 100644 --- a/onnxruntime/test/providers/cpu/math/softmax_test.cc +++ b/onnxruntime/test/providers/cpu/math/softmax_test.cc @@ -49,6 +49,22 @@ TEST(SoftmaxOperator, Simple) { RunTest(x_vals, expected_vals, dimensions); } +#ifdef USE_WEBGPU +TEST(SoftmaxOperator, webgpu_nan) { + OpTester test("Softmax", 13); // axis default is -1 + + std::vector x_vals = {-INFINITY, -INFINITY, -INFINITY}; + std::vector expected_result = {0.0f, 0.0f, 0.0f}; + std::vector dimensions = {1, 3}; + + test.AddInput("X", dimensions, x_vals); + test.AddOutput("Y", dimensions, expected_result); + + // explicitly disable CPU EP for this test since CPU implementation does not handle NaN + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCpuExecutionProvider}); +} +#endif + #if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_XNNPACK) TEST(SoftmaxOperator, Simple_fp16) { #ifdef USE_CUDA diff --git a/onnxruntime/test/providers/cpu/nn/pool_op_test.cc b/onnxruntime/test/providers/cpu/nn/pool_op_test.cc index 36150d03a7d36..1df640a84a64d 100644 --- a/onnxruntime/test/providers/cpu/nn/pool_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/pool_op_test.cc @@ -229,6 +229,7 @@ TEST(PoolTest, MaxPool1D_case2) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } diff --git a/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc b/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc index f3a963ce47eda..c04cbc7d4924e 100644 --- a/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc +++ b/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc @@ -1,4 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // Licensed under the MIT License. #include "core/graph/onnx_protobuf.h" #include "core/session/inference_session.h" @@ -24,6 +25,35 @@ namespace onnxruntime { namespace test { +template +class NvExecutionProviderTest : public ::testing::Test { + protected: + std::string getTypeAsName() { + std::string dtype_name = ""; + if constexpr (std::is_same::value) { + dtype_name = "fp64"; + } else if constexpr (std::is_same::value) { + dtype_name = "fp32"; + } else if constexpr (std::is_same::value) { + dtype_name = "fp16"; + } else if constexpr (std::is_same::value) { + dtype_name = "bf16"; + } else if constexpr (std::is_same::value) { + dtype_name = "int8"; + } else if constexpr (std::is_same::value) { + dtype_name = "uint8"; + } else if constexpr (std::is_same::value) { + dtype_name = "int32"; + } else if constexpr (std::is_same::value) { + dtype_name = "int64"; + } + return dtype_name; + } +}; + +using NvExecutionProviderTestTypes = ::testing::Types; // double, +TYPED_TEST_SUITE(NvExecutionProviderTest, NvExecutionProviderTestTypes); + std::string PathToUTF8(const PathString& path) { #ifdef WIN32 std::wstring_convert> converter; @@ -89,7 +119,8 @@ void VerifyOutputs(const std::vector& fetches, const std::vector dims, - bool add_fast_gelu = false) { + bool add_fast_gelu = false, + ONNX_NAMESPACE::TensorProto_DataType dtype = ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { onnxruntime::Model model(graph_name, false, DefaultLoggingManager().DefaultLogger()); auto& graph = model.MainGraph(); std::vector inputs; @@ -97,13 +128,13 @@ static void CreateBaseModel(const PathString& model_name, // FLOAT tensor ONNX_NAMESPACE::TypeProto float_tensor; - float_tensor.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + float_tensor.mutable_tensor_type()->set_elem_type(dtype); for (auto dim : dims) { float_tensor.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(dim); } ONNX_NAMESPACE::TypeProto dyn_float_tensor; - dyn_float_tensor.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + dyn_float_tensor.mutable_tensor_type()->set_elem_type(dtype); auto& input_arg_1 = graph.GetOrCreateNodeArg("X", &float_tensor); auto& input_arg_2 = graph.GetOrCreateNodeArg("Y", &float_tensor); @@ -139,7 +170,7 @@ static void CreateBaseModel(const PathString& model_name, } ONNX_NAMESPACE::TypeProto float_scalar; - float_scalar.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + float_scalar.mutable_tensor_type()->set_elem_type(dtype); float_scalar.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1); auto& input_scalar = graph.GetOrCreateNodeArg("S", &float_scalar); inputs.push_back(&input_scalar); @@ -331,5 +362,30 @@ TEST(NvExecutionProviderTest, ContextEmbedAndReloadDataDynamic) { } } +TYPED_TEST(NvExecutionProviderTest, IOTypeTests) { + std::string dtype_name = this->getTypeAsName(); + ASSERT_FALSE(dtype_name.empty()); + PathString model_name = ORT_TSTR("nv_execution_provider_" + dtype_name + ".onnx"); + std::string graph_name = "test" + dtype_name; + std::vector dims = {1, -1, -1}; + + CreateBaseModel(model_name, graph_name, dims, true); + + auto env = Ort::Env(); + auto logging_level = OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING; + env.UpdateEnvWithCustomLogLevel(logging_level); + + // AOT time + { + Ort::SessionOptions so; + Ort::RunOptions run_options; + so.AppendExecutionProvider(kNvTensorRTRTXExecutionProvider, {}); + Ort::Session session_object(env, model_name.c_str(), so); + + auto io_binding = generate_io_binding(session_object); + session_object.Run(run_options, io_binding); + } +} + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/providers/qnn/conv_test.cc b/onnxruntime/test/providers/qnn/conv_test.cc index b15042a808c37..8232742f35a31 100644 --- a/onnxruntime/test/providers/qnn/conv_test.cc +++ b/onnxruntime/test/providers/qnn/conv_test.cc @@ -854,6 +854,34 @@ TEST_F(QnnHTPBackendTests, ConvU16U8_PerTensor_NoBias) { 21); // opset } +#ifndef __linux__ +// Test per-channel QDQ Conv with uint16 input[0], uint8 weights, and no bias. +// in0: u16, in1 (weight): s4, out: u8 +// Tests bug in QNN SDK 2.25 when validating Conv without a bias (QNN EP adds a dummy bias). +TEST_F(QnnHTPBackendTests, ConvU16U16_PerTensor_NoBias) { + std::vector input_shape = {1, 2, 4, 4}; + std::vector weight_shape = {3, 2, 2, 2}; + + TestInputDef input_def(input_shape, false, + GetFloatDataInRange(0.0f, 1.0f, TensorShape(input_shape).Size())); + TestInputDef weight_def(weight_shape, true, + GetFloatDataInRange(-1.0f, 5.0f, TensorShape(weight_shape).Size())); + + RunHTPConvOpTest("Conv", + input_def, + weight_def, + TestInputDef(), + {1, 1}, // Strides + {0, 0, 0, 0}, // Pads + {1, 1}, // Dilations + 1, // default group + "NOTSET", + ExpectedEPNodeAssignment::All, + false, // use_qdq_contrib_ops + 21); // opset +} +#endif + TEST_F(QnnHTPBackendTests, ConvU16S4_PerChannel_NoBias_LargeINT4Weight) { std::vector input_shape = {1, 3072, 1, 512}; std::vector weight_shape = {9216, 3072, 1, 1}; @@ -1309,6 +1337,36 @@ TEST_F(QnnHTPBackendTests, ConvTranspose3D_U8S8S32_PerChannel) { 13); } +#ifndef __linux__ +// Test per-channel QDQ Conv. in0: u16, in1 (weight): s8, in2 (bias): s32, out: u16 +TEST_F(QnnHTPBackendTests, ConvU16S16S32_PerChannel) { + std::vector input_shape = {1, 2, 4, 4}; + std::vector weight_shape = {3, 2, 2, 2}; + std::vector bias_shape = {3}; + + TestInputDef input_def(input_shape, false, + GetFloatDataInRange(-10.0f, 10.0f, TensorShape(input_shape).Size())); + TestInputDef weight_def(weight_shape, true, + GetFloatDataInRange(-1.0f, 5.0f, TensorShape(weight_shape).Size())); + TestInputDef bias_def(bias_shape, true, + GetFloatDataInRange(-1.0f, 1.0f, TensorShape(bias_shape).Size())); + + RunHTPConvOpPerChannelTest("Conv", + input_def, + weight_def, + bias_def, + 0, // weight quant axis + {1, 1}, // Strides + {0, 0, 0, 0}, // Pads + {1, 1}, // Dilations + 1, // default group + "NOTSET", + ExpectedEPNodeAssignment::All, + true, // use_qdq_contrib_ops + 13); // opset +} +#endif + // Test per-channel QDQ Conv. in0: u16, in1 (weight): s8, in2 (bias): s32, out: u16 TEST_F(QnnHTPBackendTests, ConvU16S8S32_PerChannel) { std::vector input_shape = {1, 2, 4, 4}; diff --git a/onnxruntime/test/providers/qnn/gemm_op_test.cc b/onnxruntime/test/providers/qnn/gemm_op_test.cc index a7c86806bf426..fbaf997b476da 100644 --- a/onnxruntime/test/providers/qnn/gemm_op_test.cc +++ b/onnxruntime/test/providers/qnn/gemm_op_test.cc @@ -73,8 +73,9 @@ TEST_F(QnnCPUBackendTests, Gemm_2D_Bias_Unsupported) { ExpectedEPNodeAssignment::All); // Assigned to QNN EP. } +// since Qnn v2.34 value pair (120.73912, 121.73912) at index #0 don't match, which is 1 from 120.739 // Test Gemm with dynamic (i.e., not initializer) inputs (A, B, Bias). -TEST_F(QnnCPUBackendTests, Gemm_Dynamic_A_B_Bias) { +TEST_F(QnnCPUBackendTests, DISABLED_Gemm_Dynamic_A_B_Bias) { std::vector input_a_data = GetFloatDataInRange(-10.0f, 10.0f, 6); std::vector input_b_data = GetFloatDataInRange(-5.0f, 5.0f, 24); std::vector input_c_data = GetFloatDataInRange(-1.0f, 1.0f, 4); @@ -110,8 +111,9 @@ TEST_F(QnnCPUBackendTests, Gemm_TransAB_Static_B_And_Bias) { ExpectedEPNodeAssignment::All); } +// Since Qnn 2.34 value pair (29.4347763, 30.4347763) at index #0 don't match, which is 1 from 29.4348 // Test Gemm with transposed A/B and dynamic (i.e., not initializer) B and Bias inputs. -TEST_F(QnnCPUBackendTests, Gemm_TransAB_Dynamic_B_And_Bias) { +TEST_F(QnnCPUBackendTests, DISABLED_Gemm_TransAB_Dynamic_B_And_Bias) { std::vector input_a_data = GetFloatDataInRange(-10.0f, 10.0f, 6); std::vector input_b_data = GetFloatDataInRange(-5.0f, 5.0f, 24); std::vector input_c_data = GetFloatDataInRange(-1.0f, 1.0f, 4); @@ -123,7 +125,8 @@ TEST_F(QnnCPUBackendTests, Gemm_TransAB_Dynamic_B_And_Bias) { ExpectedEPNodeAssignment::All); } -TEST_F(QnnCPUBackendTests, Gemm_Broadcast_Bias_DynamicInputs) { +// Since Qnn 2.34 value pair (11, 10) at index #0 don't match, which is -1 from 11 +TEST_F(QnnCPUBackendTests, DISABLED_Gemm_Broadcast_Bias_DynamicInputs) { std::vector input_a_data = {1.0f, 2.0f, 3.0f, 4.0f, -1.0f, -2.0f, -3.0f, -4.0f}; std::vector input_b_data(12, 1.0f); std::vector input_c_data = {1.0f, 2.0f, 3.0f}; diff --git a/onnxruntime/test/providers/qnn/lstm_test.cc b/onnxruntime/test/providers/qnn/lstm_test.cc new file mode 100644 index 0000000000000..4b011b9bf1108 --- /dev/null +++ b/onnxruntime/test/providers/qnn/lstm_test.cc @@ -0,0 +1,1177 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if !defined(ORT_MINIMAL_BUILD) + +#include +#include + +#include "test/optimizer/qdq_test_utils.h" +#include "test/providers/qnn/qnn_test_utils.h" +#include "test/providers/tester_types.h" + +#include "core/graph/onnx_protobuf.h" + +#include "gtest/gtest.h" + +namespace onnxruntime { +namespace test { + +/* + ONNX LSTM inputs: + in[0]: X [seq_length, batch_size, input_size] + in[1]: W [num_directions, 4*hidden_size, input_size] + in[2]: R [num_directions, 4*hidden_size, hidden_size] + + ONNX LSTM optional inputs: + in[3]: B [num_directions, 8*hidden_size] + in[4]: + in[5]: initial_h [num_directions, batch_size, hidden_size]. + in[6]: initial_c [num_directions, batch_size, hidden_size]. + in[7]: P [num_directions, 3*hidde_size] + + ONNX LSTM Parameters: + - activation_alpha ---> Not supported by QNN. + - activation_beta ---> Not supported by QNN. + - activations ---> Not supported by QNN. + - clip ---> Not supported by QNN since the clip in ONNX applied to iofc while QNN only apply to c. Refer + https://github.com/microsoft/onnxruntime/blob/v1.21.0/onnxruntime/core/providers/cpu/rnn/uni_directional_lstm.cc + - direction + - hidden_size + - input_forget ---> Not supported by QNN + - layout: The shape format of inputs X, initial_h, initial_c and outputs Y, Y_h, Y_c. + If 0, the following shapes are expected: + X.shape = [seq_length, batch_size, input_size], + Y.shape = [seq_length, num_directions, batch_size, hidden_size], + initial_h.shape = Y_h.shape = initial_c.shape = Y_c.shape = [num_directions, batch_size, hidden_size]. + If 1, the following shapes are expected: + X.shape = [batch_size, seq_length, input_size], + Y.shape = [batch_size, seq_length, num_directions, hidden_size], + initial_h.shape = Y_h.shape = initial_c.shape = Y_c.shape = [batch_size, num_directions, hidden_size]. + + ONNX LSTM optional outputs: + out[0]: Y [seq_length, num_directions, batch_size, hidden_size] + out[1]: Y_h [num_directions, batch_size, hidden_size] + out[2]: Y_c [num_directions, batch_size, hidden_size] + +*/ + +template +void _BuildLSTMTestCase(ModelTestBuilder& builder, + const TestInputDef& X_def, + const TestInputDef& W_def, + const TestInputDef& R_def, + const std::optional>> B_def, + const std::optional>> H_def, + const std::optional>> C_def, + const std::optional>> P_def, + const bool has_Y, + const bool has_Y_h, + const bool has_Y_c, + const std::string direction, + const int64_t hidden_size, + const int64_t layout, + const std::vector>& output_qparams) { + auto convert_input = [](ModelTestBuilder& builder, const TestInputDef& def) { + if (std::is_same::value) { + TestInputDef Fp16_def = ConvertToFP16InputDef(def); + return MakeTestInput(builder, Fp16_def); + } else if (std::is_same::value) { + NodeArg* input = MakeTestInput(builder, def); + QuantParams qparams = GetTestInputQuantParams(def); + return AddQDQNodePair(builder, input, qparams.scale, qparams.zero_point); + } else { + return MakeTestInput(builder, def); + } + }; + + NodeArg* inputX = convert_input(builder, X_def); + NodeArg* inputW = convert_input(builder, W_def); + NodeArg* inputR = convert_input(builder, R_def); + std::vector input_args = {inputX, inputW, inputR}; + + // optional inputs + // B + if (B_def) { + input_args.push_back(convert_input(builder, B_def->get())); + } else { + input_args.push_back(builder.MakeOptionalTensor()); + } + + // sequence length + input_args.push_back(builder.MakeOptionalTensor()); + + // H + if (H_def) { + input_args.push_back(convert_input(builder, H_def->get())); + } else { + input_args.push_back(builder.MakeOptionalTensor()); + } + + // C + if (C_def) { + input_args.push_back(convert_input(builder, C_def->get())); + } else { + input_args.push_back(builder.MakeOptionalTensor()); + } + + // P + if (P_def) { + input_args.push_back(convert_input(builder, P_def->get())); + } else { + input_args.push_back(builder.MakeOptionalTensor()); + } + + NodeArg *lstm_output_Y, *lstm_output_Y_h, *lstm_output_Y_c; + if (has_Y) { + if (std::is_same::value || std::is_same::value) { + lstm_output_Y = builder.MakeOutput(); + } else { + lstm_output_Y = builder.MakeIntermediate(); + } + } else { + lstm_output_Y = builder.MakeOptionalTensor(); + } + + if (has_Y_h) { + if (std::is_same::value || std::is_same::value) { + lstm_output_Y_h = builder.MakeOutput(); + } else { + lstm_output_Y_h = builder.MakeIntermediate(); + } + } else { + lstm_output_Y_h = builder.MakeOptionalTensor(); + } + if (has_Y_c) { + if (std::is_same::value || std::is_same::value) { + lstm_output_Y_c = builder.MakeOutput(); + } else { + lstm_output_Y_c = builder.MakeIntermediate(); + } + } else { + lstm_output_Y_c = builder.MakeOptionalTensor(); + } + + Node& lstm_node = builder.AddNode("LSTM", + input_args, + {lstm_output_Y, lstm_output_Y_h, lstm_output_Y_c}); + lstm_node.AddAttribute("direction", direction); + lstm_node.AddAttribute("hidden_size", hidden_size); + lstm_node.AddAttribute("layout", layout); + ORT_UNUSED_PARAMETER(output_qparams); + if (std::is_same::value) { + size_t i = 0; + if (has_Y) { + AddQDQNodePairWithOutputAsGraphOutput(builder, lstm_output_Y, output_qparams[i].scale, + output_qparams[i].zero_point); + i++; + } + if (has_Y_h) { + AddQDQNodePairWithOutputAsGraphOutput(builder, lstm_output_Y_h, output_qparams[i].scale, + output_qparams[i].zero_point); + i++; + } + if (has_Y_c) { + AddQDQNodePairWithOutputAsGraphOutput(builder, lstm_output_Y_c, output_qparams[i].scale, + output_qparams[i].zero_point); + i++; + } + } +} + +template +static GetTestModelFn BuildLSTMTestCase(const TestInputDef& X_def, + const TestInputDef& W_def, + const TestInputDef& R_def, + const std::optional>> B_def, + const std::optional>> H_def, + const std::optional>> C_def, + const std::optional>> P_def, + const bool has_Y, + const bool has_Y_h, + const bool has_Y_c, + const std::string direction, + const int64_t hidden_size, + const int64_t layout) { + return [X_def, W_def, R_def, B_def, + H_def, C_def, P_def, + has_Y, has_Y_h, has_Y_c, + direction, hidden_size, layout](ModelTestBuilder& builder) { + _BuildLSTMTestCase(builder, X_def, W_def, R_def, B_def, H_def, C_def, P_def, has_Y, has_Y_h, has_Y_c, direction, hidden_size, layout, {}); + }; +} + +template +static GetTestQDQModelFn BuildQDQLSTMTestCase(const TestInputDef& X_def, + const TestInputDef& W_def, + const TestInputDef& R_def, + const std::optional>> B_def, + const std::optional>> H_def, + const std::optional>> C_def, + const std::optional>> P_def, + const bool has_Y, + const bool has_Y_h, + const bool has_Y_c, + const std::string direction, + const int64_t hidden_size, + const int64_t layout) { + return [X_def, W_def, R_def, B_def, + H_def, C_def, P_def, + has_Y, has_Y_h, has_Y_c, + direction, hidden_size, layout](ModelTestBuilder& builder, + std::vector>& output_qparams) { + _BuildLSTMTestCase(builder, X_def, W_def, R_def, B_def, H_def, C_def, P_def, has_Y, has_Y_h, has_Y_c, direction, hidden_size, layout, output_qparams); + }; +} + +#if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) + +// Runs an LSTM model on the QNN HTP backend. Checks the graph node assignment, and that inference +// outputs for QNN EP and CPU EP match. +// Note: There are accuracy on HTP in fixed point, to avoid the issue, we don't register QDQ selector for LSTM and it +// is running on HTP fp16 +template +static void RunHtpQDQLSTMOpTest(const TestInputDef& X_def, + const TestInputDef& W_def, + const TestInputDef& R_def, + const std::optional>> B_def, + const std::optional>> H_def, + const std::optional>> C_def, + const std::optional>> P_def, + const bool has_Y, + const bool has_Y_h, + const bool has_Y_c, + const std::string direction, + const int64_t hidden_size, + const int64_t layout, + ExpectedEPNodeAssignment expected_ep_assignment, + int opset = 22, + QDQTolerance tolerance = QDQTolerance()) { + ProviderOptions provider_options; + provider_options["backend_type"] = "htp"; + provider_options["offload_graph_io_quantization"] = "0"; + + TestQDQModelAccuracy(BuildLSTMTestCase(X_def, W_def, R_def, B_def, H_def, C_def, P_def, has_Y, has_Y_h, has_Y_c, direction, hidden_size, layout), + BuildQDQLSTMTestCase(X_def, W_def, R_def, B_def, H_def, C_def, P_def, has_Y, has_Y_h, has_Y_c, direction, hidden_size, layout), + provider_options, + opset, + expected_ep_assignment, + tolerance); +} + +static void RunHtpFp16LSTMOpTest(const TestInputDef& X_def, + const TestInputDef& W_def, + const TestInputDef& R_def, + const std::optional>> B_def, + const std::optional>> H_def, + const std::optional>> C_def, + const std::optional>> P_def, + const bool has_Y, + const bool has_Y_h, + const bool has_Y_c, + const std::string direction, + const int64_t hidden_size, + const int64_t layout, + ExpectedEPNodeAssignment expected_ep_assignment, + int opset = 22, + float tolerance = 0.004f) { + ProviderOptions provider_options; + provider_options["backend_type"] = "htp"; + + TestFp16ModelAccuracy(BuildLSTMTestCase(X_def, W_def, R_def, B_def, H_def, C_def, P_def, has_Y, has_Y_h, has_Y_c, direction, hidden_size, layout), + BuildLSTMTestCase(X_def, W_def, R_def, B_def, H_def, C_def, P_def, has_Y, has_Y_h, has_Y_c, direction, hidden_size, layout), + provider_options, + opset, + expected_ep_assignment, + tolerance); +} + +static void RunCpuFP32LSTMOpTest(const TestInputDef& X_def, + const TestInputDef& W_def, + const TestInputDef& R_def, + const std::optional>> B_def, + const std::optional>> H_def, + const std::optional>> C_def, + const std::optional>> P_def, + const bool has_Y, + const bool has_Y_h, + const bool has_Y_c, + const std::string direction, + const int64_t hidden_size, + const int64_t layout, + ExpectedEPNodeAssignment expected_ep_assignment, + int opset = 22, + float tolerance = 0.004f) { + ProviderOptions provider_options; + provider_options["backend_type"] = "cpu"; + + RunQnnModelTest(BuildLSTMTestCase(X_def, W_def, R_def, B_def, H_def, C_def, P_def, has_Y, has_Y_h, has_Y_c, direction, hidden_size, layout), + provider_options, + opset, + expected_ep_assignment, + tolerance); +} + +// QNN failed to finalize when P is provided +// TODO: Add P to unit test below once finalize issue is resolved + +// HTP QDQ +TEST_F(QnnHTPBackendTests, LSTM_QDQ_sanity_forward) { + std::string direction = "forward"; + uint32_t num_direction = 1; + uint32_t batch_size = 3; + uint32_t hidden_size = 4; + uint32_t input_size = 5; + uint32_t seq_len = 6; + auto B_def = TestInputDef({num_direction, 8 * hidden_size}, false, -1.0f, 1.0f); + auto H_def = TestInputDef({num_direction, batch_size, hidden_size}, false, -1.0f, 1.0f); + auto C_def = TestInputDef({num_direction, batch_size, hidden_size}, false, -1.0f, 1.0f); + RunHtpQDQLSTMOpTest(TestInputDef({seq_len, batch_size, input_size}, false, -1.0f, 1.0f), // X + TestInputDef({num_direction, 4 * hidden_size, input_size}, false, -1.0f, 1.0f), // W + TestInputDef({num_direction, 4 * hidden_size, hidden_size}, false, -1.0f, 1.0f), // R + std::ref(B_def), // B + std::ref(H_def), // initial_h + std::ref(C_def), // initial_c + std::nullopt, // P + true, // has_Y + true, // has_Y_h + true, // has_Y_c + direction, // direction + hidden_size, // hidden_size + 0, // layout + ExpectedEPNodeAssignment::All); +} + +TEST_F(QnnHTPBackendTests, LSTM_QDQ_sanity_reverse) { + std::string direction = "reverse"; + uint32_t num_direction = 1; + uint32_t batch_size = 3; + uint32_t hidden_size = 4; + uint32_t input_size = 5; + uint32_t seq_len = 6; + auto B_def = TestInputDef({num_direction, 8 * hidden_size}, false, -1.0f, 1.0f); + auto H_def = TestInputDef({num_direction, batch_size, hidden_size}, false, -1.0f, 1.0f); + auto C_def = TestInputDef({num_direction, batch_size, hidden_size}, false, -1.0f, 1.0f); + RunHtpQDQLSTMOpTest(TestInputDef({seq_len, batch_size, input_size}, false, -1.0f, 1.0f), // X + TestInputDef({num_direction, 4 * hidden_size, input_size}, false, -1.0f, 1.0f), // W + TestInputDef({num_direction, 4 * hidden_size, hidden_size}, false, -1.0f, 1.0f), // R + std::ref(B_def), // B + std::ref(H_def), // initial_h + std::ref(C_def), // initial_c + std::nullopt, // P + true, // has_Y + true, // has_Y_h + true, // has_Y_c + direction, // direction + hidden_size, // hidden_size + 0, // layout + ExpectedEPNodeAssignment::All); +} + +TEST_F(QnnHTPBackendTests, LSTM_QDQ_sanity_bidirectional) { + std::string direction = "bidirectional"; + uint32_t num_direction = 2; + uint32_t batch_size = 3; + uint32_t hidden_size = 4; + uint32_t input_size = 5; + uint32_t seq_len = 6; + auto B_def = TestInputDef({num_direction, 8 * hidden_size}, false, -1.0f, 1.0f); + auto H_def = TestInputDef({num_direction, batch_size, hidden_size}, false, -1.0f, 1.0f); + auto C_def = TestInputDef({num_direction, batch_size, hidden_size}, false, -1.0f, 1.0f); + RunHtpQDQLSTMOpTest(TestInputDef({seq_len, batch_size, input_size}, false, -1.0f, 1.0f), // X + TestInputDef({num_direction, 4 * hidden_size, input_size}, false, -1.0f, 1.0f), // W + TestInputDef({num_direction, 4 * hidden_size, hidden_size}, false, -1.0f, 1.0f), // R + std::ref(B_def), // B + std::ref(H_def), // initial_h + std::ref(C_def), // initial_c + std::nullopt, // P + true, // has_Y + true, // has_Y_h + true, // has_Y_c + direction, // direction + hidden_size, // hidden_size + 0, // layout + ExpectedEPNodeAssignment::All); +} + +TEST_F(QnnHTPBackendTests, LSTM_QDQ_sanity_bidirectional_wo_B) { + std::string direction = "bidirectional"; + uint32_t num_direction = 2; + uint32_t batch_size = 3; + uint32_t hidden_size = 4; + uint32_t input_size = 5; + uint32_t seq_len = 6; + auto H_def = TestInputDef({num_direction, batch_size, hidden_size}, false, -1.0f, 1.0f); + auto C_def = TestInputDef({num_direction, batch_size, hidden_size}, false, -1.0f, 1.0f); + RunHtpQDQLSTMOpTest(TestInputDef({seq_len, batch_size, input_size}, false, -1.0f, 1.0f), // X + TestInputDef({num_direction, 4 * hidden_size, input_size}, false, -1.0f, 1.0f), // W + TestInputDef({num_direction, 4 * hidden_size, hidden_size}, false, -1.0f, 1.0f), // R + std::nullopt, // B + std::ref(H_def), // initial_h + std::ref(C_def), // initial_c + std::nullopt, // P + true, // has_Y + true, // has_Y_h + true, // has_Y_c + direction, // direction + hidden_size, // hidden_size + 0, // layout + ExpectedEPNodeAssignment::All); +} + +TEST_F(QnnHTPBackendTests, LSTM_QDQ_sanity_bidirectional_wo_H) { + std::string direction = "bidirectional"; + uint32_t num_direction = 2; + uint32_t batch_size = 3; + uint32_t hidden_size = 4; + uint32_t input_size = 5; + uint32_t seq_len = 6; + auto B_def = TestInputDef({num_direction, 8 * hidden_size}, false, -1.0f, 1.0f); + auto C_def = TestInputDef({num_direction, batch_size, hidden_size}, false, -1.0f, 1.0f); + RunHtpQDQLSTMOpTest(TestInputDef({seq_len, batch_size, input_size}, false, -1.0f, 1.0f), // X + TestInputDef({num_direction, 4 * hidden_size, input_size}, false, -1.0f, 1.0f), // W + TestInputDef({num_direction, 4 * hidden_size, hidden_size}, false, -1.0f, 1.0f), // R + std::ref(B_def), // B + std::nullopt, // initial_h + std::ref(C_def), // initial_c + std::nullopt, // P + true, // has_Y + true, // has_Y_h + true, // has_Y_c + direction, // direction + hidden_size, // hidden_size + 0, // layout + ExpectedEPNodeAssignment::All); +} + +TEST_F(QnnHTPBackendTests, LSTM_QDQ_sanity_bidirectional_wo_C) { + std::string direction = "bidirectional"; + uint32_t num_direction = 2; + uint32_t batch_size = 3; + uint32_t hidden_size = 4; + uint32_t input_size = 5; + uint32_t seq_len = 6; + auto B_def = TestInputDef({num_direction, 8 * hidden_size}, false, -1.0f, 1.0f); + auto H_def = TestInputDef({num_direction, batch_size, hidden_size}, false, -1.0f, 1.0f); + RunHtpQDQLSTMOpTest(TestInputDef({seq_len, batch_size, input_size}, false, -1.0f, 1.0f), // X + TestInputDef({num_direction, 4 * hidden_size, input_size}, false, -1.0f, 1.0f), // W + TestInputDef({num_direction, 4 * hidden_size, hidden_size}, false, -1.0f, 1.0f), // R + std::ref(B_def), // B + std::ref(H_def), // initial_h + std::nullopt, // initial_c + std::nullopt, // P + true, // has_Y + true, // has_Y_h + true, // has_Y_c + direction, // direction + hidden_size, // hidden_size + 0, // layout + ExpectedEPNodeAssignment::All); +} + +TEST_F(QnnHTPBackendTests, LSTM_QDQ_sanity_bidirectional_all_initializer) { + std::string direction = "bidirectional"; + uint32_t num_direction = 2; + uint32_t batch_size = 3; + uint32_t hidden_size = 4; + uint32_t input_size = 5; + uint32_t seq_len = 6; + auto B_def = TestInputDef({num_direction, 8 * hidden_size}, true, -0.5f, 0.5f); + auto H_def = TestInputDef({num_direction, batch_size, hidden_size}, true, -0.5f, 0.5f); + auto C_def = TestInputDef({num_direction, batch_size, hidden_size}, true, -0.5f, 0.5f); + RunHtpQDQLSTMOpTest(TestInputDef({seq_len, batch_size, input_size}, false, -0.5f, 0.5f), // X + TestInputDef({num_direction, 4 * hidden_size, input_size}, true, -0.5f, 0.5f), // W + TestInputDef({num_direction, 4 * hidden_size, hidden_size}, true, -0.5f, 0.5f), // R + std::ref(B_def), // B + std::ref(H_def), // initial_h + std::ref(C_def), // initial_c + std::nullopt, // P + true, // has_Y + true, // has_Y_h + true, // has_Y_c + direction, // direction + hidden_size, // hidden_size + 0, // layout + ExpectedEPNodeAssignment::All, + 22, + QDQTolerance(0.008f)); +} + +TEST_F(QnnHTPBackendTests, LSTM_QDQ_sanity_bidirectional_Y_only) { + std::string direction = "bidirectional"; + uint32_t num_direction = 2; + uint32_t batch_size = 3; + uint32_t hidden_size = 4; + uint32_t input_size = 5; + uint32_t seq_len = 6; + auto B_def = TestInputDef({num_direction, 8 * hidden_size}, false, -1.0f, 1.0f); + auto H_def = TestInputDef({num_direction, batch_size, hidden_size}, false, -1.0f, 1.0f); + auto C_def = TestInputDef({num_direction, batch_size, hidden_size}, false, -1.0f, 1.0f); + RunHtpQDQLSTMOpTest(TestInputDef({seq_len, batch_size, input_size}, false, -1.0f, 1.0f), // X + TestInputDef({num_direction, 4 * hidden_size, input_size}, false, -1.0f, 1.0f), // W + TestInputDef({num_direction, 4 * hidden_size, hidden_size}, false, -1.0f, 1.0f), // R + std::ref(B_def), // B + std::ref(H_def), // initial_h + std::ref(C_def), // initial_c + std::nullopt, // P + true, // has_Y + false, // has_Y_h + false, // has_Y_c + direction, // direction + hidden_size, // hidden_size + 0, // layout + ExpectedEPNodeAssignment::All); +} + +TEST_F(QnnHTPBackendTests, LSTM_QDQ_sanity_bidirectional_Y_h_only) { + std::string direction = "bidirectional"; + uint32_t num_direction = 2; + uint32_t batch_size = 3; + uint32_t hidden_size = 4; + uint32_t input_size = 5; + uint32_t seq_len = 6; + auto B_def = TestInputDef({num_direction, 8 * hidden_size}, false, -1.0f, 1.0f); + auto H_def = TestInputDef({num_direction, batch_size, hidden_size}, false, -1.0f, 1.0f); + auto C_def = TestInputDef({num_direction, batch_size, hidden_size}, false, -1.0f, 1.0f); + RunHtpQDQLSTMOpTest(TestInputDef({seq_len, batch_size, input_size}, false, -1.0f, 1.0f), // X + TestInputDef({num_direction, 4 * hidden_size, input_size}, false, -1.0f, 1.0f), // W + TestInputDef({num_direction, 4 * hidden_size, hidden_size}, false, -1.0f, 1.0f), // R + std::ref(B_def), // B + std::ref(H_def), // initial_h + std::ref(C_def), // initial_c + std::nullopt, // P + false, // has_Y + true, // has_Y_h + false, // has_Y_c + direction, // direction + hidden_size, // hidden_size + 0, // layout + ExpectedEPNodeAssignment::All); +} + +TEST_F(QnnHTPBackendTests, LSTM_QDQ_sanity_bidirectional_Y_c_only) { + std::string direction = "bidirectional"; + uint32_t num_direction = 2; + uint32_t batch_size = 3; + uint32_t hidden_size = 4; + uint32_t input_size = 5; + uint32_t seq_len = 6; + auto B_def = TestInputDef({num_direction, 8 * hidden_size}, false, -1.0f, 1.0f); + auto H_def = TestInputDef({num_direction, batch_size, hidden_size}, false, -1.0f, 1.0f); + auto C_def = TestInputDef({num_direction, batch_size, hidden_size}, false, -1.0f, 1.0f); + RunHtpQDQLSTMOpTest(TestInputDef({seq_len, batch_size, input_size}, false, -1.0f, 1.0f), // X + TestInputDef({num_direction, 4 * hidden_size, input_size}, false, -1.0f, 1.0f), // W + TestInputDef({num_direction, 4 * hidden_size, hidden_size}, false, -1.0f, 1.0f), // R + std::ref(B_def), // B + std::ref(H_def), // initial_h + std::ref(C_def), // initial_c + std::nullopt, // P + false, // has_Y + false, // has_Y_h + true, // has_Y_c + direction, // direction + hidden_size, // hidden_size + 0, // layout + ExpectedEPNodeAssignment::All); +} + +// HTP Fp16 +TEST_F(QnnHTPBackendTests, LSTM_Fp16_sanity_forward) { + std::string direction = "forward"; + uint32_t num_direction = 1; + uint32_t batch_size = 3; + uint32_t hidden_size = 4; + uint32_t input_size = 5; + uint32_t seq_len = 6; + auto B_def = TestInputDef({num_direction, 8 * hidden_size}, false, -1.0f, 1.0f); + auto H_def = TestInputDef({num_direction, batch_size, hidden_size}, false, -1.0f, 1.0f); + auto C_def = TestInputDef({num_direction, batch_size, hidden_size}, false, -1.0f, 1.0f); + RunHtpFp16LSTMOpTest(TestInputDef({seq_len, batch_size, input_size}, false, -1.0f, 1.0f), // X + TestInputDef({num_direction, 4 * hidden_size, input_size}, false, -1.0f, 1.0f), // W + TestInputDef({num_direction, 4 * hidden_size, hidden_size}, false, -1.0f, 1.0f), // R + std::ref(B_def), // B + std::ref(H_def), // initial_h + std::ref(C_def), // initial_c + std::nullopt, // P + true, // has_Y + true, // has_Y_h + true, // has_Y_c + direction, // direction + hidden_size, // hidden_size + 0, // layout + ExpectedEPNodeAssignment::All); +} + +TEST_F(QnnHTPBackendTests, LSTM_Fp16_sanity_reverse) { + std::string direction = "reverse"; + uint32_t num_direction = 1; + uint32_t batch_size = 3; + uint32_t hidden_size = 4; + uint32_t input_size = 5; + uint32_t seq_len = 6; + auto B_def = TestInputDef({num_direction, 8 * hidden_size}, false, -1.0f, 1.0f); + auto H_def = TestInputDef({num_direction, batch_size, hidden_size}, false, -1.0f, 1.0f); + auto C_def = TestInputDef({num_direction, batch_size, hidden_size}, false, -1.0f, 1.0f); + RunHtpFp16LSTMOpTest(TestInputDef({seq_len, batch_size, input_size}, false, -1.0f, 1.0f), // X + TestInputDef({num_direction, 4 * hidden_size, input_size}, false, -1.0f, 1.0f), // W + TestInputDef({num_direction, 4 * hidden_size, hidden_size}, false, -1.0f, 1.0f), // R + std::ref(B_def), // B + std::ref(H_def), // initial_h + std::ref(C_def), // initial_c + std::nullopt, // P + true, // has_Y + true, // has_Y_h + true, // has_Y_c + direction, // direction + hidden_size, // hidden_size + 0, // layout + ExpectedEPNodeAssignment::All); +} + +TEST_F(QnnHTPBackendTests, LSTM_Fp16_sanity_bidirectional) { + std::string direction = "bidirectional"; + uint32_t num_direction = 2; + uint32_t batch_size = 3; + uint32_t hidden_size = 4; + uint32_t input_size = 5; + uint32_t seq_len = 6; + auto B_def = TestInputDef({num_direction, 8 * hidden_size}, false, -1.0f, 1.0f); + auto H_def = TestInputDef({num_direction, batch_size, hidden_size}, false, -1.0f, 1.0f); + auto C_def = TestInputDef({num_direction, batch_size, hidden_size}, false, -1.0f, 1.0f); + RunHtpFp16LSTMOpTest( + TestInputDef({seq_len, batch_size, input_size}, false, -1.0f, 1.0f), // X + TestInputDef({num_direction, 4 * hidden_size, input_size}, false, -1.0f, 1.0f), // W + TestInputDef({num_direction, 4 * hidden_size, hidden_size}, false, -1.0f, 1.0f), // R + std::ref(B_def), // B + std::ref(H_def), // initial_h + std::ref(C_def), // initial_c + std::nullopt, // P + true, // has_Y + true, // has_Y_h + true, // has_Y_c + direction, // direction + hidden_size, // hidden_size + 0, // layout + ExpectedEPNodeAssignment::All); +} + +TEST_F(QnnHTPBackendTests, LSTM_Fp16_sanity_bidirectional_wo_B) { + std::string direction = "bidirectional"; + uint32_t num_direction = 2; + uint32_t batch_size = 3; + uint32_t hidden_size = 4; + uint32_t input_size = 5; + uint32_t seq_len = 6; + auto H_def = TestInputDef({num_direction, batch_size, hidden_size}, false, -1.0f, 1.0f); + auto C_def = TestInputDef({num_direction, batch_size, hidden_size}, false, -1.0f, 1.0f); + RunHtpFp16LSTMOpTest( + TestInputDef({seq_len, batch_size, input_size}, false, -1.0f, 1.0f), // X + TestInputDef({num_direction, 4 * hidden_size, input_size}, false, -1.0f, 1.0f), // W + TestInputDef({num_direction, 4 * hidden_size, hidden_size}, false, -1.0f, 1.0f), // R + std::nullopt, // B + std::ref(H_def), // initial_h + std::ref(C_def), // initial_c + std::nullopt, // P + true, // has_Y + true, // has_Y_h + true, // has_Y_c + direction, // direction + hidden_size, // hidden_size + 0, // layout + ExpectedEPNodeAssignment::All); +} + +TEST_F(QnnHTPBackendTests, LSTM_Fp16_sanity_bidirectional_wo_H) { + std::string direction = "bidirectional"; + uint32_t num_direction = 2; + uint32_t batch_size = 3; + uint32_t hidden_size = 4; + uint32_t input_size = 5; + uint32_t seq_len = 6; + auto B_def = TestInputDef({num_direction, 8 * hidden_size}, false, -1.0f, 1.0f); + auto C_def = TestInputDef({num_direction, batch_size, hidden_size}, false, -1.0f, 1.0f); + RunHtpFp16LSTMOpTest( + TestInputDef({seq_len, batch_size, input_size}, false, -1.0f, 1.0f), // X + TestInputDef({num_direction, 4 * hidden_size, input_size}, false, -1.0f, 1.0f), // W + TestInputDef({num_direction, 4 * hidden_size, hidden_size}, false, -1.0f, 1.0f), // R + std::ref(B_def), // B + std::nullopt, // initial_h + std::ref(C_def), // initial_c + std::nullopt, // P + true, // has_Y + true, // has_Y_h + true, // has_Y_c + direction, // direction + hidden_size, // hidden_size + 0, // layout + ExpectedEPNodeAssignment::All); +} + +TEST_F(QnnHTPBackendTests, LSTM_Fp16_sanity_bidirectional_wo_C) { + std::string direction = "bidirectional"; + uint32_t num_direction = 2; + uint32_t batch_size = 3; + uint32_t hidden_size = 4; + uint32_t input_size = 5; + uint32_t seq_len = 6; + auto B_def = TestInputDef({num_direction, 8 * hidden_size}, false, -1.0f, 1.0f); + auto H_def = TestInputDef({num_direction, batch_size, hidden_size}, false, -1.0f, 1.0f); + RunHtpFp16LSTMOpTest( + TestInputDef({seq_len, batch_size, input_size}, false, -1.0f, 1.0f), // X + TestInputDef({num_direction, 4 * hidden_size, input_size}, false, -1.0f, 1.0f), // W + TestInputDef({num_direction, 4 * hidden_size, hidden_size}, false, -1.0f, 1.0f), // R + std::ref(B_def), // B + std::ref(H_def), // initial_h + std::nullopt, // initial_c + std::nullopt, // P + true, // has_Y + true, // has_Y_h + true, // has_Y_c + direction, // direction + hidden_size, // hidden_size + 0, // layout + ExpectedEPNodeAssignment::All); +} + +TEST_F(QnnHTPBackendTests, LSTM_Fp16_sanity_bidirectional_all_initializer) { + std::string direction = "bidirectional"; + uint32_t num_direction = 2; + uint32_t batch_size = 3; + uint32_t hidden_size = 4; + uint32_t input_size = 5; + uint32_t seq_len = 6; + auto B_def = TestInputDef({num_direction, 8 * hidden_size}, true, -1.0f, 1.0f); + auto H_def = TestInputDef({num_direction, batch_size, hidden_size}, true, -1.0f, 1.0f); + auto C_def = TestInputDef({num_direction, batch_size, hidden_size}, true, -1.0f, 1.0f); + RunHtpFp16LSTMOpTest( + TestInputDef({seq_len, batch_size, input_size}, false, -1.0f, 1.0f), // X + TestInputDef({num_direction, 4 * hidden_size, input_size}, true, -1.0f, 1.0f), // W + TestInputDef({num_direction, 4 * hidden_size, hidden_size}, true, -1.0f, 1.0f), // R + std::ref(B_def), // B + std::ref(H_def), // initial_h + std::ref(C_def), // initial_c + std::nullopt, // P + true, // has_Y + true, // has_Y_h + true, // has_Y_c + direction, // direction + hidden_size, // hidden_size + 0, // layout + ExpectedEPNodeAssignment::All); +} + +TEST_F(QnnHTPBackendTests, LSTM_Fp16_sanity_bidirectional_Y_only) { + std::string direction = "bidirectional"; + uint32_t num_direction = 2; + uint32_t batch_size = 3; + uint32_t hidden_size = 4; + uint32_t input_size = 5; + uint32_t seq_len = 6; + auto B_def = TestInputDef({num_direction, 8 * hidden_size}, false, -1.0f, 1.0f); + auto H_def = TestInputDef({num_direction, batch_size, hidden_size}, false, -1.0f, 1.0f); + auto C_def = TestInputDef({num_direction, batch_size, hidden_size}, false, -1.0f, 1.0f); + RunHtpFp16LSTMOpTest( + TestInputDef({seq_len, batch_size, input_size}, false, -1.0f, 1.0f), // X + TestInputDef({num_direction, 4 * hidden_size, input_size}, false, -1.0f, 1.0f), // W + TestInputDef({num_direction, 4 * hidden_size, hidden_size}, false, -1.0f, 1.0f), // R + std::ref(B_def), // B + std::ref(H_def), // initial_h + std::ref(C_def), // initial_c + std::nullopt, // P + true, // has_Y + false, // has_Y_h + false, // has_Y_c + direction, // direction + hidden_size, // hidden_size + 0, // layout + ExpectedEPNodeAssignment::All); +} + +TEST_F(QnnHTPBackendTests, LSTM_Fp16_sanity_bidirectional_Y_h_only) { + std::string direction = "bidirectional"; + uint32_t num_direction = 2; + uint32_t batch_size = 3; + uint32_t hidden_size = 4; + uint32_t input_size = 5; + uint32_t seq_len = 6; + auto B_def = TestInputDef({num_direction, 8 * hidden_size}, false, -1.0f, 1.0f); + auto H_def = TestInputDef({num_direction, batch_size, hidden_size}, false, -1.0f, 1.0f); + auto C_def = TestInputDef({num_direction, batch_size, hidden_size}, false, -1.0f, 1.0f); + RunHtpFp16LSTMOpTest( + TestInputDef({seq_len, batch_size, input_size}, false, -1.0f, 1.0f), // X + TestInputDef({num_direction, 4 * hidden_size, input_size}, false, -1.0f, 1.0f), // W + TestInputDef({num_direction, 4 * hidden_size, hidden_size}, false, -1.0f, 1.0f), // R + std::ref(B_def), // B + std::ref(H_def), // initial_h + std::ref(C_def), // initial_c + std::nullopt, // P + false, // has_Y + true, // has_Y_h + false, // has_Y_c + direction, // direction + hidden_size, // hidden_size + 0, // layout + ExpectedEPNodeAssignment::All); +} + +TEST_F(QnnHTPBackendTests, LSTM_Fp16_sanity_bidirectional_Y_c_only) { + std::string direction = "bidirectional"; + uint32_t num_direction = 2; + uint32_t batch_size = 3; + uint32_t hidden_size = 4; + uint32_t input_size = 5; + uint32_t seq_len = 6; + auto B_def = TestInputDef({num_direction, 8 * hidden_size}, false, -1.0f, 1.0f); + auto H_def = TestInputDef({num_direction, batch_size, hidden_size}, false, -1.0f, 1.0f); + auto C_def = TestInputDef({num_direction, batch_size, hidden_size}, false, -1.0f, 1.0f); + RunHtpFp16LSTMOpTest( + TestInputDef({seq_len, batch_size, input_size}, false, -1.0f, 1.0f), // X + TestInputDef({num_direction, 4 * hidden_size, input_size}, false, -1.0f, 1.0f), // W + TestInputDef({num_direction, 4 * hidden_size, hidden_size}, false, -1.0f, 1.0f), // R + std::ref(B_def), // B + std::ref(H_def), // initial_h + std::ref(C_def), // initial_c + std::nullopt, // P + false, // has_Y + false, // has_Y_h + true, // has_Y_c + direction, // direction + hidden_size, // hidden_size + 0, // layout + ExpectedEPNodeAssignment::All); +} + +// CPU FP32 +TEST_F(QnnCPUBackendTests, LSTM_FP32_sanity_forward) { + std::string direction = "forward"; + uint32_t num_direction = 1; + uint32_t batch_size = 3; + uint32_t hidden_size = 4; + uint32_t input_size = 5; + uint32_t seq_len = 6; + auto B_def = TestInputDef({num_direction, 8 * hidden_size}, false, -1.0f, 1.0f); + auto H_def = TestInputDef({num_direction, batch_size, hidden_size}, false, -1.0f, 1.0f); + auto C_def = TestInputDef({num_direction, batch_size, hidden_size}, false, -1.0f, 1.0f); + auto P_def = TestInputDef({num_direction, 3 * hidden_size}, false, -1.0f, 1.0f); + RunCpuFP32LSTMOpTest(TestInputDef({seq_len, batch_size, input_size}, false, -1.0f, 1.0f), // X + TestInputDef({num_direction, 4 * hidden_size, input_size}, false, -1.0f, 1.0f), // W + TestInputDef({num_direction, 4 * hidden_size, hidden_size}, false, -1.0f, 1.0f), // R + std::ref(B_def), // B + std::ref(H_def), // initial_h + std::ref(C_def), // initial_c + std::ref(P_def), // P + true, // has_Y + true, // has_Y_h + true, // has_Y_c + direction, // direction + hidden_size, // hidden_size + 0, // layout + ExpectedEPNodeAssignment::All); +} + +TEST_F(QnnCPUBackendTests, LSTM_FP32_sanity_reverse) { + std::string direction = "reverse"; + uint32_t num_direction = 1; + uint32_t batch_size = 3; + uint32_t hidden_size = 4; + uint32_t input_size = 5; + uint32_t seq_len = 6; + auto B_def = TestInputDef({num_direction, 8 * hidden_size}, false, -1.0f, 1.0f); + auto H_def = TestInputDef({num_direction, batch_size, hidden_size}, false, -1.0f, 1.0f); + auto C_def = TestInputDef({num_direction, batch_size, hidden_size}, false, -1.0f, 1.0f); + auto P_def = TestInputDef({num_direction, 3 * hidden_size}, false, -1.0f, 1.0f); + RunCpuFP32LSTMOpTest(TestInputDef({seq_len, batch_size, input_size}, false, -1.0f, 1.0f), // X + TestInputDef({num_direction, 4 * hidden_size, input_size}, false, -1.0f, 1.0f), // W + TestInputDef({num_direction, 4 * hidden_size, hidden_size}, false, -1.0f, 1.0f), // R + std::ref(B_def), // B + std::ref(H_def), // initial_h + std::ref(C_def), // initial_c + std::ref(P_def), // P + true, // has_Y + true, // has_Y_h + true, // has_Y_c + direction, // direction + hidden_size, // hidden_size + 0, // layout + ExpectedEPNodeAssignment::All); +} + +TEST_F(QnnCPUBackendTests, LSTM_FP32_sanity_bidirectional) { + std::string direction = "bidirectional"; + uint32_t num_direction = 2; + uint32_t batch_size = 3; + uint32_t hidden_size = 4; + uint32_t input_size = 5; + uint32_t seq_len = 6; + auto B_def = TestInputDef({num_direction, 8 * hidden_size}, false, -1.0f, 1.0f); + auto H_def = TestInputDef({num_direction, batch_size, hidden_size}, false, -1.0f, 1.0f); + auto C_def = TestInputDef({num_direction, batch_size, hidden_size}, false, -1.0f, 1.0f); + auto P_def = TestInputDef({num_direction, 3 * hidden_size}, false, -1.0f, 1.0f); + RunCpuFP32LSTMOpTest( + TestInputDef({seq_len, batch_size, input_size}, false, -1.0f, 1.0f), // X + TestInputDef({num_direction, 4 * hidden_size, input_size}, false, -1.0f, 1.0f), // W + TestInputDef({num_direction, 4 * hidden_size, hidden_size}, false, -1.0f, 1.0f), // R + std::ref(B_def), // B + std::ref(H_def), // initial_h + std::ref(C_def), // initial_c + std::ref(P_def), // P + true, // has_Y + true, // has_Y_h + true, // has_Y_c + direction, // direction + hidden_size, // hidden_size + 0, // layout + ExpectedEPNodeAssignment::All); +} + +TEST_F(QnnCPUBackendTests, LSTM_FP32_sanity_bidirectional_wo_B) { + std::string direction = "bidirectional"; + uint32_t num_direction = 2; + uint32_t batch_size = 3; + uint32_t hidden_size = 4; + uint32_t input_size = 5; + uint32_t seq_len = 6; + auto H_def = TestInputDef({num_direction, batch_size, hidden_size}, false, -1.0f, 1.0f); + auto C_def = TestInputDef({num_direction, batch_size, hidden_size}, false, -1.0f, 1.0f); + auto P_def = TestInputDef({num_direction, 3 * hidden_size}, false, -1.0f, 1.0f); + RunCpuFP32LSTMOpTest( + TestInputDef({seq_len, batch_size, input_size}, false, -1.0f, 1.0f), // X + TestInputDef({num_direction, 4 * hidden_size, input_size}, false, -1.0f, 1.0f), // W + TestInputDef({num_direction, 4 * hidden_size, hidden_size}, false, -1.0f, 1.0f), // R + std::nullopt, // B + std::ref(H_def), // initial_h + std::ref(C_def), // initial_c + std::ref(P_def), // P + true, // has_Y + true, // has_Y_h + true, // has_Y_c + direction, // direction + hidden_size, // hidden_size + 0, // layout + ExpectedEPNodeAssignment::All); +} + +TEST_F(QnnCPUBackendTests, LSTM_FP32_sanity_bidirectional_wo_H) { + std::string direction = "bidirectional"; + uint32_t num_direction = 2; + uint32_t batch_size = 3; + uint32_t hidden_size = 4; + uint32_t input_size = 5; + uint32_t seq_len = 6; + auto B_def = TestInputDef({num_direction, 8 * hidden_size}, false, -1.0f, 1.0f); + auto C_def = TestInputDef({num_direction, batch_size, hidden_size}, false, -1.0f, 1.0f); + auto P_def = TestInputDef({num_direction, 3 * hidden_size}, false, -1.0f, 1.0f); + RunCpuFP32LSTMOpTest( + TestInputDef({seq_len, batch_size, input_size}, false, -1.0f, 1.0f), // X + TestInputDef({num_direction, 4 * hidden_size, input_size}, false, -1.0f, 1.0f), // W + TestInputDef({num_direction, 4 * hidden_size, hidden_size}, false, -1.0f, 1.0f), // R + std::ref(B_def), // B + std::nullopt, // initial_h + std::ref(C_def), // initial_c + std::ref(P_def), // P + true, // has_Y + true, // has_Y_h + true, // has_Y_c + direction, // direction + hidden_size, // hidden_size + 0, // layout + ExpectedEPNodeAssignment::All); +} + +TEST_F(QnnCPUBackendTests, LSTM_FP32_sanity_bidirectional_wo_C) { + std::string direction = "bidirectional"; + uint32_t num_direction = 2; + uint32_t batch_size = 3; + uint32_t hidden_size = 4; + uint32_t input_size = 5; + uint32_t seq_len = 6; + auto B_def = TestInputDef({num_direction, 8 * hidden_size}, false, -1.0f, 1.0f); + auto H_def = TestInputDef({num_direction, batch_size, hidden_size}, false, -1.0f, 1.0f); + auto P_def = TestInputDef({num_direction, 3 * hidden_size}, false, -1.0f, 1.0f); + RunCpuFP32LSTMOpTest( + TestInputDef({seq_len, batch_size, input_size}, false, -1.0f, 1.0f), // X + TestInputDef({num_direction, 4 * hidden_size, input_size}, false, -1.0f, 1.0f), // W + TestInputDef({num_direction, 4 * hidden_size, hidden_size}, false, -1.0f, 1.0f), // R + std::ref(B_def), // B + std::ref(H_def), // initial_h + std::nullopt, // initial_c + std::ref(P_def), // P + true, // has_Y + true, // has_Y_h + true, // has_Y_c + direction, // direction + hidden_size, // hidden_size + 0, // layout + ExpectedEPNodeAssignment::All); +} + +TEST_F(QnnCPUBackendTests, LSTM_FP32_sanity_bidirectional_wo_HC) { + std::string direction = "bidirectional"; + uint32_t num_direction = 2; + uint32_t batch_size = 3; + uint32_t hidden_size = 4; + uint32_t input_size = 5; + uint32_t seq_len = 6; + auto B_def = TestInputDef({num_direction, 8 * hidden_size}, false, -1.0f, 1.0f); + auto P_def = TestInputDef({num_direction, 3 * hidden_size}, false, -1.0f, 1.0f); + RunCpuFP32LSTMOpTest( + TestInputDef({seq_len, batch_size, input_size}, false, -1.0f, 1.0f), // X + TestInputDef({num_direction, 4 * hidden_size, input_size}, false, -1.0f, 1.0f), // W + TestInputDef({num_direction, 4 * hidden_size, hidden_size}, false, -1.0f, 1.0f), // R + std::ref(B_def), // B + std::nullopt, // initial_h + std::nullopt, // initial_c + std::ref(P_def), // P + true, // has_Y + true, // has_Y_h + true, // has_Y_c + direction, // direction + hidden_size, // hidden_size + 0, // layout + ExpectedEPNodeAssignment::All); +} + +TEST_F(QnnCPUBackendTests, LSTM_FP32_sanity_bidirectional_wo_P) { + std::string direction = "forward"; + uint32_t num_direction = 1; + uint32_t batch_size = 3; + uint32_t hidden_size = 4; + uint32_t input_size = 5; + uint32_t seq_len = 6; + auto B_def = TestInputDef({num_direction, 8 * hidden_size}, false, -1.0f, 1.0f); + auto H_def = TestInputDef({num_direction, batch_size, hidden_size}, false, -1.0f, 1.0f); + auto C_def = TestInputDef({num_direction, batch_size, hidden_size}, false, -1.0f, 1.0f); + RunCpuFP32LSTMOpTest(TestInputDef({seq_len, batch_size, input_size}, false, -1.0f, 1.0f), // X + TestInputDef({num_direction, 4 * hidden_size, input_size}, false, -1.0f, 1.0f), // W + TestInputDef({num_direction, 4 * hidden_size, hidden_size}, false, -1.0f, 1.0f), // R + std::ref(B_def), // B + std::ref(H_def), // initial_h + std::ref(C_def), // initial_c + std::nullopt, // P + true, // has_Y + true, // has_Y_h + true, // has_Y_c + direction, // direction + hidden_size, // hidden_size + 0, // layout + ExpectedEPNodeAssignment::All); +} + +TEST_F(QnnCPUBackendTests, LSTM_FP32_sanity_bidirectional_all_initializer) { + std::string direction = "bidirectional"; + uint32_t num_direction = 2; + uint32_t batch_size = 3; + uint32_t hidden_size = 4; + uint32_t input_size = 5; + uint32_t seq_len = 6; + auto B_def = TestInputDef({num_direction, 8 * hidden_size}, true, -1.0f, 1.0f); + auto H_def = TestInputDef({num_direction, batch_size, hidden_size}, true, -1.0f, 1.0f); + auto C_def = TestInputDef({num_direction, batch_size, hidden_size}, true, -1.0f, 1.0f); + auto P_def = TestInputDef({num_direction, 3 * hidden_size}, true, -1.0f, 1.0f); + RunCpuFP32LSTMOpTest( + TestInputDef({seq_len, batch_size, input_size}, false, -1.0f, 1.0f), // X + TestInputDef({num_direction, 4 * hidden_size, input_size}, true, -1.0f, 1.0f), // W + TestInputDef({num_direction, 4 * hidden_size, hidden_size}, true, -1.0f, 1.0f), // R + std::ref(B_def), // B + std::ref(H_def), // initial_h + std::ref(C_def), // initial_c + std::ref(P_def), // P + true, // has_Y + true, // has_Y_h + true, // has_Y_c + direction, // direction + hidden_size, // hidden_size + 0, // layout + ExpectedEPNodeAssignment::All); +} + +TEST_F(QnnCPUBackendTests, LSTM_FP32_sanity_bidirectional_Y_only) { + std::string direction = "bidirectional"; + uint32_t num_direction = 2; + uint32_t batch_size = 3; + uint32_t hidden_size = 4; + uint32_t input_size = 5; + uint32_t seq_len = 6; + auto B_def = TestInputDef({num_direction, 8 * hidden_size}, false, -1.0f, 1.0f); + auto H_def = TestInputDef({num_direction, batch_size, hidden_size}, false, -1.0f, 1.0f); + auto C_def = TestInputDef({num_direction, batch_size, hidden_size}, false, -1.0f, 1.0f); + auto P_def = TestInputDef({num_direction, 3 * hidden_size}, false, -1.0f, 1.0f); + RunCpuFP32LSTMOpTest( + TestInputDef({seq_len, batch_size, input_size}, false, -1.0f, 1.0f), // X + TestInputDef({num_direction, 4 * hidden_size, input_size}, false, -1.0f, 1.0f), // W + TestInputDef({num_direction, 4 * hidden_size, hidden_size}, false, -1.0f, 1.0f), // R + std::ref(B_def), // B + std::ref(H_def), // initial_h + std::ref(C_def), // initial_c + std::ref(P_def), // P + true, // has_Y + false, // has_Y_h + false, // has_Y_c + direction, // direction + hidden_size, // hidden_size + 0, // layout + ExpectedEPNodeAssignment::All); +} + +TEST_F(QnnCPUBackendTests, LSTM_FP32_sanity_bidirectional_Y_h_only) { + std::string direction = "bidirectional"; + uint32_t num_direction = 2; + uint32_t batch_size = 3; + uint32_t hidden_size = 4; + uint32_t input_size = 5; + uint32_t seq_len = 6; + auto B_def = TestInputDef({num_direction, 8 * hidden_size}, false, -1.0f, 1.0f); + auto H_def = TestInputDef({num_direction, batch_size, hidden_size}, false, -1.0f, 1.0f); + auto C_def = TestInputDef({num_direction, batch_size, hidden_size}, false, -1.0f, 1.0f); + auto P_def = TestInputDef({num_direction, 3 * hidden_size}, false, -1.0f, 1.0f); + RunCpuFP32LSTMOpTest( + TestInputDef({seq_len, batch_size, input_size}, false, -1.0f, 1.0f), // X + TestInputDef({num_direction, 4 * hidden_size, input_size}, false, -1.0f, 1.0f), // W + TestInputDef({num_direction, 4 * hidden_size, hidden_size}, false, -1.0f, 1.0f), // R + std::ref(B_def), // B + std::ref(H_def), // initial_h + std::ref(C_def), // initial_c + std::ref(P_def), // P + false, // has_Y + true, // has_Y_h + false, // has_Y_c + direction, // direction + hidden_size, // hidden_size + 0, // layout + ExpectedEPNodeAssignment::All); +} + +TEST_F(QnnCPUBackendTests, LSTM_FP32_sanity_bidirectional_Y_c_only) { + std::string direction = "bidirectional"; + uint32_t num_direction = 2; + uint32_t batch_size = 3; + uint32_t hidden_size = 4; + uint32_t input_size = 5; + uint32_t seq_len = 6; + auto B_def = TestInputDef({num_direction, 8 * hidden_size}, false, -1.0f, 1.0f); + auto H_def = TestInputDef({num_direction, batch_size, hidden_size}, false, -1.0f, 1.0f); + auto C_def = TestInputDef({num_direction, batch_size, hidden_size}, false, -1.0f, 1.0f); + auto P_def = TestInputDef({num_direction, 3 * hidden_size}, false, -1.0f, 1.0f); + RunCpuFP32LSTMOpTest( + TestInputDef({seq_len, batch_size, input_size}, false, -1.0f, 1.0f), // X + TestInputDef({num_direction, 4 * hidden_size, input_size}, false, -1.0f, 1.0f), // W + TestInputDef({num_direction, 4 * hidden_size, hidden_size}, false, -1.0f, 1.0f), // R + std::ref(B_def), // B + std::ref(H_def), // initial_h + std::ref(C_def), // initial_c + std::ref(P_def), // P + false, // has_Y + false, // has_Y_h + true, // has_Y_c + direction, // direction + hidden_size, // hidden_size + 0, // layout + ExpectedEPNodeAssignment::All); +} + +#endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) + +} // namespace test +} // namespace onnxruntime + +#endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/test/providers/qnn/pool_op_test.cpp b/onnxruntime/test/providers/qnn/pool_op_test.cpp index d777b1134d060..9284df6f8a4a8 100644 --- a/onnxruntime/test/providers/qnn/pool_op_test.cpp +++ b/onnxruntime/test/providers/qnn/pool_op_test.cpp @@ -262,6 +262,51 @@ TEST_F(QnnHTPBackendTests, MaxPool_Rank3_Ceil_HTP_u8) { ExpectedEPNodeAssignment::All); } +// 1-D MaxPool HTP test for rank-3 with ceil_mode=1 and auto_pad='VALID' +TEST_F(QnnHTPBackendTests, MaxPool_Rank3_Ceil_HTP_u8_auto_pad_VALID) { + RunQDQPoolOpTest( + "MaxPool", + TestInputDef({1, 3, 3}, false, -10.0f, 10.0f), + {utils::MakeAttribute("kernel_shape", std::vector{3}), + utils::MakeAttribute("strides", std::vector{3}), + utils::MakeAttribute("pads", std::vector{0, 0}), + utils::MakeAttribute("dilations", std::vector{1}), + utils::MakeAttribute("ceil_mode", static_cast(1)), + utils::MakeAttribute("storage_order", static_cast(0)), + utils::MakeAttribute("auto_pad", "VALID")}, + ExpectedEPNodeAssignment::All); +} + +// 1-D MaxPool HTP test for rank-3 with ceil_mode=1 and auto_pad='SAME_UPPER' +TEST_F(QnnHTPBackendTests, MaxPool_Rank3_Ceil_HTP_u8_auto_pad_SAME_UPPER) { + RunQDQPoolOpTest( + "MaxPool", + TestInputDef({1, 3, 3}, false, -10.0f, 10.0f), + {utils::MakeAttribute("kernel_shape", std::vector{3}), + utils::MakeAttribute("strides", std::vector{3}), + utils::MakeAttribute("pads", std::vector{0, 0}), + utils::MakeAttribute("dilations", std::vector{1}), + utils::MakeAttribute("ceil_mode", static_cast(1)), + utils::MakeAttribute("storage_order", static_cast(0)), + utils::MakeAttribute("auto_pad", "SAME_UPPER")}, + ExpectedEPNodeAssignment::All); +} + +// 1-D MaxPool HTP test for rank-3 with ceil_mode=1 and auto_pad='SAME_LOWER' +TEST_F(QnnHTPBackendTests, MaxPool_Rank3_Ceil_HTP_u8_auto_pad_SAME_LOWER) { + RunQDQPoolOpTest( + "MaxPool", + TestInputDef({1, 3, 3}, false, -10.0f, 10.0f), + {utils::MakeAttribute("kernel_shape", std::vector{3}), + utils::MakeAttribute("strides", std::vector{3}), + utils::MakeAttribute("pads", std::vector{0, 0}), + utils::MakeAttribute("dilations", std::vector{1}), + utils::MakeAttribute("ceil_mode", static_cast(1)), + utils::MakeAttribute("storage_order", static_cast(0)), + utils::MakeAttribute("auto_pad", "SAME_LOWER")}, + ExpectedEPNodeAssignment::All); +} + TEST_F(QnnHTPBackendTests, MaxPool_Ceil_HTP_u8) { RunQDQPoolOpTest("MaxPool", TestInputDef({1, 2, 3, 3}, false, -10.0f, 10.0f), // Dynamic input with range [-10, 10] diff --git a/onnxruntime/test/providers/qnn/qnn_basic_test.cc b/onnxruntime/test/providers/qnn/qnn_basic_test.cc index 0212dacadbced..a206644bc945e 100644 --- a/onnxruntime/test/providers/qnn/qnn_basic_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_basic_test.cc @@ -6,6 +6,7 @@ #include #include "core/graph/constants.h" +#include "core/graph/node_attr_utils.h" #include "core/providers/cpu/cpu_provider_factory.h" // For OrtSessionOptionsAppendExecutionProvider_CPU #if BUILD_QNN_EP_STATIC_LIB #include "core/providers/qnn/qnn_allocator.h" // Used by QnnHTPBackendTests.UseHtpSharedMemoryAllocatorForInputs @@ -26,6 +27,8 @@ using namespace onnxruntime::logging; #define ORT_MODEL_FOLDER ORT_TSTR("testdata/") +constexpr std::string_view kDlcOutputDir("dlc_output"); + // in test_main.cc extern std::unique_ptr ort_env; extern "C" void ortenv_setup(); @@ -334,19 +337,61 @@ TEST_F(QnnHTPBackendTests, RunConvInt4Model) { } #endif // #if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) -// Helper function that runs an ONNX model with a NHWC Resize operator to test that -// type/shape inference succeeds during layout transformation. -// Refer to onnxruntime/core/graph/contrib_ops/nhwc_inference_context.h. -// -// The models passed to this function are subgraphs extracted from a larger model that exhibited -// shape inferencing issues on QNN. Thus, the models are expected to have a specific input/output -// types and shapes. -static void RunNHWCResizeModel(const ORTCHAR_T* ort_model_path, bool use_htp, bool enable_qnn_saver = false, - std::string htp_graph_finalization_opt_mode = "", - std::string qnn_context_priority = "", - std::string soc_model = "", - std::string htp_arch = "", - std::string device_id = "") { +enum class TestBackend { + Cpu, + Htp, + Saver, + Ir, +}; + +static std::string ToBackendLibName(TestBackend backend) { + switch (backend) { + case TestBackend::Cpu: + return "Cpu"; + case TestBackend::Htp: + return "Htp"; + case TestBackend::Saver: + return "Saver"; + case TestBackend::Ir: + return "Ir"; + default: + assert(false && "Invalid TestBackend value."); + return ""; + } +} + +static void AddSerializerConfigs(TestBackend serializer_backend, onnxruntime::ProviderOptions& options) { + std::string serializer_lib = ToBackendLibName(serializer_backend); + std::string serializer_path_key; + + switch (serializer_backend) { + case TestBackend::Ir: + serializer_path_key = "qnn_ir_backend_path"; + options["dump_qnn_ir_dlc"] = "1"; + options["dump_qnn_ir_dlc_dir"] = kDlcOutputDir; + break; + case TestBackend::Saver: + serializer_path_key = "qnn_saver_path"; + break; + default: + assert(false && "Invalid serializer backend."); + return; + } + +#if defined(_WIN32) + options[serializer_path_key] = "Qnn" + serializer_lib + ".dll"; +#else + options[serializer_path_key] = "libQnn" + serializer_lib + ".so"; +#endif +} + +static Ort::Session InitNHWCResizeModel(const ORTCHAR_T* ort_model_path, TestBackend backend, + std::optional serializer_backend = std::nullopt, + std::string htp_graph_finalization_opt_mode = "", + std::string qnn_context_priority = "", + std::string soc_model = "", + std::string htp_arch = "", + std::string device_id = "") { Ort::SessionOptions so; // Ensure all type/shape inference warnings result in errors! @@ -356,18 +401,18 @@ static void RunNHWCResizeModel(const ORTCHAR_T* ort_model_path, bool use_htp, bo onnxruntime::ProviderOptions options; options["offload_graph_io_quantization"] = "0"; + std::string backend_lib = ToBackendLibName(backend); + #if defined(_WIN32) - options["backend_path"] = use_htp ? "QnnHtp.dll" : "QnnCpu.dll"; - if (enable_qnn_saver) { - options["qnn_saver_path"] = "QnnSaver.dll"; - } + options["backend_path"] = "Qnn" + backend_lib + ".dll"; #else - options["backend_path"] = use_htp ? "libQnnHtp.so" : "libQnnCpu.so"; - if (enable_qnn_saver) { - options["qnn_saver_path"] = "libQnnSaver.so"; - } + options["backend_path"] = "libQnn" + backend_lib + ".so"; #endif + if (serializer_backend) { + AddSerializerConfigs(*serializer_backend, options); + } + if (!htp_graph_finalization_opt_mode.empty()) { options["htp_graph_finalization_optimization_mode"] = std::move(htp_graph_finalization_opt_mode); } @@ -392,6 +437,25 @@ static void RunNHWCResizeModel(const ORTCHAR_T* ort_model_path, bool use_htp, bo Ort::Session session(*ort_env, ort_model_path, so); + return session; +} + +// Helper function that runs an ONNX model with a NHWC Resize operator to test that +// type/shape inference succeeds during layout transformation. +// Refer to onnxruntime/core/graph/contrib_ops/nhwc_inference_context.h. +// +// The models passed to this function are subgraphs extracted from a larger model that exhibited +// shape inferencing issues on QNN. Thus, the models are expected to have a specific input/output +// types and shapes. +static void RunNHWCResizeModel(const ORTCHAR_T* ort_model_path, TestBackend backend, + std::optional serializer_backend = std::nullopt, + std::string htp_graph_finalization_opt_mode = "", + std::string qnn_context_priority = "", + std::string soc_model = "", + std::string htp_arch = "", + std::string device_id = "") { + Ort::Session session = InitNHWCResizeModel(ort_model_path, backend, serializer_backend, htp_graph_finalization_opt_mode, qnn_context_priority, soc_model, htp_arch, device_id); + // Input can be all zeros since we're testing for correct shape inference. std::array input0_data = {}; std::array input1_data = {}; @@ -433,25 +497,25 @@ static void RunNHWCResizeModel(const ORTCHAR_T* ort_model_path, bool use_htp, bo // Test shape inference of NHWC Resize operator (opset 11) that uses // the scales input. Use the QNN CPU backend. TEST_F(QnnCPUBackendTests, TestNHWCResizeShapeInference_scales_opset11) { - RunNHWCResizeModel(ORT_MODEL_FOLDER "nhwc_resize_scales_opset11.onnx", false); + RunNHWCResizeModel(ORT_MODEL_FOLDER "nhwc_resize_scales_opset11.onnx", TestBackend::Cpu); } // Test shape inference of NHWC Resize operator (opset 18) that uses // the scales input. Use the QNN CPU backend. TEST_F(QnnCPUBackendTests, TestNHWCResizeShapeInference_scales_opset18) { - RunNHWCResizeModel(ORT_MODEL_FOLDER "nhwc_resize_scales_opset18.onnx", false); + RunNHWCResizeModel(ORT_MODEL_FOLDER "nhwc_resize_scales_opset18.onnx", TestBackend::Cpu); } // Test shape inference of NHWC Resize operator (opset 11) that uses // the sizes input. Use the QNN CPU backend. TEST_F(QnnCPUBackendTests, TestNHWCResizeShapeInference_sizes_opset11) { - RunNHWCResizeModel(ORT_MODEL_FOLDER "nhwc_resize_sizes_opset11.onnx", false); + RunNHWCResizeModel(ORT_MODEL_FOLDER "nhwc_resize_sizes_opset11.onnx", TestBackend::Cpu); } // Test shape inference of NHWC Resize operator (opset 18) that uses // the sizes input. Use the QNN CPU backend. TEST_F(QnnCPUBackendTests, TestNHWCResizeShapeInference_sizes_opset18) { - RunNHWCResizeModel(ORT_MODEL_FOLDER "nhwc_resize_sizes_opset18.onnx", false); + RunNHWCResizeModel(ORT_MODEL_FOLDER "nhwc_resize_sizes_opset18.onnx", TestBackend::Cpu); } // Test that QNN Saver generates the expected files for a model meant to run on the QNN CPU backend. @@ -463,8 +527,8 @@ TEST_F(QnnCPUBackendTests, QnnSaver_OutputFiles) { ASSERT_FALSE(std::filesystem::exists(qnn_saver_output_dir)); RunNHWCResizeModel(ORT_MODEL_FOLDER "nhwc_resize_sizes_opset18.onnx", - false, // use_htp - true); // enable_qnn_saver + TestBackend::Cpu, // backend + TestBackend::Saver); // serializer_backend // Check that QNN Saver output files exist. EXPECT_TRUE(std::filesystem::exists(qnn_saver_output_dir / "saver_output.c")); @@ -856,7 +920,42 @@ TEST_F(QnnHTPBackendTests, MultithreadHtpPowerCfgDefaultAndRunOption) { // the sizes input. Use the QNN HTP backend. // Maps to QNN's ResizeBilinear operator. TEST_F(QnnHTPBackendTests, TestNHWCResizeShapeInference_qdq_sizes_opset18) { - RunNHWCResizeModel(ORT_MODEL_FOLDER "nhwc_resize_sizes_opset18.quant.onnx", true); + RunNHWCResizeModel(ORT_MODEL_FOLDER "nhwc_resize_sizes_opset18.quant.onnx", TestBackend::Htp); +} + +// Test that QNN Ir generates the expected file for a model meant to run on the QNN HTP backend. + +TEST_F(QnnHTPBackendTests, QnnIr_OutputFiles) { + const auto& logger = DefaultLoggingManager().DefaultLogger(); + if (IsIRBackendSupported() == BackendSupport::UNSUPPORTED) { + LOGS(logger, WARNING) << "QNN IR backend is not available! Skipping test."; + GTEST_SKIP(); + } else if (IsIRBackendSupported() == BackendSupport::SUPPORT_ERROR) { + LOGS(logger, ERROR) << "Failed to check if QNN IR backend is available."; + FAIL(); + } + + const std::filesystem::path qnn_dlc_dir = kDlcOutputDir; + + // Remove pre-existing QNN Ir output files. Note that fs::remove_all() can handle non-existing paths. + std::filesystem::remove_all(qnn_dlc_dir); + ASSERT_FALSE(std::filesystem::exists(qnn_dlc_dir)); + + InitNHWCResizeModel(ORT_MODEL_FOLDER "nhwc_resize_sizes_opset18.onnx", + TestBackend::Htp, // backend + TestBackend::Ir); // serializer backend + + // File names are taken from graph node names. Just make sure that we got one .dlc + // in the expected directory. + ASSERT_TRUE(std::filesystem::exists(qnn_dlc_dir)); + + int file_count = 0; + for (const auto& entry : std::filesystem::directory_iterator(qnn_dlc_dir)) { + EXPECT_TRUE(entry.is_regular_file()); + EXPECT_EQ(entry.path().extension(), ".dlc"); + ++file_count; + } + EXPECT_EQ(file_count, 1); } // Test that QNN Saver generates the expected files for a model meant to run on the QNN HTP backend. @@ -868,8 +967,8 @@ TEST_F(QnnHTPBackendTests, QnnSaver_OutputFiles) { ASSERT_FALSE(std::filesystem::exists(qnn_saver_output_dir)); RunNHWCResizeModel(ORT_MODEL_FOLDER "nhwc_resize_sizes_opset18.onnx", - true, // use_htp - true); // enable_qnn_saver + TestBackend::Htp, // backend + TestBackend::Saver); // serializer_backend // Check that QNN Saver output files exist. EXPECT_TRUE(std::filesystem::exists(qnn_saver_output_dir / "saver_output.c")); @@ -885,9 +984,9 @@ TEST_F(QnnHTPBackendTests, HTPGraphFinalizationOptimizationModes) { "3"}; // Mode 3 for (auto mode : graph_opt_modes) { RunNHWCResizeModel(ORT_MODEL_FOLDER "nhwc_resize_sizes_opset18.quant.onnx", - true, // use_htp - false, // enable_qnn_saver - mode); // htp_graph_finalization_opt_mode + TestBackend::Htp, // backend + std::nullopt, // serializer_backend + mode); // htp_graph_finalization_opt_mode } } @@ -905,10 +1004,10 @@ TEST_F(QnnHTPBackendTests, HTPSocModels) { for (auto soc_model : soc_models) { RunNHWCResizeModel(ORT_MODEL_FOLDER "nhwc_resize_sizes_opset18.quant.onnx", - true, // use_htp - false, // enable_qnn_saver - "", // htp_graph_finalization_opt_mode - "", // qnn_context_priority + TestBackend::Htp, // backend + std::nullopt, // serializer_backend + "", // htp_graph_finalization_opt_mode + "", // qnn_context_priority soc_model); } } @@ -920,23 +1019,23 @@ TEST_F(QnnHTPBackendTests, HTPArchValues) { "68"}; // v68 for (auto htp_arch : htp_archs) { RunNHWCResizeModel(ORT_MODEL_FOLDER "nhwc_resize_sizes_opset18.quant.onnx", - true, // use_htp - false, // enable_qnn_saver - "", // htp_graph_finalization_opt_mode - "", // qnn_context_priority - "", // soc_model - htp_arch, // htp_arch - "0"); // device_id + TestBackend::Htp, // backend + std::nullopt, // enable_qnn_saver + "", // htp_graph_finalization_opt_mode + "", // qnn_context_priority + "", // soc_model + htp_arch, // htp_arch + "0"); // device_id } } // Test that models run with high QNN context priority. TEST_F(QnnHTPBackendTests, QnnContextPriorityHigh) { RunNHWCResizeModel(ORT_MODEL_FOLDER "nhwc_resize_sizes_opset18.quant.onnx", - true, // use_htp - false, // enable_qnn_saver - "", // htp_graph_finalization_opt_mode - "high"); // qnn_context_priority + TestBackend::Htp, // use_htp + std::nullopt, // enable_qnn_saver + "", // htp_graph_finalization_opt_mode + "high"); // qnn_context_priority } // Create a model with Cast + Add (quantized) @@ -1286,7 +1385,78 @@ TEST_F(QnnHTPBackendTests, AutoEp_PreferNpu) { } #endif // defined(WIN32) && !BUILD_QNN_EP_STATIC_LIB +// Test whether QNN EP can handle the case where the number of graph inputs and +// the number of tensor wrappers do not match. +// Take Resize op as an example. +// - Qnn only cares about the 1st input, so the rest of the inputs are not converted +// to tensor wrappers. +// - However, these remaining inputs still appear in the graph inputs, +// resulting in a discrepancy in the input quantities. +TEST_F(QnnHTPBackendTests, TestMismatchedGraphInputAndTensorWrapperCount) { + onnxruntime::ProviderOptions provider_options; + provider_options["backend_type"] = "htp"; + + auto input_defs = {TestInputDef({1, 3, 10, 10}, false, -10.0f, 10.0f), + TestInputDef({0}, false, {}), + TestInputDef({4}, true, {1.0f, 1.0f, 2.0f, 2.0f})}; + auto attrs = {utils::MakeAttribute("mode", "nearest"), + utils::MakeAttribute("coordinate_transformation_mode", "asymmetric"), + utils::MakeAttribute("nearest_mode", "floor")}; + RunQnnModelTest(BuildOpTestCase("Resize", + input_defs, + {}, + attrs, + kOnnxDomain), + provider_options, + 11, + ExpectedEPNodeAssignment::All, + 0.008f); +} + #endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) + +// Test that QNN Ir generates the expected files for a model meant to run on any QNN backend. +TEST_F(QnnIRBackendTests, QnnIr_OutputFiles) { + const std::filesystem::path qnn_dlc_dir = kDlcOutputDir; + + // Remove pre-existing QNN Ir output files. Note that fs::remove_all() can handle non-existing paths. + std::filesystem::remove_all(qnn_dlc_dir); + ASSERT_FALSE(std::filesystem::exists(qnn_dlc_dir)); + + InitNHWCResizeModel(ORT_MODEL_FOLDER "nhwc_resize_sizes_opset18.onnx", + TestBackend::Ir, // backend + TestBackend::Ir); // serializer backend + + // File names are taken from graph node names. Just make sure that we got one .dlc + // in the expected directory. + ASSERT_TRUE(std::filesystem::exists(qnn_dlc_dir)); + + int file_count = 0; + for (const auto& entry : std::filesystem::directory_iterator(qnn_dlc_dir)) { + EXPECT_TRUE(entry.is_regular_file()); + EXPECT_EQ(entry.path().extension(), ".dlc"); + ++file_count; + } + EXPECT_EQ(file_count, 1); +} + +// Test that QNN Saver generates the expected files for a model meant to run on any QNN backend. +TEST(QnnSaverBackendTests, QnnSaver_OutputFiles) { + const std::filesystem::path qnn_saver_output_dir = "saver_output"; + + // Remove pre-existing QNN Saver output files. Note that fs::remove_all() can handle non-existing paths. + std::filesystem::remove_all(qnn_saver_output_dir); + ASSERT_FALSE(std::filesystem::exists(qnn_saver_output_dir)); + + InitNHWCResizeModel(ORT_MODEL_FOLDER "nhwc_resize_sizes_opset18.onnx", + TestBackend::Saver, // backend + TestBackend::Saver); // serializer_backend + + // Check that QNN Saver output files exist. + EXPECT_TRUE(std::filesystem::exists(qnn_saver_output_dir / "saver_output.c")); + EXPECT_TRUE(std::filesystem::exists(qnn_saver_output_dir / "params.bin")); +} + #endif // !defined(ORT_MINIMAL_BUILD) } // namespace test diff --git a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc index 8d840b1a3d45f..6ef831c8ecd6f 100644 --- a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc @@ -14,6 +14,8 @@ #include "gtest/gtest.h" #include "gmock/gmock.h" +#define ORT_MODEL_FOLDER ORT_TSTR("testdata/") + using namespace ONNX_NAMESPACE; using namespace onnxruntime::logging; @@ -361,6 +363,9 @@ TEST_F(QnnHTPBackendTests, CompileApi_FromSessionOptions_InputModelFromPath) { // Make sure the compiled model was generated and has the expected number of EPContext nodes. ASSERT_TRUE(std::filesystem::exists(output_model_file)); CheckEpContextNodeCounts(output_model_file, 2, 2); + + // Should be able to create a session with the compiled model and the original session options. + EXPECT_NO_THROW((Ort::Session(*ort_env, output_model_file, so))); } // Test using the CompileModel() API with settings: @@ -396,6 +401,9 @@ TEST_F(QnnHTPBackendTests, CompileApi_FromSessionOptions_InputModelAsBuffer_Embe // Make sure the compiled model was generated and has the expected number of EPContext nodes. ASSERT_TRUE(std::filesystem::exists(output_model_file)); CheckEpContextNodeCounts(output_model_file, 2, 2); + + // Should be able to create a session with the compiled model and the original session options. + EXPECT_NO_THROW((Ort::Session(*ort_env, output_model_file, so))); } // Test using the CompileModel() API with settings: @@ -436,6 +444,12 @@ TEST_F(QnnHTPBackendTests, CompileApi_FromSessionOptions_OutputModelBuffer) { // Check that the compiled model has the expected number of EPContext nodes. CheckEpContextNodeCounts(output_model_buffer, output_model_buffer_size, 2, 2); + + { + // Should be able to create a session with the compiled model and the original session options. + EXPECT_NO_THROW((Ort::Session(*ort_env, output_model_buffer, output_model_buffer_size, so))); + } + allocator.Free(output_model_buffer); } @@ -479,6 +493,10 @@ TEST_F(QnnHTPBackendTests, CompileApi_FromSessionOptions_InputAndOutputModelsInB // Check that the compiled model has the expected number of EPContext nodes. CheckEpContextNodeCounts(output_model_buffer, output_model_buffer_size, 2, 2); + + // Should be able to create a session with the compiled model and the original session options. + EXPECT_NO_THROW((Ort::Session(*ort_env, output_model_buffer, output_model_buffer_size, session_options))); + allocator.Free(output_model_buffer); } @@ -503,6 +521,10 @@ TEST_F(QnnHTPBackendTests, CompileApi_FromSessionOptions_InputAndOutputModelsInB // Check that the compiled model has the expected number of EPContext nodes. CheckEpContextNodeCounts(output_model_buffer, output_model_buffer_size, 2, 2); + + // Should be able to create a session with the compiled model and the original session options. + EXPECT_NO_THROW((Ort::Session(*ort_env, output_model_buffer, output_model_buffer_size, session_options))); + allocator.Free(output_model_buffer); } } @@ -554,9 +576,164 @@ TEST_F(QnnHTPBackendTests, CompileApi_FromSessionOptions_OutputModelBuffer_Outpu // Check that the compiled model has the expected number of EPContext nodes. CheckEpContextNodeCounts(output_model_buffer, output_model_buffer_size, 2, 2); + + // Should be able to create a session with the compiled model and the original session options. + EXPECT_NO_THROW((Ort::Session(*ort_env, output_model_buffer, output_model_buffer_size, so))); + allocator.Free(output_model_buffer); } +// Test that the explicit compile API can be configured to return an error if the output model does not +// have EPContext nodes. +TEST_F(QnnHTPBackendTests, CompileApi_SetFlags_ErrorIfNoCompiledNodes) { + const ORTCHAR_T* input_model_file = ORT_MODEL_FOLDER "mul_1.onnx"; + const ORTCHAR_T* output_model_file = ORT_TSTR("should_not_be_generated.onnx"); + std::filesystem::remove(output_model_file); + + // Initialize session options with only CPU EP, which will not be able to compile any nodes. + Ort::SessionOptions session_options; + Ort::ModelCompilationOptions compile_options(*ort_env, session_options); + compile_options.SetInputModelPath(input_model_file); + compile_options.SetOutputModelPath(output_model_file); + compile_options.SetFlags(OrtCompileApiFlags_ERROR_IF_NO_NODES_COMPILED); + + // Call CompileModel() but expect an error status. + Ort::Status status = Ort::CompileModel(*ort_env, compile_options); + ASSERT_EQ(status.GetErrorCode(), ORT_FAIL); + ASSERT_THAT(status.GetErrorMessage(), testing::HasSubstr("Unable to compile any nodes")); + + // Make sure that the output file was *NOT* generated. + ASSERT_FALSE(std::filesystem::exists(output_model_file)); +} + +// Test that the explicit compile API can be configured to return an error if the output model already exists and +// would have been overwritten. +TEST_F(QnnHTPBackendTests, CompileApi_SetFlags_ErrorIfOutputFileAlreadyExists) { + const ORTCHAR_T* input_model_file = ORT_MODEL_FOLDER "mul_1.onnx"; + const ORTCHAR_T* output_model_file = ORT_TSTR("mul_1_ctx_.onnx"); + std::filesystem::remove(output_model_file); + + Ort::SessionOptions session_options; + session_options.AppendExecutionProvider(kQnnExecutionProvider, ProviderOptions{{"backend_type", "htp"}}); + + // Compile with QNN EP. Should succeed the first time. + { + Ort::ModelCompilationOptions compile_options(*ort_env, session_options); + compile_options.SetInputModelPath(input_model_file); + compile_options.SetOutputModelPath(output_model_file); + + Ort::Status status = Ort::CompileModel(*ort_env, compile_options); + ASSERT_TRUE(status.IsOK()) << "CompileModel() should succeed the first time a model is compiled."; + ASSERT_TRUE(std::filesystem::exists(output_model_file)) << "compiled model should exist"; + } + + // Compiling the input model again should fail if we disallow overwriting the output file. + { + Ort::ModelCompilationOptions compile_options(*ort_env, session_options); + compile_options.SetInputModelPath(input_model_file); + compile_options.SetOutputModelPath(output_model_file); + compile_options.SetFlags(OrtCompileApiFlags_ERROR_IF_OUTPUT_FILE_EXISTS); + + Ort::Status status = Ort::CompileModel(*ort_env, compile_options); + ASSERT_EQ(status.GetErrorCode(), ORT_FAIL); + ASSERT_THAT(status.GetErrorMessage(), testing::HasSubstr("exists already")); + ASSERT_TRUE(std::filesystem::exists(output_model_file)) << "original compiled model should still exist"; + } +} + +// Tests that the explicit compile API returns an error if user tries to compile a compiled model. +// This scenario is silently ignored in the original compilation approach with session option configs. +TEST_F(QnnHTPBackendTests, CompileApi_ErrorIfCompilingACompiledModel) { + const ORTCHAR_T* input_model_file = ORT_MODEL_FOLDER "mul_1.onnx"; + const ORTCHAR_T* output_model_file = ORT_TSTR("mul_1_ctx_.onnx"); + std::filesystem::remove(output_model_file); + + Ort::SessionOptions session_options; + session_options.AppendExecutionProvider(kQnnExecutionProvider, ProviderOptions{{"backend_type", "htp"}}); + + // Compile with QNN EP. Should succeed the first time. + { + Ort::ModelCompilationOptions compile_options(*ort_env, session_options); + compile_options.SetInputModelPath(input_model_file); + compile_options.SetOutputModelPath(output_model_file); + + Ort::Status status = Ort::CompileModel(*ort_env, compile_options); + ASSERT_TRUE(status.IsOK()) << "CompileModel() should succeed the first time a model is compiled."; + ASSERT_TRUE(std::filesystem::exists(output_model_file)) << "compiled model should exist"; + } + + // Compiling the compiled model should always fail: it's already compiled! + { + const ORTCHAR_T* new_output_model_file = ORT_TSTR("should_not_be_generated.onnx"); // Should not be generated. + std::filesystem::remove(new_output_model_file); + + Ort::ModelCompilationOptions compile_options(*ort_env, session_options); + compile_options.SetInputModelPath(output_model_file); // Set the compiled model as the input! + compile_options.SetOutputModelPath(new_output_model_file); + + Ort::Status status = Ort::CompileModel(*ort_env, compile_options); + ASSERT_EQ(status.GetErrorCode(), ORT_INVALID_GRAPH); + ASSERT_THAT(status.GetErrorMessage(), testing::HasSubstr("ensure the input model is not already compiled")); + ASSERT_FALSE(std::filesystem::exists(new_output_model_file)) << "new compiled model should not be generated"; + ASSERT_TRUE(std::filesystem::exists(output_model_file)) << "original compiled model should still exist"; + } +} + +// Uses the original compiling approach with session option configs (instead of explicit compile API). +// Test that ORT does not generate an output model if the model does not contain EPContext nodes. +// Also, ORT should not return an error. +TEST_F(QnnHTPBackendTests, QnnContextBinary_OriginalCompileApproach_NoCompiledNodesDoesntGenerateOutput) { + const ORTCHAR_T* input_model_file = ORT_MODEL_FOLDER "mul_1.onnx"; + const char* output_model_file = "should_not_be_generated.onnx"; + + // Initialize session options with only CPU EP, which will not be able to compile any nodes. + Ort::SessionOptions so; + so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); + so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, output_model_file); + Ort::Session session(*ort_env, input_model_file, so); // Should not throw an error. + + // Make sure that the output file was *NOT* generated. + ASSERT_FALSE(std::filesystem::exists(output_model_file)); +} + +// Uses the original compiling approach with session option configs (instead of explicit compile API). +// Test that ORT does not generate an output model if the input model is already compiled. +// Also, ORT should not return an error. +TEST_F(QnnHTPBackendTests, QnnContextBinary_OriginalCompileApproach_IgnoreCompilingOfCompiledModel) { + const ORTCHAR_T* input_model_file = ORT_MODEL_FOLDER "mul_1.onnx"; + const char* output_model_file = "mul_1_ctx.onnx"; + std::filesystem::remove(output_model_file); + + ProviderOptions qnn_options = {{"backend_type", "htp"}}; + + // Compile a model with QNN. This should succeed. + { + Ort::SessionOptions so; + so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); + so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, output_model_file); + so.AppendExecutionProvider(kQnnExecutionProvider, qnn_options); + + Ort::Session session(*ort_env, input_model_file, so); + ASSERT_TRUE(std::filesystem::exists(output_model_file)); // check compiled model was generated. + } + + // Try compiling the compiled model again. ORT should basically ignore it. + { + const char* new_output_model_file = "should_not_be_generated.onnx"; // will not be generated! + std::filesystem::remove(new_output_model_file); + + Ort::SessionOptions so; + so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); + so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, new_output_model_file); + so.AppendExecutionProvider(kQnnExecutionProvider, qnn_options); + + Ort::Session session(*ort_env, ToPathString(output_model_file).c_str(), so); + + // Session creation should not throw an error. And a new output model should not have been generated. + ASSERT_FALSE(std::filesystem::exists(new_output_model_file)); + } +} + // Test that models with 1 non-quantized FusedMatMul node and 1 quantized Add node can still generate the context binary // The generated Onnx model has 1 FusedMatMul node and 1 EPContext node TEST_F(QnnHTPBackendTests, QnnContextBinaryMultiPartitionSupport1) { @@ -681,6 +858,49 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryGenerationFolderPathNotExpected) { } } +// Set ep.context_file_path to invalid file path, check the error message +TEST_F(QnnHTPBackendTests, QnnContextBinaryGenerationFolderPathNotExpected2) { + ProviderOptions provider_options; + provider_options["backend_type"] = "htp"; + provider_options["offload_graph_io_quantization"] = "0"; + + const std::unordered_map domain_to_version = {{"", 13}, {kMSDomain, 1}}; + + auto& logging_manager = DefaultLoggingManager(); + logging_manager.SetDefaultLoggerSeverity(logging::Severity::kERROR); + + onnxruntime::Model model("QNN_EP_TestModel", false, ModelMetaData(), PathString(), + IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {}, + logging_manager.DefaultLogger()); + Graph& graph = model.MainGraph(); + ModelTestBuilder helper(graph); + bool single_ep_node = true; + BuildGraphWithQAndNonQ(single_ep_node)(helper); + helper.SetGraphOutputs(); + ASSERT_STATUS_OK(model.MainGraph().Resolve()); + + // Serialize the model to a string. + std::string model_data; + model.ToProto().SerializeToString(&model_data); + + const auto model_data_span = AsByteSpan(model_data.data(), model_data.size()); + + const std::string ep_context_onnx_file = "./ep_context_folder_not_expected/invalid_file"; + std::remove(ep_context_onnx_file.c_str()); + Ort::SessionOptions so; + so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); + so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, ep_context_onnx_file.c_str()); + so.AppendExecutionProvider("QNN", provider_options); + + try { + Ort::Session session(*ort_env, model_data_span.data(), model_data_span.size(), so); + FAIL(); // Should not get here! + } catch (const Ort::Exception& excpt) { + ASSERT_EQ(excpt.GetOrtErrorCode(), ORT_INVALID_ARGUMENT); + ASSERT_THAT(excpt.what(), testing::HasSubstr("context_file_path should not point to a folder.")); + } +} + // Create session 1 to generate context binary file // Create session 2 to do same thing, make sure session 2 failed because file exist already // Make sure no new file over write from session 2 @@ -728,7 +948,7 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryGenerationNoOverWrite) { FAIL(); // Should not get here! } catch (const Ort::Exception& excpt) { ASSERT_EQ(excpt.GetOrtErrorCode(), ORT_FAIL); - ASSERT_THAT(excpt.what(), testing::HasSubstr("exist already.")); + ASSERT_THAT(excpt.what(), testing::HasSubstr("exists already.")); auto modify_time_2 = std::filesystem::last_write_time(ep_context_binary_file); ASSERT_EQ(modify_time_1, modify_time_2); } diff --git a/onnxruntime/test/providers/qnn/qnn_node_group/scale_softmax_fusion_test.cc b/onnxruntime/test/providers/qnn/qnn_node_group/scale_softmax_fusion_test.cc new file mode 100644 index 0000000000000..aa8dc492a95c9 --- /dev/null +++ b/onnxruntime/test/providers/qnn/qnn_node_group/scale_softmax_fusion_test.cc @@ -0,0 +1,147 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if !defined(ORT_MINIMAL_BUILD) + +#include "core/graph/graph.h" +#include "core/graph/node_attr_utils.h" + +#include "test/optimizer/qdq_test_utils.h" +#include "test/providers/qnn/qnn_test_utils.h" +#include "gtest/gtest.h" + +namespace onnxruntime { +namespace test { + +namespace { + +GetTestModelFn BuildTestCaseScalar( + const TestInputDef& input_def, + float scale_value, + bool use_constant, + bool reverse_input_order, + std::optional softmax_axis = std::nullopt) { + return [&](ModelTestBuilder& builder) -> void { + NodeArg* input = MakeTestInput(builder, input_def); + NodeArg* scale{nullptr}; + if (use_constant) { + onnx::TensorProto scale_value_proto; + scale_value_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + utils::SetRawDataInTensorProto(scale_value_proto, reinterpret_cast(&scale_value), sizeof(float)); + scale = builder.MakeIntermediate(); + builder.AddNode("Constant", {}, {scale}).AddAttribute("value", scale_value_proto); + } else { + scale = builder.MakeScalarInitializer(scale_value); + } + NodeArg* intermediate = builder.MakeIntermediate(); + auto mul_inputs = reverse_input_order ? std::vector{scale, input} : std::vector{input, scale}; + builder.AddNode("Mul", mul_inputs, {intermediate}); + Node& softmax = builder.AddNode("Softmax", {intermediate}, {builder.MakeOutput()}); + if (softmax_axis.has_value()) { + softmax.AddAttribute("axis", softmax_axis.value()); + } + }; +} + +GetTestModelFn BuildTestCaseNoScalar(const TestInputDef& input_def1, const TestInputDef& input_def2) { + return [&input_def1, input_def2](ModelTestBuilder& builder) -> void { + NodeArg* input = MakeTestInput(builder, input_def1); + NodeArg* scale = MakeTestInput(builder, input_def2); + NodeArg* intermediate = builder.MakeIntermediate(); + builder.AddNode("Mul", {input, scale}, {intermediate}); + builder.AddNode("Softmax", {intermediate}, {builder.MakeOutput()}); + }; +} + +ProviderOptions GetProviderOptions() { + ProviderOptions provider_options; + provider_options["backend_type"] = "htp"; + provider_options["offload_graph_io_quantization"] = "0"; + return provider_options; +} + +} // namespace + +#if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) + +TEST_F(QnnHTPBackendTests, ScaleSoftmaxFusionScalarInitializer) { + ProviderOptions provider_options = GetProviderOptions(); + + auto input_def = TestInputDef({1, 3, 5, 5}, false, -0.5f, 0.5f); + RunQnnModelTest(BuildTestCaseScalar(input_def, 0.125f, /*use_constant=*/false, /*reverse_input_order=*/false), + provider_options, + /*opset_version=*/13, + /*expected_ep_assignment=*/ExpectedEPNodeAssignment::All, + /*fp32_abs_err=*/1e-2f); +} + +TEST_F(QnnHTPBackendTests, ScaleSoftmaxFusionScalarConstant) { + ProviderOptions provider_options = GetProviderOptions(); + + auto input_def = TestInputDef({1, 3, 5, 5}, false, -0.5f, 0.5f); + RunQnnModelTest(BuildTestCaseScalar(input_def, 0.375f, /*use_constant=*/true, /*reverse_input_order=*/false), + provider_options, + /*opset_version=*/14, + /*expected_ep_assignment=*/ExpectedEPNodeAssignment::All, + /*fp32_abs_err=*/1e-2f); +} + +TEST_F(QnnHTPBackendTests, ScaleSoftmaxFusionScalarInitializerReversed) { + ProviderOptions provider_options = GetProviderOptions(); + auto input_def = TestInputDef({1, 3, 5, 5}, false, -0.5f, 0.5f); + RunQnnModelTest(BuildTestCaseScalar(input_def, 0.375f, /*use_constant=*/false, /*reverse_input_order=*/true), + provider_options, + /*opset_version=*/15, + /*expected_ep_assignment=*/ExpectedEPNodeAssignment::All, + /*fp32_abs_err=*/1e-2f); +} + +TEST_F(QnnHTPBackendTests, ScaleSoftmaxFusionScalarConstantReversed) { + ProviderOptions provider_options = GetProviderOptions(); + auto input_def = TestInputDef({1, 3, 5, 5}, false, -0.5f, 0.5f); + RunQnnModelTest(BuildTestCaseScalar(input_def, 0.125f, /*use_constant=*/true, /*reverse_input _order=*/true), + provider_options, + /*opset_version=*/16, + /*expected_ep_assignment=*/ExpectedEPNodeAssignment::All, + /*fp32_abs_err=*/1e-2f); +} + +TEST_F(QnnHTPBackendTests, ScaleSoftmaxFusionSoftmaxNegativeAxis) { + ProviderOptions provider_options = GetProviderOptions(); + auto input_def = TestInputDef({1, 3, 5, 5}, false, -0.5f, 0.5f); + RunQnnModelTest(BuildTestCaseScalar(input_def, 0.125f, + /*use_constant=*/true, /*reverse_input_order=*/true, /*softmax_axis=*/-1), + provider_options, + /*opset_version=*/22, + /*expected_ep_assignment=*/ExpectedEPNodeAssignment::All, + /*fp32_abs_err=*/1e-2f); +} + +TEST_F(QnnHTPBackendTests, ScaleSoftmaxFusionSkipNoScalar4d) { + ProviderOptions provider_options = GetProviderOptions(); + auto input_def1 = TestInputDef({1, 3, 5, 5}, false, -0.5f, 0.5f); + auto input_def2 = TestInputDef({1, 3, 5, 5}, false, -0.5f, 0.5f); + RunQnnModelTest(BuildTestCaseNoScalar(input_def1, input_def2), + provider_options, + /*opset_version=*/13, + /*expected_ep_assignment=*/ExpectedEPNodeAssignment::All, + /*fp32_abs_err=*/1e-2f); +} + +TEST_F(QnnHTPBackendTests, ScaleSoftmaxFusionSkipNoScalar1d) { + ProviderOptions provider_options = GetProviderOptions(); + auto input_def1 = TestInputDef({1, 3, 5, 5}, false, -0.5f, 0.5f); + auto input_def2 = TestInputDef({1}, false, -0.5f, 0.5f); + RunQnnModelTest(BuildTestCaseNoScalar(input_def1, input_def2), + provider_options, + /*opset_version=*/13, + /*expected_ep_assignment=*/ExpectedEPNodeAssignment::All, + /*fp32_abs_err=*/1e-2f); +} + +#endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) + +} // namespace test +} // namespace onnxruntime + +#endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/test/providers/qnn/qnn_test_utils.cc b/onnxruntime/test/providers/qnn/qnn_test_utils.cc index 6f8a7a9ecb602..cd163b044911c 100644 --- a/onnxruntime/test/providers/qnn/qnn_test_utils.cc +++ b/onnxruntime/test/providers/qnn/qnn_test_utils.cc @@ -309,11 +309,23 @@ void QnnHTPBackendTests::SetUp() { } } +static BackendSupport GetIRSupport(const onnxruntime::logging::Logger& logger); + +BackendSupport QnnHTPBackendTests::IsIRBackendSupported() const { + const auto& logger = DefaultLoggingManager().DefaultLogger(); + + if (cached_ir_support_ == BackendSupport::SUPPORT_UNKNOWN) { + cached_ir_support_ = test::GetIRSupport(logger); + } + + return cached_ir_support_; +} + // Testing helper function that calls QNN EP's GetCapability() function with a mock graph to check // if the QNN CPU backend is available. // TODO: Remove once the QNN CPU backend works on Windows ARM64 pipeline VM. -static BackendSupport GetCPUSupport(const onnxruntime::logging::Logger& logger) { - onnxruntime::Model model("Check if CPU is available", false, logger); +static BackendSupport GetCPUSupport(const onnxruntime::logging::Logger& logger, const std::string& backend_type = "cpu") { + onnxruntime::Model model("Check if " + backend_type + " is available", false, logger); Graph& graph = model.MainGraph(); ModelTestBuilder helper(graph); @@ -343,7 +355,7 @@ static BackendSupport GetCPUSupport(const onnxruntime::logging::Logger& logger) MockKernelLookup kernel_lookup; onnxruntime::GraphViewer graph_viewer(graph); std::unique_ptr qnn_ep = QnnExecutionProviderWithOptions( - {{"backend_type", "cpu"}, {"offload_graph_io_quantization", "0"}}); + {{"backend_type", backend_type}, {"offload_graph_io_quantization", "0"}}); GraphOptimizerRegistry graph_optimizer_registry(nullptr, nullptr, nullptr); // as a placeholder to feed into GetCapability qnn_ep->SetLogger(&logger); @@ -373,6 +385,33 @@ void QnnCPUBackendTests::SetUp() { } } +static BackendSupport GetIRSupport(const onnxruntime::logging::Logger& logger) { + // QnnIr should be able to serialize any model supported by the QNN reference spec. + // Use a model that works on QnnCpu to verify QnnIr availability. + return GetCPUSupport(logger, "ir"); +} + +void QnnIRBackendTests::SetUp() { + if (cached_ir_support_ == BackendSupport::SUPPORTED) { + return; + } + + const auto& logger = DefaultLoggingManager().DefaultLogger(); + + // Determine if IR backend is supported only if we done so haven't before. + if (cached_ir_support_ == BackendSupport::SUPPORT_UNKNOWN) { + cached_ir_support_ = GetIRSupport(logger); + } + + if (cached_ir_support_ == BackendSupport::UNSUPPORTED) { + LOGS(logger, WARNING) << "QNN IR backend is not available! Skipping test."; + GTEST_SKIP(); + } else if (cached_ir_support_ == BackendSupport::SUPPORT_ERROR) { + LOGS(logger, ERROR) << "Failed to check if QNN IR backend is available."; + FAIL(); + } +} + #if defined(_WIN32) // TODO: Remove or set to SUPPORTED once HTP emulation is supported on win arm64. BackendSupport QnnHTPBackendTests::cached_htp_support_ = BackendSupport::SUPPORT_UNKNOWN; @@ -384,6 +423,9 @@ BackendSupport QnnHTPBackendTests::cached_htp_support_ = BackendSupport::SUPPORT BackendSupport QnnCPUBackendTests::cached_cpu_support_ = BackendSupport::SUPPORTED; #endif // defined(_WIN32) +BackendSupport QnnHTPBackendTests::cached_ir_support_ = BackendSupport::SUPPORT_UNKNOWN; +BackendSupport QnnIRBackendTests::cached_ir_support_ = BackendSupport::SUPPORT_UNKNOWN; + bool ReduceOpHasAxesInput(const std::string& op_type, int opset_version) { static const std::unordered_map opset_with_axes_as_input = { {"ReduceMax", 18}, diff --git a/onnxruntime/test/providers/qnn/qnn_test_utils.h b/onnxruntime/test/providers/qnn/qnn_test_utils.h index 676460e108b0e..9fe48ddabd427 100644 --- a/onnxruntime/test/providers/qnn/qnn_test_utils.h +++ b/onnxruntime/test/providers/qnn/qnn_test_utils.h @@ -1100,7 +1100,11 @@ class QnnHTPBackendTests : public ::testing::Test { protected: void SetUp() override; + // Some tests need the Ir backend, which is not always available. + [[nodiscard]] BackendSupport IsIRBackendSupported() const; + static BackendSupport cached_htp_support_; // Set by the first test using this fixture. + static BackendSupport cached_ir_support_; }; // Testing fixture class for tests that require the QNN CPU backend. Checks if QNN CPU is available before the test @@ -1113,6 +1117,15 @@ class QnnCPUBackendTests : public ::testing::Test { static BackendSupport cached_cpu_support_; // Set by the first test using this fixture. }; +// Testing fixture class for tests that require the QNN Ir backend. Checks if QNN IR is available before the test +// begins. The test is skipped if the IR backend is unavailable (may occur with certain QNN versions). +class QnnIRBackendTests : public ::testing::Test { + protected: + void SetUp() override; + + static BackendSupport cached_ir_support_; // Set by the first test using this fixture. +}; + /** * Returns true if the given reduce operator type (e.g., "ReduceSum") and opset version (e.g., 13) * supports "axes" as an input (instead of an attribute). diff --git a/onnxruntime/test/providers/qnn/simple_op_htp_test.cc b/onnxruntime/test/providers/qnn/simple_op_htp_test.cc index bfdb1a1a6afdd..b441af4a0efe9 100644 --- a/onnxruntime/test/providers/qnn/simple_op_htp_test.cc +++ b/onnxruntime/test/providers/qnn/simple_op_htp_test.cc @@ -1017,6 +1017,78 @@ TEST_F(QnnHTPBackendTests, BinaryOp_HTP_Or_Unsupported) { ExpectedEPNodeAssignment::All); } +// Test ScatterND with reduction ADD on HTP +TEST_F(QnnHTPBackendTests, ScatterND_int64_int64_reduction_add) { + std::vector data = {0, 1, 2, 3}; + std::vector indices = {1}; + std::vector updates = {10}; + RunOpTest("ScatterND", + { + TestInputDef({4}, false, std::move(data)), + TestInputDef({1, 1}, false, std::move(indices)), + TestInputDef({1}, false, std::move(updates)), + }, + { + utils::MakeAttribute("reduction", "add"), + }, + 17, + ExpectedEPNodeAssignment::All); +} + +// Test ScatterND with reduction Mul on HTP +TEST_F(QnnHTPBackendTests, ScatterND_int64_int64_reduction_mul) { + std::vector data = {0, 1, 2, 3}; + std::vector indices = {1}; + std::vector updates = {10}; + RunOpTest("ScatterND", + { + TestInputDef({4}, false, std::move(data)), + TestInputDef({1, 1}, false, std::move(indices)), + TestInputDef({1}, false, std::move(updates)), + }, + { + utils::MakeAttribute("reduction", "mul"), + }, + 17, + ExpectedEPNodeAssignment::All); +} + +// Test ScatterND with reduction Max on CPU Fallback +TEST_F(QnnHTPBackendTests, ScatterND_int64_int64_reduction_max) { + std::vector data = {0, 1, 2, 3}; + std::vector indices = {1}; + std::vector updates = {10}; + RunOpTest("ScatterND", + { + TestInputDef({4}, false, std::move(data)), + TestInputDef({1, 1}, false, std::move(indices)), + TestInputDef({1}, false, std::move(updates)), + }, + { + utils::MakeAttribute("reduction", "max"), + }, + 17, + ExpectedEPNodeAssignment::None); +} + +// Test ScatterND with reduction Min on CPU Fallback +TEST_F(QnnHTPBackendTests, ScatterND_int64_int64_reduction_min) { + std::vector data = {0, 1, 2, 3}; + std::vector indices = {1}; + std::vector updates = {10}; + RunOpTest("ScatterND", + { + TestInputDef({4}, false, std::move(data)), + TestInputDef({1, 1}, false, std::move(indices)), + TestInputDef({1}, false, std::move(updates)), + }, + { + utils::MakeAttribute("reduction", "min"), + }, + 17, + ExpectedEPNodeAssignment::None); +} + // Test 8-bit QDQ GridSample with bilinear TEST_F(QnnHTPBackendTests, GridSample_Bilinear) { RunQDQOpTest("GridSample", diff --git a/onnxruntime/test/python/onnxruntime_test_python_compile_api.py b/onnxruntime/test/python/onnxruntime_test_python_compile_api.py index 7a410d4bbeb6a..b102676860444 100644 --- a/onnxruntime/test/python/onnxruntime_test_python_compile_api.py +++ b/onnxruntime/test/python/onnxruntime_test_python_compile_api.py @@ -13,7 +13,7 @@ from helper import get_name import onnxruntime as onnxrt -from onnxruntime.capi.onnxruntime_pybind11_state import ModelRequiresCompilation +from onnxruntime.capi.onnxruntime_pybind11_state import Fail, ModelRequiresCompilation # handle change from python 3.8 and on where loading a dll from the current directory needs to be explicitly allowed. if platform.system() == "Windows" and sys.version_info.major >= 3 and sys.version_info.minor >= 8: # noqa: YTT204 @@ -120,6 +120,55 @@ def test_compile_with_input_and_output_files(self): model_compiler.compile_to_file(output_model_path) self.assertTrue(os.path.exists(output_model_path)) + def test_compile_flags_error_if_no_compiled_nodes(self): + """ + Tests specifying an additional flag (OrtCompileApiFlags.ERROR_IF_NO_NODES_COMPILED) that + makes compiling return an error if no compiled nodes are generated (e.g., by using CPU EP). + """ + input_model_path = get_name("nhwc_resize_scales_opset18.onnx") + output_model_path = os.path.join(self._tmp_dir_path, "model.compiled1.onnx") + + session_options = onnxrt.SessionOptions() + model_compiler = onnxrt.ModelCompiler( + session_options, + input_model_path, + embed_compiled_data_into_model=True, + external_initializers_file_path=None, + flags=onnxrt.OrtCompileApiFlags.ERROR_IF_NO_NODES_COMPILED, + ) + + # Compiling should raise a Fail exception and the output model should not be generated + with self.assertRaises(Fail) as context: + model_compiler.compile_to_file(output_model_path) + self.assertIn("Unable to compile any nodes", str(context.exception)) + self.assertFalse(os.path.exists(output_model_path)) + + def test_compile_flags_error_if_output_file_exists(self): + """ + Tests specifying an additional flag (OrtCompileApiFlags.ERROR_IF_OUTPUT_FILE_EXISTS) that + makes compiling return an error the output model file already exists. + """ + input_model_path = get_name("nhwc_resize_scales_opset18.onnx") + output_model_path = os.path.join(self._tmp_dir_path, "model.compiled1.onnx") + + # Compile the first time (should be fine) + session_options = onnxrt.SessionOptions() + model_compiler = onnxrt.ModelCompiler( + session_options, + input_model_path, + embed_compiled_data_into_model=True, + external_initializers_file_path=None, + flags=onnxrt.OrtCompileApiFlags.ERROR_IF_OUTPUT_FILE_EXISTS, + ) + + model_compiler.compile_to_file(output_model_path) + self.assertTrue(os.path.exists(output_model_path)) # Output model was generated + + # Compiling again should raise a Fail exception saying that the model already exists. + with self.assertRaises(Fail) as context: + model_compiler.compile_to_file(output_model_path) + self.assertIn("exists already", str(context.exception)) + def test_compile_to_file_with_input_model_in_buffer(self): """ Tests compiling an input model that is stored in a buffer. The output is saved to a file. diff --git a/onnxruntime/test/python/transformers/test_data/models/whisper/decoder_attention_with_sln_fused.onnx b/onnxruntime/test/python/transformers/test_data/models/whisper/decoder_attention_with_sln_fused.onnx deleted file mode 100644 index a0e65a0023612..0000000000000 Binary files a/onnxruntime/test/python/transformers/test_data/models/whisper/decoder_attention_with_sln_fused.onnx and /dev/null differ diff --git a/onnxruntime/test/python/transformers/test_data/models/whisper/decoder_mha_fused.onnx b/onnxruntime/test/python/transformers/test_data/models/whisper/decoder_mha_fused.onnx deleted file mode 100644 index a159efd0cc45c..0000000000000 Binary files a/onnxruntime/test/python/transformers/test_data/models/whisper/decoder_mha_fused.onnx and /dev/null differ diff --git a/onnxruntime/test/python/transformers/test_data/models/whisper/decoder_mha_split_bias_fused.onnx b/onnxruntime/test/python/transformers/test_data/models/whisper/decoder_mha_split_bias_fused.onnx deleted file mode 100644 index 1da242e19e711..0000000000000 Binary files a/onnxruntime/test/python/transformers/test_data/models/whisper/decoder_mha_split_bias_fused.onnx and /dev/null differ diff --git a/onnxruntime/test/python/transformers/test_data/models/whisper/decoder_with_past_cross_mha_fused.onnx b/onnxruntime/test/python/transformers/test_data/models/whisper/decoder_with_past_cross_mha_fused.onnx deleted file mode 100644 index 552839e7234e2..0000000000000 Binary files a/onnxruntime/test/python/transformers/test_data/models/whisper/decoder_with_past_cross_mha_fused.onnx and /dev/null differ diff --git a/onnxruntime/test/python/transformers/test_data/models/whisper/decoder_with_past_cross_mha_split_bias_fused.onnx b/onnxruntime/test/python/transformers/test_data/models/whisper/decoder_with_past_cross_mha_split_bias_fused.onnx deleted file mode 100644 index bc72c9b350087..0000000000000 Binary files a/onnxruntime/test/python/transformers/test_data/models/whisper/decoder_with_past_cross_mha_split_bias_fused.onnx and /dev/null differ diff --git a/onnxruntime/test/python/transformers/test_data/models/whisper/decoder_with_past_self_mha_fused.onnx b/onnxruntime/test/python/transformers/test_data/models/whisper/decoder_with_past_self_mha_fused.onnx deleted file mode 100644 index e51215bff7d30..0000000000000 Binary files a/onnxruntime/test/python/transformers/test_data/models/whisper/decoder_with_past_self_mha_fused.onnx and /dev/null differ diff --git a/onnxruntime/test/python/transformers/test_data/models/whisper/decoder_with_past_self_mha_split_bias_fused.onnx b/onnxruntime/test/python/transformers/test_data/models/whisper/decoder_with_past_self_mha_split_bias_fused.onnx deleted file mode 100644 index c50162eb5bf8e..0000000000000 Binary files a/onnxruntime/test/python/transformers/test_data/models/whisper/decoder_with_past_self_mha_split_bias_fused.onnx and /dev/null differ diff --git a/onnxruntime/test/python/transformers/test_data/models/whisper/encoder_attention_with_sln_fused.onnx b/onnxruntime/test/python/transformers/test_data/models/whisper/encoder_attention_with_sln_fused.onnx deleted file mode 100644 index 751416b47d2eb..0000000000000 Binary files a/onnxruntime/test/python/transformers/test_data/models/whisper/encoder_attention_with_sln_fused.onnx and /dev/null differ diff --git a/onnxruntime/test/python/transformers/test_data/models/whisper/hf_fp16_decoder_attention_no_past.onnx b/onnxruntime/test/python/transformers/test_data/models/whisper/hf_fp16_decoder_attention_no_past.onnx new file mode 100644 index 0000000000000..122981061479a Binary files /dev/null and b/onnxruntime/test/python/transformers/test_data/models/whisper/hf_fp16_decoder_attention_no_past.onnx differ diff --git a/onnxruntime/test/python/transformers/test_data/models/whisper/hf_fp16_decoder_attention_no_past_split_bias.onnx b/onnxruntime/test/python/transformers/test_data/models/whisper/hf_fp16_decoder_attention_no_past_split_bias.onnx new file mode 100644 index 0000000000000..e7127357d0818 Binary files /dev/null and b/onnxruntime/test/python/transformers/test_data/models/whisper/hf_fp16_decoder_attention_no_past_split_bias.onnx differ diff --git a/onnxruntime/test/python/transformers/test_data/models/whisper/hf_fp16_decoder_attention_with_past.onnx b/onnxruntime/test/python/transformers/test_data/models/whisper/hf_fp16_decoder_attention_with_past.onnx new file mode 100644 index 0000000000000..daba9a7015969 Binary files /dev/null and b/onnxruntime/test/python/transformers/test_data/models/whisper/hf_fp16_decoder_attention_with_past.onnx differ diff --git a/onnxruntime/test/python/transformers/test_data/models/whisper/hf_fp16_decoder_attention_with_past_split_bias.onnx b/onnxruntime/test/python/transformers/test_data/models/whisper/hf_fp16_decoder_attention_with_past_split_bias.onnx new file mode 100644 index 0000000000000..4213079e30c92 Binary files /dev/null and b/onnxruntime/test/python/transformers/test_data/models/whisper/hf_fp16_decoder_attention_with_past_split_bias.onnx differ diff --git a/onnxruntime/test/python/transformers/test_data/models/whisper/hf_fp16_encoder_self_attention.onnx b/onnxruntime/test/python/transformers/test_data/models/whisper/hf_fp16_encoder_self_attention.onnx new file mode 100644 index 0000000000000..e0606ce3237a2 Binary files /dev/null and b/onnxruntime/test/python/transformers/test_data/models/whisper/hf_fp16_encoder_self_attention.onnx differ diff --git a/onnxruntime/test/python/transformers/test_data/models/whisper/hf_fp32_decoder_attention_no_past.onnx b/onnxruntime/test/python/transformers/test_data/models/whisper/hf_fp32_decoder_attention_no_past.onnx new file mode 100644 index 0000000000000..5a5d7dc388cba Binary files /dev/null and b/onnxruntime/test/python/transformers/test_data/models/whisper/hf_fp32_decoder_attention_no_past.onnx differ diff --git a/onnxruntime/test/python/transformers/test_data/models/whisper/hf_fp32_decoder_attention_no_past_split_bias.onnx b/onnxruntime/test/python/transformers/test_data/models/whisper/hf_fp32_decoder_attention_no_past_split_bias.onnx new file mode 100644 index 0000000000000..677dcf062a80d Binary files /dev/null and b/onnxruntime/test/python/transformers/test_data/models/whisper/hf_fp32_decoder_attention_no_past_split_bias.onnx differ diff --git a/onnxruntime/test/python/transformers/test_data/models/whisper/hf_fp32_decoder_attention_with_past.onnx b/onnxruntime/test/python/transformers/test_data/models/whisper/hf_fp32_decoder_attention_with_past.onnx new file mode 100644 index 0000000000000..0f08dd828b31a Binary files /dev/null and b/onnxruntime/test/python/transformers/test_data/models/whisper/hf_fp32_decoder_attention_with_past.onnx differ diff --git a/onnxruntime/test/python/transformers/test_data/models/whisper/hf_fp32_decoder_attention_with_past_split_bias.onnx b/onnxruntime/test/python/transformers/test_data/models/whisper/hf_fp32_decoder_attention_with_past_split_bias.onnx new file mode 100644 index 0000000000000..2f00c55afe9a5 Binary files /dev/null and b/onnxruntime/test/python/transformers/test_data/models/whisper/hf_fp32_decoder_attention_with_past_split_bias.onnx differ diff --git a/onnxruntime/test/python/transformers/test_data/models/whisper/hf_fp32_encoder_self_attention.onnx b/onnxruntime/test/python/transformers/test_data/models/whisper/hf_fp32_encoder_self_attention.onnx new file mode 100644 index 0000000000000..578ff6ea32bac Binary files /dev/null and b/onnxruntime/test/python/transformers/test_data/models/whisper/hf_fp32_encoder_self_attention.onnx differ diff --git a/onnxruntime/test/python/transformers/test_data/models/whisper/oai_fp16_decoder_attention_no_past.onnx b/onnxruntime/test/python/transformers/test_data/models/whisper/oai_fp16_decoder_attention_no_past.onnx new file mode 100644 index 0000000000000..5ec93d3cc3b53 Binary files /dev/null and b/onnxruntime/test/python/transformers/test_data/models/whisper/oai_fp16_decoder_attention_no_past.onnx differ diff --git a/onnxruntime/test/python/transformers/test_data/models/whisper/oai_fp16_decoder_attention_no_past_split_bias.onnx b/onnxruntime/test/python/transformers/test_data/models/whisper/oai_fp16_decoder_attention_no_past_split_bias.onnx new file mode 100644 index 0000000000000..11b324293fb37 Binary files /dev/null and b/onnxruntime/test/python/transformers/test_data/models/whisper/oai_fp16_decoder_attention_no_past_split_bias.onnx differ diff --git a/onnxruntime/test/python/transformers/test_data/models/whisper/oai_fp16_decoder_attention_with_past.onnx b/onnxruntime/test/python/transformers/test_data/models/whisper/oai_fp16_decoder_attention_with_past.onnx new file mode 100644 index 0000000000000..1211a2f29caa3 Binary files /dev/null and b/onnxruntime/test/python/transformers/test_data/models/whisper/oai_fp16_decoder_attention_with_past.onnx differ diff --git a/onnxruntime/test/python/transformers/test_data/models/whisper/oai_fp16_decoder_attention_with_past_split_bias.onnx b/onnxruntime/test/python/transformers/test_data/models/whisper/oai_fp16_decoder_attention_with_past_split_bias.onnx new file mode 100644 index 0000000000000..67f569265a1a3 Binary files /dev/null and b/onnxruntime/test/python/transformers/test_data/models/whisper/oai_fp16_decoder_attention_with_past_split_bias.onnx differ diff --git a/onnxruntime/test/python/transformers/test_data/models/whisper/oai_fp16_encoder_self_attention.onnx b/onnxruntime/test/python/transformers/test_data/models/whisper/oai_fp16_encoder_self_attention.onnx new file mode 100644 index 0000000000000..3ca5fc563eca7 Binary files /dev/null and b/onnxruntime/test/python/transformers/test_data/models/whisper/oai_fp16_encoder_self_attention.onnx differ diff --git a/onnxruntime/test/python/transformers/test_data/models/whisper/oai_fp32_decoder_attention_no_past.onnx b/onnxruntime/test/python/transformers/test_data/models/whisper/oai_fp32_decoder_attention_no_past.onnx new file mode 100644 index 0000000000000..8177298b24fb6 Binary files /dev/null and b/onnxruntime/test/python/transformers/test_data/models/whisper/oai_fp32_decoder_attention_no_past.onnx differ diff --git a/onnxruntime/test/python/transformers/test_data/models/whisper/oai_fp32_decoder_attention_no_past_split_bias.onnx b/onnxruntime/test/python/transformers/test_data/models/whisper/oai_fp32_decoder_attention_no_past_split_bias.onnx new file mode 100644 index 0000000000000..5b387854f4b18 Binary files /dev/null and b/onnxruntime/test/python/transformers/test_data/models/whisper/oai_fp32_decoder_attention_no_past_split_bias.onnx differ diff --git a/onnxruntime/test/python/transformers/test_data/models/whisper/oai_fp32_decoder_attention_with_past.onnx b/onnxruntime/test/python/transformers/test_data/models/whisper/oai_fp32_decoder_attention_with_past.onnx new file mode 100644 index 0000000000000..5b492da1185a1 Binary files /dev/null and b/onnxruntime/test/python/transformers/test_data/models/whisper/oai_fp32_decoder_attention_with_past.onnx differ diff --git a/onnxruntime/test/python/transformers/test_data/models/whisper/oai_fp32_decoder_attention_with_past_split_bias.onnx b/onnxruntime/test/python/transformers/test_data/models/whisper/oai_fp32_decoder_attention_with_past_split_bias.onnx new file mode 100644 index 0000000000000..5798ee37ce128 Binary files /dev/null and b/onnxruntime/test/python/transformers/test_data/models/whisper/oai_fp32_decoder_attention_with_past_split_bias.onnx differ diff --git a/onnxruntime/test/python/transformers/test_data/models/whisper/oai_fp32_encoder_self_attention.onnx b/onnxruntime/test/python/transformers/test_data/models/whisper/oai_fp32_encoder_self_attention.onnx new file mode 100644 index 0000000000000..9162f2f20a2ce Binary files /dev/null and b/onnxruntime/test/python/transformers/test_data/models/whisper/oai_fp32_encoder_self_attention.onnx differ diff --git a/onnxruntime/test/python/transformers/test_whisper.py b/onnxruntime/test/python/transformers/test_whisper.py index ceda5a88c3925..3ac723e166084 100644 --- a/onnxruntime/test/python/transformers/test_whisper.py +++ b/onnxruntime/test/python/transformers/test_whisper.py @@ -8,14 +8,10 @@ import unittest import onnx +import torch +from parameterized import parameterized from parity_utilities import find_transformers_source -from whisper_model_generator import ( - create_whisper_decoder_attention, - create_whisper_decoder_multihead_attention, - create_whisper_decoder_with_past_multihead_cross_attention, - create_whisper_decoder_with_past_multihead_self_attention, - create_whisper_encoder_attention, -) +from transformers import EncoderDecoderCache if find_transformers_source(): from fusion_options import FusionOptions @@ -27,6 +23,422 @@ from onnxruntime.transformers.optimizer import optimize_model +# Dummy constants smaller than openai/whisper-tiny +class WhisperConfig: + def __init__(self): + # Hugging Face attribute names + self.hidden_size = 10 + self.num_heads = 2 + self.head_dim = self.hidden_size // self.num_heads + self.d_model = self.embed_dim = self.hidden_size + self.encoder_sequence_length = 20 + self.encoder_ffn_dim = 10 + self.decoder_ffn_dim = 10 + + # OpenAI attribute names + self.n_state = self.hidden_size + self.n_head = self.num_heads + self.n_mlp = self.encoder_ffn_dim + + +# From https://github.com/huggingface/transformers/blob/31f8a0fe8a7e2db1ee30bf32ed5976cd11f3283c/src/transformers/models/whisper/modeling_whisper.py#L222 +class WhisperHFAttention(torch.nn.Module): + def __init__(self): + super().__init__() + config = WhisperConfig() + + self.embed_dim = config.embed_dim + self.num_heads = config.num_heads + self.head_dim = self.embed_dim // self.num_heads + self.scaling = self.head_dim**-0.5 + self.layer_idx = 0 + + self.q_proj = torch.nn.Linear(self.embed_dim, self.embed_dim, bias=True) + self.k_proj = torch.nn.Linear(self.embed_dim, self.embed_dim, bias=False) + self.v_proj = torch.nn.Linear(self.embed_dim, self.embed_dim, bias=True) + self.out_proj = torch.nn.Linear(self.embed_dim, self.embed_dim, bias=True) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: torch.Tensor | None = None, + past_key_value: tuple[tuple[torch.Tensor]] | None = None, + attention_mask: torch.Tensor | None = None, + layer_head_mask: torch.Tensor | None = None, + cache_position: torch.LongTensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]: + """Input shape: Batch x Time x Channel""" + is_updated = past_key_value is not None + past_key_value = EncoderDecoderCache.from_legacy_cache(past_key_value) + past_key_value.is_updated[self.layer_idx] = is_updated + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim) + query_states = query_states.transpose(1, 2).contiguous() + + if past_key_value is not None: + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + past_key_value.is_updated[self.layer_idx] = True + past_key_value = past_key_value.cross_attention_cache + else: + past_key_value = past_key_value.self_attention_cache + + # use key_value_states if cross attention + current_states = key_value_states if key_value_states is not None else hidden_states + if is_cross_attention and past_key_value and is_updated: + # reuse k,v, cross_attentions + key_states = past_key_value.key_cache[self.layer_idx] + value_states = past_key_value.value_cache[self.layer_idx] + else: + key_states = self.k_proj(current_states).view(bsz, -1, self.num_heads, self.head_dim) + value_states = self.v_proj(current_states).view(bsz, -1, self.num_heads, self.head_dim) + key_states = key_states.transpose(1, 2).contiguous() + value_states = value_states.transpose(1, 2).contiguous() + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) + + if layer_head_mask is not None: + if layer_head_mask.size() != (self.num_heads,): + raise ValueError( + f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" + f" {layer_head_mask.size()}" + ) + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights + + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2) + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned across GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + past_key_value = past_key_value.to_legacy_cache() + return attn_output, past_key_value + + +# From https://github.com/huggingface/transformers/blob/31f8a0fe8a7e2db1ee30bf32ed5976cd11f3283c/src/transformers/models/whisper/modeling_whisper.py#L583 +class WhisperHFEncoderLayer(torch.nn.Module): + def __init__(self): + super().__init__() + config = WhisperConfig() + self.embed_dim = config.d_model + + self.self_attn = WhisperHFAttention() + self.self_attn_layer_norm = torch.nn.LayerNorm(self.embed_dim) + self.activation_fn = torch.nn.GELU() + self.fc1 = torch.nn.Linear(self.embed_dim, config.encoder_ffn_dim) + self.fc2 = torch.nn.Linear(config.encoder_ffn_dim, self.embed_dim) + self.final_layer_norm = torch.nn.LayerNorm(self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + layer_head_mask: torch.Tensor | None = None, + ) -> torch.Tensor: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + """ + hidden_states += 1 # Add fake add to help with fusion testing + + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = self.fc2(hidden_states) + hidden_states = residual + hidden_states + + if hidden_states.dtype == torch.float16: + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + outputs = (hidden_states,) + return outputs + + +# From https://github.com/huggingface/transformers/blob/31f8a0fe8a7e2db1ee30bf32ed5976cd11f3283c/src/transformers/models/whisper/modeling_whisper.py#L651 +class WhisperHFDecoderLayer(torch.nn.Module): + def __init__(self): + super().__init__() + config = WhisperConfig() + self.embed_dim = config.d_model + + self.self_attn = WhisperHFAttention() + self.activation_fn = torch.nn.GELU() + + self.self_attn_layer_norm = torch.nn.LayerNorm(self.embed_dim) + self.encoder_attn = WhisperHFAttention() + self.encoder_attn_layer_norm = torch.nn.LayerNorm(self.embed_dim) + self.fc1 = torch.nn.Linear(self.embed_dim, config.decoder_ffn_dim) + self.fc2 = torch.nn.Linear(config.decoder_ffn_dim, self.embed_dim) + self.final_layer_norm = torch.nn.LayerNorm(self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + encoder_hidden_states: torch.Tensor | None = None, + encoder_attention_mask: torch.Tensor | None = None, + layer_head_mask: torch.Tensor | None = None, + cross_attn_layer_head_mask: torch.Tensor | None = None, + past_key_value: tuple[tuple[torch.Tensor]] | None = None, + use_cache: bool | None = True, + cache_position: torch.LongTensor | None = None, + ) -> torch.Tensor: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + encoder_hidden_states (`torch.FloatTensor`): + cross attention input to the layer of shape `(batch, seq_len, embed_dim)` + encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of + size `(decoder_attention_heads,)`. + past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states + """ + hidden_states += 1 # Add fake add to help with fusion testing + batch_size, target_length = attention_mask.shape # Get shape to create 4D attention mask + sequence_length = hidden_states.size(1) # Get shape to create 4D attention mask + + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Self Attention + hidden_states, present_key_value = self.self_attn( + hidden_states=hidden_states, + past_key_value=past_key_value, + attention_mask=attention_mask[:, None, None, :].expand(batch_size, 1, sequence_length, target_length), + layer_head_mask=layer_head_mask, + cache_position=cache_position, + ) + hidden_states = residual + hidden_states + + # Cross-Attention Block + if encoder_hidden_states is not None: + residual = hidden_states + hidden_states = self.encoder_attn_layer_norm(hidden_states) + hidden_states, cross_attn_present_key_value = self.encoder_attn( + hidden_states=hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + layer_head_mask=cross_attn_layer_head_mask, + past_key_value=past_key_value, + ) + hidden_states = residual + hidden_states + + # add cross-attn to positions 1 of present_key_value tuple + if past_key_value is None: + # Skip if cross-attention has past KV cache inputs since the outputs are identical + present_key_value = (present_key_value, cross_attn_present_key_value) + + # Fully Connected + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = self.fc2(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +# From https://github.com/openai/whisper/blob/dd985ac4b90cafeef8712f2998d62c59c3e62d22/whisper/model.py#L44 +class WhisperOAILinear(torch.nn.Linear): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.linear( + x, + self.weight.to(x.dtype), + None if self.bias is None else self.bias.to(x.dtype), + ) + + +# From https://github.com/openai/whisper/blob/423492dda7806206abe56bdfe427c1096473a020/whisper/model.py#L62 +class WhisperOAIAttention(torch.nn.Module): + def __init__(self): + super().__init__() + config = WhisperConfig() + self.n_head = config.n_head + self.query = WhisperOAILinear(config.n_state, config.n_state) + self.key = WhisperOAILinear(config.n_state, config.n_state, bias=False) + self.value = WhisperOAILinear(config.n_state, config.n_state) + self.out = WhisperOAILinear(config.n_state, config.n_state) + + def forward( + self, + x: torch.Tensor, + xa: torch.Tensor | None = None, + mask: torch.Tensor | None = None, + kv_cache: tuple[torch.Tensor] | None = None, + ): + q = self.query(x) + present_k, present_v = None, None + + if kv_cache is None or xa is None: + # If xa == None: self-attention without KV cache inputs + # If xa != None: cross-attention without KV cache inputs + k = self.key(x if xa is None else xa) + v = self.value(x if xa is None else xa) + + if mask is not None and kv_cache is not None: + # Self-attention with KV cache inputs and outputs + past_k = kv_cache[0] + past_k = past_k.transpose(1, 2) + past_k = past_k.reshape(past_k.shape[:2] + (-1,)) + past_v = kv_cache[1] + past_v = past_v.transpose(1, 2) + past_v = past_v.reshape(past_v.shape[:2] + (-1,)) + + present_k = torch.cat([past_k, k], dim=1) + present_v = torch.cat([past_v, v], dim=1) + + present_k = present_k.reshape(present_k.shape[:2] + (-1, self.n_head)).transpose(1, 2) + present_v = present_v.reshape(v.shape[:2] + (-1, self.n_head)).transpose(1, 2) + else: + # Cross-attention with KV cache inputs + past_k = kv_cache[0] + past_k = past_k.transpose(1, 2) + past_k = past_k.reshape(past_k.shape[:2] + (-1,)) + past_v = kv_cache[1] + past_v = past_v.transpose(1, 2) + past_v = past_v.reshape(past_v.shape[:2] + (-1,)) + k = past_k + v = past_v + + n_batch, n_ctx, n_state = q.shape + q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) + k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) + v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) + + wv, qk = self.qkv_attention(q, k, v, mask, n_ctx, n_state) + o = self.out(wv) + + if mask is None and kv_cache is not None: + # Cross-attention with KV cache inputs + return o, None, None + + if mask is not None and kv_cache is not None: + # Self-attention with KV cache inputs and outputs + return o, present_k, present_v + + return o, k, v + + def qkv_attention( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + mask: torch.Tensor | None, + n_ctx: int, + n_state: int, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + scale = (n_state // self.n_head) ** -0.25 + + qk = (q * scale) @ (k * scale) + if mask is not None: + qk = qk + mask[:n_ctx, :n_ctx] + + w = torch.nn.functional.softmax(qk, dim=-1) + out = (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2) + qk = qk.detach() + + return out, qk + + +# From https://github.com/openai/whisper/blob/dd985ac4b90cafeef8712f2998d62c59c3e62d22/whisper/model.py#L142 +class WhisperOAIResidualAttentionBlock(torch.nn.Module): + def __init__(self, cross_attention: bool = False): + super().__init__() + config = WhisperConfig() + + self.attn = WhisperOAIAttention() + self.attn_ln = torch.nn.LayerNorm(config.n_state) + + self.cross_attn = WhisperOAIAttention() if cross_attention else None + self.cross_attn_ln = torch.nn.LayerNorm(config.n_state) if cross_attention else None + + self.mlp = torch.nn.Sequential( + WhisperOAILinear(config.n_state, config.n_mlp), + torch.nn.GELU(), + WhisperOAILinear(config.n_mlp, config.n_state), + ) + self.mlp_ln = torch.nn.LayerNorm(config.n_state) + + def forward( + self, + x: torch.Tensor, + xa: torch.Tensor | None = None, + mask: torch.Tensor | None = None, + kv_cache: tuple[torch.Tensor] | None = None, + ): + x += 1 # Add fake add to help with fusion testing + + self_attn_output, self_k, self_v = self.attn( + self.attn_ln(x), mask=mask, kv_cache=(kv_cache[:2] if kv_cache is not None else kv_cache) + ) + x = x + self_attn_output + if self.cross_attn: + cross_attn_output, cross_k, cross_v = self.cross_attn( + self.cross_attn_ln(x), xa, kv_cache=(kv_cache[2:] if kv_cache is not None else kv_cache) + ) + x = x + cross_attn_output + else: + self_k = self_v = cross_k = cross_v = None # Set to none when creating encoder model's attention block + x = x + self.mlp(self.mlp_ln(x)) + return x, (self_k, self_v, cross_k, cross_v) + + class TestFusion(unittest.TestCase): def verify_fusion(self, optimized_model, expected_model_filename): optimized_model.topological_sort(is_deterministic=True) @@ -50,144 +462,457 @@ def verify_fusion(self, optimized_model, expected_model_filename): ) ) - # Attention type #1 in fusion_bart_attention.py - def test_encoder_attention_fusion_with_skiplayernorm(self): - num_heads = 4 - hidden_size = 64 - model = create_whisper_encoder_attention( - num_heads=num_heads, hidden_size=hidden_size, add_before_layernorm=False - ) - dir = "." - model_path = os.path.join(dir, "whisper_encoder_attention_sln.onnx") - onnx.save(model, model_path) - options = FusionOptions("bart") - optimized_model = optimize_model( - model_path, model_type="bart", num_heads=num_heads, hidden_size=hidden_size, optimization_options=options - ) - os.remove(model_path) - self.verify_fusion(optimized_model, "encoder_attention_with_sln_fused.onnx") - - # Attention type #2 in fusion_bart_attention.py - def test_decoder_attention_fusion_with_skiplayernorm(self): - num_heads = 4 - hidden_size = 64 - model = create_whisper_decoder_attention( - num_heads=num_heads, hidden_size=hidden_size, add_before_layernorm=False - ) - dir = "." - model_path = os.path.join(dir, "whisper_decoder_attention_sln.onnx") - onnx.save(model, model_path) - options = FusionOptions("bart") - optimized_model = optimize_model( - model_path, model_type="bart", num_heads=num_heads, hidden_size=hidden_size, optimization_options=options - ) - os.remove(model_path) - self.verify_fusion(optimized_model, "decoder_attention_with_sln_fused.onnx") - - # Attention type #4 in fusion_bart_attention.py - def test_decoder_multihead_attention_fusion(self): - num_heads = 4 - hidden_size = 64 - model = create_whisper_decoder_multihead_attention(num_heads=num_heads, hidden_size=hidden_size) - dir = "." - model_path = os.path.join(dir, "whisper_decoder_mha.onnx") - onnx.save(model, model_path) - options = FusionOptions("bart") - options.use_multi_head_attention = True + def export(self, model, inputs, input_names, output_names, dynamic_axes): + torch.onnx.export( + model, + args=inputs, + f=os.path.join(os.path.dirname(__file__), "export.onnx"), + export_params=True, + input_names=input_names, + output_names=output_names, + dynamic_axes=dynamic_axes, + opset_version=17, + do_constant_folding=True, + verbose=False, + ) + + def setUp(self): + # Reset the seed to 0 so that the tensor weights stay the same for each test case + # whether FP16 or FP32 is tested in a CI + torch.manual_seed(0) + + self.config = WhisperConfig() + self.optimization_options = FusionOptions("bart") + self.optimization_options.use_multi_head_attention = True + + self.batch_size = 2 + self.sequence_length = 10 + + def postSetUp(self, precision, split_bias=False): # noqa: N802 + use_fp16 = precision == "fp16" + self.device = torch.device("cuda" if use_fp16 else "cpu") + self.torch_dtype = torch.float16 if use_fp16 else torch.float32 + self.optimization_options.disable_multi_head_attention_bias = split_bias + + def tearDown(self): + path = os.path.join(os.path.dirname(__file__), "export.onnx") + if os.path.exists(path): + os.remove(path) + + @parameterized.expand( + [ + ("fp16", "cuda"), + ("fp32", "cpu"), + ] + ) + def test_hf_whisper_encoder_self_attention(self, precision, ep): + if ep == "cuda" and not torch.cuda.is_available(): + return + self.postSetUp(precision) + model = WhisperHFEncoderLayer().to(dtype=self.torch_dtype, device=self.device) + + hidden_states = torch.randn( + self.batch_size, self.sequence_length, self.config.embed_dim, device=self.device, dtype=self.torch_dtype + ) + inputs = (hidden_states,) + self.export( + model, inputs, input_names=["input_hidden_states"], output_names=["output_hidden_states"], dynamic_axes={} + ) + + original_model = onnx.load(os.path.join(os.path.dirname(__file__), "export.onnx")) optimized_model = optimize_model( - model_path, model_type="bart", num_heads=num_heads, hidden_size=hidden_size, optimization_options=options - ) - os.remove(model_path) - self.verify_fusion(optimized_model, "decoder_mha_fused.onnx") - - # Attention type #3 in fusion_bart_attention.py - def test_decoder_with_past_multihead_self_attention_fusion_with_skiplayernorm(self): - num_heads = 4 - hidden_size = 64 - model = create_whisper_decoder_with_past_multihead_self_attention( - num_heads=num_heads, hidden_size=hidden_size, add_before_layernorm=False - ) - dir = "." - model_path = os.path.join(dir, "whisper_decoder_with_past_self_mha.onnx") - onnx.save(model, model_path) - options = FusionOptions("bart") - options.use_multi_head_attention = True + original_model, + model_type="bart", + num_heads=self.config.num_heads, + hidden_size=self.config.embed_dim, + optimization_options=self.optimization_options, + opt_level=0, + use_gpu=True, + only_onnxruntime=False, + ) + name = f"hf_{precision}_encoder_self_attention.onnx" + # optimized_model.save_model_to_file(name) # Uncomment for debugging purposes + self.verify_fusion(optimized_model, name) + + @parameterized.expand( + [ + ("fp16", "cuda", False), + ("fp16", "cuda", True), + ("fp32", "cpu", False), + ("fp32", "cpu", True), + ] + ) + def test_hf_whisper_decoder_no_past(self, precision, ep, split_bias): + if ep == "cuda" and not torch.cuda.is_available(): + return + self.postSetUp(precision, split_bias) + model = WhisperHFDecoderLayer().to(dtype=self.torch_dtype, device=self.device) + + hidden_states = torch.randn( + self.batch_size, self.sequence_length, self.config.embed_dim, device=self.device, dtype=self.torch_dtype + ) + attention_mask = torch.ones(self.batch_size, self.sequence_length, device=self.device, dtype=self.torch_dtype) + encoder_hidden_states = torch.randn( + self.batch_size, + self.config.encoder_sequence_length, + self.config.embed_dim, + device=self.device, + dtype=self.torch_dtype, + ) + inputs = ( + hidden_states, + attention_mask, + encoder_hidden_states, + ) + self.export( + model, + inputs, + input_names=["input_hidden_states", "attention_mask", "encoder_hidden_states"], + output_names=[ + "output_hidden_states", + "present_key_self", + "present_value_self", + "present_key_cross", + "present_value_cross", + ], + dynamic_axes={}, + ) + + original_model = onnx.load(os.path.join(os.path.dirname(__file__), "export.onnx")) optimized_model = optimize_model( - model_path, model_type="bart", num_heads=num_heads, hidden_size=hidden_size, optimization_options=options - ) - os.remove(model_path) - self.verify_fusion(optimized_model, "decoder_with_past_self_mha_fused.onnx") - - # Attention type #5 in fusion_bart_attention.py - def test_decoder_with_past_multihead_cross_attention_fusion(self): - num_heads = 4 - hidden_size = 64 - model = create_whisper_decoder_with_past_multihead_cross_attention(num_heads=num_heads, hidden_size=hidden_size) - dir = "." - model_path = os.path.join(dir, "whisper_decoder_with_past_cross_mha.onnx") - onnx.save(model, model_path) - options = FusionOptions("bart") - options.use_multi_head_attention = True + original_model, + model_type="bart", + num_heads=self.config.num_heads, + hidden_size=self.config.embed_dim, + optimization_options=self.optimization_options, + opt_level=0, + use_gpu=True, + only_onnxruntime=False, + ) + name = f"hf_{precision}_decoder_attention_no_past{'_split_bias' if split_bias else ''}.onnx" + # optimized_model.save_model_to_file(name) # Uncomment for debugging purposes + self.verify_fusion(optimized_model, name) + + @parameterized.expand( + [ + ("fp16", "cuda", False), + ("fp16", "cuda", True), + ("fp32", "cpu", False), + ("fp32", "cpu", True), + ] + ) + def test_hf_whisper_decoder_with_past(self, precision, ep, split_bias): + if ep == "cuda" and not torch.cuda.is_available(): + return + self.postSetUp(precision, split_bias) + model = WhisperHFDecoderLayer().to(dtype=self.torch_dtype, device=self.device) + + hidden_states = torch.randn( + self.batch_size, 1, self.config.embed_dim, device=self.device, dtype=self.torch_dtype + ) + attention_mask = torch.ones( + self.batch_size, self.sequence_length + 1, device=self.device, dtype=self.torch_dtype + ) + encoder_hidden_states = torch.randn( + self.batch_size, + self.config.encoder_sequence_length, + self.config.embed_dim, + device=self.device, + dtype=self.torch_dtype, + ) + past_key_self = torch.randn( + self.batch_size, + self.config.num_heads, + self.sequence_length, + self.config.head_dim, + device=self.device, + dtype=self.torch_dtype, + ) + past_value_self = torch.randn( + self.batch_size, + self.config.num_heads, + self.sequence_length, + self.config.head_dim, + device=self.device, + dtype=self.torch_dtype, + ) + past_key_cross = torch.randn( + self.batch_size, + self.config.num_heads, + self.config.encoder_sequence_length, + self.config.head_dim, + device=self.device, + dtype=self.torch_dtype, + ) + past_value_cross = torch.randn( + self.batch_size, + self.config.num_heads, + self.config.encoder_sequence_length, + self.config.head_dim, + device=self.device, + dtype=self.torch_dtype, + ) + + # past_key_values is of shape (num_layers) where each element is of shape (4) + # + # Ex: + # past_key_values = (layer_0_tuple, layer_1_tuple,) + # layer_0_tuple = (past_key_self_0, past_value_self_0, past_key_cross_0, past_value_cross_0,) + # layer_1_tuple = (past_key_self_1, past_value_self_1, past_key_cross_1, past_value_cross_1,) + past_key_values = ( + ( + past_key_self, + past_value_self, + past_key_cross, + past_value_cross, + ), + ) + + inputs = ( + hidden_states, + attention_mask, + encoder_hidden_states, + None, + None, + None, + past_key_values, + ) + self.export( + model, + inputs, + input_names=[ + "input_hidden_states", + "attention_mask", + "encoder_hidden_states", + "past_key_self", + "past_value_self", + "past_key_cross", + "past_value_cross", + ], + output_names=["output_hidden_states", "present_key_self", "present_value_self"], + dynamic_axes={}, + ) + + original_model = onnx.load(os.path.join(os.path.dirname(__file__), "export.onnx")) optimized_model = optimize_model( - model_path, model_type="bart", num_heads=num_heads, hidden_size=hidden_size, optimization_options=options - ) - os.remove(model_path) - self.verify_fusion(optimized_model, "decoder_with_past_cross_mha_fused.onnx") - - # Attention type #4 in fusion_bart_attention.py - def test_decoder_multihead_attention_split_bias_fusion(self): - num_heads = 4 - hidden_size = 64 - model = create_whisper_decoder_multihead_attention(num_heads=num_heads, hidden_size=hidden_size) - dir = "." - model_path = os.path.join(dir, "whisper_decoder_mha.onnx") - onnx.save(model, model_path) - options = FusionOptions("bart") - options.use_multi_head_attention = True - options.disable_multi_head_attention_bias = True + original_model, + model_type="bart", + num_heads=self.config.num_heads, + hidden_size=self.config.embed_dim, + optimization_options=self.optimization_options, + opt_level=0, + use_gpu=True, + only_onnxruntime=False, + ) + name = f"hf_{precision}_decoder_attention_with_past{'_split_bias' if split_bias else ''}.onnx" + # optimized_model.save_model_to_file(name) # Uncomment for debugging purposes + self.verify_fusion(optimized_model, name) + + @parameterized.expand( + [ + ("fp16", "cuda"), + ("fp32", "cpu"), + ] + ) + def test_oai_whisper_encoder_self_attention(self, precision, ep): + if ep == "cuda" and not torch.cuda.is_available(): + return + self.postSetUp(precision) + model = WhisperOAIResidualAttentionBlock().to(dtype=self.torch_dtype, device=self.device) + + hidden_states = torch.randn( + self.batch_size, self.sequence_length, self.config.embed_dim, device=self.device, dtype=self.torch_dtype + ) + inputs = (hidden_states,) + self.export( + model, inputs, input_names=["input_hidden_states"], output_names=["output_hidden_states"], dynamic_axes={} + ) + + original_model = onnx.load(os.path.join(os.path.dirname(__file__), "export.onnx")) optimized_model = optimize_model( - model_path, model_type="bart", num_heads=num_heads, hidden_size=hidden_size, optimization_options=options + original_model, + model_type="bart", + num_heads=self.config.num_heads, + hidden_size=self.config.embed_dim, + optimization_options=self.optimization_options, + opt_level=0, + use_gpu=True, + only_onnxruntime=False, ) - os.remove(model_path) - self.verify_fusion(optimized_model, "decoder_mha_split_bias_fused.onnx") + name = f"oai_{precision}_encoder_self_attention.onnx" + # optimized_model.save_model_to_file(name) # Uncomment for debugging purposes + self.verify_fusion(optimized_model, name) - # Attention type #3 in fusion_bart_attention.py - def test_decoder_with_past_multihead_self_attention_split_bias_fusion_with_skiplayernorm(self): - num_heads = 4 - hidden_size = 64 - model = create_whisper_decoder_with_past_multihead_self_attention( - num_heads=num_heads, hidden_size=hidden_size, add_before_layernorm=False + @parameterized.expand( + [ + ("fp16", "cuda", False), + ("fp16", "cuda", True), + ("fp32", "cpu", False), + ("fp32", "cpu", True), + ] + ) + def test_oai_whisper_decoder_no_past(self, precision, ep, split_bias): + if ep == "cuda" and not torch.cuda.is_available(): + return + self.postSetUp(precision, split_bias) + model = WhisperOAIResidualAttentionBlock(cross_attention=True).to(dtype=self.torch_dtype, device=self.device) + + hidden_states = torch.randn( + self.batch_size, self.sequence_length, self.config.embed_dim, device=self.device, dtype=self.torch_dtype + ) + encoder_hidden_states = torch.randn( + self.batch_size, + self.config.encoder_sequence_length, + self.config.embed_dim, + device=self.device, + dtype=self.torch_dtype, + ) + attention_mask = torch.ones( + self.sequence_length, self.sequence_length, device=self.device, dtype=self.torch_dtype + ) + inputs = ( + hidden_states, + encoder_hidden_states, + attention_mask, + ) + self.export( + model, + inputs, + input_names=[ + "input_hidden_states", + "encoder_hidden_states", + "attention_mask", + ], + output_names=[ + "output_hidden_states", + "present_key_self", + "present_value_self", + "present_key_cross", + "present_value_cross", + ], + dynamic_axes={}, ) - dir = "." - model_path = os.path.join(dir, "whisper_decoder_with_past_self_mha.onnx") - onnx.save(model, model_path) - options = FusionOptions("bart") - options.use_multi_head_attention = True - options.disable_multi_head_attention_bias = True + original_model = onnx.load(os.path.join(os.path.dirname(__file__), "export.onnx")) optimized_model = optimize_model( - model_path, model_type="bart", num_heads=num_heads, hidden_size=hidden_size, optimization_options=options - ) - os.remove(model_path) - self.verify_fusion(optimized_model, "decoder_with_past_self_mha_split_bias_fused.onnx") - - # Attention type #5 in fusion_bart_attention.py - def test_decoder_with_past_multihead_cross_attention_split_bias_fusion(self): - num_heads = 4 - hidden_size = 64 - model = create_whisper_decoder_with_past_multihead_cross_attention(num_heads=num_heads, hidden_size=hidden_size) - dir = "." - model_path = os.path.join(dir, "whisper_decoder_with_past_cross_mha.onnx") - onnx.save(model, model_path) - options = FusionOptions("bart") - options.use_multi_head_attention = True - options.disable_multi_head_attention_bias = True + original_model, + model_type="bart", + num_heads=self.config.num_heads, + hidden_size=self.config.embed_dim, + optimization_options=self.optimization_options, + opt_level=0, + use_gpu=True, + only_onnxruntime=False, + ) + name = f"oai_{precision}_decoder_attention_no_past{'_split_bias' if split_bias else ''}.onnx" + # optimized_model.save_model_to_file(name) # Uncomment for debugging purposes + self.verify_fusion(optimized_model, name) + + @parameterized.expand( + [ + ("fp16", "cuda", False), + ("fp16", "cuda", True), + ("fp32", "cpu", False), + ("fp32", "cpu", True), + ] + ) + def test_oai_whisper_decoder_with_past(self, precision, ep, split_bias): + if ep == "cuda" and not torch.cuda.is_available(): + return + self.postSetUp(precision, split_bias) + model = WhisperOAIResidualAttentionBlock(cross_attention=True).to(dtype=self.torch_dtype, device=self.device) + + hidden_states = torch.randn( + self.batch_size, 1, self.config.embed_dim, device=self.device, dtype=self.torch_dtype + ) + encoder_hidden_states = torch.randn( + self.batch_size, + self.config.encoder_sequence_length, + self.config.embed_dim, + device=self.device, + dtype=self.torch_dtype, + ) + attention_mask = torch.ones(1, 1, device=self.device, dtype=self.torch_dtype) + past_key_self = torch.randn( + self.batch_size, + self.config.num_heads, + self.sequence_length, + self.config.head_dim, + device=self.device, + dtype=self.torch_dtype, + ) + past_value_self = torch.randn( + self.batch_size, + self.config.num_heads, + self.sequence_length, + self.config.head_dim, + device=self.device, + dtype=self.torch_dtype, + ) + past_key_cross = torch.randn( + self.batch_size, + self.config.num_heads, + self.config.encoder_sequence_length, + self.config.head_dim, + device=self.device, + dtype=self.torch_dtype, + ) + past_value_cross = torch.randn( + self.batch_size, + self.config.num_heads, + self.config.encoder_sequence_length, + self.config.head_dim, + device=self.device, + dtype=self.torch_dtype, + ) + + # past_key_values is of shape (num_layers) where each element is a past key/value + # + # Ex: + # past_key_values = (past_key_self_0, past_value_self_0, past_key_cross_0, past_value_cross_0,) + past_key_values = ( + past_key_self, + past_value_self, + past_key_cross, + past_value_cross, + ) + + inputs = ( + hidden_states, + encoder_hidden_states, + attention_mask, + past_key_values, + ) + self.export( + model, + inputs, + input_names=[ + "input_hidden_states", + "encoder_hidden_states", + "attention_mask", + "past_key_self", + "past_value_self", + "past_key_cross", + "past_value_cross", + ], + output_names=["output_hidden_states", "present_key_self", "present_value_self"], + dynamic_axes={}, + ) + original_model = onnx.load(os.path.join(os.path.dirname(__file__), "export.onnx")) optimized_model = optimize_model( - model_path, model_type="bart", num_heads=num_heads, hidden_size=hidden_size, optimization_options=options + original_model, + model_type="bart", + num_heads=self.config.num_heads, + hidden_size=self.config.embed_dim, + optimization_options=self.optimization_options, + opt_level=0, + use_gpu=True, + only_onnxruntime=False, ) - os.remove(model_path) - self.verify_fusion(optimized_model, "decoder_with_past_cross_mha_split_bias_fused.onnx") + name = f"oai_{precision}_decoder_attention_with_past{'_split_bias' if split_bias else ''}.onnx" + # optimized_model.save_model_to_file(name) # Uncomment for debugging purposes + self.verify_fusion(optimized_model, name) if __name__ == "__main__": diff --git a/onnxruntime/test/python/transformers/whisper_model_generator.py b/onnxruntime/test/python/transformers/whisper_model_generator.py deleted file mode 100644 index 5527df489b846..0000000000000 --- a/onnxruntime/test/python/transformers/whisper_model_generator.py +++ /dev/null @@ -1,2021 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - - -import numpy as np -import onnx -from bert_model_generator import float_tensor -from onnx import TensorProto, helper, numpy_helper - - -# Adapted from bert_model_generator.py -def get_tensor_and_weight(name: str, shape: list[int], random=False, zeros=False): - low = 0.0 - high = 1.0 - total_elements = 1 - for x in shape: - total_elements *= x - weights = ( - [np.random.uniform(low, high) for _ in range(total_elements)] - if random - else [0.0] * total_elements - if zeros - else [1.0] * total_elements - ) - return helper.make_tensor(name, TensorProto.FLOAT, shape, weights), weights - - -def create_whisper_encoder_attention( - hidden_size=768, - num_heads=12, - epsilon=0.000009999999747378752, - add_before_layernorm=False, - add_k=False, - fused=False, -): - # Get head size and ensure head size is an integer - assert hidden_size % num_heads == 0 - head_size = hidden_size // num_heads - - # Construct input and output nodes - inputs = [ - helper.make_tensor_value_info("input_0", TensorProto.FLOAT, ["batch_size", 1500, hidden_size]), - ] - outputs = [ - helper.make_tensor_value_info("output_0", TensorProto.FLOAT, ["batch_size", 1500, hidden_size]), - helper.make_tensor_value_info("output_1", TensorProto.FLOAT, ["batch_size", "sequence_length", hidden_size]), - ] - - nodes = [] - # Create layernorm (Add + LayerNorm or SkipLayerNorm) - if add_before_layernorm: - nodes.extend( - [ - helper.make_node( - "Add", ["input_0", "input_0"], ["layernorm_output_to_skiplayernorm"], "add_before_layernorm" - ), - helper.make_node( - "LayerNormalization", - ["layernorm_output_to_skiplayernorm", "layernorm_weight", "layernorm_bias"], - ["layernorm_output_to_matmul"], - "layernorm", - epsilon=epsilon, - ), - ] - ) - else: - nodes.append( - helper.make_node( - "SkipLayerNormalization", - ["input_0", "input_0", "layernorm_weight", "layernorm_bias"], - ["layernorm_output_to_matmul", "", "", "layernorm_output_to_skiplayernorm"], - "skiplayernorm", - domain="com.microsoft", - epsilon=epsilon, - ) - ) - - if fused: - nodes.append( - helper.make_node( - "Attention", - ["layernorm_output_to_matmul", "Attention_0_qkv_weight", "Attention_0_qkv_bias", ""], - ["attn_output"], - "Attention_0", - domain="com.microsoft", - num_heads=num_heads, - ), - ) - else: - # Create nodes for Q/K/V paths - q_nodes = [ - helper.make_node( - "MatMul", ["layernorm_output_to_matmul", "q_weight"], ["q_matmul_output"], "q_path_matmul" - ), - helper.make_node("Add", ["q_bias", "q_matmul_output"], ["q_add_output"], "q_path_add"), - helper.make_node("Mul", ["q_add_output", "q_scale"], ["q_mul_output"], "q_path_mul"), - helper.make_node("Reshape", ["q_mul_output", "q_bsnh_reshape"], ["q_4d_bsnh"], "q_reshape_to_4d"), - helper.make_node("Transpose", ["q_4d_bsnh"], ["q_4d_bnsh"], "q_transpose_to_bnsh", perm=[0, 2, 1, 3]), - helper.make_node( - "Reshape", - ["q_4d_bnsh", "q_attn_heads_output"], - ["q_output_(num_heads*batch_size,seq_len,head_size)"], - "q_reshape_to_3d", - ), - ] - k_nodes = [ - helper.make_node( - "MatMul", ["layernorm_output_to_matmul", "k_weight"], ["k_matmul_output"], "k_path_matmul" - ), - ] - if add_k: - k_nodes.extend( - [ - helper.make_node("Add", ["k_bias", "k_matmul_output"], ["k_add_output"], "k_path_add"), - helper.make_node("Reshape", ["k_add_output", "bsnh_reshape"], ["k_4d_bsnh"], "k_reshape_to_4d"), - ] - ) - else: - k_nodes.append( - helper.make_node("Reshape", ["k_matmul_output", "kv_bsnh_reshape"], ["k_4d_bsnh"], "k_reshape_to_4d"), - ) - k_nodes.extend( - [ - helper.make_node("Transpose", ["k_4d_bsnh"], ["k_4d_bnsh"], "k_transpose_to_bnsh", perm=[0, 2, 1, 3]), - helper.make_node( - "Reshape", - ["k_4d_bnsh", "k_attn_heads_output"], - ["k_output_(num_heads*batch_size,seq_len,head_size)"], - "k_reshape_to_3d", - ), - helper.make_node( - "Transpose", - ["k_output_(num_heads*batch_size,seq_len,head_size)"], - ["k_output_(num_heads*batch_size,head_size,seq_len)"], - "k_transpose_last_two_dims", - perm=[0, 2, 1], - ), - ] - ) - v_nodes = [ - helper.make_node( - "MatMul", ["layernorm_output_to_matmul", "v_weight"], ["v_matmul_output"], "v_path_matmul" - ), - helper.make_node("Add", ["v_bias", "v_matmul_output"], ["v_add_output"], "v_path_add"), - helper.make_node("Reshape", ["v_add_output", "kv_bsnh_reshape"], ["v_4d_bsnh"], "v_reshape_to_4d"), - helper.make_node("Transpose", ["v_4d_bsnh"], ["v_4d_bnsh"], "v_transpose_to_bnsh", perm=[0, 2, 1, 3]), - helper.make_node( - "Reshape", - ["v_4d_bnsh", "v_attn_heads_output"], - ["v_output_(num_heads*batch_size,seq_len,head_size)"], - "v_reshape_to_3d", - ), - ] - nodes.extend(q_nodes) - nodes.extend(k_nodes) - nodes.extend(v_nodes) - - # Create nodes used with qkv concats, reshapes, and transposes - nodes.extend( - [ - helper.make_node("Shape", ["layernorm_output_to_matmul"], ["shape_output"], "shape"), - helper.make_node("Gather", ["shape_output", "idx_0"], ["gather_0_output"], "gather_0", axis=0), - helper.make_node( - "Mul", ["gather_0_output", "num_heads_int"], ["mul_attn_heads_output"], "mul_num_heads" - ), - helper.make_node( - "Unsqueeze", - ["mul_attn_heads_output", "unsqueeze_axes_input"], - ["unsqueeze_attn_heads_output"], - "unsqueeze_num_heads", - ), - helper.make_node( - "Concat", - ["unsqueeze_attn_heads_output", "neg_one", "head_size"], - ["q_attn_heads_output"], - "q_num_heads", - axis=0, - ), - helper.make_node( - "Concat", - ["unsqueeze_attn_heads_output", "neg_one", "head_size"], - ["k_attn_heads_output"], - "k_num_heads", - axis=0, - ), - helper.make_node( - "Concat", - ["unsqueeze_attn_heads_output", "neg_one", "head_size"], - ["v_attn_heads_output"], - "v_num_heads", - axis=0, - ), - helper.make_node( - "Constant", - inputs=[], - outputs=["q_bsnh_reshape"], - value=numpy_helper.from_array( - np.array([0, 0, num_heads, head_size], dtype="int64"), name="const_tensor" - ), - ), - helper.make_node( - "Constant", - inputs=[], - outputs=["kv_bsnh_reshape"], - value=numpy_helper.from_array( - np.array([0, -1, num_heads, head_size], dtype="int64"), name="const_tensor" - ), - ), - ] - ) - - # Create nodes used with Q x K' and softmax(Q x K') x V - nodes.extend( - [ - helper.make_node("Gather", ["shape_output", "idx_1"], ["gather_1_output"], "gather_1", axis=0), - helper.make_node( - "Unsqueeze", ["gather_0_output", "unsqueeze_axes_input"], ["unsqueeze_0_output"], "unsqueeze_0" - ), - helper.make_node( - "Unsqueeze", ["gather_1_output", "unsqueeze_axes_input"], ["unsqueeze_1_output"], "unsqueeze_1" - ), - helper.make_node( - "Concat", - ["unsqueeze_0_output", "num_heads", "unsqueeze_1_output", "head_size"], - ["bnsh_format"], - axis=0, - ), - helper.make_node( - "Concat", ["unsqueeze_0_output", "unsqueeze_1_output", "hidden_size"], ["bsd_format"], axis=0 - ), - ] - ) - - # Create nodes for computing softmax(Q x K') x V - nodes.extend( - [ - helper.make_node( - "MatMul", - [ - "q_output_(num_heads*batch_size,seq_len,head_size)", - "k_output_(num_heads*batch_size,head_size,seq_len)", - ], - ["qk_output_(num_heads*batch_size,seq_len,seq_len)"], - "matmul_qk", - ), - helper.make_node( - "Softmax", - ["qk_output_(num_heads*batch_size,seq_len,seq_len)"], - ["softmax_output"], - "softmax_qk", - axis=2, - ), - helper.make_node( - "MatMul", - ["softmax_output", "v_output_(num_heads*batch_size,seq_len,head_size)"], - ["qkv_output_(num_heads*batch_size,seq_len,head_size)"], - "matmul_qkv", - ), - helper.make_node( - "Reshape", - ["qkv_output_(num_heads*batch_size,seq_len,head_size)", "bnsh_format"], - ["qkv_bnsh"], - "reshape_qkv_to_bnsh", - ), - helper.make_node("Transpose", ["qkv_bnsh"], ["qkv_bsnh"], "transpose_bnsh_to_bsnh", perm=[0, 2, 1, 3]), - helper.make_node("Reshape", ["qkv_bsnh", "bsd_format"], ["attn_output"], "qkv_bsd"), - ] - ) - - # Create final nodes to conclude attention - nodes.append( - helper.make_node( - "MatMul", - ["attn_output", "matmul_after_attn_initializer"], - ["matmul_after_attn_output"], - "matmul_after_attn", - ), - ) - if not fused: - next_sln_inputs = [ - "layernorm_output_to_skiplayernorm", - "add_after_attn_output", - "layernorm_weight", - "layernorm_bias", - ] - nodes.extend( - [ - helper.make_node( - "Add", - ["add_after_attn_initializer", "matmul_after_attn_output"], - ["add_after_attn_output"], - "add_after_attn", - ), - helper.make_node( - "SkipLayerNormalization", - next_sln_inputs, - ["output_0", "", "", "output_1"], - "next_skiplayernorm", - domain="com.microsoft", - epsilon=epsilon, - ), - ] - ) - else: - next_sln_inputs = [ - "matmul_after_attn_output", - "layernorm_output_to_skiplayernorm", - "layernorm_weight", - "layernorm_bias", - "add_after_attn_initializer", - ] - nodes.append( - helper.make_node( - "SkipLayerNormalization", - next_sln_inputs, - ["output_0", "", "", "output_1"], - "SkipLayerNorm_AddBias_0", - domain="com.microsoft", - epsilon=epsilon, - ) - ) - - # Create initializers - q_weight, q_weight_data = get_tensor_and_weight("q_weight", [hidden_size, hidden_size]) - q_bias, q_bias_data = get_tensor_and_weight("q_bias", [hidden_size]) - k_weight, k_weight_data = get_tensor_and_weight("k_weight", [hidden_size, hidden_size]) - k_bias, k_bias_data = get_tensor_and_weight("k_bias", [hidden_size], zeros=(not add_k)) - v_weight, v_weight_data = get_tensor_and_weight("v_weight", [hidden_size, hidden_size]) - v_bias, v_bias_data = get_tensor_and_weight("v_bias", [hidden_size]) - qkv_weight = helper.make_tensor( - "Attention_0_qkv_weight", - TensorProto.FLOAT, - [hidden_size, 3 * hidden_size], - q_weight_data + k_weight_data + v_weight_data, - ) - qkv_bias = helper.make_tensor( - "Attention_0_qkv_bias", TensorProto.FLOAT, [3 * hidden_size], q_bias_data + k_bias_data + v_bias_data - ) - initializers = [ - float_tensor("layernorm_weight", [hidden_size]), - float_tensor("layernorm_bias", [hidden_size]), - float_tensor("matmul_after_attn_initializer", [hidden_size, hidden_size]), - float_tensor("add_after_attn_initializer", [hidden_size]), - ] - if fused: - initializers.extend([qkv_weight, qkv_bias]) - else: - initializers.extend( - [ - numpy_helper.from_array(np.array(num_heads, dtype="int64"), name="num_heads_int"), - numpy_helper.from_array(np.array([num_heads], dtype="int64"), name="num_heads"), - numpy_helper.from_array(np.array([head_size], dtype="int64"), name="head_size"), - numpy_helper.from_array(np.array([hidden_size], dtype="int64"), name="hidden_size"), - numpy_helper.from_array(np.array(1 / np.sqrt(head_size), dtype="float32"), name="q_scale"), - numpy_helper.from_array(np.array(0, dtype="int64"), name="idx_0"), - numpy_helper.from_array(np.array(1, dtype="int64"), name="idx_1"), - numpy_helper.from_array(np.array([-1], dtype="int64"), name="neg_one"), - numpy_helper.from_array(np.array([0], dtype="int64"), name="unsqueeze_axes_input"), - ] - ) - if add_k: - initializers.extend([q_weight, q_bias, k_weight, k_bias, v_weight, v_bias]) - else: - initializers.extend([q_weight, q_bias, k_weight, v_weight, v_bias]) - - # Construct graph - graph = helper.make_graph( - nodes, "whisper_encoder_attention_graph", inputs, outputs, initializers, doc_string="whisper" - ) - opsetid = helper.make_opsetid("ai.onnx", min(onnx.defs.onnx_opset_version(), 16)) - return helper.make_model(graph, opset_imports=(opsetid,)) - - -def create_whisper_decoder_attention( - hidden_size=768, - num_heads=12, - epsilon=0.000009999999747378752, - add_before_layernorm=False, - add_k=False, - fused=False, -): - # Get head size and ensure head size is an integer - assert hidden_size % num_heads == 0 - head_size = hidden_size // num_heads - - # Construct input and output nodes - # Dummy inputs are used to prevent the nodes in the path for the decoder attention mask to be fused together - # before attention is fused - inputs = [ - helper.make_tensor_value_info("input_0", TensorProto.FLOAT, ["batch_size", 1500, hidden_size]), - ] - if not fused: - inputs.extend( - [ - helper.make_tensor_value_info("dummy_input_int64", TensorProto.INT64, ["dummy_input_1d_int64"]), - helper.make_tensor_value_info("dummy_input_fp32", TensorProto.FLOAT, ["dummy_input_1d_fp32"]), - ] - ) - outputs = [ - helper.make_tensor_value_info( - "present.0.decoder.key", TensorProto.FLOAT, ["batch_size", num_heads, 1500, head_size] - ), - helper.make_tensor_value_info( - "present.0.decoder.value", TensorProto.FLOAT, ["batch_size", num_heads, 1500, head_size] - ), - helper.make_tensor_value_info("output_0", TensorProto.FLOAT, ["batch_size", 1500, hidden_size]), - helper.make_tensor_value_info("output_1", TensorProto.FLOAT, ["batch_size", "sequence_length", hidden_size]), - ] - - nodes = [] - # Create layernorm (Add + LayerNorm or SkipLayerNorm) - if add_before_layernorm: - nodes.extend( - [ - helper.make_node( - "Add", ["input_0", "input_0"], ["layernorm_output_to_skiplayernorm"], "add_before_layernorm" - ), - helper.make_node( - "LayerNormalization", - ["layernorm_output_to_skiplayernorm", "layernorm_weight", "layernorm_bias"], - ["layernorm_output_to_matmul"], - "layernorm", - epsilon=epsilon, - ), - ] - ) - else: - nodes.append( - helper.make_node( - "SkipLayerNormalization", - ["input_0", "input_0", "layernorm_weight", "layernorm_bias"], - ["layernorm_output_to_matmul", "", "", "layernorm_output_to_skiplayernorm"], - "skiplayernorm", - domain="com.microsoft", - epsilon=epsilon, - ) - ) - - if fused: - nodes.extend( - [ - helper.make_node( - "Attention", - [ - "layernorm_output_to_matmul", - "Attention_0_qkv_weight", - "Attention_0_qkv_bias", - "", - ], - ["attn_output", "present_0_decoder"], - "Attention_0", - domain="com.microsoft", - num_heads=num_heads, - unidirectional=1, - ), - helper.make_node( - "Gather", - ["present_0_decoder", "index_0"], - ["present.0.decoder.key"], - "Gather_0", - axis=0, - ), - helper.make_node( - "Gather", - ["present_0_decoder", "index_1"], - ["present.0.decoder.value"], - "Gather_1", - axis=0, - ), - ] - ) - else: - # Create nodes for Q/K/V paths - q_nodes = [ - helper.make_node( - "MatMul", ["layernorm_output_to_matmul", "q_weight"], ["q_matmul_output"], "q_path_matmul" - ), - helper.make_node("Add", ["q_bias", "q_matmul_output"], ["q_add_output"], "q_path_add"), - helper.make_node("Mul", ["q_add_output", "q_scale"], ["q_mul_output"], "q_path_mul"), - helper.make_node("Reshape", ["q_mul_output", "q_bsnh_reshape"], ["q_4d_bsnh"], "q_reshape_to_4d"), - helper.make_node("Transpose", ["q_4d_bsnh"], ["q_4d_bnsh"], "q_transpose_to_bnsh", perm=[0, 2, 1, 3]), - helper.make_node( - "Reshape", - ["q_4d_bnsh", "q_attn_heads_output"], - ["q_output_(num_heads*batch_size,seq_len,head_size)"], - "q_reshape_to_3d", - ), - ] - k_nodes = [ - helper.make_node( - "MatMul", ["layernorm_output_to_matmul", "k_weight"], ["k_matmul_output"], "k_path_matmul" - ), - ] - if add_k: - k_nodes.extend( - [ - helper.make_node("Add", ["k_bias", "k_matmul_output"], ["k_add_output"], "k_path_add"), - helper.make_node("Reshape", ["k_add_output", "bsnh_reshape"], ["k_4d_bsnh"], "k_reshape_to_4d"), - ] - ) - else: - k_nodes.append( - helper.make_node("Reshape", ["k_matmul_output", "kv_bsnh_reshape"], ["k_4d_bsnh"], "k_reshape_to_4d"), - ) - k_nodes.extend( - [ - helper.make_node( - "Transpose", - ["k_4d_bsnh"], - ["present.0.decoder.key"], - "k_transpose_to_bnsh", - perm=[0, 2, 1, 3], - ), - helper.make_node( - "Reshape", - ["present.0.decoder.key", "k_attn_heads_output"], - ["k_output_(num_heads*batch_size,seq_len,head_size)"], - "k_reshape_to_3d", - ), - helper.make_node( - "Transpose", - ["k_output_(num_heads*batch_size,seq_len,head_size)"], - ["k_output_(num_heads*batch_size,head_size,seq_len)"], - "k_transpose_last_two_dims", - perm=[0, 2, 1], - ), - ] - ) - v_nodes = [ - helper.make_node( - "MatMul", ["layernorm_output_to_matmul", "v_weight"], ["v_matmul_output"], "v_path_matmul" - ), - helper.make_node("Add", ["v_bias", "v_matmul_output"], ["v_add_output"], "v_path_add"), - helper.make_node("Reshape", ["v_add_output", "kv_bsnh_reshape"], ["v_4d_bsnh"], "v_reshape_to_4d"), - helper.make_node( - "Transpose", ["v_4d_bsnh"], ["present.0.decoder.value"], "v_transpose_to_bnsh", perm=[0, 2, 1, 3] - ), - helper.make_node( - "Reshape", - ["present.0.decoder.value", "v_attn_heads_output"], - ["v_output_(num_heads*batch_size,seq_len,head_size)"], - "v_reshape_to_3d", - ), - ] - nodes.extend(q_nodes) - nodes.extend(k_nodes) - nodes.extend(v_nodes) - - # Create nodes used with qkv concats, reshapes, and transposes - nodes.extend( - [ - helper.make_node("Shape", ["layernorm_output_to_matmul"], ["shape_output"], "shape"), - helper.make_node("Gather", ["shape_output", "idx_0"], ["gather_0_output"], "gather_0", axis=0), - helper.make_node( - "Mul", ["gather_0_output", "num_heads_int"], ["mul_attn_heads_output"], "mul_num_heads" - ), - helper.make_node( - "Unsqueeze", - ["mul_attn_heads_output", "unsqueeze_axes_input"], - ["unsqueeze_attn_heads_output"], - "unsqueeze_num_heads", - ), - helper.make_node( - "Concat", - ["unsqueeze_attn_heads_output", "neg_one", "head_size"], - ["q_attn_heads_output"], - "q_num_heads", - axis=0, - ), - helper.make_node( - "Concat", - ["unsqueeze_attn_heads_output", "neg_one", "head_size"], - ["k_attn_heads_output"], - "k_num_heads", - axis=0, - ), - helper.make_node( - "Concat", - ["unsqueeze_attn_heads_output", "neg_one", "head_size"], - ["v_attn_heads_output"], - "v_num_heads", - axis=0, - ), - helper.make_node( - "Constant", - inputs=[], - outputs=["q_bsnh_reshape"], - value=numpy_helper.from_array( - np.array([0, 0, num_heads, head_size], dtype="int64"), name="const_tensor" - ), - ), - helper.make_node( - "Constant", - inputs=[], - outputs=["kv_bsnh_reshape"], - value=numpy_helper.from_array( - np.array([0, -1, num_heads, head_size], dtype="int64"), name="const_tensor" - ), - ), - ] - ) - - # Create nodes used with mask - nodes.extend( - [ - helper.make_node( - "Shape", ["k_output_(num_heads*batch_size,seq_len,head_size)"], ["mask_shape_output"], "mask_shape" - ), - helper.make_node( - "Gather", ["mask_shape_output", "idx_1"], ["mask_gather_1_output"], "mask_gather_1", axis=0 - ), - helper.make_node( - "Unsqueeze", - ["mask_gather_1_output", "unsqueeze_axes_input"], - ["mask_unsqueeze_1_output"], - "mask_unsqueeze_1", - ), - helper.make_node( - "Concat", - ["unsqueeze_0_output", "num_heads", "unsqueeze_1_output", "mask_unsqueeze_1_output"], - ["mask_concat_output"], - "mask_concat", - axis=0, - ), - helper.make_node( - "Mul", ["gather_0_output", "num_heads_int"], ["mul_mask_heads_output"], "mul_mask_heads" - ), - helper.make_node( - "Unsqueeze", - ["mul_mask_heads_output", "unsqueeze_axes_input"], - ["unsqueeze_mask_heads_output"], - "unsqueeze_mask_heads", - ), - helper.make_node( - "Concat", - ["unsqueeze_mask_heads_output", "unsqueeze_1_output", "mask_unsqueeze_1_output"], - ["concat_input_for_reshape_after_add"], - "concat_for_reshape_after_add", - axis=0, - ), - ] - ) - - # Create nodes used with Q x K' + mask and softmax(Q x K' + mask) x V - nodes.extend( - [ - helper.make_node("Gather", ["shape_output", "idx_1"], ["gather_1_output"], "gather_1", axis=0), - helper.make_node( - "Unsqueeze", ["gather_0_output", "unsqueeze_axes_input"], ["unsqueeze_0_output"], "unsqueeze_0" - ), - helper.make_node( - "Unsqueeze", ["gather_1_output", "unsqueeze_axes_input"], ["unsqueeze_1_output"], "unsqueeze_1" - ), - helper.make_node( - "Concat", - ["unsqueeze_0_output", "num_heads", "unsqueeze_1_output", "head_size"], - ["bnsh_format"], - axis=0, - ), - helper.make_node( - "Concat", ["unsqueeze_0_output", "unsqueeze_1_output", "hidden_size"], ["bsd_format"], axis=0 - ), - ] - ) - - # Create nodes for computing softmax(Q x K' + mask) x V - nodes.extend( - [ - helper.make_node( - "MatMul", - [ - "q_output_(num_heads*batch_size,seq_len,head_size)", - "k_output_(num_heads*batch_size,head_size,seq_len)", - ], - ["qk_output_(num_heads*batch_size,seq_len,seq_len)"], - "matmul_qk", - ), - helper.make_node( - "Reshape", - ["qk_output_(num_heads*batch_size,seq_len,seq_len)", "mask_concat_output"], - ["qk_output_(batch_size,num_heads,seq_len,seq_len)"], - "reshape_qk_to_bnsh", - ), - helper.make_node( - "Add", - ["qk_output_(batch_size,num_heads,seq_len,seq_len)", "attention_add_qk"], - ["add_qk_output_(batch_size,num_heads_seq_len,seq_len)"], - "add_qk", - ), - helper.make_node( - "Reshape", - ["add_qk_output_(batch_size,num_heads_seq_len,seq_len)", "concat_input_for_reshape_after_add"], - ["add_qk_output_(num_heads*batch_size,seq_len,seq_len)"], - "reshape_add_qk_before_softmax", - ), - helper.make_node( - "Softmax", - ["add_qk_output_(num_heads*batch_size,seq_len,seq_len)"], - ["softmax_output"], - "softmax_qk", - axis=2, - ), - helper.make_node( - "MatMul", - ["softmax_output", "v_output_(num_heads*batch_size,seq_len,head_size)"], - ["qkv_output_(num_heads*batch_size,seq_len,head_size)"], - "matmul_qkv", - ), - helper.make_node( - "Reshape", - ["qkv_output_(num_heads*batch_size,seq_len,head_size)", "bnsh_format"], - ["qkv_bnsh"], - "reshape_qkv_to_bnsh", - ), - helper.make_node("Transpose", ["qkv_bnsh"], ["qkv_bsnh"], "transpose_bnsh_to_bsnh", perm=[0, 2, 1, 3]), - helper.make_node("Reshape", ["qkv_bsnh", "bsd_format"], ["attn_output"], "qkv_bsd"), - ] - ) - - # Create nodes that make attention mask - if not fused: - nodes.extend( - [ - # "attention_mask" is (decoder_seq_len, decoder_seq_len) but is assumed to be (1, 1) for this test. - # There are other nodes that automatically set the attention mask size correctly but those nodes do not - # impact the attention fusion. Hence, this assumption is made in order to simplify the inputs for the - # following nodes. - helper.make_node( - "Where", - ["all_ones", "where_filter_constant", "dummy_input_fp32"], - ["where_output"], - "mask_filter_where", - ), - helper.make_node( - "Unsqueeze", - ["where_output", "dummy_input_int64"], - ["unsqueeze_mask_output_1"], - "unsqueeze_attn_mask_1", - ), - helper.make_node( - "Unsqueeze", - ["unsqueeze_mask_output_1", "dummy_input_int64"], - ["unsqueeze_mask_output_2"], - "unsqueeze_attn_mask_2", - ), - helper.make_node( - "Expand", - inputs=["unsqueeze_mask_output_2", "dummy_input_int64"], - outputs=["attention_add_qk"], - name="expand_mask_from_(b,1,m,m)_to_(b,n,m,m)", - ), - ] - ) - - # Create final nodes to conclude attention - nodes.append( - helper.make_node( - "MatMul", - ["attn_output", "matmul_after_attn_initializer"], - ["matmul_after_attn_output"], - "matmul_after_attn", - ), - ) - if not fused: - next_sln_inputs = [ - "layernorm_output_to_skiplayernorm", - "add_after_attn_output", - "layernorm_weight", - "layernorm_bias", - ] - nodes.extend( - [ - helper.make_node( - "Add", - ["add_after_attn_initializer", "matmul_after_attn_output"], - ["add_after_attn_output"], - "add_after_attn", - ), - helper.make_node( - "SkipLayerNormalization", - next_sln_inputs, - ["output_0", "", "", "output_1"], - "next_skiplayernorm", - domain="com.microsoft", - epsilon=epsilon, - ), - ] - ) - else: - next_sln_inputs = [ - "matmul_after_attn_output", - "layernorm_output_to_skiplayernorm", - "layernorm_weight", - "layernorm_bias", - "add_after_attn_initializer", - ] - nodes.append( - helper.make_node( - "SkipLayerNormalization", - next_sln_inputs, - ["output_0", "", "", "output_1"], - "SkipLayerNorm_AddBias_0", - domain="com.microsoft", - epsilon=epsilon, - ) - ) - - # Create initializers - q_weight, q_weight_data = get_tensor_and_weight("q_weight", [hidden_size, hidden_size]) - q_bias, q_bias_data = get_tensor_and_weight("q_bias", [hidden_size]) - k_weight, k_weight_data = get_tensor_and_weight("k_weight", [hidden_size, hidden_size]) - k_bias, k_bias_data = get_tensor_and_weight("k_bias", [hidden_size], zeros=(not add_k)) - v_weight, v_weight_data = get_tensor_and_weight("v_weight", [hidden_size, hidden_size]) - v_bias, v_bias_data = get_tensor_and_weight("v_bias", [hidden_size]) - qkv_weight = helper.make_tensor( - "Attention_0_qkv_weight", - TensorProto.FLOAT, - [hidden_size, 3 * hidden_size], - q_weight_data + k_weight_data + v_weight_data, - ) - qkv_bias = helper.make_tensor( - "Attention_0_qkv_bias", TensorProto.FLOAT, [3 * hidden_size], q_bias_data + k_bias_data + v_bias_data - ) - initializers = [ - float_tensor("layernorm_weight", [hidden_size]), - float_tensor("layernorm_bias", [hidden_size]), - float_tensor("matmul_after_attn_initializer", [hidden_size, hidden_size]), - float_tensor("add_after_attn_initializer", [hidden_size]), - ] - - if fused: - initializers.extend( - [ - qkv_weight, - qkv_bias, - numpy_helper.from_array(np.array(0, dtype="int64"), name="index_0"), - numpy_helper.from_array(np.array(1, dtype="int64"), name="index_1"), - ] - ) - else: - initializers.extend( - [ - numpy_helper.from_array(np.array([[1]], dtype=bool), name="all_ones"), - numpy_helper.from_array(np.array([1], dtype="float32"), name="where_filter_constant"), - numpy_helper.from_array(np.array(num_heads, dtype="int64"), name="num_heads_int"), - numpy_helper.from_array(np.array([num_heads], dtype="int64"), name="num_heads"), - numpy_helper.from_array(np.array([head_size], dtype="int64"), name="head_size"), - numpy_helper.from_array(np.array([hidden_size], dtype="int64"), name="hidden_size"), - numpy_helper.from_array(np.array(1 / np.sqrt(head_size), dtype="float32"), name="q_scale"), - numpy_helper.from_array(np.array(0, dtype="int64"), name="idx_0"), - numpy_helper.from_array(np.array(1, dtype="int64"), name="idx_1"), - numpy_helper.from_array(np.array([-1], dtype="int64"), name="neg_one"), - numpy_helper.from_array(np.array([0], dtype="int64"), name="unsqueeze_axes_input"), - ] - ) - - if add_k: - initializers.extend([q_weight, q_bias, k_weight, k_bias, v_weight, v_bias]) - else: - initializers.extend([q_weight, q_bias, k_weight, v_weight, v_bias]) - - # Construct graph - graph = helper.make_graph( - nodes, "whisper_decoder_attention_graph", inputs, outputs, initializers, doc_string="whisper" - ) - opsetid = helper.make_opsetid("ai.onnx", min(onnx.defs.onnx_opset_version(), 16)) - return helper.make_model(graph, opset_imports=(opsetid,)) - - -def create_whisper_decoder_multihead_attention( - hidden_size=768, num_heads=12, epsilon=0.000009999999747378752, add_k=False, fused=False -): - # Get head size and ensure head size is an integer - assert hidden_size % num_heads == 0 - head_size = hidden_size // num_heads - - # Construct input and output nodes - inputs = [ - helper.make_tensor_value_info("input_0", TensorProto.FLOAT, ["batch_size", 1500, hidden_size]), - helper.make_tensor_value_info("encoder_hidden_states", TensorProto.FLOAT, ["batch_size", 1500, hidden_size]), - ] - outputs = [ - helper.make_tensor_value_info("output_0", TensorProto.FLOAT, ["batch_size", 1500, hidden_size]), - helper.make_tensor_value_info("output_1", TensorProto.FLOAT, ["batch_size", "sequence_length", hidden_size]), - helper.make_tensor_value_info( - "present.0.encoder.key", TensorProto.FLOAT, ["batch_size", num_heads, 1500, head_size] - ), - helper.make_tensor_value_info( - "present.0.encoder.value", TensorProto.FLOAT, ["batch_size", num_heads, 1500, head_size] - ), - ] - - # Create SkipLayerNorm (since there's no Add + LayerNorm variant for this attention subgraph) - nodes = [ - helper.make_node( - "SkipLayerNormalization", - ["input_0", "input_0", "layernorm_weight", "layernorm_bias"], - ["layernorm_output_to_matmul", "", "", "layernorm_output_to_skiplayernorm"], - "skiplayernorm", - domain="com.microsoft", - epsilon=epsilon, - ) - ] - - if fused: - nodes.extend( - [ - helper.make_node( - "MatMul", ["layernorm_output_to_matmul", "q_weight"], ["q_matmul_output"], "q_path_matmul" - ), - helper.make_node("MatMul", ["encoder_hidden_states", "k_weight"], ["k_matmul_output"], "k_path_matmul"), - helper.make_node("MatMul", ["encoder_hidden_states", "v_weight"], ["v_matmul_output"], "v_path_matmul"), - helper.make_node( - "MultiHeadAttention", - ["q_matmul_output", "k_matmul_output", "v_matmul_output", "Attention_0_qkv_bias"], - ["attn_output", "present.0.encoder.key", "present.0.encoder.value"], - "Attention_0", - domain="com.microsoft", - num_heads=num_heads, - ), - ] - ) - else: - # Create nodes for Q/K/V paths - q_nodes = [ - helper.make_node( - "MatMul", - ["layernorm_output_to_matmul", "q_weight"], - ["q_matmul_output"], - "q_path_matmul", - ), - helper.make_node("Add", ["q_bias", "q_matmul_output"], ["q_add_output"], "q_path_add"), - helper.make_node("Mul", ["q_add_output", "q_scale"], ["q_mul_output"], "q_path_mul"), - helper.make_node("Reshape", ["q_mul_output", "q_bsnh_reshape"], ["q_4d_bsnh"], "q_reshape_to_4d"), - helper.make_node("Transpose", ["q_4d_bsnh"], ["q_4d_bnsh"], "q_transpose_to_bnsh", perm=[0, 2, 1, 3]), - helper.make_node( - "Reshape", - ["q_4d_bnsh", "q_attn_heads_output"], - ["q_output_(num_heads*batch_size,seq_len,head_size)"], - "q_reshape_to_3d", - ), - ] - k_nodes = [ - helper.make_node("MatMul", ["encoder_hidden_states", "k_weight"], ["k_matmul_output"], "k_path_matmul"), - ] - if add_k: - k_nodes.extend( - [ - helper.make_node("Add", ["k_bias", "k_matmul_output"], ["k_add_output"], "k_path_add"), - helper.make_node("Reshape", ["k_add_output", "bsnh_reshape"], ["k_4d_bsnh"], "k_reshape_to_4d"), - ] - ) - else: - k_nodes.append( - helper.make_node("Reshape", ["k_matmul_output", "kv_bsnh_reshape"], ["k_4d_bsnh"], "k_reshape_to_4d"), - ) - k_nodes.extend( - [ - helper.make_node( - "Transpose", ["k_4d_bsnh"], ["present.0.encoder.key"], "k_transpose_to_bnsh", perm=[0, 2, 1, 3] - ), - helper.make_node( - "Reshape", - ["present.0.encoder.key", "k_attn_heads_output"], - ["k_output_(num_heads*batch_size,seq_len,head_size)"], - "k_reshape_to_3d", - ), - helper.make_node( - "Transpose", - ["k_output_(num_heads*batch_size,seq_len,head_size)"], - ["k_output_(num_heads*batch_size,head_size,seq_len)"], - "k_transpose_last_two_dims", - perm=[0, 2, 1], - ), - ] - ) - v_nodes = [ - helper.make_node("MatMul", ["encoder_hidden_states", "v_weight"], ["v_matmul_output"], "v_path_matmul"), - helper.make_node("Add", ["v_bias", "v_matmul_output"], ["v_add_output"], "v_path_add"), - helper.make_node("Reshape", ["v_add_output", "kv_bsnh_reshape"], ["v_4d_bsnh"], "v_reshape_to_4d"), - helper.make_node( - "Transpose", ["v_4d_bsnh"], ["present.0.encoder.value"], "v_transpose_to_bnsh", perm=[0, 2, 1, 3] - ), - helper.make_node( - "Reshape", - ["present.0.encoder.value", "v_attn_heads_output"], - ["v_output_(num_heads*batch_size,seq_len,head_size)"], - "v_reshape_to_3d", - ), - ] - nodes.extend(q_nodes) - nodes.extend(k_nodes) - nodes.extend(v_nodes) - - # Create nodes used with qkv concats, reshapes, and transposes - nodes.extend( - [ - helper.make_node("Shape", ["layernorm_output_to_matmul"], ["shape_output"], "shape"), - helper.make_node("Gather", ["shape_output", "idx_0"], ["gather_0_output"], "gather_0", axis=0), - helper.make_node( - "Mul", ["gather_0_output", "num_heads_int"], ["mul_attn_heads_output"], "mul_num_heads" - ), - helper.make_node( - "Unsqueeze", - ["mul_attn_heads_output", "unsqueeze_axes_input"], - ["unsqueeze_attn_heads_output"], - "unsqueeze_num_heads", - ), - helper.make_node( - "Concat", - ["unsqueeze_attn_heads_output", "neg_one", "head_size"], - ["q_attn_heads_output"], - "q_num_heads", - axis=0, - ), - helper.make_node( - "Concat", - ["unsqueeze_attn_heads_output", "neg_one", "head_size"], - ["k_attn_heads_output"], - "k_num_heads", - axis=0, - ), - helper.make_node( - "Concat", - ["unsqueeze_attn_heads_output", "neg_one", "head_size"], - ["v_attn_heads_output"], - "v_num_heads", - axis=0, - ), - helper.make_node( - "Constant", - inputs=[], - outputs=["q_bsnh_reshape"], - value=numpy_helper.from_array( - np.array([0, 0, num_heads, head_size], dtype="int64"), name="const_tensor" - ), - ), - helper.make_node( - "Constant", - inputs=[], - outputs=["kv_bsnh_reshape"], - value=numpy_helper.from_array( - np.array([0, -1, num_heads, head_size], dtype="int64"), name="const_tensor" - ), - ), - ] - ) - - # Create nodes used with Q x K' and softmax(Q x K') x V - nodes.extend( - [ - helper.make_node("Gather", ["shape_output", "idx_1"], ["gather_1_output"], "gather_1", axis=0), - helper.make_node( - "Unsqueeze", - ["gather_0_output", "unsqueeze_axes_input"], - ["unsqueeze_0_output"], - "unsqueeze_0", - ), - helper.make_node( - "Unsqueeze", - ["gather_1_output", "unsqueeze_axes_input"], - ["unsqueeze_1_output"], - "unsqueeze_1", - ), - helper.make_node( - "Concat", - ["unsqueeze_0_output", "num_heads", "unsqueeze_1_output", "head_size"], - ["bnsh_format"], - axis=0, - ), - helper.make_node( - "Concat", - ["unsqueeze_0_output", "unsqueeze_1_output", "hidden_size"], - ["bsd_format"], - axis=0, - ), - ] - ) - - # Create nodes for computing softmax(Q x K') x V - nodes.extend( - [ - helper.make_node( - "MatMul", - [ - "q_output_(num_heads*batch_size,seq_len,head_size)", - "k_output_(num_heads*batch_size,head_size,seq_len)", - ], - ["qk_output_(num_heads*batch_size,seq_len,seq_len)"], - "matmul_qk", - ), - helper.make_node( - "Softmax", - ["qk_output_(num_heads*batch_size,seq_len,seq_len)"], - ["softmax_output"], - "softmax_qk", - axis=2, - ), - helper.make_node( - "MatMul", - ["softmax_output", "v_output_(num_heads*batch_size,seq_len,head_size)"], - ["qkv_output_(num_heads*batch_size,seq_len,head_size)"], - "matmul_qkv", - ), - helper.make_node( - "Reshape", - ["qkv_output_(num_heads*batch_size,seq_len,head_size)", "bnsh_format"], - ["qkv_bnsh"], - "reshape_qkv_to_bnsh", - ), - helper.make_node("Transpose", ["qkv_bnsh"], ["qkv_bsnh"], "transpose_bnsh_to_bsnh", perm=[0, 2, 1, 3]), - helper.make_node("Reshape", ["qkv_bsnh", "bsd_format"], ["attn_output"], "qkv_bsd"), - ] - ) - - # Create final nodes to conclude attention - nodes.append( - helper.make_node( - "MatMul", - ["attn_output", "matmul_after_attn_initializer"], - ["matmul_after_attn_output"], - "matmul_after_attn", - ), - ) - if not fused: - next_sln_inputs = [ - "layernorm_output_to_skiplayernorm", - "add_after_attn_output", - "layernorm_weight", - "layernorm_bias", - ] - nodes.extend( - [ - helper.make_node( - "Add", - ["add_after_attn_initializer", "matmul_after_attn_output"], - ["add_after_attn_output"], - "add_after_attn", - ), - helper.make_node( - "SkipLayerNormalization", - next_sln_inputs, - ["output_0", "", "", "output_1"], - "next_skiplayernorm", - domain="com.microsoft", - epsilon=epsilon, - ), - ] - ) - else: - next_sln_inputs = [ - "matmul_after_attn_output", - "layernorm_output_to_skiplayernorm", - "layernorm_weight", - "layernorm_bias", - "add_after_attn_initializer", - ] - nodes.append( - helper.make_node( - "SkipLayerNormalization", - next_sln_inputs, - ["output_0", "", "", "output_1"], - "SkipLayerNorm_AddBias_0", - domain="com.microsoft", - epsilon=epsilon, - ) - ) - - # Create initializers - q_weight, q_weight_data = get_tensor_and_weight("q_weight", [hidden_size, hidden_size]) - q_bias, q_bias_data = get_tensor_and_weight("q_bias", [hidden_size]) - k_weight, k_weight_data = get_tensor_and_weight("k_weight", [hidden_size, hidden_size]) - k_bias, k_bias_data = get_tensor_and_weight("k_bias", [hidden_size], zeros=(not add_k)) - v_weight, v_weight_data = get_tensor_and_weight("v_weight", [hidden_size, hidden_size]) - v_bias, v_bias_data = get_tensor_and_weight("v_bias", [hidden_size]) - qkv_bias = helper.make_tensor( - "Attention_0_qkv_bias", TensorProto.FLOAT, [3 * hidden_size], q_bias_data + k_bias_data + v_bias_data - ) - initializers = [ - float_tensor("layernorm_weight", [hidden_size]), - float_tensor("layernorm_bias", [hidden_size]), - float_tensor("matmul_after_attn_initializer", [hidden_size, hidden_size]), - float_tensor("add_after_attn_initializer", [hidden_size]), - ] - - # Add Q/K/V weight tensors as initializers - initializers.extend([q_weight, k_weight, v_weight]) - - if fused: - initializers.append(qkv_bias) - else: - if add_k: - initializers.extend([q_bias, k_bias, v_bias]) - else: - initializers.extend([q_bias, v_bias]) - - initializers.extend( - [ - numpy_helper.from_array(np.array(num_heads, dtype="int64"), name="num_heads_int"), - numpy_helper.from_array(np.array([num_heads], dtype="int64"), name="num_heads"), - numpy_helper.from_array(np.array([head_size], dtype="int64"), name="head_size"), - numpy_helper.from_array(np.array([hidden_size], dtype="int64"), name="hidden_size"), - numpy_helper.from_array(np.array(1 / np.sqrt(head_size), dtype="float32"), name="q_scale"), - numpy_helper.from_array(np.array(0, dtype="int64"), name="idx_0"), - numpy_helper.from_array(np.array(1, dtype="int64"), name="idx_1"), - numpy_helper.from_array(np.array([-1], dtype="int64"), name="neg_one"), - numpy_helper.from_array(np.array([0], dtype="int64"), name="unsqueeze_axes_input"), - ] - ) - - # Construct graph - graph = helper.make_graph(nodes, "whisper_decoder_mha_graph", inputs, outputs, initializers, doc_string="whisper") - opsetid = helper.make_opsetid("ai.onnx", min(onnx.defs.onnx_opset_version(), 16)) - return helper.make_model(graph, opset_imports=(opsetid,)) - - -def create_whisper_decoder_with_past_multihead_self_attention( - hidden_size=768, - num_heads=12, - epsilon=0.000009999999747378752, - add_before_layernorm=False, - add_k=False, - fused=False, -): - # Get head size and ensure head size is an integer - assert hidden_size % num_heads == 0 - head_size = hidden_size // num_heads - - # Construct input and output nodes - inputs = [ - helper.make_tensor_value_info("input_0", TensorProto.FLOAT, ["batch_size", 1500, hidden_size]), - helper.make_tensor_value_info( - "past_key_values.0.decoder.key", TensorProto.FLOAT, ["batch_size", num_heads, "past_seq_len", head_size] - ), - helper.make_tensor_value_info( - "past_key_values.0.decoder.value", TensorProto.FLOAT, ["batch_size", num_heads, "past_seq_len", head_size] - ), - ] - outputs = [ - helper.make_tensor_value_info("output_0", TensorProto.FLOAT, ["batch_size", 1500, hidden_size]), - helper.make_tensor_value_info("output_1", TensorProto.FLOAT, ["batch_size", "sequence_length", hidden_size]), - helper.make_tensor_value_info( - "present.0.decoder.key", TensorProto.FLOAT, ["batch_size", num_heads, "past_seq_len + 1", head_size] - ), - helper.make_tensor_value_info( - "present.0.decoder.value", TensorProto.FLOAT, ["batch_size", num_heads, "past_seq_len + 1", head_size] - ), - ] - nodes = [] - - # Create layernorm (Add + LayerNorm or SkipLayerNorm) - if add_before_layernorm: - nodes.extend( - [ - helper.make_node( - "Add", ["input_0", "input_0"], ["layernorm_output_to_skiplayernorm"], "add_before_layernorm" - ), - helper.make_node( - "LayerNormalization", - ["layernorm_output_to_skiplayernorm", "layernorm_weight", "layernorm_bias"], - ["layernorm_output_to_matmul"], - "layernorm", - epsilon=epsilon, - ), - ] - ) - else: - nodes.append( - helper.make_node( - "SkipLayerNormalization", - ["input_0", "input_0", "layernorm_weight", "layernorm_bias"], - ["layernorm_output_to_matmul", "", "", "layernorm_output_to_skiplayernorm"], - "skiplayernorm", - domain="com.microsoft", - epsilon=epsilon, - ) - ) - - if fused: - nodes.extend( - [ - helper.make_node( - "MatMul", - ["layernorm_output_to_matmul", "MatMul_0_qkv_weight"], - ["MatMul_0_qkv_out"], - "MatMul_0", - ), - helper.make_node( - "Slice", - ["MatMul_0_qkv_out", "MatMul_0_q_start_index", "MatMul_0_k_start_index", "MatMul_0_qkv_last_axis"], - ["MatMul_0_q_out"], - "Slice_0", - ), - helper.make_node( - "Slice", - ["MatMul_0_qkv_out", "MatMul_0_k_start_index", "MatMul_0_v_start_index", "MatMul_0_qkv_last_axis"], - ["MatMul_0_k_out"], - "Slice_1", - ), - helper.make_node( - "Slice", - [ - "MatMul_0_qkv_out", - "MatMul_0_v_start_index", - "MatMul_0_end_of_qkv_index", - "MatMul_0_qkv_last_axis", - ], - ["MatMul_0_v_out"], - "Slice_2", - ), - helper.make_node( - "MultiHeadAttention", - [ - "MatMul_0_q_out", - "MatMul_0_k_out", - "MatMul_0_v_out", - "Attention_0_qkv_bias", - "", - "", - "past_key_values.0.decoder.key", - "past_key_values.0.decoder.value", - ], - ["attn_output", "present.0.decoder.key", "present.0.decoder.value"], - "Attention_0", - domain="com.microsoft", - num_heads=num_heads, - unidirectional=1, - ), - ] - ) - else: - # Create nodes for Q/K/V paths - q_nodes = [ - helper.make_node( - "MatMul", ["layernorm_output_to_matmul", "q_weight"], ["q_matmul_output"], "q_path_matmul" - ), - helper.make_node("Add", ["q_bias", "q_matmul_output"], ["q_add_output"], "q_path_add"), - helper.make_node("Mul", ["q_add_output", "q_scale"], ["q_mul_output"], "q_path_mul"), - helper.make_node("Reshape", ["q_mul_output", "q_bsnh_reshape"], ["q_4d_bsnh"], "q_reshape_to_4d"), - helper.make_node("Transpose", ["q_4d_bsnh"], ["q_4d_bnsh"], "q_transpose_to_bnsh", perm=[0, 2, 1, 3]), - helper.make_node( - "Reshape", - ["q_4d_bnsh", "q_attn_heads_output"], - ["q_output_(num_heads*batch_size,seq_len,head_size)"], - "q_reshape_to_3d", - ), - ] - k_nodes = [ - helper.make_node( - "MatMul", - ["layernorm_output_to_matmul", "k_weight"], - ["k_matmul_output"], - "k_path_matmul", - ), - ] - if add_k: - k_nodes.extend( - [ - helper.make_node("Add", ["k_bias", "k_matmul_output"], ["k_add_output"], "k_path_add"), - helper.make_node("Reshape", ["k_add_output", "bsnh_reshape"], ["k_4d_bsnh"], "k_reshape_to_4d"), - ] - ) - else: - k_nodes.append( - helper.make_node("Reshape", ["k_matmul_output", "kv_bsnh_reshape"], ["k_4d_bsnh"], "k_reshape_to_4d"), - ) - k_nodes.extend( - [ - helper.make_node("Transpose", ["k_4d_bsnh"], ["k_4d_bnsh"], "k_transpose_to_bnsh", perm=[0, 2, 1, 3]), - helper.make_node( - "Concat", - ["past_key_values.0.decoder.key", "k_4d_bnsh"], - ["present.0.decoder.key"], - "concat_past_k_and_curr_k", - axis=2, - ), - helper.make_node( - "Reshape", - ["present.0.decoder.key", "k_attn_heads_output"], - ["k_output_(num_heads*batch_size,seq_len,head_size)"], - "k_reshape_to_3d", - ), - helper.make_node( - "Transpose", - ["k_output_(num_heads*batch_size,seq_len,head_size)"], - ["k_output_(num_heads*batch_size,head_size,seq_len)"], - "k_transpose_last_two_dims", - perm=[0, 2, 1], - ), - ] - ) - v_nodes = [ - helper.make_node( - "MatMul", - ["layernorm_output_to_matmul", "v_weight"], - ["v_matmul_output"], - "v_path_matmul", - ), - helper.make_node("Add", ["v_bias", "v_matmul_output"], ["v_add_output"], "v_path_add"), - helper.make_node("Reshape", ["v_add_output", "kv_bsnh_reshape"], ["v_4d_bsnh"], "v_reshape_to_4d"), - helper.make_node("Transpose", ["v_4d_bsnh"], ["v_4d_bnsh"], "v_transpose_to_bnsh", perm=[0, 2, 1, 3]), - helper.make_node( - "Concat", - ["past_key_values.0.decoder.value", "v_4d_bnsh"], - ["present.0.decoder.value"], - "concat_past_v_and_curr_v", - axis=2, - ), - helper.make_node( - "Reshape", - ["present.0.decoder.value", "v_attn_heads_output"], - ["v_output_(num_heads*batch_size,seq_len,head_size)"], - "v_reshape_to_3d", - ), - ] - nodes.extend(q_nodes) - nodes.extend(k_nodes) - nodes.extend(v_nodes) - - # Create nodes used with qkv concats, reshapes, and transposes - nodes.extend( - [ - helper.make_node("Shape", ["layernorm_output_to_matmul"], ["shape_output"], "shape"), - helper.make_node("Gather", ["shape_output", "idx_0"], ["gather_0_output"], "gather_0", axis=0), - helper.make_node( - "Mul", - ["gather_0_output", "num_heads_int"], - ["mul_attn_heads_output"], - "mul_num_heads", - ), - helper.make_node( - "Unsqueeze", - ["mul_attn_heads_output", "unsqueeze_axes_input"], - ["unsqueeze_attn_heads_output"], - "unsqueeze_num_heads", - ), - helper.make_node( - "Concat", - ["unsqueeze_attn_heads_output", "neg_one", "head_size"], - ["q_attn_heads_output"], - "q_num_heads", - axis=0, - ), - helper.make_node( - "Concat", - ["unsqueeze_attn_heads_output", "neg_one", "head_size"], - ["k_attn_heads_output"], - "k_num_heads", - axis=0, - ), - helper.make_node( - "Concat", - ["unsqueeze_attn_heads_output", "neg_one", "head_size"], - ["v_attn_heads_output"], - "v_num_heads", - axis=0, - ), - helper.make_node( - "Constant", - inputs=[], - outputs=["q_bsnh_reshape"], - value=numpy_helper.from_array( - np.array([0, 0, num_heads, head_size], dtype="int64"), name="const_tensor" - ), - ), - helper.make_node( - "Constant", - inputs=[], - outputs=["kv_bsnh_reshape"], - value=numpy_helper.from_array( - np.array([0, -1, num_heads, head_size], dtype="int64"), name="const_tensor" - ), - ), - ] - ) - - # Create nodes used with Q x K' and softmax(Q x K') x V - nodes.extend( - [ - helper.make_node("Gather", ["shape_output", "idx_1"], ["gather_1_output"], "gather_1", axis=0), - helper.make_node( - "Unsqueeze", - ["gather_0_output", "unsqueeze_axes_input"], - ["unsqueeze_0_output"], - "unsqueeze_0", - ), - helper.make_node( - "Unsqueeze", - ["gather_1_output", "unsqueeze_axes_input"], - ["unsqueeze_1_output"], - "unsqueeze_1", - ), - helper.make_node( - "Concat", - ["unsqueeze_0_output", "num_heads", "unsqueeze_1_output", "head_size"], - ["bnsh_format"], - axis=0, - ), - helper.make_node( - "Concat", - ["unsqueeze_0_output", "unsqueeze_1_output", "hidden_size"], - ["bsd_format"], - axis=0, - ), - ] - ) - - # Create nodes for computing softmax(Q x K') x V - nodes.extend( - [ - helper.make_node( - "MatMul", - [ - "q_output_(num_heads*batch_size,seq_len,head_size)", - "k_output_(num_heads*batch_size,head_size,seq_len)", - ], - ["qk_output_(num_heads*batch_size,seq_len,seq_len)"], - "matmul_qk", - ), - helper.make_node( - "Softmax", - ["qk_output_(num_heads*batch_size,seq_len,seq_len)"], - ["softmax_output"], - "softmax_qk", - axis=2, - ), - helper.make_node( - "MatMul", - ["softmax_output", "v_output_(num_heads*batch_size,seq_len,head_size)"], - ["qkv_output_(num_heads*batch_size,seq_len,head_size)"], - "matmul_qkv", - ), - helper.make_node( - "Reshape", - ["qkv_output_(num_heads*batch_size,seq_len,head_size)", "bnsh_format"], - ["qkv_bnsh"], - "reshape_qkv_to_bnsh", - ), - helper.make_node("Transpose", ["qkv_bnsh"], ["qkv_bsnh"], "transpose_bnsh_to_bsnh", perm=[0, 2, 1, 3]), - helper.make_node("Reshape", ["qkv_bsnh", "bsd_format"], ["attn_output"], "qkv_bsd"), - ] - ) - - # Create final nodes to conclude attention - nodes.append( - helper.make_node( - "MatMul", - ["attn_output", "matmul_after_attn_initializer"], - ["matmul_after_attn_output"], - "matmul_after_attn", - ), - ) - if not fused: - next_sln_inputs = [ - "layernorm_output_to_skiplayernorm", - "add_after_attn_output", - "layernorm_weight", - "layernorm_bias", - ] - nodes.extend( - [ - helper.make_node( - "Add", - ["add_after_attn_initializer", "matmul_after_attn_output"], - ["add_after_attn_output"], - "add_after_attn", - ), - helper.make_node( - "SkipLayerNormalization", - next_sln_inputs, - ["output_0", "", "", "output_1"], - "next_skiplayernorm", - domain="com.microsoft", - epsilon=epsilon, - ), - ] - ) - else: - next_sln_inputs = [ - "matmul_after_attn_output", - "layernorm_output_to_skiplayernorm", - "layernorm_weight", - "layernorm_bias", - "add_after_attn_initializer", - ] - nodes.append( - helper.make_node( - "SkipLayerNormalization", - next_sln_inputs, - ["output_0", "", "", "output_1"], - "SkipLayerNorm_AddBias_0", - domain="com.microsoft", - epsilon=epsilon, - ) - ) - - # Create initializers - q_weight, q_weight_data = get_tensor_and_weight("q_weight", [hidden_size, hidden_size]) - q_bias, q_bias_data = get_tensor_and_weight("q_bias", [hidden_size]) - k_weight, k_weight_data = get_tensor_and_weight("k_weight", [hidden_size, hidden_size]) - k_bias, k_bias_data = get_tensor_and_weight("k_bias", [hidden_size], zeros=(not add_k)) - v_weight, v_weight_data = get_tensor_and_weight("v_weight", [hidden_size, hidden_size]) - v_bias, v_bias_data = get_tensor_and_weight("v_bias", [hidden_size]) - qkv_weight = helper.make_tensor( - "MatMul_0_qkv_weight", - TensorProto.FLOAT, - [hidden_size, 3 * hidden_size], - q_weight_data + k_weight_data + v_weight_data, - ) - qkv_bias = helper.make_tensor( - "Attention_0_qkv_bias", - TensorProto.FLOAT, - [3 * hidden_size], - q_bias_data + k_bias_data + v_bias_data, - ) - initializers = [ - float_tensor("layernorm_weight", [hidden_size]), - float_tensor("layernorm_bias", [hidden_size]), - float_tensor("matmul_after_attn_initializer", [hidden_size, hidden_size]), - float_tensor("add_after_attn_initializer", [hidden_size]), - ] - - if fused: - # Add packed QKV weight tensor as initializer - initializers.append(qkv_weight) - - # Add Slice indices as initializers - initializers.extend( - [ - helper.make_tensor(name="MatMul_0_q_start_index", data_type=TensorProto.INT64, dims=[1], vals=[0]), - helper.make_tensor( - name="MatMul_0_k_start_index", data_type=TensorProto.INT64, dims=[1], vals=[hidden_size] - ), - helper.make_tensor( - name="MatMul_0_v_start_index", data_type=TensorProto.INT64, dims=[1], vals=[2 * hidden_size] - ), - helper.make_tensor( - name="MatMul_0_end_of_qkv_index", data_type=TensorProto.INT64, dims=[1], vals=[3 * hidden_size] - ), - helper.make_tensor(name="MatMul_0_qkv_last_axis", data_type=TensorProto.INT64, dims=[1], vals=[-1]), - ] - ) - - # Add packed QKV bias tensor as initializer - initializers.append(qkv_bias) - else: - # Add Q/K/V weight tensors as initializers - initializers.extend([q_weight, k_weight, v_weight]) - - if add_k: - initializers.extend([q_bias, k_bias, v_bias]) - else: - initializers.extend([q_bias, v_bias]) - - initializers.extend( - [ - numpy_helper.from_array(np.array(num_heads, dtype="int64"), name="num_heads_int"), - numpy_helper.from_array(np.array([num_heads], dtype="int64"), name="num_heads"), - numpy_helper.from_array(np.array([head_size], dtype="int64"), name="head_size"), - numpy_helper.from_array(np.array([hidden_size], dtype="int64"), name="hidden_size"), - numpy_helper.from_array(np.array(1 / np.sqrt(head_size), dtype="float32"), name="q_scale"), - numpy_helper.from_array(np.array(0, dtype="int64"), name="idx_0"), - numpy_helper.from_array(np.array(1, dtype="int64"), name="idx_1"), - numpy_helper.from_array(np.array([-1], dtype="int64"), name="neg_one"), - numpy_helper.from_array(np.array([0], dtype="int64"), name="unsqueeze_axes_input"), - ] - ) - - # Construct graph - graph = helper.make_graph( - nodes, "whisper_decoder_with_past_self_mha_graph", inputs, outputs, initializers, doc_string="whisper" - ) - opsetid = helper.make_opsetid("ai.onnx", min(onnx.defs.onnx_opset_version(), 16)) - return helper.make_model(graph, opset_imports=(opsetid,)) - - -def create_whisper_decoder_with_past_multihead_cross_attention( - hidden_size=768, num_heads=12, epsilon=0.000009999999747378752, fused=False -): - # Get head size and ensure head size is an integer - assert hidden_size % num_heads == 0 - head_size = hidden_size // num_heads - - # Construct input and output nodes - inputs = [ - helper.make_tensor_value_info("input_0", TensorProto.FLOAT, ["batch_size", 1500, hidden_size]), - helper.make_tensor_value_info( - "past_key_values.0.encoder.key", TensorProto.FLOAT, ["batch_size", num_heads, "past_seq_len", head_size] - ), - helper.make_tensor_value_info( - "past_key_values.0.encoder.value", TensorProto.FLOAT, ["batch_size", num_heads, "past_seq_len", head_size] - ), - ] - outputs = [ - helper.make_tensor_value_info("output_0", TensorProto.FLOAT, ["batch_size", 1500, hidden_size]), - helper.make_tensor_value_info("output_1", TensorProto.FLOAT, ["batch_size", "sequence_length", hidden_size]), - ] - - # Create SkipLayerNorm (since there's no Add + LayerNorm variant for this attention subgraph) - nodes = [ - helper.make_node( - "SkipLayerNormalization", - ["input_0", "input_0", "layernorm_weight", "layernorm_bias"], - ["layernorm_output_to_matmul", "", "", "layernorm_output_to_skiplayernorm"], - "skiplayernorm", - domain="com.microsoft", - epsilon=epsilon, - ) - ] - - if fused: - nodes.extend( - [ - helper.make_node( - "MatMul", ["layernorm_output_to_matmul", "q_weight"], ["q_matmul_output"], "q_path_matmul" - ), - helper.make_node( - "MultiHeadAttention", - [ - "q_matmul_output", - "past_key_values.0.encoder.key", - "past_key_values.0.encoder.value", - "Attention_0_qkv_bias", - ], - ["attn_output"], - "Attention_0", - domain="com.microsoft", - num_heads=num_heads, - ), - ] - ) - else: - # Create nodes for Q/K/V paths - q_nodes = [ - helper.make_node( - "MatMul", ["layernorm_output_to_matmul", "q_weight"], ["q_matmul_output"], "q_path_matmul" - ), - helper.make_node("Add", ["q_bias", "q_matmul_output"], ["q_add_output"], "q_path_add"), - helper.make_node("Mul", ["q_add_output", "q_scale"], ["q_mul_output"], "q_path_mul"), - helper.make_node("Reshape", ["q_mul_output", "q_bsnh_reshape"], ["q_4d_bsnh"], "q_reshape_to_4d"), - helper.make_node("Transpose", ["q_4d_bsnh"], ["q_4d_bnsh"], "q_transpose_to_bnsh", perm=[0, 2, 1, 3]), - helper.make_node( - "Reshape", - ["q_4d_bnsh", "q_attn_heads_output"], - ["q_output_(num_heads*batch_size,seq_len,head_size)"], - "q_reshape_to_3d", - ), - ] - k_nodes = [ - helper.make_node( - "Reshape", - ["past_key_values.0.encoder.key", "k_attn_heads_output"], - ["k_output_(num_heads*batch_size,seq_len,head_size)"], - "k_reshape_to_3d", - ), - helper.make_node( - "Transpose", - ["k_output_(num_heads*batch_size,seq_len,head_size)"], - ["k_output_(num_heads*batch_size,head_size,seq_len)"], - "k_transpose_last_two_dims", - perm=[0, 2, 1], - ), - ] - v_nodes = [ - helper.make_node( - "Reshape", - ["past_key_values.0.encoder.value", "v_attn_heads_output"], - ["v_output_(num_heads*batch_size,seq_len,head_size)"], - "v_reshape_to_3d", - ), - ] - nodes.extend(q_nodes) - nodes.extend(k_nodes) - nodes.extend(v_nodes) - - # Create nodes used with qkv concats, reshapes, and transposes - nodes.extend( - [ - helper.make_node("Shape", ["layernorm_output_to_matmul"], ["shape_output"], "shape"), - helper.make_node("Gather", ["shape_output", "idx_0"], ["gather_0_output"], "gather_0", axis=0), - helper.make_node( - "Mul", ["gather_0_output", "num_heads_int"], ["mul_attn_heads_output"], "mul_num_heads" - ), - helper.make_node( - "Unsqueeze", - ["mul_attn_heads_output", "unsqueeze_axes_input"], - ["unsqueeze_attn_heads_output"], - "unsqueeze_num_heads", - ), - helper.make_node( - "Concat", - ["unsqueeze_attn_heads_output", "neg_one", "head_size"], - ["q_attn_heads_output"], - "q_num_heads", - axis=0, - ), - helper.make_node( - "Concat", - ["unsqueeze_attn_heads_output", "neg_one", "head_size"], - ["k_attn_heads_output"], - "k_num_heads", - axis=0, - ), - helper.make_node( - "Concat", - ["unsqueeze_attn_heads_output", "neg_one", "head_size"], - ["v_attn_heads_output"], - "v_num_heads", - axis=0, - ), - helper.make_node( - "Constant", - inputs=[], - outputs=["q_bsnh_reshape"], - value=numpy_helper.from_array( - np.array([0, 0, num_heads, head_size], dtype="int64"), name="const_tensor" - ), - ), - ] - ) - - # Create nodes used with Q x K' and softmax(Q x K') x V - nodes.extend( - [ - helper.make_node("Gather", ["shape_output", "idx_1"], ["gather_1_output"], "gather_1", axis=0), - helper.make_node( - "Unsqueeze", ["gather_0_output", "unsqueeze_axes_input"], ["unsqueeze_0_output"], "unsqueeze_0" - ), - helper.make_node( - "Unsqueeze", ["gather_1_output", "unsqueeze_axes_input"], ["unsqueeze_1_output"], "unsqueeze_1" - ), - helper.make_node( - "Concat", - ["unsqueeze_0_output", "num_heads", "unsqueeze_1_output", "head_size"], - ["bnsh_format"], - axis=0, - ), - helper.make_node( - "Concat", - ["unsqueeze_0_output", "unsqueeze_1_output", "hidden_size"], - ["bsd_format"], - axis=0, - ), - ] - ) - - # Create nodes for computing softmax(Q x K') x V - nodes.extend( - [ - helper.make_node( - "MatMul", - [ - "q_output_(num_heads*batch_size,seq_len,head_size)", - "k_output_(num_heads*batch_size,head_size,seq_len)", - ], - ["qk_output_(num_heads*batch_size,seq_len,seq_len)"], - "matmul_qk", - ), - helper.make_node( - "Softmax", - ["qk_output_(num_heads*batch_size,seq_len,seq_len)"], - ["softmax_output"], - "softmax_qk", - axis=2, - ), - helper.make_node( - "MatMul", - ["softmax_output", "v_output_(num_heads*batch_size,seq_len,head_size)"], - ["qkv_output_(num_heads*batch_size,seq_len,head_size)"], - "matmul_qkv", - ), - helper.make_node( - "Reshape", - ["qkv_output_(num_heads*batch_size,seq_len,head_size)", "bnsh_format"], - ["qkv_bnsh"], - "reshape_qkv_to_bnsh", - ), - helper.make_node("Transpose", ["qkv_bnsh"], ["qkv_bsnh"], "transpose_bnsh_to_bsnh", perm=[0, 2, 1, 3]), - helper.make_node("Reshape", ["qkv_bsnh", "bsd_format"], ["attn_output"], "qkv_bsd"), - ] - ) - - # Create final nodes to conclude attention - nodes.append( - helper.make_node( - "MatMul", - ["attn_output", "matmul_after_attn_initializer"], - ["matmul_after_attn_output"], - "matmul_after_attn", - ), - ) - if not fused: - next_sln_inputs = [ - "layernorm_output_to_skiplayernorm", - "add_after_attn_output", - "layernorm_weight", - "layernorm_bias", - ] - nodes.extend( - [ - helper.make_node( - "Add", - ["add_after_attn_initializer", "matmul_after_attn_output"], - ["add_after_attn_output"], - "add_after_attn", - ), - helper.make_node( - "SkipLayerNormalization", - next_sln_inputs, - ["output_0", "", "", "output_1"], - "next_skiplayernorm", - domain="com.microsoft", - epsilon=epsilon, - ), - ] - ) - else: - next_sln_inputs = [ - "matmul_after_attn_output", - "layernorm_output_to_skiplayernorm", - "layernorm_weight", - "layernorm_bias", - "add_after_attn_initializer", - ] - nodes.append( - helper.make_node( - "SkipLayerNormalization", - next_sln_inputs, - ["output_0", "", "", "output_1"], - "SkipLayerNorm_AddBias_0", - domain="com.microsoft", - epsilon=epsilon, - ) - ) - - # Create initializers - q_weight, q_weight_data = get_tensor_and_weight("q_weight", [hidden_size, hidden_size]) - q_bias, q_bias_data = get_tensor_and_weight("q_bias", [hidden_size]) - k_bias, k_bias_data = get_tensor_and_weight("k_bias", [hidden_size], zeros=True) - v_bias, v_bias_data = get_tensor_and_weight("v_bias", [hidden_size], zeros=True) - qkv_bias = helper.make_tensor( - "Attention_0_qkv_bias", TensorProto.FLOAT, [3 * hidden_size], q_bias_data + k_bias_data + v_bias_data - ) - initializers = [ - float_tensor("layernorm_weight", [hidden_size]), - float_tensor("layernorm_bias", [hidden_size]), - float_tensor("matmul_after_attn_initializer", [hidden_size, hidden_size]), - float_tensor("add_after_attn_initializer", [hidden_size]), - q_weight, - ] - - if fused: - # Add packed QKV bias tensor as initializer - initializers.append(qkv_bias) - else: - # Add Q bias tensor as initializer - initializers.append(q_bias) - - initializers.extend( - [ - numpy_helper.from_array(np.array(num_heads, dtype="int64"), name="num_heads_int"), - numpy_helper.from_array(np.array([num_heads], dtype="int64"), name="num_heads"), - numpy_helper.from_array(np.array([head_size], dtype="int64"), name="head_size"), - numpy_helper.from_array(np.array([hidden_size], dtype="int64"), name="hidden_size"), - numpy_helper.from_array(np.array(1 / np.sqrt(head_size), dtype="float32"), name="q_scale"), - numpy_helper.from_array(np.array(0, dtype="int64"), name="idx_0"), - numpy_helper.from_array(np.array(1, dtype="int64"), name="idx_1"), - numpy_helper.from_array(np.array([-1], dtype="int64"), name="neg_one"), - numpy_helper.from_array(np.array([0], dtype="int64"), name="unsqueeze_axes_input"), - ] - ) - - # Construct graph - graph = helper.make_graph( - nodes, "whisper_decoder_with_past_cross_mha_graph", inputs, outputs, initializers, doc_string="whisper" - ) - opsetid = helper.make_opsetid("ai.onnx", min(onnx.defs.onnx_opset_version(), 16)) - return helper.make_model(graph, opset_imports=(opsetid,)) - - -if __name__ == "__main__": - np.random.seed(2) - num_heads = 4 - hidden_size = 64 - - model = create_whisper_encoder_attention(num_heads=num_heads, hidden_size=hidden_size) - onnx.save(model, "whisper_encoder_attention_sln.onnx") - - model = create_whisper_encoder_attention(num_heads=num_heads, hidden_size=hidden_size, fused=True) - onnx.save(model, "./test_data/models/whisper/encoder_attention_with_sln_fused.onnx") - - model = create_whisper_decoder_attention(num_heads=num_heads, hidden_size=hidden_size) - onnx.save(model, "whisper_decoder_attention_sln.onnx") - - model = create_whisper_decoder_attention(num_heads=num_heads, hidden_size=hidden_size, fused=True) - onnx.save(model, "./test_data/models/whisper/decoder_attention_with_sln_fused.onnx") - - model = create_whisper_decoder_multihead_attention(num_heads=num_heads, hidden_size=hidden_size) - onnx.save(model, "whisper_decoder_mha.onnx") - - model = create_whisper_decoder_multihead_attention(num_heads=num_heads, hidden_size=hidden_size, fused=True) - onnx.save(model, "./test_data/models/whisper/decoder_mha_fused.onnx") - - model = create_whisper_decoder_with_past_multihead_self_attention(num_heads=num_heads, hidden_size=hidden_size) - onnx.save(model, "whisper_decoder_with_past_self_mha.onnx") - - model = create_whisper_decoder_with_past_multihead_self_attention( - num_heads=num_heads, hidden_size=hidden_size, fused=True - ) - onnx.save(model, "./test_data/models/whisper/decoder_with_past_self_mha_fused.onnx") - - model = create_whisper_decoder_with_past_multihead_cross_attention(num_heads=num_heads, hidden_size=hidden_size) - onnx.save(model, "whisper_decoder_with_past_cross_mha.onnx") - - model = create_whisper_decoder_with_past_multihead_cross_attention( - num_heads=num_heads, hidden_size=hidden_size, fused=True - ) - onnx.save(model, "./test_data/models/whisper/decoder_with_past_cross_mha_fused.onnx") diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index 6460e3cb3aec4..b49c0bad711e1 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -3931,7 +3931,7 @@ TEST_P(CApiTensorRTTest, TestConfigureTensorRTProviderOptions) { * The TensorrtExecutionProviderOptionsTest can be used to test TRT options */ INSTANTIATE_TEST_SUITE_P(CApiTensorRTTest, CApiTensorRTTest, - ::testing::Values("trt_build_heuristics_enable=1", "trt_sparsity_enable=1", "trt_builder_optimization_level=0", "trt_tactic_sources=-CUDNN,+CUBLAS", "trt_auxiliary_streams=2")); + ::testing::Values("trt_build_heuristics_enable=1", "trt_sparsity_enable=1", "trt_builder_optimization_level=0", "trt_tactic_sources=-CUDNN,+CUBLAS", "trt_auxiliary_streams=2", "trt_bf16_enable=1")); #endif #ifdef USE_CUDA diff --git a/requirements-dev.txt b/requirements-dev.txt index b95b85781a398..e89edaa33e98e 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -2,6 +2,7 @@ cerberus flatbuffers jinja2 +markupsafe numpy onnx onnxmltools diff --git a/requirements-lintrunner.txt b/requirements-lintrunner.txt index 37df56216bde9..ed3471bdb47a9 100644 --- a/requirements-lintrunner.txt +++ b/requirements-lintrunner.txt @@ -3,6 +3,6 @@ lintrunner==0.12.7 lintrunner-adapters==0.12.4 # RUFF -ruff==0.11.9 +ruff==0.11.11 # CLANGFORMAT clang-format==19.1.7 diff --git a/setup.py b/setup.py index c45657c0c2873..6a1f126476158 100644 --- a/setup.py +++ b/setup.py @@ -324,6 +324,7 @@ def finalize_options(self): if platform.system() == "Linux": providers_cuda_or_rocm = "lib" + providers_cuda_or_rocm + ".so" providers_tensorrt_or_migraphx = "lib" + providers_tensorrt_or_migraphx + ".so" + providers_nv_tensorrt_rtx = "lib" + providers_nv_tensorrt_rtx + ".so" providers_openvino = "lib" + providers_openvino + ".so" providers_cann = "lib" + providers_cann + ".so" providers_qnn = "lib" + providers_qnn + ".so" @@ -361,6 +362,7 @@ def finalize_options(self): libs.extend(["libonnxruntime_providers_openvino.so"]) libs.extend(["libonnxruntime_providers_vitisai.so"]) libs.append(providers_cuda_or_rocm) + libs.append(providers_nv_tensorrt_rtx) libs.append(providers_tensorrt_or_migraphx) libs.append(providers_cann) libs.append(providers_qnn) @@ -512,6 +514,7 @@ def finalize_options(self): "onnxruntime.tools.ort_format_model.ort_flatbuffers_py", "onnxruntime.tools.ort_format_model.ort_flatbuffers_py.fbs", "onnxruntime.tools.qdq_helpers", + "onnxruntime.tools.qnn", "onnxruntime.quantization", "onnxruntime.quantization.operators", "onnxruntime.quantization.CalTableFlatBuffers", diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 8dce6be731402..82372645d364f 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -318,7 +318,6 @@ def generate_vcpkg_install_options(build_dir, args): elif "RUNNER_TEMP" in os.environ: temp_dir = os.environ["RUNNER_TEMP"] vcpkg_install_options.append(f"--x-buildtrees-root={temp_dir}") - vcpkg_install_options.append("--binarysource=clear\\;x-gha,readwrite") # Config asset cache if args.use_vcpkg_ms_internal_asset_cache: diff --git a/tools/ci_build/build_args.py b/tools/ci_build/build_args.py index 215ad77335083..807c8b327c780 100644 --- a/tools/ci_build/build_args.py +++ b/tools/ci_build/build_args.py @@ -342,7 +342,7 @@ def add_webassembly_args(parser: argparse.ArgumentParser) -> None: """Adds arguments for WebAssembly (WASM) platform builds.""" parser.add_argument("--build_wasm", action="store_true", help="Build for WebAssembly.") parser.add_argument("--build_wasm_static_lib", action="store_true", help="Build WebAssembly static library.") - parser.add_argument("--emsdk_version", default="4.0.4", help="Specify version of emsdk.") + parser.add_argument("--emsdk_version", default="4.0.8", help="Specify version of emsdk.") parser.add_argument("--enable_wasm_simd", action="store_true", help="Enable WebAssembly SIMD.") parser.add_argument("--enable_wasm_relaxed_simd", action="store_true", help="Enable WebAssembly Relaxed SIMD.") parser.add_argument("--enable_wasm_threads", action="store_true", help="Enable WebAssembly multi-threading.") diff --git a/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml index ba6a33b07e765..ab10bdfba0e0f 100644 --- a/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml @@ -32,7 +32,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: 2.33.2.250410 + default: 2.34.0.250424 jobs: - job: Build_QNN_EP diff --git a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml index cf213c47195c4..7f8039d237731 100644 --- a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml +++ b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml @@ -60,7 +60,7 @@ parameters: - name: QnnSdk displayName: QNN SDK Version type: string - default: 2.33.0.250327 + default: 2.34.0.250424 resources: repositories: @@ -155,7 +155,7 @@ extends: IsReleaseBuild: ${{ parameters.IsReleaseBuild }} ArtifactName: 'drop-onnxruntime-nodejs-win-x64' StageName: 'Windows_Nodejs_Packaging_x64' - BuildCommand: --skip_submodule_sync --build_shared_lib --enable_onnx_tests --enable_wcos --use_telemetry --use_dml --use_webgpu --build_nodejs --cmake_generator "Visual Studio 17 2022" + BuildCommand: --skip_submodule_sync --build_shared_lib --enable_onnx_tests --enable_wcos --use_telemetry --use_dml --use_webgpu --build_nodejs --cmake_generator "Visual Studio 17 2022" --enable_generic_interface BuildArch: 'x64' EnvSetupScript: 'setup_env.bat' sln_platform: 'x64' @@ -167,7 +167,7 @@ extends: IsReleaseBuild: ${{ parameters.IsReleaseBuild }} ArtifactName: 'drop-onnxruntime-nodejs-win-arm64' StageName: 'Windows_Nodejs_Packaging_arm64' - BuildCommand: --arm64 --skip_submodule_sync --build_shared_lib --enable_onnx_tests --enable_wcos --use_telemetry --use_dml --use_webgpu --build_nodejs --cmake_generator "Visual Studio 17 2022" + BuildCommand: --arm64 --skip_submodule_sync --build_shared_lib --enable_onnx_tests --enable_wcos --use_telemetry --use_dml --use_webgpu --build_nodejs --cmake_generator "Visual Studio 17 2022" --enable_generic_interface BuildArch: 'x64' EnvSetupScript: 'setup_env.bat' sln_platform: 'arm64' diff --git a/tools/ci_build/github/azure-pipelines/custom-nuget-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/custom-nuget-packaging-pipeline.yml index b1a7c92dc3529..6ee64e4870fd5 100644 --- a/tools/ci_build/github/azure-pipelines/custom-nuget-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/custom-nuget-packaging-pipeline.yml @@ -6,7 +6,7 @@ parameters: - name: QnnSdk displayName: QNN SDK Version type: string - default: 2.33.2.250410 + default: 2.34.0.250424 - name: IsReleaseBuild displayName: Is a release build? Set it to true if you are doing an Onnx Runtime release. diff --git a/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml index f08fd70d6d6cf..580f565310661 100644 --- a/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml @@ -33,7 +33,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: 2.33.2.250410 + default: 2.34.0.250424 jobs: - job: Build_QNN_EP diff --git a/tools/ci_build/github/azure-pipelines/nuget/templates/dml-vs-2022.yml b/tools/ci_build/github/azure-pipelines/nuget/templates/dml-vs-2022.yml index 2a09eba776353..a87b85eaac256 100644 --- a/tools/ci_build/github/azure-pipelines/nuget/templates/dml-vs-2022.yml +++ b/tools/ci_build/github/azure-pipelines/nuget/templates/dml-vs-2022.yml @@ -151,7 +151,7 @@ stages: - ${{if or(eq(variables['Build.SourceBranch'], 'refs/heads/main'), startsWith(variables['Build.SourceBranch'], 'refs/heads/rel-'))}}: - - template: publish-symbolrequestprod-api.yml + - template: ../../templates/publish-symbolrequestprod-api.yml parameters: ${{if eq(variables['Build.SourceBranch'], 'refs/heads/main')}}: symbolExpiryTime: 60 diff --git a/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml index d19f9bde7ad75..035b4b6c17222 100644 --- a/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml @@ -59,7 +59,7 @@ parameters: - name: qnn_sdk_version type: string displayName: 'QNN SDK version. Only for QNN packages.' - default: 2.33.2.250410 + default: 2.34.0.250424 trigger: none diff --git a/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml index 722a3162cfed8..63fb41ab24c68 100644 --- a/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml @@ -2,7 +2,7 @@ parameters: - name: QnnSdk displayName: QNN SDK Version type: string - default: 2.33.2.250410 + default: 2.34.0.250424 - name: build_config displayName: Build Configuration diff --git a/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml index 9928a68b6df06..84445b117b495 100644 --- a/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml @@ -59,7 +59,7 @@ parameters: - name: qnn_sdk_version type: string displayName: 'QNN SDK version. Only for QNN packages.' - default: 2.33.2.250410 + default: 2.34.0.250424 stages: - ${{ if eq(parameters.enable_windows_cpu, true) }}: diff --git a/tools/ci_build/github/azure-pipelines/stages/py-win-gpu-stage.yml b/tools/ci_build/github/azure-pipelines/stages/py-win-gpu-stage.yml index 0a88391dd4ad6..de0a8f10b82be 100644 --- a/tools/ci_build/github/azure-pipelines/stages/py-win-gpu-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/py-win-gpu-stage.yml @@ -140,7 +140,6 @@ stages: ${{if eq(variables['Build.SourceBranch'], 'refs/heads/main')}}: symbolExpiryTime: 60 includePublicSymbolServer: true - symbolsFolder: '$(Build.BinariesDirectory)\${{ parameters.cmake_build_type }}\${{ parameters.cmake_build_type }}' symbolsArtifactName: onnxruntime_gpu_win_x64_${{ parameters.PYTHON_VERSION }} symbolsVersion: $(Build.BuildId) symbolProject: 'ONNX Runtime' diff --git a/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar-test.yml b/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar-test.yml index d1fa72d7e4413..0c70a4f82c566 100644 --- a/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar-test.yml +++ b/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar-test.yml @@ -19,7 +19,7 @@ parameters: - name: QnnSDKVersion displayName: QNN SDK Version type: string - default: '2.33.0.250327' + default: '2.34.0.250424' - name: enableWebGpu displayName: Enable WebGPU test diff --git a/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar.yml b/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar.yml index 4474a6b45ef58..c94969d9e9d41 100644 --- a/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar.yml +++ b/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar.yml @@ -53,7 +53,7 @@ parameters: - name: QnnSDKVersion displayName: QNN SDK Version type: string - default: '2.33.0.250327' + default: '2.34.0.250424' - name: is1ES displayName: Is 1ES pipeline diff --git a/tools/ci_build/github/azure-pipelines/templates/c-api-artifacts-package-and-publish-steps-windows.yml b/tools/ci_build/github/azure-pipelines/templates/c-api-artifacts-package-and-publish-steps-windows.yml index adf9c91e602a0..72343613d6b26 100644 --- a/tools/ci_build/github/azure-pipelines/templates/c-api-artifacts-package-and-publish-steps-windows.yml +++ b/tools/ci_build/github/azure-pipelines/templates/c-api-artifacts-package-and-publish-steps-windows.yml @@ -119,6 +119,12 @@ steps: DoEsrp: ${{parameters.DoEsrp}} Pattern: '*.dll,*.exe' + - task: DeleteFiles@1 + displayName: 'Delete CodeSignSummary*.md' + inputs: + SourceFolder: '$(Build.BinariesDirectory)\${{parameters.artifactName}}' + Contents: 'CodeSignSummary*.md' + - task: ArchiveFiles@2 inputs: rootFolderOrFile: '$(Build.BinariesDirectory)\${{parameters.artifactName}}' diff --git a/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml b/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml index 9f65fc8891e94..a6cf5b9a7713e 100644 --- a/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml +++ b/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml @@ -47,7 +47,7 @@ parameters: - name: QnnSDKVersion displayName: QNN SDK Version type: string - default: 2.33.0.250327 + default: 2.34.0.250424 - name: is1ES displayName: Is 1ES pipeline diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/download_linux_qnn_sdk.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/download_linux_qnn_sdk.yml index e00e40b80b723..01dbfc5292aa9 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jobs/download_linux_qnn_sdk.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jobs/download_linux_qnn_sdk.yml @@ -1,7 +1,7 @@ parameters: - name: QnnSDKVersion type: string - default: '2.33.2.250410' + default: '2.34.0.250424' steps: - script: | diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_qnn_sdk.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_qnn_sdk.yml index 3b27060b3fcec..13cc9314caf77 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_qnn_sdk.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_qnn_sdk.yml @@ -1,7 +1,7 @@ parameters: - name: QnnSDKVersion type: string - default: '2.33.2.250410' + default: '2.34.0.250424' steps: - powershell: | diff --git a/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml b/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml index 0e60bf8e2e26d..aa434699fbe02 100644 --- a/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml @@ -88,15 +88,15 @@ jobs: - script: | set -ex cd '$(Build.SourcesDirectory)/cmake/external/emsdk' - ./emsdk install 4.0.4 ccache-git-emscripten-64bit - ./emsdk activate 4.0.4 ccache-git-emscripten-64bit + ./emsdk install 4.0.8 ccache-git-emscripten-64bit + ./emsdk activate 4.0.8 ccache-git-emscripten-64bit displayName: 'emsdk install and activate ccache for emscripten' - ${{if eq(parameters.WithCache, false)}}: - script: | set -ex cd '$(Build.SourcesDirectory)/cmake/external/emsdk' - ./emsdk install 4.0.4 - ./emsdk activate 4.0.4 + ./emsdk install 4.0.8 + ./emsdk activate 4.0.8 displayName: 'emsdk install and activate ccache for emscripten' - template: build-linux-wasm-step.yml diff --git a/tools/ci_build/github/azure-pipelines/templates/publish-symbolrequestprod-api.yml b/tools/ci_build/github/azure-pipelines/templates/publish-symbolrequestprod-api.yml index b2a3eaca0280f..9f0230c4b1141 100644 --- a/tools/ci_build/github/azure-pipelines/templates/publish-symbolrequestprod-api.yml +++ b/tools/ci_build/github/azure-pipelines/templates/publish-symbolrequestprod-api.yml @@ -1,4 +1,3 @@ -# This file was copied from https://github.com/microsoft/devhome/blob/main/build/templates/publish-symbolrequestprod-api.yml#L71 parameters: - name: includePublicSymbolServer type: boolean @@ -32,18 +31,6 @@ steps: Install-Module -Verbose -AllowClobber -Force Az.Accounts, Az.Storage, Az.Network, Az.Resources, Az.Compute displayName: Install Azure Module Dependencies - # Transit the Azure token from the Service Connection into a secret variable for the rest of the pipeline to use. - - task: AzurePowerShell@5 - displayName: Generate an Azure Token - inputs: - azureSubscription: ${{ parameters.subscription }} - azurePowerShellVersion: LatestVersion - pwsh: true - ScriptType: InlineScript - Inline: |- - $AzToken = (Get-AzAccessToken -ResourceUrl api://30471ccf-0966-45b9-a979-065dbedb24c1).Token - Write-Host "##vso[task.setvariable variable=SymbolAccessToken;issecret=true]$AzToken" - - task: PublishSymbols@2 displayName: Publish Symbols (to current Azure DevOps tenant) continueOnError: True @@ -60,28 +47,76 @@ steps: env: LIB: $(Build.SourcesDirectory) - - pwsh: |- - # Prepare the defaults for IRM - $PSDefaultParameterValues['Invoke-RestMethod:Headers'] = @{ Authorization = "Bearer $(SymbolAccessToken)" } - $PSDefaultParameterValues['Invoke-RestMethod:ContentType'] = "application/json" - $PSDefaultParameterValues['Invoke-RestMethod:Method'] = "POST" + - task: AzurePowerShell@5 + displayName: Generate Token and Publish Symbols via REST API + inputs: + azureSubscription: ${{ parameters.subscription }} + azurePowerShellVersion: LatestVersion + pwsh: true + ScriptType: InlineScript + Inline: | + # Part 1: Generate an Azure Token + Write-Host "Attempting to retrieve Azure Access Token for symbol publishing API." + $apiResourceUrl = "api://30471ccf-0966-45b9-a979-065dbedb24c1" + try { + $secureTokenObject = (Get-AzAccessToken -ResourceUrl $apiResourceUrl).Token + Write-Host "Successfully retrieved a token object." + } + catch { + Write-Error "Failed to retrieve Azure Access Token. Error: $($_.Exception.Message)" + throw "Failed to retrieve Azure Access Token." # Fail the task + } + + # Convert the SecureString token to a plain text string for the HTTP header + # This is done just-in-time before its use. + $plainTextToken = $secureTokenObject | ConvertFrom-SecureString -AsPlainText + Write-Host "Token converted to plain text for API call (will not be logged)." + + # Part 2: Publish Symbols using internal REST API + Write-Host "Preparing to publish symbols using internal REST API." + + # Prepare the defaults for Invoke-RestMethod for this scope + $PSDefaultParameterValues = @{} # Initialize to ensure a clean state for default parameters + $PSDefaultParameterValues['Invoke-RestMethod:Headers'] = @{ Authorization = "Bearer $plainTextToken" } + $PSDefaultParameterValues['Invoke-RestMethod:ContentType'] = "application/json" + $PSDefaultParameterValues['Invoke-RestMethod:Method'] = "POST" # Default method for symbol request creation/update + + $baseUri = "https://symbolrequestprod.trafficmanager.net/projects/${{ parameters.symbolProject }}/requests" + + # Prepare and submit the symbol request creation + $expirationDate = (Get-Date).Add([TimeSpan]::FromDays(${{ parameters.symbolExpiryTime }})) + $createRequestBody = @{ + requestName = "${{ parameters.symbolsArtifactName }}_${{ parameters.symbolsVersion }}"; + expirationTime = $expirationDate.ToString(); + } + $requestNameForUri = $createRequestBody.requestName # Store for use in the next URI + + Write-Host "##[debug]Creating symbol request: Name '$($createRequestBody.requestName)', Expiration '$($createRequestBody.expirationTime)'. URI: '$baseUri'" + try { + Invoke-RestMethod -Uri $baseUri -Body ($createRequestBody | ConvertTo-Json -Compress) -Verbose + Write-Host "Successfully initiated symbol request '$($createRequestBody.requestName)'." + } + catch { + Write-Error "Failed to create symbol request. Error: $($_.Exception.Message)" + # Optionally inspect response: $_.ErrorDetails.Message or $_.Exception.Response + throw "Failed to create symbol request." + } - $BaseUri = "https://symbolrequestprod.trafficmanager.net/projects/${{ parameters.symbolProject }}/requests" + # Prepare and submit the symbol publication details + $publishRequestBody = @{ + publishToInternalServer = $true; + publishToPublicServer = [System.Convert]::ToBoolean("${{ parameters.includePublicSymbolServer }}"); # Ensure YAML boolean is correctly PowerShell boolean + } + $publishUri = "$baseUri/$requestNameForUri" - # Prepare the request - $expiration = (Get-Date).Add([TimeSpan]::FromDays(${{ parameters.symbolExpiryTime }})) - $createRequestBody = @{ - requestName = "${{ parameters.symbolsArtifactName }}_${{ parameters.symbolsVersion }}"; - expirationTime = $expiration.ToString(); - } - Write-Host "##[debug]Starting request $($createRequestBody.requestName) with expiration date of $($createRequestBody.expirationTime)" - Invoke-RestMethod -Uri "$BaseUri" -Body ($createRequestBody | ConvertTo-Json -Compress) -Verbose + Write-Host "##[debug]Submitting symbol publication details for request '$requestNameForUri'. URI: '$publishUri'. Payload: $($publishRequestBody | ConvertTo-Json -Compress)" + try { + Invoke-RestMethod -Uri $publishUri -Body ($publishRequestBody | ConvertTo-Json -Compress) -Verbose + Write-Host "Successfully submitted symbol publication details for '$requestNameForUri'." + } + catch { + Write-Error "Failed to submit symbol publication details. Error: $($_.Exception.Message)" + throw "Failed to submit symbol publication details." + } - # Request symbol publication - $publishRequestBody = @{ - publishToInternalServer = $true; - publishToPublicServer = $${{ parameters.includePublicSymbolServer }}; - } - Write-Host "##[debug]Submitting request $($createRequestBody.requestName) ($($publishRequestBody | ConvertTo-Json -Compress))" - Invoke-RestMethod -Uri "$BaseUri/$($createRequestBody.requestName)" -Body ($publishRequestBody | ConvertTo-Json -Compress) -Verbose - displayName: Publish Symbols using internal REST API + Write-Host "Symbol publishing process via REST API completed for '$requestNameForUri'." diff --git a/tools/ci_build/github/azure-pipelines/templates/py-linux-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-linux-qnn.yml index c361fe678699e..a0bfd6a46a43c 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-linux-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-linux-qnn.yml @@ -26,7 +26,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: 2.33.2.250410 + default: 2.34.0.250424 - name: is1ES displayName: 'Whether the pipeline is running in 1ES' diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml index c1f47de63c38c..d28b3e9604c5d 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml @@ -7,7 +7,7 @@ parameters: - name: QNN_SDK displayName: QNN SDK Version type: string - default: 2.33.2.250410 + default: 2.34.0.250424 - name: ENV_SETUP_SCRIPT type: string diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml index 1a00d67bdbb2a..f300d845579bf 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml @@ -7,7 +7,7 @@ parameters: - name: QNN_SDK displayName: QNN SDK Version type: string - default: 2.33.2.250410 + default: 2.34.0.250424 - name: ENV_SETUP_SCRIPT type: string diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml index 72c8323d032ed..ce22142e6c5bd 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml @@ -7,7 +7,7 @@ parameters: - name: QNN_SDK displayName: QNN SDK Version type: string - default: 2.33.2.250410 + default: 2.34.0.250424 - name: ENV_SETUP_SCRIPT type: string diff --git a/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml b/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml index d739724f8744a..0b8c493ae124d 100644 --- a/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml +++ b/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml @@ -1,5 +1,5 @@ parameters: - QnnSdk: '2.33.2.250410' + QnnSdk: '2.34.0.250424' build_config: 'RelWithDebInfo' IsReleaseBuild: false DoEsrp: false @@ -125,4 +125,4 @@ stages: displayName: 'Publish Pipeline Qnn NuGet Artifact' inputs: artifactName: 'drop-signed-nuget-qnn' - targetPath: '$(Build.ArtifactStagingDirectory)' \ No newline at end of file + targetPath: '$(Build.ArtifactStagingDirectory)' diff --git a/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml index 93a9909e529f8..9c06edb4d03e8 100644 --- a/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml @@ -33,7 +33,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: 2.33.2.250410 + default: 2.34.0.250424 jobs: - job: 'BUILD_QNN_EP' diff --git a/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml index b83621d285f9a..3b41394b97bd3 100644 --- a/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml @@ -33,7 +33,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: 2.33.2.250410 + default: 2.34.0.250424 jobs: - job: 'BUILD_QNN_EP' diff --git a/tools/ci_build/github/linux/docker/scripts/requirements.txt b/tools/ci_build/github/linux/docker/scripts/requirements.txt index ee5cedb73ff04..7c1731aef992d 100644 --- a/tools/ci_build/github/linux/docker/scripts/requirements.txt +++ b/tools/ci_build/github/linux/docker/scripts/requirements.txt @@ -3,7 +3,7 @@ numpy==1.24.4 ; python_version < '3.9' numpy==2.1.2; python_version >= '3.9' mypy pytest -setuptools==69.0.3 +setuptools==78.1.1 wheel==0.42.0 onnx==1.17.0 ; python_version < '3.13' argparse diff --git a/tools/ci_build/github/linux/python/requirements.txt b/tools/ci_build/github/linux/python/requirements.txt index 3ca025514ea3d..f499dae947b4f 100644 --- a/tools/ci_build/github/linux/python/requirements.txt +++ b/tools/ci_build/github/linux/python/requirements.txt @@ -9,3 +9,5 @@ sympy==1.12 flatbuffers psutil onnxscript==0.2.3 ; python_version < '3.13' +jinja2 +markupsafe diff --git a/tools/ci_build/github/windows/python/requirements.txt b/tools/ci_build/github/windows/python/requirements.txt index 2b222c4b1d4a4..d292f6edacde2 100644 --- a/tools/ci_build/github/windows/python/requirements.txt +++ b/tools/ci_build/github/windows/python/requirements.txt @@ -9,3 +9,5 @@ sympy==1.12 flatbuffers psutil onnxscript==0.2.3 +jinja2 +markupsafe